Spaces:
Running on Zero
Running on Zero
- .gitignore +54 -0
- src/YingMusicSinger/config/YingMusic_Singer.yaml +110 -0
- src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json +196 -0
- src/YingMusicSinger/utils/checkpoint.py +64 -0
- src/YingMusicSinger/utils/cnen_tokenizer.py +34 -0
- src/YingMusicSinger/utils/common.py +325 -0
- src/YingMusicSinger/utils/lrc_align.py +61 -0
- src/YingMusicSinger/utils/mel_spectrogram.py +86 -0
.gitignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python bytecode files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
|
| 5 |
+
# Virtual environment
|
| 6 |
+
venv/
|
| 7 |
+
ENV/
|
| 8 |
+
env/
|
| 9 |
+
.venv/
|
| 10 |
+
.ENV/
|
| 11 |
+
# Python IDEs
|
| 12 |
+
.idea/
|
| 13 |
+
.vscode/
|
| 14 |
+
*.sublime-project
|
| 15 |
+
*.sublime-workspace
|
| 16 |
+
|
| 17 |
+
# Jupyter Notebook checkpoints
|
| 18 |
+
.ipynb_checkpoints/
|
| 19 |
+
|
| 20 |
+
# Data files (data used for training or testing)
|
| 21 |
+
*.log
|
| 22 |
+
|
| 23 |
+
# TensorBoard logs
|
| 24 |
+
runs/
|
| 25 |
+
tensorboard_logs/
|
| 26 |
+
|
| 27 |
+
# Operating system files
|
| 28 |
+
.DS_Store
|
| 29 |
+
Thumbs.db
|
| 30 |
+
|
| 31 |
+
# PyCharm files
|
| 32 |
+
*.iml
|
| 33 |
+
.idea/
|
| 34 |
+
|
| 35 |
+
# Coverage and testing tools
|
| 36 |
+
.coverage
|
| 37 |
+
nosetests.xml
|
| 38 |
+
coverage.xml
|
| 39 |
+
*.cover
|
| 40 |
+
*.log
|
| 41 |
+
|
| 42 |
+
# Compiled extension modules
|
| 43 |
+
*.so
|
| 44 |
+
*.dylib
|
| 45 |
+
*.pyd
|
| 46 |
+
|
| 47 |
+
# Cython debug symbols
|
| 48 |
+
cython_debug/
|
| 49 |
+
|
| 50 |
+
# Other custom ignore rules
|
| 51 |
+
*.bak
|
| 52 |
+
*.swp
|
| 53 |
+
|
| 54 |
+
.ruff_cache/
|
src/YingMusicSinger/config/YingMusic_Singer.yaml
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
|
| 5 |
+
runname: YingMusic_Singer
|
| 6 |
+
|
| 7 |
+
datasets:
|
| 8 |
+
name: svs_infer
|
| 9 |
+
batch_size_per_gpu: 6
|
| 10 |
+
batch_size_type: sample
|
| 11 |
+
max_samples: null
|
| 12 |
+
num_workers: 4
|
| 13 |
+
|
| 14 |
+
datasets_cfg:
|
| 15 |
+
filelist_path: /path/to/your/filelist
|
| 16 |
+
vae_frame_rate: 21.533203125
|
| 17 |
+
text_num_embeds: 373
|
| 18 |
+
lrc_align_mode: sentence_level
|
| 19 |
+
|
| 20 |
+
optim:
|
| 21 |
+
epochs: null
|
| 22 |
+
num_updates: 31518
|
| 23 |
+
learning_rate: 7e-6
|
| 24 |
+
num_warmup_updates: 60
|
| 25 |
+
grad_accumulation_steps: 1
|
| 26 |
+
max_grad_norm: 1.0
|
| 27 |
+
bnb_optimizer: False
|
| 28 |
+
max_iter: null
|
| 29 |
+
|
| 30 |
+
model:
|
| 31 |
+
name: YingMusic_Singer
|
| 32 |
+
tokenizer: null
|
| 33 |
+
tokenizer_path: null
|
| 34 |
+
is_tts_pretrain: 0
|
| 35 |
+
melody_input_source: some_pretrain_fuzzdisturb
|
| 36 |
+
cka_disabled: 0
|
| 37 |
+
backbone: DiT
|
| 38 |
+
f0_fn_type: null
|
| 39 |
+
f0_fn_path: null
|
| 40 |
+
|
| 41 |
+
arch:
|
| 42 |
+
dim: 1024
|
| 43 |
+
depth: 22
|
| 44 |
+
heads: 16
|
| 45 |
+
ff_mult: 2
|
| 46 |
+
text_dim: 512
|
| 47 |
+
text_mask_padding: False
|
| 48 |
+
qk_norm: null
|
| 49 |
+
conv_layers: 4
|
| 50 |
+
pe_attn_head: null
|
| 51 |
+
attn_backend: torch
|
| 52 |
+
attn_mask_enabled: False
|
| 53 |
+
checkpoint_activations: False
|
| 54 |
+
guidance_scale_embed_dim: null
|
| 55 |
+
|
| 56 |
+
mel_spec:
|
| 57 |
+
n_mel_channels: 64
|
| 58 |
+
mel_spec_type: vae
|
| 59 |
+
|
| 60 |
+
vocoder:
|
| 61 |
+
is_local: True
|
| 62 |
+
local_path: null
|
| 63 |
+
|
| 64 |
+
midi_extractor:
|
| 65 |
+
path: ckpts/model_ckpt_steps_100000_simplified.ckpt
|
| 66 |
+
|
| 67 |
+
extra_parameters:
|
| 68 |
+
some_pretrain_fuzzdisturb:
|
| 69 |
+
dim: 128
|
| 70 |
+
drop_type: equal_space
|
| 71 |
+
drop_prob: [1, 9]
|
| 72 |
+
noise_scale: 0.0
|
| 73 |
+
blur_kernel: 0
|
| 74 |
+
|
| 75 |
+
grpo:
|
| 76 |
+
noise_level: 0.8
|
| 77 |
+
num_samples: 8
|
| 78 |
+
upper_clip_epsilon: 0.02
|
| 79 |
+
lower_clip_epsilon: 0.002
|
| 80 |
+
beta: 1
|
| 81 |
+
ppo_epochs: 1
|
| 82 |
+
num_steps: 32
|
| 83 |
+
sde_window_range: [1, 16]
|
| 84 |
+
sde_window_size: 2
|
| 85 |
+
delet_temp: 10
|
| 86 |
+
use_cfg_sample: false
|
| 87 |
+
wer_SDI_weights: [1, 1, 1]
|
| 88 |
+
reward_config: {"qwen_asr_wer": 0.25, "f0_correlation": 0.25, "qwenfeat": 0.25, "sim_wavlm_large": 0.25}
|
| 89 |
+
grpo_wanted_loss: ["qwen_asr_wer_reward", "f0_correlation_reward", "qwenfeat_reward", "sim_wavlm_large_reward"]
|
| 90 |
+
use_guidance_scale_embed: false
|
| 91 |
+
t_shift: 0.5
|
| 92 |
+
cfg_strength: null
|
| 93 |
+
GDPO_batch_norm: false
|
| 94 |
+
use_egrpo: false
|
| 95 |
+
egrpo_tau: null
|
| 96 |
+
egrpo_d: null
|
| 97 |
+
use_max_group_std_dev: false
|
| 98 |
+
|
| 99 |
+
ema_kwargs:
|
| 100 |
+
beta: 0.995
|
| 101 |
+
update_after_step: 100
|
| 102 |
+
update_every: 1
|
| 103 |
+
|
| 104 |
+
ckpts:
|
| 105 |
+
logger: tensorboard
|
| 106 |
+
log_samples: False
|
| 107 |
+
save_per_updates: 100
|
| 108 |
+
keep_last_n_checkpoints: -1
|
| 109 |
+
last_per_updates: 100
|
| 110 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}_CKA
|
src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "autoencoder",
|
| 3 |
+
"sample_size": 24576,
|
| 4 |
+
"sample_rate": 44100,
|
| 5 |
+
"audio_channels": 2,
|
| 6 |
+
"model": {
|
| 7 |
+
"encoder": {
|
| 8 |
+
"type": "oobleck",
|
| 9 |
+
"requires_grad": false,
|
| 10 |
+
"config": {
|
| 11 |
+
"in_channels": 2,
|
| 12 |
+
"channels": 128,
|
| 13 |
+
"c_mults": [
|
| 14 |
+
1,
|
| 15 |
+
2,
|
| 16 |
+
4,
|
| 17 |
+
8,
|
| 18 |
+
16
|
| 19 |
+
],
|
| 20 |
+
"strides": [
|
| 21 |
+
2,
|
| 22 |
+
4,
|
| 23 |
+
4,
|
| 24 |
+
8,
|
| 25 |
+
8
|
| 26 |
+
],
|
| 27 |
+
"latent_dim": 128,
|
| 28 |
+
"use_snake": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"decoder": {
|
| 32 |
+
"type": "oobleck",
|
| 33 |
+
"config": {
|
| 34 |
+
"out_channels": 2,
|
| 35 |
+
"channels": 128,
|
| 36 |
+
"c_mults": [
|
| 37 |
+
1,
|
| 38 |
+
2,
|
| 39 |
+
4,
|
| 40 |
+
8,
|
| 41 |
+
16
|
| 42 |
+
],
|
| 43 |
+
"strides": [
|
| 44 |
+
2,
|
| 45 |
+
4,
|
| 46 |
+
4,
|
| 47 |
+
8,
|
| 48 |
+
8
|
| 49 |
+
],
|
| 50 |
+
"latent_dim": 64,
|
| 51 |
+
"use_snake": true,
|
| 52 |
+
"final_tanh": false
|
| 53 |
+
}
|
| 54 |
+
},
|
| 55 |
+
"bottleneck": {
|
| 56 |
+
"type": "vae"
|
| 57 |
+
},
|
| 58 |
+
"latent_dim": 64,
|
| 59 |
+
"downsampling_ratio": 2048,
|
| 60 |
+
"io_channels": 2
|
| 61 |
+
},
|
| 62 |
+
"training": {
|
| 63 |
+
"learning_rate": 8e-5,
|
| 64 |
+
"warmup_steps": 0,
|
| 65 |
+
"use_ema": true,
|
| 66 |
+
"optimizer_configs": {
|
| 67 |
+
"autoencoder": {
|
| 68 |
+
"optimizer": {
|
| 69 |
+
"type": "AdamW",
|
| 70 |
+
"config": {
|
| 71 |
+
"betas": [
|
| 72 |
+
0.8,
|
| 73 |
+
0.99
|
| 74 |
+
],
|
| 75 |
+
"lr": 1e-4,
|
| 76 |
+
"weight_decay": 8e-4
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
"scheduler": {
|
| 80 |
+
"type": "InverseLR",
|
| 81 |
+
"config": {
|
| 82 |
+
"inv_gamma": 200000,
|
| 83 |
+
"power": 0.5,
|
| 84 |
+
"warmup": 0.999
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
},
|
| 88 |
+
"discriminator": {
|
| 89 |
+
"optimizer": {
|
| 90 |
+
"type": "AdamW",
|
| 91 |
+
"config": {
|
| 92 |
+
"betas": [
|
| 93 |
+
0.8,
|
| 94 |
+
0.99
|
| 95 |
+
],
|
| 96 |
+
"lr": 3e-4,
|
| 97 |
+
"weight_decay": 1e-3
|
| 98 |
+
}
|
| 99 |
+
},
|
| 100 |
+
"scheduler": {
|
| 101 |
+
"type": "InverseLR",
|
| 102 |
+
"config": {
|
| 103 |
+
"inv_gamma": 200000,
|
| 104 |
+
"power": 0.5,
|
| 105 |
+
"warmup": 0.999
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
"loss_configs": {
|
| 111 |
+
"discriminator": {
|
| 112 |
+
"type": "encodec",
|
| 113 |
+
"config": {
|
| 114 |
+
"filters": 64,
|
| 115 |
+
"n_ffts": [
|
| 116 |
+
2048,
|
| 117 |
+
1024,
|
| 118 |
+
512,
|
| 119 |
+
256,
|
| 120 |
+
128
|
| 121 |
+
],
|
| 122 |
+
"hop_lengths": [
|
| 123 |
+
512,
|
| 124 |
+
256,
|
| 125 |
+
128,
|
| 126 |
+
64,
|
| 127 |
+
32
|
| 128 |
+
],
|
| 129 |
+
"win_lengths": [
|
| 130 |
+
2048,
|
| 131 |
+
1024,
|
| 132 |
+
512,
|
| 133 |
+
256,
|
| 134 |
+
128
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
"weights": {
|
| 138 |
+
"adversarial": 0.1,
|
| 139 |
+
"feature_matching": 5.0
|
| 140 |
+
}
|
| 141 |
+
},
|
| 142 |
+
"spectral": {
|
| 143 |
+
"type": "mrstft",
|
| 144 |
+
"config": {
|
| 145 |
+
"fft_sizes": [
|
| 146 |
+
2048,
|
| 147 |
+
1024,
|
| 148 |
+
512,
|
| 149 |
+
256,
|
| 150 |
+
128,
|
| 151 |
+
64,
|
| 152 |
+
32
|
| 153 |
+
],
|
| 154 |
+
"hop_sizes": [
|
| 155 |
+
512,
|
| 156 |
+
256,
|
| 157 |
+
128,
|
| 158 |
+
64,
|
| 159 |
+
32,
|
| 160 |
+
16,
|
| 161 |
+
8
|
| 162 |
+
],
|
| 163 |
+
"win_lengths": [
|
| 164 |
+
2048,
|
| 165 |
+
1024,
|
| 166 |
+
512,
|
| 167 |
+
256,
|
| 168 |
+
128,
|
| 169 |
+
64,
|
| 170 |
+
32
|
| 171 |
+
],
|
| 172 |
+
"perceptual_weighting": true
|
| 173 |
+
},
|
| 174 |
+
"weights": {
|
| 175 |
+
"mrstft": 1.0
|
| 176 |
+
}
|
| 177 |
+
},
|
| 178 |
+
"time": {
|
| 179 |
+
"type": "l1",
|
| 180 |
+
"weights": {
|
| 181 |
+
"l1": 0.0
|
| 182 |
+
}
|
| 183 |
+
},
|
| 184 |
+
"bottleneck": {
|
| 185 |
+
"type": "kl",
|
| 186 |
+
"weights": {
|
| 187 |
+
"kl": 1e-4
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
},
|
| 191 |
+
"demo": {
|
| 192 |
+
"demo_every": 10000,
|
| 193 |
+
"demo_dir": "/home/node44_tmpdata3/netease/hkchen/stable-audio-tools-1/stable-audio-tools/outputs/vae_large_fresh_data_demo"
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
src/YingMusicSinger/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# from vocos import Vocos
|
| 4 |
+
from singer.model import Singer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def load_model(model_cls, model_cfg, ckpt_path, vocab_char_map, device="cuda"):
|
| 8 |
+
model_arc = model_cfg.model.arch
|
| 9 |
+
mel_spec_kwargs = model_cfg.model.mel_spec
|
| 10 |
+
vocab_size = len(vocab_char_map)
|
| 11 |
+
|
| 12 |
+
backbone = model_cls(
|
| 13 |
+
**model_arc, text_num_embeds=vocab_size, mel_dim=mel_spec_kwargs.n_mel_channels
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
model = Singer(
|
| 17 |
+
transformer=backbone,
|
| 18 |
+
mel_spec_kwargs=mel_spec_kwargs,
|
| 19 |
+
vocab_char_map=vocab_char_map,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 23 |
+
if "ema_model_state_dict" in checkpoint:
|
| 24 |
+
state_dict = checkpoint["ema_model_state_dict"]
|
| 25 |
+
elif "model_state_dict" in checkpoint:
|
| 26 |
+
state_dict = checkpoint["model_state_dict"]
|
| 27 |
+
else:
|
| 28 |
+
state_dict = checkpoint
|
| 29 |
+
|
| 30 |
+
# Handle module prefix
|
| 31 |
+
new_state_dict = {}
|
| 32 |
+
for k, v in state_dict.items():
|
| 33 |
+
if k.startswith("module."):
|
| 34 |
+
new_state_dict[k[7:]] = v
|
| 35 |
+
else:
|
| 36 |
+
new_state_dict[k] = v
|
| 37 |
+
|
| 38 |
+
model.load_state_dict(new_state_dict)
|
| 39 |
+
model.to(device)
|
| 40 |
+
model.eval()
|
| 41 |
+
return model
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_vocoder(vocoder_name, is_local, local_path, device="cuda"):
|
| 45 |
+
if vocoder_name == "vocos":
|
| 46 |
+
if is_local:
|
| 47 |
+
vocoder = Vocos.from_hparams(local_path).to(device)
|
| 48 |
+
else:
|
| 49 |
+
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
|
| 50 |
+
elif vocoder_name == "bigvgan":
|
| 51 |
+
# Placeholder for bigvgan
|
| 52 |
+
# You might need to import bigvgan here
|
| 53 |
+
raise NotImplementedError("BigVGAN loading not implemented yet")
|
| 54 |
+
else:
|
| 55 |
+
# Fallback or error
|
| 56 |
+
print(
|
| 57 |
+
f"Warning: Unknown vocoder {vocoder_name}, trying to load from local path if provided"
|
| 58 |
+
)
|
| 59 |
+
if is_local:
|
| 60 |
+
# Try loading as vocos or similar if generic
|
| 61 |
+
vocoder = Vocos.from_hparams(local_path).to(device)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unknown vocoder: {vocoder_name}")
|
| 64 |
+
return vocoder
|
src/YingMusicSinger/utils/cnen_tokenizer.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CNENTokenizer:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
with open(
|
| 7 |
+
"./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json",
|
| 8 |
+
"r",
|
| 9 |
+
encoding="utf-8",
|
| 10 |
+
) as file:
|
| 11 |
+
self.phone2id: dict = json.load(file)["vocab"]
|
| 12 |
+
self.phone2id = {k: int(v) + 1 for (k, v) in self.phone2id.items()}
|
| 13 |
+
|
| 14 |
+
self.pad_token_id = 0
|
| 15 |
+
self.phone2id["<PAD>"] = 0
|
| 16 |
+
|
| 17 |
+
self.punct_token_id = len(self.phone2id) # Punctuation marks tokens
|
| 18 |
+
self.phone2id["<PUNCT>"] = len(self.phone2id)
|
| 19 |
+
|
| 20 |
+
self.sep_token_id = len(self.phone2id) # Sentence separation token
|
| 21 |
+
self.phone2id["<SEP>"] = len(self.phone2id)
|
| 22 |
+
|
| 23 |
+
self.id2phone = {v: k for (k, v) in self.phone2id.items()}
|
| 24 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p_generation import chn_eng_g2p
|
| 25 |
+
|
| 26 |
+
self.tokenizer = chn_eng_g2p
|
| 27 |
+
|
| 28 |
+
def encode(self, text):
|
| 29 |
+
phone, token = self.tokenizer(text)
|
| 30 |
+
token = [x + 1 for x in token]
|
| 31 |
+
return token
|
| 32 |
+
|
| 33 |
+
def decode(self, token):
|
| 34 |
+
return "|".join([self.id2phone[x] for x in token])
|
src/YingMusicSinger/utils/common.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import jieba
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from pypinyin import Style, lazy_pinyin
|
| 11 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 12 |
+
|
| 13 |
+
# seed everything
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def seed_everything(seed=0):
|
| 17 |
+
random.seed(seed)
|
| 18 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 19 |
+
torch.manual_seed(seed)
|
| 20 |
+
torch.cuda.manual_seed(seed)
|
| 21 |
+
torch.cuda.manual_seed_all(seed)
|
| 22 |
+
torch.backends.cudnn.deterministic = True
|
| 23 |
+
torch.backends.cudnn.benchmark = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# helpers
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def exists(v):
|
| 30 |
+
return v is not None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def default(v, d):
|
| 34 |
+
return v if exists(v) else d
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def is_package_available(package_name: str) -> bool:
|
| 38 |
+
try:
|
| 39 |
+
import importlib
|
| 40 |
+
|
| 41 |
+
package_exists = importlib.util.find_spec(package_name) is not None
|
| 42 |
+
return package_exists
|
| 43 |
+
except Exception:
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# tensor helpers
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
|
| 51 |
+
if not exists(length):
|
| 52 |
+
length = t.amax()
|
| 53 |
+
|
| 54 |
+
seq = torch.arange(length, device=t.device)
|
| 55 |
+
return seq[None, :] < t[:, None]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
|
| 59 |
+
max_seq_len = seq_len.max().item()
|
| 60 |
+
seq = torch.arange(max_seq_len, device=start.device).long()
|
| 61 |
+
start_mask = seq[None, :] >= start[:, None]
|
| 62 |
+
end_mask = seq[None, :] < end[:, None]
|
| 63 |
+
return start_mask & end_mask
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
|
| 67 |
+
lengths = (frac_lengths * seq_len).long()
|
| 68 |
+
max_start = seq_len - lengths
|
| 69 |
+
|
| 70 |
+
rand = torch.rand_like(frac_lengths)
|
| 71 |
+
start = (max_start * rand).long().clamp(min=0)
|
| 72 |
+
end = start + lengths
|
| 73 |
+
|
| 74 |
+
return mask_from_start_end_indices(seq_len, start, end)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
|
| 78 |
+
if not exists(mask):
|
| 79 |
+
return t.mean(dim=1)
|
| 80 |
+
|
| 81 |
+
t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
|
| 82 |
+
num = t.sum(dim=1)
|
| 83 |
+
den = mask.float().sum(dim=1)
|
| 84 |
+
|
| 85 |
+
return num / den.clamp(min=1.0)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# simple utf-8 tokenizer, since paper went character based
|
| 89 |
+
def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
|
| 90 |
+
list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
|
| 91 |
+
text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
|
| 92 |
+
return text
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# char tokenizer, based on custom dataset's extracted .txt file
|
| 96 |
+
def list_str_to_idx(
|
| 97 |
+
text: list[str] | list[list[str]],
|
| 98 |
+
vocab_char_map: dict[str, int], # {char: idx}
|
| 99 |
+
padding_value=-1,
|
| 100 |
+
) -> int["b nt"]: # noqa: F722
|
| 101 |
+
list_idx_tensors = [
|
| 102 |
+
torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
|
| 103 |
+
] # pinyin or char style
|
| 104 |
+
text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
|
| 105 |
+
return text
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Get tokenizer
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
| 112 |
+
"""
|
| 113 |
+
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
| 114 |
+
- "char" for char-wise tokenizer, need .txt vocab_file
|
| 115 |
+
- "byte" for utf-8 tokenizer
|
| 116 |
+
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
| 117 |
+
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
| 118 |
+
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
| 119 |
+
- if use "byte", set to 256 (unicode byte range)
|
| 120 |
+
"""
|
| 121 |
+
if tokenizer in ["pinyin", "char"]:
|
| 122 |
+
# tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
|
| 123 |
+
# tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}/vocab.txt")
|
| 124 |
+
tokenizer_path = (
|
| 125 |
+
"/ailab-train/speech/zhengjunjie/opt/models/F5-TTS/F5TTS_v1_Base/vocab.txt"
|
| 126 |
+
)
|
| 127 |
+
with open(tokenizer_path, "r", encoding="utf-8") as f:
|
| 128 |
+
vocab_char_map = {}
|
| 129 |
+
for i, char in enumerate(f):
|
| 130 |
+
vocab_char_map[char[:-1]] = i
|
| 131 |
+
vocab_size = len(vocab_char_map)
|
| 132 |
+
assert vocab_char_map[" "] == 0, (
|
| 133 |
+
"make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
elif tokenizer == "byte":
|
| 137 |
+
vocab_char_map = None
|
| 138 |
+
vocab_size = 256
|
| 139 |
+
|
| 140 |
+
elif tokenizer == "custom":
|
| 141 |
+
with open(dataset_name, "r", encoding="utf-8") as f:
|
| 142 |
+
vocab_char_map = {}
|
| 143 |
+
for i, char in enumerate(f):
|
| 144 |
+
vocab_char_map[char[:-1]] = i
|
| 145 |
+
vocab_size = len(vocab_char_map)
|
| 146 |
+
|
| 147 |
+
return vocab_char_map, vocab_size
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# convert char to pinyin
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def convert_char_to_pinyin(text_list, polyphone=True, with_tone=True):
|
| 154 |
+
if with_tone:
|
| 155 |
+
style = Style.TONE3 # with tone number
|
| 156 |
+
else:
|
| 157 |
+
style = Style.NORMAL # no tone
|
| 158 |
+
|
| 159 |
+
if jieba.dt.initialized is False:
|
| 160 |
+
jieba.default_logger.setLevel(50) # CRITICAL
|
| 161 |
+
jieba.initialize()
|
| 162 |
+
|
| 163 |
+
final_text_list = []
|
| 164 |
+
custom_trans = str.maketrans(
|
| 165 |
+
{";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
|
| 166 |
+
) # add custom trans here, to address oov
|
| 167 |
+
|
| 168 |
+
def is_chinese(c):
|
| 169 |
+
return (
|
| 170 |
+
"\u3100" <= c <= "\u9fff" # common chinese characters
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
for text in text_list:
|
| 174 |
+
char_list = []
|
| 175 |
+
text = text.translate(custom_trans)
|
| 176 |
+
for seg in jieba.cut(text):
|
| 177 |
+
seg_byte_len = len(bytes(seg, "UTF-8"))
|
| 178 |
+
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
| 179 |
+
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
| 180 |
+
char_list.append(" ")
|
| 181 |
+
char_list.extend(seg)
|
| 182 |
+
elif polyphone and seg_byte_len == 3 * len(
|
| 183 |
+
seg
|
| 184 |
+
): # if pure east asian characters
|
| 185 |
+
seg_ = lazy_pinyin(seg, style=style, tone_sandhi=True)
|
| 186 |
+
for i, c in enumerate(seg):
|
| 187 |
+
if is_chinese(c):
|
| 188 |
+
char_list.append(" ")
|
| 189 |
+
char_list.append(seg_[i])
|
| 190 |
+
else: # if mixed characters, alphabets and symbols
|
| 191 |
+
for c in seg:
|
| 192 |
+
if ord(c) < 256:
|
| 193 |
+
char_list.extend(c)
|
| 194 |
+
elif is_chinese(c):
|
| 195 |
+
char_list.append(" ")
|
| 196 |
+
char_list.extend(lazy_pinyin(c, style=style, tone_sandhi=True))
|
| 197 |
+
else:
|
| 198 |
+
char_list.append(c)
|
| 199 |
+
|
| 200 |
+
if with_tone is False:
|
| 201 |
+
for idx, item in enumerate(char_list):
|
| 202 |
+
char_list[idx] = "__" + item
|
| 203 |
+
|
| 204 |
+
final_text_list.append(char_list)
|
| 205 |
+
|
| 206 |
+
return final_text_list
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# filter func for dirty data with many repetitions
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def repetition_found(text, length=2, tolerance=10):
|
| 213 |
+
pattern_count = defaultdict(int)
|
| 214 |
+
for i in range(len(text) - length + 1):
|
| 215 |
+
pattern = text[i : i + length]
|
| 216 |
+
pattern_count[pattern] += 1
|
| 217 |
+
for pattern, count in pattern_count.items():
|
| 218 |
+
if count > tolerance:
|
| 219 |
+
return True
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# get the empirically pruned step for sampling
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_epss_timesteps(n, device, dtype):
|
| 227 |
+
dt = 1 / 32
|
| 228 |
+
predefined_timesteps = {
|
| 229 |
+
5: [0, 2, 4, 8, 16, 32],
|
| 230 |
+
6: [0, 2, 4, 6, 8, 16, 32],
|
| 231 |
+
7: [0, 2, 4, 6, 8, 16, 24, 32],
|
| 232 |
+
10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
|
| 233 |
+
12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
| 234 |
+
16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
|
| 235 |
+
}
|
| 236 |
+
t = predefined_timesteps.get(n, [])
|
| 237 |
+
if not t:
|
| 238 |
+
return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
|
| 239 |
+
return dt * torch.tensor(t, device=device, dtype=dtype)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def calculate_similarity_matrix_with_mask(
|
| 243 |
+
vectors: torch.Tensor, valid_mask: torch.Tensor = None
|
| 244 |
+
) -> torch.Tensor:
|
| 245 |
+
if valid_mask is None:
|
| 246 |
+
valid_mask = torch.ones(
|
| 247 |
+
vectors.shape[:-1], dtype=torch.bool, device=vectors.device
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
if valid_mask.dtype != torch.bool:
|
| 251 |
+
valid_mask = valid_mask.bool()
|
| 252 |
+
|
| 253 |
+
vectors = vectors * valid_mask.unsqueeze(-1).float()
|
| 254 |
+
|
| 255 |
+
vectors_normalized = F.normalize(vectors, p=2, dim=-1, eps=1e-8)
|
| 256 |
+
|
| 257 |
+
# (B, N, D) * (B, D, N) -> (B, N, N)
|
| 258 |
+
similarity_matrix = torch.bmm(
|
| 259 |
+
vectors_normalized, vectors_normalized.transpose(1, 2)
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# (B, N, 1) & (B, 1, N) -> (B, N, N)
|
| 263 |
+
combined_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
|
| 264 |
+
|
| 265 |
+
similarity_matrix.masked_fill_(~combined_mask, 0.0)
|
| 266 |
+
|
| 267 |
+
return similarity_matrix
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _center_gram_batch(gram: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
| 271 |
+
"""Center Gram matrices in batch.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
gram: [B, N, N] Gram matrices.
|
| 275 |
+
mask: [B, N] optional validity mask.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Centered Gram matrices [B, N, N].
|
| 279 |
+
"""
|
| 280 |
+
if mask is None:
|
| 281 |
+
gram = gram - gram.mean(dim=2, keepdim=True)
|
| 282 |
+
gram = gram - gram.mean(dim=1, keepdim=True)
|
| 283 |
+
return gram
|
| 284 |
+
else:
|
| 285 |
+
mask_float = mask.float()
|
| 286 |
+
n_valid = mask_float.sum(dim=1, keepdim=True).clamp(min=1.0)
|
| 287 |
+
|
| 288 |
+
mask_mat = mask_float.unsqueeze(2) * mask_float.unsqueeze(1) # [B, N, N]
|
| 289 |
+
gram = gram * mask_mat
|
| 290 |
+
|
| 291 |
+
row_mean = gram.sum(dim=2, keepdim=True) / n_valid.unsqueeze(2)
|
| 292 |
+
col_mean = gram.sum(dim=1, keepdim=True) / n_valid.unsqueeze(1)
|
| 293 |
+
grand_mean = row_mean.sum(dim=1, keepdim=True) / n_valid.unsqueeze(2)
|
| 294 |
+
|
| 295 |
+
centered = gram - row_mean - col_mean + grand_mean
|
| 296 |
+
return centered * mask_mat
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def cka_loss(
|
| 300 |
+
sim_x: torch.Tensor, sim_y: torch.Tensor, valid_mask: torch.Tensor = None
|
| 301 |
+
) -> torch.Tensor:
|
| 302 |
+
"""Compute CKA loss between two similarity matrices in batch.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
sim_x: [B, N, N] similarity matrix.
|
| 306 |
+
sim_y: [B, N, N] similarity matrix.
|
| 307 |
+
valid_mask: [B, N] optional validity mask.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
Scalar CKA loss (1 - mean CKA similarity).
|
| 311 |
+
"""
|
| 312 |
+
eps = 1e-6
|
| 313 |
+
|
| 314 |
+
sim_x_c = _center_gram_batch(sim_x, valid_mask) # [B, N, N]
|
| 315 |
+
sim_y_c = _center_gram_batch(sim_y, valid_mask) # [B, N, N]
|
| 316 |
+
|
| 317 |
+
# HSIC via element-wise product summed over spatial dims
|
| 318 |
+
hsic = torch.sum(sim_x_c * sim_y_c, dim=(1, 2)) # [B]
|
| 319 |
+
|
| 320 |
+
norm_x = torch.sqrt(torch.sum(sim_x_c**2, dim=(1, 2)) + eps) # [B]
|
| 321 |
+
norm_y = torch.sqrt(torch.sum(sim_y_c**2, dim=(1, 2)) + eps) # [B]
|
| 322 |
+
|
| 323 |
+
cka_similarity = hsic / (norm_x * norm_y + eps) # [B]
|
| 324 |
+
|
| 325 |
+
return torch.mean(1.0 - cka_similarity)
|
src/YingMusicSinger/utils/lrc_align.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def align_lrc_put_to_front(tokenizer, lrc_start_times, lrc_lines, total_lens):
|
| 5 |
+
lrc_text_list = []
|
| 6 |
+
lrc_token = np.zeros(total_lens, dtype=np.int64)
|
| 7 |
+
|
| 8 |
+
token_start = 0
|
| 9 |
+
for temp in lrc_lines:
|
| 10 |
+
# for punct in ",。!?、;:,.!?;:":
|
| 11 |
+
# one_line_lrc = one_line_lrc.replace(punct, ",")
|
| 12 |
+
# one_line_lrc = one_line_lrc.strip(",。!?、;:,.!?;: ")
|
| 13 |
+
for one_line_lrc in temp.split("|"):
|
| 14 |
+
lrc_text_list.append(one_line_lrc)
|
| 15 |
+
one_line_token = tokenizer.encode(one_line_lrc)
|
| 16 |
+
lrc_text_list.append("<SEP>")
|
| 17 |
+
one_line_token = one_line_token + [tokenizer.phone2id["<SEP>"]]
|
| 18 |
+
|
| 19 |
+
one_line_token = np.array(one_line_token)
|
| 20 |
+
assert token_start + len(one_line_token) <= len(lrc_token), (
|
| 21 |
+
"lrc_token 的长度超过了 vocal latent"
|
| 22 |
+
)
|
| 23 |
+
lrc_token[token_start : token_start + len(one_line_token)] = one_line_token
|
| 24 |
+
token_start = token_start + len(one_line_token)
|
| 25 |
+
return lrc_token, "".join(lrc_text_list)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def align_lrc_sentence_level(
|
| 29 |
+
tokenizer, lrc_start_times, lrc_lines, total_lens, vae_frame_rate
|
| 30 |
+
):
|
| 31 |
+
# BUG Only the prompt and the two segments to be generated have start timestamps, the generated content and the prompt do not contain anything like <SEP>.
|
| 32 |
+
lrc_text_list = []
|
| 33 |
+
lrc_token = np.zeros(total_lens, dtype=np.int64)
|
| 34 |
+
|
| 35 |
+
token_start = 0
|
| 36 |
+
for lrc_start_time, one_line_lrc in zip(lrc_start_times, lrc_lines):
|
| 37 |
+
one_line_lrc = one_line_lrc.replace("|", " ")
|
| 38 |
+
for punct in ",。!?、;:,.!?;:":
|
| 39 |
+
one_line_lrc = one_line_lrc.replace(punct, ",")
|
| 40 |
+
one_line_lrc = one_line_lrc.strip(",。!?、;:,.!?;: ")
|
| 41 |
+
|
| 42 |
+
lrc_text_list.append(one_line_lrc)
|
| 43 |
+
one_line_token = tokenizer.encode(one_line_lrc)
|
| 44 |
+
lrc_text_list.append("<SEP>")
|
| 45 |
+
one_line_token = one_line_token + [tokenizer.phone2id["<SEP>"]]
|
| 46 |
+
|
| 47 |
+
one_line_token = np.array(one_line_token)
|
| 48 |
+
|
| 49 |
+
timestamp_cal_start_frame = int(lrc_start_time * vae_frame_rate)
|
| 50 |
+
|
| 51 |
+
# Handling Postponement Situations
|
| 52 |
+
timestamp_cal_start_frame = max(timestamp_cal_start_frame, token_start)
|
| 53 |
+
|
| 54 |
+
assert timestamp_cal_start_frame + len(one_line_token) <= len(lrc_token), (
|
| 55 |
+
"The length of the lrc_token exceeds that of the vocal latent"
|
| 56 |
+
)
|
| 57 |
+
lrc_token[
|
| 58 |
+
timestamp_cal_start_frame : timestamp_cal_start_frame + len(one_line_token)
|
| 59 |
+
] = one_line_token
|
| 60 |
+
token_start = timestamp_cal_start_frame + len(one_line_token)
|
| 61 |
+
return lrc_token, "".join(lrc_text_list)
|
src/YingMusicSinger/utils/mel_spectrogram.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MelodySpectrogram(torch.nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
n_mel_channels=80,
|
| 11 |
+
sampling_rate=44100,
|
| 12 |
+
win_length=2048,
|
| 13 |
+
hop_length=512,
|
| 14 |
+
n_fft=None,
|
| 15 |
+
mel_fmin=0,
|
| 16 |
+
mel_fmax=None,
|
| 17 |
+
clamp=1e-5,
|
| 18 |
+
):
|
| 19 |
+
from librosa.filters import mel
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
n_fft = win_length if n_fft is None else n_fft
|
| 23 |
+
self.hann_window = {}
|
| 24 |
+
mel_basis = mel(
|
| 25 |
+
sr=sampling_rate,
|
| 26 |
+
n_fft=n_fft,
|
| 27 |
+
n_mels=n_mel_channels,
|
| 28 |
+
fmin=mel_fmin,
|
| 29 |
+
fmax=mel_fmax,
|
| 30 |
+
htk=True,
|
| 31 |
+
)
|
| 32 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 33 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 34 |
+
self.n_fft = n_fft
|
| 35 |
+
self.hop_length = hop_length
|
| 36 |
+
self.win_length = win_length
|
| 37 |
+
self.sampling_rate = sampling_rate
|
| 38 |
+
self.n_mel_channels = n_mel_channels
|
| 39 |
+
self.clamp = clamp
|
| 40 |
+
|
| 41 |
+
def _mel_forward(self, audio, keyshift=0, speed=1, center=True):
|
| 42 |
+
factor = 2 ** (keyshift / 12)
|
| 43 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
| 44 |
+
win_length_new = int(np.round(self.win_length * factor))
|
| 45 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
| 46 |
+
|
| 47 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
| 48 |
+
if keyshift_key not in self.hann_window:
|
| 49 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
|
| 50 |
+
audio.device
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
fft = torch.stft(
|
| 54 |
+
audio,
|
| 55 |
+
n_fft=n_fft_new,
|
| 56 |
+
hop_length=hop_length_new,
|
| 57 |
+
win_length=win_length_new,
|
| 58 |
+
window=self.hann_window[keyshift_key],
|
| 59 |
+
center=center,
|
| 60 |
+
return_complex=True,
|
| 61 |
+
)
|
| 62 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 63 |
+
|
| 64 |
+
if keyshift != 0:
|
| 65 |
+
size = self.n_fft // 2 + 1
|
| 66 |
+
resize = magnitude.size(1)
|
| 67 |
+
if resize < size:
|
| 68 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
| 69 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 70 |
+
|
| 71 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 72 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 73 |
+
return log_mel_spec
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def forward(self, audio, sr, sil_len_to_end=None, keyshift=0, speed=1):
|
| 77 |
+
# audio, sr = torchaudio.load(audio_path)
|
| 78 |
+
if sil_len_to_end is not None:
|
| 79 |
+
silence = torch.zeros(audio.shape[0], int(sr * sil_len_to_end))
|
| 80 |
+
audio = torch.cat([audio, silence], dim=1)
|
| 81 |
+
if sr != self.sampling_rate:
|
| 82 |
+
audio = torchaudio.transforms.Resample(sr, self.sampling_rate)(audio)
|
| 83 |
+
if audio.shape[0] > 1:
|
| 84 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
| 85 |
+
audio = audio.to(self.mel_basis.device)
|
| 86 |
+
return self._mel_forward(audio, keyshift=keyshift, speed=speed)
|