aoiandroid yanchaomars commited on
Commit
02d418e
·
0 Parent(s):

Duplicate from stepfun-ai/Step-Audio-EditX

Browse files

Co-authored-by: chao yan <yanchaomars@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/architechture.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/emotion-eval.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/test.wav filter=lfs diff=lfs merge=lfs -text
CosyVoice-300M-25Hz/FLOW_VERSION ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /mnt/wby-jfs/models/train/flow_matching/flow_v2_1node_vq0206_dit_v8_fullattn_exp0227_sft_exp0408_stepaudio_sft_exp0616/model_epoch_5_whole.pt
2
+ fae53942e60310eb172b170396202069
CosyVoice-300M-25Hz/campplus.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
3
+ size 28303423
CosyVoice-300M-25Hz/cosyvoice.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mel_conf:
2
+ num_mels: 80
3
+ n_fft: 1920
4
+ hop_size: 480
5
+ win_size: 1920
6
+ sampling_rate: 24000
7
+ fmin: 0
8
+ fmax: 8000
9
+
10
+
11
+ flow: !new:stepvocoder.cosyvoice2.flow.flow.CausalMaskedDiffWithXvec
12
+ input_size: 512
13
+ output_size: 80
14
+ spk_embed_dim: 192
15
+ output_type: 'mel'
16
+ vocab_size: 5121 # 1024(vq02) + 4096(vq06) + 1(vq02-pad)
17
+ input_embedding: !new:stepvocoder.cosyvoice2.embedding.dual_codebook.DualCodebookEmbedding
18
+ vocab_size: 5121 # 1024(vq02) + 4096(vq06) + 1(vq02-pad)
19
+ input_size: 512
20
+ encoder: !new:stepvocoder.cosyvoice2.transformer.upsample_encoder_v2.UpsampleConformerEncoderV2
21
+ input_size: 512
22
+ output_size: 512
23
+ input_layer: 'linear'
24
+ pre_lookahead_len: 3
25
+ num_blocks: 6
26
+ num_up_blocks: 4
27
+ up_stride: 2
28
+ up_scale_factor: 2
29
+ attention_heads: 8
30
+ pos_enc_layer_type: 'rel_pos_espnet'
31
+ selfattention_layer_type: 'rel_selfattn'
32
+ key_bias: true
33
+ linear_units: 2048
34
+ dropout_rate: 0.1
35
+ positional_dropout_rate: 0.1
36
+ attention_dropout_rate: 0.1
37
+ normalize_before: True
38
+ decoder: !new:stepvocoder.cosyvoice2.flow.flow_matching.CausalConditionalCFM
39
+ inference_cfg_rate: 0.7
40
+ estimator: !new:stepvocoder.cosyvoice2.flow.decoder_dit.DiT
41
+ in_channels: 320
42
+ out_channels: 80
43
+ mlp_ratio: 4.0
44
+ depth: 16
45
+ num_heads: 8
46
+ head_dim: 64
47
+ hidden_size: 512
48
+
49
+
50
+ hift: !new:stepvocoder.cosyvoice2.hifigan.generator.HiFTGenerator
51
+ in_channels: 80
52
+ base_channels: 512
53
+ nb_harmonics: 8
54
+ sampling_rate: 24000
55
+ nsf_alpha: 0.1
56
+ nsf_sigma: 0.003
57
+ nsf_voiced_threshold: 10
58
+ upsample_rates: [8, 5, 3]
59
+ upsample_kernel_sizes: [16, 11, 7]
60
+ istft_params:
61
+ n_fft: 16
62
+ hop_len: 4
63
+ resblock_kernel_sizes: [3, 7, 11]
64
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
65
+ source_resblock_kernel_sizes: [7, 7, 11]
66
+ source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
67
+ lrelu_slope: 0.1
68
+ audio_limit: 0.99
69
+ f0_predictor: !new:stepvocoder.cosyvoice2.hifigan.f0_predictor.ConvRNNF0Predictor
70
+ num_class: 1
71
+ in_channels: 80
72
+ cond_channels: 512
CosyVoice-300M-25Hz/flow.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37f18fcb9c374bb8d8ae229e2f7618b6effaa208609bd0407fc661234125531c
3
+ size 615269316
CosyVoice-300M-25Hz/hift.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3386cc880324d4e98e05987b99107f49e40ed925b8ecc87c1f4939432d429879
3
+ size 83390254
CosyVoice-300M-25Hz/speech_tokenizer_v1.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
3
+ size 522625011
README.md ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Step-Audio-EditX
2
+ <p align="center">
3
+ <img src="assets/logo.png" height=100>
4
+ </p>
5
+
6
+ <div align="center">
7
+ <a href="https://stepaudiollm.github.io/step-audio-editx/"><img src="https://img.shields.io/static/v1?label=Demo%20Page&message=Web&color=green"></a> &ensp;
8
+ <a href="https://arxiv.org/abs/2511.03601"><img src="https://img.shields.io/static/v1?label=Tech%20Report&message=Arxiv&color=red"></a> &ensp;
9
+ <a href="https://huggingface.co/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Step-Audio-EditX&message=HuggingFace&color=yellow"></a> &ensp;
10
+ <a href="https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Step-Audio-EditX&message=ModelScope&color=blue"></a> &ensp;
11
+ <a href="https://huggingface.co/spaces/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Space%20Playground&message=HuggingFace&color=yellow"></a> &ensp;
12
+ </div>
13
+
14
+ ## 🔥🔥🔥 News!!!
15
+ * Jan 23, 2026: 🌟 Training and inference for vLLM are now supported. Thanks to the vLLM team!
16
+ * Jan 23, 2026: 💻 We release the GRPO training code.
17
+ * Jan 23, 2026: 🧩 New Model Release: Now supporting more paralinguistic tags.
18
+ * Nov 28, 2025: 🚀 New Model Release: Now supporting **`Japanese`** and **`Korean`** languages.
19
+ * Nov 23, 2025: 📊 [Step-Audio-Edit-Benchmark](https://github.com/stepfun-ai/Step-Audio-Edit-Benchmark) Released!
20
+ * Nov 19, 2025: ⚙️ We release a **new version** of our model, which **supports polyphonic pronunciation control** and improves the performance of emotion, speaking style, and paralinguistic editing.
21
+ * Nov 12, 2025: 📦 We release the **optimized inference code** and **model weights** of **Step-Audio-EditX** ([HuggingFace](https://huggingface.co/stepfun-ai/Step-Audio-EditX); [ModelScope](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX)) and **Step-Audio-Tokenizer**([HuggingFace](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer); [ModelScope](https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer))
22
+ * Nov 07, 2025: ✨ [Demo Page](https://stepaudiollm.github.io/step-audio-editx/) ; 🎮 [HF Space Playground](https://huggingface.co/spaces/stepfun-ai/Step-Audio-EditX)
23
+ * Nov 06, 2025: 👋 We release the technical report of [Step-Audio-EditX](https://arxiv.org/abs/2511.03601).
24
+
25
+ ## Introduction
26
+ We are open-sourcing Step-Audio-EditX, a powerful **3B-parameter** LLM-based **Reinforcement Learning** audio model specialized in expressive and iterative audio editing. It excels at editing emotion, speaking style, and paralinguistics, and also features robust zero-shot text-to-speech (TTS) capabilities.
27
+
28
+ ## 📑 Open-source Plan
29
+ - [x] Inference Code
30
+ - [x] Online demo (Gradio)
31
+ - [x] Step-Audio-Edit-Benchmark
32
+ - [x] Model Checkpoints
33
+ - [x] Step-Audio-Tokenizer
34
+ - [x] Step-Audio-EditX
35
+ - [x] Step-Audio-EditX-Int4
36
+ - [ ] Training Code
37
+ - [x] GRPO training
38
+ - [ ] SFT training
39
+ - [ ] PPO training
40
+ - [ ] ⏳ Feature Support Plan
41
+ - [ ] Editing
42
+ - [x] Polyphone pronunciation control
43
+ - [x] More paralinguistic tags ([Cough, Crying, Stress, etc.])
44
+ - [ ] Filler word removal
45
+ - [ ] Other Languages
46
+ - [x] Japanese, Korean
47
+ - [ ] Arabic, French, Russian, Spanish, etc.
48
+
49
+ ## Online demonstration
50
+
51
+ ### StepFun Audio Studio
52
+
53
+ - Both Step-Audio-EditX are available in our [StepFun Audio Studio](https://www.stepfun.com/studio/audio).
54
+ - You will need an API key from the [StepFun Open Platform](https://platform.stepfun.com/).
55
+
56
+ ## WeChat group
57
+
58
+ You can scan the following QR code to join our WeChat group for communication and discussion.
59
+ <div align="center">
60
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/66518fd07d8cb2629a514c18/DRdnp1SN-yxhlNOfy26mE.jpeg" width="200" alt="QR code">
61
+ </div>
62
+
63
+ ## Features
64
+ - **Zero-Shot TTS**
65
+ - Excellent zero-shot TTS cloning for Mandarin, English, Sichuanese, and Cantonese.
66
+ - To use dialect or other languages, just add a **`[Sichuanese]`** / **`[Cantonese]`** / **`[Japanese]`** / **`[Korean]`** tag before your text.
67
+ - 🔥 Polyphone pronunciation control, all you need to do is replace the polyphonic characters with pinyin.
68
+ - **[我也想过过过儿过过的生活]** -> **[我也想guo4guo4guo1儿guo4guo4的生活]**
69
+
70
+
71
+ - **Emotion and Speaking Style Editing**
72
+ - Remarkably effective iterative control over emotions and styles, supporting **dozens** of options for editing.
73
+ - Emotion Editing : [ *Angry*, *Happy*, *Sad*, *Excited*, *Fearful*, *Surprised*, *Disgusted*, etc. ]
74
+ - Speaking Style Editing: [ *Act_coy*, *Older*, *Child*, *Whisper*, *Serious*, *Generous*, *Exaggerated*, etc.]
75
+ - Editing with more emotion and more speaking styles is on the way. **Get Ready!** 🚀
76
+
77
+
78
+ - **Paralinguistic Editing**
79
+ - Precise control over 10 types of paralinguistic features for more natural, human-like, and expressive synthetic audio.
80
+ - Supporting Tags:
81
+ - [ *Breathing*, *Laughter*, *Surprise-oh*, *Confirmation-en*, *Uhm*, *Surprise-ah*, *Surprise-wa*, *Sigh*, *Question-ei*, *Dissatisfaction-hnn* ]
82
+
83
+ - **Available Tags**
84
+ <table>
85
+ <tr>
86
+ <td rowspan="8" style="vertical-align: middle; text-align:center;" align="center">emotion</td>
87
+ <td align="center"><b>happy</b></td>
88
+ <td align="center">Expressing happiness</td>
89
+ <td align="center"><b>angry</b></td>
90
+ <td align="center">Expressing anger</td>
91
+ </tr>
92
+ <tr>
93
+ <td align="center"><b>sad</b></td>
94
+ <td align="center">Expressing sadness</td>
95
+ <td align="center"><b>fear</b></td>
96
+ <td align="center">Expressing fear</td>
97
+ </tr>
98
+ <tr>
99
+ <td align="center"><b>surprised</b></td>
100
+ <td align="center">Expressing surprise</td>
101
+ <td align="center"><b>confusion</b></td>
102
+ <td align="center">Expressing confusion</td>
103
+ </tr>
104
+ <tr>
105
+ <td align="center"><b>empathy</b></td>
106
+ <td align="center">Expressing empathy and understanding</td>
107
+ <td align="center"><b>embarrass</b></td>
108
+ <td align="center">Expressing embarrassment</td>
109
+ </tr>
110
+ <tr>
111
+ <td align="center"><b>excited</b></td>
112
+ <td align="center">Expressing excitement and enthusiasm</td>
113
+ <td align="center"><b>depressed</b></td>
114
+ <td align="center">Expressing a depressed or discouraged mood</td>
115
+ </tr>
116
+ <tr>
117
+ <td align="center"><b>admiration</b></td>
118
+ <td align="center">Expressing admiration or respect</td>
119
+ <td align="center"><b>coldness</b></td>
120
+ <td align="center">Expressing coldness and indifference</td>
121
+ </tr>
122
+ <tr>
123
+ <td align="center"><b>disgusted</b></td>
124
+ <td align="center">Expressing disgust or aversion</td>
125
+ <td align="center"><b>humour</b></td>
126
+ <td align="center">Expressing humor or playfulness</td>
127
+ </tr>
128
+ <tr>
129
+ </tr>
130
+ <tr>
131
+ <td rowspan="17" style="vertical-align: middle; text-align:center;" align="center">speaking style</td>
132
+ <td align="center"><b>serious</b></td>
133
+ <td align="center">Speaking in a serious or solemn manner</td>
134
+ <td align="center"><b>arrogant</b></td>
135
+ <td align="center">Speaking in an arrogant manner</td>
136
+ </tr>
137
+ <tr>
138
+ <td align="center"><b>child</b></td>
139
+ <td align="center">Speaking in a childlike manner</td>
140
+ <td align="center"><b>older</b></td>
141
+ <td align="center">Speaking in an elderly-sounding manner</td>
142
+ </tr>
143
+ <tr>
144
+ <td align="center"><b>girl</b></td>
145
+ <td align="center">Speaking in a light, youthful feminine manner</td>
146
+ <td align="center"><b>pure</b></td>
147
+ <td align="center">Speaking in a pure, innocent manner</td>
148
+ </tr>
149
+ <tr>
150
+ <td align="center"><b>sister</b></td>
151
+ <td align="center">Speaking in a mature, confident feminine manner</td>
152
+ <td align="center"><b>sweet</b></td>
153
+ <td align="center">Speaking in a sweet, lovely manner</td>
154
+ </tr>
155
+ <tr>
156
+ <td align="center"><b>exaggerated</b></td>
157
+ <td align="center">Speaking in an exaggerated, dramatic manner</td>
158
+ <td align="center"><b>ethereal</b></td>
159
+ <td align="center">Speaking in a soft, airy, dreamy manner</td>
160
+ </tr>
161
+ <tr>
162
+ <td align="center"><b>whisper</b></td>
163
+ <td align="center">Speaking in a whispering, very soft manner</td>
164
+ <td align="center"><b>generous</b></td>
165
+ <td align="center">Speaking in a hearty, outgoing, and straight-talking manner</td>
166
+ </tr>
167
+ <tr>
168
+ <td align="center"><b>recite</b></td>
169
+ <td align="center">Speaking in a clear, well-paced, poetry-reading manner</td>
170
+ <td align="center"><b>act_coy</b></td>
171
+ <td align="center">Speaking in a sweet, playful, and endearing manner</td>
172
+ </tr>
173
+ <tr>
174
+ <td align="center"><b>warm</b></td>
175
+ <td align="center">Speaking in a warm, friendly manner</td>
176
+ <td align="center"><b>shy</b></td>
177
+ <td align="center">Speaking in a shy, timid manner</td>
178
+ </tr>
179
+ <tr>
180
+ <td align="center"><b>comfort</b></td>
181
+ <td align="center">Speaking in a comforting, reassuring manner</td>
182
+ <td align="center"><b>authority</b></td>
183
+ <td align="center">Speaking in an authoritative, commanding manner</td>
184
+ </tr>
185
+ <tr>
186
+ <td align="center"><b>chat</b></td>
187
+ <td align="center">Speaking in a casual, conversational manner</td>
188
+ <td align="center"><b>radio</b></td>
189
+ <td align="center">Speaking in a radio-broadcast manner</td>
190
+ </tr>
191
+ <tr>
192
+ <td align="center"><b>soulful</b></td>
193
+ <td align="center">Speaking in a heartfelt, deeply emotional manner</td>
194
+ <td align="center"><b>gentle</b></td>
195
+ <td align="center">Speaking in a gentle, soft manner</td>
196
+ </tr>
197
+ <tr>
198
+ <td align="center"><b>story</b></td>
199
+ <td align="center">Speaking in a narrative, audiobook-style manner</td>
200
+ <td align="center"><b>vivid</b></td>
201
+ <td align="center">Speaking in a lively, expressive manner</td>
202
+ </tr>
203
+ <tr>
204
+ <td align="center"><b>program</b></td>
205
+ <td align="center">Speaking in a show-host/presenter manner</td>
206
+ <td align="center"><b>news</b></td>
207
+ <td align="center">Speaking in a news broadcasting manner</td>
208
+ </tr>
209
+ <tr>
210
+ <td align="center"><b>advertising</b></td>
211
+ <td align="center">Speaking in a polished, high-end commercial voiceover manner</td>
212
+ <td align="center"><b>roar</b></td>
213
+ <td align="center">Speaking in a loud, deep, roaring manner</td>
214
+ </tr>
215
+ <tr>
216
+ <td align="center"><b>murmur</b></td>
217
+ <td align="center">Speaking in a quiet, low manner</td>
218
+ <td align="center"><b>shout</b></td>
219
+ <td align="center">Speaking in a loud, sharp, shouting manner</td>
220
+ </tr>
221
+ <tr>
222
+ <td align="center"><b>deeply</b></td>
223
+ <td align="center">Speaking in a deep and low-pitched tone</td>
224
+ <td align="center"><b>loudly</b></td>
225
+ <td align="center">Speaking in a loud and high-pitched tone</td>
226
+ </tr>
227
+ <tr>
228
+ </tr>
229
+ <tr>
230
+ </tr>
231
+ <tr>
232
+ <td rowspan="11" style="vertical-align: middle; text-align:center;" align="center">paralinguistic</td>
233
+ <td align="center"><b>[sigh]</b></td>
234
+ <td align="center">Sighing sound</td>
235
+ <td align="center"><b>[inhale]</b></td>
236
+ <td align="center">Inhaling sound</td>
237
+ </tr>
238
+
239
+ <tr>
240
+ <td align="center"><b>[laugh]</b></td>
241
+ <td align="center">Laughter sound</td>
242
+ <td align="center"><b>[chuckle]</b></td>
243
+ <td align="center">Chuckling sound</td>
244
+ </tr>
245
+
246
+ <tr>
247
+ <td align="center"><b>[exhale]</b></td>
248
+ <td align="center">Exhaling sound</td>
249
+ <td align="center"><b>[clears throat]</b></td>
250
+ <td align="center">Throat clearing sound</td>
251
+ </tr>
252
+
253
+ <tr>
254
+ <td align="center"><b>[snort]</b></td>
255
+ <td align="center">Snorting sound</td>
256
+ <td align="center"><b>[giggle]</b></td>
257
+ <td align="center">Giggling sound</td>
258
+ </tr>
259
+
260
+ <tr>
261
+ <td align="center"><b>[cough]</b></td>
262
+ <td align="center">Coughing sound</td>
263
+ <td align="center"><b>[breath]</b></td>
264
+ <td align="center">Breathing sound</td>
265
+ </tr>
266
+
267
+ <tr>
268
+ <td align="center"><b>[uhm]</b></td>
269
+ <td align="center">Hesitation sound: "Uhm"</td>
270
+ <td align="center"><b>[Confirmation-en]</b></td>
271
+ <td align="center">Confirming: "En"</td>
272
+ </tr>
273
+
274
+ <tr>
275
+ <td align="center"><b>[Surprise-oh]</b></td>
276
+ <td align="center">Expressing surprise: "Oh"</td>
277
+ <td align="center"><b>[Surprise-ah]</b></td>
278
+ <td align="center">Expressing surprise: "Ah"</td>
279
+ </tr>
280
+
281
+ <tr>
282
+ <td align="center"><b>[Surprise-wa]</b></td>
283
+ <td align="center">Expressing surprise: "Wa"</td>
284
+ <td align="center"><b>[Surprise-yo]</b></td>
285
+ <td align="center">Expressing surprise: "Yo"</td>
286
+ </tr>
287
+
288
+ <tr>
289
+ <td align="center"><b>[Dissatisfaction-hnn]</b></td>
290
+ <td align="center">Dissatisfied sound: "Hnn"</td>
291
+ <td align="center"><b>[Question-ei]</b></td>
292
+ <td align="center">Questioning: "Ei"</td>
293
+ </tr>
294
+
295
+ <tr>
296
+ <td align="center"><b>[Question-ah]</b></td>
297
+ <td align="center">Questioning: "Ah"</td>
298
+ <td align="center"><b>[Question-en]</b></td>
299
+ <td align="center">Questioning: "En"</td>
300
+ </tr>
301
+
302
+ <tr>
303
+ <td align="center"><b>[Question-yi]</b></td>
304
+ <td align="center">Questioning: "Yi"</td>
305
+ <td align="center"><b>[Question-oh]</b></td>
306
+ <td align="center">Questioning: "Oh"</td>
307
+ </tr>
308
+ </table>
309
+
310
+ ## Feature Requests & Wishlist
311
+ 💡 We welcome all ideas for new features! If you'd like to see a feature added to the project, please start a discussion in our [Discussions](https://github.com/stepfun-ai/Step-Audio-EditX/discussions) section.
312
+
313
+ We'll be collecting community feedback here and will incorporate popular suggestions into our future development plans. Thank you for your contribution!
314
+
315
+ ## Demos
316
+
317
+ <table>
318
+ <tr>
319
+ <th style="vertical-align : middle;text-align: center">Task</th>
320
+ <th style="vertical-align : middle;text-align: center">Text</th>
321
+ <th style="vertical-align : middle;text-align: center">Source</th>
322
+ <th style="vertical-align : middle;text-align: center">Edited</th>
323
+ </tr>
324
+
325
+ <tr>
326
+ <td align="center"> Emotion-Fear</td>
327
+ <td align="center"> 我总觉得,有人在跟着我,我能听到奇怪的脚步声。</td>
328
+ <td align="center">
329
+
330
+ [fear_zh_female_prompt.webm](https://github.com/user-attachments/assets/a088c059-032c-423f-81d6-3816ba347ff5)
331
+ </td>
332
+ <td align="center">
333
+
334
+ [fear_zh_female_output.webm](https://github.com/user-attachments/assets/917494ac-5913-4949-8022-46cf55ca05dd)
335
+ </td>
336
+ </tr>
337
+
338
+
339
+ <tr>
340
+ <td align="center"> Style-Whisper</td>
341
+ <td align="center"> 比如在工作间隙,做一些简单的伸展运动,放松一下身体,这样,会让你更有精力。</td>
342
+ <td align="center">
343
+
344
+ [whisper_prompt.webm](https://github.com/user-attachments/assets/ed9e22f1-1bac-417b-913a-5f1db31f35c9)
345
+ </td>
346
+ <td align="center">
347
+
348
+ [whisper_output.webm](https://github.com/user-attachments/assets/e0501050-40db-4d45-b380-8bcc309f0b5f)
349
+ </td>
350
+ </tr>
351
+
352
+ <tr>
353
+ <td align="center"> Style-Act_coy</td>
354
+ <td align="center"> 我今天想喝奶茶,可是不知道喝什么口味,你帮我选一下嘛,你选的都好喝~</td>
355
+ <td align="center">
356
+
357
+ [act_coy_prompt.webm](https://github.com/user-attachments/assets/74d60625-5b3c-4f45-becb-0d3fe7cc4b3f)
358
+ </td>
359
+ <td align="center">
360
+
361
+ [act_coy_output.webm](https://github.com/user-attachments/assets/b2f74577-56c2-4997-afd6-6bf47d15ea51)
362
+ </td>
363
+ </tr>
364
+
365
+
366
+ <tr>
367
+ <td align="center"> Paralinguistics</td>
368
+ <td align="center"> 你这次又忘记带钥匙了 [Dissatisfaction-hnn],真是拿你没办法。</td>
369
+ <td align="center">
370
+
371
+ [paralingustic_prompt.webm](https://github.com/user-attachments/assets/21e831a3-8110-4c64-a157-60e0cf6735f0)
372
+ </td>
373
+ <td align="center">
374
+
375
+ [paralingustic_output.webm](https://github.com/user-attachments/assets/a82f5a40-c6a3-409b-bbe6-271180b20d7b)
376
+ </td>
377
+ </tr>
378
+
379
+
380
+ <tr>
381
+ <td align="center"> Denoising</td>
382
+ <td align="center"> Such legislation was clarified and extended from time to time thereafter. No, the man was not drunk, he wondered how we got tied up with this stranger. Suddenly, my reflexes had gone. It's healthier to cook without sugar.</td>
383
+ <td align="center">
384
+
385
+ [denoising_prompt.webm](https://github.com/user-attachments/assets/70464bf4-ebde-44a3-b2a6-8c292333319b)
386
+ </td>
387
+ <td align="center">
388
+
389
+ [denoising_output.webm](https://github.com/user-attachments/assets/7cd0ae8d-1bf0-40fc-9bcd-f419bd4b2d21)
390
+ </td>
391
+ </tr>
392
+
393
+ <tr>
394
+ <td align="center"> Speed-Faster</td>
395
+ <td align="center"> 上次你说鞋子有点磨脚,我给你买了一双软软的鞋垫。</td>
396
+ <td align="center">
397
+
398
+ [speed_faster_prompt.webm](https://github.com/user-attachments/assets/db46609e-1b98-48d8-99c8-e166cfdfc6e3)
399
+ </td>
400
+ <td align="center">
401
+
402
+ [speed_faster_output.webm](https://github.com/user-attachments/assets/0fbc14ca-dd4a-4362-aadc-afe0629f4c9f)
403
+ </td>
404
+ </tr>
405
+
406
+ </table>
407
+
408
+
409
+ For more examples, see [demo page](https://stepaudiollm.github.io/step-audio-editx/).
410
+
411
+ ## Model Download
412
+
413
+ | Models | 🤗 Hugging Face | ModelScope |
414
+ |-------|-------|-------|
415
+ | Step-Audio-EditX | [stepfun-ai/Step-Audio-EditX](https://huggingface.co/stepfun-ai/Step-Audio-EditX) | [stepfun-ai/Step-Audio-EditX](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX) |
416
+ | Step-Audio-EditX | [stepfun-ai/Step-Audio-EditX-AWQ-4bit](https://huggingface.co/stepfun-ai/Step-Audio-EditX-AWQ-4bit) | [stepfun-ai/Step-Audio-EditX-AWQ-4bit](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX-AWQ-4bit) |
417
+ | Step-Audio-Tokenizer | [stepfun-ai/Step-Audio-Tokenizer](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer) | [stepfun-ai/Step-Audio-Tokenizer](https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer) |
418
+
419
+
420
+ ## Model Usage
421
+ ### 📜 Requirements
422
+ The following table shows the requirements for running Step-Audio-EditX model (batch size = 1):
423
+
424
+ | Model | Parameters | Setting<br/>(sample frequency) | GPU Optimal Memory |
425
+ |------------|------------|--------------------------------|----------------|
426
+ | Step-Audio-EditX | 3B| 41.6Hz | 12 GB |
427
+
428
+ * An NVIDIA GPU with CUDA support is required.
429
+ * The model is tested on a single L40S GPU.
430
+ * 12GB is just a critical value, and 16GB GPU memory shoule be safer.
431
+ * Tested operating system: Linux
432
+
433
+ ### 🔧 Dependencies and Installation
434
+ - Python >= 3.12
435
+ - [PyTorch >= 2.9.1](https://pytorch.org/)
436
+ - [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads)
437
+
438
+ ```bash
439
+ git clone https://github.com/stepfun-ai/Step-Audio-EditX.git
440
+
441
+ cd Step-Audio-EditX
442
+ uv sync --refresh
443
+ source .venv/bin/activate
444
+
445
+ git lfs install
446
+ git clone https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer
447
+ git clone https://huggingface.co/stepfun-ai/Step-Audio-EditX
448
+ git clone https://huggingface.co/stepfun-ai/Step-Audio-EditX-AWQ-4bit/
449
+
450
+ ```
451
+
452
+ After downloading the models, where_you_download_dir should have the following structure:
453
+ ```
454
+ where_you_download_dir
455
+ ├── Step-Audio-Tokenizer
456
+ ├── Step-Audio-EditX
457
+ ```
458
+
459
+ #### Run with Docker
460
+
461
+ You can set up the environment required for running Step-Audio-EditX using the provided Dockerfile.
462
+
463
+ ```bash
464
+ # build docker
465
+ docker build . -t step-audio-editx
466
+
467
+ # run docker
468
+ docker run --rm --gpus all \
469
+ -v /your/code/path:/app \
470
+ -v /your/model/path:/model \
471
+ -p 7860:7860 \
472
+ step-audio-editx
473
+ ```
474
+ #### Local Inference Demo
475
+ > [!TIP]
476
+ > For optimal performance, keep audio under 30 seconds per inference.
477
+
478
+ ```bash
479
+ # zero-shot cloning
480
+ # The path of the generated audio file is output/fear_zh_female_prompt_cloned.wav
481
+ python3 tts_infer.py \
482
+ --model-path where_you_download_dir \
483
+ --tokenizer-path where_you_download_dir \
484
+ --prompt-text "我总觉得,有人在跟着我,我能听到奇怪的脚步声。" \
485
+ --prompt-audio "examples/fear_zh_female_prompt.wav" \
486
+ --generated-text "可惜没有如果,已经发生的事情终究是发生了。" \
487
+ --edit-type "clone" \
488
+ --output-dir ./output
489
+
490
+ python3 tts_infer.py \
491
+ --model-path where_you_download_dir \
492
+ --tokenizer-path where_you_download_dir \
493
+ --prompt-text "His political stance was conservative, and he was particularly close to margaret thatcher." \
494
+ --prompt-audio "examples/zero_shot_en_prompt.wav" \
495
+ --generated-text "Underneath the courtyard is a large underground exhibition room which connects the two buildings. " \
496
+ --edit-type "clone" \
497
+ --output-dir ./output
498
+
499
+ # edit
500
+ # There will be one or multiple wave files corresponding to each edit iteration, for example: output/fear_zh_female_prompt_edited_iter1.wav, output/fear_zh_female_prompt_edited_iter2.wav, ...
501
+ # emotion; fear
502
+ python3 tts_infer.py \
503
+ --model-path where_you_download_dir \
504
+ --tokenizer-path where_you_download_dir \
505
+ --prompt-text "我总觉得,有人在跟着我,我能听到奇怪的脚步声。" \
506
+ --prompt-audio "examples/fear_zh_female_prompt.wav" \
507
+ --edit-type "emotion" \
508
+ --edit-info "fear" \
509
+ --output-dir ./output
510
+
511
+ # emotion; happy
512
+ python3 tts_infer.py \
513
+ --model-path where_you_download_dir \
514
+ --tokenizer-path where_you_download_dir \
515
+ --prompt-text "You know, I just finished that big project and feel so relieved. Everything seems easier and more colorful, what a wonderful feeling!" \
516
+ --prompt-audio "examples/en_happy_prompt.wav" \
517
+ --edit-type "emotion" \
518
+ --edit-info "happy" \
519
+ --output-dir ./output
520
+
521
+ # style; whisper
522
+ # for style whisper, the edit iteration num should be set bigger than 1 to get better results.
523
+ python3 tts_infer.py \
524
+ --model-path where_you_download_dir \
525
+ --tokenizer-path where_you_download_dir \
526
+ --prompt-text "比如在工作间隙,做一些简单的伸展运动,放松一下身体,这样,会让你更有精力." \
527
+ --prompt-audio "examples/whisper_prompt.wav" \
528
+ --edit-type "style" \
529
+ --edit-info "whisper" \
530
+ --output-dir ./output
531
+
532
+ # paraliguistic
533
+ # supported tags, Breathing, Laughter, Surprise-oh, Confirmation-en, Uhm, Surprise-ah, Surprise-wa, Sigh, Question-ei, Dissatisfaction-hnn
534
+ python3 tts_infer.py \
535
+ --model-path where_you_download_dir \
536
+ --tokenizer-path where_you_download_dir \
537
+ --prompt-text "我觉得这个计划大概是可行的,不过还需要再仔细考虑一下。" \
538
+ --prompt-audio "examples/paralingustic_prompt.wav" \
539
+ --generated-text "我觉得这个计划大概是可行的,[Uhm]不过还需要再仔细考虑一下。" \
540
+ --edit-type "paralinguistic" \
541
+ --output-dir ./output
542
+
543
+ # denoise
544
+ # Prompt text is not needed.
545
+ python3 tts_infer.py \
546
+ --model-path where_you_download_dir \
547
+ --tokenizer-path where_you_download_dir \
548
+ --prompt-audio "examples/denoise_prompt.wav"\
549
+ --edit-type "denoise" \
550
+ --output-dir ./output
551
+
552
+ # vad
553
+ # Prompt text is not needed.
554
+ python3 tts_infer.py \
555
+ --model-path where_you_download_dir \
556
+ --tokenizer-path where_you_download_dir \
557
+ --prompt-audio "examples/vad_prompt.wav" \
558
+ --edit-type "vad" \
559
+ --output-dir ./output
560
+
561
+ # speed
562
+ # supported edit-info: faster, slower, more faster, more slower
563
+ python3 tts_infer.py \
564
+ --model-path where_you_download_dir \
565
+ --tokenizer-path where_you_download_dir \
566
+ --prompt-text "上次你说鞋子有点磨脚,我给你买了一双软软的鞋垫。" \
567
+ --prompt-audio "examples/speed_prompt.wav" \
568
+ --edit-type "speed" \
569
+ --edit-info "more faster" \
570
+ --output-dir ./output
571
+
572
+ ```
573
+
574
+
575
+
576
+ #### Launch Web Demo
577
+ Start a local server for online inference.
578
+ Assume you have one GPU with at least 12GB memory available and have already downloaded all the models.
579
+
580
+ ```bash
581
+ # Standard launch
582
+ python app.py --model-path where_you_download_dir --tokenizer-path where_you_download_dir --model-source local
583
+
584
+ # Using pre-quantized AWQ 4-bit models, memory-efficient mode (for limited GPU memory, ~6-8GB usage)
585
+ python app.py \
586
+ --model-path path/to/quantized/model \
587
+ --tokenizer-path where_you_download_dir \
588
+ --model-source local \
589
+ --gpu-memory-utilization 0.1 \
590
+ --enforce-eager \
591
+ --max-num-seqs 1 \
592
+ --cosyvoice-dtype bfloat16 \
593
+ --no-cosyvoice-cuda-graph
594
+
595
+ ```
596
+
597
+ ##### Available Parameters
598
+
599
+ | Parameter | Default | Description |
600
+ |-----------|---------|-------------|
601
+ | `--model-path` | (required) | Path to the model directory |
602
+ | `--model-source` | `auto` | Model source: `auto`, `local`, `modelscope`, `huggingface` |
603
+ | `--gpu-memory-utilization` | `0.5` | GPU memory ratio for vLLM KV cache (0.0-1.0) |
604
+ | `--max-model-len` | `3072` | Maximum sequence length, affects KV cache size |
605
+ | `--enforce-eager` | `True` | Disable vLLM CUDA Graphs (saves ~0.5GB memory) |
606
+ | `--max-num-seqs` | `1` | Maximum concurrent sequences (vLLM default: 256, lower = less memory) |
607
+ | `--dtype` | `bfloat16` | Model dtype: `float16`, `bfloat16` |
608
+ | `--quantization` | `None` | Quantization method: `awq`, `gptq`, `fp8` |
609
+ | `--cosyvoice-dtype` | `bfloat16` | CosyVoice vocoder dtype: `float32`, `bfloat16`, `float16` |
610
+ | `--no-cosyvoice-cuda-graph` | `False` | Disable CosyVoice CUDA Graphs (saves memory) |
611
+ | `--enable-auto-transcribe` | `False` | Enable automatic audio transcription |
612
+
613
+ ##### Memory Usage Guide
614
+
615
+ | Configuration | Estimated GPU Memory | Use Case |
616
+ |--------------|---------------------|----------|
617
+ | Standard (defaults) | ~12-15 GB | Best quality and speed |
618
+ | Memory-efficient | ~6-8 GB | Limited GPU memory, some quality trade-off |
619
+ | AWQ 4-bit quantized | ~8-10 GB | Good balance of quality and memory |
620
+
621
+ ## Training
622
+ Please refer to script/ReadMe.md
623
+
624
+ ### 🔄 Model Quantization (Optional)
625
+
626
+ For users with limited GPU memory, you can create quantized versions of the model to reduce memory requirements:
627
+
628
+ ```bash
629
+ # Create an AWQ 4-bit quantized model
630
+ python quantization/awq_quantize.py --model_path path/to/Step-Audio-EditX
631
+
632
+ # Advanced quantization options
633
+ python quantization/awq_quantize.py
634
+ ```
635
+
636
+ For detailed quantization options and parameters, see [quantization/README.md](quantization/README.md).
637
+
638
+
639
+ ## Technical Details
640
+ <img src="assets/architechture.png" width=900>
641
+ Step-Audio-EditX comprises three primary components:
642
+
643
+ - A dual-codebook audio tokenizer, which converts reference or input audio into discrete tokens.
644
+ - An audio LLM that generates dual-codebook token sequences.
645
+ - An audio decoder, which converts the dual-codebook token sequences predicted by the audio LLM back into audio waveforms using a flow matching approach.
646
+
647
+ Audio-Edit enables iterative control over emotion and speaking style across all voices, leveraging large-margin data during SFT and PPO training.
648
+
649
+ ## Evaluation
650
+
651
+ ### Comparison between Step-Audio-EditX and Closed-Source models.
652
+
653
+ - Step-Audio-EditX demonstrates superior performance over Minimax and Doubao in both zero-shot cloning and emotion control.
654
+ - Emotion editing of Step-Audio-EditX significantly improves the emotion-controlled audio outputs of all three models after just one iteration. With further iterations, their overall performance continues to improve.
655
+
656
+ <div align="center">
657
+ <img src="assets/emotion-eval.png" width=800 >
658
+ </div>
659
+
660
+ ### Generalization on Closed-Source Models.
661
+ - For emotion and speaking style editing, the built-in voices of leading closed-source systems possess considerable in-context capabilities, allowing them to partially convey the emotions in the text. After a single editing round with Step-Audio-EditX, the emotion and style accuracy across all voice models exhibited significant improvement. Further enhancement was observed over the next two iterations, robustly demonstrating our model's strong generalization.
662
+
663
+ - For paralinguistic editing, after editing with Step-Audio-EditX, the performance of paralinguistic reproduction is comparable to that achieved by the built-in voices of closed-source models when synthesizing native paralinguistic content directly. (**sub** means replacement of paralinguistic tags with native words)
664
+
665
+
666
+ <div align="center">
667
+
668
+ <table border="1" cellspacing="0" cellpadding="5" style="border-collapse: collapse; font-family: sans-serif; width: auto;">
669
+ <caption><b>Table: Generalization of Emotion, Speaking Style, and Paralinguistic Editing on Closed-Source Models.</b></caption>
670
+ <thead>
671
+ <tr>
672
+ <th rowspan="2" align="center" style="vertical-align: bottom;">Language</th>
673
+ <th rowspan="2" align="center" style="vertical-align: bottom;">Model</th>
674
+ <th colspan="4" style="border-bottom: 1px solid black;">Emotion &uarr;</th>
675
+ <th colspan="4" style="border-bottom: 1px solid black;">Speaking Style &uarr;</th>
676
+ <th colspan="3" style="border-bottom: 1px solid black; border-left: 1px solid black;">Paralinguistic &uarr;</th>
677
+ </tr>
678
+ <tr>
679
+ <th>Iter<sub>0</sub></th>
680
+ <th>Iter<sub>1</sub></th>
681
+ <th>Iter<sub>2</sub></th>
682
+ <th>Iter<sub>3</sub></th>
683
+ <th style="border-left: 1px solid #ccc;">Iter<sub>0</sub></th>
684
+ <th>Iter<sub>1</sub></th>
685
+ <th>Iter<sub>2</sub></th>
686
+ <th>Iter<sub>3</sub></th>
687
+ <th style="border-left: 1px solid black;">Iter<sub>0</sub></th>
688
+ <th>sub</th>
689
+ <th>Iter<sub>1</sub></th>
690
+ </tr>
691
+ </thead>
692
+ <tbody>
693
+ <tr>
694
+ <td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">Chinese</td>
695
+ <td align="left">MiniMax-2.6-hd</td>
696
+ <td align="center">71.6</td>
697
+ <td align="center">78.6</td>
698
+ <td align="center">81.2</td>
699
+ <td align="center"><b>83.4</b></td>
700
+ <td align="center" style="border-left: 1px solid #ccc;">36.7</td>
701
+ <td align="center">58.8</td>
702
+ <td align="center">63.1</td>
703
+ <td align="center"><b>67.3</b></td>
704
+ <td align="center" style="border-left: 1px solid black;">1.73</td>
705
+ <td align="center">2.80</td>
706
+ <td align="center">2.90</td>
707
+ </tr>
708
+ <tr>
709
+ <td align="left">Doubao-Seed-TTS-2.0</td>
710
+ <td align="center">67.4</td>
711
+ <td align="center">77.8</td>
712
+ <td align="center">80.6</td>
713
+ <td align="center"><b>82.8</b></td>
714
+ <td align="center" style="border-left: 1px solid #ccc;">38.2</td>
715
+ <td align="center">60.2</td>
716
+ <td align="center"><b>65.0</b></td>
717
+ <td align="center">64.9</td>
718
+ <td align="center" style="border-left: 1px solid black;">1.67</td>
719
+ <td align="center">2.81</td>
720
+ <td align="center">2.90</td>
721
+ </tr>
722
+ <tr>
723
+ <td align="left">GPT-4o-mini-TTS</td>
724
+ <td align="center">62.6</td>
725
+ <td align="center">76.0</td>
726
+ <td align="center">77.0</td>
727
+ <td align="center"><b>81.8</b></td>
728
+ <td align="center" style="border-left: 1px solid #ccc;">45.9</td>
729
+ <td align="center">64.0</td>
730
+ <td align="center">65.7</td>
731
+ <td align="center"><b>69.7</b></td>
732
+ <td align="center" style="border-left: 1px solid black;">1.71</td>
733
+ <td align="center">2.88</td>
734
+ <td align="center">2.93</td>
735
+ </tr>
736
+ <tr style="border-bottom: 1px solid black;">
737
+ <td align="left">ElevenLabs-v2</td>
738
+ <td align="center">60.4</td>
739
+ <td align="center">74.6</td>
740
+ <td align="center">77.4</td>
741
+ <td align="center"><b>79.2</b></td>
742
+ <td align="center" style="border-left: 1px solid #ccc;">43.8</td>
743
+ <td align="center">63.3</td>
744
+ <td align="center">69.7</td>
745
+ <td align="center"><b>70.8</b></td>
746
+ <td align="center" style="border-left: 1px solid black;">1.70</td>
747
+ <td align="center">2.71</td>
748
+ <td align="center">2.92</td>
749
+ </tr>
750
+ <tr>
751
+ <td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">English</td>
752
+ <td align="left">MiniMax-2.6-hd</td>
753
+ <td align="center">55.0</td>
754
+ <td align="center">64.0</td>
755
+ <td align="center">64.2</td>
756
+ <td align="center"><b>66.4</b></td>
757
+ <td align="center" style="border-left: 1px solid #ccc;">51.9</td>
758
+ <td align="center">60.3</td>
759
+ <td align="center">62.3</td>
760
+ <td align="center"><b>64.3</b></td>
761
+ <td align="center" style="border-left: 1px solid black;">1.72</td>
762
+ <td align="center">2.87</td>
763
+ <td align="center">2.88</td>
764
+ </tr>
765
+ <tr>
766
+ <td align="left">Doubao-Seed-TTS-2.0</td>
767
+ <td align="center">53.8</td>
768
+ <td align="center">65.8</td>
769
+ <td align="center">65.8</td>
770
+ <td align="center"><b>66.2</b></td>
771
+ <td align="center" style="border-left: 1px solid #ccc;">47.0</td>
772
+ <td align="center">62.0</td>
773
+ <td align="center"><b>62.7</b></td>
774
+ <td align="center">62.3</td>
775
+ <td align="center" style="border-left: 1px solid black;">1.72</td>
776
+ <td align="center">2.75</td>
777
+ <td align="center">2.92</td>
778
+ </tr>
779
+ <tr>
780
+ <td align="left">GPT-4o-mini-TTS</td>
781
+ <td align="center">56.8</td>
782
+ <td align="center">61.4</td>
783
+ <td align="center">64.8</td>
784
+ <td align="center"><b>65.2</b></td>
785
+ <td align="center" style="border-left: 1px solid #ccc;">52.3</td>
786
+ <td align="center">62.3</td>
787
+ <td align="center">62.4</td>
788
+ <td align="center"><b>63.4</b></td>
789
+ <td align="center" style="border-left: 1px solid black;">1.90</td>
790
+ <td align="center">2.90</td>
791
+ <td align="center">2.88</td>
792
+ </tr>
793
+ <tr style="border-bottom: 1px solid black;">
794
+ <td align="left">ElevenLabs-v2</td>
795
+ <td align="center">51.0</td>
796
+ <td align="center">61.2</td>
797
+ <td align="center">64.0</td>
798
+ <td align="center"><b>65.2</b></td>
799
+ <td align="center" style="border-left: 1px solid #ccc;">51.0</td>
800
+ <td align="center">62.1</td>
801
+ <td align="center">62.6</td>
802
+ <td align="center"><b>64.0</b></td>
803
+ <td align="center" style="border-left: 1px solid black;">1.93</td>
804
+ <td align="center">2.87</td>
805
+ <td align="center">2.88</td>
806
+ </tr>
807
+ <tr>
808
+ <td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">Average</td>
809
+ <td align="left">MiniMax-2.6-hd</td>
810
+ <td align="center">63.3</td>
811
+ <td align="center">71.3</td>
812
+ <td align="center">72.7</td>
813
+ <td align="center"><b>74.9</b></td>
814
+ <td align="center" style="border-left: 1px solid #ccc;">44.2</td>
815
+ <td align="center">59.6</td>
816
+ <td align="center">62.7</td>
817
+ <td align="center"><b>65.8</b></td>
818
+ <td align="center" style="border-left: 1px solid black;">1.73</td>
819
+ <td align="center">2.84</td>
820
+ <td align="center">2.89</td>
821
+ </tr>
822
+ <tr>
823
+ <td align="left">Doubao-Seed-TTS-2.0</td>
824
+ <td align="center">60.6</td>
825
+ <td align="center">71.8</td>
826
+ <td align="center">73.2</td>
827
+ <td align="center"><b>74.5</b></td>
828
+ <td align="center" style="border-left: 1px solid #ccc;">42.6</td>
829
+ <td align="center">61.1</td>
830
+ <td align="center"><b>63.9</b></td>
831
+ <td align="center">63.6</td>
832
+ <td align="center" style="border-left: 1px solid black;">1.70</td>
833
+ <td align="center">2.78</td>
834
+ <td align="center">2.91</td>
835
+ </tr>
836
+ <tr>
837
+ <td align="left">GPT-4o-mini-TTS</td>
838
+ <td align="center">59.7</td>
839
+ <td align="center">68.7</td>
840
+ <td align="center">70.9</td>
841
+ <td align="center"><b>73.5</b></td>
842
+ <td align="center" style="border-left: 1px solid #ccc;">49.1</td>
843
+ <td align="center">63.2</td>
844
+ <td align="center">64.1</td>
845
+ <td align="center"><b>66.6</b></td>
846
+ <td align="center" style="border-left: 1px solid black;">1.81</td>
847
+ <td align="center">2.89</td>
848
+ <td align="center">2.90</td>
849
+ </tr>
850
+ <tr>
851
+ <td align="left">ElevenLabs-v2</td>
852
+ <td align="center">55.7</td>
853
+ <td align="center">67.9</td>
854
+ <td align="center">70.7</td>
855
+ <td align="center"><b>72.2</b></td>
856
+ <td align="center" style="border-left: 1px solid #ccc;">47.4</td>
857
+ <td align="center">62.7</td>
858
+ <td align="center">66.1</td>
859
+ <td align="center"><b>67.4</b></td>
860
+ <td align="center" style="border-left: 1px solid black;">1.82</td>
861
+ <td align="center">2.79</td>
862
+ <td align="center">2.90</td>
863
+ </tr>
864
+ </tbody>
865
+ </table>
866
+
867
+ </div>
868
+
869
+
870
+ ## Acknowledgements
871
+
872
+ Part of the code and data for this project comes from:
873
+ * [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
874
+ * [transformers](https://github.com/huggingface/transformers)
875
+ * [FunASR](https://github.com/modelscope/FunASR)
876
+ * [NVSpeech](https://huggingface.co/datasets/amphion/Emilia-NV)
877
+ * [vllm](https://github.com/vllm-project/vllm)
878
+
879
+ Thank you to all the open-source projects for their contributions to this project!
880
+
881
+ ## License Agreement
882
+ + The code in this open-source repository is licensed under the [Apache 2.0](LICENSE) License.
883
+
884
+ ## Citation
885
+
886
+ ```
887
+ @misc{yan2025stepaudioeditxtechnicalreport,
888
+ title={Step-Audio-EditX Technical Report},
889
+ author={Chao Yan and Boyong Wu and Peng Yang and Pengfei Tan and Guoqiang Hu and Yuxin Zhang and Xiangyu and Zhang and Fei Tian and Xuerui Yang and Xiangyu Zhang and Daxin Jiang and Gang Yu},
890
+ year={2025},
891
+ eprint={2511.03601},
892
+ archivePrefix={arXiv},
893
+ primaryClass={cs.CL},
894
+ url={https://arxiv.org/abs/2511.03601},
895
+ }
896
+ ```
897
+
898
+
899
+ ## ⚠️ Usage Disclaimer
900
+ - Do not use this model for any unauthorized activities, including but not limited to:
901
+ - Voice cloning without permission
902
+ - Identity impersonation
903
+ - Fraud
904
+ - Deepfakes or any other illegal purposes
905
+ - Ensure compliance with local laws and regulations, and adhere to ethical guidelines when using this model.
906
+ - The model developers are not responsible for any misuse or abuse of this technology.
907
+
908
+ We advocate for responsible generative AI research and urge the community to uphold safety and ethical standards in AI development and application. If you have any concerns regarding the use of this model, please feel free to contact us.
909
+
910
+ ## Star History
911
+ [![Star History Chart](https://api.star-history.com/svg?repos=stepfun-ai/Step-Audio-EditX&type=Date)](https://star-history.com/#stepfun-ai/Step-Audio-EditX&Date)
assets/architechture.png ADDED

Git LFS Details

  • SHA256: 208ed5db4767e37eda66eec26818aa57cc8965051c7758a8e3e8e5e8af497833
  • Pointer size: 131 Bytes
  • Size of remote file: 211 kB
assets/emotion-eval.png ADDED

Git LFS Details

  • SHA256: 8370c9a1f1e7a513c0f67ec4b8d9ffc89daa91ff298cbe6e137709c28d70a582
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
assets/logo.png ADDED
assets/test.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7d3dc34ac3dd2f7765e61ba5e2023beb0f59cdf2acb14c42fa00fbddd13afa3
3
+ size 192558
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Step1ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_step1.Step1Config",
7
+ "AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
8
+ },
9
+ "model_type": "step1",
10
+ "bos_token_id": 1,
11
+ "pad_token_id": 0,
12
+ "eos_token_id": 3,
13
+ "hidden_size": 3072,
14
+ "intermediate_size": 8192,
15
+ "num_attention_heads": 48,
16
+ "num_attention_groups": 4,
17
+ "num_hidden_layers": 32,
18
+ "max_seq_len": 32768,
19
+ "vocab_size": 74752,
20
+ "rms_norm_eps": 1e-05,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "bfloat16",
23
+ "use_cache": true
24
+ }
configuration_step1.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Any, Dict
2
+ from transformers.configuration_utils import PretrainedConfig
3
+
4
+
5
+
6
+ class Step1Config(PretrainedConfig):
7
+ model_type = "step1"
8
+ keys_to_ignore_at_inference = ["past_key_values"]
9
+
10
+ def __init__(
11
+ self,
12
+ hidden_size: int = 5120,
13
+ intermediate_size: int = 13312,
14
+ num_attention_heads: int = 40,
15
+ num_attention_groups: int = 8,
16
+ num_hidden_layers: int = 48,
17
+ max_seq_len: int = 4096,
18
+ vocab_size: int = 65536,
19
+ rms_norm_eps: float = 1e-5,
20
+ bos_token_id: int = 1,
21
+ eos_token_id: int = 3,
22
+ pad_token_id: int = 0,
23
+ **kwargs,
24
+ ) -> None:
25
+ self.hidden_size = hidden_size
26
+ self.intermediate_size = intermediate_size
27
+ self.num_attention_heads = num_attention_heads
28
+ self.num_attention_groups = num_attention_groups
29
+ self.num_hidden_layers = num_hidden_layers
30
+ self.max_seq_len = max_seq_len
31
+ self.vocab_size = vocab_size
32
+ self.rms_norm_eps = rms_norm_eps
33
+ super().__init__(
34
+ bos_token_id=bos_token_id,
35
+ pad_token_id=pad_token_id,
36
+ eos_token_id=eos_token_id,
37
+ **kwargs
38
+ )
39
+
40
+
41
+ __all__ = ["Step1Config"]
model-00001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb2c6baae3a40ccd19cf21f7a28629666e31c215b5d63b2ed1e04aac6dd08d69
3
+ size 7059446656
model.safetensors.index.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"metadata": {"total_size": 7059412992}, "weight_map": {"model.embed_tokens.weight": "model-00001.safetensors", "model.layers.0.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.0.input_layernorm.weight": "model-00001.safetensors", "model.layers.0.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.0.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.0.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.0.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.0.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.0.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.0.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.1.input_layernorm.weight": "model-00001.safetensors", "model.layers.1.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.1.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.1.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.1.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.1.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.2.input_layernorm.weight": "model-00001.safetensors", "model.layers.2.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.2.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.2.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.2.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.2.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.3.input_layernorm.weight": "model-00001.safetensors", "model.layers.3.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.3.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.3.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.3.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.3.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.4.input_layernorm.weight": "model-00001.safetensors", "model.layers.4.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.4.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.4.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.4.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.4.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.5.input_layernorm.weight": "model-00001.safetensors", "model.layers.5.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.5.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.5.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.5.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.5.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.6.input_layernorm.weight": "model-00001.safetensors", "model.layers.6.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.6.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.6.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.6.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.6.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.7.input_layernorm.weight": "model-00001.safetensors", "model.layers.7.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.7.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.7.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.7.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.7.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.8.input_layernorm.weight": "model-00001.safetensors", "model.layers.8.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.8.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.8.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.8.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.8.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.9.input_layernorm.weight": "model-00001.safetensors", "model.layers.9.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.9.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.9.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.9.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.9.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.10.input_layernorm.weight": "model-00001.safetensors", "model.layers.10.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.10.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.10.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.10.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.10.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.11.input_layernorm.weight": "model-00001.safetensors", "model.layers.11.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.11.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.11.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.11.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.11.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.12.input_layernorm.weight": "model-00001.safetensors", "model.layers.12.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.12.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.12.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.12.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.12.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.13.input_layernorm.weight": "model-00001.safetensors", "model.layers.13.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.13.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.13.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.13.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.13.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.14.input_layernorm.weight": "model-00001.safetensors", "model.layers.14.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.14.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.14.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.14.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.14.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.15.input_layernorm.weight": "model-00001.safetensors", "model.layers.15.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.15.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.15.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.15.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.15.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.16.input_layernorm.weight": "model-00001.safetensors", "model.layers.16.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.16.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.16.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.16.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.16.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.17.input_layernorm.weight": "model-00001.safetensors", "model.layers.17.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.17.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.17.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.17.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.17.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.18.input_layernorm.weight": "model-00001.safetensors", "model.layers.18.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.18.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.18.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.18.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.18.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.19.input_layernorm.weight": "model-00001.safetensors", "model.layers.19.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.19.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.19.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.19.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.19.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.20.input_layernorm.weight": "model-00001.safetensors", "model.layers.20.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.20.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.20.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.20.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.20.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.21.input_layernorm.weight": "model-00001.safetensors", "model.layers.21.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.21.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.21.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.21.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.21.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.22.input_layernorm.weight": "model-00001.safetensors", "model.layers.22.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.22.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.22.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.22.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.22.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.23.input_layernorm.weight": "model-00001.safetensors", "model.layers.23.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.23.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.23.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.23.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.23.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.24.input_layernorm.weight": "model-00001.safetensors", "model.layers.24.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.24.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.24.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.24.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.24.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.25.input_layernorm.weight": "model-00001.safetensors", "model.layers.25.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.25.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.25.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.25.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.25.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.26.input_layernorm.weight": "model-00001.safetensors", "model.layers.26.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.26.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.26.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.26.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.26.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.27.input_layernorm.weight": "model-00001.safetensors", "model.layers.27.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.27.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.27.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.27.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.27.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.28.input_layernorm.weight": "model-00001.safetensors", "model.layers.28.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.28.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.28.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.28.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.28.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.29.input_layernorm.weight": "model-00001.safetensors", "model.layers.29.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.29.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.29.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.29.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.29.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.30.input_layernorm.weight": "model-00001.safetensors", "model.layers.30.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.30.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.30.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.30.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.30.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.31.input_layernorm.weight": "model-00001.safetensors", "model.layers.31.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.31.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.31.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.31.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.31.mlp.up_proj.weight": "model-00001.safetensors", "model.norm.weight": "model-00001.safetensors", "lm_head.weight": "model-00001.safetensors"}}
modeling_step1.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union, List
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from transformers.generation import GenerationMixin
8
+
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from transformers.utils import logging
11
+ from .configuration_step1 import Step1Config
12
+ from transformers.cache_utils import Cache, DynamicCache
13
+ from einops import rearrange
14
+ from transformers.modeling_outputs import (
15
+ BaseModelOutputWithPast,
16
+ CausalLMOutputWithPast,
17
+ )
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ def build_alibi_cache(block_size, n_heads, dtype, device):
23
+ # get slopes
24
+ n = 2 ** math.floor(math.log2(n_heads)) # nearest 2**n to n_heads
25
+ m0 = 2.0 ** (-8.0 / n)
26
+ # 2^(-8/n), 2^(-8*2/n), 2^(-8*3/n), ...
27
+ slopes = torch.pow(m0, torch.arange(1, n + 1))
28
+ if n < n_heads:
29
+ m1 = 2.0 ** (-4.0 / n)
30
+ # 2^(-8/(2n)), 2^(-8*3/(2n)), 2^(-8*5/(2n)), ...
31
+ mm = torch.pow(m1, torch.arange(1, 1 + 2 * (n_heads - n), 2))
32
+ slopes = torch.cat([slopes, mm])
33
+ slopes = slopes.to(device)
34
+
35
+ tril = torch.tril(torch.ones(1, 1, block_size, block_size, device=device))
36
+
37
+ bias_rows = torch.arange(block_size, device=device).view(1, -1)
38
+ bias_cols = torch.arange(block_size, device=device).view(-1, 1)
39
+ bias = -torch.sqrt(bias_cols - bias_rows)
40
+ bias = bias.view(1, block_size, block_size) * slopes.view(-1, 1, 1)
41
+ bias = bias.masked_fill(tril == 0, float("-inf"))
42
+
43
+ return bias.type(dtype)
44
+
45
+
46
+ class StepRMSNorm(torch.nn.Module):
47
+ def __init__(self, hidden_size, eps=1e-5):
48
+ super().__init__()
49
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
50
+ self.eps = eps
51
+
52
+ def forward(self, x: torch.Tensor):
53
+ var = x.float().pow(2).mean(-1, keepdim=True)
54
+ x = x * torch.rsqrt(var + self.eps).to(x.dtype)
55
+ x = x * self.weight
56
+ return x
57
+
58
+
59
+ class StepAttention(torch.nn.Module):
60
+ def __init__(self, hidden_size, num_heads, num_groups, layer_idx: int):
61
+ super().__init__()
62
+
63
+ self.num_heads = num_heads
64
+ self.num_groups = num_groups
65
+ self.hidden_size = hidden_size
66
+ self.head_dim = hidden_size // num_heads
67
+
68
+ self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
69
+ self.k_proj = torch.nn.Linear(
70
+ hidden_size, num_groups * self.head_dim, bias=False
71
+ )
72
+ self.v_proj = torch.nn.Linear(
73
+ hidden_size, num_groups * self.head_dim, bias=False
74
+ )
75
+ self.o_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
76
+
77
+ self.layer_idx = layer_idx
78
+
79
+ def flash_attn_func(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
80
+ return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
81
+ softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
82
+ return torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ past_key_value: Optional[Cache] = None,
88
+ attention_mask: Optional[torch.Tensor] = None,
89
+ cache_position: Optional[torch.LongTensor] = None,
90
+ ):
91
+
92
+ q: torch.Tensor = self.q_proj(x)
93
+ k: torch.Tensor = self.k_proj(x)
94
+ v: torch.Tensor = self.v_proj(x)
95
+ if past_key_value is not None:
96
+ cache_kwargs = {"cache_position": cache_position}
97
+ k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
98
+
99
+ q = rearrange(q, "b s (h d) -> b s h d", h=self.num_heads)
100
+ k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
101
+ v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
102
+
103
+ try:
104
+ if self.head_dim not in (64, 128):
105
+ raise ValueError("head_dim must be 64 or 128")
106
+ attn_output = self.flash_attn_func(q, k, v)
107
+ attn_output = attn_output.flatten(-2, -1)
108
+ except:
109
+ k = k.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
110
+ v = v.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
111
+
112
+ attention_mask = build_alibi_cache(
113
+ k.size(1), self.num_heads, dtype=q.dtype, device=q.device
114
+ )[:, :, -q.size(1) :, :].contiguous()
115
+
116
+ q = q.transpose(1, 2)
117
+ k = k.transpose(1, 2)
118
+ v = v.transpose(1, 2)
119
+
120
+ attn_output: torch.Tensor = torch.nn.functional.scaled_dot_product_attention(
121
+ q, k, v, attn_mask=attention_mask
122
+ )
123
+
124
+ attn_output = attn_output.transpose(1, 2).flatten(-2, -1)
125
+
126
+ out = self.o_proj(attn_output)
127
+ return out, None # attn weights are not returned
128
+
129
+
130
+ class StepMLP(torch.nn.Module):
131
+ def __init__(self, hidden_size, intermediate_size):
132
+ super().__init__()
133
+ self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
134
+ self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
135
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
136
+
137
+ def forward(self, x):
138
+ gate = self.gate_proj(x)
139
+ up = self.up_proj(x)
140
+ x = torch.nn.functional.silu(gate) * up
141
+ x = self.down_proj(x)
142
+ return x
143
+
144
+
145
+ class StepLayer(torch.nn.Module):
146
+ def __init__(self, config: Step1Config, layer_idx: int):
147
+ super().__init__()
148
+ self.layer_idx = layer_idx
149
+ self.self_attn = StepAttention(
150
+ hidden_size=config.hidden_size,
151
+ num_heads=config.num_attention_heads,
152
+ num_groups=config.num_attention_groups,
153
+ layer_idx=layer_idx,
154
+ )
155
+ self.mlp = StepMLP(
156
+ hidden_size=config.hidden_size,
157
+ intermediate_size=config.intermediate_size,
158
+ )
159
+ self.input_layernorm = StepRMSNorm(
160
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps
161
+ )
162
+ self.post_attention_layernorm = StepRMSNorm(
163
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps
164
+ )
165
+
166
+ def forward(
167
+ self,
168
+ hidden_states: torch.Tensor,
169
+ attention_mask: Optional[torch.Tensor] = None,
170
+ past_key_value: Optional[Cache] = None,
171
+ output_attentions: Optional[bool] = False,
172
+ cache_position: Optional[torch.LongTensor] = None,
173
+ ):
174
+ residual = hidden_states
175
+ hidden_states = self.input_layernorm(hidden_states)
176
+ hidden_states, self_attn_weights = self.self_attn(hidden_states, past_key_value, attention_mask, cache_position)
177
+ hidden_states = residual + hidden_states
178
+
179
+ residual = hidden_states
180
+ hidden_states = self.post_attention_layernorm(hidden_states)
181
+ hidden_states = self.mlp(hidden_states)
182
+ hidden_states = residual + hidden_states
183
+
184
+ outputs = (hidden_states, )
185
+ if output_attentions:
186
+ outputs += (self_attn_weights,)
187
+ return outputs
188
+
189
+
190
+ class StepPreTrainedModel(PreTrainedModel):
191
+ config_class = Step1Config
192
+ base_model_prefix = "model"
193
+ supports_gradient_checkpointing = True
194
+ _no_split_modules = ["StepLayer"]
195
+ _skip_keys_device_placement = ["past_key_values"]
196
+ _supports_cache_class = True
197
+ _supports_static_cache = True
198
+
199
+ def _init_weights(self, module):
200
+ std = self.config.initializer_range
201
+ if isinstance(module, nn.Linear):
202
+ module.weight.data.normal_(mean=0.0, std=std)
203
+ if module.bias is not None:
204
+ module.bias.data.zero_()
205
+ elif isinstance(module, nn.Embedding):
206
+ module.weight.data.normal_(mean=0.0, std=std)
207
+ if module.padding_idx is not None:
208
+ module.weight.data[module.padding_idx].zero_()
209
+
210
+
211
+ class Step1Model(StepPreTrainedModel):
212
+ """
213
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
214
+
215
+ Args:
216
+ config: Step1Config
217
+ """
218
+
219
+ def __init__(self, config: Step1Config):
220
+ super().__init__(config)
221
+ self.config = config
222
+ self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size)
223
+
224
+ self.layers = torch.nn.Sequential(
225
+ *[
226
+ StepLayer(config, layer_idx)
227
+ for layer_idx in range(config.num_hidden_layers)
228
+ ]
229
+ )
230
+
231
+ self.norm = StepRMSNorm(
232
+ hidden_size=config.hidden_size, eps=config.rms_norm_eps
233
+ )
234
+
235
+ # Initialize weights and apply final processing
236
+ self.post_init()
237
+
238
+ def get_input_embeddings(self):
239
+ return self.embed_tokens
240
+
241
+ def set_input_embeddings(self, value):
242
+ self.embed_tokens = value
243
+
244
+ def forward(
245
+ self,
246
+ input_ids: torch.LongTensor = None,
247
+ attention_mask: Optional[torch.Tensor] = None,
248
+ past_key_values: Optional[Cache] = None,
249
+ inputs_embeds: Optional[torch.FloatTensor] = None,
250
+ use_cache: Optional[bool] = None,
251
+ output_attentions: Optional[bool] = None,
252
+ output_hidden_states: Optional[bool] = None,
253
+ return_dict: Optional[bool] = None,
254
+ cache_position: Optional[torch.LongTensor] = None,
255
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
256
+ output_attentions = (
257
+ output_attentions
258
+ if output_attentions is not None
259
+ else self.config.output_attentions
260
+ )
261
+ output_hidden_states = (
262
+ output_hidden_states
263
+ if output_hidden_states is not None
264
+ else self.config.output_hidden_states
265
+ )
266
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
267
+ return_dict = (
268
+ return_dict if return_dict is not None else self.config.use_return_dict
269
+ )
270
+
271
+ if (input_ids is None) ^ (inputs_embeds is not None):
272
+ raise ValueError(
273
+ "You must specify exactly one of input_ids or inputs_embeds"
274
+ )
275
+
276
+ if inputs_embeds is None:
277
+ inputs_embeds = self.embed_tokens(input_ids)
278
+
279
+ if use_cache and past_key_values is None:
280
+ past_key_values = DynamicCache()
281
+
282
+ if cache_position is None:
283
+ past_seen_tokens = (
284
+ past_key_values.get_seq_length() if past_key_values is not None else 0
285
+ )
286
+ cache_position = torch.arange(
287
+ past_seen_tokens,
288
+ past_seen_tokens + inputs_embeds.shape[1],
289
+ device=inputs_embeds.device,
290
+ )
291
+
292
+ causal_mask = attention_mask
293
+
294
+ hidden_states = inputs_embeds
295
+
296
+ # decoder layers
297
+ all_hidden_states = () if output_hidden_states else None
298
+ all_self_attns = () if output_attentions else None
299
+
300
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
301
+ if output_hidden_states:
302
+ all_hidden_states += (hidden_states,)
303
+
304
+ layer_outputs = decoder_layer(
305
+ hidden_states,
306
+ attention_mask=causal_mask,
307
+ past_key_value=past_key_values,
308
+ cache_position=cache_position,
309
+ output_attentions=output_attentions,
310
+ )
311
+
312
+ hidden_states = layer_outputs[0]
313
+
314
+ if output_attentions:
315
+ all_self_attns += (layer_outputs[1],)
316
+
317
+ hidden_states = self.norm(hidden_states)
318
+
319
+ # add hidden states from the last decoder layer
320
+ if output_hidden_states:
321
+ all_hidden_states += (hidden_states,)
322
+
323
+ output = BaseModelOutputWithPast(
324
+ last_hidden_state=hidden_states,
325
+ past_key_values=past_key_values if use_cache else None,
326
+ hidden_states=all_hidden_states,
327
+ attentions=None,
328
+ )
329
+ return output if return_dict else output.to_tuple()
330
+
331
+
332
+ class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
333
+ _tied_weights_keys = ["lm_head.weight"]
334
+
335
+ def __init__(self, config):
336
+ super().__init__(config)
337
+ self.model = Step1Model(config)
338
+ self.vocab_size = config.vocab_size
339
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
340
+
341
+ # Initialize weights and apply final processing
342
+ self.post_init()
343
+
344
+ def get_input_embeddings(self):
345
+ return self.model.embed_tokens
346
+
347
+ def set_input_embeddings(self, value):
348
+ self.model.embed_tokens = value
349
+
350
+ def set_decoder(self, decoder):
351
+ self.model = decoder
352
+
353
+ def get_decoder(self):
354
+ return self.model
355
+
356
+ def forward(
357
+ self,
358
+ input_ids: torch.LongTensor = None,
359
+ attention_mask: Optional[torch.Tensor] = None,
360
+ position_ids: Optional[torch.LongTensor] = None,
361
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
362
+ inputs_embeds: Optional[torch.FloatTensor] = None,
363
+ labels: Optional[torch.LongTensor] = None,
364
+ use_cache: Optional[bool] = None,
365
+ output_attentions: Optional[bool] = None,
366
+ output_hidden_states: Optional[bool] = None,
367
+ return_dict: Optional[bool] = None,
368
+ cache_position: Optional[torch.LongTensor] = None,
369
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
370
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
371
+ output_hidden_states = (
372
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
373
+ )
374
+ return_dict = (
375
+ return_dict if return_dict is not None else self.config.use_return_dict
376
+ )
377
+
378
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
379
+ outputs = self.model(
380
+ input_ids=input_ids,
381
+ attention_mask=attention_mask,
382
+ past_key_values=past_key_values,
383
+ inputs_embeds=inputs_embeds,
384
+ use_cache=use_cache,
385
+ output_attentions=output_attentions,
386
+ output_hidden_states=output_hidden_states,
387
+ return_dict=return_dict,
388
+ cache_position=cache_position,
389
+ )
390
+
391
+ hidden_states = outputs[0]
392
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
393
+
394
+ logits = self.lm_head(hidden_states)
395
+
396
+ loss = None
397
+ if labels is not None:
398
+ loss = self.loss_function(
399
+ logits=logits,
400
+ labels=labels,
401
+ vocab_size=self.config.vocab_size,
402
+ )
403
+
404
+ if not return_dict:
405
+ output = (logits,) + outputs[1:]
406
+ return (loss,) + output if loss is not None else output
407
+
408
+ return CausalLMOutputWithPast(
409
+ loss=loss,
410
+ logits=logits,
411
+ past_key_values=outputs.past_key_values,
412
+ hidden_states=outputs.hidden_states,
413
+ attentions=outputs.attentions,
414
+ )
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25e122d9205d035033a9994c4d46a6a1b467a938654e4178fc0e5f4f5d610674
3
+ size 1264044
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "clean_up_tokenization_spaces": false,
4
+ "eos_token": "</s>",
5
+ "legacy": false,
6
+ "model_max_length": 65536,
7
+ "pad_token": "<unk>",
8
+ "padding_side": "left",
9
+ "sp_model_kwargs": {},
10
+ "tokenizer_class": "LlamaTokenizer",
11
+ "unk_token": "<unk>",
12
+ "use_default_system_prompt": false,
13
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{{ '<s>' }}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% set role = 'human' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<|BOT|> ' + role + '\\n' }}{{ message['content'] }}{% if not loop.last or message['role'] != 'assistant' %}{{ '<|EOT|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|BOT|> assistant\\n' }}{% endif %}"
14
+ }
15
+