Commit ·
02d418e
0
Parent(s):
Duplicate from stepfun-ai/Step-Audio-EditX
Browse filesCo-authored-by: chao yan <yanchaomars@users.noreply.huggingface.co>
- .gitattributes +38 -0
- CosyVoice-300M-25Hz/FLOW_VERSION +2 -0
- CosyVoice-300M-25Hz/campplus.onnx +3 -0
- CosyVoice-300M-25Hz/cosyvoice.yaml +72 -0
- CosyVoice-300M-25Hz/flow.pt +3 -0
- CosyVoice-300M-25Hz/hift.pt +3 -0
- CosyVoice-300M-25Hz/speech_tokenizer_v1.onnx +3 -0
- README.md +911 -0
- assets/architechture.png +3 -0
- assets/emotion-eval.png +3 -0
- assets/logo.png +0 -0
- assets/test.wav +3 -0
- config.json +24 -0
- configuration_step1.py +41 -0
- model-00001.safetensors +3 -0
- model.safetensors.index.json +1 -0
- modeling_step1.py +414 -0
- tokenizer.model +3 -0
- tokenizer_config.json +15 -0
.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/architechture.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/emotion-eval.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/test.wav filter=lfs diff=lfs merge=lfs -text
|
CosyVoice-300M-25Hz/FLOW_VERSION
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/mnt/wby-jfs/models/train/flow_matching/flow_v2_1node_vq0206_dit_v8_fullattn_exp0227_sft_exp0408_stepaudio_sft_exp0616/model_epoch_5_whole.pt
|
| 2 |
+
fae53942e60310eb172b170396202069
|
CosyVoice-300M-25Hz/campplus.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6ac6a63997761ae2997373e2ee1c47040854b4b759ea41ec48e4e42df0f4d73
|
| 3 |
+
size 28303423
|
CosyVoice-300M-25Hz/cosyvoice.yaml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mel_conf:
|
| 2 |
+
num_mels: 80
|
| 3 |
+
n_fft: 1920
|
| 4 |
+
hop_size: 480
|
| 5 |
+
win_size: 1920
|
| 6 |
+
sampling_rate: 24000
|
| 7 |
+
fmin: 0
|
| 8 |
+
fmax: 8000
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
flow: !new:stepvocoder.cosyvoice2.flow.flow.CausalMaskedDiffWithXvec
|
| 12 |
+
input_size: 512
|
| 13 |
+
output_size: 80
|
| 14 |
+
spk_embed_dim: 192
|
| 15 |
+
output_type: 'mel'
|
| 16 |
+
vocab_size: 5121 # 1024(vq02) + 4096(vq06) + 1(vq02-pad)
|
| 17 |
+
input_embedding: !new:stepvocoder.cosyvoice2.embedding.dual_codebook.DualCodebookEmbedding
|
| 18 |
+
vocab_size: 5121 # 1024(vq02) + 4096(vq06) + 1(vq02-pad)
|
| 19 |
+
input_size: 512
|
| 20 |
+
encoder: !new:stepvocoder.cosyvoice2.transformer.upsample_encoder_v2.UpsampleConformerEncoderV2
|
| 21 |
+
input_size: 512
|
| 22 |
+
output_size: 512
|
| 23 |
+
input_layer: 'linear'
|
| 24 |
+
pre_lookahead_len: 3
|
| 25 |
+
num_blocks: 6
|
| 26 |
+
num_up_blocks: 4
|
| 27 |
+
up_stride: 2
|
| 28 |
+
up_scale_factor: 2
|
| 29 |
+
attention_heads: 8
|
| 30 |
+
pos_enc_layer_type: 'rel_pos_espnet'
|
| 31 |
+
selfattention_layer_type: 'rel_selfattn'
|
| 32 |
+
key_bias: true
|
| 33 |
+
linear_units: 2048
|
| 34 |
+
dropout_rate: 0.1
|
| 35 |
+
positional_dropout_rate: 0.1
|
| 36 |
+
attention_dropout_rate: 0.1
|
| 37 |
+
normalize_before: True
|
| 38 |
+
decoder: !new:stepvocoder.cosyvoice2.flow.flow_matching.CausalConditionalCFM
|
| 39 |
+
inference_cfg_rate: 0.7
|
| 40 |
+
estimator: !new:stepvocoder.cosyvoice2.flow.decoder_dit.DiT
|
| 41 |
+
in_channels: 320
|
| 42 |
+
out_channels: 80
|
| 43 |
+
mlp_ratio: 4.0
|
| 44 |
+
depth: 16
|
| 45 |
+
num_heads: 8
|
| 46 |
+
head_dim: 64
|
| 47 |
+
hidden_size: 512
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
hift: !new:stepvocoder.cosyvoice2.hifigan.generator.HiFTGenerator
|
| 51 |
+
in_channels: 80
|
| 52 |
+
base_channels: 512
|
| 53 |
+
nb_harmonics: 8
|
| 54 |
+
sampling_rate: 24000
|
| 55 |
+
nsf_alpha: 0.1
|
| 56 |
+
nsf_sigma: 0.003
|
| 57 |
+
nsf_voiced_threshold: 10
|
| 58 |
+
upsample_rates: [8, 5, 3]
|
| 59 |
+
upsample_kernel_sizes: [16, 11, 7]
|
| 60 |
+
istft_params:
|
| 61 |
+
n_fft: 16
|
| 62 |
+
hop_len: 4
|
| 63 |
+
resblock_kernel_sizes: [3, 7, 11]
|
| 64 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 65 |
+
source_resblock_kernel_sizes: [7, 7, 11]
|
| 66 |
+
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 67 |
+
lrelu_slope: 0.1
|
| 68 |
+
audio_limit: 0.99
|
| 69 |
+
f0_predictor: !new:stepvocoder.cosyvoice2.hifigan.f0_predictor.ConvRNNF0Predictor
|
| 70 |
+
num_class: 1
|
| 71 |
+
in_channels: 80
|
| 72 |
+
cond_channels: 512
|
CosyVoice-300M-25Hz/flow.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37f18fcb9c374bb8d8ae229e2f7618b6effaa208609bd0407fc661234125531c
|
| 3 |
+
size 615269316
|
CosyVoice-300M-25Hz/hift.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3386cc880324d4e98e05987b99107f49e40ed925b8ecc87c1f4939432d429879
|
| 3 |
+
size 83390254
|
CosyVoice-300M-25Hz/speech_tokenizer_v1.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486
|
| 3 |
+
size 522625011
|
README.md
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Step-Audio-EditX
|
| 2 |
+
<p align="center">
|
| 3 |
+
<img src="assets/logo.png" height=100>
|
| 4 |
+
</p>
|
| 5 |
+
|
| 6 |
+
<div align="center">
|
| 7 |
+
<a href="https://stepaudiollm.github.io/step-audio-editx/"><img src="https://img.shields.io/static/v1?label=Demo%20Page&message=Web&color=green"></a>  
|
| 8 |
+
<a href="https://arxiv.org/abs/2511.03601"><img src="https://img.shields.io/static/v1?label=Tech%20Report&message=Arxiv&color=red"></a>  
|
| 9 |
+
<a href="https://huggingface.co/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Step-Audio-EditX&message=HuggingFace&color=yellow"></a>  
|
| 10 |
+
<a href="https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Step-Audio-EditX&message=ModelScope&color=blue"></a>  
|
| 11 |
+
<a href="https://huggingface.co/spaces/stepfun-ai/Step-Audio-EditX"><img src="https://img.shields.io/static/v1?label=Space%20Playground&message=HuggingFace&color=yellow"></a>  
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
## 🔥🔥🔥 News!!!
|
| 15 |
+
* Jan 23, 2026: 🌟 Training and inference for vLLM are now supported. Thanks to the vLLM team!
|
| 16 |
+
* Jan 23, 2026: 💻 We release the GRPO training code.
|
| 17 |
+
* Jan 23, 2026: 🧩 New Model Release: Now supporting more paralinguistic tags.
|
| 18 |
+
* Nov 28, 2025: 🚀 New Model Release: Now supporting **`Japanese`** and **`Korean`** languages.
|
| 19 |
+
* Nov 23, 2025: 📊 [Step-Audio-Edit-Benchmark](https://github.com/stepfun-ai/Step-Audio-Edit-Benchmark) Released!
|
| 20 |
+
* Nov 19, 2025: ⚙️ We release a **new version** of our model, which **supports polyphonic pronunciation control** and improves the performance of emotion, speaking style, and paralinguistic editing.
|
| 21 |
+
* Nov 12, 2025: 📦 We release the **optimized inference code** and **model weights** of **Step-Audio-EditX** ([HuggingFace](https://huggingface.co/stepfun-ai/Step-Audio-EditX); [ModelScope](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX)) and **Step-Audio-Tokenizer**([HuggingFace](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer); [ModelScope](https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer))
|
| 22 |
+
* Nov 07, 2025: ✨ [Demo Page](https://stepaudiollm.github.io/step-audio-editx/) ; 🎮 [HF Space Playground](https://huggingface.co/spaces/stepfun-ai/Step-Audio-EditX)
|
| 23 |
+
* Nov 06, 2025: 👋 We release the technical report of [Step-Audio-EditX](https://arxiv.org/abs/2511.03601).
|
| 24 |
+
|
| 25 |
+
## Introduction
|
| 26 |
+
We are open-sourcing Step-Audio-EditX, a powerful **3B-parameter** LLM-based **Reinforcement Learning** audio model specialized in expressive and iterative audio editing. It excels at editing emotion, speaking style, and paralinguistics, and also features robust zero-shot text-to-speech (TTS) capabilities.
|
| 27 |
+
|
| 28 |
+
## 📑 Open-source Plan
|
| 29 |
+
- [x] Inference Code
|
| 30 |
+
- [x] Online demo (Gradio)
|
| 31 |
+
- [x] Step-Audio-Edit-Benchmark
|
| 32 |
+
- [x] Model Checkpoints
|
| 33 |
+
- [x] Step-Audio-Tokenizer
|
| 34 |
+
- [x] Step-Audio-EditX
|
| 35 |
+
- [x] Step-Audio-EditX-Int4
|
| 36 |
+
- [ ] Training Code
|
| 37 |
+
- [x] GRPO training
|
| 38 |
+
- [ ] SFT training
|
| 39 |
+
- [ ] PPO training
|
| 40 |
+
- [ ] ⏳ Feature Support Plan
|
| 41 |
+
- [ ] Editing
|
| 42 |
+
- [x] Polyphone pronunciation control
|
| 43 |
+
- [x] More paralinguistic tags ([Cough, Crying, Stress, etc.])
|
| 44 |
+
- [ ] Filler word removal
|
| 45 |
+
- [ ] Other Languages
|
| 46 |
+
- [x] Japanese, Korean
|
| 47 |
+
- [ ] Arabic, French, Russian, Spanish, etc.
|
| 48 |
+
|
| 49 |
+
## Online demonstration
|
| 50 |
+
|
| 51 |
+
### StepFun Audio Studio
|
| 52 |
+
|
| 53 |
+
- Both Step-Audio-EditX are available in our [StepFun Audio Studio](https://www.stepfun.com/studio/audio).
|
| 54 |
+
- You will need an API key from the [StepFun Open Platform](https://platform.stepfun.com/).
|
| 55 |
+
|
| 56 |
+
## WeChat group
|
| 57 |
+
|
| 58 |
+
You can scan the following QR code to join our WeChat group for communication and discussion.
|
| 59 |
+
<div align="center">
|
| 60 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/66518fd07d8cb2629a514c18/DRdnp1SN-yxhlNOfy26mE.jpeg" width="200" alt="QR code">
|
| 61 |
+
</div>
|
| 62 |
+
|
| 63 |
+
## Features
|
| 64 |
+
- **Zero-Shot TTS**
|
| 65 |
+
- Excellent zero-shot TTS cloning for Mandarin, English, Sichuanese, and Cantonese.
|
| 66 |
+
- To use dialect or other languages, just add a **`[Sichuanese]`** / **`[Cantonese]`** / **`[Japanese]`** / **`[Korean]`** tag before your text.
|
| 67 |
+
- 🔥 Polyphone pronunciation control, all you need to do is replace the polyphonic characters with pinyin.
|
| 68 |
+
- **[我也想过过过儿过过的生活]** -> **[我也想guo4guo4guo1儿guo4guo4的生活]**
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
- **Emotion and Speaking Style Editing**
|
| 72 |
+
- Remarkably effective iterative control over emotions and styles, supporting **dozens** of options for editing.
|
| 73 |
+
- Emotion Editing : [ *Angry*, *Happy*, *Sad*, *Excited*, *Fearful*, *Surprised*, *Disgusted*, etc. ]
|
| 74 |
+
- Speaking Style Editing: [ *Act_coy*, *Older*, *Child*, *Whisper*, *Serious*, *Generous*, *Exaggerated*, etc.]
|
| 75 |
+
- Editing with more emotion and more speaking styles is on the way. **Get Ready!** 🚀
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
- **Paralinguistic Editing**
|
| 79 |
+
- Precise control over 10 types of paralinguistic features for more natural, human-like, and expressive synthetic audio.
|
| 80 |
+
- Supporting Tags:
|
| 81 |
+
- [ *Breathing*, *Laughter*, *Surprise-oh*, *Confirmation-en*, *Uhm*, *Surprise-ah*, *Surprise-wa*, *Sigh*, *Question-ei*, *Dissatisfaction-hnn* ]
|
| 82 |
+
|
| 83 |
+
- **Available Tags**
|
| 84 |
+
<table>
|
| 85 |
+
<tr>
|
| 86 |
+
<td rowspan="8" style="vertical-align: middle; text-align:center;" align="center">emotion</td>
|
| 87 |
+
<td align="center"><b>happy</b></td>
|
| 88 |
+
<td align="center">Expressing happiness</td>
|
| 89 |
+
<td align="center"><b>angry</b></td>
|
| 90 |
+
<td align="center">Expressing anger</td>
|
| 91 |
+
</tr>
|
| 92 |
+
<tr>
|
| 93 |
+
<td align="center"><b>sad</b></td>
|
| 94 |
+
<td align="center">Expressing sadness</td>
|
| 95 |
+
<td align="center"><b>fear</b></td>
|
| 96 |
+
<td align="center">Expressing fear</td>
|
| 97 |
+
</tr>
|
| 98 |
+
<tr>
|
| 99 |
+
<td align="center"><b>surprised</b></td>
|
| 100 |
+
<td align="center">Expressing surprise</td>
|
| 101 |
+
<td align="center"><b>confusion</b></td>
|
| 102 |
+
<td align="center">Expressing confusion</td>
|
| 103 |
+
</tr>
|
| 104 |
+
<tr>
|
| 105 |
+
<td align="center"><b>empathy</b></td>
|
| 106 |
+
<td align="center">Expressing empathy and understanding</td>
|
| 107 |
+
<td align="center"><b>embarrass</b></td>
|
| 108 |
+
<td align="center">Expressing embarrassment</td>
|
| 109 |
+
</tr>
|
| 110 |
+
<tr>
|
| 111 |
+
<td align="center"><b>excited</b></td>
|
| 112 |
+
<td align="center">Expressing excitement and enthusiasm</td>
|
| 113 |
+
<td align="center"><b>depressed</b></td>
|
| 114 |
+
<td align="center">Expressing a depressed or discouraged mood</td>
|
| 115 |
+
</tr>
|
| 116 |
+
<tr>
|
| 117 |
+
<td align="center"><b>admiration</b></td>
|
| 118 |
+
<td align="center">Expressing admiration or respect</td>
|
| 119 |
+
<td align="center"><b>coldness</b></td>
|
| 120 |
+
<td align="center">Expressing coldness and indifference</td>
|
| 121 |
+
</tr>
|
| 122 |
+
<tr>
|
| 123 |
+
<td align="center"><b>disgusted</b></td>
|
| 124 |
+
<td align="center">Expressing disgust or aversion</td>
|
| 125 |
+
<td align="center"><b>humour</b></td>
|
| 126 |
+
<td align="center">Expressing humor or playfulness</td>
|
| 127 |
+
</tr>
|
| 128 |
+
<tr>
|
| 129 |
+
</tr>
|
| 130 |
+
<tr>
|
| 131 |
+
<td rowspan="17" style="vertical-align: middle; text-align:center;" align="center">speaking style</td>
|
| 132 |
+
<td align="center"><b>serious</b></td>
|
| 133 |
+
<td align="center">Speaking in a serious or solemn manner</td>
|
| 134 |
+
<td align="center"><b>arrogant</b></td>
|
| 135 |
+
<td align="center">Speaking in an arrogant manner</td>
|
| 136 |
+
</tr>
|
| 137 |
+
<tr>
|
| 138 |
+
<td align="center"><b>child</b></td>
|
| 139 |
+
<td align="center">Speaking in a childlike manner</td>
|
| 140 |
+
<td align="center"><b>older</b></td>
|
| 141 |
+
<td align="center">Speaking in an elderly-sounding manner</td>
|
| 142 |
+
</tr>
|
| 143 |
+
<tr>
|
| 144 |
+
<td align="center"><b>girl</b></td>
|
| 145 |
+
<td align="center">Speaking in a light, youthful feminine manner</td>
|
| 146 |
+
<td align="center"><b>pure</b></td>
|
| 147 |
+
<td align="center">Speaking in a pure, innocent manner</td>
|
| 148 |
+
</tr>
|
| 149 |
+
<tr>
|
| 150 |
+
<td align="center"><b>sister</b></td>
|
| 151 |
+
<td align="center">Speaking in a mature, confident feminine manner</td>
|
| 152 |
+
<td align="center"><b>sweet</b></td>
|
| 153 |
+
<td align="center">Speaking in a sweet, lovely manner</td>
|
| 154 |
+
</tr>
|
| 155 |
+
<tr>
|
| 156 |
+
<td align="center"><b>exaggerated</b></td>
|
| 157 |
+
<td align="center">Speaking in an exaggerated, dramatic manner</td>
|
| 158 |
+
<td align="center"><b>ethereal</b></td>
|
| 159 |
+
<td align="center">Speaking in a soft, airy, dreamy manner</td>
|
| 160 |
+
</tr>
|
| 161 |
+
<tr>
|
| 162 |
+
<td align="center"><b>whisper</b></td>
|
| 163 |
+
<td align="center">Speaking in a whispering, very soft manner</td>
|
| 164 |
+
<td align="center"><b>generous</b></td>
|
| 165 |
+
<td align="center">Speaking in a hearty, outgoing, and straight-talking manner</td>
|
| 166 |
+
</tr>
|
| 167 |
+
<tr>
|
| 168 |
+
<td align="center"><b>recite</b></td>
|
| 169 |
+
<td align="center">Speaking in a clear, well-paced, poetry-reading manner</td>
|
| 170 |
+
<td align="center"><b>act_coy</b></td>
|
| 171 |
+
<td align="center">Speaking in a sweet, playful, and endearing manner</td>
|
| 172 |
+
</tr>
|
| 173 |
+
<tr>
|
| 174 |
+
<td align="center"><b>warm</b></td>
|
| 175 |
+
<td align="center">Speaking in a warm, friendly manner</td>
|
| 176 |
+
<td align="center"><b>shy</b></td>
|
| 177 |
+
<td align="center">Speaking in a shy, timid manner</td>
|
| 178 |
+
</tr>
|
| 179 |
+
<tr>
|
| 180 |
+
<td align="center"><b>comfort</b></td>
|
| 181 |
+
<td align="center">Speaking in a comforting, reassuring manner</td>
|
| 182 |
+
<td align="center"><b>authority</b></td>
|
| 183 |
+
<td align="center">Speaking in an authoritative, commanding manner</td>
|
| 184 |
+
</tr>
|
| 185 |
+
<tr>
|
| 186 |
+
<td align="center"><b>chat</b></td>
|
| 187 |
+
<td align="center">Speaking in a casual, conversational manner</td>
|
| 188 |
+
<td align="center"><b>radio</b></td>
|
| 189 |
+
<td align="center">Speaking in a radio-broadcast manner</td>
|
| 190 |
+
</tr>
|
| 191 |
+
<tr>
|
| 192 |
+
<td align="center"><b>soulful</b></td>
|
| 193 |
+
<td align="center">Speaking in a heartfelt, deeply emotional manner</td>
|
| 194 |
+
<td align="center"><b>gentle</b></td>
|
| 195 |
+
<td align="center">Speaking in a gentle, soft manner</td>
|
| 196 |
+
</tr>
|
| 197 |
+
<tr>
|
| 198 |
+
<td align="center"><b>story</b></td>
|
| 199 |
+
<td align="center">Speaking in a narrative, audiobook-style manner</td>
|
| 200 |
+
<td align="center"><b>vivid</b></td>
|
| 201 |
+
<td align="center">Speaking in a lively, expressive manner</td>
|
| 202 |
+
</tr>
|
| 203 |
+
<tr>
|
| 204 |
+
<td align="center"><b>program</b></td>
|
| 205 |
+
<td align="center">Speaking in a show-host/presenter manner</td>
|
| 206 |
+
<td align="center"><b>news</b></td>
|
| 207 |
+
<td align="center">Speaking in a news broadcasting manner</td>
|
| 208 |
+
</tr>
|
| 209 |
+
<tr>
|
| 210 |
+
<td align="center"><b>advertising</b></td>
|
| 211 |
+
<td align="center">Speaking in a polished, high-end commercial voiceover manner</td>
|
| 212 |
+
<td align="center"><b>roar</b></td>
|
| 213 |
+
<td align="center">Speaking in a loud, deep, roaring manner</td>
|
| 214 |
+
</tr>
|
| 215 |
+
<tr>
|
| 216 |
+
<td align="center"><b>murmur</b></td>
|
| 217 |
+
<td align="center">Speaking in a quiet, low manner</td>
|
| 218 |
+
<td align="center"><b>shout</b></td>
|
| 219 |
+
<td align="center">Speaking in a loud, sharp, shouting manner</td>
|
| 220 |
+
</tr>
|
| 221 |
+
<tr>
|
| 222 |
+
<td align="center"><b>deeply</b></td>
|
| 223 |
+
<td align="center">Speaking in a deep and low-pitched tone</td>
|
| 224 |
+
<td align="center"><b>loudly</b></td>
|
| 225 |
+
<td align="center">Speaking in a loud and high-pitched tone</td>
|
| 226 |
+
</tr>
|
| 227 |
+
<tr>
|
| 228 |
+
</tr>
|
| 229 |
+
<tr>
|
| 230 |
+
</tr>
|
| 231 |
+
<tr>
|
| 232 |
+
<td rowspan="11" style="vertical-align: middle; text-align:center;" align="center">paralinguistic</td>
|
| 233 |
+
<td align="center"><b>[sigh]</b></td>
|
| 234 |
+
<td align="center">Sighing sound</td>
|
| 235 |
+
<td align="center"><b>[inhale]</b></td>
|
| 236 |
+
<td align="center">Inhaling sound</td>
|
| 237 |
+
</tr>
|
| 238 |
+
|
| 239 |
+
<tr>
|
| 240 |
+
<td align="center"><b>[laugh]</b></td>
|
| 241 |
+
<td align="center">Laughter sound</td>
|
| 242 |
+
<td align="center"><b>[chuckle]</b></td>
|
| 243 |
+
<td align="center">Chuckling sound</td>
|
| 244 |
+
</tr>
|
| 245 |
+
|
| 246 |
+
<tr>
|
| 247 |
+
<td align="center"><b>[exhale]</b></td>
|
| 248 |
+
<td align="center">Exhaling sound</td>
|
| 249 |
+
<td align="center"><b>[clears throat]</b></td>
|
| 250 |
+
<td align="center">Throat clearing sound</td>
|
| 251 |
+
</tr>
|
| 252 |
+
|
| 253 |
+
<tr>
|
| 254 |
+
<td align="center"><b>[snort]</b></td>
|
| 255 |
+
<td align="center">Snorting sound</td>
|
| 256 |
+
<td align="center"><b>[giggle]</b></td>
|
| 257 |
+
<td align="center">Giggling sound</td>
|
| 258 |
+
</tr>
|
| 259 |
+
|
| 260 |
+
<tr>
|
| 261 |
+
<td align="center"><b>[cough]</b></td>
|
| 262 |
+
<td align="center">Coughing sound</td>
|
| 263 |
+
<td align="center"><b>[breath]</b></td>
|
| 264 |
+
<td align="center">Breathing sound</td>
|
| 265 |
+
</tr>
|
| 266 |
+
|
| 267 |
+
<tr>
|
| 268 |
+
<td align="center"><b>[uhm]</b></td>
|
| 269 |
+
<td align="center">Hesitation sound: "Uhm"</td>
|
| 270 |
+
<td align="center"><b>[Confirmation-en]</b></td>
|
| 271 |
+
<td align="center">Confirming: "En"</td>
|
| 272 |
+
</tr>
|
| 273 |
+
|
| 274 |
+
<tr>
|
| 275 |
+
<td align="center"><b>[Surprise-oh]</b></td>
|
| 276 |
+
<td align="center">Expressing surprise: "Oh"</td>
|
| 277 |
+
<td align="center"><b>[Surprise-ah]</b></td>
|
| 278 |
+
<td align="center">Expressing surprise: "Ah"</td>
|
| 279 |
+
</tr>
|
| 280 |
+
|
| 281 |
+
<tr>
|
| 282 |
+
<td align="center"><b>[Surprise-wa]</b></td>
|
| 283 |
+
<td align="center">Expressing surprise: "Wa"</td>
|
| 284 |
+
<td align="center"><b>[Surprise-yo]</b></td>
|
| 285 |
+
<td align="center">Expressing surprise: "Yo"</td>
|
| 286 |
+
</tr>
|
| 287 |
+
|
| 288 |
+
<tr>
|
| 289 |
+
<td align="center"><b>[Dissatisfaction-hnn]</b></td>
|
| 290 |
+
<td align="center">Dissatisfied sound: "Hnn"</td>
|
| 291 |
+
<td align="center"><b>[Question-ei]</b></td>
|
| 292 |
+
<td align="center">Questioning: "Ei"</td>
|
| 293 |
+
</tr>
|
| 294 |
+
|
| 295 |
+
<tr>
|
| 296 |
+
<td align="center"><b>[Question-ah]</b></td>
|
| 297 |
+
<td align="center">Questioning: "Ah"</td>
|
| 298 |
+
<td align="center"><b>[Question-en]</b></td>
|
| 299 |
+
<td align="center">Questioning: "En"</td>
|
| 300 |
+
</tr>
|
| 301 |
+
|
| 302 |
+
<tr>
|
| 303 |
+
<td align="center"><b>[Question-yi]</b></td>
|
| 304 |
+
<td align="center">Questioning: "Yi"</td>
|
| 305 |
+
<td align="center"><b>[Question-oh]</b></td>
|
| 306 |
+
<td align="center">Questioning: "Oh"</td>
|
| 307 |
+
</tr>
|
| 308 |
+
</table>
|
| 309 |
+
|
| 310 |
+
## Feature Requests & Wishlist
|
| 311 |
+
💡 We welcome all ideas for new features! If you'd like to see a feature added to the project, please start a discussion in our [Discussions](https://github.com/stepfun-ai/Step-Audio-EditX/discussions) section.
|
| 312 |
+
|
| 313 |
+
We'll be collecting community feedback here and will incorporate popular suggestions into our future development plans. Thank you for your contribution!
|
| 314 |
+
|
| 315 |
+
## Demos
|
| 316 |
+
|
| 317 |
+
<table>
|
| 318 |
+
<tr>
|
| 319 |
+
<th style="vertical-align : middle;text-align: center">Task</th>
|
| 320 |
+
<th style="vertical-align : middle;text-align: center">Text</th>
|
| 321 |
+
<th style="vertical-align : middle;text-align: center">Source</th>
|
| 322 |
+
<th style="vertical-align : middle;text-align: center">Edited</th>
|
| 323 |
+
</tr>
|
| 324 |
+
|
| 325 |
+
<tr>
|
| 326 |
+
<td align="center"> Emotion-Fear</td>
|
| 327 |
+
<td align="center"> 我总觉得,有人在跟着我,我能听到奇怪的脚步声。</td>
|
| 328 |
+
<td align="center">
|
| 329 |
+
|
| 330 |
+
[fear_zh_female_prompt.webm](https://github.com/user-attachments/assets/a088c059-032c-423f-81d6-3816ba347ff5)
|
| 331 |
+
</td>
|
| 332 |
+
<td align="center">
|
| 333 |
+
|
| 334 |
+
[fear_zh_female_output.webm](https://github.com/user-attachments/assets/917494ac-5913-4949-8022-46cf55ca05dd)
|
| 335 |
+
</td>
|
| 336 |
+
</tr>
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
<tr>
|
| 340 |
+
<td align="center"> Style-Whisper</td>
|
| 341 |
+
<td align="center"> 比如在工作间隙,做一些简单的伸展运动,放松一下身体,这样,会让你更有精力。</td>
|
| 342 |
+
<td align="center">
|
| 343 |
+
|
| 344 |
+
[whisper_prompt.webm](https://github.com/user-attachments/assets/ed9e22f1-1bac-417b-913a-5f1db31f35c9)
|
| 345 |
+
</td>
|
| 346 |
+
<td align="center">
|
| 347 |
+
|
| 348 |
+
[whisper_output.webm](https://github.com/user-attachments/assets/e0501050-40db-4d45-b380-8bcc309f0b5f)
|
| 349 |
+
</td>
|
| 350 |
+
</tr>
|
| 351 |
+
|
| 352 |
+
<tr>
|
| 353 |
+
<td align="center"> Style-Act_coy</td>
|
| 354 |
+
<td align="center"> 我今天想喝奶茶,可是不知道喝什么口味,你帮我选一下嘛,你选的都好喝~</td>
|
| 355 |
+
<td align="center">
|
| 356 |
+
|
| 357 |
+
[act_coy_prompt.webm](https://github.com/user-attachments/assets/74d60625-5b3c-4f45-becb-0d3fe7cc4b3f)
|
| 358 |
+
</td>
|
| 359 |
+
<td align="center">
|
| 360 |
+
|
| 361 |
+
[act_coy_output.webm](https://github.com/user-attachments/assets/b2f74577-56c2-4997-afd6-6bf47d15ea51)
|
| 362 |
+
</td>
|
| 363 |
+
</tr>
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
<tr>
|
| 367 |
+
<td align="center"> Paralinguistics</td>
|
| 368 |
+
<td align="center"> 你这次又忘记带钥匙了 [Dissatisfaction-hnn],真是拿你没办法。</td>
|
| 369 |
+
<td align="center">
|
| 370 |
+
|
| 371 |
+
[paralingustic_prompt.webm](https://github.com/user-attachments/assets/21e831a3-8110-4c64-a157-60e0cf6735f0)
|
| 372 |
+
</td>
|
| 373 |
+
<td align="center">
|
| 374 |
+
|
| 375 |
+
[paralingustic_output.webm](https://github.com/user-attachments/assets/a82f5a40-c6a3-409b-bbe6-271180b20d7b)
|
| 376 |
+
</td>
|
| 377 |
+
</tr>
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
<tr>
|
| 381 |
+
<td align="center"> Denoising</td>
|
| 382 |
+
<td align="center"> Such legislation was clarified and extended from time to time thereafter. No, the man was not drunk, he wondered how we got tied up with this stranger. Suddenly, my reflexes had gone. It's healthier to cook without sugar.</td>
|
| 383 |
+
<td align="center">
|
| 384 |
+
|
| 385 |
+
[denoising_prompt.webm](https://github.com/user-attachments/assets/70464bf4-ebde-44a3-b2a6-8c292333319b)
|
| 386 |
+
</td>
|
| 387 |
+
<td align="center">
|
| 388 |
+
|
| 389 |
+
[denoising_output.webm](https://github.com/user-attachments/assets/7cd0ae8d-1bf0-40fc-9bcd-f419bd4b2d21)
|
| 390 |
+
</td>
|
| 391 |
+
</tr>
|
| 392 |
+
|
| 393 |
+
<tr>
|
| 394 |
+
<td align="center"> Speed-Faster</td>
|
| 395 |
+
<td align="center"> 上次你说鞋子有点磨脚,我给你买了一双软软的鞋垫。</td>
|
| 396 |
+
<td align="center">
|
| 397 |
+
|
| 398 |
+
[speed_faster_prompt.webm](https://github.com/user-attachments/assets/db46609e-1b98-48d8-99c8-e166cfdfc6e3)
|
| 399 |
+
</td>
|
| 400 |
+
<td align="center">
|
| 401 |
+
|
| 402 |
+
[speed_faster_output.webm](https://github.com/user-attachments/assets/0fbc14ca-dd4a-4362-aadc-afe0629f4c9f)
|
| 403 |
+
</td>
|
| 404 |
+
</tr>
|
| 405 |
+
|
| 406 |
+
</table>
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
For more examples, see [demo page](https://stepaudiollm.github.io/step-audio-editx/).
|
| 410 |
+
|
| 411 |
+
## Model Download
|
| 412 |
+
|
| 413 |
+
| Models | 🤗 Hugging Face | ModelScope |
|
| 414 |
+
|-------|-------|-------|
|
| 415 |
+
| Step-Audio-EditX | [stepfun-ai/Step-Audio-EditX](https://huggingface.co/stepfun-ai/Step-Audio-EditX) | [stepfun-ai/Step-Audio-EditX](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX) |
|
| 416 |
+
| Step-Audio-EditX | [stepfun-ai/Step-Audio-EditX-AWQ-4bit](https://huggingface.co/stepfun-ai/Step-Audio-EditX-AWQ-4bit) | [stepfun-ai/Step-Audio-EditX-AWQ-4bit](https://modelscope.cn/models/stepfun-ai/Step-Audio-EditX-AWQ-4bit) |
|
| 417 |
+
| Step-Audio-Tokenizer | [stepfun-ai/Step-Audio-Tokenizer](https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer) | [stepfun-ai/Step-Audio-Tokenizer](https://modelscope.cn/models/stepfun-ai/Step-Audio-Tokenizer) |
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
## Model Usage
|
| 421 |
+
### 📜 Requirements
|
| 422 |
+
The following table shows the requirements for running Step-Audio-EditX model (batch size = 1):
|
| 423 |
+
|
| 424 |
+
| Model | Parameters | Setting<br/>(sample frequency) | GPU Optimal Memory |
|
| 425 |
+
|------------|------------|--------------------------------|----------------|
|
| 426 |
+
| Step-Audio-EditX | 3B| 41.6Hz | 12 GB |
|
| 427 |
+
|
| 428 |
+
* An NVIDIA GPU with CUDA support is required.
|
| 429 |
+
* The model is tested on a single L40S GPU.
|
| 430 |
+
* 12GB is just a critical value, and 16GB GPU memory shoule be safer.
|
| 431 |
+
* Tested operating system: Linux
|
| 432 |
+
|
| 433 |
+
### 🔧 Dependencies and Installation
|
| 434 |
+
- Python >= 3.12
|
| 435 |
+
- [PyTorch >= 2.9.1](https://pytorch.org/)
|
| 436 |
+
- [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads)
|
| 437 |
+
|
| 438 |
+
```bash
|
| 439 |
+
git clone https://github.com/stepfun-ai/Step-Audio-EditX.git
|
| 440 |
+
|
| 441 |
+
cd Step-Audio-EditX
|
| 442 |
+
uv sync --refresh
|
| 443 |
+
source .venv/bin/activate
|
| 444 |
+
|
| 445 |
+
git lfs install
|
| 446 |
+
git clone https://huggingface.co/stepfun-ai/Step-Audio-Tokenizer
|
| 447 |
+
git clone https://huggingface.co/stepfun-ai/Step-Audio-EditX
|
| 448 |
+
git clone https://huggingface.co/stepfun-ai/Step-Audio-EditX-AWQ-4bit/
|
| 449 |
+
|
| 450 |
+
```
|
| 451 |
+
|
| 452 |
+
After downloading the models, where_you_download_dir should have the following structure:
|
| 453 |
+
```
|
| 454 |
+
where_you_download_dir
|
| 455 |
+
├── Step-Audio-Tokenizer
|
| 456 |
+
├── Step-Audio-EditX
|
| 457 |
+
```
|
| 458 |
+
|
| 459 |
+
#### Run with Docker
|
| 460 |
+
|
| 461 |
+
You can set up the environment required for running Step-Audio-EditX using the provided Dockerfile.
|
| 462 |
+
|
| 463 |
+
```bash
|
| 464 |
+
# build docker
|
| 465 |
+
docker build . -t step-audio-editx
|
| 466 |
+
|
| 467 |
+
# run docker
|
| 468 |
+
docker run --rm --gpus all \
|
| 469 |
+
-v /your/code/path:/app \
|
| 470 |
+
-v /your/model/path:/model \
|
| 471 |
+
-p 7860:7860 \
|
| 472 |
+
step-audio-editx
|
| 473 |
+
```
|
| 474 |
+
#### Local Inference Demo
|
| 475 |
+
> [!TIP]
|
| 476 |
+
> For optimal performance, keep audio under 30 seconds per inference.
|
| 477 |
+
|
| 478 |
+
```bash
|
| 479 |
+
# zero-shot cloning
|
| 480 |
+
# The path of the generated audio file is output/fear_zh_female_prompt_cloned.wav
|
| 481 |
+
python3 tts_infer.py \
|
| 482 |
+
--model-path where_you_download_dir \
|
| 483 |
+
--tokenizer-path where_you_download_dir \
|
| 484 |
+
--prompt-text "我总觉得,有人在跟着我,我能听到奇怪的脚步声。" \
|
| 485 |
+
--prompt-audio "examples/fear_zh_female_prompt.wav" \
|
| 486 |
+
--generated-text "可惜没有如果,已经发生的事情终究是发生了。" \
|
| 487 |
+
--edit-type "clone" \
|
| 488 |
+
--output-dir ./output
|
| 489 |
+
|
| 490 |
+
python3 tts_infer.py \
|
| 491 |
+
--model-path where_you_download_dir \
|
| 492 |
+
--tokenizer-path where_you_download_dir \
|
| 493 |
+
--prompt-text "His political stance was conservative, and he was particularly close to margaret thatcher." \
|
| 494 |
+
--prompt-audio "examples/zero_shot_en_prompt.wav" \
|
| 495 |
+
--generated-text "Underneath the courtyard is a large underground exhibition room which connects the two buildings. " \
|
| 496 |
+
--edit-type "clone" \
|
| 497 |
+
--output-dir ./output
|
| 498 |
+
|
| 499 |
+
# edit
|
| 500 |
+
# There will be one or multiple wave files corresponding to each edit iteration, for example: output/fear_zh_female_prompt_edited_iter1.wav, output/fear_zh_female_prompt_edited_iter2.wav, ...
|
| 501 |
+
# emotion; fear
|
| 502 |
+
python3 tts_infer.py \
|
| 503 |
+
--model-path where_you_download_dir \
|
| 504 |
+
--tokenizer-path where_you_download_dir \
|
| 505 |
+
--prompt-text "我总觉得,有人在跟着我,我能听到奇怪的脚步声。" \
|
| 506 |
+
--prompt-audio "examples/fear_zh_female_prompt.wav" \
|
| 507 |
+
--edit-type "emotion" \
|
| 508 |
+
--edit-info "fear" \
|
| 509 |
+
--output-dir ./output
|
| 510 |
+
|
| 511 |
+
# emotion; happy
|
| 512 |
+
python3 tts_infer.py \
|
| 513 |
+
--model-path where_you_download_dir \
|
| 514 |
+
--tokenizer-path where_you_download_dir \
|
| 515 |
+
--prompt-text "You know, I just finished that big project and feel so relieved. Everything seems easier and more colorful, what a wonderful feeling!" \
|
| 516 |
+
--prompt-audio "examples/en_happy_prompt.wav" \
|
| 517 |
+
--edit-type "emotion" \
|
| 518 |
+
--edit-info "happy" \
|
| 519 |
+
--output-dir ./output
|
| 520 |
+
|
| 521 |
+
# style; whisper
|
| 522 |
+
# for style whisper, the edit iteration num should be set bigger than 1 to get better results.
|
| 523 |
+
python3 tts_infer.py \
|
| 524 |
+
--model-path where_you_download_dir \
|
| 525 |
+
--tokenizer-path where_you_download_dir \
|
| 526 |
+
--prompt-text "比如在工作间隙,做一些简单的伸展运动,放松一下身体,这样,会让你更有精力." \
|
| 527 |
+
--prompt-audio "examples/whisper_prompt.wav" \
|
| 528 |
+
--edit-type "style" \
|
| 529 |
+
--edit-info "whisper" \
|
| 530 |
+
--output-dir ./output
|
| 531 |
+
|
| 532 |
+
# paraliguistic
|
| 533 |
+
# supported tags, Breathing, Laughter, Surprise-oh, Confirmation-en, Uhm, Surprise-ah, Surprise-wa, Sigh, Question-ei, Dissatisfaction-hnn
|
| 534 |
+
python3 tts_infer.py \
|
| 535 |
+
--model-path where_you_download_dir \
|
| 536 |
+
--tokenizer-path where_you_download_dir \
|
| 537 |
+
--prompt-text "我觉得这个计划大概是可行的,不过还需要再仔细考虑一下。" \
|
| 538 |
+
--prompt-audio "examples/paralingustic_prompt.wav" \
|
| 539 |
+
--generated-text "我觉得这个计划大概是可行的,[Uhm]不过还需要再仔细考虑一下。" \
|
| 540 |
+
--edit-type "paralinguistic" \
|
| 541 |
+
--output-dir ./output
|
| 542 |
+
|
| 543 |
+
# denoise
|
| 544 |
+
# Prompt text is not needed.
|
| 545 |
+
python3 tts_infer.py \
|
| 546 |
+
--model-path where_you_download_dir \
|
| 547 |
+
--tokenizer-path where_you_download_dir \
|
| 548 |
+
--prompt-audio "examples/denoise_prompt.wav"\
|
| 549 |
+
--edit-type "denoise" \
|
| 550 |
+
--output-dir ./output
|
| 551 |
+
|
| 552 |
+
# vad
|
| 553 |
+
# Prompt text is not needed.
|
| 554 |
+
python3 tts_infer.py \
|
| 555 |
+
--model-path where_you_download_dir \
|
| 556 |
+
--tokenizer-path where_you_download_dir \
|
| 557 |
+
--prompt-audio "examples/vad_prompt.wav" \
|
| 558 |
+
--edit-type "vad" \
|
| 559 |
+
--output-dir ./output
|
| 560 |
+
|
| 561 |
+
# speed
|
| 562 |
+
# supported edit-info: faster, slower, more faster, more slower
|
| 563 |
+
python3 tts_infer.py \
|
| 564 |
+
--model-path where_you_download_dir \
|
| 565 |
+
--tokenizer-path where_you_download_dir \
|
| 566 |
+
--prompt-text "上次你说鞋子有点磨脚,我给你买了一双软软的鞋垫。" \
|
| 567 |
+
--prompt-audio "examples/speed_prompt.wav" \
|
| 568 |
+
--edit-type "speed" \
|
| 569 |
+
--edit-info "more faster" \
|
| 570 |
+
--output-dir ./output
|
| 571 |
+
|
| 572 |
+
```
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
#### Launch Web Demo
|
| 577 |
+
Start a local server for online inference.
|
| 578 |
+
Assume you have one GPU with at least 12GB memory available and have already downloaded all the models.
|
| 579 |
+
|
| 580 |
+
```bash
|
| 581 |
+
# Standard launch
|
| 582 |
+
python app.py --model-path where_you_download_dir --tokenizer-path where_you_download_dir --model-source local
|
| 583 |
+
|
| 584 |
+
# Using pre-quantized AWQ 4-bit models, memory-efficient mode (for limited GPU memory, ~6-8GB usage)
|
| 585 |
+
python app.py \
|
| 586 |
+
--model-path path/to/quantized/model \
|
| 587 |
+
--tokenizer-path where_you_download_dir \
|
| 588 |
+
--model-source local \
|
| 589 |
+
--gpu-memory-utilization 0.1 \
|
| 590 |
+
--enforce-eager \
|
| 591 |
+
--max-num-seqs 1 \
|
| 592 |
+
--cosyvoice-dtype bfloat16 \
|
| 593 |
+
--no-cosyvoice-cuda-graph
|
| 594 |
+
|
| 595 |
+
```
|
| 596 |
+
|
| 597 |
+
##### Available Parameters
|
| 598 |
+
|
| 599 |
+
| Parameter | Default | Description |
|
| 600 |
+
|-----------|---------|-------------|
|
| 601 |
+
| `--model-path` | (required) | Path to the model directory |
|
| 602 |
+
| `--model-source` | `auto` | Model source: `auto`, `local`, `modelscope`, `huggingface` |
|
| 603 |
+
| `--gpu-memory-utilization` | `0.5` | GPU memory ratio for vLLM KV cache (0.0-1.0) |
|
| 604 |
+
| `--max-model-len` | `3072` | Maximum sequence length, affects KV cache size |
|
| 605 |
+
| `--enforce-eager` | `True` | Disable vLLM CUDA Graphs (saves ~0.5GB memory) |
|
| 606 |
+
| `--max-num-seqs` | `1` | Maximum concurrent sequences (vLLM default: 256, lower = less memory) |
|
| 607 |
+
| `--dtype` | `bfloat16` | Model dtype: `float16`, `bfloat16` |
|
| 608 |
+
| `--quantization` | `None` | Quantization method: `awq`, `gptq`, `fp8` |
|
| 609 |
+
| `--cosyvoice-dtype` | `bfloat16` | CosyVoice vocoder dtype: `float32`, `bfloat16`, `float16` |
|
| 610 |
+
| `--no-cosyvoice-cuda-graph` | `False` | Disable CosyVoice CUDA Graphs (saves memory) |
|
| 611 |
+
| `--enable-auto-transcribe` | `False` | Enable automatic audio transcription |
|
| 612 |
+
|
| 613 |
+
##### Memory Usage Guide
|
| 614 |
+
|
| 615 |
+
| Configuration | Estimated GPU Memory | Use Case |
|
| 616 |
+
|--------------|---------------------|----------|
|
| 617 |
+
| Standard (defaults) | ~12-15 GB | Best quality and speed |
|
| 618 |
+
| Memory-efficient | ~6-8 GB | Limited GPU memory, some quality trade-off |
|
| 619 |
+
| AWQ 4-bit quantized | ~8-10 GB | Good balance of quality and memory |
|
| 620 |
+
|
| 621 |
+
## Training
|
| 622 |
+
Please refer to script/ReadMe.md
|
| 623 |
+
|
| 624 |
+
### 🔄 Model Quantization (Optional)
|
| 625 |
+
|
| 626 |
+
For users with limited GPU memory, you can create quantized versions of the model to reduce memory requirements:
|
| 627 |
+
|
| 628 |
+
```bash
|
| 629 |
+
# Create an AWQ 4-bit quantized model
|
| 630 |
+
python quantization/awq_quantize.py --model_path path/to/Step-Audio-EditX
|
| 631 |
+
|
| 632 |
+
# Advanced quantization options
|
| 633 |
+
python quantization/awq_quantize.py
|
| 634 |
+
```
|
| 635 |
+
|
| 636 |
+
For detailed quantization options and parameters, see [quantization/README.md](quantization/README.md).
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
## Technical Details
|
| 640 |
+
<img src="assets/architechture.png" width=900>
|
| 641 |
+
Step-Audio-EditX comprises three primary components:
|
| 642 |
+
|
| 643 |
+
- A dual-codebook audio tokenizer, which converts reference or input audio into discrete tokens.
|
| 644 |
+
- An audio LLM that generates dual-codebook token sequences.
|
| 645 |
+
- An audio decoder, which converts the dual-codebook token sequences predicted by the audio LLM back into audio waveforms using a flow matching approach.
|
| 646 |
+
|
| 647 |
+
Audio-Edit enables iterative control over emotion and speaking style across all voices, leveraging large-margin data during SFT and PPO training.
|
| 648 |
+
|
| 649 |
+
## Evaluation
|
| 650 |
+
|
| 651 |
+
### Comparison between Step-Audio-EditX and Closed-Source models.
|
| 652 |
+
|
| 653 |
+
- Step-Audio-EditX demonstrates superior performance over Minimax and Doubao in both zero-shot cloning and emotion control.
|
| 654 |
+
- Emotion editing of Step-Audio-EditX significantly improves the emotion-controlled audio outputs of all three models after just one iteration. With further iterations, their overall performance continues to improve.
|
| 655 |
+
|
| 656 |
+
<div align="center">
|
| 657 |
+
<img src="assets/emotion-eval.png" width=800 >
|
| 658 |
+
</div>
|
| 659 |
+
|
| 660 |
+
### Generalization on Closed-Source Models.
|
| 661 |
+
- For emotion and speaking style editing, the built-in voices of leading closed-source systems possess considerable in-context capabilities, allowing them to partially convey the emotions in the text. After a single editing round with Step-Audio-EditX, the emotion and style accuracy across all voice models exhibited significant improvement. Further enhancement was observed over the next two iterations, robustly demonstrating our model's strong generalization.
|
| 662 |
+
|
| 663 |
+
- For paralinguistic editing, after editing with Step-Audio-EditX, the performance of paralinguistic reproduction is comparable to that achieved by the built-in voices of closed-source models when synthesizing native paralinguistic content directly. (**sub** means replacement of paralinguistic tags with native words)
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
<div align="center">
|
| 667 |
+
|
| 668 |
+
<table border="1" cellspacing="0" cellpadding="5" style="border-collapse: collapse; font-family: sans-serif; width: auto;">
|
| 669 |
+
<caption><b>Table: Generalization of Emotion, Speaking Style, and Paralinguistic Editing on Closed-Source Models.</b></caption>
|
| 670 |
+
<thead>
|
| 671 |
+
<tr>
|
| 672 |
+
<th rowspan="2" align="center" style="vertical-align: bottom;">Language</th>
|
| 673 |
+
<th rowspan="2" align="center" style="vertical-align: bottom;">Model</th>
|
| 674 |
+
<th colspan="4" style="border-bottom: 1px solid black;">Emotion ↑</th>
|
| 675 |
+
<th colspan="4" style="border-bottom: 1px solid black;">Speaking Style ↑</th>
|
| 676 |
+
<th colspan="3" style="border-bottom: 1px solid black; border-left: 1px solid black;">Paralinguistic ↑</th>
|
| 677 |
+
</tr>
|
| 678 |
+
<tr>
|
| 679 |
+
<th>Iter<sub>0</sub></th>
|
| 680 |
+
<th>Iter<sub>1</sub></th>
|
| 681 |
+
<th>Iter<sub>2</sub></th>
|
| 682 |
+
<th>Iter<sub>3</sub></th>
|
| 683 |
+
<th style="border-left: 1px solid #ccc;">Iter<sub>0</sub></th>
|
| 684 |
+
<th>Iter<sub>1</sub></th>
|
| 685 |
+
<th>Iter<sub>2</sub></th>
|
| 686 |
+
<th>Iter<sub>3</sub></th>
|
| 687 |
+
<th style="border-left: 1px solid black;">Iter<sub>0</sub></th>
|
| 688 |
+
<th>sub</th>
|
| 689 |
+
<th>Iter<sub>1</sub></th>
|
| 690 |
+
</tr>
|
| 691 |
+
</thead>
|
| 692 |
+
<tbody>
|
| 693 |
+
<tr>
|
| 694 |
+
<td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">Chinese</td>
|
| 695 |
+
<td align="left">MiniMax-2.6-hd</td>
|
| 696 |
+
<td align="center">71.6</td>
|
| 697 |
+
<td align="center">78.6</td>
|
| 698 |
+
<td align="center">81.2</td>
|
| 699 |
+
<td align="center"><b>83.4</b></td>
|
| 700 |
+
<td align="center" style="border-left: 1px solid #ccc;">36.7</td>
|
| 701 |
+
<td align="center">58.8</td>
|
| 702 |
+
<td align="center">63.1</td>
|
| 703 |
+
<td align="center"><b>67.3</b></td>
|
| 704 |
+
<td align="center" style="border-left: 1px solid black;">1.73</td>
|
| 705 |
+
<td align="center">2.80</td>
|
| 706 |
+
<td align="center">2.90</td>
|
| 707 |
+
</tr>
|
| 708 |
+
<tr>
|
| 709 |
+
<td align="left">Doubao-Seed-TTS-2.0</td>
|
| 710 |
+
<td align="center">67.4</td>
|
| 711 |
+
<td align="center">77.8</td>
|
| 712 |
+
<td align="center">80.6</td>
|
| 713 |
+
<td align="center"><b>82.8</b></td>
|
| 714 |
+
<td align="center" style="border-left: 1px solid #ccc;">38.2</td>
|
| 715 |
+
<td align="center">60.2</td>
|
| 716 |
+
<td align="center"><b>65.0</b></td>
|
| 717 |
+
<td align="center">64.9</td>
|
| 718 |
+
<td align="center" style="border-left: 1px solid black;">1.67</td>
|
| 719 |
+
<td align="center">2.81</td>
|
| 720 |
+
<td align="center">2.90</td>
|
| 721 |
+
</tr>
|
| 722 |
+
<tr>
|
| 723 |
+
<td align="left">GPT-4o-mini-TTS</td>
|
| 724 |
+
<td align="center">62.6</td>
|
| 725 |
+
<td align="center">76.0</td>
|
| 726 |
+
<td align="center">77.0</td>
|
| 727 |
+
<td align="center"><b>81.8</b></td>
|
| 728 |
+
<td align="center" style="border-left: 1px solid #ccc;">45.9</td>
|
| 729 |
+
<td align="center">64.0</td>
|
| 730 |
+
<td align="center">65.7</td>
|
| 731 |
+
<td align="center"><b>69.7</b></td>
|
| 732 |
+
<td align="center" style="border-left: 1px solid black;">1.71</td>
|
| 733 |
+
<td align="center">2.88</td>
|
| 734 |
+
<td align="center">2.93</td>
|
| 735 |
+
</tr>
|
| 736 |
+
<tr style="border-bottom: 1px solid black;">
|
| 737 |
+
<td align="left">ElevenLabs-v2</td>
|
| 738 |
+
<td align="center">60.4</td>
|
| 739 |
+
<td align="center">74.6</td>
|
| 740 |
+
<td align="center">77.4</td>
|
| 741 |
+
<td align="center"><b>79.2</b></td>
|
| 742 |
+
<td align="center" style="border-left: 1px solid #ccc;">43.8</td>
|
| 743 |
+
<td align="center">63.3</td>
|
| 744 |
+
<td align="center">69.7</td>
|
| 745 |
+
<td align="center"><b>70.8</b></td>
|
| 746 |
+
<td align="center" style="border-left: 1px solid black;">1.70</td>
|
| 747 |
+
<td align="center">2.71</td>
|
| 748 |
+
<td align="center">2.92</td>
|
| 749 |
+
</tr>
|
| 750 |
+
<tr>
|
| 751 |
+
<td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">English</td>
|
| 752 |
+
<td align="left">MiniMax-2.6-hd</td>
|
| 753 |
+
<td align="center">55.0</td>
|
| 754 |
+
<td align="center">64.0</td>
|
| 755 |
+
<td align="center">64.2</td>
|
| 756 |
+
<td align="center"><b>66.4</b></td>
|
| 757 |
+
<td align="center" style="border-left: 1px solid #ccc;">51.9</td>
|
| 758 |
+
<td align="center">60.3</td>
|
| 759 |
+
<td align="center">62.3</td>
|
| 760 |
+
<td align="center"><b>64.3</b></td>
|
| 761 |
+
<td align="center" style="border-left: 1px solid black;">1.72</td>
|
| 762 |
+
<td align="center">2.87</td>
|
| 763 |
+
<td align="center">2.88</td>
|
| 764 |
+
</tr>
|
| 765 |
+
<tr>
|
| 766 |
+
<td align="left">Doubao-Seed-TTS-2.0</td>
|
| 767 |
+
<td align="center">53.8</td>
|
| 768 |
+
<td align="center">65.8</td>
|
| 769 |
+
<td align="center">65.8</td>
|
| 770 |
+
<td align="center"><b>66.2</b></td>
|
| 771 |
+
<td align="center" style="border-left: 1px solid #ccc;">47.0</td>
|
| 772 |
+
<td align="center">62.0</td>
|
| 773 |
+
<td align="center"><b>62.7</b></td>
|
| 774 |
+
<td align="center">62.3</td>
|
| 775 |
+
<td align="center" style="border-left: 1px solid black;">1.72</td>
|
| 776 |
+
<td align="center">2.75</td>
|
| 777 |
+
<td align="center">2.92</td>
|
| 778 |
+
</tr>
|
| 779 |
+
<tr>
|
| 780 |
+
<td align="left">GPT-4o-mini-TTS</td>
|
| 781 |
+
<td align="center">56.8</td>
|
| 782 |
+
<td align="center">61.4</td>
|
| 783 |
+
<td align="center">64.8</td>
|
| 784 |
+
<td align="center"><b>65.2</b></td>
|
| 785 |
+
<td align="center" style="border-left: 1px solid #ccc;">52.3</td>
|
| 786 |
+
<td align="center">62.3</td>
|
| 787 |
+
<td align="center">62.4</td>
|
| 788 |
+
<td align="center"><b>63.4</b></td>
|
| 789 |
+
<td align="center" style="border-left: 1px solid black;">1.90</td>
|
| 790 |
+
<td align="center">2.90</td>
|
| 791 |
+
<td align="center">2.88</td>
|
| 792 |
+
</tr>
|
| 793 |
+
<tr style="border-bottom: 1px solid black;">
|
| 794 |
+
<td align="left">ElevenLabs-v2</td>
|
| 795 |
+
<td align="center">51.0</td>
|
| 796 |
+
<td align="center">61.2</td>
|
| 797 |
+
<td align="center">64.0</td>
|
| 798 |
+
<td align="center"><b>65.2</b></td>
|
| 799 |
+
<td align="center" style="border-left: 1px solid #ccc;">51.0</td>
|
| 800 |
+
<td align="center">62.1</td>
|
| 801 |
+
<td align="center">62.6</td>
|
| 802 |
+
<td align="center"><b>64.0</b></td>
|
| 803 |
+
<td align="center" style="border-left: 1px solid black;">1.93</td>
|
| 804 |
+
<td align="center">2.87</td>
|
| 805 |
+
<td align="center">2.88</td>
|
| 806 |
+
</tr>
|
| 807 |
+
<tr>
|
| 808 |
+
<td rowspan="4" align="center" style="font-weight: bold; vertical-align: middle;">Average</td>
|
| 809 |
+
<td align="left">MiniMax-2.6-hd</td>
|
| 810 |
+
<td align="center">63.3</td>
|
| 811 |
+
<td align="center">71.3</td>
|
| 812 |
+
<td align="center">72.7</td>
|
| 813 |
+
<td align="center"><b>74.9</b></td>
|
| 814 |
+
<td align="center" style="border-left: 1px solid #ccc;">44.2</td>
|
| 815 |
+
<td align="center">59.6</td>
|
| 816 |
+
<td align="center">62.7</td>
|
| 817 |
+
<td align="center"><b>65.8</b></td>
|
| 818 |
+
<td align="center" style="border-left: 1px solid black;">1.73</td>
|
| 819 |
+
<td align="center">2.84</td>
|
| 820 |
+
<td align="center">2.89</td>
|
| 821 |
+
</tr>
|
| 822 |
+
<tr>
|
| 823 |
+
<td align="left">Doubao-Seed-TTS-2.0</td>
|
| 824 |
+
<td align="center">60.6</td>
|
| 825 |
+
<td align="center">71.8</td>
|
| 826 |
+
<td align="center">73.2</td>
|
| 827 |
+
<td align="center"><b>74.5</b></td>
|
| 828 |
+
<td align="center" style="border-left: 1px solid #ccc;">42.6</td>
|
| 829 |
+
<td align="center">61.1</td>
|
| 830 |
+
<td align="center"><b>63.9</b></td>
|
| 831 |
+
<td align="center">63.6</td>
|
| 832 |
+
<td align="center" style="border-left: 1px solid black;">1.70</td>
|
| 833 |
+
<td align="center">2.78</td>
|
| 834 |
+
<td align="center">2.91</td>
|
| 835 |
+
</tr>
|
| 836 |
+
<tr>
|
| 837 |
+
<td align="left">GPT-4o-mini-TTS</td>
|
| 838 |
+
<td align="center">59.7</td>
|
| 839 |
+
<td align="center">68.7</td>
|
| 840 |
+
<td align="center">70.9</td>
|
| 841 |
+
<td align="center"><b>73.5</b></td>
|
| 842 |
+
<td align="center" style="border-left: 1px solid #ccc;">49.1</td>
|
| 843 |
+
<td align="center">63.2</td>
|
| 844 |
+
<td align="center">64.1</td>
|
| 845 |
+
<td align="center"><b>66.6</b></td>
|
| 846 |
+
<td align="center" style="border-left: 1px solid black;">1.81</td>
|
| 847 |
+
<td align="center">2.89</td>
|
| 848 |
+
<td align="center">2.90</td>
|
| 849 |
+
</tr>
|
| 850 |
+
<tr>
|
| 851 |
+
<td align="left">ElevenLabs-v2</td>
|
| 852 |
+
<td align="center">55.7</td>
|
| 853 |
+
<td align="center">67.9</td>
|
| 854 |
+
<td align="center">70.7</td>
|
| 855 |
+
<td align="center"><b>72.2</b></td>
|
| 856 |
+
<td align="center" style="border-left: 1px solid #ccc;">47.4</td>
|
| 857 |
+
<td align="center">62.7</td>
|
| 858 |
+
<td align="center">66.1</td>
|
| 859 |
+
<td align="center"><b>67.4</b></td>
|
| 860 |
+
<td align="center" style="border-left: 1px solid black;">1.82</td>
|
| 861 |
+
<td align="center">2.79</td>
|
| 862 |
+
<td align="center">2.90</td>
|
| 863 |
+
</tr>
|
| 864 |
+
</tbody>
|
| 865 |
+
</table>
|
| 866 |
+
|
| 867 |
+
</div>
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
## Acknowledgements
|
| 871 |
+
|
| 872 |
+
Part of the code and data for this project comes from:
|
| 873 |
+
* [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
| 874 |
+
* [transformers](https://github.com/huggingface/transformers)
|
| 875 |
+
* [FunASR](https://github.com/modelscope/FunASR)
|
| 876 |
+
* [NVSpeech](https://huggingface.co/datasets/amphion/Emilia-NV)
|
| 877 |
+
* [vllm](https://github.com/vllm-project/vllm)
|
| 878 |
+
|
| 879 |
+
Thank you to all the open-source projects for their contributions to this project!
|
| 880 |
+
|
| 881 |
+
## License Agreement
|
| 882 |
+
+ The code in this open-source repository is licensed under the [Apache 2.0](LICENSE) License.
|
| 883 |
+
|
| 884 |
+
## Citation
|
| 885 |
+
|
| 886 |
+
```
|
| 887 |
+
@misc{yan2025stepaudioeditxtechnicalreport,
|
| 888 |
+
title={Step-Audio-EditX Technical Report},
|
| 889 |
+
author={Chao Yan and Boyong Wu and Peng Yang and Pengfei Tan and Guoqiang Hu and Yuxin Zhang and Xiangyu and Zhang and Fei Tian and Xuerui Yang and Xiangyu Zhang and Daxin Jiang and Gang Yu},
|
| 890 |
+
year={2025},
|
| 891 |
+
eprint={2511.03601},
|
| 892 |
+
archivePrefix={arXiv},
|
| 893 |
+
primaryClass={cs.CL},
|
| 894 |
+
url={https://arxiv.org/abs/2511.03601},
|
| 895 |
+
}
|
| 896 |
+
```
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
## ⚠️ Usage Disclaimer
|
| 900 |
+
- Do not use this model for any unauthorized activities, including but not limited to:
|
| 901 |
+
- Voice cloning without permission
|
| 902 |
+
- Identity impersonation
|
| 903 |
+
- Fraud
|
| 904 |
+
- Deepfakes or any other illegal purposes
|
| 905 |
+
- Ensure compliance with local laws and regulations, and adhere to ethical guidelines when using this model.
|
| 906 |
+
- The model developers are not responsible for any misuse or abuse of this technology.
|
| 907 |
+
|
| 908 |
+
We advocate for responsible generative AI research and urge the community to uphold safety and ethical standards in AI development and application. If you have any concerns regarding the use of this model, please feel free to contact us.
|
| 909 |
+
|
| 910 |
+
## Star History
|
| 911 |
+
[](https://star-history.com/#stepfun-ai/Step-Audio-EditX&Date)
|
assets/architechture.png
ADDED
|
Git LFS Details
|
assets/emotion-eval.png
ADDED
|
Git LFS Details
|
assets/logo.png
ADDED
|
assets/test.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d7d3dc34ac3dd2f7765e61ba5e2023beb0f59cdf2acb14c42fa00fbddd13afa3
|
| 3 |
+
size 192558
|
config.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Step1ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_step1.Step1Config",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_step1.Step1ForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "step1",
|
| 10 |
+
"bos_token_id": 1,
|
| 11 |
+
"pad_token_id": 0,
|
| 12 |
+
"eos_token_id": 3,
|
| 13 |
+
"hidden_size": 3072,
|
| 14 |
+
"intermediate_size": 8192,
|
| 15 |
+
"num_attention_heads": 48,
|
| 16 |
+
"num_attention_groups": 4,
|
| 17 |
+
"num_hidden_layers": 32,
|
| 18 |
+
"max_seq_len": 32768,
|
| 19 |
+
"vocab_size": 74752,
|
| 20 |
+
"rms_norm_eps": 1e-05,
|
| 21 |
+
"tie_word_embeddings": false,
|
| 22 |
+
"torch_dtype": "bfloat16",
|
| 23 |
+
"use_cache": true
|
| 24 |
+
}
|
configuration_step1.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Any, Dict
|
| 2 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Step1Config(PretrainedConfig):
|
| 7 |
+
model_type = "step1"
|
| 8 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
hidden_size: int = 5120,
|
| 13 |
+
intermediate_size: int = 13312,
|
| 14 |
+
num_attention_heads: int = 40,
|
| 15 |
+
num_attention_groups: int = 8,
|
| 16 |
+
num_hidden_layers: int = 48,
|
| 17 |
+
max_seq_len: int = 4096,
|
| 18 |
+
vocab_size: int = 65536,
|
| 19 |
+
rms_norm_eps: float = 1e-5,
|
| 20 |
+
bos_token_id: int = 1,
|
| 21 |
+
eos_token_id: int = 3,
|
| 22 |
+
pad_token_id: int = 0,
|
| 23 |
+
**kwargs,
|
| 24 |
+
) -> None:
|
| 25 |
+
self.hidden_size = hidden_size
|
| 26 |
+
self.intermediate_size = intermediate_size
|
| 27 |
+
self.num_attention_heads = num_attention_heads
|
| 28 |
+
self.num_attention_groups = num_attention_groups
|
| 29 |
+
self.num_hidden_layers = num_hidden_layers
|
| 30 |
+
self.max_seq_len = max_seq_len
|
| 31 |
+
self.vocab_size = vocab_size
|
| 32 |
+
self.rms_norm_eps = rms_norm_eps
|
| 33 |
+
super().__init__(
|
| 34 |
+
bos_token_id=bos_token_id,
|
| 35 |
+
pad_token_id=pad_token_id,
|
| 36 |
+
eos_token_id=eos_token_id,
|
| 37 |
+
**kwargs
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
__all__ = ["Step1Config"]
|
model-00001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb2c6baae3a40ccd19cf21f7a28629666e31c215b5d63b2ed1e04aac6dd08d69
|
| 3 |
+
size 7059446656
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"metadata": {"total_size": 7059412992}, "weight_map": {"model.embed_tokens.weight": "model-00001.safetensors", "model.layers.0.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.0.input_layernorm.weight": "model-00001.safetensors", "model.layers.0.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.0.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.0.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.0.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.0.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.0.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.0.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.1.input_layernorm.weight": "model-00001.safetensors", "model.layers.1.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.1.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.1.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.1.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.1.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.1.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.2.input_layernorm.weight": "model-00001.safetensors", "model.layers.2.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.2.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.2.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.2.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.2.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.2.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.3.input_layernorm.weight": "model-00001.safetensors", "model.layers.3.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.3.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.3.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.3.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.3.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.3.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.4.input_layernorm.weight": "model-00001.safetensors", "model.layers.4.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.4.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.4.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.4.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.4.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.4.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.5.input_layernorm.weight": "model-00001.safetensors", "model.layers.5.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.5.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.5.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.5.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.5.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.5.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.6.input_layernorm.weight": "model-00001.safetensors", "model.layers.6.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.6.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.6.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.6.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.6.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.6.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.7.input_layernorm.weight": "model-00001.safetensors", "model.layers.7.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.7.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.7.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.7.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.7.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.7.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.8.input_layernorm.weight": "model-00001.safetensors", "model.layers.8.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.8.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.8.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.8.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.8.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.8.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.9.input_layernorm.weight": "model-00001.safetensors", "model.layers.9.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.9.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.9.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.9.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.9.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.9.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.10.input_layernorm.weight": "model-00001.safetensors", "model.layers.10.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.10.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.10.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.10.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.10.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.10.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.11.input_layernorm.weight": "model-00001.safetensors", "model.layers.11.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.11.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.11.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.11.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.11.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.11.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.12.input_layernorm.weight": "model-00001.safetensors", "model.layers.12.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.12.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.12.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.12.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.12.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.12.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.13.input_layernorm.weight": "model-00001.safetensors", "model.layers.13.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.13.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.13.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.13.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.13.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.13.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.14.input_layernorm.weight": "model-00001.safetensors", "model.layers.14.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.14.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.14.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.14.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.14.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.14.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.15.input_layernorm.weight": "model-00001.safetensors", "model.layers.15.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.15.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.15.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.15.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.15.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.15.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.16.input_layernorm.weight": "model-00001.safetensors", "model.layers.16.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.16.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.16.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.16.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.16.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.16.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.17.input_layernorm.weight": "model-00001.safetensors", "model.layers.17.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.17.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.17.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.17.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.17.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.17.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.18.input_layernorm.weight": "model-00001.safetensors", "model.layers.18.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.18.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.18.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.18.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.18.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.18.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.19.input_layernorm.weight": "model-00001.safetensors", "model.layers.19.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.19.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.19.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.19.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.19.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.19.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.20.input_layernorm.weight": "model-00001.safetensors", "model.layers.20.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.20.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.20.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.20.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.20.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.20.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.21.input_layernorm.weight": "model-00001.safetensors", "model.layers.21.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.21.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.21.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.21.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.21.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.21.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.22.input_layernorm.weight": "model-00001.safetensors", "model.layers.22.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.22.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.22.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.22.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.22.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.22.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.23.input_layernorm.weight": "model-00001.safetensors", "model.layers.23.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.23.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.23.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.23.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.23.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.23.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.24.input_layernorm.weight": "model-00001.safetensors", "model.layers.24.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.24.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.24.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.24.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.24.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.24.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.25.input_layernorm.weight": "model-00001.safetensors", "model.layers.25.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.25.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.25.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.25.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.25.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.25.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.26.input_layernorm.weight": "model-00001.safetensors", "model.layers.26.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.26.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.26.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.26.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.26.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.26.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.27.input_layernorm.weight": "model-00001.safetensors", "model.layers.27.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.27.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.27.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.27.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.27.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.27.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.28.input_layernorm.weight": "model-00001.safetensors", "model.layers.28.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.28.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.28.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.28.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.28.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.28.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.29.input_layernorm.weight": "model-00001.safetensors", "model.layers.29.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.29.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.29.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.29.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.29.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.29.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.30.input_layernorm.weight": "model-00001.safetensors", "model.layers.30.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.30.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.30.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.30.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.30.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.30.mlp.up_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.o_proj.weight": "model-00001.safetensors", "model.layers.31.input_layernorm.weight": "model-00001.safetensors", "model.layers.31.mlp.down_proj.weight": "model-00001.safetensors", "model.layers.31.post_attention_layernorm.weight": "model-00001.safetensors", "model.layers.31.self_attn.q_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.k_proj.weight": "model-00001.safetensors", "model.layers.31.self_attn.v_proj.weight": "model-00001.safetensors", "model.layers.31.mlp.gate_proj.weight": "model-00001.safetensors", "model.layers.31.mlp.up_proj.weight": "model-00001.safetensors", "model.norm.weight": "model-00001.safetensors", "lm_head.weight": "model-00001.safetensors"}}
|
modeling_step1.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional, Tuple, Union, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.utils.checkpoint
|
| 6 |
+
from torch import nn
|
| 7 |
+
from transformers.generation import GenerationMixin
|
| 8 |
+
|
| 9 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
from .configuration_step1 import Step1Config
|
| 12 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from transformers.modeling_outputs import (
|
| 15 |
+
BaseModelOutputWithPast,
|
| 16 |
+
CausalLMOutputWithPast,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def build_alibi_cache(block_size, n_heads, dtype, device):
|
| 23 |
+
# get slopes
|
| 24 |
+
n = 2 ** math.floor(math.log2(n_heads)) # nearest 2**n to n_heads
|
| 25 |
+
m0 = 2.0 ** (-8.0 / n)
|
| 26 |
+
# 2^(-8/n), 2^(-8*2/n), 2^(-8*3/n), ...
|
| 27 |
+
slopes = torch.pow(m0, torch.arange(1, n + 1))
|
| 28 |
+
if n < n_heads:
|
| 29 |
+
m1 = 2.0 ** (-4.0 / n)
|
| 30 |
+
# 2^(-8/(2n)), 2^(-8*3/(2n)), 2^(-8*5/(2n)), ...
|
| 31 |
+
mm = torch.pow(m1, torch.arange(1, 1 + 2 * (n_heads - n), 2))
|
| 32 |
+
slopes = torch.cat([slopes, mm])
|
| 33 |
+
slopes = slopes.to(device)
|
| 34 |
+
|
| 35 |
+
tril = torch.tril(torch.ones(1, 1, block_size, block_size, device=device))
|
| 36 |
+
|
| 37 |
+
bias_rows = torch.arange(block_size, device=device).view(1, -1)
|
| 38 |
+
bias_cols = torch.arange(block_size, device=device).view(-1, 1)
|
| 39 |
+
bias = -torch.sqrt(bias_cols - bias_rows)
|
| 40 |
+
bias = bias.view(1, block_size, block_size) * slopes.view(-1, 1, 1)
|
| 41 |
+
bias = bias.masked_fill(tril == 0, float("-inf"))
|
| 42 |
+
|
| 43 |
+
return bias.type(dtype)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class StepRMSNorm(torch.nn.Module):
|
| 47 |
+
def __init__(self, hidden_size, eps=1e-5):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
| 50 |
+
self.eps = eps
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor):
|
| 53 |
+
var = x.float().pow(2).mean(-1, keepdim=True)
|
| 54 |
+
x = x * torch.rsqrt(var + self.eps).to(x.dtype)
|
| 55 |
+
x = x * self.weight
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class StepAttention(torch.nn.Module):
|
| 60 |
+
def __init__(self, hidden_size, num_heads, num_groups, layer_idx: int):
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
self.num_heads = num_heads
|
| 64 |
+
self.num_groups = num_groups
|
| 65 |
+
self.hidden_size = hidden_size
|
| 66 |
+
self.head_dim = hidden_size // num_heads
|
| 67 |
+
|
| 68 |
+
self.q_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
|
| 69 |
+
self.k_proj = torch.nn.Linear(
|
| 70 |
+
hidden_size, num_groups * self.head_dim, bias=False
|
| 71 |
+
)
|
| 72 |
+
self.v_proj = torch.nn.Linear(
|
| 73 |
+
hidden_size, num_groups * self.head_dim, bias=False
|
| 74 |
+
)
|
| 75 |
+
self.o_proj = torch.nn.Linear(hidden_size, hidden_size, bias=False)
|
| 76 |
+
|
| 77 |
+
self.layer_idx = layer_idx
|
| 78 |
+
|
| 79 |
+
def flash_attn_func(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
| 80 |
+
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
| 81 |
+
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
| 82 |
+
return torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
x: torch.Tensor,
|
| 87 |
+
past_key_value: Optional[Cache] = None,
|
| 88 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 89 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 90 |
+
):
|
| 91 |
+
|
| 92 |
+
q: torch.Tensor = self.q_proj(x)
|
| 93 |
+
k: torch.Tensor = self.k_proj(x)
|
| 94 |
+
v: torch.Tensor = self.v_proj(x)
|
| 95 |
+
if past_key_value is not None:
|
| 96 |
+
cache_kwargs = {"cache_position": cache_position}
|
| 97 |
+
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
|
| 98 |
+
|
| 99 |
+
q = rearrange(q, "b s (h d) -> b s h d", h=self.num_heads)
|
| 100 |
+
k = rearrange(k, "b s (g d) -> b s g d", g=self.num_groups)
|
| 101 |
+
v = rearrange(v, "b s (g d) -> b s g d", g=self.num_groups)
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
if self.head_dim not in (64, 128):
|
| 105 |
+
raise ValueError("head_dim must be 64 or 128")
|
| 106 |
+
attn_output = self.flash_attn_func(q, k, v)
|
| 107 |
+
attn_output = attn_output.flatten(-2, -1)
|
| 108 |
+
except:
|
| 109 |
+
k = k.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
|
| 110 |
+
v = v.repeat_interleave(self.num_heads // self.num_groups, dim=-2)
|
| 111 |
+
|
| 112 |
+
attention_mask = build_alibi_cache(
|
| 113 |
+
k.size(1), self.num_heads, dtype=q.dtype, device=q.device
|
| 114 |
+
)[:, :, -q.size(1) :, :].contiguous()
|
| 115 |
+
|
| 116 |
+
q = q.transpose(1, 2)
|
| 117 |
+
k = k.transpose(1, 2)
|
| 118 |
+
v = v.transpose(1, 2)
|
| 119 |
+
|
| 120 |
+
attn_output: torch.Tensor = torch.nn.functional.scaled_dot_product_attention(
|
| 121 |
+
q, k, v, attn_mask=attention_mask
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
attn_output = attn_output.transpose(1, 2).flatten(-2, -1)
|
| 125 |
+
|
| 126 |
+
out = self.o_proj(attn_output)
|
| 127 |
+
return out, None # attn weights are not returned
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class StepMLP(torch.nn.Module):
|
| 131 |
+
def __init__(self, hidden_size, intermediate_size):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 134 |
+
self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 135 |
+
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
gate = self.gate_proj(x)
|
| 139 |
+
up = self.up_proj(x)
|
| 140 |
+
x = torch.nn.functional.silu(gate) * up
|
| 141 |
+
x = self.down_proj(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class StepLayer(torch.nn.Module):
|
| 146 |
+
def __init__(self, config: Step1Config, layer_idx: int):
|
| 147 |
+
super().__init__()
|
| 148 |
+
self.layer_idx = layer_idx
|
| 149 |
+
self.self_attn = StepAttention(
|
| 150 |
+
hidden_size=config.hidden_size,
|
| 151 |
+
num_heads=config.num_attention_heads,
|
| 152 |
+
num_groups=config.num_attention_groups,
|
| 153 |
+
layer_idx=layer_idx,
|
| 154 |
+
)
|
| 155 |
+
self.mlp = StepMLP(
|
| 156 |
+
hidden_size=config.hidden_size,
|
| 157 |
+
intermediate_size=config.intermediate_size,
|
| 158 |
+
)
|
| 159 |
+
self.input_layernorm = StepRMSNorm(
|
| 160 |
+
hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
| 161 |
+
)
|
| 162 |
+
self.post_attention_layernorm = StepRMSNorm(
|
| 163 |
+
hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
hidden_states: torch.Tensor,
|
| 169 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 170 |
+
past_key_value: Optional[Cache] = None,
|
| 171 |
+
output_attentions: Optional[bool] = False,
|
| 172 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 173 |
+
):
|
| 174 |
+
residual = hidden_states
|
| 175 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 176 |
+
hidden_states, self_attn_weights = self.self_attn(hidden_states, past_key_value, attention_mask, cache_position)
|
| 177 |
+
hidden_states = residual + hidden_states
|
| 178 |
+
|
| 179 |
+
residual = hidden_states
|
| 180 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 181 |
+
hidden_states = self.mlp(hidden_states)
|
| 182 |
+
hidden_states = residual + hidden_states
|
| 183 |
+
|
| 184 |
+
outputs = (hidden_states, )
|
| 185 |
+
if output_attentions:
|
| 186 |
+
outputs += (self_attn_weights,)
|
| 187 |
+
return outputs
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class StepPreTrainedModel(PreTrainedModel):
|
| 191 |
+
config_class = Step1Config
|
| 192 |
+
base_model_prefix = "model"
|
| 193 |
+
supports_gradient_checkpointing = True
|
| 194 |
+
_no_split_modules = ["StepLayer"]
|
| 195 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 196 |
+
_supports_cache_class = True
|
| 197 |
+
_supports_static_cache = True
|
| 198 |
+
|
| 199 |
+
def _init_weights(self, module):
|
| 200 |
+
std = self.config.initializer_range
|
| 201 |
+
if isinstance(module, nn.Linear):
|
| 202 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 203 |
+
if module.bias is not None:
|
| 204 |
+
module.bias.data.zero_()
|
| 205 |
+
elif isinstance(module, nn.Embedding):
|
| 206 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 207 |
+
if module.padding_idx is not None:
|
| 208 |
+
module.weight.data[module.padding_idx].zero_()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class Step1Model(StepPreTrainedModel):
|
| 212 |
+
"""
|
| 213 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
config: Step1Config
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
def __init__(self, config: Step1Config):
|
| 220 |
+
super().__init__(config)
|
| 221 |
+
self.config = config
|
| 222 |
+
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size)
|
| 223 |
+
|
| 224 |
+
self.layers = torch.nn.Sequential(
|
| 225 |
+
*[
|
| 226 |
+
StepLayer(config, layer_idx)
|
| 227 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 228 |
+
]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
self.norm = StepRMSNorm(
|
| 232 |
+
hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Initialize weights and apply final processing
|
| 236 |
+
self.post_init()
|
| 237 |
+
|
| 238 |
+
def get_input_embeddings(self):
|
| 239 |
+
return self.embed_tokens
|
| 240 |
+
|
| 241 |
+
def set_input_embeddings(self, value):
|
| 242 |
+
self.embed_tokens = value
|
| 243 |
+
|
| 244 |
+
def forward(
|
| 245 |
+
self,
|
| 246 |
+
input_ids: torch.LongTensor = None,
|
| 247 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 248 |
+
past_key_values: Optional[Cache] = None,
|
| 249 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 250 |
+
use_cache: Optional[bool] = None,
|
| 251 |
+
output_attentions: Optional[bool] = None,
|
| 252 |
+
output_hidden_states: Optional[bool] = None,
|
| 253 |
+
return_dict: Optional[bool] = None,
|
| 254 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 255 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 256 |
+
output_attentions = (
|
| 257 |
+
output_attentions
|
| 258 |
+
if output_attentions is not None
|
| 259 |
+
else self.config.output_attentions
|
| 260 |
+
)
|
| 261 |
+
output_hidden_states = (
|
| 262 |
+
output_hidden_states
|
| 263 |
+
if output_hidden_states is not None
|
| 264 |
+
else self.config.output_hidden_states
|
| 265 |
+
)
|
| 266 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 267 |
+
return_dict = (
|
| 268 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 272 |
+
raise ValueError(
|
| 273 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
if inputs_embeds is None:
|
| 277 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 278 |
+
|
| 279 |
+
if use_cache and past_key_values is None:
|
| 280 |
+
past_key_values = DynamicCache()
|
| 281 |
+
|
| 282 |
+
if cache_position is None:
|
| 283 |
+
past_seen_tokens = (
|
| 284 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 285 |
+
)
|
| 286 |
+
cache_position = torch.arange(
|
| 287 |
+
past_seen_tokens,
|
| 288 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 289 |
+
device=inputs_embeds.device,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
causal_mask = attention_mask
|
| 293 |
+
|
| 294 |
+
hidden_states = inputs_embeds
|
| 295 |
+
|
| 296 |
+
# decoder layers
|
| 297 |
+
all_hidden_states = () if output_hidden_states else None
|
| 298 |
+
all_self_attns = () if output_attentions else None
|
| 299 |
+
|
| 300 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 301 |
+
if output_hidden_states:
|
| 302 |
+
all_hidden_states += (hidden_states,)
|
| 303 |
+
|
| 304 |
+
layer_outputs = decoder_layer(
|
| 305 |
+
hidden_states,
|
| 306 |
+
attention_mask=causal_mask,
|
| 307 |
+
past_key_value=past_key_values,
|
| 308 |
+
cache_position=cache_position,
|
| 309 |
+
output_attentions=output_attentions,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
hidden_states = layer_outputs[0]
|
| 313 |
+
|
| 314 |
+
if output_attentions:
|
| 315 |
+
all_self_attns += (layer_outputs[1],)
|
| 316 |
+
|
| 317 |
+
hidden_states = self.norm(hidden_states)
|
| 318 |
+
|
| 319 |
+
# add hidden states from the last decoder layer
|
| 320 |
+
if output_hidden_states:
|
| 321 |
+
all_hidden_states += (hidden_states,)
|
| 322 |
+
|
| 323 |
+
output = BaseModelOutputWithPast(
|
| 324 |
+
last_hidden_state=hidden_states,
|
| 325 |
+
past_key_values=past_key_values if use_cache else None,
|
| 326 |
+
hidden_states=all_hidden_states,
|
| 327 |
+
attentions=None,
|
| 328 |
+
)
|
| 329 |
+
return output if return_dict else output.to_tuple()
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class Step1ForCausalLM(StepPreTrainedModel, GenerationMixin):
|
| 333 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 334 |
+
|
| 335 |
+
def __init__(self, config):
|
| 336 |
+
super().__init__(config)
|
| 337 |
+
self.model = Step1Model(config)
|
| 338 |
+
self.vocab_size = config.vocab_size
|
| 339 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 340 |
+
|
| 341 |
+
# Initialize weights and apply final processing
|
| 342 |
+
self.post_init()
|
| 343 |
+
|
| 344 |
+
def get_input_embeddings(self):
|
| 345 |
+
return self.model.embed_tokens
|
| 346 |
+
|
| 347 |
+
def set_input_embeddings(self, value):
|
| 348 |
+
self.model.embed_tokens = value
|
| 349 |
+
|
| 350 |
+
def set_decoder(self, decoder):
|
| 351 |
+
self.model = decoder
|
| 352 |
+
|
| 353 |
+
def get_decoder(self):
|
| 354 |
+
return self.model
|
| 355 |
+
|
| 356 |
+
def forward(
|
| 357 |
+
self,
|
| 358 |
+
input_ids: torch.LongTensor = None,
|
| 359 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 360 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 361 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| 362 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 363 |
+
labels: Optional[torch.LongTensor] = None,
|
| 364 |
+
use_cache: Optional[bool] = None,
|
| 365 |
+
output_attentions: Optional[bool] = None,
|
| 366 |
+
output_hidden_states: Optional[bool] = None,
|
| 367 |
+
return_dict: Optional[bool] = None,
|
| 368 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 369 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 370 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 371 |
+
output_hidden_states = (
|
| 372 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 373 |
+
)
|
| 374 |
+
return_dict = (
|
| 375 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 379 |
+
outputs = self.model(
|
| 380 |
+
input_ids=input_ids,
|
| 381 |
+
attention_mask=attention_mask,
|
| 382 |
+
past_key_values=past_key_values,
|
| 383 |
+
inputs_embeds=inputs_embeds,
|
| 384 |
+
use_cache=use_cache,
|
| 385 |
+
output_attentions=output_attentions,
|
| 386 |
+
output_hidden_states=output_hidden_states,
|
| 387 |
+
return_dict=return_dict,
|
| 388 |
+
cache_position=cache_position,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
hidden_states = outputs[0]
|
| 392 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 393 |
+
|
| 394 |
+
logits = self.lm_head(hidden_states)
|
| 395 |
+
|
| 396 |
+
loss = None
|
| 397 |
+
if labels is not None:
|
| 398 |
+
loss = self.loss_function(
|
| 399 |
+
logits=logits,
|
| 400 |
+
labels=labels,
|
| 401 |
+
vocab_size=self.config.vocab_size,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
if not return_dict:
|
| 405 |
+
output = (logits,) + outputs[1:]
|
| 406 |
+
return (loss,) + output if loss is not None else output
|
| 407 |
+
|
| 408 |
+
return CausalLMOutputWithPast(
|
| 409 |
+
loss=loss,
|
| 410 |
+
logits=logits,
|
| 411 |
+
past_key_values=outputs.past_key_values,
|
| 412 |
+
hidden_states=outputs.hidden_states,
|
| 413 |
+
attentions=outputs.attentions,
|
| 414 |
+
)
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:25e122d9205d035033a9994c4d46a6a1b467a938654e4178fc0e5f4f5d610674
|
| 3 |
+
size 1264044
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "<s>",
|
| 3 |
+
"clean_up_tokenization_spaces": false,
|
| 4 |
+
"eos_token": "</s>",
|
| 5 |
+
"legacy": false,
|
| 6 |
+
"model_max_length": 65536,
|
| 7 |
+
"pad_token": "<unk>",
|
| 8 |
+
"padding_side": "left",
|
| 9 |
+
"sp_model_kwargs": {},
|
| 10 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 11 |
+
"unk_token": "<unk>",
|
| 12 |
+
"use_default_system_prompt": false,
|
| 13 |
+
"chat_template": "{% if messages[0]['role'] == 'system' %}{{ '<s>' }}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% set role = 'human' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<|BOT|> ' + role + '\\n' }}{{ message['content'] }}{% if not loop.last or message['role'] != 'assistant' %}{{ '<|EOT|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|BOT|> assistant\\n' }}{% endif %}"
|
| 14 |
+
}
|
| 15 |
+
|