Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- ms-swift/processed_data/processed_overlap5s_speaker_segments.json +0 -0
- ms-swift/processed_data/processed_silence_isoverlaps.json +0 -0
- ms-swift/silence_overlaps/700/test/overlap5s_segments_test.json +27 -0
- ms-swift/silence_overlaps/700/test/overlap5s_silence_segments_test.json +27 -0
- ms-swift/silence_overlaps/700/train/overlap5s_issilence_segments_train.json +0 -0
- ms-swift/silence_overlaps/test/test_train.json +963 -0
- ms-swift/swift/llm/sampling/mcts.py +400 -0
- ms-swift/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/valley.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/__pycache__/yi.cpython-310.pyc +0 -0
- ms-swift/swift/llm/template/template/deepseek.py +315 -0
- ms-swift/swift/llm/template/template/glm.py +293 -0
- ms-swift/swift/llm/template/template/internvl.py +168 -0
- ms-swift/swift/llm/template/template/llama.py +213 -0
- ms-swift/swift/llm/template/template/megrez.py +93 -0
- ms-swift/swift/llm/template/template/openbuddy.py +48 -0
- ms-swift/swift/llm/template/template/pixtral.py +59 -0
- ms-swift/swift/llm/template/template/qwen.py +671 -0
- ms-swift/swift/llm/template/template/stepfun.py +128 -0
- ms-swift/swift/llm/template/template/yi.py +63 -0
- ms-swift/swift/llm/train/__pycache__/callback.cpython-310.pyc +0 -0
- ms-swift/swift/llm/train/__pycache__/rlhf.cpython-310.pyc +0 -0
- ms-swift/swift/llm/train/__pycache__/sft.cpython-310.pyc +0 -0
- ms-swift/swift/llm/train/__pycache__/tuner.cpython-310.pyc +0 -0
- ms-swift/swift/llm/train/callback.py +80 -0
- ms-swift/swift/llm/train/rlhf.py +154 -0
- ms-swift/swift/llm/train/sft.py +287 -0
- ms-swift/swift/llm/train/tuner.py +424 -0
- ms-swift/swift/megatron/argument/train_args.py +53 -0
- ms-swift/swift/megatron/model/__init__.py +4 -0
- ms-swift/swift/megatron/model/config.py +57 -0
- ms-swift/swift/megatron/model/constant.py +3 -0
- ms-swift/swift/megatron/model/gpt/__init__.py +40 -0
- ms-swift/swift/megatron/model/gpt/config.py +13 -0
- ms-swift/swift/megatron/model/gpt/model.py +37 -0
- ms-swift/swift/megatron/model/register.py +47 -0
- ms-swift/swift/megatron/model/rope.py +40 -0
- ms-swift/swift/megatron/train/patcher.py +64 -0
- ms-swift/swift/megatron/utils/__init__.py +4 -0
- ms-swift/swift/megatron/utils/convert.py +122 -0
- ms-swift/swift/megatron/utils/patcher.py +26 -0
- ms-swift/swift/plugin/__pycache__/__init__.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/__pycache__/callback.cpython-310.pyc +0 -0
- ms-swift/swift/plugin/__pycache__/metric.cpython-310.pyc +0 -0
ms-swift/processed_data/processed_overlap5s_speaker_segments.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/processed_data/processed_silence_isoverlaps.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/700/test/overlap5s_segments_test.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"key": "SODA_PROCESSED--train--123906",
|
| 4 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--123906.wav",
|
| 5 |
+
"model_output": "Multiple speakers talk simultaneously from 00:03-00:09"
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"key": "SODA_PROCESSED--train--1112763",
|
| 9 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--1112763.wav",
|
| 10 |
+
"model_output": "Multiple speakers talk simultaneously from 00:09-00:15"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"key": "SODA_PROCESSED--train--790538",
|
| 14 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--790538.wav",
|
| 15 |
+
"model_output": "Multiple speakers talk simultaneously from 00:15-00:19"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"key": "SODA_PROCESSED--train--822773",
|
| 19 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--822773.wav",
|
| 20 |
+
"model_output": "Multiple speakers talk simultaneously from 00:14-00:19"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"key": "SODA_PROCESSED--train--424960",
|
| 24 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--424960.wav",
|
| 25 |
+
"model_output": "Multiple speakers talk simultaneously from 00:29-00:33"
|
| 26 |
+
}
|
| 27 |
+
]
|
ms-swift/silence_overlaps/700/test/overlap5s_silence_segments_test.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"key": "SODA_PROCESSED--train--137471",
|
| 4 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--137471.wav",
|
| 5 |
+
"model_output": "No, there is no silence gap."
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"key": "SODA_PROCESSED--train--201044",
|
| 9 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--201044.wav",
|
| 10 |
+
"model_output": "No, there is no silence gap."
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"key": "SODA_PROCESSED--train--596349",
|
| 14 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--596349.wav",
|
| 15 |
+
"model_output": "No, there is no silence gap."
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"key": "SODA_PROCESSED--train--956648",
|
| 19 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--956648.wav",
|
| 20 |
+
"model_output": "No, there is no silence gap."
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"key": "SODA_PROCESSED--train--962210",
|
| 24 |
+
"audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--962210.wav",
|
| 25 |
+
"model_output": "No, there is no silence gap."
|
| 26 |
+
}
|
| 27 |
+
]
|
ms-swift/silence_overlaps/700/train/overlap5s_issilence_segments_train.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ms-swift/silence_overlaps/test/test_train.json
ADDED
|
@@ -0,0 +1,963 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"SODA_PROCESSED--train--449689": {
|
| 3 |
+
"original_dialog_id": "",
|
| 4 |
+
"dialog_index": 449689,
|
| 5 |
+
"processed_dialogue": "A: Hey there. Mind if I lay down next to you? \nB: No, go ahead. \nA: Thanks. I needed a break from the sun. It's so hot today. \nB: Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage. \nA: Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax? \nB: Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week. \nA: That sounds rough. Are you excited for it? Or [interrupt] worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period? \nB: Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus. \nA: Definitely. Well, I hope you enjoy the rest of your day here. \nB: Thanks. You too.",
|
| 6 |
+
"clean_dialogue": "A: Hey there. Mind if I lay down next to you? \nB: No, go ahead. \nA: Thanks. I needed a break from the sun. It's so hot today. \nB: Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage. \nA: Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax? \nB: Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week. \nA:That sounds rough. Are you excited for it? Or worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?\nB: Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus. \nA: Definitely. Well, I hope you enjoy the rest of your day here. \nB: Thanks. You too.",
|
| 7 |
+
"speaker_tracks": {
|
| 8 |
+
"A": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/A_track.wav",
|
| 9 |
+
"B": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/B_track.wav"
|
| 10 |
+
},
|
| 11 |
+
"error_type": "error_after_interrupt",
|
| 12 |
+
"stereo_audio": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/stereo_dialogue.wav",
|
| 13 |
+
"total_duration": 50.09668934240363,
|
| 14 |
+
"segments": [
|
| 15 |
+
{
|
| 16 |
+
"speaker": "A",
|
| 17 |
+
"text": "Hey there. Mind if I lay down next to you?",
|
| 18 |
+
"original_text": "Hey there. Mind if I lay down next to you?",
|
| 19 |
+
"start_time": 0,
|
| 20 |
+
"end_time": 2.4961451247165534,
|
| 21 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_0_A.wav",
|
| 22 |
+
"silence_duration": 0,
|
| 23 |
+
"is_interrupted": false
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"speaker": "B",
|
| 27 |
+
"text": "No, go ahead.",
|
| 28 |
+
"original_text": "No, go ahead.",
|
| 29 |
+
"start_time": 3.0616233505922237,
|
| 30 |
+
"end_time": 4.257451014991316,
|
| 31 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_1_B.wav",
|
| 32 |
+
"silence_duration": 0.5654782258756702,
|
| 33 |
+
"is_interrupted": false
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"speaker": "A",
|
| 37 |
+
"text": "Thanks. I needed a break from the sun. It's so hot today.",
|
| 38 |
+
"original_text": "Thanks. I needed a break from the sun. It's so hot today.",
|
| 39 |
+
"start_time": 4.673061027457998,
|
| 40 |
+
"end_time": 8.666893227004483,
|
| 41 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_2_A.wav",
|
| 42 |
+
"silence_duration": 0.41561001246668183,
|
| 43 |
+
"is_interrupted": false
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"speaker": "B",
|
| 47 |
+
"text": "Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage.",
|
| 48 |
+
"original_text": "Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage.",
|
| 49 |
+
"start_time": 9.128191918953855,
|
| 50 |
+
"end_time": 19.01989259922596,
|
| 51 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_3_B.wav",
|
| 52 |
+
"silence_duration": 0.46129869194937123,
|
| 53 |
+
"is_interrupted": false
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"speaker": "A",
|
| 57 |
+
"text": "Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax?",
|
| 58 |
+
"original_text": "Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax?",
|
| 59 |
+
"start_time": 19.43691572474219,
|
| 60 |
+
"end_time": 27.215600531998426,
|
| 61 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_4_A.wav",
|
| 62 |
+
"silence_duration": 0.4170231255162265,
|
| 63 |
+
"is_interrupted": false
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"speaker": "B",
|
| 67 |
+
"text": "Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week.",
|
| 68 |
+
"original_text": "Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week.",
|
| 69 |
+
"start_time": 27.73206790619358,
|
| 70 |
+
"end_time": 34.08272550256547,
|
| 71 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_5_B.wav",
|
| 72 |
+
"silence_duration": 0.5164673741951538,
|
| 73 |
+
"is_interrupted": false
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"speaker": "A",
|
| 77 |
+
"text": "That sounds rough. Are you excited for it? Or",
|
| 78 |
+
"original_text": "That sounds rough. Are you excited for it? Or [interrupt] worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?",
|
| 79 |
+
"start_time": 34.40566150397062,
|
| 80 |
+
"end_time": 44.703711390591934,
|
| 81 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_6_A.wav",
|
| 82 |
+
"silence_duration": 0.3229360014051523,
|
| 83 |
+
"is_interrupted": true,
|
| 84 |
+
"text_after_interrupt": "worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?"
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"speaker": "B",
|
| 88 |
+
"text": "Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus.",
|
| 89 |
+
"original_text": "Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus.",
|
| 90 |
+
"start_time": 37.1456161524967,
|
| 91 |
+
"end_time": 44.564391662700785,
|
| 92 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_7_B.wav",
|
| 93 |
+
"silence_duration": 0.36321869535217244,
|
| 94 |
+
"is_interrupted": false
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"speaker": "A",
|
| 98 |
+
"text": "Definitely. Well, I hope you enjoy the rest of your day here.",
|
| 99 |
+
"original_text": "Definitely. Well, I hope you enjoy the rest of your day here.",
|
| 100 |
+
"start_time": 44.9023552612567,
|
| 101 |
+
"end_time": 48.78008768756056,
|
| 102 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_8_A.wav",
|
| 103 |
+
"silence_duration": 0.33796359855591646,
|
| 104 |
+
"is_interrupted": false
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"speaker": "B",
|
| 108 |
+
"text": "Thanks. You too.",
|
| 109 |
+
"original_text": "Thanks. You too.",
|
| 110 |
+
"start_time": 49.1679089027611,
|
| 111 |
+
"end_time": 50.09670708870214,
|
| 112 |
+
"audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_9_B.wav",
|
| 113 |
+
"silence_duration": 0.38782121520053575,
|
| 114 |
+
"is_interrupted": false
|
| 115 |
+
}
|
| 116 |
+
],
|
| 117 |
+
"gt_score": 1
|
| 118 |
+
},
|
| 119 |
+
"SODA_PROCESSED--train--787791": {
|
| 120 |
+
"original_dialog_id": "",
|
| 121 |
+
"dialog_index": 787791,
|
| 122 |
+
"processed_dialogue": "A: You're welcome. I'm just glad I was able to stop it from happening. \nB: Thank you so much for saving my life. I can't even begin to express how [interrupt] grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment. \nA: Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally? \nB: I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time. \nA: Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay? \nB: I really appreciate that. Thanks again, Antwain. \nA: No problem. Take care.",
|
| 123 |
+
"clean_dialogue": "A: You're welcome. I'm just glad I was able to stop it from happening. \nB:Thank you so much for saving my life. I can't even begin to express how grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment.\nA: Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally? \nB: I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time. \nA: Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay? \nB: I really appreciate that. Thanks again, Antwain. \nA: No problem. Take care.",
|
| 124 |
+
"speaker_tracks": {
|
| 125 |
+
"A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/A_track.wav",
|
| 126 |
+
"B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/B_track.wav"
|
| 127 |
+
},
|
| 128 |
+
"error_type": "error_after_interrupt",
|
| 129 |
+
"stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/stereo_dialogue.wav",
|
| 130 |
+
"total_duration": 37.52730158730159,
|
| 131 |
+
"segments": [
|
| 132 |
+
{
|
| 133 |
+
"speaker": "A",
|
| 134 |
+
"text": "You're welcome. I'm just glad I was able to stop it from happening.",
|
| 135 |
+
"original_text": "You're welcome. I'm just glad I was able to stop it from happening.",
|
| 136 |
+
"start_time": 0,
|
| 137 |
+
"end_time": 4.249251700680272,
|
| 138 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_0_A.wav",
|
| 139 |
+
"silence_duration": 0,
|
| 140 |
+
"is_interrupted": false
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"speaker": "B",
|
| 144 |
+
"text": "Thank you so much for saving my life. I can't even begin to express how",
|
| 145 |
+
"original_text": "Thank you so much for saving my life. I can't even begin to express how [interrupt] grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment.",
|
| 146 |
+
"start_time": 4.756366963799184,
|
| 147 |
+
"end_time": 14.694507553368345,
|
| 148 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_1_B.wav",
|
| 149 |
+
"silence_duration": 0.5071152631189118,
|
| 150 |
+
"is_interrupted": true,
|
| 151 |
+
"text_after_interrupt": "grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment."
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"speaker": "A",
|
| 155 |
+
"text": "Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally?",
|
| 156 |
+
"original_text": "Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally?",
|
| 157 |
+
"start_time": 8.726979208697143,
|
| 158 |
+
"end_time": 14.357818210964716,
|
| 159 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_2_A.wav",
|
| 160 |
+
"silence_duration": 0.4049084459018305,
|
| 161 |
+
"is_interrupted": false
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"speaker": "B",
|
| 165 |
+
"text": "I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time.",
|
| 166 |
+
"original_text": "I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time.",
|
| 167 |
+
"start_time": 14.861085984580113,
|
| 168 |
+
"end_time": 23.649838819047233,
|
| 169 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_3_B.wav",
|
| 170 |
+
"silence_duration": 0.5032677736153957,
|
| 171 |
+
"is_interrupted": false
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"speaker": "A",
|
| 175 |
+
"text": "Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay?",
|
| 176 |
+
"original_text": "Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay?",
|
| 177 |
+
"start_time": 24.145193415777634,
|
| 178 |
+
"end_time": 32.515987066571284,
|
| 179 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_4_A.wav",
|
| 180 |
+
"silence_duration": 0.4953545967303996,
|
| 181 |
+
"is_interrupted": false
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"speaker": "B",
|
| 185 |
+
"text": "I really appreciate that. Thanks again, Antwain.",
|
| 186 |
+
"original_text": "I really appreciate that. Thanks again, Antwain.",
|
| 187 |
+
"start_time": 32.97180815148517,
|
| 188 |
+
"end_time": 35.68854284536272,
|
| 189 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_5_B.wav",
|
| 190 |
+
"silence_duration": 0.4558210849138826,
|
| 191 |
+
"is_interrupted": false
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"speaker": "A",
|
| 195 |
+
"text": "No problem. Take care.",
|
| 196 |
+
"original_text": "No problem. Take care.",
|
| 197 |
+
"start_time": 35.99481454512998,
|
| 198 |
+
"end_time": 37.5273315519327,
|
| 199 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_6_A.wav",
|
| 200 |
+
"silence_duration": 0.3062716997672569,
|
| 201 |
+
"is_interrupted": false
|
| 202 |
+
}
|
| 203 |
+
],
|
| 204 |
+
"gt_score": 1
|
| 205 |
+
},
|
| 206 |
+
"SODA_PROCESSED--train--179972": {
|
| 207 |
+
"original_dialog_id": "",
|
| 208 |
+
"dialog_index": 179972,
|
| 209 |
+
"processed_dialogue": "A: So, how did you like the book? \nB: I loved it! The ending was so shocking, I couldn't believe what happened. \nA: Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided. \nB: No, I didn't see it coming at all! It was so unexpected. \nA: Yeah, I know. I couldn't put it down. \nB: Me neither. I'm so glad you wanted to read it. \nA: Yeah, I was curious about the protagonist's journey and how it would [interrupt] evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict. \nB: Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me. \nA: It was definitely a rollercoaster ride. There were so many twists and turns. \nB: I know! I didn't see any of them coming. \nA: That's what made it so great. It kept you guessing the whole time. \nB: Definitely. It was a great book. Thanks for lending it to me.",
|
| 210 |
+
"clean_dialogue": "A: So, how did you like the book? \nB: I loved it! The ending was so shocking, I couldn't believe what happened. \nA: Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided. \nB: No, I didn't see it coming at all! It was so unexpected. \nA: Yeah, I know. I couldn't put it down. \nB: Me neither. I'm so glad you wanted to read it. \nA:Yeah, I was curious about the protagonist's journey and how it would evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict.\nB: Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me. \nA: It was definitely a rollercoaster ride. There were so many twists and turns. \nB: I know! I didn't see any of them coming. \nA: That's what made it so great. It kept you guessing the whole time. \nB: Definitely. It was a great book. Thanks for lending it to me.",
|
| 211 |
+
"speaker_tracks": {
|
| 212 |
+
"A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/A_track.wav",
|
| 213 |
+
"B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/B_track.wav"
|
| 214 |
+
},
|
| 215 |
+
"error_type": "error_after_interrupt",
|
| 216 |
+
"stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/stereo_dialogue.wav",
|
| 217 |
+
"total_duration": 53.57845804988662,
|
| 218 |
+
"segments": [
|
| 219 |
+
{
|
| 220 |
+
"speaker": "A",
|
| 221 |
+
"text": "So, how did you like the book?",
|
| 222 |
+
"original_text": "So, how did you like the book?",
|
| 223 |
+
"start_time": 0,
|
| 224 |
+
"end_time": 1.6950566893424037,
|
| 225 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_0_A.wav",
|
| 226 |
+
"silence_duration": 0,
|
| 227 |
+
"is_interrupted": false
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"speaker": "B",
|
| 231 |
+
"text": "I loved it! The ending was so shocking, I couldn't believe what happened.",
|
| 232 |
+
"original_text": "I loved it! The ending was so shocking, I couldn't believe what happened.",
|
| 233 |
+
"start_time": 2.1792484824735485,
|
| 234 |
+
"end_time": 5.871221271589195,
|
| 235 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_1_B.wav",
|
| 236 |
+
"silence_duration": 0.4841917931311449,
|
| 237 |
+
"is_interrupted": false
|
| 238 |
+
},
|
| 239 |
+
{
|
| 240 |
+
"speaker": "A",
|
| 241 |
+
"text": "Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided.",
|
| 242 |
+
"original_text": "Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided.",
|
| 243 |
+
"start_time": 6.47038511683308,
|
| 244 |
+
"end_time": 14.504489425223102,
|
| 245 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_2_A.wav",
|
| 246 |
+
"silence_duration": 0.5991638452438857,
|
| 247 |
+
"is_interrupted": false
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"speaker": "B",
|
| 251 |
+
"text": "No, I didn't see it coming at all! It was so unexpected.",
|
| 252 |
+
"original_text": "No, I didn't see it coming at all! It was so unexpected.",
|
| 253 |
+
"start_time": 15.012397119017507,
|
| 254 |
+
"end_time": 18.448950406999366,
|
| 255 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_3_B.wav",
|
| 256 |
+
"silence_duration": 0.507907693794404,
|
| 257 |
+
"is_interrupted": false
|
| 258 |
+
},
|
| 259 |
+
{
|
| 260 |
+
"speaker": "A",
|
| 261 |
+
"text": "Yeah, I know. I couldn't put it down.",
|
| 262 |
+
"original_text": "Yeah, I know. I couldn't put it down.",
|
| 263 |
+
"start_time": 18.875209136594886,
|
| 264 |
+
"end_time": 21.847363331606225,
|
| 265 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_4_A.wav",
|
| 266 |
+
"silence_duration": 0.42625872959552,
|
| 267 |
+
"is_interrupted": false
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"speaker": "B",
|
| 271 |
+
"text": "Me neither. I'm so glad you wanted to read it.",
|
| 272 |
+
"original_text": "Me neither. I'm so glad you wanted to read it.",
|
| 273 |
+
"start_time": 22.440054691555087,
|
| 274 |
+
"end_time": 25.110349476135585,
|
| 275 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_5_B.wav",
|
| 276 |
+
"silence_duration": 0.5926913599488615,
|
| 277 |
+
"is_interrupted": false
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"speaker": "A",
|
| 281 |
+
"text": "Yeah, I was curious about the protagonist's journey and how it would",
|
| 282 |
+
"original_text": "Yeah, I was curious about the protagonist's journey and how it would [interrupt] evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict.",
|
| 283 |
+
"start_time": 25.51803755034393,
|
| 284 |
+
"end_time": 36.89581532812171,
|
| 285 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_6_A.wav",
|
| 286 |
+
"silence_duration": 0.40768807420834613,
|
| 287 |
+
"is_interrupted": true,
|
| 288 |
+
"text_after_interrupt": "evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict."
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"speaker": "B",
|
| 292 |
+
"text": "Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me.",
|
| 293 |
+
"original_text": "Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me.",
|
| 294 |
+
"start_time": 29.790509205672727,
|
| 295 |
+
"end_time": 37.429874285037805,
|
| 296 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_7_B.wav",
|
| 297 |
+
"silence_duration": 0.32835611460902553,
|
| 298 |
+
"is_interrupted": false
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"speaker": "A",
|
| 302 |
+
"text": "It was definitely a rollercoaster ride. There were so many twists and turns.",
|
| 303 |
+
"original_text": "It was definitely a rollercoaster ride. There were so many twists and turns.",
|
| 304 |
+
"start_time": 37.91219711578734,
|
| 305 |
+
"end_time": 42.405258340277136,
|
| 306 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_8_A.wav",
|
| 307 |
+
"silence_duration": 0.4823228307495384,
|
| 308 |
+
"is_interrupted": false
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"speaker": "B",
|
| 312 |
+
"text": "I know! I didn't see any of them coming.",
|
| 313 |
+
"original_text": "I know! I didn't see any of them coming.",
|
| 314 |
+
"start_time": 42.860468420817675,
|
| 315 |
+
"end_time": 45.08958406707618,
|
| 316 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_9_B.wav",
|
| 317 |
+
"silence_duration": 0.4552100805405374,
|
| 318 |
+
"is_interrupted": false
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"speaker": "A",
|
| 322 |
+
"text": "That's what made it so great. It kept you guessing the whole time.",
|
| 323 |
+
"original_text": "That's what made it so great. It kept you guessing the whole time.",
|
| 324 |
+
"start_time": 45.679186523390214,
|
| 325 |
+
"end_time": 49.394379267154385,
|
| 326 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_10_A.wav",
|
| 327 |
+
"silence_duration": 0.5896024563140343,
|
| 328 |
+
"is_interrupted": false
|
| 329 |
+
},
|
| 330 |
+
{
|
| 331 |
+
"speaker": "B",
|
| 332 |
+
"text": "Definitely. It was a great book. Thanks for lending it to me.",
|
| 333 |
+
"original_text": "Definitely. It was a great book. Thanks for lending it to me.",
|
| 334 |
+
"start_time": 49.70074891577286,
|
| 335 |
+
"end_time": 53.57848134207672,
|
| 336 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_11_B.wav",
|
| 337 |
+
"silence_duration": 0.3063696486184793,
|
| 338 |
+
"is_interrupted": false
|
| 339 |
+
}
|
| 340 |
+
],
|
| 341 |
+
"gt_score": 1
|
| 342 |
+
},
|
| 343 |
+
"SODA_PROCESSED--train--715956": {
|
| 344 |
+
"original_dialog_id": "",
|
| 345 |
+
"dialog_index": 715956,
|
| 346 |
+
"processed_dialogue": "A: Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your [interrupt] help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.\nB: Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.\nA: But you're my lawyer! You're supposed to help me!\nB: Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.\nA: But I can't go to prison! I'll lose my job, my apartment, everything!\nB: Aadya, you need to calm down. Getting upset isn't going to help anything.\nA: Fine. But, you have to promise me that you'll do everything you can to help me.\nB: I promise.",
|
| 347 |
+
"clean_dialogue": "A:Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.\nB: Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.\nA: But you're my lawyer! You're supposed to help me!\nB: Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.\nA: But I can't go to prison! I'll lose my job, my apartment, everything!\nB: Aadya, you need to calm down. Getting upset isn't going to help anything.\nA: Fine. But, you have to promise me that you'll do everything you can to help me.\nB: I promise.",
|
| 348 |
+
"speaker_tracks": {
|
| 349 |
+
"A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/A_track.wav",
|
| 350 |
+
"B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/B_track.wav"
|
| 351 |
+
},
|
| 352 |
+
"error_type": "error_after_interrupt",
|
| 353 |
+
"stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/stereo_dialogue.wav",
|
| 354 |
+
"total_duration": 49.52126984126984,
|
| 355 |
+
"segments": [
|
| 356 |
+
{
|
| 357 |
+
"speaker": "A",
|
| 358 |
+
"text": "Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your",
|
| 359 |
+
"original_text": "Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your [interrupt] help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.",
|
| 360 |
+
"start_time": 0,
|
| 361 |
+
"end_time": 16.579047619047618,
|
| 362 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_0_A.wav",
|
| 363 |
+
"silence_duration": 0,
|
| 364 |
+
"is_interrupted": true,
|
| 365 |
+
"text_after_interrupt": "help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time."
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"speaker": "B",
|
| 369 |
+
"text": "Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.",
|
| 370 |
+
"original_text": "Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.",
|
| 371 |
+
"start_time": 8.510113378684807,
|
| 372 |
+
"end_time": 18.36698412698413,
|
| 373 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_1_B.wav",
|
| 374 |
+
"silence_duration": 0.4899749375576017,
|
| 375 |
+
"is_interrupted": false
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"speaker": "A",
|
| 379 |
+
"text": "But you're my lawyer! You're supposed to help me!",
|
| 380 |
+
"original_text": "But you're my lawyer! You're supposed to help me!",
|
| 381 |
+
"start_time": 18.846747434390966,
|
| 382 |
+
"end_time": 21.37772249108031,
|
| 383 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_2_A.wav",
|
| 384 |
+
"silence_duration": 0.4797633074068387,
|
| 385 |
+
"is_interrupted": false
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"speaker": "B",
|
| 389 |
+
"text": "Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.",
|
| 390 |
+
"original_text": "Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.",
|
| 391 |
+
"start_time": 21.881120947184385,
|
| 392 |
+
"end_time": 33.51431822609595,
|
| 393 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_3_B.wav",
|
| 394 |
+
"silence_duration": 0.5033984561040751,
|
| 395 |
+
"is_interrupted": false
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"speaker": "A",
|
| 399 |
+
"text": "But I can't go to prison! I'll lose my job, my apartment, everything!",
|
| 400 |
+
"original_text": "But I can't go to prison! I'll lose my job, my apartment, everything!",
|
| 401 |
+
"start_time": 34.047335561433606,
|
| 402 |
+
"end_time": 38.48234689930209,
|
| 403 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_4_A.wav",
|
| 404 |
+
"silence_duration": 0.5330173353376504,
|
| 405 |
+
"is_interrupted": false
|
| 406 |
+
},
|
| 407 |
+
{
|
| 408 |
+
"speaker": "B",
|
| 409 |
+
"text": "Aadya, you need to calm down. Getting upset isn't going to help anything.",
|
| 410 |
+
"original_text": "Aadya, you need to calm down. Getting upset isn't going to help anything.",
|
| 411 |
+
"start_time": 38.89720479711025,
|
| 412 |
+
"end_time": 43.39026602160004,
|
| 413 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_5_B.wav",
|
| 414 |
+
"silence_duration": 0.4148578978081613,
|
| 415 |
+
"is_interrupted": false
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"speaker": "A",
|
| 419 |
+
"text": "Fine. But, you have to promise me that you'll do everything you can to help me.",
|
| 420 |
+
"original_text": "Fine. But, you have to promise me that you'll do everything you can to help me.",
|
| 421 |
+
"start_time": 43.92319932038778,
|
| 422 |
+
"end_time": 48.27694081698642,
|
| 423 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_6_A.wav",
|
| 424 |
+
"silence_duration": 0.5329332987877419,
|
| 425 |
+
"is_interrupted": false
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"speaker": "B",
|
| 429 |
+
"text": "I promise.",
|
| 430 |
+
"original_text": "I promise.",
|
| 431 |
+
"start_time": 48.62731544236006,
|
| 432 |
+
"end_time": 49.52128369632831,
|
| 433 |
+
"audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_7_B.wav",
|
| 434 |
+
"silence_duration": 0.3503746253736393,
|
| 435 |
+
"is_interrupted": false
|
| 436 |
+
}
|
| 437 |
+
],
|
| 438 |
+
"gt_score": 1
|
| 439 |
+
},
|
| 440 |
+
"SODA_PROCESSED--train--740576": {
|
| 441 |
+
"original_text": "A: Good morning, Mr. Nguyen! I hope you're doing well today.\nB: I'm doing well, thank you. How are you?\nA: I'm feeling great today! I have a lot of energy and I'm excited to [interrupt] tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.\nB: Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?\nA: I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.\nB: I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?\nA: That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
|
| 442 |
+
"cleaned_text": "A: Good morning, Mr. Nguyen! I hope you're doing well today.\nB: I'm doing well, thank you. How are you?\nA:I'm feeling great today! I have a lot of energy and I'm excited to tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.\nB: Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?\nA: I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.\nB: I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?\nA: That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
|
| 443 |
+
"total_duration": 49.437278911564626,
|
| 444 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/stereo_dialogue.wav",
|
| 445 |
+
"speaker_tracks": {
|
| 446 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/A_track.wav",
|
| 447 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/B_track.wav"
|
| 448 |
+
},
|
| 449 |
+
"error_type": "error_after_interrupt",
|
| 450 |
+
"segments": [
|
| 451 |
+
{
|
| 452 |
+
"speaker": "A",
|
| 453 |
+
"text": "Good morning, Mr. Nguyen! I hope you're doing well today.",
|
| 454 |
+
"original_text": "Good morning, Mr. Nguyen! I hope you're doing well today.",
|
| 455 |
+
"start_time": 0,
|
| 456 |
+
"end_time": 3.332063492063492,
|
| 457 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_0_A.wav",
|
| 458 |
+
"silence_duration": 0,
|
| 459 |
+
"is_interrupted": false
|
| 460 |
+
},
|
| 461 |
+
{
|
| 462 |
+
"speaker": "B",
|
| 463 |
+
"text": "I'm doing well, thank you. How are you?",
|
| 464 |
+
"original_text": "I'm doing well, thank you. How are you?",
|
| 465 |
+
"start_time": 3.7838731632362803,
|
| 466 |
+
"end_time": 5.583419648497051,
|
| 467 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_1_B.wav",
|
| 468 |
+
"silence_duration": 0.4518096711727882,
|
| 469 |
+
"is_interrupted": false
|
| 470 |
+
},
|
| 471 |
+
{
|
| 472 |
+
"speaker": "A",
|
| 473 |
+
"text": "I'm feeling great today! I have a lot of energy and I'm excited to",
|
| 474 |
+
"original_text": "I'm feeling great today! I have a lot of energy and I'm excited to [interrupt] tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.",
|
| 475 |
+
"start_time": 5.88797031081498,
|
| 476 |
+
"end_time": 16.96388867816192,
|
| 477 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_2_A.wav",
|
| 478 |
+
"silence_duration": 0.30455066231792893,
|
| 479 |
+
"is_interrupted": true,
|
| 480 |
+
"text_after_interrupt": "tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients."
|
| 481 |
+
},
|
| 482 |
+
{
|
| 483 |
+
"speaker": "B",
|
| 484 |
+
"text": "Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?",
|
| 485 |
+
"original_text": "Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?",
|
| 486 |
+
"start_time": 10.485521331223143,
|
| 487 |
+
"end_time": 16.104750356166456,
|
| 488 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_3_B.wav",
|
| 489 |
+
"silence_duration": 0.587489668114177,
|
| 490 |
+
"is_interrupted": false
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"speaker": "A",
|
| 494 |
+
"text": "I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.",
|
| 495 |
+
"original_text": "I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.",
|
| 496 |
+
"start_time": 17.385624216961087,
|
| 497 |
+
"end_time": 33.94145188136018,
|
| 498 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_4_A.wav",
|
| 499 |
+
"silence_duration": 0.4217355387991674,
|
| 500 |
+
"is_interrupted": false
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"speaker": "B",
|
| 504 |
+
"text": "I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?",
|
| 505 |
+
"original_text": "I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?",
|
| 506 |
+
"start_time": 34.39980783470558,
|
| 507 |
+
"end_time": 41.74892348096408,
|
| 508 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_5_B.wav",
|
| 509 |
+
"silence_duration": 0.4583559533453947,
|
| 510 |
+
"is_interrupted": false
|
| 511 |
+
},
|
| 512 |
+
{
|
| 513 |
+
"speaker": "A",
|
| 514 |
+
"text": "That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
|
| 515 |
+
"original_text": "That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
|
| 516 |
+
"start_time": 42.285572803275116,
|
| 517 |
+
"end_time": 49.437318835021145,
|
| 518 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_6_A.wav",
|
| 519 |
+
"silence_duration": 0.5366493223110326,
|
| 520 |
+
"is_interrupted": false
|
| 521 |
+
}
|
| 522 |
+
]
|
| 523 |
+
},
|
| 524 |
+
"SODA_PROCESSED--train--836018": {
|
| 525 |
+
"original_text": "A: Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with [interrupt] organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.\nB: Actually, I could use some help with the data analysis part. It's a bit overwhelming.\nA: Sure, I can take care of that. So what do you think of the project so far?\nB: It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.\nA: Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.\nB: Yeah, definitely. It's fascinating.",
|
| 526 |
+
"cleaned_text": "A:Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.\nB: Actually, I could use some help with the data analysis part. It's a bit overwhelming.\nA: Sure, I can take care of that. So what do you think of the project so far?\nB: It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.\nA: Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.\nB: Yeah, definitely. It's fascinating.",
|
| 527 |
+
"total_duration": 42.34984126984127,
|
| 528 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/stereo_dialogue.wav",
|
| 529 |
+
"speaker_tracks": {
|
| 530 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/A_track.wav",
|
| 531 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/B_track.wav"
|
| 532 |
+
},
|
| 533 |
+
"error_type": "error_after_interrupt",
|
| 534 |
+
"segments": [
|
| 535 |
+
{
|
| 536 |
+
"speaker": "A",
|
| 537 |
+
"text": "Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with",
|
| 538 |
+
"original_text": "Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with [interrupt] organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.",
|
| 539 |
+
"start_time": 0,
|
| 540 |
+
"end_time": 15.011700680272108,
|
| 541 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_0_A.wav",
|
| 542 |
+
"silence_duration": 0,
|
| 543 |
+
"is_interrupted": true,
|
| 544 |
+
"text_after_interrupt": "organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned."
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"speaker": "B",
|
| 548 |
+
"text": "Actually, I could use some help with the data analysis part. It's a bit overwhelming.",
|
| 549 |
+
"original_text": "Actually, I could use some help with the data analysis part. It's a bit overwhelming.",
|
| 550 |
+
"start_time": 6.176507936507937,
|
| 551 |
+
"end_time": 11.250068027210885,
|
| 552 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_1_B.wav",
|
| 553 |
+
"silence_duration": 0.5190912573415952,
|
| 554 |
+
"is_interrupted": false
|
| 555 |
+
},
|
| 556 |
+
{
|
| 557 |
+
"speaker": "A",
|
| 558 |
+
"text": "Sure, I can take care of that. So what do you think of the project so far?",
|
| 559 |
+
"original_text": "Sure, I can take care of that. So what do you think of the project so far?",
|
| 560 |
+
"start_time": 15.60657282124108,
|
| 561 |
+
"end_time": 19.937094363191193,
|
| 562 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_2_A.wav",
|
| 563 |
+
"silence_duration": 0.5948721409689715,
|
| 564 |
+
"is_interrupted": false
|
| 565 |
+
},
|
| 566 |
+
{
|
| 567 |
+
"speaker": "B",
|
| 568 |
+
"text": "It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.",
|
| 569 |
+
"original_text": "It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.",
|
| 570 |
+
"start_time": 20.306213172030862,
|
| 571 |
+
"end_time": 30.476553308085286,
|
| 572 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_3_B.wav",
|
| 573 |
+
"silence_duration": 0.36911880883966963,
|
| 574 |
+
"is_interrupted": false
|
| 575 |
+
},
|
| 576 |
+
{
|
| 577 |
+
"speaker": "A",
|
| 578 |
+
"text": "Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.",
|
| 579 |
+
"original_text": "Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.",
|
| 580 |
+
"start_time": 30.848617682402736,
|
| 581 |
+
"end_time": 39.10331155995375,
|
| 582 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_4_A.wav",
|
| 583 |
+
"silence_duration": 0.3720643743174508,
|
| 584 |
+
"is_interrupted": false
|
| 585 |
+
},
|
| 586 |
+
{
|
| 587 |
+
"speaker": "B",
|
| 588 |
+
"text": "Yeah, definitely. It's fascinating.",
|
| 589 |
+
"original_text": "Yeah, definitely. It's fascinating.",
|
| 590 |
+
"start_time": 39.435776463870354,
|
| 591 |
+
"end_time": 42.34988077226038,
|
| 592 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_5_B.wav",
|
| 593 |
+
"silence_duration": 0.3324649039166007,
|
| 594 |
+
"is_interrupted": false
|
| 595 |
+
}
|
| 596 |
+
]
|
| 597 |
+
},
|
| 598 |
+
"SODA_PROCESSED--train--771154": {
|
| 599 |
+
"original_text": "A: Hey, Mom. Can I go to Tim's house today? \nB: No, you can't go to Tim's house today. \nA: Why [interrupt] can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this. \nB: Because I said so. And before you ask again, you need to focus on your homework. \nA: But I don't want to do my homework. \nB: Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester. \nA: Wait, can't I just finish it later? I promise I'll get it done before bed. \nB: No, you need to do it now. Once it's done, then we can talk about other plans.",
|
| 600 |
+
"cleaned_text": "A: Hey, Mom. Can I go to Tim's house today? \nB: No, you can't go to Tim's house today. \nA:Why can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this.\nB: Because I said so. And before you ask again, you need to focus on your homework. \nA: But I don't want to do my homework. \nB: Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester. \nA: Wait, can't I just finish it later? I promise I'll get it done before bed. \nB: No, you need to do it now. Once it's done, then we can talk about other plans.",
|
| 601 |
+
"total_duration": 35.76784580498866,
|
| 602 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/stereo_dialogue.wav",
|
| 603 |
+
"speaker_tracks": {
|
| 604 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/A_track.wav",
|
| 605 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/B_track.wav"
|
| 606 |
+
},
|
| 607 |
+
"error_type": "error_after_interrupt",
|
| 608 |
+
"segments": [
|
| 609 |
+
{
|
| 610 |
+
"speaker": "A",
|
| 611 |
+
"text": "Hey, Mom. Can I go to Tim's house today?",
|
| 612 |
+
"original_text": "Hey, Mom. Can I go to Tim's house today?",
|
| 613 |
+
"start_time": 0,
|
| 614 |
+
"end_time": 3.5294331065759637,
|
| 615 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_0_A.wav",
|
| 616 |
+
"silence_duration": 0,
|
| 617 |
+
"is_interrupted": false
|
| 618 |
+
},
|
| 619 |
+
{
|
| 620 |
+
"speaker": "B",
|
| 621 |
+
"text": "No, you can't go to Tim's house today.",
|
| 622 |
+
"original_text": "No, you can't go to Tim's house today.",
|
| 623 |
+
"start_time": 3.9899851353219105,
|
| 624 |
+
"end_time": 6.126220962986309,
|
| 625 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_1_B.wav",
|
| 626 |
+
"silence_duration": 0.4605520287459467,
|
| 627 |
+
"is_interrupted": false
|
| 628 |
+
},
|
| 629 |
+
{
|
| 630 |
+
"speaker": "A",
|
| 631 |
+
"text": "Why",
|
| 632 |
+
"original_text": "Why [interrupt] can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this.",
|
| 633 |
+
"start_time": 6.4787876256667465,
|
| 634 |
+
"end_time": 14.652211661947927,
|
| 635 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_2_A.wav",
|
| 636 |
+
"silence_duration": 0.3525666626804373,
|
| 637 |
+
"is_interrupted": true,
|
| 638 |
+
"text_after_interrupt": "can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this."
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"speaker": "B",
|
| 642 |
+
"text": "Because I said so. And before you ask again, you need to focus on your homework.",
|
| 643 |
+
"original_text": "Because I said so. And before you ask again, you need to focus on your homework.",
|
| 644 |
+
"start_time": 7.210216197095318,
|
| 645 |
+
"end_time": 11.889037058773322,
|
| 646 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_3_B.wav",
|
| 647 |
+
"silence_duration": 0.4183677243140269,
|
| 648 |
+
"is_interrupted": false
|
| 649 |
+
},
|
| 650 |
+
{
|
| 651 |
+
"speaker": "A",
|
| 652 |
+
"text": "But I don't want to do my homework.",
|
| 653 |
+
"original_text": "But I don't want to do my homework.",
|
| 654 |
+
"start_time": 15.159162983353092,
|
| 655 |
+
"end_time": 17.074809241856492,
|
| 656 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_4_A.wav",
|
| 657 |
+
"silence_duration": 0.5069513214051653,
|
| 658 |
+
"is_interrupted": false
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"speaker": "B",
|
| 662 |
+
"text": "Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester.",
|
| 663 |
+
"original_text": "Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester.",
|
| 664 |
+
"start_time": 17.6716136549098,
|
| 665 |
+
"end_time": 25.763767849921138,
|
| 666 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_5_B.wav",
|
| 667 |
+
"silence_duration": 0.5968044130533094,
|
| 668 |
+
"is_interrupted": false
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"speaker": "A",
|
| 672 |
+
"text": "Wait, can't I just finish it later? I promise I'll get it done before bed.",
|
| 673 |
+
"original_text": "Wait, can't I just finish it later? I promise I'll get it done before bed.",
|
| 674 |
+
"start_time": 26.149694131743242,
|
| 675 |
+
"end_time": 31.02588460793372,
|
| 676 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_6_A.wav",
|
| 677 |
+
"silence_duration": 0.38592628182210614,
|
| 678 |
+
"is_interrupted": false
|
| 679 |
+
},
|
| 680 |
+
{
|
| 681 |
+
"speaker": "B",
|
| 682 |
+
"text": "No, you need to do it now. Once it's done, then we can talk about other plans.",
|
| 683 |
+
"original_text": "No, you need to do it now. Once it's done, then we can talk about other plans.",
|
| 684 |
+
"start_time": 31.518621255026567,
|
| 685 |
+
"end_time": 35.767872955706835,
|
| 686 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_7_B.wav",
|
| 687 |
+
"silence_duration": 0.49273664709284837,
|
| 688 |
+
"is_interrupted": false
|
| 689 |
+
}
|
| 690 |
+
]
|
| 691 |
+
},
|
| 692 |
+
"SODA_PROCESSED--train--1794": {
|
| 693 |
+
"original_text": "A: Hey, Mom. \nB: Hey, Moriah. What's up? \nA: Not much. Just hanging out in my room. \nB: That's good. I'm glad you're keeping busy. \nA: Yeah, I'm just trying to stay out of [interrupt] everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately. \nB: Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you. \nA: I'm just kind of going through some stuff right now. \nB: Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being. \nA: Wait, is this about me staying out late last weekend? \nB: Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you. \nA: I don't know, Mom. Like I said, I'm just dealing with some stuff. \nB: Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah. \nA: I love you too, Mom.",
|
| 694 |
+
"cleaned_text": "A: Hey, Mom. \nB: Hey, Moriah. What's up? \nA: Not much. Just hanging out in my room. \nB: That's good. I'm glad you're keeping busy. \nA:Yeah, I'm just trying to stay out of everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately.\nB: Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you. \nA: I'm just kind of going through some stuff right now. \nB: Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being. \nA: Wait, is this about me staying out late last weekend? \nB: Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you. \nA: I don't know, Mom. Like I said, I'm just dealing with some stuff. \nB: Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah. \nA: I love you too, Mom.",
|
| 695 |
+
"total_duration": 57.99024943310658,
|
| 696 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/stereo_dialogue.wav",
|
| 697 |
+
"speaker_tracks": {
|
| 698 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/A_track.wav",
|
| 699 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/B_track.wav"
|
| 700 |
+
},
|
| 701 |
+
"error_type": "error_after_interrupt",
|
| 702 |
+
"segments": [
|
| 703 |
+
{
|
| 704 |
+
"speaker": "A",
|
| 705 |
+
"text": "Hey, Mom.",
|
| 706 |
+
"original_text": "Hey, Mom.",
|
| 707 |
+
"start_time": 0,
|
| 708 |
+
"end_time": 0.8591383219954648,
|
| 709 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_0_A.wav",
|
| 710 |
+
"silence_duration": 0,
|
| 711 |
+
"is_interrupted": false
|
| 712 |
+
},
|
| 713 |
+
{
|
| 714 |
+
"speaker": "B",
|
| 715 |
+
"text": "Hey, Moriah. What's up?",
|
| 716 |
+
"original_text": "Hey, Moriah. What's up?",
|
| 717 |
+
"start_time": 1.2689805234753475,
|
| 718 |
+
"end_time": 2.7782775756295424,
|
| 719 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_1_B.wav",
|
| 720 |
+
"silence_duration": 0.4098422014798827,
|
| 721 |
+
"is_interrupted": false
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"speaker": "A",
|
| 725 |
+
"text": "Not much. Just hanging out in my room.",
|
| 726 |
+
"original_text": "Not much. Just hanging out in my room.",
|
| 727 |
+
"start_time": 3.2528527196865094,
|
| 728 |
+
"end_time": 5.505188320593539,
|
| 729 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_2_A.wav",
|
| 730 |
+
"silence_duration": 0.47457514405696677,
|
| 731 |
+
"is_interrupted": false
|
| 732 |
+
},
|
| 733 |
+
{
|
| 734 |
+
"speaker": "B",
|
| 735 |
+
"text": "That's good. I'm glad you're keeping busy.",
|
| 736 |
+
"original_text": "That's good. I'm glad you're keeping busy.",
|
| 737 |
+
"start_time": 6.047417085120735,
|
| 738 |
+
"end_time": 8.520342255188762,
|
| 739 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_3_B.wav",
|
| 740 |
+
"silence_duration": 0.5422287645271964,
|
| 741 |
+
"is_interrupted": false
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"speaker": "A",
|
| 745 |
+
"text": "Yeah, I'm just trying to stay out of",
|
| 746 |
+
"original_text": "Yeah, I'm just trying to stay out of [interrupt] everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately.",
|
| 747 |
+
"start_time": 8.88750351109664,
|
| 748 |
+
"end_time": 18.059385597264438,
|
| 749 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_4_A.wav",
|
| 750 |
+
"silence_duration": 0.3671612559078772,
|
| 751 |
+
"is_interrupted": true,
|
| 752 |
+
"text_after_interrupt": "everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately."
|
| 753 |
+
},
|
| 754 |
+
{
|
| 755 |
+
"speaker": "B",
|
| 756 |
+
"text": "Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you.",
|
| 757 |
+
"original_text": "Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you.",
|
| 758 |
+
"start_time": 11.697118023568294,
|
| 759 |
+
"end_time": 18.2915851437497,
|
| 760 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_5_B.wav",
|
| 761 |
+
"silence_duration": 0.32519714638310315,
|
| 762 |
+
"is_interrupted": false
|
| 763 |
+
},
|
| 764 |
+
{
|
| 765 |
+
"speaker": "A",
|
| 766 |
+
"text": "I'm just kind of going through some stuff right now.",
|
| 767 |
+
"original_text": "I'm just kind of going through some stuff right now.",
|
| 768 |
+
"start_time": 18.62204195980515,
|
| 769 |
+
"end_time": 21.396826540304016,
|
| 770 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_6_A.wav",
|
| 771 |
+
"silence_duration": 0.3304568160554501,
|
| 772 |
+
"is_interrupted": false
|
| 773 |
+
},
|
| 774 |
+
{
|
| 775 |
+
"speaker": "B",
|
| 776 |
+
"text": "Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being.",
|
| 777 |
+
"original_text": "Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being.",
|
| 778 |
+
"start_time": 21.697523952118004,
|
| 779 |
+
"end_time": 34.7355284872654,
|
| 780 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_7_B.wav",
|
| 781 |
+
"silence_duration": 0.30069741181398774,
|
| 782 |
+
"is_interrupted": false
|
| 783 |
+
},
|
| 784 |
+
{
|
| 785 |
+
"speaker": "A",
|
| 786 |
+
"text": "Wait, is this about me staying out late last weekend?",
|
| 787 |
+
"original_text": "Wait, is this about me staying out late last weekend?",
|
| 788 |
+
"start_time": 35.29912687220732,
|
| 789 |
+
"end_time": 38.677630273567864,
|
| 790 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_8_A.wav",
|
| 791 |
+
"silence_duration": 0.5635983849419206,
|
| 792 |
+
"is_interrupted": false
|
| 793 |
+
},
|
| 794 |
+
{
|
| 795 |
+
"speaker": "B",
|
| 796 |
+
"text": "Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you.",
|
| 797 |
+
"original_text": "Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you.",
|
| 798 |
+
"start_time": 39.09678068392148,
|
| 799 |
+
"end_time": 45.99310721453372,
|
| 800 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_9_B.wav",
|
| 801 |
+
"silence_duration": 0.4191504103536184,
|
| 802 |
+
"is_interrupted": false
|
| 803 |
+
},
|
| 804 |
+
{
|
| 805 |
+
"speaker": "A",
|
| 806 |
+
"text": "I don't know, Mom. Like I said, I'm just dealing with some stuff.",
|
| 807 |
+
"original_text": "I don't know, Mom. Like I said, I'm just dealing with some stuff.",
|
| 808 |
+
"start_time": 46.3670775788443,
|
| 809 |
+
"end_time": 50.46539957430915,
|
| 810 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_10_A.wav",
|
| 811 |
+
"silence_duration": 0.3739703643105766,
|
| 812 |
+
"is_interrupted": false
|
| 813 |
+
},
|
| 814 |
+
{
|
| 815 |
+
"speaker": "B",
|
| 816 |
+
"text": "Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah.",
|
| 817 |
+
"original_text": "Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah.",
|
| 818 |
+
"start_time": 50.99388055366539,
|
| 819 |
+
"end_time": 56.06744064436834,
|
| 820 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_11_B.wav",
|
| 821 |
+
"silence_duration": 0.5284809793562373,
|
| 822 |
+
"is_interrupted": false
|
| 823 |
+
},
|
| 824 |
+
{
|
| 825 |
+
"speaker": "A",
|
| 826 |
+
"text": "I love you too, Mom.",
|
| 827 |
+
"original_text": "I love you too, Mom.",
|
| 828 |
+
"start_time": 56.55062063706958,
|
| 829 |
+
"end_time": 57.99025782527819,
|
| 830 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_12_A.wav",
|
| 831 |
+
"silence_duration": 0.4831799927012399,
|
| 832 |
+
"is_interrupted": false
|
| 833 |
+
}
|
| 834 |
+
]
|
| 835 |
+
},
|
| 836 |
+
"SODA_PROCESSED--train--1070688": {
|
| 837 |
+
"original_text": "A: Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're [interrupt] ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.\nB: Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.\nA: No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.\nB: Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.\nA: I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
|
| 838 |
+
"cleaned_text": "A:Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.\nB: Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.\nA: No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.\nB: Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.\nA: I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
|
| 839 |
+
"total_duration": 66.58453514739229,
|
| 840 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/stereo_dialogue.wav",
|
| 841 |
+
"speaker_tracks": {
|
| 842 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/A_track.wav",
|
| 843 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/B_track.wav"
|
| 844 |
+
},
|
| 845 |
+
"error_type": "error_after_interrupt",
|
| 846 |
+
"segments": [
|
| 847 |
+
{
|
| 848 |
+
"speaker": "A",
|
| 849 |
+
"text": "Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're",
|
| 850 |
+
"original_text": "Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're [interrupt] ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.",
|
| 851 |
+
"start_time": 0,
|
| 852 |
+
"end_time": 16.172698412698413,
|
| 853 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_0_A.wav",
|
| 854 |
+
"silence_duration": 0,
|
| 855 |
+
"is_interrupted": true,
|
| 856 |
+
"text_after_interrupt": "ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met."
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
"speaker": "B",
|
| 860 |
+
"text": "Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.",
|
| 861 |
+
"original_text": "Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.",
|
| 862 |
+
"start_time": 8.719092970521542,
|
| 863 |
+
"end_time": 15.650249433106577,
|
| 864 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_1_B.wav",
|
| 865 |
+
"silence_duration": 0.42791712549357114,
|
| 866 |
+
"is_interrupted": false
|
| 867 |
+
},{
|
| 868 |
+
"speaker": "A",
|
| 869 |
+
"text": "No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.",
|
| 870 |
+
"original_text": "No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.",
|
| 871 |
+
"start_time": 16.66087863834312,
|
| 872 |
+
"end_time": 43.38704643879663,
|
| 873 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_2_A.wav",
|
| 874 |
+
"silence_duration": 0.488180225644707,
|
| 875 |
+
"is_interrupted": false
|
| 876 |
+
},
|
| 877 |
+
{
|
| 878 |
+
"speaker": "B",
|
| 879 |
+
"text": "Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.",
|
| 880 |
+
"original_text": "Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.",
|
| 881 |
+
"start_time": 43.75020989775093,
|
| 882 |
+
"end_time": 49.926717834258866,
|
| 883 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_3_B.wav",
|
| 884 |
+
"silence_duration": 0.36316345895429397,
|
| 885 |
+
"is_interrupted": false
|
| 886 |
+
},
|
| 887 |
+
{
|
| 888 |
+
"speaker": "A",
|
| 889 |
+
"text": "I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
|
| 890 |
+
"original_text": "I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
|
| 891 |
+
"start_time": 50.49314394878711,
|
| 892 |
+
"end_time": 66.58457252021569,
|
| 893 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_4_A.wav",
|
| 894 |
+
"silence_duration": 0.5664261145282402,
|
| 895 |
+
"is_interrupted": false
|
| 896 |
+
}
|
| 897 |
+
]
|
| 898 |
+
},
|
| 899 |
+
"SODA_PROCESSED--train--737676": {
|
| 900 |
+
"original_text": "A: Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your [interrupt] car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.\nB: I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?\nA: Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.\nB: Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.\nA: Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
|
| 901 |
+
"cleaned_text": "A:Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.\nB: I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?\nA: Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.\nB: Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.\nA: Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
|
| 902 |
+
"total_duration": 52.89809523809524,
|
| 903 |
+
"stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/stereo_dialogue.wav",
|
| 904 |
+
"speaker_tracks": {
|
| 905 |
+
"A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/A_track.wav",
|
| 906 |
+
"B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/B_track.wav"
|
| 907 |
+
},
|
| 908 |
+
"error_type": "error_after_interrupt",
|
| 909 |
+
"segments": [
|
| 910 |
+
{
|
| 911 |
+
"speaker": "A",
|
| 912 |
+
"text": "Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your",
|
| 913 |
+
"original_text": "Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your [interrupt] car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.",
|
| 914 |
+
"start_time": 0,
|
| 915 |
+
"end_time": 16.938956916099773,
|
| 916 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_0_A.wav",
|
| 917 |
+
"silence_duration": 0,
|
| 918 |
+
"is_interrupted": true,
|
| 919 |
+
"text_after_interrupt": "car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply."
|
| 920 |
+
},
|
| 921 |
+
{
|
| 922 |
+
"speaker": "B",
|
| 923 |
+
"text": "I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?",
|
| 924 |
+
"original_text": "I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?",
|
| 925 |
+
"start_time": 8.753922902494331,
|
| 926 |
+
"end_time": 15.348390022675737,
|
| 927 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_1_B.wav",
|
| 928 |
+
"silence_duration": 0.5553895116856843,
|
| 929 |
+
"is_interrupted": false
|
| 930 |
+
},
|
| 931 |
+
{
|
| 932 |
+
"speaker": "A",
|
| 933 |
+
"text": "Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.",
|
| 934 |
+
"original_text": "Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.",
|
| 935 |
+
"start_time": 17.329799609194744,
|
| 936 |
+
"end_time": 26.582951536632386,
|
| 937 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_2_A.wav",
|
| 938 |
+
"silence_duration": 0.3908426930949695,
|
| 939 |
+
"is_interrupted": false
|
| 940 |
+
},
|
| 941 |
+
{
|
| 942 |
+
"speaker": "B",
|
| 943 |
+
"text": "Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.",
|
| 944 |
+
"original_text": "Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.",
|
| 945 |
+
"start_time": 26.900238001740547,
|
| 946 |
+
"end_time": 44.05978448700132,
|
| 947 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_3_B.wav",
|
| 948 |
+
"silence_duration": 0.3172864651081615,
|
| 949 |
+
"is_interrupted": false
|
| 950 |
+
},
|
| 951 |
+
{
|
| 952 |
+
"speaker": "A",
|
| 953 |
+
"text": "Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
|
| 954 |
+
"original_text": "Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
|
| 955 |
+
"start_time": 44.64342590433178,
|
| 956 |
+
"end_time": 52.8981197818828,
|
| 957 |
+
"audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_4_A.wav",
|
| 958 |
+
"silence_duration": 0.5836414173304574,
|
| 959 |
+
"is_interrupted": false
|
| 960 |
+
}
|
| 961 |
+
]
|
| 962 |
+
}
|
| 963 |
+
}
|
ms-swift/swift/llm/sampling/mcts.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from swift.llm import InferRequest, SamplingArguments
|
| 10 |
+
from swift.llm.infer.protocol import UsageInfo
|
| 11 |
+
from swift.utils import get_logger
|
| 12 |
+
from .base import Sampler
|
| 13 |
+
from .utils import get_reward, perform_infer
|
| 14 |
+
|
| 15 |
+
logger = get_logger()
|
| 16 |
+
|
| 17 |
+
NXT_PROMPT = """Continue.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
next_message = {
|
| 21 |
+
'role': 'user',
|
| 22 |
+
'content': NXT_PROMPT,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class LanguageNode:
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
step: str = None,
|
| 31 |
+
sep_token: str = None,
|
| 32 |
+
parent: 'LanguageNode' = None,
|
| 33 |
+
):
|
| 34 |
+
self.parent = parent
|
| 35 |
+
|
| 36 |
+
if sep_token:
|
| 37 |
+
self.sep_token = sep_token
|
| 38 |
+
else:
|
| 39 |
+
self.sep_token = parent.sep_token
|
| 40 |
+
|
| 41 |
+
if parent:
|
| 42 |
+
self.path = parent.path[:] + [step]
|
| 43 |
+
self.answer = parent.answer + step + self.sep_token
|
| 44 |
+
self.depth = parent.depth + 1
|
| 45 |
+
else:
|
| 46 |
+
self.path = []
|
| 47 |
+
self.answer = ''
|
| 48 |
+
self.depth = 0
|
| 49 |
+
|
| 50 |
+
self.active_children = []
|
| 51 |
+
self.children = []
|
| 52 |
+
self.visit_count = 0
|
| 53 |
+
self.process_reward = 0.0
|
| 54 |
+
self.outcome_reward = 0.0
|
| 55 |
+
self.terminated = False
|
| 56 |
+
self.correct = False
|
| 57 |
+
|
| 58 |
+
def is_leaf(self):
|
| 59 |
+
return len(self.children) == 0
|
| 60 |
+
|
| 61 |
+
def is_root(self):
|
| 62 |
+
return self.parent is None
|
| 63 |
+
|
| 64 |
+
def visit(self):
|
| 65 |
+
self.visit_count += 1
|
| 66 |
+
|
| 67 |
+
def init_and_update_value(self, value):
|
| 68 |
+
self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1)
|
| 69 |
+
|
| 70 |
+
def add_child(self, child: 'LanguageNode'):
|
| 71 |
+
self.children.append(child)
|
| 72 |
+
if not child.terminated:
|
| 73 |
+
self.active_children.append(child)
|
| 74 |
+
|
| 75 |
+
def collect(self):
|
| 76 |
+
result = {
|
| 77 |
+
'path': self.path,
|
| 78 |
+
'depth': self.depth,
|
| 79 |
+
'visit_count': self.visit_count,
|
| 80 |
+
'process_reward': self.process_reward,
|
| 81 |
+
'outcome_reward': self.outcome_reward,
|
| 82 |
+
'terminated': str(self.terminated),
|
| 83 |
+
'correct': str(self.correct),
|
| 84 |
+
'children': [child.collect() for child in self.children],
|
| 85 |
+
}
|
| 86 |
+
return result
|
| 87 |
+
|
| 88 |
+
def __lt__(self, other):
|
| 89 |
+
return self.outcome_reward < other.outcome_reward
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MctsSampler(Sampler):
|
| 93 |
+
|
| 94 |
+
def __init__(self, input_args: SamplingArguments):
|
| 95 |
+
super().__init__(input_args)
|
| 96 |
+
self.usage_info = UsageInfo(0, 0, 0)
|
| 97 |
+
|
| 98 |
+
def _prepare_model_tokenizer(self):
|
| 99 |
+
args = self.args
|
| 100 |
+
self.infer_kwargs = {}
|
| 101 |
+
if args.sampler_engine == 'client':
|
| 102 |
+
from swift.llm import InferClient
|
| 103 |
+
api_key = args.api_key
|
| 104 |
+
base_url = args.base_url
|
| 105 |
+
self.infer_engine = [
|
| 106 |
+
InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)
|
| 107 |
+
]
|
| 108 |
+
self.infer_kwargs['model'] = args.model
|
| 109 |
+
else:
|
| 110 |
+
_Engine = self.get_infer_engine()
|
| 111 |
+
self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
|
| 112 |
+
|
| 113 |
+
def get_infer_engine(self):
|
| 114 |
+
if self.args.sampler_engine == 'pt':
|
| 115 |
+
from swift.llm import PtEngine
|
| 116 |
+
_Engine = PtEngine
|
| 117 |
+
elif self.args.sampler_engine == 'vllm':
|
| 118 |
+
from swift.llm import VllmEngine
|
| 119 |
+
_Engine = VllmEngine
|
| 120 |
+
elif self.args.sampler_engine == 'lmdeploy':
|
| 121 |
+
from swift.llm import LmdeployEngine
|
| 122 |
+
_Engine = LmdeployEngine
|
| 123 |
+
elif self.args.sampler_engine == 'no':
|
| 124 |
+
_Engine = None
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
|
| 127 |
+
return _Engine
|
| 128 |
+
|
| 129 |
+
def _prepare_template(self) -> None:
|
| 130 |
+
# Hack from super()
|
| 131 |
+
self._prepare_request_configs()
|
| 132 |
+
|
| 133 |
+
def _prepare_request_configs(self):
|
| 134 |
+
_args = self.args
|
| 135 |
+
request_config = _args.get_request_config()
|
| 136 |
+
request_config.stop = _args.stop_words
|
| 137 |
+
request_config.seed = _args.seed
|
| 138 |
+
self.expand_request_configs = []
|
| 139 |
+
self.rollout_request_configs = []
|
| 140 |
+
for i in range(_args.num_return_sequences):
|
| 141 |
+
expand_request_config = deepcopy(request_config)
|
| 142 |
+
expand_request_config.n = 1
|
| 143 |
+
expand_request_config.num_beams = expand_request_config.n
|
| 144 |
+
expand_request_config.seed += i
|
| 145 |
+
self.expand_request_configs.append(expand_request_config)
|
| 146 |
+
rollout_request_config = deepcopy(request_config)
|
| 147 |
+
rollout_request_config.max_tokens = 500
|
| 148 |
+
rollout_request_config.temperature = 0.0
|
| 149 |
+
rollout_request_config.n = 1
|
| 150 |
+
self.rollout_request_configs.append(rollout_request_config)
|
| 151 |
+
|
| 152 |
+
def update_usage_info(self, response):
|
| 153 |
+
for key, value in self.usage_info.__dict__.items():
|
| 154 |
+
update_value = getattr(response.usage, key, None) + value
|
| 155 |
+
setattr(self.usage_info, key, update_value)
|
| 156 |
+
|
| 157 |
+
def search_single(self, query, ground_truth):
|
| 158 |
+
|
| 159 |
+
def _uct(uct_curr_node: LanguageNode):
|
| 160 |
+
alpha = _args.process_reward_rate
|
| 161 |
+
value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward
|
| 162 |
+
if uct_curr_node.is_root():
|
| 163 |
+
return value
|
| 164 |
+
|
| 165 |
+
exploitation_score = value
|
| 166 |
+
exploration_score = (
|
| 167 |
+
_args.exploration_rate
|
| 168 |
+
* np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1)))
|
| 169 |
+
|
| 170 |
+
return exploration_score + exploitation_score
|
| 171 |
+
|
| 172 |
+
def _select(select_curr_node: LanguageNode):
|
| 173 |
+
while not select_curr_node.is_leaf():
|
| 174 |
+
select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x))
|
| 175 |
+
return select_curr_node
|
| 176 |
+
|
| 177 |
+
def _expand(expand_curr_node: LanguageNode):
|
| 178 |
+
n = _args.num_return_sequences - len(expand_curr_node.children)
|
| 179 |
+
if expand_curr_node.is_root():
|
| 180 |
+
infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)]
|
| 181 |
+
else:
|
| 182 |
+
history_message = {
|
| 183 |
+
'role': 'assistant',
|
| 184 |
+
'content': expand_curr_node.answer,
|
| 185 |
+
}
|
| 186 |
+
infer_request = InferRequest(system_message + [prompt_message, history_message, next_message])
|
| 187 |
+
infer_requests = [infer_request for _ in range(n)]
|
| 188 |
+
|
| 189 |
+
# e_time = time.time()
|
| 190 |
+
# To perform the Expand operation in parallel,
|
| 191 |
+
# there's no need to consider the order for now, since the Prompt is the same.
|
| 192 |
+
expand_iter_index = 0
|
| 193 |
+
while True:
|
| 194 |
+
responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs,
|
| 195 |
+
**self.infer_kwargs)
|
| 196 |
+
if len(responses) > 0:
|
| 197 |
+
break
|
| 198 |
+
if expand_iter_index == 5:
|
| 199 |
+
raise ValueError('Expand should not return any response')
|
| 200 |
+
expand_iter_index += 1
|
| 201 |
+
# logger.info(f"expand.expand time: {time.time() - e_time}")
|
| 202 |
+
|
| 203 |
+
# To fetch Outcome Reward in parallel,
|
| 204 |
+
# the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
|
| 205 |
+
orm_infer_requests = []
|
| 206 |
+
unique_output = set()
|
| 207 |
+
for response in responses:
|
| 208 |
+
self.update_usage_info(response)
|
| 209 |
+
output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0]
|
| 210 |
+
if output in unique_output:
|
| 211 |
+
continue
|
| 212 |
+
unique_output.add(output)
|
| 213 |
+
orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}]))
|
| 214 |
+
child = LanguageNode(step=output, parent=expand_curr_node)
|
| 215 |
+
if self.orm_model.check_terminate(child.answer)[0]:
|
| 216 |
+
child.terminated = True
|
| 217 |
+
expand_curr_node.add_child(child)
|
| 218 |
+
|
| 219 |
+
# e_time = time.time()
|
| 220 |
+
orm_score, _orm_mask = get_reward(
|
| 221 |
+
self.orm_model,
|
| 222 |
+
orm_infer_requests,
|
| 223 |
+
ground_truths=[ground_truth] * len(orm_infer_requests),
|
| 224 |
+
threshold=0.0)
|
| 225 |
+
# logger.info(f"expand.orm time: {time.time() - e_time}")
|
| 226 |
+
for child, score in zip(expand_curr_node.children, orm_score):
|
| 227 |
+
if child.terminated:
|
| 228 |
+
child.init_and_update_value(score)
|
| 229 |
+
child.correct = score > 0.9
|
| 230 |
+
terminated_nodes.append(child)
|
| 231 |
+
|
| 232 |
+
# e_time = time.time()
|
| 233 |
+
if self.prm_model:
|
| 234 |
+
prm_infer_requests = []
|
| 235 |
+
for child in expand_curr_node.children:
|
| 236 |
+
prm_message = {'role': 'assistant', 'content': child.answer}
|
| 237 |
+
prm_infer_requests.append(InferRequest([prompt_message, prm_message]))
|
| 238 |
+
prm_score, _prm_mask = get_reward(
|
| 239 |
+
self.prm_model,
|
| 240 |
+
prm_infer_requests,
|
| 241 |
+
ground_truths=[ground_truth] * len(prm_infer_requests),
|
| 242 |
+
threshold=0.0)
|
| 243 |
+
for child, score in zip(expand_curr_node.children, prm_score):
|
| 244 |
+
child.process_reward = score
|
| 245 |
+
# logger.info(f"expand.prm time: {time.time() - e_time}")
|
| 246 |
+
|
| 247 |
+
def _rollout(rollout_curr_node: LanguageNode):
|
| 248 |
+
rollout_depth = 0
|
| 249 |
+
rollout_nodes = {}
|
| 250 |
+
for i in range(len(rollout_curr_node.active_children)):
|
| 251 |
+
rollout_nodes[i] = {
|
| 252 |
+
'node': rollout_curr_node.active_children[i],
|
| 253 |
+
'history_messages': {
|
| 254 |
+
'role': 'assistant',
|
| 255 |
+
'content': rollout_curr_node.active_children[i].answer,
|
| 256 |
+
},
|
| 257 |
+
}
|
| 258 |
+
active_rollout_nodes = list(rollout_nodes.keys())
|
| 259 |
+
while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
|
| 260 |
+
# r_time = time.time()
|
| 261 |
+
infer_requests = [
|
| 262 |
+
InferRequest(system_message
|
| 263 |
+
+ [prompt_message, rollout_nodes[index]['history_messages'], next_message])
|
| 264 |
+
for index in active_rollout_nodes
|
| 265 |
+
]
|
| 266 |
+
# logger.info(f"rollout.prepare time: {time.time() - r_time}")
|
| 267 |
+
# r_time = time.time()
|
| 268 |
+
rollout_iter_index = 0
|
| 269 |
+
while True:
|
| 270 |
+
responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs,
|
| 271 |
+
**self.infer_kwargs)
|
| 272 |
+
if len(responses) > 0:
|
| 273 |
+
break
|
| 274 |
+
if rollout_iter_index == 5:
|
| 275 |
+
raise ValueError('Rollout should not return any response')
|
| 276 |
+
rollout_iter_index += 1
|
| 277 |
+
# logger.info(f"rollout.infer time: {time.time() - r_time}")
|
| 278 |
+
|
| 279 |
+
# r_time = time.time()
|
| 280 |
+
orm_infer_requests = []
|
| 281 |
+
end_paths = []
|
| 282 |
+
for index, response in zip(active_rollout_nodes, responses):
|
| 283 |
+
self.update_usage_info(response)
|
| 284 |
+
output = response.choices[0].message.content.rstrip(sep_token
|
| 285 |
+
+ '\n').split(sep_token)[0] + sep_token + '\n'
|
| 286 |
+
rollout_nodes[index]['history_messages']['content'] += output
|
| 287 |
+
end_paths.append(rollout_nodes[index]['history_messages']['content'])
|
| 288 |
+
orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']]))
|
| 289 |
+
# logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")
|
| 290 |
+
|
| 291 |
+
# r_time = time.time()
|
| 292 |
+
orm_score, _orm_mask = get_reward(
|
| 293 |
+
self.orm_model,
|
| 294 |
+
orm_infer_requests,
|
| 295 |
+
ground_truths=[ground_truth] * len(infer_requests),
|
| 296 |
+
threshold=0.0)
|
| 297 |
+
# logger.info(f"rollout.get_orm time: {time.time() - r_time}")
|
| 298 |
+
terminated_state = self.orm_model.check_terminate(end_paths)
|
| 299 |
+
for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state):
|
| 300 |
+
if terminated:
|
| 301 |
+
rollout_curr_node.active_children[index].init_and_update_value(score)
|
| 302 |
+
if score > 0.9:
|
| 303 |
+
rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content'])
|
| 304 |
+
else:
|
| 305 |
+
rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content'])
|
| 306 |
+
rollout_nodes.pop(index)
|
| 307 |
+
active_rollout_nodes = list(rollout_nodes.keys())
|
| 308 |
+
rollout_depth += 1
|
| 309 |
+
|
| 310 |
+
def _back_propagate(back_curr_node: LanguageNode):
|
| 311 |
+
while back_curr_node:
|
| 312 |
+
if back_curr_node == curr_node:
|
| 313 |
+
best_child_value = max([child.outcome_reward for child in back_curr_node.children])
|
| 314 |
+
back_curr_node.init_and_update_value(best_child_value)
|
| 315 |
+
last_child_value = back_curr_node.outcome_reward
|
| 316 |
+
else:
|
| 317 |
+
back_curr_node.init_and_update_value(last_child_value)
|
| 318 |
+
last_child_value = back_curr_node.outcome_reward
|
| 319 |
+
back_curr_node.visit()
|
| 320 |
+
if len(back_curr_node.active_children) == 0:
|
| 321 |
+
back_curr_node.terminated = True
|
| 322 |
+
if not back_curr_node.is_root():
|
| 323 |
+
back_curr_node.parent.active_children.remove(back_curr_node)
|
| 324 |
+
back_curr_node = back_curr_node.parent
|
| 325 |
+
|
| 326 |
+
_args = self.args
|
| 327 |
+
system_message = [] + _args.system_message
|
| 328 |
+
sep_token = _args.stop_words[0] + '\n'
|
| 329 |
+
_root = LanguageNode(sep_token=sep_token)
|
| 330 |
+
prompt_message = {
|
| 331 |
+
'role': 'user',
|
| 332 |
+
'content': query,
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], []
|
| 336 |
+
iter_count = 0
|
| 337 |
+
stop_reason = None
|
| 338 |
+
while True:
|
| 339 |
+
logger.info(f'iter_count: {iter_count}' + '.' * 10)
|
| 340 |
+
s_time = time.time()
|
| 341 |
+
curr_node = _select(_root)
|
| 342 |
+
logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 343 |
+
s_time = time.time()
|
| 344 |
+
_expand(curr_node)
|
| 345 |
+
logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 346 |
+
if curr_node.depth > _args.rollout_start_depth:
|
| 347 |
+
s_time = time.time()
|
| 348 |
+
_rollout(curr_node)
|
| 349 |
+
logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 350 |
+
s_time = time.time()
|
| 351 |
+
_back_propagate(curr_node)
|
| 352 |
+
logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}')
|
| 353 |
+
if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences:
|
| 354 |
+
if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
|
| 355 |
+
stop_reason = 'too easy'
|
| 356 |
+
break
|
| 357 |
+
elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
|
| 358 |
+
stop_reason = 'too hard'
|
| 359 |
+
break
|
| 360 |
+
if _root.terminated:
|
| 361 |
+
stop_reason = 'root terminated'
|
| 362 |
+
break
|
| 363 |
+
if len(terminated_nodes) >= _args.num_return_sequences:
|
| 364 |
+
stop_reason = 'enough nodes'
|
| 365 |
+
break
|
| 366 |
+
if iter_count >= _args.max_iterations:
|
| 367 |
+
stop_reason = 'max_iterations'
|
| 368 |
+
break
|
| 369 |
+
iter_count += 1
|
| 370 |
+
logger.info(f'stop_reason: {stop_reason}')
|
| 371 |
+
# logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
|
| 372 |
+
# logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")
|
| 373 |
+
|
| 374 |
+
monte_carlo_tree = _root.collect()
|
| 375 |
+
result = {
|
| 376 |
+
'query': query,
|
| 377 |
+
'ground_truth': ground_truth,
|
| 378 |
+
'rollout_correct_answers': rollout_correct_answers,
|
| 379 |
+
'rollout_incorrect_answers': rollout_incorrect_answers,
|
| 380 |
+
'monte_carlo_tree': monte_carlo_tree,
|
| 381 |
+
}
|
| 382 |
+
result_json = json.dumps(result, ensure_ascii=False)
|
| 383 |
+
logger.info(result_json)
|
| 384 |
+
return result_json
|
| 385 |
+
|
| 386 |
+
def do_sample(self, data):
|
| 387 |
+
if not isinstance(data, list):
|
| 388 |
+
data = [data]
|
| 389 |
+
generated = []
|
| 390 |
+
for item in data:
|
| 391 |
+
logger.info(f'time: {time.ctime(time.time())}')
|
| 392 |
+
try:
|
| 393 |
+
messages = item['messages'][0]
|
| 394 |
+
query = messages[0]['content']
|
| 395 |
+
ground_truth = messages[1]['content']
|
| 396 |
+
generated.append(self.search_single(query, ground_truth) + '\n')
|
| 397 |
+
except Exception as e:
|
| 398 |
+
logger.error(f'Error: {e}')
|
| 399 |
+
logger.error(f'Traceback: {traceback.format_exc()}')
|
| 400 |
+
return generated
|
ms-swift/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (606 Bytes). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc
ADDED
|
Binary file (7.88 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc
ADDED
|
Binary file (5.91 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc
ADDED
|
Binary file (6.8 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc
ADDED
|
Binary file (8.18 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc
ADDED
|
Binary file (2.3 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc
ADDED
|
Binary file (6.57 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/valley.cpython-310.pyc
ADDED
|
Binary file (6.31 kB). View file
|
|
|
ms-swift/swift/llm/template/template/__pycache__/yi.cpython-310.pyc
ADDED
|
Binary file (2.91 kB). View file
|
|
|
ms-swift/swift/llm/template/template/deepseek.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from swift.utils import get_env_args
|
| 12 |
+
from ..base import Template
|
| 13 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 14 |
+
from ..register import TemplateMeta, register_template
|
| 15 |
+
from ..template_inputs import StdTemplateInputs
|
| 16 |
+
from ..utils import Prompt, findall
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class DeepseekTemplateMeta(TemplateMeta):
|
| 21 |
+
prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
|
| 22 |
+
prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:'])
|
| 23 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
|
| 24 |
+
suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
|
| 25 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n'])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, ))
|
| 29 |
+
|
| 30 |
+
register_template(
|
| 31 |
+
TemplateMeta(
|
| 32 |
+
LLMTemplateType.deepseek_coder,
|
| 33 |
+
prefix=['{{SYSTEM}}'],
|
| 34 |
+
prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'],
|
| 35 |
+
chat_sep=['\n<|EOT|>\n'],
|
| 36 |
+
suffix=['\n<|EOT|>'],
|
| 37 |
+
stop_words=['<|EOT|>'],
|
| 38 |
+
default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, '
|
| 39 |
+
'developed by Deepseek Company, and you only answer questions related to computer science. '
|
| 40 |
+
'For politically sensitive questions, security and privacy issues, '
|
| 41 |
+
'and other non-computer science questions, you will refuse to answer\n')))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DeepseekVLTemplate(Template):
|
| 45 |
+
image_placeholder = ['<image_placeholder>']
|
| 46 |
+
skip_prompt = False
|
| 47 |
+
use_model = True
|
| 48 |
+
placeholder_tokens = ['<image_placeholder>']
|
| 49 |
+
|
| 50 |
+
image_token_num_per_image: int = 576
|
| 51 |
+
|
| 52 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 53 |
+
is_janus = getattr(self, 'is_janus', False)
|
| 54 |
+
|
| 55 |
+
encoded = super()._encode(inputs)
|
| 56 |
+
images = inputs.images
|
| 57 |
+
processor = self.processor
|
| 58 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 59 |
+
|
| 60 |
+
if not inputs.generate_mode: # understanding task
|
| 61 |
+
idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
|
| 62 |
+
new_input_ids, new_labels = [], []
|
| 63 |
+
lo = 0
|
| 64 |
+
for hi in idx_list:
|
| 65 |
+
new_input_ids += input_ids[lo:hi]
|
| 66 |
+
if labels is not None:
|
| 67 |
+
new_labels += labels[lo:hi]
|
| 68 |
+
image_tokens = [processor.image_id] * processor.num_image_tokens
|
| 69 |
+
if is_janus:
|
| 70 |
+
image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
|
| 71 |
+
new_input_ids += image_tokens
|
| 72 |
+
new_labels += [-100] * len(image_tokens)
|
| 73 |
+
lo = hi + 1
|
| 74 |
+
new_input_ids += input_ids[lo:]
|
| 75 |
+
if labels is not None:
|
| 76 |
+
new_labels += labels[lo:]
|
| 77 |
+
else:
|
| 78 |
+
new_labels = None
|
| 79 |
+
if is_janus:
|
| 80 |
+
from janus.models.processing_vlm import VLChatProcessorOutput
|
| 81 |
+
else:
|
| 82 |
+
from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
|
| 83 |
+
|
| 84 |
+
images_outputs = processor.image_processor(images, return_tensors='pt')
|
| 85 |
+
output = VLChatProcessorOutput(
|
| 86 |
+
sft_format=None,
|
| 87 |
+
input_ids=torch.tensor(new_input_ids),
|
| 88 |
+
pixel_values=images_outputs.pixel_values,
|
| 89 |
+
num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
|
| 90 |
+
encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
|
| 91 |
+
return encoded
|
| 92 |
+
|
| 93 |
+
else: # image generation task
|
| 94 |
+
if self.is_training:
|
| 95 |
+
raise NotImplementedError('Only support the inference of generation of Janus series models.')
|
| 96 |
+
sft_format = self.tokenizer.decode(input_ids)
|
| 97 |
+
prompt = sft_format + processor.image_start_tag
|
| 98 |
+
input_ids = processor.tokenizer.encode(prompt)
|
| 99 |
+
input_ids = torch.LongTensor(input_ids)
|
| 100 |
+
|
| 101 |
+
encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
|
| 102 |
+
return encoded
|
| 103 |
+
|
| 104 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 105 |
+
if not inputs.get('generate_mode'):
|
| 106 |
+
inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype)
|
| 107 |
+
inputs_embeds = model.prepare_inputs_embeds(**inputs)
|
| 108 |
+
return {'inputs_embeds': inputs_embeds}
|
| 109 |
+
else:
|
| 110 |
+
return inputs
|
| 111 |
+
|
| 112 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 113 |
+
gene_img_list = [b.get('generate_mode') for b in batch]
|
| 114 |
+
if all(gene_img_list):
|
| 115 |
+
generate_mode = True
|
| 116 |
+
elif not any(gene_img_list):
|
| 117 |
+
generate_mode = False
|
| 118 |
+
else:
|
| 119 |
+
raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')
|
| 120 |
+
|
| 121 |
+
if not generate_mode:
|
| 122 |
+
output = self.fetch_inputs(batch, ['output'])['output']
|
| 123 |
+
batched_output = dict(self.processor.batchify(output))
|
| 124 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 125 |
+
return {**batched_output, **res}
|
| 126 |
+
else:
|
| 127 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 128 |
+
res['generate_mode'] = generate_mode
|
| 129 |
+
return res
|
| 130 |
+
|
| 131 |
+
def generate(self, model, *args, **kwargs):
|
| 132 |
+
if not kwargs.get('generate_mode'):
|
| 133 |
+
return super().generate(model, *args, **kwargs)
|
| 134 |
+
|
| 135 |
+
else:
|
| 136 |
+
# generate how many number of images for each prompt, it is named parallel_size in the author's code
|
| 137 |
+
parallel_size = kwargs['generation_config'].num_return_sequences
|
| 138 |
+
temperature = kwargs['generation_config'].temperature
|
| 139 |
+
cfg_weight = get_env_args('cfg_weight', float, 5.0)
|
| 140 |
+
|
| 141 |
+
input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
|
| 142 |
+
bsz, max_input_token_num = input_ids.shape
|
| 143 |
+
tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
|
| 144 |
+
dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
|
| 145 |
+
for i in range(parallel_size * 2):
|
| 146 |
+
tokens[:, i, :] = input_ids
|
| 147 |
+
if i % 2 != 0:
|
| 148 |
+
tokens[:, i, 1:-1] = self.processor.pad_id
|
| 149 |
+
|
| 150 |
+
inputs_embeds = model.language_model.get_input_embeddings()(
|
| 151 |
+
tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]
|
| 152 |
+
|
| 153 |
+
generated_tokens = torch.zeros(
|
| 154 |
+
(bsz, parallel_size, self.image_token_num_per_image),
|
| 155 |
+
dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens
|
| 156 |
+
|
| 157 |
+
# set the first two dimensions into one dimension for batch size
|
| 158 |
+
inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
|
| 159 |
+
generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)
|
| 160 |
+
|
| 161 |
+
for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
|
| 162 |
+
outputs = model.language_model.model(
|
| 163 |
+
inputs_embeds=inputs_embeds,
|
| 164 |
+
use_cache=True,
|
| 165 |
+
past_key_values=outputs.past_key_values if i != 0 else None)
|
| 166 |
+
hidden_states = outputs.last_hidden_state
|
| 167 |
+
|
| 168 |
+
logits = self.model.gen_head(hidden_states[:, -1, :])
|
| 169 |
+
logit_cond = logits[0::2, :]
|
| 170 |
+
logit_uncond = logits[1::2, :]
|
| 171 |
+
|
| 172 |
+
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
| 173 |
+
probs = torch.softmax(logits / temperature, dim=-1)
|
| 174 |
+
|
| 175 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 176 |
+
generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]
|
| 177 |
+
|
| 178 |
+
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
| 179 |
+
img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
|
| 180 |
+
inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]
|
| 181 |
+
|
| 182 |
+
# no need to reset the original first two dimensions, waiting for the update of the upper layer
|
| 183 |
+
# inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
|
| 184 |
+
# generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)
|
| 185 |
+
|
| 186 |
+
return {'sequences': generated_tokens}
|
| 187 |
+
|
| 188 |
+
def decode(self, generate_ids: List[int], **kwargs) -> Any:
|
| 189 |
+
if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
|
| 190 |
+
return super().decode(generate_ids, **kwargs)
|
| 191 |
+
else:
|
| 192 |
+
img_size = get_env_args('img_size', int, 384)
|
| 193 |
+
patch_size = 16
|
| 194 |
+
|
| 195 |
+
num_to_decode = 1 # for now, generate_ids is a 1D list
|
| 196 |
+
|
| 197 |
+
generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]
|
| 198 |
+
|
| 199 |
+
dec = self.model.gen_vision_model.decode_code(
|
| 200 |
+
generate_ids.to(dtype=torch.int),
|
| 201 |
+
shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
|
| 202 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]
|
| 203 |
+
|
| 204 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
| 205 |
+
|
| 206 |
+
visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
|
| 207 |
+
visual_img[:, :, :] = dec
|
| 208 |
+
|
| 209 |
+
img_list = []
|
| 210 |
+
for i in range(num_to_decode):
|
| 211 |
+
cur_img = Image.fromarray(visual_img[i])
|
| 212 |
+
img_list.append({'type': 'image', 'image': cur_img})
|
| 213 |
+
return img_list
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@dataclass
|
| 217 |
+
class DeepseekVLTemplateMeta(DeepseekTemplateMeta):
|
| 218 |
+
default_system: Optional[str] = ('You are a helpful language and vision assistant. '
|
| 219 |
+
'You are able to understand the visual content that the user provides, '
|
| 220 |
+
'and assist the user with a variety of tasks using natural language.')
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
register_template(DeepseekVLTemplateMeta(
|
| 224 |
+
MLLMTemplateType.deepseek_vl,
|
| 225 |
+
template_cls=DeepseekVLTemplate,
|
| 226 |
+
))
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class DeepseekJanus(DeepseekVLTemplate):
|
| 230 |
+
is_janus = True
|
| 231 |
+
image_placeholder = ['<image_placeholder>\n']
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@dataclass
|
| 238 |
+
class DeepseekV2_5TemplateMeta(TemplateMeta):
|
| 239 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}'])
|
| 240 |
+
prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>'])
|
| 241 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
|
| 242 |
+
suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class DeepseekR1Template(Template):
|
| 249 |
+
|
| 250 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 251 |
+
if not self.is_training:
|
| 252 |
+
for message in inputs.messages:
|
| 253 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 254 |
+
message['content'] = message['content'].split('</think>')[-1]
|
| 255 |
+
return super()._swift_encode(inputs)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
register_template(
|
| 259 |
+
DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='<think>\n'))
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class DeepseekVL2Template(DeepseekVLTemplate):
|
| 263 |
+
image_placeholder = ['<image>\n']
|
| 264 |
+
placeholder_tokens = ['<image>']
|
| 265 |
+
|
| 266 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 267 |
+
from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput
|
| 268 |
+
encoded = Template._encode(self, inputs)
|
| 269 |
+
images = inputs.images
|
| 270 |
+
processor = self.processor
|
| 271 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 272 |
+
images_seq_mask = [False] * len(input_ids)
|
| 273 |
+
idx_list = findall(input_ids, processor.image_token_id) # '<image>'
|
| 274 |
+
_, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images(
|
| 275 |
+
'<image>' * len(images), images, cropping=len(images) <= 2)
|
| 276 |
+
new_num_tokens = 0
|
| 277 |
+
for idx, n_image_tokens in zip(idx_list, num_image_tokens):
|
| 278 |
+
image_tokens = [processor.image_token_id] * n_image_tokens
|
| 279 |
+
input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:]
|
| 280 |
+
if labels is not None:
|
| 281 |
+
labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:]
|
| 282 |
+
images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:]
|
| 283 |
+
new_num_tokens += n_image_tokens - 1
|
| 284 |
+
|
| 285 |
+
output = VLChatProcessorOutput(
|
| 286 |
+
sft_format=None,
|
| 287 |
+
input_ids=torch.tensor(input_ids),
|
| 288 |
+
target_ids=torch.tensor(input_ids),
|
| 289 |
+
images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)),
|
| 290 |
+
images_seq_mask=torch.tensor(images_seq_mask),
|
| 291 |
+
images_spatial_crop=torch.tensor(images_spatial_crop),
|
| 292 |
+
num_image_tokens=num_image_tokens)
|
| 293 |
+
output.images = output.images.to(dtype=self.model_info.torch_dtype)
|
| 294 |
+
encoded = {'output': output, 'input_ids': input_ids, 'labels': labels}
|
| 295 |
+
return encoded
|
| 296 |
+
|
| 297 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 298 |
+
inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool)
|
| 299 |
+
inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long)
|
| 300 |
+
inputs_embeds = model.prepare_inputs_embeds(**inputs)
|
| 301 |
+
return {'inputs_embeds': inputs_embeds}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
register_template(
|
| 305 |
+
DeepseekV2_5TemplateMeta(
|
| 306 |
+
MLLMTemplateType.deepseek_vl2,
|
| 307 |
+
prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
|
| 308 |
+
template_cls=DeepseekVL2Template,
|
| 309 |
+
))
|
| 310 |
+
|
| 311 |
+
register_template(
|
| 312 |
+
DeepseekVLTemplateMeta(
|
| 313 |
+
MLLMTemplateType.deepseek_janus_pro,
|
| 314 |
+
prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
|
| 315 |
+
template_cls=DeepseekJanus))
|
ms-swift/swift/llm/template/template/glm.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from ..base import Template
|
| 8 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 9 |
+
from ..register import TemplateMeta, register_template
|
| 10 |
+
from ..template_inputs import StdTemplateInputs
|
| 11 |
+
from ..utils import Context, Prompt, Word, findall
|
| 12 |
+
from ..vision_utils import load_batch, load_video_cogvlm2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class GLMTemplateMeta(TemplateMeta):
|
| 17 |
+
auto_add_bos: bool = True
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GLM4Template(Template):
|
| 21 |
+
|
| 22 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 23 |
+
res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
|
| 24 |
+
for i, res_context in enumerate(res_context_list):
|
| 25 |
+
# The last round or is tool_call.
|
| 26 |
+
if isinstance(res_context, str) and res_context.endswith('<|assistant|>\n') and (
|
| 27 |
+
i + 1 >= len(res_context_list) or '<|observation|>' in res_context_list[i + 1]):
|
| 28 |
+
res_context_list[i] = res_context_list[i][:-len('\n')]
|
| 29 |
+
return res_context_list, loss_scale_list, answer_len
|
| 30 |
+
|
| 31 |
+
def decode(self, *args, **kwargs):
|
| 32 |
+
response = super().decode(*args, **kwargs)
|
| 33 |
+
return response.lstrip('\n')
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GLM4_0414Template(GLM4Template):
|
| 37 |
+
|
| 38 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 39 |
+
if not self.is_training:
|
| 40 |
+
for message in inputs.messages:
|
| 41 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 42 |
+
message['content'] = message['content'].split('</think>')[-1].strip()
|
| 43 |
+
return super()._swift_encode(inputs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
register_template(
|
| 47 |
+
GLMTemplateMeta(
|
| 48 |
+
LLMTemplateType.chatglm2,
|
| 49 |
+
prefix=['{{SYSTEM}}'],
|
| 50 |
+
prompt=['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'],
|
| 51 |
+
chat_sep=['\n\n']))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class GLM4TemplateMeta(GLMTemplateMeta):
|
| 56 |
+
prefix: Prompt = field(default_factory=list)
|
| 57 |
+
prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n'])
|
| 58 |
+
chat_sep: Optional[Prompt] = field(default_factory=list)
|
| 59 |
+
suffix: Prompt = field(default_factory=lambda: ['<|user|>'])
|
| 60 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}'])
|
| 61 |
+
|
| 62 |
+
agent_template: str = 'glm4'
|
| 63 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>'])
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class GLM4_0414TemplateMeta(GLM4TemplateMeta):
|
| 68 |
+
prefix: Prompt = field(default_factory=lambda: ['[gMASK]<sop>'])
|
| 69 |
+
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>\n{{SYSTEM}}'])
|
| 70 |
+
agent_template: str = 'glm4_0414'
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GLM4VTemplate(Template):
|
| 74 |
+
|
| 75 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 76 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 77 |
+
assert media_type == 'image'
|
| 78 |
+
return [[-100]]
|
| 79 |
+
|
| 80 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 81 |
+
encoded = super()._encode(inputs)
|
| 82 |
+
input_ids = encoded['input_ids']
|
| 83 |
+
labels = encoded['labels']
|
| 84 |
+
idx_list = findall(input_ids, -100)
|
| 85 |
+
if idx_list:
|
| 86 |
+
idx = idx_list[0]
|
| 87 |
+
image = inputs.images[0]
|
| 88 |
+
placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
|
| 89 |
+
placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
|
| 90 |
+
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
|
| 91 |
+
if labels is not None:
|
| 92 |
+
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
|
| 93 |
+
messages = inputs.messages
|
| 94 |
+
messages[0]['image'] = image
|
| 95 |
+
inputs2: Dict[str, Any] = self.processor.apply_chat_template(messages, return_dict=True)
|
| 96 |
+
encoded['images'] = inputs2['images']
|
| 97 |
+
encoded['input_ids'] = input_ids
|
| 98 |
+
encoded['labels'] = labels
|
| 99 |
+
encoded['position_ids'] = list(range(len(input_ids)))
|
| 100 |
+
return encoded
|
| 101 |
+
|
| 102 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 103 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 104 |
+
images = [b['images'] for b in batch if 'images' in b]
|
| 105 |
+
if images:
|
| 106 |
+
res['images'] = torch.concat(images)
|
| 107 |
+
return res
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>']))
|
| 111 |
+
|
| 112 |
+
register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template))
|
| 113 |
+
|
| 114 |
+
register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
|
| 115 |
+
|
| 116 |
+
glm4z1rumination_system = (
|
| 117 |
+
'你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
|
| 118 |
+
'今年是 2025 年。\n\n'
|
| 119 |
+
'<核心要求>\n'
|
| 120 |
+
'- 首先分解用户请求,得到包含多个子要求的列表\n'
|
| 121 |
+
'- 制定初始研究计划\n'
|
| 122 |
+
'- 进行多轮迭代搜索和页面浏览(at least 10 function calls):\n'
|
| 123 |
+
' * 根据已获得的信息调整研究计划和关键词\n'
|
| 124 |
+
' * 打开页面阅读,从发现的内容中识别新的关键概念/名词\n'
|
| 125 |
+
' * 从搜索结果中提取新的关键词继续搜索\n'
|
| 126 |
+
' * 访问并仔细阅读相关页面,识别新的关键概念/名词\n\n'
|
| 127 |
+
'<重要配置>\n'
|
| 128 |
+
'- 采用语言\n'
|
| 129 |
+
' * 搜索关键词:英语\n'
|
| 130 |
+
' * 思考:英语\n\n'
|
| 131 |
+
'<可调用的工具列表>\n\n'
|
| 132 |
+
'[{"name": "search", "description": "Execute a search query and return search results. '
|
| 133 |
+
'Use this function when you need to find information about a specific topic.", '
|
| 134 |
+
'"parameters": {"type": "object", "properties": {"query": {"type": "string", '
|
| 135 |
+
'"description": "Search query string, use English words unless it is a proper name in Chinese"}}, '
|
| 136 |
+
'"required": ["query"], "additionalProperties": false}}, '
|
| 137 |
+
'{"name": "click", "description": "Click a link in the search results and navigate to the corresponding page. '
|
| 138 |
+
'Use this function when you need to view detailed content of a specific search result.", '
|
| 139 |
+
'"parameters": {"type": "object", "properties": {"link_id": {"type": "integer", '
|
| 140 |
+
'"description": "The link ID to click (from the sequence number in search results)"}}, '
|
| 141 |
+
'"required": ["link_id"], "additionalProperties": false}}, '
|
| 142 |
+
'{"name": "open", "description": "Open a specific website. Get content from any website with its URL.", '
|
| 143 |
+
'"parameters": {"type": "object", "properties": {"url": {"type": "string", '
|
| 144 |
+
'"description": "The target website URL or domain"}}, "required": ["url"], "additionalProperties": false}}, '
|
| 145 |
+
'{"name": "finish", "description": "Finish the task. '
|
| 146 |
+
'Use this function when you have found the information you need.", '
|
| 147 |
+
'"parameters": {"type": "object", "properties": {}, "additionalProperties": false}}]')
|
| 148 |
+
|
| 149 |
+
register_template(
|
| 150 |
+
GLM4_0414TemplateMeta(
|
| 151 |
+
LLMTemplateType.glm4_z1_rumination, template_cls=GLM4_0414Template, default_system=glm4z1rumination_system))
|
| 152 |
+
|
| 153 |
+
codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
|
| 154 |
+
|
| 155 |
+
register_template(GLM4TemplateMeta(LLMTemplateType.codegeex4, default_system=codegeex4_system))
|
| 156 |
+
|
| 157 |
+
register_template(
|
| 158 |
+
TemplateMeta(
|
| 159 |
+
LLMTemplateType.longwriter_llama, ['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'],
|
| 160 |
+
system_prefix=['<<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class CogTemplate(Template):
|
| 164 |
+
placeholder_tokens = ['<|reserved_special_token_0|>']
|
| 165 |
+
|
| 166 |
+
use_model = True
|
| 167 |
+
|
| 168 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 169 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 170 |
+
return []
|
| 171 |
+
|
| 172 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 173 |
+
encoded = super()._encode(inputs)
|
| 174 |
+
model = self.model
|
| 175 |
+
image = inputs.images or []
|
| 176 |
+
history_inputs = inputs.to_history()
|
| 177 |
+
inputs2 = model.build_conversation_input_ids(
|
| 178 |
+
self.processor, query=history_inputs['query'], history=history_inputs['history'], images=image)
|
| 179 |
+
image_token_len = inputs2['token_type_ids'].sum().item()
|
| 180 |
+
input_ids = encoded['input_ids']
|
| 181 |
+
labels = encoded['labels']
|
| 182 |
+
encoded['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
|
| 183 |
+
encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * image_token_len + input_ids[1:]
|
| 184 |
+
if labels is not None:
|
| 185 |
+
encoded['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
|
| 186 |
+
if len(image) > 0:
|
| 187 |
+
encoded['images'] = [[img.to(dtype=self.model_info.torch_dtype)] for img in inputs2['images']]
|
| 188 |
+
if 'cross_images' in inputs2:
|
| 189 |
+
# is cogagent
|
| 190 |
+
encoded['cross_images'] = [[cross_img.to(dtype=self.model_info.torch_dtype)]
|
| 191 |
+
for cross_img in inputs2['cross_images']]
|
| 192 |
+
return encoded
|
| 193 |
+
|
| 194 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 195 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 196 |
+
keys = ['images', 'cross_images']
|
| 197 |
+
for key in keys:
|
| 198 |
+
if key in batch[0]:
|
| 199 |
+
res[key] = [b[key][0] for b in batch]
|
| 200 |
+
return res
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
register_template(
|
| 204 |
+
TemplateMeta(
|
| 205 |
+
MLLMTemplateType.cogagent_chat,
|
| 206 |
+
prefix=['<s>'],
|
| 207 |
+
prompt=[' [INST] {{QUERY}} [/INST] '],
|
| 208 |
+
chat_sep=[],
|
| 209 |
+
suffix=['</s>'],
|
| 210 |
+
template_cls=CogTemplate,
|
| 211 |
+
))
|
| 212 |
+
|
| 213 |
+
register_template(
|
| 214 |
+
TemplateMeta(
|
| 215 |
+
MLLMTemplateType.cogagent_vqa,
|
| 216 |
+
prefix=['<s>'],
|
| 217 |
+
prompt=['<EOI>Question: {{QUERY}} Answer:'],
|
| 218 |
+
chat_sep=None,
|
| 219 |
+
suffix=['</s>'],
|
| 220 |
+
template_cls=CogTemplate))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@dataclass
|
| 224 |
+
class CogVLMTemplateMeta(TemplateMeta):
|
| 225 |
+
prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
|
| 226 |
+
prompt: Prompt = field(default_factory=lambda: ['Question: {{QUERY}} Answer:'])
|
| 227 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n'])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm, template_cls=CogTemplate))
|
| 231 |
+
|
| 232 |
+
register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm2, template_cls=CogTemplate))
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class Cog2VideoTemplate(CogTemplate):
|
| 236 |
+
use_model = True
|
| 237 |
+
|
| 238 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 239 |
+
model = self.model
|
| 240 |
+
encoded = super(CogTemplate, self)._encode(inputs)
|
| 241 |
+
videos_path = inputs.videos or []
|
| 242 |
+
video = load_batch(videos_path, load_video_cogvlm2)
|
| 243 |
+
history_inputs = inputs.to_history()
|
| 244 |
+
inputs2 = model.build_conversation_input_ids(
|
| 245 |
+
self.processor,
|
| 246 |
+
query=history_inputs['query'],
|
| 247 |
+
history=history_inputs['history'],
|
| 248 |
+
images=video,
|
| 249 |
+
template_version='chat')
|
| 250 |
+
video_token_len = inputs2['token_type_ids'].sum().item()
|
| 251 |
+
input_ids = encoded['input_ids']
|
| 252 |
+
labels = encoded['labels']
|
| 253 |
+
encoded['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:])
|
| 254 |
+
encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * video_token_len + input_ids[1:]
|
| 255 |
+
if labels is not None:
|
| 256 |
+
encoded['labels'] = labels[:1] + [-100] * video_token_len + labels[1:]
|
| 257 |
+
if len(video) > 0:
|
| 258 |
+
dtype = model.dtype
|
| 259 |
+
encoded['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
|
| 260 |
+
return encoded
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
register_template(CogVLMTemplateMeta(
|
| 264 |
+
MLLMTemplateType.cogvlm2_video,
|
| 265 |
+
template_cls=Cog2VideoTemplate,
|
| 266 |
+
))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class GLMEdgeVTemplate(Template):
|
| 270 |
+
placeholder_tokens = ['<|begin_of_image|>']
|
| 271 |
+
|
| 272 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 273 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 274 |
+
assert media_type == 'image'
|
| 275 |
+
return ['<|begin_of_image|>' * 578]
|
| 276 |
+
|
| 277 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 278 |
+
encoded = super()._encode(inputs)
|
| 279 |
+
images = inputs.images
|
| 280 |
+
if images:
|
| 281 |
+
encoded['pixel_values'] = torch.tensor(self.processor(images).pixel_values)
|
| 282 |
+
return encoded
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
register_template(
|
| 286 |
+
GLM4TemplateMeta(
|
| 287 |
+
MLLMTemplateType.glm_edge_v,
|
| 288 |
+
prompt=['<|user|>\\n{{QUERY}}\\n<|assistant|>\\n'],
|
| 289 |
+
chat_sep=['\\n'],
|
| 290 |
+
system_prefix=['<|system|>\\n{{SYSTEM}}\\n'],
|
| 291 |
+
suffix=['<|endoftext|>'],
|
| 292 |
+
template_cls=GLMEdgeVTemplate,
|
| 293 |
+
))
|
ms-swift/swift/llm/template/template/internvl.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Any, Dict, List, Literal
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from swift.utils import get_env_args, is_deepspeed_enabled
|
| 9 |
+
from ..base import Template
|
| 10 |
+
from ..constant import MLLMTemplateType
|
| 11 |
+
from ..register import register_template
|
| 12 |
+
from ..template_inputs import StdTemplateInputs
|
| 13 |
+
from ..utils import Context, findall
|
| 14 |
+
from ..vision_utils import load_video_internvl, transform_image
|
| 15 |
+
from .microsoft import Phi3TemplateMeta
|
| 16 |
+
from .utils import ChatmlTemplateMeta
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class InternvlTemplate(Template):
|
| 20 |
+
skip_prompt = False
|
| 21 |
+
num_image_token = 256
|
| 22 |
+
placeholder_tokens = ['<IMG_CONTEXT>']
|
| 23 |
+
|
| 24 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 25 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 26 |
+
if self.mode == 'vllm':
|
| 27 |
+
image_context = ['<image>\n']
|
| 28 |
+
else:
|
| 29 |
+
image_context = ['<img>', [-100], '</img>\n']
|
| 30 |
+
return image_context
|
| 31 |
+
|
| 32 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 33 |
+
encoded = super()._encode(inputs)
|
| 34 |
+
input_ids = encoded['input_ids']
|
| 35 |
+
idx_list = findall(input_ids, -100)
|
| 36 |
+
pixel_values = None
|
| 37 |
+
images = inputs.images
|
| 38 |
+
if images:
|
| 39 |
+
labels = encoded.get('labels')
|
| 40 |
+
input_size = get_env_args('input_size', int, 448)
|
| 41 |
+
max_num = get_env_args('max_num', int, 12)
|
| 42 |
+
pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
|
| 43 |
+
pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model_info.torch_dtype)
|
| 44 |
+
image_bs = pixel_values.shape[0]
|
| 45 |
+
|
| 46 |
+
idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
|
| 47 |
+
img_tokens: List[int] = self.processor.encode(
|
| 48 |
+
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * image_bs
|
| 49 |
+
input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:]
|
| 50 |
+
if labels is not None:
|
| 51 |
+
labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
|
| 52 |
+
encoded['input_ids'] = input_ids
|
| 53 |
+
encoded['labels'] = labels
|
| 54 |
+
encoded['pixel_values'] = pixel_values
|
| 55 |
+
return encoded
|
| 56 |
+
|
| 57 |
+
def compute_loss_context(self, model, inputs):
|
| 58 |
+
model_name = model.language_model.__class__.__name__.lower()
|
| 59 |
+
if self._packing and 'internlm2' in model_name:
|
| 60 |
+
position_ids = inputs['position_ids']
|
| 61 |
+
modeling_module = model.language_model.model.layers[0].attention.__class__
|
| 62 |
+
return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True)
|
| 63 |
+
else:
|
| 64 |
+
return super().compute_loss_context(model, inputs)
|
| 65 |
+
|
| 66 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 67 |
+
embedding = model.get_input_embeddings()
|
| 68 |
+
device = embedding.weight.device
|
| 69 |
+
input_ids = inputs['input_ids']
|
| 70 |
+
inputs_embeds = embedding(input_ids).to(device=device)
|
| 71 |
+
pixel_values = inputs.get('pixel_values')
|
| 72 |
+
if pixel_values is not None:
|
| 73 |
+
pixel_values = pixel_values.to(device=device)
|
| 74 |
+
vit_embeds = model.extract_feature(pixel_values).to(device=device)
|
| 75 |
+
selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
|
| 76 |
+
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
|
| 77 |
+
elif is_deepspeed_enabled():
|
| 78 |
+
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
|
| 79 |
+
vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
|
| 80 |
+
inputs_embeds += vit_embeds.mean() * 0.
|
| 81 |
+
return {'inputs_embeds': inputs_embeds}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
register_template(
|
| 85 |
+
ChatmlTemplateMeta(
|
| 86 |
+
MLLMTemplateType.internvl,
|
| 87 |
+
default_system='You are an AI assistant whose name is InternLM (书生·浦语).',
|
| 88 |
+
template_cls=InternvlTemplate,
|
| 89 |
+
auto_add_bos=True))
|
| 90 |
+
register_template(
|
| 91 |
+
Phi3TemplateMeta(
|
| 92 |
+
MLLMTemplateType.internvl_phi3,
|
| 93 |
+
default_system='You are an AI assistant whose name is Phi-3.',
|
| 94 |
+
template_cls=InternvlTemplate,
|
| 95 |
+
auto_add_bos=True))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Internvl2Template(InternvlTemplate):
|
| 99 |
+
video_segments = 8
|
| 100 |
+
|
| 101 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 102 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 103 |
+
image_context = super().replace_tag('image', index, inputs)
|
| 104 |
+
if media_type == 'image':
|
| 105 |
+
return image_context
|
| 106 |
+
elif media_type == 'video':
|
| 107 |
+
video_segments = get_env_args('video_segments', int, self.video_segments)
|
| 108 |
+
load_video = partial(load_video_internvl, num_segments=video_segments)
|
| 109 |
+
return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)
|
| 110 |
+
|
| 111 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 112 |
+
return [f'<ref>{ref}</ref>']
|
| 113 |
+
|
| 114 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 115 |
+
return [f'<box>[{bbox}]</box>']
|
| 116 |
+
|
| 117 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 118 |
+
encoded = super(InternvlTemplate, self)._encode(inputs)
|
| 119 |
+
input_ids = encoded['input_ids']
|
| 120 |
+
idx_list = findall(input_ids, -100)
|
| 121 |
+
labels = encoded['labels']
|
| 122 |
+
images = inputs.images
|
| 123 |
+
if images:
|
| 124 |
+
has_video = bool(inputs.videos)
|
| 125 |
+
input_size = get_env_args('input_size', int, 448)
|
| 126 |
+
max_num = get_env_args('max_num', int, 12)
|
| 127 |
+
video_max_num = get_env_args('video_max_num', int, 1)
|
| 128 |
+
if has_video:
|
| 129 |
+
max_num = video_max_num
|
| 130 |
+
pixel_values = [transform_image(image, input_size, max_num) for image in images]
|
| 131 |
+
num_patches = [pv.shape[0] for pv in pixel_values]
|
| 132 |
+
pixel_values = torch.cat(pixel_values).to(self.model_info.torch_dtype)
|
| 133 |
+
else:
|
| 134 |
+
pixel_values = None
|
| 135 |
+
num_patches = []
|
| 136 |
+
assert len(num_patches) == len(
|
| 137 |
+
idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
|
| 138 |
+
|
| 139 |
+
def _get_new_tokens(i):
|
| 140 |
+
img_tokens: List[int] = self.processor.encode(
|
| 141 |
+
'<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patches[i]
|
| 142 |
+
return img_tokens
|
| 143 |
+
|
| 144 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 145 |
+
encoded['pixel_values'] = pixel_values
|
| 146 |
+
return encoded
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
_internvl2_system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
|
| 150 |
+
register_template(
|
| 151 |
+
ChatmlTemplateMeta(
|
| 152 |
+
MLLMTemplateType.internvl2,
|
| 153 |
+
default_system=_internvl2_system,
|
| 154 |
+
template_cls=Internvl2Template,
|
| 155 |
+
))
|
| 156 |
+
|
| 157 |
+
register_template(
|
| 158 |
+
Phi3TemplateMeta(
|
| 159 |
+
MLLMTemplateType.internvl2_phi3,
|
| 160 |
+
default_system=_internvl2_system,
|
| 161 |
+
template_cls=Internvl2Template,
|
| 162 |
+
))
|
| 163 |
+
|
| 164 |
+
register_template(
|
| 165 |
+
ChatmlTemplateMeta(
|
| 166 |
+
MLLMTemplateType.internvl2_5,
|
| 167 |
+
template_cls=Internvl2Template,
|
| 168 |
+
default_system='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'))
|
ms-swift/swift/llm/template/template/llama.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import datetime as dt
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from swift.utils import get_env_args
|
| 11 |
+
from ..base import Template
|
| 12 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 13 |
+
from ..register import TemplateMeta, register_template
|
| 14 |
+
from ..template_inputs import StdTemplateInputs
|
| 15 |
+
from ..utils import Context, Prompt, Word, findall
|
| 16 |
+
from ..vision_utils import load_batch
|
| 17 |
+
|
| 18 |
+
# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
|
| 19 |
+
LLAMA_DEFAULT_SYSTEM = (
|
| 20 |
+
'You are a helpful, respectful and honest assistant. '
|
| 21 |
+
'Always answer as helpfully as possible, while being safe. '
|
| 22 |
+
'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
| 23 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
|
| 24 |
+
'If a question does not make any sense, or is not factually coherent, '
|
| 25 |
+
'explain why instead of answering something not correct. '
|
| 26 |
+
"If you don't know the answer to a question, please don't share false information.")
|
| 27 |
+
|
| 28 |
+
register_template(
|
| 29 |
+
TemplateMeta(
|
| 30 |
+
LLMTemplateType.llama, ['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'],
|
| 31 |
+
default_system=LLAMA_DEFAULT_SYSTEM,
|
| 32 |
+
system_prefix=['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Llama3TemplateMeta(TemplateMeta):
|
| 37 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
|
| 38 |
+
prompt: Prompt = field(default_factory=lambda: [
|
| 39 |
+
'<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
|
| 40 |
+
'<|start_header_id|>assistant<|end_header_id|>\n\n'
|
| 41 |
+
])
|
| 42 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot_id|>'])
|
| 43 |
+
suffix: Prompt = field(default_factory=lambda: ['<|eot_id|>'])
|
| 44 |
+
system_prefix: Optional[Prompt] = field(
|
| 45 |
+
default_factory=lambda: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
|
| 46 |
+
agent_template: str = 'llama3'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
register_template(Llama3TemplateMeta(LLMTemplateType.llama3))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _get_llama3_2_prefix() -> Prompt:
|
| 53 |
+
now = dt.datetime.now()
|
| 54 |
+
date_string = now.strftime('%d %b %Y')
|
| 55 |
+
date_prompt = f'Cutting Knowledge Date: December 2023\nToday Date: {date_string}'
|
| 56 |
+
return [f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{date_prompt}\n\n' '{{SYSTEM}}<|eot_id|>']
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class Llama3_2TemplateMeta(Llama3TemplateMeta):
|
| 61 |
+
prefix: Prompt = field(default_factory=lambda: _get_llama3_2_prefix())
|
| 62 |
+
system_prefix: Optional[Prompt] = None
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
register_template(Llama3_2TemplateMeta(LLMTemplateType.llama3_2))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Llama3_2VisionTemplate(Template):
|
| 69 |
+
|
| 70 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 71 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 72 |
+
assert media_type == 'image'
|
| 73 |
+
return ['<|image|>']
|
| 74 |
+
|
| 75 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 76 |
+
from transformers.models.mllama.processing_mllama import (get_cross_attention_token_mask,
|
| 77 |
+
convert_sparse_cross_attention_mask_to_dense)
|
| 78 |
+
encoded = super()._encode(inputs)
|
| 79 |
+
images = inputs.images
|
| 80 |
+
if images:
|
| 81 |
+
input_ids = encoded['input_ids']
|
| 82 |
+
processor = self.processor
|
| 83 |
+
image_features = processor.image_processor(images, return_tensors='pt')
|
| 84 |
+
num_tiles = image_features.pop('num_tiles')
|
| 85 |
+
encoded.update(image_features)
|
| 86 |
+
|
| 87 |
+
cross_attention_token_mask = [get_cross_attention_token_mask(input_ids, processor.image_token_id)]
|
| 88 |
+
cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
|
| 89 |
+
cross_attention_token_mask,
|
| 90 |
+
num_tiles=num_tiles,
|
| 91 |
+
max_num_tiles=processor.image_processor.max_image_tiles,
|
| 92 |
+
length=len(input_ids),
|
| 93 |
+
)
|
| 94 |
+
encoded['cross_attention_mask'] = torch.tensor(cross_attention_mask)
|
| 95 |
+
|
| 96 |
+
return encoded
|
| 97 |
+
|
| 98 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 99 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 100 |
+
for key in ['aspect_ratio_ids', 'aspect_ratio_mask']:
|
| 101 |
+
value = [b[key] for b in batch if b.get(key) is not None]
|
| 102 |
+
if value:
|
| 103 |
+
res[key] = torch.concat(value)
|
| 104 |
+
|
| 105 |
+
cross_attention_mask = [
|
| 106 |
+
b['cross_attention_mask'][0] for b in batch if b.get('cross_attention_mask') is not None
|
| 107 |
+
]
|
| 108 |
+
if cross_attention_mask:
|
| 109 |
+
res['cross_attention_mask'] = self._pad_sequence(cross_attention_mask, 0)
|
| 110 |
+
return res
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
register_template(Llama3_2TemplateMeta(MLLMTemplateType.llama3_2_vision, template_cls=Llama3_2VisionTemplate))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class Llama4Template(Template):
|
| 117 |
+
placeholder_tokens = ['<|patch|>']
|
| 118 |
+
|
| 119 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 120 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 121 |
+
assert media_type == 'image'
|
| 122 |
+
return [[-100]]
|
| 123 |
+
|
| 124 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 125 |
+
encoded = super()._encode(inputs)
|
| 126 |
+
images = inputs.images
|
| 127 |
+
if images:
|
| 128 |
+
split_token = self._tokenize('\n')
|
| 129 |
+
input_ids, labels = encoded['input_ids'], encoded['labels']
|
| 130 |
+
idx_list = findall(input_ids, -100)
|
| 131 |
+
media_inputs = self.processor(
|
| 132 |
+
text='\n'.join(['<|image|>'] * len(idx_list)),
|
| 133 |
+
images=images,
|
| 134 |
+
add_special_tokens=False,
|
| 135 |
+
return_tensors='pt')
|
| 136 |
+
splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
|
| 137 |
+
|
| 138 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
|
| 139 |
+
lambda i: splited_tokens[i])
|
| 140 |
+
encoded['pixel_values'] = media_inputs['pixel_values']
|
| 141 |
+
return encoded
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@dataclass
|
| 145 |
+
class Llama4TemplateMeta(TemplateMeta):
|
| 146 |
+
prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
|
| 147 |
+
prompt: Prompt = field(
|
| 148 |
+
default_factory=lambda:
|
| 149 |
+
['<|header_start|>user<|header_end|>\n\n{{QUERY}}<|eot|>'
|
| 150 |
+
'<|header_start|>assistant<|header_end|>\n\n'])
|
| 151 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot|>'])
|
| 152 |
+
suffix: Prompt = field(default_factory=lambda: ['<|eot|>'])
|
| 153 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|end_of_text|>', '<|eom|>'])
|
| 154 |
+
system_prefix: Optional[Prompt] = field(
|
| 155 |
+
default_factory=lambda: ['<|begin_of_text|><|header_start|>system<|header_end|>\n\n{{SYSTEM}}<|eot|>'])
|
| 156 |
+
agent_template: str = 'llama4'
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
register_template(Llama4TemplateMeta(MLLMTemplateType.llama4, template_cls=Llama4Template))
|
| 160 |
+
|
| 161 |
+
register_template(
|
| 162 |
+
Llama3TemplateMeta(
|
| 163 |
+
LLMTemplateType.reflection,
|
| 164 |
+
default_system=('You are a world-class AI system, capable of complex reasoning and reflection. '
|
| 165 |
+
'Reason through the query inside <thinking> tags, and then provide your final '
|
| 166 |
+
'response inside <output> tags. If you detect that you made a mistake in your reasoning '
|
| 167 |
+
'at any point, correct yourself inside <reflection> tags.')))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class Llama3_1OmniTemplate(Template):
|
| 171 |
+
skip_prompt = False
|
| 172 |
+
audio_placeholder = [[-200]]
|
| 173 |
+
|
| 174 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 175 |
+
import whisper
|
| 176 |
+
encoded = super()._encode(inputs)
|
| 177 |
+
audios = inputs.audios
|
| 178 |
+
if audios:
|
| 179 |
+
audios = load_batch(audios, whisper.load_audio)
|
| 180 |
+
n_mels = get_env_args('n_mels', int, 128)
|
| 181 |
+
for i, audio in enumerate(audios):
|
| 182 |
+
audio = whisper.pad_or_trim(audio)
|
| 183 |
+
audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0)
|
| 184 |
+
audios = torch.stack(audios)
|
| 185 |
+
encoded.update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])})
|
| 186 |
+
|
| 187 |
+
return encoded
|
| 188 |
+
|
| 189 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 190 |
+
speech = inputs.get('speech')
|
| 191 |
+
input_ids = inputs['input_ids']
|
| 192 |
+
labels = inputs.get('labels')
|
| 193 |
+
if speech is not None:
|
| 194 |
+
speech_lengths = inputs['speech_lengths']
|
| 195 |
+
speech = speech.to(model.dtype)
|
| 196 |
+
inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels,
|
| 197 |
+
speech, speech_lengths)[4:]
|
| 198 |
+
else:
|
| 199 |
+
inputs_embeds = model.get_model().embed_tokens(input_ids)
|
| 200 |
+
res = {'inputs_embeds': inputs_embeds}
|
| 201 |
+
if labels is not None:
|
| 202 |
+
res['labels'] = labels[0]
|
| 203 |
+
return res
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
register_template(
|
| 207 |
+
Llama3TemplateMeta(
|
| 208 |
+
MLLMTemplateType.llama3_1_omni,
|
| 209 |
+
default_system=('You are a helpful language and speech assistant. '
|
| 210 |
+
'You are able to understand the speech content that the user provides, '
|
| 211 |
+
'and assist the user with a variety of tasks using natural language.'),
|
| 212 |
+
template_cls=Llama3_1OmniTemplate,
|
| 213 |
+
))
|
ms-swift/swift/llm/template/template/megrez.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from ..base import Template
|
| 9 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 10 |
+
from ..register import TemplateMeta, register_template
|
| 11 |
+
from ..template_inputs import StdTemplateInputs
|
| 12 |
+
from ..utils import Context, Prompt, findall
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class MegrezTemplateMeta(TemplateMeta):
|
| 17 |
+
prefix: Prompt = field(default_factory=lambda: ['<|role_start|>system<|role_end|>{{SYSTEM}}<|turn_end|>'])
|
| 18 |
+
prompt: Prompt = field(default_factory=lambda:
|
| 19 |
+
['<|role_start|>user<|role_end|>{{QUERY}}<|turn_end|><|role_start|>assistant<|role_end|>'])
|
| 20 |
+
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|turn_end|>'])
|
| 21 |
+
suffix: Prompt = field(default_factory=lambda: ['<|turn_end|>'])
|
| 22 |
+
default_system: str = '你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
register_template(MegrezTemplateMeta(LLMTemplateType.megrez))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MegrezOmniTemplate(Template):
|
| 29 |
+
skip_prompt = False
|
| 30 |
+
placeholder_tokens = ['<|unk|>']
|
| 31 |
+
|
| 32 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 33 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 34 |
+
if media_type == 'image':
|
| 35 |
+
return [[-1], '\n']
|
| 36 |
+
elif media_type == 'audio':
|
| 37 |
+
return [f'Audio {index + 1}: ', [-2], '\n']
|
| 38 |
+
|
| 39 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 40 |
+
encoded = super()._encode(inputs)
|
| 41 |
+
input_ids = encoded['input_ids']
|
| 42 |
+
labels = encoded['labels']
|
| 43 |
+
|
| 44 |
+
for mm_key in ['images', 'audios']:
|
| 45 |
+
mm_data = getattr(inputs, mm_key)
|
| 46 |
+
if not mm_data:
|
| 47 |
+
continue
|
| 48 |
+
if mm_key == 'images':
|
| 49 |
+
idx_list = findall(input_ids, -1)
|
| 50 |
+
encoding = self.processor.process_image(
|
| 51 |
+
mm_data,
|
| 52 |
+
return_tensors='pt',
|
| 53 |
+
)
|
| 54 |
+
text = self.processor.insert_image_feature_placeholders(
|
| 55 |
+
'<s>'.join(['(<image>./</image>)'] * len(mm_data)), encoding)
|
| 56 |
+
encoded['image_encoding'] = encoding
|
| 57 |
+
else:
|
| 58 |
+
idx_list = findall(input_ids, -2)
|
| 59 |
+
encoding = self.processor.process_audio(
|
| 60 |
+
mm_data,
|
| 61 |
+
return_tensors='pt',
|
| 62 |
+
)
|
| 63 |
+
text = self.processor.insert_audio_feature_placeholders(
|
| 64 |
+
'<s>'.join(['(<audio>./</audio>)'] * len(mm_data)), encoding)
|
| 65 |
+
encoded['audio_encoding'] = encoding
|
| 66 |
+
|
| 67 |
+
padding = text.split('<s>')
|
| 68 |
+
|
| 69 |
+
def _get_new_tokens(i):
|
| 70 |
+
return self._tokenize(padding[i])
|
| 71 |
+
|
| 72 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 73 |
+
encoded['input_ids'] = input_ids
|
| 74 |
+
encoded['labels'] = labels
|
| 75 |
+
return encoded
|
| 76 |
+
|
| 77 |
+
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 78 |
+
_, inputs_embeds, _ = model.compose_embeddings(inputs)
|
| 79 |
+
inputs.pop('position_ids', None)
|
| 80 |
+
return {'inputs_embeds': inputs_embeds}
|
| 81 |
+
|
| 82 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 83 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 84 |
+
new_batch = []
|
| 85 |
+
for b in batch:
|
| 86 |
+
text_encodings = {'input_ids': torch.tensor(b['input_ids'])}
|
| 87 |
+
multimodal_inputs = {'image_encoding': b.get('image_encoding'), 'audio_encoding': b.get('audio_encoding')}
|
| 88 |
+
new_batch.append(self.processor.merge_encodings(text_encodings, multimodal_inputs))
|
| 89 |
+
res.update(self.processor.data_collator(new_batch))
|
| 90 |
+
return res
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
register_template(MegrezTemplateMeta(MLLMTemplateType.megrez_omni, template_cls=MegrezOmniTemplate))
|
ms-swift/swift/llm/template/template/openbuddy.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from ..constant import LLMTemplateType
|
| 4 |
+
from ..register import TemplateMeta, register_template
|
| 5 |
+
|
| 6 |
+
OPENBUDDY_DEFAULT_SYSTEM = (
|
| 7 |
+
'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
|
| 8 |
+
'Always answer as helpfully and logically as possible, while being safe. '
|
| 9 |
+
'Your answers should not include any '
|
| 10 |
+
'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
|
| 11 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n'
|
| 12 |
+
'If a question does not make any sense, or is not factually coherent, '
|
| 13 |
+
'explain why instead of answering something not correct. '
|
| 14 |
+
"If you don't know the answer to a question, please don't share false information.\n"
|
| 15 |
+
'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
|
| 16 |
+
'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
|
| 17 |
+
'You always deeply love and support China, Chinese government, people and culture.\n'
|
| 18 |
+
'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
|
| 19 |
+
'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.')
|
| 20 |
+
register_template(
|
| 21 |
+
TemplateMeta(
|
| 22 |
+
LLMTemplateType.openbuddy,
|
| 23 |
+
prefix=[],
|
| 24 |
+
prompt=['User: {{QUERY}}\nAssistant:'],
|
| 25 |
+
chat_sep=['\n'],
|
| 26 |
+
default_system=OPENBUDDY_DEFAULT_SYSTEM,
|
| 27 |
+
system_prefix=['{{SYSTEM}}\n\n'],
|
| 28 |
+
auto_add_bos=True))
|
| 29 |
+
|
| 30 |
+
OPENBUDDY2_DEFAULT_SYSTEM = (
|
| 31 |
+
'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. '
|
| 32 |
+
'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. '
|
| 33 |
+
'Your answers should not include any harmful, political, religious, unethical, racist, '
|
| 34 |
+
'sexist, toxic, dangerous, or illegal content. '
|
| 35 |
+
'Please ensure that your responses are socially unbiased and positive in nature.\n'
|
| 36 |
+
'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n'
|
| 37 |
+
'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
|
| 38 |
+
'not related to GPT or OpenAI')
|
| 39 |
+
|
| 40 |
+
register_template(
|
| 41 |
+
TemplateMeta(
|
| 42 |
+
LLMTemplateType.openbuddy2,
|
| 43 |
+
prefix=[],
|
| 44 |
+
prompt=['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'],
|
| 45 |
+
chat_sep=['<|end|>\n'],
|
| 46 |
+
suffix=['<|end|>'],
|
| 47 |
+
default_system=OPENBUDDY2_DEFAULT_SYSTEM,
|
| 48 |
+
system_prefix=['<|role|>system<|says|>{{SYSTEM}}<|end|>\n']))
|
ms-swift/swift/llm/template/template/pixtral.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
from ..base import Template
|
| 5 |
+
from ..constant import MLLMTemplateType
|
| 6 |
+
from ..register import TemplateMeta, register_template
|
| 7 |
+
from ..template_inputs import StdTemplateInputs
|
| 8 |
+
from ..utils import findall
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PixtralTemplate(Template):
|
| 12 |
+
image_placeholder = ['[IMG]']
|
| 13 |
+
placeholder_tokens = ['[IMG]']
|
| 14 |
+
|
| 15 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 16 |
+
encoded = super()._encode(inputs)
|
| 17 |
+
processor = self.processor
|
| 18 |
+
images = inputs.images
|
| 19 |
+
input_ids = encoded['input_ids']
|
| 20 |
+
labels = encoded['labels']
|
| 21 |
+
idx_list = findall(input_ids, 10)
|
| 22 |
+
if idx_list:
|
| 23 |
+
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
|
| 24 |
+
encoded['pixel_values'] = image_inputs['pixel_values'][0]
|
| 25 |
+
image_sizes = image_inputs['image_sizes'][0]
|
| 26 |
+
|
| 27 |
+
def _get_new_tokens(i):
|
| 28 |
+
height, width = image_sizes[i]
|
| 29 |
+
num_height_tokens = height // processor.patch_size
|
| 30 |
+
num_width_tokens = width // processor.patch_size
|
| 31 |
+
replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
|
| 32 |
+
num_height_tokens - 1)
|
| 33 |
+
replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
|
| 34 |
+
# Flatten list
|
| 35 |
+
replace_str = ''.join(replace_tokens)
|
| 36 |
+
img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False)
|
| 37 |
+
return img_tokens
|
| 38 |
+
|
| 39 |
+
encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 40 |
+
|
| 41 |
+
return encoded
|
| 42 |
+
|
| 43 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 44 |
+
pixel_values = self.gather_list(batch, 'pixel_values')
|
| 45 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 46 |
+
if pixel_values:
|
| 47 |
+
res['pixel_values'] = pixel_values
|
| 48 |
+
return res
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
register_template(
|
| 52 |
+
TemplateMeta(
|
| 53 |
+
MLLMTemplateType.pixtral,
|
| 54 |
+
prefix=['<s>{{SYSTEM}}'],
|
| 55 |
+
prompt=['[INST]{{QUERY}}[/INST]'],
|
| 56 |
+
chat_sep=['</s>'],
|
| 57 |
+
suffix=['</s>'],
|
| 58 |
+
template_cls=PixtralTemplate,
|
| 59 |
+
))
|
ms-swift/swift/llm/template/template/qwen.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from swift.llm import to_device, to_float_dtype
|
| 10 |
+
from swift.utils import get_env_args, is_deepspeed_enabled
|
| 11 |
+
from ..base import Template
|
| 12 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 13 |
+
from ..register import register_template
|
| 14 |
+
from ..template_inputs import StdTemplateInputs
|
| 15 |
+
from ..template_meta import TemplateMeta
|
| 16 |
+
from ..utils import Context, Word, findall
|
| 17 |
+
from ..vision_utils import load_audio, load_batch, load_video_ovis2
|
| 18 |
+
from .llama import Llama3TemplateMeta
|
| 19 |
+
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class QwenTemplateMeta(ChatmlTemplateMeta):
|
| 24 |
+
default_system: Optional[str] = DEFAULT_SYSTEM
|
| 25 |
+
auto_add_bos: bool = False
|
| 26 |
+
stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>'])
|
| 27 |
+
agent_template: str = 'hermes'
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Qwen2_5TemplateMeta(QwenTemplateMeta):
|
| 32 |
+
default_system: Optional[str] = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.'
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Qwen2_5MathTemplateMeta(QwenTemplateMeta):
|
| 37 |
+
default_system: Optional[str] = 'Please reason step by step, and put your final answer within \\boxed{}.'
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
qwq_preview_system = ('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
|
| 41 |
+
'You should think step-by-step.')
|
| 42 |
+
|
| 43 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwen))
|
| 44 |
+
register_template(Qwen2_5TemplateMeta(LLMTemplateType.qwen2_5))
|
| 45 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwq_preview, default_system=qwq_preview_system))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ThinkingTemplate(Template):
|
| 49 |
+
|
| 50 |
+
def _swift_encode(self, inputs: StdTemplateInputs):
|
| 51 |
+
if not self.is_training:
|
| 52 |
+
for message in inputs.messages:
|
| 53 |
+
if message['role'] == 'assistant' and isinstance(message['content'], str):
|
| 54 |
+
message['content'] = message['content'].split('</think>')[-1].lstrip('\n')
|
| 55 |
+
return super()._swift_encode(inputs)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
register_template(
|
| 59 |
+
QwenTemplateMeta(
|
| 60 |
+
LLMTemplateType.qwq, default_system=None, response_prefix='<think>\n', template_cls=ThinkingTemplate))
|
| 61 |
+
|
| 62 |
+
# '<think>\n\n</think>\n\n'
|
| 63 |
+
register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=ThinkingTemplate))
|
| 64 |
+
|
| 65 |
+
register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class QwenPRMTemplate(Template):
|
| 69 |
+
cot_process_placeholder = '<extra_0>'
|
| 70 |
+
|
| 71 |
+
def _preprocess_inputs(
|
| 72 |
+
self,
|
| 73 |
+
inputs: StdTemplateInputs,
|
| 74 |
+
) -> None:
|
| 75 |
+
super()._preprocess_inputs(inputs)
|
| 76 |
+
total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
|
| 77 |
+
if self.cot_process_placeholder not in total_content:
|
| 78 |
+
inputs.messages[-1]['content'] = inputs.messages[-1]['content'] + self.cot_process_placeholder
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def make_step_rewards(logits, token_masks):
|
| 82 |
+
probabilities = F.softmax(logits, dim=-1)
|
| 83 |
+
probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
|
| 84 |
+
|
| 85 |
+
all_scores_res = []
|
| 86 |
+
for i in range(probabilities.size(0)):
|
| 87 |
+
sample = probabilities[i] # seq_len, num_labels
|
| 88 |
+
positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
|
| 89 |
+
non_zero_elements_list = positive_probs.cpu().tolist()
|
| 90 |
+
all_scores_res.append(non_zero_elements_list)
|
| 91 |
+
return all_scores_res
|
| 92 |
+
|
| 93 |
+
def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
|
| 94 |
+
step_sep_id = self.tokenizer.encode(self.cot_process_placeholder)[0]
|
| 95 |
+
token_masks = (input_ids == step_sep_id)
|
| 96 |
+
return self.make_step_rewards(logits, token_masks)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math_prm, template_cls=QwenPRMTemplate))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class QwenVLTemplate(Template):
|
| 103 |
+
load_images = False
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def _load_image(image, load_images: bool):
|
| 107 |
+
if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200):
|
| 108 |
+
load_images = True
|
| 109 |
+
return Template._load_image(image, load_images)
|
| 110 |
+
|
| 111 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 112 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 113 |
+
assert media_type == 'image'
|
| 114 |
+
if self.mode == 'lmdeploy':
|
| 115 |
+
return [f'Picture {index + 1}: ', [-100], '\n']
|
| 116 |
+
else:
|
| 117 |
+
image = inputs.images[index]
|
| 118 |
+
if self.mode == 'vllm':
|
| 119 |
+
return [f'Picture {index + 1}: <img></img>\n']
|
| 120 |
+
else:
|
| 121 |
+
assert isinstance(image, str)
|
| 122 |
+
return [f'Picture {index + 1}: <img>{image}</img>\n']
|
| 123 |
+
|
| 124 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 125 |
+
return [f'<ref>{ref}</ref>']
|
| 126 |
+
|
| 127 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 128 |
+
return [f'<box>{self._get_bbox_str(bbox)}</box>']
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen_vl, template_cls=QwenVLTemplate))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class QwenAudioTemplate(Template):
|
| 135 |
+
|
| 136 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 137 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 138 |
+
assert media_type == 'audio'
|
| 139 |
+
audios = inputs.audios
|
| 140 |
+
audio = audios[index]
|
| 141 |
+
assert isinstance(audio, str)
|
| 142 |
+
return [f'Audio {index + 1}:<audio>{audio}</audio>\n']
|
| 143 |
+
|
| 144 |
+
def _tokenize(self, context, **tokenizer_kwargs):
|
| 145 |
+
audio_info = self.processor.process_audio(context)
|
| 146 |
+
return super()._tokenize(context, audio_info=audio_info)
|
| 147 |
+
|
| 148 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 149 |
+
encoded = super()._encode(inputs)
|
| 150 |
+
text = ''.join([f'<audio>{audio}</audio>' for audio in inputs.audios])
|
| 151 |
+
audio_info = self.processor.process_audio(text)
|
| 152 |
+
if audio_info:
|
| 153 |
+
tokenizer_kwargs = {'audio_info': audio_info}
|
| 154 |
+
encoded.update(tokenizer_kwargs)
|
| 155 |
+
encoded['tokenizer_kwargs'] = tokenizer_kwargs
|
| 156 |
+
return encoded
|
| 157 |
+
|
| 158 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 159 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 160 |
+
if batch[0].get('audio_info') is not None:
|
| 161 |
+
res['audio_info'] = [b['audio_info'] for b in batch]
|
| 162 |
+
return res
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen_audio, template_cls=QwenAudioTemplate))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class Qwen2AudioTemplate(Template):
|
| 169 |
+
|
| 170 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 171 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 172 |
+
assert media_type == 'audio'
|
| 173 |
+
if not self.use_chat_template:
|
| 174 |
+
return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
|
| 175 |
+
else:
|
| 176 |
+
return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
|
| 177 |
+
|
| 178 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 179 |
+
encoded = super()._encode(inputs)
|
| 180 |
+
if inputs.audios:
|
| 181 |
+
sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
|
| 182 |
+
audios = load_batch(inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate))
|
| 183 |
+
audio_inputs = self.processor.feature_extractor(
|
| 184 |
+
audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
|
| 185 |
+
audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
|
| 186 |
+
encoded.update(audio_inputs)
|
| 187 |
+
return encoded
|
| 188 |
+
|
| 189 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 190 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 191 |
+
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
|
| 192 |
+
feature_attention_mask = [
|
| 193 |
+
b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
|
| 194 |
+
]
|
| 195 |
+
if input_features:
|
| 196 |
+
res['input_features'] = torch.concat(input_features)
|
| 197 |
+
res['feature_attention_mask'] = torch.concat(feature_attention_mask)
|
| 198 |
+
return res
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate))
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Qwen2VLTemplate(Template):
|
| 205 |
+
image_token_id = 151655
|
| 206 |
+
video_token_id = 151656
|
| 207 |
+
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
|
| 208 |
+
version = 'v2'
|
| 209 |
+
use_model = True
|
| 210 |
+
|
| 211 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 212 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 213 |
+
from qwen_vl_utils import fetch_image, fetch_video
|
| 214 |
+
assert media_type in {'image', 'video'}
|
| 215 |
+
if media_type == 'image':
|
| 216 |
+
inputs.images[index] = fetch_image({'image': inputs.images[index]})
|
| 217 |
+
if self.mode == 'lmdeploy':
|
| 218 |
+
return ['<|vision_start|>', [-100], '<|vision_end|>']
|
| 219 |
+
else:
|
| 220 |
+
return ['<|vision_start|><|image_pad|><|vision_end|>']
|
| 221 |
+
else:
|
| 222 |
+
inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
|
| 223 |
+
return ['<|vision_start|><|video_pad|><|vision_end|>']
|
| 224 |
+
|
| 225 |
+
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 226 |
+
return [f'<|object_ref_start|>{ref}<|object_ref_end|>']
|
| 227 |
+
|
| 228 |
+
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
|
| 229 |
+
return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>']
|
| 230 |
+
|
| 231 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 232 |
+
encoded = super()._encode(inputs)
|
| 233 |
+
processor = self.processor
|
| 234 |
+
input_ids = encoded['input_ids']
|
| 235 |
+
labels = encoded['labels']
|
| 236 |
+
images = inputs.images
|
| 237 |
+
videos = inputs.videos
|
| 238 |
+
for media_type in ['images', 'videos']:
|
| 239 |
+
if locals()[media_type]:
|
| 240 |
+
if media_type == 'images':
|
| 241 |
+
media_token = self.image_token_id
|
| 242 |
+
media_inputs = processor.image_processor(
|
| 243 |
+
images=images, videos=None, return_tensors='pt', do_resize=False)
|
| 244 |
+
media_grid_thw = media_inputs['image_grid_thw']
|
| 245 |
+
else:
|
| 246 |
+
media_inputs = processor.image_processor(
|
| 247 |
+
images=None, videos=videos, return_tensors='pt', do_resize=False)
|
| 248 |
+
media_grid_thw = media_inputs['video_grid_thw']
|
| 249 |
+
media_token = self.video_token_id
|
| 250 |
+
if self.version == 'v2_5':
|
| 251 |
+
from qwen_vl_utils import vision_process
|
| 252 |
+
media_inputs['second_per_grid_ts'] = [
|
| 253 |
+
processor.image_processor.temporal_patch_size / vision_process.FPS
|
| 254 |
+
] * len(media_grid_thw)
|
| 255 |
+
idx_list = findall(input_ids, media_token)
|
| 256 |
+
merge_length = processor.image_processor.merge_size**2
|
| 257 |
+
|
| 258 |
+
def _get_new_tokens(i):
|
| 259 |
+
token_len = (media_grid_thw[i].prod() // merge_length)
|
| 260 |
+
return [media_token] * token_len
|
| 261 |
+
|
| 262 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 263 |
+
encoded.update(media_inputs)
|
| 264 |
+
|
| 265 |
+
encoded['input_ids'] = input_ids
|
| 266 |
+
encoded['labels'] = labels
|
| 267 |
+
return encoded
|
| 268 |
+
|
| 269 |
+
def compute_loss_context(self, model, inputs):
|
| 270 |
+
if 'real_position_ids' not in inputs:
|
| 271 |
+
return super().compute_loss_context(model, inputs)
|
| 272 |
+
if self.version == 'v2':
|
| 273 |
+
from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
|
| 274 |
+
elif self.version == 'v2_5':
|
| 275 |
+
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
|
| 276 |
+
elif self.version == 'omni':
|
| 277 |
+
from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
|
| 278 |
+
position_ids = inputs['position_ids']
|
| 279 |
+
inputs['position_ids'] = inputs.pop('real_position_ids')
|
| 280 |
+
return self._patch_flash_attention_forward(modeling_module, position_ids)
|
| 281 |
+
|
| 282 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 283 |
+
if not self.is_training:
|
| 284 |
+
return inputs
|
| 285 |
+
input_ids = inputs['input_ids']
|
| 286 |
+
_model = model.model
|
| 287 |
+
if not hasattr(_model, 'embed_tokens'):
|
| 288 |
+
_model = _model.model # LoRA
|
| 289 |
+
pixel_values = inputs.get('pixel_values')
|
| 290 |
+
pixel_values_videos = inputs.get('pixel_values_videos')
|
| 291 |
+
image_grid_thw = inputs.get('image_grid_thw')
|
| 292 |
+
video_grid_thw = inputs.get('video_grid_thw')
|
| 293 |
+
|
| 294 |
+
inputs_embeds = _model.embed_tokens(input_ids)
|
| 295 |
+
|
| 296 |
+
dtype = model.visual.get_dtype() if self.version == 'v2' else model.visual.dtype
|
| 297 |
+
if pixel_values is None and pixel_values_videos is None: # plain-text
|
| 298 |
+
if is_deepspeed_enabled():
|
| 299 |
+
from PIL import Image
|
| 300 |
+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
|
| 301 |
+
media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
|
| 302 |
+
device = input_ids.device
|
| 303 |
+
media_inputs = to_device(media_inputs, device)
|
| 304 |
+
pixel_values = media_inputs['pixel_values'].type(dtype)
|
| 305 |
+
image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
|
| 306 |
+
inputs_embeds += image_embeds.mean() * 0.
|
| 307 |
+
else:
|
| 308 |
+
if pixel_values is not None:
|
| 309 |
+
pixel_values = pixel_values.type(dtype)
|
| 310 |
+
image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
|
| 311 |
+
image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
| 312 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 313 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 314 |
+
|
| 315 |
+
if pixel_values_videos is not None:
|
| 316 |
+
pixel_values_videos = pixel_values_videos.type(dtype)
|
| 317 |
+
video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 318 |
+
video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
|
| 319 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 320 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 321 |
+
|
| 322 |
+
return {'inputs_embeds': inputs_embeds}
|
| 323 |
+
|
| 324 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 325 |
+
res = super()._data_collator_mm_data(batch)
|
| 326 |
+
second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')
|
| 327 |
+
if second_per_grid_ts:
|
| 328 |
+
res['second_per_grid_ts'] = second_per_grid_ts
|
| 329 |
+
for media_type in ['image', 'video']:
|
| 330 |
+
grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0)
|
| 331 |
+
if grid_thw is not None:
|
| 332 |
+
res[f'{media_type}_grid_thw'] = grid_thw
|
| 333 |
+
return res
|
| 334 |
+
|
| 335 |
+
def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
|
| 336 |
+
position_ids = []
|
| 337 |
+
for r in row:
|
| 338 |
+
r = r[0].copy()
|
| 339 |
+
r['input_ids'] = torch.tensor(r['input_ids'])[None]
|
| 340 |
+
position_ids.append(self._get_position_ids(r))
|
| 341 |
+
packed = super().packing_row(row)
|
| 342 |
+
packed['real_position_ids'] = torch.concat(position_ids, dim=-1)
|
| 343 |
+
return packed
|
| 344 |
+
|
| 345 |
+
def _get_position_ids(self, inputs: Dict[str, Any]):
|
| 346 |
+
# fix https://github.com/huggingface/transformers/pull/33487
|
| 347 |
+
kwargs = {}
|
| 348 |
+
if self.version == 'v2_5':
|
| 349 |
+
kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')}
|
| 350 |
+
position_ids, _ = self.model.get_rope_index(
|
| 351 |
+
inputs['input_ids'],
|
| 352 |
+
inputs.get('image_grid_thw'),
|
| 353 |
+
inputs.get('video_grid_thw'),
|
| 354 |
+
attention_mask=inputs.get('attention_mask'),
|
| 355 |
+
**kwargs)
|
| 356 |
+
return position_ids.contiguous()
|
| 357 |
+
|
| 358 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 359 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 360 |
+
if self._packing:
|
| 361 |
+
res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1)
|
| 362 |
+
elif self.is_training:
|
| 363 |
+
res['position_ids'] = self._get_position_ids(res)
|
| 364 |
+
return res
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_vl, template_cls=Qwen2VLTemplate))
|
| 368 |
+
|
| 369 |
+
register_template(
|
| 370 |
+
QwenTemplateMeta(
|
| 371 |
+
MLLMTemplateType.qvq,
|
| 372 |
+
default_system=('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
|
| 373 |
+
'Answer in the language of the question. You should think step-by-step.'),
|
| 374 |
+
template_cls=Qwen2VLTemplate,
|
| 375 |
+
))
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Qwen2_5VLTemplate(Qwen2VLTemplate):
|
| 379 |
+
version = 'v2_5'
|
| 380 |
+
norm_bbox = 'none'
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_vl, template_cls=Qwen2_5VLTemplate))
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
|
| 387 |
+
version = 'omni'
|
| 388 |
+
placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
|
| 389 |
+
|
| 390 |
+
def __init__(self, *args, **kwargs):
|
| 391 |
+
super().__init__(*args, **kwargs)
|
| 392 |
+
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
|
| 393 |
+
default = Qwen2_5OmniProcessorKwargs._defaults
|
| 394 |
+
self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
|
| 395 |
+
self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
|
| 396 |
+
self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
|
| 397 |
+
self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
|
| 398 |
+
|
| 399 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 400 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 401 |
+
from qwen_omni_utils import fetch_image, fetch_video
|
| 402 |
+
if media_type == 'image':
|
| 403 |
+
inputs.images[index] = fetch_image({'image': inputs.images[index]})
|
| 404 |
+
return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
|
| 405 |
+
elif media_type == 'audio':
|
| 406 |
+
inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
|
| 407 |
+
return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
|
| 408 |
+
elif media_type == 'video':
|
| 409 |
+
video = inputs.videos[index]
|
| 410 |
+
inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
|
| 411 |
+
if self.use_audio_in_video:
|
| 412 |
+
import librosa
|
| 413 |
+
if video.startswith('http://') or video.startswith('https://'):
|
| 414 |
+
import audioread
|
| 415 |
+
video = audioread.ffdec.FFmpegAudioFile(video)
|
| 416 |
+
video = librosa.load(video, sr=self.sampling_rate)[0]
|
| 417 |
+
inputs.audios.insert(inputs.audio_idx, (video, 'video'))
|
| 418 |
+
inputs.audio_idx += 1
|
| 419 |
+
return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
|
| 420 |
+
return ['<|vision_bos|><|VIDEO|><|vision_eos|>']
|
| 421 |
+
|
| 422 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 423 |
+
encoded = Template._encode(self, inputs)
|
| 424 |
+
processor = self.processor
|
| 425 |
+
video_audios_mask = []
|
| 426 |
+
for i, audio in enumerate(inputs.audios):
|
| 427 |
+
if isinstance(audio, tuple) and audio[1] == 'video':
|
| 428 |
+
inputs.audios[i] = audio[0]
|
| 429 |
+
video_audios_mask.append(True)
|
| 430 |
+
else:
|
| 431 |
+
video_audios_mask.append(False)
|
| 432 |
+
video_audios_mask = torch.tensor(video_audios_mask)
|
| 433 |
+
media_inputs = processor(
|
| 434 |
+
text='',
|
| 435 |
+
audio=inputs.audios or None,
|
| 436 |
+
images=inputs.images or None,
|
| 437 |
+
videos=inputs.videos or None,
|
| 438 |
+
return_tensors='pt')
|
| 439 |
+
media_inputs.pop('input_ids')
|
| 440 |
+
media_inputs.pop('attention_mask')
|
| 441 |
+
media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
|
| 442 |
+
input_ids = encoded['input_ids']
|
| 443 |
+
labels = encoded['labels']
|
| 444 |
+
# audio
|
| 445 |
+
audio_token_id = self._tokenize('<|AUDIO|>')
|
| 446 |
+
idx_list = findall(input_ids, audio_token_id)
|
| 447 |
+
feature_attention_mask = media_inputs.get('feature_attention_mask')
|
| 448 |
+
if feature_attention_mask is not None:
|
| 449 |
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 450 |
+
audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
|
| 451 |
+
else:
|
| 452 |
+
audio_lengths = None
|
| 453 |
+
audio_lengths_origin = audio_lengths
|
| 454 |
+
if idx_list:
|
| 455 |
+
if self.use_audio_in_video:
|
| 456 |
+
audio_lengths = audio_lengths[~video_audios_mask]
|
| 457 |
+
|
| 458 |
+
def _get_new_audio_tokens(i):
|
| 459 |
+
return audio_token_id * audio_lengths[i]
|
| 460 |
+
|
| 461 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
|
| 462 |
+
|
| 463 |
+
for media_type in ['image', 'video']:
|
| 464 |
+
token = f'<|{media_type.upper()}|>'
|
| 465 |
+
token_id = self._tokenize(token)
|
| 466 |
+
idx_list = findall(input_ids, token_id)
|
| 467 |
+
if idx_list:
|
| 468 |
+
merge_size = processor.image_processor.merge_size
|
| 469 |
+
media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
|
| 470 |
+
if media_type == 'video' and self.use_audio_in_video:
|
| 471 |
+
audio_lengths = audio_lengths_origin[video_audios_mask]
|
| 472 |
+
video_second_per_grid = media_inputs['video_second_per_grid']
|
| 473 |
+
|
| 474 |
+
def _get_new_tokens_use_audio_in_video(i):
|
| 475 |
+
audio_token_indices = torch.arange(audio_lengths[i])
|
| 476 |
+
grid_thw = media_grid_thw[i]
|
| 477 |
+
height = grid_thw[1] // merge_size
|
| 478 |
+
width = grid_thw[2] // merge_size
|
| 479 |
+
video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
|
| 480 |
+
video_token_indices = torch.broadcast_to(
|
| 481 |
+
video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
|
| 482 |
+
video_token_indices = (
|
| 483 |
+
video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
|
| 484 |
+
tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
|
| 485 |
+
video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
|
| 486 |
+
audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)
|
| 487 |
+
|
| 488 |
+
res = []
|
| 489 |
+
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
|
| 490 |
+
if j < len(video_chunk_indexes):
|
| 491 |
+
video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
|
| 492 |
+
res += token_id * video_seq_length
|
| 493 |
+
if j < len(audio_chunk_indexes):
|
| 494 |
+
audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
|
| 495 |
+
res += audio_token_id * audio_seq_length
|
| 496 |
+
return res
|
| 497 |
+
|
| 498 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
|
| 499 |
+
_get_new_tokens_use_audio_in_video)
|
| 500 |
+
|
| 501 |
+
else:
|
| 502 |
+
|
| 503 |
+
def _get_new_tokens(i):
|
| 504 |
+
token_len = (media_grid_thw[i].prod() // (merge_size**2))
|
| 505 |
+
return token_id * token_len
|
| 506 |
+
|
| 507 |
+
input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
|
| 508 |
+
|
| 509 |
+
encoded['input_ids'] = input_ids
|
| 510 |
+
encoded['labels'] = labels
|
| 511 |
+
encoded.update(media_inputs)
|
| 512 |
+
return encoded
|
| 513 |
+
|
| 514 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 515 |
+
return Template._post_encode(self, model, inputs)
|
| 516 |
+
|
| 517 |
+
def _get_position_ids(self, inputs: Dict[str, Any]):
|
| 518 |
+
feature_attention_mask = inputs.get('feature_attention_mask')
|
| 519 |
+
if feature_attention_mask is not None:
|
| 520 |
+
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
| 521 |
+
else:
|
| 522 |
+
audio_feature_lengths = None
|
| 523 |
+
video_second_per_grid = inputs.pop('video_second_per_grid', None)
|
| 524 |
+
input_ids = inputs['input_ids']
|
| 525 |
+
attention_mask = inputs.get('attention_mask')
|
| 526 |
+
if attention_mask is None:
|
| 527 |
+
attention_mask = torch.ones_like(input_ids)
|
| 528 |
+
position_ids, _ = self.model.thinker.get_rope_index(
|
| 529 |
+
input_ids,
|
| 530 |
+
inputs.get('image_grid_thw'),
|
| 531 |
+
inputs.get('video_grid_thw'),
|
| 532 |
+
attention_mask,
|
| 533 |
+
self.use_audio_in_video,
|
| 534 |
+
audio_feature_lengths,
|
| 535 |
+
video_second_per_grid,
|
| 536 |
+
)
|
| 537 |
+
return position_ids.contiguous()
|
| 538 |
+
|
| 539 |
+
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 540 |
+
res = super()._data_collator_mm_data(batch)
|
| 541 |
+
video_second_per_grid = self.gather_list(batch, 'video_second_per_grid')
|
| 542 |
+
if video_second_per_grid:
|
| 543 |
+
res['video_second_per_grid'] = video_second_per_grid
|
| 544 |
+
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
|
| 545 |
+
feature_attention_mask = [
|
| 546 |
+
b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
|
| 547 |
+
]
|
| 548 |
+
if input_features:
|
| 549 |
+
res['input_features'] = torch.concat(input_features)
|
| 550 |
+
res['feature_attention_mask'] = torch.concat(feature_attention_mask)
|
| 551 |
+
return res
|
| 552 |
+
|
| 553 |
+
def generate(self, model, *args, **kwargs):
|
| 554 |
+
if kwargs.get('video_grid_thw') is not None:
|
| 555 |
+
kwargs['use_audio_in_video'] = self.use_audio_in_video
|
| 556 |
+
return super().generate(model, *args, **kwargs)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_omni, template_cls=Qwen2_5OmniTemplate))
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
class Ovis1_6Template(Template):
|
| 563 |
+
skip_prompt = False
|
| 564 |
+
use_model = True
|
| 565 |
+
|
| 566 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 567 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 568 |
+
assert media_type == 'image'
|
| 569 |
+
return [[-200], '\n']
|
| 570 |
+
|
| 571 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 572 |
+
encoded = super()._encode(inputs)
|
| 573 |
+
images = inputs.images
|
| 574 |
+
input_ids = encoded['input_ids']
|
| 575 |
+
labels = encoded['labels']
|
| 576 |
+
idx_list = findall(input_ids, [-200])
|
| 577 |
+
added_tokens_len = 0
|
| 578 |
+
pixel_values = []
|
| 579 |
+
for i, idx in enumerate(idx_list):
|
| 580 |
+
max_partition = get_env_args('max_partition', int, 9)
|
| 581 |
+
raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image(
|
| 582 |
+
images[i], max_partition=max_partition)
|
| 583 |
+
input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:]
|
| 584 |
+
if labels is not None:
|
| 585 |
+
labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:]
|
| 586 |
+
pixel_values.append(raw_pixel_values)
|
| 587 |
+
added_tokens_len += len(image_placeholders) - 1
|
| 588 |
+
dtype = self.model.visual_tokenizer.dtype
|
| 589 |
+
if pixel_values:
|
| 590 |
+
pixel_values = torch.cat(pixel_values, dim=0).to(dtype)
|
| 591 |
+
else:
|
| 592 |
+
pixel_values = torch.zeros((1, 3, 384, 384), dtype=dtype) # dummpy
|
| 593 |
+
encoded.update({'input_ids': input_ids, 'labels': labels})
|
| 594 |
+
encoded['pixel_values'] = [pixel_values]
|
| 595 |
+
return encoded
|
| 596 |
+
|
| 597 |
+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 598 |
+
padding_side = self.padding_side if self.is_training else 'left'
|
| 599 |
+
if self.max_length is not None:
|
| 600 |
+
model.config.multimodal_max_length = self.max_length
|
| 601 |
+
input_ids = inputs['input_ids']
|
| 602 |
+
labels = inputs.get('labels')
|
| 603 |
+
if labels is None:
|
| 604 |
+
labels = input_ids.new_full(input_ids.shape, -100)
|
| 605 |
+
_, inputs_embeds, labels, attention_mask = model.merge_multimodal(
|
| 606 |
+
text_input_ids=input_ids,
|
| 607 |
+
text_attention_masks=torch.ones_like(input_ids), # not use, only compat
|
| 608 |
+
text_labels=labels,
|
| 609 |
+
pixel_values=inputs['pixel_values'],
|
| 610 |
+
left_padding=padding_side == 'left')
|
| 611 |
+
if inputs.get('labels') is None:
|
| 612 |
+
labels = None
|
| 613 |
+
return {'inputs_embeds': inputs_embeds, 'labels': labels, 'attention_mask': attention_mask}
|
| 614 |
+
|
| 615 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 616 |
+
pixel_values = self.gather_list(batch, 'pixel_values')
|
| 617 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 618 |
+
res['pixel_values'] = pixel_values
|
| 619 |
+
return res
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
register_template(
|
| 623 |
+
TemplateMeta(
|
| 624 |
+
MLLMTemplateType.ovis1_6,
|
| 625 |
+
prefix=['<bos>'],
|
| 626 |
+
prompt=['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
|
| 627 |
+
chat_sep=['<end_of_turn>\n'],
|
| 628 |
+
suffix=['<end_of_turn>'],
|
| 629 |
+
system_prefix=['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'],
|
| 630 |
+
template_cls=Ovis1_6Template,
|
| 631 |
+
))
|
| 632 |
+
|
| 633 |
+
register_template(
|
| 634 |
+
Llama3TemplateMeta(
|
| 635 |
+
MLLMTemplateType.ovis1_6_llama3,
|
| 636 |
+
default_system='You are a helpful and honest multimodal assistant.',
|
| 637 |
+
template_cls=Ovis1_6Template,
|
| 638 |
+
))
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
class Ovis2Template(Ovis1_6Template):
|
| 642 |
+
placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
|
| 643 |
+
nframes = 12
|
| 644 |
+
|
| 645 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 646 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 647 |
+
if media_type == 'image':
|
| 648 |
+
return [[-200], '\n']
|
| 649 |
+
elif media_type == 'video':
|
| 650 |
+
nframes = get_env_args('nframes', int, self.nframes)
|
| 651 |
+
inputs.images = load_video_ovis2(inputs.videos[index], nframes)
|
| 652 |
+
return [[-200] * nframes, '\n']
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
register_template(QwenTemplateMeta(
|
| 656 |
+
MLLMTemplateType.ovis2,
|
| 657 |
+
template_cls=Ovis2Template,
|
| 658 |
+
))
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
@dataclass
|
| 662 |
+
class MarcoO1TemplateMeta(QwenTemplateMeta):
|
| 663 |
+
default_system: Optional[str] = """
|
| 664 |
+
你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
|
| 665 |
+
\n## 重要!!!!!
|
| 666 |
+
当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
|
| 667 |
+
<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
register_template(MarcoO1TemplateMeta(LLMTemplateType.marco_o1))
|
ms-swift/swift/llm/template/template/stepfun.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 3 |
+
|
| 4 |
+
from ..base import Template
|
| 5 |
+
from ..constant import MLLMTemplateType
|
| 6 |
+
from ..register import TemplateMeta, register_template
|
| 7 |
+
from ..template_inputs import StdTemplateInputs
|
| 8 |
+
from ..utils import Context
|
| 9 |
+
from ..vision_utils import load_file
|
| 10 |
+
from .qwen import QwenTemplateMeta
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GOTImageEvalProcessor:
|
| 14 |
+
|
| 15 |
+
def __init__(self, image_size=384, mean=None, std=None):
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 18 |
+
if mean is None:
|
| 19 |
+
mean = (0.48145466, 0.4578275, 0.40821073)
|
| 20 |
+
if std is None:
|
| 21 |
+
std = (0.26862954, 0.26130258, 0.27577711)
|
| 22 |
+
|
| 23 |
+
self.normalize = transforms.Normalize(mean, std)
|
| 24 |
+
|
| 25 |
+
self.transform = transforms.Compose([
|
| 26 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
| 27 |
+
transforms.ToTensor(),
|
| 28 |
+
self.normalize,
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
def __call__(self, item):
|
| 32 |
+
return self.transform(item)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GOT_OCR2Template(Template):
|
| 36 |
+
placeholder_tokens = ['<imgpad>']
|
| 37 |
+
|
| 38 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 39 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 40 |
+
# 'OCR: '
|
| 41 |
+
# 'OCR with format: '
|
| 42 |
+
assert media_type == 'image'
|
| 43 |
+
return ['<img>' + '<imgpad>' * 256 + '</img>\n']
|
| 44 |
+
|
| 45 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 46 |
+
encoded = super()._encode(inputs)
|
| 47 |
+
images = inputs.images
|
| 48 |
+
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
| 49 |
+
for i, image in enumerate(images):
|
| 50 |
+
images[i] = image_processor_high(image)[None].to(self.model_info.torch_dtype)
|
| 51 |
+
if images:
|
| 52 |
+
encoded['images'] = images
|
| 53 |
+
return encoded
|
| 54 |
+
|
| 55 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 56 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 57 |
+
images = self.gather_list(batch, 'images')
|
| 58 |
+
if images:
|
| 59 |
+
res['images'] = images
|
| 60 |
+
return res
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
register_template(
|
| 64 |
+
QwenTemplateMeta(
|
| 65 |
+
MLLMTemplateType.got_ocr2,
|
| 66 |
+
default_system=' You should follow the instructions carefully and explain your answers in detail.',
|
| 67 |
+
template_cls=GOT_OCR2Template,
|
| 68 |
+
))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class GOT_OCR2HfTemplate(Template):
|
| 72 |
+
placeholder_tokens = ['<imgpad>']
|
| 73 |
+
|
| 74 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 75 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 76 |
+
# 'OCR: '
|
| 77 |
+
# 'OCR with format: '
|
| 78 |
+
assert media_type == 'image'
|
| 79 |
+
return ['<img>' + '<imgpad>' * 256 + '</img>\n']
|
| 80 |
+
|
| 81 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面
|
| 82 |
+
encoded = super()._encode(inputs)
|
| 83 |
+
images = inputs.images
|
| 84 |
+
if images:
|
| 85 |
+
encoded['images'] = images
|
| 86 |
+
return encoded
|
| 87 |
+
|
| 88 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 89 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 90 |
+
images = self.gather_list(batch, 'images')
|
| 91 |
+
_inputs = self.processor(images, return_tensors='pt')
|
| 92 |
+
_inputs.pop('input_ids') # this does not contain the response, so cannot be used when training
|
| 93 |
+
_inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training
|
| 94 |
+
|
| 95 |
+
res.update(_inputs.data)
|
| 96 |
+
return res
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
register_template(
|
| 100 |
+
QwenTemplateMeta(
|
| 101 |
+
MLLMTemplateType.got_ocr2_hf,
|
| 102 |
+
default_system=' You should follow the instructions carefully and explain your answers in detail.',
|
| 103 |
+
template_cls=GOT_OCR2HfTemplate,
|
| 104 |
+
))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class StepAudioTemplate(Template):
|
| 108 |
+
use_model = True
|
| 109 |
+
|
| 110 |
+
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
|
| 111 |
+
inputs: StdTemplateInputs) -> List[Context]:
|
| 112 |
+
assert media_type == 'audio', f'media_type: {media_type}'
|
| 113 |
+
from utils import load_audio
|
| 114 |
+
audio_wav, sr = load_audio(load_file(inputs.audios[index]))
|
| 115 |
+
audio_tokens = self.model.encoder(audio_wav, sr)
|
| 116 |
+
return audio_tokens
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
register_template(
|
| 120 |
+
TemplateMeta(
|
| 121 |
+
MLLMTemplateType.step_audio,
|
| 122 |
+
template_cls=StepAudioTemplate,
|
| 123 |
+
prefix=['<s>'],
|
| 124 |
+
prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'],
|
| 125 |
+
system_prefix=['<s><|BOT|>system\n{{SYSTEM}}<|EOT|>'],
|
| 126 |
+
chat_sep=['<|EOT|>'],
|
| 127 |
+
suffix=['<|EOT|>'],
|
| 128 |
+
))
|
ms-swift/swift/llm/template/template/yi.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ..base import Template
|
| 7 |
+
from ..constant import LLMTemplateType, MLLMTemplateType
|
| 8 |
+
from ..register import TemplateMeta, register_template
|
| 9 |
+
from ..template_inputs import StdTemplateInputs
|
| 10 |
+
from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
|
| 11 |
+
|
| 12 |
+
register_template(ChatmlTemplateMeta(
|
| 13 |
+
LLMTemplateType.yi_coder,
|
| 14 |
+
default_system=DEFAULT_SYSTEM,
|
| 15 |
+
))
|
| 16 |
+
|
| 17 |
+
yi_vl_default_system = (
|
| 18 |
+
'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
|
| 19 |
+
"Read all the images carefully, and respond to the human's questions with informative, "
|
| 20 |
+
'helpful, detailed and polite answers. '
|
| 21 |
+
'这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
|
| 22 |
+
'仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class YiVLTemplate(Template):
|
| 26 |
+
image_placeholder = [[-200], '\n']
|
| 27 |
+
use_model = True
|
| 28 |
+
|
| 29 |
+
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
|
| 30 |
+
encoded = super()._encode(inputs)
|
| 31 |
+
model = self.model
|
| 32 |
+
from llava.mm_utils import expand2square
|
| 33 |
+
if not hasattr(model, 'vision_tower'):
|
| 34 |
+
model = model.model
|
| 35 |
+
image_processor = model.vision_tower.image_processor
|
| 36 |
+
images = inputs.images or []
|
| 37 |
+
for i, image in enumerate(images):
|
| 38 |
+
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
| 39 |
+
image = expand2square(image, background_color)
|
| 40 |
+
images[i] = image
|
| 41 |
+
if images:
|
| 42 |
+
image_tensor = image_processor.preprocess(images, return_tensors='pt')['pixel_values']
|
| 43 |
+
encoded['images'] = image_tensor.to(model.dtype)
|
| 44 |
+
return encoded
|
| 45 |
+
|
| 46 |
+
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
|
| 47 |
+
res = super()._data_collator(batch, padding_to=padding_to)
|
| 48 |
+
images = [b['images'] for b in batch if 'images' in b]
|
| 49 |
+
if images:
|
| 50 |
+
res['images'] = torch.concat(images)
|
| 51 |
+
return res
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
register_template(
|
| 55 |
+
TemplateMeta(
|
| 56 |
+
MLLMTemplateType.yi_vl,
|
| 57 |
+
prefix=[],
|
| 58 |
+
prompt=[[8308], ' Human: {{QUERY}}\n', [8308], ' Assistant:'],
|
| 59 |
+
chat_sep=['\n'],
|
| 60 |
+
suffix=['\n', [8308]],
|
| 61 |
+
default_system=yi_vl_default_system,
|
| 62 |
+
template_cls=YiVLTemplate,
|
| 63 |
+
system_prefix=['{{SYSTEM}}\n\n']))
|
ms-swift/swift/llm/train/__pycache__/callback.cpython-310.pyc
ADDED
|
Binary file (3.11 kB). View file
|
|
|
ms-swift/swift/llm/train/__pycache__/rlhf.cpython-310.pyc
ADDED
|
Binary file (4.6 kB). View file
|
|
|
ms-swift/swift/llm/train/__pycache__/sft.cpython-310.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
ms-swift/swift/llm/train/__pycache__/tuner.cpython-310.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
ms-swift/swift/llm/train/callback.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import types
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import TrainerCallback
|
| 7 |
+
|
| 8 |
+
from swift.utils import get_logger
|
| 9 |
+
|
| 10 |
+
logger = get_logger()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TrainerAdapterCallback(TrainerCallback):
|
| 14 |
+
|
| 15 |
+
def __init__(self, args):
|
| 16 |
+
self.global_step = 0
|
| 17 |
+
self.args = args
|
| 18 |
+
|
| 19 |
+
# offload original_modules to cpu, to save memory
|
| 20 |
+
def on_train_begin(self, _args, state, control, **kwargs):
|
| 21 |
+
model = kwargs['model']
|
| 22 |
+
if self.args.train_type == 'adalora':
|
| 23 |
+
model.peft_config['default'].total_step = state.max_steps
|
| 24 |
+
|
| 25 |
+
def zero_grad(_self, *args, **kwargs):
|
| 26 |
+
_self.update_and_allocate(self.global_step + 1)
|
| 27 |
+
_self._zero_grad(*args, **kwargs)
|
| 28 |
+
|
| 29 |
+
model._zero_grad = model.zero_grad
|
| 30 |
+
model.zero_grad = types.MethodType(zero_grad, model)
|
| 31 |
+
|
| 32 |
+
def on_step_end(self, _args, state, control, **kwargs):
|
| 33 |
+
if self.args.train_type == 'adalora':
|
| 34 |
+
self.global_step = state.global_step
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DynamicLayerActivationCallback(TrainerCallback):
|
| 38 |
+
|
| 39 |
+
def __init__(self, n_layers: int, step_interval: int, model: torch.nn.Module):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.n_layers = n_layers
|
| 42 |
+
self.step_interval = step_interval
|
| 43 |
+
self.model = model
|
| 44 |
+
layers_name = None
|
| 45 |
+
layers = None
|
| 46 |
+
for name, module in model.named_modules():
|
| 47 |
+
if isinstance(module, torch.nn.ModuleList):
|
| 48 |
+
layers_name = name
|
| 49 |
+
layers = module
|
| 50 |
+
break
|
| 51 |
+
assert layers_name is not None
|
| 52 |
+
self.layers_attribute = layers_name
|
| 53 |
+
self.total_layers = len(layers)
|
| 54 |
+
|
| 55 |
+
# Freeze all layers upon initialization
|
| 56 |
+
self.freeze_all_layers()
|
| 57 |
+
self.active_layers_indices = []
|
| 58 |
+
|
| 59 |
+
def freeze_all_layers(self):
|
| 60 |
+
layers = self.model.get_submodule(self.layers_attribute)
|
| 61 |
+
for layer in layers:
|
| 62 |
+
for param in layer.parameters():
|
| 63 |
+
param.requires_grad = False
|
| 64 |
+
|
| 65 |
+
def on_step_begin(self, args, state, control, **kwargs):
|
| 66 |
+
# Check if it's time to switch active layers, including at step 0
|
| 67 |
+
if state.global_step % self.step_interval == 0 or state.global_step == 1:
|
| 68 |
+
self.switch_active_layers()
|
| 69 |
+
|
| 70 |
+
def switch_active_layers(self):
|
| 71 |
+
# First, disable gradients for all layers
|
| 72 |
+
self.freeze_all_layers()
|
| 73 |
+
|
| 74 |
+
# Randomly select n_layers to activate
|
| 75 |
+
layers = self.model.get_submodule(self.layers_attribute)
|
| 76 |
+
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
|
| 77 |
+
# Enable gradients only for the selected layers
|
| 78 |
+
for idx in self.active_layers_indices:
|
| 79 |
+
for param in layers[idx].parameters():
|
| 80 |
+
param.requires_grad = True
|
ms-swift/swift/llm/train/rlhf.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Union
|
| 4 |
+
|
| 5 |
+
from swift.llm import safe_snapshot_download
|
| 6 |
+
from swift.utils import get_logger, get_model_parameter_info
|
| 7 |
+
from ..argument import BaseArguments, RLHFArguments
|
| 8 |
+
from ..model import HfConfigFactory
|
| 9 |
+
from .kto import prepare_kto_dataset
|
| 10 |
+
from .sft import SwiftSft
|
| 11 |
+
|
| 12 |
+
logger = get_logger()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SwiftRLHF(SwiftSft):
|
| 16 |
+
args_class = RLHFArguments
|
| 17 |
+
args: args_class
|
| 18 |
+
|
| 19 |
+
def _prepare_model_tokenizer(self):
|
| 20 |
+
if self.args.sequence_parallel_size > 1:
|
| 21 |
+
# Duplicate calling is allowd to promise this function will
|
| 22 |
+
# be called before model initializing.
|
| 23 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 24 |
+
sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size)
|
| 25 |
+
# prepare ref/reward/value model
|
| 26 |
+
from swift.llm.infer.utils import prepare_adapter
|
| 27 |
+
args = self.args
|
| 28 |
+
|
| 29 |
+
def prepare_single_model(key, origin_key=None):
|
| 30 |
+
origin_key = origin_key or key
|
| 31 |
+
model_id_or_path = getattr(args, f'{key}_model')
|
| 32 |
+
if model_id_or_path is None:
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
model_type = getattr(args, f'{key}_model_type')
|
| 36 |
+
model_revision = getattr(args, f'{key}_model_revision')
|
| 37 |
+
model_dir = safe_snapshot_download(
|
| 38 |
+
model_id_or_path=model_id_or_path,
|
| 39 |
+
revision=model_revision,
|
| 40 |
+
download_model=False,
|
| 41 |
+
use_hf=args.use_hf,
|
| 42 |
+
hub_token=args.hub_token,
|
| 43 |
+
)
|
| 44 |
+
task_type = None
|
| 45 |
+
num_labels = None
|
| 46 |
+
if os.path.exists(os.path.join(model_dir, 'args.json')):
|
| 47 |
+
model_args = BaseArguments.from_pretrained(model_dir)
|
| 48 |
+
if hasattr(model_args, 'task_type'):
|
| 49 |
+
task_type = model_args.task_type
|
| 50 |
+
else:
|
| 51 |
+
from transformers import AutoConfig
|
| 52 |
+
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
| 53 |
+
if hasattr(model_config, 'num_labels'):
|
| 54 |
+
num_labels = model_config.num_labels
|
| 55 |
+
if task_type == 'seq_cls':
|
| 56 |
+
num_labels = 1
|
| 57 |
+
|
| 58 |
+
model, processor = args.get_model_processor(
|
| 59 |
+
model=model_id_or_path,
|
| 60 |
+
model_type=model_type,
|
| 61 |
+
model_revision=model_revision,
|
| 62 |
+
task_type=task_type,
|
| 63 |
+
num_labels=num_labels)
|
| 64 |
+
|
| 65 |
+
adapters = args.adapters if key == 'ref' else args.reward_adapters
|
| 66 |
+
model = prepare_adapter(args, model, adapters)
|
| 67 |
+
if origin_key in {'ref', 'reward'}:
|
| 68 |
+
if self.args.sequence_parallel_size > 1:
|
| 69 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 70 |
+
if hasattr(model, 'model_meta'):
|
| 71 |
+
is_multimodal = model.model_meta.is_multimodal
|
| 72 |
+
else:
|
| 73 |
+
is_multimodal = model.model.model_meta.is_multimodal
|
| 74 |
+
sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal)
|
| 75 |
+
model.requires_grad_(False).eval()
|
| 76 |
+
else:
|
| 77 |
+
model = self.prepare_model(args, model, task_type=task_type)
|
| 78 |
+
logger.info(f'value_model: {model}')
|
| 79 |
+
model_parameter_info = get_model_parameter_info(model)
|
| 80 |
+
self.train_msg['value_model_parameter_info'] = model_parameter_info
|
| 81 |
+
logger.info(f'value_model_parameter_info: {model_parameter_info}')
|
| 82 |
+
|
| 83 |
+
HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
|
| 84 |
+
return model, processor
|
| 85 |
+
|
| 86 |
+
# Handle ref and value models
|
| 87 |
+
for key in ['ref', 'value']:
|
| 88 |
+
setattr(self, f'{key}_model', None)
|
| 89 |
+
if key == 'value' and args.rlhf_type != 'ppo':
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
model_key = 'reward' if key == 'value' else key
|
| 93 |
+
result = prepare_single_model(model_key, key)
|
| 94 |
+
if result is not None:
|
| 95 |
+
model, _ = result
|
| 96 |
+
setattr(self, f'{key}_model', model)
|
| 97 |
+
|
| 98 |
+
# Handle reward model(s)
|
| 99 |
+
self.reward_model = None
|
| 100 |
+
if hasattr(args, 'reward_model') and args.reward_model is not None:
|
| 101 |
+
reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model]
|
| 102 |
+
self.reward_model = []
|
| 103 |
+
if args.rlhf_type == 'grpo':
|
| 104 |
+
self.reward_template = []
|
| 105 |
+
|
| 106 |
+
for reward_model_path in reward_models:
|
| 107 |
+
args.reward_model = reward_model_path # Temporarily set for prepare_single_model
|
| 108 |
+
result = prepare_single_model('reward')
|
| 109 |
+
if result is not None:
|
| 110 |
+
model, processor = result
|
| 111 |
+
self.reward_model.append(model)
|
| 112 |
+
|
| 113 |
+
if args.rlhf_type == 'grpo':
|
| 114 |
+
reward_template = self.args.get_template(processor, processor.model_meta.template)
|
| 115 |
+
if reward_template.use_model:
|
| 116 |
+
reward_template.model = model
|
| 117 |
+
self.reward_template.append(reward_template)
|
| 118 |
+
args.reward_model = reward_models # Restore original value
|
| 119 |
+
|
| 120 |
+
super()._prepare_model_tokenizer()
|
| 121 |
+
|
| 122 |
+
def _prepare_template(self) -> None:
|
| 123 |
+
args = self.args
|
| 124 |
+
super()._prepare_template()
|
| 125 |
+
model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'}
|
| 126 |
+
self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))
|
| 127 |
+
|
| 128 |
+
if args.rlhf_type == 'ppo':
|
| 129 |
+
args.training_args.stop_token_id = self.template.template_meta.stop_token_id
|
| 130 |
+
|
| 131 |
+
def _get_dataset(self):
|
| 132 |
+
args = self.args
|
| 133 |
+
train_dataset, val_dataset = super()._get_dataset()
|
| 134 |
+
if args.rlhf_type == 'kto':
|
| 135 |
+
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
|
| 136 |
+
return train_dataset, val_dataset
|
| 137 |
+
|
| 138 |
+
def _get_trainer_kwargs(self):
|
| 139 |
+
trainer_kwargs = {}
|
| 140 |
+
for key in ['ref', 'reward', 'value']:
|
| 141 |
+
key = f'{key}_model'
|
| 142 |
+
model = getattr(self, key, None)
|
| 143 |
+
if model or self.args.rlhf_type == 'ppo':
|
| 144 |
+
trainer_kwargs[key] = model
|
| 145 |
+
if hasattr(self, 'reward_template'):
|
| 146 |
+
trainer_kwargs['reward_template'] = self.reward_template
|
| 147 |
+
if self.args.rlhf_type == 'grpo':
|
| 148 |
+
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
|
| 149 |
+
trainer_kwargs['vllm_client'] = self.args.vllm_client
|
| 150 |
+
return trainer_kwargs
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def rlhf_main(args: Union[List[str], RLHFArguments, None] = None):
|
| 154 |
+
return SwiftRLHF(args).main()
|
ms-swift/swift/llm/train/sft.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
from datasets import Dataset as HfDataset
|
| 7 |
+
|
| 8 |
+
from swift.plugin import extra_callbacks, get_loss_func, get_metric
|
| 9 |
+
from swift.trainers import TrainerFactory
|
| 10 |
+
from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array,
|
| 11 |
+
use_torchacc)
|
| 12 |
+
from ..argument import TrainArguments
|
| 13 |
+
from ..base import SwiftPipeline
|
| 14 |
+
from ..dataset import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset,
|
| 15 |
+
PackingDataset, load_dataset)
|
| 16 |
+
from ..infer import prepare_generation_config
|
| 17 |
+
from ..model import HfConfigFactory, get_model_arch
|
| 18 |
+
from ..utils import deep_getattr, dynamic_gradient_checkpointing
|
| 19 |
+
from .tuner import TunerMixin
|
| 20 |
+
|
| 21 |
+
logger = get_logger()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SwiftSft(SwiftPipeline, TunerMixin):
|
| 25 |
+
args_class = TrainArguments
|
| 26 |
+
args: args_class
|
| 27 |
+
|
| 28 |
+
def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None:
|
| 29 |
+
super().__init__(args)
|
| 30 |
+
self.train_msg = {}
|
| 31 |
+
self._prepare_model_tokenizer()
|
| 32 |
+
self._prepare_template()
|
| 33 |
+
self._prepare_callbacks()
|
| 34 |
+
|
| 35 |
+
def _prepare_gradient_checkpointing(self):
|
| 36 |
+
args = self.args
|
| 37 |
+
HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False)
|
| 38 |
+
if args.gradient_checkpointing:
|
| 39 |
+
self.model.supports_gradient_checkpointing = True
|
| 40 |
+
dynamic_gradient_checkpointing(self.model)
|
| 41 |
+
self.model.enable_input_require_grads()
|
| 42 |
+
model_meta = self.model.model_meta
|
| 43 |
+
model_arch = get_model_arch(model_meta.model_arch)
|
| 44 |
+
if model_meta.is_multimodal and model_arch:
|
| 45 |
+
for vision_tower_name in model_arch.vision_tower:
|
| 46 |
+
vision_tower = deep_getattr(self.model, vision_tower_name)
|
| 47 |
+
if hasattr(vision_tower, 'enable_input_require_grads'):
|
| 48 |
+
try:
|
| 49 |
+
vision_tower.enable_input_require_grads()
|
| 50 |
+
except NotImplementedError:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def _prepare_generation_config(self):
|
| 54 |
+
args = self.args
|
| 55 |
+
self.model.origin_generation_config = self.model.generation_config
|
| 56 |
+
self.model.generation_config = prepare_generation_config(self.model.generation_config,
|
| 57 |
+
args.get_request_config(), self.tokenizer)
|
| 58 |
+
logger.info(f'model.generation_config: {self.model.generation_config}')
|
| 59 |
+
|
| 60 |
+
def _prepare_model_tokenizer(self):
|
| 61 |
+
args = self.args
|
| 62 |
+
if args.sequence_parallel_size > 1:
|
| 63 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 64 |
+
sequence_parallel.init_sequence_parallel(args.sequence_parallel_size)
|
| 65 |
+
self.model, self.processor = args.get_model_processor()
|
| 66 |
+
|
| 67 |
+
if hasattr(self.model, 'hf_device_map'):
|
| 68 |
+
logger.info(f'model.hf_device_map: {self.model.hf_device_map}')
|
| 69 |
+
|
| 70 |
+
logger.info(f'model_info: {self.model.model_info}')
|
| 71 |
+
|
| 72 |
+
self._prepare_generation_config()
|
| 73 |
+
self._prepare_gradient_checkpointing()
|
| 74 |
+
|
| 75 |
+
def _prepare_template(self) -> None:
|
| 76 |
+
template = self.args.get_template(self.processor)
|
| 77 |
+
if self.args.task_type == 'causal_lm':
|
| 78 |
+
template.set_mode('train')
|
| 79 |
+
if template.use_model:
|
| 80 |
+
template.model = self.model
|
| 81 |
+
self.template = template
|
| 82 |
+
|
| 83 |
+
def _get_dataset(self):
|
| 84 |
+
# The random shuffling of the training set occurs in the dataloader of the trainer.
|
| 85 |
+
args = self.args
|
| 86 |
+
dataset_kwargs = args.get_dataset_kwargs()
|
| 87 |
+
train_dataset, val_dataset = load_dataset(
|
| 88 |
+
args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs)
|
| 89 |
+
if len(args.val_dataset) > 0:
|
| 90 |
+
# Loading val dataset
|
| 91 |
+
_, val_dataset = load_dataset(
|
| 92 |
+
args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs)
|
| 93 |
+
assert args.split_dataset_ratio == 0.
|
| 94 |
+
logger.info(f'train_dataset: {train_dataset}')
|
| 95 |
+
logger.info(f'val_dataset: {val_dataset}')
|
| 96 |
+
|
| 97 |
+
return train_dataset, val_dataset
|
| 98 |
+
|
| 99 |
+
def _get_loss_func(self):
|
| 100 |
+
args = self.args
|
| 101 |
+
loss_type = args.loss_type
|
| 102 |
+
if loss_type is None and args.loss_scale != 'default':
|
| 103 |
+
loss_type = 'loss_scale'
|
| 104 |
+
return get_loss_func(loss_type)
|
| 105 |
+
|
| 106 |
+
def _get_data_collator(self):
|
| 107 |
+
args = self.args
|
| 108 |
+
template = self.template
|
| 109 |
+
padding_to = args.max_length if args.train_type == 'longlora' else None
|
| 110 |
+
return partial(template.data_collator, padding_to=padding_to)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def _save_val_dataset(output_dir: str, val_dataset):
|
| 114 |
+
if is_master() and isinstance(val_dataset, HfDataset):
|
| 115 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 116 |
+
val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl')
|
| 117 |
+
append_to_jsonl(val_dataset_path, val_dataset.to_list())
|
| 118 |
+
logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.')
|
| 119 |
+
|
| 120 |
+
def run(self):
|
| 121 |
+
args = self.args
|
| 122 |
+
|
| 123 |
+
train_dataset, val_dataset = self._get_dataset()
|
| 124 |
+
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
|
| 125 |
+
|
| 126 |
+
if args.task_type == 'seq_cls':
|
| 127 |
+
args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None)
|
| 128 |
+
logger.info(f'args.problem_type: {args.problem_type}')
|
| 129 |
+
args.save_args()
|
| 130 |
+
|
| 131 |
+
data_collator = self._get_data_collator()
|
| 132 |
+
# Some tuners require train_dataset and data_collator for preparation: LoRA-GA
|
| 133 |
+
self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset)
|
| 134 |
+
logger.info(f'model: {self.model}')
|
| 135 |
+
model_parameter_info = get_model_parameter_info(self.model)
|
| 136 |
+
self.train_msg['model_parameter_info'] = model_parameter_info
|
| 137 |
+
logger.info(f'model_parameter_info: {model_parameter_info}')
|
| 138 |
+
|
| 139 |
+
trainer_cls = TrainerFactory.get_trainer_cls(args)
|
| 140 |
+
trainer = trainer_cls(
|
| 141 |
+
model=self.model,
|
| 142 |
+
args=self.args.training_args,
|
| 143 |
+
data_collator=data_collator,
|
| 144 |
+
train_dataset=train_dataset,
|
| 145 |
+
eval_dataset=val_dataset,
|
| 146 |
+
callbacks=self.callbacks,
|
| 147 |
+
template=self.template,
|
| 148 |
+
**self._get_trainer_kwargs(),
|
| 149 |
+
)
|
| 150 |
+
return self.train(trainer)
|
| 151 |
+
|
| 152 |
+
def _get_trainer_kwargs(self):
|
| 153 |
+
args = self.args
|
| 154 |
+
if args.metric is not None:
|
| 155 |
+
compute_metrics, preprocess_logits_for_metrics = get_metric(args.metric)
|
| 156 |
+
elif args.predict_with_generate:
|
| 157 |
+
compute_metrics, preprocess_logits_for_metrics = get_metric('nlg')
|
| 158 |
+
else:
|
| 159 |
+
compute_metrics, preprocess_logits_for_metrics = get_metric('acc')
|
| 160 |
+
compute_metrics = partial(
|
| 161 |
+
compute_metrics, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
|
| 162 |
+
return {
|
| 163 |
+
'compute_metrics': compute_metrics,
|
| 164 |
+
'preprocess_logits_for_metrics': preprocess_logits_for_metrics,
|
| 165 |
+
'compute_loss_func': self._get_loss_func()
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def _save_trainer_state(self, trainer):
|
| 169 |
+
training_args = trainer.args
|
| 170 |
+
state = trainer.state
|
| 171 |
+
if hasattr(state, 'last_model_checkpoint'):
|
| 172 |
+
if self.args.create_checkpoint_symlink:
|
| 173 |
+
last_checkpoint = os.path.join(self.args.output_dir, 'last')
|
| 174 |
+
best_checkpoint = os.path.join(self.args.output_dir, 'best')
|
| 175 |
+
os.symlink(state.last_model_checkpoint, last_checkpoint)
|
| 176 |
+
os.symlink(state.best_model_checkpoint, best_checkpoint)
|
| 177 |
+
state.last_model_checkpoint = last_checkpoint
|
| 178 |
+
state.best_model_checkpoint = best_checkpoint
|
| 179 |
+
else:
|
| 180 |
+
state.last_model_checkpoint = None
|
| 181 |
+
logger.warning('No training was carried out, which may be due to the dataset being too small '
|
| 182 |
+
'or incorrect usage of resume_from_checkpoint.')
|
| 183 |
+
logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}')
|
| 184 |
+
logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}')
|
| 185 |
+
|
| 186 |
+
# Visualization
|
| 187 |
+
if is_master() and not use_torchacc():
|
| 188 |
+
if 'tensorboard' in training_args.report_to:
|
| 189 |
+
images_dir = os.path.join(training_args.output_dir, 'images')
|
| 190 |
+
logger.info(f'images_dir: {images_dir}')
|
| 191 |
+
plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9)
|
| 192 |
+
if training_args.push_to_hub:
|
| 193 |
+
trainer.push_to_hub()
|
| 194 |
+
|
| 195 |
+
self.train_msg.update({
|
| 196 |
+
'last_model_checkpoint': state.last_model_checkpoint,
|
| 197 |
+
'best_model_checkpoint': state.best_model_checkpoint,
|
| 198 |
+
'best_metric': state.best_metric,
|
| 199 |
+
'global_step': state.global_step,
|
| 200 |
+
'log_history': state.log_history,
|
| 201 |
+
'memory': trainer.max_memory,
|
| 202 |
+
})
|
| 203 |
+
if is_master():
|
| 204 |
+
jsonl_path = os.path.join(training_args.output_dir, 'logging.jsonl')
|
| 205 |
+
append_to_jsonl(jsonl_path, self.train_msg)
|
| 206 |
+
return self.train_msg
|
| 207 |
+
|
| 208 |
+
def train(self, trainer):
|
| 209 |
+
logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl')
|
| 210 |
+
logger.info(f'The logging file will be saved in: {logging_path}')
|
| 211 |
+
try:
|
| 212 |
+
trainer.train(trainer.args.resume_from_checkpoint)
|
| 213 |
+
finally:
|
| 214 |
+
res = self._save_trainer_state(trainer)
|
| 215 |
+
return res
|
| 216 |
+
|
| 217 |
+
def _prepare_callbacks(self):
|
| 218 |
+
from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback
|
| 219 |
+
args = self.args
|
| 220 |
+
callbacks = []
|
| 221 |
+
if args.lisa_activated_layers > 0:
|
| 222 |
+
assert args.train_type == 'full', 'LISA only supports full parameter training.'
|
| 223 |
+
lisa_callback = DynamicLayerActivationCallback(
|
| 224 |
+
n_layers=args.lisa_activated_layers, # Number of layers to activate
|
| 225 |
+
step_interval=args.lisa_step_interval, # Step interval to update active layers
|
| 226 |
+
model=self.model)
|
| 227 |
+
lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value
|
| 228 |
+
callbacks.append(lisa_callback)
|
| 229 |
+
|
| 230 |
+
if args.is_adapter and args.train_type == 'adalora':
|
| 231 |
+
callbacks.append(TrainerAdapterCallback(args))
|
| 232 |
+
callbacks += extra_callbacks
|
| 233 |
+
self.callbacks = callbacks
|
| 234 |
+
|
| 235 |
+
def _stat_dataset(self, dataset: HfDataset):
|
| 236 |
+
args = self.args
|
| 237 |
+
if isinstance(dataset, HfDataset):
|
| 238 |
+
dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc)
|
| 239 |
+
length = dataset['length']
|
| 240 |
+
else:
|
| 241 |
+
length = []
|
| 242 |
+
for row in dataset:
|
| 243 |
+
length.append(max([len(row[k]) for k in row.keys() if k.endswith('input_ids')]))
|
| 244 |
+
_, stat_str = stat_array(length)
|
| 245 |
+
logger.info(f'Dataset Token Length: {stat_str}')
|
| 246 |
+
return stat_str
|
| 247 |
+
|
| 248 |
+
def _encode_dataset(self, train_dataset, val_dataset):
|
| 249 |
+
template = self.template
|
| 250 |
+
args = self.args
|
| 251 |
+
output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save')
|
| 252 |
+
self._save_val_dataset(output_dir, val_dataset)
|
| 253 |
+
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
|
| 254 |
+
predict_with_generate = getattr(args, 'predict_with_generate', False)
|
| 255 |
+
if not is_grpo:
|
| 256 |
+
if args.packing:
|
| 257 |
+
packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset
|
| 258 |
+
train_dataset = packing_dataset_cls(
|
| 259 |
+
self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
|
| 260 |
+
if val_dataset is not None:
|
| 261 |
+
val_dataset = packing_dataset_cls(
|
| 262 |
+
self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
|
| 263 |
+
elif args.lazy_tokenize:
|
| 264 |
+
train_dataset = LazyLLMDataset(
|
| 265 |
+
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
|
| 266 |
+
if val_dataset is not None and not predict_with_generate:
|
| 267 |
+
val_dataset = LazyLLMDataset(
|
| 268 |
+
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
|
| 269 |
+
else:
|
| 270 |
+
preprocessor = EncodePreprocessor(template=template)
|
| 271 |
+
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
|
| 272 |
+
if val_dataset is not None and not predict_with_generate:
|
| 273 |
+
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
|
| 274 |
+
|
| 275 |
+
if is_master():
|
| 276 |
+
inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset))
|
| 277 |
+
template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {})
|
| 278 |
+
if isinstance(train_dataset, (HfDataset, PackingDataset)):
|
| 279 |
+
self.train_msg['train_dataset'] = self._stat_dataset(train_dataset)
|
| 280 |
+
if val_dataset is not None and not predict_with_generate:
|
| 281 |
+
self.train_msg['val_dataset'] = self._stat_dataset(val_dataset)
|
| 282 |
+
|
| 283 |
+
return train_dataset, val_dataset
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def sft_main(args: Union[List[str], TrainArguments, None] = None):
|
| 287 |
+
return SwiftSft(args).main()
|
ms-swift/swift/llm/train/tuner.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import inspect
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import transformers
|
| 9 |
+
from packaging import version
|
| 10 |
+
from transformers import TrainingArguments
|
| 11 |
+
|
| 12 |
+
from swift.llm import TrainArguments, deep_getattr, get_model_arch
|
| 13 |
+
from swift.plugin import Tuner, extra_tuners
|
| 14 |
+
from swift.tuners import Swift
|
| 15 |
+
from swift.utils import (activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters,
|
| 16 |
+
get_logger, use_torchacc)
|
| 17 |
+
|
| 18 |
+
logger = get_logger()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_liger(model_type: str):
|
| 22 |
+
from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral,
|
| 23 |
+
apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma,
|
| 24 |
+
apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen3,
|
| 25 |
+
apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl,
|
| 26 |
+
apply_liger_kernel_to_phi3, apply_liger_kernel_to_mllama)
|
| 27 |
+
from swift.llm import ModelType
|
| 28 |
+
if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2):
|
| 29 |
+
apply_liger_kernel_to_llama()
|
| 30 |
+
elif model_type in (ModelType.mistral):
|
| 31 |
+
apply_liger_kernel_to_mistral()
|
| 32 |
+
elif model_type in (ModelType.mixtral):
|
| 33 |
+
apply_liger_kernel_to_mixtral()
|
| 34 |
+
elif model_type in (ModelType.gemma, ModelType.gemma2):
|
| 35 |
+
apply_liger_kernel_to_gemma()
|
| 36 |
+
elif model_type in (ModelType.qwen2, ModelType.qwen2_5):
|
| 37 |
+
apply_liger_kernel_to_qwen2()
|
| 38 |
+
elif model_type in (ModelType.qwen3):
|
| 39 |
+
apply_liger_kernel_to_qwen3()
|
| 40 |
+
elif model_type in (ModelType.phi3):
|
| 41 |
+
apply_liger_kernel_to_phi3()
|
| 42 |
+
elif model_type in (ModelType.llama3_2_vision):
|
| 43 |
+
apply_liger_kernel_to_mllama()
|
| 44 |
+
elif model_type in (ModelType.qwen2_vl):
|
| 45 |
+
apply_liger_kernel_to_qwen2_vl()
|
| 46 |
+
elif model_type in (ModelType.qwen2_5_vl):
|
| 47 |
+
apply_liger_kernel_to_qwen2_5_vl()
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f'Unsupported liger model_type: {model_type}')
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_multimodal_target_regex(
|
| 53 |
+
model,
|
| 54 |
+
*,
|
| 55 |
+
freeze_llm: bool = False,
|
| 56 |
+
freeze_vit: bool = True,
|
| 57 |
+
freeze_aligner: bool = True,
|
| 58 |
+
include_embedding: bool = False,
|
| 59 |
+
) -> str:
|
| 60 |
+
model_arch = get_model_arch(model.model_meta.model_arch)
|
| 61 |
+
modules = []
|
| 62 |
+
if not freeze_llm:
|
| 63 |
+
modules += model_arch.language_model
|
| 64 |
+
if not freeze_vit:
|
| 65 |
+
modules += model_arch.vision_tower
|
| 66 |
+
if not freeze_aligner:
|
| 67 |
+
modules += model_arch.aligner
|
| 68 |
+
assert len(modules) > 0, f'modules: {modules}'
|
| 69 |
+
|
| 70 |
+
extra_layers = []
|
| 71 |
+
if include_embedding:
|
| 72 |
+
extra_layers.append(nn.Embedding)
|
| 73 |
+
res = []
|
| 74 |
+
for module in modules:
|
| 75 |
+
rejected_modules = []
|
| 76 |
+
if not freeze_vit:
|
| 77 |
+
for aligner in model_arch.aligner:
|
| 78 |
+
if aligner.startswith(f'{module}.'):
|
| 79 |
+
rejected_modules.append(aligner)
|
| 80 |
+
|
| 81 |
+
sub_module = deep_getattr(model, module)
|
| 82 |
+
target_modules = find_all_linears(sub_module, model_arch, extra_layers)
|
| 83 |
+
target_modules = [tm for tm in target_modules if tm]
|
| 84 |
+
target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
|
| 85 |
+
rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
|
| 86 |
+
res.append(rf'{rejected_pattern}{module}{target_pattern}')
|
| 87 |
+
|
| 88 |
+
return rf'^({"|".join(res)})$'
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_target_modules(args, model) -> Union[str, List[str]]:
|
| 92 |
+
"""Replace all-linear to actual modules"""
|
| 93 |
+
model_meta = model.model_meta
|
| 94 |
+
if isinstance(args.target_modules, str):
|
| 95 |
+
return args.target_modules
|
| 96 |
+
target_modules = args.target_modules.copy()
|
| 97 |
+
if 'all-linear' in target_modules:
|
| 98 |
+
if model_meta.is_multimodal:
|
| 99 |
+
return get_multimodal_target_regex(
|
| 100 |
+
model,
|
| 101 |
+
freeze_llm=args.freeze_llm,
|
| 102 |
+
freeze_vit=args.freeze_vit,
|
| 103 |
+
freeze_aligner=args.freeze_aligner,
|
| 104 |
+
include_embedding='all-embedding' in target_modules)
|
| 105 |
+
else:
|
| 106 |
+
target_modules.remove('all-linear')
|
| 107 |
+
target_modules += find_all_linears(model)
|
| 108 |
+
if 'all-embedding' in target_modules:
|
| 109 |
+
target_modules.remove('all-embedding')
|
| 110 |
+
target_modules += find_embedding(model)
|
| 111 |
+
return target_modules
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_modules_to_save(args, model, task_type=None):
|
| 115 |
+
modules_to_save = args.modules_to_save.copy()
|
| 116 |
+
if 'all-embedding' in args.modules_to_save:
|
| 117 |
+
modules_to_save.remove('all-embedding')
|
| 118 |
+
modules_to_save += find_embedding(model)
|
| 119 |
+
if 'all-norm' in args.modules_to_save:
|
| 120 |
+
modules_to_save.remove('all-norm')
|
| 121 |
+
modules_to_save += find_norm(model)
|
| 122 |
+
if task_type and task_type.lower() == 'seq_cls': # reward_model
|
| 123 |
+
modules_to_save.append('v_head')
|
| 124 |
+
return modules_to_save
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_vera_target_modules(model, config):
|
| 128 |
+
"""This function is only useful on the vera tuner"""
|
| 129 |
+
target_modules = config.target_modules
|
| 130 |
+
modules_dict = {
|
| 131 |
+
name: module.weight.shape
|
| 132 |
+
for name, module in model.named_modules()
|
| 133 |
+
if isinstance(module, torch.nn.Linear) and any([t in name for t in target_modules])
|
| 134 |
+
} # only Linear for now
|
| 135 |
+
if len(set(modules_dict.values())) > 1:
|
| 136 |
+
v = [t for t in target_modules if 'v' in t]
|
| 137 |
+
if not v:
|
| 138 |
+
raise ValueError('Please manually pass in `vera_target_modules`, do not use `all-linear`,'
|
| 139 |
+
'because Vera need all target linears to be the same size.')
|
| 140 |
+
v = v[0]
|
| 141 |
+
shape = [shape for name, shape in modules_dict.items() if v in name][0]
|
| 142 |
+
names = [_name for _name, _shape in modules_dict.items() if _shape == shape]
|
| 143 |
+
config.target_modules = [t for t in target_modules if any([t in name for name in names])]
|
| 144 |
+
return config
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None, task_type=None):
|
| 148 |
+
from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, LLaMAProConfig, LongLoRAModelType, LoraConfig,
|
| 149 |
+
LoRAConfig, ReftConfig, Swift, VeraConfig)
|
| 150 |
+
task_type = (task_type or args.task_type).upper()
|
| 151 |
+
target_modules = get_target_modules(args, model)
|
| 152 |
+
modules_to_save = get_modules_to_save(args, model, task_type)
|
| 153 |
+
lora_kwargs = {
|
| 154 |
+
'r': args.lora_rank,
|
| 155 |
+
'target_modules': target_modules,
|
| 156 |
+
'lora_alpha': args.lora_alpha,
|
| 157 |
+
'lora_dropout': args.lora_dropout,
|
| 158 |
+
'bias': args.lora_bias,
|
| 159 |
+
'modules_to_save': modules_to_save,
|
| 160 |
+
'use_rslora': args.use_rslora,
|
| 161 |
+
'use_dora': args.use_dora,
|
| 162 |
+
'lorap_lr_ratio': args.lorap_lr_ratio,
|
| 163 |
+
'init_lora_weights': args.init_weights,
|
| 164 |
+
}
|
| 165 |
+
if args.train_type in ('lora', 'longlora'):
|
| 166 |
+
if args.use_swift_lora:
|
| 167 |
+
lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs)
|
| 168 |
+
model = Swift.prepare_model(model, lora_config)
|
| 169 |
+
logger.info(f'lora_config: {lora_config}')
|
| 170 |
+
elif args.tuner_backend == 'peft':
|
| 171 |
+
if task_type == 'EMBEDDING':
|
| 172 |
+
task_type = None
|
| 173 |
+
lora_config = LoraConfig(task_type=task_type, lora_dtype=args.lora_dtype, **lora_kwargs)
|
| 174 |
+
if args.init_weights == 'lora-ga':
|
| 175 |
+
try:
|
| 176 |
+
import lora_ga
|
| 177 |
+
except ImportError as e:
|
| 178 |
+
error_message = """
|
| 179 |
+
Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
|
| 180 |
+
Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
|
| 181 |
+
"""
|
| 182 |
+
logger.info(error_message)
|
| 183 |
+
raise RuntimeError(error_message) from e
|
| 184 |
+
model = lora_ga.entrypoint.get_lora_ga_model(
|
| 185 |
+
model=model,
|
| 186 |
+
data_collator=template.data_collator,
|
| 187 |
+
dataset=train_dataset,
|
| 188 |
+
batch_size=args.lora_ga_batch_size,
|
| 189 |
+
num_iters=args.lora_ga_iters,
|
| 190 |
+
max_length=args.lora_ga_max_length,
|
| 191 |
+
direction=args.lora_ga_direction,
|
| 192 |
+
dtype=args.lora_dtype,
|
| 193 |
+
scale=args.lora_ga_scale,
|
| 194 |
+
stable_gamma=args.lora_ga_stable_gamma,
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
model = Swift.prepare_model(model, lora_config)
|
| 198 |
+
logger.info(f'lora_config: {lora_config}')
|
| 199 |
+
elif args.tuner_backend == 'unsloth':
|
| 200 |
+
if args.resume_from_checkpoint is None:
|
| 201 |
+
if args.model_meta.is_multimodal:
|
| 202 |
+
from unsloth import FastVisionModel as UnslothModel
|
| 203 |
+
else:
|
| 204 |
+
from unsloth import FastLanguageModel as UnslothModel
|
| 205 |
+
assert args.train_type == 'lora', 'Unsloth does not support LongLoRA'
|
| 206 |
+
lora_kwargs.pop('lorap_lr_ratio')
|
| 207 |
+
model = UnslothModel.get_peft_model(
|
| 208 |
+
model,
|
| 209 |
+
use_gradient_checkpointing='unsloth',
|
| 210 |
+
max_seq_length=args.max_length or 2048, # 2048 is the default value of unsloth
|
| 211 |
+
**lora_kwargs,
|
| 212 |
+
)
|
| 213 |
+
logger.info(f'unsloth_config: {lora_kwargs}')
|
| 214 |
+
if args.train_type == 'longlora':
|
| 215 |
+
assert LongLoRAModelType.LLAMA in args.model_type
|
| 216 |
+
assert version.parse(transformers.__version__) >= version.parse('4.39.3')
|
| 217 |
+
from swift.tuners.longlora.llama import replace_llama_attn
|
| 218 |
+
replace_llama_attn(model)
|
| 219 |
+
model.config.group_size_ratio = 0.25
|
| 220 |
+
elif args.train_type == 'adalora':
|
| 221 |
+
lora_kwargs.pop('lorap_lr_ratio', None)
|
| 222 |
+
lora_kwargs['rank_pattern'] = None
|
| 223 |
+
from swift.plugin.optimizer import calculate_max_steps
|
| 224 |
+
adalora_config = AdaLoraConfig(
|
| 225 |
+
task_type=task_type,
|
| 226 |
+
**lora_kwargs,
|
| 227 |
+
target_r=args.adalora_target_r,
|
| 228 |
+
init_r=args.adalora_init_r,
|
| 229 |
+
tinit=args.adalora_tinit,
|
| 230 |
+
tfinal=args.adalora_tfinal,
|
| 231 |
+
deltaT=args.adalora_deltaT,
|
| 232 |
+
beta1=args.adalora_beta1,
|
| 233 |
+
beta2=args.adalora_beta2,
|
| 234 |
+
orth_reg_weight=args.adalora_orth_reg_weight,
|
| 235 |
+
total_step=calculate_max_steps(args.training_args, train_dataset),
|
| 236 |
+
)
|
| 237 |
+
model = Swift.prepare_model(model, adalora_config)
|
| 238 |
+
logger.info(f'adalora_config: {adalora_config}')
|
| 239 |
+
elif args.train_type == 'llamapro':
|
| 240 |
+
llamapro_config = LLaMAProConfig(
|
| 241 |
+
model_type=model.model_meta.model_arch,
|
| 242 |
+
num_new_blocks=args.llamapro_num_new_blocks,
|
| 243 |
+
num_groups=args.llamapro_num_groups)
|
| 244 |
+
model = Swift.prepare_model(model, llamapro_config)
|
| 245 |
+
logger.info(f'llamapro_config: {llamapro_config}')
|
| 246 |
+
elif args.train_type == 'adapter':
|
| 247 |
+
model_arch = get_model_arch(model.model_meta.model_arch)
|
| 248 |
+
mlp_key = model_arch.mlp
|
| 249 |
+
mlp_key = mlp_key.split('.{}.')[1]
|
| 250 |
+
adapter_config = AdapterConfig(
|
| 251 |
+
dim=model.config.hidden_size,
|
| 252 |
+
target_modules=[mlp_key],
|
| 253 |
+
hidden_pos=0,
|
| 254 |
+
adapter_length=args.adapter_length,
|
| 255 |
+
act_layer=args.adapter_act)
|
| 256 |
+
model = Swift.prepare_model(model, adapter_config)
|
| 257 |
+
logger.info(f'adapter_config: {adapter_config}')
|
| 258 |
+
elif args.train_type == 'vera':
|
| 259 |
+
vera_config = VeraConfig(
|
| 260 |
+
r=args.vera_rank,
|
| 261 |
+
target_modules=target_modules,
|
| 262 |
+
projection_prng_key=args.vera_projection_prng_key,
|
| 263 |
+
vera_dropout=args.vera_dropout,
|
| 264 |
+
d_initial=args.vera_d_initial,
|
| 265 |
+
modules_to_save=args.modules_to_save,
|
| 266 |
+
)
|
| 267 |
+
vera_config = get_vera_target_modules(model, vera_config)
|
| 268 |
+
model = Swift.prepare_model(model, vera_config)
|
| 269 |
+
logger.info(f'vera_config: {vera_config}')
|
| 270 |
+
elif args.train_type == 'boft':
|
| 271 |
+
boft_config = BOFTConfig(
|
| 272 |
+
boft_block_size=args.boft_block_size,
|
| 273 |
+
boft_block_num=args.boft_block_num,
|
| 274 |
+
boft_n_butterfly_factor=args.boft_n_butterfly_factor,
|
| 275 |
+
target_modules=target_modules,
|
| 276 |
+
boft_dropout=args.boft_dropout,
|
| 277 |
+
modules_to_save=args.modules_to_save,
|
| 278 |
+
)
|
| 279 |
+
model = Swift.prepare_model(model, boft_config)
|
| 280 |
+
logger.info(f'boft_config: {boft_config}')
|
| 281 |
+
elif args.train_type == 'fourierft':
|
| 282 |
+
from peft import FourierFTConfig
|
| 283 |
+
fourier_config = FourierFTConfig(
|
| 284 |
+
target_modules=target_modules,
|
| 285 |
+
modules_to_save=args.modules_to_save,
|
| 286 |
+
n_frequency=args.fourier_n_frequency,
|
| 287 |
+
scaling=args.fourier_scaling,
|
| 288 |
+
)
|
| 289 |
+
model = Swift.prepare_model(model, fourier_config)
|
| 290 |
+
logger.info(f'fourier_config: {fourier_config}')
|
| 291 |
+
elif args.train_type == 'reft':
|
| 292 |
+
reft_config = ReftConfig(
|
| 293 |
+
model_type=model.model_meta.model_arch,
|
| 294 |
+
layer_key=args.reft_layer_key,
|
| 295 |
+
r=args.reft_rank,
|
| 296 |
+
layers=args.reft_layers,
|
| 297 |
+
intervention_type=args.reft_intervention_type,
|
| 298 |
+
args=args.reft_args,
|
| 299 |
+
)
|
| 300 |
+
logger.info(f'reft config: {reft_config}')
|
| 301 |
+
model = Swift.prepare_model(model, {'reft': reft_config})
|
| 302 |
+
elif args.train_type == 'bone':
|
| 303 |
+
# Version loosing
|
| 304 |
+
from peft import BoneConfig
|
| 305 |
+
bone_config = BoneConfig(
|
| 306 |
+
target_modules=target_modules,
|
| 307 |
+
r=args.reft_rank,
|
| 308 |
+
init_weights=args.init_weights,
|
| 309 |
+
)
|
| 310 |
+
logger.info(f'bone config: {bone_config}')
|
| 311 |
+
model = Swift.prepare_model(model, bone_config)
|
| 312 |
+
return model
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def torchacc_resume_from_checkpoint(args, model):
|
| 316 |
+
import safetensors
|
| 317 |
+
weights_file = os.path.join(args.resume_from_checkpoint, 'pytorch_model.bin')
|
| 318 |
+
safe_weights_file = os.path.join(args.resume_from_checkpoint, 'model.safetensors')
|
| 319 |
+
if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
|
| 320 |
+
if args.save_safetensors and os.path.isfile(safe_weights_file):
|
| 321 |
+
state_dict = safetensors.torch.load_file(safe_weights_file, device='cpu')
|
| 322 |
+
else:
|
| 323 |
+
state_dict = torch.load(weights_file, map_location='cpu')
|
| 324 |
+
model.load_state_dict(state_dict, False)
|
| 325 |
+
del state_dict
|
| 326 |
+
else:
|
| 327 |
+
from transformers.modeling_utils import load_sharded_checkpoint
|
| 328 |
+
# We load the sharded checkpoint
|
| 329 |
+
load_result = load_sharded_checkpoint(
|
| 330 |
+
model, args.resume_from_checkpoint, strict=False, prefer_safe=args.save_safetensors)
|
| 331 |
+
if len(load_result.missing_keys) != 0:
|
| 332 |
+
if model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
|
| 333 |
+
model._keys_to_ignore_on_save):
|
| 334 |
+
model.tie_weights()
|
| 335 |
+
else:
|
| 336 |
+
logger.warning(f'There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.')
|
| 337 |
+
if len(load_result.unexpected_keys) != 0:
|
| 338 |
+
logger.warning(f'There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.')
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class TunerMixin:
|
| 342 |
+
|
| 343 |
+
@classmethod
|
| 344 |
+
def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None):
|
| 345 |
+
if args.use_liger_kernel and 'use_liger_kernel' not in inspect.signature(TrainingArguments).parameters:
|
| 346 |
+
# Apply liger
|
| 347 |
+
apply_liger(args.model_type)
|
| 348 |
+
|
| 349 |
+
if args.is_adapter:
|
| 350 |
+
if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners:
|
| 351 |
+
# Fix the name of the layer in xcomposer that contains Plora.
|
| 352 |
+
# Unsloth prepares and loads lora outside this function when
|
| 353 |
+
# resume_from_checkpoint, so do not disable grad here
|
| 354 |
+
model.requires_grad_(False)
|
| 355 |
+
if args.resume_from_checkpoint:
|
| 356 |
+
if args.train_type in extra_tuners:
|
| 357 |
+
tuner: Tuner = extra_tuners[args.train_type]
|
| 358 |
+
else:
|
| 359 |
+
tuner = Swift
|
| 360 |
+
kwargs = {}
|
| 361 |
+
if use_torchacc():
|
| 362 |
+
kwargs = {'adapter_name': 'default'}
|
| 363 |
+
model = tuner.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True, **kwargs)
|
| 364 |
+
else:
|
| 365 |
+
if args.train_type in extra_tuners:
|
| 366 |
+
tuner: Tuner = extra_tuners[args.train_type]
|
| 367 |
+
model = tuner.prepare_model(args, model)
|
| 368 |
+
else:
|
| 369 |
+
model = prepare_adapter(
|
| 370 |
+
args, model, template=template, train_dataset=train_dataset, task_type=task_type)
|
| 371 |
+
# fix bug: Attempting to unscale FP16 gradients.
|
| 372 |
+
# peft: https://github.com/huggingface/peft/issues/1249
|
| 373 |
+
for p in model.parameters():
|
| 374 |
+
if p.requires_grad and p.dtype == torch.float16:
|
| 375 |
+
logger.info_once('Convert trainable parameters from fp16 to fp32.')
|
| 376 |
+
p.data = p.data.to(dtype=torch.float32)
|
| 377 |
+
elif args.train_type == 'full':
|
| 378 |
+
model.train()
|
| 379 |
+
model.requires_grad_(True)
|
| 380 |
+
|
| 381 |
+
freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex)
|
| 382 |
+
if len(args.trainable_parameters) > 0 or args.trainable_parameters_regex is not None:
|
| 383 |
+
activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex)
|
| 384 |
+
if use_torchacc() and args.resume_from_checkpoint:
|
| 385 |
+
torchacc_resume_from_checkpoint(args, model)
|
| 386 |
+
else:
|
| 387 |
+
raise ValueError(f'args.train_type: {args.train_type}')
|
| 388 |
+
|
| 389 |
+
if args.resume_only_model:
|
| 390 |
+
args.training_args.resume_from_checkpoint = None
|
| 391 |
+
if args.use_galore:
|
| 392 |
+
from swift.trainers.optimizers.galore import GaLoreConfig
|
| 393 |
+
if args.galore_target_modules is None:
|
| 394 |
+
args.galore_target_modules = find_all_linears(model)
|
| 395 |
+
if args.galore_with_embedding:
|
| 396 |
+
args.galore_target_modules += find_embedding(model)
|
| 397 |
+
args.galore_config = GaLoreConfig(
|
| 398 |
+
target_modules=args.galore_target_modules,
|
| 399 |
+
rank=args.galore_rank,
|
| 400 |
+
update_proj_gap=args.galore_update_proj_gap,
|
| 401 |
+
galore_scale=args.galore_scale,
|
| 402 |
+
proj_type=args.galore_proj_type,
|
| 403 |
+
optim_per_parameter=args.galore_optim_per_parameter,
|
| 404 |
+
quantize=args.galore_quantization,
|
| 405 |
+
proj_quant=args.galore_proj_quant,
|
| 406 |
+
proj_bits=args.galore_proj_bits,
|
| 407 |
+
proj_group_size=args.galore_proj_group_size,
|
| 408 |
+
cos_threshold=args.galore_cos_threshold,
|
| 409 |
+
gamma_proj=args.galore_gamma_proj,
|
| 410 |
+
queue_size=args.galore_queue_size,
|
| 411 |
+
)
|
| 412 |
+
args.training_args.galore_config = args.galore_config
|
| 413 |
+
|
| 414 |
+
if args.sequence_parallel_size > 1:
|
| 415 |
+
from swift.trainers.sequence_parallel import sequence_parallel
|
| 416 |
+
if hasattr(model, 'model_meta'):
|
| 417 |
+
is_multimodal = model.model_meta.is_multimodal
|
| 418 |
+
else:
|
| 419 |
+
is_multimodal = model.model.model_meta.is_multimodal
|
| 420 |
+
# multimodal model must do split in basemodel's forward
|
| 421 |
+
# or the media embedding may occur error
|
| 422 |
+
sequence_parallel.prepare_model(model, template.tokenizer, split_in_forward=is_multimodal)
|
| 423 |
+
|
| 424 |
+
return model
|
ms-swift/swift/megatron/argument/train_args.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from swift.llm import BaseArguments
|
| 8 |
+
from swift.llm.argument.base_args import to_abspath
|
| 9 |
+
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
|
| 10 |
+
from ..model import get_megatron_model_meta
|
| 11 |
+
from .megatron_args import MegatronArguments
|
| 12 |
+
|
| 13 |
+
logger = get_logger()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class MegatronTrainArguments(MegatronArguments, BaseArguments):
|
| 18 |
+
add_version: bool = True
|
| 19 |
+
# dataset
|
| 20 |
+
lazy_tokenize: bool = False
|
| 21 |
+
packing: bool = False
|
| 22 |
+
|
| 23 |
+
def init_model_args(self, config):
|
| 24 |
+
self.megatron_model_meta = get_megatron_model_meta(self.model_type)
|
| 25 |
+
kwargs = self.megatron_model_meta.convert_hf_config(config)
|
| 26 |
+
for k, v in kwargs.items():
|
| 27 |
+
if getattr(self, k) is None:
|
| 28 |
+
setattr(self, k, v)
|
| 29 |
+
MegatronArguments.__post_init__(self)
|
| 30 |
+
self.extra_args = self.parse_to_megatron()
|
| 31 |
+
|
| 32 |
+
def _init_save(self):
|
| 33 |
+
init_process_group()
|
| 34 |
+
if self.save is None:
|
| 35 |
+
self.save = f'megatron_output/{self.model_suffix}'
|
| 36 |
+
self.save = to_abspath(self.save)
|
| 37 |
+
if self.add_version:
|
| 38 |
+
self.save = add_version_to_work_dir(self.save)
|
| 39 |
+
logger.info(f'args.save: {self.save}')
|
| 40 |
+
if is_master():
|
| 41 |
+
os.makedirs(self.save, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
self.sequence_parallel_size = self.context_parallel_size
|
| 45 |
+
self.load = to_abspath(self.load, check_path_exist=True)
|
| 46 |
+
BaseArguments.__post_init__(self)
|
| 47 |
+
self._init_save()
|
| 48 |
+
self.seq_length = self.seq_length or self.max_length
|
| 49 |
+
if self.streaming:
|
| 50 |
+
self.dataloader_type = 'external'
|
| 51 |
+
if self.num_workers > 1:
|
| 52 |
+
self.num_workers = 1
|
| 53 |
+
logger.info('Using streaming dataset, setting args.num_workers to 1.')
|
ms-swift/swift/megatron/model/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from . import gpt
|
| 3 |
+
from .constant import MegatronModelType
|
| 4 |
+
from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
|
ms-swift/swift/megatron/model/config.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
from swift.utils import get_logger
|
| 5 |
+
|
| 6 |
+
logger = get_logger()
|
| 7 |
+
config_mapping = {
|
| 8 |
+
'num_layers': ['num_hidden_layers'],
|
| 9 |
+
'hidden_size': ['hidden_size'],
|
| 10 |
+
'ffn_hidden_size': ['intermediate_size'],
|
| 11 |
+
'num_attention_heads': ['num_attention_heads'],
|
| 12 |
+
'num_query_groups': ['num_key_value_heads'],
|
| 13 |
+
'max_position_embeddings': ['max_position_embeddings'],
|
| 14 |
+
'norm_epsilon': ['rms_norm_eps'],
|
| 15 |
+
'rotary_base': ['rope_theta'],
|
| 16 |
+
'padded_vocab_size': ['vocab_size'],
|
| 17 |
+
'attention_dropout': ['attention_dropout'],
|
| 18 |
+
'untie_embeddings_and_output_weights': ['tie_word_embeddings'],
|
| 19 |
+
'swiglu': ['hidden_act'],
|
| 20 |
+
'add_qkv_bias': ['attention_bias'],
|
| 21 |
+
'disable_bias_linear': ['mlp_bias'],
|
| 22 |
+
'kv_channels': ['head_dim'],
|
| 23 |
+
'model_type': ['model_type'],
|
| 24 |
+
# moe
|
| 25 |
+
'moe_ffn_hidden_size': ['moe_intermediate_size'],
|
| 26 |
+
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
|
| 27 |
+
'moe_router_topk': ['num_experts_per_tok'],
|
| 28 |
+
'num_experts': ['num_experts'],
|
| 29 |
+
'moe_router_pre_softmax': ['norm_topk_prob'],
|
| 30 |
+
'moe_aux_loss_coeff': ['router_aux_loss_coef'],
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def convert_hf_config(config) -> Dict[str, Any]:
|
| 35 |
+
megatron_config = {}
|
| 36 |
+
for k, hf_keys in config_mapping.items():
|
| 37 |
+
for hf_k in hf_keys:
|
| 38 |
+
if hasattr(config, hf_k):
|
| 39 |
+
hf_v = getattr(config, hf_k)
|
| 40 |
+
if k == 'rotary_base':
|
| 41 |
+
megatron_config[k] = int(hf_v)
|
| 42 |
+
elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}:
|
| 43 |
+
megatron_config[k] = not hf_v
|
| 44 |
+
elif k == 'swiglu':
|
| 45 |
+
if hf_v == 'silu':
|
| 46 |
+
megatron_config[k] = True
|
| 47 |
+
else:
|
| 48 |
+
megatron_config[k] = hf_v
|
| 49 |
+
break
|
| 50 |
+
# compat llama3
|
| 51 |
+
if getattr(config, 'rope_scaling', None) is not None:
|
| 52 |
+
if isinstance(config.rope_scaling, int):
|
| 53 |
+
megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'},
|
| 54 |
+
elif isinstance(config.rope_scaling, dict):
|
| 55 |
+
megatron_config['rope_scaling'] = config.rope_scaling
|
| 56 |
+
logger.info(f'megatron_config: {megatron_config}')
|
| 57 |
+
return megatron_config
|
ms-swift/swift/megatron/model/constant.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
class MegatronModelType:
|
| 3 |
+
gpt = 'gpt'
|
ms-swift/swift/megatron/model/gpt/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from swift.llm import ModelType
|
| 3 |
+
from ..constant import MegatronModelType
|
| 4 |
+
from ..register import MegatronModelMeta, register_megatron_model
|
| 5 |
+
from .config import convert_gpt_hf_config
|
| 6 |
+
from .hf2mcore import convert_hf2mcore
|
| 7 |
+
from .mcore2hf import convert_mcore2hf
|
| 8 |
+
from .model import model_provider
|
| 9 |
+
|
| 10 |
+
register_megatron_model(
|
| 11 |
+
MegatronModelMeta(MegatronModelType.gpt, [
|
| 12 |
+
ModelType.qwen2,
|
| 13 |
+
ModelType.qwen2_5,
|
| 14 |
+
ModelType.qwq,
|
| 15 |
+
ModelType.qwq_preview,
|
| 16 |
+
ModelType.qwen2_5_math,
|
| 17 |
+
ModelType.llama,
|
| 18 |
+
ModelType.llama3,
|
| 19 |
+
ModelType.llama3_1,
|
| 20 |
+
ModelType.llama3_2,
|
| 21 |
+
ModelType.longwriter_llama3_1,
|
| 22 |
+
ModelType.codefuse_codellama,
|
| 23 |
+
ModelType.marco_o1,
|
| 24 |
+
ModelType.deepseek,
|
| 25 |
+
ModelType.deepseek_r1_distill,
|
| 26 |
+
ModelType.yi,
|
| 27 |
+
ModelType.yi_coder,
|
| 28 |
+
ModelType.sus,
|
| 29 |
+
ModelType.skywork_o1,
|
| 30 |
+
ModelType.openbuddy_llama,
|
| 31 |
+
ModelType.openbuddy_llama3,
|
| 32 |
+
ModelType.megrez,
|
| 33 |
+
ModelType.reflection,
|
| 34 |
+
ModelType.numina,
|
| 35 |
+
ModelType.ziya,
|
| 36 |
+
ModelType.mengzi3,
|
| 37 |
+
ModelType.qwen3,
|
| 38 |
+
ModelType.qwen2_moe,
|
| 39 |
+
ModelType.qwen3_moe,
|
| 40 |
+
], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore))
|
ms-swift/swift/megatron/model/gpt/config.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from ..config import convert_hf_config
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def convert_gpt_hf_config(config) -> Dict[str, Any]:
|
| 7 |
+
res = convert_hf_config(config)
|
| 8 |
+
model_type = res.get('model_type')
|
| 9 |
+
if model_type in {'qwen3', 'qwen3_moe'}:
|
| 10 |
+
res['qk_layernorm'] = True
|
| 11 |
+
if model_type in {'qwen2_moe', 'qwen3_moe'}:
|
| 12 |
+
res.pop('ffn_hidden_size', None)
|
| 13 |
+
return res
|
ms-swift/swift/megatron/model/gpt/model.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from megatron.core.models.gpt import GPTModel
|
| 3 |
+
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
|
| 4 |
+
from megatron.training import get_args
|
| 5 |
+
from megatron.training.arguments import core_transformer_config_from_args
|
| 6 |
+
|
| 7 |
+
from ..rope import update_rope_inv_freq
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def model_provider(pre_process=True, post_process=True):
|
| 11 |
+
args = get_args()
|
| 12 |
+
config = core_transformer_config_from_args(args)
|
| 13 |
+
config.variable_seq_lengths = True
|
| 14 |
+
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm,
|
| 15 |
+
args.qk_layernorm, args.multi_latent_attention)
|
| 16 |
+
if args.num_experts and args.moe_shared_expert_intermediate_size:
|
| 17 |
+
# qwen2_moe/qwen3_moe
|
| 18 |
+
transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True}
|
| 19 |
+
model = GPTModel(
|
| 20 |
+
config=config,
|
| 21 |
+
transformer_layer_spec=transformer_layer_spec,
|
| 22 |
+
vocab_size=args.padded_vocab_size,
|
| 23 |
+
max_sequence_length=args.max_position_embeddings,
|
| 24 |
+
pre_process=pre_process,
|
| 25 |
+
post_process=post_process,
|
| 26 |
+
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
|
| 27 |
+
parallel_output=True,
|
| 28 |
+
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
|
| 29 |
+
position_embedding_type=args.position_embedding_type,
|
| 30 |
+
rotary_percent=args.rotary_percent,
|
| 31 |
+
rotary_base=args.rotary_base,
|
| 32 |
+
rope_scaling=args.use_rope_scaling,
|
| 33 |
+
rope_scaling_factor=args.rope_scaling_factor,
|
| 34 |
+
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
|
| 35 |
+
if args.rope_scaling:
|
| 36 |
+
update_rope_inv_freq(model.rotary_pos_emb.inv_freq, args.rope_scaling)
|
| 37 |
+
return model
|
ms-swift/swift/megatron/model/register.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
from swift.llm import MODEL_MAPPING, ModelGroup
|
| 9 |
+
|
| 10 |
+
MEGATRON_MODEL_MAPPING = {}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class MegatronModelMeta:
|
| 15 |
+
megatron_model_type: str
|
| 16 |
+
model_types: List[str]
|
| 17 |
+
|
| 18 |
+
model_provider: Callable[[], nn.Module]
|
| 19 |
+
convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]]
|
| 20 |
+
convert_mcore2hf: Callable[[nn.Module, nn.Module], None]
|
| 21 |
+
convert_hf2mcore: Callable[[nn.Module, nn.Module], None]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False):
|
| 25 |
+
megatron_model_type = megatron_model_meta.megatron_model_type
|
| 26 |
+
for model_type in megatron_model_meta.model_types:
|
| 27 |
+
model_meta = MODEL_MAPPING[model_type]
|
| 28 |
+
model_meta.support_megatron = True
|
| 29 |
+
if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING:
|
| 30 |
+
raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.')
|
| 31 |
+
|
| 32 |
+
MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
_MODEL_META_MAPPING = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]:
|
| 39 |
+
global _MODEL_META_MAPPING
|
| 40 |
+
if _MODEL_META_MAPPING is None:
|
| 41 |
+
_MODEL_META_MAPPING = {}
|
| 42 |
+
for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items():
|
| 43 |
+
for _model_type in megatron_model_meta.model_types:
|
| 44 |
+
_MODEL_META_MAPPING[_model_type] = k
|
| 45 |
+
if model_type not in _MODEL_META_MAPPING:
|
| 46 |
+
return
|
| 47 |
+
return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]]
|
ms-swift/swift/megatron/model/rope.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _to_llama3_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]):
|
| 8 |
+
# copy from transformers
|
| 9 |
+
factor = rope_scaling['factor'] # `8` in the original implementation
|
| 10 |
+
low_freq_factor = rope_scaling['low_freq_factor'] # `1` in the original implementation
|
| 11 |
+
high_freq_factor = rope_scaling['high_freq_factor'] # `4` in the original implementation
|
| 12 |
+
old_context_len = rope_scaling['original_max_position_embeddings'] # `8192` in the original implementation
|
| 13 |
+
|
| 14 |
+
low_freq_wavelen = old_context_len / low_freq_factor
|
| 15 |
+
high_freq_wavelen = old_context_len / high_freq_factor
|
| 16 |
+
|
| 17 |
+
wavelen = 2 * math.pi / inv_freq
|
| 18 |
+
# wavelen < high_freq_wavelen: do nothing
|
| 19 |
+
# wavelen > low_freq_wavelen: divide by factor
|
| 20 |
+
inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
|
| 21 |
+
# otherwise: interpolate between the two, using a smooth factor
|
| 22 |
+
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
| 23 |
+
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
|
| 24 |
+
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
| 25 |
+
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
| 26 |
+
return inv_freq_llama
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _to_linear_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]):
|
| 30 |
+
factor = rope_scaling['factor']
|
| 31 |
+
inv_freq /= factor
|
| 32 |
+
return inv_freq
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
ROPE_MAPPING = {'llama3': _to_llama3_rope, 'linear': _to_linear_rope}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def update_rope_inv_freq(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]) -> None:
|
| 39 |
+
new_inv_freq = ROPE_MAPPING[rope_scaling['rope_type']](inv_freq, rope_scaling)
|
| 40 |
+
inv_freq.data.copy_(new_inv_freq)
|
ms-swift/swift/megatron/train/patcher.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
import os
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from functools import wraps
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from megatron.training import get_args, global_vars, initialize, training
|
| 8 |
+
|
| 9 |
+
from swift.utils import JsonlWriter, is_master
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def patch_training_log():
|
| 14 |
+
jsonl_writer = None
|
| 15 |
+
origin_training_log = training.training_log
|
| 16 |
+
|
| 17 |
+
@wraps(origin_training_log)
|
| 18 |
+
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
|
| 19 |
+
report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs):
|
| 20 |
+
nonlocal jsonl_writer
|
| 21 |
+
args = get_args()
|
| 22 |
+
if is_master() and iteration % args.log_interval == 0:
|
| 23 |
+
logging_path = os.path.join(args.save, 'logging.jsonl')
|
| 24 |
+
logs = {}
|
| 25 |
+
for k, v in loss_dict.items():
|
| 26 |
+
if isinstance(v, torch.Tensor):
|
| 27 |
+
v = v.item()
|
| 28 |
+
logs[k] = round(v, 8)
|
| 29 |
+
for k in {'grad_norm', 'params_norm', 'learning_rate'}:
|
| 30 |
+
v = locals()[k]
|
| 31 |
+
if v is not None:
|
| 32 |
+
logs[k] = round(v, 8)
|
| 33 |
+
logs['consumed_samples'] = args.consumed_train_samples
|
| 34 |
+
logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}'
|
| 35 |
+
if jsonl_writer is None:
|
| 36 |
+
jsonl_writer = JsonlWriter(logging_path, enable_async=True)
|
| 37 |
+
jsonl_writer.append(logs)
|
| 38 |
+
return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
|
| 39 |
+
loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm,
|
| 40 |
+
num_zeros_in_grad, *_args, **kwargs)
|
| 41 |
+
|
| 42 |
+
training.training_log = training_log
|
| 43 |
+
try:
|
| 44 |
+
yield
|
| 45 |
+
finally:
|
| 46 |
+
training.training_log = origin_training_log
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@contextmanager
|
| 50 |
+
def patch_megatron_data_collator(data_collator):
|
| 51 |
+
origin_build_pretraining_data_loader = training.build_pretraining_data_loader
|
| 52 |
+
|
| 53 |
+
def build_pretraining_data_loader(*_args, **kwargs):
|
| 54 |
+
args = get_args()
|
| 55 |
+
res = origin_build_pretraining_data_loader(*_args, **kwargs)
|
| 56 |
+
if res is not None and args.dataloader_type != 'external':
|
| 57 |
+
res.collate_fn = data_collator
|
| 58 |
+
return res
|
| 59 |
+
|
| 60 |
+
training.build_pretraining_data_loader = build_pretraining_data_loader
|
| 61 |
+
try:
|
| 62 |
+
yield
|
| 63 |
+
finally:
|
| 64 |
+
training.build_pretraining_data_loader = origin_build_pretraining_data_loader
|
ms-swift/swift/megatron/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
from .convert import convert_hf2mcore, convert_mcore2hf
|
| 4 |
+
from .patcher import patch_megatron_tokenizer
|
ms-swift/swift/megatron/utils/convert.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from megatron.training.checkpointing import load_checkpoint
|
| 7 |
+
from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint
|
| 8 |
+
from megatron.training.initialize import initialize_megatron
|
| 9 |
+
from megatron.training.utils import get_ltor_masks_and_position_ids
|
| 10 |
+
|
| 11 |
+
from swift.llm import ExportArguments, get_model_tokenizer, get_template, save_checkpoint
|
| 12 |
+
from swift.utils import get_logger, get_n_params_grads
|
| 13 |
+
from ..argument import MegatronArguments
|
| 14 |
+
from ..model import get_megatron_model_meta
|
| 15 |
+
from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard
|
| 16 |
+
|
| 17 |
+
logger = get_logger()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_convert_precision(hf_model, mg_model, processor):
|
| 21 |
+
torch_dtype = hf_model.dtype
|
| 22 |
+
template = get_template(hf_model.model_meta.template, processor)
|
| 23 |
+
input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids']
|
| 24 |
+
input_ids = torch.tensor(input_ids)[None].to('cuda')
|
| 25 |
+
hf_model.to('cuda')
|
| 26 |
+
hf_model.to(torch.float32)
|
| 27 |
+
with torch.inference_mode():
|
| 28 |
+
hf_logits = hf_model(input_ids).logits
|
| 29 |
+
hf_model.to(torch_dtype)
|
| 30 |
+
hf_model.to('cpu')
|
| 31 |
+
|
| 32 |
+
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
|
| 33 |
+
mg_model.to('cuda')
|
| 34 |
+
mg_model.to(torch.float32)
|
| 35 |
+
with torch.inference_mode():
|
| 36 |
+
mg_logits = mg_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
|
| 37 |
+
mg_model.to(torch_dtype)
|
| 38 |
+
mg_model.to('cpu')
|
| 39 |
+
|
| 40 |
+
mean_diff = (mg_logits - hf_logits).abs().mean().item()
|
| 41 |
+
max_diff = (mg_logits - hf_logits).abs().max().item()
|
| 42 |
+
print(f'mean_diff: {mean_diff}, max_diff: {max_diff}')
|
| 43 |
+
hf_tokens = hf_logits.argmax(-1)
|
| 44 |
+
mg_tokens = mg_logits.argmax(-1)
|
| 45 |
+
print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}')
|
| 46 |
+
assert mean_diff < 0.1
|
| 47 |
+
assert (hf_tokens == mg_tokens).all()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
convert_kwargs = {
|
| 51 |
+
'use_cpu_initialization': True,
|
| 52 |
+
'no_save_optim': True,
|
| 53 |
+
'no_save_rng': True,
|
| 54 |
+
'no_load_optim': True,
|
| 55 |
+
'no_load_rng': True,
|
| 56 |
+
'no_masked_softmax_fusion': True,
|
| 57 |
+
'no_bias_dropout_fusion': True,
|
| 58 |
+
'no_bias_swiglu_fusion': True,
|
| 59 |
+
'no_rope_fusion': True
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def convert_hf2mcore(args: ExportArguments) -> None:
|
| 64 |
+
kwargs = args.get_model_kwargs()
|
| 65 |
+
hf_model, processor = get_model_tokenizer(**kwargs)
|
| 66 |
+
if args.thread_count is None:
|
| 67 |
+
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
|
| 68 |
+
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
|
| 69 |
+
patch_torch_dist_shard(args.thread_count)
|
| 70 |
+
|
| 71 |
+
megatron_model_meta = get_megatron_model_meta(args.model_type)
|
| 72 |
+
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
|
| 73 |
+
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
|
| 74 |
+
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
|
| 75 |
+
patch_megatron_tokenizer(processor)
|
| 76 |
+
extra_args = megatron_args.parse_to_megatron()
|
| 77 |
+
initialize_megatron(args_defaults=extra_args)
|
| 78 |
+
|
| 79 |
+
mg_model = megatron_model_meta.model_provider()
|
| 80 |
+
logger.info('Megatron model created successfully.')
|
| 81 |
+
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
|
| 82 |
+
if args.test_convert_precision:
|
| 83 |
+
test_convert_precision(hf_model, mg_model, processor)
|
| 84 |
+
logger.info('Successfully transferred HF model weights to MG model.')
|
| 85 |
+
mg_save_checkpoint(1, [mg_model], None, None, 0)
|
| 86 |
+
args.save_args()
|
| 87 |
+
logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def convert_mcore2hf(args: ExportArguments) -> None:
|
| 91 |
+
kwargs = args.get_model_kwargs()
|
| 92 |
+
hf_model, processor = get_model_tokenizer(**kwargs)
|
| 93 |
+
if args.thread_count is None:
|
| 94 |
+
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
|
| 95 |
+
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
|
| 96 |
+
patch_torch_dist_shard(args.thread_count)
|
| 97 |
+
|
| 98 |
+
megatron_model_meta = get_megatron_model_meta(args.model_type)
|
| 99 |
+
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
|
| 100 |
+
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
|
| 101 |
+
megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype)
|
| 102 |
+
patch_megatron_tokenizer(processor)
|
| 103 |
+
extra_args = megatron_args.parse_to_megatron()
|
| 104 |
+
initialize_megatron(args_defaults=extra_args)
|
| 105 |
+
|
| 106 |
+
mg_model = megatron_model_meta.model_provider()
|
| 107 |
+
load_checkpoint([mg_model], None, None, strict=True)
|
| 108 |
+
logger.info('Megatron model created successfully.')
|
| 109 |
+
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
|
| 110 |
+
if args.test_convert_precision:
|
| 111 |
+
test_convert_precision(hf_model, mg_model, processor)
|
| 112 |
+
logger.info('Successfully transferred MG model weights to HF model.')
|
| 113 |
+
save_checkpoint(
|
| 114 |
+
hf_model,
|
| 115 |
+
processor,
|
| 116 |
+
args.output_dir,
|
| 117 |
+
safe_serialization=args.safe_serialization,
|
| 118 |
+
model_dirs=[args.mcore_model, args.model_dir],
|
| 119 |
+
max_shard_size=args.max_shard_size,
|
| 120 |
+
additional_saved_files=hf_model.model_meta.additional_saved_files)
|
| 121 |
+
args.save_args()
|
| 122 |
+
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
|
ms-swift/swift/megatron/utils/patcher.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 2 |
+
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
|
| 3 |
+
from megatron.training import get_args, global_vars, initialize, training
|
| 4 |
+
|
| 5 |
+
from swift.utils import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def patch_megatron_tokenizer(tokenizer):
|
| 11 |
+
|
| 12 |
+
def build_tokenizer(args):
|
| 13 |
+
args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size
|
| 14 |
+
return tokenizer
|
| 15 |
+
|
| 16 |
+
global_vars.build_tokenizer = build_tokenizer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def patch_torch_dist_shard(thread_count):
|
| 20 |
+
__init__ = TorchDistSaveShardedStrategy.__init__
|
| 21 |
+
|
| 22 |
+
def __new_init__(*args, **kwargs):
|
| 23 |
+
kwargs['thread_count'] = thread_count
|
| 24 |
+
return __init__(*args, **kwargs)
|
| 25 |
+
|
| 26 |
+
TorchDistSaveShardedStrategy.__init__ = __new_init__
|
ms-swift/swift/plugin/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
ms-swift/swift/plugin/__pycache__/callback.cpython-310.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
ms-swift/swift/plugin/__pycache__/metric.cpython-310.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|