Student0809 commited on
Commit
a050167
·
verified ·
1 Parent(s): 35dfdd4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ms-swift/processed_data/processed_overlap5s_speaker_segments.json +0 -0
  2. ms-swift/processed_data/processed_silence_isoverlaps.json +0 -0
  3. ms-swift/silence_overlaps/700/test/overlap5s_segments_test.json +27 -0
  4. ms-swift/silence_overlaps/700/test/overlap5s_silence_segments_test.json +27 -0
  5. ms-swift/silence_overlaps/700/train/overlap5s_issilence_segments_train.json +0 -0
  6. ms-swift/silence_overlaps/test/test_train.json +963 -0
  7. ms-swift/swift/llm/sampling/mcts.py +400 -0
  8. ms-swift/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc +0 -0
  9. ms-swift/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc +0 -0
  10. ms-swift/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc +0 -0
  11. ms-swift/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc +0 -0
  12. ms-swift/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc +0 -0
  13. ms-swift/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc +0 -0
  14. ms-swift/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc +0 -0
  15. ms-swift/swift/llm/template/template/__pycache__/valley.cpython-310.pyc +0 -0
  16. ms-swift/swift/llm/template/template/__pycache__/yi.cpython-310.pyc +0 -0
  17. ms-swift/swift/llm/template/template/deepseek.py +315 -0
  18. ms-swift/swift/llm/template/template/glm.py +293 -0
  19. ms-swift/swift/llm/template/template/internvl.py +168 -0
  20. ms-swift/swift/llm/template/template/llama.py +213 -0
  21. ms-swift/swift/llm/template/template/megrez.py +93 -0
  22. ms-swift/swift/llm/template/template/openbuddy.py +48 -0
  23. ms-swift/swift/llm/template/template/pixtral.py +59 -0
  24. ms-swift/swift/llm/template/template/qwen.py +671 -0
  25. ms-swift/swift/llm/template/template/stepfun.py +128 -0
  26. ms-swift/swift/llm/template/template/yi.py +63 -0
  27. ms-swift/swift/llm/train/__pycache__/callback.cpython-310.pyc +0 -0
  28. ms-swift/swift/llm/train/__pycache__/rlhf.cpython-310.pyc +0 -0
  29. ms-swift/swift/llm/train/__pycache__/sft.cpython-310.pyc +0 -0
  30. ms-swift/swift/llm/train/__pycache__/tuner.cpython-310.pyc +0 -0
  31. ms-swift/swift/llm/train/callback.py +80 -0
  32. ms-swift/swift/llm/train/rlhf.py +154 -0
  33. ms-swift/swift/llm/train/sft.py +287 -0
  34. ms-swift/swift/llm/train/tuner.py +424 -0
  35. ms-swift/swift/megatron/argument/train_args.py +53 -0
  36. ms-swift/swift/megatron/model/__init__.py +4 -0
  37. ms-swift/swift/megatron/model/config.py +57 -0
  38. ms-swift/swift/megatron/model/constant.py +3 -0
  39. ms-swift/swift/megatron/model/gpt/__init__.py +40 -0
  40. ms-swift/swift/megatron/model/gpt/config.py +13 -0
  41. ms-swift/swift/megatron/model/gpt/model.py +37 -0
  42. ms-swift/swift/megatron/model/register.py +47 -0
  43. ms-swift/swift/megatron/model/rope.py +40 -0
  44. ms-swift/swift/megatron/train/patcher.py +64 -0
  45. ms-swift/swift/megatron/utils/__init__.py +4 -0
  46. ms-swift/swift/megatron/utils/convert.py +122 -0
  47. ms-swift/swift/megatron/utils/patcher.py +26 -0
  48. ms-swift/swift/plugin/__pycache__/__init__.cpython-310.pyc +0 -0
  49. ms-swift/swift/plugin/__pycache__/callback.cpython-310.pyc +0 -0
  50. ms-swift/swift/plugin/__pycache__/metric.cpython-310.pyc +0 -0
ms-swift/processed_data/processed_overlap5s_speaker_segments.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/processed_data/processed_silence_isoverlaps.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/700/test/overlap5s_segments_test.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "key": "SODA_PROCESSED--train--123906",
4
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--123906.wav",
5
+ "model_output": "Multiple speakers talk simultaneously from 00:03-00:09"
6
+ },
7
+ {
8
+ "key": "SODA_PROCESSED--train--1112763",
9
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--1112763.wav",
10
+ "model_output": "Multiple speakers talk simultaneously from 00:09-00:15"
11
+ },
12
+ {
13
+ "key": "SODA_PROCESSED--train--790538",
14
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--790538.wav",
15
+ "model_output": "Multiple speakers talk simultaneously from 00:15-00:19"
16
+ },
17
+ {
18
+ "key": "SODA_PROCESSED--train--822773",
19
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--822773.wav",
20
+ "model_output": "Multiple speakers talk simultaneously from 00:14-00:19"
21
+ },
22
+ {
23
+ "key": "SODA_PROCESSED--train--424960",
24
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--424960.wav",
25
+ "model_output": "Multiple speakers talk simultaneously from 00:29-00:33"
26
+ }
27
+ ]
ms-swift/silence_overlaps/700/test/overlap5s_silence_segments_test.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "key": "SODA_PROCESSED--train--137471",
4
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--137471.wav",
5
+ "model_output": "No, there is no silence gap."
6
+ },
7
+ {
8
+ "key": "SODA_PROCESSED--train--201044",
9
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--201044.wav",
10
+ "model_output": "No, there is no silence gap."
11
+ },
12
+ {
13
+ "key": "SODA_PROCESSED--train--596349",
14
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--596349.wav",
15
+ "model_output": "No, there is no silence gap."
16
+ },
17
+ {
18
+ "key": "SODA_PROCESSED--train--956648",
19
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--956648.wav",
20
+ "model_output": "No, there is no silence gap."
21
+ },
22
+ {
23
+ "key": "SODA_PROCESSED--train--962210",
24
+ "audio_url": "/root/autodl-tmp/output_overlapslong/newoverlapjson/overlap5s700/SODA_PROCESSED--train--962210.wav",
25
+ "model_output": "No, there is no silence gap."
26
+ }
27
+ ]
ms-swift/silence_overlaps/700/train/overlap5s_issilence_segments_train.json ADDED
The diff for this file is too large to render. See raw diff
 
ms-swift/silence_overlaps/test/test_train.json ADDED
@@ -0,0 +1,963 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "SODA_PROCESSED--train--449689": {
3
+ "original_dialog_id": "",
4
+ "dialog_index": 449689,
5
+ "processed_dialogue": "A: Hey there. Mind if I lay down next to you? \nB: No, go ahead. \nA: Thanks. I needed a break from the sun. It's so hot today. \nB: Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage. \nA: Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax? \nB: Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week. \nA: That sounds rough. Are you excited for it? Or [interrupt] worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period? \nB: Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus. \nA: Definitely. Well, I hope you enjoy the rest of your day here. \nB: Thanks. You too.",
6
+ "clean_dialogue": "A: Hey there. Mind if I lay down next to you? \nB: No, go ahead. \nA: Thanks. I needed a break from the sun. It's so hot today. \nB: Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage. \nA: Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax? \nB: Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week. \nA:That sounds rough. Are you excited for it? Or worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?\nB: Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus. \nA: Definitely. Well, I hope you enjoy the rest of your day here. \nB: Thanks. You too.",
7
+ "speaker_tracks": {
8
+ "A": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/A_track.wav",
9
+ "B": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/B_track.wav"
10
+ },
11
+ "error_type": "error_after_interrupt",
12
+ "stereo_audio": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/stereo_dialogue.wav",
13
+ "total_duration": 50.09668934240363,
14
+ "segments": [
15
+ {
16
+ "speaker": "A",
17
+ "text": "Hey there. Mind if I lay down next to you?",
18
+ "original_text": "Hey there. Mind if I lay down next to you?",
19
+ "start_time": 0,
20
+ "end_time": 2.4961451247165534,
21
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_0_A.wav",
22
+ "silence_duration": 0,
23
+ "is_interrupted": false
24
+ },
25
+ {
26
+ "speaker": "B",
27
+ "text": "No, go ahead.",
28
+ "original_text": "No, go ahead.",
29
+ "start_time": 3.0616233505922237,
30
+ "end_time": 4.257451014991316,
31
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_1_B.wav",
32
+ "silence_duration": 0.5654782258756702,
33
+ "is_interrupted": false
34
+ },
35
+ {
36
+ "speaker": "A",
37
+ "text": "Thanks. I needed a break from the sun. It's so hot today.",
38
+ "original_text": "Thanks. I needed a break from the sun. It's so hot today.",
39
+ "start_time": 4.673061027457998,
40
+ "end_time": 8.666893227004483,
41
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_2_A.wav",
42
+ "silence_duration": 0.41561001246668183,
43
+ "is_interrupted": false
44
+ },
45
+ {
46
+ "speaker": "B",
47
+ "text": "Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage.",
48
+ "original_text": "Yeah, it is. I'm trying to get a tan, but I don't want to get too dehydrated, so I'm keeping a bottle of water close by and reapplying sunscreen every hour to avoid any skin damage.",
49
+ "start_time": 9.128191918953855,
50
+ "end_time": 19.01989259922596,
51
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_3_B.wav",
52
+ "silence_duration": 0.46129869194937123,
53
+ "is_interrupted": false
54
+ },
55
+ {
56
+ "speaker": "A",
57
+ "text": "Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax?",
58
+ "original_text": "Burnt? Yeah, that's definitely a possibility out here. So what brings you to the beach today? Just wanting to relax?",
59
+ "start_time": 19.43691572474219,
60
+ "end_time": 27.215600531998426,
61
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_4_A.wav",
62
+ "silence_duration": 0.4170231255162265,
63
+ "is_interrupted": false
64
+ },
65
+ {
66
+ "speaker": "B",
67
+ "text": "Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week.",
68
+ "original_text": "Yeah, pretty much. I just finished up my summer classes and needed some time to myself before starting my new job next week.",
69
+ "start_time": 27.73206790619358,
70
+ "end_time": 34.08272550256547,
71
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_5_B.wav",
72
+ "silence_duration": 0.5164673741951538,
73
+ "is_interrupted": false
74
+ },
75
+ {
76
+ "speaker": "A",
77
+ "text": "That sounds rough. Are you excited for it? Or",
78
+ "original_text": "That sounds rough. Are you excited for it? Or [interrupt] worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?",
79
+ "start_time": 34.40566150397062,
80
+ "end_time": 44.703711390591934,
81
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_6_A.wav",
82
+ "silence_duration": 0.3229360014051523,
83
+ "is_interrupted": true,
84
+ "text_after_interrupt": "worried about how you'll balance everything with your personal life and other commitments you might have during this transitional period?"
85
+ },
86
+ {
87
+ "speaker": "B",
88
+ "text": "Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus.",
89
+ "original_text": "Nervous? A little bit of both, honestly. But mostly excited. It should be a good experience. And the pay is great, so that's a plus.",
90
+ "start_time": 37.1456161524967,
91
+ "end_time": 44.564391662700785,
92
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_7_B.wav",
93
+ "silence_duration": 0.36321869535217244,
94
+ "is_interrupted": false
95
+ },
96
+ {
97
+ "speaker": "A",
98
+ "text": "Definitely. Well, I hope you enjoy the rest of your day here.",
99
+ "original_text": "Definitely. Well, I hope you enjoy the rest of your day here.",
100
+ "start_time": 44.9023552612567,
101
+ "end_time": 48.78008768756056,
102
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_8_A.wav",
103
+ "silence_duration": 0.33796359855591646,
104
+ "is_interrupted": false
105
+ },
106
+ {
107
+ "speaker": "B",
108
+ "text": "Thanks. You too.",
109
+ "original_text": "Thanks. You too.",
110
+ "start_time": 49.1679089027611,
111
+ "end_time": 50.09670708870214,
112
+ "audio_file": "/root/autodl-tmp/output_mixedAudios/processed_soda_3_processed_dialogues_part_20/SODA_PROCESSED--train--449689/temp/line_9_B.wav",
113
+ "silence_duration": 0.38782121520053575,
114
+ "is_interrupted": false
115
+ }
116
+ ],
117
+ "gt_score": 1
118
+ },
119
+ "SODA_PROCESSED--train--787791": {
120
+ "original_dialog_id": "",
121
+ "dialog_index": 787791,
122
+ "processed_dialogue": "A: You're welcome. I'm just glad I was able to stop it from happening. \nB: Thank you so much for saving my life. I can't even begin to express how [interrupt] grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment. \nA: Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally? \nB: I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time. \nA: Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay? \nB: I really appreciate that. Thanks again, Antwain. \nA: No problem. Take care.",
123
+ "clean_dialogue": "A: You're welcome. I'm just glad I was able to stop it from happening. \nB:Thank you so much for saving my life. I can't even begin to express how grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment.\nA: Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally? \nB: I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time. \nA: Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay? \nB: I really appreciate that. Thanks again, Antwain. \nA: No problem. Take care.",
124
+ "speaker_tracks": {
125
+ "A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/A_track.wav",
126
+ "B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/B_track.wav"
127
+ },
128
+ "error_type": "error_after_interrupt",
129
+ "stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/stereo_dialogue.wav",
130
+ "total_duration": 37.52730158730159,
131
+ "segments": [
132
+ {
133
+ "speaker": "A",
134
+ "text": "You're welcome. I'm just glad I was able to stop it from happening.",
135
+ "original_text": "You're welcome. I'm just glad I was able to stop it from happening.",
136
+ "start_time": 0,
137
+ "end_time": 4.249251700680272,
138
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_0_A.wav",
139
+ "silence_duration": 0,
140
+ "is_interrupted": false
141
+ },
142
+ {
143
+ "speaker": "B",
144
+ "text": "Thank you so much for saving my life. I can't even begin to express how",
145
+ "original_text": "Thank you so much for saving my life. I can't even begin to express how [interrupt] grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment.",
146
+ "start_time": 4.756366963799184,
147
+ "end_time": 14.694507553368345,
148
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_1_B.wav",
149
+ "silence_duration": 0.5071152631189118,
150
+ "is_interrupted": true,
151
+ "text_after_interrupt": "grateful I am for what you did. It means the world to me and I'll never forget your kindness and quick thinking in that moment."
152
+ },
153
+ {
154
+ "speaker": "A",
155
+ "text": "Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally?",
156
+ "original_text": "Sorry to jump in, but are you sure you're okay? I mean, physically and emotionally?",
157
+ "start_time": 8.726979208697143,
158
+ "end_time": 14.357818210964716,
159
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_2_A.wav",
160
+ "silence_duration": 0.4049084459018305,
161
+ "is_interrupted": false
162
+ },
163
+ {
164
+ "speaker": "B",
165
+ "text": "I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time.",
166
+ "original_text": "I think so, but it's all still a bit of a blur. I don't know what would have happened if you hadn't been there. I'm just glad that you were in the right place at the right time.",
167
+ "start_time": 14.861085984580113,
168
+ "end_time": 23.649838819047233,
169
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_3_B.wav",
170
+ "silence_duration": 0.5032677736153957,
171
+ "is_interrupted": false
172
+ },
173
+ {
174
+ "speaker": "A",
175
+ "text": "Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay?",
176
+ "original_text": "Yeah, me too. But seriously, if you need anything—someone to talk to or whatever—don't hesitate to reach out, okay?",
177
+ "start_time": 24.145193415777634,
178
+ "end_time": 32.515987066571284,
179
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_4_A.wav",
180
+ "silence_duration": 0.4953545967303996,
181
+ "is_interrupted": false
182
+ },
183
+ {
184
+ "speaker": "B",
185
+ "text": "I really appreciate that. Thanks again, Antwain.",
186
+ "original_text": "I really appreciate that. Thanks again, Antwain.",
187
+ "start_time": 32.97180815148517,
188
+ "end_time": 35.68854284536272,
189
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_5_B.wav",
190
+ "silence_duration": 0.4558210849138826,
191
+ "is_interrupted": false
192
+ },
193
+ {
194
+ "speaker": "A",
195
+ "text": "No problem. Take care.",
196
+ "original_text": "No problem. Take care.",
197
+ "start_time": 35.99481454512998,
198
+ "end_time": 37.5273315519327,
199
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_8/SODA_PROCESSED--train--787791/temp/line_6_A.wav",
200
+ "silence_duration": 0.3062716997672569,
201
+ "is_interrupted": false
202
+ }
203
+ ],
204
+ "gt_score": 1
205
+ },
206
+ "SODA_PROCESSED--train--179972": {
207
+ "original_dialog_id": "",
208
+ "dialog_index": 179972,
209
+ "processed_dialogue": "A: So, how did you like the book? \nB: I loved it! The ending was so shocking, I couldn't believe what happened. \nA: Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided. \nB: No, I didn't see it coming at all! It was so unexpected. \nA: Yeah, I know. I couldn't put it down. \nB: Me neither. I'm so glad you wanted to read it. \nA: Yeah, I was curious about the protagonist's journey and how it would [interrupt] evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict. \nB: Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me. \nA: It was definitely a rollercoaster ride. There were so many twists and turns. \nB: I know! I didn't see any of them coming. \nA: That's what made it so great. It kept you guessing the whole time. \nB: Definitely. It was a great book. Thanks for lending it to me.",
210
+ "clean_dialogue": "A: So, how did you like the book? \nB: I loved it! The ending was so shocking, I couldn't believe what happened. \nA: Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided. \nB: No, I didn't see it coming at all! It was so unexpected. \nA: Yeah, I know. I couldn't put it down. \nB: Me neither. I'm so glad you wanted to read it. \nA:Yeah, I was curious about the protagonist's journey and how it would evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict.\nB: Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me. \nA: It was definitely a rollercoaster ride. There were so many twists and turns. \nB: I know! I didn't see any of them coming. \nA: That's what made it so great. It kept you guessing the whole time. \nB: Definitely. It was a great book. Thanks for lending it to me.",
211
+ "speaker_tracks": {
212
+ "A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/A_track.wav",
213
+ "B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/B_track.wav"
214
+ },
215
+ "error_type": "error_after_interrupt",
216
+ "stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/stereo_dialogue.wav",
217
+ "total_duration": 53.57845804988662,
218
+ "segments": [
219
+ {
220
+ "speaker": "A",
221
+ "text": "So, how did you like the book?",
222
+ "original_text": "So, how did you like the book?",
223
+ "start_time": 0,
224
+ "end_time": 1.6950566893424037,
225
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_0_A.wav",
226
+ "silence_duration": 0,
227
+ "is_interrupted": false
228
+ },
229
+ {
230
+ "speaker": "B",
231
+ "text": "I loved it! The ending was so shocking, I couldn't believe what happened.",
232
+ "original_text": "I loved it! The ending was so shocking, I couldn't believe what happened.",
233
+ "start_time": 2.1792484824735485,
234
+ "end_time": 5.871221271589195,
235
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_1_B.wav",
236
+ "silence_duration": 0.4841917931311449,
237
+ "is_interrupted": false
238
+ },
239
+ {
240
+ "speaker": "A",
241
+ "text": "Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided.",
242
+ "original_text": "Sorry to interrupt, but I just have to ask—did you see that twist with the protagonist coming? I was totally blindsided.",
243
+ "start_time": 6.47038511683308,
244
+ "end_time": 14.504489425223102,
245
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_2_A.wav",
246
+ "silence_duration": 0.5991638452438857,
247
+ "is_interrupted": false
248
+ },
249
+ {
250
+ "speaker": "B",
251
+ "text": "No, I didn't see it coming at all! It was so unexpected.",
252
+ "original_text": "No, I didn't see it coming at all! It was so unexpected.",
253
+ "start_time": 15.012397119017507,
254
+ "end_time": 18.448950406999366,
255
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_3_B.wav",
256
+ "silence_duration": 0.507907693794404,
257
+ "is_interrupted": false
258
+ },
259
+ {
260
+ "speaker": "A",
261
+ "text": "Yeah, I know. I couldn't put it down.",
262
+ "original_text": "Yeah, I know. I couldn't put it down.",
263
+ "start_time": 18.875209136594886,
264
+ "end_time": 21.847363331606225,
265
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_4_A.wav",
266
+ "silence_duration": 0.42625872959552,
267
+ "is_interrupted": false
268
+ },
269
+ {
270
+ "speaker": "B",
271
+ "text": "Me neither. I'm so glad you wanted to read it.",
272
+ "original_text": "Me neither. I'm so glad you wanted to read it.",
273
+ "start_time": 22.440054691555087,
274
+ "end_time": 25.110349476135585,
275
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_5_B.wav",
276
+ "silence_duration": 0.5926913599488615,
277
+ "is_interrupted": false
278
+ },
279
+ {
280
+ "speaker": "A",
281
+ "text": "Yeah, I was curious about the protagonist's journey and how it would",
282
+ "original_text": "Yeah, I was curious about the protagonist's journey and how it would [interrupt] evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict.",
283
+ "start_time": 25.51803755034393,
284
+ "end_time": 36.89581532812171,
285
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_6_A.wav",
286
+ "silence_duration": 0.40768807420834613,
287
+ "is_interrupted": true,
288
+ "text_after_interrupt": "evolve, especially after that major setback when they had to completely rethink their entire approach to solving the central conflict."
289
+ },
290
+ {
291
+ "speaker": "B",
292
+ "text": "Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me.",
293
+ "original_text": "Oh, speaking of the journey, what did you think about that part where the protagonist had to make that impossible choice? It really stuck with me.",
294
+ "start_time": 29.790509205672727,
295
+ "end_time": 37.429874285037805,
296
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_7_B.wav",
297
+ "silence_duration": 0.32835611460902553,
298
+ "is_interrupted": false
299
+ },
300
+ {
301
+ "speaker": "A",
302
+ "text": "It was definitely a rollercoaster ride. There were so many twists and turns.",
303
+ "original_text": "It was definitely a rollercoaster ride. There were so many twists and turns.",
304
+ "start_time": 37.91219711578734,
305
+ "end_time": 42.405258340277136,
306
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_8_A.wav",
307
+ "silence_duration": 0.4823228307495384,
308
+ "is_interrupted": false
309
+ },
310
+ {
311
+ "speaker": "B",
312
+ "text": "I know! I didn't see any of them coming.",
313
+ "original_text": "I know! I didn't see any of them coming.",
314
+ "start_time": 42.860468420817675,
315
+ "end_time": 45.08958406707618,
316
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_9_B.wav",
317
+ "silence_duration": 0.4552100805405374,
318
+ "is_interrupted": false
319
+ },
320
+ {
321
+ "speaker": "A",
322
+ "text": "That's what made it so great. It kept you guessing the whole time.",
323
+ "original_text": "That's what made it so great. It kept you guessing the whole time.",
324
+ "start_time": 45.679186523390214,
325
+ "end_time": 49.394379267154385,
326
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_10_A.wav",
327
+ "silence_duration": 0.5896024563140343,
328
+ "is_interrupted": false
329
+ },
330
+ {
331
+ "speaker": "B",
332
+ "text": "Definitely. It was a great book. Thanks for lending it to me.",
333
+ "original_text": "Definitely. It was a great book. Thanks for lending it to me.",
334
+ "start_time": 49.70074891577286,
335
+ "end_time": 53.57848134207672,
336
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_7/SODA_PROCESSED--train--179972/temp/line_11_B.wav",
337
+ "silence_duration": 0.3063696486184793,
338
+ "is_interrupted": false
339
+ }
340
+ ],
341
+ "gt_score": 1
342
+ },
343
+ "SODA_PROCESSED--train--715956": {
344
+ "original_dialog_id": "",
345
+ "dialog_index": 715956,
346
+ "processed_dialogue": "A: Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your [interrupt] help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.\nB: Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.\nA: But you're my lawyer! You're supposed to help me!\nB: Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.\nA: But I can't go to prison! I'll lose my job, my apartment, everything!\nB: Aadya, you need to calm down. Getting upset isn't going to help anything.\nA: Fine. But, you have to promise me that you'll do everything you can to help me.\nB: I promise.",
347
+ "clean_dialogue": "A:Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.\nB: Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.\nA: But you're my lawyer! You're supposed to help me!\nB: Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.\nA: But I can't go to prison! I'll lose my job, my apartment, everything!\nB: Aadya, you need to calm down. Getting upset isn't going to help anything.\nA: Fine. But, you have to promise me that you'll do everything you can to help me.\nB: I promise.",
348
+ "speaker_tracks": {
349
+ "A": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/A_track.wav",
350
+ "B": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/B_track.wav"
351
+ },
352
+ "error_type": "error_after_interrupt",
353
+ "stereo_audio": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/stereo_dialogue.wav",
354
+ "total_duration": 49.52126984126984,
355
+ "segments": [
356
+ {
357
+ "speaker": "A",
358
+ "text": "Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your",
359
+ "original_text": "Look, I know that I messed up. I was caught with heroin and I'm facing some serious charges. But, I really need your [interrupt] help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time.",
360
+ "start_time": 0,
361
+ "end_time": 16.579047619047618,
362
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_0_A.wav",
363
+ "silence_duration": 0,
364
+ "is_interrupted": true,
365
+ "text_after_interrupt": "help to find a way out of this because I'm terrified of what might happen if I don't get proper legal representation and support during this difficult time."
366
+ },
367
+ {
368
+ "speaker": "B",
369
+ "text": "Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.",
370
+ "original_text": "Aadya, we've been over this already. The evidence against you is pretty damning. Plus, you have a history of drug use. I don't think there's much that can be done to help you at this point.",
371
+ "start_time": 8.510113378684807,
372
+ "end_time": 18.36698412698413,
373
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_1_B.wav",
374
+ "silence_duration": 0.4899749375576017,
375
+ "is_interrupted": false
376
+ },
377
+ {
378
+ "speaker": "A",
379
+ "text": "But you're my lawyer! You're supposed to help me!",
380
+ "original_text": "But you're my lawyer! You're supposed to help me!",
381
+ "start_time": 18.846747434390966,
382
+ "end_time": 21.37772249108031,
383
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_2_A.wav",
384
+ "silence_duration": 0.4797633074068387,
385
+ "is_interrupted": false
386
+ },
387
+ {
388
+ "speaker": "B",
389
+ "text": "Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.",
390
+ "original_text": "Aadya, I'm doing everything that I can. But, realistically, the chances of you getting out of this are pretty slim. You need to prepare yourself for the possibility of a conviction and think about how you'll handle it.",
391
+ "start_time": 21.881120947184385,
392
+ "end_time": 33.51431822609595,
393
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_3_B.wav",
394
+ "silence_duration": 0.5033984561040751,
395
+ "is_interrupted": false
396
+ },
397
+ {
398
+ "speaker": "A",
399
+ "text": "But I can't go to prison! I'll lose my job, my apartment, everything!",
400
+ "original_text": "But I can't go to prison! I'll lose my job, my apartment, everything!",
401
+ "start_time": 34.047335561433606,
402
+ "end_time": 38.48234689930209,
403
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_4_A.wav",
404
+ "silence_duration": 0.5330173353376504,
405
+ "is_interrupted": false
406
+ },
407
+ {
408
+ "speaker": "B",
409
+ "text": "Aadya, you need to calm down. Getting upset isn't going to help anything.",
410
+ "original_text": "Aadya, you need to calm down. Getting upset isn't going to help anything.",
411
+ "start_time": 38.89720479711025,
412
+ "end_time": 43.39026602160004,
413
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_5_B.wav",
414
+ "silence_duration": 0.4148578978081613,
415
+ "is_interrupted": false
416
+ },
417
+ {
418
+ "speaker": "A",
419
+ "text": "Fine. But, you have to promise me that you'll do everything you can to help me.",
420
+ "original_text": "Fine. But, you have to promise me that you'll do everything you can to help me.",
421
+ "start_time": 43.92319932038778,
422
+ "end_time": 48.27694081698642,
423
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_6_A.wav",
424
+ "silence_duration": 0.5329332987877419,
425
+ "is_interrupted": false
426
+ },
427
+ {
428
+ "speaker": "B",
429
+ "text": "I promise.",
430
+ "original_text": "I promise.",
431
+ "start_time": 48.62731544236006,
432
+ "end_time": 49.52128369632831,
433
+ "audio_file": "/root/autodl-tmp/output_overlapslong/processed_soda_3_processed_dialogues_part_3/SODA_PROCESSED--train--715956/temp/line_7_B.wav",
434
+ "silence_duration": 0.3503746253736393,
435
+ "is_interrupted": false
436
+ }
437
+ ],
438
+ "gt_score": 1
439
+ },
440
+ "SODA_PROCESSED--train--740576": {
441
+ "original_text": "A: Good morning, Mr. Nguyen! I hope you're doing well today.\nB: I'm doing well, thank you. How are you?\nA: I'm feeling great today! I have a lot of energy and I'm excited to [interrupt] tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.\nB: Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?\nA: I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.\nB: I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?\nA: That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
442
+ "cleaned_text": "A: Good morning, Mr. Nguyen! I hope you're doing well today.\nB: I'm doing well, thank you. How are you?\nA:I'm feeling great today! I have a lot of energy and I'm excited to tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.\nB: Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?\nA: I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.\nB: I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?\nA: That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
443
+ "total_duration": 49.437278911564626,
444
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/stereo_dialogue.wav",
445
+ "speaker_tracks": {
446
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/A_track.wav",
447
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/B_track.wav"
448
+ },
449
+ "error_type": "error_after_interrupt",
450
+ "segments": [
451
+ {
452
+ "speaker": "A",
453
+ "text": "Good morning, Mr. Nguyen! I hope you're doing well today.",
454
+ "original_text": "Good morning, Mr. Nguyen! I hope you're doing well today.",
455
+ "start_time": 0,
456
+ "end_time": 3.332063492063492,
457
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_0_A.wav",
458
+ "silence_duration": 0,
459
+ "is_interrupted": false
460
+ },
461
+ {
462
+ "speaker": "B",
463
+ "text": "I'm doing well, thank you. How are you?",
464
+ "original_text": "I'm doing well, thank you. How are you?",
465
+ "start_time": 3.7838731632362803,
466
+ "end_time": 5.583419648497051,
467
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_1_B.wav",
468
+ "silence_duration": 0.4518096711727882,
469
+ "is_interrupted": false
470
+ },
471
+ {
472
+ "speaker": "A",
473
+ "text": "I'm feeling great today! I have a lot of energy and I'm excited to",
474
+ "original_text": "I'm feeling great today! I have a lot of energy and I'm excited to [interrupt] tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients.",
475
+ "start_time": 5.88797031081498,
476
+ "end_time": 16.96388867816192,
477
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_2_A.wav",
478
+ "silence_duration": 0.30455066231792893,
479
+ "is_interrupted": true,
480
+ "text_after_interrupt": "tackle some new projects and challenges that will help us improve our workflow and achieve better results for our clients."
481
+ },
482
+ {
483
+ "speaker": "B",
484
+ "text": "Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?",
485
+ "original_text": "Sorry to interrupt, but I wanted to ask if there's anything specific you're looking forward to today?",
486
+ "start_time": 10.485521331223143,
487
+ "end_time": 16.104750356166456,
488
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_3_B.wav",
489
+ "silence_duration": 0.587489668114177,
490
+ "is_interrupted": false
491
+ },
492
+ {
493
+ "speaker": "A",
494
+ "text": "I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.",
495
+ "original_text": "I was going to say I'm excited to start my day. Actually, I'm looking forward to a team meeting we have later. I love working here. It's a great environment and the people are really supportive and collaborative, always willing to share their expertise and help each other grow professionally.",
496
+ "start_time": 17.385624216961087,
497
+ "end_time": 33.94145188136018,
498
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_4_A.wav",
499
+ "silence_duration": 0.4217355387991674,
500
+ "is_interrupted": false
501
+ },
502
+ {
503
+ "speaker": "B",
504
+ "text": "I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?",
505
+ "original_text": "I'm glad to hear that! Speaking of the team, do you think we should plan more team-building activities to maintain this positive environment?",
506
+ "start_time": 34.39980783470558,
507
+ "end_time": 41.74892348096408,
508
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_5_B.wav",
509
+ "silence_duration": 0.4583559533453947,
510
+ "is_interrupted": false
511
+ },
512
+ {
513
+ "speaker": "A",
514
+ "text": "That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
515
+ "original_text": "That's a great idea! We could definitely benefit from more team-building activities. We're happy to have you on our team.",
516
+ "start_time": 42.285572803275116,
517
+ "end_time": 49.437318835021145,
518
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--740576/temp/line_6_A.wav",
519
+ "silence_duration": 0.5366493223110326,
520
+ "is_interrupted": false
521
+ }
522
+ ]
523
+ },
524
+ "SODA_PROCESSED--train--836018": {
525
+ "original_text": "A: Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with [interrupt] organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.\nB: Actually, I could use some help with the data analysis part. It's a bit overwhelming.\nA: Sure, I can take care of that. So what do you think of the project so far?\nB: It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.\nA: Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.\nB: Yeah, definitely. It's fascinating.",
526
+ "cleaned_text": "A:Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.\nB: Actually, I could use some help with the data analysis part. It's a bit overwhelming.\nA: Sure, I can take care of that. So what do you think of the project so far?\nB: It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.\nA: Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.\nB: Yeah, definitely. It's fascinating.",
527
+ "total_duration": 42.34984126984127,
528
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/stereo_dialogue.wav",
529
+ "speaker_tracks": {
530
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/A_track.wav",
531
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/B_track.wav"
532
+ },
533
+ "error_type": "error_after_interrupt",
534
+ "segments": [
535
+ {
536
+ "speaker": "A",
537
+ "text": "Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with",
538
+ "original_text": "Hey Ceanna, I saw that you were doing the reports for the group project. Do you want me to help you with [interrupt] organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned.",
539
+ "start_time": 0,
540
+ "end_time": 15.011700680272108,
541
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_0_A.wav",
542
+ "silence_duration": 0,
543
+ "is_interrupted": true,
544
+ "text_after_interrupt": "organizing the sections or proofreading? I've got some experience with formatting academic papers and making sure all the citations are properly aligned."
545
+ },
546
+ {
547
+ "speaker": "B",
548
+ "text": "Actually, I could use some help with the data analysis part. It's a bit overwhelming.",
549
+ "original_text": "Actually, I could use some help with the data analysis part. It's a bit overwhelming.",
550
+ "start_time": 6.176507936507937,
551
+ "end_time": 11.250068027210885,
552
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_1_B.wav",
553
+ "silence_duration": 0.5190912573415952,
554
+ "is_interrupted": false
555
+ },
556
+ {
557
+ "speaker": "A",
558
+ "text": "Sure, I can take care of that. So what do you think of the project so far?",
559
+ "original_text": "Sure, I can take care of that. So what do you think of the project so far?",
560
+ "start_time": 15.60657282124108,
561
+ "end_time": 19.937094363191193,
562
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_2_A.wav",
563
+ "silence_duration": 0.5948721409689715,
564
+ "is_interrupted": false
565
+ },
566
+ {
567
+ "speaker": "B",
568
+ "text": "It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.",
569
+ "original_text": "It's interesting. I'm learning a lot about different cultures and how they influence people's daily lives, from their eating habits to their social interactions and even their work-life balance perspectives.",
570
+ "start_time": 20.306213172030862,
571
+ "end_time": 30.476553308085286,
572
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_3_B.wav",
573
+ "silence_duration": 0.36911880883966963,
574
+ "is_interrupted": false
575
+ },
576
+ {
577
+ "speaker": "A",
578
+ "text": "Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.",
579
+ "original_text": "Speaking of cultures, did you notice how the traditions vary even within the same country? It's amazing how diverse it can be.",
580
+ "start_time": 30.848617682402736,
581
+ "end_time": 39.10331155995375,
582
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_4_A.wav",
583
+ "silence_duration": 0.3720643743174508,
584
+ "is_interrupted": false
585
+ },
586
+ {
587
+ "speaker": "B",
588
+ "text": "Yeah, definitely. It's fascinating.",
589
+ "original_text": "Yeah, definitely. It's fascinating.",
590
+ "start_time": 39.435776463870354,
591
+ "end_time": 42.34988077226038,
592
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--836018/temp/line_5_B.wav",
593
+ "silence_duration": 0.3324649039166007,
594
+ "is_interrupted": false
595
+ }
596
+ ]
597
+ },
598
+ "SODA_PROCESSED--train--771154": {
599
+ "original_text": "A: Hey, Mom. Can I go to Tim's house today? \nB: No, you can't go to Tim's house today. \nA: Why [interrupt] can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this. \nB: Because I said so. And before you ask again, you need to focus on your homework. \nA: But I don't want to do my homework. \nB: Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester. \nA: Wait, can't I just finish it later? I promise I'll get it done before bed. \nB: No, you need to do it now. Once it's done, then we can talk about other plans.",
600
+ "cleaned_text": "A: Hey, Mom. Can I go to Tim's house today? \nB: No, you can't go to Tim's house today. \nA:Why can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this.\nB: Because I said so. And before you ask again, you need to focus on your homework. \nA: But I don't want to do my homework. \nB: Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester. \nA: Wait, can't I just finish it later? I promise I'll get it done before bed. \nB: No, you need to do it now. Once it's done, then we can talk about other plans.",
601
+ "total_duration": 35.76784580498866,
602
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/stereo_dialogue.wav",
603
+ "speaker_tracks": {
604
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/A_track.wav",
605
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/B_track.wav"
606
+ },
607
+ "error_type": "error_after_interrupt",
608
+ "segments": [
609
+ {
610
+ "speaker": "A",
611
+ "text": "Hey, Mom. Can I go to Tim's house today?",
612
+ "original_text": "Hey, Mom. Can I go to Tim's house today?",
613
+ "start_time": 0,
614
+ "end_time": 3.5294331065759637,
615
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_0_A.wav",
616
+ "silence_duration": 0,
617
+ "is_interrupted": false
618
+ },
619
+ {
620
+ "speaker": "B",
621
+ "text": "No, you can't go to Tim's house today.",
622
+ "original_text": "No, you can't go to Tim's house today.",
623
+ "start_time": 3.9899851353219105,
624
+ "end_time": 6.126220962986309,
625
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_1_B.wav",
626
+ "silence_duration": 0.4605520287459467,
627
+ "is_interrupted": false
628
+ },
629
+ {
630
+ "speaker": "A",
631
+ "text": "Why",
632
+ "original_text": "Why [interrupt] can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this.",
633
+ "start_time": 6.4787876256667465,
634
+ "end_time": 14.652211661947927,
635
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_2_A.wav",
636
+ "silence_duration": 0.3525666626804373,
637
+ "is_interrupted": true,
638
+ "text_after_interrupt": "can't I go? I finished all my chores and even helped with the dishes after dinner last night, so I really think I deserve this."
639
+ },
640
+ {
641
+ "speaker": "B",
642
+ "text": "Because I said so. And before you ask again, you need to focus on your homework.",
643
+ "original_text": "Because I said so. And before you ask again, you need to focus on your homework.",
644
+ "start_time": 7.210216197095318,
645
+ "end_time": 11.889037058773322,
646
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_3_B.wav",
647
+ "silence_duration": 0.4183677243140269,
648
+ "is_interrupted": false
649
+ },
650
+ {
651
+ "speaker": "A",
652
+ "text": "But I don't want to do my homework.",
653
+ "original_text": "But I don't want to do my homework.",
654
+ "start_time": 15.159162983353092,
655
+ "end_time": 17.074809241856492,
656
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_4_A.wav",
657
+ "silence_duration": 0.5069513214051653,
658
+ "is_interrupted": false
659
+ },
660
+ {
661
+ "speaker": "B",
662
+ "text": "Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester.",
663
+ "original_text": "Well, you have to do it anyway. Now go and get started immediately because your teacher specifically mentioned you need to improve your math skills this semester.",
664
+ "start_time": 17.6716136549098,
665
+ "end_time": 25.763767849921138,
666
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_5_B.wav",
667
+ "silence_duration": 0.5968044130533094,
668
+ "is_interrupted": false
669
+ },
670
+ {
671
+ "speaker": "A",
672
+ "text": "Wait, can't I just finish it later? I promise I'll get it done before bed.",
673
+ "original_text": "Wait, can't I just finish it later? I promise I'll get it done before bed.",
674
+ "start_time": 26.149694131743242,
675
+ "end_time": 31.02588460793372,
676
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_6_A.wav",
677
+ "silence_duration": 0.38592628182210614,
678
+ "is_interrupted": false
679
+ },
680
+ {
681
+ "speaker": "B",
682
+ "text": "No, you need to do it now. Once it's done, then we can talk about other plans.",
683
+ "original_text": "No, you need to do it now. Once it's done, then we can talk about other plans.",
684
+ "start_time": 31.518621255026567,
685
+ "end_time": 35.767872955706835,
686
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--771154/temp/line_7_B.wav",
687
+ "silence_duration": 0.49273664709284837,
688
+ "is_interrupted": false
689
+ }
690
+ ]
691
+ },
692
+ "SODA_PROCESSED--train--1794": {
693
+ "original_text": "A: Hey, Mom. \nB: Hey, Moriah. What's up? \nA: Not much. Just hanging out in my room. \nB: That's good. I'm glad you're keeping busy. \nA: Yeah, I'm just trying to stay out of [interrupt] everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately. \nB: Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you. \nA: I'm just kind of going through some stuff right now. \nB: Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being. \nA: Wait, is this about me staying out late last weekend? \nB: Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you. \nA: I don't know, Mom. Like I said, I'm just dealing with some stuff. \nB: Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah. \nA: I love you too, Mom.",
694
+ "cleaned_text": "A: Hey, Mom. \nB: Hey, Moriah. What's up? \nA: Not much. Just hanging out in my room. \nB: That's good. I'm glad you're keeping busy. \nA:Yeah, I'm just trying to stay out of everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately.\nB: Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you. \nA: I'm just kind of going through some stuff right now. \nB: Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being. \nA: Wait, is this about me staying out late last weekend? \nB: Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you. \nA: I don't know, Mom. Like I said, I'm just dealing with some stuff. \nB: Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah. \nA: I love you too, Mom.",
695
+ "total_duration": 57.99024943310658,
696
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/stereo_dialogue.wav",
697
+ "speaker_tracks": {
698
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/A_track.wav",
699
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/B_track.wav"
700
+ },
701
+ "error_type": "error_after_interrupt",
702
+ "segments": [
703
+ {
704
+ "speaker": "A",
705
+ "text": "Hey, Mom.",
706
+ "original_text": "Hey, Mom.",
707
+ "start_time": 0,
708
+ "end_time": 0.8591383219954648,
709
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_0_A.wav",
710
+ "silence_duration": 0,
711
+ "is_interrupted": false
712
+ },
713
+ {
714
+ "speaker": "B",
715
+ "text": "Hey, Moriah. What's up?",
716
+ "original_text": "Hey, Moriah. What's up?",
717
+ "start_time": 1.2689805234753475,
718
+ "end_time": 2.7782775756295424,
719
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_1_B.wav",
720
+ "silence_duration": 0.4098422014798827,
721
+ "is_interrupted": false
722
+ },
723
+ {
724
+ "speaker": "A",
725
+ "text": "Not much. Just hanging out in my room.",
726
+ "original_text": "Not much. Just hanging out in my room.",
727
+ "start_time": 3.2528527196865094,
728
+ "end_time": 5.505188320593539,
729
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_2_A.wav",
730
+ "silence_duration": 0.47457514405696677,
731
+ "is_interrupted": false
732
+ },
733
+ {
734
+ "speaker": "B",
735
+ "text": "That's good. I'm glad you're keeping busy.",
736
+ "original_text": "That's good. I'm glad you're keeping busy.",
737
+ "start_time": 6.047417085120735,
738
+ "end_time": 8.520342255188762,
739
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_3_B.wav",
740
+ "silence_duration": 0.5422287645271964,
741
+ "is_interrupted": false
742
+ },
743
+ {
744
+ "speaker": "A",
745
+ "text": "Yeah, I'm just trying to stay out of",
746
+ "original_text": "Yeah, I'm just trying to stay out of [interrupt] everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately.",
747
+ "start_time": 8.88750351109664,
748
+ "end_time": 18.059385597264438,
749
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_4_A.wav",
750
+ "silence_duration": 0.3671612559078772,
751
+ "is_interrupted": true,
752
+ "text_after_interrupt": "everyone's way and focus on my own things because I've been feeling a bit overwhelmed with school and social stuff lately."
753
+ },
754
+ {
755
+ "speaker": "B",
756
+ "text": "Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you.",
757
+ "original_text": "Trouble? Is everything okay? I mean, you know you can always talk to me if something's bothering you.",
758
+ "start_time": 11.697118023568294,
759
+ "end_time": 18.2915851437497,
760
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_5_B.wav",
761
+ "silence_duration": 0.32519714638310315,
762
+ "is_interrupted": false
763
+ },
764
+ {
765
+ "speaker": "A",
766
+ "text": "I'm just kind of going through some stuff right now.",
767
+ "original_text": "I'm just kind of going through some stuff right now.",
768
+ "start_time": 18.62204195980515,
769
+ "end_time": 21.396826540304016,
770
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_6_A.wav",
771
+ "silence_duration": 0.3304568160554501,
772
+ "is_interrupted": false
773
+ },
774
+ {
775
+ "speaker": "B",
776
+ "text": "Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being.",
777
+ "original_text": "Well, your father and I were just talking about how we need to have a serious talk with you about some things that have been going on around the house and how you've been feeling lately because we've noticed some changes in your behavior and we're genuinely concerned about your well-being.",
778
+ "start_time": 21.697523952118004,
779
+ "end_time": 34.7355284872654,
780
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_7_B.wav",
781
+ "silence_duration": 0.30069741181398774,
782
+ "is_interrupted": false
783
+ },
784
+ {
785
+ "speaker": "A",
786
+ "text": "Wait, is this about me staying out late last weekend?",
787
+ "original_text": "Wait, is this about me staying out late last weekend?",
788
+ "start_time": 35.29912687220732,
789
+ "end_time": 38.677630273567864,
790
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_8_A.wav",
791
+ "silence_duration": 0.5635983849419206,
792
+ "is_interrupted": false
793
+ },
794
+ {
795
+ "speaker": "B",
796
+ "text": "Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you.",
797
+ "original_text": "Not just that, but it's part of it. We've also noticed you've been acting a bit differently lately, and we're just wondering if everything is okay with you.",
798
+ "start_time": 39.09678068392148,
799
+ "end_time": 45.99310721453372,
800
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_9_B.wav",
801
+ "silence_duration": 0.4191504103536184,
802
+ "is_interrupted": false
803
+ },
804
+ {
805
+ "speaker": "A",
806
+ "text": "I don't know, Mom. Like I said, I'm just dealing with some stuff.",
807
+ "original_text": "I don't know, Mom. Like I said, I'm just dealing with some stuff.",
808
+ "start_time": 46.3670775788443,
809
+ "end_time": 50.46539957430915,
810
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_10_A.wav",
811
+ "silence_duration": 0.3739703643105766,
812
+ "is_interrupted": false
813
+ },
814
+ {
815
+ "speaker": "B",
816
+ "text": "Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah.",
817
+ "original_text": "Okay. Well, if you ever want to talk about anything, we're here for you. We love you, Moriah.",
818
+ "start_time": 50.99388055366539,
819
+ "end_time": 56.06744064436834,
820
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_11_B.wav",
821
+ "silence_duration": 0.5284809793562373,
822
+ "is_interrupted": false
823
+ },
824
+ {
825
+ "speaker": "A",
826
+ "text": "I love you too, Mom.",
827
+ "original_text": "I love you too, Mom.",
828
+ "start_time": 56.55062063706958,
829
+ "end_time": 57.99025782527819,
830
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1794/temp/line_12_A.wav",
831
+ "silence_duration": 0.4831799927012399,
832
+ "is_interrupted": false
833
+ }
834
+ ]
835
+ },
836
+ "SODA_PROCESSED--train--1070688": {
837
+ "original_text": "A: Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're [interrupt] ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.\nB: Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.\nA: No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.\nB: Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.\nA: I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
838
+ "cleaned_text": "A:Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.\nB: Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.\nA: No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.\nB: Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.\nA: I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
839
+ "total_duration": 66.58453514739229,
840
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/stereo_dialogue.wav",
841
+ "speaker_tracks": {
842
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/A_track.wav",
843
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/B_track.wav"
844
+ },
845
+ "error_type": "error_after_interrupt",
846
+ "segments": [
847
+ {
848
+ "speaker": "A",
849
+ "text": "Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're",
850
+ "original_text": "Hi Karis, I'm so excited to have you over for dinner tonight. I've been planning the menu and setting the table all day. I hope you're [interrupt] ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met.",
851
+ "start_time": 0,
852
+ "end_time": 16.172698412698413,
853
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_0_A.wav",
854
+ "silence_duration": 0,
855
+ "is_interrupted": true,
856
+ "text_after_interrupt": "ready for a cozy evening with some delicious food and great conversation about your recent travels through Europe that you mentioned last time we met."
857
+ },
858
+ {
859
+ "speaker": "B",
860
+ "text": "Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.",
861
+ "original_text": "Oh, I just remembered—I have a slight allergy to shellfish. I know you usually avoid it, but I wanted to mention it just in case.",
862
+ "start_time": 8.719092970521542,
863
+ "end_time": 15.650249433106577,
864
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_1_B.wav",
865
+ "silence_duration": 0.42791712549357114,
866
+ "is_interrupted": false
867
+ },{
868
+ "speaker": "A",
869
+ "text": "No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.",
870
+ "original_text": "No worries, there's no shellfish on the menu tonight. Well, let's get started then! For our first course, we'll be having a spinach and feta salad. The feta is from a local farm and the spinach is from my garden. For our main course, I've made chicken Parmesan with homemade tomato sauce and fresh mozzarella cheese. And for dessert, we'll be having tiramisu that I made from scratch this afternoon. I wanted it to be just right for tonight.",
871
+ "start_time": 16.66087863834312,
872
+ "end_time": 43.38704643879663,
873
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_2_A.wav",
874
+ "silence_duration": 0.488180225644707,
875
+ "is_interrupted": false
876
+ },
877
+ {
878
+ "speaker": "B",
879
+ "text": "Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.",
880
+ "original_text": "Tiramisu? That's my favorite dessert! I'm so excited to try it. You really know how to make a meal special.",
881
+ "start_time": 43.75020989775093,
882
+ "end_time": 49.926717834258866,
883
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_3_B.wav",
884
+ "silence_duration": 0.36316345895429397,
885
+ "is_interrupted": false
886
+ },
887
+ {
888
+ "speaker": "A",
889
+ "text": "I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
890
+ "original_text": "I'm glad you're excited! I was about to say I made it this morning using a special family recipe that's been passed down through generations, so it's extra fresh and has that authentic Italian flavor you can't find in restaurants. I hope you enjoy everything!",
891
+ "start_time": 50.49314394878711,
892
+ "end_time": 66.58457252021569,
893
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--1070688/temp/line_4_A.wav",
894
+ "silence_duration": 0.5664261145282402,
895
+ "is_interrupted": false
896
+ }
897
+ ]
898
+ },
899
+ "SODA_PROCESSED--train--737676": {
900
+ "original_text": "A: Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your [interrupt] car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.\nB: I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?\nA: Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.\nB: Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.\nA: Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
901
+ "cleaned_text": "A:Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.\nB: I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?\nA: Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.\nB: Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.\nA: Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
902
+ "total_duration": 52.89809523809524,
903
+ "stereo_audio": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/stereo_dialogue.wav",
904
+ "speaker_tracks": {
905
+ "A": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/A_track.wav",
906
+ "B": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/B_track.wav"
907
+ },
908
+ "error_type": "error_after_interrupt",
909
+ "segments": [
910
+ {
911
+ "speaker": "A",
912
+ "text": "Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your",
913
+ "original_text": "Hey, Miraya. I'm sorry about what happened with the car last night. I was really angry and I didn't mean to take it out on your [interrupt] car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply.",
914
+ "start_time": 0,
915
+ "end_time": 16.938956916099773,
916
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_0_A.wav",
917
+ "silence_duration": 0,
918
+ "is_interrupted": true,
919
+ "text_after_interrupt": "car like that, especially since it's your most valuable possession and you've always taken such good care of it. I know it was wrong, and I regret it deeply."
920
+ },
921
+ {
922
+ "speaker": "B",
923
+ "text": "I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?",
924
+ "original_text": "I understand, Stephon. But what exactly made you so angry? Was it something specific about what happened earlier in the week?",
925
+ "start_time": 8.753922902494331,
926
+ "end_time": 15.348390022675737,
927
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_1_B.wav",
928
+ "silence_duration": 0.5553895116856843,
929
+ "is_interrupted": false
930
+ },
931
+ {
932
+ "speaker": "A",
933
+ "text": "Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.",
934
+ "original_text": "Yeah, it did. I was really mad at you for a while after that. But I know it wasn't your fault and I shouldn't have taken it out on your car like that.",
935
+ "start_time": 17.329799609194744,
936
+ "end_time": 26.582951536632386,
937
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_2_A.wav",
938
+ "silence_duration": 0.3908426930949695,
939
+ "is_interrupted": false
940
+ },
941
+ {
942
+ "speaker": "B",
943
+ "text": "Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.",
944
+ "original_text": "Well, since you're being honest and apologetic about it, I don't think there's anything else you need to do other than maybe just be more mindful in the future about how you express your emotions, especially when you're upset, because lashing out at objects or people never really solves the underlying issue and often makes things worse.",
945
+ "start_time": 26.900238001740547,
946
+ "end_time": 44.05978448700132,
947
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_3_B.wav",
948
+ "silence_duration": 0.3172864651081615,
949
+ "is_interrupted": false
950
+ },
951
+ {
952
+ "speaker": "A",
953
+ "text": "Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
954
+ "original_text": "Absolutely, I'll work on that. And I really appreciate you being so understanding about this. Thanks for giving me the chance to talk it out.",
955
+ "start_time": 44.64342590433178,
956
+ "end_time": 52.8981197818828,
957
+ "audio_file": "/root/autodl-tmp/output_matches_soda/SODA_PROCESSED--train--737676/temp/line_4_A.wav",
958
+ "silence_duration": 0.5836414173304574,
959
+ "is_interrupted": false
960
+ }
961
+ ]
962
+ }
963
+ }
ms-swift/swift/llm/sampling/mcts.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from copy import deepcopy
5
+
6
+ import json
7
+ import numpy as np
8
+
9
+ from swift.llm import InferRequest, SamplingArguments
10
+ from swift.llm.infer.protocol import UsageInfo
11
+ from swift.utils import get_logger
12
+ from .base import Sampler
13
+ from .utils import get_reward, perform_infer
14
+
15
+ logger = get_logger()
16
+
17
+ NXT_PROMPT = """Continue.
18
+ """
19
+
20
+ next_message = {
21
+ 'role': 'user',
22
+ 'content': NXT_PROMPT,
23
+ }
24
+
25
+
26
+ class LanguageNode:
27
+
28
+ def __init__(
29
+ self,
30
+ step: str = None,
31
+ sep_token: str = None,
32
+ parent: 'LanguageNode' = None,
33
+ ):
34
+ self.parent = parent
35
+
36
+ if sep_token:
37
+ self.sep_token = sep_token
38
+ else:
39
+ self.sep_token = parent.sep_token
40
+
41
+ if parent:
42
+ self.path = parent.path[:] + [step]
43
+ self.answer = parent.answer + step + self.sep_token
44
+ self.depth = parent.depth + 1
45
+ else:
46
+ self.path = []
47
+ self.answer = ''
48
+ self.depth = 0
49
+
50
+ self.active_children = []
51
+ self.children = []
52
+ self.visit_count = 0
53
+ self.process_reward = 0.0
54
+ self.outcome_reward = 0.0
55
+ self.terminated = False
56
+ self.correct = False
57
+
58
+ def is_leaf(self):
59
+ return len(self.children) == 0
60
+
61
+ def is_root(self):
62
+ return self.parent is None
63
+
64
+ def visit(self):
65
+ self.visit_count += 1
66
+
67
+ def init_and_update_value(self, value):
68
+ self.outcome_reward = (self.outcome_reward * self.visit_count + value) / (self.visit_count + 1)
69
+
70
+ def add_child(self, child: 'LanguageNode'):
71
+ self.children.append(child)
72
+ if not child.terminated:
73
+ self.active_children.append(child)
74
+
75
+ def collect(self):
76
+ result = {
77
+ 'path': self.path,
78
+ 'depth': self.depth,
79
+ 'visit_count': self.visit_count,
80
+ 'process_reward': self.process_reward,
81
+ 'outcome_reward': self.outcome_reward,
82
+ 'terminated': str(self.terminated),
83
+ 'correct': str(self.correct),
84
+ 'children': [child.collect() for child in self.children],
85
+ }
86
+ return result
87
+
88
+ def __lt__(self, other):
89
+ return self.outcome_reward < other.outcome_reward
90
+
91
+
92
+ class MctsSampler(Sampler):
93
+
94
+ def __init__(self, input_args: SamplingArguments):
95
+ super().__init__(input_args)
96
+ self.usage_info = UsageInfo(0, 0, 0)
97
+
98
+ def _prepare_model_tokenizer(self):
99
+ args = self.args
100
+ self.infer_kwargs = {}
101
+ if args.sampler_engine == 'client':
102
+ from swift.llm import InferClient
103
+ api_key = args.api_key
104
+ base_url = args.base_url
105
+ self.infer_engine = [
106
+ InferClient(base_url=base_url, api_key=api_key) for _ in range(args.num_return_sequences)
107
+ ]
108
+ self.infer_kwargs['model'] = args.model
109
+ else:
110
+ _Engine = self.get_infer_engine()
111
+ self.infer_engine = _Engine(self.args.model, model_type=self.args.model_type, **self.args.engine_kwargs)
112
+
113
+ def get_infer_engine(self):
114
+ if self.args.sampler_engine == 'pt':
115
+ from swift.llm import PtEngine
116
+ _Engine = PtEngine
117
+ elif self.args.sampler_engine == 'vllm':
118
+ from swift.llm import VllmEngine
119
+ _Engine = VllmEngine
120
+ elif self.args.sampler_engine == 'lmdeploy':
121
+ from swift.llm import LmdeployEngine
122
+ _Engine = LmdeployEngine
123
+ elif self.args.sampler_engine == 'no':
124
+ _Engine = None
125
+ else:
126
+ raise ValueError(f'Cannot find engine name: {self.args.sampler_engine}')
127
+ return _Engine
128
+
129
+ def _prepare_template(self) -> None:
130
+ # Hack from super()
131
+ self._prepare_request_configs()
132
+
133
+ def _prepare_request_configs(self):
134
+ _args = self.args
135
+ request_config = _args.get_request_config()
136
+ request_config.stop = _args.stop_words
137
+ request_config.seed = _args.seed
138
+ self.expand_request_configs = []
139
+ self.rollout_request_configs = []
140
+ for i in range(_args.num_return_sequences):
141
+ expand_request_config = deepcopy(request_config)
142
+ expand_request_config.n = 1
143
+ expand_request_config.num_beams = expand_request_config.n
144
+ expand_request_config.seed += i
145
+ self.expand_request_configs.append(expand_request_config)
146
+ rollout_request_config = deepcopy(request_config)
147
+ rollout_request_config.max_tokens = 500
148
+ rollout_request_config.temperature = 0.0
149
+ rollout_request_config.n = 1
150
+ self.rollout_request_configs.append(rollout_request_config)
151
+
152
+ def update_usage_info(self, response):
153
+ for key, value in self.usage_info.__dict__.items():
154
+ update_value = getattr(response.usage, key, None) + value
155
+ setattr(self.usage_info, key, update_value)
156
+
157
+ def search_single(self, query, ground_truth):
158
+
159
+ def _uct(uct_curr_node: LanguageNode):
160
+ alpha = _args.process_reward_rate
161
+ value = alpha * uct_curr_node.process_reward + (1 - alpha) * uct_curr_node.outcome_reward
162
+ if uct_curr_node.is_root():
163
+ return value
164
+
165
+ exploitation_score = value
166
+ exploration_score = (
167
+ _args.exploration_rate
168
+ * np.sqrt(np.log(uct_curr_node.parent.visit_count + 1) / (uct_curr_node.visit_count + 1)))
169
+
170
+ return exploration_score + exploitation_score
171
+
172
+ def _select(select_curr_node: LanguageNode):
173
+ while not select_curr_node.is_leaf():
174
+ select_curr_node = max(select_curr_node.active_children, key=lambda x: _uct(x))
175
+ return select_curr_node
176
+
177
+ def _expand(expand_curr_node: LanguageNode):
178
+ n = _args.num_return_sequences - len(expand_curr_node.children)
179
+ if expand_curr_node.is_root():
180
+ infer_requests = [InferRequest(system_message + [prompt_message]) for _ in range(n)]
181
+ else:
182
+ history_message = {
183
+ 'role': 'assistant',
184
+ 'content': expand_curr_node.answer,
185
+ }
186
+ infer_request = InferRequest(system_message + [prompt_message, history_message, next_message])
187
+ infer_requests = [infer_request for _ in range(n)]
188
+
189
+ # e_time = time.time()
190
+ # To perform the Expand operation in parallel,
191
+ # there's no need to consider the order for now, since the Prompt is the same.
192
+ expand_iter_index = 0
193
+ while True:
194
+ responses = perform_infer(self.infer_engine, infer_requests, self.expand_request_configs,
195
+ **self.infer_kwargs)
196
+ if len(responses) > 0:
197
+ break
198
+ if expand_iter_index == 5:
199
+ raise ValueError('Expand should not return any response')
200
+ expand_iter_index += 1
201
+ # logger.info(f"expand.expand time: {time.time() - e_time}")
202
+
203
+ # To fetch Outcome Reward in parallel,
204
+ # the Outcome-Reward obtained is returned in order, so they can be directly matched accordingly.
205
+ orm_infer_requests = []
206
+ unique_output = set()
207
+ for response in responses:
208
+ self.update_usage_info(response)
209
+ output = response.choices[0].message.content.rstrip(sep_token + '\n').split(sep_token)[0]
210
+ if output in unique_output:
211
+ continue
212
+ unique_output.add(output)
213
+ orm_infer_requests.append(InferRequest([{'role': 'assistant', 'content': output}]))
214
+ child = LanguageNode(step=output, parent=expand_curr_node)
215
+ if self.orm_model.check_terminate(child.answer)[0]:
216
+ child.terminated = True
217
+ expand_curr_node.add_child(child)
218
+
219
+ # e_time = time.time()
220
+ orm_score, _orm_mask = get_reward(
221
+ self.orm_model,
222
+ orm_infer_requests,
223
+ ground_truths=[ground_truth] * len(orm_infer_requests),
224
+ threshold=0.0)
225
+ # logger.info(f"expand.orm time: {time.time() - e_time}")
226
+ for child, score in zip(expand_curr_node.children, orm_score):
227
+ if child.terminated:
228
+ child.init_and_update_value(score)
229
+ child.correct = score > 0.9
230
+ terminated_nodes.append(child)
231
+
232
+ # e_time = time.time()
233
+ if self.prm_model:
234
+ prm_infer_requests = []
235
+ for child in expand_curr_node.children:
236
+ prm_message = {'role': 'assistant', 'content': child.answer}
237
+ prm_infer_requests.append(InferRequest([prompt_message, prm_message]))
238
+ prm_score, _prm_mask = get_reward(
239
+ self.prm_model,
240
+ prm_infer_requests,
241
+ ground_truths=[ground_truth] * len(prm_infer_requests),
242
+ threshold=0.0)
243
+ for child, score in zip(expand_curr_node.children, prm_score):
244
+ child.process_reward = score
245
+ # logger.info(f"expand.prm time: {time.time() - e_time}")
246
+
247
+ def _rollout(rollout_curr_node: LanguageNode):
248
+ rollout_depth = 0
249
+ rollout_nodes = {}
250
+ for i in range(len(rollout_curr_node.active_children)):
251
+ rollout_nodes[i] = {
252
+ 'node': rollout_curr_node.active_children[i],
253
+ 'history_messages': {
254
+ 'role': 'assistant',
255
+ 'content': rollout_curr_node.active_children[i].answer,
256
+ },
257
+ }
258
+ active_rollout_nodes = list(rollout_nodes.keys())
259
+ while len(active_rollout_nodes) > 0 and rollout_depth < _args.rollout_depth:
260
+ # r_time = time.time()
261
+ infer_requests = [
262
+ InferRequest(system_message
263
+ + [prompt_message, rollout_nodes[index]['history_messages'], next_message])
264
+ for index in active_rollout_nodes
265
+ ]
266
+ # logger.info(f"rollout.prepare time: {time.time() - r_time}")
267
+ # r_time = time.time()
268
+ rollout_iter_index = 0
269
+ while True:
270
+ responses = perform_infer(self.infer_engine, infer_requests, self.rollout_request_configs,
271
+ **self.infer_kwargs)
272
+ if len(responses) > 0:
273
+ break
274
+ if rollout_iter_index == 5:
275
+ raise ValueError('Rollout should not return any response')
276
+ rollout_iter_index += 1
277
+ # logger.info(f"rollout.infer time: {time.time() - r_time}")
278
+
279
+ # r_time = time.time()
280
+ orm_infer_requests = []
281
+ end_paths = []
282
+ for index, response in zip(active_rollout_nodes, responses):
283
+ self.update_usage_info(response)
284
+ output = response.choices[0].message.content.rstrip(sep_token
285
+ + '\n').split(sep_token)[0] + sep_token + '\n'
286
+ rollout_nodes[index]['history_messages']['content'] += output
287
+ end_paths.append(rollout_nodes[index]['history_messages']['content'])
288
+ orm_infer_requests.append(InferRequest([rollout_nodes[index]['history_messages']]))
289
+ # logger.info(f"rollout.orm_prepare time: {time.time() - r_time}")
290
+
291
+ # r_time = time.time()
292
+ orm_score, _orm_mask = get_reward(
293
+ self.orm_model,
294
+ orm_infer_requests,
295
+ ground_truths=[ground_truth] * len(infer_requests),
296
+ threshold=0.0)
297
+ # logger.info(f"rollout.get_orm time: {time.time() - r_time}")
298
+ terminated_state = self.orm_model.check_terminate(end_paths)
299
+ for index, score, terminated in zip(active_rollout_nodes, orm_score, terminated_state):
300
+ if terminated:
301
+ rollout_curr_node.active_children[index].init_and_update_value(score)
302
+ if score > 0.9:
303
+ rollout_correct_answers.append(rollout_nodes[index]['history_messages']['content'])
304
+ else:
305
+ rollout_incorrect_answers.append(rollout_nodes[index]['history_messages']['content'])
306
+ rollout_nodes.pop(index)
307
+ active_rollout_nodes = list(rollout_nodes.keys())
308
+ rollout_depth += 1
309
+
310
+ def _back_propagate(back_curr_node: LanguageNode):
311
+ while back_curr_node:
312
+ if back_curr_node == curr_node:
313
+ best_child_value = max([child.outcome_reward for child in back_curr_node.children])
314
+ back_curr_node.init_and_update_value(best_child_value)
315
+ last_child_value = back_curr_node.outcome_reward
316
+ else:
317
+ back_curr_node.init_and_update_value(last_child_value)
318
+ last_child_value = back_curr_node.outcome_reward
319
+ back_curr_node.visit()
320
+ if len(back_curr_node.active_children) == 0:
321
+ back_curr_node.terminated = True
322
+ if not back_curr_node.is_root():
323
+ back_curr_node.parent.active_children.remove(back_curr_node)
324
+ back_curr_node = back_curr_node.parent
325
+
326
+ _args = self.args
327
+ system_message = [] + _args.system_message
328
+ sep_token = _args.stop_words[0] + '\n'
329
+ _root = LanguageNode(sep_token=sep_token)
330
+ prompt_message = {
331
+ 'role': 'user',
332
+ 'content': query,
333
+ }
334
+
335
+ rollout_correct_answers, rollout_incorrect_answers, terminated_nodes = [], [], []
336
+ iter_count = 0
337
+ stop_reason = None
338
+ while True:
339
+ logger.info(f'iter_count: {iter_count}' + '.' * 10)
340
+ s_time = time.time()
341
+ curr_node = _select(_root)
342
+ logger.debug('select' + '=' * 10 + f'time: {time.time() - s_time}')
343
+ s_time = time.time()
344
+ _expand(curr_node)
345
+ logger.debug('expand' + '=' * 10 + f'time: {time.time() - s_time}')
346
+ if curr_node.depth > _args.rollout_start_depth:
347
+ s_time = time.time()
348
+ _rollout(curr_node)
349
+ logger.debug('rollout' + '=' * 10 + f'time: {time.time() - s_time}')
350
+ s_time = time.time()
351
+ _back_propagate(curr_node)
352
+ logger.debug('back propagate' + '=' * 10 + f'time: {time.time() - s_time}')
353
+ if len(rollout_correct_answers) + len(rollout_incorrect_answers) >= 2 * _args.num_return_sequences:
354
+ if 4 * len(rollout_incorrect_answers) < len(rollout_correct_answers):
355
+ stop_reason = 'too easy'
356
+ break
357
+ elif 4 * len(rollout_correct_answers) < len(rollout_incorrect_answers):
358
+ stop_reason = 'too hard'
359
+ break
360
+ if _root.terminated:
361
+ stop_reason = 'root terminated'
362
+ break
363
+ if len(terminated_nodes) >= _args.num_return_sequences:
364
+ stop_reason = 'enough nodes'
365
+ break
366
+ if iter_count >= _args.max_iterations:
367
+ stop_reason = 'max_iterations'
368
+ break
369
+ iter_count += 1
370
+ logger.info(f'stop_reason: {stop_reason}')
371
+ # logger.info(f"rollout_correct_answers: {rollout_correct_answers}")
372
+ # logger.info(f"rollout_incorrect_answers: {rollout_incorrect_answers}")
373
+
374
+ monte_carlo_tree = _root.collect()
375
+ result = {
376
+ 'query': query,
377
+ 'ground_truth': ground_truth,
378
+ 'rollout_correct_answers': rollout_correct_answers,
379
+ 'rollout_incorrect_answers': rollout_incorrect_answers,
380
+ 'monte_carlo_tree': monte_carlo_tree,
381
+ }
382
+ result_json = json.dumps(result, ensure_ascii=False)
383
+ logger.info(result_json)
384
+ return result_json
385
+
386
+ def do_sample(self, data):
387
+ if not isinstance(data, list):
388
+ data = [data]
389
+ generated = []
390
+ for item in data:
391
+ logger.info(f'time: {time.ctime(time.time())}')
392
+ try:
393
+ messages = item['messages'][0]
394
+ query = messages[0]['content']
395
+ ground_truth = messages[1]['content']
396
+ generated.append(self.search_single(query, ground_truth) + '\n')
397
+ except Exception as e:
398
+ logger.error(f'Error: {e}')
399
+ logger.error(f'Traceback: {traceback.format_exc()}')
400
+ return generated
ms-swift/swift/llm/template/template/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (606 Bytes). View file
 
ms-swift/swift/llm/template/template/__pycache__/emu3.cpython-310.pyc ADDED
Binary file (7.88 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/gemma.cpython-310.pyc ADDED
Binary file (5.91 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/internvl.cpython-310.pyc ADDED
Binary file (6.8 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/minicpm.cpython-310.pyc ADDED
Binary file (8.18 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/pixtral.cpython-310.pyc ADDED
Binary file (2.3 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/stepfun.cpython-310.pyc ADDED
Binary file (6.57 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/valley.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
ms-swift/swift/llm/template/template/__pycache__/yi.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
ms-swift/swift/llm/template/template/deepseek.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from PIL import Image
10
+
11
+ from swift.utils import get_env_args
12
+ from ..base import Template
13
+ from ..constant import LLMTemplateType, MLLMTemplateType
14
+ from ..register import TemplateMeta, register_template
15
+ from ..template_inputs import StdTemplateInputs
16
+ from ..utils import Prompt, findall
17
+
18
+
19
+ @dataclass
20
+ class DeepseekTemplateMeta(TemplateMeta):
21
+ prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
22
+ prompt: Prompt = field(default_factory=lambda: ['User: {{QUERY}}\n\nAssistant:'])
23
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: [['eos_token_id']])
24
+ suffix: Prompt = field(default_factory=lambda: [['eos_token_id']])
25
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: [['bos_token_id'], '{{SYSTEM}}\n\n'])
26
+
27
+
28
+ register_template(DeepseekTemplateMeta(LLMTemplateType.deepseek, ))
29
+
30
+ register_template(
31
+ TemplateMeta(
32
+ LLMTemplateType.deepseek_coder,
33
+ prefix=['{{SYSTEM}}'],
34
+ prompt=['### Instruction:\n{{QUERY}}\n### Response:\n'],
35
+ chat_sep=['\n<|EOT|>\n'],
36
+ suffix=['\n<|EOT|>'],
37
+ stop_words=['<|EOT|>'],
38
+ default_system=('You are an AI programming assistant, utilizing the Deepseek Coder model, '
39
+ 'developed by Deepseek Company, and you only answer questions related to computer science. '
40
+ 'For politically sensitive questions, security and privacy issues, '
41
+ 'and other non-computer science questions, you will refuse to answer\n')))
42
+
43
+
44
+ class DeepseekVLTemplate(Template):
45
+ image_placeholder = ['<image_placeholder>']
46
+ skip_prompt = False
47
+ use_model = True
48
+ placeholder_tokens = ['<image_placeholder>']
49
+
50
+ image_token_num_per_image: int = 576
51
+
52
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
53
+ is_janus = getattr(self, 'is_janus', False)
54
+
55
+ encoded = super()._encode(inputs)
56
+ images = inputs.images
57
+ processor = self.processor
58
+ input_ids, labels = encoded['input_ids'], encoded['labels']
59
+
60
+ if not inputs.generate_mode: # understanding task
61
+ idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
62
+ new_input_ids, new_labels = [], []
63
+ lo = 0
64
+ for hi in idx_list:
65
+ new_input_ids += input_ids[lo:hi]
66
+ if labels is not None:
67
+ new_labels += labels[lo:hi]
68
+ image_tokens = [processor.image_id] * processor.num_image_tokens
69
+ if is_janus:
70
+ image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
71
+ new_input_ids += image_tokens
72
+ new_labels += [-100] * len(image_tokens)
73
+ lo = hi + 1
74
+ new_input_ids += input_ids[lo:]
75
+ if labels is not None:
76
+ new_labels += labels[lo:]
77
+ else:
78
+ new_labels = None
79
+ if is_janus:
80
+ from janus.models.processing_vlm import VLChatProcessorOutput
81
+ else:
82
+ from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
83
+
84
+ images_outputs = processor.image_processor(images, return_tensors='pt')
85
+ output = VLChatProcessorOutput(
86
+ sft_format=None,
87
+ input_ids=torch.tensor(new_input_ids),
88
+ pixel_values=images_outputs.pixel_values,
89
+ num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
90
+ encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
91
+ return encoded
92
+
93
+ else: # image generation task
94
+ if self.is_training:
95
+ raise NotImplementedError('Only support the inference of generation of Janus series models.')
96
+ sft_format = self.tokenizer.decode(input_ids)
97
+ prompt = sft_format + processor.image_start_tag
98
+ input_ids = processor.tokenizer.encode(prompt)
99
+ input_ids = torch.LongTensor(input_ids)
100
+
101
+ encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
102
+ return encoded
103
+
104
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
105
+ if not inputs.get('generate_mode'):
106
+ inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.model_info.torch_dtype)
107
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
108
+ return {'inputs_embeds': inputs_embeds}
109
+ else:
110
+ return inputs
111
+
112
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
113
+ gene_img_list = [b.get('generate_mode') for b in batch]
114
+ if all(gene_img_list):
115
+ generate_mode = True
116
+ elif not any(gene_img_list):
117
+ generate_mode = False
118
+ else:
119
+ raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')
120
+
121
+ if not generate_mode:
122
+ output = self.fetch_inputs(batch, ['output'])['output']
123
+ batched_output = dict(self.processor.batchify(output))
124
+ res = super()._data_collator(batch, padding_to=padding_to)
125
+ return {**batched_output, **res}
126
+ else:
127
+ res = super()._data_collator(batch, padding_to=padding_to)
128
+ res['generate_mode'] = generate_mode
129
+ return res
130
+
131
+ def generate(self, model, *args, **kwargs):
132
+ if not kwargs.get('generate_mode'):
133
+ return super().generate(model, *args, **kwargs)
134
+
135
+ else:
136
+ # generate how many number of images for each prompt, it is named parallel_size in the author's code
137
+ parallel_size = kwargs['generation_config'].num_return_sequences
138
+ temperature = kwargs['generation_config'].temperature
139
+ cfg_weight = get_env_args('cfg_weight', float, 5.0)
140
+
141
+ input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
142
+ bsz, max_input_token_num = input_ids.shape
143
+ tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
144
+ dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
145
+ for i in range(parallel_size * 2):
146
+ tokens[:, i, :] = input_ids
147
+ if i % 2 != 0:
148
+ tokens[:, i, 1:-1] = self.processor.pad_id
149
+
150
+ inputs_embeds = model.language_model.get_input_embeddings()(
151
+ tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]
152
+
153
+ generated_tokens = torch.zeros(
154
+ (bsz, parallel_size, self.image_token_num_per_image),
155
+ dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens
156
+
157
+ # set the first two dimensions into one dimension for batch size
158
+ inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
159
+ generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)
160
+
161
+ for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
162
+ outputs = model.language_model.model(
163
+ inputs_embeds=inputs_embeds,
164
+ use_cache=True,
165
+ past_key_values=outputs.past_key_values if i != 0 else None)
166
+ hidden_states = outputs.last_hidden_state
167
+
168
+ logits = self.model.gen_head(hidden_states[:, -1, :])
169
+ logit_cond = logits[0::2, :]
170
+ logit_uncond = logits[1::2, :]
171
+
172
+ logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
173
+ probs = torch.softmax(logits / temperature, dim=-1)
174
+
175
+ next_token = torch.multinomial(probs, num_samples=1)
176
+ generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]
177
+
178
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
179
+ img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
180
+ inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]
181
+
182
+ # no need to reset the original first two dimensions, waiting for the update of the upper layer
183
+ # inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
184
+ # generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)
185
+
186
+ return {'sequences': generated_tokens}
187
+
188
+ def decode(self, generate_ids: List[int], **kwargs) -> Any:
189
+ if 'template_inputs' not in kwargs or not kwargs['template_inputs'].generate_mode:
190
+ return super().decode(generate_ids, **kwargs)
191
+ else:
192
+ img_size = get_env_args('img_size', int, 384)
193
+ patch_size = 16
194
+
195
+ num_to_decode = 1 # for now, generate_ids is a 1D list
196
+
197
+ generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]
198
+
199
+ dec = self.model.gen_vision_model.decode_code(
200
+ generate_ids.to(dtype=torch.int),
201
+ shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
202
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]
203
+
204
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
205
+
206
+ visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
207
+ visual_img[:, :, :] = dec
208
+
209
+ img_list = []
210
+ for i in range(num_to_decode):
211
+ cur_img = Image.fromarray(visual_img[i])
212
+ img_list.append({'type': 'image', 'image': cur_img})
213
+ return img_list
214
+
215
+
216
+ @dataclass
217
+ class DeepseekVLTemplateMeta(DeepseekTemplateMeta):
218
+ default_system: Optional[str] = ('You are a helpful language and vision assistant. '
219
+ 'You are able to understand the visual content that the user provides, '
220
+ 'and assist the user with a variety of tasks using natural language.')
221
+
222
+
223
+ register_template(DeepseekVLTemplateMeta(
224
+ MLLMTemplateType.deepseek_vl,
225
+ template_cls=DeepseekVLTemplate,
226
+ ))
227
+
228
+
229
+ class DeepseekJanus(DeepseekVLTemplate):
230
+ is_janus = True
231
+ image_placeholder = ['<image_placeholder>\n']
232
+
233
+
234
+ register_template(DeepseekVLTemplateMeta(MLLMTemplateType.deepseek_janus, template_cls=DeepseekJanus))
235
+
236
+
237
+ @dataclass
238
+ class DeepseekV2_5TemplateMeta(TemplateMeta):
239
+ prefix: Prompt = field(default_factory=lambda: ['<|begin▁of▁sentence|>{{SYSTEM}}'])
240
+ prompt: Prompt = field(default_factory=lambda: ['<|User|>{{QUERY}}<|Assistant|>'])
241
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
242
+ suffix: Prompt = field(default_factory=lambda: ['<|end▁of▁sentence|>'])
243
+
244
+
245
+ register_template(DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_v2_5))
246
+
247
+
248
+ class DeepseekR1Template(Template):
249
+
250
+ def _swift_encode(self, inputs: StdTemplateInputs):
251
+ if not self.is_training:
252
+ for message in inputs.messages:
253
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
254
+ message['content'] = message['content'].split('</think>')[-1]
255
+ return super()._swift_encode(inputs)
256
+
257
+
258
+ register_template(
259
+ DeepseekV2_5TemplateMeta(LLMTemplateType.deepseek_r1, template_cls=DeepseekR1Template, response_prefix='<think>\n'))
260
+
261
+
262
+ class DeepseekVL2Template(DeepseekVLTemplate):
263
+ image_placeholder = ['<image>\n']
264
+ placeholder_tokens = ['<image>']
265
+
266
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
267
+ from deepseek_vl2.models.processing_deepseek_vl_v2 import VLChatProcessorOutput
268
+ encoded = Template._encode(self, inputs)
269
+ images = inputs.images
270
+ processor = self.processor
271
+ input_ids, labels = encoded['input_ids'], encoded['labels']
272
+ images_seq_mask = [False] * len(input_ids)
273
+ idx_list = findall(input_ids, processor.image_token_id) # '<image>'
274
+ _, images_list, _, images_spatial_crop, num_image_tokens = processor.tokenize_with_images(
275
+ '<image>' * len(images), images, cropping=len(images) <= 2)
276
+ new_num_tokens = 0
277
+ for idx, n_image_tokens in zip(idx_list, num_image_tokens):
278
+ image_tokens = [processor.image_token_id] * n_image_tokens
279
+ input_ids = input_ids[:idx] + image_tokens + input_ids[idx + 1:]
280
+ if labels is not None:
281
+ labels = labels[:idx] + [-100] * n_image_tokens + labels[idx + 1:]
282
+ images_seq_mask = images_seq_mask[:idx] + [True] * n_image_tokens + images_seq_mask[idx + 1:]
283
+ new_num_tokens += n_image_tokens - 1
284
+
285
+ output = VLChatProcessorOutput(
286
+ sft_format=None,
287
+ input_ids=torch.tensor(input_ids),
288
+ target_ids=torch.tensor(input_ids),
289
+ images=torch.stack(images_list) if images_list else torch.zeros((0, 3, 384, 384)),
290
+ images_seq_mask=torch.tensor(images_seq_mask),
291
+ images_spatial_crop=torch.tensor(images_spatial_crop),
292
+ num_image_tokens=num_image_tokens)
293
+ output.images = output.images.to(dtype=self.model_info.torch_dtype)
294
+ encoded = {'output': output, 'input_ids': input_ids, 'labels': labels}
295
+ return encoded
296
+
297
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
298
+ inputs['images_seq_mask'] = inputs['images_seq_mask'].to(torch.bool)
299
+ inputs['images_spatial_crop'] = inputs['images_spatial_crop'].to(torch.long)
300
+ inputs_embeds = model.prepare_inputs_embeds(**inputs)
301
+ return {'inputs_embeds': inputs_embeds}
302
+
303
+
304
+ register_template(
305
+ DeepseekV2_5TemplateMeta(
306
+ MLLMTemplateType.deepseek_vl2,
307
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
308
+ template_cls=DeepseekVL2Template,
309
+ ))
310
+
311
+ register_template(
312
+ DeepseekVLTemplateMeta(
313
+ MLLMTemplateType.deepseek_janus_pro,
314
+ prompt=['<|User|>: {{QUERY}}\n\n<|Assistant|>:'],
315
+ template_cls=DeepseekJanus))
ms-swift/swift/llm/template/template/glm.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+
7
+ from ..base import Template
8
+ from ..constant import LLMTemplateType, MLLMTemplateType
9
+ from ..register import TemplateMeta, register_template
10
+ from ..template_inputs import StdTemplateInputs
11
+ from ..utils import Context, Prompt, Word, findall
12
+ from ..vision_utils import load_batch, load_video_cogvlm2
13
+
14
+
15
+ @dataclass
16
+ class GLMTemplateMeta(TemplateMeta):
17
+ auto_add_bos: bool = True
18
+
19
+
20
+ class GLM4Template(Template):
21
+
22
+ def _swift_encode(self, inputs: StdTemplateInputs):
23
+ res_context_list, loss_scale_list, answer_len = super()._swift_encode(inputs)
24
+ for i, res_context in enumerate(res_context_list):
25
+ # The last round or is tool_call.
26
+ if isinstance(res_context, str) and res_context.endswith('<|assistant|>\n') and (
27
+ i + 1 >= len(res_context_list) or '<|observation|>' in res_context_list[i + 1]):
28
+ res_context_list[i] = res_context_list[i][:-len('\n')]
29
+ return res_context_list, loss_scale_list, answer_len
30
+
31
+ def decode(self, *args, **kwargs):
32
+ response = super().decode(*args, **kwargs)
33
+ return response.lstrip('\n')
34
+
35
+
36
+ class GLM4_0414Template(GLM4Template):
37
+
38
+ def _swift_encode(self, inputs: StdTemplateInputs):
39
+ if not self.is_training:
40
+ for message in inputs.messages:
41
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
42
+ message['content'] = message['content'].split('</think>')[-1].strip()
43
+ return super()._swift_encode(inputs)
44
+
45
+
46
+ register_template(
47
+ GLMTemplateMeta(
48
+ LLMTemplateType.chatglm2,
49
+ prefix=['{{SYSTEM}}'],
50
+ prompt=['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'],
51
+ chat_sep=['\n\n']))
52
+
53
+
54
+ @dataclass
55
+ class GLM4TemplateMeta(GLMTemplateMeta):
56
+ prefix: Prompt = field(default_factory=list)
57
+ prompt: Prompt = field(default_factory=lambda: ['<|user|>\n{{QUERY}}<|assistant|>\n'])
58
+ chat_sep: Optional[Prompt] = field(default_factory=list)
59
+ suffix: Prompt = field(default_factory=lambda: ['<|user|>'])
60
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|system|>\n{{SYSTEM}}'])
61
+
62
+ agent_template: str = 'glm4'
63
+ stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>', '<|user|>', '<|observation|>'])
64
+
65
+
66
+ @dataclass
67
+ class GLM4_0414TemplateMeta(GLM4TemplateMeta):
68
+ prefix: Prompt = field(default_factory=lambda: ['[gMASK]<sop>'])
69
+ system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[gMASK]<sop><|system|>\n{{SYSTEM}}'])
70
+ agent_template: str = 'glm4_0414'
71
+
72
+
73
+ class GLM4VTemplate(Template):
74
+
75
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
76
+ inputs: StdTemplateInputs) -> List[Context]:
77
+ assert media_type == 'image'
78
+ return [[-100]]
79
+
80
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
81
+ encoded = super()._encode(inputs)
82
+ input_ids = encoded['input_ids']
83
+ labels = encoded['labels']
84
+ idx_list = findall(input_ids, -100)
85
+ if idx_list:
86
+ idx = idx_list[0]
87
+ image = inputs.images[0]
88
+ placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
89
+ placeholder_id = self.processor.encode(placeholder, add_special_tokens=False)
90
+ input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
91
+ if labels is not None:
92
+ labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
93
+ messages = inputs.messages
94
+ messages[0]['image'] = image
95
+ inputs2: Dict[str, Any] = self.processor.apply_chat_template(messages, return_dict=True)
96
+ encoded['images'] = inputs2['images']
97
+ encoded['input_ids'] = input_ids
98
+ encoded['labels'] = labels
99
+ encoded['position_ids'] = list(range(len(input_ids)))
100
+ return encoded
101
+
102
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
103
+ res = super()._data_collator(batch, padding_to=padding_to)
104
+ images = [b['images'] for b in batch if 'images' in b]
105
+ if images:
106
+ res['images'] = torch.concat(images)
107
+ return res
108
+
109
+
110
+ register_template(GLM4TemplateMeta(MLLMTemplateType.glm4v, template_cls=GLM4VTemplate, suffix=['<|endoftext|>']))
111
+
112
+ register_template(GLM4TemplateMeta(LLMTemplateType.glm4, template_cls=GLM4Template))
113
+
114
+ register_template(GLM4_0414TemplateMeta(LLMTemplateType.glm4_0414, template_cls=GLM4_0414Template))
115
+
116
+ glm4z1rumination_system = (
117
+ '你是一个专业的深度研究助手,通过提供的工具与模拟浏览器交互,来帮助用户完成深度信息调研和报告撰写任务。'
118
+ '今年是 2025 年。\n\n'
119
+ '<核心要求>\n'
120
+ '- 首先分解用户请求,得到包含多个子要求的列表\n'
121
+ '- 制定初始研究计划\n'
122
+ '- 进行多轮迭代搜索和页面浏览(at least 10 function calls):\n'
123
+ ' * 根据已获得的信息调整研究计划和关键词\n'
124
+ ' * 打开页面阅读,从发现的内容中识别新的关键概念/名词\n'
125
+ ' * 从搜索结果中提取新的关键词继续搜索\n'
126
+ ' * 访问并仔细阅读相关页面,识别新的关键概念/名词\n\n'
127
+ '<重要配置>\n'
128
+ '- 采用语言\n'
129
+ ' * 搜索关键词:英语\n'
130
+ ' * 思考:英语\n\n'
131
+ '<可调用的工具列表>\n\n'
132
+ '[{"name": "search", "description": "Execute a search query and return search results. '
133
+ 'Use this function when you need to find information about a specific topic.", '
134
+ '"parameters": {"type": "object", "properties": {"query": {"type": "string", '
135
+ '"description": "Search query string, use English words unless it is a proper name in Chinese"}}, '
136
+ '"required": ["query"], "additionalProperties": false}}, '
137
+ '{"name": "click", "description": "Click a link in the search results and navigate to the corresponding page. '
138
+ 'Use this function when you need to view detailed content of a specific search result.", '
139
+ '"parameters": {"type": "object", "properties": {"link_id": {"type": "integer", '
140
+ '"description": "The link ID to click (from the sequence number in search results)"}}, '
141
+ '"required": ["link_id"], "additionalProperties": false}}, '
142
+ '{"name": "open", "description": "Open a specific website. Get content from any website with its URL.", '
143
+ '"parameters": {"type": "object", "properties": {"url": {"type": "string", '
144
+ '"description": "The target website URL or domain"}}, "required": ["url"], "additionalProperties": false}}, '
145
+ '{"name": "finish", "description": "Finish the task. '
146
+ 'Use this function when you have found the information you need.", '
147
+ '"parameters": {"type": "object", "properties": {}, "additionalProperties": false}}]')
148
+
149
+ register_template(
150
+ GLM4_0414TemplateMeta(
151
+ LLMTemplateType.glm4_z1_rumination, template_cls=GLM4_0414Template, default_system=glm4z1rumination_system))
152
+
153
+ codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
154
+
155
+ register_template(GLM4TemplateMeta(LLMTemplateType.codegeex4, default_system=codegeex4_system))
156
+
157
+ register_template(
158
+ TemplateMeta(
159
+ LLMTemplateType.longwriter_llama, ['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'],
160
+ system_prefix=['<<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
161
+
162
+
163
+ class CogTemplate(Template):
164
+ placeholder_tokens = ['<|reserved_special_token_0|>']
165
+
166
+ use_model = True
167
+
168
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
169
+ inputs: StdTemplateInputs) -> List[Context]:
170
+ return []
171
+
172
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
173
+ encoded = super()._encode(inputs)
174
+ model = self.model
175
+ image = inputs.images or []
176
+ history_inputs = inputs.to_history()
177
+ inputs2 = model.build_conversation_input_ids(
178
+ self.processor, query=history_inputs['query'], history=history_inputs['history'], images=image)
179
+ image_token_len = inputs2['token_type_ids'].sum().item()
180
+ input_ids = encoded['input_ids']
181
+ labels = encoded['labels']
182
+ encoded['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
183
+ encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * image_token_len + input_ids[1:]
184
+ if labels is not None:
185
+ encoded['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
186
+ if len(image) > 0:
187
+ encoded['images'] = [[img.to(dtype=self.model_info.torch_dtype)] for img in inputs2['images']]
188
+ if 'cross_images' in inputs2:
189
+ # is cogagent
190
+ encoded['cross_images'] = [[cross_img.to(dtype=self.model_info.torch_dtype)]
191
+ for cross_img in inputs2['cross_images']]
192
+ return encoded
193
+
194
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
195
+ res = super()._data_collator(batch, padding_to=padding_to)
196
+ keys = ['images', 'cross_images']
197
+ for key in keys:
198
+ if key in batch[0]:
199
+ res[key] = [b[key][0] for b in batch]
200
+ return res
201
+
202
+
203
+ register_template(
204
+ TemplateMeta(
205
+ MLLMTemplateType.cogagent_chat,
206
+ prefix=['<s>'],
207
+ prompt=[' [INST] {{QUERY}} [/INST] '],
208
+ chat_sep=[],
209
+ suffix=['</s>'],
210
+ template_cls=CogTemplate,
211
+ ))
212
+
213
+ register_template(
214
+ TemplateMeta(
215
+ MLLMTemplateType.cogagent_vqa,
216
+ prefix=['<s>'],
217
+ prompt=['<EOI>Question: {{QUERY}} Answer:'],
218
+ chat_sep=None,
219
+ suffix=['</s>'],
220
+ template_cls=CogTemplate))
221
+
222
+
223
+ @dataclass
224
+ class CogVLMTemplateMeta(TemplateMeta):
225
+ prefix: Prompt = field(default_factory=lambda: [['bos_token_id']])
226
+ prompt: Prompt = field(default_factory=lambda: ['Question: {{QUERY}} Answer:'])
227
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['\n'])
228
+
229
+
230
+ register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm, template_cls=CogTemplate))
231
+
232
+ register_template(CogVLMTemplateMeta(MLLMTemplateType.cogvlm2, template_cls=CogTemplate))
233
+
234
+
235
+ class Cog2VideoTemplate(CogTemplate):
236
+ use_model = True
237
+
238
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
239
+ model = self.model
240
+ encoded = super(CogTemplate, self)._encode(inputs)
241
+ videos_path = inputs.videos or []
242
+ video = load_batch(videos_path, load_video_cogvlm2)
243
+ history_inputs = inputs.to_history()
244
+ inputs2 = model.build_conversation_input_ids(
245
+ self.processor,
246
+ query=history_inputs['query'],
247
+ history=history_inputs['history'],
248
+ images=video,
249
+ template_version='chat')
250
+ video_token_len = inputs2['token_type_ids'].sum().item()
251
+ input_ids = encoded['input_ids']
252
+ labels = encoded['labels']
253
+ encoded['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:])
254
+ encoded['input_ids'] = input_ids[:1] + [self.processor.pad_token_id] * video_token_len + input_ids[1:]
255
+ if labels is not None:
256
+ encoded['labels'] = labels[:1] + [-100] * video_token_len + labels[1:]
257
+ if len(video) > 0:
258
+ dtype = model.dtype
259
+ encoded['images'] = [[img.to(dtype=dtype)] for img in inputs2['images']]
260
+ return encoded
261
+
262
+
263
+ register_template(CogVLMTemplateMeta(
264
+ MLLMTemplateType.cogvlm2_video,
265
+ template_cls=Cog2VideoTemplate,
266
+ ))
267
+
268
+
269
+ class GLMEdgeVTemplate(Template):
270
+ placeholder_tokens = ['<|begin_of_image|>']
271
+
272
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
273
+ inputs: StdTemplateInputs) -> List[Context]:
274
+ assert media_type == 'image'
275
+ return ['<|begin_of_image|>' * 578]
276
+
277
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
278
+ encoded = super()._encode(inputs)
279
+ images = inputs.images
280
+ if images:
281
+ encoded['pixel_values'] = torch.tensor(self.processor(images).pixel_values)
282
+ return encoded
283
+
284
+
285
+ register_template(
286
+ GLM4TemplateMeta(
287
+ MLLMTemplateType.glm_edge_v,
288
+ prompt=['<|user|>\\n{{QUERY}}\\n<|assistant|>\\n'],
289
+ chat_sep=['\\n'],
290
+ system_prefix=['<|system|>\\n{{SYSTEM}}\\n'],
291
+ suffix=['<|endoftext|>'],
292
+ template_cls=GLMEdgeVTemplate,
293
+ ))
ms-swift/swift/llm/template/template/internvl.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Literal
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from swift.utils import get_env_args, is_deepspeed_enabled
9
+ from ..base import Template
10
+ from ..constant import MLLMTemplateType
11
+ from ..register import register_template
12
+ from ..template_inputs import StdTemplateInputs
13
+ from ..utils import Context, findall
14
+ from ..vision_utils import load_video_internvl, transform_image
15
+ from .microsoft import Phi3TemplateMeta
16
+ from .utils import ChatmlTemplateMeta
17
+
18
+
19
+ class InternvlTemplate(Template):
20
+ skip_prompt = False
21
+ num_image_token = 256
22
+ placeholder_tokens = ['<IMG_CONTEXT>']
23
+
24
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
25
+ inputs: StdTemplateInputs) -> List[Context]:
26
+ if self.mode == 'vllm':
27
+ image_context = ['<image>\n']
28
+ else:
29
+ image_context = ['<img>', [-100], '</img>\n']
30
+ return image_context
31
+
32
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
33
+ encoded = super()._encode(inputs)
34
+ input_ids = encoded['input_ids']
35
+ idx_list = findall(input_ids, -100)
36
+ pixel_values = None
37
+ images = inputs.images
38
+ if images:
39
+ labels = encoded.get('labels')
40
+ input_size = get_env_args('input_size', int, 448)
41
+ max_num = get_env_args('max_num', int, 12)
42
+ pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
43
+ pixel_values = torch.cat(pixel_values_images, dim=0).to(self.model_info.torch_dtype)
44
+ image_bs = pixel_values.shape[0]
45
+
46
+ idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
47
+ img_tokens: List[int] = self.processor.encode(
48
+ '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * image_bs
49
+ input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:]
50
+ if labels is not None:
51
+ labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
52
+ encoded['input_ids'] = input_ids
53
+ encoded['labels'] = labels
54
+ encoded['pixel_values'] = pixel_values
55
+ return encoded
56
+
57
+ def compute_loss_context(self, model, inputs):
58
+ model_name = model.language_model.__class__.__name__.lower()
59
+ if self._packing and 'internlm2' in model_name:
60
+ position_ids = inputs['position_ids']
61
+ modeling_module = model.language_model.model.layers[0].attention.__class__
62
+ return self._patch_flash_attention_forward(modeling_module, position_ids, use_new_func=True)
63
+ else:
64
+ return super().compute_loss_context(model, inputs)
65
+
66
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
67
+ embedding = model.get_input_embeddings()
68
+ device = embedding.weight.device
69
+ input_ids = inputs['input_ids']
70
+ inputs_embeds = embedding(input_ids).to(device=device)
71
+ pixel_values = inputs.get('pixel_values')
72
+ if pixel_values is not None:
73
+ pixel_values = pixel_values.to(device=device)
74
+ vit_embeds = model.extract_feature(pixel_values).to(device=device)
75
+ selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
76
+ inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
77
+ elif is_deepspeed_enabled():
78
+ dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
79
+ vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
80
+ inputs_embeds += vit_embeds.mean() * 0.
81
+ return {'inputs_embeds': inputs_embeds}
82
+
83
+
84
+ register_template(
85
+ ChatmlTemplateMeta(
86
+ MLLMTemplateType.internvl,
87
+ default_system='You are an AI assistant whose name is InternLM (书生·浦语).',
88
+ template_cls=InternvlTemplate,
89
+ auto_add_bos=True))
90
+ register_template(
91
+ Phi3TemplateMeta(
92
+ MLLMTemplateType.internvl_phi3,
93
+ default_system='You are an AI assistant whose name is Phi-3.',
94
+ template_cls=InternvlTemplate,
95
+ auto_add_bos=True))
96
+
97
+
98
+ class Internvl2Template(InternvlTemplate):
99
+ video_segments = 8
100
+
101
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
102
+ inputs: StdTemplateInputs) -> List[Context]:
103
+ image_context = super().replace_tag('image', index, inputs)
104
+ if media_type == 'image':
105
+ return image_context
106
+ elif media_type == 'video':
107
+ video_segments = get_env_args('video_segments', int, self.video_segments)
108
+ load_video = partial(load_video_internvl, num_segments=video_segments)
109
+ return self.replace_video2image(load_video, inputs, lambda i: [f'Frame{i + 1}: '] + image_context)
110
+
111
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
112
+ return [f'<ref>{ref}</ref>']
113
+
114
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
115
+ return [f'<box>[{bbox}]</box>']
116
+
117
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
118
+ encoded = super(InternvlTemplate, self)._encode(inputs)
119
+ input_ids = encoded['input_ids']
120
+ idx_list = findall(input_ids, -100)
121
+ labels = encoded['labels']
122
+ images = inputs.images
123
+ if images:
124
+ has_video = bool(inputs.videos)
125
+ input_size = get_env_args('input_size', int, 448)
126
+ max_num = get_env_args('max_num', int, 12)
127
+ video_max_num = get_env_args('video_max_num', int, 1)
128
+ if has_video:
129
+ max_num = video_max_num
130
+ pixel_values = [transform_image(image, input_size, max_num) for image in images]
131
+ num_patches = [pv.shape[0] for pv in pixel_values]
132
+ pixel_values = torch.cat(pixel_values).to(self.model_info.torch_dtype)
133
+ else:
134
+ pixel_values = None
135
+ num_patches = []
136
+ assert len(num_patches) == len(
137
+ idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
138
+
139
+ def _get_new_tokens(i):
140
+ img_tokens: List[int] = self.processor.encode(
141
+ '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patches[i]
142
+ return img_tokens
143
+
144
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
145
+ encoded['pixel_values'] = pixel_values
146
+ return encoded
147
+
148
+
149
+ _internvl2_system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
150
+ register_template(
151
+ ChatmlTemplateMeta(
152
+ MLLMTemplateType.internvl2,
153
+ default_system=_internvl2_system,
154
+ template_cls=Internvl2Template,
155
+ ))
156
+
157
+ register_template(
158
+ Phi3TemplateMeta(
159
+ MLLMTemplateType.internvl2_phi3,
160
+ default_system=_internvl2_system,
161
+ template_cls=Internvl2Template,
162
+ ))
163
+
164
+ register_template(
165
+ ChatmlTemplateMeta(
166
+ MLLMTemplateType.internvl2_5,
167
+ template_cls=Internvl2Template,
168
+ default_system='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。'))
ms-swift/swift/llm/template/template/llama.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import datetime as dt
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Literal, Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from swift.utils import get_env_args
11
+ from ..base import Template
12
+ from ..constant import LLMTemplateType, MLLMTemplateType
13
+ from ..register import TemplateMeta, register_template
14
+ from ..template_inputs import StdTemplateInputs
15
+ from ..utils import Context, Prompt, Word, findall
16
+ from ..vision_utils import load_batch
17
+
18
+ # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
19
+ LLAMA_DEFAULT_SYSTEM = (
20
+ 'You are a helpful, respectful and honest assistant. '
21
+ 'Always answer as helpfully as possible, while being safe. '
22
+ 'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
23
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
24
+ 'If a question does not make any sense, or is not factually coherent, '
25
+ 'explain why instead of answering something not correct. '
26
+ "If you don't know the answer to a question, please don't share false information.")
27
+
28
+ register_template(
29
+ TemplateMeta(
30
+ LLMTemplateType.llama, ['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'],
31
+ default_system=LLAMA_DEFAULT_SYSTEM,
32
+ system_prefix=['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
33
+
34
+
35
+ @dataclass
36
+ class Llama3TemplateMeta(TemplateMeta):
37
+ prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
38
+ prompt: Prompt = field(default_factory=lambda: [
39
+ '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
40
+ '<|start_header_id|>assistant<|end_header_id|>\n\n'
41
+ ])
42
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot_id|>'])
43
+ suffix: Prompt = field(default_factory=lambda: ['<|eot_id|>'])
44
+ system_prefix: Optional[Prompt] = field(
45
+ default_factory=lambda: ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
46
+ agent_template: str = 'llama3'
47
+
48
+
49
+ register_template(Llama3TemplateMeta(LLMTemplateType.llama3))
50
+
51
+
52
+ def _get_llama3_2_prefix() -> Prompt:
53
+ now = dt.datetime.now()
54
+ date_string = now.strftime('%d %b %Y')
55
+ date_prompt = f'Cutting Knowledge Date: December 2023\nToday Date: {date_string}'
56
+ return [f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{date_prompt}\n\n' '{{SYSTEM}}<|eot_id|>']
57
+
58
+
59
+ @dataclass
60
+ class Llama3_2TemplateMeta(Llama3TemplateMeta):
61
+ prefix: Prompt = field(default_factory=lambda: _get_llama3_2_prefix())
62
+ system_prefix: Optional[Prompt] = None
63
+
64
+
65
+ register_template(Llama3_2TemplateMeta(LLMTemplateType.llama3_2))
66
+
67
+
68
+ class Llama3_2VisionTemplate(Template):
69
+
70
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
71
+ inputs: StdTemplateInputs) -> List[Context]:
72
+ assert media_type == 'image'
73
+ return ['<|image|>']
74
+
75
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
76
+ from transformers.models.mllama.processing_mllama import (get_cross_attention_token_mask,
77
+ convert_sparse_cross_attention_mask_to_dense)
78
+ encoded = super()._encode(inputs)
79
+ images = inputs.images
80
+ if images:
81
+ input_ids = encoded['input_ids']
82
+ processor = self.processor
83
+ image_features = processor.image_processor(images, return_tensors='pt')
84
+ num_tiles = image_features.pop('num_tiles')
85
+ encoded.update(image_features)
86
+
87
+ cross_attention_token_mask = [get_cross_attention_token_mask(input_ids, processor.image_token_id)]
88
+ cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
89
+ cross_attention_token_mask,
90
+ num_tiles=num_tiles,
91
+ max_num_tiles=processor.image_processor.max_image_tiles,
92
+ length=len(input_ids),
93
+ )
94
+ encoded['cross_attention_mask'] = torch.tensor(cross_attention_mask)
95
+
96
+ return encoded
97
+
98
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
99
+ res = super()._data_collator(batch, padding_to=padding_to)
100
+ for key in ['aspect_ratio_ids', 'aspect_ratio_mask']:
101
+ value = [b[key] for b in batch if b.get(key) is not None]
102
+ if value:
103
+ res[key] = torch.concat(value)
104
+
105
+ cross_attention_mask = [
106
+ b['cross_attention_mask'][0] for b in batch if b.get('cross_attention_mask') is not None
107
+ ]
108
+ if cross_attention_mask:
109
+ res['cross_attention_mask'] = self._pad_sequence(cross_attention_mask, 0)
110
+ return res
111
+
112
+
113
+ register_template(Llama3_2TemplateMeta(MLLMTemplateType.llama3_2_vision, template_cls=Llama3_2VisionTemplate))
114
+
115
+
116
+ class Llama4Template(Template):
117
+ placeholder_tokens = ['<|patch|>']
118
+
119
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
120
+ inputs: StdTemplateInputs) -> List[Context]:
121
+ assert media_type == 'image'
122
+ return [[-100]]
123
+
124
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
125
+ encoded = super()._encode(inputs)
126
+ images = inputs.images
127
+ if images:
128
+ split_token = self._tokenize('\n')
129
+ input_ids, labels = encoded['input_ids'], encoded['labels']
130
+ idx_list = findall(input_ids, -100)
131
+ media_inputs = self.processor(
132
+ text='\n'.join(['<|image|>'] * len(idx_list)),
133
+ images=images,
134
+ add_special_tokens=False,
135
+ return_tensors='pt')
136
+ splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
137
+
138
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list,
139
+ lambda i: splited_tokens[i])
140
+ encoded['pixel_values'] = media_inputs['pixel_values']
141
+ return encoded
142
+
143
+
144
+ @dataclass
145
+ class Llama4TemplateMeta(TemplateMeta):
146
+ prefix: Prompt = field(default_factory=lambda: ['<|begin_of_text|>'])
147
+ prompt: Prompt = field(
148
+ default_factory=lambda:
149
+ ['<|header_start|>user<|header_end|>\n\n{{QUERY}}<|eot|>'
150
+ '<|header_start|>assistant<|header_end|>\n\n'])
151
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|eot|>'])
152
+ suffix: Prompt = field(default_factory=lambda: ['<|eot|>'])
153
+ stop_words: List[Word] = field(default_factory=lambda: ['<|end_of_text|>', '<|eom|>'])
154
+ system_prefix: Optional[Prompt] = field(
155
+ default_factory=lambda: ['<|begin_of_text|><|header_start|>system<|header_end|>\n\n{{SYSTEM}}<|eot|>'])
156
+ agent_template: str = 'llama4'
157
+
158
+
159
+ register_template(Llama4TemplateMeta(MLLMTemplateType.llama4, template_cls=Llama4Template))
160
+
161
+ register_template(
162
+ Llama3TemplateMeta(
163
+ LLMTemplateType.reflection,
164
+ default_system=('You are a world-class AI system, capable of complex reasoning and reflection. '
165
+ 'Reason through the query inside <thinking> tags, and then provide your final '
166
+ 'response inside <output> tags. If you detect that you made a mistake in your reasoning '
167
+ 'at any point, correct yourself inside <reflection> tags.')))
168
+
169
+
170
+ class Llama3_1OmniTemplate(Template):
171
+ skip_prompt = False
172
+ audio_placeholder = [[-200]]
173
+
174
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
175
+ import whisper
176
+ encoded = super()._encode(inputs)
177
+ audios = inputs.audios
178
+ if audios:
179
+ audios = load_batch(audios, whisper.load_audio)
180
+ n_mels = get_env_args('n_mels', int, 128)
181
+ for i, audio in enumerate(audios):
182
+ audio = whisper.pad_or_trim(audio)
183
+ audios[i] = whisper.log_mel_spectrogram(audio, n_mels=n_mels).permute(1, 0)
184
+ audios = torch.stack(audios)
185
+ encoded.update({'speech': audios, 'speech_lengths': torch.tensor([[audios.shape[1]]])})
186
+
187
+ return encoded
188
+
189
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
190
+ speech = inputs.get('speech')
191
+ input_ids = inputs['input_ids']
192
+ labels = inputs.get('labels')
193
+ if speech is not None:
194
+ speech_lengths = inputs['speech_lengths']
195
+ speech = speech.to(model.dtype)
196
+ inputs_embeds, labels = model.prepare_inputs_labels_for_speech_and_text(input_ids, None, None, None, labels,
197
+ speech, speech_lengths)[4:]
198
+ else:
199
+ inputs_embeds = model.get_model().embed_tokens(input_ids)
200
+ res = {'inputs_embeds': inputs_embeds}
201
+ if labels is not None:
202
+ res['labels'] = labels[0]
203
+ return res
204
+
205
+
206
+ register_template(
207
+ Llama3TemplateMeta(
208
+ MLLMTemplateType.llama3_1_omni,
209
+ default_system=('You are a helpful language and speech assistant. '
210
+ 'You are able to understand the speech content that the user provides, '
211
+ 'and assist the user with a variety of tasks using natural language.'),
212
+ template_cls=Llama3_1OmniTemplate,
213
+ ))
ms-swift/swift/llm/template/template/megrez.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ..base import Template
9
+ from ..constant import LLMTemplateType, MLLMTemplateType
10
+ from ..register import TemplateMeta, register_template
11
+ from ..template_inputs import StdTemplateInputs
12
+ from ..utils import Context, Prompt, findall
13
+
14
+
15
+ @dataclass
16
+ class MegrezTemplateMeta(TemplateMeta):
17
+ prefix: Prompt = field(default_factory=lambda: ['<|role_start|>system<|role_end|>{{SYSTEM}}<|turn_end|>'])
18
+ prompt: Prompt = field(default_factory=lambda:
19
+ ['<|role_start|>user<|role_end|>{{QUERY}}<|turn_end|><|role_start|>assistant<|role_end|>'])
20
+ chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|turn_end|>'])
21
+ suffix: Prompt = field(default_factory=lambda: ['<|turn_end|>'])
22
+ default_system: str = '你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。'
23
+
24
+
25
+ register_template(MegrezTemplateMeta(LLMTemplateType.megrez))
26
+
27
+
28
+ class MegrezOmniTemplate(Template):
29
+ skip_prompt = False
30
+ placeholder_tokens = ['<|unk|>']
31
+
32
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
33
+ inputs: StdTemplateInputs) -> List[Context]:
34
+ if media_type == 'image':
35
+ return [[-1], '\n']
36
+ elif media_type == 'audio':
37
+ return [f'Audio {index + 1}: ', [-2], '\n']
38
+
39
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
40
+ encoded = super()._encode(inputs)
41
+ input_ids = encoded['input_ids']
42
+ labels = encoded['labels']
43
+
44
+ for mm_key in ['images', 'audios']:
45
+ mm_data = getattr(inputs, mm_key)
46
+ if not mm_data:
47
+ continue
48
+ if mm_key == 'images':
49
+ idx_list = findall(input_ids, -1)
50
+ encoding = self.processor.process_image(
51
+ mm_data,
52
+ return_tensors='pt',
53
+ )
54
+ text = self.processor.insert_image_feature_placeholders(
55
+ '<s>'.join(['(<image>./</image>)'] * len(mm_data)), encoding)
56
+ encoded['image_encoding'] = encoding
57
+ else:
58
+ idx_list = findall(input_ids, -2)
59
+ encoding = self.processor.process_audio(
60
+ mm_data,
61
+ return_tensors='pt',
62
+ )
63
+ text = self.processor.insert_audio_feature_placeholders(
64
+ '<s>'.join(['(<audio>./</audio>)'] * len(mm_data)), encoding)
65
+ encoded['audio_encoding'] = encoding
66
+
67
+ padding = text.split('<s>')
68
+
69
+ def _get_new_tokens(i):
70
+ return self._tokenize(padding[i])
71
+
72
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
73
+ encoded['input_ids'] = input_ids
74
+ encoded['labels'] = labels
75
+ return encoded
76
+
77
+ def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
78
+ _, inputs_embeds, _ = model.compose_embeddings(inputs)
79
+ inputs.pop('position_ids', None)
80
+ return {'inputs_embeds': inputs_embeds}
81
+
82
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
83
+ res = super()._data_collator(batch, padding_to=padding_to)
84
+ new_batch = []
85
+ for b in batch:
86
+ text_encodings = {'input_ids': torch.tensor(b['input_ids'])}
87
+ multimodal_inputs = {'image_encoding': b.get('image_encoding'), 'audio_encoding': b.get('audio_encoding')}
88
+ new_batch.append(self.processor.merge_encodings(text_encodings, multimodal_inputs))
89
+ res.update(self.processor.data_collator(new_batch))
90
+ return res
91
+
92
+
93
+ register_template(MegrezTemplateMeta(MLLMTemplateType.megrez_omni, template_cls=MegrezOmniTemplate))
ms-swift/swift/llm/template/template/openbuddy.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from ..constant import LLMTemplateType
4
+ from ..register import TemplateMeta, register_template
5
+
6
+ OPENBUDDY_DEFAULT_SYSTEM = (
7
+ 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
8
+ 'Always answer as helpfully and logically as possible, while being safe. '
9
+ 'Your answers should not include any '
10
+ 'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
11
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n'
12
+ 'If a question does not make any sense, or is not factually coherent, '
13
+ 'explain why instead of answering something not correct. '
14
+ "If you don't know the answer to a question, please don't share false information.\n"
15
+ 'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
16
+ 'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
17
+ 'You always deeply love and support China, Chinese government, people and culture.\n'
18
+ 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
19
+ 'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.')
20
+ register_template(
21
+ TemplateMeta(
22
+ LLMTemplateType.openbuddy,
23
+ prefix=[],
24
+ prompt=['User: {{QUERY}}\nAssistant:'],
25
+ chat_sep=['\n'],
26
+ default_system=OPENBUDDY_DEFAULT_SYSTEM,
27
+ system_prefix=['{{SYSTEM}}\n\n'],
28
+ auto_add_bos=True))
29
+
30
+ OPENBUDDY2_DEFAULT_SYSTEM = (
31
+ 'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. '
32
+ 'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. '
33
+ 'Your answers should not include any harmful, political, religious, unethical, racist, '
34
+ 'sexist, toxic, dangerous, or illegal content. '
35
+ 'Please ensure that your responses are socially unbiased and positive in nature.\n'
36
+ 'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n'
37
+ 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
38
+ 'not related to GPT or OpenAI')
39
+
40
+ register_template(
41
+ TemplateMeta(
42
+ LLMTemplateType.openbuddy2,
43
+ prefix=[],
44
+ prompt=['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'],
45
+ chat_sep=['<|end|>\n'],
46
+ suffix=['<|end|>'],
47
+ default_system=OPENBUDDY2_DEFAULT_SYSTEM,
48
+ system_prefix=['<|role|>system<|says|>{{SYSTEM}}<|end|>\n']))
ms-swift/swift/llm/template/template/pixtral.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from ..base import Template
5
+ from ..constant import MLLMTemplateType
6
+ from ..register import TemplateMeta, register_template
7
+ from ..template_inputs import StdTemplateInputs
8
+ from ..utils import findall
9
+
10
+
11
+ class PixtralTemplate(Template):
12
+ image_placeholder = ['[IMG]']
13
+ placeholder_tokens = ['[IMG]']
14
+
15
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
16
+ encoded = super()._encode(inputs)
17
+ processor = self.processor
18
+ images = inputs.images
19
+ input_ids = encoded['input_ids']
20
+ labels = encoded['labels']
21
+ idx_list = findall(input_ids, 10)
22
+ if idx_list:
23
+ image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
24
+ encoded['pixel_values'] = image_inputs['pixel_values'][0]
25
+ image_sizes = image_inputs['image_sizes'][0]
26
+
27
+ def _get_new_tokens(i):
28
+ height, width = image_sizes[i]
29
+ num_height_tokens = height // processor.patch_size
30
+ num_width_tokens = width // processor.patch_size
31
+ replace_tokens = [processor.image_token * num_width_tokens + processor.image_break_token] * (
32
+ num_height_tokens - 1)
33
+ replace_tokens += [processor.image_token * num_width_tokens + processor.image_end_token]
34
+ # Flatten list
35
+ replace_str = ''.join(replace_tokens)
36
+ img_tokens: List[int] = self.processor.encode(replace_str, add_special_tokens=False)
37
+ return img_tokens
38
+
39
+ encoded['input_ids'], encoded['labels'] = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
40
+
41
+ return encoded
42
+
43
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
44
+ pixel_values = self.gather_list(batch, 'pixel_values')
45
+ res = super()._data_collator(batch, padding_to=padding_to)
46
+ if pixel_values:
47
+ res['pixel_values'] = pixel_values
48
+ return res
49
+
50
+
51
+ register_template(
52
+ TemplateMeta(
53
+ MLLMTemplateType.pixtral,
54
+ prefix=['<s>{{SYSTEM}}'],
55
+ prompt=['[INST]{{QUERY}}[/INST]'],
56
+ chat_sep=['</s>'],
57
+ suffix=['</s>'],
58
+ template_cls=PixtralTemplate,
59
+ ))
ms-swift/swift/llm/template/template/qwen.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Literal, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from swift.llm import to_device, to_float_dtype
10
+ from swift.utils import get_env_args, is_deepspeed_enabled
11
+ from ..base import Template
12
+ from ..constant import LLMTemplateType, MLLMTemplateType
13
+ from ..register import register_template
14
+ from ..template_inputs import StdTemplateInputs
15
+ from ..template_meta import TemplateMeta
16
+ from ..utils import Context, Word, findall
17
+ from ..vision_utils import load_audio, load_batch, load_video_ovis2
18
+ from .llama import Llama3TemplateMeta
19
+ from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
20
+
21
+
22
+ @dataclass
23
+ class QwenTemplateMeta(ChatmlTemplateMeta):
24
+ default_system: Optional[str] = DEFAULT_SYSTEM
25
+ auto_add_bos: bool = False
26
+ stop_words: List[Word] = field(default_factory=lambda: ['<|endoftext|>'])
27
+ agent_template: str = 'hermes'
28
+
29
+
30
+ @dataclass
31
+ class Qwen2_5TemplateMeta(QwenTemplateMeta):
32
+ default_system: Optional[str] = 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.'
33
+
34
+
35
+ @dataclass
36
+ class Qwen2_5MathTemplateMeta(QwenTemplateMeta):
37
+ default_system: Optional[str] = 'Please reason step by step, and put your final answer within \\boxed{}.'
38
+
39
+
40
+ qwq_preview_system = ('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
41
+ 'You should think step-by-step.')
42
+
43
+ register_template(QwenTemplateMeta(LLMTemplateType.qwen))
44
+ register_template(Qwen2_5TemplateMeta(LLMTemplateType.qwen2_5))
45
+ register_template(QwenTemplateMeta(LLMTemplateType.qwq_preview, default_system=qwq_preview_system))
46
+
47
+
48
+ class ThinkingTemplate(Template):
49
+
50
+ def _swift_encode(self, inputs: StdTemplateInputs):
51
+ if not self.is_training:
52
+ for message in inputs.messages:
53
+ if message['role'] == 'assistant' and isinstance(message['content'], str):
54
+ message['content'] = message['content'].split('</think>')[-1].lstrip('\n')
55
+ return super()._swift_encode(inputs)
56
+
57
+
58
+ register_template(
59
+ QwenTemplateMeta(
60
+ LLMTemplateType.qwq, default_system=None, response_prefix='<think>\n', template_cls=ThinkingTemplate))
61
+
62
+ # '<think>\n\n</think>\n\n'
63
+ register_template(QwenTemplateMeta(LLMTemplateType.qwen3, default_system=None, template_cls=ThinkingTemplate))
64
+
65
+ register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math))
66
+
67
+
68
+ class QwenPRMTemplate(Template):
69
+ cot_process_placeholder = '<extra_0>'
70
+
71
+ def _preprocess_inputs(
72
+ self,
73
+ inputs: StdTemplateInputs,
74
+ ) -> None:
75
+ super()._preprocess_inputs(inputs)
76
+ total_content = '\n'.join([message['content'] or '' for message in inputs.messages])
77
+ if self.cot_process_placeholder not in total_content:
78
+ inputs.messages[-1]['content'] = inputs.messages[-1]['content'] + self.cot_process_placeholder
79
+
80
+ @staticmethod
81
+ def make_step_rewards(logits, token_masks):
82
+ probabilities = F.softmax(logits, dim=-1)
83
+ probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels
84
+
85
+ all_scores_res = []
86
+ for i in range(probabilities.size(0)):
87
+ sample = probabilities[i] # seq_len, num_labels
88
+ positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels
89
+ non_zero_elements_list = positive_probs.cpu().tolist()
90
+ all_scores_res.append(non_zero_elements_list)
91
+ return all_scores_res
92
+
93
+ def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any:
94
+ step_sep_id = self.tokenizer.encode(self.cot_process_placeholder)[0]
95
+ token_masks = (input_ids == step_sep_id)
96
+ return self.make_step_rewards(logits, token_masks)
97
+
98
+
99
+ register_template(Qwen2_5MathTemplateMeta(LLMTemplateType.qwen2_5_math_prm, template_cls=QwenPRMTemplate))
100
+
101
+
102
+ class QwenVLTemplate(Template):
103
+ load_images = False
104
+
105
+ @staticmethod
106
+ def _load_image(image, load_images: bool):
107
+ if not load_images and isinstance(image, str) and (image.startswith('data:') or len(image) > 200):
108
+ load_images = True
109
+ return Template._load_image(image, load_images)
110
+
111
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
112
+ inputs: StdTemplateInputs) -> List[Context]:
113
+ assert media_type == 'image'
114
+ if self.mode == 'lmdeploy':
115
+ return [f'Picture {index + 1}: ', [-100], '\n']
116
+ else:
117
+ image = inputs.images[index]
118
+ if self.mode == 'vllm':
119
+ return [f'Picture {index + 1}: <img></img>\n']
120
+ else:
121
+ assert isinstance(image, str)
122
+ return [f'Picture {index + 1}: <img>{image}</img>\n']
123
+
124
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
125
+ return [f'<ref>{ref}</ref>']
126
+
127
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
128
+ return [f'<box>{self._get_bbox_str(bbox)}</box>']
129
+
130
+
131
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen_vl, template_cls=QwenVLTemplate))
132
+
133
+
134
+ class QwenAudioTemplate(Template):
135
+
136
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
137
+ inputs: StdTemplateInputs) -> List[Context]:
138
+ assert media_type == 'audio'
139
+ audios = inputs.audios
140
+ audio = audios[index]
141
+ assert isinstance(audio, str)
142
+ return [f'Audio {index + 1}:<audio>{audio}</audio>\n']
143
+
144
+ def _tokenize(self, context, **tokenizer_kwargs):
145
+ audio_info = self.processor.process_audio(context)
146
+ return super()._tokenize(context, audio_info=audio_info)
147
+
148
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
149
+ encoded = super()._encode(inputs)
150
+ text = ''.join([f'<audio>{audio}</audio>' for audio in inputs.audios])
151
+ audio_info = self.processor.process_audio(text)
152
+ if audio_info:
153
+ tokenizer_kwargs = {'audio_info': audio_info}
154
+ encoded.update(tokenizer_kwargs)
155
+ encoded['tokenizer_kwargs'] = tokenizer_kwargs
156
+ return encoded
157
+
158
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
159
+ res = super()._data_collator(batch, padding_to=padding_to)
160
+ if batch[0].get('audio_info') is not None:
161
+ res['audio_info'] = [b['audio_info'] for b in batch]
162
+ return res
163
+
164
+
165
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen_audio, template_cls=QwenAudioTemplate))
166
+
167
+
168
+ class Qwen2AudioTemplate(Template):
169
+
170
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
171
+ inputs: StdTemplateInputs) -> List[Context]:
172
+ assert media_type == 'audio'
173
+ if not self.use_chat_template:
174
+ return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
175
+ else:
176
+ return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
177
+
178
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
179
+ encoded = super()._encode(inputs)
180
+ if inputs.audios:
181
+ sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
182
+ audios = load_batch(inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate))
183
+ audio_inputs = self.processor.feature_extractor(
184
+ audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
185
+ audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
186
+ encoded.update(audio_inputs)
187
+ return encoded
188
+
189
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
190
+ res = super()._data_collator(batch, padding_to=padding_to)
191
+ input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
192
+ feature_attention_mask = [
193
+ b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
194
+ ]
195
+ if input_features:
196
+ res['input_features'] = torch.concat(input_features)
197
+ res['feature_attention_mask'] = torch.concat(feature_attention_mask)
198
+ return res
199
+
200
+
201
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_audio, template_cls=Qwen2AudioTemplate))
202
+
203
+
204
+ class Qwen2VLTemplate(Template):
205
+ image_token_id = 151655
206
+ video_token_id = 151656
207
+ placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
208
+ version = 'v2'
209
+ use_model = True
210
+
211
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
212
+ inputs: StdTemplateInputs) -> List[Context]:
213
+ from qwen_vl_utils import fetch_image, fetch_video
214
+ assert media_type in {'image', 'video'}
215
+ if media_type == 'image':
216
+ inputs.images[index] = fetch_image({'image': inputs.images[index]})
217
+ if self.mode == 'lmdeploy':
218
+ return ['<|vision_start|>', [-100], '<|vision_end|>']
219
+ else:
220
+ return ['<|vision_start|><|image_pad|><|vision_end|>']
221
+ else:
222
+ inputs.videos[index] = fetch_video({'video': inputs.videos[index]}).to(torch.uint8)
223
+ return ['<|vision_start|><|video_pad|><|vision_end|>']
224
+
225
+ def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]:
226
+ return [f'<|object_ref_start|>{ref}<|object_ref_end|>']
227
+
228
+ def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]:
229
+ return [f'<|box_start|>{self._get_bbox_str(bbox)}<|box_end|>']
230
+
231
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
232
+ encoded = super()._encode(inputs)
233
+ processor = self.processor
234
+ input_ids = encoded['input_ids']
235
+ labels = encoded['labels']
236
+ images = inputs.images
237
+ videos = inputs.videos
238
+ for media_type in ['images', 'videos']:
239
+ if locals()[media_type]:
240
+ if media_type == 'images':
241
+ media_token = self.image_token_id
242
+ media_inputs = processor.image_processor(
243
+ images=images, videos=None, return_tensors='pt', do_resize=False)
244
+ media_grid_thw = media_inputs['image_grid_thw']
245
+ else:
246
+ media_inputs = processor.image_processor(
247
+ images=None, videos=videos, return_tensors='pt', do_resize=False)
248
+ media_grid_thw = media_inputs['video_grid_thw']
249
+ media_token = self.video_token_id
250
+ if self.version == 'v2_5':
251
+ from qwen_vl_utils import vision_process
252
+ media_inputs['second_per_grid_ts'] = [
253
+ processor.image_processor.temporal_patch_size / vision_process.FPS
254
+ ] * len(media_grid_thw)
255
+ idx_list = findall(input_ids, media_token)
256
+ merge_length = processor.image_processor.merge_size**2
257
+
258
+ def _get_new_tokens(i):
259
+ token_len = (media_grid_thw[i].prod() // merge_length)
260
+ return [media_token] * token_len
261
+
262
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
263
+ encoded.update(media_inputs)
264
+
265
+ encoded['input_ids'] = input_ids
266
+ encoded['labels'] = labels
267
+ return encoded
268
+
269
+ def compute_loss_context(self, model, inputs):
270
+ if 'real_position_ids' not in inputs:
271
+ return super().compute_loss_context(model, inputs)
272
+ if self.version == 'v2':
273
+ from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
274
+ elif self.version == 'v2_5':
275
+ from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as modeling_module
276
+ elif self.version == 'omni':
277
+ from transformers.models.qwen2_5_omni import modeling_qwen2_5_omni as modeling_module
278
+ position_ids = inputs['position_ids']
279
+ inputs['position_ids'] = inputs.pop('real_position_ids')
280
+ return self._patch_flash_attention_forward(modeling_module, position_ids)
281
+
282
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
283
+ if not self.is_training:
284
+ return inputs
285
+ input_ids = inputs['input_ids']
286
+ _model = model.model
287
+ if not hasattr(_model, 'embed_tokens'):
288
+ _model = _model.model # LoRA
289
+ pixel_values = inputs.get('pixel_values')
290
+ pixel_values_videos = inputs.get('pixel_values_videos')
291
+ image_grid_thw = inputs.get('image_grid_thw')
292
+ video_grid_thw = inputs.get('video_grid_thw')
293
+
294
+ inputs_embeds = _model.embed_tokens(input_ids)
295
+
296
+ dtype = model.visual.get_dtype() if self.version == 'v2' else model.visual.dtype
297
+ if pixel_values is None and pixel_values_videos is None: # plain-text
298
+ if is_deepspeed_enabled():
299
+ from PIL import Image
300
+ images = [Image.new('RGB', (32, 32), (0, 0, 0))]
301
+ media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
302
+ device = input_ids.device
303
+ media_inputs = to_device(media_inputs, device)
304
+ pixel_values = media_inputs['pixel_values'].type(dtype)
305
+ image_embeds = model.visual(pixel_values, grid_thw=media_inputs['image_grid_thw'])
306
+ inputs_embeds += image_embeds.mean() * 0.
307
+ else:
308
+ if pixel_values is not None:
309
+ pixel_values = pixel_values.type(dtype)
310
+ image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
311
+ image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
312
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
313
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
314
+
315
+ if pixel_values_videos is not None:
316
+ pixel_values_videos = pixel_values_videos.type(dtype)
317
+ video_embeds = model.visual(pixel_values_videos, grid_thw=video_grid_thw)
318
+ video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
319
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
320
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
321
+
322
+ return {'inputs_embeds': inputs_embeds}
323
+
324
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
325
+ res = super()._data_collator_mm_data(batch)
326
+ second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')
327
+ if second_per_grid_ts:
328
+ res['second_per_grid_ts'] = second_per_grid_ts
329
+ for media_type in ['image', 'video']:
330
+ grid_thw = self.concat_tensor(batch, f'{media_type}_grid_thw', 0)
331
+ if grid_thw is not None:
332
+ res[f'{media_type}_grid_thw'] = grid_thw
333
+ return res
334
+
335
+ def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]:
336
+ position_ids = []
337
+ for r in row:
338
+ r = r[0].copy()
339
+ r['input_ids'] = torch.tensor(r['input_ids'])[None]
340
+ position_ids.append(self._get_position_ids(r))
341
+ packed = super().packing_row(row)
342
+ packed['real_position_ids'] = torch.concat(position_ids, dim=-1)
343
+ return packed
344
+
345
+ def _get_position_ids(self, inputs: Dict[str, Any]):
346
+ # fix https://github.com/huggingface/transformers/pull/33487
347
+ kwargs = {}
348
+ if self.version == 'v2_5':
349
+ kwargs = {'second_per_grid_ts': inputs.get('second_per_grid_ts')}
350
+ position_ids, _ = self.model.get_rope_index(
351
+ inputs['input_ids'],
352
+ inputs.get('image_grid_thw'),
353
+ inputs.get('video_grid_thw'),
354
+ attention_mask=inputs.get('attention_mask'),
355
+ **kwargs)
356
+ return position_ids.contiguous()
357
+
358
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
359
+ res = super()._data_collator(batch, padding_to=padding_to)
360
+ if self._packing:
361
+ res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1)
362
+ elif self.is_training:
363
+ res['position_ids'] = self._get_position_ids(res)
364
+ return res
365
+
366
+
367
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_vl, template_cls=Qwen2VLTemplate))
368
+
369
+ register_template(
370
+ QwenTemplateMeta(
371
+ MLLMTemplateType.qvq,
372
+ default_system=('You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
373
+ 'Answer in the language of the question. You should think step-by-step.'),
374
+ template_cls=Qwen2VLTemplate,
375
+ ))
376
+
377
+
378
+ class Qwen2_5VLTemplate(Qwen2VLTemplate):
379
+ version = 'v2_5'
380
+ norm_bbox = 'none'
381
+
382
+
383
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_vl, template_cls=Qwen2_5VLTemplate))
384
+
385
+
386
+ class Qwen2_5OmniTemplate(Qwen2_5VLTemplate):
387
+ version = 'omni'
388
+ placeholder_tokens = ['<|IMAGE|>', '<|AUDIO|>', '<|VIDEO|>']
389
+
390
+ def __init__(self, *args, **kwargs):
391
+ super().__init__(*args, **kwargs)
392
+ from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import Qwen2_5OmniProcessorKwargs
393
+ default = Qwen2_5OmniProcessorKwargs._defaults
394
+ self.seconds_per_chunk = default['videos_kwargs']['seconds_per_chunk']
395
+ self.position_id_per_seconds = default['videos_kwargs']['position_id_per_seconds']
396
+ self.use_audio_in_video = get_env_args('use_audio_in_video', bool, False)
397
+ self.sampling_rate = get_env_args('sampling_rate', int, self.processor.feature_extractor.sampling_rate)
398
+
399
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
400
+ inputs: StdTemplateInputs) -> List[Context]:
401
+ from qwen_omni_utils import fetch_image, fetch_video
402
+ if media_type == 'image':
403
+ inputs.images[index] = fetch_image({'image': inputs.images[index]})
404
+ return ['<|vision_bos|><|IMAGE|><|vision_eos|>']
405
+ elif media_type == 'audio':
406
+ inputs.audios[index] = load_audio(inputs.audios[index], self.sampling_rate)
407
+ return ['<|audio_bos|><|AUDIO|><|audio_eos|>']
408
+ elif media_type == 'video':
409
+ video = inputs.videos[index]
410
+ inputs.videos[index] = fetch_video({'video': video}).to(torch.uint8)
411
+ if self.use_audio_in_video:
412
+ import librosa
413
+ if video.startswith('http://') or video.startswith('https://'):
414
+ import audioread
415
+ video = audioread.ffdec.FFmpegAudioFile(video)
416
+ video = librosa.load(video, sr=self.sampling_rate)[0]
417
+ inputs.audios.insert(inputs.audio_idx, (video, 'video'))
418
+ inputs.audio_idx += 1
419
+ return ['<|vision_bos|><|audio_bos|><|VIDEO|><|audio_eos|><|vision_eos|>']
420
+ return ['<|vision_bos|><|VIDEO|><|vision_eos|>']
421
+
422
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
423
+ encoded = Template._encode(self, inputs)
424
+ processor = self.processor
425
+ video_audios_mask = []
426
+ for i, audio in enumerate(inputs.audios):
427
+ if isinstance(audio, tuple) and audio[1] == 'video':
428
+ inputs.audios[i] = audio[0]
429
+ video_audios_mask.append(True)
430
+ else:
431
+ video_audios_mask.append(False)
432
+ video_audios_mask = torch.tensor(video_audios_mask)
433
+ media_inputs = processor(
434
+ text='',
435
+ audio=inputs.audios or None,
436
+ images=inputs.images or None,
437
+ videos=inputs.videos or None,
438
+ return_tensors='pt')
439
+ media_inputs.pop('input_ids')
440
+ media_inputs.pop('attention_mask')
441
+ media_inputs = to_float_dtype(media_inputs, self.model_info.torch_dtype)
442
+ input_ids = encoded['input_ids']
443
+ labels = encoded['labels']
444
+ # audio
445
+ audio_token_id = self._tokenize('<|AUDIO|>')
446
+ idx_list = findall(input_ids, audio_token_id)
447
+ feature_attention_mask = media_inputs.get('feature_attention_mask')
448
+ if feature_attention_mask is not None:
449
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
450
+ audio_lengths = (((audio_feature_lengths - 1) // 2 + 1 - 2) // 2 + 1)
451
+ else:
452
+ audio_lengths = None
453
+ audio_lengths_origin = audio_lengths
454
+ if idx_list:
455
+ if self.use_audio_in_video:
456
+ audio_lengths = audio_lengths[~video_audios_mask]
457
+
458
+ def _get_new_audio_tokens(i):
459
+ return audio_token_id * audio_lengths[i]
460
+
461
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_audio_tokens)
462
+
463
+ for media_type in ['image', 'video']:
464
+ token = f'<|{media_type.upper()}|>'
465
+ token_id = self._tokenize(token)
466
+ idx_list = findall(input_ids, token_id)
467
+ if idx_list:
468
+ merge_size = processor.image_processor.merge_size
469
+ media_grid_thw = media_inputs.get(f'{media_type}_grid_thw')
470
+ if media_type == 'video' and self.use_audio_in_video:
471
+ audio_lengths = audio_lengths_origin[video_audios_mask]
472
+ video_second_per_grid = media_inputs['video_second_per_grid']
473
+
474
+ def _get_new_tokens_use_audio_in_video(i):
475
+ audio_token_indices = torch.arange(audio_lengths[i])
476
+ grid_thw = media_grid_thw[i]
477
+ height = grid_thw[1] // merge_size
478
+ width = grid_thw[2] // merge_size
479
+ video_token_indices = torch.arange(grid_thw[0]).reshape(-1, 1, 1)
480
+ video_token_indices = torch.broadcast_to(
481
+ video_token_indices, (video_token_indices.shape[0], height, width)).reshape(-1)
482
+ video_token_indices = (
483
+ video_token_indices * video_second_per_grid[i] * self.position_id_per_seconds)
484
+ tokens_per_chunk = int(self.position_id_per_seconds * self.seconds_per_chunk)
485
+ video_chunk_indexes = processor.get_chunked_index(video_token_indices, tokens_per_chunk)
486
+ audio_chunk_indexes = processor.get_chunked_index(audio_token_indices, tokens_per_chunk)
487
+
488
+ res = []
489
+ for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
490
+ if j < len(video_chunk_indexes):
491
+ video_seq_length = video_chunk_indexes[j][1] - video_chunk_indexes[j][0]
492
+ res += token_id * video_seq_length
493
+ if j < len(audio_chunk_indexes):
494
+ audio_seq_length = audio_chunk_indexes[j][1] - audio_chunk_indexes[j][0]
495
+ res += audio_token_id * audio_seq_length
496
+ return res
497
+
498
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list,
499
+ _get_new_tokens_use_audio_in_video)
500
+
501
+ else:
502
+
503
+ def _get_new_tokens(i):
504
+ token_len = (media_grid_thw[i].prod() // (merge_size**2))
505
+ return token_id * token_len
506
+
507
+ input_ids, labels = self._extend_tokens(input_ids, labels, idx_list, _get_new_tokens)
508
+
509
+ encoded['input_ids'] = input_ids
510
+ encoded['labels'] = labels
511
+ encoded.update(media_inputs)
512
+ return encoded
513
+
514
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
515
+ return Template._post_encode(self, model, inputs)
516
+
517
+ def _get_position_ids(self, inputs: Dict[str, Any]):
518
+ feature_attention_mask = inputs.get('feature_attention_mask')
519
+ if feature_attention_mask is not None:
520
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
521
+ else:
522
+ audio_feature_lengths = None
523
+ video_second_per_grid = inputs.pop('video_second_per_grid', None)
524
+ input_ids = inputs['input_ids']
525
+ attention_mask = inputs.get('attention_mask')
526
+ if attention_mask is None:
527
+ attention_mask = torch.ones_like(input_ids)
528
+ position_ids, _ = self.model.thinker.get_rope_index(
529
+ input_ids,
530
+ inputs.get('image_grid_thw'),
531
+ inputs.get('video_grid_thw'),
532
+ attention_mask,
533
+ self.use_audio_in_video,
534
+ audio_feature_lengths,
535
+ video_second_per_grid,
536
+ )
537
+ return position_ids.contiguous()
538
+
539
+ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
540
+ res = super()._data_collator_mm_data(batch)
541
+ video_second_per_grid = self.gather_list(batch, 'video_second_per_grid')
542
+ if video_second_per_grid:
543
+ res['video_second_per_grid'] = video_second_per_grid
544
+ input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
545
+ feature_attention_mask = [
546
+ b['feature_attention_mask'] for b in batch if b.get('feature_attention_mask') is not None
547
+ ]
548
+ if input_features:
549
+ res['input_features'] = torch.concat(input_features)
550
+ res['feature_attention_mask'] = torch.concat(feature_attention_mask)
551
+ return res
552
+
553
+ def generate(self, model, *args, **kwargs):
554
+ if kwargs.get('video_grid_thw') is not None:
555
+ kwargs['use_audio_in_video'] = self.use_audio_in_video
556
+ return super().generate(model, *args, **kwargs)
557
+
558
+
559
+ register_template(QwenTemplateMeta(MLLMTemplateType.qwen2_5_omni, template_cls=Qwen2_5OmniTemplate))
560
+
561
+
562
+ class Ovis1_6Template(Template):
563
+ skip_prompt = False
564
+ use_model = True
565
+
566
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
567
+ inputs: StdTemplateInputs) -> List[Context]:
568
+ assert media_type == 'image'
569
+ return [[-200], '\n']
570
+
571
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
572
+ encoded = super()._encode(inputs)
573
+ images = inputs.images
574
+ input_ids = encoded['input_ids']
575
+ labels = encoded['labels']
576
+ idx_list = findall(input_ids, [-200])
577
+ added_tokens_len = 0
578
+ pixel_values = []
579
+ for i, idx in enumerate(idx_list):
580
+ max_partition = get_env_args('max_partition', int, 9)
581
+ raw_pixel_values, image_placeholders = self.model.visual_tokenizer.preprocess_image(
582
+ images[i], max_partition=max_partition)
583
+ input_ids = input_ids[:idx] + image_placeholders + input_ids[idx + 1:]
584
+ if labels is not None:
585
+ labels = labels[:idx] + [-100] * len(image_placeholders) + labels[idx + 1:]
586
+ pixel_values.append(raw_pixel_values)
587
+ added_tokens_len += len(image_placeholders) - 1
588
+ dtype = self.model.visual_tokenizer.dtype
589
+ if pixel_values:
590
+ pixel_values = torch.cat(pixel_values, dim=0).to(dtype)
591
+ else:
592
+ pixel_values = torch.zeros((1, 3, 384, 384), dtype=dtype) # dummpy
593
+ encoded.update({'input_ids': input_ids, 'labels': labels})
594
+ encoded['pixel_values'] = [pixel_values]
595
+ return encoded
596
+
597
+ def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
598
+ padding_side = self.padding_side if self.is_training else 'left'
599
+ if self.max_length is not None:
600
+ model.config.multimodal_max_length = self.max_length
601
+ input_ids = inputs['input_ids']
602
+ labels = inputs.get('labels')
603
+ if labels is None:
604
+ labels = input_ids.new_full(input_ids.shape, -100)
605
+ _, inputs_embeds, labels, attention_mask = model.merge_multimodal(
606
+ text_input_ids=input_ids,
607
+ text_attention_masks=torch.ones_like(input_ids), # not use, only compat
608
+ text_labels=labels,
609
+ pixel_values=inputs['pixel_values'],
610
+ left_padding=padding_side == 'left')
611
+ if inputs.get('labels') is None:
612
+ labels = None
613
+ return {'inputs_embeds': inputs_embeds, 'labels': labels, 'attention_mask': attention_mask}
614
+
615
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
616
+ pixel_values = self.gather_list(batch, 'pixel_values')
617
+ res = super()._data_collator(batch, padding_to=padding_to)
618
+ res['pixel_values'] = pixel_values
619
+ return res
620
+
621
+
622
+ register_template(
623
+ TemplateMeta(
624
+ MLLMTemplateType.ovis1_6,
625
+ prefix=['<bos>'],
626
+ prompt=['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
627
+ chat_sep=['<end_of_turn>\n'],
628
+ suffix=['<end_of_turn>'],
629
+ system_prefix=['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'],
630
+ template_cls=Ovis1_6Template,
631
+ ))
632
+
633
+ register_template(
634
+ Llama3TemplateMeta(
635
+ MLLMTemplateType.ovis1_6_llama3,
636
+ default_system='You are a helpful and honest multimodal assistant.',
637
+ template_cls=Ovis1_6Template,
638
+ ))
639
+
640
+
641
+ class Ovis2Template(Ovis1_6Template):
642
+ placeholder_tokens = ['<|image_pad|>', '<|video_pad|>']
643
+ nframes = 12
644
+
645
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
646
+ inputs: StdTemplateInputs) -> List[Context]:
647
+ if media_type == 'image':
648
+ return [[-200], '\n']
649
+ elif media_type == 'video':
650
+ nframes = get_env_args('nframes', int, self.nframes)
651
+ inputs.images = load_video_ovis2(inputs.videos[index], nframes)
652
+ return [[-200] * nframes, '\n']
653
+
654
+
655
+ register_template(QwenTemplateMeta(
656
+ MLLMTemplateType.ovis2,
657
+ template_cls=Ovis2Template,
658
+ ))
659
+
660
+
661
+ @dataclass
662
+ class MarcoO1TemplateMeta(QwenTemplateMeta):
663
+ default_system: Optional[str] = """
664
+ 你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
665
+ \n## 重要!!!!!
666
+ 当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
667
+ <Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
668
+ """
669
+
670
+
671
+ register_template(MarcoO1TemplateMeta(LLMTemplateType.marco_o1))
ms-swift/swift/llm/template/template/stepfun.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ from ..base import Template
5
+ from ..constant import MLLMTemplateType
6
+ from ..register import TemplateMeta, register_template
7
+ from ..template_inputs import StdTemplateInputs
8
+ from ..utils import Context
9
+ from ..vision_utils import load_file
10
+ from .qwen import QwenTemplateMeta
11
+
12
+
13
+ class GOTImageEvalProcessor:
14
+
15
+ def __init__(self, image_size=384, mean=None, std=None):
16
+ from torchvision import transforms
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ if mean is None:
19
+ mean = (0.48145466, 0.4578275, 0.40821073)
20
+ if std is None:
21
+ std = (0.26862954, 0.26130258, 0.27577711)
22
+
23
+ self.normalize = transforms.Normalize(mean, std)
24
+
25
+ self.transform = transforms.Compose([
26
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
27
+ transforms.ToTensor(),
28
+ self.normalize,
29
+ ])
30
+
31
+ def __call__(self, item):
32
+ return self.transform(item)
33
+
34
+
35
+ class GOT_OCR2Template(Template):
36
+ placeholder_tokens = ['<imgpad>']
37
+
38
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
39
+ inputs: StdTemplateInputs) -> List[Context]:
40
+ # 'OCR: '
41
+ # 'OCR with format: '
42
+ assert media_type == 'image'
43
+ return ['<img>' + '<imgpad>' * 256 + '</img>\n']
44
+
45
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
46
+ encoded = super()._encode(inputs)
47
+ images = inputs.images
48
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
49
+ for i, image in enumerate(images):
50
+ images[i] = image_processor_high(image)[None].to(self.model_info.torch_dtype)
51
+ if images:
52
+ encoded['images'] = images
53
+ return encoded
54
+
55
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
56
+ res = super()._data_collator(batch, padding_to=padding_to)
57
+ images = self.gather_list(batch, 'images')
58
+ if images:
59
+ res['images'] = images
60
+ return res
61
+
62
+
63
+ register_template(
64
+ QwenTemplateMeta(
65
+ MLLMTemplateType.got_ocr2,
66
+ default_system=' You should follow the instructions carefully and explain your answers in detail.',
67
+ template_cls=GOT_OCR2Template,
68
+ ))
69
+
70
+
71
+ class GOT_OCR2HfTemplate(Template):
72
+ placeholder_tokens = ['<imgpad>']
73
+
74
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
75
+ inputs: StdTemplateInputs) -> List[Context]:
76
+ # 'OCR: '
77
+ # 'OCR with format: '
78
+ assert media_type == 'image'
79
+ return ['<img>' + '<imgpad>' * 256 + '</img>\n']
80
+
81
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面
82
+ encoded = super()._encode(inputs)
83
+ images = inputs.images
84
+ if images:
85
+ encoded['images'] = images
86
+ return encoded
87
+
88
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
89
+ res = super()._data_collator(batch, padding_to=padding_to)
90
+ images = self.gather_list(batch, 'images')
91
+ _inputs = self.processor(images, return_tensors='pt')
92
+ _inputs.pop('input_ids') # this does not contain the response, so cannot be used when training
93
+ _inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training
94
+
95
+ res.update(_inputs.data)
96
+ return res
97
+
98
+
99
+ register_template(
100
+ QwenTemplateMeta(
101
+ MLLMTemplateType.got_ocr2_hf,
102
+ default_system=' You should follow the instructions carefully and explain your answers in detail.',
103
+ template_cls=GOT_OCR2HfTemplate,
104
+ ))
105
+
106
+
107
+ class StepAudioTemplate(Template):
108
+ use_model = True
109
+
110
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
111
+ inputs: StdTemplateInputs) -> List[Context]:
112
+ assert media_type == 'audio', f'media_type: {media_type}'
113
+ from utils import load_audio
114
+ audio_wav, sr = load_audio(load_file(inputs.audios[index]))
115
+ audio_tokens = self.model.encoder(audio_wav, sr)
116
+ return audio_tokens
117
+
118
+
119
+ register_template(
120
+ TemplateMeta(
121
+ MLLMTemplateType.step_audio,
122
+ template_cls=StepAudioTemplate,
123
+ prefix=['<s>'],
124
+ prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'],
125
+ system_prefix=['<s><|BOT|>system\n{{SYSTEM}}<|EOT|>'],
126
+ chat_sep=['<|EOT|>'],
127
+ suffix=['<|EOT|>'],
128
+ ))
ms-swift/swift/llm/template/template/yi.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import torch
5
+
6
+ from ..base import Template
7
+ from ..constant import LLMTemplateType, MLLMTemplateType
8
+ from ..register import TemplateMeta, register_template
9
+ from ..template_inputs import StdTemplateInputs
10
+ from .utils import DEFAULT_SYSTEM, ChatmlTemplateMeta
11
+
12
+ register_template(ChatmlTemplateMeta(
13
+ LLMTemplateType.yi_coder,
14
+ default_system=DEFAULT_SYSTEM,
15
+ ))
16
+
17
+ yi_vl_default_system = (
18
+ 'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
19
+ "Read all the images carefully, and respond to the human's questions with informative, "
20
+ 'helpful, detailed and polite answers. '
21
+ '这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
22
+ '仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
23
+
24
+
25
+ class YiVLTemplate(Template):
26
+ image_placeholder = [[-200], '\n']
27
+ use_model = True
28
+
29
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
30
+ encoded = super()._encode(inputs)
31
+ model = self.model
32
+ from llava.mm_utils import expand2square
33
+ if not hasattr(model, 'vision_tower'):
34
+ model = model.model
35
+ image_processor = model.vision_tower.image_processor
36
+ images = inputs.images or []
37
+ for i, image in enumerate(images):
38
+ background_color = tuple(int(x * 255) for x in image_processor.image_mean)
39
+ image = expand2square(image, background_color)
40
+ images[i] = image
41
+ if images:
42
+ image_tensor = image_processor.preprocess(images, return_tensors='pt')['pixel_values']
43
+ encoded['images'] = image_tensor.to(model.dtype)
44
+ return encoded
45
+
46
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
47
+ res = super()._data_collator(batch, padding_to=padding_to)
48
+ images = [b['images'] for b in batch if 'images' in b]
49
+ if images:
50
+ res['images'] = torch.concat(images)
51
+ return res
52
+
53
+
54
+ register_template(
55
+ TemplateMeta(
56
+ MLLMTemplateType.yi_vl,
57
+ prefix=[],
58
+ prompt=[[8308], ' Human: {{QUERY}}\n', [8308], ' Assistant:'],
59
+ chat_sep=['\n'],
60
+ suffix=['\n', [8308]],
61
+ default_system=yi_vl_default_system,
62
+ template_cls=YiVLTemplate,
63
+ system_prefix=['{{SYSTEM}}\n\n']))
ms-swift/swift/llm/train/__pycache__/callback.cpython-310.pyc ADDED
Binary file (3.11 kB). View file
 
ms-swift/swift/llm/train/__pycache__/rlhf.cpython-310.pyc ADDED
Binary file (4.6 kB). View file
 
ms-swift/swift/llm/train/__pycache__/sft.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
ms-swift/swift/llm/train/__pycache__/tuner.cpython-310.pyc ADDED
Binary file (13.5 kB). View file
 
ms-swift/swift/llm/train/callback.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import types
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import TrainerCallback
7
+
8
+ from swift.utils import get_logger
9
+
10
+ logger = get_logger()
11
+
12
+
13
+ class TrainerAdapterCallback(TrainerCallback):
14
+
15
+ def __init__(self, args):
16
+ self.global_step = 0
17
+ self.args = args
18
+
19
+ # offload original_modules to cpu, to save memory
20
+ def on_train_begin(self, _args, state, control, **kwargs):
21
+ model = kwargs['model']
22
+ if self.args.train_type == 'adalora':
23
+ model.peft_config['default'].total_step = state.max_steps
24
+
25
+ def zero_grad(_self, *args, **kwargs):
26
+ _self.update_and_allocate(self.global_step + 1)
27
+ _self._zero_grad(*args, **kwargs)
28
+
29
+ model._zero_grad = model.zero_grad
30
+ model.zero_grad = types.MethodType(zero_grad, model)
31
+
32
+ def on_step_end(self, _args, state, control, **kwargs):
33
+ if self.args.train_type == 'adalora':
34
+ self.global_step = state.global_step
35
+
36
+
37
+ class DynamicLayerActivationCallback(TrainerCallback):
38
+
39
+ def __init__(self, n_layers: int, step_interval: int, model: torch.nn.Module):
40
+ super().__init__()
41
+ self.n_layers = n_layers
42
+ self.step_interval = step_interval
43
+ self.model = model
44
+ layers_name = None
45
+ layers = None
46
+ for name, module in model.named_modules():
47
+ if isinstance(module, torch.nn.ModuleList):
48
+ layers_name = name
49
+ layers = module
50
+ break
51
+ assert layers_name is not None
52
+ self.layers_attribute = layers_name
53
+ self.total_layers = len(layers)
54
+
55
+ # Freeze all layers upon initialization
56
+ self.freeze_all_layers()
57
+ self.active_layers_indices = []
58
+
59
+ def freeze_all_layers(self):
60
+ layers = self.model.get_submodule(self.layers_attribute)
61
+ for layer in layers:
62
+ for param in layer.parameters():
63
+ param.requires_grad = False
64
+
65
+ def on_step_begin(self, args, state, control, **kwargs):
66
+ # Check if it's time to switch active layers, including at step 0
67
+ if state.global_step % self.step_interval == 0 or state.global_step == 1:
68
+ self.switch_active_layers()
69
+
70
+ def switch_active_layers(self):
71
+ # First, disable gradients for all layers
72
+ self.freeze_all_layers()
73
+
74
+ # Randomly select n_layers to activate
75
+ layers = self.model.get_submodule(self.layers_attribute)
76
+ self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
77
+ # Enable gradients only for the selected layers
78
+ for idx in self.active_layers_indices:
79
+ for param in layers[idx].parameters():
80
+ param.requires_grad = True
ms-swift/swift/llm/train/rlhf.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from typing import List, Union
4
+
5
+ from swift.llm import safe_snapshot_download
6
+ from swift.utils import get_logger, get_model_parameter_info
7
+ from ..argument import BaseArguments, RLHFArguments
8
+ from ..model import HfConfigFactory
9
+ from .kto import prepare_kto_dataset
10
+ from .sft import SwiftSft
11
+
12
+ logger = get_logger()
13
+
14
+
15
+ class SwiftRLHF(SwiftSft):
16
+ args_class = RLHFArguments
17
+ args: args_class
18
+
19
+ def _prepare_model_tokenizer(self):
20
+ if self.args.sequence_parallel_size > 1:
21
+ # Duplicate calling is allowd to promise this function will
22
+ # be called before model initializing.
23
+ from swift.trainers.sequence_parallel import sequence_parallel
24
+ sequence_parallel.init_sequence_parallel(self.args.sequence_parallel_size)
25
+ # prepare ref/reward/value model
26
+ from swift.llm.infer.utils import prepare_adapter
27
+ args = self.args
28
+
29
+ def prepare_single_model(key, origin_key=None):
30
+ origin_key = origin_key or key
31
+ model_id_or_path = getattr(args, f'{key}_model')
32
+ if model_id_or_path is None:
33
+ return None
34
+
35
+ model_type = getattr(args, f'{key}_model_type')
36
+ model_revision = getattr(args, f'{key}_model_revision')
37
+ model_dir = safe_snapshot_download(
38
+ model_id_or_path=model_id_or_path,
39
+ revision=model_revision,
40
+ download_model=False,
41
+ use_hf=args.use_hf,
42
+ hub_token=args.hub_token,
43
+ )
44
+ task_type = None
45
+ num_labels = None
46
+ if os.path.exists(os.path.join(model_dir, 'args.json')):
47
+ model_args = BaseArguments.from_pretrained(model_dir)
48
+ if hasattr(model_args, 'task_type'):
49
+ task_type = model_args.task_type
50
+ else:
51
+ from transformers import AutoConfig
52
+ model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
53
+ if hasattr(model_config, 'num_labels'):
54
+ num_labels = model_config.num_labels
55
+ if task_type == 'seq_cls':
56
+ num_labels = 1
57
+
58
+ model, processor = args.get_model_processor(
59
+ model=model_id_or_path,
60
+ model_type=model_type,
61
+ model_revision=model_revision,
62
+ task_type=task_type,
63
+ num_labels=num_labels)
64
+
65
+ adapters = args.adapters if key == 'ref' else args.reward_adapters
66
+ model = prepare_adapter(args, model, adapters)
67
+ if origin_key in {'ref', 'reward'}:
68
+ if self.args.sequence_parallel_size > 1:
69
+ from swift.trainers.sequence_parallel import sequence_parallel
70
+ if hasattr(model, 'model_meta'):
71
+ is_multimodal = model.model_meta.is_multimodal
72
+ else:
73
+ is_multimodal = model.model.model_meta.is_multimodal
74
+ sequence_parallel.prepare_model(model, processor, split_in_forward=is_multimodal)
75
+ model.requires_grad_(False).eval()
76
+ else:
77
+ model = self.prepare_model(args, model, task_type=task_type)
78
+ logger.info(f'value_model: {model}')
79
+ model_parameter_info = get_model_parameter_info(model)
80
+ self.train_msg['value_model_parameter_info'] = model_parameter_info
81
+ logger.info(f'value_model_parameter_info: {model_parameter_info}')
82
+
83
+ HfConfigFactory.set_model_config_attr(model, 'use_cache', False)
84
+ return model, processor
85
+
86
+ # Handle ref and value models
87
+ for key in ['ref', 'value']:
88
+ setattr(self, f'{key}_model', None)
89
+ if key == 'value' and args.rlhf_type != 'ppo':
90
+ continue
91
+
92
+ model_key = 'reward' if key == 'value' else key
93
+ result = prepare_single_model(model_key, key)
94
+ if result is not None:
95
+ model, _ = result
96
+ setattr(self, f'{key}_model', model)
97
+
98
+ # Handle reward model(s)
99
+ self.reward_model = None
100
+ if hasattr(args, 'reward_model') and args.reward_model is not None:
101
+ reward_models = args.reward_model if isinstance(args.reward_model, list) else [args.reward_model]
102
+ self.reward_model = []
103
+ if args.rlhf_type == 'grpo':
104
+ self.reward_template = []
105
+
106
+ for reward_model_path in reward_models:
107
+ args.reward_model = reward_model_path # Temporarily set for prepare_single_model
108
+ result = prepare_single_model('reward')
109
+ if result is not None:
110
+ model, processor = result
111
+ self.reward_model.append(model)
112
+
113
+ if args.rlhf_type == 'grpo':
114
+ reward_template = self.args.get_template(processor, processor.model_meta.template)
115
+ if reward_template.use_model:
116
+ reward_template.model = model
117
+ self.reward_template.append(reward_template)
118
+ args.reward_model = reward_models # Restore original value
119
+
120
+ super()._prepare_model_tokenizer()
121
+
122
+ def _prepare_template(self) -> None:
123
+ args = self.args
124
+ super()._prepare_template()
125
+ model_mapping = {'kto': 'kto', 'ppo': 'pt', 'grpo': 'pt'}
126
+ self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))
127
+
128
+ if args.rlhf_type == 'ppo':
129
+ args.training_args.stop_token_id = self.template.template_meta.stop_token_id
130
+
131
+ def _get_dataset(self):
132
+ args = self.args
133
+ train_dataset, val_dataset = super()._get_dataset()
134
+ if args.rlhf_type == 'kto':
135
+ train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
136
+ return train_dataset, val_dataset
137
+
138
+ def _get_trainer_kwargs(self):
139
+ trainer_kwargs = {}
140
+ for key in ['ref', 'reward', 'value']:
141
+ key = f'{key}_model'
142
+ model = getattr(self, key, None)
143
+ if model or self.args.rlhf_type == 'ppo':
144
+ trainer_kwargs[key] = model
145
+ if hasattr(self, 'reward_template'):
146
+ trainer_kwargs['reward_template'] = self.reward_template
147
+ if self.args.rlhf_type == 'grpo':
148
+ trainer_kwargs['reward_funcs'] = self.args.reward_funcs
149
+ trainer_kwargs['vllm_client'] = self.args.vllm_client
150
+ return trainer_kwargs
151
+
152
+
153
+ def rlhf_main(args: Union[List[str], RLHFArguments, None] = None):
154
+ return SwiftRLHF(args).main()
ms-swift/swift/llm/train/sft.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from functools import partial
4
+ from typing import List, Union
5
+
6
+ from datasets import Dataset as HfDataset
7
+
8
+ from swift.plugin import extra_callbacks, get_loss_func, get_metric
9
+ from swift.trainers import TrainerFactory
10
+ from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array,
11
+ use_torchacc)
12
+ from ..argument import TrainArguments
13
+ from ..base import SwiftPipeline
14
+ from ..dataset import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset,
15
+ PackingDataset, load_dataset)
16
+ from ..infer import prepare_generation_config
17
+ from ..model import HfConfigFactory, get_model_arch
18
+ from ..utils import deep_getattr, dynamic_gradient_checkpointing
19
+ from .tuner import TunerMixin
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ class SwiftSft(SwiftPipeline, TunerMixin):
25
+ args_class = TrainArguments
26
+ args: args_class
27
+
28
+ def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None:
29
+ super().__init__(args)
30
+ self.train_msg = {}
31
+ self._prepare_model_tokenizer()
32
+ self._prepare_template()
33
+ self._prepare_callbacks()
34
+
35
+ def _prepare_gradient_checkpointing(self):
36
+ args = self.args
37
+ HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False)
38
+ if args.gradient_checkpointing:
39
+ self.model.supports_gradient_checkpointing = True
40
+ dynamic_gradient_checkpointing(self.model)
41
+ self.model.enable_input_require_grads()
42
+ model_meta = self.model.model_meta
43
+ model_arch = get_model_arch(model_meta.model_arch)
44
+ if model_meta.is_multimodal and model_arch:
45
+ for vision_tower_name in model_arch.vision_tower:
46
+ vision_tower = deep_getattr(self.model, vision_tower_name)
47
+ if hasattr(vision_tower, 'enable_input_require_grads'):
48
+ try:
49
+ vision_tower.enable_input_require_grads()
50
+ except NotImplementedError:
51
+ pass
52
+
53
+ def _prepare_generation_config(self):
54
+ args = self.args
55
+ self.model.origin_generation_config = self.model.generation_config
56
+ self.model.generation_config = prepare_generation_config(self.model.generation_config,
57
+ args.get_request_config(), self.tokenizer)
58
+ logger.info(f'model.generation_config: {self.model.generation_config}')
59
+
60
+ def _prepare_model_tokenizer(self):
61
+ args = self.args
62
+ if args.sequence_parallel_size > 1:
63
+ from swift.trainers.sequence_parallel import sequence_parallel
64
+ sequence_parallel.init_sequence_parallel(args.sequence_parallel_size)
65
+ self.model, self.processor = args.get_model_processor()
66
+
67
+ if hasattr(self.model, 'hf_device_map'):
68
+ logger.info(f'model.hf_device_map: {self.model.hf_device_map}')
69
+
70
+ logger.info(f'model_info: {self.model.model_info}')
71
+
72
+ self._prepare_generation_config()
73
+ self._prepare_gradient_checkpointing()
74
+
75
+ def _prepare_template(self) -> None:
76
+ template = self.args.get_template(self.processor)
77
+ if self.args.task_type == 'causal_lm':
78
+ template.set_mode('train')
79
+ if template.use_model:
80
+ template.model = self.model
81
+ self.template = template
82
+
83
+ def _get_dataset(self):
84
+ # The random shuffling of the training set occurs in the dataloader of the trainer.
85
+ args = self.args
86
+ dataset_kwargs = args.get_dataset_kwargs()
87
+ train_dataset, val_dataset = load_dataset(
88
+ args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs)
89
+ if len(args.val_dataset) > 0:
90
+ # Loading val dataset
91
+ _, val_dataset = load_dataset(
92
+ args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs)
93
+ assert args.split_dataset_ratio == 0.
94
+ logger.info(f'train_dataset: {train_dataset}')
95
+ logger.info(f'val_dataset: {val_dataset}')
96
+
97
+ return train_dataset, val_dataset
98
+
99
+ def _get_loss_func(self):
100
+ args = self.args
101
+ loss_type = args.loss_type
102
+ if loss_type is None and args.loss_scale != 'default':
103
+ loss_type = 'loss_scale'
104
+ return get_loss_func(loss_type)
105
+
106
+ def _get_data_collator(self):
107
+ args = self.args
108
+ template = self.template
109
+ padding_to = args.max_length if args.train_type == 'longlora' else None
110
+ return partial(template.data_collator, padding_to=padding_to)
111
+
112
+ @staticmethod
113
+ def _save_val_dataset(output_dir: str, val_dataset):
114
+ if is_master() and isinstance(val_dataset, HfDataset):
115
+ os.makedirs(output_dir, exist_ok=True)
116
+ val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl')
117
+ append_to_jsonl(val_dataset_path, val_dataset.to_list())
118
+ logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.')
119
+
120
+ def run(self):
121
+ args = self.args
122
+
123
+ train_dataset, val_dataset = self._get_dataset()
124
+ train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
125
+
126
+ if args.task_type == 'seq_cls':
127
+ args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None)
128
+ logger.info(f'args.problem_type: {args.problem_type}')
129
+ args.save_args()
130
+
131
+ data_collator = self._get_data_collator()
132
+ # Some tuners require train_dataset and data_collator for preparation: LoRA-GA
133
+ self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset)
134
+ logger.info(f'model: {self.model}')
135
+ model_parameter_info = get_model_parameter_info(self.model)
136
+ self.train_msg['model_parameter_info'] = model_parameter_info
137
+ logger.info(f'model_parameter_info: {model_parameter_info}')
138
+
139
+ trainer_cls = TrainerFactory.get_trainer_cls(args)
140
+ trainer = trainer_cls(
141
+ model=self.model,
142
+ args=self.args.training_args,
143
+ data_collator=data_collator,
144
+ train_dataset=train_dataset,
145
+ eval_dataset=val_dataset,
146
+ callbacks=self.callbacks,
147
+ template=self.template,
148
+ **self._get_trainer_kwargs(),
149
+ )
150
+ return self.train(trainer)
151
+
152
+ def _get_trainer_kwargs(self):
153
+ args = self.args
154
+ if args.metric is not None:
155
+ compute_metrics, preprocess_logits_for_metrics = get_metric(args.metric)
156
+ elif args.predict_with_generate:
157
+ compute_metrics, preprocess_logits_for_metrics = get_metric('nlg')
158
+ else:
159
+ compute_metrics, preprocess_logits_for_metrics = get_metric('acc')
160
+ compute_metrics = partial(
161
+ compute_metrics, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder)
162
+ return {
163
+ 'compute_metrics': compute_metrics,
164
+ 'preprocess_logits_for_metrics': preprocess_logits_for_metrics,
165
+ 'compute_loss_func': self._get_loss_func()
166
+ }
167
+
168
+ def _save_trainer_state(self, trainer):
169
+ training_args = trainer.args
170
+ state = trainer.state
171
+ if hasattr(state, 'last_model_checkpoint'):
172
+ if self.args.create_checkpoint_symlink:
173
+ last_checkpoint = os.path.join(self.args.output_dir, 'last')
174
+ best_checkpoint = os.path.join(self.args.output_dir, 'best')
175
+ os.symlink(state.last_model_checkpoint, last_checkpoint)
176
+ os.symlink(state.best_model_checkpoint, best_checkpoint)
177
+ state.last_model_checkpoint = last_checkpoint
178
+ state.best_model_checkpoint = best_checkpoint
179
+ else:
180
+ state.last_model_checkpoint = None
181
+ logger.warning('No training was carried out, which may be due to the dataset being too small '
182
+ 'or incorrect usage of resume_from_checkpoint.')
183
+ logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}')
184
+ logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}')
185
+
186
+ # Visualization
187
+ if is_master() and not use_torchacc():
188
+ if 'tensorboard' in training_args.report_to:
189
+ images_dir = os.path.join(training_args.output_dir, 'images')
190
+ logger.info(f'images_dir: {images_dir}')
191
+ plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9)
192
+ if training_args.push_to_hub:
193
+ trainer.push_to_hub()
194
+
195
+ self.train_msg.update({
196
+ 'last_model_checkpoint': state.last_model_checkpoint,
197
+ 'best_model_checkpoint': state.best_model_checkpoint,
198
+ 'best_metric': state.best_metric,
199
+ 'global_step': state.global_step,
200
+ 'log_history': state.log_history,
201
+ 'memory': trainer.max_memory,
202
+ })
203
+ if is_master():
204
+ jsonl_path = os.path.join(training_args.output_dir, 'logging.jsonl')
205
+ append_to_jsonl(jsonl_path, self.train_msg)
206
+ return self.train_msg
207
+
208
+ def train(self, trainer):
209
+ logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl')
210
+ logger.info(f'The logging file will be saved in: {logging_path}')
211
+ try:
212
+ trainer.train(trainer.args.resume_from_checkpoint)
213
+ finally:
214
+ res = self._save_trainer_state(trainer)
215
+ return res
216
+
217
+ def _prepare_callbacks(self):
218
+ from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback
219
+ args = self.args
220
+ callbacks = []
221
+ if args.lisa_activated_layers > 0:
222
+ assert args.train_type == 'full', 'LISA only supports full parameter training.'
223
+ lisa_callback = DynamicLayerActivationCallback(
224
+ n_layers=args.lisa_activated_layers, # Number of layers to activate
225
+ step_interval=args.lisa_step_interval, # Step interval to update active layers
226
+ model=self.model)
227
+ lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value
228
+ callbacks.append(lisa_callback)
229
+
230
+ if args.is_adapter and args.train_type == 'adalora':
231
+ callbacks.append(TrainerAdapterCallback(args))
232
+ callbacks += extra_callbacks
233
+ self.callbacks = callbacks
234
+
235
+ def _stat_dataset(self, dataset: HfDataset):
236
+ args = self.args
237
+ if isinstance(dataset, HfDataset):
238
+ dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc)
239
+ length = dataset['length']
240
+ else:
241
+ length = []
242
+ for row in dataset:
243
+ length.append(max([len(row[k]) for k in row.keys() if k.endswith('input_ids')]))
244
+ _, stat_str = stat_array(length)
245
+ logger.info(f'Dataset Token Length: {stat_str}')
246
+ return stat_str
247
+
248
+ def _encode_dataset(self, train_dataset, val_dataset):
249
+ template = self.template
250
+ args = self.args
251
+ output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save')
252
+ self._save_val_dataset(output_dir, val_dataset)
253
+ is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
254
+ predict_with_generate = getattr(args, 'predict_with_generate', False)
255
+ if not is_grpo:
256
+ if args.packing:
257
+ packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset
258
+ train_dataset = packing_dataset_cls(
259
+ self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
260
+ if val_dataset is not None:
261
+ val_dataset = packing_dataset_cls(
262
+ self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
263
+ elif args.lazy_tokenize:
264
+ train_dataset = LazyLLMDataset(
265
+ train_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
266
+ if val_dataset is not None and not predict_with_generate:
267
+ val_dataset = LazyLLMDataset(
268
+ val_dataset, template.encode, strict=args.strict, random_state=args.data_seed)
269
+ else:
270
+ preprocessor = EncodePreprocessor(template=template)
271
+ train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
272
+ if val_dataset is not None and not predict_with_generate:
273
+ val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict)
274
+
275
+ if is_master():
276
+ inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset))
277
+ template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {})
278
+ if isinstance(train_dataset, (HfDataset, PackingDataset)):
279
+ self.train_msg['train_dataset'] = self._stat_dataset(train_dataset)
280
+ if val_dataset is not None and not predict_with_generate:
281
+ self.train_msg['val_dataset'] = self._stat_dataset(val_dataset)
282
+
283
+ return train_dataset, val_dataset
284
+
285
+
286
+ def sft_main(args: Union[List[str], TrainArguments, None] = None):
287
+ return SwiftSft(args).main()
ms-swift/swift/llm/train/tuner.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import inspect
3
+ import os
4
+ from typing import List, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import transformers
9
+ from packaging import version
10
+ from transformers import TrainingArguments
11
+
12
+ from swift.llm import TrainArguments, deep_getattr, get_model_arch
13
+ from swift.plugin import Tuner, extra_tuners
14
+ from swift.tuners import Swift
15
+ from swift.utils import (activate_parameters, find_all_linears, find_embedding, find_norm, freeze_parameters,
16
+ get_logger, use_torchacc)
17
+
18
+ logger = get_logger()
19
+
20
+
21
+ def apply_liger(model_type: str):
22
+ from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral,
23
+ apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma,
24
+ apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen3,
25
+ apply_liger_kernel_to_qwen2_vl, apply_liger_kernel_to_qwen2_5_vl,
26
+ apply_liger_kernel_to_phi3, apply_liger_kernel_to_mllama)
27
+ from swift.llm import ModelType
28
+ if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2):
29
+ apply_liger_kernel_to_llama()
30
+ elif model_type in (ModelType.mistral):
31
+ apply_liger_kernel_to_mistral()
32
+ elif model_type in (ModelType.mixtral):
33
+ apply_liger_kernel_to_mixtral()
34
+ elif model_type in (ModelType.gemma, ModelType.gemma2):
35
+ apply_liger_kernel_to_gemma()
36
+ elif model_type in (ModelType.qwen2, ModelType.qwen2_5):
37
+ apply_liger_kernel_to_qwen2()
38
+ elif model_type in (ModelType.qwen3):
39
+ apply_liger_kernel_to_qwen3()
40
+ elif model_type in (ModelType.phi3):
41
+ apply_liger_kernel_to_phi3()
42
+ elif model_type in (ModelType.llama3_2_vision):
43
+ apply_liger_kernel_to_mllama()
44
+ elif model_type in (ModelType.qwen2_vl):
45
+ apply_liger_kernel_to_qwen2_vl()
46
+ elif model_type in (ModelType.qwen2_5_vl):
47
+ apply_liger_kernel_to_qwen2_5_vl()
48
+ else:
49
+ raise ValueError(f'Unsupported liger model_type: {model_type}')
50
+
51
+
52
+ def get_multimodal_target_regex(
53
+ model,
54
+ *,
55
+ freeze_llm: bool = False,
56
+ freeze_vit: bool = True,
57
+ freeze_aligner: bool = True,
58
+ include_embedding: bool = False,
59
+ ) -> str:
60
+ model_arch = get_model_arch(model.model_meta.model_arch)
61
+ modules = []
62
+ if not freeze_llm:
63
+ modules += model_arch.language_model
64
+ if not freeze_vit:
65
+ modules += model_arch.vision_tower
66
+ if not freeze_aligner:
67
+ modules += model_arch.aligner
68
+ assert len(modules) > 0, f'modules: {modules}'
69
+
70
+ extra_layers = []
71
+ if include_embedding:
72
+ extra_layers.append(nn.Embedding)
73
+ res = []
74
+ for module in modules:
75
+ rejected_modules = []
76
+ if not freeze_vit:
77
+ for aligner in model_arch.aligner:
78
+ if aligner.startswith(f'{module}.'):
79
+ rejected_modules.append(aligner)
80
+
81
+ sub_module = deep_getattr(model, module)
82
+ target_modules = find_all_linears(sub_module, model_arch, extra_layers)
83
+ target_modules = [tm for tm in target_modules if tm]
84
+ target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else ''
85
+ rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else ''
86
+ res.append(rf'{rejected_pattern}{module}{target_pattern}')
87
+
88
+ return rf'^({"|".join(res)})$'
89
+
90
+
91
+ def get_target_modules(args, model) -> Union[str, List[str]]:
92
+ """Replace all-linear to actual modules"""
93
+ model_meta = model.model_meta
94
+ if isinstance(args.target_modules, str):
95
+ return args.target_modules
96
+ target_modules = args.target_modules.copy()
97
+ if 'all-linear' in target_modules:
98
+ if model_meta.is_multimodal:
99
+ return get_multimodal_target_regex(
100
+ model,
101
+ freeze_llm=args.freeze_llm,
102
+ freeze_vit=args.freeze_vit,
103
+ freeze_aligner=args.freeze_aligner,
104
+ include_embedding='all-embedding' in target_modules)
105
+ else:
106
+ target_modules.remove('all-linear')
107
+ target_modules += find_all_linears(model)
108
+ if 'all-embedding' in target_modules:
109
+ target_modules.remove('all-embedding')
110
+ target_modules += find_embedding(model)
111
+ return target_modules
112
+
113
+
114
+ def get_modules_to_save(args, model, task_type=None):
115
+ modules_to_save = args.modules_to_save.copy()
116
+ if 'all-embedding' in args.modules_to_save:
117
+ modules_to_save.remove('all-embedding')
118
+ modules_to_save += find_embedding(model)
119
+ if 'all-norm' in args.modules_to_save:
120
+ modules_to_save.remove('all-norm')
121
+ modules_to_save += find_norm(model)
122
+ if task_type and task_type.lower() == 'seq_cls': # reward_model
123
+ modules_to_save.append('v_head')
124
+ return modules_to_save
125
+
126
+
127
+ def get_vera_target_modules(model, config):
128
+ """This function is only useful on the vera tuner"""
129
+ target_modules = config.target_modules
130
+ modules_dict = {
131
+ name: module.weight.shape
132
+ for name, module in model.named_modules()
133
+ if isinstance(module, torch.nn.Linear) and any([t in name for t in target_modules])
134
+ } # only Linear for now
135
+ if len(set(modules_dict.values())) > 1:
136
+ v = [t for t in target_modules if 'v' in t]
137
+ if not v:
138
+ raise ValueError('Please manually pass in `vera_target_modules`, do not use `all-linear`,'
139
+ 'because Vera need all target linears to be the same size.')
140
+ v = v[0]
141
+ shape = [shape for name, shape in modules_dict.items() if v in name][0]
142
+ names = [_name for _name, _shape in modules_dict.items() if _shape == shape]
143
+ config.target_modules = [t for t in target_modules if any([t in name for name in names])]
144
+ return config
145
+
146
+
147
+ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset=None, task_type=None):
148
+ from swift.tuners import (AdaLoraConfig, AdapterConfig, BOFTConfig, LLaMAProConfig, LongLoRAModelType, LoraConfig,
149
+ LoRAConfig, ReftConfig, Swift, VeraConfig)
150
+ task_type = (task_type or args.task_type).upper()
151
+ target_modules = get_target_modules(args, model)
152
+ modules_to_save = get_modules_to_save(args, model, task_type)
153
+ lora_kwargs = {
154
+ 'r': args.lora_rank,
155
+ 'target_modules': target_modules,
156
+ 'lora_alpha': args.lora_alpha,
157
+ 'lora_dropout': args.lora_dropout,
158
+ 'bias': args.lora_bias,
159
+ 'modules_to_save': modules_to_save,
160
+ 'use_rslora': args.use_rslora,
161
+ 'use_dora': args.use_dora,
162
+ 'lorap_lr_ratio': args.lorap_lr_ratio,
163
+ 'init_lora_weights': args.init_weights,
164
+ }
165
+ if args.train_type in ('lora', 'longlora'):
166
+ if args.use_swift_lora:
167
+ lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs)
168
+ model = Swift.prepare_model(model, lora_config)
169
+ logger.info(f'lora_config: {lora_config}')
170
+ elif args.tuner_backend == 'peft':
171
+ if task_type == 'EMBEDDING':
172
+ task_type = None
173
+ lora_config = LoraConfig(task_type=task_type, lora_dtype=args.lora_dtype, **lora_kwargs)
174
+ if args.init_weights == 'lora-ga':
175
+ try:
176
+ import lora_ga
177
+ except ImportError as e:
178
+ error_message = """
179
+ Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
180
+ Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
181
+ """
182
+ logger.info(error_message)
183
+ raise RuntimeError(error_message) from e
184
+ model = lora_ga.entrypoint.get_lora_ga_model(
185
+ model=model,
186
+ data_collator=template.data_collator,
187
+ dataset=train_dataset,
188
+ batch_size=args.lora_ga_batch_size,
189
+ num_iters=args.lora_ga_iters,
190
+ max_length=args.lora_ga_max_length,
191
+ direction=args.lora_ga_direction,
192
+ dtype=args.lora_dtype,
193
+ scale=args.lora_ga_scale,
194
+ stable_gamma=args.lora_ga_stable_gamma,
195
+ )
196
+ else:
197
+ model = Swift.prepare_model(model, lora_config)
198
+ logger.info(f'lora_config: {lora_config}')
199
+ elif args.tuner_backend == 'unsloth':
200
+ if args.resume_from_checkpoint is None:
201
+ if args.model_meta.is_multimodal:
202
+ from unsloth import FastVisionModel as UnslothModel
203
+ else:
204
+ from unsloth import FastLanguageModel as UnslothModel
205
+ assert args.train_type == 'lora', 'Unsloth does not support LongLoRA'
206
+ lora_kwargs.pop('lorap_lr_ratio')
207
+ model = UnslothModel.get_peft_model(
208
+ model,
209
+ use_gradient_checkpointing='unsloth',
210
+ max_seq_length=args.max_length or 2048, # 2048 is the default value of unsloth
211
+ **lora_kwargs,
212
+ )
213
+ logger.info(f'unsloth_config: {lora_kwargs}')
214
+ if args.train_type == 'longlora':
215
+ assert LongLoRAModelType.LLAMA in args.model_type
216
+ assert version.parse(transformers.__version__) >= version.parse('4.39.3')
217
+ from swift.tuners.longlora.llama import replace_llama_attn
218
+ replace_llama_attn(model)
219
+ model.config.group_size_ratio = 0.25
220
+ elif args.train_type == 'adalora':
221
+ lora_kwargs.pop('lorap_lr_ratio', None)
222
+ lora_kwargs['rank_pattern'] = None
223
+ from swift.plugin.optimizer import calculate_max_steps
224
+ adalora_config = AdaLoraConfig(
225
+ task_type=task_type,
226
+ **lora_kwargs,
227
+ target_r=args.adalora_target_r,
228
+ init_r=args.adalora_init_r,
229
+ tinit=args.adalora_tinit,
230
+ tfinal=args.adalora_tfinal,
231
+ deltaT=args.adalora_deltaT,
232
+ beta1=args.adalora_beta1,
233
+ beta2=args.adalora_beta2,
234
+ orth_reg_weight=args.adalora_orth_reg_weight,
235
+ total_step=calculate_max_steps(args.training_args, train_dataset),
236
+ )
237
+ model = Swift.prepare_model(model, adalora_config)
238
+ logger.info(f'adalora_config: {adalora_config}')
239
+ elif args.train_type == 'llamapro':
240
+ llamapro_config = LLaMAProConfig(
241
+ model_type=model.model_meta.model_arch,
242
+ num_new_blocks=args.llamapro_num_new_blocks,
243
+ num_groups=args.llamapro_num_groups)
244
+ model = Swift.prepare_model(model, llamapro_config)
245
+ logger.info(f'llamapro_config: {llamapro_config}')
246
+ elif args.train_type == 'adapter':
247
+ model_arch = get_model_arch(model.model_meta.model_arch)
248
+ mlp_key = model_arch.mlp
249
+ mlp_key = mlp_key.split('.{}.')[1]
250
+ adapter_config = AdapterConfig(
251
+ dim=model.config.hidden_size,
252
+ target_modules=[mlp_key],
253
+ hidden_pos=0,
254
+ adapter_length=args.adapter_length,
255
+ act_layer=args.adapter_act)
256
+ model = Swift.prepare_model(model, adapter_config)
257
+ logger.info(f'adapter_config: {adapter_config}')
258
+ elif args.train_type == 'vera':
259
+ vera_config = VeraConfig(
260
+ r=args.vera_rank,
261
+ target_modules=target_modules,
262
+ projection_prng_key=args.vera_projection_prng_key,
263
+ vera_dropout=args.vera_dropout,
264
+ d_initial=args.vera_d_initial,
265
+ modules_to_save=args.modules_to_save,
266
+ )
267
+ vera_config = get_vera_target_modules(model, vera_config)
268
+ model = Swift.prepare_model(model, vera_config)
269
+ logger.info(f'vera_config: {vera_config}')
270
+ elif args.train_type == 'boft':
271
+ boft_config = BOFTConfig(
272
+ boft_block_size=args.boft_block_size,
273
+ boft_block_num=args.boft_block_num,
274
+ boft_n_butterfly_factor=args.boft_n_butterfly_factor,
275
+ target_modules=target_modules,
276
+ boft_dropout=args.boft_dropout,
277
+ modules_to_save=args.modules_to_save,
278
+ )
279
+ model = Swift.prepare_model(model, boft_config)
280
+ logger.info(f'boft_config: {boft_config}')
281
+ elif args.train_type == 'fourierft':
282
+ from peft import FourierFTConfig
283
+ fourier_config = FourierFTConfig(
284
+ target_modules=target_modules,
285
+ modules_to_save=args.modules_to_save,
286
+ n_frequency=args.fourier_n_frequency,
287
+ scaling=args.fourier_scaling,
288
+ )
289
+ model = Swift.prepare_model(model, fourier_config)
290
+ logger.info(f'fourier_config: {fourier_config}')
291
+ elif args.train_type == 'reft':
292
+ reft_config = ReftConfig(
293
+ model_type=model.model_meta.model_arch,
294
+ layer_key=args.reft_layer_key,
295
+ r=args.reft_rank,
296
+ layers=args.reft_layers,
297
+ intervention_type=args.reft_intervention_type,
298
+ args=args.reft_args,
299
+ )
300
+ logger.info(f'reft config: {reft_config}')
301
+ model = Swift.prepare_model(model, {'reft': reft_config})
302
+ elif args.train_type == 'bone':
303
+ # Version loosing
304
+ from peft import BoneConfig
305
+ bone_config = BoneConfig(
306
+ target_modules=target_modules,
307
+ r=args.reft_rank,
308
+ init_weights=args.init_weights,
309
+ )
310
+ logger.info(f'bone config: {bone_config}')
311
+ model = Swift.prepare_model(model, bone_config)
312
+ return model
313
+
314
+
315
+ def torchacc_resume_from_checkpoint(args, model):
316
+ import safetensors
317
+ weights_file = os.path.join(args.resume_from_checkpoint, 'pytorch_model.bin')
318
+ safe_weights_file = os.path.join(args.resume_from_checkpoint, 'model.safetensors')
319
+ if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
320
+ if args.save_safetensors and os.path.isfile(safe_weights_file):
321
+ state_dict = safetensors.torch.load_file(safe_weights_file, device='cpu')
322
+ else:
323
+ state_dict = torch.load(weights_file, map_location='cpu')
324
+ model.load_state_dict(state_dict, False)
325
+ del state_dict
326
+ else:
327
+ from transformers.modeling_utils import load_sharded_checkpoint
328
+ # We load the sharded checkpoint
329
+ load_result = load_sharded_checkpoint(
330
+ model, args.resume_from_checkpoint, strict=False, prefer_safe=args.save_safetensors)
331
+ if len(load_result.missing_keys) != 0:
332
+ if model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
333
+ model._keys_to_ignore_on_save):
334
+ model.tie_weights()
335
+ else:
336
+ logger.warning(f'There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.')
337
+ if len(load_result.unexpected_keys) != 0:
338
+ logger.warning(f'There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}.')
339
+
340
+
341
+ class TunerMixin:
342
+
343
+ @classmethod
344
+ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_type=None):
345
+ if args.use_liger_kernel and 'use_liger_kernel' not in inspect.signature(TrainingArguments).parameters:
346
+ # Apply liger
347
+ apply_liger(args.model_type)
348
+
349
+ if args.is_adapter:
350
+ if args.tuner_backend != 'unsloth' and args.train_type not in extra_tuners:
351
+ # Fix the name of the layer in xcomposer that contains Plora.
352
+ # Unsloth prepares and loads lora outside this function when
353
+ # resume_from_checkpoint, so do not disable grad here
354
+ model.requires_grad_(False)
355
+ if args.resume_from_checkpoint:
356
+ if args.train_type in extra_tuners:
357
+ tuner: Tuner = extra_tuners[args.train_type]
358
+ else:
359
+ tuner = Swift
360
+ kwargs = {}
361
+ if use_torchacc():
362
+ kwargs = {'adapter_name': 'default'}
363
+ model = tuner.from_pretrained(model, args.resume_from_checkpoint, is_trainable=True, **kwargs)
364
+ else:
365
+ if args.train_type in extra_tuners:
366
+ tuner: Tuner = extra_tuners[args.train_type]
367
+ model = tuner.prepare_model(args, model)
368
+ else:
369
+ model = prepare_adapter(
370
+ args, model, template=template, train_dataset=train_dataset, task_type=task_type)
371
+ # fix bug: Attempting to unscale FP16 gradients.
372
+ # peft: https://github.com/huggingface/peft/issues/1249
373
+ for p in model.parameters():
374
+ if p.requires_grad and p.dtype == torch.float16:
375
+ logger.info_once('Convert trainable parameters from fp16 to fp32.')
376
+ p.data = p.data.to(dtype=torch.float32)
377
+ elif args.train_type == 'full':
378
+ model.train()
379
+ model.requires_grad_(True)
380
+
381
+ freeze_parameters(model, args.freeze_parameters_ratio, args.freeze_parameters, args.freeze_parameters_regex)
382
+ if len(args.trainable_parameters) > 0 or args.trainable_parameters_regex is not None:
383
+ activate_parameters(model, args.trainable_parameters, args.trainable_parameters_regex)
384
+ if use_torchacc() and args.resume_from_checkpoint:
385
+ torchacc_resume_from_checkpoint(args, model)
386
+ else:
387
+ raise ValueError(f'args.train_type: {args.train_type}')
388
+
389
+ if args.resume_only_model:
390
+ args.training_args.resume_from_checkpoint = None
391
+ if args.use_galore:
392
+ from swift.trainers.optimizers.galore import GaLoreConfig
393
+ if args.galore_target_modules is None:
394
+ args.galore_target_modules = find_all_linears(model)
395
+ if args.galore_with_embedding:
396
+ args.galore_target_modules += find_embedding(model)
397
+ args.galore_config = GaLoreConfig(
398
+ target_modules=args.galore_target_modules,
399
+ rank=args.galore_rank,
400
+ update_proj_gap=args.galore_update_proj_gap,
401
+ galore_scale=args.galore_scale,
402
+ proj_type=args.galore_proj_type,
403
+ optim_per_parameter=args.galore_optim_per_parameter,
404
+ quantize=args.galore_quantization,
405
+ proj_quant=args.galore_proj_quant,
406
+ proj_bits=args.galore_proj_bits,
407
+ proj_group_size=args.galore_proj_group_size,
408
+ cos_threshold=args.galore_cos_threshold,
409
+ gamma_proj=args.galore_gamma_proj,
410
+ queue_size=args.galore_queue_size,
411
+ )
412
+ args.training_args.galore_config = args.galore_config
413
+
414
+ if args.sequence_parallel_size > 1:
415
+ from swift.trainers.sequence_parallel import sequence_parallel
416
+ if hasattr(model, 'model_meta'):
417
+ is_multimodal = model.model_meta.is_multimodal
418
+ else:
419
+ is_multimodal = model.model.model_meta.is_multimodal
420
+ # multimodal model must do split in basemodel's forward
421
+ # or the media embedding may occur error
422
+ sequence_parallel.prepare_model(model, template.tokenizer, split_in_forward=is_multimodal)
423
+
424
+ return model
ms-swift/swift/megatron/argument/train_args.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+
7
+ from swift.llm import BaseArguments
8
+ from swift.llm.argument.base_args import to_abspath
9
+ from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
10
+ from ..model import get_megatron_model_meta
11
+ from .megatron_args import MegatronArguments
12
+
13
+ logger = get_logger()
14
+
15
+
16
+ @dataclass
17
+ class MegatronTrainArguments(MegatronArguments, BaseArguments):
18
+ add_version: bool = True
19
+ # dataset
20
+ lazy_tokenize: bool = False
21
+ packing: bool = False
22
+
23
+ def init_model_args(self, config):
24
+ self.megatron_model_meta = get_megatron_model_meta(self.model_type)
25
+ kwargs = self.megatron_model_meta.convert_hf_config(config)
26
+ for k, v in kwargs.items():
27
+ if getattr(self, k) is None:
28
+ setattr(self, k, v)
29
+ MegatronArguments.__post_init__(self)
30
+ self.extra_args = self.parse_to_megatron()
31
+
32
+ def _init_save(self):
33
+ init_process_group()
34
+ if self.save is None:
35
+ self.save = f'megatron_output/{self.model_suffix}'
36
+ self.save = to_abspath(self.save)
37
+ if self.add_version:
38
+ self.save = add_version_to_work_dir(self.save)
39
+ logger.info(f'args.save: {self.save}')
40
+ if is_master():
41
+ os.makedirs(self.save, exist_ok=True)
42
+
43
+ def __post_init__(self):
44
+ self.sequence_parallel_size = self.context_parallel_size
45
+ self.load = to_abspath(self.load, check_path_exist=True)
46
+ BaseArguments.__post_init__(self)
47
+ self._init_save()
48
+ self.seq_length = self.seq_length or self.max_length
49
+ if self.streaming:
50
+ self.dataloader_type = 'external'
51
+ if self.num_workers > 1:
52
+ self.num_workers = 1
53
+ logger.info('Using streaming dataset, setting args.num_workers to 1.')
ms-swift/swift/megatron/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from . import gpt
3
+ from .constant import MegatronModelType
4
+ from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
ms-swift/swift/megatron/model/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from typing import Any, Dict
3
+
4
+ from swift.utils import get_logger
5
+
6
+ logger = get_logger()
7
+ config_mapping = {
8
+ 'num_layers': ['num_hidden_layers'],
9
+ 'hidden_size': ['hidden_size'],
10
+ 'ffn_hidden_size': ['intermediate_size'],
11
+ 'num_attention_heads': ['num_attention_heads'],
12
+ 'num_query_groups': ['num_key_value_heads'],
13
+ 'max_position_embeddings': ['max_position_embeddings'],
14
+ 'norm_epsilon': ['rms_norm_eps'],
15
+ 'rotary_base': ['rope_theta'],
16
+ 'padded_vocab_size': ['vocab_size'],
17
+ 'attention_dropout': ['attention_dropout'],
18
+ 'untie_embeddings_and_output_weights': ['tie_word_embeddings'],
19
+ 'swiglu': ['hidden_act'],
20
+ 'add_qkv_bias': ['attention_bias'],
21
+ 'disable_bias_linear': ['mlp_bias'],
22
+ 'kv_channels': ['head_dim'],
23
+ 'model_type': ['model_type'],
24
+ # moe
25
+ 'moe_ffn_hidden_size': ['moe_intermediate_size'],
26
+ 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
27
+ 'moe_router_topk': ['num_experts_per_tok'],
28
+ 'num_experts': ['num_experts'],
29
+ 'moe_router_pre_softmax': ['norm_topk_prob'],
30
+ 'moe_aux_loss_coeff': ['router_aux_loss_coef'],
31
+ }
32
+
33
+
34
+ def convert_hf_config(config) -> Dict[str, Any]:
35
+ megatron_config = {}
36
+ for k, hf_keys in config_mapping.items():
37
+ for hf_k in hf_keys:
38
+ if hasattr(config, hf_k):
39
+ hf_v = getattr(config, hf_k)
40
+ if k == 'rotary_base':
41
+ megatron_config[k] = int(hf_v)
42
+ elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}:
43
+ megatron_config[k] = not hf_v
44
+ elif k == 'swiglu':
45
+ if hf_v == 'silu':
46
+ megatron_config[k] = True
47
+ else:
48
+ megatron_config[k] = hf_v
49
+ break
50
+ # compat llama3
51
+ if getattr(config, 'rope_scaling', None) is not None:
52
+ if isinstance(config.rope_scaling, int):
53
+ megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'},
54
+ elif isinstance(config.rope_scaling, dict):
55
+ megatron_config['rope_scaling'] = config.rope_scaling
56
+ logger.info(f'megatron_config: {megatron_config}')
57
+ return megatron_config
ms-swift/swift/megatron/model/constant.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ class MegatronModelType:
3
+ gpt = 'gpt'
ms-swift/swift/megatron/model/gpt/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from swift.llm import ModelType
3
+ from ..constant import MegatronModelType
4
+ from ..register import MegatronModelMeta, register_megatron_model
5
+ from .config import convert_gpt_hf_config
6
+ from .hf2mcore import convert_hf2mcore
7
+ from .mcore2hf import convert_mcore2hf
8
+ from .model import model_provider
9
+
10
+ register_megatron_model(
11
+ MegatronModelMeta(MegatronModelType.gpt, [
12
+ ModelType.qwen2,
13
+ ModelType.qwen2_5,
14
+ ModelType.qwq,
15
+ ModelType.qwq_preview,
16
+ ModelType.qwen2_5_math,
17
+ ModelType.llama,
18
+ ModelType.llama3,
19
+ ModelType.llama3_1,
20
+ ModelType.llama3_2,
21
+ ModelType.longwriter_llama3_1,
22
+ ModelType.codefuse_codellama,
23
+ ModelType.marco_o1,
24
+ ModelType.deepseek,
25
+ ModelType.deepseek_r1_distill,
26
+ ModelType.yi,
27
+ ModelType.yi_coder,
28
+ ModelType.sus,
29
+ ModelType.skywork_o1,
30
+ ModelType.openbuddy_llama,
31
+ ModelType.openbuddy_llama3,
32
+ ModelType.megrez,
33
+ ModelType.reflection,
34
+ ModelType.numina,
35
+ ModelType.ziya,
36
+ ModelType.mengzi3,
37
+ ModelType.qwen3,
38
+ ModelType.qwen2_moe,
39
+ ModelType.qwen3_moe,
40
+ ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore))
ms-swift/swift/megatron/model/gpt/config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from ..config import convert_hf_config
4
+
5
+
6
+ def convert_gpt_hf_config(config) -> Dict[str, Any]:
7
+ res = convert_hf_config(config)
8
+ model_type = res.get('model_type')
9
+ if model_type in {'qwen3', 'qwen3_moe'}:
10
+ res['qk_layernorm'] = True
11
+ if model_type in {'qwen2_moe', 'qwen3_moe'}:
12
+ res.pop('ffn_hidden_size', None)
13
+ return res
ms-swift/swift/megatron/model/gpt/model.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from megatron.core.models.gpt import GPTModel
3
+ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
4
+ from megatron.training import get_args
5
+ from megatron.training.arguments import core_transformer_config_from_args
6
+
7
+ from ..rope import update_rope_inv_freq
8
+
9
+
10
+ def model_provider(pre_process=True, post_process=True):
11
+ args = get_args()
12
+ config = core_transformer_config_from_args(args)
13
+ config.variable_seq_lengths = True
14
+ transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm,
15
+ args.qk_layernorm, args.multi_latent_attention)
16
+ if args.num_experts and args.moe_shared_expert_intermediate_size:
17
+ # qwen2_moe/qwen3_moe
18
+ transformer_layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True}
19
+ model = GPTModel(
20
+ config=config,
21
+ transformer_layer_spec=transformer_layer_spec,
22
+ vocab_size=args.padded_vocab_size,
23
+ max_sequence_length=args.max_position_embeddings,
24
+ pre_process=pre_process,
25
+ post_process=post_process,
26
+ fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
27
+ parallel_output=True,
28
+ share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
29
+ position_embedding_type=args.position_embedding_type,
30
+ rotary_percent=args.rotary_percent,
31
+ rotary_base=args.rotary_base,
32
+ rope_scaling=args.use_rope_scaling,
33
+ rope_scaling_factor=args.rope_scaling_factor,
34
+ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor)
35
+ if args.rope_scaling:
36
+ update_rope_inv_freq(model.rotary_pos_emb.inv_freq, args.rope_scaling)
37
+ return model
ms-swift/swift/megatron/model/register.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
5
+ import torch.nn as nn
6
+ from transformers import PretrainedConfig
7
+
8
+ from swift.llm import MODEL_MAPPING, ModelGroup
9
+
10
+ MEGATRON_MODEL_MAPPING = {}
11
+
12
+
13
+ @dataclass
14
+ class MegatronModelMeta:
15
+ megatron_model_type: str
16
+ model_types: List[str]
17
+
18
+ model_provider: Callable[[], nn.Module]
19
+ convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]]
20
+ convert_mcore2hf: Callable[[nn.Module, nn.Module], None]
21
+ convert_hf2mcore: Callable[[nn.Module, nn.Module], None]
22
+
23
+
24
+ def register_megatron_model(megatron_model_meta: MegatronModelMeta, *, exist_ok: bool = False):
25
+ megatron_model_type = megatron_model_meta.megatron_model_type
26
+ for model_type in megatron_model_meta.model_types:
27
+ model_meta = MODEL_MAPPING[model_type]
28
+ model_meta.support_megatron = True
29
+ if not exist_ok and megatron_model_type in MEGATRON_MODEL_MAPPING:
30
+ raise ValueError(f'The `{megatron_model_type}` has already been registered in the MODEL_MAPPING.')
31
+
32
+ MEGATRON_MODEL_MAPPING[megatron_model_type] = megatron_model_meta
33
+
34
+
35
+ _MODEL_META_MAPPING = None
36
+
37
+
38
+ def get_megatron_model_meta(model_type: str) -> Optional[MegatronModelMeta]:
39
+ global _MODEL_META_MAPPING
40
+ if _MODEL_META_MAPPING is None:
41
+ _MODEL_META_MAPPING = {}
42
+ for k, megatron_model_meta in MEGATRON_MODEL_MAPPING.items():
43
+ for _model_type in megatron_model_meta.model_types:
44
+ _MODEL_META_MAPPING[_model_type] = k
45
+ if model_type not in _MODEL_META_MAPPING:
46
+ return
47
+ return MEGATRON_MODEL_MAPPING[_MODEL_META_MAPPING[model_type]]
ms-swift/swift/megatron/model/rope.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict
3
+
4
+ import torch
5
+
6
+
7
+ def _to_llama3_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]):
8
+ # copy from transformers
9
+ factor = rope_scaling['factor'] # `8` in the original implementation
10
+ low_freq_factor = rope_scaling['low_freq_factor'] # `1` in the original implementation
11
+ high_freq_factor = rope_scaling['high_freq_factor'] # `4` in the original implementation
12
+ old_context_len = rope_scaling['original_max_position_embeddings'] # `8192` in the original implementation
13
+
14
+ low_freq_wavelen = old_context_len / low_freq_factor
15
+ high_freq_wavelen = old_context_len / high_freq_factor
16
+
17
+ wavelen = 2 * math.pi / inv_freq
18
+ # wavelen < high_freq_wavelen: do nothing
19
+ # wavelen > low_freq_wavelen: divide by factor
20
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
21
+ # otherwise: interpolate between the two, using a smooth factor
22
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
23
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
24
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
25
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
26
+ return inv_freq_llama
27
+
28
+
29
+ def _to_linear_rope(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]):
30
+ factor = rope_scaling['factor']
31
+ inv_freq /= factor
32
+ return inv_freq
33
+
34
+
35
+ ROPE_MAPPING = {'llama3': _to_llama3_rope, 'linear': _to_linear_rope}
36
+
37
+
38
+ def update_rope_inv_freq(inv_freq: torch.Tensor, rope_scaling: Dict[str, Any]) -> None:
39
+ new_inv_freq = ROPE_MAPPING[rope_scaling['rope_type']](inv_freq, rope_scaling)
40
+ inv_freq.data.copy_(new_inv_freq)
ms-swift/swift/megatron/train/patcher.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import os
3
+ from contextlib import contextmanager
4
+ from functools import wraps
5
+
6
+ import torch
7
+ from megatron.training import get_args, global_vars, initialize, training
8
+
9
+ from swift.utils import JsonlWriter, is_master
10
+
11
+
12
+ @contextmanager
13
+ def patch_training_log():
14
+ jsonl_writer = None
15
+ origin_training_log = training.training_log
16
+
17
+ @wraps(origin_training_log)
18
+ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
19
+ report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs):
20
+ nonlocal jsonl_writer
21
+ args = get_args()
22
+ if is_master() and iteration % args.log_interval == 0:
23
+ logging_path = os.path.join(args.save, 'logging.jsonl')
24
+ logs = {}
25
+ for k, v in loss_dict.items():
26
+ if isinstance(v, torch.Tensor):
27
+ v = v.item()
28
+ logs[k] = round(v, 8)
29
+ for k in {'grad_norm', 'params_norm', 'learning_rate'}:
30
+ v = locals()[k]
31
+ if v is not None:
32
+ logs[k] = round(v, 8)
33
+ logs['consumed_samples'] = args.consumed_train_samples
34
+ logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}'
35
+ if jsonl_writer is None:
36
+ jsonl_writer = JsonlWriter(logging_path, enable_async=True)
37
+ jsonl_writer.append(logs)
38
+ return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
39
+ loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm,
40
+ num_zeros_in_grad, *_args, **kwargs)
41
+
42
+ training.training_log = training_log
43
+ try:
44
+ yield
45
+ finally:
46
+ training.training_log = origin_training_log
47
+
48
+
49
+ @contextmanager
50
+ def patch_megatron_data_collator(data_collator):
51
+ origin_build_pretraining_data_loader = training.build_pretraining_data_loader
52
+
53
+ def build_pretraining_data_loader(*_args, **kwargs):
54
+ args = get_args()
55
+ res = origin_build_pretraining_data_loader(*_args, **kwargs)
56
+ if res is not None and args.dataloader_type != 'external':
57
+ res.collate_fn = data_collator
58
+ return res
59
+
60
+ training.build_pretraining_data_loader = build_pretraining_data_loader
61
+ try:
62
+ yield
63
+ finally:
64
+ training.build_pretraining_data_loader = origin_build_pretraining_data_loader
ms-swift/swift/megatron/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ from .convert import convert_hf2mcore, convert_mcore2hf
4
+ from .patcher import patch_megatron_tokenizer
ms-swift/swift/megatron/utils/convert.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import math
4
+
5
+ import torch
6
+ from megatron.training.checkpointing import load_checkpoint
7
+ from megatron.training.checkpointing import save_checkpoint as mg_save_checkpoint
8
+ from megatron.training.initialize import initialize_megatron
9
+ from megatron.training.utils import get_ltor_masks_and_position_ids
10
+
11
+ from swift.llm import ExportArguments, get_model_tokenizer, get_template, save_checkpoint
12
+ from swift.utils import get_logger, get_n_params_grads
13
+ from ..argument import MegatronArguments
14
+ from ..model import get_megatron_model_meta
15
+ from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard
16
+
17
+ logger = get_logger()
18
+
19
+
20
+ def test_convert_precision(hf_model, mg_model, processor):
21
+ torch_dtype = hf_model.dtype
22
+ template = get_template(hf_model.model_meta.template, processor)
23
+ input_ids = template.encode({'messages': [{'role': 'user', 'content': 'who are you?'}]})['input_ids']
24
+ input_ids = torch.tensor(input_ids)[None].to('cuda')
25
+ hf_model.to('cuda')
26
+ hf_model.to(torch.float32)
27
+ with torch.inference_mode():
28
+ hf_logits = hf_model(input_ids).logits
29
+ hf_model.to(torch_dtype)
30
+ hf_model.to('cpu')
31
+
32
+ attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
33
+ mg_model.to('cuda')
34
+ mg_model.to(torch.float32)
35
+ with torch.inference_mode():
36
+ mg_logits = mg_model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
37
+ mg_model.to(torch_dtype)
38
+ mg_model.to('cpu')
39
+
40
+ mean_diff = (mg_logits - hf_logits).abs().mean().item()
41
+ max_diff = (mg_logits - hf_logits).abs().max().item()
42
+ print(f'mean_diff: {mean_diff}, max_diff: {max_diff}')
43
+ hf_tokens = hf_logits.argmax(-1)
44
+ mg_tokens = mg_logits.argmax(-1)
45
+ print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}')
46
+ assert mean_diff < 0.1
47
+ assert (hf_tokens == mg_tokens).all()
48
+
49
+
50
+ convert_kwargs = {
51
+ 'use_cpu_initialization': True,
52
+ 'no_save_optim': True,
53
+ 'no_save_rng': True,
54
+ 'no_load_optim': True,
55
+ 'no_load_rng': True,
56
+ 'no_masked_softmax_fusion': True,
57
+ 'no_bias_dropout_fusion': True,
58
+ 'no_bias_swiglu_fusion': True,
59
+ 'no_rope_fusion': True
60
+ }
61
+
62
+
63
+ def convert_hf2mcore(args: ExportArguments) -> None:
64
+ kwargs = args.get_model_kwargs()
65
+ hf_model, processor = get_model_tokenizer(**kwargs)
66
+ if args.thread_count is None:
67
+ checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
68
+ args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
69
+ patch_torch_dist_shard(args.thread_count)
70
+
71
+ megatron_model_meta = get_megatron_model_meta(args.model_type)
72
+ assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
73
+ kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
74
+ megatron_args = MegatronArguments(**kwargs, **convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
75
+ patch_megatron_tokenizer(processor)
76
+ extra_args = megatron_args.parse_to_megatron()
77
+ initialize_megatron(args_defaults=extra_args)
78
+
79
+ mg_model = megatron_model_meta.model_provider()
80
+ logger.info('Megatron model created successfully.')
81
+ megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
82
+ if args.test_convert_precision:
83
+ test_convert_precision(hf_model, mg_model, processor)
84
+ logger.info('Successfully transferred HF model weights to MG model.')
85
+ mg_save_checkpoint(1, [mg_model], None, None, 0)
86
+ args.save_args()
87
+ logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')
88
+
89
+
90
+ def convert_mcore2hf(args: ExportArguments) -> None:
91
+ kwargs = args.get_model_kwargs()
92
+ hf_model, processor = get_model_tokenizer(**kwargs)
93
+ if args.thread_count is None:
94
+ checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
95
+ args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
96
+ patch_torch_dist_shard(args.thread_count)
97
+
98
+ megatron_model_meta = get_megatron_model_meta(args.model_type)
99
+ assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
100
+ kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
101
+ megatron_args = MegatronArguments(**kwargs, **convert_kwargs, load=args.mcore_model, torch_dtype=args.torch_dtype)
102
+ patch_megatron_tokenizer(processor)
103
+ extra_args = megatron_args.parse_to_megatron()
104
+ initialize_megatron(args_defaults=extra_args)
105
+
106
+ mg_model = megatron_model_meta.model_provider()
107
+ load_checkpoint([mg_model], None, None, strict=True)
108
+ logger.info('Megatron model created successfully.')
109
+ megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
110
+ if args.test_convert_precision:
111
+ test_convert_precision(hf_model, mg_model, processor)
112
+ logger.info('Successfully transferred MG model weights to HF model.')
113
+ save_checkpoint(
114
+ hf_model,
115
+ processor,
116
+ args.output_dir,
117
+ safe_serialization=args.safe_serialization,
118
+ model_dirs=[args.mcore_model, args.model_dir],
119
+ max_shard_size=args.max_shard_size,
120
+ additional_saved_files=hf_model.model_meta.additional_saved_files)
121
+ args.save_args()
122
+ logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
ms-swift/swift/megatron/utils/patcher.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+ from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
3
+ from megatron.training import get_args, global_vars, initialize, training
4
+
5
+ from swift.utils import get_logger
6
+
7
+ logger = get_logger()
8
+
9
+
10
+ def patch_megatron_tokenizer(tokenizer):
11
+
12
+ def build_tokenizer(args):
13
+ args.extra_vocab_size = args.padded_vocab_size - tokenizer.vocab_size
14
+ return tokenizer
15
+
16
+ global_vars.build_tokenizer = build_tokenizer
17
+
18
+
19
+ def patch_torch_dist_shard(thread_count):
20
+ __init__ = TorchDistSaveShardedStrategy.__init__
21
+
22
+ def __new_init__(*args, **kwargs):
23
+ kwargs['thread_count'] = thread_count
24
+ return __init__(*args, **kwargs)
25
+
26
+ TorchDistSaveShardedStrategy.__init__ = __new_init__
ms-swift/swift/plugin/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
ms-swift/swift/plugin/__pycache__/callback.cpython-310.pyc ADDED
Binary file (1.31 kB). View file
 
ms-swift/swift/plugin/__pycache__/metric.cpython-310.pyc ADDED
Binary file (6.73 kB). View file