xjsc0 commited on
Commit
61e6f25
·
1 Parent(s): 64ec292
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python bytecode files
2
+ __pycache__/
3
+ *.py[cod]
4
+
5
+ # Virtual environment
6
+ venv/
7
+ ENV/
8
+ env/
9
+ .venv/
10
+ .ENV/
11
+ # Python IDEs
12
+ .idea/
13
+ .vscode/
14
+ *.sublime-project
15
+ *.sublime-workspace
16
+
17
+ # Jupyter Notebook checkpoints
18
+ .ipynb_checkpoints/
19
+
20
+ # Data files (data used for training or testing)
21
+ *.log
22
+
23
+ # TensorBoard logs
24
+ runs/
25
+ tensorboard_logs/
26
+
27
+ # Operating system files
28
+ .DS_Store
29
+ Thumbs.db
30
+
31
+ # PyCharm files
32
+ *.iml
33
+ .idea/
34
+
35
+ # Coverage and testing tools
36
+ .coverage
37
+ nosetests.xml
38
+ coverage.xml
39
+ *.cover
40
+ *.log
41
+
42
+ # Compiled extension modules
43
+ *.so
44
+ *.dylib
45
+ *.pyd
46
+
47
+ # Cython debug symbols
48
+ cython_debug/
49
+
50
+ # Other custom ignore rules
51
+ *.bak
52
+ *.swp
53
+
54
+ .ruff_cache/
src/YingMusicSinger/config/YingMusic_Singer.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+
5
+ runname: YingMusic_Singer
6
+
7
+ datasets:
8
+ name: svs_infer
9
+ batch_size_per_gpu: 6
10
+ batch_size_type: sample
11
+ max_samples: null
12
+ num_workers: 4
13
+
14
+ datasets_cfg:
15
+ filelist_path: /path/to/your/filelist
16
+ vae_frame_rate: 21.533203125
17
+ text_num_embeds: 373
18
+ lrc_align_mode: sentence_level
19
+
20
+ optim:
21
+ epochs: null
22
+ num_updates: 31518
23
+ learning_rate: 7e-6
24
+ num_warmup_updates: 60
25
+ grad_accumulation_steps: 1
26
+ max_grad_norm: 1.0
27
+ bnb_optimizer: False
28
+ max_iter: null
29
+
30
+ model:
31
+ name: YingMusic_Singer
32
+ tokenizer: null
33
+ tokenizer_path: null
34
+ is_tts_pretrain: 0
35
+ melody_input_source: some_pretrain_fuzzdisturb
36
+ cka_disabled: 0
37
+ backbone: DiT
38
+ f0_fn_type: null
39
+ f0_fn_path: null
40
+
41
+ arch:
42
+ dim: 1024
43
+ depth: 22
44
+ heads: 16
45
+ ff_mult: 2
46
+ text_dim: 512
47
+ text_mask_padding: False
48
+ qk_norm: null
49
+ conv_layers: 4
50
+ pe_attn_head: null
51
+ attn_backend: torch
52
+ attn_mask_enabled: False
53
+ checkpoint_activations: False
54
+ guidance_scale_embed_dim: null
55
+
56
+ mel_spec:
57
+ n_mel_channels: 64
58
+ mel_spec_type: vae
59
+
60
+ vocoder:
61
+ is_local: True
62
+ local_path: null
63
+
64
+ midi_extractor:
65
+ path: ckpts/model_ckpt_steps_100000_simplified.ckpt
66
+
67
+ extra_parameters:
68
+ some_pretrain_fuzzdisturb:
69
+ dim: 128
70
+ drop_type: equal_space
71
+ drop_prob: [1, 9]
72
+ noise_scale: 0.0
73
+ blur_kernel: 0
74
+
75
+ grpo:
76
+ noise_level: 0.8
77
+ num_samples: 8
78
+ upper_clip_epsilon: 0.02
79
+ lower_clip_epsilon: 0.002
80
+ beta: 1
81
+ ppo_epochs: 1
82
+ num_steps: 32
83
+ sde_window_range: [1, 16]
84
+ sde_window_size: 2
85
+ delet_temp: 10
86
+ use_cfg_sample: false
87
+ wer_SDI_weights: [1, 1, 1]
88
+ reward_config: {"qwen_asr_wer": 0.25, "f0_correlation": 0.25, "qwenfeat": 0.25, "sim_wavlm_large": 0.25}
89
+ grpo_wanted_loss: ["qwen_asr_wer_reward", "f0_correlation_reward", "qwenfeat_reward", "sim_wavlm_large_reward"]
90
+ use_guidance_scale_embed: false
91
+ t_shift: 0.5
92
+ cfg_strength: null
93
+ GDPO_batch_norm: false
94
+ use_egrpo: false
95
+ egrpo_tau: null
96
+ egrpo_d: null
97
+ use_max_group_std_dev: false
98
+
99
+ ema_kwargs:
100
+ beta: 0.995
101
+ update_after_step: 100
102
+ update_every: 1
103
+
104
+ ckpts:
105
+ logger: tensorboard
106
+ log_samples: False
107
+ save_per_updates: 100
108
+ keep_last_n_checkpoints: -1
109
+ last_per_updates: 100
110
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}_CKA
src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 24576,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "requires_grad": false,
10
+ "config": {
11
+ "in_channels": 2,
12
+ "channels": 128,
13
+ "c_mults": [
14
+ 1,
15
+ 2,
16
+ 4,
17
+ 8,
18
+ 16
19
+ ],
20
+ "strides": [
21
+ 2,
22
+ 4,
23
+ 4,
24
+ 8,
25
+ 8
26
+ ],
27
+ "latent_dim": 128,
28
+ "use_snake": true
29
+ }
30
+ },
31
+ "decoder": {
32
+ "type": "oobleck",
33
+ "config": {
34
+ "out_channels": 2,
35
+ "channels": 128,
36
+ "c_mults": [
37
+ 1,
38
+ 2,
39
+ 4,
40
+ 8,
41
+ 16
42
+ ],
43
+ "strides": [
44
+ 2,
45
+ 4,
46
+ 4,
47
+ 8,
48
+ 8
49
+ ],
50
+ "latent_dim": 64,
51
+ "use_snake": true,
52
+ "final_tanh": false
53
+ }
54
+ },
55
+ "bottleneck": {
56
+ "type": "vae"
57
+ },
58
+ "latent_dim": 64,
59
+ "downsampling_ratio": 2048,
60
+ "io_channels": 2
61
+ },
62
+ "training": {
63
+ "learning_rate": 8e-5,
64
+ "warmup_steps": 0,
65
+ "use_ema": true,
66
+ "optimizer_configs": {
67
+ "autoencoder": {
68
+ "optimizer": {
69
+ "type": "AdamW",
70
+ "config": {
71
+ "betas": [
72
+ 0.8,
73
+ 0.99
74
+ ],
75
+ "lr": 1e-4,
76
+ "weight_decay": 8e-4
77
+ }
78
+ },
79
+ "scheduler": {
80
+ "type": "InverseLR",
81
+ "config": {
82
+ "inv_gamma": 200000,
83
+ "power": 0.5,
84
+ "warmup": 0.999
85
+ }
86
+ }
87
+ },
88
+ "discriminator": {
89
+ "optimizer": {
90
+ "type": "AdamW",
91
+ "config": {
92
+ "betas": [
93
+ 0.8,
94
+ 0.99
95
+ ],
96
+ "lr": 3e-4,
97
+ "weight_decay": 1e-3
98
+ }
99
+ },
100
+ "scheduler": {
101
+ "type": "InverseLR",
102
+ "config": {
103
+ "inv_gamma": 200000,
104
+ "power": 0.5,
105
+ "warmup": 0.999
106
+ }
107
+ }
108
+ }
109
+ },
110
+ "loss_configs": {
111
+ "discriminator": {
112
+ "type": "encodec",
113
+ "config": {
114
+ "filters": 64,
115
+ "n_ffts": [
116
+ 2048,
117
+ 1024,
118
+ 512,
119
+ 256,
120
+ 128
121
+ ],
122
+ "hop_lengths": [
123
+ 512,
124
+ 256,
125
+ 128,
126
+ 64,
127
+ 32
128
+ ],
129
+ "win_lengths": [
130
+ 2048,
131
+ 1024,
132
+ 512,
133
+ 256,
134
+ 128
135
+ ]
136
+ },
137
+ "weights": {
138
+ "adversarial": 0.1,
139
+ "feature_matching": 5.0
140
+ }
141
+ },
142
+ "spectral": {
143
+ "type": "mrstft",
144
+ "config": {
145
+ "fft_sizes": [
146
+ 2048,
147
+ 1024,
148
+ 512,
149
+ 256,
150
+ 128,
151
+ 64,
152
+ 32
153
+ ],
154
+ "hop_sizes": [
155
+ 512,
156
+ 256,
157
+ 128,
158
+ 64,
159
+ 32,
160
+ 16,
161
+ 8
162
+ ],
163
+ "win_lengths": [
164
+ 2048,
165
+ 1024,
166
+ 512,
167
+ 256,
168
+ 128,
169
+ 64,
170
+ 32
171
+ ],
172
+ "perceptual_weighting": true
173
+ },
174
+ "weights": {
175
+ "mrstft": 1.0
176
+ }
177
+ },
178
+ "time": {
179
+ "type": "l1",
180
+ "weights": {
181
+ "l1": 0.0
182
+ }
183
+ },
184
+ "bottleneck": {
185
+ "type": "kl",
186
+ "weights": {
187
+ "kl": 1e-4
188
+ }
189
+ }
190
+ },
191
+ "demo": {
192
+ "demo_every": 10000,
193
+ "demo_dir": "/home/node44_tmpdata3/netease/hkchen/stable-audio-tools-1/stable-audio-tools/outputs/vae_large_fresh_data_demo"
194
+ }
195
+ }
196
+ }
src/YingMusicSinger/utils/checkpoint.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # from vocos import Vocos
4
+ from singer.model import Singer
5
+
6
+
7
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_char_map, device="cuda"):
8
+ model_arc = model_cfg.model.arch
9
+ mel_spec_kwargs = model_cfg.model.mel_spec
10
+ vocab_size = len(vocab_char_map)
11
+
12
+ backbone = model_cls(
13
+ **model_arc, text_num_embeds=vocab_size, mel_dim=mel_spec_kwargs.n_mel_channels
14
+ )
15
+
16
+ model = Singer(
17
+ transformer=backbone,
18
+ mel_spec_kwargs=mel_spec_kwargs,
19
+ vocab_char_map=vocab_char_map,
20
+ )
21
+
22
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
23
+ if "ema_model_state_dict" in checkpoint:
24
+ state_dict = checkpoint["ema_model_state_dict"]
25
+ elif "model_state_dict" in checkpoint:
26
+ state_dict = checkpoint["model_state_dict"]
27
+ else:
28
+ state_dict = checkpoint
29
+
30
+ # Handle module prefix
31
+ new_state_dict = {}
32
+ for k, v in state_dict.items():
33
+ if k.startswith("module."):
34
+ new_state_dict[k[7:]] = v
35
+ else:
36
+ new_state_dict[k] = v
37
+
38
+ model.load_state_dict(new_state_dict)
39
+ model.to(device)
40
+ model.eval()
41
+ return model
42
+
43
+
44
+ def load_vocoder(vocoder_name, is_local, local_path, device="cuda"):
45
+ if vocoder_name == "vocos":
46
+ if is_local:
47
+ vocoder = Vocos.from_hparams(local_path).to(device)
48
+ else:
49
+ vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
50
+ elif vocoder_name == "bigvgan":
51
+ # Placeholder for bigvgan
52
+ # You might need to import bigvgan here
53
+ raise NotImplementedError("BigVGAN loading not implemented yet")
54
+ else:
55
+ # Fallback or error
56
+ print(
57
+ f"Warning: Unknown vocoder {vocoder_name}, trying to load from local path if provided"
58
+ )
59
+ if is_local:
60
+ # Try loading as vocos or similar if generic
61
+ vocoder = Vocos.from_hparams(local_path).to(device)
62
+ else:
63
+ raise ValueError(f"Unknown vocoder: {vocoder_name}")
64
+ return vocoder
src/YingMusicSinger/utils/cnen_tokenizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ class CNENTokenizer:
5
+ def __init__(self):
6
+ with open(
7
+ "./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json",
8
+ "r",
9
+ encoding="utf-8",
10
+ ) as file:
11
+ self.phone2id: dict = json.load(file)["vocab"]
12
+ self.phone2id = {k: int(v) + 1 for (k, v) in self.phone2id.items()}
13
+
14
+ self.pad_token_id = 0
15
+ self.phone2id["<PAD>"] = 0
16
+
17
+ self.punct_token_id = len(self.phone2id) # Punctuation marks tokens
18
+ self.phone2id["<PUNCT>"] = len(self.phone2id)
19
+
20
+ self.sep_token_id = len(self.phone2id) # Sentence separation token
21
+ self.phone2id["<SEP>"] = len(self.phone2id)
22
+
23
+ self.id2phone = {v: k for (k, v) in self.phone2id.items()}
24
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p_generation import chn_eng_g2p
25
+
26
+ self.tokenizer = chn_eng_g2p
27
+
28
+ def encode(self, text):
29
+ phone, token = self.tokenizer(text)
30
+ token = [x + 1 for x in token]
31
+ return token
32
+
33
+ def decode(self, token):
34
+ return "|".join([self.id2phone[x] for x in token])
src/YingMusicSinger/utils/common.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+
7
+ import jieba
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from pypinyin import Style, lazy_pinyin
11
+ from torch.nn.utils.rnn import pad_sequence
12
+
13
+ # seed everything
14
+
15
+
16
+ def seed_everything(seed=0):
17
+ random.seed(seed)
18
+ os.environ["PYTHONHASHSEED"] = str(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+
25
+
26
+ # helpers
27
+
28
+
29
+ def exists(v):
30
+ return v is not None
31
+
32
+
33
+ def default(v, d):
34
+ return v if exists(v) else d
35
+
36
+
37
+ def is_package_available(package_name: str) -> bool:
38
+ try:
39
+ import importlib
40
+
41
+ package_exists = importlib.util.find_spec(package_name) is not None
42
+ return package_exists
43
+ except Exception:
44
+ return False
45
+
46
+
47
+ # tensor helpers
48
+
49
+
50
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
51
+ if not exists(length):
52
+ length = t.amax()
53
+
54
+ seq = torch.arange(length, device=t.device)
55
+ return seq[None, :] < t[:, None]
56
+
57
+
58
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
59
+ max_seq_len = seq_len.max().item()
60
+ seq = torch.arange(max_seq_len, device=start.device).long()
61
+ start_mask = seq[None, :] >= start[:, None]
62
+ end_mask = seq[None, :] < end[:, None]
63
+ return start_mask & end_mask
64
+
65
+
66
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
67
+ lengths = (frac_lengths * seq_len).long()
68
+ max_start = seq_len - lengths
69
+
70
+ rand = torch.rand_like(frac_lengths)
71
+ start = (max_start * rand).long().clamp(min=0)
72
+ end = start + lengths
73
+
74
+ return mask_from_start_end_indices(seq_len, start, end)
75
+
76
+
77
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
78
+ if not exists(mask):
79
+ return t.mean(dim=1)
80
+
81
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
82
+ num = t.sum(dim=1)
83
+ den = mask.float().sum(dim=1)
84
+
85
+ return num / den.clamp(min=1.0)
86
+
87
+
88
+ # simple utf-8 tokenizer, since paper went character based
89
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
90
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
91
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
92
+ return text
93
+
94
+
95
+ # char tokenizer, based on custom dataset's extracted .txt file
96
+ def list_str_to_idx(
97
+ text: list[str] | list[list[str]],
98
+ vocab_char_map: dict[str, int], # {char: idx}
99
+ padding_value=-1,
100
+ ) -> int["b nt"]: # noqa: F722
101
+ list_idx_tensors = [
102
+ torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text
103
+ ] # pinyin or char style
104
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
105
+ return text
106
+
107
+
108
+ # Get tokenizer
109
+
110
+
111
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
112
+ """
113
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
114
+ - "char" for char-wise tokenizer, need .txt vocab_file
115
+ - "byte" for utf-8 tokenizer
116
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
117
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
118
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
119
+ - if use "byte", set to 256 (unicode byte range)
120
+ """
121
+ if tokenizer in ["pinyin", "char"]:
122
+ # tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
123
+ # tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}/vocab.txt")
124
+ tokenizer_path = (
125
+ "/ailab-train/speech/zhengjunjie/opt/models/F5-TTS/F5TTS_v1_Base/vocab.txt"
126
+ )
127
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
128
+ vocab_char_map = {}
129
+ for i, char in enumerate(f):
130
+ vocab_char_map[char[:-1]] = i
131
+ vocab_size = len(vocab_char_map)
132
+ assert vocab_char_map[" "] == 0, (
133
+ "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
134
+ )
135
+
136
+ elif tokenizer == "byte":
137
+ vocab_char_map = None
138
+ vocab_size = 256
139
+
140
+ elif tokenizer == "custom":
141
+ with open(dataset_name, "r", encoding="utf-8") as f:
142
+ vocab_char_map = {}
143
+ for i, char in enumerate(f):
144
+ vocab_char_map[char[:-1]] = i
145
+ vocab_size = len(vocab_char_map)
146
+
147
+ return vocab_char_map, vocab_size
148
+
149
+
150
+ # convert char to pinyin
151
+
152
+
153
+ def convert_char_to_pinyin(text_list, polyphone=True, with_tone=True):
154
+ if with_tone:
155
+ style = Style.TONE3 # with tone number
156
+ else:
157
+ style = Style.NORMAL # no tone
158
+
159
+ if jieba.dt.initialized is False:
160
+ jieba.default_logger.setLevel(50) # CRITICAL
161
+ jieba.initialize()
162
+
163
+ final_text_list = []
164
+ custom_trans = str.maketrans(
165
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
166
+ ) # add custom trans here, to address oov
167
+
168
+ def is_chinese(c):
169
+ return (
170
+ "\u3100" <= c <= "\u9fff" # common chinese characters
171
+ )
172
+
173
+ for text in text_list:
174
+ char_list = []
175
+ text = text.translate(custom_trans)
176
+ for seg in jieba.cut(text):
177
+ seg_byte_len = len(bytes(seg, "UTF-8"))
178
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
179
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
180
+ char_list.append(" ")
181
+ char_list.extend(seg)
182
+ elif polyphone and seg_byte_len == 3 * len(
183
+ seg
184
+ ): # if pure east asian characters
185
+ seg_ = lazy_pinyin(seg, style=style, tone_sandhi=True)
186
+ for i, c in enumerate(seg):
187
+ if is_chinese(c):
188
+ char_list.append(" ")
189
+ char_list.append(seg_[i])
190
+ else: # if mixed characters, alphabets and symbols
191
+ for c in seg:
192
+ if ord(c) < 256:
193
+ char_list.extend(c)
194
+ elif is_chinese(c):
195
+ char_list.append(" ")
196
+ char_list.extend(lazy_pinyin(c, style=style, tone_sandhi=True))
197
+ else:
198
+ char_list.append(c)
199
+
200
+ if with_tone is False:
201
+ for idx, item in enumerate(char_list):
202
+ char_list[idx] = "__" + item
203
+
204
+ final_text_list.append(char_list)
205
+
206
+ return final_text_list
207
+
208
+
209
+ # filter func for dirty data with many repetitions
210
+
211
+
212
+ def repetition_found(text, length=2, tolerance=10):
213
+ pattern_count = defaultdict(int)
214
+ for i in range(len(text) - length + 1):
215
+ pattern = text[i : i + length]
216
+ pattern_count[pattern] += 1
217
+ for pattern, count in pattern_count.items():
218
+ if count > tolerance:
219
+ return True
220
+ return False
221
+
222
+
223
+ # get the empirically pruned step for sampling
224
+
225
+
226
+ def get_epss_timesteps(n, device, dtype):
227
+ dt = 1 / 32
228
+ predefined_timesteps = {
229
+ 5: [0, 2, 4, 8, 16, 32],
230
+ 6: [0, 2, 4, 6, 8, 16, 32],
231
+ 7: [0, 2, 4, 6, 8, 16, 24, 32],
232
+ 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32],
233
+ 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32],
234
+ 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32],
235
+ }
236
+ t = predefined_timesteps.get(n, [])
237
+ if not t:
238
+ return torch.linspace(0, 1, n + 1, device=device, dtype=dtype)
239
+ return dt * torch.tensor(t, device=device, dtype=dtype)
240
+
241
+
242
+ def calculate_similarity_matrix_with_mask(
243
+ vectors: torch.Tensor, valid_mask: torch.Tensor = None
244
+ ) -> torch.Tensor:
245
+ if valid_mask is None:
246
+ valid_mask = torch.ones(
247
+ vectors.shape[:-1], dtype=torch.bool, device=vectors.device
248
+ )
249
+
250
+ if valid_mask.dtype != torch.bool:
251
+ valid_mask = valid_mask.bool()
252
+
253
+ vectors = vectors * valid_mask.unsqueeze(-1).float()
254
+
255
+ vectors_normalized = F.normalize(vectors, p=2, dim=-1, eps=1e-8)
256
+
257
+ # (B, N, D) * (B, D, N) -> (B, N, N)
258
+ similarity_matrix = torch.bmm(
259
+ vectors_normalized, vectors_normalized.transpose(1, 2)
260
+ )
261
+
262
+ # (B, N, 1) & (B, 1, N) -> (B, N, N)
263
+ combined_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
264
+
265
+ similarity_matrix.masked_fill_(~combined_mask, 0.0)
266
+
267
+ return similarity_matrix
268
+
269
+
270
+ def _center_gram_batch(gram: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
271
+ """Center Gram matrices in batch.
272
+
273
+ Args:
274
+ gram: [B, N, N] Gram matrices.
275
+ mask: [B, N] optional validity mask.
276
+
277
+ Returns:
278
+ Centered Gram matrices [B, N, N].
279
+ """
280
+ if mask is None:
281
+ gram = gram - gram.mean(dim=2, keepdim=True)
282
+ gram = gram - gram.mean(dim=1, keepdim=True)
283
+ return gram
284
+ else:
285
+ mask_float = mask.float()
286
+ n_valid = mask_float.sum(dim=1, keepdim=True).clamp(min=1.0)
287
+
288
+ mask_mat = mask_float.unsqueeze(2) * mask_float.unsqueeze(1) # [B, N, N]
289
+ gram = gram * mask_mat
290
+
291
+ row_mean = gram.sum(dim=2, keepdim=True) / n_valid.unsqueeze(2)
292
+ col_mean = gram.sum(dim=1, keepdim=True) / n_valid.unsqueeze(1)
293
+ grand_mean = row_mean.sum(dim=1, keepdim=True) / n_valid.unsqueeze(2)
294
+
295
+ centered = gram - row_mean - col_mean + grand_mean
296
+ return centered * mask_mat
297
+
298
+
299
+ def cka_loss(
300
+ sim_x: torch.Tensor, sim_y: torch.Tensor, valid_mask: torch.Tensor = None
301
+ ) -> torch.Tensor:
302
+ """Compute CKA loss between two similarity matrices in batch.
303
+
304
+ Args:
305
+ sim_x: [B, N, N] similarity matrix.
306
+ sim_y: [B, N, N] similarity matrix.
307
+ valid_mask: [B, N] optional validity mask.
308
+
309
+ Returns:
310
+ Scalar CKA loss (1 - mean CKA similarity).
311
+ """
312
+ eps = 1e-6
313
+
314
+ sim_x_c = _center_gram_batch(sim_x, valid_mask) # [B, N, N]
315
+ sim_y_c = _center_gram_batch(sim_y, valid_mask) # [B, N, N]
316
+
317
+ # HSIC via element-wise product summed over spatial dims
318
+ hsic = torch.sum(sim_x_c * sim_y_c, dim=(1, 2)) # [B]
319
+
320
+ norm_x = torch.sqrt(torch.sum(sim_x_c**2, dim=(1, 2)) + eps) # [B]
321
+ norm_y = torch.sqrt(torch.sum(sim_y_c**2, dim=(1, 2)) + eps) # [B]
322
+
323
+ cka_similarity = hsic / (norm_x * norm_y + eps) # [B]
324
+
325
+ return torch.mean(1.0 - cka_similarity)
src/YingMusicSinger/utils/lrc_align.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def align_lrc_put_to_front(tokenizer, lrc_start_times, lrc_lines, total_lens):
5
+ lrc_text_list = []
6
+ lrc_token = np.zeros(total_lens, dtype=np.int64)
7
+
8
+ token_start = 0
9
+ for temp in lrc_lines:
10
+ # for punct in ",。!?、;:,.!?;:":
11
+ # one_line_lrc = one_line_lrc.replace(punct, ",")
12
+ # one_line_lrc = one_line_lrc.strip(",。!?、;:,.!?;: ")
13
+ for one_line_lrc in temp.split("|"):
14
+ lrc_text_list.append(one_line_lrc)
15
+ one_line_token = tokenizer.encode(one_line_lrc)
16
+ lrc_text_list.append("<SEP>")
17
+ one_line_token = one_line_token + [tokenizer.phone2id["<SEP>"]]
18
+
19
+ one_line_token = np.array(one_line_token)
20
+ assert token_start + len(one_line_token) <= len(lrc_token), (
21
+ "lrc_token 的长度超过了 vocal latent"
22
+ )
23
+ lrc_token[token_start : token_start + len(one_line_token)] = one_line_token
24
+ token_start = token_start + len(one_line_token)
25
+ return lrc_token, "".join(lrc_text_list)
26
+
27
+
28
+ def align_lrc_sentence_level(
29
+ tokenizer, lrc_start_times, lrc_lines, total_lens, vae_frame_rate
30
+ ):
31
+ # BUG Only the prompt and the two segments to be generated have start timestamps, the generated content and the prompt do not contain anything like <SEP>.
32
+ lrc_text_list = []
33
+ lrc_token = np.zeros(total_lens, dtype=np.int64)
34
+
35
+ token_start = 0
36
+ for lrc_start_time, one_line_lrc in zip(lrc_start_times, lrc_lines):
37
+ one_line_lrc = one_line_lrc.replace("|", " ")
38
+ for punct in ",。!?、;:,.!?;:":
39
+ one_line_lrc = one_line_lrc.replace(punct, ",")
40
+ one_line_lrc = one_line_lrc.strip(",。!?、;:,.!?;: ")
41
+
42
+ lrc_text_list.append(one_line_lrc)
43
+ one_line_token = tokenizer.encode(one_line_lrc)
44
+ lrc_text_list.append("<SEP>")
45
+ one_line_token = one_line_token + [tokenizer.phone2id["<SEP>"]]
46
+
47
+ one_line_token = np.array(one_line_token)
48
+
49
+ timestamp_cal_start_frame = int(lrc_start_time * vae_frame_rate)
50
+
51
+ # Handling Postponement Situations
52
+ timestamp_cal_start_frame = max(timestamp_cal_start_frame, token_start)
53
+
54
+ assert timestamp_cal_start_frame + len(one_line_token) <= len(lrc_token), (
55
+ "The length of the lrc_token exceeds that of the vocal latent"
56
+ )
57
+ lrc_token[
58
+ timestamp_cal_start_frame : timestamp_cal_start_frame + len(one_line_token)
59
+ ] = one_line_token
60
+ token_start = timestamp_cal_start_frame + len(one_line_token)
61
+ return lrc_token, "".join(lrc_text_list)
src/YingMusicSinger/utils/mel_spectrogram.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+
6
+
7
+ class MelodySpectrogram(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_mel_channels=80,
11
+ sampling_rate=44100,
12
+ win_length=2048,
13
+ hop_length=512,
14
+ n_fft=None,
15
+ mel_fmin=0,
16
+ mel_fmax=None,
17
+ clamp=1e-5,
18
+ ):
19
+ from librosa.filters import mel
20
+
21
+ super().__init__()
22
+ n_fft = win_length if n_fft is None else n_fft
23
+ self.hann_window = {}
24
+ mel_basis = mel(
25
+ sr=sampling_rate,
26
+ n_fft=n_fft,
27
+ n_mels=n_mel_channels,
28
+ fmin=mel_fmin,
29
+ fmax=mel_fmax,
30
+ htk=True,
31
+ )
32
+ mel_basis = torch.from_numpy(mel_basis).float()
33
+ self.register_buffer("mel_basis", mel_basis)
34
+ self.n_fft = n_fft
35
+ self.hop_length = hop_length
36
+ self.win_length = win_length
37
+ self.sampling_rate = sampling_rate
38
+ self.n_mel_channels = n_mel_channels
39
+ self.clamp = clamp
40
+
41
+ def _mel_forward(self, audio, keyshift=0, speed=1, center=True):
42
+ factor = 2 ** (keyshift / 12)
43
+ n_fft_new = int(np.round(self.n_fft * factor))
44
+ win_length_new = int(np.round(self.win_length * factor))
45
+ hop_length_new = int(np.round(self.hop_length * speed))
46
+
47
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
48
+ if keyshift_key not in self.hann_window:
49
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
50
+ audio.device
51
+ )
52
+
53
+ fft = torch.stft(
54
+ audio,
55
+ n_fft=n_fft_new,
56
+ hop_length=hop_length_new,
57
+ win_length=win_length_new,
58
+ window=self.hann_window[keyshift_key],
59
+ center=center,
60
+ return_complex=True,
61
+ )
62
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
63
+
64
+ if keyshift != 0:
65
+ size = self.n_fft // 2 + 1
66
+ resize = magnitude.size(1)
67
+ if resize < size:
68
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
69
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
70
+
71
+ mel_output = torch.matmul(self.mel_basis, magnitude)
72
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
73
+ return log_mel_spec
74
+
75
+ @torch.no_grad()
76
+ def forward(self, audio, sr, sil_len_to_end=None, keyshift=0, speed=1):
77
+ # audio, sr = torchaudio.load(audio_path)
78
+ if sil_len_to_end is not None:
79
+ silence = torch.zeros(audio.shape[0], int(sr * sil_len_to_end))
80
+ audio = torch.cat([audio, silence], dim=1)
81
+ if sr != self.sampling_rate:
82
+ audio = torchaudio.transforms.Resample(sr, self.sampling_rate)(audio)
83
+ if audio.shape[0] > 1:
84
+ audio = torch.mean(audio, dim=0, keepdim=True)
85
+ audio = audio.to(self.mel_basis.device)
86
+ return self._mel_forward(audio, keyshift=keyshift, speed=speed)