xjsc0 commited on
Commit
64ec292
·
1 Parent(s): 4ed4cff
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +47 -0
  2. src/YingMusicSinger/infer/YingMusicSinger.py +263 -0
  3. src/YingMusicSinger/melody/Gconform.py +298 -0
  4. src/YingMusicSinger/melody/Gconv.py +60 -0
  5. src/YingMusicSinger/melody/SmoothMelody.py +144 -0
  6. src/YingMusicSinger/melody/midi_extractor.py +208 -0
  7. src/YingMusicSinger/models/__init__.py +1 -0
  8. src/YingMusicSinger/models/dit.py +472 -0
  9. src/YingMusicSinger/models/model.py +423 -0
  10. src/YingMusicSinger/models/modules.py +961 -0
  11. src/YingMusicSinger/utils/f5_tts/g2p/g2p/__init__.py +91 -0
  12. src/YingMusicSinger/utils/f5_tts/g2p/g2p/chinese_model_g2p.py +209 -0
  13. src/YingMusicSinger/utils/f5_tts/g2p/g2p/cleaners.py +28 -0
  14. src/YingMusicSinger/utils/f5_tts/g2p/g2p/english.py +202 -0
  15. src/YingMusicSinger/utils/f5_tts/g2p/g2p/french.py +149 -0
  16. src/YingMusicSinger/utils/f5_tts/g2p/g2p/german.py +94 -0
  17. src/YingMusicSinger/utils/f5_tts/g2p/g2p/korean.py +81 -0
  18. src/YingMusicSinger/utils/f5_tts/g2p/g2p/mandarin.py +603 -0
  19. src/YingMusicSinger/utils/f5_tts/g2p/g2p/text_tokenizers.py +82 -0
  20. src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json +372 -0
  21. src/YingMusicSinger/utils/f5_tts/g2p/g2p_generation.py +129 -0
  22. src/YingMusicSinger/utils/f5_tts/g2p/infer_dpo.py +277 -0
  23. src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt +41 -0
  24. src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt +3 -0
  25. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/config.json +819 -0
  26. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/poly_bert_model.onnx +3 -0
  27. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polychar.txt +159 -0
  28. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict.json +393 -0
  29. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict_r.json +393 -0
  30. src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/vocab.txt +0 -0
  31. src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt +429 -0
  32. src/YingMusicSinger/utils/f5_tts/g2p/utils/front_utils.py +18 -0
  33. src/YingMusicSinger/utils/f5_tts/g2p/utils/g2p.py +139 -0
  34. src/YingMusicSinger/utils/f5_tts/g2p/utils/log.py +52 -0
  35. src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json +335 -0
  36. src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/LangSegment.py +1251 -0
  37. src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/__init__.py +24 -0
  38. src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/__init__.py +0 -0
  39. src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/num.py +332 -0
  40. src/YingMusicSinger/utils/stable_audio_tools/__init__.py +0 -0
  41. src/YingMusicSinger/utils/stable_audio_tools/adp.py +1686 -0
  42. src/YingMusicSinger/utils/stable_audio_tools/autoencoders.py +975 -0
  43. src/YingMusicSinger/utils/stable_audio_tools/blocks.py +398 -0
  44. src/YingMusicSinger/utils/stable_audio_tools/bottleneck copy.py +393 -0
  45. src/YingMusicSinger/utils/stable_audio_tools/bottleneck.py +393 -0
  46. src/YingMusicSinger/utils/stable_audio_tools/conditioners.py +664 -0
  47. src/YingMusicSinger/utils/stable_audio_tools/diffusion.py +740 -0
  48. src/YingMusicSinger/utils/stable_audio_tools/dit.py +451 -0
  49. src/YingMusicSinger/utils/stable_audio_tools/factory.py +185 -0
  50. src/YingMusicSinger/utils/stable_audio_tools/pretransforms.py +425 -0
.gitattributes ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.gif filter=lfs diff=lfs merge=lfs -text
3
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
4
+ *.avi filter=lfs diff=lfs merge=lfs -text
5
+ *.dylib filter=lfs diff=lfs merge=lfs -text
6
+ *.npz filter=lfs diff=lfs merge=lfs -text
7
+ *.svg filter=lfs diff=lfs merge=lfs -text
8
+ *.wav filter=lfs diff=lfs merge=lfs -text
9
+ *.m4a filter=lfs diff=lfs merge=lfs -text
10
+ *.zip filter=lfs diff=lfs merge=lfs -text
11
+ *.pth filter=lfs diff=lfs merge=lfs -text
12
+ *.pkl filter=lfs diff=lfs merge=lfs -text
13
+ *.mkv filter=lfs diff=lfs merge=lfs -text
14
+ *.tar filter=lfs diff=lfs merge=lfs -text
15
+ *.docx filter=lfs diff=lfs merge=lfs -text
16
+ *.ppt filter=lfs diff=lfs merge=lfs -text
17
+ *.pptx filter=lfs diff=lfs merge=lfs -text
18
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
19
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
20
+ *.ico filter=lfs diff=lfs merge=lfs -text
21
+ *.flac filter=lfs diff=lfs merge=lfs -text
22
+ *.ogg filter=lfs diff=lfs merge=lfs -text
23
+ *.xls filter=lfs diff=lfs merge=lfs -text
24
+ *.pickle filter=lfs diff=lfs merge=lfs -text
25
+ *.webm filter=lfs diff=lfs merge=lfs -text
26
+ *.doc filter=lfs diff=lfs merge=lfs -text
27
+ *.onnx filter=lfs diff=lfs merge=lfs -text
28
+ *.gz filter=lfs diff=lfs merge=lfs -text
29
+ *.bin filter=lfs diff=lfs merge=lfs -text
30
+ *.h5 filter=lfs diff=lfs merge=lfs -text
31
+ *.png filter=lfs diff=lfs merge=lfs -text
32
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
33
+ *.7z filter=lfs diff=lfs merge=lfs -text
34
+ *.rar filter=lfs diff=lfs merge=lfs -text
35
+ *.pdf filter=lfs diff=lfs merge=lfs -text
36
+ *.xlsx filter=lfs diff=lfs merge=lfs -text
37
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
38
+ *.dll filter=lfs diff=lfs merge=lfs -text
39
+ *.npy filter=lfs diff=lfs merge=lfs -text
40
+ *.bmp filter=lfs diff=lfs merge=lfs -text
41
+ *.mov filter=lfs diff=lfs merge=lfs -text
42
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
43
+ *.pt filter=lfs diff=lfs merge=lfs -text
44
+ *.so filter=lfs diff=lfs merge=lfs -text
45
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
46
+ *.ttf filter=lfs diff=lfs merge=lfs -text
47
+ src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
src/YingMusicSinger/infer/YingMusicSinger.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchaudio
5
+ from einops import rearrange
6
+ from ema_pytorch import EMA
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from omegaconf import OmegaConf
9
+
10
+ from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
11
+ from src.YingMusicSinger.models.model import Singer
12
+ from src.YingMusicSinger.utils.cnen_tokenizer import CNENTokenizer
13
+ from src.YingMusicSinger.utils.lrc_align import (
14
+ align_lrc_put_to_front,
15
+ align_lrc_sentence_level,
16
+ )
17
+ from src.YingMusicSinger.utils.mel_spectrogram import MelodySpectrogram
18
+ from src.YingMusicSinger.utils.stable_audio_tools.vae_copysyn import StableAudioInfer
19
+
20
+
21
+ class YingMusicSinger(nn.Module, PyTorchModelHubMixin):
22
+ def __init__(
23
+ self,
24
+ model_cfg_path,
25
+ ckpt_path=None,
26
+ vae_config_path=None,
27
+ vae_ckpt_path=None,
28
+ midi_teacher_ckpt_path=None,
29
+ is_distilled=False,
30
+ use_ema=True,
31
+ ):
32
+ super().__init__()
33
+ self.cfg = OmegaConf.load(model_cfg_path)
34
+ model_cls = hydra.utils.get_class(
35
+ f"src.YingMusicSinger.models.{self.cfg.model.backbone}"
36
+ )
37
+ self.melody_input_source = self.cfg.model.melody_input_source
38
+ self.is_tts_pretrain = self.cfg.model.is_tts_pretrain
39
+
40
+ self.model = Singer(
41
+ transformer=model_cls(
42
+ **self.cfg.model.arch,
43
+ text_num_embeds=self.cfg.datasets_cfg.text_num_embeds,
44
+ mel_dim=self.cfg.model.mel_spec.n_mel_channels,
45
+ use_guidance_scale_embed=is_distilled,
46
+ ),
47
+ mel_spec_kwargs=self.cfg.model.mel_spec,
48
+ is_tts_pretrain=self.is_tts_pretrain,
49
+ melody_input_source=self.melody_input_source,
50
+ cka_disabled=self.cfg.model.cka_disabled,
51
+ num_channels=None,
52
+ extra_parameters=self.cfg.extra_parameters,
53
+ distill_stage=1,
54
+ use_guidance_scale_embed=is_distilled,
55
+ )
56
+
57
+ self.vae = StableAudioInfer(
58
+ model_config_path=vae_config_path,
59
+ model_ckpt_path=vae_ckpt_path,
60
+ )
61
+
62
+ self._need_midi = self.melody_input_source in {
63
+ "some_pretrain",
64
+ "some_pretrain_fuzzdisturb",
65
+ "some_pretrain_postprocess_embedding",
66
+ }
67
+ self.midi_teacher = None
68
+ if self._need_midi:
69
+ self.midi_teacher = MIDIExtractor()
70
+ if midi_teacher_ckpt_path is not None:
71
+ self.midi_teacher._load_form_ckpt(midi_teacher_ckpt_path)
72
+ for p in self.midi_teacher.parameters():
73
+ p.requires_grad = False
74
+
75
+ self.melody_spectrogram_extract = MelodySpectrogram()
76
+
77
+ self.vae_frame_rate = 44100 / 2048
78
+
79
+ if ckpt_path is not None:
80
+ ckpt = torch.load(ckpt_path, map_location="cpu")
81
+ if use_ema:
82
+ ema_model = EMA(self.model, include_online_model=False)
83
+ ema_model.load_state_dict(ckpt["ema_model_state_dict"])
84
+
85
+ self.model = ema_model.ema_model
86
+ else:
87
+ self.model.load_state_dict(ckpt["model_state_dict"])
88
+
89
+ self.cnen_tokenizer = CNENTokenizer()
90
+
91
+ @property
92
+ def device(self):
93
+ return next(self.parameters()).device
94
+
95
+ def prepare_input(
96
+ self,
97
+ ref_audio_path,
98
+ melody_audio_path,
99
+ ref_text,
100
+ target_text,
101
+ sil_len_to_end,
102
+ lrc_align_mode,
103
+ ):
104
+ ref_audio, ref_audio_sr = torchaudio.load(ref_audio_path)
105
+ silence = torch.zeros(ref_audio.shape[0], int(ref_audio_sr * sil_len_to_end))
106
+ ref_wav = torch.cat([ref_audio, silence], dim=1)
107
+ ref_latent = self.vae.encode_audio(ref_wav, in_sr=ref_audio_sr).transpose(
108
+ 1, 2
109
+ ) # [B, T, D]
110
+
111
+ melody_wav, melody_sr = torchaudio.load(melody_audio_path)
112
+ melody_latent = self.vae.encode_audio(melody_wav, in_sr=melody_sr).transpose(
113
+ 1, 2
114
+ ) # [B, T, D]
115
+
116
+ midi_in = torch.cat([ref_latent, melody_latent], dim=1)
117
+ if self.is_tts_pretrain:
118
+ midi_in = torch.zeros_like(midi_in)
119
+
120
+ ref_latent_len = ref_latent.shape[1]
121
+ total_len = int(ref_latent.shape[1] + melody_latent.shape[1])
122
+
123
+ if self._need_midi:
124
+ ref_mel = self.melody_spectrogram_extract(audio=ref_wav, sr=ref_audio_sr)
125
+ melody_mel = self.melody_spectrogram_extract(audio=melody_wav, sr=melody_sr)
126
+ melody_mel_spec = torch.cat([ref_mel, melody_mel], dim=2)
127
+ else:
128
+ raise NotImplementedError()
129
+
130
+ assert isinstance(ref_text, str) and isinstance(target_text, str)
131
+ text_list = [ref_text] + [target_text]
132
+
133
+ if lrc_align_mode == "put_to_front":
134
+ lrc_token, _ = align_lrc_put_to_front(
135
+ tokenizer=self.cnen_tokenizer,
136
+ lrc_start_times=None,
137
+ lrc_lines=text_list,
138
+ total_lens=total_len,
139
+ )
140
+ elif lrc_align_mode == "sentence_level":
141
+ lrc_token, _ = align_lrc_sentence_level(
142
+ tokenizer=self.cnen_tokenizer,
143
+ lrc_start_times=[0.0, ref_latent_len / self.vae_frame_rate],
144
+ lrc_lines=text_list,
145
+ total_lens=total_len,
146
+ vae_frame_rate=self.vae_frame_rate,
147
+ )
148
+ else:
149
+ raise ValueError(f"Unsupported lrc_align_mode: {lrc_align_mode}")
150
+
151
+ text_tokens = (
152
+ torch.tensor(lrc_token, dtype=torch.int64).unsqueeze(0).to(self.device)
153
+ )
154
+
155
+ midi_p, bound_p = None, None
156
+ if self._need_midi:
157
+ with torch.no_grad():
158
+ midi_p, bound_p = self.midi_teacher(melody_mel_spec.transpose(1, 2))
159
+
160
+ return (
161
+ ref_latent,
162
+ ref_latent_len,
163
+ text_tokens,
164
+ total_len,
165
+ midi_in,
166
+ midi_p,
167
+ bound_p,
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ ref_audio_path,
173
+ melody_audio_path,
174
+ ref_text,
175
+ target_text,
176
+ lrc_align_mode: str = "sentence_level",
177
+ sil_len_to_end: float = 0.5,
178
+ t_shift: float = 0.5,
179
+ nfe_step: int = 32,
180
+ cfg_strength: float = 3.0,
181
+ seed: int = 666,
182
+ is_tts_pretrain: bool = False,
183
+ ):
184
+ """
185
+ Args:
186
+ ref_audio_path: Path to the reference audio (for timbre)
187
+ melody_audio_path: Path to the melody reference audio (provides target duration and melody information)
188
+ ref_text: Text corresponding to the reference audio
189
+ target_text: Target text to be synthesized
190
+ lrc_align_mode: Lyric alignment mode "sentence_level" | "put_to_front"
191
+ sil_len_to_end: Duration of silence appended to the end of the reference audio (seconds)
192
+ t_shift: Sampling time offset
193
+ nfe_step: ODE sampling steps
194
+ cfg_strength: CFG strength
195
+ seed: Random seed
196
+ is_tts_pretrain: If True, melody is not provided (TTS mode)
197
+ """
198
+ ref_latent, ref_latent_len, text_tokens, total_len, midi_in, midi_p, bound_p = (
199
+ self.prepare_input(
200
+ ref_audio_path=ref_audio_path,
201
+ melody_audio_path=melody_audio_path,
202
+ ref_text=ref_text,
203
+ target_text=target_text,
204
+ sil_len_to_end=sil_len_to_end,
205
+ lrc_align_mode=lrc_align_mode,
206
+ )
207
+ )
208
+
209
+ assert midi_p is not None and bound_p is not None
210
+ with torch.inference_mode():
211
+ generated_latent, _ = self.model.sample(
212
+ cond=ref_latent,
213
+ midi_in=midi_in,
214
+ text=text_tokens,
215
+ duration=total_len,
216
+ steps=nfe_step,
217
+ cfg_strength=cfg_strength,
218
+ sway_sampling_coef=None,
219
+ use_epss=False,
220
+ seed=seed,
221
+ midi_p=midi_p,
222
+ t_shift=t_shift,
223
+ bound_p=bound_p,
224
+ guidance_scale=cfg_strength,
225
+ )
226
+ generated_latent = generated_latent.to(torch.float32)
227
+ generated_latent = generated_latent[:, ref_latent_len:, :]
228
+ generated_latent = generated_latent.permute(0, 2, 1) # [B, D, T]
229
+
230
+ generated_audio = self.vae.decode_audio(generated_latent)
231
+ audio = rearrange(generated_audio, "b d n -> d (b n)")
232
+
233
+ audio = audio.to(torch.float32).cpu()
234
+
235
+ return audio, 44100
236
+
237
+
238
+ if __name__ == "__main__":
239
+ # === Export to HuggingFace safetensors (optional) ===
240
+ # model = YingMusicSinger(
241
+ # model_cfg_path="src/YingMusicSinger/config/YingMusic_Singer.yaml",
242
+ # ckpt_path="ckpts/YingMusicSinger_model.pt",
243
+ # vae_config_path="src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json",
244
+ # vae_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
245
+ # midi_teacher_ckpt_path="ckpts/model_ckpt_steps_100000_simplified.ckpt",
246
+ # )
247
+ # model.save_pretrained("path/to/save")
248
+
249
+ # === Inference Example ===
250
+ model = YingMusicSinger.from_pretrained("ASLP-lab/YingMusic-Singer")
251
+ model.to("cuda:0")
252
+ model.eval()
253
+
254
+ waveform, sample_rate = model(
255
+ ref_audio_path="path/to/ref_audio", # Timbre reference audio
256
+ melody_audio_path="path/to/melody_audio", # Melody-providing singing clip
257
+ ref_text="oh the reason i hold on", # Lyrics corresponding to ref_audio
258
+ target_text="oldest book broken watch|bare feet in grassy spot", # Modified target lyrics
259
+ seed=42,
260
+ )
261
+
262
+ torchaudio.save("output.wav", waveform, sample_rate=sample_rate)
263
+ print("Saved to output.wav")
src/YingMusicSinger/melody/Gconform.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+
7
+ class GLU(nn.Module):
8
+ def __init__(self, dim):
9
+ super().__init__()
10
+ self.dim = dim
11
+
12
+ def forward(self, x):
13
+ out, gate = x.chunk(2, dim=self.dim)
14
+
15
+ return out * gate.sigmoid()
16
+
17
+
18
+ class conform_conv(nn.Module):
19
+ def __init__(
20
+ self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
21
+ ):
22
+ super().__init__()
23
+ self.act2 = nn.SiLU()
24
+ self.act1 = GLU(1)
25
+
26
+ self.pointwise_conv1 = nn.Conv1d(
27
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
28
+ )
29
+
30
+ # self.lorder is used to distinguish if it's a causal convolution,
31
+ # if self.lorder > 0:
32
+ # it's a causal convolution, the input will be padded with
33
+ # `self.lorder` frames on the left in forward (causal conv impl).
34
+ # else: it's a symmetrical convolution
35
+
36
+ assert (kernel_size - 1) % 2 == 0
37
+ padding = (kernel_size - 1) // 2
38
+
39
+ self.depthwise_conv = nn.Conv1d(
40
+ channels,
41
+ channels,
42
+ kernel_size,
43
+ stride=1,
44
+ padding=padding,
45
+ groups=channels,
46
+ bias=bias,
47
+ )
48
+
49
+ self.norm = nn.BatchNorm1d(channels)
50
+
51
+ self.pointwise_conv2 = nn.Conv1d(
52
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
53
+ )
54
+ self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()
55
+
56
+ def forward(self, x):
57
+ x = x.transpose(1, 2)
58
+ x = self.act1(self.pointwise_conv1(x))
59
+ x = self.depthwise_conv(x)
60
+ x = self.norm(x)
61
+ x = self.act2(x)
62
+ x = self.pointwise_conv2(x)
63
+ return self.drop(x).transpose(1, 2)
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(self, dim, heads=4, dim_head=32, conditiondim=None):
68
+ super().__init__()
69
+ if conditiondim is None:
70
+ conditiondim = dim
71
+
72
+ self.scale = dim_head**-0.5
73
+ self.heads = heads
74
+ hidden_dim = dim_head * heads
75
+ self.to_q = nn.Linear(dim, hidden_dim, bias=False)
76
+ self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False)
77
+
78
+ self.to_out = nn.Sequential(
79
+ nn.Linear(
80
+ hidden_dim,
81
+ dim,
82
+ ),
83
+ )
84
+
85
+ def forward(self, q, kv=None, mask=None):
86
+ # b, c, h, w = x.shape
87
+ if kv is None:
88
+ kv = q
89
+ # q, kv = map(
90
+ # lambda t: rearrange(t, "b c t -> b t c", ), (q, kv)
91
+ # )
92
+
93
+ q = self.to_q(q)
94
+ k, v = self.to_kv(kv).chunk(2, dim=2)
95
+
96
+ q, k, v = map(
97
+ lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v)
98
+ )
99
+
100
+ if mask is not None:
101
+ mask = mask.unsqueeze(1).unsqueeze(1)
102
+
103
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
104
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
105
+
106
+ out = rearrange(
107
+ out,
108
+ "b h t c -> b t (h c) ",
109
+ h=self.heads,
110
+ )
111
+ return self.to_out(out)
112
+
113
+
114
+ class conform_ffn(nn.Module):
115
+ def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1):
116
+ super().__init__()
117
+ self.ln1 = nn.Linear(dim, dim * 4)
118
+ self.ln2 = nn.Linear(dim * 4, dim)
119
+ self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0.0 else nn.Identity()
120
+ self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0.0 else nn.Identity()
121
+ self.act = nn.SiLU()
122
+
123
+ def forward(self, x):
124
+ x = self.ln1(x)
125
+ x = self.act(x)
126
+ x = self.drop1(x)
127
+ x = self.ln2(x)
128
+ return self.drop2(x)
129
+
130
+
131
+ class conform_blocke(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ kernel_size: int = 31,
136
+ conv_drop: float = 0.1,
137
+ ffn_latent_drop: float = 0.1,
138
+ ffn_out_drop: float = 0.1,
139
+ attention_drop: float = 0.1,
140
+ attention_heads: int = 4,
141
+ attention_heads_dim: int = 64,
142
+ ):
143
+ super().__init__()
144
+ self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
145
+ self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
146
+ self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim)
147
+ self.attdrop = (
148
+ nn.Dropout(attention_drop) if attention_drop > 0.0 else nn.Identity()
149
+ )
150
+ self.conv = conform_conv(
151
+ dim,
152
+ kernel_size=kernel_size,
153
+ DropoutL=conv_drop,
154
+ )
155
+ self.norm1 = nn.LayerNorm(dim)
156
+ self.norm2 = nn.LayerNorm(dim)
157
+ self.norm3 = nn.LayerNorm(dim)
158
+ self.norm4 = nn.LayerNorm(dim)
159
+ self.norm5 = nn.LayerNorm(dim)
160
+
161
+ def forward(
162
+ self,
163
+ x,
164
+ mask=None,
165
+ ):
166
+ x = self.ffn1(self.norm1(x)) * 0.5 + x
167
+
168
+ x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x
169
+ x = self.conv(self.norm3(x)) + x
170
+ x = self.ffn2(self.norm4(x)) * 0.5 + x
171
+ return self.norm5(x)
172
+
173
+ # return x
174
+
175
+
176
+ class Gcf(nn.Module):
177
+ def __init__(
178
+ self,
179
+ dim: int,
180
+ kernel_size: int = 31,
181
+ conv_drop: float = 0.1,
182
+ ffn_latent_drop: float = 0.1,
183
+ ffn_out_drop: float = 0.1,
184
+ attention_drop: float = 0.1,
185
+ attention_heads: int = 4,
186
+ attention_heads_dim: int = 64,
187
+ ):
188
+ super().__init__()
189
+ self.att1 = conform_blocke(
190
+ dim=dim,
191
+ kernel_size=kernel_size,
192
+ conv_drop=conv_drop,
193
+ ffn_latent_drop=ffn_latent_drop,
194
+ ffn_out_drop=ffn_out_drop,
195
+ attention_drop=attention_drop,
196
+ attention_heads=attention_heads,
197
+ attention_heads_dim=attention_heads_dim,
198
+ )
199
+ self.att2 = conform_blocke(
200
+ dim=dim,
201
+ kernel_size=kernel_size,
202
+ conv_drop=conv_drop,
203
+ ffn_latent_drop=ffn_latent_drop,
204
+ ffn_out_drop=ffn_out_drop,
205
+ attention_drop=attention_drop,
206
+ attention_heads=attention_heads,
207
+ attention_heads_dim=attention_heads_dim,
208
+ )
209
+ self.glu1 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
210
+ self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
211
+
212
+ def forward(self, midi, bound):
213
+ midi = self.att1(midi)
214
+ bound = self.att2(bound)
215
+ midis = self.glu1(midi)
216
+ bounds = self.glu2(bound)
217
+ return midi + bounds, bound + midis
218
+
219
+
220
+ class Gmidi_conform(nn.Module):
221
+ def __init__(
222
+ self,
223
+ lay: int,
224
+ dim: int,
225
+ indim: int,
226
+ outdim: int,
227
+ use_lay_skip: bool,
228
+ kernel_size: int = 31,
229
+ conv_drop: float = 0.1,
230
+ ffn_latent_drop: float = 0.1,
231
+ ffn_out_drop: float = 0.1,
232
+ attention_drop: float = 0.1,
233
+ attention_heads: int = 4,
234
+ attention_heads_dim: int = 64,
235
+ ):
236
+ super().__init__()
237
+
238
+ self.inln = nn.Linear(indim, dim)
239
+ self.inln1 = nn.Linear(indim, dim)
240
+ self.outln = nn.Linear(dim, outdim)
241
+ self.cutheard = nn.Linear(dim, 1)
242
+ # self.cutheard = nn.Linear(dim, outdim)
243
+ self.lay = lay
244
+ self.use_lay_skip = use_lay_skip
245
+ self.cf_lay = nn.ModuleList(
246
+ [
247
+ Gcf(
248
+ dim=dim,
249
+ kernel_size=kernel_size,
250
+ conv_drop=conv_drop,
251
+ ffn_latent_drop=ffn_latent_drop,
252
+ ffn_out_drop=ffn_out_drop,
253
+ attention_drop=attention_drop,
254
+ attention_heads=attention_heads,
255
+ attention_heads_dim=attention_heads_dim,
256
+ )
257
+ for _ in range(lay)
258
+ ]
259
+ )
260
+ self.att1 = conform_blocke(
261
+ dim=dim,
262
+ kernel_size=kernel_size,
263
+ conv_drop=conv_drop,
264
+ ffn_latent_drop=ffn_latent_drop,
265
+ ffn_out_drop=ffn_out_drop,
266
+ attention_drop=attention_drop,
267
+ attention_heads=attention_heads,
268
+ attention_heads_dim=attention_heads_dim,
269
+ )
270
+ self.att2 = conform_blocke(
271
+ dim=dim,
272
+ kernel_size=kernel_size,
273
+ conv_drop=conv_drop,
274
+ ffn_latent_drop=ffn_latent_drop,
275
+ ffn_out_drop=ffn_out_drop,
276
+ attention_drop=attention_drop,
277
+ attention_heads=attention_heads,
278
+ attention_heads_dim=attention_heads_dim,
279
+ )
280
+
281
+ def forward(self, x, mask=None):
282
+ x1 = x.clone()
283
+
284
+ x = self.inln(x)
285
+ x1 = self.inln1(x1)
286
+ if mask is not None:
287
+ x = x.masked_fill(~mask.unsqueeze(-1), 0)
288
+ for idx, i in enumerate(self.cf_lay):
289
+ x, x1 = i(x, x1)
290
+
291
+ if mask is not None:
292
+ x = x.masked_fill(~mask.unsqueeze(-1), 0)
293
+ x, x1 = self.att1(x), self.att2(x1)
294
+
295
+ cutprp = self.cutheard(x1)
296
+ midiout = self.outln(x)
297
+
298
+ return midiout, cutprp
src/YingMusicSinger/melody/Gconv.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class GLU(nn.Module):
5
+ def __init__(self, dim):
6
+ super().__init__()
7
+ self.dim = dim
8
+
9
+ def forward(self, x):
10
+ out, gate = x.chunk(2, dim=self.dim)
11
+
12
+ return out * gate.sigmoid()
13
+
14
+
15
+ class conform_conv(nn.Module):
16
+ def __init__(
17
+ self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
18
+ ):
19
+ super().__init__()
20
+ self.act2 = nn.SiLU()
21
+ self.act1 = GLU(1)
22
+
23
+ self.pointwise_conv1 = nn.Conv1d(
24
+ channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
25
+ )
26
+
27
+ # self.lorder is used to distinguish if it's a causal convolution,
28
+ # if self.lorder > 0:
29
+ # it's a causal convolution, the input will be padded with
30
+ # `self.lorder` frames on the left in forward (causal conv impl).
31
+ # else: it's a symmetrical convolution
32
+
33
+ assert (kernel_size - 1) % 2 == 0
34
+ padding = (kernel_size - 1) // 2
35
+
36
+ self.depthwise_conv = nn.Conv1d(
37
+ channels,
38
+ channels,
39
+ kernel_size,
40
+ stride=1,
41
+ padding=padding,
42
+ groups=channels,
43
+ bias=bias,
44
+ )
45
+
46
+ self.norm = nn.BatchNorm1d(channels)
47
+
48
+ self.pointwise_conv2 = nn.Conv1d(
49
+ channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
50
+ )
51
+ self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ x = x.transpose(1, 2)
55
+ x = self.act1(self.pointwise_conv1(x))
56
+ x = self.depthwise_conv(x)
57
+ x = self.norm(x)
58
+ x = self.act2(x)
59
+ x = self.pointwise_conv2(x)
60
+ return self.drop(x).transpose(1, 2)
src/YingMusicSinger/melody/SmoothMelody.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class MIDIFuzzDisturb(nn.Module):
6
+ """Applies fuzzing perturbations to MIDI latent representations.
7
+
8
+ The raw MIDI teacher model output preserves good prosody but causes
9
+ pronunciation interference. This module mitigates that by applying
10
+ blur, temporal dropout, and noise to the melody latent.
11
+ """
12
+
13
+ def __init__(
14
+ self, dim=128, drop_prob=0.3, noise_scale=0.1, blur_kernel=3, drop_type="random"
15
+ ):
16
+ super().__init__()
17
+ self.blur = None
18
+ self.drop_prob = None
19
+ self.noise_scale = None
20
+ self.dim = dim
21
+ self.drop_type = drop_type
22
+
23
+ assert drop_prob is not None
24
+ assert drop_type is not None
25
+ if drop_type == "random":
26
+ # drop_prob is a float
27
+ if drop_prob != 0:
28
+ self.drop_prob = drop_prob
29
+ elif drop_type == "equal_space":
30
+ # drop_prob is a [drop, keep] list, e.g., [1, 1] means 1 frame drop, 1 frame keep
31
+ self.drop_prob = drop_prob
32
+ else:
33
+ raise ValueError(f"Unknown drop_type: {drop_type}")
34
+
35
+ if noise_scale != 0:
36
+ self.noise_scale = noise_scale
37
+ if blur_kernel != 0:
38
+ assert blur_kernel % 2 == 1, f"blur_kernel {blur_kernel} must be odd"
39
+ self.blur = nn.AvgPool1d(
40
+ kernel_size=blur_kernel, stride=1, padding=blur_kernel // 2
41
+ )
42
+
43
+ def _create_equal_space_mask(self, batch_size, seq_len, device):
44
+ """Create an equally-spaced mask cycling [drop, keep] frames."""
45
+ drop_frames, keep_frames = self.drop_prob
46
+ cycle_len = drop_frames + keep_frames
47
+
48
+ # Pattern: first drop_frames are 0 (drop), next keep_frames are 1 (keep)
49
+ pattern = torch.cat(
50
+ [
51
+ torch.zeros(drop_frames, device=device),
52
+ torch.ones(keep_frames, device=device),
53
+ ]
54
+ )
55
+
56
+ # Repeat pattern to cover the full sequence length
57
+ num_repeats = (seq_len + cycle_len - 1) // cycle_len
58
+ mask = pattern.repeat(num_repeats)[:seq_len] # [T]
59
+
60
+ # Expand to [B, T, 1]
61
+ mask = mask.view(1, seq_len, 1).expand(batch_size, -1, -1)
62
+
63
+ return mask
64
+
65
+ def forward(self, x):
66
+ # x: [B, T, D=128], pre-sigmoid logits
67
+ x = torch.sigmoid(x)
68
+
69
+ assert x.shape[-1] == self.dim, (
70
+ f"MIDIFuzzDisturb: expected dim={self.dim}, got {x.shape[-1]}"
71
+ )
72
+
73
+ if self.blur:
74
+ x = self.blur(x.transpose(1, 2)).transpose(1, 2)
75
+
76
+ if self.drop_prob:
77
+ if self.drop_type == "random":
78
+ time_mask = (
79
+ torch.rand(x.shape[0], x.shape[1], 1, device=x.device)
80
+ > self.drop_prob
81
+ )
82
+ x = x * time_mask.float()
83
+ elif self.drop_type == "equal_space":
84
+ time_mask = self._create_equal_space_mask(
85
+ x.shape[0], x.shape[1], x.device
86
+ )
87
+ x = x * time_mask.float()
88
+ else:
89
+ raise ValueError(f"Unknown drop_type: {self.drop_type}")
90
+
91
+ if self.noise_scale:
92
+ noise = torch.randn_like(x) * self.noise_scale
93
+ x = x + noise
94
+
95
+ return x
96
+
97
+
98
+ class MIDIDigitalEmbedding(nn.Module):
99
+ """Embeds continuous MIDI values into discrete token embeddings.
100
+
101
+ Continuous MIDI values in [0, 127] are quantized at a configurable
102
+ resolution (mark_distinguish_scale) and mapped to learned embeddings.
103
+ """
104
+
105
+ def __init__(self, embed_dim=128, num_classes=128, mark_distinguish_scale=2):
106
+ super().__init__()
107
+
108
+ # num_classes covers the input range [0, 127] plus 2 special tokens
109
+ self.num_classes = num_classes + 2
110
+ self.mark_distinguish_scale = mark_distinguish_scale
111
+ self.embedding_input_num_class = self.num_classes * self.mark_distinguish_scale
112
+ self.embedding = nn.Embedding(self.embedding_input_num_class, embed_dim)
113
+
114
+ def midi_to_class(self, midi_values):
115
+ """Map continuous MIDI values to discrete class indices.
116
+
117
+ Args:
118
+ midi_values: [B, T] continuous MIDI values, roughly in [0, 127]
119
+
120
+ Returns:
121
+ class_indices: [B, T] discrete class indices
122
+ """
123
+ # Round to nearest quantization step
124
+ # e.g., with scale=2: 0->0, 0.3->1, 0.5->1, 0.8->2, 1.0->2, ...
125
+ class_indices = torch.round(midi_values * self.mark_distinguish_scale).long()
126
+
127
+ # Clamp to valid range
128
+ class_indices = torch.clamp(
129
+ class_indices, 0, self.embedding_input_num_class - 1
130
+ )
131
+
132
+ return class_indices
133
+
134
+ def forward(self, midi_values):
135
+ """
136
+ Args:
137
+ midi_values: [B, T] continuous MIDI values
138
+
139
+ Returns:
140
+ embeddings: [B, T, embed_dim] embedding vectors
141
+ """
142
+ class_indices = self.midi_to_class(midi_values)
143
+ embeddings = self.embedding(class_indices)
144
+ return embeddings
src/YingMusicSinger/melody/midi_extractor.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pad_sequence
5
+
6
+ from src.YingMusicSinger.melody.Gconform import Gmidi_conform
7
+
8
+ # midi decoding utils
9
+
10
+
11
+ def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
12
+ num_bins = int(probs.shape[-1])
13
+ interval = (vmax - vmin) / (num_bins - 1)
14
+ width = int(3 * deviation / interval) # 3 * sigma
15
+ idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N]
16
+ idx_values = idx * interval + vmin
17
+ center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1]
18
+ start = torch.clip(center - width, min=0) # [B, T, 1]
19
+ end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1]
20
+ idx_masks = (idx >= start) & (idx < end) # [B, T, N]
21
+ weights = probs * idx_masks # [B, T, N]
22
+ product_sum = torch.sum(weights * idx_values, dim=2) # [B, T]
23
+ weight_sum = torch.sum(weights, dim=2) # [B, T]
24
+ values = product_sum / (
25
+ weight_sum + (weight_sum == 0)
26
+ ) # avoid dividing by zero, [B, T]
27
+ rest = probs.max(dim=-1)[0] < threshold # [B, T]
28
+ return values, rest
29
+
30
+
31
+ def decode_bounds_to_alignment(bounds, use_diff=True):
32
+ bounds_step = bounds.cumsum(dim=1).round().long()
33
+ if use_diff:
34
+ bounds_inc = (
35
+ torch.diff(
36
+ bounds_step,
37
+ dim=1,
38
+ prepend=torch.full(
39
+ (bounds.shape[0], 1),
40
+ fill_value=-1,
41
+ dtype=bounds_step.dtype,
42
+ device=bounds_step.device,
43
+ ),
44
+ )
45
+ > 0
46
+ )
47
+ else:
48
+ bounds_inc = F.pad(
49
+ (bounds_step[:, 1:] > bounds_step[:, :-1]), [1, 0], value=True
50
+ )
51
+ frame2item = bounds_inc.long().cumsum(dim=1)
52
+ return frame2item
53
+
54
+
55
+ def decode_note_sequence(frame2item, values, masks, threshold=0.5):
56
+ """
57
+
58
+ :param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3]
59
+ :param values:
60
+ :param masks:
61
+ :param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item
62
+ :return: item_values, item_dur, item_masks
63
+ """
64
+ b = frame2item.shape[0]
65
+ space = frame2item.max() + 1
66
+
67
+ item_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
68
+ 1, frame2item, torch.ones_like(frame2item)
69
+ )[:, 1:]
70
+ item_unmasked_dur = frame2item.new_zeros(
71
+ b, space, dtype=frame2item.dtype
72
+ ).scatter_add(1, frame2item, masks.long())[:, 1:]
73
+ item_masks = item_unmasked_dur / item_dur >= threshold
74
+
75
+ values_quant = values.round().long()
76
+ histogram = (
77
+ frame2item.new_zeros(b, space * 128, dtype=frame2item.dtype)
78
+ .scatter_add(
79
+ 1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks
80
+ )
81
+ .unflatten(1, [space, 128])[:, 1:, :]
82
+ )
83
+ item_values_center = histogram.float().argmax(dim=2).to(dtype=values.dtype)
84
+ values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item)
85
+ values_near_center = (
86
+ masks & (values >= values_center - 0.5) & (values <= values_center + 0.5)
87
+ )
88
+ item_valid_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
89
+ 1, frame2item, values_near_center.long()
90
+ )[:, 1:]
91
+ item_values = values.new_zeros(b, space, dtype=values.dtype).scatter_add(
92
+ 1, frame2item, values * values_near_center
93
+ )[:, 1:] / (item_valid_dur + (item_valid_dur == 0))
94
+
95
+ return item_values, item_dur, item_masks
96
+
97
+
98
+ def expand_batch_padded(feature_tensor, counts_tensor, padding_value=0.0):
99
+ assert feature_tensor.dim() == 2 and counts_tensor.dim() == 2
100
+
101
+ lengths = torch.sum(counts_tensor, dim=1)
102
+
103
+ feature_tensor = feature_tensor.reshape(-1)
104
+ counts_tensor = counts_tensor.reshape(-1)
105
+ expanded_flat = torch.repeat_interleave(feature_tensor, counts_tensor)
106
+
107
+ ragged_list = torch.split(expanded_flat, lengths.tolist())
108
+
109
+ padded_tensor = pad_sequence(
110
+ ragged_list, batch_first=True, padding_value=padding_value
111
+ )
112
+
113
+ return padded_tensor, lengths
114
+
115
+
116
+ class midi_loss(nn.Module):
117
+ def __init__(self):
118
+ super().__init__()
119
+ self.loss = nn.BCELoss()
120
+
121
+ def forward(self, x, target):
122
+ midiout, cutp = x
123
+ midi_target, cutp_target = target
124
+
125
+ cutploss = self.loss(cutp, cutp_target)
126
+ midiloss = self.loss(midiout, midi_target)
127
+ return midiloss, cutploss
128
+
129
+
130
+ class MIDIExtractor(nn.Module):
131
+ def __init__(self, in_dim=None, out_dim=None):
132
+ super().__init__()
133
+
134
+ cfg = {
135
+ "attention_drop": 0.1,
136
+ "attention_heads": 8,
137
+ "attention_heads_dim": 64,
138
+ "conv_drop": 0.1,
139
+ "dim": 512,
140
+ "ffn_latent_drop": 0.1,
141
+ "ffn_out_drop": 0.1,
142
+ "kernel_size": 31,
143
+ "lay": 8,
144
+ "use_lay_skip": True,
145
+ "indim": 80,
146
+ "outdim": 128,
147
+ }
148
+ if in_dim is not None:
149
+ cfg["indim"] = in_dim
150
+ if out_dim is not None:
151
+ cfg["outdim"] = out_dim
152
+
153
+ self.midi_conform = Gmidi_conform(**cfg)
154
+
155
+ self.midi_min = 0
156
+ self.midi_max = 127
157
+ self.midi_deviation = 1.0
158
+ self.rest_threshold = 0.1
159
+
160
+ def _load_form_ckpt(self, ckpt_path, device="cpu"):
161
+ from collections import OrderedDict
162
+
163
+ if ckpt_path is None:
164
+ raise ValueError("midi_extractor_path is required")
165
+
166
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
167
+ prefix_in_ckpt = "model.model"
168
+ state_dict = OrderedDict(
169
+ {
170
+ k.replace(f"{prefix_in_ckpt}.", "midi_conform."): v
171
+ for k, v in state_dict.items()
172
+ if k.startswith(f"{prefix_in_ckpt}.")
173
+ }
174
+ )
175
+ self.load_state_dict(state_dict, strict=True)
176
+ # self.to(device)
177
+
178
+ def forward(self, x, mask=None):
179
+ midi, bound = self.midi_conform(x, mask)
180
+
181
+ return midi, bound
182
+
183
+ def postprocess(self, midi, bounds, with_expand=False):
184
+ probs = torch.sigmoid(midi)
185
+
186
+ bound_probs = torch.sigmoid(bounds)
187
+ bound_probs = torch.squeeze(bound_probs, -1)
188
+
189
+ masks = torch.ones_like(bound_probs).bool()
190
+ # Avoid in-place ops on tensors needed for autograd (outputs of SigmoidBackward)
191
+ probs = probs * masks[..., None]
192
+ bound_probs = bound_probs * masks
193
+ unit2note_pred = decode_bounds_to_alignment(bound_probs) * masks
194
+ midi_pred, rest_pred = decode_gaussian_blurred_probs(
195
+ probs,
196
+ vmin=self.midi_min,
197
+ vmax=self.midi_max,
198
+ deviation=self.midi_deviation,
199
+ threshold=self.rest_threshold,
200
+ )
201
+ note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
202
+ unit2note_pred, midi_pred, ~rest_pred & masks
203
+ )
204
+ if not with_expand:
205
+ return note_midi_pred, note_dur_pred
206
+
207
+ note_midi_expand, _ = expand_batch_padded(note_midi_pred, note_dur_pred)
208
+ return note_midi_expand, None
src/YingMusicSinger/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dit import DiT
src/YingMusicSinger/models/dit.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+
17
+ from src.YingMusicSinger.models.modules import (
18
+ AdaLayerNorm_Final,
19
+ ConvNeXtV2Block,
20
+ ConvPositionEmbedding,
21
+ DiTBlock,
22
+ TimestepGuidanceEmbedding,
23
+ get_pos_embed_indices,
24
+ precompute_freqs_cis,
25
+ )
26
+
27
+
28
+ # Text embedding
29
+
30
+
31
+ class TextEmbedding(nn.Module):
32
+ def __init__(
33
+ self,
34
+ text_num_embeds,
35
+ text_dim,
36
+ mask_padding=False,
37
+ average_upsampling=False,
38
+ conv_layers=0,
39
+ conv_mult=2,
40
+ ):
41
+ super().__init__()
42
+ self.text_embed = nn.Embedding(
43
+ text_num_embeds + 1, text_dim
44
+ ) # index 0 reserved as filler token
45
+
46
+ self.mask_padding = mask_padding
47
+ self.average_upsampling = average_upsampling # ZipVoice-style late average upsampling (after text encoder)
48
+ if average_upsampling:
49
+ assert mask_padding, (
50
+ "text_embedding_average_upsampling requires text_mask_padding to be True"
51
+ )
52
+
53
+ if conv_layers > 0:
54
+ self.extra_modeling = True
55
+ self.precompute_max_pos = 4096 # ~44s of 24kHz audio
56
+ self.register_buffer(
57
+ "freqs_cis",
58
+ precompute_freqs_cis(text_dim, self.precompute_max_pos),
59
+ persistent=False,
60
+ )
61
+ self.text_blocks = nn.Sequential(
62
+ *[
63
+ ConvNeXtV2Block(text_dim, text_dim * conv_mult)
64
+ for _ in range(conv_layers)
65
+ ]
66
+ )
67
+ else:
68
+ self.extra_modeling = False
69
+
70
+ print(
71
+ f"[info] TextEmbedding: mask_padding={mask_padding}, average_upsampling={average_upsampling}, conv_layers={conv_layers}"
72
+ )
73
+
74
+ def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
75
+ batch, text_len, text_dim = text.shape
76
+
77
+ if audio_mask is None:
78
+ audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
79
+ valid_mask = audio_mask & text_mask
80
+ audio_lens = audio_mask.sum(dim=1) # [batch]
81
+ valid_lens = valid_mask.sum(dim=1) # [batch]
82
+
83
+ upsampled_text = torch.zeros_like(text)
84
+
85
+ for i in range(batch):
86
+ audio_len = audio_lens[i].item()
87
+ valid_len = valid_lens[i].item()
88
+
89
+ if valid_len == 0:
90
+ continue
91
+
92
+ valid_ind = torch.where(valid_mask[i])[0]
93
+ valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
94
+
95
+ base_repeat = audio_len // valid_len
96
+ remainder = audio_len % valid_len
97
+
98
+ indices = []
99
+ for j in range(valid_len):
100
+ repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
101
+ indices.extend([j] * repeat_count)
102
+
103
+ indices = torch.tensor(
104
+ indices[:audio_len], device=text.device, dtype=torch.long
105
+ )
106
+ upsampled = valid_data[indices] # [audio_len, text_dim]
107
+
108
+ upsampled_text[i, :audio_len, :] = upsampled
109
+
110
+ return upsampled_text
111
+
112
+ def forward(
113
+ self,
114
+ text: int["b nt"],
115
+ seq_len,
116
+ drop_text=False,
117
+ audio_mask: bool["b n"] | None = None,
118
+ ): # noqa: F722
119
+ # Text tokens start from 0; shift by 1 so that 0 is never a valid token
120
+ text = text + 1
121
+ # Note: 1 is used as the PAD token
122
+ text = text[
123
+ :, :seq_len
124
+ ] # Truncate if text tokens exceed mel spectrogram length
125
+ batch, text_len = text.shape[0], text.shape[1]
126
+ text = F.pad(text, (0, seq_len - text_len), value=1)
127
+
128
+ if self.mask_padding:
129
+ text_mask = text == 1
130
+ else:
131
+ text_mask = torch.zeros(
132
+ (batch, seq_len), device=text.device, dtype=torch.bool
133
+ )
134
+
135
+ if drop_text: # CFG for text
136
+ text = torch.zeros_like(text)
137
+
138
+ text = self.text_embed(text) # b n -> b n d
139
+
140
+ # Optional extra modeling
141
+ if self.extra_modeling:
142
+ # Sinusoidal positional embedding
143
+ batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
144
+ pos_idx = get_pos_embed_indices(
145
+ batch_start, seq_len, max_pos=self.precompute_max_pos
146
+ )
147
+ text_pos_embed = self.freqs_cis[pos_idx]
148
+ text = text + text_pos_embed
149
+
150
+ # ConvNeXtV2 blocks
151
+ if self.mask_padding:
152
+ text = text.masked_fill(
153
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
154
+ )
155
+ for block in self.text_blocks:
156
+ text = block(text)
157
+ text = text.masked_fill(
158
+ text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
159
+ )
160
+ else:
161
+ text = self.text_blocks(text)
162
+
163
+ if self.average_upsampling:
164
+ text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
165
+
166
+ return text, text_mask
167
+
168
+
169
+ # Noised input audio and context mixing embedding
170
+
171
+
172
+ class InputEmbedding(nn.Module):
173
+ def __init__(self, mel_dim, text_dim, out_dim, midi_dim=128):
174
+ super().__init__()
175
+ self.proj = nn.Linear(mel_dim * 2 + text_dim + midi_dim, out_dim)
176
+ self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
177
+ self.midi_proj = nn.Linear(128, 128)
178
+
179
+ def forward(
180
+ self,
181
+ x: float["b n d"], # noqa: F722
182
+ cond: float["b n d"], # noqa: F722
183
+ text_embed: float["b n d"], # noqa: F722
184
+ midi,
185
+ drop_audio_cond=False,
186
+ drop_midi=False,
187
+ ):
188
+ if drop_audio_cond: # CFG for conditioning audio
189
+ cond = torch.zeros_like(cond)
190
+
191
+ midi = self.midi_proj(midi)
192
+
193
+ if drop_midi: # CFG for melody
194
+ midi = torch.zeros_like(midi)
195
+
196
+ x = self.proj(torch.cat((x, cond, text_embed, midi), dim=-1))
197
+ x = self.conv_pos_embed(x) + x
198
+ return x
199
+
200
+
201
+ # Transformer backbone using DiT blocks
202
+
203
+
204
+ class DiT(nn.Module):
205
+ def __init__(
206
+ self,
207
+ *,
208
+ dim,
209
+ depth=8,
210
+ heads=8,
211
+ dim_head=64,
212
+ dropout=0.1,
213
+ ff_mult=4,
214
+ mel_dim=100,
215
+ text_num_embeds=256,
216
+ text_dim=None,
217
+ n_f0_bins=512,
218
+ text_mask_padding=True,
219
+ text_embedding_average_upsampling=False,
220
+ qk_norm=None,
221
+ conv_layers=0,
222
+ pe_attn_head=None,
223
+ attn_backend="torch", # "torch" | "flash_attn"
224
+ attn_mask_enabled=False,
225
+ long_skip_connection=False,
226
+ checkpoint_activations=False,
227
+ use_guidance_scale_embed: bool = False,
228
+ guidance_scale_embed_dim: int = 192,
229
+ ):
230
+ super().__init__()
231
+
232
+ self.time_embed = TimestepGuidanceEmbedding(
233
+ dim,
234
+ use_guidance_scale_embed=use_guidance_scale_embed,
235
+ guidance_scale_embed_dim=guidance_scale_embed_dim,
236
+ )
237
+ if text_dim is None:
238
+ text_dim = mel_dim
239
+ self.text_embed_p = TextEmbedding(
240
+ text_num_embeds,
241
+ text_dim,
242
+ mask_padding=text_mask_padding,
243
+ average_upsampling=text_embedding_average_upsampling,
244
+ conv_layers=conv_layers,
245
+ )
246
+ self.text_cond, self.text_uncond = None, None # text cache
247
+ self.input_embed_with_midi = InputEmbedding(mel_dim, text_dim, dim)
248
+
249
+ self.rotary_embed = RotaryEmbedding(dim_head)
250
+ self.use_guidance_scale_embed = use_guidance_scale_embed
251
+
252
+ self.dim = dim
253
+ self.depth = depth
254
+
255
+ self.transformer_blocks = nn.ModuleList(
256
+ [
257
+ DiTBlock(
258
+ dim=dim,
259
+ heads=heads,
260
+ dim_head=dim_head,
261
+ ff_mult=ff_mult,
262
+ dropout=dropout,
263
+ qk_norm=qk_norm,
264
+ pe_attn_head=pe_attn_head,
265
+ attn_backend=attn_backend,
266
+ attn_mask_enabled=attn_mask_enabled,
267
+ )
268
+ for _ in range(depth)
269
+ ]
270
+ )
271
+ self.long_skip_connection = (
272
+ nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
273
+ )
274
+
275
+ self.norm_out = AdaLayerNorm_Final(dim) # Final modulation
276
+ self.proj_out = nn.Linear(dim, mel_dim)
277
+
278
+ self.checkpoint_activations = checkpoint_activations
279
+
280
+ self.initialize_weights()
281
+
282
+ def initialize_weights(self):
283
+ # Zero-out AdaLN layers in DiT blocks
284
+ for block in self.transformer_blocks:
285
+ nn.init.constant_(block.attn_norm.linear.weight, 0)
286
+ nn.init.constant_(block.attn_norm.linear.bias, 0)
287
+
288
+ # Zero-out output layers
289
+ nn.init.constant_(self.norm_out.linear.weight, 0)
290
+ nn.init.constant_(self.norm_out.linear.bias, 0)
291
+ nn.init.constant_(self.proj_out.weight, 0)
292
+ nn.init.constant_(self.proj_out.bias, 0)
293
+
294
+ nn.init.zeros_(self.input_embed_with_midi.midi_proj.weight)
295
+ nn.init.zeros_(self.input_embed_with_midi.midi_proj.bias)
296
+
297
+ def ckpt_wrapper(self, module):
298
+ # Ref: https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
299
+ def ckpt_forward(*inputs):
300
+ outputs = module(*inputs)
301
+ return outputs
302
+
303
+ return ckpt_forward
304
+
305
+ def get_input_embed(
306
+ self,
307
+ x, # b n d
308
+ cond, # b n d
309
+ text, # b nt
310
+ midi, # b n
311
+ drop_audio_cond: bool = False,
312
+ drop_text: bool = False,
313
+ drop_midi: bool = False,
314
+ cache: bool = True,
315
+ audio_mask: bool["b n"] | None = None, # noqa: F722
316
+ ):
317
+ seq_len = x.shape[1]
318
+
319
+ if cache:
320
+ if drop_text:
321
+ if self.text_uncond is None:
322
+ self.text_uncond, _ = self.text_embed_p(
323
+ text, seq_len, drop_text=True, audio_mask=audio_mask
324
+ )
325
+ text_embed = self.text_uncond
326
+ else:
327
+ if self.text_cond is None:
328
+ self.text_cond, _ = self.text_embed_p(
329
+ text, seq_len, drop_text=False, audio_mask=audio_mask
330
+ )
331
+ text_embed = self.text_cond
332
+ else:
333
+ text_embed, text_mask = self.text_embed_p(
334
+ text, seq_len, drop_text=drop_text, audio_mask=audio_mask
335
+ )
336
+
337
+ if midi is None:
338
+ midi = torch.zeros(
339
+ (x.size(0), x.size(1)), device=x.device, dtype=torch.long
340
+ )
341
+
342
+ x = self.input_embed_with_midi(
343
+ x,
344
+ cond,
345
+ text_embed,
346
+ midi,
347
+ drop_audio_cond=drop_audio_cond,
348
+ drop_midi=drop_midi,
349
+ )
350
+
351
+ return x, None
352
+
353
+ def clear_cache(self):
354
+ self.text_cond, self.text_uncond = None, None
355
+
356
+ def forward(
357
+ self,
358
+ x: float["b n d"], # Noised input audio # noqa: F722
359
+ cond: float["b n d"], # Masked conditioning audio # noqa: F722
360
+ text: int["b nt"], # Text tokens # noqa: F722
361
+ time: float["b"] | float[""], # Timestep # noqa: F821 F722
362
+ midi: float["b n"] | None = None, # Melody latent # noqa: F722
363
+ mask: bool["b n"] | None = None, # noqa: F722
364
+ drop_audio_cond: bool = False, # CFG for conditioning audio
365
+ drop_text: bool = False, # CFG for text
366
+ drop_midi: bool = False, # CFG for melody
367
+ cfg_infer: bool = False, # CFG inference: pack cond & uncond forward
368
+ cache: bool = False,
369
+ guidance_scale=None,
370
+ cfg_infer_ids=None, # tuple(bool): (x_cond, x_uncond, x_uncond_cc, x_drop_all_cond)
371
+ ):
372
+ batch, seq_len = x.shape[0], x.shape[1]
373
+ if time.ndim == 0:
374
+ time = time.repeat(batch)
375
+
376
+ # Timestep embedding (with optional distillation guidance scale)
377
+ t = self.time_embed(time, guidance_scale=guidance_scale)
378
+
379
+ if cfg_infer: # Pack cond & uncond forward: b n d -> Kb n d
380
+ x_cond, x_uncond, x_uncond_cc, x_drop_all_cond = None, None, None, None
381
+ if cfg_infer_ids is None or cfg_infer_ids[0]:
382
+ x_cond, _ = self.get_input_embed(
383
+ x,
384
+ cond,
385
+ text,
386
+ midi,
387
+ drop_audio_cond=False,
388
+ drop_text=False,
389
+ drop_midi=False,
390
+ cache=cache,
391
+ audio_mask=mask,
392
+ )
393
+ if cfg_infer_ids is None or cfg_infer_ids[1]:
394
+ x_uncond, _ = self.get_input_embed(
395
+ x,
396
+ cond,
397
+ text,
398
+ midi,
399
+ drop_audio_cond=True,
400
+ drop_text=False,
401
+ drop_midi=False,
402
+ cache=cache,
403
+ audio_mask=mask,
404
+ )
405
+ if cfg_infer_ids is None or cfg_infer_ids[2]:
406
+ x_uncond_cc, _ = self.get_input_embed(
407
+ x,
408
+ cond,
409
+ text,
410
+ midi,
411
+ drop_audio_cond=False,
412
+ drop_text=True,
413
+ drop_midi=True,
414
+ cache=cache,
415
+ audio_mask=mask,
416
+ )
417
+ if cfg_infer_ids is None or cfg_infer_ids[3]:
418
+ x_drop_all_cond, _ = self.get_input_embed(
419
+ x,
420
+ cond,
421
+ text,
422
+ midi,
423
+ drop_audio_cond=True,
424
+ drop_text=True,
425
+ drop_midi=True,
426
+ cache=cache,
427
+ audio_mask=mask,
428
+ )
429
+
430
+ # Concatenate only non-None tensors
431
+ x_list = [
432
+ xi
433
+ for xi in [x_cond, x_uncond, x_uncond_cc, x_drop_all_cond]
434
+ if xi is not None
435
+ ]
436
+ x = torch.cat(x_list, dim=0)
437
+ t = torch.cat([t] * len(x_list), dim=0)
438
+ mask = torch.cat([mask] * len(x_list), dim=0) if mask is not None else None
439
+ else:
440
+ x, text_inner_sim_matrix = self.get_input_embed(
441
+ x,
442
+ cond,
443
+ text,
444
+ midi,
445
+ drop_audio_cond=drop_audio_cond,
446
+ drop_text=drop_text,
447
+ drop_midi=drop_midi,
448
+ cache=cache,
449
+ audio_mask=mask,
450
+ )
451
+
452
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
453
+
454
+ if self.long_skip_connection is not None:
455
+ residual = x
456
+
457
+ # Mask is all zeros during inference
458
+ for block in self.transformer_blocks:
459
+ if self.checkpoint_activations:
460
+ x = torch.utils.checkpoint.checkpoint(
461
+ self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False
462
+ )
463
+ else:
464
+ x = block(x, t, mask=mask, rope=rope)
465
+
466
+ if self.long_skip_connection is not None:
467
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
468
+
469
+ x = self.norm_out(x, t)
470
+ output = self.proj_out(x)
471
+
472
+ return output, text_inner_sim_matrix if not cfg_infer else None
src/YingMusicSinger/models/model.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Callable
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from torchdiffeq import odeint
10
+
11
+ from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
12
+ from src.YingMusicSinger.utils.common import (
13
+ default,
14
+ exists,
15
+ get_epss_timesteps,
16
+ lens_to_mask,
17
+ )
18
+
19
+
20
+ def interpolation_midi_continuous(midi_p, bound_p, total_len):
21
+ """Temporally interpolate 3D melody latent to match target length."""
22
+ if midi_p.shape[1] != total_len:
23
+ midi = (
24
+ F.interpolate(
25
+ midi_p.clone().detach().transpose(1, 2),
26
+ size=total_len,
27
+ mode="linear",
28
+ align_corners=False,
29
+ )
30
+ .transpose(1, 2)
31
+ .clone()
32
+ .detach()
33
+ )
34
+ if bound_p is not None:
35
+ midi_bound = (
36
+ F.interpolate(
37
+ bound_p.clone().detach().transpose(1, 2),
38
+ size=total_len,
39
+ mode="linear",
40
+ align_corners=False,
41
+ )
42
+ .transpose(1, 2)
43
+ .clone()
44
+ .detach()
45
+ )
46
+ else:
47
+ midi = midi_p.clone().detach()
48
+ if bound_p is not None:
49
+ midi_bound = bound_p.clone().detach()
50
+ if bound_p is not None:
51
+ return midi, midi_bound
52
+ else:
53
+ return midi
54
+
55
+
56
+ def interpolation_midi_continuous_2_dim(midi_p, bound_p, total_len):
57
+ """Temporally interpolate 2D melody latent to match target length."""
58
+ assert len(midi_p.shape) == 2
59
+
60
+ if midi_p.shape[1] != total_len:
61
+ midi = (
62
+ F.interpolate(
63
+ midi_p.unsqueeze(2).clone().detach().transpose(1, 2),
64
+ size=total_len,
65
+ mode="linear",
66
+ align_corners=False,
67
+ )
68
+ .transpose(1, 2)
69
+ .clone()
70
+ .detach()
71
+ )
72
+ if bound_p:
73
+ midi_bound = (
74
+ F.interpolate(
75
+ bound_p.unsqueeze(2).clone().detach().transpose(1, 2),
76
+ size=total_len,
77
+ mode="linear",
78
+ align_corners=False,
79
+ )
80
+ .transpose(1, 2)
81
+ .clone()
82
+ .detach()
83
+ )
84
+ else:
85
+ midi = midi_p.clone().detach()
86
+ if bound_p:
87
+ midi_bound = bound_p.clone().detach()
88
+ if bound_p:
89
+ return midi.squeeze(2), midi_bound.squeeze(2)
90
+ else:
91
+ return midi.squeeze(2)
92
+
93
+
94
+ class Singer(nn.Module):
95
+ def __init__(
96
+ self,
97
+ transformer: nn.Module,
98
+ is_tts_pretrain,
99
+ melody_input_source,
100
+ cka_disabled,
101
+ distill_stage,
102
+ use_guidance_scale_embed,
103
+ sigma=0.0,
104
+ odeint_kwargs: dict = dict(method="euler"),
105
+ audio_drop_prob=0.3,
106
+ cond_drop_prob=0.2,
107
+ num_channels=None,
108
+ mel_spec_module: nn.Module | None = None,
109
+ mel_spec_kwargs: dict = dict(),
110
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
111
+ extra_parameters=None,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.is_tts_pretrain = is_tts_pretrain
116
+
117
+ if distill_stage is None:
118
+ self.distill_stage = 0
119
+ else:
120
+ self.distill_stage = int(distill_stage)
121
+
122
+ self.use_guidance_scale_embed = use_guidance_scale_embed
123
+
124
+ assert melody_input_source in {
125
+ "student_model",
126
+ "some_pretrain",
127
+ "some_pretrain_fuzzdisturb",
128
+ "some_pretrain_postprocess_embedding",
129
+ "none",
130
+ }
131
+ from src.YingMusicSinger.melody.SmoothMelody import MIDIFuzzDisturb
132
+
133
+ if melody_input_source == "some_pretrain_fuzzdisturb":
134
+ self.smoothMelody_MIDIFuzzDisturb = MIDIFuzzDisturb(
135
+ dim=extra_parameters.some_pretrain_fuzzdisturb.dim,
136
+ drop_prob=extra_parameters.some_pretrain_fuzzdisturb.drop_prob,
137
+ noise_scale=extra_parameters.some_pretrain_fuzzdisturb.noise_scale,
138
+ blur_kernel=extra_parameters.some_pretrain_fuzzdisturb.blur_kernel,
139
+ drop_type=extra_parameters.some_pretrain_fuzzdisturb.drop_type,
140
+ )
141
+ from src.YingMusicSinger.melody.SmoothMelody import MIDIDigitalEmbedding
142
+
143
+ if melody_input_source == "some_pretrain_postprocess_embedding":
144
+ self.smoothMelody_MIDIDigitalEmbedding = MIDIDigitalEmbedding(
145
+ embed_dim=extra_parameters.some_pretrain_postprocess_embedding.embed_dim,
146
+ num_classes=extra_parameters.some_pretrain_postprocess_embedding.num_classes,
147
+ mark_distinguish_scale=extra_parameters.some_pretrain_postprocess_embedding.mark_distinguish_scale,
148
+ )
149
+
150
+ self.melody_input_source = melody_input_source
151
+ self.cka_disabled = cka_disabled
152
+
153
+ self.frac_lengths_mask = frac_lengths_mask
154
+
155
+ num_channels = default(num_channels, mel_spec_kwargs.n_mel_channels)
156
+ self.num_channels = num_channels
157
+
158
+ # Classifier-free guidance drop probabilities
159
+ self.audio_drop_prob = audio_drop_prob
160
+ self.cond_drop_prob = cond_drop_prob
161
+
162
+ # Transformer backbone
163
+ self.transformer = transformer
164
+ dim = transformer.dim
165
+ self.dim = dim
166
+
167
+ # Conditional flow matching
168
+ self.sigma = sigma
169
+ self.odeint_kwargs = odeint_kwargs
170
+
171
+ # Melody extractor
172
+ self.midi_extractor = MIDIExtractor(in_dim=num_channels)
173
+
174
+ @property
175
+ def device(self):
176
+ return next(self.parameters()).device
177
+
178
+ @torch.no_grad()
179
+ def sample(
180
+ self,
181
+ cond: float["b n d"] | float["b nw"], # noqa: F722
182
+ text: int["b nt"] | list[str], # noqa: F722
183
+ duration: int | int["b"] | None = None, # noqa: F821
184
+ *,
185
+ midi_in: float["b n d"] | None = None,
186
+ lens: int["b"] | None = None, # noqa: F821
187
+ steps=32,
188
+ cfg_strength=1.0,
189
+ sway_sampling_coef=None,
190
+ seed: int | None = None,
191
+ max_duration=4096, # Maximum total length (including ICL prompt), ~190s
192
+ vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
193
+ use_epss=True,
194
+ no_ref_audio=False,
195
+ duplicate_test=False,
196
+ t_inter=0.1,
197
+ t_shift=1.0, # Sampling timestep shift (ZipVoice-style)
198
+ guidance_scale=None,
199
+ edit_mask=None,
200
+ midi_p=None,
201
+ bound_p=None,
202
+ enable_melody_control=True,
203
+ ):
204
+ self.eval()
205
+
206
+ assert isinstance(cond, torch.Tensor)
207
+ assert not edit_mask, "edit_mask is not supported in this mode"
208
+ assert not duplicate_test, "duplicate_test is not supported in this mode"
209
+
210
+ if self.melody_input_source == "student_model":
211
+ assert midi_p is None and bound_p is None
212
+ elif self.melody_input_source in {
213
+ "some_pretrain",
214
+ "some_pretrain_fuzzdisturb",
215
+ "some_pretrain_postprocess_embedding",
216
+ }:
217
+ assert midi_p is not None and bound_p is not None
218
+ elif self.melody_input_source == "none":
219
+ assert midi_p is None and bound_p is None
220
+ else:
221
+ raise ValueError(
222
+ f"Unsupported melody_input_source: {self.melody_input_source}"
223
+ )
224
+
225
+ # duration is the total latent sequence length
226
+ assert duration
227
+
228
+ cond = cond.to(next(self.parameters()).dtype)
229
+
230
+ # Extract or interpolate melody representation
231
+ if self.melody_input_source == "student_model":
232
+ midi, midi_bound = self.midi_extractor(midi_in)
233
+
234
+ elif self.melody_input_source == "some_pretrain":
235
+ midi, midi_bound = interpolation_midi_continuous(
236
+ midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1]
237
+ )
238
+ elif self.melody_input_source == "some_pretrain_fuzzdisturb":
239
+ midi, midi_bound = interpolation_midi_continuous(
240
+ midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1]
241
+ )
242
+ midi = self.smoothMelody_MIDIFuzzDisturb(midi)
243
+
244
+ elif self.melody_input_source == "some_pretrain_postprocess_embedding":
245
+ midi_after_postprocess, _ = self.midi_extractor.postprocess(
246
+ midi=midi_p, bounds=bound_p, with_expand=True
247
+ )
248
+ midi = interpolation_midi_continuous_2_dim(
249
+ midi_p=midi_after_postprocess, bound_p=None, total_len=text.shape[1]
250
+ )
251
+ midi = self.smoothMelody_MIDIDigitalEmbedding(midi)
252
+ midi_bound = None
253
+
254
+ elif self.melody_input_source == "none":
255
+ midi = torch.zeros(
256
+ text.shape[0], text.shape[1], 128, dtype=cond.dtype, device=text.device
257
+ )
258
+ midi_bound = None
259
+ else:
260
+ raise NotImplementedError()
261
+
262
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
263
+ if not exists(lens):
264
+ lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
265
+
266
+ assert isinstance(text, torch.Tensor)
267
+
268
+ cond_mask = lens_to_mask(lens)
269
+
270
+ if edit_mask is not None:
271
+ cond_mask = cond_mask & edit_mask
272
+
273
+ if isinstance(duration, int):
274
+ duration = torch.full((batch,), duration, device=device, dtype=torch.long)
275
+
276
+ # Duration must be at least max(text_len, audio_prompt_len) + 1
277
+ duration = torch.maximum(
278
+ torch.maximum((text != 0).sum(dim=-1), lens) + 1, duration
279
+ )
280
+ duration = duration.clamp(max=max_duration)
281
+
282
+ max_duration = duration.amax()
283
+
284
+ # Duplicate test: interpolate between noise and conditioning
285
+ if duplicate_test:
286
+ test_cond = F.pad(
287
+ cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
288
+ )
289
+
290
+ # Zero-pad conditioning latent to max_duration
291
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
292
+
293
+ if no_ref_audio:
294
+ cond = torch.zeros_like(cond)
295
+
296
+ cond_mask = F.pad(
297
+ cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
298
+ )
299
+ cond_mask = cond_mask.unsqueeze(-1)
300
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
301
+
302
+ assert max_duration == midi.shape[1]
303
+
304
+ # Zero out melody in prompt region; optionally disable melody control entirely
305
+ if enable_melody_control:
306
+ midi = torch.where(cond_mask, torch.zeros_like(midi), midi)
307
+ else:
308
+ midi = torch.zeros_like(midi)
309
+
310
+ if self.is_tts_pretrain:
311
+ midi = torch.zeros_like(midi)
312
+
313
+ # For batched inference, explicit mask prevents causal attention fallback
314
+ if batch > 1:
315
+ mask = lens_to_mask(duration)
316
+ else:
317
+ mask = None
318
+
319
+ # ODE velocity function
320
+ def fn(t, x):
321
+ if cfg_strength < 1e-5:
322
+ # No classifier-free guidance
323
+ pred, _ = self.transformer(
324
+ x=x,
325
+ cond=step_cond,
326
+ text=text,
327
+ midi=midi,
328
+ time=t,
329
+ mask=mask,
330
+ drop_audio_cond=False,
331
+ drop_text=False,
332
+ drop_midi=not enable_melody_control,
333
+ cache=False,
334
+ )
335
+ return pred
336
+ else:
337
+ if self.use_guidance_scale_embed:
338
+ # Distilled model with built-in CFG
339
+ assert enable_melody_control
340
+ pred_cfg, _ = self.transformer(
341
+ x=x,
342
+ cond=step_cond,
343
+ text=text,
344
+ midi=midi,
345
+ time=t,
346
+ mask=mask,
347
+ drop_audio_cond=False,
348
+ drop_text=False,
349
+ drop_midi=not enable_melody_control,
350
+ cache=False,
351
+ guidance_scale=torch.tensor([guidance_scale], device=device),
352
+ )
353
+ print(
354
+ f"CFG 参数调节无作用! 蒸馏之后的,输入CFG为 guidance_scale={guidance_scale}"
355
+ )
356
+ return pred_cfg
357
+ else:
358
+ # Standard CFG: cond + uncond forward
359
+ # BUG If enable_melody_control is False, there might be a slight issue here
360
+ assert guidance_scale is not None
361
+ pred_cfg, _ = self.transformer(
362
+ x=x,
363
+ cond=step_cond,
364
+ text=text,
365
+ midi=midi,
366
+ time=t,
367
+ mask=mask,
368
+ cfg_infer=True,
369
+ cache=False,
370
+ cfg_infer_ids=(True, False, False, True),
371
+ )
372
+
373
+ pred, pred_drop_all_cond = torch.chunk(pred_cfg, 2, dim=0)
374
+ return pred + (pred - pred_drop_all_cond) * float(guidance_scale)
375
+
376
+ # Generate initial noise (per-sample seeding for batch consistency)
377
+ y0 = []
378
+ for dur in duration:
379
+ if exists(seed):
380
+ torch.manual_seed(seed)
381
+ y0.append(
382
+ torch.randn(
383
+ dur, self.num_channels, device=self.device, dtype=step_cond.dtype
384
+ )
385
+ )
386
+ y0 = pad_sequence(y0, padding_value=0, batch_first=True)
387
+
388
+ t_start = 0
389
+
390
+ if duplicate_test:
391
+ t_start = t_inter
392
+ y0 = (1 - t_start) * y0 + t_start * test_cond
393
+ steps = int(steps * (1 - t_start))
394
+
395
+ # Build timestep schedule
396
+ assert not use_epss and sway_sampling_coef is None, (
397
+ "Use timestep shift instead of the strategy in F5"
398
+ )
399
+ if t_start == 0 and use_epss:
400
+ # Empirically Pruned Step Sampling for low NFE
401
+ t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
402
+ else:
403
+ t = torch.linspace(
404
+ t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
405
+ )
406
+
407
+ if sway_sampling_coef is not None:
408
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
409
+
410
+ # Apply timestep shift
411
+ t = t_shift * t / (1 + (t_shift - 1) * t)
412
+
413
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
414
+ self.transformer.clear_cache()
415
+
416
+ sampled = trajectory[-1]
417
+ out = sampled
418
+
419
+ if exists(vocoder):
420
+ out = out.permute(0, 2, 1)
421
+ out = vocoder(out)
422
+
423
+ return out, trajectory
src/YingMusicSinger/models/modules.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import math
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torchaudio
19
+ from librosa.filters import mel as librosa_mel_fn
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+ from src.YingMusicSinger.utils.common import is_package_available
23
+
24
+ # raw wav to mel spec
25
+
26
+
27
+ mel_basis_cache = {}
28
+ hann_window_cache = {}
29
+
30
+
31
+ def get_bigvgan_mel_spectrogram(
32
+ waveform,
33
+ n_fft=1024,
34
+ n_mel_channels=100,
35
+ target_sample_rate=24000,
36
+ hop_length=256,
37
+ win_length=1024,
38
+ fmin=0,
39
+ fmax=None,
40
+ center=False,
41
+ ): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
42
+ device = waveform.device
43
+ key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
44
+
45
+ if key not in mel_basis_cache:
46
+ mel = librosa_mel_fn(
47
+ sr=target_sample_rate,
48
+ n_fft=n_fft,
49
+ n_mels=n_mel_channels,
50
+ fmin=fmin,
51
+ fmax=fmax,
52
+ )
53
+ mel_basis_cache[key] = (
54
+ torch.from_numpy(mel).float().to(device)
55
+ ) # TODO: why they need .float()?
56
+ hann_window_cache[key] = torch.hann_window(win_length).to(device)
57
+
58
+ mel_basis = mel_basis_cache[key]
59
+ hann_window = hann_window_cache[key]
60
+
61
+ padding = (n_fft - hop_length) // 2
62
+ waveform = torch.nn.functional.pad(
63
+ waveform.unsqueeze(1), (padding, padding), mode="reflect"
64
+ ).squeeze(1)
65
+
66
+ spec = torch.stft(
67
+ waveform,
68
+ n_fft,
69
+ hop_length=hop_length,
70
+ win_length=win_length,
71
+ window=hann_window,
72
+ center=center,
73
+ pad_mode="reflect",
74
+ normalized=False,
75
+ onesided=True,
76
+ return_complex=True,
77
+ )
78
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
79
+
80
+ mel_spec = torch.matmul(mel_basis, spec)
81
+ mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
82
+
83
+ return mel_spec
84
+
85
+
86
+ def get_vocos_mel_spectrogram(
87
+ waveform,
88
+ n_fft=1024,
89
+ n_mel_channels=100,
90
+ target_sample_rate=24000,
91
+ hop_length=256,
92
+ win_length=1024,
93
+ ):
94
+ mel_stft = torchaudio.transforms.MelSpectrogram(
95
+ sample_rate=target_sample_rate,
96
+ n_fft=n_fft,
97
+ win_length=win_length,
98
+ hop_length=hop_length,
99
+ n_mels=n_mel_channels,
100
+ power=1,
101
+ center=True,
102
+ normalized=False,
103
+ norm=None,
104
+ ).to(waveform.device)
105
+ if len(waveform.shape) == 3:
106
+ waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
107
+
108
+ assert len(waveform.shape) == 2
109
+
110
+ mel = mel_stft(waveform)
111
+ mel = mel.clamp(min=1e-5).log()
112
+ return mel
113
+
114
+
115
+ class MelSpec(nn.Module):
116
+ def __init__(
117
+ self,
118
+ n_fft=1024,
119
+ hop_length=256,
120
+ win_length=1024,
121
+ n_mel_channels=100,
122
+ target_sample_rate=24_000,
123
+ mel_spec_type="vocos",
124
+ ):
125
+ super().__init__()
126
+ assert mel_spec_type in ["vocos", "bigvgan"], print(
127
+ "We only support two extract mel backend: vocos or bigvgan"
128
+ )
129
+
130
+ self.n_fft = n_fft
131
+ self.hop_length = hop_length
132
+ self.win_length = win_length
133
+ self.n_mel_channels = n_mel_channels
134
+ self.target_sample_rate = target_sample_rate
135
+
136
+ if mel_spec_type == "vocos":
137
+ self.extractor = get_vocos_mel_spectrogram
138
+ elif mel_spec_type == "bigvgan":
139
+ self.extractor = get_bigvgan_mel_spectrogram
140
+
141
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
142
+
143
+ def forward(self, wav):
144
+ if self.dummy.device != wav.device:
145
+ self.to(wav.device)
146
+
147
+ mel = self.extractor(
148
+ waveform=wav,
149
+ n_fft=self.n_fft,
150
+ n_mel_channels=self.n_mel_channels,
151
+ target_sample_rate=self.target_sample_rate,
152
+ hop_length=self.hop_length,
153
+ win_length=self.win_length,
154
+ )
155
+
156
+ return mel
157
+
158
+
159
+ # sinusoidal position embedding
160
+
161
+
162
+ class SinusPositionEmbedding(nn.Module):
163
+ def __init__(self, dim):
164
+ super().__init__()
165
+ self.dim = dim
166
+
167
+ def forward(self, x, scale=1000):
168
+ device = x.device
169
+ half_dim = self.dim // 2
170
+ emb = math.log(10000) / (half_dim - 1)
171
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
172
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
173
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
174
+ return emb
175
+
176
+
177
+ # convolutional position embedding
178
+
179
+
180
+ class ConvPositionEmbedding(nn.Module):
181
+ def __init__(self, dim, kernel_size=31, groups=16):
182
+ super().__init__()
183
+ assert kernel_size % 2 != 0
184
+ self.conv1d = nn.Sequential(
185
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
186
+ nn.Mish(),
187
+ nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
188
+ nn.Mish(),
189
+ )
190
+
191
+ def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
192
+ if mask is not None:
193
+ mask = mask[..., None]
194
+ x = x.masked_fill(~mask, 0.0)
195
+
196
+ x = x.permute(0, 2, 1)
197
+ x = self.conv1d(x)
198
+ out = x.permute(0, 2, 1)
199
+
200
+ if mask is not None:
201
+ out = out.masked_fill(~mask, 0.0)
202
+
203
+ return out
204
+
205
+
206
+ # rotary positional embedding related
207
+
208
+
209
+ def precompute_freqs_cis(
210
+ dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
211
+ ):
212
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
213
+ # has some connection to NTK literature
214
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
215
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
216
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
217
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
218
+ t = torch.arange(end, device=freqs.device) # type: ignore
219
+ freqs = torch.outer(t, freqs).float() # type: ignore
220
+ freqs_cos = torch.cos(freqs) # real part
221
+ freqs_sin = torch.sin(freqs) # imaginary part
222
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
223
+
224
+
225
+ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
226
+ # length = length if isinstance(length, int) else length.max()
227
+ scale = scale * torch.ones_like(
228
+ start, dtype=torch.float32
229
+ ) # in case scale is a scalar
230
+ pos = (
231
+ start.unsqueeze(1)
232
+ + (
233
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
234
+ * scale.unsqueeze(1)
235
+ ).long()
236
+ )
237
+ # avoid extra long error.
238
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
239
+ return pos
240
+
241
+
242
+ # Global Response Normalization layer (Instance Normalization ?)
243
+
244
+
245
+ class GRN(nn.Module):
246
+ def __init__(self, dim):
247
+ super().__init__()
248
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
249
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
250
+
251
+ def forward(self, x):
252
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
253
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
254
+ return self.gamma * (x * Nx) + self.beta + x
255
+
256
+
257
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
258
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
259
+
260
+
261
+ class ConvNeXtV2Block(nn.Module):
262
+ def __init__(
263
+ self,
264
+ dim: int,
265
+ intermediate_dim: int,
266
+ dilation: int = 1,
267
+ ):
268
+ super().__init__()
269
+ padding = (dilation * (7 - 1)) // 2
270
+ self.dwconv = nn.Conv1d(
271
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
272
+ ) # depthwise conv
273
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
274
+ self.pwconv1 = nn.Linear(
275
+ dim, intermediate_dim
276
+ ) # pointwise/1x1 convs, implemented with linear layers
277
+ self.act = nn.GELU()
278
+ self.grn = GRN(intermediate_dim)
279
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ residual = x
283
+ x = x.transpose(1, 2) # b n d -> b d n
284
+ x = self.dwconv(x)
285
+ x = x.transpose(1, 2) # b d n -> b n d
286
+ x = self.norm(x)
287
+ x = self.pwconv1(x)
288
+ x = self.act(x)
289
+ x = self.grn(x)
290
+ x = self.pwconv2(x)
291
+ return residual + x
292
+
293
+
294
+ # RMSNorm
295
+
296
+
297
+ class RMSNorm(nn.Module):
298
+ def __init__(self, dim: int, eps: float):
299
+ super().__init__()
300
+ self.eps = eps
301
+ self.weight = nn.Parameter(torch.ones(dim))
302
+ self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
303
+
304
+ def forward(self, x):
305
+ if self.native_rms_norm:
306
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
307
+ x = x.to(self.weight.dtype)
308
+ x = F.rms_norm(
309
+ x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps
310
+ )
311
+ else:
312
+ variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
313
+ x = x * torch.rsqrt(variance + self.eps)
314
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
315
+ x = x.to(self.weight.dtype)
316
+ x = x * self.weight
317
+
318
+ return x
319
+
320
+
321
+ # AdaLayerNorm
322
+ # return with modulated x for attn input, and params for later mlp modulation
323
+
324
+
325
+ class AdaLayerNorm(nn.Module):
326
+ def __init__(self, dim):
327
+ super().__init__()
328
+
329
+ self.silu = nn.SiLU()
330
+ self.linear = nn.Linear(dim, dim * 6)
331
+
332
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
333
+
334
+ def forward(self, x, emb=None):
335
+ emb = self.linear(self.silu(emb))
336
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
337
+ emb, 6, dim=1
338
+ )
339
+
340
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
341
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
342
+
343
+
344
+ # AdaLayerNorm for final layer
345
+ # return only with modulated x for attn input, cuz no more mlp modulation
346
+
347
+
348
+ class AdaLayerNorm_Final(nn.Module):
349
+ def __init__(self, dim):
350
+ super().__init__()
351
+
352
+ self.silu = nn.SiLU()
353
+ self.linear = nn.Linear(dim, dim * 2)
354
+
355
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
356
+
357
+ def forward(self, x, emb):
358
+ emb = self.linear(self.silu(emb))
359
+ scale, shift = torch.chunk(emb, 2, dim=1)
360
+
361
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
362
+ return x
363
+
364
+
365
+ # FeedForward
366
+
367
+
368
+ class FeedForward(nn.Module):
369
+ def __init__(
370
+ self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
371
+ ):
372
+ super().__init__()
373
+ inner_dim = int(dim * mult)
374
+ dim_out = dim_out if dim_out is not None else dim
375
+
376
+ activation = nn.GELU(approximate=approximate)
377
+ project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
378
+ self.ff = nn.Sequential(
379
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
380
+ )
381
+
382
+ def forward(self, x):
383
+ return self.ff(x)
384
+
385
+
386
+ # Attention with possible joint part
387
+ # modified from diffusers/src/diffusers/models/attention_processor.py
388
+
389
+
390
+ class Attention(nn.Module):
391
+ def __init__(
392
+ self,
393
+ processor: JointAttnProcessor | AttnProcessor,
394
+ dim: int,
395
+ heads: int = 8,
396
+ dim_head: int = 64,
397
+ dropout: float = 0.0,
398
+ context_dim: Optional[int] = None, # if not None -> joint attention
399
+ context_pre_only: bool = False,
400
+ qk_norm: Optional[str] = None,
401
+ ):
402
+ super().__init__()
403
+
404
+ if not hasattr(F, "scaled_dot_product_attention"):
405
+ raise ImportError(
406
+ "Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
407
+ )
408
+
409
+ self.processor = processor
410
+
411
+ self.dim = dim
412
+ self.heads = heads
413
+ self.inner_dim = dim_head * heads
414
+ self.dropout = dropout
415
+
416
+ self.context_dim = context_dim
417
+ self.context_pre_only = context_pre_only
418
+
419
+ self.to_q = nn.Linear(dim, self.inner_dim)
420
+ self.to_k = nn.Linear(dim, self.inner_dim)
421
+ self.to_v = nn.Linear(dim, self.inner_dim)
422
+
423
+ if qk_norm is None:
424
+ self.q_norm = None
425
+ self.k_norm = None
426
+ elif qk_norm == "rms_norm":
427
+ self.q_norm = RMSNorm(dim_head, eps=1e-6)
428
+ self.k_norm = RMSNorm(dim_head, eps=1e-6)
429
+ else:
430
+ raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
431
+
432
+ if self.context_dim is not None:
433
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
434
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
435
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
436
+ if qk_norm is None:
437
+ self.c_q_norm = None
438
+ self.c_k_norm = None
439
+ elif qk_norm == "rms_norm":
440
+ self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
441
+ self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
442
+
443
+ self.to_out = nn.ModuleList([])
444
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
445
+ self.to_out.append(nn.Dropout(dropout))
446
+
447
+ if self.context_dim is not None and not self.context_pre_only:
448
+ self.to_out_c = nn.Linear(self.inner_dim, context_dim)
449
+
450
+ def forward(
451
+ self,
452
+ x: float["b n d"], # noised input x
453
+ c: float["b n d"] = None, # context c
454
+ mask: bool["b n"] | None = None,
455
+ rope=None, # rotary position embedding for x
456
+ c_rope=None, # rotary position embedding for c
457
+ ) -> torch.Tensor:
458
+ if c is not None:
459
+ return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
460
+ else:
461
+ return self.processor(self, x, mask=mask, rope=rope)
462
+
463
+
464
+ # Attention processor
465
+
466
+ if is_package_available("flash_attn"):
467
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
468
+ from flash_attn.bert_padding import pad_input, unpad_input
469
+
470
+
471
+ class AttnProcessor:
472
+ def __init__(
473
+ self,
474
+ pe_attn_head: int
475
+ | None = None, # number of attention head to apply rope, None for all
476
+ attn_backend: str = "torch", # "torch" or "flash_attn"
477
+ attn_mask_enabled: bool = True,
478
+ ):
479
+ if attn_backend == "flash_attn":
480
+ assert is_package_available("flash_attn"), (
481
+ "Please install flash-attn first."
482
+ )
483
+
484
+ self.pe_attn_head = pe_attn_head
485
+ self.attn_backend = attn_backend
486
+ self.attn_mask_enabled = attn_mask_enabled
487
+
488
+ def __call__(
489
+ self,
490
+ attn: Attention,
491
+ x: float["b n d"], # noised input x
492
+ mask: bool["b n"] | None = None,
493
+ rope=None, # rotary position embedding
494
+ ) -> torch.FloatTensor:
495
+ batch_size = x.shape[0]
496
+
497
+ # `sample` projections
498
+ query = attn.to_q(x)
499
+ key = attn.to_k(x)
500
+ value = attn.to_v(x)
501
+
502
+ # attention
503
+ inner_dim = key.shape[-1]
504
+ head_dim = inner_dim // attn.heads
505
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
506
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
507
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
508
+
509
+ # qk norm
510
+ if attn.q_norm is not None:
511
+ query = attn.q_norm(query)
512
+ if attn.k_norm is not None:
513
+ key = attn.k_norm(key)
514
+
515
+ # apply rotary position embedding
516
+ if rope is not None:
517
+ freqs, xpos_scale = rope
518
+ q_xpos_scale, k_xpos_scale = (
519
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
520
+ )
521
+
522
+ if self.pe_attn_head is not None:
523
+ pn = self.pe_attn_head
524
+ query[:, :pn, :, :] = apply_rotary_pos_emb(
525
+ query[:, :pn, :, :], freqs, q_xpos_scale
526
+ )
527
+ key[:, :pn, :, :] = apply_rotary_pos_emb(
528
+ key[:, :pn, :, :], freqs, k_xpos_scale
529
+ )
530
+ else:
531
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
532
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
533
+
534
+ if self.attn_backend == "torch":
535
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
536
+ if self.attn_mask_enabled and mask is not None:
537
+ attn_mask = mask
538
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
539
+ attn_mask = attn_mask.expand(
540
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
541
+ )
542
+ else:
543
+ attn_mask = None
544
+ x = F.scaled_dot_product_attention(
545
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
546
+ )
547
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
548
+
549
+ elif self.attn_backend == "flash_attn":
550
+ query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
551
+ key = key.transpose(1, 2)
552
+ value = value.transpose(1, 2)
553
+ if self.attn_mask_enabled and mask is not None:
554
+ query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(
555
+ query, mask
556
+ )
557
+ key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
558
+ value, _, _, _, _ = unpad_input(value, mask)
559
+ x = flash_attn_varlen_func(
560
+ query,
561
+ key,
562
+ value,
563
+ q_cu_seqlens,
564
+ k_cu_seqlens,
565
+ q_max_seqlen_in_batch,
566
+ k_max_seqlen_in_batch,
567
+ )
568
+ x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
569
+ x = x.reshape(batch_size, -1, attn.heads * head_dim)
570
+ else:
571
+ x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
572
+ x = x.reshape(batch_size, -1, attn.heads * head_dim)
573
+
574
+ x = x.to(query.dtype)
575
+
576
+ # linear proj
577
+ x = attn.to_out[0](x)
578
+ # dropout
579
+ x = attn.to_out[1](x)
580
+
581
+ if mask is not None:
582
+ mask = mask.unsqueeze(-1)
583
+ x = x.masked_fill(~mask, 0.0)
584
+
585
+ return x
586
+
587
+
588
+ # Joint Attention processor for MM-DiT
589
+ # modified from diffusers/src/diffusers/models/attention_processor.py
590
+
591
+
592
+ class JointAttnProcessor:
593
+ def __init__(self):
594
+ pass
595
+
596
+ def __call__(
597
+ self,
598
+ attn: Attention,
599
+ x: float["b n d"], # noised input x
600
+ c: float["b nt d"] = None, # context c, here text
601
+ mask: bool["b n"] | None = None,
602
+ rope=None, # rotary position embedding for x
603
+ c_rope=None, # rotary position embedding for c
604
+ ) -> torch.FloatTensor:
605
+ residual = x
606
+
607
+ batch_size = c.shape[0]
608
+
609
+ # `sample` projections
610
+ query = attn.to_q(x)
611
+ key = attn.to_k(x)
612
+ value = attn.to_v(x)
613
+
614
+ # `context` projections
615
+ c_query = attn.to_q_c(c)
616
+ c_key = attn.to_k_c(c)
617
+ c_value = attn.to_v_c(c)
618
+
619
+ # attention
620
+ inner_dim = key.shape[-1]
621
+ head_dim = inner_dim // attn.heads
622
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
623
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
624
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
625
+ c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
626
+ c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
627
+ c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
628
+
629
+ # qk norm
630
+ if attn.q_norm is not None:
631
+ query = attn.q_norm(query)
632
+ if attn.k_norm is not None:
633
+ key = attn.k_norm(key)
634
+ if attn.c_q_norm is not None:
635
+ c_query = attn.c_q_norm(c_query)
636
+ if attn.c_k_norm is not None:
637
+ c_key = attn.c_k_norm(c_key)
638
+
639
+ # apply rope for context and noised input independently
640
+ if rope is not None:
641
+ freqs, xpos_scale = rope
642
+ q_xpos_scale, k_xpos_scale = (
643
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
644
+ )
645
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
646
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
647
+ if c_rope is not None:
648
+ freqs, xpos_scale = c_rope
649
+ q_xpos_scale, k_xpos_scale = (
650
+ (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
651
+ )
652
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
653
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
654
+
655
+ # joint attention
656
+ query = torch.cat([query, c_query], dim=2)
657
+ key = torch.cat([key, c_key], dim=2)
658
+ value = torch.cat([value, c_value], dim=2)
659
+
660
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
661
+ if mask is not None:
662
+ attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
663
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
664
+ attn_mask = attn_mask.expand(
665
+ batch_size, attn.heads, query.shape[-2], key.shape[-2]
666
+ )
667
+ else:
668
+ attn_mask = None
669
+
670
+ x = F.scaled_dot_product_attention(
671
+ query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
672
+ )
673
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
674
+ x = x.to(query.dtype)
675
+
676
+ # Split the attention outputs.
677
+ x, c = (
678
+ x[:, : residual.shape[1]],
679
+ x[:, residual.shape[1] :],
680
+ )
681
+
682
+ # linear proj
683
+ x = attn.to_out[0](x)
684
+ # dropout
685
+ x = attn.to_out[1](x)
686
+ if not attn.context_pre_only:
687
+ c = attn.to_out_c(c)
688
+
689
+ if mask is not None:
690
+ mask = mask.unsqueeze(-1)
691
+ x = x.masked_fill(~mask, 0.0)
692
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
693
+
694
+ return x, c
695
+
696
+
697
+ # DiT Block
698
+
699
+
700
+ class DiTBlock(nn.Module):
701
+ def __init__(
702
+ self,
703
+ dim,
704
+ heads,
705
+ dim_head,
706
+ ff_mult=4,
707
+ dropout=0.1,
708
+ qk_norm=None,
709
+ pe_attn_head=None,
710
+ attn_backend="torch", # "torch" or "flash_attn"
711
+ attn_mask_enabled=True,
712
+ ):
713
+ super().__init__()
714
+
715
+ self.attn_norm = AdaLayerNorm(dim)
716
+ self.attn = Attention(
717
+ processor=AttnProcessor(
718
+ pe_attn_head=pe_attn_head,
719
+ attn_backend=attn_backend,
720
+ attn_mask_enabled=attn_mask_enabled,
721
+ ),
722
+ dim=dim,
723
+ heads=heads,
724
+ dim_head=dim_head,
725
+ dropout=dropout,
726
+ qk_norm=qk_norm,
727
+ )
728
+
729
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
730
+ self.ff = FeedForward(
731
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
732
+ )
733
+
734
+ def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
735
+ # pre-norm & modulation for attention input
736
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
737
+
738
+ # attention
739
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
740
+
741
+ # process attention output for input x
742
+ x = x + gate_msa.unsqueeze(1) * attn_output
743
+
744
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
745
+ ff_output = self.ff(norm)
746
+ x = x + gate_mlp.unsqueeze(1) * ff_output
747
+
748
+ return x
749
+
750
+
751
+ # MMDiT Block https://arxiv.org/abs/2403.03206
752
+
753
+
754
+ class MMDiTBlock(nn.Module):
755
+ r"""
756
+ modified from diffusers/src/diffusers/models/attention.py
757
+
758
+ notes.
759
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
760
+ _x: noised input related. (right part)
761
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
762
+ """
763
+
764
+ def __init__(
765
+ self,
766
+ dim,
767
+ heads,
768
+ dim_head,
769
+ ff_mult=4,
770
+ dropout=0.1,
771
+ context_dim=None,
772
+ context_pre_only=False,
773
+ qk_norm=None,
774
+ ):
775
+ super().__init__()
776
+ if context_dim is None:
777
+ context_dim = dim
778
+ self.context_pre_only = context_pre_only
779
+
780
+ self.attn_norm_c = (
781
+ AdaLayerNorm_Final(context_dim)
782
+ if context_pre_only
783
+ else AdaLayerNorm(context_dim)
784
+ )
785
+ self.attn_norm_x = AdaLayerNorm(dim)
786
+ self.attn = Attention(
787
+ processor=JointAttnProcessor(),
788
+ dim=dim,
789
+ heads=heads,
790
+ dim_head=dim_head,
791
+ dropout=dropout,
792
+ context_dim=context_dim,
793
+ context_pre_only=context_pre_only,
794
+ qk_norm=qk_norm,
795
+ )
796
+
797
+ if not context_pre_only:
798
+ self.ff_norm_c = nn.LayerNorm(
799
+ context_dim, elementwise_affine=False, eps=1e-6
800
+ )
801
+ self.ff_c = FeedForward(
802
+ dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh"
803
+ )
804
+ else:
805
+ self.ff_norm_c = None
806
+ self.ff_c = None
807
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
808
+ self.ff_x = FeedForward(
809
+ dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
810
+ )
811
+
812
+ def forward(
813
+ self, x, c, t, mask=None, rope=None, c_rope=None
814
+ ): # x: noised input, c: context, t: time embedding
815
+ # pre-norm & modulation for attention input
816
+ if self.context_pre_only:
817
+ norm_c = self.attn_norm_c(c, t)
818
+ else:
819
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
820
+ c, emb=t
821
+ )
822
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
823
+ x, emb=t
824
+ )
825
+
826
+ # attention
827
+ x_attn_output, c_attn_output = self.attn(
828
+ x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
829
+ )
830
+
831
+ # process attention output for context c
832
+ if self.context_pre_only:
833
+ c = None
834
+ else: # if not last layer
835
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
836
+
837
+ norm_c = (
838
+ self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
839
+ )
840
+ c_ff_output = self.ff_c(norm_c)
841
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
842
+
843
+ # process attention output for input x
844
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
845
+
846
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
847
+ x_ff_output = self.ff_x(norm_x)
848
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
849
+
850
+ return c, x
851
+
852
+
853
+ # time step conditioning embedding
854
+
855
+
856
+ # class TimestepEmbedding(nn.Module):
857
+ # def __init__(self, dim, freq_embed_dim=256):
858
+ # super().__init__()
859
+ # self.time_embed = SinusPositionEmbedding(freq_embed_dim)
860
+ # self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
861
+
862
+ # def forward(self, timestep: float["b"]):
863
+ # time_hidden = self.time_embed(timestep)
864
+ # time_hidden = time_hidden.to(timestep.dtype)
865
+ # time = self.time_mlp(time_hidden) # b d
866
+ # return time
867
+
868
+
869
+ def zipvoice_timestep_embedding(timesteps, dim, max_period=10000):
870
+ """Create sinusoidal timestep embeddings.
871
+
872
+ :param timesteps: shape of (N) or (N, T)
873
+ :param dim: the dimension of the output.
874
+ :param max_period: controls the minimum frequency of the embeddings.
875
+ :return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
876
+ """
877
+ half = dim // 2
878
+ freqs = torch.exp(
879
+ -math.log(max_period)
880
+ * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
881
+ / half
882
+ )
883
+
884
+ if timesteps.dim() == 2:
885
+ timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
886
+
887
+ args = timesteps[..., None].float() * freqs[None]
888
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
889
+ if dim % 2:
890
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
891
+ return embedding
892
+
893
+
894
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
895
+ """
896
+ Behaves like a constructor of a modified version of nn.Linear
897
+ that gives an easy way to set the default initial parameter scale.
898
+
899
+ Args:
900
+ Accepts the standard args and kwargs that nn.Linear accepts
901
+ e.g. in_features, out_features, bias=False.
902
+
903
+ initial_scale: you can override this if you want to increase
904
+ or decrease the initial magnitude of the module's output
905
+ (affects the initialization of weight_scale and bias_scale).
906
+ Another option, if you want to do something like this, is
907
+ to re-initialize the parameters.
908
+ """
909
+ ans = nn.Linear(*args, **kwargs)
910
+ with torch.no_grad():
911
+ ans.weight[:] *= initial_scale
912
+ if ans.bias is not None:
913
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
914
+ return ans
915
+
916
+
917
+ # 在蒸馏的时候使用!
918
+ class TimestepGuidanceEmbedding(nn.Module):
919
+ def __init__(
920
+ self,
921
+ dim,
922
+ freq_embed_dim=256,
923
+ use_guidance_scale_embed=False,
924
+ guidance_scale_embed_dim=192,
925
+ ):
926
+ super().__init__()
927
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
928
+ self.time_mlp = nn.Sequential(
929
+ nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
930
+ )
931
+ if use_guidance_scale_embed:
932
+ self.guidance_scale_embed = ScaledLinear(
933
+ guidance_scale_embed_dim,
934
+ freq_embed_dim,
935
+ bias=False,
936
+ initial_scale=0.1,
937
+ )
938
+ self.guidance_scale_embed_dim = guidance_scale_embed_dim
939
+ else:
940
+ self.guidance_scale_embed = None
941
+
942
+ def forward(self, timestep: float["b"], guidance_scale=None):
943
+ # import pdb
944
+
945
+ # pdb.set_trace()
946
+ time_hidden = self.time_embed(timestep)
947
+
948
+ if self.guidance_scale_embed:
949
+ assert guidance_scale is not None
950
+ guidance_scale_emb = self.guidance_scale_embed(
951
+ zipvoice_timestep_embedding(
952
+ guidance_scale, self.guidance_scale_embed_dim
953
+ )
954
+ )
955
+ time_hidden = time_hidden + guidance_scale_emb
956
+ else:
957
+ assert guidance_scale is None
958
+
959
+ time_hidden = time_hidden.to(timestep.dtype)
960
+ time = self.time_mlp(time_hidden) # b d
961
+ return time
src/YingMusicSinger/utils/f5_tts/g2p/g2p/__init__.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import re
8
+
9
+ from tokenizers import Tokenizer
10
+
11
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p import cleaners
12
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.text_tokenizers import TextTokenizer
13
+
14
+ # import LangSegment
15
+ from src.YingMusicSinger.utils.f5_tts.thirdparty.LangSegment import LangSegment
16
+
17
+
18
+ class PhonemeBpeTokenizer:
19
+ def __init__(
20
+ self, vacab_path="./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json"
21
+ ):
22
+ self.lang2backend = {
23
+ "zh": "cmn",
24
+ "ja": "ja",
25
+ "en": "en-us",
26
+ "fr": "fr-fr",
27
+ "ko": "ko",
28
+ "de": "de",
29
+ }
30
+ self.text_tokenizers = {}
31
+ self.int_text_tokenizers()
32
+
33
+ with open(vacab_path, "r") as f:
34
+ json_data = f.read()
35
+ data = json.loads(json_data)
36
+ self.vocab = data["vocab"]
37
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
38
+
39
+ def int_text_tokenizers(self):
40
+ for key, value in self.lang2backend.items():
41
+ self.text_tokenizers[key] = TextTokenizer(language=value)
42
+
43
+ def tokenize(self, text, sentence, language):
44
+ # 1. convert text to phoneme
45
+ phonemes = []
46
+ if language == "auto":
47
+ seglist = LangSegment.getTexts(text)
48
+ tmp_ph = []
49
+ for seg in seglist:
50
+ tmp_ph.append(
51
+ self._clean_text(
52
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
53
+ )
54
+ )
55
+ phonemes = "|_|".join(tmp_ph)
56
+ else:
57
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
58
+ # print('clean text: ', phonemes)
59
+
60
+ # 2. tokenize phonemes
61
+ phoneme_tokens = self.phoneme2token(phonemes)
62
+ # print('encode: ', phoneme_tokens)
63
+
64
+ # # 3. decode tokens [optional]
65
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
66
+ # print('decoded: ', decoded_text)
67
+
68
+ return phonemes, phoneme_tokens
69
+
70
+ def _clean_text(self, text, sentence, language, cleaner_names):
71
+ for name in cleaner_names:
72
+ cleaner = getattr(cleaners, name)
73
+ if not cleaner:
74
+ raise Exception("Unknown cleaner: %s" % name)
75
+ text = cleaner(text, sentence, language, self.text_tokenizers)
76
+ return text
77
+
78
+ def phoneme2token(self, phonemes):
79
+ tokens = []
80
+ if isinstance(phonemes, list):
81
+ for phone in phonemes:
82
+ phone = phone.split("\t")[0]
83
+ phonemes_split = phone.split("|")
84
+ tokens.append(
85
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
86
+ )
87
+ else:
88
+ phonemes = phonemes.split("\t")[0]
89
+ phonemes_split = phonemes.split("|")
90
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
91
+ return tokens
src/YingMusicSinger/utils/f5_tts/g2p/g2p/chinese_model_g2p.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+
9
+ import numpy as np
10
+ import torch
11
+ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from transformers import BertTokenizer
14
+ from transformers.models.bert.modeling_bert import *
15
+
16
+
17
+ class PolyDataset(Dataset):
18
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
19
+ self.dataset = self.preprocess(words, labels)
20
+ self.word_pad_idx = word_pad_idx
21
+ self.label_pad_idx = label_pad_idx
22
+
23
+ def preprocess(self, origin_sentences, origin_labels):
24
+ """
25
+ Maps tokens and tags to their indices and stores them in the dict data.
26
+ examples:
27
+ word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
28
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
29
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
30
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
31
+ """
32
+ data = []
33
+ labels = []
34
+ sentences = []
35
+ # tokenize
36
+ for line in origin_sentences:
37
+ # replace each token by its index
38
+ # we can not use encode_plus because our sentences are aligned to labels in list type
39
+ words = []
40
+ word_lens = []
41
+ for token in line:
42
+ words.append(token)
43
+ word_lens.append(1)
44
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
45
+ sentences.append(((words, token_start_idxs), 0))
46
+ ###
47
+ for tag in origin_labels:
48
+ labels.append(tag)
49
+
50
+ for sentence, label in zip(sentences, labels):
51
+ data.append((sentence, label))
52
+ return data
53
+
54
+ def __getitem__(self, idx):
55
+ """sample data to get batch"""
56
+ word = self.dataset[idx][0]
57
+ label = self.dataset[idx][1]
58
+ return [word, label]
59
+
60
+ def __len__(self):
61
+ """get dataset size"""
62
+ return len(self.dataset)
63
+
64
+ def collate_fn(self, batch):
65
+ sentences = [x[0][0] for x in batch]
66
+ ori_sents = [x[0][1] for x in batch]
67
+ labels = [x[1] for x in batch]
68
+ batch_len = len(sentences)
69
+
70
+ # compute length of longest sentence in batch
71
+ max_len = max([len(s[0]) for s in sentences])
72
+ max_label_len = 0
73
+ batch_data = np.ones((batch_len, max_len))
74
+ batch_label_starts = []
75
+
76
+ # padding and aligning
77
+ for j in range(batch_len):
78
+ cur_len = len(sentences[j][0])
79
+ batch_data[j][:cur_len] = sentences[j][0]
80
+ label_start_idx = sentences[j][-1]
81
+ label_starts = np.zeros(max_len)
82
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
83
+ batch_label_starts.append(label_starts)
84
+ max_label_len = max(int(sum(label_starts)), max_label_len)
85
+
86
+ # padding label
87
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
88
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
89
+ for j in range(batch_len):
90
+ cur_tags_len = len(labels[j])
91
+ batch_labels[j][:cur_tags_len] = labels[j]
92
+ batch_pmasks[j][:cur_tags_len] = [
93
+ 1 if item > 0 else 0 for item in labels[j]
94
+ ]
95
+
96
+ # convert data to torch LongTensors
97
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
98
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
99
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
100
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
101
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
102
+
103
+
104
+ class BertPolyPredict:
105
+ def __init__(self, bert_model, jsonr_file, json_file):
106
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
107
+ with open(jsonr_file, "r", encoding="utf8") as fp:
108
+ self.pron_dict = json.load(fp)
109
+ with open(json_file, "r", encoding="utf8") as fp:
110
+ self.pron_dict_id_2_pinyin = json.load(fp)
111
+ self.num_polyphone = len(self.pron_dict)
112
+ self.device = "cpu"
113
+ self.polydataset = PolyDataset
114
+ options = SessionOptions() # initialize session options
115
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
116
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
117
+ self.session = InferenceSession(
118
+ os.path.join(bert_model, "poly_bert_model.onnx"),
119
+ sess_options=options,
120
+ providers=[
121
+ "CPUExecutionProvider",
122
+ "CUDAExecutionProvider",
123
+ ], # CPUExecutionProvider #CUDAExecutionProvider
124
+ )
125
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
126
+
127
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
128
+ self.session.disable_fallback()
129
+
130
+ def predict_process(self, txt_list):
131
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
132
+ data = self.polydataset(word_test, label_test)
133
+ predict_loader = DataLoader(
134
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
135
+ )
136
+ pred_tags = self.predict_onnx(predict_loader)
137
+ return pred_tags
138
+
139
+ def predict_onnx(self, dev_loader):
140
+ pred_tags = []
141
+ with torch.no_grad():
142
+ for idx, batch_samples in enumerate(dev_loader):
143
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
144
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
145
+ batch_samples
146
+ )
147
+ # shift tensors to GPU if available
148
+ batch_data = batch_data.to(self.device)
149
+ batch_label_starts = batch_label_starts.to(self.device)
150
+ batch_labels = batch_labels.to(self.device)
151
+ batch_pmasks = batch_pmasks.to(self.device)
152
+ batch_data = np.asarray(batch_data, dtype=np.int32)
153
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
154
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
155
+ batch_output = self.session.run(
156
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
157
+ )[0]
158
+ label_masks = batch_pmasks == 1
159
+ batch_labels = batch_labels.to("cpu").numpy()
160
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
161
+ for j, idx in enumerate(indices):
162
+ if label_masks[i][j]:
163
+ # pred_tag.append(idx)
164
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
165
+ return pred_tags
166
+
167
+ def get_examples_po(self, text_list):
168
+ word_list = []
169
+ label_list = []
170
+ sentence_list = []
171
+ id = 0
172
+ for line in [text_list]:
173
+ sentence = line[0]
174
+ words = []
175
+ tokens = line[0]
176
+ index = line[-1]
177
+ front = index
178
+ back = len(tokens) - index - 1
179
+ labels = [0] * front + [1] + [0] * back
180
+ words = ["[CLS]"] + [item for item in sentence]
181
+ words = self.tokenizer.convert_tokens_to_ids(words)
182
+ word_list.append(words)
183
+ label_list.append(labels)
184
+ sentence_list.append(sentence)
185
+
186
+ id += 1
187
+ # mask_list.append(masks)
188
+ assert len(labels) + 1 == len(words), print(
189
+ (
190
+ poly,
191
+ sentence,
192
+ words,
193
+ labels,
194
+ sentence,
195
+ len(sentence),
196
+ len(words),
197
+ len(labels),
198
+ )
199
+ )
200
+ assert len(labels) + 1 == len(words), (
201
+ "Number of labels does not match number of words"
202
+ )
203
+ assert len(labels) == len(sentence), (
204
+ "Number of labels does not match number of sentences"
205
+ )
206
+ assert len(word_list) == len(label_list), (
207
+ "Number of label sentences does not match number of word sentences"
208
+ )
209
+ return word_list, label_list, text_list
src/YingMusicSinger/utils/f5_tts/g2p/g2p/cleaners.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.english import english_to_ipa
7
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.french import french_to_ipa
8
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.german import german_to_ipa
9
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.korean import korean_to_ipa
10
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.mandarin import chinese_to_ipa
11
+
12
+
13
+ def cjekfd_cleaners(text, sentence, language, text_tokenizers):
14
+ if language == "zh":
15
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
16
+ elif language == "ja":
17
+ return japanese_to_ipa(text, text_tokenizers["ja"])
18
+ elif language == "en":
19
+ return english_to_ipa(text, text_tokenizers["en"])
20
+ elif language == "fr":
21
+ return french_to_ipa(text, text_tokenizers["fr"])
22
+ elif language == "ko":
23
+ return korean_to_ipa(text, text_tokenizers["ko"])
24
+ elif language == "de":
25
+ return german_to_ipa(text, text_tokenizers["de"])
26
+ else:
27
+ raise Exception("Unknown language: %s" % language)
28
+ return None
src/YingMusicSinger/utils/f5_tts/g2p/g2p/english.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ import inflect
9
+
10
+ """
11
+ Text clean time
12
+ """
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
15
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
16
+ _percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
17
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
18
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
19
+ _fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
20
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
+ _number_re = re.compile(r"[0-9]+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ("etc", "et cetera"),
46
+ ("btw", "by the way"),
47
+ ]
48
+ ]
49
+
50
+ _special_map = [
51
+ ("t|ɹ", "tɹ"),
52
+ ("d|ɹ", "dɹ"),
53
+ ("t|s", "ts"),
54
+ ("d|z", "dz"),
55
+ ("ɪ|ɹ", "ɪɹ"),
56
+ ("ɐ", "ɚ"),
57
+ ("ᵻ", "ɪ"),
58
+ ("əl", "l"),
59
+ ("x", "k"),
60
+ ("ɬ", "l"),
61
+ ("ʔ", "t"),
62
+ ("n̩", "n"),
63
+ ("oː|ɹ", "oːɹ"),
64
+ ]
65
+
66
+
67
+ def expand_abbreviations(text):
68
+ for regex, replacement in _abbreviations:
69
+ text = re.sub(regex, replacement, text)
70
+ return text
71
+
72
+
73
+ def _remove_commas(m):
74
+ return m.group(1).replace(",", "")
75
+
76
+
77
+ def _expand_decimal_point(m):
78
+ return m.group(1).replace(".", " point ")
79
+
80
+
81
+ def _expand_percent(m):
82
+ return m.group(1).replace("%", " percent ")
83
+
84
+
85
+ def _expand_dollars(m):
86
+ match = m.group(1)
87
+ parts = match.split(".")
88
+ if len(parts) > 2:
89
+ return " " + match + " dollars " # Unexpected format
90
+ dollars = int(parts[0]) if parts[0] else 0
91
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
92
+ if dollars and cents:
93
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
94
+ cent_unit = "cent" if cents == 1 else "cents"
95
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
96
+ elif dollars:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ return " %s %s " % (dollars, dollar_unit)
99
+ elif cents:
100
+ cent_unit = "cent" if cents == 1 else "cents"
101
+ return " %s %s " % (cents, cent_unit)
102
+ else:
103
+ return " zero dollars "
104
+
105
+
106
+ def fraction_to_words(numerator, denominator):
107
+ if numerator == 1 and denominator == 2:
108
+ return " one half "
109
+ if numerator == 1 and denominator == 4:
110
+ return " one quarter "
111
+ if denominator == 2:
112
+ return " " + _inflect.number_to_words(numerator) + " halves "
113
+ if denominator == 4:
114
+ return " " + _inflect.number_to_words(numerator) + " quarters "
115
+ return (
116
+ " "
117
+ + _inflect.number_to_words(numerator)
118
+ + " "
119
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
120
+ + " "
121
+ )
122
+
123
+
124
+ def _expand_fraction(m):
125
+ numerator = int(m.group(1))
126
+ denominator = int(m.group(2))
127
+ return fraction_to_words(numerator, denominator)
128
+
129
+
130
+ def _expand_ordinal(m):
131
+ return " " + _inflect.number_to_words(m.group(0)) + " "
132
+
133
+
134
+ def _expand_number(m):
135
+ num = int(m.group(0))
136
+ if num > 1000 and num < 3000:
137
+ if num == 2000:
138
+ return " two thousand "
139
+ elif num > 2000 and num < 2010:
140
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
141
+ elif num % 100 == 0:
142
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
143
+ else:
144
+ return (
145
+ " "
146
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
147
+ ", ", " "
148
+ )
149
+ + " "
150
+ )
151
+ else:
152
+ return " " + _inflect.number_to_words(num, andword="") + " "
153
+
154
+
155
+ # Normalize numbers pronunciation
156
+ def normalize_numbers(text):
157
+ text = re.sub(_comma_number_re, _remove_commas, text)
158
+ text = re.sub(_pounds_re, r"\1 pounds", text)
159
+ text = re.sub(_dollars_re, _expand_dollars, text)
160
+ text = re.sub(_fraction_re, _expand_fraction, text)
161
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
162
+ text = re.sub(_percent_number_re, _expand_percent, text)
163
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
164
+ text = re.sub(_number_re, _expand_number, text)
165
+ return text
166
+
167
+
168
+ def _english_to_ipa(text):
169
+ # text = unidecode(text).lower()
170
+ text = expand_abbreviations(text)
171
+ text = normalize_numbers(text)
172
+ return text
173
+
174
+
175
+ # special map
176
+ def special_map(text):
177
+ for regex, replacement in _special_map:
178
+ regex = regex.replace("|", "\|")
179
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
180
+ text = re.sub(
181
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
182
+ )
183
+ # text = re.sub(r'([,.!?])', r'|\1', text)
184
+ return text
185
+
186
+
187
+ # Add some special operation
188
+ def english_to_ipa(text, text_tokenizer):
189
+ if type(text) == str:
190
+ text = _english_to_ipa(text)
191
+ else:
192
+ text = [_english_to_ipa(t) for t in text]
193
+ phonemes = text_tokenizer(text)
194
+ if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
195
+ phonemes += "|_"
196
+ if type(text) == str:
197
+ return special_map(phonemes)
198
+ else:
199
+ result_ph = []
200
+ for phone in phonemes:
201
+ result_ph.append(special_map(phone))
202
+ return result_ph
src/YingMusicSinger/utils/f5_tts/g2p/g2p/french.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ # List of (regular expression, replacement) pairs for abbreviations in french:
12
+ _abbreviations = [
13
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
14
+ for x in [
15
+ ("M", "monsieur"),
16
+ ("Mlle", "mademoiselle"),
17
+ ("Mlles", "mesdemoiselles"),
18
+ ("Mme", "Madame"),
19
+ ("Mmes", "Mesdames"),
20
+ ("N.B", "nota bene"),
21
+ ("M", "monsieur"),
22
+ ("p.c.q", "parce que"),
23
+ ("Pr", "professeur"),
24
+ ("qqch", "quelque chose"),
25
+ ("rdv", "rendez-vous"),
26
+ ("max", "maximum"),
27
+ ("min", "minimum"),
28
+ ("no", "numéro"),
29
+ ("adr", "adresse"),
30
+ ("dr", "docteur"),
31
+ ("st", "saint"),
32
+ ("co", "companie"),
33
+ ("jr", "junior"),
34
+ ("sgt", "sergent"),
35
+ ("capt", "capitain"),
36
+ ("col", "colonel"),
37
+ ("av", "avenue"),
38
+ ("av. J.-C", "avant Jésus-Christ"),
39
+ ("apr. J.-C", "après Jésus-Christ"),
40
+ ("art", "article"),
41
+ ("boul", "boulevard"),
42
+ ("c.-à-d", "c’est-à-dire"),
43
+ ("etc", "et cetera"),
44
+ ("ex", "exemple"),
45
+ ("excl", "exclusivement"),
46
+ ("boul", "boulevard"),
47
+ ]
48
+ ] + [
49
+ (re.compile("\\b%s" % x[0]), x[1])
50
+ for x in [
51
+ ("Mlle", "mademoiselle"),
52
+ ("Mlles", "mesdemoiselles"),
53
+ ("Mme", "Madame"),
54
+ ("Mmes", "Mesdames"),
55
+ ]
56
+ ]
57
+
58
+ rep_map = {
59
+ ":": ",",
60
+ ";": ",",
61
+ ",": ",",
62
+ "。": ".",
63
+ "!": "!",
64
+ "?": "?",
65
+ "\n": ".",
66
+ "·": ",",
67
+ "、": ",",
68
+ "...": ".",
69
+ "…": ".",
70
+ "$": ".",
71
+ "“": "",
72
+ "”": "",
73
+ "‘": "",
74
+ "’": "",
75
+ "(": "",
76
+ ")": "",
77
+ "(": "",
78
+ ")": "",
79
+ "《": "",
80
+ "》": "",
81
+ "【": "",
82
+ "】": "",
83
+ "[": "",
84
+ "]": "",
85
+ "—": "",
86
+ "~": "-",
87
+ "~": "-",
88
+ "「": "",
89
+ "」": "",
90
+ "¿": "",
91
+ "¡": "",
92
+ }
93
+
94
+
95
+ def collapse_whitespace(text):
96
+ # Regular expression matching whitespace:
97
+ _whitespace_re = re.compile(r"\s+")
98
+ return re.sub(_whitespace_re, " ", text).strip()
99
+
100
+
101
+ def remove_punctuation_at_begin(text):
102
+ return re.sub(r"^[,.!?]+", "", text)
103
+
104
+
105
+ def remove_aux_symbols(text):
106
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
107
+ return text
108
+
109
+
110
+ def replace_symbols(text):
111
+ text = text.replace(";", ",")
112
+ text = text.replace("-", " ")
113
+ text = text.replace(":", ",")
114
+ text = text.replace("&", " et ")
115
+ return text
116
+
117
+
118
+ def expand_abbreviations(text):
119
+ for regex, replacement in _abbreviations:
120
+ text = re.sub(regex, replacement, text)
121
+ return text
122
+
123
+
124
+ def replace_punctuation(text):
125
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
126
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
127
+ return replaced_text
128
+
129
+
130
+ def text_normalize(text):
131
+ text = expand_abbreviations(text)
132
+ text = replace_punctuation(text)
133
+ text = replace_symbols(text)
134
+ text = remove_aux_symbols(text)
135
+ text = remove_punctuation_at_begin(text)
136
+ text = collapse_whitespace(text)
137
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
138
+ return text
139
+
140
+
141
+ def french_to_ipa(text, text_tokenizer):
142
+ if type(text) == str:
143
+ text = text_normalize(text)
144
+ phonemes = text_tokenizer(text)
145
+ return phonemes
146
+ else:
147
+ for i, t in enumerate(text):
148
+ text[i] = text_normalize(t)
149
+ return text_tokenizer(text)
src/YingMusicSinger/utils/f5_tts/g2p/g2p/german.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ rep_map = {
12
+ ":": ",",
13
+ ";": ",",
14
+ ",": ",",
15
+ "。": ".",
16
+ "!": "!",
17
+ "?": "?",
18
+ "\n": ".",
19
+ "·": ",",
20
+ "、": ",",
21
+ "...": ".",
22
+ "…": ".",
23
+ "$": ".",
24
+ "“": "",
25
+ "”": "",
26
+ "‘": "",
27
+ "’": "",
28
+ "(": "",
29
+ ")": "",
30
+ "(": "",
31
+ ")": "",
32
+ "《": "",
33
+ "》": "",
34
+ "【": "",
35
+ "】": "",
36
+ "[": "",
37
+ "]": "",
38
+ "—": "",
39
+ "~": "-",
40
+ "~": "-",
41
+ "「": "",
42
+ "」": "",
43
+ "¿": "",
44
+ "¡": "",
45
+ }
46
+
47
+
48
+ def collapse_whitespace(text):
49
+ # Regular expression matching whitespace:
50
+ _whitespace_re = re.compile(r"\s+")
51
+ return re.sub(_whitespace_re, " ", text).strip()
52
+
53
+
54
+ def remove_punctuation_at_begin(text):
55
+ return re.sub(r"^[,.!?]+", "", text)
56
+
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text):
64
+ text = text.replace(";", ",")
65
+ text = text.replace("-", " ")
66
+ text = text.replace(":", ",")
67
+ return text
68
+
69
+
70
+ def replace_punctuation(text):
71
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
72
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
73
+ return replaced_text
74
+
75
+
76
+ def text_normalize(text):
77
+ text = replace_punctuation(text)
78
+ text = replace_symbols(text)
79
+ text = remove_aux_symbols(text)
80
+ text = remove_punctuation_at_begin(text)
81
+ text = collapse_whitespace(text)
82
+ text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
83
+ return text
84
+
85
+
86
+ def german_to_ipa(text, text_tokenizer):
87
+ if type(text) == str:
88
+ text = text_normalize(text)
89
+ phonemes = text_tokenizer(text)
90
+ return phonemes
91
+ else:
92
+ for i, t in enumerate(text):
93
+ text[i] = text_normalize(t)
94
+ return text_tokenizer(text)
src/YingMusicSinger/utils/f5_tts/g2p/g2p/korean.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ english_dictionary = {
12
+ "KOREA": "코리아",
13
+ "IDOL": "아이돌",
14
+ "IT": "아이티",
15
+ "IQ": "아이큐",
16
+ "UP": "업",
17
+ "DOWN": "다운",
18
+ "PC": "피씨",
19
+ "CCTV": "씨씨티비",
20
+ "SNS": "에스엔에스",
21
+ "AI": "에이아이",
22
+ "CEO": "씨이오",
23
+ "A": "에이",
24
+ "B": "비",
25
+ "C": "씨",
26
+ "D": "디",
27
+ "E": "이",
28
+ "F": "에프",
29
+ "G": "지",
30
+ "H": "에이치",
31
+ "I": "아이",
32
+ "J": "제이",
33
+ "K": "케이",
34
+ "L": "엘",
35
+ "M": "엠",
36
+ "N": "엔",
37
+ "O": "오",
38
+ "P": "피",
39
+ "Q": "큐",
40
+ "R": "알",
41
+ "S": "에스",
42
+ "T": "티",
43
+ "U": "유",
44
+ "V": "브이",
45
+ "W": "더블유",
46
+ "X": "엑스",
47
+ "Y": "와이",
48
+ "Z": "제트",
49
+ }
50
+
51
+
52
+ def normalize(text):
53
+ text = text.strip()
54
+ text = re.sub(
55
+ "[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
56
+ )
57
+ text = normalize_english(text)
58
+ text = text.lower()
59
+ return text
60
+
61
+
62
+ def normalize_english(text):
63
+ def fn(m):
64
+ word = m.group()
65
+ if word in english_dictionary:
66
+ return english_dictionary.get(word)
67
+ return word
68
+
69
+ text = re.sub("([A-Za-z]+)", fn, text)
70
+ return text
71
+
72
+
73
+ def korean_to_ipa(text, text_tokenizer):
74
+ if type(text) == str:
75
+ text = normalize(text)
76
+ phonemes = text_tokenizer(text)
77
+ return phonemes
78
+ else:
79
+ for i, t in enumerate(text):
80
+ text[i] = normalize(t)
81
+ return text_tokenizer(text)
src/YingMusicSinger/utils/f5_tts/g2p/g2p/mandarin.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import re
8
+ from typing import List
9
+
10
+ import cn2an
11
+ import jieba
12
+ from pypinyin import BOPOMOFO, lazy_pinyin
13
+
14
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p.chinese_model_g2p import BertPolyPredict
15
+ from src.YingMusicSinger.utils.f5_tts.g2p.utils.front_utils import *
16
+
17
+ # from g2pw import G2PWConverter
18
+
19
+
20
+ # set blank level, {0:"none",1:"char", 2:"word"}
21
+ BLANK_LEVEL = 0
22
+
23
+ # conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
24
+ resource_path = r"./src/YingMusicSinger/utils/f5_tts/g2p"
25
+ poly_all_class_path = os.path.join(
26
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
27
+ )
28
+ if not os.path.exists(poly_all_class_path):
29
+ print(
30
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
31
+ poly_all_class_path
32
+ )
33
+ )
34
+ exit()
35
+ poly_dict = generate_poly_lexicon(poly_all_class_path)
36
+
37
+ # Set up G2PW model parameters
38
+ g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
39
+ if not os.path.exists(g2pw_poly_model_path):
40
+ print(
41
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
42
+ g2pw_poly_model_path
43
+ )
44
+ )
45
+ exit()
46
+
47
+ json_file_path = os.path.join(
48
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
49
+ )
50
+ if not os.path.exists(json_file_path):
51
+ print(
52
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
53
+ json_file_path
54
+ )
55
+ )
56
+ exit()
57
+
58
+ jsonr_file_path = os.path.join(
59
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
60
+ )
61
+ if not os.path.exists(jsonr_file_path):
62
+ print(
63
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
64
+ jsonr_file_path
65
+ )
66
+ )
67
+ exit()
68
+
69
+ g2pw_poly_predict = BertPolyPredict(
70
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
71
+ )
72
+
73
+
74
+ """
75
+ Text clean time
76
+ """
77
+ # List of (Latin alphabet, bopomofo) pairs:
78
+ _latin_to_bopomofo = [
79
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
80
+ for x in [
81
+ ("a", "ㄟˉ"),
82
+ ("b", "ㄅㄧˋ"),
83
+ ("c", "ㄙㄧˉ"),
84
+ ("d", "ㄉㄧˋ"),
85
+ ("e", "ㄧˋ"),
86
+ ("f", "ㄝˊㄈㄨˋ"),
87
+ ("g", "ㄐㄧˋ"),
88
+ ("h", "ㄝˇㄑㄩˋ"),
89
+ ("i", "ㄞˋ"),
90
+ ("j", "ㄐㄟˋ"),
91
+ ("k", "ㄎㄟˋ"),
92
+ ("l", "ㄝˊㄛˋ"),
93
+ ("m", "ㄝˊㄇㄨˋ"),
94
+ ("n", "ㄣˉ"),
95
+ ("o", "ㄡˉ"),
96
+ ("p", "ㄆㄧˉ"),
97
+ ("q", "ㄎㄧㄡˉ"),
98
+ ("r", "ㄚˋ"),
99
+ ("s", "ㄝˊㄙˋ"),
100
+ ("t", "ㄊㄧˋ"),
101
+ ("u", "ㄧㄡˉ"),
102
+ ("v", "ㄨㄧˉ"),
103
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
104
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
105
+ ("y", "ㄨㄞˋ"),
106
+ ("z", "ㄗㄟˋ"),
107
+ ]
108
+ ]
109
+
110
+ # List of (bopomofo, ipa) pairs:
111
+ _bopomofo_to_ipa = [
112
+ (re.compile("%s" % x[0]), x[1])
113
+ for x in [
114
+ ("ㄅㄛ", "p⁼wo"),
115
+ ("ㄆㄛ", "pʰwo"),
116
+ ("ㄇㄛ", "mwo"),
117
+ ("ㄈㄛ", "fwo"),
118
+ ("ㄧㄢ", "|jɛn"),
119
+ ("ㄩㄢ", "|ɥæn"),
120
+ ("ㄧㄣ", "|in"),
121
+ ("ㄩㄣ", "|ɥn"),
122
+ ("ㄧㄥ", "|iŋ"),
123
+ ("ㄨㄥ", "|ʊŋ"),
124
+ ("ㄩㄥ", "|jʊŋ"),
125
+ # Add
126
+ ("ㄧㄚ", "|ia"),
127
+ ("ㄧㄝ", "|iɛ"),
128
+ ("ㄧㄠ", "|iɑʊ"),
129
+ ("ㄧㄡ", "|ioʊ"),
130
+ ("ㄧㄤ", "|iɑŋ"),
131
+ ("ㄨㄚ", "|ua"),
132
+ ("ㄨㄛ", "|uo"),
133
+ ("ㄨㄞ", "|uaɪ"),
134
+ ("ㄨㄟ", "|ueɪ"),
135
+ ("ㄨㄢ", "|uan"),
136
+ ("ㄨㄣ", "|uən"),
137
+ ("ㄨㄤ", "|uɑŋ"),
138
+ ("ㄩㄝ", "|ɥɛ"),
139
+ # End
140
+ ("ㄅ", "p⁼"),
141
+ ("ㄆ", "pʰ"),
142
+ ("ㄇ", "m"),
143
+ ("ㄈ", "f"),
144
+ ("ㄉ", "t⁼"),
145
+ ("ㄊ", "tʰ"),
146
+ ("ㄋ", "n"),
147
+ ("ㄌ", "l"),
148
+ ("ㄍ", "k⁼"),
149
+ ("ㄎ", "kʰ"),
150
+ ("ㄏ", "x"),
151
+ ("ㄐ", "tʃ⁼"),
152
+ ("ㄑ", "tʃʰ"),
153
+ ("ㄒ", "ʃ"),
154
+ ("ㄓ", "ts`⁼"),
155
+ ("ㄔ", "ts`ʰ"),
156
+ ("ㄕ", "s`"),
157
+ ("ㄖ", "ɹ`"),
158
+ ("ㄗ", "ts⁼"),
159
+ ("ㄘ", "tsʰ"),
160
+ ("ㄙ", "|s"),
161
+ ("ㄚ", "|a"),
162
+ ("ㄛ", "|o"),
163
+ ("ㄜ", "|ə"),
164
+ ("ㄝ", "|ɛ"),
165
+ ("ㄞ", "|aɪ"),
166
+ ("ㄟ", "|eɪ"),
167
+ ("ㄠ", "|ɑʊ"),
168
+ ("ㄡ", "|oʊ"),
169
+ ("ㄢ", "|an"),
170
+ ("ㄣ", "|ən"),
171
+ ("ㄤ", "|ɑŋ"),
172
+ ("ㄥ", "|əŋ"),
173
+ ("ㄦ", "əɹ"),
174
+ ("ㄧ", "|i"),
175
+ ("ㄨ", "|u"),
176
+ ("ㄩ", "|ɥ"),
177
+ ("ˉ", "→|"),
178
+ ("ˊ", "↑|"),
179
+ ("ˇ", "↓↑|"),
180
+ ("ˋ", "↓|"),
181
+ ("˙", "|"),
182
+ ]
183
+ ]
184
+ must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
185
+
186
+ word_pinyin_dict = {}
187
+ with open(
188
+ r"src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt",
189
+ "r",
190
+ encoding="utf-8",
191
+ ) as fread:
192
+ txt_list = fread.readlines()
193
+ for txt in txt_list:
194
+ word, pinyin = txt.strip().split("\t")
195
+ word_pinyin_dict[word] = pinyin
196
+ fread.close()
197
+
198
+ pinyin_2_bopomofo_dict = {}
199
+ with open(
200
+ r"./src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt",
201
+ "r",
202
+ encoding="utf-8",
203
+ ) as fread:
204
+ txt_list = fread.readlines()
205
+ for txt in txt_list:
206
+ pinyin, bopomofo = txt.strip().split("\t")
207
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
208
+ fread.close()
209
+
210
+ tone_dict = {
211
+ "0": "˙",
212
+ "5": "˙",
213
+ "1": "",
214
+ "2": "ˊ",
215
+ "3": "ˇ",
216
+ "4": "ˋ",
217
+ }
218
+
219
+ bopomofos2pinyin_dict = {}
220
+ with open(
221
+ r"./src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt",
222
+ "r",
223
+ encoding="utf-8",
224
+ ) as fread:
225
+ txt_list = fread.readlines()
226
+ for txt in txt_list:
227
+ v, k = txt.strip().split("\t")
228
+ bopomofos2pinyin_dict[k] = v
229
+ fread.close()
230
+
231
+
232
+ def bpmf_to_pinyin(text):
233
+ bopomofo_list = text.split("|")
234
+ pinyin_list = []
235
+ for info in bopomofo_list:
236
+ pinyin = ""
237
+ for c in info:
238
+ if c in bopomofos2pinyin_dict:
239
+ pinyin += bopomofos2pinyin_dict[c]
240
+ if len(pinyin) == 0:
241
+ continue
242
+ if pinyin[-1] not in "01234":
243
+ pinyin += "1"
244
+ if pinyin[:-1] == "ve":
245
+ pinyin = "y" + pinyin
246
+ if pinyin[:-1] == "sh":
247
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
248
+ if pinyin == "sh":
249
+ pinyin = pinyin[:-1] + "i"
250
+ if pinyin[:-1] == "s":
251
+ pinyin = "si" + pinyin[-1]
252
+ if pinyin[:-1] == "c":
253
+ pinyin = "ci" + pinyin[-1]
254
+ if pinyin[:-1] == "i":
255
+ pinyin = "yi" + pinyin[-1]
256
+ if pinyin[:-1] == "iou":
257
+ pinyin = "you" + pinyin[-1]
258
+ if pinyin[:-1] == "ien":
259
+ pinyin = "yin" + pinyin[-1]
260
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
261
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
262
+ if "uei" in pinyin:
263
+ if pinyin[:-1] == "uei":
264
+ pinyin = "wei" + pinyin[-1]
265
+ elif pinyin[-4:-1] == "uei":
266
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
267
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
268
+ if pinyin[:-1] == "uen":
269
+ pinyin = "wen" + pinyin[-1]
270
+ elif pinyin[-4:-1] == "uei":
271
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
272
+ if "van" in pinyin and pinyin[-4:-1] == "van":
273
+ if pinyin[:-1] == "van":
274
+ pinyin = "yuan" + pinyin[-1]
275
+ elif pinyin[-4:-1] == "van":
276
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
277
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
278
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
279
+ if pinyin[:-1] == "veng":
280
+ pinyin = "yong" + pinyin[-1]
281
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
282
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
283
+ if pinyin[:-1] == "ieng":
284
+ pinyin = "ying" + pinyin[-1]
285
+ if pinyin[:-1] == "u":
286
+ pinyin = "wu" + pinyin[-1]
287
+ if pinyin[:-1] == "v":
288
+ pinyin = "yv" + pinyin[-1]
289
+ if pinyin[:-1] == "ing":
290
+ pinyin = "ying" + pinyin[-1]
291
+ if pinyin[:-1] == "z":
292
+ pinyin = "zi" + pinyin[-1]
293
+ if pinyin[:-1] == "zh":
294
+ pinyin = "zhi" + pinyin[-1]
295
+ if pinyin[0] == "u":
296
+ pinyin = "w" + pinyin[1:]
297
+ if pinyin[0] == "i":
298
+ pinyin = "y" + pinyin[1:]
299
+ pinyin = pinyin.replace("ien", "in")
300
+
301
+ pinyin_list.append(pinyin)
302
+ return " ".join(pinyin_list)
303
+
304
+
305
+ # Convert numbers to Chinese pronunciation
306
+ def number_to_chinese(text):
307
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
308
+ # for number in numbers:
309
+ # text = text.replace(number, cn2an.an2cn(number), 1)
310
+ text = cn2an.transform(text, "an2cn")
311
+ return text
312
+
313
+
314
+ def normalization(text):
315
+ text = text.replace(",", ",")
316
+ text = text.replace("。", ".")
317
+ text = text.replace("!", "!")
318
+ text = text.replace("?", "?")
319
+ text = text.replace(";", ";")
320
+ text = text.replace(":", ":")
321
+ text = text.replace("、", ",")
322
+ text = text.replace("‘", "'")
323
+ text = text.replace("’", "'")
324
+ text = text.replace("⋯", "…")
325
+ text = text.replace("···", "…")
326
+ text = text.replace("・・・", "…")
327
+ text = text.replace("...", "…")
328
+ text = re.sub(r"\s+", "", text)
329
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
330
+ text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
331
+ return text
332
+
333
+
334
+ def change_tone(bopomofo: str, tone: str) -> str:
335
+ if bopomofo[-1] not in "˙ˊˇˋ":
336
+ bopomofo = bopomofo + tone
337
+ else:
338
+ bopomofo = bopomofo[:-1] + tone
339
+ return bopomofo
340
+
341
+
342
+ def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
343
+ if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
344
+ bopomofos[-1] = change_tone(bopomofos[-1], "˙")
345
+ return bopomofos
346
+
347
+
348
+ def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
349
+ valid_char = set(word)
350
+ if len(valid_char) == 1 and "不" in valid_char:
351
+ pass
352
+ elif word in ["不字"]:
353
+ pass
354
+ elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
355
+ bopomofos[1] = bopomofos[1][:-1] + "˙"
356
+ else:
357
+ for i, char in enumerate(word):
358
+ if (
359
+ i + 1 < len(bopomofos)
360
+ and char == "不"
361
+ and i + 1 < len(word)
362
+ and 0 < len(bopomofos[i + 1])
363
+ and bopomofos[i + 1][-1] == "ˋ"
364
+ ):
365
+ bopomofos[i] = bopomofos[i][:-1] + "ˊ"
366
+ return bopomofos
367
+
368
+
369
+ def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
370
+ punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
371
+ if word.find("一") != -1 and any(
372
+ [item.isnumeric() for item in word if item != "一"]
373
+ ):
374
+ for i in range(len(word)):
375
+ if (
376
+ i == 0
377
+ and word[0] == "一"
378
+ and len(word) > 1
379
+ and word[1]
380
+ not in [
381
+ "零",
382
+ "一",
383
+ "二",
384
+ "三",
385
+ "四",
386
+ "五",
387
+ "六",
388
+ "七",
389
+ "八",
390
+ "九",
391
+ "十",
392
+ ]
393
+ ):
394
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
395
+ bopomofos[0] = change_tone(bopomofos[0], "ˊ")
396
+ else:
397
+ bopomofos[0] = change_tone(bopomofos[0], "ˋ")
398
+ elif word[i] == "一":
399
+ bopomofos[i] = change_tone(bopomofos[i], "")
400
+ return bopomofos
401
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
402
+ bopomofos[1] = change_tone(bopomofos[1], "˙")
403
+ elif word.startswith("第一"):
404
+ bopomofos[1] = change_tone(bopomofos[1], "")
405
+ elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
406
+ bopomofos[0] = change_tone(bopomofos[0], "")
407
+ else:
408
+ for i, char in enumerate(word):
409
+ if char == "一" and i + 1 < len(word):
410
+ if (
411
+ len(bopomofos) > i + 1
412
+ and len(bopomofos[i + 1]) > 0
413
+ and bopomofos[i + 1][-1] in {"ˋ"}
414
+ ):
415
+ bopomofos[i] = change_tone(bopomofos[i], "ˊ")
416
+ else:
417
+ if word[i + 1] not in punc:
418
+ bopomofos[i] = change_tone(bopomofos[i], "ˋ")
419
+ else:
420
+ pass
421
+ return bopomofos
422
+
423
+
424
+ def merge_bu(seg: List) -> List:
425
+ new_seg = []
426
+ last_word = ""
427
+ for word in seg:
428
+ if word != "不":
429
+ if last_word == "不":
430
+ word = last_word + word
431
+ new_seg.append(word)
432
+ last_word = word
433
+ return new_seg
434
+
435
+
436
+ def merge_er(seg: List) -> List:
437
+ new_seg = []
438
+ for i, word in enumerate(seg):
439
+ if i - 1 >= 0 and word == "儿":
440
+ new_seg[-1] = new_seg[-1] + seg[i]
441
+ else:
442
+ new_seg.append(word)
443
+ return new_seg
444
+
445
+
446
+ def merge_yi(seg: List) -> List:
447
+ new_seg = []
448
+ # function 1
449
+ for i, word in enumerate(seg):
450
+ if (
451
+ i - 1 >= 0
452
+ and word == "一"
453
+ and i + 1 < len(seg)
454
+ and seg[i - 1] == seg[i + 1]
455
+ ):
456
+ if i - 1 < len(new_seg):
457
+ new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
458
+ else:
459
+ new_seg.append(word)
460
+ new_seg.append(seg[i + 1])
461
+ else:
462
+ if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
463
+ continue
464
+ else:
465
+ new_seg.append(word)
466
+ seg = new_seg
467
+ new_seg = []
468
+ isnumeric_flag = False
469
+ for i, word in enumerate(seg):
470
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
471
+ isnumeric_flag = True
472
+ new_seg.append(word)
473
+ else:
474
+ new_seg.append(word)
475
+ seg = new_seg
476
+ new_seg = []
477
+ # function 2
478
+ for i, word in enumerate(seg):
479
+ if new_seg and new_seg[-1] == "一":
480
+ new_seg[-1] = new_seg[-1] + word
481
+ else:
482
+ new_seg.append(word)
483
+ return new_seg
484
+
485
+
486
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
487
+ def chinese_to_bopomofo(text_short, sentence):
488
+ # bopomofos = conv(text_short)
489
+ words = jieba.lcut(text_short, cut_all=False)
490
+ words = merge_yi(words)
491
+ words = merge_bu(words)
492
+ words = merge_er(words)
493
+ text = ""
494
+
495
+ char_index = 0
496
+ for word in words:
497
+ bopomofos = []
498
+ if word in word_pinyin_dict and word not in poly_dict:
499
+ pinyin = word_pinyin_dict[word]
500
+ for py in pinyin.split(" "):
501
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
502
+ bopomofos.append(
503
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
504
+ )
505
+ if BLANK_LEVEL == 1:
506
+ bopomofos.append("_")
507
+ else:
508
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
509
+ bopomofos += bopomofos_lazy
510
+ if BLANK_LEVEL == 1:
511
+ bopomofos.append("_")
512
+ else:
513
+ for i in range(len(word)):
514
+ c = word[i]
515
+ if c in poly_dict:
516
+ poly_pinyin = g2pw_poly_predict.predict_process(
517
+ [text_short, char_index + i]
518
+ )[0]
519
+ py = poly_pinyin[2:-1]
520
+ bopomofos.append(
521
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
522
+ )
523
+ if BLANK_LEVEL == 1:
524
+ bopomofos.append("_")
525
+ elif c in word_pinyin_dict:
526
+ py = word_pinyin_dict[c]
527
+ bopomofos.append(
528
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
529
+ )
530
+ if BLANK_LEVEL == 1:
531
+ bopomofos.append("_")
532
+ else:
533
+ bopomofos.append(c)
534
+ if BLANK_LEVEL == 1:
535
+ bopomofos.append("_")
536
+ if BLANK_LEVEL == 2:
537
+ bopomofos.append("_")
538
+ char_index += len(word)
539
+
540
+ if (
541
+ len(word) == 3
542
+ and bopomofos[0][-1] == "ˇ"
543
+ and bopomofos[1][-1] == "ˇ"
544
+ and bopomofos[-1][-1] == "ˇ"
545
+ ):
546
+ bopomofos[0] = bopomofos[0] + "ˊ"
547
+ bopomofos[1] = bopomofos[1] + "ˊ"
548
+ if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
549
+ bopomofos[0] = bopomofos[0][:-1] + "ˊ"
550
+ bopomofos = bu_sandhi(word, bopomofos)
551
+ bopomofos = yi_sandhi(word, bopomofos)
552
+ bopomofos = er_sandhi(word, bopomofos)
553
+ if not re.search("[\u4e00-\u9fff]", word):
554
+ text += "|" + word
555
+ continue
556
+ for i in range(len(bopomofos)):
557
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
558
+ if text != "":
559
+ text += "|"
560
+ text += "|".join(bopomofos)
561
+ return text
562
+
563
+
564
+ # Convert latin pronunciation to pinyin (bopomofo)
565
+ def latin_to_bopomofo(text):
566
+ for regex, replacement in _latin_to_bopomofo:
567
+ text = re.sub(regex, replacement, text)
568
+ return text
569
+
570
+
571
+ # Convert pinyin (bopomofo) to IPA
572
+ def bopomofo_to_ipa(text):
573
+ for regex, replacement in _bopomofo_to_ipa:
574
+ text = re.sub(regex, replacement, text)
575
+ return text
576
+
577
+
578
+ def _chinese_to_ipa(text, sentence):
579
+ text = number_to_chinese(text.strip())
580
+ text = normalization(text)
581
+ text = chinese_to_bopomofo(text, sentence)
582
+ # pinyin = bpmf_to_pinyin(text)
583
+ text = latin_to_bopomofo(text)
584
+ text = bopomofo_to_ipa(text)
585
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
586
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
587
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
588
+ text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
589
+ text = re.sub(r"\|+", "|", text)
590
+ text = text.rstrip("|")
591
+ return text
592
+
593
+
594
+ # Convert Chinese to IPA
595
+ def chinese_to_ipa(text, sentence, text_tokenizer):
596
+ # phonemes = text_tokenizer(text.strip())
597
+ if type(text) == str:
598
+ return _chinese_to_ipa(text, sentence)
599
+ else:
600
+ result_ph = []
601
+ for t in text:
602
+ result_ph.append(_chinese_to_ipa(t, sentence))
603
+ return result_ph
src/YingMusicSinger/utils/f5_tts/g2p/g2p/text_tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from typing import List, Union
8
+
9
+ from phonemizer.backend import EspeakBackend
10
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
11
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
12
+ from phonemizer.separator import Separator
13
+ from phonemizer.utils import list2str, str2list
14
+
15
+
16
+ class TextTokenizer:
17
+ """Phonemize Text."""
18
+
19
+ def __init__(
20
+ self,
21
+ language="en-us",
22
+ backend="espeak",
23
+ separator=Separator(word="|_|", syllable="-", phone="|"),
24
+ preserve_punctuation=True,
25
+ with_stress: bool = False,
26
+ tie: Union[bool, str] = False,
27
+ language_switch: LanguageSwitch = "remove-flags",
28
+ words_mismatch: WordMismatch = "ignore",
29
+ ) -> None:
30
+ self.preserve_punctuation_marks = ",.?!;:'…"
31
+ self.backend = EspeakBackend(
32
+ language,
33
+ punctuation_marks=self.preserve_punctuation_marks,
34
+ preserve_punctuation=preserve_punctuation,
35
+ with_stress=with_stress,
36
+ tie=tie,
37
+ language_switch=language_switch,
38
+ words_mismatch=words_mismatch,
39
+ )
40
+
41
+ self.separator = separator
42
+
43
+ # convert chinese punctuation to english punctuation
44
+ def convert_chinese_punctuation(self, text: str) -> str:
45
+ text = text.replace(",", ",")
46
+ text = text.replace("。", ".")
47
+ text = text.replace("!", "!")
48
+ text = text.replace("?", "?")
49
+ text = text.replace(";", ";")
50
+ text = text.replace(":", ":")
51
+ text = text.replace("、", ",")
52
+ text = text.replace("‘", "'")
53
+ text = text.replace("’", "'")
54
+ text = text.replace("⋯", "…")
55
+ text = text.replace("···", "…")
56
+ text = text.replace("・・・", "…")
57
+ text = text.replace("...", "…")
58
+ return text
59
+
60
+ def __call__(self, text, strip=True) -> List[str]:
61
+ text_type = type(text)
62
+ normalized_text = []
63
+ for line in str2list(text):
64
+ line = self.convert_chinese_punctuation(line.strip())
65
+ line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
66
+ line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
67
+ line = re.sub(r"\s+", " ", line)
68
+ normalized_text.append(line)
69
+ # print("Normalized test: ", normalized_text[0])
70
+ phonemized = self.backend.phonemize(
71
+ normalized_text, separator=self.separator, strip=strip, njobs=1
72
+ )
73
+ if text_type == str:
74
+ phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
75
+ phonemized = re.sub(r"\|+", "|", phonemized)
76
+ phonemized = phonemized.rstrip("|")
77
+ else:
78
+ for i in range(len(phonemized)):
79
+ phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
80
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
81
+ phonemized[i] = phonemized[i].rstrip("|")
82
+ return phonemized
src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iː": 5,
9
+ "ɪ": 6,
10
+ "ɜː": 7,
11
+ "ɚ": 8,
12
+ "oːɹ": 9,
13
+ "ɔː": 10,
14
+ "ɔːɹ": 11,
15
+ "ɑː": 12,
16
+ "uː": 13,
17
+ "ʊ": 14,
18
+ "ɑːɹ": 15,
19
+ "ʌ": 16,
20
+ "ɛ": 17,
21
+ "æ": 18,
22
+ "eɪ": 19,
23
+ "aɪ": 20,
24
+ "ɔɪ": 21,
25
+ "aʊ": 22,
26
+ "oʊ": 23,
27
+ "ɪɹ": 24,
28
+ "ɛɹ": 25,
29
+ "ʊɹ": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ɡ": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "θ": 35,
39
+ "ð": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "ʃ": 39,
43
+ "ʒ": 40,
44
+ "h": 41,
45
+ "tʃ": 42,
46
+ "dʒ": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ŋ": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ɹ": 49,
53
+ "l": 50,
54
+ "tɹ": 51,
55
+ "dɹ": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ɔ": 56,
60
+ "ə": 57,
61
+ "ɾ": 58,
62
+ "iə": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oː": 62,
66
+ "ɛː": 63,
67
+ "ɪː": 64,
68
+ "aɪə": 65,
69
+ "aɪɚ": 66,
70
+ "ɑ̃": 67,
71
+ "ç": 68,
72
+ "ɔ̃": 69,
73
+ "ææ": 70,
74
+ "ɐɐ": 71,
75
+ "ɡʲ": 72,
76
+ "nʲ": 73,
77
+ "iːː": 74,
78
+
79
+ "p⁼": 75,
80
+ "pʰ": 76,
81
+ "t⁼": 77,
82
+ "tʰ": 78,
83
+ "k⁼": 79,
84
+ "kʰ": 80,
85
+ "x": 81,
86
+ "tʃ⁼": 82,
87
+ "tʃʰ": 83,
88
+ "ts`⁼": 84,
89
+ "ts`ʰ": 85,
90
+ "s`": 86,
91
+ "ɹ`": 87,
92
+ "ts⁼": 88,
93
+ "tsʰ": 89,
94
+ "p⁼wo": 90,
95
+ "p⁼wo→": 91,
96
+ "p⁼wo↑": 92,
97
+ "p⁼wo↓↑": 93,
98
+ "p⁼wo↓": 94,
99
+ "pʰwo": 95,
100
+ "pʰwo→": 96,
101
+ "pʰwo↑": 97,
102
+ "pʰwo↓↑": 98,
103
+ "pʰwo↓": 99,
104
+ "mwo": 100,
105
+ "mwo→": 101,
106
+ "mwo↑": 102,
107
+ "mwo↓↑": 103,
108
+ "mwo↓": 104,
109
+ "fwo": 105,
110
+ "fwo→": 106,
111
+ "fwo↑": 107,
112
+ "fwo↓↑": 108,
113
+ "fwo↓": 109,
114
+ "jɛn": 110,
115
+ "jɛn→": 111,
116
+ "jɛn↑": 112,
117
+ "jɛn↓↑": 113,
118
+ "jɛn↓": 114,
119
+ "ɥæn": 115,
120
+ "ɥæn→": 116,
121
+ "ɥæn↑": 117,
122
+ "ɥæn↓↑": 118,
123
+ "ɥæn↓": 119,
124
+ "in": 120,
125
+ "in→": 121,
126
+ "in↑": 122,
127
+ "in↓↑": 123,
128
+ "in↓": 124,
129
+ "ɥn": 125,
130
+ "ɥn→": 126,
131
+ "ɥn↑": 127,
132
+ "ɥn↓↑": 128,
133
+ "ɥn↓": 129,
134
+ "iŋ": 130,
135
+ "iŋ→": 131,
136
+ "iŋ↑": 132,
137
+ "iŋ↓↑": 133,
138
+ "iŋ↓": 134,
139
+ "ʊŋ": 135,
140
+ "ʊŋ→": 136,
141
+ "ʊŋ↑": 137,
142
+ "ʊŋ↓↑": 138,
143
+ "ʊŋ↓": 139,
144
+ "jʊŋ": 140,
145
+ "jʊŋ→": 141,
146
+ "jʊŋ↑": 142,
147
+ "jʊŋ↓↑": 143,
148
+ "jʊŋ↓": 144,
149
+ "ia": 145,
150
+ "ia→": 146,
151
+ "ia↑": 147,
152
+ "ia↓↑": 148,
153
+ "ia↓": 149,
154
+ "iɛ": 150,
155
+ "iɛ→": 151,
156
+ "iɛ↑": 152,
157
+ "iɛ↓↑": 153,
158
+ "iɛ↓": 154,
159
+ "iɑʊ": 155,
160
+ "iɑʊ→": 156,
161
+ "iɑʊ↑": 157,
162
+ "iɑʊ↓↑": 158,
163
+ "iɑʊ↓": 159,
164
+ "ioʊ": 160,
165
+ "ioʊ→": 161,
166
+ "ioʊ↑": 162,
167
+ "ioʊ↓↑": 163,
168
+ "ioʊ↓": 164,
169
+ "iɑŋ": 165,
170
+ "iɑŋ→": 166,
171
+ "iɑŋ↑": 167,
172
+ "iɑŋ↓↑": 168,
173
+ "iɑŋ↓": 169,
174
+ "ua": 170,
175
+ "ua→": 171,
176
+ "ua↑": 172,
177
+ "ua↓↑": 173,
178
+ "ua↓": 174,
179
+ "uo": 175,
180
+ "uo→": 176,
181
+ "uo↑": 177,
182
+ "uo↓↑": 178,
183
+ "uo↓": 179,
184
+ "uaɪ": 180,
185
+ "uaɪ→": 181,
186
+ "uaɪ↑": 182,
187
+ "uaɪ↓↑": 183,
188
+ "uaɪ↓": 184,
189
+ "ueɪ": 185,
190
+ "ueɪ→": 186,
191
+ "ueɪ↑": 187,
192
+ "ueɪ↓↑": 188,
193
+ "ueɪ↓": 189,
194
+ "uan": 190,
195
+ "uan→": 191,
196
+ "uan↑": 192,
197
+ "uan↓↑": 193,
198
+ "uan↓": 194,
199
+ "uən": 195,
200
+ "uən→": 196,
201
+ "uən↑": 197,
202
+ "uən↓↑": 198,
203
+ "uən↓": 199,
204
+ "uɑŋ": 200,
205
+ "uɑŋ→": 201,
206
+ "uɑŋ↑": 202,
207
+ "uɑŋ↓↑": 203,
208
+ "uɑŋ↓": 204,
209
+ "ɥɛ": 205,
210
+ "ɥɛ→": 206,
211
+ "ɥɛ↑": 207,
212
+ "ɥɛ↓↑": 208,
213
+ "ɥɛ↓": 209,
214
+ "a": 210,
215
+ "a→": 211,
216
+ "a↑": 212,
217
+ "a↓↑": 213,
218
+ "a↓": 214,
219
+ "o": 215,
220
+ "o→": 216,
221
+ "o↑": 217,
222
+ "o↓↑": 218,
223
+ "o↓": 219,
224
+ "ə→": 220,
225
+ "ə↑": 221,
226
+ "ə↓↑": 222,
227
+ "ə↓": 223,
228
+ "ɛ→": 224,
229
+ "ɛ↑": 225,
230
+ "ɛ↓↑": 226,
231
+ "ɛ↓": 227,
232
+ "aɪ→": 228,
233
+ "aɪ↑": 229,
234
+ "aɪ↓↑": 230,
235
+ "aɪ↓": 231,
236
+ "eɪ→": 232,
237
+ "eɪ↑": 233,
238
+ "eɪ↓↑": 234,
239
+ "eɪ↓": 235,
240
+ "ɑʊ": 236,
241
+ "ɑʊ→": 237,
242
+ "ɑʊ↑": 238,
243
+ "ɑʊ↓↑": 239,
244
+ "ɑʊ↓": 240,
245
+ "oʊ→": 241,
246
+ "oʊ↑": 242,
247
+ "oʊ↓↑": 243,
248
+ "oʊ↓": 244,
249
+ "an": 245,
250
+ "an→": 246,
251
+ "an↑": 247,
252
+ "an↓↑": 248,
253
+ "an↓": 249,
254
+ "ən": 250,
255
+ "ən→": 251,
256
+ "ən↑": 252,
257
+ "ən↓↑": 253,
258
+ "ən↓": 254,
259
+ "ɑŋ": 255,
260
+ "ɑŋ→": 256,
261
+ "ɑŋ↑": 257,
262
+ "ɑŋ↓↑": 258,
263
+ "ɑŋ↓": 259,
264
+ "əŋ": 260,
265
+ "əŋ→": 261,
266
+ "əŋ↑": 262,
267
+ "əŋ↓↑": 263,
268
+ "əŋ↓": 264,
269
+ "əɹ": 265,
270
+ "əɹ→": 266,
271
+ "əɹ↑": 267,
272
+ "əɹ↓↑": 268,
273
+ "əɹ↓": 269,
274
+ "i→": 270,
275
+ "i↑": 271,
276
+ "i↓↑": 272,
277
+ "i↓": 273,
278
+ "u→": 274,
279
+ "u↑": 275,
280
+ "u↓↑": 276,
281
+ "u↓": 277,
282
+ "ɥ": 278,
283
+ "ɥ→": 279,
284
+ "ɥ↑": 280,
285
+ "ɥ↓↑": 281,
286
+ "ɥ↓": 282,
287
+ "ts`⁼ɹ": 283,
288
+ "ts`⁼ɹ→": 284,
289
+ "ts`⁼ɹ↑": 285,
290
+ "ts`⁼ɹ↓↑": 286,
291
+ "ts`⁼ɹ↓": 287,
292
+ "ts`ʰɹ": 288,
293
+ "ts`ʰɹ→": 289,
294
+ "ts`ʰɹ↑": 290,
295
+ "ts`ʰɹ↓↑": 291,
296
+ "ts`ʰɹ↓": 292,
297
+ "s`ɹ": 293,
298
+ "s`ɹ→": 294,
299
+ "s`ɹ↑": 295,
300
+ "s`ɹ↓↑": 296,
301
+ "s`ɹ���": 297,
302
+ "ɹ`ɹ": 298,
303
+ "ɹ`ɹ→": 299,
304
+ "ɹ`ɹ↑": 300,
305
+ "ɹ`ɹ↓↑": 301,
306
+ "ɹ`ɹ↓": 302,
307
+ "ts⁼ɹ": 303,
308
+ "ts⁼ɹ→": 304,
309
+ "ts⁼ɹ↑": 305,
310
+ "ts⁼ɹ↓↑": 306,
311
+ "ts⁼ɹ↓": 307,
312
+ "tsʰɹ": 308,
313
+ "tsʰɹ→": 309,
314
+ "tsʰɹ↑": 310,
315
+ "tsʰɹ↓↑": 311,
316
+ "tsʰɹ↓": 312,
317
+ "sɹ": 313,
318
+ "sɹ→": 314,
319
+ "sɹ↑": 315,
320
+ "sɹ↓↑": 316,
321
+ "sɹ↓": 317,
322
+
323
+ "ɯ": 318,
324
+ "e": 319,
325
+ "aː": 320,
326
+ "ɯː": 321,
327
+ "eː": 322,
328
+ "ç": 323,
329
+ "ɸ": 324,
330
+ "ɰᵝ": 325,
331
+ "ɴ": 326,
332
+ "g": 327,
333
+ "dʑ": 328,
334
+ "q": 329,
335
+ "ː": 330,
336
+ "bj": 331,
337
+ "tɕ": 332,
338
+ "dej": 333,
339
+ "tej": 334,
340
+ "gj": 335,
341
+ "gɯ": 336,
342
+ "çj": 337,
343
+ "kj": 338,
344
+ "kɯ": 339,
345
+ "mj": 340,
346
+ "nj": 341,
347
+ "pj": 342,
348
+ "ɾj": 343,
349
+ "ɕ": 344,
350
+ "tsɯ": 345,
351
+
352
+ "ɐ": 346,
353
+ "ɑ": 347,
354
+ "ɒ": 348,
355
+ "ɜ": 349,
356
+ "ɫ": 350,
357
+ "ʑ": 351,
358
+ "ʲ": 352,
359
+
360
+ "y": 353,
361
+ "ø": 354,
362
+ "œ": 355,
363
+ "ʁ": 356,
364
+ "̃": 357,
365
+ "ɲ": 358,
366
+
367
+ ":": 359,
368
+ ";": 360,
369
+ "'": 361,
370
+ "…": 362
371
+ }
372
+ }
src/YingMusicSinger/utils/f5_tts/g2p/g2p_generation.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import json
8
+ from typing import List
9
+
10
+ from src.YingMusicSinger.utils.f5_tts.g2p.g2p import PhonemeBpeTokenizer
11
+ from src.YingMusicSinger.utils.f5_tts.g2p.utils.g2p import phonemizer_g2p
12
+
13
+
14
+ def ph_g2p(text, language):
15
+ return phonemizer_g2p(text=text, language=language)
16
+
17
+
18
+ def g2p(text, sentence, language):
19
+ return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
20
+
21
+
22
+ def is_chinese(char):
23
+ if char >= "\u4e00" and char <= "\u9fa5":
24
+ return True
25
+ else:
26
+ return False
27
+
28
+
29
+ def is_alphabet(char):
30
+ if (char >= "\u0041" and char <= "\u005a") or (
31
+ char >= "\u0061" and char <= "\u007a"
32
+ ):
33
+ return True
34
+ else:
35
+ return False
36
+
37
+
38
+ def is_other(char):
39
+ if not (is_chinese(char) or is_alphabet(char)):
40
+ return True
41
+ else:
42
+ return False
43
+
44
+
45
+ def get_segment(text: str) -> List[str]:
46
+ # sentence --> [ch_part, en_part, ch_part, ...]
47
+ segments = []
48
+ types = []
49
+ flag = 0
50
+ temp_seg = ""
51
+ temp_lang = ""
52
+
53
+ # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
54
+ for i, ch in enumerate(text):
55
+ if is_chinese(ch):
56
+ types.append("zh")
57
+ elif is_alphabet(ch):
58
+ types.append("en")
59
+ else:
60
+ types.append("other")
61
+
62
+ assert len(types) == len(text)
63
+
64
+ for i in range(len(types)):
65
+ # find the first char of the seg
66
+ if flag == 0:
67
+ temp_seg += text[i]
68
+ temp_lang = types[i]
69
+ flag = 1
70
+ else:
71
+ if temp_lang == "other":
72
+ if types[i] == temp_lang:
73
+ temp_seg += text[i]
74
+ else:
75
+ temp_seg += text[i]
76
+ temp_lang = types[i]
77
+ else:
78
+ if types[i] == temp_lang:
79
+ temp_seg += text[i]
80
+ elif types[i] == "other":
81
+ temp_seg += text[i]
82
+ else:
83
+ segments.append((temp_seg, temp_lang))
84
+ temp_seg = text[i]
85
+ temp_lang = types[i]
86
+ flag = 1
87
+
88
+ segments.append((temp_seg, temp_lang))
89
+ return segments
90
+
91
+
92
+ def chn_eng_g2p(text: str):
93
+ # now only en and ch
94
+ segments = get_segment(text)
95
+ all_phoneme = ""
96
+ all_tokens = []
97
+
98
+ for index in range(len(segments)):
99
+ seg = segments[index]
100
+ phoneme, token = g2p(seg[0], text, seg[1])
101
+ all_phoneme += phoneme + "|"
102
+ all_tokens += token
103
+
104
+ if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
105
+ all_phoneme = all_phoneme[:-2]
106
+ all_tokens = all_tokens[:-1]
107
+ return all_phoneme, all_tokens
108
+
109
+
110
+ text_tokenizer = PhonemeBpeTokenizer()
111
+ with open("./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json", "r") as f:
112
+ json_data = f.read()
113
+ data = json.loads(json_data)
114
+ vocab = data["vocab"]
115
+
116
+ if __name__ == "__main__":
117
+ phone, token = chn_eng_g2p("你好,hello world")
118
+ phone, token = chn_eng_g2p(
119
+ "你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑"
120
+ )
121
+ print(phone)
122
+ print(token)
123
+
124
+ # phone, token = text_tokenizer.tokenize("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑", "", "auto")
125
+ phone, token = text_tokenizer.tokenize("緑", "", "auto")
126
+ # phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "auto")
127
+ # phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "other")
128
+ print(phone)
129
+ print(token)
src/YingMusicSinger/utils/f5_tts/g2p/infer_dpo.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+
6
+ import numpy as np
7
+ import torch
8
+ from f5_tts.infer.utils_infer import load_checkpoint
9
+ from f5_tts.model import CFM, DiT
10
+ from f5_tts.model.alsp_lance.data.npydata import FloatData
11
+ from f5_tts.model.alsp_lance.tools import LanceReader, LanceWriter
12
+ from tqdm import tqdm
13
+
14
+ filter_keyword_list = [
15
+ "纯音乐",
16
+ "编曲",
17
+ "作词",
18
+ "作曲",
19
+ "调音",
20
+ "制作人",
21
+ "录音师",
22
+ ]
23
+
24
+ filter_full_list = ["music", "end"]
25
+
26
+
27
+ def check_lyric(time: float, lyric: str):
28
+ if time < 0.1:
29
+ return False
30
+ for filter_keyword in filter_keyword_list:
31
+ if filter_keyword in lyric:
32
+ return False
33
+ for filter_full in filter_full_list:
34
+ if filter_full == lyric.strip().lower():
35
+ return False
36
+ if len(lyric) == 0:
37
+ return False
38
+ return True
39
+
40
+
41
+ def parse_lyrics(lyrics: str):
42
+ lyrics_with_time = []
43
+ lyrics = lyrics.strip()
44
+ for line in lyrics.split("\n"):
45
+ try:
46
+ time, lyric = line[1:9], line[10:]
47
+ lyric = lyric.strip()
48
+ mins, secs = time.split(":")
49
+ secs = int(mins) * 60 + float(secs)
50
+ # print(lyric, check_lyric(secs, lyric))
51
+ if not check_lyric(secs, lyric):
52
+ continue
53
+ lyrics_with_time.append((secs, lyric))
54
+ except:
55
+ # traceback.print_exc()
56
+ continue
57
+ # print("error", line)
58
+ return lyrics_with_time
59
+
60
+
61
+ class CNENTokenizer:
62
+ def __init__(self):
63
+ with open("./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json", "r") as file:
64
+ self.phone2id: dict = json.load(file)["vocab"]
65
+ self.id2phone = {v: k for (k, v) in self.phone2id.items()}
66
+ from f5_tts.g2p.g2p_generation import chn_eng_g2p
67
+
68
+ self.tokenizer = chn_eng_g2p
69
+
70
+ def encode(self, text):
71
+ phone, token = self.tokenizer(text)
72
+ token = [x + 1 for x in token]
73
+ return token
74
+
75
+ def decode(self, token):
76
+ return "|".join([self.id2phone[x - 1] for x in token])
77
+
78
+
79
+ def inference(
80
+ model,
81
+ cond,
82
+ text,
83
+ duration,
84
+ style_prompt,
85
+ style,
86
+ output_dir,
87
+ song_name,
88
+ ckpt_step,
89
+ start_time,
90
+ latent_pred_start_frame,
91
+ latent_pred_end_frame,
92
+ epoch,
93
+ cfg_strength,
94
+ ):
95
+ # import pdb; pdb.set_trace()
96
+ with torch.inference_mode():
97
+ generated, _ = model.sample(
98
+ cond=cond,
99
+ text=text,
100
+ duration=duration,
101
+ style_prompt=style_prompt,
102
+ steps=32,
103
+ cfg_strength=cfg_strength,
104
+ sway_sampling_coef=None,
105
+ start_time=start_time,
106
+ latent_pred_start_frame=latent_pred_start_frame,
107
+ latent_pred_end_frame=latent_pred_end_frame,
108
+ )
109
+
110
+ generated = generated.to(torch.float32) # [b t d]
111
+ latent = generated.transpose(1, 2) # [b d t]
112
+ latent = latent.detach().cpu.numpy()
113
+
114
+ return latent
115
+
116
+
117
+ def get_style_prompt(device, song_name, song_name2ref_npy):
118
+ mulan_style_path = song_name2ref_npy[song_name]
119
+ mulan_stlye = np.load(mulan_style_path)
120
+
121
+ style_prompt = torch.from_numpy(mulan_stlye).to(device) # [1, 512]
122
+ style_prompt = style_prompt.half()
123
+
124
+ return style_prompt
125
+
126
+
127
+ def get_lrc_prompt(text, tokenizer, dit_model, max_secs):
128
+ max_frames = 2048
129
+ lyrics_shift = 2
130
+ sampling_rate = 44100
131
+ downsample_rate = 2048
132
+
133
+ pad_token_id = 0
134
+ comma_token_id = 1
135
+ period_token_id = 2
136
+
137
+ fsmin = -10
138
+ fsmax = 10
139
+
140
+ lrc_with_time = parse_lyrics(text)
141
+
142
+ modified_lrc_with_time = []
143
+ for i in range(len(lrc_with_time)):
144
+ time, line = lrc_with_time[i]
145
+ # line_token = self.tokenizer.encode(line)
146
+ line_token = tokenizer.encode(line)
147
+ modified_lrc_with_time.append((time, line_token))
148
+
149
+ lrc_with_time = modified_lrc_with_time
150
+
151
+ lrc_with_time = [
152
+ (time_start, line)
153
+ for (time_start, line) in lrc_with_time
154
+ if time_start < max_secs
155
+ ]
156
+ # latent_end_time = lrc_with_time[-1][0] if len(lrc_with_time) >= 1 else -1
157
+ lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
158
+
159
+ normalized_start_time = 0.0
160
+
161
+ lrc = torch.zeros((max_frames,), dtype=torch.long)
162
+
163
+ tokens_count = 0
164
+ last_end_pos = 0
165
+ for time_start, line in lrc_with_time:
166
+ tokens = [
167
+ token if token != period_token_id else comma_token_id for token in line
168
+ ] + [period_token_id]
169
+ tokens = torch.tensor(tokens, dtype=torch.long)
170
+ num_tokens = tokens.shape[0]
171
+
172
+ gt_frame_start = int(time_start * sampling_rate / downsample_rate)
173
+
174
+ frame_shift = random.randint(int(fsmin), int(fsmax))
175
+
176
+ frame_start = max(gt_frame_start - frame_shift, last_end_pos)
177
+ frame_len = min(num_tokens, max_frames - frame_start)
178
+
179
+ # print(gt_frame_start, frame_shift, frame_start, frame_len, tokens_count, last_end_pos, full_pos_emb.shape)
180
+
181
+ lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
182
+
183
+ tokens_count += num_tokens
184
+ last_end_pos = frame_start + frame_len
185
+
186
+ lrc_emb = lrc.unsqueeze(0).to(dit_model.device)
187
+
188
+ normalized_start_time = (
189
+ torch.tensor(normalized_start_time).unsqueeze(0).to(dit_model.device)
190
+ )
191
+
192
+ return lrc_emb, normalized_start_time
193
+
194
+
195
+ if __name__ == "__main__":
196
+ parser = argparse.ArgumentParser()
197
+
198
+ parser.add_argument("--model-config", type=str, default=None)
199
+ parser.add_argument("--ckpt-path", type=str, default=None)
200
+ parser.add_argument("--output-dir", type=str, default=None) # lance
201
+ parser.add_argument("--lrc-path", type=str, default=None)
202
+ parser.add_argument("--mulan-style-path", type=str, default=None) # lance
203
+ parser.add_argument("--cfg-strength", type=float, default=None)
204
+
205
+ args = parser.parse_args()
206
+
207
+ lrc_path = args.lrc_path
208
+ cfg_strength = args.cfg_strength
209
+ style_path = args.mulan_style_path
210
+
211
+ with open(args.model_config) as f:
212
+ model_config = json.load(f)
213
+
214
+ model_cls = DiT
215
+ ckpt_path = args.ckpt_path
216
+ device = "cuda"
217
+ use_style_prompt = True
218
+ dit_model = CFM(
219
+ transformer=model_cls(
220
+ **model_config["model"], use_style_prompt=use_style_prompt
221
+ ),
222
+ num_channels=model_config["model"]["mel_dim"],
223
+ use_style_prompt=use_style_prompt,
224
+ )
225
+ dit_model = dit_model.to(device)
226
+ dit_model = load_checkpoint(dit_model, ckpt_path, device=device, use_ema=True)
227
+
228
+ lrc_tokenizer = CNENTokenizer()
229
+
230
+ sampling_rate = 44100
231
+ downsample_rate = 2048
232
+ max_frames = 2048
233
+ max_secs = max_frames / (sampling_rate / downsample_rate)
234
+
235
+ output_dir = args.output_dir
236
+ writer = LanceWriter(output_dir, target_cls=FloatData)
237
+
238
+ reader = LanceReader(style_path, target_cls=FloatData)
239
+
240
+ WRITE_INTERVAL = 500
241
+
242
+ latent_data = []
243
+ for id in tqdm(reader.get_ids()):
244
+ item = reader.get_datas_by_rowids(row_ids=[id._rowid])[0]
245
+ data_id = item.data_id
246
+ style_prompt = torch.from_numpy(item.data).to(device)
247
+ stlye_prompt = style_prompt.half()
248
+
249
+ lrc_path = os.path.join(lrc_path, f"{data_id}.lrc")
250
+ with (open(lrc_path), "r") as f:
251
+ lrc = [line.strip() for line in f.readlines()]
252
+ lrc_prompt, start_time = get_lrc_prompt(lrc, lrc_tokenizer, dit_model, max_secs)
253
+
254
+ latent_prompt = torch.zeros(1, max_frames, 64).to(device)
255
+ sf = 0
256
+ ef = max_frames
257
+
258
+ generated_latent = inference(
259
+ model=dit_model,
260
+ cond=latent_prompt,
261
+ text=lrc_prompt,
262
+ duration=max_frames,
263
+ style_prompt=style_prompt,
264
+ output_dir=output_dir,
265
+ start_time=start_time,
266
+ latent_pred_start_frame=sf,
267
+ latent_pred_end_frame=ef,
268
+ cfg_strength=cfg_strength,
269
+ ) # [b d t] numpy
270
+
271
+ latent_data.append(generated_latent)
272
+
273
+ if len(latent_data) > WRITE_INTERVAL:
274
+ writer.write_parallel(latent_data)
275
+ latent_data = []
276
+
277
+ writer.write_parallel(latent_data)
src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b ㄅ
2
+ p ㄆ
3
+ m ㄇ
4
+ f ㄈ
5
+ d ㄉ
6
+ t ㄊ
7
+ n ㄋ
8
+ l ㄌ
9
+ g ㄍ
10
+ k ㄎ
11
+ h ㄏ
12
+ j ㄐ
13
+ q ㄑ
14
+ x ㄒ
15
+ zh ㄓ
16
+ ch ㄔ
17
+ sh ㄕ
18
+ r ㄖ
19
+ z ㄗ
20
+ c ㄘ
21
+ s ㄙ
22
+ i ㄧ
23
+ u ㄨ
24
+ v ㄩ
25
+ a ㄚ
26
+ o ㄛ
27
+ e ㄜ
28
+ e ㄝ
29
+ ai ㄞ
30
+ ei ㄟ
31
+ ao ㄠ
32
+ ou ㄡ
33
+ an ㄢ
34
+ en ㄣ
35
+ ang ㄤ
36
+ eng ㄥ
37
+ er ㄦ
38
+ 2 ˊ
39
+ 3 ˇ
40
+ 4 ˋ
41
+ 0 ˙
src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
3
+ size 14779443
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/config.json ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
3
+ "architectures": [
4
+ "BertPoly"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5",
20
+ "6": "LABEL_6",
21
+ "7": "LABEL_7",
22
+ "8": "LABEL_8",
23
+ "9": "LABEL_9",
24
+ "10": "LABEL_10",
25
+ "11": "LABEL_11",
26
+ "12": "LABEL_12",
27
+ "13": "LABEL_13",
28
+ "14": "LABEL_14",
29
+ "15": "LABEL_15",
30
+ "16": "LABEL_16",
31
+ "17": "LABEL_17",
32
+ "18": "LABEL_18",
33
+ "19": "LABEL_19",
34
+ "20": "LABEL_20",
35
+ "21": "LABEL_21",
36
+ "22": "LABEL_22",
37
+ "23": "LABEL_23",
38
+ "24": "LABEL_24",
39
+ "25": "LABEL_25",
40
+ "26": "LABEL_26",
41
+ "27": "LABEL_27",
42
+ "28": "LABEL_28",
43
+ "29": "LABEL_29",
44
+ "30": "LABEL_30",
45
+ "31": "LABEL_31",
46
+ "32": "LABEL_32",
47
+ "33": "LABEL_33",
48
+ "34": "LABEL_34",
49
+ "35": "LABEL_35",
50
+ "36": "LABEL_36",
51
+ "37": "LABEL_37",
52
+ "38": "LABEL_38",
53
+ "39": "LABEL_39",
54
+ "40": "LABEL_40",
55
+ "41": "LABEL_41",
56
+ "42": "LABEL_42",
57
+ "43": "LABEL_43",
58
+ "44": "LABEL_44",
59
+ "45": "LABEL_45",
60
+ "46": "LABEL_46",
61
+ "47": "LABEL_47",
62
+ "48": "LABEL_48",
63
+ "49": "LABEL_49",
64
+ "50": "LABEL_50",
65
+ "51": "LABEL_51",
66
+ "52": "LABEL_52",
67
+ "53": "LABEL_53",
68
+ "54": "LABEL_54",
69
+ "55": "LABEL_55",
70
+ "56": "LABEL_56",
71
+ "57": "LABEL_57",
72
+ "58": "LABEL_58",
73
+ "59": "LABEL_59",
74
+ "60": "LABEL_60",
75
+ "61": "LABEL_61",
76
+ "62": "LABEL_62",
77
+ "63": "LABEL_63",
78
+ "64": "LABEL_64",
79
+ "65": "LABEL_65",
80
+ "66": "LABEL_66",
81
+ "67": "LABEL_67",
82
+ "68": "LABEL_68",
83
+ "69": "LABEL_69",
84
+ "70": "LABEL_70",
85
+ "71": "LABEL_71",
86
+ "72": "LABEL_72",
87
+ "73": "LABEL_73",
88
+ "74": "LABEL_74",
89
+ "75": "LABEL_75",
90
+ "76": "LABEL_76",
91
+ "77": "LABEL_77",
92
+ "78": "LABEL_78",
93
+ "79": "LABEL_79",
94
+ "80": "LABEL_80",
95
+ "81": "LABEL_81",
96
+ "82": "LABEL_82",
97
+ "83": "LABEL_83",
98
+ "84": "LABEL_84",
99
+ "85": "LABEL_85",
100
+ "86": "LABEL_86",
101
+ "87": "LABEL_87",
102
+ "88": "LABEL_88",
103
+ "89": "LABEL_89",
104
+ "90": "LABEL_90",
105
+ "91": "LABEL_91",
106
+ "92": "LABEL_92",
107
+ "93": "LABEL_93",
108
+ "94": "LABEL_94",
109
+ "95": "LABEL_95",
110
+ "96": "LABEL_96",
111
+ "97": "LABEL_97",
112
+ "98": "LABEL_98",
113
+ "99": "LABEL_99",
114
+ "100": "LABEL_100",
115
+ "101": "LABEL_101",
116
+ "102": "LABEL_102",
117
+ "103": "LABEL_103",
118
+ "104": "LABEL_104",
119
+ "105": "LABEL_105",
120
+ "106": "LABEL_106",
121
+ "107": "LABEL_107",
122
+ "108": "LABEL_108",
123
+ "109": "LABEL_109",
124
+ "110": "LABEL_110",
125
+ "111": "LABEL_111",
126
+ "112": "LABEL_112",
127
+ "113": "LABEL_113",
128
+ "114": "LABEL_114",
129
+ "115": "LABEL_115",
130
+ "116": "LABEL_116",
131
+ "117": "LABEL_117",
132
+ "118": "LABEL_118",
133
+ "119": "LABEL_119",
134
+ "120": "LABEL_120",
135
+ "121": "LABEL_121",
136
+ "122": "LABEL_122",
137
+ "123": "LABEL_123",
138
+ "124": "LABEL_124",
139
+ "125": "LABEL_125",
140
+ "126": "LABEL_126",
141
+ "127": "LABEL_127",
142
+ "128": "LABEL_128",
143
+ "129": "LABEL_129",
144
+ "130": "LABEL_130",
145
+ "131": "LABEL_131",
146
+ "132": "LABEL_132",
147
+ "133": "LABEL_133",
148
+ "134": "LABEL_134",
149
+ "135": "LABEL_135",
150
+ "136": "LABEL_136",
151
+ "137": "LABEL_137",
152
+ "138": "LABEL_138",
153
+ "139": "LABEL_139",
154
+ "140": "LABEL_140",
155
+ "141": "LABEL_141",
156
+ "142": "LABEL_142",
157
+ "143": "LABEL_143",
158
+ "144": "LABEL_144",
159
+ "145": "LABEL_145",
160
+ "146": "LABEL_146",
161
+ "147": "LABEL_147",
162
+ "148": "LABEL_148",
163
+ "149": "LABEL_149",
164
+ "150": "LABEL_150",
165
+ "151": "LABEL_151",
166
+ "152": "LABEL_152",
167
+ "153": "LABEL_153",
168
+ "154": "LABEL_154",
169
+ "155": "LABEL_155",
170
+ "156": "LABEL_156",
171
+ "157": "LABEL_157",
172
+ "158": "LABEL_158",
173
+ "159": "LABEL_159",
174
+ "160": "LABEL_160",
175
+ "161": "LABEL_161",
176
+ "162": "LABEL_162",
177
+ "163": "LABEL_163",
178
+ "164": "LABEL_164",
179
+ "165": "LABEL_165",
180
+ "166": "LABEL_166",
181
+ "167": "LABEL_167",
182
+ "168": "LABEL_168",
183
+ "169": "LABEL_169",
184
+ "170": "LABEL_170",
185
+ "171": "LABEL_171",
186
+ "172": "LABEL_172",
187
+ "173": "LABEL_173",
188
+ "174": "LABEL_174",
189
+ "175": "LABEL_175",
190
+ "176": "LABEL_176",
191
+ "177": "LABEL_177",
192
+ "178": "LABEL_178",
193
+ "179": "LABEL_179",
194
+ "180": "LABEL_180",
195
+ "181": "LABEL_181",
196
+ "182": "LABEL_182",
197
+ "183": "LABEL_183",
198
+ "184": "LABEL_184",
199
+ "185": "LABEL_185",
200
+ "186": "LABEL_186",
201
+ "187": "LABEL_187",
202
+ "188": "LABEL_188",
203
+ "189": "LABEL_189",
204
+ "190": "LABEL_190",
205
+ "191": "LABEL_191",
206
+ "192": "LABEL_192",
207
+ "193": "LABEL_193",
208
+ "194": "LABEL_194",
209
+ "195": "LABEL_195",
210
+ "196": "LABEL_196",
211
+ "197": "LABEL_197",
212
+ "198": "LABEL_198",
213
+ "199": "LABEL_199",
214
+ "200": "LABEL_200",
215
+ "201": "LABEL_201",
216
+ "202": "LABEL_202",
217
+ "203": "LABEL_203",
218
+ "204": "LABEL_204",
219
+ "205": "LABEL_205",
220
+ "206": "LABEL_206",
221
+ "207": "LABEL_207",
222
+ "208": "LABEL_208",
223
+ "209": "LABEL_209",
224
+ "210": "LABEL_210",
225
+ "211": "LABEL_211",
226
+ "212": "LABEL_212",
227
+ "213": "LABEL_213",
228
+ "214": "LABEL_214",
229
+ "215": "LABEL_215",
230
+ "216": "LABEL_216",
231
+ "217": "LABEL_217",
232
+ "218": "LABEL_218",
233
+ "219": "LABEL_219",
234
+ "220": "LABEL_220",
235
+ "221": "LABEL_221",
236
+ "222": "LABEL_222",
237
+ "223": "LABEL_223",
238
+ "224": "LABEL_224",
239
+ "225": "LABEL_225",
240
+ "226": "LABEL_226",
241
+ "227": "LABEL_227",
242
+ "228": "LABEL_228",
243
+ "229": "LABEL_229",
244
+ "230": "LABEL_230",
245
+ "231": "LABEL_231",
246
+ "232": "LABEL_232",
247
+ "233": "LABEL_233",
248
+ "234": "LABEL_234",
249
+ "235": "LABEL_235",
250
+ "236": "LABEL_236",
251
+ "237": "LABEL_237",
252
+ "238": "LABEL_238",
253
+ "239": "LABEL_239",
254
+ "240": "LABEL_240",
255
+ "241": "LABEL_241",
256
+ "242": "LABEL_242",
257
+ "243": "LABEL_243",
258
+ "244": "LABEL_244",
259
+ "245": "LABEL_245",
260
+ "246": "LABEL_246",
261
+ "247": "LABEL_247",
262
+ "248": "LABEL_248",
263
+ "249": "LABEL_249",
264
+ "250": "LABEL_250",
265
+ "251": "LABEL_251",
266
+ "252": "LABEL_252",
267
+ "253": "LABEL_253",
268
+ "254": "LABEL_254",
269
+ "255": "LABEL_255",
270
+ "256": "LABEL_256",
271
+ "257": "LABEL_257",
272
+ "258": "LABEL_258",
273
+ "259": "LABEL_259",
274
+ "260": "LABEL_260",
275
+ "261": "LABEL_261",
276
+ "262": "LABEL_262",
277
+ "263": "LABEL_263",
278
+ "264": "LABEL_264",
279
+ "265": "LABEL_265",
280
+ "266": "LABEL_266",
281
+ "267": "LABEL_267",
282
+ "268": "LABEL_268",
283
+ "269": "LABEL_269",
284
+ "270": "LABEL_270",
285
+ "271": "LABEL_271",
286
+ "272": "LABEL_272",
287
+ "273": "LABEL_273",
288
+ "274": "LABEL_274",
289
+ "275": "LABEL_275",
290
+ "276": "LABEL_276",
291
+ "277": "LABEL_277",
292
+ "278": "LABEL_278",
293
+ "279": "LABEL_279",
294
+ "280": "LABEL_280",
295
+ "281": "LABEL_281",
296
+ "282": "LABEL_282",
297
+ "283": "LABEL_283",
298
+ "284": "LABEL_284",
299
+ "285": "LABEL_285",
300
+ "286": "LABEL_286",
301
+ "287": "LABEL_287",
302
+ "288": "LABEL_288",
303
+ "289": "LABEL_289",
304
+ "290": "LABEL_290",
305
+ "291": "LABEL_291",
306
+ "292": "LABEL_292",
307
+ "293": "LABEL_293",
308
+ "294": "LABEL_294",
309
+ "295": "LABEL_295",
310
+ "296": "LABEL_296",
311
+ "297": "LABEL_297",
312
+ "298": "LABEL_298",
313
+ "299": "LABEL_299",
314
+ "300": "LABEL_300",
315
+ "301": "LABEL_301",
316
+ "302": "LABEL_302",
317
+ "303": "LABEL_303",
318
+ "304": "LABEL_304",
319
+ "305": "LABEL_305",
320
+ "306": "LABEL_306",
321
+ "307": "LABEL_307",
322
+ "308": "LABEL_308",
323
+ "309": "LABEL_309",
324
+ "310": "LABEL_310",
325
+ "311": "LABEL_311",
326
+ "312": "LABEL_312",
327
+ "313": "LABEL_313",
328
+ "314": "LABEL_314",
329
+ "315": "LABEL_315",
330
+ "316": "LABEL_316",
331
+ "317": "LABEL_317",
332
+ "318": "LABEL_318",
333
+ "319": "LABEL_319",
334
+ "320": "LABEL_320",
335
+ "321": "LABEL_321",
336
+ "322": "LABEL_322",
337
+ "323": "LABEL_323",
338
+ "324": "LABEL_324",
339
+ "325": "LABEL_325",
340
+ "326": "LABEL_326",
341
+ "327": "LABEL_327",
342
+ "328": "LABEL_328",
343
+ "329": "LABEL_329",
344
+ "330": "LABEL_330",
345
+ "331": "LABEL_331",
346
+ "332": "LABEL_332",
347
+ "333": "LABEL_333",
348
+ "334": "LABEL_334",
349
+ "335": "LABEL_335",
350
+ "336": "LABEL_336",
351
+ "337": "LABEL_337",
352
+ "338": "LABEL_338",
353
+ "339": "LABEL_339",
354
+ "340": "LABEL_340",
355
+ "341": "LABEL_341",
356
+ "342": "LABEL_342",
357
+ "343": "LABEL_343",
358
+ "344": "LABEL_344",
359
+ "345": "LABEL_345",
360
+ "346": "LABEL_346",
361
+ "347": "LABEL_347",
362
+ "348": "LABEL_348",
363
+ "349": "LABEL_349",
364
+ "350": "LABEL_350",
365
+ "351": "LABEL_351",
366
+ "352": "LABEL_352",
367
+ "353": "LABEL_353",
368
+ "354": "LABEL_354",
369
+ "355": "LABEL_355",
370
+ "356": "LABEL_356",
371
+ "357": "LABEL_357",
372
+ "358": "LABEL_358",
373
+ "359": "LABEL_359",
374
+ "360": "LABEL_360",
375
+ "361": "LABEL_361",
376
+ "362": "LABEL_362",
377
+ "363": "LABEL_363",
378
+ "364": "LABEL_364",
379
+ "365": "LABEL_365",
380
+ "366": "LABEL_366",
381
+ "367": "LABEL_367",
382
+ "368": "LABEL_368",
383
+ "369": "LABEL_369",
384
+ "370": "LABEL_370",
385
+ "371": "LABEL_371",
386
+ "372": "LABEL_372",
387
+ "373": "LABEL_373",
388
+ "374": "LABEL_374",
389
+ "375": "LABEL_375",
390
+ "376": "LABEL_376",
391
+ "377": "LABEL_377",
392
+ "378": "LABEL_378",
393
+ "379": "LABEL_379",
394
+ "380": "LABEL_380",
395
+ "381": "LABEL_381",
396
+ "382": "LABEL_382",
397
+ "383": "LABEL_383",
398
+ "384": "LABEL_384",
399
+ "385": "LABEL_385",
400
+ "386": "LABEL_386",
401
+ "387": "LABEL_387",
402
+ "388": "LABEL_388",
403
+ "389": "LABEL_389",
404
+ "390": "LABEL_390"
405
+ },
406
+ "initializer_range": 0.02,
407
+ "intermediate_size": 1536,
408
+ "label2id": {
409
+ "LABEL_0": 0,
410
+ "LABEL_1": 1,
411
+ "LABEL_10": 10,
412
+ "LABEL_100": 100,
413
+ "LABEL_101": 101,
414
+ "LABEL_102": 102,
415
+ "LABEL_103": 103,
416
+ "LABEL_104": 104,
417
+ "LABEL_105": 105,
418
+ "LABEL_106": 106,
419
+ "LABEL_107": 107,
420
+ "LABEL_108": 108,
421
+ "LABEL_109": 109,
422
+ "LABEL_11": 11,
423
+ "LABEL_110": 110,
424
+ "LABEL_111": 111,
425
+ "LABEL_112": 112,
426
+ "LABEL_113": 113,
427
+ "LABEL_114": 114,
428
+ "LABEL_115": 115,
429
+ "LABEL_116": 116,
430
+ "LABEL_117": 117,
431
+ "LABEL_118": 118,
432
+ "LABEL_119": 119,
433
+ "LABEL_12": 12,
434
+ "LABEL_120": 120,
435
+ "LABEL_121": 121,
436
+ "LABEL_122": 122,
437
+ "LABEL_123": 123,
438
+ "LABEL_124": 124,
439
+ "LABEL_125": 125,
440
+ "LABEL_126": 126,
441
+ "LABEL_127": 127,
442
+ "LABEL_128": 128,
443
+ "LABEL_129": 129,
444
+ "LABEL_13": 13,
445
+ "LABEL_130": 130,
446
+ "LABEL_131": 131,
447
+ "LABEL_132": 132,
448
+ "LABEL_133": 133,
449
+ "LABEL_134": 134,
450
+ "LABEL_135": 135,
451
+ "LABEL_136": 136,
452
+ "LABEL_137": 137,
453
+ "LABEL_138": 138,
454
+ "LABEL_139": 139,
455
+ "LABEL_14": 14,
456
+ "LABEL_140": 140,
457
+ "LABEL_141": 141,
458
+ "LABEL_142": 142,
459
+ "LABEL_143": 143,
460
+ "LABEL_144": 144,
461
+ "LABEL_145": 145,
462
+ "LABEL_146": 146,
463
+ "LABEL_147": 147,
464
+ "LABEL_148": 148,
465
+ "LABEL_149": 149,
466
+ "LABEL_15": 15,
467
+ "LABEL_150": 150,
468
+ "LABEL_151": 151,
469
+ "LABEL_152": 152,
470
+ "LABEL_153": 153,
471
+ "LABEL_154": 154,
472
+ "LABEL_155": 155,
473
+ "LABEL_156": 156,
474
+ "LABEL_157": 157,
475
+ "LABEL_158": 158,
476
+ "LABEL_159": 159,
477
+ "LABEL_16": 16,
478
+ "LABEL_160": 160,
479
+ "LABEL_161": 161,
480
+ "LABEL_162": 162,
481
+ "LABEL_163": 163,
482
+ "LABEL_164": 164,
483
+ "LABEL_165": 165,
484
+ "LABEL_166": 166,
485
+ "LABEL_167": 167,
486
+ "LABEL_168": 168,
487
+ "LABEL_169": 169,
488
+ "LABEL_17": 17,
489
+ "LABEL_170": 170,
490
+ "LABEL_171": 171,
491
+ "LABEL_172": 172,
492
+ "LABEL_173": 173,
493
+ "LABEL_174": 174,
494
+ "LABEL_175": 175,
495
+ "LABEL_176": 176,
496
+ "LABEL_177": 177,
497
+ "LABEL_178": 178,
498
+ "LABEL_179": 179,
499
+ "LABEL_18": 18,
500
+ "LABEL_180": 180,
501
+ "LABEL_181": 181,
502
+ "LABEL_182": 182,
503
+ "LABEL_183": 183,
504
+ "LABEL_184": 184,
505
+ "LABEL_185": 185,
506
+ "LABEL_186": 186,
507
+ "LABEL_187": 187,
508
+ "LABEL_188": 188,
509
+ "LABEL_189": 189,
510
+ "LABEL_19": 19,
511
+ "LABEL_190": 190,
512
+ "LABEL_191": 191,
513
+ "LABEL_192": 192,
514
+ "LABEL_193": 193,
515
+ "LABEL_194": 194,
516
+ "LABEL_195": 195,
517
+ "LABEL_196": 196,
518
+ "LABEL_197": 197,
519
+ "LABEL_198": 198,
520
+ "LABEL_199": 199,
521
+ "LABEL_2": 2,
522
+ "LABEL_20": 20,
523
+ "LABEL_200": 200,
524
+ "LABEL_201": 201,
525
+ "LABEL_202": 202,
526
+ "LABEL_203": 203,
527
+ "LABEL_204": 204,
528
+ "LABEL_205": 205,
529
+ "LABEL_206": 206,
530
+ "LABEL_207": 207,
531
+ "LABEL_208": 208,
532
+ "LABEL_209": 209,
533
+ "LABEL_21": 21,
534
+ "LABEL_210": 210,
535
+ "LABEL_211": 211,
536
+ "LABEL_212": 212,
537
+ "LABEL_213": 213,
538
+ "LABEL_214": 214,
539
+ "LABEL_215": 215,
540
+ "LABEL_216": 216,
541
+ "LABEL_217": 217,
542
+ "LABEL_218": 218,
543
+ "LABEL_219": 219,
544
+ "LABEL_22": 22,
545
+ "LABEL_220": 220,
546
+ "LABEL_221": 221,
547
+ "LABEL_222": 222,
548
+ "LABEL_223": 223,
549
+ "LABEL_224": 224,
550
+ "LABEL_225": 225,
551
+ "LABEL_226": 226,
552
+ "LABEL_227": 227,
553
+ "LABEL_228": 228,
554
+ "LABEL_229": 229,
555
+ "LABEL_23": 23,
556
+ "LABEL_230": 230,
557
+ "LABEL_231": 231,
558
+ "LABEL_232": 232,
559
+ "LABEL_233": 233,
560
+ "LABEL_234": 234,
561
+ "LABEL_235": 235,
562
+ "LABEL_236": 236,
563
+ "LABEL_237": 237,
564
+ "LABEL_238": 238,
565
+ "LABEL_239": 239,
566
+ "LABEL_24": 24,
567
+ "LABEL_240": 240,
568
+ "LABEL_241": 241,
569
+ "LABEL_242": 242,
570
+ "LABEL_243": 243,
571
+ "LABEL_244": 244,
572
+ "LABEL_245": 245,
573
+ "LABEL_246": 246,
574
+ "LABEL_247": 247,
575
+ "LABEL_248": 248,
576
+ "LABEL_249": 249,
577
+ "LABEL_25": 25,
578
+ "LABEL_250": 250,
579
+ "LABEL_251": 251,
580
+ "LABEL_252": 252,
581
+ "LABEL_253": 253,
582
+ "LABEL_254": 254,
583
+ "LABEL_255": 255,
584
+ "LABEL_256": 256,
585
+ "LABEL_257": 257,
586
+ "LABEL_258": 258,
587
+ "LABEL_259": 259,
588
+ "LABEL_26": 26,
589
+ "LABEL_260": 260,
590
+ "LABEL_261": 261,
591
+ "LABEL_262": 262,
592
+ "LABEL_263": 263,
593
+ "LABEL_264": 264,
594
+ "LABEL_265": 265,
595
+ "LABEL_266": 266,
596
+ "LABEL_267": 267,
597
+ "LABEL_268": 268,
598
+ "LABEL_269": 269,
599
+ "LABEL_27": 27,
600
+ "LABEL_270": 270,
601
+ "LABEL_271": 271,
602
+ "LABEL_272": 272,
603
+ "LABEL_273": 273,
604
+ "LABEL_274": 274,
605
+ "LABEL_275": 275,
606
+ "LABEL_276": 276,
607
+ "LABEL_277": 277,
608
+ "LABEL_278": 278,
609
+ "LABEL_279": 279,
610
+ "LABEL_28": 28,
611
+ "LABEL_280": 280,
612
+ "LABEL_281": 281,
613
+ "LABEL_282": 282,
614
+ "LABEL_283": 283,
615
+ "LABEL_284": 284,
616
+ "LABEL_285": 285,
617
+ "LABEL_286": 286,
618
+ "LABEL_287": 287,
619
+ "LABEL_288": 288,
620
+ "LABEL_289": 289,
621
+ "LABEL_29": 29,
622
+ "LABEL_290": 290,
623
+ "LABEL_291": 291,
624
+ "LABEL_292": 292,
625
+ "LABEL_293": 293,
626
+ "LABEL_294": 294,
627
+ "LABEL_295": 295,
628
+ "LABEL_296": 296,
629
+ "LABEL_297": 297,
630
+ "LABEL_298": 298,
631
+ "LABEL_299": 299,
632
+ "LABEL_3": 3,
633
+ "LABEL_30": 30,
634
+ "LABEL_300": 300,
635
+ "LABEL_301": 301,
636
+ "LABEL_302": 302,
637
+ "LABEL_303": 303,
638
+ "LABEL_304": 304,
639
+ "LABEL_305": 305,
640
+ "LABEL_306": 306,
641
+ "LABEL_307": 307,
642
+ "LABEL_308": 308,
643
+ "LABEL_309": 309,
644
+ "LABEL_31": 31,
645
+ "LABEL_310": 310,
646
+ "LABEL_311": 311,
647
+ "LABEL_312": 312,
648
+ "LABEL_313": 313,
649
+ "LABEL_314": 314,
650
+ "LABEL_315": 315,
651
+ "LABEL_316": 316,
652
+ "LABEL_317": 317,
653
+ "LABEL_318": 318,
654
+ "LABEL_319": 319,
655
+ "LABEL_32": 32,
656
+ "LABEL_320": 320,
657
+ "LABEL_321": 321,
658
+ "LABEL_322": 322,
659
+ "LABEL_323": 323,
660
+ "LABEL_324": 324,
661
+ "LABEL_325": 325,
662
+ "LABEL_326": 326,
663
+ "LABEL_327": 327,
664
+ "LABEL_328": 328,
665
+ "LABEL_329": 329,
666
+ "LABEL_33": 33,
667
+ "LABEL_330": 330,
668
+ "LABEL_331": 331,
669
+ "LABEL_332": 332,
670
+ "LABEL_333": 333,
671
+ "LABEL_334": 334,
672
+ "LABEL_335": 335,
673
+ "LABEL_336": 336,
674
+ "LABEL_337": 337,
675
+ "LABEL_338": 338,
676
+ "LABEL_339": 339,
677
+ "LABEL_34": 34,
678
+ "LABEL_340": 340,
679
+ "LABEL_341": 341,
680
+ "LABEL_342": 342,
681
+ "LABEL_343": 343,
682
+ "LABEL_344": 344,
683
+ "LABEL_345": 345,
684
+ "LABEL_346": 346,
685
+ "LABEL_347": 347,
686
+ "LABEL_348": 348,
687
+ "LABEL_349": 349,
688
+ "LABEL_35": 35,
689
+ "LABEL_350": 350,
690
+ "LABEL_351": 351,
691
+ "LABEL_352": 352,
692
+ "LABEL_353": 353,
693
+ "LABEL_354": 354,
694
+ "LABEL_355": 355,
695
+ "LABEL_356": 356,
696
+ "LABEL_357": 357,
697
+ "LABEL_358": 358,
698
+ "LABEL_359": 359,
699
+ "LABEL_36": 36,
700
+ "LABEL_360": 360,
701
+ "LABEL_361": 361,
702
+ "LABEL_362": 362,
703
+ "LABEL_363": 363,
704
+ "LABEL_364": 364,
705
+ "LABEL_365": 365,
706
+ "LABEL_366": 366,
707
+ "LABEL_367": 367,
708
+ "LABEL_368": 368,
709
+ "LABEL_369": 369,
710
+ "LABEL_37": 37,
711
+ "LABEL_370": 370,
712
+ "LABEL_371": 371,
713
+ "LABEL_372": 372,
714
+ "LABEL_373": 373,
715
+ "LABEL_374": 374,
716
+ "LABEL_375": 375,
717
+ "LABEL_376": 376,
718
+ "LABEL_377": 377,
719
+ "LABEL_378": 378,
720
+ "LABEL_379": 379,
721
+ "LABEL_38": 38,
722
+ "LABEL_380": 380,
723
+ "LABEL_381": 381,
724
+ "LABEL_382": 382,
725
+ "LABEL_383": 383,
726
+ "LABEL_384": 384,
727
+ "LABEL_385": 385,
728
+ "LABEL_386": 386,
729
+ "LABEL_387": 387,
730
+ "LABEL_388": 388,
731
+ "LABEL_389": 389,
732
+ "LABEL_39": 39,
733
+ "LABEL_390": 390,
734
+ "LABEL_4": 4,
735
+ "LABEL_40": 40,
736
+ "LABEL_41": 41,
737
+ "LABEL_42": 42,
738
+ "LABEL_43": 43,
739
+ "LABEL_44": 44,
740
+ "LABEL_45": 45,
741
+ "LABEL_46": 46,
742
+ "LABEL_47": 47,
743
+ "LABEL_48": 48,
744
+ "LABEL_49": 49,
745
+ "LABEL_5": 5,
746
+ "LABEL_50": 50,
747
+ "LABEL_51": 51,
748
+ "LABEL_52": 52,
749
+ "LABEL_53": 53,
750
+ "LABEL_54": 54,
751
+ "LABEL_55": 55,
752
+ "LABEL_56": 56,
753
+ "LABEL_57": 57,
754
+ "LABEL_58": 58,
755
+ "LABEL_59": 59,
756
+ "LABEL_6": 6,
757
+ "LABEL_60": 60,
758
+ "LABEL_61": 61,
759
+ "LABEL_62": 62,
760
+ "LABEL_63": 63,
761
+ "LABEL_64": 64,
762
+ "LABEL_65": 65,
763
+ "LABEL_66": 66,
764
+ "LABEL_67": 67,
765
+ "LABEL_68": 68,
766
+ "LABEL_69": 69,
767
+ "LABEL_7": 7,
768
+ "LABEL_70": 70,
769
+ "LABEL_71": 71,
770
+ "LABEL_72": 72,
771
+ "LABEL_73": 73,
772
+ "LABEL_74": 74,
773
+ "LABEL_75": 75,
774
+ "LABEL_76": 76,
775
+ "LABEL_77": 77,
776
+ "LABEL_78": 78,
777
+ "LABEL_79": 79,
778
+ "LABEL_8": 8,
779
+ "LABEL_80": 80,
780
+ "LABEL_81": 81,
781
+ "LABEL_82": 82,
782
+ "LABEL_83": 83,
783
+ "LABEL_84": 84,
784
+ "LABEL_85": 85,
785
+ "LABEL_86": 86,
786
+ "LABEL_87": 87,
787
+ "LABEL_88": 88,
788
+ "LABEL_89": 89,
789
+ "LABEL_9": 9,
790
+ "LABEL_90": 90,
791
+ "LABEL_91": 91,
792
+ "LABEL_92": 92,
793
+ "LABEL_93": 93,
794
+ "LABEL_94": 94,
795
+ "LABEL_95": 95,
796
+ "LABEL_96": 96,
797
+ "LABEL_97": 97,
798
+ "LABEL_98": 98,
799
+ "LABEL_99": 99
800
+ },
801
+ "layer_norm_eps": 1e-12,
802
+ "max_position_embeddings": 512,
803
+ "model_type": "bert",
804
+ "num_attention_heads": 12,
805
+ "num_hidden_layers": 6,
806
+ "num_relation_heads": 32,
807
+ "pad_token_id": 0,
808
+ "pooler_fc_size": 768,
809
+ "pooler_num_attention_heads": 12,
810
+ "pooler_num_fc_layers": 3,
811
+ "pooler_size_per_head": 128,
812
+ "pooler_type": "first_token_transform",
813
+ "position_embedding_type": "absolute",
814
+ "torch_dtype": "float32",
815
+ "transformers_version": "4.44.1",
816
+ "type_vocab_size": 2,
817
+ "use_cache": true,
818
+ "vocab_size": 21128
819
+ }
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/poly_bert_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
3
+ size 76583333
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polychar.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+
10
+
11
+
12
+
13
+
14
+
15
+ 便
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+ 宿
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": "丧{sang1}",
3
+ "2": "丧{sang4}",
4
+ "3": "中{zhong1}",
5
+ "4": "中{zhong4}",
6
+ "5": "为{wei2}",
7
+ "6": "为{wei4}",
8
+ "7": "乌{wu1}",
9
+ "8": "乌{wu4}",
10
+ "9": "乐{lao4}",
11
+ "10": "乐{le4}",
12
+ "11": "乐{le5}",
13
+ "12": "乐{yao4}",
14
+ "13": "乐{yve4}",
15
+ "14": "了{le5}",
16
+ "15": "了{liao3}",
17
+ "16": "了{liao5}",
18
+ "17": "什{shen2}",
19
+ "18": "什{shi2}",
20
+ "19": "仔{zai3}",
21
+ "20": "仔{zai5}",
22
+ "21": "仔{zi3}",
23
+ "22": "仔{zi5}",
24
+ "23": "令{ling2}",
25
+ "24": "令{ling4}",
26
+ "25": "任{ren2}",
27
+ "26": "任{ren4}",
28
+ "27": "会{hui4}",
29
+ "28": "会{hui5}",
30
+ "29": "会{kuai4}",
31
+ "30": "传{chuan2}",
32
+ "31": "传{zhuan4}",
33
+ "32": "佛{fo2}",
34
+ "33": "佛{fu2}",
35
+ "34": "供{gong1}",
36
+ "35": "供{gong4}",
37
+ "36": "便{bian4}",
38
+ "37": "便{pian2}",
39
+ "38": "倒{dao3}",
40
+ "39": "倒{dao4}",
41
+ "40": "假{jia3}",
42
+ "41": "假{jia4}",
43
+ "42": "兴{xing1}",
44
+ "43": "兴{xing4}",
45
+ "44": "冠{guan1}",
46
+ "45": "冠{guan4}",
47
+ "46": "冲{chong1}",
48
+ "47": "冲{chong4}",
49
+ "48": "几{ji1}",
50
+ "49": "几{ji2}",
51
+ "50": "几{ji3}",
52
+ "51": "分{fen1}",
53
+ "52": "分{fen4}",
54
+ "53": "分{fen5}",
55
+ "54": "切{qie1}",
56
+ "55": "切{qie4}",
57
+ "56": "划{hua2}",
58
+ "57": "划{hua4}",
59
+ "58": "划{hua5}",
60
+ "59": "创{chuang1}",
61
+ "60": "创{chuang4}",
62
+ "61": "剥{bao1}",
63
+ "62": "剥{bo1}",
64
+ "63": "勒{le4}",
65
+ "64": "勒{le5}",
66
+ "65": "勒{lei1}",
67
+ "66": "区{ou1}",
68
+ "67": "区{qu1}",
69
+ "68": "华{hua2}",
70
+ "69": "华{hua4}",
71
+ "70": "单{chan2}",
72
+ "71": "单{dan1}",
73
+ "72": "单{shan4}",
74
+ "73": "卜{bo5}",
75
+ "74": "卜{bu3}",
76
+ "75": "占{zhan1}",
77
+ "76": "占{zhan4}",
78
+ "77": "卡{ka2}",
79
+ "78": "卡{ka3}",
80
+ "79": "卡{qia3}",
81
+ "80": "卷{jvan3}",
82
+ "81": "卷{jvan4}",
83
+ "82": "厦{sha4}",
84
+ "83": "厦{xia4}",
85
+ "84": "参{can1}",
86
+ "85": "参{cen1}",
87
+ "86": "参{shen1}",
88
+ "87": "发{fa1}",
89
+ "88": "发{fa4}",
90
+ "89": "发{fa5}",
91
+ "90": "只{zhi1}",
92
+ "91": "只{zhi3}",
93
+ "92": "号{hao2}",
94
+ "93": "号{hao4}",
95
+ "94": "号{hao5}",
96
+ "95": "同{tong2}",
97
+ "96": "同{tong4}",
98
+ "97": "同{tong5}",
99
+ "98": "吐{tu2}",
100
+ "99": "吐{tu3}",
101
+ "100": "吐{tu4}",
102
+ "101": "和{he2}",
103
+ "102": "和{he4}",
104
+ "103": "和{he5}",
105
+ "104": "和{huo2}",
106
+ "105": "和{huo4}",
107
+ "106": "和{huo5}",
108
+ "107": "喝{he1}",
109
+ "108": "喝{he4}",
110
+ "109": "圈{jvan4}",
111
+ "110": "圈{qvan1}",
112
+ "111": "圈{qvan5}",
113
+ "112": "地{de5}",
114
+ "113": "地{di4}",
115
+ "114": "地{di5}",
116
+ "115": "塞{sai1}",
117
+ "116": "塞{sai2}",
118
+ "117": "塞{sai4}",
119
+ "118": "塞{se4}",
120
+ "119": "壳{ke2}",
121
+ "120": "壳{qiao4}",
122
+ "121": "处{chu3}",
123
+ "122": "处{chu4}",
124
+ "123": "奇{ji1}",
125
+ "124": "奇{qi2}",
126
+ "125": "奔{ben1}",
127
+ "126": "奔{ben4}",
128
+ "127": "好{hao3}",
129
+ "128": "好{hao4}",
130
+ "129": "好{hao5}",
131
+ "130": "宁{ning2}",
132
+ "131": "宁{ning4}",
133
+ "132": "宁{ning5}",
134
+ "133": "宿{su4}",
135
+ "134": "宿{xiu3}",
136
+ "135": "宿{xiu4}",
137
+ "136": "将{jiang1}",
138
+ "137": "将{jiang4}",
139
+ "138": "少{shao3}",
140
+ "139": "少{shao4}",
141
+ "140": "尽{jin3}",
142
+ "141": "尽{jin4}",
143
+ "142": "岗{gang1}",
144
+ "143": "岗{gang3}",
145
+ "144": "差{cha1}",
146
+ "145": "差{cha4}",
147
+ "146": "差{chai1}",
148
+ "147": "差{ci1}",
149
+ "148": "巷{hang4}",
150
+ "149": "巷{xiang4}",
151
+ "150": "帖{tie1}",
152
+ "151": "帖{tie3}",
153
+ "152": "帖{tie4}",
154
+ "153": "干{gan1}",
155
+ "154": "干{gan4}",
156
+ "155": "应{ying1}",
157
+ "156": "应{ying4}",
158
+ "157": "应{ying5}",
159
+ "158": "度{du4}",
160
+ "159": "度{du5}",
161
+ "160": "度{duo2}",
162
+ "161": "弹{dan4}",
163
+ "162": "弹{tan2}",
164
+ "163": "弹{tan5}",
165
+ "164": "强{jiang4}",
166
+ "165": "强{qiang2}",
167
+ "166": "强{qiang3}",
168
+ "167": "当{dang1}",
169
+ "168": "当{dang4}",
170
+ "169": "当{dang5}",
171
+ "170": "待{dai1}",
172
+ "171": "待{dai4}",
173
+ "172": "得{de2}",
174
+ "173": "得{de5}",
175
+ "174": "得{dei3}",
176
+ "175": "得{dei5}",
177
+ "176": "恶{e3}",
178
+ "177": "恶{e4}",
179
+ "178": "恶{wu4}",
180
+ "179": "扁{bian3}",
181
+ "180": "扁{pian1}",
182
+ "181": "扇{shan1}",
183
+ "182": "扇{shan4}",
184
+ "183": "扎{za1}",
185
+ "184": "扎{zha1}",
186
+ "185": "扎{zha2}",
187
+ "186": "扫{sao3}",
188
+ "187": "扫{sao4}",
189
+ "188": "担{dan1}",
190
+ "189": "担{dan4}",
191
+ "190": "担{dan5}",
192
+ "191": "挑{tiao1}",
193
+ "192": "挑{tiao3}",
194
+ "193": "据{jv1}",
195
+ "194": "据{jv4}",
196
+ "195": "撒{sa1}",
197
+ "196": "撒{sa3}",
198
+ "197": "撒{sa5}",
199
+ "198": "教{jiao1}",
200
+ "199": "教{jiao4}",
201
+ "200": "散{san3}",
202
+ "201": "散{san4}",
203
+ "202": "散{san5}",
204
+ "203": "数{shu3}",
205
+ "204": "数{shu4}",
206
+ "205": "数{shu5}",
207
+ "206": "斗{dou3}",
208
+ "207": "斗{dou4}",
209
+ "208": "晃{huang3}",
210
+ "209": "曝{bao4}",
211
+ "210": "曲{qu1}",
212
+ "211": "曲{qu3}",
213
+ "212": "更{geng1}",
214
+ "213": "更{geng4}",
215
+ "214": "曾{ceng1}",
216
+ "215": "曾{ceng2}",
217
+ "216": "曾{zeng1}",
218
+ "217": "朝{chao2}",
219
+ "218": "朝{zhao1}",
220
+ "219": "朴{piao2}",
221
+ "220": "朴{pu2}",
222
+ "221": "朴{pu3}",
223
+ "222": "杆{gan1}",
224
+ "223": "杆{gan3}",
225
+ "224": "查{cha2}",
226
+ "225": "查{zha1}",
227
+ "226": "校{jiao4}",
228
+ "227": "校{xiao4}",
229
+ "228": "模{mo2}",
230
+ "229": "模{mu2}",
231
+ "230": "横{heng2}",
232
+ "231": "横{heng4}",
233
+ "232": "没{mei2}",
234
+ "233": "没{mo4}",
235
+ "234": "泡{pao1}",
236
+ "235": "泡{pao4}",
237
+ "236": "泡{pao5}",
238
+ "237": "济{ji3}",
239
+ "238": "济{ji4}",
240
+ "239": "混{hun2}",
241
+ "240": "混{hun3}",
242
+ "241": "混{hun4}",
243
+ "242": "混{hun5}",
244
+ "243": "漂{piao1}",
245
+ "244": "漂{piao3}",
246
+ "245": "漂{piao4}",
247
+ "246": "炸{zha2}",
248
+ "247": "炸{zha4}",
249
+ "248": "熟{shou2}",
250
+ "249": "熟{shu2}",
251
+ "250": "燕{yan1}",
252
+ "251": "燕{yan4}",
253
+ "252": "片{pian1}",
254
+ "253": "片{pian4}",
255
+ "254": "率{lv4}",
256
+ "255": "率{shuai4}",
257
+ "256": "畜{chu4}",
258
+ "257": "畜{xu4}",
259
+ "258": "的{de5}",
260
+ "259": "的{di1}",
261
+ "260": "的{di2}",
262
+ "261": "的{di4}",
263
+ "262": "的{di5}",
264
+ "263": "盛{cheng2}",
265
+ "264": "盛{sheng4}",
266
+ "265": "相{xiang1}",
267
+ "266": "相{xiang4}",
268
+ "267": "相{xiang5}",
269
+ "268": "省{sheng3}",
270
+ "269": "省{xing3}",
271
+ "270": "看{kan1}",
272
+ "271": "看{kan4}",
273
+ "272": "看{kan5}",
274
+ "273": "着{zhao1}",
275
+ "274": "着{zhao2}",
276
+ "275": "着{zhao5}",
277
+ "276": "着{zhe5}",
278
+ "277": "着{zhuo2}",
279
+ "278": "着{zhuo5}",
280
+ "279": "矫{jiao3}",
281
+ "280": "禁{jin1}",
282
+ "281": "禁{jin4}",
283
+ "282": "种{zhong3}",
284
+ "283": "种{zhong4}",
285
+ "284": "称{chen4}",
286
+ "285": "称{cheng1}",
287
+ "286": "空{kong1}",
288
+ "287": "空{kong4}",
289
+ "288": "答{da1}",
290
+ "289": "答{da2}",
291
+ "290": "粘{nian2}",
292
+ "291": "粘{zhan1}",
293
+ "292": "糊{hu2}",
294
+ "293": "糊{hu5}",
295
+ "294": "系{ji4}",
296
+ "295": "系{xi4}",
297
+ "296": "系{xi5}",
298
+ "297": "累{lei2}",
299
+ "298": "累{lei3}",
300
+ "299": "累{lei4}",
301
+ "300": "累{lei5}",
302
+ "301": "纤{qian4}",
303
+ "302": "纤{xian1}",
304
+ "303": "结{jie1}",
305
+ "304": "结{jie2}",
306
+ "305": "结{jie5}",
307
+ "306": "给{gei3}",
308
+ "307": "给{gei5}",
309
+ "308": "给{ji3}",
310
+ "309": "缝{feng2}",
311
+ "310": "缝{feng4}",
312
+ "311": "缝{feng5}",
313
+ "312": "肖{xiao1}",
314
+ "313": "肖{xiao4}",
315
+ "314": "背{bei1}",
316
+ "315": "背{bei4}",
317
+ "316": "脏{zang1}",
318
+ "317": "脏{zang4}",
319
+ "318": "舍{she3}",
320
+ "319": "舍{she4}",
321
+ "320": "色{se4}",
322
+ "321": "色{shai3}",
323
+ "322": "落{lao4}",
324
+ "323": "落{luo4}",
325
+ "324": "蒙{meng1}",
326
+ "325": "蒙{meng2}",
327
+ "326": "蒙{meng3}",
328
+ "327": "薄{bao2}",
329
+ "328": "薄{bo2}",
330
+ "329": "薄{bo4}",
331
+ "330": "藏{cang2}",
332
+ "331": "藏{zang4}",
333
+ "332": "血{xie3}",
334
+ "333": "血{xue4}",
335
+ "334": "行{hang2}",
336
+ "335": "行{hang5}",
337
+ "336": "行{heng5}",
338
+ "337": "行{xing2}",
339
+ "338": "行{xing4}",
340
+ "339": "要{yao1}",
341
+ "340": "要{yao4}",
342
+ "341": "观{guan1}",
343
+ "342": "观{guan4}",
344
+ "343": "觉{jiao4}",
345
+ "344": "觉{jiao5}",
346
+ "345": "觉{jve2}",
347
+ "346": "角{jiao3}",
348
+ "347": "角{jve2}",
349
+ "348": "解{jie3}",
350
+ "349": "解{jie4}",
351
+ "350": "解{xie4}",
352
+ "351": "说{shui4}",
353
+ "352": "说{shuo1}",
354
+ "353": "调{diao4}",
355
+ "354": "调{tiao2}",
356
+ "355": "踏{ta1}",
357
+ "356": "踏{ta4}",
358
+ "357": "车{che1}",
359
+ "358": "车{jv1}",
360
+ "359": "转{zhuan3}",
361
+ "360": "转{zhuan4}",
362
+ "361": "载{zai3}",
363
+ "362": "载{zai4}",
364
+ "363": "还{hai2}",
365
+ "364": "还{huan2}",
366
+ "365": "遂{sui2}",
367
+ "366": "遂{sui4}",
368
+ "367": "都{dou1}",
369
+ "368": "都{du1}",
370
+ "369": "重{chong2}",
371
+ "370": "重{zhong4}",
372
+ "371": "量{liang2}",
373
+ "372": "量{liang4}",
374
+ "373": "量{liang5}",
375
+ "374": "钻{zuan1}",
376
+ "375": "钻{zuan4}",
377
+ "376": "铺{pu1}",
378
+ "377": "铺{pu4}",
379
+ "378": "长{chang2}",
380
+ "379": "长{chang3}",
381
+ "380": "长{zhang3}",
382
+ "381": "间{jian1}",
383
+ "382": "间{jian4}",
384
+ "383": "降{jiang4}",
385
+ "384": "降{xiang2}",
386
+ "385": "难{nan2}",
387
+ "386": "难{nan4}",
388
+ "387": "难{nan5}",
389
+ "388": "露{lou4}",
390
+ "389": "露{lu4}",
391
+ "390": "鲜{xian1}",
392
+ "391": "鲜{xian3}"
393
+ }
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict_r.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "丧{sang1}": 1,
3
+ "丧{sang4}": 2,
4
+ "中{zhong1}": 3,
5
+ "中{zhong4}": 4,
6
+ "为{wei2}": 5,
7
+ "为{wei4}": 6,
8
+ "乌{wu1}": 7,
9
+ "乌{wu4}": 8,
10
+ "乐{lao4}": 9,
11
+ "乐{le4}": 10,
12
+ "乐{le5}": 11,
13
+ "乐{yao4}": 12,
14
+ "乐{yve4}": 13,
15
+ "了{le5}": 14,
16
+ "了{liao3}": 15,
17
+ "了{liao5}": 16,
18
+ "什{shen2}": 17,
19
+ "什{shi2}": 18,
20
+ "仔{zai3}": 19,
21
+ "仔{zai5}": 20,
22
+ "仔{zi3}": 21,
23
+ "仔{zi5}": 22,
24
+ "令{ling2}": 23,
25
+ "令{ling4}": 24,
26
+ "任{ren2}": 25,
27
+ "任{ren4}": 26,
28
+ "会{hui4}": 27,
29
+ "会{hui5}": 28,
30
+ "会{kuai4}": 29,
31
+ "传{chuan2}": 30,
32
+ "传{zhuan4}": 31,
33
+ "佛{fo2}": 32,
34
+ "佛{fu2}": 33,
35
+ "供{gong1}": 34,
36
+ "供{gong4}": 35,
37
+ "便{bian4}": 36,
38
+ "便{pian2}": 37,
39
+ "倒{dao3}": 38,
40
+ "倒{dao4}": 39,
41
+ "假{jia3}": 40,
42
+ "假{jia4}": 41,
43
+ "兴{xing1}": 42,
44
+ "兴{xing4}": 43,
45
+ "冠{guan1}": 44,
46
+ "冠{guan4}": 45,
47
+ "冲{chong1}": 46,
48
+ "冲{chong4}": 47,
49
+ "几{ji1}": 48,
50
+ "几{ji2}": 49,
51
+ "几{ji3}": 50,
52
+ "分{fen1}": 51,
53
+ "分{fen4}": 52,
54
+ "分{fen5}": 53,
55
+ "切{qie1}": 54,
56
+ "切{qie4}": 55,
57
+ "划{hua2}": 56,
58
+ "划{hua4}": 57,
59
+ "划{hua5}": 58,
60
+ "创{chuang1}": 59,
61
+ "创{chuang4}": 60,
62
+ "剥{bao1}": 61,
63
+ "剥{bo1}": 62,
64
+ "勒{le4}": 63,
65
+ "勒{le5}": 64,
66
+ "勒{lei1}": 65,
67
+ "区{ou1}": 66,
68
+ "区{qu1}": 67,
69
+ "华{hua2}": 68,
70
+ "华{hua4}": 69,
71
+ "单{chan2}": 70,
72
+ "单{dan1}": 71,
73
+ "单{shan4}": 72,
74
+ "卜{bo5}": 73,
75
+ "卜{bu3}": 74,
76
+ "占{zhan1}": 75,
77
+ "占{zhan4}": 76,
78
+ "卡{ka2}": 77,
79
+ "卡{ka3}": 78,
80
+ "卡{qia3}": 79,
81
+ "卷{jvan3}": 80,
82
+ "卷{jvan4}": 81,
83
+ "厦{sha4}": 82,
84
+ "厦{xia4}": 83,
85
+ "参{can1}": 84,
86
+ "参{cen1}": 85,
87
+ "参{shen1}": 86,
88
+ "发{fa1}": 87,
89
+ "发{fa4}": 88,
90
+ "发{fa5}": 89,
91
+ "只{zhi1}": 90,
92
+ "只{zhi3}": 91,
93
+ "号{hao2}": 92,
94
+ "号{hao4}": 93,
95
+ "号{hao5}": 94,
96
+ "同{tong2}": 95,
97
+ "同{tong4}": 96,
98
+ "同{tong5}": 97,
99
+ "吐{tu2}": 98,
100
+ "吐{tu3}": 99,
101
+ "吐{tu4}": 100,
102
+ "和{he2}": 101,
103
+ "和{he4}": 102,
104
+ "和{he5}": 103,
105
+ "和{huo2}": 104,
106
+ "和{huo4}": 105,
107
+ "和{huo5}": 106,
108
+ "喝{he1}": 107,
109
+ "喝{he4}": 108,
110
+ "圈{jvan4}": 109,
111
+ "圈{qvan1}": 110,
112
+ "圈{qvan5}": 111,
113
+ "地{de5}": 112,
114
+ "地{di4}": 113,
115
+ "地{di5}": 114,
116
+ "塞{sai1}": 115,
117
+ "塞{sai2}": 116,
118
+ "塞{sai4}": 117,
119
+ "塞{se4}": 118,
120
+ "壳{ke2}": 119,
121
+ "壳{qiao4}": 120,
122
+ "处{chu3}": 121,
123
+ "处{chu4}": 122,
124
+ "奇{ji1}": 123,
125
+ "奇{qi2}": 124,
126
+ "奔{ben1}": 125,
127
+ "奔{ben4}": 126,
128
+ "好{hao3}": 127,
129
+ "好{hao4}": 128,
130
+ "好{hao5}": 129,
131
+ "宁{ning2}": 130,
132
+ "宁{ning4}": 131,
133
+ "宁{ning5}": 132,
134
+ "宿{su4}": 133,
135
+ "宿{xiu3}": 134,
136
+ "宿{xiu4}": 135,
137
+ "将{jiang1}": 136,
138
+ "将{jiang4}": 137,
139
+ "少{shao3}": 138,
140
+ "少{shao4}": 139,
141
+ "尽{jin3}": 140,
142
+ "尽{jin4}": 141,
143
+ "岗{gang1}": 142,
144
+ "岗{gang3}": 143,
145
+ "差{cha1}": 144,
146
+ "差{cha4}": 145,
147
+ "差{chai1}": 146,
148
+ "差{ci1}": 147,
149
+ "巷{hang4}": 148,
150
+ "巷{xiang4}": 149,
151
+ "帖{tie1}": 150,
152
+ "帖{tie3}": 151,
153
+ "帖{tie4}": 152,
154
+ "干{gan1}": 153,
155
+ "干{gan4}": 154,
156
+ "应{ying1}": 155,
157
+ "应{ying4}": 156,
158
+ "应{ying5}": 157,
159
+ "度{du4}": 158,
160
+ "度{du5}": 159,
161
+ "度{duo2}": 160,
162
+ "弹{dan4}": 161,
163
+ "弹{tan2}": 162,
164
+ "弹{tan5}": 163,
165
+ "强{jiang4}": 164,
166
+ "强{qiang2}": 165,
167
+ "强{qiang3}": 166,
168
+ "当{dang1}": 167,
169
+ "当{dang4}": 168,
170
+ "当{dang5}": 169,
171
+ "待{dai1}": 170,
172
+ "待{dai4}": 171,
173
+ "得{de2}": 172,
174
+ "得{de5}": 173,
175
+ "得{dei3}": 174,
176
+ "得{dei5}": 175,
177
+ "恶{e3}": 176,
178
+ "恶{e4}": 177,
179
+ "恶{wu4}": 178,
180
+ "扁{bian3}": 179,
181
+ "扁{pian1}": 180,
182
+ "扇{shan1}": 181,
183
+ "扇{shan4}": 182,
184
+ "扎{za1}": 183,
185
+ "扎{zha1}": 184,
186
+ "扎{zha2}": 185,
187
+ "扫{sao3}": 186,
188
+ "扫{sao4}": 187,
189
+ "担{dan1}": 188,
190
+ "担{dan4}": 189,
191
+ "担{dan5}": 190,
192
+ "挑{tiao1}": 191,
193
+ "挑{tiao3}": 192,
194
+ "据{jv1}": 193,
195
+ "据{jv4}": 194,
196
+ "撒{sa1}": 195,
197
+ "撒{sa3}": 196,
198
+ "撒{sa5}": 197,
199
+ "教{jiao1}": 198,
200
+ "教{jiao4}": 199,
201
+ "散{san3}": 200,
202
+ "散{san4}": 201,
203
+ "散{san5}": 202,
204
+ "数{shu3}": 203,
205
+ "数{shu4}": 204,
206
+ "数{shu5}": 205,
207
+ "斗{dou3}": 206,
208
+ "斗{dou4}": 207,
209
+ "晃{huang3}": 208,
210
+ "曝{bao4}": 209,
211
+ "曲{qu1}": 210,
212
+ "曲{qu3}": 211,
213
+ "更{geng1}": 212,
214
+ "更{geng4}": 213,
215
+ "曾{ceng1}": 214,
216
+ "曾{ceng2}": 215,
217
+ "曾{zeng1}": 216,
218
+ "朝{chao2}": 217,
219
+ "朝{zhao1}": 218,
220
+ "朴{piao2}": 219,
221
+ "朴{pu2}": 220,
222
+ "朴{pu3}": 221,
223
+ "杆{gan1}": 222,
224
+ "杆{gan3}": 223,
225
+ "查{cha2}": 224,
226
+ "查{zha1}": 225,
227
+ "校{jiao4}": 226,
228
+ "校{xiao4}": 227,
229
+ "模{mo2}": 228,
230
+ "模{mu2}": 229,
231
+ "横{heng2}": 230,
232
+ "横{heng4}": 231,
233
+ "没{mei2}": 232,
234
+ "没{mo4}": 233,
235
+ "泡{pao1}": 234,
236
+ "泡{pao4}": 235,
237
+ "泡{pao5}": 236,
238
+ "济{ji3}": 237,
239
+ "济{ji4}": 238,
240
+ "混{hun2}": 239,
241
+ "混{hun3}": 240,
242
+ "混{hun4}": 241,
243
+ "混{hun5}": 242,
244
+ "漂{piao1}": 243,
245
+ "漂{piao3}": 244,
246
+ "漂{piao4}": 245,
247
+ "炸{zha2}": 246,
248
+ "炸{zha4}": 247,
249
+ "熟{shou2}": 248,
250
+ "熟{shu2}": 249,
251
+ "燕{yan1}": 250,
252
+ "燕{yan4}": 251,
253
+ "片{pian1}": 252,
254
+ "片{pian4}": 253,
255
+ "率{lv4}": 254,
256
+ "率{shuai4}": 255,
257
+ "畜{chu4}": 256,
258
+ "畜{xu4}": 257,
259
+ "的{de5}": 258,
260
+ "的{di1}": 259,
261
+ "的{di2}": 260,
262
+ "的{di4}": 261,
263
+ "的{di5}": 262,
264
+ "盛{cheng2}": 263,
265
+ "盛{sheng4}": 264,
266
+ "相{xiang1}": 265,
267
+ "相{xiang4}": 266,
268
+ "相{xiang5}": 267,
269
+ "省{sheng3}": 268,
270
+ "省{xing3}": 269,
271
+ "看{kan1}": 270,
272
+ "看{kan4}": 271,
273
+ "看{kan5}": 272,
274
+ "着{zhao1}": 273,
275
+ "着{zhao2}": 274,
276
+ "着{zhao5}": 275,
277
+ "着{zhe5}": 276,
278
+ "着{zhuo2}": 277,
279
+ "着{zhuo5}": 278,
280
+ "矫{jiao3}": 279,
281
+ "禁{jin1}": 280,
282
+ "禁{jin4}": 281,
283
+ "种{zhong3}": 282,
284
+ "种{zhong4}": 283,
285
+ "称{chen4}": 284,
286
+ "称{cheng1}": 285,
287
+ "空{kong1}": 286,
288
+ "空{kong4}": 287,
289
+ "答{da1}": 288,
290
+ "答{da2}": 289,
291
+ "粘{nian2}": 290,
292
+ "粘{zhan1}": 291,
293
+ "糊{hu2}": 292,
294
+ "糊{hu5}": 293,
295
+ "系{ji4}": 294,
296
+ "系{xi4}": 295,
297
+ "系{xi5}": 296,
298
+ "累{lei2}": 297,
299
+ "累{lei3}": 298,
300
+ "累{lei4}": 299,
301
+ "累{lei5}": 300,
302
+ "纤{qian4}": 301,
303
+ "纤{xian1}": 302,
304
+ "结{jie1}": 303,
305
+ "结{jie2}": 304,
306
+ "结{jie5}": 305,
307
+ "给{gei3}": 306,
308
+ "给{gei5}": 307,
309
+ "给{ji3}": 308,
310
+ "缝{feng2}": 309,
311
+ "缝{feng4}": 310,
312
+ "缝{feng5}": 311,
313
+ "肖{xiao1}": 312,
314
+ "肖{xiao4}": 313,
315
+ "背{bei1}": 314,
316
+ "背{bei4}": 315,
317
+ "脏{zang1}": 316,
318
+ "脏{zang4}": 317,
319
+ "舍{she3}": 318,
320
+ "舍{she4}": 319,
321
+ "色{se4}": 320,
322
+ "色{shai3}": 321,
323
+ "落{lao4}": 322,
324
+ "落{luo4}": 323,
325
+ "蒙{meng1}": 324,
326
+ "蒙{meng2}": 325,
327
+ "蒙{meng3}": 326,
328
+ "薄{bao2}": 327,
329
+ "薄{bo2}": 328,
330
+ "薄{bo4}": 329,
331
+ "藏{cang2}": 330,
332
+ "藏{zang4}": 331,
333
+ "血{xie3}": 332,
334
+ "血{xue4}": 333,
335
+ "行{hang2}": 334,
336
+ "行{hang5}": 335,
337
+ "行{heng5}": 336,
338
+ "行{xing2}": 337,
339
+ "行{xing4}": 338,
340
+ "要{yao1}": 339,
341
+ "要{yao4}": 340,
342
+ "观{guan1}": 341,
343
+ "观{guan4}": 342,
344
+ "觉{jiao4}": 343,
345
+ "觉{jiao5}": 344,
346
+ "觉{jve2}": 345,
347
+ "角{jiao3}": 346,
348
+ "角{jve2}": 347,
349
+ "解{jie3}": 348,
350
+ "解{jie4}": 349,
351
+ "解{xie4}": 350,
352
+ "说{shui4}": 351,
353
+ "说{shuo1}": 352,
354
+ "调{diao4}": 353,
355
+ "调{tiao2}": 354,
356
+ "踏{ta1}": 355,
357
+ "踏{ta4}": 356,
358
+ "车{che1}": 357,
359
+ "车{jv1}": 358,
360
+ "转{zhuan3}": 359,
361
+ "转{zhuan4}": 360,
362
+ "载{zai3}": 361,
363
+ "载{zai4}": 362,
364
+ "还{hai2}": 363,
365
+ "还{huan2}": 364,
366
+ "遂{sui2}": 365,
367
+ "遂{sui4}": 366,
368
+ "都{dou1}": 367,
369
+ "都{du1}": 368,
370
+ "重{chong2}": 369,
371
+ "重{zhong4}": 370,
372
+ "量{liang2}": 371,
373
+ "量{liang4}": 372,
374
+ "量{liang5}": 373,
375
+ "钻{zuan1}": 374,
376
+ "钻{zuan4}": 375,
377
+ "铺{pu1}": 376,
378
+ "铺{pu4}": 377,
379
+ "长{chang2}": 378,
380
+ "长{chang3}": 379,
381
+ "长{zhang3}": 380,
382
+ "间{jian1}": 381,
383
+ "间{jian4}": 382,
384
+ "降{jiang4}": 383,
385
+ "降{xiang2}": 384,
386
+ "难{nan2}": 385,
387
+ "难{nan4}": 386,
388
+ "难{nan5}": 387,
389
+ "露{lou4}": 388,
390
+ "露{lu4}": 389,
391
+ "鲜{xian1}": 390,
392
+ "鲜{xian3}": 391
393
+ }
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a ㄚ
2
+ ai ㄞ
3
+ an ㄢ
4
+ ang ㄤ
5
+ ao ㄠ
6
+ ba ㄅㄚ
7
+ bai ㄅㄞ
8
+ ban ㄅㄢ
9
+ bang ㄅㄤ
10
+ bao ㄅㄠ
11
+ bei ㄅㄟ
12
+ ben ㄅㄣ
13
+ beng ㄅㄥ
14
+ bi ㄅㄧ
15
+ bian ㄅㄧㄢ
16
+ biang ㄅㄧㄤ
17
+ biao ㄅㄧㄠ
18
+ bie ㄅㄧㄝ
19
+ bin ㄅㄧㄣ
20
+ bing ㄅㄧㄥ
21
+ bo ㄅㄛ
22
+ bu ㄅㄨ
23
+ ca ㄘㄚ
24
+ cai ㄘㄞ
25
+ can ㄘㄢ
26
+ cang ㄘㄤ
27
+ cao ㄘㄠ
28
+ ce ㄘㄜ
29
+ cen ㄘㄣ
30
+ ceng ㄘㄥ
31
+ cha ㄔㄚ
32
+ chai ㄔㄞ
33
+ chan ㄔㄢ
34
+ chang ㄔㄤ
35
+ chao ㄔㄠ
36
+ che ㄔㄜ
37
+ chen ㄔㄣ
38
+ cheng ㄔㄥ
39
+ chi ㄔ
40
+ chong ㄔㄨㄥ
41
+ chou ㄔㄡ
42
+ chu ㄔㄨ
43
+ chua ㄔㄨㄚ
44
+ chuai ㄔㄨㄞ
45
+ chuan ㄔㄨㄢ
46
+ chuang ㄔㄨㄤ
47
+ chui ㄔㄨㄟ
48
+ chun ㄔㄨㄣ
49
+ chuo ㄔㄨㄛ
50
+ ci ㄘ
51
+ cong ㄘㄨㄥ
52
+ cou ㄘㄡ
53
+ cu ㄘㄨ
54
+ cuan ㄘㄨㄢ
55
+ cui ㄘㄨㄟ
56
+ cun ㄘㄨㄣ
57
+ cuo ㄘㄨㄛ
58
+ da ㄉㄚ
59
+ dai ㄉㄞ
60
+ dan ㄉㄢ
61
+ dang ㄉㄤ
62
+ dao ㄉㄠ
63
+ de ㄉㄜ
64
+ dei ㄉㄟ
65
+ den ㄉㄣ
66
+ deng ㄉㄥ
67
+ di ㄉㄧ
68
+ dia ㄉㄧㄚ
69
+ dian ㄉㄧㄢ
70
+ diao ㄉㄧㄠ
71
+ die ㄉㄧㄝ
72
+ din ㄉㄧㄣ
73
+ ding ㄉㄧㄥ
74
+ diu ㄉㄧㄡ
75
+ dong ㄉㄨㄥ
76
+ dou ㄉㄡ
77
+ du ㄉㄨ
78
+ duan ㄉㄨㄢ
79
+ dui ㄉㄨㄟ
80
+ dun ㄉㄨㄣ
81
+ duo ㄉㄨㄛ
82
+ e ㄜ
83
+ ei ㄟ
84
+ en ㄣ
85
+ eng ㄥ
86
+ er ㄦ
87
+ fa ㄈㄚ
88
+ fan ㄈㄢ
89
+ fang ㄈㄤ
90
+ fei ㄈㄟ
91
+ fen ㄈㄣ
92
+ feng ㄈㄥ
93
+ fo ㄈㄛ
94
+ fou ㄈㄡ
95
+ fu ㄈㄨ
96
+ ga ㄍㄚ
97
+ gai ㄍㄞ
98
+ gan ㄍㄢ
99
+ gang ㄍㄤ
100
+ gao ㄍㄠ
101
+ ge ㄍㄜ
102
+ gei ㄍㄟ
103
+ gen ㄍㄣ
104
+ geng ㄍㄥ
105
+ gong ㄍㄨㄥ
106
+ gou ㄍㄡ
107
+ gu ㄍㄨ
108
+ gua ㄍㄨㄚ
109
+ guai ㄍㄨㄞ
110
+ guan ㄍㄨㄢ
111
+ guang ㄍㄨㄤ
112
+ gui ㄍㄨㄟ
113
+ gun ㄍㄨㄣ
114
+ guo ㄍㄨㄛ
115
+ ha ㄏㄚ
116
+ hai ㄏㄞ
117
+ han ㄏㄢ
118
+ hang ㄏㄤ
119
+ hao ㄏㄠ
120
+ he ㄏㄜ
121
+ hei ㄏㄟ
122
+ hen ㄏㄣ
123
+ heng ㄏㄥ
124
+ hm ㄏㄇ
125
+ hong ㄏㄨㄥ
126
+ hou ㄏㄡ
127
+ hu ㄏㄨ
128
+ hua ㄏㄨㄚ
129
+ huai ㄏㄨㄞ
130
+ huan ㄏㄨㄢ
131
+ huang ㄏㄨㄤ
132
+ hui ㄏㄨㄟ
133
+ hun ㄏㄨㄣ
134
+ huo ㄏㄨㄛ
135
+ ji ㄐㄧ
136
+ jia ㄐㄧㄚ
137
+ jian ㄐㄧㄢ
138
+ jiang ㄐㄧㄤ
139
+ jiao ㄐㄧㄠ
140
+ jie ㄐㄧㄝ
141
+ jin ㄐㄧㄣ
142
+ jing ㄐㄧㄥ
143
+ jiong ㄐㄩㄥ
144
+ jiu ㄐㄧㄡ
145
+ ju ㄐㄩ
146
+ jv ㄐㄩ
147
+ juan ㄐㄩㄢ
148
+ jvan ㄐㄩㄢ
149
+ jue ㄐㄩㄝ
150
+ jve ㄐㄩㄝ
151
+ jun ㄐㄩㄣ
152
+ ka ㄎㄚ
153
+ kai ㄎㄞ
154
+ kan ㄎㄢ
155
+ kang ㄎㄤ
156
+ kao ㄎㄠ
157
+ ke ㄎㄜ
158
+ kei ㄎㄟ
159
+ ken ㄎㄣ
160
+ keng ㄎㄥ
161
+ kong ㄎㄨㄥ
162
+ kou ㄎㄡ
163
+ ku ㄎㄨ
164
+ kua ㄎㄨㄚ
165
+ kuai ㄎㄨㄞ
166
+ kuan ㄎㄨㄢ
167
+ kuang ㄎㄨㄤ
168
+ kui ㄎㄨㄟ
169
+ kun ㄎㄨㄣ
170
+ kuo ㄎㄨㄛ
171
+ la ㄌㄚ
172
+ lai ㄌㄞ
173
+ lan ㄌㄢ
174
+ lang ㄌㄤ
175
+ lao ㄌㄠ
176
+ le ㄌㄜ
177
+ lei ㄌㄟ
178
+ leng ㄌㄥ
179
+ li ㄌㄧ
180
+ lia ㄌㄧㄚ
181
+ lian ㄌㄧㄢ
182
+ liang ㄌㄧㄤ
183
+ liao ㄌㄧㄠ
184
+ lie ㄌㄧㄝ
185
+ lin ㄌㄧㄣ
186
+ ling ㄌㄧㄥ
187
+ liu ㄌㄧㄡ
188
+ lo ㄌㄛ
189
+ long ㄌㄨㄥ
190
+ lou ㄌㄡ
191
+ lu ㄌㄨ
192
+ luan ㄌㄨㄢ
193
+ lue ㄌㄩㄝ
194
+ lun ㄌㄨㄣ
195
+ luo ㄌㄨㄛ
196
+ lv ㄌㄩ
197
+ lve ㄌㄩㄝ
198
+ m ㄇㄨ
199
+ ma ㄇㄚ
200
+ mai ㄇㄞ
201
+ man ㄇㄢ
202
+ mang ㄇㄤ
203
+ mao ㄇㄠ
204
+ me ㄇㄜ
205
+ mei ㄇㄟ
206
+ men ㄇㄣ
207
+ meng ㄇㄥ
208
+ mi ㄇㄧ
209
+ mian ㄇㄧㄢ
210
+ miao ㄇㄧㄠ
211
+ mie ㄇㄧㄝ
212
+ min ㄇㄧㄣ
213
+ ming ㄇㄧㄥ
214
+ miu ㄇㄧㄡ
215
+ mo ㄇㄛ
216
+ mou ㄇㄡ
217
+ mu ㄇㄨ
218
+ n ㄣ
219
+ na ㄋㄚ
220
+ nai ㄋㄞ
221
+ nan ㄋㄢ
222
+ nang ㄋㄤ
223
+ nao ㄋㄠ
224
+ ne ㄋㄜ
225
+ nei ㄋㄟ
226
+ nen ㄋㄣ
227
+ neng ㄋㄥ
228
+ ng ㄣ
229
+ ni ㄋㄧ
230
+ nian ㄋㄧㄢ
231
+ niang ㄋㄧㄤ
232
+ niao ㄋㄧㄠ
233
+ nie ㄋㄧㄝ
234
+ nin ㄋㄧㄣ
235
+ ning ㄋㄧㄥ
236
+ niu ㄋㄧㄡ
237
+ nong ㄋㄨㄥ
238
+ nou ㄋㄡ
239
+ nu ㄋㄨ
240
+ nuan ㄋㄨㄢ
241
+ nue ㄋㄩㄝ
242
+ nun ㄋㄨㄣ
243
+ nuo ㄋㄨㄛ
244
+ nv ㄋㄩ
245
+ nve ㄋㄩㄝ
246
+ o ㄛ
247
+ ou ㄡ
248
+ pa ㄆㄚ
249
+ pai ㄆㄞ
250
+ pan ㄆㄢ
251
+ pang ㄆㄤ
252
+ pao ㄆㄠ
253
+ pei ㄆㄟ
254
+ pen ㄆㄣ
255
+ peng ㄆㄥ
256
+ pi ㄆㄧ
257
+ pian ㄆㄧㄢ
258
+ piao ㄆㄧㄠ
259
+ pie ㄆㄧㄝ
260
+ pin ㄆㄧㄣ
261
+ ping ㄆㄧㄥ
262
+ po ㄆㄛ
263
+ pou ㄆㄡ
264
+ pu ㄆㄨ
265
+ qi ㄑㄧ
266
+ qia ㄑㄧㄚ
267
+ qian ㄑㄧㄢ
268
+ qiang ㄑㄧㄤ
269
+ qiao ㄑㄧㄠ
270
+ qie ㄑㄧㄝ
271
+ qin ㄑㄧㄣ
272
+ qing ㄑㄧㄥ
273
+ qiong ㄑㄩㄥ
274
+ qiu ㄑㄧㄡ
275
+ qu ㄑㄩ
276
+ quan ㄑㄩㄢ
277
+ qvan ㄑㄩㄢ
278
+ que ㄑㄩㄝ
279
+ qun ㄑㄩㄣ
280
+ ran ㄖㄢ
281
+ rang ㄖㄤ
282
+ rao ㄖㄠ
283
+ re ㄖㄜ
284
+ ren ㄖㄣ
285
+ reng ㄖㄥ
286
+ ri ㄖ
287
+ rong ㄖㄨㄥ
288
+ rou ㄖㄡ
289
+ ru ㄖㄨ
290
+ rua ㄖㄨㄚ
291
+ ruan ㄖㄨㄢ
292
+ rui ㄖㄨㄟ
293
+ run ㄖㄨㄣ
294
+ ruo ㄖㄨㄛ
295
+ sa ㄙㄚ
296
+ sai ㄙㄞ
297
+ san ㄙㄢ
298
+ sang ㄙㄤ
299
+ sao ㄙㄠ
300
+ se ㄙㄜ
301
+ sen ㄙㄣ
302
+ seng ㄙㄥ
303
+ sha ㄕㄚ
304
+ shai ㄕㄞ
305
+ shan ㄕㄢ
306
+ shang ㄕㄤ
307
+ shao ㄕㄠ
308
+ she ㄕㄜ
309
+ shei ㄕㄟ
310
+ shen ㄕㄣ
311
+ sheng ㄕㄥ
312
+ shi ㄕ
313
+ shou ㄕㄡ
314
+ shu ㄕㄨ
315
+ shua ㄕㄨㄚ
316
+ shuai ㄕㄨㄞ
317
+ shuan ㄕㄨㄢ
318
+ shuang ㄕㄨㄤ
319
+ shui ㄕㄨㄟ
320
+ shun ㄕㄨㄣ
321
+ shuo ㄕㄨㄛ
322
+ si ㄙ
323
+ song ㄙㄨㄥ
324
+ sou ㄙㄡ
325
+ su ㄙㄨ
326
+ suan ㄙㄨㄢ
327
+ sui ㄙㄨㄟ
328
+ sun ㄙㄨㄣ
329
+ suo ㄙㄨㄛ
330
+ ta ㄊㄚ
331
+ tai ㄊㄞ
332
+ tan ㄊㄢ
333
+ tang ㄊㄤ
334
+ tao ㄊㄠ
335
+ te ㄊㄜ
336
+ tei ㄊㄟ
337
+ teng ㄊㄥ
338
+ ti ㄊㄧ
339
+ tian ㄊㄧㄢ
340
+ tiao ㄊㄧㄠ
341
+ tie ㄊㄧㄝ
342
+ ting ㄊㄧㄥ
343
+ tong ㄊㄨㄥ
344
+ tou ㄊㄡ
345
+ tsuo ㄘㄨㄛ
346
+ tu ㄊㄨ
347
+ tuan ㄊㄨㄢ
348
+ tui ㄊㄨㄟ
349
+ tun ㄊㄨㄣ
350
+ tuo ㄊㄨㄛ
351
+ tzan ㄗㄢ
352
+ wa ㄨㄚ
353
+ wai ㄨㄞ
354
+ wan ㄨㄢ
355
+ wang ㄨㄤ
356
+ wei ㄨㄟ
357
+ wen ㄨㄣ
358
+ weng ㄨㄥ
359
+ wo ㄨㄛ
360
+ wong ㄨㄥ
361
+ wu ㄨ
362
+ xi ㄒㄧ
363
+ xia ㄒㄧㄚ
364
+ xian ㄒㄧㄢ
365
+ xiang ㄒㄧㄤ
366
+ xiao ㄒㄧㄠ
367
+ xie ㄒㄧㄝ
368
+ xin ㄒㄧㄣ
369
+ xing ㄒㄧㄥ
370
+ xiong ㄒㄩㄥ
371
+ xiu ㄒㄧㄡ
372
+ xu ㄒㄩ
373
+ xuan ㄒㄩㄢ
374
+ xue ㄒㄩㄝ
375
+ xun ㄒㄩㄣ
376
+ ya ㄧㄚ
377
+ yai ㄧㄞ
378
+ yan ㄧㄢ
379
+ yang ㄧㄤ
380
+ yao ㄧㄠ
381
+ ye ㄧㄝ
382
+ yi ㄧ
383
+ yin ㄧㄣ
384
+ ying ㄧㄥ
385
+ yo ㄧㄛ
386
+ yong ㄩㄥ
387
+ you ㄧㄡ
388
+ yu ㄩ
389
+ yuan ㄩㄢ
390
+ yue ㄩㄝ
391
+ yve ㄩㄝ
392
+ yun ㄩㄣ
393
+ za ㄗㄚ
394
+ zai ㄗㄞ
395
+ zan ㄗㄢ
396
+ zang ㄗㄤ
397
+ zao ㄗㄠ
398
+ ze ㄗㄜ
399
+ zei ㄗㄟ
400
+ zen ㄗㄣ
401
+ zeng ㄗㄥ
402
+ zha ㄓㄚ
403
+ zhai ㄓㄞ
404
+ zhan ㄓㄢ
405
+ zhang ㄓㄤ
406
+ zhao ㄓㄠ
407
+ zhe ㄓㄜ
408
+ zhei ㄓㄟ
409
+ zhen ㄓㄣ
410
+ zheng ㄓㄥ
411
+ zhi ㄓ
412
+ zhong ㄓㄨㄥ
413
+ zhou ㄓㄡ
414
+ zhu ㄓㄨ
415
+ zhua ㄓㄨㄚ
416
+ zhuai ㄓㄨㄞ
417
+ zhuan ㄓㄨㄢ
418
+ zhuang ㄓㄨㄤ
419
+ zhui ㄓㄨㄟ
420
+ zhun ㄓㄨㄣ
421
+ zhuo ㄓㄨㄛ
422
+ zi ㄗ
423
+ zong ㄗㄨㄥ
424
+ zou ㄗㄡ
425
+ zu ㄗㄨ
426
+ zuan ㄗㄨㄢ
427
+ zui ㄗㄨㄟ
428
+ zun ㄗㄨㄣ
429
+ zuo ㄗㄨㄛ
src/YingMusicSinger/utils/f5_tts/g2p/utils/front_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ def generate_poly_lexicon(file_path: str):
8
+ """Generate poly char lexicon for Mandarin Chinese."""
9
+ poly_dict = {}
10
+
11
+ with open(file_path, "r", encoding="utf-8") as readf:
12
+ txt_list = readf.readlines()
13
+ for txt in txt_list:
14
+ word = txt.strip("\n")
15
+ if word not in poly_dict:
16
+ poly_dict[word] = 1
17
+ readf.close()
18
+ return poly_dict
src/YingMusicSinger/utils/f5_tts/g2p/utils/g2p.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ from typing import List, Union
9
+
10
+ from phonemizer.backend import EspeakBackend
11
+ from phonemizer.separator import Separator
12
+ from phonemizer.utils import list2str, str2list
13
+
14
+ # separator=Separator(phone=' ', word=' _ ', syllable='|'),
15
+ separator = Separator(word=" _ ", syllable="|", phone=" ")
16
+
17
+ phonemizer_zh = EspeakBackend(
18
+ "cmn", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
19
+ )
20
+ # phonemizer_zh.separator = separator
21
+
22
+ phonemizer_en = EspeakBackend(
23
+ "en-us",
24
+ preserve_punctuation=False,
25
+ with_stress=False,
26
+ language_switch="remove-flags",
27
+ )
28
+ # phonemizer_en.separator = separator
29
+
30
+ phonemizer_ja = EspeakBackend(
31
+ "ja", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
32
+ )
33
+ # phonemizer_ja.separator = separator
34
+
35
+ phonemizer_ko = EspeakBackend(
36
+ "ko", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
37
+ )
38
+ # phonemizer_ko.separator = separator
39
+
40
+ phonemizer_fr = EspeakBackend(
41
+ "fr-fr",
42
+ preserve_punctuation=False,
43
+ with_stress=False,
44
+ language_switch="remove-flags",
45
+ )
46
+ # phonemizer_fr.separator = separator
47
+
48
+ phonemizer_de = EspeakBackend(
49
+ "de", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
50
+ )
51
+ # phonemizer_de.separator = separator
52
+
53
+
54
+ lang2backend = {
55
+ "zh": phonemizer_zh,
56
+ "ja": phonemizer_ja,
57
+ "en": phonemizer_en,
58
+ "fr": phonemizer_fr,
59
+ "ko": phonemizer_ko,
60
+ "de": phonemizer_de,
61
+ }
62
+
63
+ with open("./src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json", "r") as f:
64
+ json_data = f.read()
65
+ token = json.loads(json_data)
66
+
67
+
68
+ def phonemizer_g2p(text, language):
69
+ langbackend = lang2backend[language]
70
+ phonemes = _phonemize(
71
+ langbackend,
72
+ text,
73
+ separator,
74
+ strip=True,
75
+ njobs=1,
76
+ prepend_text=False,
77
+ preserve_empty_lines=False,
78
+ )
79
+ token_id = []
80
+ if isinstance(phonemes, list):
81
+ for phone in phonemes:
82
+ phonemes_split = phone.split(" ")
83
+ token_id.append([token[p] for p in phonemes_split if p in token])
84
+ else:
85
+ phonemes_split = phonemes.split(" ")
86
+ token_id = [token[p] for p in phonemes_split if p in token]
87
+ return phonemes, token_id
88
+
89
+
90
+ def _phonemize( # pylint: disable=too-many-arguments
91
+ backend,
92
+ text: Union[str, List[str]],
93
+ separator: Separator,
94
+ strip: bool,
95
+ njobs: int,
96
+ prepend_text: bool,
97
+ preserve_empty_lines: bool,
98
+ ):
99
+ """Auxiliary function to phonemize()
100
+
101
+ Does the phonemization and returns the phonemized text. Raises a
102
+ RuntimeError on error.
103
+
104
+ """
105
+ # remember the text type for output (either list or string)
106
+ text_type = type(text)
107
+
108
+ # force the text as a list
109
+ text = [line.strip(os.linesep) for line in str2list(text)]
110
+
111
+ # if preserving empty lines, note the index of each empty line
112
+ if preserve_empty_lines:
113
+ empty_lines = [n for n, line in enumerate(text) if not line.strip()]
114
+
115
+ # ignore empty lines
116
+ text = [line for line in text if line.strip()]
117
+
118
+ if text:
119
+ # phonemize the text
120
+ phonemized = backend.phonemize(
121
+ text, separator=separator, strip=strip, njobs=njobs
122
+ )
123
+ else:
124
+ phonemized = []
125
+
126
+ # if preserving empty lines, reinsert them into text and phonemized lists
127
+ if preserve_empty_lines:
128
+ for i in empty_lines: # noqa
129
+ if prepend_text:
130
+ text.insert(i, "")
131
+ phonemized.insert(i, "")
132
+
133
+ # at that point, the phonemized text is a list of str. Format it as
134
+ # expected by the parameters
135
+ if prepend_text:
136
+ return list(zip(text, phonemized))
137
+ if text_type == str:
138
+ return list2str(phonemized)
139
+ return phonemized
src/YingMusicSinger/utils/f5_tts/g2p/utils/log.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import functools
8
+ import logging
9
+
10
+ __all__ = [
11
+ "logger",
12
+ ]
13
+
14
+
15
+ class Logger(object):
16
+ def __init__(self, name: str = None):
17
+ name = "PaddleSpeech" if not name else name
18
+ self.logger = logging.getLogger(name)
19
+
20
+ log_config = {
21
+ "DEBUG": 10,
22
+ "INFO": 20,
23
+ "TRAIN": 21,
24
+ "EVAL": 22,
25
+ "WARNING": 30,
26
+ "ERROR": 40,
27
+ "CRITICAL": 50,
28
+ "EXCEPTION": 100,
29
+ }
30
+ for key, level in log_config.items():
31
+ logging.addLevelName(level, key)
32
+ if key == "EXCEPTION":
33
+ self.__dict__[key.lower()] = self.logger.exception
34
+ else:
35
+ self.__dict__[key.lower()] = functools.partial(self.__call__, level)
36
+
37
+ self.format = logging.Formatter(
38
+ fmt="[%(asctime)-15s] [%(levelname)8s] - %(message)s"
39
+ )
40
+
41
+ self.handler = logging.StreamHandler()
42
+ self.handler.setFormatter(self.format)
43
+
44
+ self.logger.addHandler(self.handler)
45
+ self.logger.setLevel(logging.INFO)
46
+ self.logger.propagate = False
47
+
48
+ def __call__(self, log_level: str, msg: str):
49
+ self.logger.log(log_level, msg)
50
+
51
+
52
+ logger = Logger()
src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[UNK]": 0,
3
+ "_": 1,
4
+ "b": 2,
5
+ "d": 3,
6
+ "f": 4,
7
+ "h": 5,
8
+ "i": 6,
9
+ "j": 7,
10
+ "k": 8,
11
+ "l": 9,
12
+ "m": 10,
13
+ "n": 11,
14
+ "p": 12,
15
+ "r": 13,
16
+ "s": 14,
17
+ "t": 15,
18
+ "v": 16,
19
+ "w": 17,
20
+ "x": 18,
21
+ "z": 19,
22
+ "æ": 20,
23
+ "ç": 21,
24
+ "ð": 22,
25
+ "ŋ": 23,
26
+ "ɐ": 24,
27
+ "ɔ": 25,
28
+ "ə": 26,
29
+ "ɚ": 27,
30
+ "ɛ": 28,
31
+ "ɡ": 29,
32
+ "ɪ": 30,
33
+ "ɬ": 31,
34
+ "ɹ": 32,
35
+ "ɾ": 33,
36
+ "ʃ": 34,
37
+ "ʊ": 35,
38
+ "ʌ": 36,
39
+ "ʒ": 37,
40
+ "ʔ": 38,
41
+ "θ": 39,
42
+ "ᵻ": 40,
43
+ "aɪ": 41,
44
+ "aʊ": 42,
45
+ "dʒ": 43,
46
+ "eɪ": 44,
47
+ "iə": 45,
48
+ "iː": 46,
49
+ "n̩": 47,
50
+ "oʊ": 48,
51
+ "oː": 49,
52
+ "tʃ": 50,
53
+ "uː": 51,
54
+ "ææ": 52,
55
+ "ɐɐ": 53,
56
+ "ɑː": 54,
57
+ "ɑ̃": 55,
58
+ "ɔɪ": 56,
59
+ "ɔː": 57,
60
+ "ɔ̃": 58,
61
+ "əl": 59,
62
+ "ɛɹ": 60,
63
+ "ɜː": 61,
64
+ "ɡʲ": 62,
65
+ "ɪɹ": 63,
66
+ "ʊɹ": 64,
67
+ "aɪə": 65,
68
+ "aɪɚ": 66,
69
+ "iːː": 67,
70
+ "oːɹ": 68,
71
+ "ɑːɹ": 69,
72
+ "ɔːɹ": 70,
73
+
74
+ "1": 71,
75
+ "a": 72,
76
+ "e": 73,
77
+ "o": 74,
78
+ "q": 75,
79
+ "u": 76,
80
+ "y": 77,
81
+ "ɑ": 78,
82
+ "ɒ": 79,
83
+ "ɕ": 80,
84
+ "ɣ": 81,
85
+ "ɫ": 82,
86
+ "ɯ": 83,
87
+ "ʐ": 84,
88
+ "ʲ": 85,
89
+ "a1": 86,
90
+ "a2": 87,
91
+ "a5": 88,
92
+ "ai": 89,
93
+ "aɜ": 90,
94
+ "aː": 91,
95
+ "ei": 92,
96
+ "eə": 93,
97
+ "i.": 94,
98
+ "i1": 95,
99
+ "i2": 96,
100
+ "i5": 97,
101
+ "io": 98,
102
+ "iɑ": 99,
103
+ "iɛ": 100,
104
+ "iɜ": 101,
105
+ "i̪": 102,
106
+ "kh": 103,
107
+ "nʲ": 104,
108
+ "o1": 105,
109
+ "o2": 106,
110
+ "o5": 107,
111
+ "ou": 108,
112
+ "oɜ": 109,
113
+ "ph": 110,
114
+ "s.": 111,
115
+ "th": 112,
116
+ "ts": 113,
117
+ "tɕ": 114,
118
+ "u1": 115,
119
+ "u2": 116,
120
+ "u5": 117,
121
+ "ua": 118,
122
+ "uo": 119,
123
+ "uə": 120,
124
+ "uɜ": 121,
125
+ "y1": 122,
126
+ "y2": 123,
127
+ "y5": 124,
128
+ "yu": 125,
129
+ "yæ": 126,
130
+ "yə": 127,
131
+ "yɛ": 128,
132
+ "yɜ": 129,
133
+ "ŋɜ": 130,
134
+ "ŋʲ": 131,
135
+ "ɑ1": 132,
136
+ "ɑ2": 133,
137
+ "ɑ5": 134,
138
+ "ɑu": 135,
139
+ "ɑɜ": 136,
140
+ "ɑʲ": 137,
141
+ "ə1": 138,
142
+ "ə2": 139,
143
+ "ə5": 140,
144
+ "ər": 141,
145
+ "əɜ": 142,
146
+ "əʊ": 143,
147
+ "ʊə": 144,
148
+ "ai1": 145,
149
+ "ai2": 146,
150
+ "ai5": 147,
151
+ "aiɜ": 148,
152
+ "ei1": 149,
153
+ "ei2": 150,
154
+ "ei5": 151,
155
+ "eiɜ": 152,
156
+ "i.1": 153,
157
+ "i.2": 154,
158
+ "i.5": 155,
159
+ "i.ɜ": 156,
160
+ "io5": 157,
161
+ "iou": 158,
162
+ "iɑ1": 159,
163
+ "iɑ2": 160,
164
+ "iɑ5": 161,
165
+ "iɑɜ": 162,
166
+ "iɛ1": 163,
167
+ "iɛ2": 164,
168
+ "iɛ5": 165,
169
+ "iɛɜ": 166,
170
+ "i̪1": 167,
171
+ "i̪2": 168,
172
+ "i̪5": 169,
173
+ "i̪ɜ": 170,
174
+ "onɡ": 171,
175
+ "ou1": 172,
176
+ "ou2": 173,
177
+ "ou5": 174,
178
+ "ouɜ": 175,
179
+ "ts.": 176,
180
+ "tsh": 177,
181
+ "tɕh": 178,
182
+ "u5ʲ": 179,
183
+ "ua1": 180,
184
+ "ua2": 181,
185
+ "ua5": 182,
186
+ "uai": 183,
187
+ "uaɜ": 184,
188
+ "uei": 185,
189
+ "uo1": 186,
190
+ "uo2": 187,
191
+ "uo5": 188,
192
+ "uoɜ": 189,
193
+ "uə1": 190,
194
+ "uə2": 191,
195
+ "uə5": 192,
196
+ "uəɜ": 193,
197
+ "yiɜ": 194,
198
+ "yu2": 195,
199
+ "yu5": 196,
200
+ "yæ2": 197,
201
+ "yæ5": 198,
202
+ "yæɜ": 199,
203
+ "yə2": 200,
204
+ "yə5": 201,
205
+ "yəɜ": 202,
206
+ "yɛ1": 203,
207
+ "yɛ2": 204,
208
+ "yɛ5": 205,
209
+ "yɛɜ": 206,
210
+ "ɑu1": 207,
211
+ "ɑu2": 208,
212
+ "ɑu5": 209,
213
+ "ɑuɜ": 210,
214
+ "ər1": 211,
215
+ "ər2": 212,
216
+ "ər5": 213,
217
+ "ərɜ": 214,
218
+ "əː1": 215,
219
+ "iou1": 216,
220
+ "iou2": 217,
221
+ "iou5": 218,
222
+ "iouɜ": 219,
223
+ "onɡ1": 220,
224
+ "onɡ2": 221,
225
+ "onɡ5": 222,
226
+ "onɡɜ": 223,
227
+ "ts.h": 224,
228
+ "uai2": 225,
229
+ "uai5": 226,
230
+ "uaiɜ": 227,
231
+ "uei1": 228,
232
+ "uei2": 229,
233
+ "uei5": 230,
234
+ "ueiɜ": 231,
235
+ "uoɜʲ": 232,
236
+ "yɛ5ʲ": 233,
237
+ "ɑu2ʲ": 234,
238
+
239
+ "2": 235,
240
+ "5": 236,
241
+ "ɜ": 237,
242
+ "ʂ": 238,
243
+ "dʑ": 239,
244
+ "iɪ": 240,
245
+ "uɪ": 241,
246
+ "xʲ": 242,
247
+ "ɑt": 243,
248
+ "ɛɜ": 244,
249
+ "ɛː": 245,
250
+ "ɪː": 246,
251
+ "phʲ": 247,
252
+ "ɑ5ʲ": 248,
253
+ "ɑuʲ": 249,
254
+ "ərə": 250,
255
+ "uozʰ": 251,
256
+ "ər1ʲ": 252,
257
+ "tɕhtɕh": 253,
258
+
259
+ "c": 254,
260
+ "ʋ": 255,
261
+ "ʍ": 256,
262
+ "ʑ": 257,
263
+ "ː": 258,
264
+ "aə": 259,
265
+ "eː": 260,
266
+ "hʲ": 261,
267
+ "iʊ": 262,
268
+ "kʲ": 263,
269
+ "lʲ": 264,
270
+ "oə": 265,
271
+ "oɪ": 266,
272
+ "oʲ": 267,
273
+ "pʲ": 268,
274
+ "sʲ": 269,
275
+ "u4": 270,
276
+ "uʲ": 271,
277
+ "yi": 272,
278
+ "yʲ": 273,
279
+ "ŋ2": 274,
280
+ "ŋ5": 275,
281
+ "ŋ̩": 276,
282
+ "ɑɪ": 277,
283
+ "ɑʊ": 278,
284
+ "ɕʲ": 279,
285
+ "ət": 280,
286
+ "əə": 281,
287
+ "əɪ": 282,
288
+ "əʲ": 283,
289
+ "ɛ1": 284,
290
+ "ɛ5": 285,
291
+ "aiə": 286,
292
+ "aiɪ": 287,
293
+ "azʰ": 288,
294
+ "eiə": 289,
295
+ "eiɪ": 290,
296
+ "eiʊ": 291,
297
+ "i.ə": 292,
298
+ "i.ɪ": 293,
299
+ "i.ʊ": 294,
300
+ "ioɜ": 295,
301
+ "izʰ": 296,
302
+ "iɑə": 297,
303
+ "iɑʊ": 298,
304
+ "iɑʲ": 299,
305
+ "iɛə": 300,
306
+ "iɛɪ": 301,
307
+ "iɛʊ": 302,
308
+ "i̪ə": 303,
309
+ "i̪ʊ": 304,
310
+ "khʲ": 305,
311
+ "ouʲ": 306,
312
+ "tsʲ": 307,
313
+ "u2ʲ": 308,
314
+ "uoɪ": 309,
315
+ "uzʰ": 310,
316
+ "uɜʲ": 311,
317
+ "yæɪ": 312,
318
+ "yəʊ": 313,
319
+ "ərt": 314,
320
+ "ərɪ": 315,
321
+ "ərʲ": 316,
322
+ "əːt": 317,
323
+ "iouə": 318,
324
+ "iouʊ": 319,
325
+ "iouʲ": 320,
326
+ "iɛzʰ": 321,
327
+ "onɡə": 322,
328
+ "onɡɪ": 323,
329
+ "onɡʊ": 324,
330
+ "ouzʰ": 325,
331
+ "uai1": 326,
332
+ "ueiɪ": 327,
333
+ "ɑuzʰ": 328,
334
+ "iouzʰ": 329
335
+ }
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/LangSegment.py ADDED
@@ -0,0 +1,1251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file bundles language identification functions.
3
+
4
+ Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
5
+
6
+ Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
7
+ Based on research by Marco Lui and Tim Baldwin.
8
+
9
+ See LICENSE file for more info.
10
+ https://github.com/adbar/py3langid
11
+
12
+ Projects:
13
+ https://github.com/juntaosun/LangSegment
14
+ """
15
+
16
+ import re
17
+ from collections import Counter, defaultdict
18
+
19
+ import numpy as np
20
+
21
+ # import langid
22
+ # import py3langid as langid
23
+ # pip install py3langid==0.2.2
24
+ # 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
25
+ # langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
26
+ # For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
27
+ from py3langid.langid import MODEL_FILE, LanguageIdentifier
28
+
29
+ langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
30
+
31
+ # Digital processing
32
+ try:
33
+ from src.YingMusicSinger.utils.f5_tts.thirdparty.LangSegment.utils.num import (
34
+ num2str,
35
+ )
36
+ except ImportError:
37
+ try:
38
+ from thirdparty.LangSegment.utils.num import num2str
39
+ except ImportError as e:
40
+ raise e
41
+
42
+ # -----------------------------------
43
+ # 更新日志:新版本分词更加精准。
44
+ # Changelog: The new version of the word segmentation is more accurate.
45
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
46
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
47
+ # -----------------------------------
48
+
49
+
50
+ # Word segmentation function:
51
+ # automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
52
+ # making it more suitable for TTS processing.
53
+ # This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
54
+ # This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
55
+
56
+ # ===========================================================================================================
57
+ # 分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
58
+ # このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
59
+ # ===========================================================================================================
60
+ # (1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
61
+ # (2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
62
+ # この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
63
+ # ===========================================================================================================
64
+
65
+ # ===========================================================================================================
66
+ # 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
67
+ # 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
68
+ # ===========================================================================================================
69
+ # (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
70
+ # (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
71
+ # 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
72
+ # ===========================================================================================================
73
+
74
+ # ===========================================================================================================
75
+ # 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
76
+ # 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
77
+ # ===========================================================================================================
78
+ # (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
79
+ # (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
80
+ # 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
81
+ # ===========================================================================================================
82
+
83
+
84
+ # 手动分词标签规范:<语言标签>文本内容</语言标签>
85
+ # 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
86
+ # Manual word segmentation tag specification: <language tags> text content </language tags>
87
+ # 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
88
+ # ===========================================================================================================
89
+ # For manual word segmentation, labels need to appear in pairs, such as:
90
+ # 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
91
+ # 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
92
+ # Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
93
+ # ===========================================================================================================
94
+
95
+
96
+ # ===========================================================================================================
97
+ # 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
98
+ # 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
99
+ # 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
100
+ # Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
101
+ # ===========================================================================================================
102
+ # 中文实现:Chinese implementation:
103
+ # 【SSML】<number>=中文大写数字读法(单字)
104
+ # 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
105
+ # 【SSML】<currency>=按金额发音。
106
+ # 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
107
+ # ===========================================================================================================
108
+ class LangSSML:
109
+ # 纯数字
110
+ _zh_numerals_number = {
111
+ "0": "零",
112
+ "1": "一",
113
+ "2": "二",
114
+ "3": "三",
115
+ "4": "四",
116
+ "5": "五",
117
+ "6": "六",
118
+ "7": "七",
119
+ "8": "八",
120
+ "9": "九",
121
+ }
122
+
123
+ # 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
124
+ # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
125
+ def _format_chinese_data(date_str: str):
126
+ # 处理日期格式
127
+ input_date = date_str
128
+ if date_str is None or date_str.strip() == "":
129
+ return ""
130
+ date_str = re.sub(r"[\/\._|年|月]", "-", date_str)
131
+ date_str = re.sub(r"日", r"", date_str)
132
+ date_arrs = date_str.split(" ")
133
+ if len(date_arrs) == 1 and ":" in date_arrs[0]:
134
+ time_str = date_arrs[0]
135
+ date_arrs = []
136
+ else:
137
+ time_str = date_arrs[1] if len(date_arrs) >= 2 else ""
138
+
139
+ def nonZero(num, cn, func=None):
140
+ if func is not None:
141
+ num = func(num)
142
+ return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
143
+
144
+ f_number = LangSSML.to_chinese_number
145
+ f_currency = LangSSML.to_chinese_currency
146
+ # year, month, day
147
+ year_month_day = ""
148
+ if len(date_arrs) > 0:
149
+ year, month, day = "", "", ""
150
+ parts = date_arrs[0].split("-")
151
+ if len(parts) == 3: # 格式为 YYYY-MM-DD
152
+ year, month, day = parts
153
+ elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
154
+ if len(parts[0]) == 4: # 年-月
155
+ year, month = parts
156
+ else:
157
+ month, day = parts # 月-日
158
+ elif len(parts[0]) > 0: # 仅有月-日或年
159
+ if len(parts[0]) == 4:
160
+ year = parts[0]
161
+ else:
162
+ day = parts[0]
163
+ year, month, day = (
164
+ nonZero(year, "年", f_number),
165
+ nonZero(month, "月", f_currency),
166
+ nonZero(day, "日", f_currency),
167
+ )
168
+ year_month_day = re.sub(r"([年|月|日])+", r"\1", f"{year}{month}{day}")
169
+ # hours, minutes, seconds
170
+ time_str = re.sub(r"[\/\.\-:_]", ":", time_str)
171
+ time_arrs = time_str.split(":")
172
+ hours, minutes, seconds = "", "", ""
173
+ if len(time_arrs) == 3: # H/M/S
174
+ hours, minutes, seconds = time_arrs
175
+ elif len(time_arrs) == 2: # H/M
176
+ hours, minutes = time_arrs
177
+ elif len(time_arrs[0]) > 0:
178
+ hours = f"{time_arrs[0]}点" # H
179
+ if len(time_arrs) > 1:
180
+ hours, minutes, seconds = (
181
+ nonZero(hours, "点", f_currency),
182
+ nonZero(minutes, "分", f_currency),
183
+ nonZero(seconds, "秒", f_currency),
184
+ )
185
+ hours_minutes_seconds = re.sub(
186
+ r"([点|分|秒])+", r"\1", f"{hours}{minutes}{seconds}"
187
+ )
188
+ output_date = f"{year_month_day}{hours_minutes_seconds}"
189
+ return output_date
190
+
191
+ # 【SSML】number=中文大写数字读法(单字)
192
+ # Chinese Numbers(single word)
193
+ def to_chinese_number(num: str):
194
+ pattern = r"(\d+)"
195
+ zh_numerals = LangSSML._zh_numerals_number
196
+ arrs = re.split(pattern, num)
197
+ output = ""
198
+ for item in arrs:
199
+ if re.match(pattern, item):
200
+ output += "".join(
201
+ zh_numerals[digit] if digit in zh_numerals else ""
202
+ for digit in str(item)
203
+ )
204
+ else:
205
+ output += item
206
+ output = output.replace(".", "点")
207
+ return output
208
+
209
+ # 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
210
+ # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
211
+ def to_chinese_telephone(num: str):
212
+ output = LangSSML.to_chinese_number(num.replace("+86", "")) # zh +86
213
+ output = output.replace("一", "幺")
214
+ return output
215
+
216
+ # 【SSML】currency=按金额发音。
217
+ # Digital processing from GPT_SoVITS num.py (thanks)
218
+ def to_chinese_currency(num: str):
219
+ pattern = r"(\d+)"
220
+ arrs = re.split(pattern, num)
221
+ output = ""
222
+ for item in arrs:
223
+ if re.match(pattern, item):
224
+ output += num2str(item)
225
+ else:
226
+ output += item
227
+ output = output.replace(".", "点")
228
+ return output
229
+
230
+ # 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
231
+ def to_chinese_date(num: str):
232
+ chinese_date = LangSSML._format_chinese_data(num)
233
+ return chinese_date
234
+
235
+
236
+ class LangSegment:
237
+ _text_cache = None
238
+ _text_lasts = None
239
+ _text_langs = None
240
+ _lang_count = None
241
+ _lang_eos = None
242
+
243
+ # 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
244
+ # Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
245
+ # <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
246
+ SYMBOLS_PATTERN = r"(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)"
247
+
248
+ # 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
249
+ # 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
250
+ # 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
251
+ # The language filter group function allows you to specify reserved languages.
252
+ # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
253
+ # 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
254
+
255
+ # 系统默认过滤器。System default filter。(ISO 639-1 codes given)
256
+ # ----------------------------------------------------------------------------------------------------------------------------------
257
+ # "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
258
+ # "th"泰语=Thai
259
+ # ----------------------------------------------------------------------------------------------------------------------------------
260
+ DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
261
+
262
+ # 用户可自定义过滤器。User-defined filters
263
+ Langfilters = DEFAULT_FILTERS[:] # 创建副本
264
+
265
+ # 合并文本
266
+ isLangMerge = True
267
+
268
+ # 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
269
+ # 请使用API启用:LangSegment.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
270
+
271
+ # 预览版功能,自动启用或禁用,无需设置
272
+ # Preview feature, automatically enabled or disabled, no settings required
273
+ EnablePreview = False
274
+
275
+ # 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
276
+ # In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
277
+ # 示例:您可以任意指定多种组合,进行过滤
278
+ # Example: You can specify any combination to filter
279
+
280
+ # 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
281
+ # 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
282
+ # 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
283
+ # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
284
+ # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
285
+ LangPriorityThreshold = 0.89
286
+
287
+ # Langfilters = ["zh"] # 按中文识别
288
+ # Langfilters = ["en"] # 按英文识别
289
+ # Langfilters = ["ja"] # 按日文识别
290
+ # Langfilters = ["ko"] # 按韩文识别
291
+ # Langfilters = ["zh_ja"] # 中日混合识别
292
+ # Langfilters = ["zh_en"] # 中英混合识别
293
+ # Langfilters = ["ja_en"] # 日英混合识别
294
+ # Langfilters = ["zh_ko"] # 中韩混合识别
295
+ # Langfilters = ["ja_ko"] # 日韩混合识别
296
+ # Langfilters = ["en_ko"] # 英韩混合识别
297
+ # Langfilters = ["zh_ja_en"] # 中日英混合识别
298
+ # Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
299
+
300
+ # 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
301
+ # より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
302
+
303
+ # 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
304
+ # 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
305
+ keepPinyin = False
306
+
307
+ # DEFINITION
308
+ PARSE_TAG = re.compile(r"(⑥\$*\d+[\d]{6,}⑥)")
309
+
310
+ @staticmethod
311
+ def _clears():
312
+ LangSegment._text_cache = None
313
+ LangSegment._text_lasts = None
314
+ LangSegment._text_langs = None
315
+ LangSegment._text_waits = None
316
+ LangSegment._lang_count = None
317
+ LangSegment._lang_eos = None
318
+ pass
319
+
320
+ @staticmethod
321
+ def _is_english_word(word):
322
+ return bool(re.match(r"^[a-zA-Z]+$", word))
323
+
324
+ @staticmethod
325
+ def _is_chinese(word):
326
+ for char in word:
327
+ if "\u4e00" <= char <= "\u9fff":
328
+ return True
329
+ return False
330
+
331
+ @staticmethod
332
+ def _is_japanese_kana(word):
333
+ pattern = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]+")
334
+ matches = pattern.findall(word)
335
+ return len(matches) > 0
336
+
337
+ @staticmethod
338
+ def _insert_english_uppercase(word):
339
+ modified_text = re.sub(r"(?<!\b)([A-Z])", r" \1", word)
340
+ modified_text = modified_text.strip("-")
341
+ return modified_text + " "
342
+
343
+ @staticmethod
344
+ def _split_camel_case(word):
345
+ return re.sub(r"(?<!^)(?=[A-Z])", " ", word)
346
+
347
+ @staticmethod
348
+ def _statistics(language, text):
349
+ # Language word statistics:
350
+ # Chinese characters usually occupy double bytes
351
+ if LangSegment._lang_count is None or not isinstance(
352
+ LangSegment._lang_count, defaultdict
353
+ ):
354
+ LangSegment._lang_count = defaultdict(int)
355
+ lang_count = LangSegment._lang_count
356
+ if "|" not in language:
357
+ lang_count[language] += (
358
+ int(len(text) * 2) if language == "zh" else len(text)
359
+ )
360
+ LangSegment._lang_count = lang_count
361
+ pass
362
+
363
+ @staticmethod
364
+ def _clear_text_number(text):
365
+ if text == "\n":
366
+ return text, False # Keep Line Breaks
367
+ clear_text = re.sub(r"([^\w\s]+)", "", re.sub(r"\n+", "", text)).strip()
368
+ is_number = len(re.sub(re.compile(r"(\d+)"), "", clear_text)) == 0
369
+ return clear_text, is_number
370
+
371
+ @staticmethod
372
+ def _saveData(words, language: str, text: str, score: float, symbol=None):
373
+ # Pre-detection
374
+ clear_text, is_number = LangSegment._clear_text_number(text)
375
+ # Merge the same language and save the results
376
+ preData = words[-1] if len(words) > 0 else None
377
+ if symbol is not None:
378
+ pass
379
+ elif preData is not None and preData["symbol"] is None:
380
+ if len(clear_text) == 0:
381
+ language = preData["lang"]
382
+ elif is_number == True:
383
+ language = preData["lang"]
384
+ _, pre_is_number = LangSegment._clear_text_number(preData["text"])
385
+ if preData["lang"] == language:
386
+ LangSegment._statistics(preData["lang"], text)
387
+ text = preData["text"] + text
388
+ preData["text"] = text
389
+ return preData
390
+ elif pre_is_number == True:
391
+ text = f"{preData['text']}{text}"
392
+ words.pop()
393
+ elif is_number == True:
394
+ priority_language = LangSegment._get_filters_string()[:2]
395
+ if priority_language in "ja-zh-en-ko-fr-vi":
396
+ language = priority_language
397
+ data = {"lang": language, "text": text, "score": score, "symbol": symbol}
398
+ filters = LangSegment.Langfilters
399
+ if (
400
+ filters is None
401
+ or len(filters) == 0
402
+ or "?" in language
403
+ or language in filters
404
+ or language in filters[0]
405
+ or filters[0] == "*"
406
+ or filters[0] in "alls-mixs-autos"
407
+ ):
408
+ words.append(data)
409
+ LangSegment._statistics(data["lang"], data["text"])
410
+ return data
411
+
412
+ @staticmethod
413
+ def _addwords(words, language, text, score, symbol=None):
414
+ if text == "\n":
415
+ pass # Keep Line Breaks
416
+ elif text is None or len(text.strip()) == 0:
417
+ return True
418
+ if language is None:
419
+ language = ""
420
+ language = language.lower()
421
+ if language == "en":
422
+ text = LangSegment._insert_english_uppercase(text)
423
+ # text = re.sub(r'[(())]', ',' , text) # Keep it.
424
+ text_waits = LangSegment._text_waits
425
+ ispre_waits = len(text_waits) > 0
426
+ preResult = text_waits.pop() if ispre_waits else None
427
+ if preResult is None:
428
+ preResult = words[-1] if len(words) > 0 else None
429
+ if preResult and ("|" in preResult["lang"]):
430
+ pre_lang = preResult["lang"]
431
+ if language in pre_lang:
432
+ preResult["lang"] = language = language.split("|")[0]
433
+ else:
434
+ preResult["lang"] = pre_lang.split("|")[0]
435
+ if ispre_waits:
436
+ preResult = LangSegment._saveData(
437
+ words,
438
+ preResult["lang"],
439
+ preResult["text"],
440
+ preResult["score"],
441
+ preResult["symbol"],
442
+ )
443
+ pre_lang = preResult["lang"] if preResult else None
444
+ if ("|" in language) and (
445
+ pre_lang and pre_lang not in language and "…" not in language
446
+ ):
447
+ language = language.split("|")[0]
448
+ if "|" in language:
449
+ LangSegment._text_waits.append(
450
+ {"lang": language, "text": text, "score": score, "symbol": symbol}
451
+ )
452
+ else:
453
+ LangSegment._saveData(words, language, text, score, symbol)
454
+ return False
455
+
456
+ @staticmethod
457
+ def _get_prev_data(words):
458
+ data = words[-1] if words and len(words) > 0 else None
459
+ if data:
460
+ return (data["lang"], data["text"])
461
+ return (None, "")
462
+
463
+ @staticmethod
464
+ def _match_ending(input, index):
465
+ if input is None or len(input) == 0:
466
+ return False, None
467
+ input = re.sub(r"\s+", "", input)
468
+ if len(input) == 0 or abs(index) > len(input):
469
+ return False, None
470
+ ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
471
+ return ending_pattern.match(input[index]), input[index]
472
+
473
+ @staticmethod
474
+ def _cleans_text(cleans_text):
475
+ cleans_text = re.sub(r"(.*?)([^\w]+)", r"\1 ", cleans_text)
476
+ cleans_text = re.sub(r"(.)\1+", r"\1", cleans_text)
477
+ return cleans_text.strip()
478
+
479
+ @staticmethod
480
+ def _mean_processing(text: str):
481
+ if text is None or (text.strip()) == "":
482
+ return None, 0.0
483
+ arrs = LangSegment._split_camel_case(text).split(" ")
484
+ langs = []
485
+ for t in arrs:
486
+ if len(t.strip()) <= 3:
487
+ continue
488
+ language, score = langid.classify(t)
489
+ langs.append({"lang": language})
490
+ if len(langs) == 0:
491
+ return None, 0.0
492
+ return Counter([item["lang"] for item in langs]).most_common(1)[0][0], 1.0
493
+
494
+ @staticmethod
495
+ def _lang_classify(cleans_text):
496
+ language, score = langid.classify(cleans_text)
497
+ # fix: Huggingface is np.float32
498
+ if (
499
+ score is not None
500
+ and isinstance(score, np.generic)
501
+ and hasattr(score, "item")
502
+ ):
503
+ score = score.item()
504
+ score = round(score, 3)
505
+ return language, score
506
+
507
+ @staticmethod
508
+ def _get_filters_string():
509
+ filters = LangSegment.Langfilters
510
+ return "-".join(filters).lower().strip() if filters is not None else ""
511
+
512
+ @staticmethod
513
+ def _parse_language(words, segment):
514
+ LANG_JA = "ja"
515
+ LANG_ZH = "zh"
516
+ LANG_ZH_JA = f"{LANG_ZH}|{LANG_JA}"
517
+ LANG_JA_ZH = f"{LANG_JA}|{LANG_ZH}"
518
+ language = LANG_ZH
519
+ regex_pattern = re.compile(r"([^\w\s]+)")
520
+ lines = regex_pattern.split(segment)
521
+ lines_max = len(lines)
522
+ LANG_EOS = LangSegment._lang_eos
523
+ for index, text in enumerate(lines):
524
+ if len(text) == 0:
525
+ continue
526
+ EOS = index >= (lines_max - 1)
527
+ nextId = index + 1
528
+ nextText = lines[nextId] if not EOS else ""
529
+ nextPunc = (
530
+ len(re.sub(regex_pattern, "", re.sub(r"\n+", "", nextText)).strip())
531
+ == 0
532
+ )
533
+ textPunc = (
534
+ len(re.sub(regex_pattern, "", re.sub(r"\n+", "", text)).strip()) == 0
535
+ )
536
+ if not EOS and (
537
+ textPunc == True or (len(nextText.strip()) >= 0 and nextPunc == True)
538
+ ):
539
+ lines[nextId] = f"{text}{nextText}"
540
+ continue
541
+ number_tags = re.compile(r"(⑥\d{6,}⑥)")
542
+ cleans_text = re.sub(number_tags, "", text)
543
+ cleans_text = re.sub(r"\d+", "", cleans_text)
544
+ cleans_text = LangSegment._cleans_text(cleans_text)
545
+ # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
546
+ if not EOS and len(cleans_text) <= 2:
547
+ lines[nextId] = f"{text}{nextText}"
548
+ continue
549
+ language, score = LangSegment._lang_classify(cleans_text)
550
+ prev_language, prev_text = LangSegment._get_prev_data(words)
551
+ if language != LANG_ZH and all(
552
+ "\u4e00" <= c <= "\u9fff" for c in re.sub(r"\s", "", cleans_text)
553
+ ):
554
+ language, score = LANG_ZH, 1
555
+ if len(cleans_text) <= 5 and LangSegment._is_chinese(cleans_text):
556
+ filters_string = LangSegment._get_filters_string()
557
+ if (
558
+ score < LangSegment.LangPriorityThreshold
559
+ and len(filters_string) > 0
560
+ ):
561
+ index_ja, index_zh = (
562
+ filters_string.find(LANG_JA),
563
+ filters_string.find(LANG_ZH),
564
+ )
565
+ if index_ja != -1 and index_ja < index_zh:
566
+ language = LANG_JA
567
+ elif index_zh != -1 and index_zh < index_ja:
568
+ language = LANG_ZH
569
+ if LangSegment._is_japanese_kana(cleans_text):
570
+ language = LANG_JA
571
+ elif len(cleans_text) > 2 and score > 0.90:
572
+ pass
573
+ elif EOS and LANG_EOS:
574
+ language = LANG_ZH if len(cleans_text) <= 1 else language
575
+ else:
576
+ LANG_UNKNOWN = (
577
+ LANG_ZH_JA
578
+ if language == LANG_ZH
579
+ or (len(cleans_text) <= 2 and prev_language == LANG_ZH)
580
+ else LANG_JA_ZH
581
+ )
582
+ match_end, match_char = LangSegment._match_ending(text, -1)
583
+ referen = (
584
+ prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language
585
+ if prev_language
586
+ else False
587
+ )
588
+ if match_char in "。.":
589
+ language = (
590
+ prev_language if referen and len(words) > 0 else language
591
+ )
592
+ else:
593
+ language = f"{LANG_UNKNOWN}|…"
594
+ text, *_ = re.subn(number_tags, LangSegment._restore_number, text)
595
+ LangSegment._addwords(words, language, text, score)
596
+ pass
597
+ pass
598
+
599
+ # ----------------------------------------------------------
600
+ # 【SSML】中文数字处理:Chinese Number Processing (SSML support)
601
+ # 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
602
+ # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
603
+ # 中文电话号码:<telephone>1234567</telephone>
604
+ # 中文数字号码:<number>1234567</number>
605
+ @staticmethod
606
+ def _process_symbol_SSML(words, data):
607
+ tag, match = data
608
+ language = SSML = match[1]
609
+ text = match[2]
610
+ score = 1.0
611
+ if SSML == "telephone":
612
+ # 中文-电话号码
613
+ language = "zh"
614
+ text = LangSSML.to_chinese_telephone(text)
615
+ pass
616
+ elif SSML == "number":
617
+ # 中文-数字读法
618
+ language = "zh"
619
+ text = LangSSML.to_chinese_number(text)
620
+ pass
621
+ elif SSML == "currency":
622
+ # 中文-按金额发音
623
+ language = "zh"
624
+ text = LangSSML.to_chinese_currency(text)
625
+ pass
626
+ elif SSML == "date":
627
+ # 中文-按金额发音
628
+ language = "zh"
629
+ text = LangSSML.to_chinese_date(text)
630
+ pass
631
+ LangSegment._addwords(words, language, text, score, SSML)
632
+ pass
633
+
634
+ # ----------------------------------------------------------
635
+
636
+ @staticmethod
637
+ def _restore_number(matche):
638
+ value = matche.group(0)
639
+ text_cache = LangSegment._text_cache
640
+ if value in text_cache:
641
+ process, data = text_cache[value]
642
+ tag, match = data
643
+ value = match
644
+ return value
645
+
646
+ @staticmethod
647
+ def _pattern_symbols(item, text):
648
+ if text is None:
649
+ return text
650
+ tag, pattern, process = item
651
+ matches = pattern.findall(text)
652
+ if len(matches) == 1 and "".join(matches[0]) == text:
653
+ return text
654
+ for i, match in enumerate(matches):
655
+ key = f"⑥{tag}{i:06d}⑥"
656
+ text = re.sub(pattern, key, text, count=1)
657
+ LangSegment._text_cache[key] = (process, (tag, match))
658
+ return text
659
+
660
+ @staticmethod
661
+ def _process_symbol(words, data):
662
+ tag, match = data
663
+ language = match[1]
664
+ text = match[2]
665
+ score = 1.0
666
+ filters = LangSegment._get_filters_string()
667
+ if language not in filters:
668
+ LangSegment._process_symbol_SSML(words, data)
669
+ else:
670
+ LangSegment._addwords(words, language, text, score, True)
671
+ pass
672
+
673
+ @staticmethod
674
+ def _process_english(words, data):
675
+ tag, match = data
676
+ text = match[0]
677
+ filters = LangSegment._get_filters_string()
678
+ priority_language = filters[:2]
679
+ # Preview feature, other language segmentation processing
680
+ enablePreview = LangSegment.EnablePreview
681
+ if enablePreview == True:
682
+ # Experimental: Other language support
683
+ regex_pattern = re.compile(r"(.*?[。.??!!]+[\n]{,1})")
684
+ lines = regex_pattern.split(text)
685
+ for index, text in enumerate(lines):
686
+ if len(text.strip()) == 0:
687
+ continue
688
+ cleans_text = LangSegment._cleans_text(text)
689
+ language, score = LangSegment._lang_classify(cleans_text)
690
+ if language not in filters:
691
+ language, score = LangSegment._mean_processing(cleans_text)
692
+ if language is None or score <= 0.0:
693
+ continue
694
+ elif language in filters:
695
+ pass # pass
696
+ elif score >= 0.95:
697
+ continue # High score, but not in the filter, excluded.
698
+ elif score <= 0.15 and filters[:2] == "fr":
699
+ language = priority_language
700
+ else:
701
+ language = "en"
702
+ LangSegment._addwords(words, language, text, score)
703
+ else:
704
+ # Default is English
705
+ language, score = "en", 1.0
706
+ LangSegment._addwords(words, language, text, score)
707
+ pass
708
+
709
+ @staticmethod
710
+ def _process_Russian(words, data):
711
+ tag, match = data
712
+ text = match[0]
713
+ language = "ru"
714
+ score = 1.0
715
+ LangSegment._addwords(words, language, text, score)
716
+ pass
717
+
718
+ @staticmethod
719
+ def _process_Thai(words, data):
720
+ tag, match = data
721
+ text = match[0]
722
+ language = "th"
723
+ score = 1.0
724
+ LangSegment._addwords(words, language, text, score)
725
+ pass
726
+
727
+ @staticmethod
728
+ def _process_korean(words, data):
729
+ tag, match = data
730
+ text = match[0]
731
+ language = "ko"
732
+ score = 1.0
733
+ LangSegment._addwords(words, language, text, score)
734
+ pass
735
+
736
+ @staticmethod
737
+ def _process_quotes(words, data):
738
+ tag, match = data
739
+ text = "".join(match)
740
+ childs = LangSegment.PARSE_TAG.findall(text)
741
+ if len(childs) > 0:
742
+ LangSegment._process_tags(words, text, False)
743
+ else:
744
+ cleans_text = LangSegment._cleans_text(match[1])
745
+ if len(cleans_text) <= 5:
746
+ LangSegment._parse_language(words, text)
747
+ else:
748
+ language, score = LangSegment._lang_classify(cleans_text)
749
+ LangSegment._addwords(words, language, text, score)
750
+ pass
751
+
752
+ @staticmethod
753
+ def _process_pinyin(words, data):
754
+ tag, match = data
755
+ text = match
756
+ language = "zh"
757
+ score = 1.0
758
+ LangSegment._addwords(words, language, text, score)
759
+ pass
760
+
761
+ @staticmethod
762
+ def _process_number(words, data): # "$0" process only
763
+ """
764
+ Numbers alone cannot accurately identify language.
765
+ Because numbers are universal in all languages.
766
+ So it won't be executed here, just for testing.
767
+ """
768
+ tag, match = data
769
+ language = words[0]["lang"] if len(words) > 0 else "zh"
770
+ text = match
771
+ score = 0.0
772
+ LangSegment._addwords(words, language, text, score)
773
+ pass
774
+
775
+ @staticmethod
776
+ def _process_tags(words, text, root_tag):
777
+ text_cache = LangSegment._text_cache
778
+ segments = re.split(LangSegment.PARSE_TAG, text)
779
+ segments_len = len(segments) - 1
780
+ for index, text in enumerate(segments):
781
+ if root_tag:
782
+ LangSegment._lang_eos = index >= segments_len
783
+ if LangSegment.PARSE_TAG.match(text):
784
+ process, data = text_cache[text]
785
+ if process:
786
+ process(words, data)
787
+ else:
788
+ LangSegment._parse_language(words, text)
789
+ pass
790
+ return words
791
+
792
+ @staticmethod
793
+ def _merge_results(words):
794
+ new_word = []
795
+ for index, cur_data in enumerate(words):
796
+ if "symbol" in cur_data:
797
+ del cur_data["symbol"]
798
+ if index == 0:
799
+ new_word.append(cur_data)
800
+ else:
801
+ pre_data = new_word[-1]
802
+ if cur_data["lang"] == pre_data["lang"]:
803
+ pre_data["text"] = f"{pre_data['text']}{cur_data['text']}"
804
+ else:
805
+ new_word.append(cur_data)
806
+ return new_word
807
+
808
+ @staticmethod
809
+ def _parse_symbols(text):
810
+ TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
811
+ TAG_S1, TAG_S2, TAG_P1, TAG_P2, TAG_EN, TAG_KO, TAG_RU, TAG_TH = (
812
+ "$1",
813
+ "$2",
814
+ "$3",
815
+ "$4",
816
+ "$5",
817
+ "$6",
818
+ "$7",
819
+ "$8",
820
+ )
821
+ TAG_BASE = re.compile(r'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
822
+ # Get custom language filter
823
+ filters = LangSegment.Langfilters
824
+ filters = filters if filters is not None else ""
825
+ # =======================================================================================================
826
+ # Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
827
+ # 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
828
+ # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
829
+ # S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
830
+ # Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
831
+ # -------------------------------------------------------------------------------------------------------
832
+ # Preview feature, other language support
833
+ enablePreview = LangSegment.EnablePreview
834
+ if "fr" in filters or "vi" in filters:
835
+ enablePreview = True
836
+ LangSegment.EnablePreview = enablePreview
837
+ # 实验性:法语字符支持。Prise en charge des caractères français
838
+ RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
839
+ # 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
840
+ RE_VI = (
841
+ ""
842
+ if not enablePreview
843
+ else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
844
+ )
845
+ # -------------------------------------------------------------------------------------------------------
846
+ # Basic options:
847
+ process_list = [
848
+ (
849
+ TAG_S1,
850
+ re.compile(LangSegment.SYMBOLS_PATTERN),
851
+ LangSegment._process_symbol,
852
+ ), # Symbol Tag
853
+ (
854
+ TAG_KO,
855
+ re.compile(re.sub(r"LANGUAGE", "\uac00-\ud7a3", TAG_BASE.pattern)),
856
+ LangSegment._process_korean,
857
+ ), # Korean words
858
+ (
859
+ TAG_TH,
860
+ re.compile(re.sub(r"LANGUAGE", "\u0e00-\u0e7f", TAG_BASE.pattern)),
861
+ LangSegment._process_Thai,
862
+ ), # Thai words support.
863
+ (
864
+ TAG_RU,
865
+ re.compile(re.sub(r"LANGUAGE", "А-Яа-яЁё", TAG_BASE.pattern)),
866
+ LangSegment._process_Russian,
867
+ ), # Russian words support.
868
+ (
869
+ TAG_NUM,
870
+ re.compile(r"(\W*\d+\W+\d*\W*\d*)"),
871
+ LangSegment._process_number,
872
+ ), # Number words, Universal in all languages, Ignore it.
873
+ (
874
+ TAG_EN,
875
+ re.compile(
876
+ re.sub(r"LANGUAGE", f"a-zA-Z{RE_FR}{RE_VI}", TAG_BASE.pattern)
877
+ ),
878
+ LangSegment._process_english,
879
+ ), # English words + Other language support.
880
+ (
881
+ TAG_P1,
882
+ re.compile(r'(["\'])(.*?)(\1)'),
883
+ LangSegment._process_quotes,
884
+ ), # Regular quotes
885
+ (
886
+ TAG_P2,
887
+ re.compile(
888
+ r"([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})"
889
+ ),
890
+ LangSegment._process_quotes,
891
+ ), # Special quotes, There are left and right.
892
+ ]
893
+ # Extended options: Default False
894
+ if LangSegment.keepPinyin == True:
895
+ process_list.insert(
896
+ 1,
897
+ (
898
+ TAG_S2,
899
+ re.compile(r"([\(({](?:\s*\w*\d\w*\s*)+[})\)])"),
900
+ LangSegment._process_pinyin,
901
+ ), # Chinese Pinyin Tag.
902
+ )
903
+ # -------------------------------------------------------------------------------------------------------
904
+ words = []
905
+ lines = re.findall(r".*\n*", re.sub(LangSegment.PARSE_TAG, "", text))
906
+ for index, text in enumerate(lines):
907
+ if len(text.strip()) == 0:
908
+ continue
909
+ LangSegment._lang_eos = False
910
+ LangSegment._text_cache = {}
911
+ for item in process_list:
912
+ text = LangSegment._pattern_symbols(item, text)
913
+ cur_word = LangSegment._process_tags([], text, True)
914
+ if len(cur_word) == 0:
915
+ continue
916
+ cur_data = cur_word[0] if len(cur_word) > 0 else None
917
+ pre_data = words[-1] if len(words) > 0 else None
918
+ if (
919
+ cur_data
920
+ and pre_data
921
+ and cur_data["lang"] == pre_data["lang"]
922
+ and cur_data["symbol"] == False
923
+ and pre_data["symbol"]
924
+ ):
925
+ cur_data["text"] = f"{pre_data['text']}{cur_data['text']}"
926
+ words.pop()
927
+ words += cur_word
928
+ if LangSegment.isLangMerge == True:
929
+ words = LangSegment._merge_results(words)
930
+ lang_count = LangSegment._lang_count
931
+ if lang_count and len(lang_count) > 0:
932
+ lang_count = dict(
933
+ sorted(lang_count.items(), key=lambda x: x[1], reverse=True)
934
+ )
935
+ lang_count = list(lang_count.items())
936
+ LangSegment._lang_count = lang_count
937
+ return words
938
+
939
+ @staticmethod
940
+ def setfilters(filters):
941
+ # 当过滤器更改时,清除缓存
942
+ # 필터가 변경되면 캐시를 지웁니다.
943
+ # フィルタが変更されると、キャッシュがクリアされます
944
+ # When the filter changes, clear the cache
945
+ if LangSegment.Langfilters != filters:
946
+ LangSegment._clears()
947
+ LangSegment.Langfilters = filters
948
+ pass
949
+
950
+ @staticmethod
951
+ def getfilters():
952
+ return LangSegment.Langfilters
953
+
954
+ @staticmethod
955
+ def setPriorityThreshold(threshold: float):
956
+ LangSegment.LangPriorityThreshold = threshold
957
+ pass
958
+
959
+ @staticmethod
960
+ def getPriorityThreshold():
961
+ return LangSegment.LangPriorityThreshold
962
+
963
+ @staticmethod
964
+ def getCounts():
965
+ lang_count = LangSegment._lang_count
966
+ if lang_count is not None:
967
+ return lang_count
968
+ text_langs = LangSegment._text_langs
969
+ if text_langs is None or len(text_langs) == 0:
970
+ return [("zh", 0)]
971
+ lang_counts = defaultdict(int)
972
+ for d in text_langs:
973
+ lang_counts[d["lang"]] += (
974
+ int(len(d["text"]) * 2) if d["lang"] == "zh" else len(d["text"])
975
+ )
976
+ lang_counts = dict(
977
+ sorted(lang_counts.items(), key=lambda x: x[1], reverse=True)
978
+ )
979
+ lang_counts = list(lang_counts.items())
980
+ LangSegment._lang_count = lang_counts
981
+ return lang_counts
982
+
983
+ @staticmethod
984
+ def getTexts(text: str):
985
+ if text is None or len(text.strip()) == 0:
986
+ LangSegment._clears()
987
+ return []
988
+ # lasts
989
+ text_langs = LangSegment._text_langs
990
+ if LangSegment._text_lasts == text and text_langs is not None:
991
+ return text_langs
992
+ # parse
993
+ LangSegment._text_waits = []
994
+ LangSegment._lang_count = None
995
+ LangSegment._text_lasts = text
996
+ text = LangSegment._parse_symbols(text)
997
+ LangSegment._text_langs = text
998
+ return text
999
+
1000
+ @staticmethod
1001
+ def classify(text: str):
1002
+ return LangSegment.getTexts(text)
1003
+
1004
+
1005
+ def setLangMerge(value: bool):
1006
+ """是否优化合并结果"""
1007
+ LangSegment.isLangMerge = value
1008
+ pass
1009
+
1010
+
1011
+ def getLangMerge():
1012
+ """是否优化合并结果"""
1013
+ return LangSegment.isLangMerge
1014
+
1015
+
1016
+ def setfilters(filters):
1017
+ """
1018
+ 功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
1019
+ 기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
1020
+ 機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
1021
+ Function: Language filter group function, you can specify reserved languages. \n
1022
+ Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
1023
+ Args:
1024
+ filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
1025
+ """
1026
+ LangSegment.setfilters(filters)
1027
+ pass
1028
+
1029
+
1030
+ def getfilters():
1031
+ """
1032
+ 功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
1033
+ 기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
1034
+ 機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
1035
+ Function: Language filter group function, you can specify reserved languages. \n
1036
+ Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
1037
+ Args:
1038
+ filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
1039
+ """
1040
+ return LangSegment.getfilters()
1041
+
1042
+
1043
+ # # @Deprecated:Use shorter setfilters
1044
+ # def setLangfilters(filters):
1045
+ # """
1046
+ # >0.1.9废除:使用更简短的setfilters
1047
+ # """
1048
+ # setfilters(filters)
1049
+ # # @Deprecated:Use shorter getfilters
1050
+ # def getLangfilters():
1051
+ # """
1052
+ # >0.1.9废除:使用更简短的getfilters
1053
+ # """
1054
+ # return getfilters()
1055
+
1056
+
1057
+ def setKeepPinyin(value: bool):
1058
+ """
1059
+ 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
1060
+ 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
1061
+ """
1062
+ LangSegment.keepPinyin = value
1063
+ pass
1064
+
1065
+
1066
+ def getKeepPinyin():
1067
+ """
1068
+ 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
1069
+ 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
1070
+ """
1071
+ return LangSegment.keepPinyin
1072
+
1073
+
1074
+ def setEnablePreview(value: bool):
1075
+ """
1076
+ 启用预览版功能(默认关闭)
1077
+ Enable preview functionality (off by default)
1078
+ Args:
1079
+ value (bool): True=开启, False=关闭
1080
+ """
1081
+ LangSegment.EnablePreview = value == True
1082
+ pass
1083
+
1084
+
1085
+ def getEnablePreview():
1086
+ """
1087
+ 启用预览版功能(默认关闭)
1088
+ Enable preview functionality (off by default)
1089
+ Args:
1090
+ value (bool): True=开启, False=关闭
1091
+ """
1092
+ return LangSegment.EnablePreview == True
1093
+
1094
+
1095
+ def setPriorityThreshold(threshold: float):
1096
+ """
1097
+ 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
1098
+ 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
1099
+ 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
1100
+ Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
1101
+ Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
1102
+ Args:
1103
+ threshold:float (score range is 0 ~ 1)
1104
+ """
1105
+ LangSegment.setPriorityThreshold(threshold)
1106
+ pass
1107
+
1108
+
1109
+ def getPriorityThreshold():
1110
+ """
1111
+ 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
1112
+ 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
1113
+ 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
1114
+ Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
1115
+ Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
1116
+ Args:
1117
+ threshold:float (score range is 0 ~ 1)
1118
+ """
1119
+ return LangSegment.getPriorityThreshold()
1120
+
1121
+
1122
+ def getTexts(text: str):
1123
+ """
1124
+ 功能:对输入的文本进行多语种分词\n
1125
+ 기능: 입력 텍스트의 다국어 분할 \n
1126
+ 機能:入力されたテキストの多言語セグメンテーション\n
1127
+ Feature: Tokenizing multilingual text input.\n
1128
+ 参数-Args:
1129
+ text (str): Text content,文本内容\n
1130
+ 返回-Returns:
1131
+ list: 示例结果:[{'lang':'zh','text':'?'},...]\n
1132
+ lang=语种 , text=内容\n
1133
+ """
1134
+ return LangSegment.getTexts(text)
1135
+
1136
+
1137
+ def getCounts():
1138
+ """
1139
+ 功能:分词结果统计,按语种字数降序,用于确定其主要语言\n
1140
+ 기능: 주요 언어를 결정하는 데 사용되는 언어별 단어 수 내림차순으로 단어 분할 결과의 통계 \n
1141
+ 機能:主な言語を決定するために使用される、言語の単語数の降順による単語分割結果の統計\n
1142
+ Function: Tokenizing multilingual text input.\n
1143
+ 返回-Returns:
1144
+ list: 示例结果:[('zh', 5), ('ja', 2), ('en', 1)] = [(语种,字数含标点)]\n
1145
+ """
1146
+ return LangSegment.getCounts()
1147
+
1148
+
1149
+ def classify(text: str):
1150
+ """
1151
+ 功能:兼容接口实现
1152
+ Function: Compatible interface implementation
1153
+ """
1154
+ return LangSegment.classify(text)
1155
+
1156
+
1157
+ def printList(langlist):
1158
+ """
1159
+ 功能:打印数组结果
1160
+ 기능: 어레이 결과 인쇄
1161
+ 機能:配列結果を印刷
1162
+ Function: Print array results
1163
+ """
1164
+ print("\n===================【打印结果】===================")
1165
+ if langlist is None or len(langlist) == 0:
1166
+ print("无内容结果,No content result")
1167
+ return
1168
+ for line in langlist:
1169
+ print(line)
1170
+ pass
1171
+
1172
+
1173
+ def main():
1174
+ # -----------------------------------
1175
+ # 更新日志:新版本分词更加精准。
1176
+ # Changelog: The new version of the word segmentation is more accurate.
1177
+ # チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
1178
+ # Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
1179
+ # -----------------------------------
1180
+
1181
+ # 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
1182
+ # text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
1183
+
1184
+ # 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
1185
+ # text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
1186
+
1187
+ # 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
1188
+ # text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
1189
+
1190
+ # 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
1191
+ # text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
1192
+
1193
+ # 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
1194
+ LangSegment.setfilters(["fr", "vi", "ja", "zh", "ko", "en", "ru", "th"])
1195
+ text = """
1196
+ 我喜欢在雨天里听音乐。
1197
+ I enjoy listening to music on rainy days.
1198
+ 雨の日に音楽を聴くのが好きです。
1199
+ 비 오는 날에 음악을 듣는 것을 즐깁니다。
1200
+ J'aime écouter de la musique les jours de pluie.
1201
+ Tôi thích nghe nhạc vào những ngày mưa.
1202
+ Мне нравится слушать музыку в дождливую погоду.
1203
+ ฉันชอบฟังเพ���งในวันที่ฝนตก
1204
+ """
1205
+
1206
+ # 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
1207
+ langlist = LangSegment.getTexts(text)
1208
+ printList(langlist)
1209
+
1210
+ # 语种统计:Language statistics:
1211
+ print("\n===================【语种统计】===================")
1212
+ # 获取所有语种数组结果,根据内容字数降序排列
1213
+ # Get the array results in all languages, sorted in descending order according to the number of content words
1214
+ langCounts = LangSegment.getCounts()
1215
+ print(langCounts, "\n")
1216
+
1217
+ # 根据结果获取内容的主要语种 (语言,字数含标点)
1218
+ # Get the main language of content based on the results (language, word count including punctuation)
1219
+ lang, count = langCounts[0]
1220
+ print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
1221
+ print("==================================================\n")
1222
+
1223
+ # 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
1224
+ # ===================【打印结果】===================
1225
+ # {'lang': 'zh', 'text': '你的名字叫'}
1226
+ # {'lang': 'ja', 'text': '佐々木?'}
1227
+ # {'lang': 'zh', 'text': '吗?韩语中的'}
1228
+ # {'lang': 'ko', 'text': '안녕 오빠'}
1229
+ # {'lang': 'zh', 'text': '读什么呢?'}
1230
+ # {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
1231
+ # {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
1232
+ # {'lang': 'en', 'text': 'i Phone '}
1233
+ # {'lang': 'zh', 'text': '15系列机型和三款'}
1234
+ # {'lang': 'en', 'text': 'Apple Watch '}
1235
+ # {'lang': 'zh', 'text': '等一系列新品,这次的'}
1236
+ # {'lang': 'en', 'text': 'i Pad Air '}
1237
+ # {'lang': 'zh', 'text': '采用了'}
1238
+ # {'lang': 'en', 'text': 'L C D '}
1239
+ # {'lang': 'zh', 'text': '屏幕'}
1240
+ # ===================【语种统计】===================
1241
+
1242
+ # ===================【语种统计】===================
1243
+ # [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
1244
+
1245
+ # 输入内容的主要语言为 = zh ,字数 = 51
1246
+ # ==================================================
1247
+ # The main language of the input content is = zh, word count = 51
1248
+
1249
+
1250
+ if __name__ == "__main__":
1251
+ main()
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .LangSegment import (
2
+ LangSegment,
3
+ classify,
4
+ getCounts,
5
+ getEnablePreview,
6
+ getfilters,
7
+ getKeepPinyin,
8
+ getLangMerge,
9
+ getPriorityThreshold,
10
+ getTexts,
11
+ printList,
12
+ setEnablePreview,
13
+ setfilters,
14
+ setKeepPinyin,
15
+ setLangMerge,
16
+ setPriorityThreshold,
17
+ )
18
+
19
+ # release
20
+ __version__ = "0.3.5"
21
+
22
+
23
+ # develop
24
+ __develop__ = "dev-0.0.1"
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/__init__.py ADDED
File without changes
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/num.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Digital processing from GPT_SoVITS num.py (thanks)
15
+ """
16
+ Rules to verbalize numbers into Chinese characters.
17
+ https://zh.wikipedia.org/wiki/中文数字#現代中文
18
+ """
19
+
20
+ import re
21
+ from collections import OrderedDict
22
+ from typing import List
23
+
24
+ DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
25
+ UNITS = OrderedDict(
26
+ {
27
+ 1: "十",
28
+ 2: "百",
29
+ 3: "千",
30
+ 4: "万",
31
+ 8: "亿",
32
+ }
33
+ )
34
+
35
+ COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
36
+
37
+ # 分数表达式
38
+ RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
39
+
40
+
41
+ def replace_frac(match) -> str:
42
+ """
43
+ Args:
44
+ match (re.Match)
45
+ Returns:
46
+ str
47
+ """
48
+ sign = match.group(1)
49
+ nominator = match.group(2)
50
+ denominator = match.group(3)
51
+ sign: str = "负" if sign else ""
52
+ nominator: str = num2str(nominator)
53
+ denominator: str = num2str(denominator)
54
+ result = f"{sign}{denominator}分之{nominator}"
55
+ return result
56
+
57
+
58
+ # 百分数表达式
59
+ RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
60
+
61
+
62
+ def replace_percentage(match) -> str:
63
+ """
64
+ Args:
65
+ match (re.Match)
66
+ Returns:
67
+ str
68
+ """
69
+ sign = match.group(1)
70
+ percent = match.group(2)
71
+ sign: str = "负" if sign else ""
72
+ percent: str = num2str(percent)
73
+ result = f"{sign}百分之{percent}"
74
+ return result
75
+
76
+
77
+ # 整数表达式
78
+ # 带负号的整数 -10
79
+ RE_INTEGER = re.compile(r"(-)" r"(\d+)")
80
+
81
+
82
+ def replace_negative_num(match) -> str:
83
+ """
84
+ Args:
85
+ match (re.Match)
86
+ Returns:
87
+ str
88
+ """
89
+ sign = match.group(1)
90
+ number = match.group(2)
91
+ sign: str = "负" if sign else ""
92
+ number: str = num2str(number)
93
+ result = f"{sign}{number}"
94
+ return result
95
+
96
+
97
+ # 编号-无符号整形
98
+ # 00078
99
+ RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
100
+
101
+
102
+ def replace_default_num(match):
103
+ """
104
+ Args:
105
+ match (re.Match)
106
+ Returns:
107
+ str
108
+ """
109
+ number = match.group(0)
110
+ return verbalize_digit(number, alt_one=True)
111
+
112
+
113
+ # 加减乘除
114
+ # RE_ASMD = re.compile(
115
+ # r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
116
+ RE_ASMD = re.compile(
117
+ r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
118
+ )
119
+
120
+ asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"}
121
+
122
+
123
+ def replace_asmd(match) -> str:
124
+ """
125
+ Args:
126
+ match (re.Match)
127
+ Returns:
128
+ str
129
+ """
130
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
131
+ return result
132
+
133
+
134
+ # 次方专项
135
+ RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
136
+
137
+ power_map = {
138
+ "⁰": "0",
139
+ "¹": "1",
140
+ "²": "2",
141
+ "³": "3",
142
+ "⁴": "4",
143
+ "⁵": "5",
144
+ "⁶": "6",
145
+ "⁷": "7",
146
+ "⁸": "8",
147
+ "⁹": "9",
148
+ "ˣ": "x",
149
+ "ʸ": "y",
150
+ "ⁿ": "n",
151
+ }
152
+
153
+
154
+ def replace_power(match) -> str:
155
+ """
156
+ Args:
157
+ match (re.Match)
158
+ Returns:
159
+ str
160
+ """
161
+ power_num = ""
162
+ for m in match.group(0):
163
+ power_num += power_map[m]
164
+ result = "的" + power_num + "次方"
165
+ return result
166
+
167
+
168
+ # 数字表达式
169
+ # 纯小数
170
+ RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
171
+ # 正整数 + 量词
172
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
173
+ RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
174
+
175
+
176
+ def replace_positive_quantifier(match) -> str:
177
+ """
178
+ Args:
179
+ match (re.Match)
180
+ Returns:
181
+ str
182
+ """
183
+ number = match.group(1)
184
+ match_2 = match.group(2)
185
+ if match_2 == "+":
186
+ match_2 = "多"
187
+ match_2: str = match_2 if match_2 else ""
188
+ quantifiers: str = match.group(3)
189
+ number: str = num2str(number)
190
+ result = f"{number}{match_2}{quantifiers}"
191
+ return result
192
+
193
+
194
+ def replace_number(match) -> str:
195
+ """
196
+ Args:
197
+ match (re.Match)
198
+ Returns:
199
+ str
200
+ """
201
+ sign = match.group(1)
202
+ number = match.group(2)
203
+ pure_decimal = match.group(5)
204
+ if pure_decimal:
205
+ result = num2str(pure_decimal)
206
+ else:
207
+ sign: str = "负" if sign else ""
208
+ number: str = num2str(number)
209
+ result = f"{sign}{number}"
210
+ return result
211
+
212
+
213
+ # 范围表达式
214
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
215
+
216
+ RE_RANGE = re.compile(
217
+ r"""
218
+ (?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
219
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
220
+ [-~] # 匹配范围分隔符
221
+ ((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
222
+ (?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
223
+ """,
224
+ re.VERBOSE,
225
+ )
226
+
227
+
228
+ def replace_range(match) -> str:
229
+ """
230
+ Args:
231
+ match (re.Match)
232
+ Returns:
233
+ str
234
+ """
235
+ first, second = match.group(1), match.group(6)
236
+ first = RE_NUMBER.sub(replace_number, first)
237
+ second = RE_NUMBER.sub(replace_number, second)
238
+ result = f"{first}到{second}"
239
+ return result
240
+
241
+
242
+ # ~至表达式
243
+ RE_TO_RANGE = re.compile(
244
+ r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
245
+ )
246
+
247
+
248
+ def replace_to_range(match) -> str:
249
+ """
250
+ Args:
251
+ match (re.Match)
252
+ Returns:
253
+ str
254
+ """
255
+ result = match.group(0).replace("~", "至")
256
+ return result
257
+
258
+
259
+ def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
260
+ stripped = value_string.lstrip("0")
261
+ if len(stripped) == 0:
262
+ return []
263
+ elif len(stripped) == 1:
264
+ if use_zero and len(stripped) < len(value_string):
265
+ return [DIGITS["0"], DIGITS[stripped]]
266
+ else:
267
+ return [DIGITS[stripped]]
268
+ else:
269
+ largest_unit = next(
270
+ power for power in reversed(UNITS.keys()) if power < len(stripped)
271
+ )
272
+ first_part = value_string[:-largest_unit]
273
+ second_part = value_string[-largest_unit:]
274
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
275
+
276
+
277
+ def verbalize_cardinal(value_string: str) -> str:
278
+ if not value_string:
279
+ return ""
280
+
281
+ # 000 -> '零' , 0 -> '零'
282
+ value_string = value_string.lstrip("0")
283
+ if len(value_string) == 0:
284
+ return DIGITS["0"]
285
+
286
+ result_symbols = _get_value(value_string)
287
+ # verbalized number starting with '一十*' is abbreviated as `十*`
288
+ if (
289
+ len(result_symbols) >= 2
290
+ and result_symbols[0] == DIGITS["1"]
291
+ and result_symbols[1] == UNITS[1]
292
+ ):
293
+ result_symbols = result_symbols[1:]
294
+ return "".join(result_symbols)
295
+
296
+
297
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
298
+ result_symbols = [DIGITS[digit] for digit in value_string]
299
+ result = "".join(result_symbols)
300
+ if alt_one:
301
+ result = result.replace("一", "幺")
302
+ return result
303
+
304
+
305
+ def num2str(value_string: str) -> str:
306
+ integer_decimal = value_string.split(".")
307
+ if len(integer_decimal) == 1:
308
+ integer = integer_decimal[0]
309
+ decimal = ""
310
+ elif len(integer_decimal) == 2:
311
+ integer, decimal = integer_decimal
312
+ else:
313
+ raise ValueError(
314
+ f"The value string: '${value_string}' has more than one point in it."
315
+ )
316
+
317
+ result = verbalize_cardinal(integer)
318
+
319
+ decimal = decimal.rstrip("0")
320
+ if decimal:
321
+ # '.22' is verbalized as '零点二二'
322
+ # '3.20' is verbalized as '三点二
323
+ result = result if result else "零"
324
+ result += "点" + verbalize_digit(decimal)
325
+ return result
326
+
327
+
328
+ if __name__ == "__main__":
329
+ text = ""
330
+ text = num2str(text)
331
+ print(text)
332
+ pass
src/YingMusicSinger/utils/stable_audio_tools/__init__.py ADDED
File without changes
src/YingMusicSinger/utils/stable_audio_tools/adp.py ADDED
@@ -0,0 +1,1686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
2
+ # License can be found in LICENSES/LICENSE_ADP.txt
3
+
4
+ import math
5
+ from inspect import isfunction
6
+ from math import ceil, floor, log, log2, pi
7
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from dac.nn.layers import Snake1d
12
+ from einops import rearrange, reduce, repeat
13
+ from einops.layers.torch import Rearrange
14
+ from einops_exts import rearrange_many
15
+ from packaging import version
16
+ from torch import Tensor, einsum
17
+ from torch.backends.cuda import sdp_kernel
18
+ from torch.nn import functional as F
19
+
20
+ """
21
+ Utils
22
+ """
23
+
24
+
25
+ class ConditionedSequential(nn.Module):
26
+ def __init__(self, *modules):
27
+ super().__init__()
28
+ self.module_list = nn.ModuleList(*modules)
29
+
30
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
31
+ for module in self.module_list:
32
+ x = module(x, mapping)
33
+ return x
34
+
35
+
36
+ T = TypeVar("T")
37
+
38
+
39
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
40
+ if exists(val):
41
+ return val
42
+ return d() if isfunction(d) else d
43
+
44
+
45
+ def exists(val: Optional[T]) -> T:
46
+ return val is not None
47
+
48
+
49
+ def closest_power_2(x: float) -> int:
50
+ exponent = log2(x)
51
+ distance_fn = lambda z: abs(x - 2**z) # noqa
52
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
53
+ return 2 ** int(exponent_closest)
54
+
55
+
56
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
57
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
58
+ for key in d.keys():
59
+ no_prefix = int(not key.startswith(prefix))
60
+ return_dicts[no_prefix][key] = d[key]
61
+ return return_dicts
62
+
63
+
64
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
65
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
66
+ if keep_prefix:
67
+ return kwargs_with_prefix, kwargs
68
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
69
+ return kwargs_no_prefix, kwargs
70
+
71
+
72
+ """
73
+ Convolutional Blocks
74
+ """
75
+ import typing as tp
76
+
77
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
78
+ # License available in LICENSES/LICENSE_META.txt
79
+
80
+
81
+ def get_extra_padding_for_conv1d(
82
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
83
+ ) -> int:
84
+ """See `pad_for_conv1d`."""
85
+ length = x.shape[-1]
86
+ n_frames = (length - kernel_size + padding_total) / stride + 1
87
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
88
+ return ideal_length - length
89
+
90
+
91
+ def pad_for_conv1d(
92
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
93
+ ):
94
+ """Pad for a convolution to make sure that the last window is full.
95
+ Extra padding is added at the end. This is required to ensure that we can rebuild
96
+ an output of the same length, as otherwise, even with padding, some time steps
97
+ might get removed.
98
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
99
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
100
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
101
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
102
+ 1 2 3 4 # once you removed padding, we are missing one time step !
103
+ """
104
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
105
+ return F.pad(x, (0, extra_padding))
106
+
107
+
108
+ def pad1d(
109
+ x: torch.Tensor,
110
+ paddings: tp.Tuple[int, int],
111
+ mode: str = "constant",
112
+ value: float = 0.0,
113
+ ):
114
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
115
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
116
+ """
117
+ length = x.shape[-1]
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ if mode == "reflect":
121
+ max_pad = max(padding_left, padding_right)
122
+ extra_pad = 0
123
+ if length <= max_pad:
124
+ extra_pad = max_pad - length + 1
125
+ x = F.pad(x, (0, extra_pad))
126
+ padded = F.pad(x, paddings, mode, value)
127
+ end = padded.shape[-1] - extra_pad
128
+ return padded[..., :end]
129
+ else:
130
+ return F.pad(x, paddings, mode, value)
131
+
132
+
133
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
134
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
135
+ padding_left, padding_right = paddings
136
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
137
+ assert (padding_left + padding_right) <= x.shape[-1]
138
+ end = x.shape[-1] - padding_right
139
+ return x[..., padding_left:end]
140
+
141
+
142
+ class Conv1d(nn.Conv1d):
143
+ def __init__(self, *args, **kwargs):
144
+ super().__init__(*args, **kwargs)
145
+
146
+ def forward(self, x: Tensor, causal=False) -> Tensor:
147
+ kernel_size = self.kernel_size[0]
148
+ stride = self.stride[0]
149
+ dilation = self.dilation[0]
150
+ kernel_size = (
151
+ kernel_size - 1
152
+ ) * dilation + 1 # effective kernel size with dilations
153
+ padding_total = kernel_size - stride
154
+ extra_padding = get_extra_padding_for_conv1d(
155
+ x, kernel_size, stride, padding_total
156
+ )
157
+ if causal:
158
+ # Left padding for causal
159
+ x = pad1d(x, (padding_total, extra_padding))
160
+ else:
161
+ # Asymmetric padding required for odd strides
162
+ padding_right = padding_total // 2
163
+ padding_left = padding_total - padding_right
164
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
165
+ return super().forward(x)
166
+
167
+
168
+ class ConvTranspose1d(nn.ConvTranspose1d):
169
+ def __init__(self, *args, **kwargs):
170
+ super().__init__(*args, **kwargs)
171
+
172
+ def forward(self, x: Tensor, causal=False) -> Tensor:
173
+ kernel_size = self.kernel_size[0]
174
+ stride = self.stride[0]
175
+ padding_total = kernel_size - stride
176
+
177
+ y = super().forward(x)
178
+
179
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
180
+ # removed at the very end, when keeping only the right length for the output,
181
+ # as removing it here would require also passing the length at the matching layer
182
+ # in the encoder.
183
+ if causal:
184
+ padding_right = ceil(padding_total)
185
+ padding_left = padding_total - padding_right
186
+ y = unpad1d(y, (padding_left, padding_right))
187
+ else:
188
+ # Asymmetric padding required for odd strides
189
+ padding_right = padding_total // 2
190
+ padding_left = padding_total - padding_right
191
+ y = unpad1d(y, (padding_left, padding_right))
192
+ return y
193
+
194
+
195
+ def Downsample1d(
196
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
197
+ ) -> nn.Module:
198
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
199
+
200
+ return Conv1d(
201
+ in_channels=in_channels,
202
+ out_channels=out_channels,
203
+ kernel_size=factor * kernel_multiplier + 1,
204
+ stride=factor,
205
+ )
206
+
207
+
208
+ def Upsample1d(
209
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
210
+ ) -> nn.Module:
211
+ if factor == 1:
212
+ return Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)
213
+
214
+ if use_nearest:
215
+ return nn.Sequential(
216
+ nn.Upsample(scale_factor=factor, mode="nearest"),
217
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
218
+ )
219
+ else:
220
+ return ConvTranspose1d(
221
+ in_channels=in_channels,
222
+ out_channels=out_channels,
223
+ kernel_size=factor * 2,
224
+ stride=factor,
225
+ )
226
+
227
+
228
+ class ConvBlock1d(nn.Module):
229
+ def __init__(
230
+ self,
231
+ in_channels: int,
232
+ out_channels: int,
233
+ *,
234
+ kernel_size: int = 3,
235
+ stride: int = 1,
236
+ dilation: int = 1,
237
+ num_groups: int = 8,
238
+ use_norm: bool = True,
239
+ use_snake: bool = False,
240
+ ) -> None:
241
+ super().__init__()
242
+
243
+ self.groupnorm = (
244
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
245
+ if use_norm
246
+ else nn.Identity()
247
+ )
248
+
249
+ if use_snake:
250
+ self.activation = Snake1d(in_channels)
251
+ else:
252
+ self.activation = nn.SiLU()
253
+
254
+ self.project = Conv1d(
255
+ in_channels=in_channels,
256
+ out_channels=out_channels,
257
+ kernel_size=kernel_size,
258
+ stride=stride,
259
+ dilation=dilation,
260
+ )
261
+
262
+ def forward(
263
+ self,
264
+ x: Tensor,
265
+ scale_shift: Optional[Tuple[Tensor, Tensor]] = None,
266
+ causal=False,
267
+ ) -> Tensor:
268
+ x = self.groupnorm(x)
269
+ if exists(scale_shift):
270
+ scale, shift = scale_shift
271
+ x = x * (scale + 1) + shift
272
+ x = self.activation(x)
273
+ return self.project(x, causal=causal)
274
+
275
+
276
+ class MappingToScaleShift(nn.Module):
277
+ def __init__(
278
+ self,
279
+ features: int,
280
+ channels: int,
281
+ ):
282
+ super().__init__()
283
+
284
+ self.to_scale_shift = nn.Sequential(
285
+ nn.SiLU(),
286
+ nn.Linear(in_features=features, out_features=channels * 2),
287
+ )
288
+
289
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
290
+ scale_shift = self.to_scale_shift(mapping)
291
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
292
+ scale, shift = scale_shift.chunk(2, dim=1)
293
+ return scale, shift
294
+
295
+
296
+ class ResnetBlock1d(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ *,
302
+ kernel_size: int = 3,
303
+ stride: int = 1,
304
+ dilation: int = 1,
305
+ use_norm: bool = True,
306
+ use_snake: bool = False,
307
+ num_groups: int = 8,
308
+ context_mapping_features: Optional[int] = None,
309
+ ) -> None:
310
+ super().__init__()
311
+
312
+ self.use_mapping = exists(context_mapping_features)
313
+
314
+ self.block1 = ConvBlock1d(
315
+ in_channels=in_channels,
316
+ out_channels=out_channels,
317
+ kernel_size=kernel_size,
318
+ stride=stride,
319
+ dilation=dilation,
320
+ use_norm=use_norm,
321
+ num_groups=num_groups,
322
+ use_snake=use_snake,
323
+ )
324
+
325
+ if self.use_mapping:
326
+ assert exists(context_mapping_features)
327
+ self.to_scale_shift = MappingToScaleShift(
328
+ features=context_mapping_features, channels=out_channels
329
+ )
330
+
331
+ self.block2 = ConvBlock1d(
332
+ in_channels=out_channels,
333
+ out_channels=out_channels,
334
+ use_norm=use_norm,
335
+ num_groups=num_groups,
336
+ use_snake=use_snake,
337
+ )
338
+
339
+ self.to_out = (
340
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
341
+ if in_channels != out_channels
342
+ else nn.Identity()
343
+ )
344
+
345
+ def forward(
346
+ self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
347
+ ) -> Tensor:
348
+ assert_message = "context mapping required if context_mapping_features > 0"
349
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
350
+
351
+ h = self.block1(x, causal=causal)
352
+
353
+ scale_shift = None
354
+ if self.use_mapping:
355
+ scale_shift = self.to_scale_shift(mapping)
356
+
357
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
358
+
359
+ return h + self.to_out(x)
360
+
361
+
362
+ class Patcher(nn.Module):
363
+ def __init__(
364
+ self,
365
+ in_channels: int,
366
+ out_channels: int,
367
+ patch_size: int,
368
+ context_mapping_features: Optional[int] = None,
369
+ use_snake: bool = False,
370
+ ):
371
+ super().__init__()
372
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
373
+ assert out_channels % patch_size == 0, assert_message
374
+ self.patch_size = patch_size
375
+
376
+ self.block = ResnetBlock1d(
377
+ in_channels=in_channels,
378
+ out_channels=out_channels // patch_size,
379
+ num_groups=1,
380
+ context_mapping_features=context_mapping_features,
381
+ use_snake=use_snake,
382
+ )
383
+
384
+ def forward(
385
+ self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
386
+ ) -> Tensor:
387
+ x = self.block(x, mapping, causal=causal)
388
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
389
+ return x
390
+
391
+
392
+ class Unpatcher(nn.Module):
393
+ def __init__(
394
+ self,
395
+ in_channels: int,
396
+ out_channels: int,
397
+ patch_size: int,
398
+ context_mapping_features: Optional[int] = None,
399
+ use_snake: bool = False,
400
+ ):
401
+ super().__init__()
402
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
403
+ assert in_channels % patch_size == 0, assert_message
404
+ self.patch_size = patch_size
405
+
406
+ self.block = ResnetBlock1d(
407
+ in_channels=in_channels // patch_size,
408
+ out_channels=out_channels,
409
+ num_groups=1,
410
+ context_mapping_features=context_mapping_features,
411
+ use_snake=use_snake,
412
+ )
413
+
414
+ def forward(
415
+ self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
416
+ ) -> Tensor:
417
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
418
+ x = self.block(x, mapping, causal=causal)
419
+ return x
420
+
421
+
422
+ """
423
+ Attention Components
424
+ """
425
+
426
+
427
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
428
+ mid_features = features * multiplier
429
+ return nn.Sequential(
430
+ nn.Linear(in_features=features, out_features=mid_features),
431
+ nn.GELU(),
432
+ nn.Linear(in_features=mid_features, out_features=features),
433
+ )
434
+
435
+
436
+ def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
437
+ b, ndim = sim.shape[0], mask.ndim
438
+ if ndim == 3:
439
+ mask = rearrange(mask, "b n m -> b 1 n m")
440
+ if ndim == 2:
441
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
442
+ max_neg_value = -torch.finfo(sim.dtype).max
443
+ sim = sim.masked_fill(~mask, max_neg_value)
444
+ return sim
445
+
446
+
447
+ def causal_mask(q: Tensor, k: Tensor) -> Tensor:
448
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
449
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
450
+ mask = repeat(mask, "n m -> b n m", b=b)
451
+ return mask
452
+
453
+
454
+ class AttentionBase(nn.Module):
455
+ def __init__(
456
+ self,
457
+ features: int,
458
+ *,
459
+ head_features: int,
460
+ num_heads: int,
461
+ out_features: Optional[int] = None,
462
+ ):
463
+ super().__init__()
464
+ self.scale = head_features**-0.5
465
+ self.num_heads = num_heads
466
+ mid_features = head_features * num_heads
467
+ out_features = default(out_features, features)
468
+
469
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
470
+
471
+ self.use_flash = torch.cuda.is_available() and version.parse(
472
+ torch.__version__
473
+ ) >= version.parse("2.0.0")
474
+
475
+ if not self.use_flash:
476
+ return
477
+
478
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
479
+
480
+ if device_properties.major == 8 and device_properties.minor == 0:
481
+ # Use flash attention for A100 GPUs
482
+ self.sdp_kernel_config = (True, False, False)
483
+ else:
484
+ # Don't use flash attention for other GPUs
485
+ self.sdp_kernel_config = (False, True, True)
486
+
487
+ def forward(
488
+ self,
489
+ q: Tensor,
490
+ k: Tensor,
491
+ v: Tensor,
492
+ mask: Optional[Tensor] = None,
493
+ is_causal: bool = False,
494
+ ) -> Tensor:
495
+ # Split heads
496
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
497
+
498
+ if not self.use_flash:
499
+ if is_causal and not mask:
500
+ # Mask out future tokens for causal attention
501
+ mask = causal_mask(q, k)
502
+
503
+ # Compute similarity matrix and add eventual mask
504
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
505
+ sim = add_mask(sim, mask) if exists(mask) else sim
506
+
507
+ # Get attention matrix with softmax
508
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
509
+
510
+ # Compute values
511
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
512
+ else:
513
+ with sdp_kernel(*self.sdp_kernel_config):
514
+ out = F.scaled_dot_product_attention(
515
+ q, k, v, attn_mask=mask, is_causal=is_causal
516
+ )
517
+
518
+ out = rearrange(out, "b h n d -> b n (h d)")
519
+ return self.to_out(out)
520
+
521
+
522
+ class Attention(nn.Module):
523
+ def __init__(
524
+ self,
525
+ features: int,
526
+ *,
527
+ head_features: int,
528
+ num_heads: int,
529
+ out_features: Optional[int] = None,
530
+ context_features: Optional[int] = None,
531
+ causal: bool = False,
532
+ ):
533
+ super().__init__()
534
+ self.context_features = context_features
535
+ self.causal = causal
536
+ mid_features = head_features * num_heads
537
+ context_features = default(context_features, features)
538
+
539
+ self.norm = nn.LayerNorm(features)
540
+ self.norm_context = nn.LayerNorm(context_features)
541
+ self.to_q = nn.Linear(
542
+ in_features=features, out_features=mid_features, bias=False
543
+ )
544
+ self.to_kv = nn.Linear(
545
+ in_features=context_features, out_features=mid_features * 2, bias=False
546
+ )
547
+ self.attention = AttentionBase(
548
+ features,
549
+ num_heads=num_heads,
550
+ head_features=head_features,
551
+ out_features=out_features,
552
+ )
553
+
554
+ def forward(
555
+ self,
556
+ x: Tensor, # [b, n, c]
557
+ context: Optional[Tensor] = None, # [b, m, d]
558
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
559
+ causal: Optional[bool] = False,
560
+ ) -> Tensor:
561
+ assert_message = "You must provide a context when using context_features"
562
+ assert not self.context_features or exists(context), assert_message
563
+ # Use context if provided
564
+ context = default(context, x)
565
+ # Normalize then compute q from input and k,v from context
566
+ x, context = self.norm(x), self.norm_context(context)
567
+
568
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
569
+
570
+ if exists(context_mask):
571
+ # Mask out cross-attention for padding tokens
572
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
573
+ k, v = k * mask, v * mask
574
+
575
+ # Compute and return attention
576
+ return self.attention(q, k, v, is_causal=self.causal or causal)
577
+
578
+
579
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
580
+ mid_features = features * multiplier
581
+ return nn.Sequential(
582
+ nn.Linear(in_features=features, out_features=mid_features),
583
+ nn.GELU(),
584
+ nn.Linear(in_features=mid_features, out_features=features),
585
+ )
586
+
587
+
588
+ """
589
+ Transformer Blocks
590
+ """
591
+
592
+
593
+ class TransformerBlock(nn.Module):
594
+ def __init__(
595
+ self,
596
+ features: int,
597
+ num_heads: int,
598
+ head_features: int,
599
+ multiplier: int,
600
+ context_features: Optional[int] = None,
601
+ ):
602
+ super().__init__()
603
+
604
+ self.use_cross_attention = exists(context_features) and context_features > 0
605
+
606
+ self.attention = Attention(
607
+ features=features, num_heads=num_heads, head_features=head_features
608
+ )
609
+
610
+ if self.use_cross_attention:
611
+ self.cross_attention = Attention(
612
+ features=features,
613
+ num_heads=num_heads,
614
+ head_features=head_features,
615
+ context_features=context_features,
616
+ )
617
+
618
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
619
+
620
+ def forward(
621
+ self,
622
+ x: Tensor,
623
+ *,
624
+ context: Optional[Tensor] = None,
625
+ context_mask: Optional[Tensor] = None,
626
+ causal: Optional[bool] = False,
627
+ ) -> Tensor:
628
+ x = self.attention(x, causal=causal) + x
629
+ if self.use_cross_attention:
630
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
631
+ x = self.feed_forward(x) + x
632
+ return x
633
+
634
+
635
+ """
636
+ Transformers
637
+ """
638
+
639
+
640
+ class Transformer1d(nn.Module):
641
+ def __init__(
642
+ self,
643
+ num_layers: int,
644
+ channels: int,
645
+ num_heads: int,
646
+ head_features: int,
647
+ multiplier: int,
648
+ context_features: Optional[int] = None,
649
+ ):
650
+ super().__init__()
651
+
652
+ self.to_in = nn.Sequential(
653
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
654
+ Conv1d(
655
+ in_channels=channels,
656
+ out_channels=channels,
657
+ kernel_size=1,
658
+ ),
659
+ Rearrange("b c t -> b t c"),
660
+ )
661
+
662
+ self.blocks = nn.ModuleList(
663
+ [
664
+ TransformerBlock(
665
+ features=channels,
666
+ head_features=head_features,
667
+ num_heads=num_heads,
668
+ multiplier=multiplier,
669
+ context_features=context_features,
670
+ )
671
+ for i in range(num_layers)
672
+ ]
673
+ )
674
+
675
+ self.to_out = nn.Sequential(
676
+ Rearrange("b t c -> b c t"),
677
+ Conv1d(
678
+ in_channels=channels,
679
+ out_channels=channels,
680
+ kernel_size=1,
681
+ ),
682
+ )
683
+
684
+ def forward(
685
+ self,
686
+ x: Tensor,
687
+ *,
688
+ context: Optional[Tensor] = None,
689
+ context_mask: Optional[Tensor] = None,
690
+ causal=False,
691
+ ) -> Tensor:
692
+ x = self.to_in(x)
693
+ for block in self.blocks:
694
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
695
+ x = self.to_out(x)
696
+ return x
697
+
698
+
699
+ """
700
+ Time Embeddings
701
+ """
702
+
703
+
704
+ class SinusoidalEmbedding(nn.Module):
705
+ def __init__(self, dim: int):
706
+ super().__init__()
707
+ self.dim = dim
708
+
709
+ def forward(self, x: Tensor) -> Tensor:
710
+ device, half_dim = x.device, self.dim // 2
711
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
712
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
713
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
714
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
715
+
716
+
717
+ class LearnedPositionalEmbedding(nn.Module):
718
+ """Used for continuous time"""
719
+
720
+ def __init__(self, dim: int):
721
+ super().__init__()
722
+ assert (dim % 2) == 0
723
+ half_dim = dim // 2
724
+ self.weights = nn.Parameter(torch.randn(half_dim))
725
+
726
+ def forward(self, x: Tensor) -> Tensor:
727
+ x = rearrange(x, "b -> b 1")
728
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
729
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
730
+ fouriered = torch.cat((x, fouriered), dim=-1)
731
+ return fouriered
732
+
733
+
734
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
735
+ return nn.Sequential(
736
+ LearnedPositionalEmbedding(dim),
737
+ nn.Linear(in_features=dim + 1, out_features=out_features),
738
+ )
739
+
740
+
741
+ """
742
+ Encoder/Decoder Components
743
+ """
744
+
745
+
746
+ class DownsampleBlock1d(nn.Module):
747
+ def __init__(
748
+ self,
749
+ in_channels: int,
750
+ out_channels: int,
751
+ *,
752
+ factor: int,
753
+ num_groups: int,
754
+ num_layers: int,
755
+ kernel_multiplier: int = 2,
756
+ use_pre_downsample: bool = True,
757
+ use_skip: bool = False,
758
+ use_snake: bool = False,
759
+ extract_channels: int = 0,
760
+ context_channels: int = 0,
761
+ num_transformer_blocks: int = 0,
762
+ attention_heads: Optional[int] = None,
763
+ attention_features: Optional[int] = None,
764
+ attention_multiplier: Optional[int] = None,
765
+ context_mapping_features: Optional[int] = None,
766
+ context_embedding_features: Optional[int] = None,
767
+ ):
768
+ super().__init__()
769
+ self.use_pre_downsample = use_pre_downsample
770
+ self.use_skip = use_skip
771
+ self.use_transformer = num_transformer_blocks > 0
772
+ self.use_extract = extract_channels > 0
773
+ self.use_context = context_channels > 0
774
+
775
+ channels = out_channels if use_pre_downsample else in_channels
776
+
777
+ self.downsample = Downsample1d(
778
+ in_channels=in_channels,
779
+ out_channels=out_channels,
780
+ factor=factor,
781
+ kernel_multiplier=kernel_multiplier,
782
+ )
783
+
784
+ self.blocks = nn.ModuleList(
785
+ [
786
+ ResnetBlock1d(
787
+ in_channels=channels + context_channels if i == 0 else channels,
788
+ out_channels=channels,
789
+ num_groups=num_groups,
790
+ context_mapping_features=context_mapping_features,
791
+ use_snake=use_snake,
792
+ )
793
+ for i in range(num_layers)
794
+ ]
795
+ )
796
+
797
+ if self.use_transformer:
798
+ assert (exists(attention_heads) or exists(attention_features)) and exists(
799
+ attention_multiplier
800
+ )
801
+
802
+ if attention_features is None and attention_heads is not None:
803
+ attention_features = channels // attention_heads
804
+
805
+ if attention_heads is None and attention_features is not None:
806
+ attention_heads = channels // attention_features
807
+
808
+ self.transformer = Transformer1d(
809
+ num_layers=num_transformer_blocks,
810
+ channels=channels,
811
+ num_heads=attention_heads,
812
+ head_features=attention_features,
813
+ multiplier=attention_multiplier,
814
+ context_features=context_embedding_features,
815
+ )
816
+
817
+ if self.use_extract:
818
+ num_extract_groups = min(num_groups, extract_channels)
819
+ self.to_extracted = ResnetBlock1d(
820
+ in_channels=out_channels,
821
+ out_channels=extract_channels,
822
+ num_groups=num_extract_groups,
823
+ use_snake=use_snake,
824
+ )
825
+
826
+ def forward(
827
+ self,
828
+ x: Tensor,
829
+ *,
830
+ mapping: Optional[Tensor] = None,
831
+ channels: Optional[Tensor] = None,
832
+ embedding: Optional[Tensor] = None,
833
+ embedding_mask: Optional[Tensor] = None,
834
+ causal: Optional[bool] = False,
835
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
836
+ if self.use_pre_downsample:
837
+ x = self.downsample(x)
838
+
839
+ if self.use_context and exists(channels):
840
+ x = torch.cat([x, channels], dim=1)
841
+
842
+ skips = []
843
+ for block in self.blocks:
844
+ x = block(x, mapping=mapping, causal=causal)
845
+ skips += [x] if self.use_skip else []
846
+
847
+ if self.use_transformer:
848
+ x = self.transformer(
849
+ x, context=embedding, context_mask=embedding_mask, causal=causal
850
+ )
851
+ skips += [x] if self.use_skip else []
852
+
853
+ if not self.use_pre_downsample:
854
+ x = self.downsample(x)
855
+
856
+ if self.use_extract:
857
+ extracted = self.to_extracted(x)
858
+ return x, extracted
859
+
860
+ return (x, skips) if self.use_skip else x
861
+
862
+
863
+ class UpsampleBlock1d(nn.Module):
864
+ def __init__(
865
+ self,
866
+ in_channels: int,
867
+ out_channels: int,
868
+ *,
869
+ factor: int,
870
+ num_layers: int,
871
+ num_groups: int,
872
+ use_nearest: bool = False,
873
+ use_pre_upsample: bool = False,
874
+ use_skip: bool = False,
875
+ use_snake: bool = False,
876
+ skip_channels: int = 0,
877
+ use_skip_scale: bool = False,
878
+ extract_channels: int = 0,
879
+ num_transformer_blocks: int = 0,
880
+ attention_heads: Optional[int] = None,
881
+ attention_features: Optional[int] = None,
882
+ attention_multiplier: Optional[int] = None,
883
+ context_mapping_features: Optional[int] = None,
884
+ context_embedding_features: Optional[int] = None,
885
+ ):
886
+ super().__init__()
887
+
888
+ self.use_extract = extract_channels > 0
889
+ self.use_pre_upsample = use_pre_upsample
890
+ self.use_transformer = num_transformer_blocks > 0
891
+ self.use_skip = use_skip
892
+ self.skip_scale = 2**-0.5 if use_skip_scale else 1.0
893
+
894
+ channels = out_channels if use_pre_upsample else in_channels
895
+
896
+ self.blocks = nn.ModuleList(
897
+ [
898
+ ResnetBlock1d(
899
+ in_channels=channels + skip_channels,
900
+ out_channels=channels,
901
+ num_groups=num_groups,
902
+ context_mapping_features=context_mapping_features,
903
+ use_snake=use_snake,
904
+ )
905
+ for _ in range(num_layers)
906
+ ]
907
+ )
908
+
909
+ if self.use_transformer:
910
+ assert (exists(attention_heads) or exists(attention_features)) and exists(
911
+ attention_multiplier
912
+ )
913
+
914
+ if attention_features is None and attention_heads is not None:
915
+ attention_features = channels // attention_heads
916
+
917
+ if attention_heads is None and attention_features is not None:
918
+ attention_heads = channels // attention_features
919
+
920
+ self.transformer = Transformer1d(
921
+ num_layers=num_transformer_blocks,
922
+ channels=channels,
923
+ num_heads=attention_heads,
924
+ head_features=attention_features,
925
+ multiplier=attention_multiplier,
926
+ context_features=context_embedding_features,
927
+ )
928
+
929
+ self.upsample = Upsample1d(
930
+ in_channels=in_channels,
931
+ out_channels=out_channels,
932
+ factor=factor,
933
+ use_nearest=use_nearest,
934
+ )
935
+
936
+ if self.use_extract:
937
+ num_extract_groups = min(num_groups, extract_channels)
938
+ self.to_extracted = ResnetBlock1d(
939
+ in_channels=out_channels,
940
+ out_channels=extract_channels,
941
+ num_groups=num_extract_groups,
942
+ use_snake=use_snake,
943
+ )
944
+
945
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
946
+ return torch.cat([x, skip * self.skip_scale], dim=1)
947
+
948
+ def forward(
949
+ self,
950
+ x: Tensor,
951
+ *,
952
+ skips: Optional[List[Tensor]] = None,
953
+ mapping: Optional[Tensor] = None,
954
+ embedding: Optional[Tensor] = None,
955
+ embedding_mask: Optional[Tensor] = None,
956
+ causal: Optional[bool] = False,
957
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
958
+ if self.use_pre_upsample:
959
+ x = self.upsample(x)
960
+
961
+ for block in self.blocks:
962
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
963
+ x = block(x, mapping=mapping, causal=causal)
964
+
965
+ if self.use_transformer:
966
+ x = self.transformer(
967
+ x, context=embedding, context_mask=embedding_mask, causal=causal
968
+ )
969
+
970
+ if not self.use_pre_upsample:
971
+ x = self.upsample(x)
972
+
973
+ if self.use_extract:
974
+ extracted = self.to_extracted(x)
975
+ return x, extracted
976
+
977
+ return x
978
+
979
+
980
+ class BottleneckBlock1d(nn.Module):
981
+ def __init__(
982
+ self,
983
+ channels: int,
984
+ *,
985
+ num_groups: int,
986
+ num_transformer_blocks: int = 0,
987
+ attention_heads: Optional[int] = None,
988
+ attention_features: Optional[int] = None,
989
+ attention_multiplier: Optional[int] = None,
990
+ context_mapping_features: Optional[int] = None,
991
+ context_embedding_features: Optional[int] = None,
992
+ use_snake: bool = False,
993
+ ):
994
+ super().__init__()
995
+ self.use_transformer = num_transformer_blocks > 0
996
+
997
+ self.pre_block = ResnetBlock1d(
998
+ in_channels=channels,
999
+ out_channels=channels,
1000
+ num_groups=num_groups,
1001
+ context_mapping_features=context_mapping_features,
1002
+ use_snake=use_snake,
1003
+ )
1004
+
1005
+ if self.use_transformer:
1006
+ assert (exists(attention_heads) or exists(attention_features)) and exists(
1007
+ attention_multiplier
1008
+ )
1009
+
1010
+ if attention_features is None and attention_heads is not None:
1011
+ attention_features = channels // attention_heads
1012
+
1013
+ if attention_heads is None and attention_features is not None:
1014
+ attention_heads = channels // attention_features
1015
+
1016
+ self.transformer = Transformer1d(
1017
+ num_layers=num_transformer_blocks,
1018
+ channels=channels,
1019
+ num_heads=attention_heads,
1020
+ head_features=attention_features,
1021
+ multiplier=attention_multiplier,
1022
+ context_features=context_embedding_features,
1023
+ )
1024
+
1025
+ self.post_block = ResnetBlock1d(
1026
+ in_channels=channels,
1027
+ out_channels=channels,
1028
+ num_groups=num_groups,
1029
+ context_mapping_features=context_mapping_features,
1030
+ use_snake=use_snake,
1031
+ )
1032
+
1033
+ def forward(
1034
+ self,
1035
+ x: Tensor,
1036
+ *,
1037
+ mapping: Optional[Tensor] = None,
1038
+ embedding: Optional[Tensor] = None,
1039
+ embedding_mask: Optional[Tensor] = None,
1040
+ causal: Optional[bool] = False,
1041
+ ) -> Tensor:
1042
+ x = self.pre_block(x, mapping=mapping, causal=causal)
1043
+ if self.use_transformer:
1044
+ x = self.transformer(
1045
+ x, context=embedding, context_mask=embedding_mask, causal=causal
1046
+ )
1047
+ x = self.post_block(x, mapping=mapping, causal=causal)
1048
+ return x
1049
+
1050
+
1051
+ """
1052
+ UNet
1053
+ """
1054
+
1055
+
1056
+ class UNet1d(nn.Module):
1057
+ def __init__(
1058
+ self,
1059
+ in_channels: int,
1060
+ channels: int,
1061
+ multipliers: Sequence[int],
1062
+ factors: Sequence[int],
1063
+ num_blocks: Sequence[int],
1064
+ attentions: Sequence[int],
1065
+ patch_size: int = 1,
1066
+ resnet_groups: int = 8,
1067
+ use_context_time: bool = True,
1068
+ kernel_multiplier_downsample: int = 2,
1069
+ use_nearest_upsample: bool = False,
1070
+ use_skip_scale: bool = True,
1071
+ use_snake: bool = False,
1072
+ use_stft: bool = False,
1073
+ use_stft_context: bool = False,
1074
+ out_channels: Optional[int] = None,
1075
+ context_features: Optional[int] = None,
1076
+ context_features_multiplier: int = 4,
1077
+ context_channels: Optional[Sequence[int]] = None,
1078
+ context_embedding_features: Optional[int] = None,
1079
+ **kwargs,
1080
+ ):
1081
+ super().__init__()
1082
+ out_channels = default(out_channels, in_channels)
1083
+ context_channels = list(default(context_channels, []))
1084
+ num_layers = len(multipliers) - 1
1085
+ use_context_features = exists(context_features)
1086
+ use_context_channels = len(context_channels) > 0
1087
+ context_mapping_features = None
1088
+
1089
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
1090
+
1091
+ self.num_layers = num_layers
1092
+ self.use_context_time = use_context_time
1093
+ self.use_context_features = use_context_features
1094
+ self.use_context_channels = use_context_channels
1095
+ self.use_stft = use_stft
1096
+ self.use_stft_context = use_stft_context
1097
+
1098
+ self.context_features = context_features
1099
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
1100
+ context_channels = context_channels + [0] * context_channels_pad_length
1101
+ self.context_channels = context_channels
1102
+ self.context_embedding_features = context_embedding_features
1103
+
1104
+ if use_context_channels:
1105
+ has_context = [c > 0 for c in context_channels]
1106
+ self.has_context = has_context
1107
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
1108
+
1109
+ assert (
1110
+ len(factors) == num_layers
1111
+ and len(attentions) >= num_layers
1112
+ and len(num_blocks) == num_layers
1113
+ )
1114
+
1115
+ if use_context_time or use_context_features:
1116
+ context_mapping_features = channels * context_features_multiplier
1117
+
1118
+ self.to_mapping = nn.Sequential(
1119
+ nn.Linear(context_mapping_features, context_mapping_features),
1120
+ nn.GELU(),
1121
+ nn.Linear(context_mapping_features, context_mapping_features),
1122
+ nn.GELU(),
1123
+ )
1124
+
1125
+ if use_context_time:
1126
+ assert exists(context_mapping_features)
1127
+ self.to_time = nn.Sequential(
1128
+ TimePositionalEmbedding(
1129
+ dim=channels, out_features=context_mapping_features
1130
+ ),
1131
+ nn.GELU(),
1132
+ )
1133
+
1134
+ if use_context_features:
1135
+ assert exists(context_features) and exists(context_mapping_features)
1136
+ self.to_features = nn.Sequential(
1137
+ nn.Linear(
1138
+ in_features=context_features, out_features=context_mapping_features
1139
+ ),
1140
+ nn.GELU(),
1141
+ )
1142
+
1143
+ if use_stft:
1144
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
1145
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
1146
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
1147
+ in_channels *= stft_channels
1148
+ out_channels *= stft_channels
1149
+ context_channels[0] *= stft_channels if use_stft_context else 1
1150
+ assert exists(in_channels) and exists(out_channels)
1151
+ self.stft = STFT(**stft_kwargs)
1152
+
1153
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
1154
+
1155
+ self.to_in = Patcher(
1156
+ in_channels=in_channels + context_channels[0],
1157
+ out_channels=channels * multipliers[0],
1158
+ patch_size=patch_size,
1159
+ context_mapping_features=context_mapping_features,
1160
+ use_snake=use_snake,
1161
+ )
1162
+
1163
+ self.downsamples = nn.ModuleList(
1164
+ [
1165
+ DownsampleBlock1d(
1166
+ in_channels=channels * multipliers[i],
1167
+ out_channels=channels * multipliers[i + 1],
1168
+ context_mapping_features=context_mapping_features,
1169
+ context_channels=context_channels[i + 1],
1170
+ context_embedding_features=context_embedding_features,
1171
+ num_layers=num_blocks[i],
1172
+ factor=factors[i],
1173
+ kernel_multiplier=kernel_multiplier_downsample,
1174
+ num_groups=resnet_groups,
1175
+ use_pre_downsample=True,
1176
+ use_skip=True,
1177
+ use_snake=use_snake,
1178
+ num_transformer_blocks=attentions[i],
1179
+ **attention_kwargs,
1180
+ )
1181
+ for i in range(num_layers)
1182
+ ]
1183
+ )
1184
+
1185
+ self.bottleneck = BottleneckBlock1d(
1186
+ channels=channels * multipliers[-1],
1187
+ context_mapping_features=context_mapping_features,
1188
+ context_embedding_features=context_embedding_features,
1189
+ num_groups=resnet_groups,
1190
+ num_transformer_blocks=attentions[-1],
1191
+ use_snake=use_snake,
1192
+ **attention_kwargs,
1193
+ )
1194
+
1195
+ self.upsamples = nn.ModuleList(
1196
+ [
1197
+ UpsampleBlock1d(
1198
+ in_channels=channels * multipliers[i + 1],
1199
+ out_channels=channels * multipliers[i],
1200
+ context_mapping_features=context_mapping_features,
1201
+ context_embedding_features=context_embedding_features,
1202
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
1203
+ factor=factors[i],
1204
+ use_nearest=use_nearest_upsample,
1205
+ num_groups=resnet_groups,
1206
+ use_skip_scale=use_skip_scale,
1207
+ use_pre_upsample=False,
1208
+ use_skip=True,
1209
+ use_snake=use_snake,
1210
+ skip_channels=channels * multipliers[i + 1],
1211
+ num_transformer_blocks=attentions[i],
1212
+ **attention_kwargs,
1213
+ )
1214
+ for i in reversed(range(num_layers))
1215
+ ]
1216
+ )
1217
+
1218
+ self.to_out = Unpatcher(
1219
+ in_channels=channels * multipliers[0],
1220
+ out_channels=out_channels,
1221
+ patch_size=patch_size,
1222
+ context_mapping_features=context_mapping_features,
1223
+ use_snake=use_snake,
1224
+ )
1225
+
1226
+ def get_channels(
1227
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
1228
+ ) -> Optional[Tensor]:
1229
+ """Gets context channels at `layer` and checks that shape is correct"""
1230
+ use_context_channels = self.use_context_channels and self.has_context[layer]
1231
+ if not use_context_channels:
1232
+ return None
1233
+ assert exists(channels_list), "Missing context"
1234
+ # Get channels index (skipping zero channel contexts)
1235
+ channels_id = self.channels_ids[layer]
1236
+ # Get channels
1237
+ channels = channels_list[channels_id]
1238
+ message = f"Missing context for layer {layer} at index {channels_id}"
1239
+ assert exists(channels), message
1240
+ # Check channels
1241
+ num_channels = self.context_channels[layer]
1242
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
1243
+ assert channels.shape[1] == num_channels, message
1244
+ # STFT channels if requested
1245
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
1246
+ return channels
1247
+
1248
+ def get_mapping(
1249
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
1250
+ ) -> Optional[Tensor]:
1251
+ """Combines context time features and features into mapping"""
1252
+ items, mapping = [], None
1253
+ # Compute time features
1254
+ if self.use_context_time:
1255
+ assert_message = "use_context_time=True but no time features provided"
1256
+ assert exists(time), assert_message
1257
+ items += [self.to_time(time)]
1258
+ # Compute features
1259
+ if self.use_context_features:
1260
+ assert_message = "context_features exists but no features provided"
1261
+ assert exists(features), assert_message
1262
+ items += [self.to_features(features)]
1263
+ # Compute joint mapping
1264
+ if self.use_context_time or self.use_context_features:
1265
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
1266
+ mapping = self.to_mapping(mapping)
1267
+ return mapping
1268
+
1269
+ def forward(
1270
+ self,
1271
+ x: Tensor,
1272
+ time: Optional[Tensor] = None,
1273
+ *,
1274
+ features: Optional[Tensor] = None,
1275
+ channels_list: Optional[Sequence[Tensor]] = None,
1276
+ embedding: Optional[Tensor] = None,
1277
+ embedding_mask: Optional[Tensor] = None,
1278
+ causal: Optional[bool] = False,
1279
+ ) -> Tensor:
1280
+ channels = self.get_channels(channels_list, layer=0)
1281
+ # Apply stft if required
1282
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
1283
+ # Concat context channels at layer 0 if provided
1284
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
1285
+ # Compute mapping from time and features
1286
+ mapping = self.get_mapping(time, features)
1287
+ x = self.to_in(x, mapping, causal=causal)
1288
+ skips_list = [x]
1289
+
1290
+ for i, downsample in enumerate(self.downsamples):
1291
+ channels = self.get_channels(channels_list, layer=i + 1)
1292
+ x, skips = downsample(
1293
+ x,
1294
+ mapping=mapping,
1295
+ channels=channels,
1296
+ embedding=embedding,
1297
+ embedding_mask=embedding_mask,
1298
+ causal=causal,
1299
+ )
1300
+ skips_list += [skips]
1301
+
1302
+ x = self.bottleneck(
1303
+ x,
1304
+ mapping=mapping,
1305
+ embedding=embedding,
1306
+ embedding_mask=embedding_mask,
1307
+ causal=causal,
1308
+ )
1309
+
1310
+ for i, upsample in enumerate(self.upsamples):
1311
+ skips = skips_list.pop()
1312
+ x = upsample(
1313
+ x,
1314
+ skips=skips,
1315
+ mapping=mapping,
1316
+ embedding=embedding,
1317
+ embedding_mask=embedding_mask,
1318
+ causal=causal,
1319
+ )
1320
+
1321
+ x += skips_list.pop()
1322
+ x = self.to_out(x, mapping, causal=causal)
1323
+ x = self.stft.decode1d(x) if self.use_stft else x
1324
+
1325
+ return x
1326
+
1327
+
1328
+ """ Conditioning Modules """
1329
+
1330
+
1331
+ class FixedEmbedding(nn.Module):
1332
+ def __init__(self, max_length: int, features: int):
1333
+ super().__init__()
1334
+ self.max_length = max_length
1335
+ self.embedding = nn.Embedding(max_length, features)
1336
+
1337
+ def forward(self, x: Tensor) -> Tensor:
1338
+ batch_size, length, device = *x.shape[0:2], x.device
1339
+ assert_message = "Input sequence length must be <= max_length"
1340
+ assert length <= self.max_length, assert_message
1341
+ position = torch.arange(length, device=device)
1342
+ fixed_embedding = self.embedding(position)
1343
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
1344
+ return fixed_embedding
1345
+
1346
+
1347
+ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
1348
+ if proba == 1:
1349
+ return torch.ones(shape, device=device, dtype=torch.bool)
1350
+ elif proba == 0:
1351
+ return torch.zeros(shape, device=device, dtype=torch.bool)
1352
+ else:
1353
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
1354
+
1355
+
1356
+ class UNetCFG1d(UNet1d):
1357
+ """UNet1d with Classifier-Free Guidance"""
1358
+
1359
+ def __init__(
1360
+ self,
1361
+ context_embedding_max_length: int,
1362
+ context_embedding_features: int,
1363
+ use_xattn_time: bool = False,
1364
+ **kwargs,
1365
+ ):
1366
+ super().__init__(
1367
+ context_embedding_features=context_embedding_features, **kwargs
1368
+ )
1369
+
1370
+ self.use_xattn_time = use_xattn_time
1371
+
1372
+ if use_xattn_time:
1373
+ assert exists(context_embedding_features)
1374
+ self.to_time_embedding = nn.Sequential(
1375
+ TimePositionalEmbedding(
1376
+ dim=kwargs["channels"], out_features=context_embedding_features
1377
+ ),
1378
+ nn.GELU(),
1379
+ )
1380
+
1381
+ context_embedding_max_length += 1 # Add one for time embedding
1382
+
1383
+ self.fixed_embedding = FixedEmbedding(
1384
+ max_length=context_embedding_max_length, features=context_embedding_features
1385
+ )
1386
+
1387
+ def forward( # type: ignore
1388
+ self,
1389
+ x: Tensor,
1390
+ time: Tensor,
1391
+ *,
1392
+ embedding: Tensor,
1393
+ embedding_mask: Optional[Tensor] = None,
1394
+ embedding_scale: float = 1.0,
1395
+ embedding_mask_proba: float = 0.0,
1396
+ batch_cfg: bool = False,
1397
+ rescale_cfg: bool = False,
1398
+ scale_phi: float = 0.4,
1399
+ negative_embedding: Optional[Tensor] = None,
1400
+ negative_embedding_mask: Optional[Tensor] = None,
1401
+ **kwargs,
1402
+ ) -> Tensor:
1403
+ b, device = embedding.shape[0], embedding.device
1404
+
1405
+ if self.use_xattn_time:
1406
+ embedding = torch.cat(
1407
+ [embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1
1408
+ )
1409
+
1410
+ if embedding_mask is not None:
1411
+ embedding_mask = torch.cat(
1412
+ [embedding_mask, torch.ones((b, 1), device=device)], dim=1
1413
+ )
1414
+
1415
+ fixed_embedding = self.fixed_embedding(embedding)
1416
+
1417
+ if embedding_mask_proba > 0.0:
1418
+ # Randomly mask embedding
1419
+ batch_mask = rand_bool(
1420
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
1421
+ )
1422
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
1423
+
1424
+ if embedding_scale != 1.0:
1425
+ if batch_cfg:
1426
+ batch_x = torch.cat([x, x], dim=0)
1427
+ batch_time = torch.cat([time, time], dim=0)
1428
+
1429
+ if negative_embedding is not None:
1430
+ if negative_embedding_mask is not None:
1431
+ negative_embedding_mask = negative_embedding_mask.to(
1432
+ torch.bool
1433
+ ).unsqueeze(2)
1434
+
1435
+ negative_embedding = torch.where(
1436
+ negative_embedding_mask, negative_embedding, fixed_embedding
1437
+ )
1438
+
1439
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
1440
+
1441
+ else:
1442
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
1443
+
1444
+ batch_mask = None
1445
+ if embedding_mask is not None:
1446
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
1447
+
1448
+ batch_features = None
1449
+ features = kwargs.pop("features", None)
1450
+ if self.use_context_features:
1451
+ batch_features = torch.cat([features, features], dim=0)
1452
+
1453
+ batch_channels = None
1454
+ channels_list = kwargs.pop("channels_list", None)
1455
+ if self.use_context_channels:
1456
+ batch_channels = []
1457
+ for channels in channels_list:
1458
+ batch_channels += [torch.cat([channels, channels], dim=0)]
1459
+
1460
+ # Compute both normal and fixed embedding outputs
1461
+ batch_out = super().forward(
1462
+ batch_x,
1463
+ batch_time,
1464
+ embedding=batch_embed,
1465
+ embedding_mask=batch_mask,
1466
+ features=batch_features,
1467
+ channels_list=batch_channels,
1468
+ **kwargs,
1469
+ )
1470
+ out, out_masked = batch_out.chunk(2, dim=0)
1471
+
1472
+ else:
1473
+ # Compute both normal and fixed embedding outputs
1474
+ out = super().forward(
1475
+ x,
1476
+ time,
1477
+ embedding=embedding,
1478
+ embedding_mask=embedding_mask,
1479
+ **kwargs,
1480
+ )
1481
+ out_masked = super().forward(
1482
+ x,
1483
+ time,
1484
+ embedding=fixed_embedding,
1485
+ embedding_mask=embedding_mask,
1486
+ **kwargs,
1487
+ )
1488
+
1489
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
1490
+
1491
+ if rescale_cfg:
1492
+ out_std = out.std(dim=1, keepdim=True)
1493
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
1494
+
1495
+ return (
1496
+ scale_phi * (out_cfg * (out_std / out_cfg_std))
1497
+ + (1 - scale_phi) * out_cfg
1498
+ )
1499
+
1500
+ else:
1501
+ return out_cfg
1502
+
1503
+ else:
1504
+ return super().forward(
1505
+ x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs
1506
+ )
1507
+
1508
+
1509
+ class UNetNCCA1d(UNet1d):
1510
+ """UNet1d with Noise Channel Conditioning Augmentation"""
1511
+
1512
+ def __init__(self, context_features: int, **kwargs):
1513
+ super().__init__(context_features=context_features, **kwargs)
1514
+ self.embedder = NumberEmbedder(features=context_features)
1515
+
1516
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1517
+ x = x if torch.is_tensor(x) else torch.tensor(x)
1518
+ return x.expand(shape)
1519
+
1520
+ def forward( # type: ignore
1521
+ self,
1522
+ x: Tensor,
1523
+ time: Tensor,
1524
+ *,
1525
+ channels_list: Sequence[Tensor],
1526
+ channels_augmentation: Union[
1527
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1528
+ ] = False,
1529
+ channels_scale: Union[
1530
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
1531
+ ] = 0,
1532
+ **kwargs,
1533
+ ) -> Tensor:
1534
+ b, n = x.shape[0], len(channels_list)
1535
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1536
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1537
+
1538
+ # Augmentation (for each channel list item)
1539
+ for i in range(n):
1540
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
1541
+ scale = rearrange(scale, "b -> b 1 1")
1542
+ item = channels_list[i]
1543
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1544
+
1545
+ # Scale embedding (sum reduction if more than one channel list item)
1546
+ channels_scale_emb = self.embedder(channels_scale)
1547
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
1548
+
1549
+ return super().forward(
1550
+ x=x,
1551
+ time=time,
1552
+ channels_list=channels_list,
1553
+ features=channels_scale_emb,
1554
+ **kwargs,
1555
+ )
1556
+
1557
+
1558
+ class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1559
+ def __init__(self, *args, **kwargs):
1560
+ super().__init__(*args, **kwargs)
1561
+
1562
+ def forward(self, *args, **kwargs): # type: ignore
1563
+ return UNetCFG1d.forward(self, *args, **kwargs)
1564
+
1565
+
1566
+ def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1567
+ if type == "base":
1568
+ return UNet1d(**kwargs)
1569
+ elif type == "all":
1570
+ return UNetAll1d(**kwargs)
1571
+ elif type == "cfg":
1572
+ return UNetCFG1d(**kwargs)
1573
+ elif type == "ncca":
1574
+ return UNetNCCA1d(**kwargs)
1575
+ else:
1576
+ raise ValueError(f"Unknown XUNet1d type: {type}")
1577
+
1578
+
1579
+ class NumberEmbedder(nn.Module):
1580
+ def __init__(
1581
+ self,
1582
+ features: int,
1583
+ dim: int = 256,
1584
+ ):
1585
+ super().__init__()
1586
+ self.features = features
1587
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
1588
+
1589
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
1590
+ if not torch.is_tensor(x):
1591
+ device = next(self.embedding.parameters()).device
1592
+ x = torch.tensor(x, device=device)
1593
+ assert isinstance(x, Tensor)
1594
+ shape = x.shape
1595
+ x = rearrange(x, "... -> (...)")
1596
+ embedding = self.embedding(x)
1597
+ x = embedding.view(*shape, self.features)
1598
+ return x # type: ignore
1599
+
1600
+
1601
+ """
1602
+ Audio Transforms
1603
+ """
1604
+
1605
+
1606
+ class STFT(nn.Module):
1607
+ """Helper for torch stft and istft"""
1608
+
1609
+ def __init__(
1610
+ self,
1611
+ num_fft: int = 1023,
1612
+ hop_length: int = 256,
1613
+ window_length: Optional[int] = None,
1614
+ length: Optional[int] = None,
1615
+ use_complex: bool = False,
1616
+ ):
1617
+ super().__init__()
1618
+ self.num_fft = num_fft
1619
+ self.hop_length = default(hop_length, floor(num_fft // 4))
1620
+ self.window_length = default(window_length, num_fft)
1621
+ self.length = length
1622
+ self.register_buffer("window", torch.hann_window(self.window_length))
1623
+ self.use_complex = use_complex
1624
+
1625
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1626
+ b = wave.shape[0]
1627
+ wave = rearrange(wave, "b c t -> (b c) t")
1628
+
1629
+ stft = torch.stft(
1630
+ wave,
1631
+ n_fft=self.num_fft,
1632
+ hop_length=self.hop_length,
1633
+ win_length=self.window_length,
1634
+ window=self.window, # type: ignore
1635
+ return_complex=True,
1636
+ normalized=True,
1637
+ )
1638
+
1639
+ if self.use_complex:
1640
+ # Returns real and imaginary
1641
+ stft_a, stft_b = stft.real, stft.imag
1642
+ else:
1643
+ # Returns magnitude and phase matrices
1644
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
1645
+ stft_a, stft_b = magnitude, phase
1646
+
1647
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
1648
+
1649
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1650
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
1651
+ length = closest_power_2(l * self.hop_length)
1652
+
1653
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1654
+
1655
+ if self.use_complex:
1656
+ real, imag = stft_a, stft_b
1657
+ else:
1658
+ magnitude, phase = stft_a, stft_b
1659
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1660
+
1661
+ stft = torch.stack([real, imag], dim=-1)
1662
+
1663
+ wave = torch.istft(
1664
+ stft,
1665
+ n_fft=self.num_fft,
1666
+ hop_length=self.hop_length,
1667
+ win_length=self.window_length,
1668
+ window=self.window, # type: ignore
1669
+ length=default(self.length, length),
1670
+ normalized=True,
1671
+ )
1672
+
1673
+ return rearrange(wave, "(b c) t -> b c t", b=b)
1674
+
1675
+ def encode1d(
1676
+ self, wave: Tensor, stacked: bool = True
1677
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1678
+ stft_a, stft_b = self.encode(wave)
1679
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1680
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
1681
+
1682
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
1683
+ f = self.num_fft // 2 + 1
1684
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1685
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1686
+ return self.decode(stft_a, stft_b)
src/YingMusicSinger/utils/stable_audio_tools/autoencoders.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Literal
3
+
4
+ import numpy as np
5
+ import torch
6
+ from alias_free_torch import Activation1d
7
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torchaudio import transforms as T
11
+
12
+ # from ..inference.sampling import sample
13
+ # from ..inference.utils import prepare_audio
14
+ from .blocks import SnakeBeta
15
+ from .bottleneck import Bottleneck, DiscreteBottleneck
16
+ from .diffusion import (
17
+ ConditionedDiffusionModel,
18
+ DAU1DCondWrapper,
19
+ DiTWrapper,
20
+ UNet1DCondWrapper,
21
+ )
22
+ from .factory import create_bottleneck_from_config, create_pretransform_from_config
23
+ from .pretransforms import Pretransform
24
+
25
+
26
+ def checkpoint(function, *args, **kwargs):
27
+ kwargs.setdefault("use_reentrant", False)
28
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
29
+
30
+
31
+ def get_activation(
32
+ activation: Literal["elu", "snake", "none"], antialias=False, channels=None
33
+ ) -> nn.Module:
34
+ if activation == "elu":
35
+ act = nn.ELU()
36
+ elif activation == "snake":
37
+ act = SnakeBeta(channels)
38
+ elif activation == "none":
39
+ act = nn.Identity()
40
+ else:
41
+ raise ValueError(f"Unknown activation {activation}")
42
+
43
+ if antialias:
44
+ act = Activation1d(act)
45
+
46
+ return act
47
+
48
+
49
+ class ResidualUnit(nn.Module):
50
+ def __init__(
51
+ self,
52
+ in_channels,
53
+ out_channels,
54
+ dilation,
55
+ use_snake=False,
56
+ antialias_activation=False,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.dilation = dilation
61
+
62
+ padding = (dilation * (7 - 1)) // 2
63
+
64
+ self.layers = nn.Sequential(
65
+ get_activation(
66
+ "snake" if use_snake else "elu",
67
+ antialias=antialias_activation,
68
+ channels=out_channels,
69
+ ),
70
+ WNConv1d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ kernel_size=7,
74
+ dilation=dilation,
75
+ padding=padding,
76
+ ),
77
+ get_activation(
78
+ "snake" if use_snake else "elu",
79
+ antialias=antialias_activation,
80
+ channels=out_channels,
81
+ ),
82
+ WNConv1d(
83
+ in_channels=out_channels, out_channels=out_channels, kernel_size=1
84
+ ),
85
+ )
86
+
87
+ def forward(self, x):
88
+ res = x
89
+
90
+ # x = checkpoint(self.layers, x)
91
+ x = self.layers(x)
92
+
93
+ return x + res
94
+
95
+
96
+ class EncoderBlock(nn.Module):
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ out_channels,
101
+ stride,
102
+ use_snake=False,
103
+ antialias_activation=False,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.layers = nn.Sequential(
108
+ ResidualUnit(
109
+ in_channels=in_channels,
110
+ out_channels=in_channels,
111
+ dilation=1,
112
+ use_snake=use_snake,
113
+ ),
114
+ ResidualUnit(
115
+ in_channels=in_channels,
116
+ out_channels=in_channels,
117
+ dilation=3,
118
+ use_snake=use_snake,
119
+ ),
120
+ ResidualUnit(
121
+ in_channels=in_channels,
122
+ out_channels=in_channels,
123
+ dilation=9,
124
+ use_snake=use_snake,
125
+ ),
126
+ get_activation(
127
+ "snake" if use_snake else "elu",
128
+ antialias=antialias_activation,
129
+ channels=in_channels,
130
+ ),
131
+ WNConv1d(
132
+ in_channels=in_channels,
133
+ out_channels=out_channels,
134
+ kernel_size=2 * stride,
135
+ stride=stride,
136
+ padding=math.ceil(stride / 2),
137
+ ),
138
+ )
139
+
140
+ def forward(self, x):
141
+ return self.layers(x)
142
+
143
+
144
+ class DecoderBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ out_channels,
149
+ stride,
150
+ use_snake=False,
151
+ antialias_activation=False,
152
+ use_nearest_upsample=False,
153
+ ):
154
+ super().__init__()
155
+
156
+ if use_nearest_upsample:
157
+ upsample_layer = nn.Sequential(
158
+ nn.Upsample(scale_factor=stride, mode="nearest"),
159
+ WNConv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=2 * stride,
163
+ stride=1,
164
+ bias=False,
165
+ padding="same",
166
+ ),
167
+ )
168
+ else:
169
+ upsample_layer = WNConvTranspose1d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=2 * stride,
173
+ stride=stride,
174
+ padding=math.ceil(stride / 2),
175
+ )
176
+
177
+ self.layers = nn.Sequential(
178
+ get_activation(
179
+ "snake" if use_snake else "elu",
180
+ antialias=antialias_activation,
181
+ channels=in_channels,
182
+ ),
183
+ upsample_layer,
184
+ ResidualUnit(
185
+ in_channels=out_channels,
186
+ out_channels=out_channels,
187
+ dilation=1,
188
+ use_snake=use_snake,
189
+ ),
190
+ ResidualUnit(
191
+ in_channels=out_channels,
192
+ out_channels=out_channels,
193
+ dilation=3,
194
+ use_snake=use_snake,
195
+ ),
196
+ ResidualUnit(
197
+ in_channels=out_channels,
198
+ out_channels=out_channels,
199
+ dilation=9,
200
+ use_snake=use_snake,
201
+ ),
202
+ )
203
+
204
+ def forward(self, x):
205
+ return self.layers(x)
206
+
207
+
208
+ class OobleckEncoder(nn.Module):
209
+ def __init__(
210
+ self,
211
+ in_channels=2,
212
+ channels=128,
213
+ latent_dim=32,
214
+ c_mults=[1, 2, 4, 8],
215
+ strides=[2, 4, 8, 8],
216
+ use_snake=False,
217
+ antialias_activation=False,
218
+ ):
219
+ super().__init__()
220
+
221
+ c_mults = [1] + c_mults
222
+
223
+ self.depth = len(c_mults)
224
+
225
+ layers = [
226
+ WNConv1d(
227
+ in_channels=in_channels,
228
+ out_channels=c_mults[0] * channels,
229
+ kernel_size=7,
230
+ padding=3,
231
+ )
232
+ ]
233
+
234
+ for i in range(self.depth - 1):
235
+ layers += [
236
+ EncoderBlock(
237
+ in_channels=c_mults[i] * channels,
238
+ out_channels=c_mults[i + 1] * channels,
239
+ stride=strides[i],
240
+ use_snake=use_snake,
241
+ )
242
+ ]
243
+
244
+ layers += [
245
+ get_activation(
246
+ "snake" if use_snake else "elu",
247
+ antialias=antialias_activation,
248
+ channels=c_mults[-1] * channels,
249
+ ),
250
+ WNConv1d(
251
+ in_channels=c_mults[-1] * channels,
252
+ out_channels=latent_dim,
253
+ kernel_size=3,
254
+ padding=1,
255
+ ),
256
+ ]
257
+
258
+ self.layers = nn.Sequential(*layers)
259
+
260
+ def forward(self, x):
261
+ return self.layers(x)
262
+
263
+
264
+ class OobleckDecoder(nn.Module):
265
+ def __init__(
266
+ self,
267
+ out_channels=2,
268
+ channels=128,
269
+ latent_dim=32,
270
+ c_mults=[1, 2, 4, 8],
271
+ strides=[2, 4, 8, 8],
272
+ use_snake=False,
273
+ antialias_activation=False,
274
+ use_nearest_upsample=False,
275
+ final_tanh=True,
276
+ ):
277
+ super().__init__()
278
+
279
+ c_mults = [1] + c_mults
280
+
281
+ self.depth = len(c_mults)
282
+
283
+ layers = [
284
+ WNConv1d(
285
+ in_channels=latent_dim,
286
+ out_channels=c_mults[-1] * channels,
287
+ kernel_size=7,
288
+ padding=3,
289
+ ),
290
+ ]
291
+
292
+ for i in range(self.depth - 1, 0, -1):
293
+ layers += [
294
+ DecoderBlock(
295
+ in_channels=c_mults[i] * channels,
296
+ out_channels=c_mults[i - 1] * channels,
297
+ stride=strides[i - 1],
298
+ use_snake=use_snake,
299
+ antialias_activation=antialias_activation,
300
+ use_nearest_upsample=use_nearest_upsample,
301
+ )
302
+ ]
303
+
304
+ layers += [
305
+ get_activation(
306
+ "snake" if use_snake else "elu",
307
+ antialias=antialias_activation,
308
+ channels=c_mults[0] * channels,
309
+ ),
310
+ WNConv1d(
311
+ in_channels=c_mults[0] * channels,
312
+ out_channels=out_channels,
313
+ kernel_size=7,
314
+ padding=3,
315
+ bias=False,
316
+ ),
317
+ nn.Tanh() if final_tanh else nn.Identity(),
318
+ ]
319
+
320
+ self.layers = nn.Sequential(*layers)
321
+
322
+ def forward(self, x):
323
+ return self.layers(x)
324
+
325
+
326
+ class DACEncoderWrapper(nn.Module):
327
+ def __init__(self, in_channels=1, **kwargs):
328
+ super().__init__()
329
+
330
+ from dac.model.dac import Encoder as DACEncoder
331
+
332
+ latent_dim = kwargs.pop("latent_dim", None)
333
+
334
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
335
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
336
+ self.latent_dim = latent_dim
337
+
338
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
339
+ self.proj_out = (
340
+ nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1)
341
+ if latent_dim is not None
342
+ else nn.Identity()
343
+ )
344
+
345
+ if in_channels != 1:
346
+ self.encoder.block[0] = WNConv1d(
347
+ in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3
348
+ )
349
+
350
+ def forward(self, x):
351
+ x = self.encoder(x)
352
+ x = self.proj_out(x)
353
+ return x
354
+
355
+
356
+ class DACDecoderWrapper(nn.Module):
357
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
358
+ super().__init__()
359
+
360
+ from dac.model.dac import Decoder as DACDecoder
361
+
362
+ self.decoder = DACDecoder(
363
+ **kwargs, input_channel=latent_dim, d_out=out_channels
364
+ )
365
+
366
+ self.latent_dim = latent_dim
367
+
368
+ def forward(self, x):
369
+ return self.decoder(x)
370
+
371
+
372
+ class AudioAutoencoder(nn.Module):
373
+ def __init__(
374
+ self,
375
+ encoder,
376
+ decoder,
377
+ latent_dim,
378
+ downsampling_ratio,
379
+ sample_rate,
380
+ io_channels=2,
381
+ bottleneck: Bottleneck = None,
382
+ pretransform: Pretransform = None,
383
+ in_channels=None,
384
+ out_channels=None,
385
+ soft_clip=False,
386
+ ):
387
+ super().__init__()
388
+
389
+ self.downsampling_ratio = downsampling_ratio
390
+ self.sample_rate = sample_rate
391
+
392
+ self.latent_dim = latent_dim
393
+ self.io_channels = io_channels
394
+ self.in_channels = io_channels
395
+ self.out_channels = io_channels
396
+
397
+ self.min_length = self.downsampling_ratio
398
+
399
+ if in_channels is not None:
400
+ self.in_channels = in_channels
401
+
402
+ if out_channels is not None:
403
+ self.out_channels = out_channels
404
+
405
+ self.bottleneck = bottleneck
406
+
407
+ self.encoder = encoder
408
+
409
+ self.decoder = decoder
410
+
411
+ self.pretransform = pretransform
412
+
413
+ self.soft_clip = soft_clip
414
+
415
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
416
+
417
+ def encode(
418
+ self,
419
+ audio,
420
+ return_info=False,
421
+ skip_pretransform=False,
422
+ iterate_batch=False,
423
+ **kwargs,
424
+ ):
425
+ info = {}
426
+
427
+ if self.pretransform is not None and not skip_pretransform:
428
+ if self.pretransform.enable_grad:
429
+ if iterate_batch:
430
+ audios = []
431
+ for i in range(audio.shape[0]):
432
+ audios.append(self.pretransform.encode(audio[i : i + 1]))
433
+ audio = torch.cat(audios, dim=0)
434
+ else:
435
+ audio = self.pretransform.encode(audio)
436
+ else:
437
+ with torch.no_grad():
438
+ if iterate_batch:
439
+ audios = []
440
+ for i in range(audio.shape[0]):
441
+ audios.append(self.pretransform.encode(audio[i : i + 1]))
442
+ audio = torch.cat(audios, dim=0)
443
+ else:
444
+ audio = self.pretransform.encode(audio)
445
+
446
+ if self.encoder is not None:
447
+ if iterate_batch:
448
+ latents = []
449
+ for i in range(audio.shape[0]):
450
+ latents.append(self.encoder(audio[i : i + 1]))
451
+ latents = torch.cat(latents, dim=0)
452
+ else:
453
+ latents = self.encoder(audio)
454
+ else:
455
+ latents = audio
456
+
457
+ if self.bottleneck is not None:
458
+ # TODO: Add iterate batch logic, needs to merge the info dicts
459
+ latents, bottleneck_info = self.bottleneck.encode(
460
+ latents, return_info=True, **kwargs
461
+ )
462
+
463
+ info.update(bottleneck_info)
464
+
465
+ if return_info:
466
+ return latents, info
467
+
468
+ return latents
469
+
470
+ def decode(self, latents, iterate_batch=False, **kwargs):
471
+ if self.bottleneck is not None:
472
+ if iterate_batch:
473
+ decoded = []
474
+ for i in range(latents.shape[0]):
475
+ decoded.append(self.bottleneck.decode(latents[i : i + 1]))
476
+ latents = torch.cat(decoded, dim=0)
477
+ else:
478
+ latents = self.bottleneck.decode(latents)
479
+
480
+ if iterate_batch:
481
+ decoded = []
482
+ for i in range(latents.shape[0]):
483
+ decoded.append(self.decoder(latents[i : i + 1]))
484
+ decoded = torch.cat(decoded, dim=0)
485
+ else:
486
+ decoded = self.decoder(latents, **kwargs)
487
+
488
+ if self.pretransform is not None:
489
+ if self.pretransform.enable_grad:
490
+ if iterate_batch:
491
+ decodeds = []
492
+ for i in range(decoded.shape[0]):
493
+ decodeds.append(self.pretransform.decode(decoded[i : i + 1]))
494
+ decoded = torch.cat(decodeds, dim=0)
495
+ else:
496
+ decoded = self.pretransform.decode(decoded)
497
+ else:
498
+ with torch.no_grad():
499
+ if iterate_batch:
500
+ decodeds = []
501
+ for i in range(latents.shape[0]):
502
+ decodeds.append(
503
+ self.pretransform.decode(decoded[i : i + 1])
504
+ )
505
+ decoded = torch.cat(decodeds, dim=0)
506
+ else:
507
+ decoded = self.pretransform.decode(decoded)
508
+
509
+ if self.soft_clip:
510
+ decoded = torch.tanh(decoded)
511
+
512
+ return decoded
513
+
514
+ def decode_tokens(self, tokens, **kwargs):
515
+ """
516
+ Decode discrete tokens to audio
517
+ Only works with discrete autoencoders
518
+ """
519
+
520
+ assert isinstance(self.bottleneck, DiscreteBottleneck), (
521
+ "decode_tokens only works with discrete autoencoders"
522
+ )
523
+
524
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
525
+
526
+ return self.decode(latents, **kwargs)
527
+
528
+ def preprocess_audio_for_encoder(self, audio, in_sr):
529
+ """
530
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
531
+ If the model is mono, stereo audio will be converted to mono.
532
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
533
+ Audio will be resampled to the model's sample rate.
534
+ The output will have batch size 1 and be shape (1 x Channels x Length)
535
+ """
536
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
537
+
538
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
539
+ """
540
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
541
+ The audio in that list can be of different lengths and channels.
542
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
543
+ All audio will be resampled to the model's sample rate.
544
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
545
+ If the model is mono, all audio will be converted to mono.
546
+ The output will be a tensor of shape (Batch x Channels x Length)
547
+ """
548
+ batch_size = len(audio_list)
549
+ if isinstance(in_sr_list, int):
550
+ in_sr_list = [in_sr_list] * batch_size
551
+ assert len(in_sr_list) == batch_size, (
552
+ "list of sample rates must be the same length of audio_list"
553
+ )
554
+ new_audio = []
555
+ max_length = 0
556
+ # resample & find the max length
557
+ for i in range(batch_size):
558
+ audio = audio_list[i]
559
+ in_sr = in_sr_list[i]
560
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
561
+ # batchsize 1 was given by accident. Just squeeze it.
562
+ audio = audio.squeeze(0)
563
+ elif len(audio.shape) == 1:
564
+ # Mono signal, channel dimension is missing, unsqueeze it in
565
+ audio = audio.unsqueeze(0)
566
+ assert len(audio.shape) == 2, (
567
+ "Audio should be shape (Channels x Length) with no batch dimension"
568
+ )
569
+ # Resample audio
570
+ if in_sr != self.sample_rate:
571
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
572
+ audio = resample_tf(audio)
573
+ new_audio.append(audio)
574
+ if audio.shape[-1] > max_length:
575
+ max_length = audio.shape[-1]
576
+ # Pad every audio to the same length, multiple of model's downsampling ratio
577
+ padded_audio_length = (
578
+ max_length
579
+ + (self.min_length - (max_length % self.min_length)) % self.min_length
580
+ )
581
+ for i in range(batch_size):
582
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
583
+ new_audio[i] = prepare_audio(
584
+ new_audio[i],
585
+ in_sr=in_sr,
586
+ target_sr=in_sr,
587
+ target_length=padded_audio_length,
588
+ target_channels=self.in_channels,
589
+ device=new_audio[i].device,
590
+ ).squeeze(0)
591
+ # convert to tensor
592
+ return torch.stack(new_audio)
593
+
594
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
595
+ """
596
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
597
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
598
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
599
+ # and therefore you likely could use the same values with decode_audio.
600
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
601
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
602
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
603
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
604
+ Smaller chunk_size uses less memory, but more compute.
605
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
606
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
607
+ """
608
+ if not chunked:
609
+ # default behavior. Encode the entire audio in parallel
610
+ return self.encode(audio, **kwargs)
611
+ else:
612
+ # CHUNKED ENCODING
613
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
614
+ samples_per_latent = self.downsampling_ratio
615
+ total_size = audio.shape[2] # in samples
616
+ batch_size = audio.shape[0]
617
+ chunk_size *= samples_per_latent # converting metric in latents to samples
618
+ overlap *= samples_per_latent # converting metric in latents to samples
619
+ hop_size = chunk_size - overlap
620
+ chunks = []
621
+ for i in range(0, total_size - chunk_size + 1, hop_size):
622
+ chunk = audio[:, :, i : i + chunk_size]
623
+ chunks.append(chunk)
624
+ if i + chunk_size != total_size:
625
+ # Final chunk
626
+ chunk = audio[:, :, -chunk_size:]
627
+ chunks.append(chunk)
628
+ chunks = torch.stack(chunks)
629
+ num_chunks = chunks.shape[0]
630
+ # Note: y_size might be a different value from the latent length used in diffusion training
631
+ # because we can encode audio of varying lengths
632
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
633
+ y_size = total_size // samples_per_latent
634
+ # Create an empty latent, we will populate it with chunks as we encode them
635
+ y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(
636
+ audio.device
637
+ )
638
+ for i in range(num_chunks):
639
+ x_chunk = chunks[i, :]
640
+ # encode the chunk
641
+ y_chunk = self.encode(x_chunk)
642
+ # figure out where to put the audio along the time domain
643
+ if i == num_chunks - 1:
644
+ # final chunk always goes at the end
645
+ t_end = y_size
646
+ t_start = t_end - y_chunk.shape[2]
647
+ else:
648
+ t_start = i * hop_size // samples_per_latent
649
+ t_end = t_start + chunk_size // samples_per_latent
650
+ # remove the edges of the overlaps
651
+ ol = overlap // samples_per_latent // 2
652
+ chunk_start = 0
653
+ chunk_end = y_chunk.shape[2]
654
+ if i > 0:
655
+ # no overlap for the start of the first chunk
656
+ t_start += ol
657
+ chunk_start += ol
658
+ if i < num_chunks - 1:
659
+ # no overlap for the end of the last chunk
660
+ t_end -= ol
661
+ chunk_end -= ol
662
+ # paste the chunked audio into our y_final output audio
663
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
664
+ return y_final
665
+
666
+ def decode_audio(
667
+ self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs
668
+ ):
669
+ """
670
+ Decode latents to audio.
671
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
672
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
673
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
674
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
675
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
676
+ Smaller chunk_size uses less memory, but more compute.
677
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
678
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
679
+ """
680
+ if not chunked:
681
+ # default behavior. Decode the entire latent in parallel
682
+ return self.decode(latents, **kwargs)
683
+ else:
684
+ # chunked decoding
685
+ hop_size = chunk_size - overlap
686
+ total_size = latents.shape[2]
687
+ batch_size = latents.shape[0]
688
+ chunks = []
689
+ if total_size < chunk_size:
690
+ # pad the latents to be at least chunk_size
691
+ # 如果在这里pad之后,那么之后的生成歌曲就变噪音了
692
+ pad_size = chunk_size - total_size + 1
693
+ latents = F.pad(latents, (0, pad_size), mode="replicate")
694
+ total_size = latents.shape[2]
695
+ # import pdb; pdb.set_trace()
696
+ for i in range(0, total_size - chunk_size + 1, hop_size):
697
+ chunk = latents[:, :, i : i + chunk_size]
698
+ chunks.append(chunk)
699
+ if i + chunk_size != total_size:
700
+ # Final chunk
701
+ chunk = latents[:, :, -chunk_size:]
702
+ chunks.append(chunk)
703
+ chunks = torch.stack(chunks)
704
+ num_chunks = chunks.shape[0]
705
+ # samples_per_latent is just the downsampling ratio
706
+ samples_per_latent = self.downsampling_ratio
707
+ # Create an empty waveform, we will populate it with chunks as decode them
708
+ y_size = total_size * samples_per_latent
709
+ y_final = torch.zeros((batch_size, self.out_channels, y_size)).to(
710
+ latents.device
711
+ )
712
+ for i in range(num_chunks):
713
+ x_chunk = chunks[i, :]
714
+ # decode the chunk
715
+ y_chunk = self.decode(x_chunk)
716
+ # figure out where to put the audio along the time domain
717
+ if i == num_chunks - 1:
718
+ # final chunk always goes at the end
719
+ t_end = y_size
720
+ t_start = t_end - y_chunk.shape[2]
721
+ else:
722
+ t_start = i * hop_size * samples_per_latent
723
+ t_end = t_start + chunk_size * samples_per_latent
724
+ # remove the edges of the overlaps
725
+ ol = (overlap // 2) * samples_per_latent
726
+ chunk_start = 0
727
+ chunk_end = y_chunk.shape[2]
728
+ if i > 0:
729
+ # no overlap for the start of the first chunk
730
+ t_start += ol
731
+ chunk_start += ol
732
+ if i < num_chunks - 1:
733
+ # no overlap for the end of the last chunk
734
+ t_end -= ol
735
+ chunk_end -= ol
736
+ # paste the chunked audio into our y_final output audio
737
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
738
+ return y_final
739
+
740
+
741
+ class DiffusionAutoencoder(AudioAutoencoder):
742
+ def __init__(
743
+ self,
744
+ diffusion: ConditionedDiffusionModel,
745
+ diffusion_downsampling_ratio,
746
+ *args,
747
+ **kwargs,
748
+ ):
749
+ super().__init__(*args, **kwargs)
750
+
751
+ self.diffusion = diffusion
752
+
753
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
754
+
755
+ if self.encoder is not None:
756
+ # Shrink the initial encoder parameters to avoid saturated latents
757
+ with torch.no_grad():
758
+ for param in self.encoder.parameters():
759
+ param *= 0.5
760
+
761
+ def decode(self, latents, steps=100):
762
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
763
+
764
+ if self.bottleneck is not None:
765
+ latents = self.bottleneck.decode(latents)
766
+
767
+ if self.decoder is not None:
768
+ latents = self.decode(latents)
769
+
770
+ # Upsample latents to match diffusion length
771
+ if latents.shape[2] != upsampled_length:
772
+ latents = F.interpolate(latents, size=upsampled_length, mode="nearest")
773
+
774
+ noise = torch.randn(
775
+ latents.shape[0], self.io_channels, upsampled_length, device=latents.device
776
+ )
777
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
778
+
779
+ if self.pretransform is not None:
780
+ if self.pretransform.enable_grad:
781
+ decoded = self.pretransform.decode(decoded)
782
+ else:
783
+ with torch.no_grad():
784
+ decoded = self.pretransform.decode(decoded)
785
+
786
+ return decoded
787
+
788
+
789
+ # AE factories
790
+
791
+
792
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
793
+ encoder_type = encoder_config.get("type", None)
794
+ assert encoder_type is not None, "Encoder type must be specified"
795
+
796
+ if encoder_type == "oobleck":
797
+ encoder = OobleckEncoder(**encoder_config["config"])
798
+
799
+ elif encoder_type == "seanet":
800
+ from encodec.modules import SEANetEncoder
801
+
802
+ seanet_encoder_config = encoder_config["config"]
803
+
804
+ # SEANet encoder expects strides in reverse order
805
+ seanet_encoder_config["ratios"] = list(
806
+ reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))
807
+ )
808
+ encoder = SEANetEncoder(**seanet_encoder_config)
809
+ elif encoder_type == "dac":
810
+ dac_config = encoder_config["config"]
811
+
812
+ encoder = DACEncoderWrapper(**dac_config)
813
+ elif encoder_type == "local_attn":
814
+ from .local_attention import TransformerEncoder1D
815
+
816
+ local_attn_config = encoder_config["config"]
817
+
818
+ encoder = TransformerEncoder1D(**local_attn_config)
819
+ else:
820
+ raise ValueError(f"Unknown encoder type {encoder_type}")
821
+
822
+ requires_grad = encoder_config.get("requires_grad", True)
823
+ if not requires_grad:
824
+ for param in encoder.parameters():
825
+ param.requires_grad = False
826
+
827
+ return encoder
828
+
829
+
830
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
831
+ decoder_type = decoder_config.get("type", None)
832
+ assert decoder_type is not None, "Decoder type must be specified"
833
+
834
+ if decoder_type == "oobleck":
835
+ decoder = OobleckDecoder(**decoder_config["config"])
836
+ elif decoder_type == "seanet":
837
+ from encodec.modules import SEANetDecoder
838
+
839
+ decoder = SEANetDecoder(**decoder_config["config"])
840
+ elif decoder_type == "dac":
841
+ dac_config = decoder_config["config"]
842
+
843
+ decoder = DACDecoderWrapper(**dac_config)
844
+ elif decoder_type == "local_attn":
845
+ from .local_attention import TransformerDecoder1D
846
+
847
+ local_attn_config = decoder_config["config"]
848
+
849
+ decoder = TransformerDecoder1D(**local_attn_config)
850
+ else:
851
+ raise ValueError(f"Unknown decoder type {decoder_type}")
852
+
853
+ requires_grad = decoder_config.get("requires_grad", True)
854
+ if not requires_grad:
855
+ for param in decoder.parameters():
856
+ param.requires_grad = False
857
+
858
+ return decoder
859
+
860
+
861
+ def create_autoencoder_from_config(config: Dict[str, Any]):
862
+ ae_config = config["model"]
863
+
864
+ encoder = create_encoder_from_config(ae_config["encoder"])
865
+ decoder = create_decoder_from_config(ae_config["decoder"])
866
+
867
+ bottleneck = ae_config.get("bottleneck", None)
868
+
869
+ latent_dim = ae_config.get("latent_dim", None)
870
+ assert latent_dim is not None, "latent_dim must be specified in model config"
871
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
872
+ assert downsampling_ratio is not None, (
873
+ "downsampling_ratio must be specified in model config"
874
+ )
875
+ io_channels = ae_config.get("io_channels", None)
876
+ assert io_channels is not None, "io_channels must be specified in model config"
877
+ sample_rate = config.get("sample_rate", None)
878
+ assert sample_rate is not None, "sample_rate must be specified in model config"
879
+
880
+ in_channels = ae_config.get("in_channels", None)
881
+ out_channels = ae_config.get("out_channels", None)
882
+
883
+ pretransform = ae_config.get("pretransform", None)
884
+
885
+ if pretransform is not None:
886
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
887
+
888
+ if bottleneck is not None:
889
+ bottleneck = create_bottleneck_from_config(bottleneck)
890
+
891
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
892
+
893
+ return AudioAutoencoder(
894
+ encoder,
895
+ decoder,
896
+ io_channels=io_channels,
897
+ latent_dim=latent_dim,
898
+ downsampling_ratio=downsampling_ratio,
899
+ sample_rate=sample_rate,
900
+ bottleneck=bottleneck,
901
+ pretransform=pretransform,
902
+ in_channels=in_channels,
903
+ out_channels=out_channels,
904
+ soft_clip=soft_clip,
905
+ )
906
+
907
+
908
+ def create_diffAE_from_config(config: Dict[str, Any]):
909
+ diffae_config = config["model"]
910
+
911
+ if "encoder" in diffae_config:
912
+ encoder = create_encoder_from_config(diffae_config["encoder"])
913
+ else:
914
+ encoder = None
915
+
916
+ if "decoder" in diffae_config:
917
+ decoder = create_decoder_from_config(diffae_config["decoder"])
918
+ else:
919
+ decoder = None
920
+
921
+ diffusion_model_type = diffae_config["diffusion"]["type"]
922
+
923
+ if diffusion_model_type == "DAU1d":
924
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
925
+ elif diffusion_model_type == "adp_1d":
926
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
927
+ elif diffusion_model_type == "dit":
928
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
929
+
930
+ latent_dim = diffae_config.get("latent_dim", None)
931
+ assert latent_dim is not None, "latent_dim must be specified in model config"
932
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
933
+ assert downsampling_ratio is not None, (
934
+ "downsampling_ratio must be specified in model config"
935
+ )
936
+ io_channels = diffae_config.get("io_channels", None)
937
+ assert io_channels is not None, "io_channels must be specified in model config"
938
+ sample_rate = config.get("sample_rate", None)
939
+ assert sample_rate is not None, "sample_rate must be specified in model config"
940
+
941
+ bottleneck = diffae_config.get("bottleneck", None)
942
+
943
+ pretransform = diffae_config.get("pretransform", None)
944
+
945
+ if pretransform is not None:
946
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
947
+
948
+ if bottleneck is not None:
949
+ bottleneck = create_bottleneck_from_config(bottleneck)
950
+
951
+ diffusion_downsampling_ratio = (None,)
952
+
953
+ if diffusion_model_type == "DAU1d":
954
+ diffusion_downsampling_ratio = np.prod(
955
+ diffae_config["diffusion"]["config"]["strides"]
956
+ )
957
+ elif diffusion_model_type == "adp_1d":
958
+ diffusion_downsampling_ratio = np.prod(
959
+ diffae_config["diffusion"]["config"]["factors"]
960
+ )
961
+ elif diffusion_model_type == "dit":
962
+ diffusion_downsampling_ratio = 1
963
+
964
+ return DiffusionAutoencoder(
965
+ encoder=encoder,
966
+ decoder=decoder,
967
+ diffusion=diffusion,
968
+ io_channels=io_channels,
969
+ sample_rate=sample_rate,
970
+ latent_dim=latent_dim,
971
+ downsampling_ratio=downsampling_ratio,
972
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
973
+ bottleneck=bottleneck,
974
+ pretransform=pretransform,
975
+ )
src/YingMusicSinger/utils/stable_audio_tools/blocks.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import reduce
3
+
4
+ import numpy as np
5
+ import torch
6
+ from dac.nn.layers import Snake1d
7
+ from packaging import version
8
+ from torch import nn
9
+ from torch.backends.cuda import sdp_kernel
10
+ from torch.nn import functional as F
11
+
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+
23
+ class ResConvBlock(ResidualBlock):
24
+ def __init__(
25
+ self,
26
+ c_in,
27
+ c_mid,
28
+ c_out,
29
+ is_last=False,
30
+ kernel_size=5,
31
+ conv_bias=True,
32
+ use_snake=False,
33
+ ):
34
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
35
+ super().__init__(
36
+ [
37
+ nn.Conv1d(
38
+ c_in, c_mid, kernel_size, padding=kernel_size // 2, bias=conv_bias
39
+ ),
40
+ nn.GroupNorm(1, c_mid),
41
+ Snake1d(c_mid) if use_snake else nn.GELU(),
42
+ nn.Conv1d(
43
+ c_mid, c_out, kernel_size, padding=kernel_size // 2, bias=conv_bias
44
+ ),
45
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
46
+ (Snake1d(c_out) if use_snake else nn.GELU())
47
+ if not is_last
48
+ else nn.Identity(),
49
+ ],
50
+ skip,
51
+ )
52
+
53
+
54
+ class SelfAttention1d(nn.Module):
55
+ def __init__(self, c_in, n_head=1, dropout_rate=0.0):
56
+ super().__init__()
57
+ assert c_in % n_head == 0
58
+ self.norm = nn.GroupNorm(1, c_in)
59
+ self.n_head = n_head
60
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
61
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
62
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
63
+
64
+ self.use_flash = torch.cuda.is_available() and version.parse(
65
+ torch.__version__
66
+ ) >= version.parse("2.0.0")
67
+
68
+ if not self.use_flash:
69
+ return
70
+
71
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
72
+
73
+ if device_properties.major == 8 and device_properties.minor == 0:
74
+ # Use flash attention for A100 GPUs
75
+ self.sdp_kernel_config = (True, False, False)
76
+ else:
77
+ # Don't use flash attention for other GPUs
78
+ self.sdp_kernel_config = (False, True, True)
79
+
80
+ def forward(self, input):
81
+ n, c, s = input.shape
82
+ qkv = self.qkv_proj(self.norm(input))
83
+ qkv = qkv.view([n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
84
+ q, k, v = qkv.chunk(3, dim=1)
85
+ scale = k.shape[3] ** -0.25
86
+
87
+ if self.use_flash:
88
+ with sdp_kernel(*self.sdp_kernel_config):
89
+ y = (
90
+ F.scaled_dot_product_attention(q, k, v, is_causal=False)
91
+ .contiguous()
92
+ .view([n, c, s])
93
+ )
94
+ else:
95
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
96
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
97
+
98
+ return input + self.dropout(self.out_proj(y))
99
+
100
+
101
+ class SkipBlock(nn.Module):
102
+ def __init__(self, *main):
103
+ super().__init__()
104
+ self.main = nn.Sequential(*main)
105
+
106
+ def forward(self, input):
107
+ return torch.cat([self.main(input), input], dim=1)
108
+
109
+
110
+ class FourierFeatures(nn.Module):
111
+ def __init__(self, in_features, out_features, std=1.0):
112
+ super().__init__()
113
+ assert out_features % 2 == 0
114
+ self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
115
+
116
+ def forward(self, input):
117
+ f = 2 * math.pi * input @ self.weight.T
118
+ return torch.cat([f.cos(), f.sin()], dim=-1)
119
+
120
+
121
+ def expand_to_planes(input, shape):
122
+ return input[..., None].repeat([1, 1, shape[2]])
123
+
124
+
125
+ _kernels = {
126
+ "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
127
+ "cubic": [
128
+ -0.01171875,
129
+ -0.03515625,
130
+ 0.11328125,
131
+ 0.43359375,
132
+ 0.43359375,
133
+ 0.11328125,
134
+ -0.03515625,
135
+ -0.01171875,
136
+ ],
137
+ "lanczos3": [
138
+ 0.003689131001010537,
139
+ 0.015056144446134567,
140
+ -0.03399861603975296,
141
+ -0.066637322306633,
142
+ 0.13550527393817902,
143
+ 0.44638532400131226,
144
+ 0.44638532400131226,
145
+ 0.13550527393817902,
146
+ -0.066637322306633,
147
+ -0.03399861603975296,
148
+ 0.015056144446134567,
149
+ 0.003689131001010537,
150
+ ],
151
+ }
152
+
153
+
154
+ class Downsample1d(nn.Module):
155
+ def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False):
156
+ super().__init__()
157
+ self.pad_mode = pad_mode
158
+ kernel_1d = torch.tensor(_kernels[kernel])
159
+ self.pad = kernel_1d.shape[0] // 2 - 1
160
+ self.register_buffer("kernel", kernel_1d)
161
+ self.channels_last = channels_last
162
+
163
+ def forward(self, x):
164
+ if self.channels_last:
165
+ x = x.permute(0, 2, 1)
166
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
167
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
168
+ indices = torch.arange(x.shape[1], device=x.device)
169
+ weight[indices, indices] = self.kernel.to(weight)
170
+ x = F.conv1d(x, weight, stride=2)
171
+ if self.channels_last:
172
+ x = x.permute(0, 2, 1)
173
+ return x
174
+
175
+
176
+ class Upsample1d(nn.Module):
177
+ def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False):
178
+ super().__init__()
179
+ self.pad_mode = pad_mode
180
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
181
+ self.pad = kernel_1d.shape[0] // 2 - 1
182
+ self.register_buffer("kernel", kernel_1d)
183
+ self.channels_last = channels_last
184
+
185
+ def forward(self, x):
186
+ if self.channels_last:
187
+ x = x.permute(0, 2, 1)
188
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
189
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
190
+ indices = torch.arange(x.shape[1], device=x.device)
191
+ weight[indices, indices] = self.kernel.to(weight)
192
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
193
+ if self.channels_last:
194
+ x = x.permute(0, 2, 1)
195
+ return x
196
+
197
+
198
+ def Downsample1d_2(
199
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
200
+ ) -> nn.Module:
201
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
202
+
203
+ return nn.Conv1d(
204
+ in_channels=in_channels,
205
+ out_channels=out_channels,
206
+ kernel_size=factor * kernel_multiplier + 1,
207
+ stride=factor,
208
+ padding=factor * (kernel_multiplier // 2),
209
+ )
210
+
211
+
212
+ def Upsample1d_2(
213
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
214
+ ) -> nn.Module:
215
+ if factor == 1:
216
+ return nn.Conv1d(
217
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
218
+ )
219
+
220
+ if use_nearest:
221
+ return nn.Sequential(
222
+ nn.Upsample(scale_factor=factor, mode="nearest"),
223
+ nn.Conv1d(
224
+ in_channels=in_channels,
225
+ out_channels=out_channels,
226
+ kernel_size=3,
227
+ padding=1,
228
+ ),
229
+ )
230
+ else:
231
+ return nn.ConvTranspose1d(
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ kernel_size=factor * 2,
235
+ stride=factor,
236
+ padding=factor // 2 + factor % 2,
237
+ output_padding=factor % 2,
238
+ )
239
+
240
+
241
+ def zero_init(layer):
242
+ nn.init.zeros_(layer.weight)
243
+ if layer.bias is not None:
244
+ nn.init.zeros_(layer.bias)
245
+ return layer
246
+
247
+
248
+ def rms_norm(x, scale, eps):
249
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
250
+ mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True)
251
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
252
+ return x * scale.to(x.dtype)
253
+
254
+
255
+ # rms_norm = torch.compile(rms_norm)
256
+
257
+
258
+ class AdaRMSNorm(nn.Module):
259
+ def __init__(self, features, cond_features, eps=1e-6):
260
+ super().__init__()
261
+ self.eps = eps
262
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
263
+
264
+ def extra_repr(self):
265
+ return f"eps={self.eps},"
266
+
267
+ def forward(self, x, cond):
268
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
269
+
270
+
271
+ def normalize(x, eps=1e-4):
272
+ dim = list(range(1, x.ndim))
273
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
274
+ alpha = np.sqrt(n.numel() / x.numel())
275
+ return x / torch.add(eps, n, alpha=alpha)
276
+
277
+
278
+ class ForcedWNConv1d(nn.Module):
279
+ def __init__(self, in_channels, out_channels, kernel_size=1):
280
+ super().__init__()
281
+ self.weight = nn.Parameter(
282
+ torch.randn([out_channels, in_channels, kernel_size])
283
+ )
284
+
285
+ def forward(self, x):
286
+ if self.training:
287
+ with torch.no_grad():
288
+ self.weight.copy_(normalize(self.weight))
289
+
290
+ fan_in = self.weight[0].numel()
291
+
292
+ w = normalize(self.weight) / math.sqrt(fan_in)
293
+
294
+ return F.conv1d(x, w, padding="same")
295
+
296
+
297
+ # Kernels
298
+
299
+ use_compile = True
300
+
301
+
302
+ def compile(function, *args, **kwargs):
303
+ if not use_compile:
304
+ return function
305
+ try:
306
+ return torch.compile(function, *args, **kwargs)
307
+ except RuntimeError:
308
+ return function
309
+
310
+
311
+ @compile
312
+ def linear_geglu(x, weight, bias=None):
313
+ x = x @ weight.mT
314
+ if bias is not None:
315
+ x = x + bias
316
+ x, gate = x.chunk(2, dim=-1)
317
+ return x * F.gelu(gate)
318
+
319
+
320
+ @compile
321
+ def rms_norm(x, scale, eps):
322
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
323
+ mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True)
324
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
325
+ return x * scale.to(x.dtype)
326
+
327
+
328
+ # Layers
329
+
330
+
331
+ class LinearGEGLU(nn.Linear):
332
+ def __init__(self, in_features, out_features, bias=True):
333
+ super().__init__(in_features, out_features * 2, bias=bias)
334
+ self.out_features = out_features
335
+
336
+ def forward(self, x):
337
+ return linear_geglu(x, self.weight, self.bias)
338
+
339
+
340
+ class RMSNorm(nn.Module):
341
+ def __init__(self, shape, fix_scale=False, eps=1e-6):
342
+ super().__init__()
343
+ self.eps = eps
344
+
345
+ if fix_scale:
346
+ self.register_buffer("scale", torch.ones(shape))
347
+ else:
348
+ self.scale = nn.Parameter(torch.ones(shape))
349
+
350
+ def extra_repr(self):
351
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
352
+
353
+ def forward(self, x):
354
+ return rms_norm(x, self.scale, self.eps)
355
+
356
+
357
+ def snake_beta(x, alpha, beta):
358
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
359
+
360
+
361
+ # try:
362
+ # snake_beta = torch.compile(snake_beta)
363
+ # except RuntimeError:
364
+ # pass
365
+
366
+
367
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
368
+ # License available in LICENSES/LICENSE_NVIDIA.txt
369
+ class SnakeBeta(nn.Module):
370
+ def __init__(
371
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
372
+ ):
373
+ super(SnakeBeta, self).__init__()
374
+ self.in_features = in_features
375
+
376
+ # initialize alpha
377
+ self.alpha_logscale = alpha_logscale
378
+ if self.alpha_logscale: # log scale alphas initialized to zeros
379
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
380
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
381
+ else: # linear scale alphas initialized to ones
382
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
383
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
384
+
385
+ self.alpha.requires_grad = alpha_trainable
386
+ self.beta.requires_grad = alpha_trainable
387
+
388
+ self.no_div_by_zero = 0.000000001
389
+
390
+ def forward(self, x):
391
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
392
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
393
+ if self.alpha_logscale:
394
+ alpha = torch.exp(alpha)
395
+ beta = torch.exp(beta)
396
+ x = snake_beta(x, alpha, beta)
397
+
398
+ return x
src/YingMusicSinger/utils/stable_audio_tools/bottleneck copy.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from vector_quantize_pytorch import FSQ, ResidualVQ
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+
23
+ class DiscreteBottleneck(Bottleneck):
24
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
25
+ super().__init__(is_discrete=True)
26
+
27
+ self.num_quantizers = num_quantizers
28
+ self.codebook_size = codebook_size
29
+ self.tokens_id = tokens_id
30
+
31
+ def decode_tokens(self, codes, **kwargs):
32
+ raise NotImplementedError
33
+
34
+
35
+ class TanhBottleneck(Bottleneck):
36
+ def __init__(self):
37
+ super().__init__(is_discrete=False)
38
+ self.tanh = nn.Tanh()
39
+
40
+ def encode(self, x, return_info=False):
41
+ info = {}
42
+
43
+ x = torch.tanh(x)
44
+
45
+ if return_info:
46
+ return x, info
47
+ else:
48
+ return x
49
+
50
+ def decode(self, x):
51
+ return x
52
+
53
+
54
+ def vae_sample(mean, scale):
55
+ stdev = nn.functional.softplus(scale) + 1e-4
56
+ var = stdev * stdev
57
+ logvar = torch.log(var)
58
+ latents = torch.randn_like(mean) * stdev + mean
59
+
60
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
61
+
62
+ return latents, kl
63
+
64
+
65
+ class VAEBottleneck(Bottleneck):
66
+ def __init__(self):
67
+ super().__init__(is_discrete=False)
68
+
69
+ def encode(self, x, return_info=False, **kwargs):
70
+ info = {}
71
+
72
+ mean, scale = x.chunk(2, dim=1)
73
+
74
+ x, kl = vae_sample(mean, scale)
75
+
76
+ info["kl"] = kl
77
+
78
+ if return_info:
79
+ return x, info
80
+ else:
81
+ return x
82
+
83
+ def decode(self, x):
84
+ return x
85
+
86
+
87
+ def compute_mean_kernel(x, y):
88
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
89
+ return torch.exp(-kernel_input).mean()
90
+
91
+
92
+ def compute_mmd(latents):
93
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
94
+ noise = torch.randn_like(latents_reshaped)
95
+
96
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
97
+ noise_kernel = compute_mean_kernel(noise, noise)
98
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
99
+
100
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
101
+ return mmd.mean()
102
+
103
+
104
+ class WassersteinBottleneck(Bottleneck):
105
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
106
+ super().__init__(is_discrete=False)
107
+
108
+ self.noise_augment_dim = noise_augment_dim
109
+ self.bypass_mmd = bypass_mmd
110
+
111
+ def encode(self, x, return_info=False):
112
+ info = {}
113
+
114
+ if self.training and return_info:
115
+ if self.bypass_mmd:
116
+ mmd = torch.tensor(0.0)
117
+ else:
118
+ mmd = compute_mmd(x)
119
+
120
+ info["mmd"] = mmd
121
+
122
+ if return_info:
123
+ return x, info
124
+
125
+ return x
126
+
127
+ def decode(self, x):
128
+ if self.noise_augment_dim > 0:
129
+ noise = torch.randn(
130
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
131
+ ).type_as(x)
132
+ x = torch.cat([x, noise], dim=1)
133
+
134
+ return x
135
+
136
+
137
+ class L2Bottleneck(Bottleneck):
138
+ def __init__(self):
139
+ super().__init__(is_discrete=False)
140
+
141
+ def encode(self, x, return_info=False):
142
+ info = {}
143
+
144
+ x = F.normalize(x, dim=1)
145
+
146
+ if return_info:
147
+ return x, info
148
+ else:
149
+ return x
150
+
151
+ def decode(self, x):
152
+ return F.normalize(x, dim=1)
153
+
154
+
155
+ class RVQBottleneck(DiscreteBottleneck):
156
+ def __init__(self, **quantizer_kwargs):
157
+ super().__init__(
158
+ num_quantizers=quantizer_kwargs["num_quantizers"],
159
+ codebook_size=quantizer_kwargs["codebook_size"],
160
+ tokens_id="quantizer_indices",
161
+ )
162
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
163
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
164
+
165
+ def encode(self, x, return_info=False, **kwargs):
166
+ info = {}
167
+
168
+ x = rearrange(x, "b c n -> b n c")
169
+ x, indices, loss = self.quantizer(x)
170
+ x = rearrange(x, "b n c -> b c n")
171
+
172
+ info["quantizer_indices"] = indices
173
+ info["quantizer_loss"] = loss.mean()
174
+
175
+ if return_info:
176
+ return x, info
177
+ else:
178
+ return x
179
+
180
+ def decode(self, x):
181
+ return x
182
+
183
+ def decode_tokens(self, codes, **kwargs):
184
+ latents = self.quantizer.get_outputs_from_indices(codes)
185
+
186
+ return self.decode(latents, **kwargs)
187
+
188
+
189
+ class RVQVAEBottleneck(DiscreteBottleneck):
190
+ def __init__(self, **quantizer_kwargs):
191
+ super().__init__(
192
+ num_quantizers=quantizer_kwargs["num_quantizers"],
193
+ codebook_size=quantizer_kwargs["codebook_size"],
194
+ tokens_id="quantizer_indices",
195
+ )
196
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
197
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
198
+
199
+ def encode(self, x, return_info=False):
200
+ info = {}
201
+
202
+ x, kl = vae_sample(*x.chunk(2, dim=1))
203
+
204
+ info["kl"] = kl
205
+
206
+ x = rearrange(x, "b c n -> b n c")
207
+ x, indices, loss = self.quantizer(x)
208
+ x = rearrange(x, "b n c -> b c n")
209
+
210
+ info["quantizer_indices"] = indices
211
+ info["quantizer_loss"] = loss.mean()
212
+
213
+ if return_info:
214
+ return x, info
215
+ else:
216
+ return x
217
+
218
+ def decode(self, x):
219
+ return x
220
+
221
+ def decode_tokens(self, codes, **kwargs):
222
+ latents = self.quantizer.get_outputs_from_indices(codes)
223
+
224
+ return self.decode(latents, **kwargs)
225
+
226
+
227
+ class DACRVQBottleneck(DiscreteBottleneck):
228
+ def __init__(
229
+ self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs
230
+ ):
231
+ super().__init__(
232
+ num_quantizers=quantizer_kwargs["n_codebooks"],
233
+ codebook_size=quantizer_kwargs["codebook_size"],
234
+ tokens_id="codes",
235
+ )
236
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
237
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
238
+ self.quantize_on_decode = quantize_on_decode
239
+ self.noise_augment_dim = noise_augment_dim
240
+
241
+ def encode(self, x, return_info=False, **kwargs):
242
+ info = {}
243
+
244
+ info["pre_quantizer"] = x
245
+
246
+ if self.quantize_on_decode:
247
+ return x, info if return_info else x
248
+
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
250
+
251
+ output = {
252
+ "z": z,
253
+ "codes": codes,
254
+ "latents": latents,
255
+ "vq/commitment_loss": commitment_loss,
256
+ "vq/codebook_loss": codebook_loss,
257
+ }
258
+
259
+ output["vq/commitment_loss"] /= self.num_quantizers
260
+ output["vq/codebook_loss"] /= self.num_quantizers
261
+
262
+ info.update(output)
263
+
264
+ if return_info:
265
+ return output["z"], info
266
+
267
+ return output["z"]
268
+
269
+ def decode(self, x):
270
+ if self.quantize_on_decode:
271
+ x = self.quantizer(x)[0]
272
+
273
+ if self.noise_augment_dim > 0:
274
+ noise = torch.randn(
275
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
276
+ ).type_as(x)
277
+ x = torch.cat([x, noise], dim=1)
278
+
279
+ return x
280
+
281
+ def decode_tokens(self, codes, **kwargs):
282
+ latents, _, _ = self.quantizer.from_codes(codes)
283
+
284
+ return self.decode(latents, **kwargs)
285
+
286
+
287
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
288
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
289
+ super().__init__(
290
+ num_quantizers=quantizer_kwargs["n_codebooks"],
291
+ codebook_size=quantizer_kwargs["codebook_size"],
292
+ tokens_id="codes",
293
+ )
294
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
295
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
296
+ self.quantize_on_decode = quantize_on_decode
297
+
298
+ def encode(self, x, return_info=False, n_quantizers: int = None):
299
+ info = {}
300
+
301
+ mean, scale = x.chunk(2, dim=1)
302
+
303
+ x, kl = vae_sample(mean, scale)
304
+
305
+ info["pre_quantizer"] = x
306
+ info["kl"] = kl
307
+
308
+ if self.quantize_on_decode:
309
+ return x, info if return_info else x
310
+
311
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
312
+ x, n_quantizers=n_quantizers
313
+ )
314
+
315
+ output = {
316
+ "z": z,
317
+ "codes": codes,
318
+ "latents": latents,
319
+ "vq/commitment_loss": commitment_loss,
320
+ "vq/codebook_loss": codebook_loss,
321
+ }
322
+
323
+ output["vq/commitment_loss"] /= self.num_quantizers
324
+ output["vq/codebook_loss"] /= self.num_quantizers
325
+
326
+ info.update(output)
327
+
328
+ if return_info:
329
+ return output["z"], info
330
+
331
+ return output["z"]
332
+
333
+ def decode(self, x):
334
+ if self.quantize_on_decode:
335
+ x = self.quantizer(x)[0]
336
+
337
+ return x
338
+
339
+ def decode_tokens(self, codes, **kwargs):
340
+ latents, _, _ = self.quantizer.from_codes(codes)
341
+
342
+ return self.decode(latents, **kwargs)
343
+
344
+
345
+ class FSQBottleneck(DiscreteBottleneck):
346
+ def __init__(self, noise_augment_dim=0, **kwargs):
347
+ super().__init__(
348
+ num_quantizers=kwargs.get("num_codebooks", 1),
349
+ codebook_size=np.prod(kwargs["levels"]),
350
+ tokens_id="quantizer_indices",
351
+ )
352
+
353
+ self.noise_augment_dim = noise_augment_dim
354
+
355
+ self.quantizer = FSQ(
356
+ **kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]
357
+ )
358
+
359
+ def encode(self, x, return_info=False):
360
+ info = {}
361
+
362
+ orig_dtype = x.dtype
363
+ x = x.float()
364
+
365
+ x = rearrange(x, "b c n -> b n c")
366
+ x, indices = self.quantizer(x)
367
+ x = rearrange(x, "b n c -> b c n")
368
+
369
+ x = x.to(orig_dtype)
370
+
371
+ # Reorder indices to match the expected format
372
+ indices = rearrange(indices, "b n q -> b q n")
373
+
374
+ info["quantizer_indices"] = indices
375
+
376
+ if return_info:
377
+ return x, info
378
+ else:
379
+ return x
380
+
381
+ def decode(self, x):
382
+ if self.noise_augment_dim > 0:
383
+ noise = torch.randn(
384
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
385
+ ).type_as(x)
386
+ x = torch.cat([x, noise], dim=1)
387
+
388
+ return x
389
+
390
+ def decode_tokens(self, tokens, **kwargs):
391
+ latents = self.quantizer.indices_to_codes(tokens)
392
+
393
+ return self.decode(latents, **kwargs)
src/YingMusicSinger/utils/stable_audio_tools/bottleneck.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from vector_quantize_pytorch import FSQ, ResidualVQ
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+
23
+ class DiscreteBottleneck(Bottleneck):
24
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
25
+ super().__init__(is_discrete=True)
26
+
27
+ self.num_quantizers = num_quantizers
28
+ self.codebook_size = codebook_size
29
+ self.tokens_id = tokens_id
30
+
31
+ def decode_tokens(self, codes, **kwargs):
32
+ raise NotImplementedError
33
+
34
+
35
+ class TanhBottleneck(Bottleneck):
36
+ def __init__(self):
37
+ super().__init__(is_discrete=False)
38
+ self.tanh = nn.Tanh()
39
+
40
+ def encode(self, x, return_info=False):
41
+ info = {}
42
+
43
+ x = torch.tanh(x)
44
+
45
+ if return_info:
46
+ return x, info
47
+ else:
48
+ return x
49
+
50
+ def decode(self, x):
51
+ return x
52
+
53
+
54
+ def vae_sample(mean, scale):
55
+ stdev = nn.functional.softplus(scale) + 1e-4
56
+ var = stdev * stdev
57
+ logvar = torch.log(var)
58
+ latents = torch.randn_like(mean) * stdev + mean
59
+
60
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
61
+
62
+ return latents, kl
63
+
64
+
65
+ class VAEBottleneck(Bottleneck):
66
+ def __init__(self):
67
+ super().__init__(is_discrete=False)
68
+
69
+ def encode(self, x, return_info=False, **kwargs):
70
+ info = {}
71
+
72
+ mean, scale = x.chunk(2, dim=1)
73
+
74
+ x, kl = vae_sample(mean, scale)
75
+
76
+ info["kl"] = kl
77
+
78
+ if return_info:
79
+ return x, info
80
+ else:
81
+ return x
82
+
83
+ def decode(self, x):
84
+ return x
85
+
86
+
87
+ def compute_mean_kernel(x, y):
88
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
89
+ return torch.exp(-kernel_input).mean()
90
+
91
+
92
+ def compute_mmd(latents):
93
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
94
+ noise = torch.randn_like(latents_reshaped)
95
+
96
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
97
+ noise_kernel = compute_mean_kernel(noise, noise)
98
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
99
+
100
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
101
+ return mmd.mean()
102
+
103
+
104
+ class WassersteinBottleneck(Bottleneck):
105
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
106
+ super().__init__(is_discrete=False)
107
+
108
+ self.noise_augment_dim = noise_augment_dim
109
+ self.bypass_mmd = bypass_mmd
110
+
111
+ def encode(self, x, return_info=False):
112
+ info = {}
113
+
114
+ if self.training and return_info:
115
+ if self.bypass_mmd:
116
+ mmd = torch.tensor(0.0)
117
+ else:
118
+ mmd = compute_mmd(x)
119
+
120
+ info["mmd"] = mmd
121
+
122
+ if return_info:
123
+ return x, info
124
+
125
+ return x
126
+
127
+ def decode(self, x):
128
+ if self.noise_augment_dim > 0:
129
+ noise = torch.randn(
130
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
131
+ ).type_as(x)
132
+ x = torch.cat([x, noise], dim=1)
133
+
134
+ return x
135
+
136
+
137
+ class L2Bottleneck(Bottleneck):
138
+ def __init__(self):
139
+ super().__init__(is_discrete=False)
140
+
141
+ def encode(self, x, return_info=False):
142
+ info = {}
143
+
144
+ x = F.normalize(x, dim=1)
145
+
146
+ if return_info:
147
+ return x, info
148
+ else:
149
+ return x
150
+
151
+ def decode(self, x):
152
+ return F.normalize(x, dim=1)
153
+
154
+
155
+ class RVQBottleneck(DiscreteBottleneck):
156
+ def __init__(self, **quantizer_kwargs):
157
+ super().__init__(
158
+ num_quantizers=quantizer_kwargs["num_quantizers"],
159
+ codebook_size=quantizer_kwargs["codebook_size"],
160
+ tokens_id="quantizer_indices",
161
+ )
162
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
163
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
164
+
165
+ def encode(self, x, return_info=False, **kwargs):
166
+ info = {}
167
+
168
+ x = rearrange(x, "b c n -> b n c")
169
+ x, indices, loss = self.quantizer(x)
170
+ x = rearrange(x, "b n c -> b c n")
171
+
172
+ info["quantizer_indices"] = indices
173
+ info["quantizer_loss"] = loss.mean()
174
+
175
+ if return_info:
176
+ return x, info
177
+ else:
178
+ return x
179
+
180
+ def decode(self, x):
181
+ return x
182
+
183
+ def decode_tokens(self, codes, **kwargs):
184
+ latents = self.quantizer.get_outputs_from_indices(codes)
185
+
186
+ return self.decode(latents, **kwargs)
187
+
188
+
189
+ class RVQVAEBottleneck(DiscreteBottleneck):
190
+ def __init__(self, **quantizer_kwargs):
191
+ super().__init__(
192
+ num_quantizers=quantizer_kwargs["num_quantizers"],
193
+ codebook_size=quantizer_kwargs["codebook_size"],
194
+ tokens_id="quantizer_indices",
195
+ )
196
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
197
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
198
+
199
+ def encode(self, x, return_info=False):
200
+ info = {}
201
+
202
+ x, kl = vae_sample(*x.chunk(2, dim=1))
203
+
204
+ info["kl"] = kl
205
+
206
+ x = rearrange(x, "b c n -> b n c")
207
+ x, indices, loss = self.quantizer(x)
208
+ x = rearrange(x, "b n c -> b c n")
209
+
210
+ info["quantizer_indices"] = indices
211
+ info["quantizer_loss"] = loss.mean()
212
+
213
+ if return_info:
214
+ return x, info
215
+ else:
216
+ return x
217
+
218
+ def decode(self, x):
219
+ return x
220
+
221
+ def decode_tokens(self, codes, **kwargs):
222
+ latents = self.quantizer.get_outputs_from_indices(codes)
223
+
224
+ return self.decode(latents, **kwargs)
225
+
226
+
227
+ class DACRVQBottleneck(DiscreteBottleneck):
228
+ def __init__(
229
+ self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs
230
+ ):
231
+ super().__init__(
232
+ num_quantizers=quantizer_kwargs["n_codebooks"],
233
+ codebook_size=quantizer_kwargs["codebook_size"],
234
+ tokens_id="codes",
235
+ )
236
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
237
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
238
+ self.quantize_on_decode = quantize_on_decode
239
+ self.noise_augment_dim = noise_augment_dim
240
+
241
+ def encode(self, x, return_info=False, **kwargs):
242
+ info = {}
243
+
244
+ info["pre_quantizer"] = x
245
+
246
+ if self.quantize_on_decode:
247
+ return x, info if return_info else x
248
+
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
250
+
251
+ output = {
252
+ "z": z,
253
+ "codes": codes,
254
+ "latents": latents,
255
+ "vq/commitment_loss": commitment_loss,
256
+ "vq/codebook_loss": codebook_loss,
257
+ }
258
+
259
+ output["vq/commitment_loss"] /= self.num_quantizers
260
+ output["vq/codebook_loss"] /= self.num_quantizers
261
+
262
+ info.update(output)
263
+
264
+ if return_info:
265
+ return output["z"], info
266
+
267
+ return output["z"]
268
+
269
+ def decode(self, x):
270
+ if self.quantize_on_decode:
271
+ x = self.quantizer(x)[0]
272
+
273
+ if self.noise_augment_dim > 0:
274
+ noise = torch.randn(
275
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
276
+ ).type_as(x)
277
+ x = torch.cat([x, noise], dim=1)
278
+
279
+ return x
280
+
281
+ def decode_tokens(self, codes, **kwargs):
282
+ latents, _, _ = self.quantizer.from_codes(codes)
283
+
284
+ return self.decode(latents, **kwargs)
285
+
286
+
287
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
288
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
289
+ super().__init__(
290
+ num_quantizers=quantizer_kwargs["n_codebooks"],
291
+ codebook_size=quantizer_kwargs["codebook_size"],
292
+ tokens_id="codes",
293
+ )
294
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
295
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
296
+ self.quantize_on_decode = quantize_on_decode
297
+
298
+ def encode(self, x, return_info=False, n_quantizers: int = None):
299
+ info = {}
300
+
301
+ mean, scale = x.chunk(2, dim=1)
302
+
303
+ x, kl = vae_sample(mean, scale)
304
+
305
+ info["pre_quantizer"] = x
306
+ info["kl"] = kl
307
+
308
+ if self.quantize_on_decode:
309
+ return x, info if return_info else x
310
+
311
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
312
+ x, n_quantizers=n_quantizers
313
+ )
314
+
315
+ output = {
316
+ "z": z,
317
+ "codes": codes,
318
+ "latents": latents,
319
+ "vq/commitment_loss": commitment_loss,
320
+ "vq/codebook_loss": codebook_loss,
321
+ }
322
+
323
+ output["vq/commitment_loss"] /= self.num_quantizers
324
+ output["vq/codebook_loss"] /= self.num_quantizers
325
+
326
+ info.update(output)
327
+
328
+ if return_info:
329
+ return output["z"], info
330
+
331
+ return output["z"]
332
+
333
+ def decode(self, x):
334
+ if self.quantize_on_decode:
335
+ x = self.quantizer(x)[0]
336
+
337
+ return x
338
+
339
+ def decode_tokens(self, codes, **kwargs):
340
+ latents, _, _ = self.quantizer.from_codes(codes)
341
+
342
+ return self.decode(latents, **kwargs)
343
+
344
+
345
+ class FSQBottleneck(DiscreteBottleneck):
346
+ def __init__(self, noise_augment_dim=0, **kwargs):
347
+ super().__init__(
348
+ num_quantizers=kwargs.get("num_codebooks", 1),
349
+ codebook_size=np.prod(kwargs["levels"]),
350
+ tokens_id="quantizer_indices",
351
+ )
352
+
353
+ self.noise_augment_dim = noise_augment_dim
354
+
355
+ self.quantizer = FSQ(
356
+ **kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]
357
+ )
358
+
359
+ def encode(self, x, return_info=False):
360
+ info = {}
361
+
362
+ orig_dtype = x.dtype
363
+ x = x.float()
364
+
365
+ x = rearrange(x, "b c n -> b n c")
366
+ x, indices = self.quantizer(x)
367
+ x = rearrange(x, "b n c -> b c n")
368
+
369
+ x = x.to(orig_dtype)
370
+
371
+ # Reorder indices to match the expected format
372
+ indices = rearrange(indices, "b n q -> b q n")
373
+
374
+ info["quantizer_indices"] = indices
375
+
376
+ if return_info:
377
+ return x, info
378
+ else:
379
+ return x
380
+
381
+ def decode(self, x):
382
+ if self.noise_augment_dim > 0:
383
+ noise = torch.randn(
384
+ x.shape[0], self.noise_augment_dim, x.shape[-1]
385
+ ).type_as(x)
386
+ x = torch.cat([x, noise], dim=1)
387
+
388
+ return x
389
+
390
+ def decode_tokens(self, tokens, **kwargs):
391
+ latents = self.quantizer.indices_to_codes(tokens)
392
+
393
+ return self.decode(latents, **kwargs)
src/YingMusicSinger/utils/stable_audio_tools/conditioners.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
+
3
+ import gc
4
+ import logging
5
+ import string
6
+ import typing as tp
7
+ import warnings
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from .adp import NumberEmbedder
13
+
14
+ # from ..inference.utils import set_audio_channels
15
+ from .factory import create_pretransform_from_config
16
+ from .pretransforms import Pretransform
17
+
18
+ # from ..training.utils import copy_state_dict
19
+ from .utils import load_ckpt_state_dict
20
+
21
+
22
+ class Conditioner(nn.Module):
23
+ def __init__(self, dim: int, output_dim: int, project_out: bool = False):
24
+ super().__init__()
25
+
26
+ self.dim = dim
27
+ self.output_dim = output_dim
28
+ self.proj_out = (
29
+ nn.Linear(dim, output_dim)
30
+ if (dim != output_dim or project_out)
31
+ else nn.Identity()
32
+ )
33
+
34
+ def forward(self, x: tp.Any) -> tp.Any:
35
+ raise NotImplementedError()
36
+
37
+
38
+ class IntConditioner(Conditioner):
39
+ def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512):
40
+ super().__init__(output_dim, output_dim)
41
+
42
+ self.min_val = min_val
43
+ self.max_val = max_val
44
+ self.int_embedder = nn.Embedding(
45
+ max_val - min_val + 1, output_dim
46
+ ).requires_grad_(True)
47
+
48
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
49
+ # self.int_embedder.to(device)
50
+
51
+ ints = torch.tensor(ints).to(device)
52
+ ints = ints.clamp(self.min_val, self.max_val)
53
+
54
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
55
+
56
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
57
+
58
+
59
+ class NumberConditioner(Conditioner):
60
+ """
61
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
62
+ """
63
+
64
+ def __init__(self, output_dim: int, min_val: float = 0, max_val: float = 1):
65
+ super().__init__(output_dim, output_dim)
66
+
67
+ self.min_val = min_val
68
+ self.max_val = max_val
69
+
70
+ self.embedder = NumberEmbedder(features=output_dim)
71
+
72
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
73
+ # Cast the inputs to floats
74
+ floats = [float(x) for x in floats]
75
+
76
+ floats = torch.tensor(floats).to(device)
77
+
78
+ floats = floats.clamp(self.min_val, self.max_val)
79
+
80
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
81
+
82
+ # Cast floats to same type as embedder
83
+ embedder_dtype = next(self.embedder.parameters()).dtype
84
+ normalized_floats = normalized_floats.to(embedder_dtype)
85
+
86
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
87
+
88
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
89
+
90
+
91
+ class CLAPTextConditioner(Conditioner):
92
+ def __init__(
93
+ self,
94
+ output_dim: int,
95
+ clap_ckpt_path,
96
+ use_text_features=False,
97
+ feature_layer_ix: int = -1,
98
+ audio_model_type="HTSAT-base",
99
+ enable_fusion=True,
100
+ project_out: bool = False,
101
+ finetune: bool = False,
102
+ ):
103
+ super().__init__(
104
+ 768 if use_text_features else 512, output_dim, project_out=project_out
105
+ )
106
+
107
+ self.use_text_features = use_text_features
108
+ self.feature_layer_ix = feature_layer_ix
109
+ self.finetune = finetune
110
+
111
+ # Suppress logging from transformers
112
+ previous_level = logging.root.manager.disable
113
+ logging.disable(logging.ERROR)
114
+ with warnings.catch_warnings():
115
+ warnings.simplefilter("ignore")
116
+ try:
117
+ import laion_clap
118
+ from laion_clap.clap_module.factory import (
119
+ load_state_dict as clap_load_state_dict,
120
+ )
121
+
122
+ model = laion_clap.CLAP_Module(
123
+ enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu"
124
+ )
125
+
126
+ if self.finetune:
127
+ self.model = model
128
+ else:
129
+ self.__dict__["model"] = model
130
+
131
+ state_dict = clap_load_state_dict(clap_ckpt_path)
132
+ self.model.model.load_state_dict(state_dict, strict=False)
133
+
134
+ if self.finetune:
135
+ self.model.model.text_branch.requires_grad_(True)
136
+ self.model.model.text_branch.train()
137
+ else:
138
+ self.model.model.text_branch.requires_grad_(False)
139
+ self.model.model.text_branch.eval()
140
+
141
+ finally:
142
+ logging.disable(previous_level)
143
+
144
+ del self.model.model.audio_branch
145
+
146
+ gc.collect()
147
+ torch.cuda.empty_cache()
148
+
149
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
150
+ prompt_tokens = self.model.tokenizer(prompts)
151
+ attention_mask = prompt_tokens["attention_mask"].to(
152
+ device=device, non_blocking=True
153
+ )
154
+ prompt_features = self.model.model.text_branch(
155
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
156
+ attention_mask=attention_mask,
157
+ output_hidden_states=True,
158
+ )["hidden_states"][layer_ix]
159
+
160
+ return prompt_features, attention_mask
161
+
162
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
163
+ self.model.to(device)
164
+
165
+ if self.use_text_features:
166
+ if len(texts) == 1:
167
+ text_features, text_attention_mask = self.get_clap_features(
168
+ [texts[0], ""], layer_ix=self.feature_layer_ix, device=device
169
+ )
170
+ text_features = text_features[:1, ...]
171
+ text_attention_mask = text_attention_mask[:1, ...]
172
+ else:
173
+ text_features, text_attention_mask = self.get_clap_features(
174
+ texts, layer_ix=self.feature_layer_ix, device=device
175
+ )
176
+ return [self.proj_out(text_features), text_attention_mask]
177
+
178
+ # Fix for CLAP bug when only one text is passed
179
+ if len(texts) == 1:
180
+ text_embedding = self.model.get_text_embedding(
181
+ [texts[0], ""], use_tensor=True
182
+ )[:1, ...]
183
+ else:
184
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
185
+
186
+ text_embedding = text_embedding.unsqueeze(1).to(device)
187
+
188
+ return [
189
+ self.proj_out(text_embedding),
190
+ torch.ones(text_embedding.shape[0], 1).to(device),
191
+ ]
192
+
193
+
194
+ class CLAPAudioConditioner(Conditioner):
195
+ def __init__(
196
+ self,
197
+ output_dim: int,
198
+ clap_ckpt_path,
199
+ audio_model_type="HTSAT-base",
200
+ enable_fusion=True,
201
+ project_out: bool = False,
202
+ ):
203
+ super().__init__(512, output_dim, project_out=project_out)
204
+
205
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
+
207
+ # Suppress logging from transformers
208
+ previous_level = logging.root.manager.disable
209
+ logging.disable(logging.ERROR)
210
+ with warnings.catch_warnings():
211
+ warnings.simplefilter("ignore")
212
+ try:
213
+ import laion_clap
214
+ from laion_clap.clap_module.factory import (
215
+ load_state_dict as clap_load_state_dict,
216
+ )
217
+
218
+ model = laion_clap.CLAP_Module(
219
+ enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu"
220
+ )
221
+
222
+ if self.finetune:
223
+ self.model = model
224
+ else:
225
+ self.__dict__["model"] = model
226
+
227
+ state_dict = clap_load_state_dict(clap_ckpt_path)
228
+ self.model.model.load_state_dict(state_dict, strict=False)
229
+
230
+ if self.finetune:
231
+ self.model.model.audio_branch.requires_grad_(True)
232
+ self.model.model.audio_branch.train()
233
+ else:
234
+ self.model.model.audio_branch.requires_grad_(False)
235
+ self.model.model.audio_branch.eval()
236
+
237
+ finally:
238
+ logging.disable(previous_level)
239
+
240
+ del self.model.model.text_branch
241
+
242
+ gc.collect()
243
+ torch.cuda.empty_cache()
244
+
245
+ def forward(
246
+ self,
247
+ audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]],
248
+ device: tp.Any = "cuda",
249
+ ) -> tp.Any:
250
+ self.model.to(device)
251
+
252
+ if isinstance(audios, list) or isinstance(audios, tuple):
253
+ audios = torch.cat(audios, dim=0)
254
+
255
+ # Convert to mono
256
+ mono_audios = audios.mean(dim=1)
257
+
258
+ with torch.cuda.amp.autocast(enabled=False):
259
+ audio_embedding = self.model.get_audio_embedding_from_data(
260
+ mono_audios.float(), use_tensor=True
261
+ )
262
+
263
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
264
+
265
+ return [
266
+ self.proj_out(audio_embedding),
267
+ torch.ones(audio_embedding.shape[0], 1).to(device),
268
+ ]
269
+
270
+
271
+ class T5Conditioner(Conditioner):
272
+ T5_MODELS = [
273
+ "t5-small",
274
+ "t5-base",
275
+ "t5-large",
276
+ "t5-3b",
277
+ "t5-11b",
278
+ "google/flan-t5-small",
279
+ "google/flan-t5-base",
280
+ "google/flan-t5-large",
281
+ "google/flan-t5-xl",
282
+ "google/flan-t5-xxl",
283
+ ]
284
+
285
+ T5_MODEL_DIMS = {
286
+ "t5-small": 512,
287
+ "t5-base": 768,
288
+ "t5-large": 1024,
289
+ "t5-3b": 1024,
290
+ "t5-11b": 1024,
291
+ "t5-xl": 2048,
292
+ "t5-xxl": 4096,
293
+ "google/flan-t5-small": 512,
294
+ "google/flan-t5-base": 768,
295
+ "google/flan-t5-large": 1024,
296
+ "google/flan-t5-3b": 1024,
297
+ "google/flan-t5-11b": 1024,
298
+ "google/flan-t5-xl": 2048,
299
+ "google/flan-t5-xxl": 4096,
300
+ }
301
+
302
+ def __init__(
303
+ self,
304
+ output_dim: int,
305
+ t5_model_name: str = "t5-base",
306
+ max_length: str = 128,
307
+ enable_grad: bool = False,
308
+ project_out: bool = False,
309
+ ):
310
+ assert t5_model_name in self.T5_MODELS, (
311
+ f"Unknown T5 model name: {t5_model_name}"
312
+ )
313
+ super().__init__(
314
+ self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out
315
+ )
316
+
317
+ from transformers import AutoTokenizer, T5EncoderModel
318
+
319
+ self.max_length = max_length
320
+ self.enable_grad = enable_grad
321
+
322
+ # Suppress logging from transformers
323
+ previous_level = logging.root.manager.disable
324
+ logging.disable(logging.ERROR)
325
+ with warnings.catch_warnings():
326
+ warnings.simplefilter("ignore")
327
+ try:
328
+ # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
329
+ # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
330
+ self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
331
+ model = (
332
+ T5EncoderModel.from_pretrained(t5_model_name)
333
+ .train(enable_grad)
334
+ .requires_grad_(enable_grad)
335
+ .to(torch.float16)
336
+ )
337
+ finally:
338
+ logging.disable(previous_level)
339
+
340
+ if self.enable_grad:
341
+ self.model = model
342
+ else:
343
+ self.__dict__["model"] = model
344
+
345
+ def forward(
346
+ self, texts: tp.List[str], device: tp.Union[torch.device, str]
347
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
348
+ self.model.to(device)
349
+ self.proj_out.to(device)
350
+
351
+ encoded = self.tokenizer(
352
+ texts,
353
+ truncation=True,
354
+ max_length=self.max_length,
355
+ padding="max_length",
356
+ return_tensors="pt",
357
+ )
358
+
359
+ input_ids = encoded["input_ids"].to(device)
360
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
361
+
362
+ self.model.eval()
363
+
364
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(
365
+ self.enable_grad
366
+ ):
367
+ embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)[
368
+ "last_hidden_state"
369
+ ]
370
+
371
+ embeddings = self.proj_out(embeddings.float())
372
+
373
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
374
+
375
+ return embeddings, attention_mask
376
+
377
+
378
+ class PhonemeConditioner(Conditioner):
379
+ """
380
+ A conditioner that turns text into phonemes and embeds them using a lookup table
381
+ Only works for English text
382
+
383
+ Args:
384
+ output_dim: the dimension of the output embeddings
385
+ max_length: the maximum number of phonemes to embed
386
+ project_out: whether to add another linear projection to the output embeddings
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ output_dim: int,
392
+ max_length: int = 1024,
393
+ project_out: bool = False,
394
+ ):
395
+ super().__init__(output_dim, output_dim, project_out=project_out)
396
+
397
+ from g2p_en import G2p
398
+
399
+ self.max_length = max_length
400
+
401
+ self.g2p = G2p()
402
+
403
+ # Reserving 0 for padding, 1 for ignored
404
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
405
+
406
+ def forward(
407
+ self, texts: tp.List[str], device: tp.Union[torch.device, str]
408
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
409
+ self.phoneme_embedder.to(device)
410
+ self.proj_out.to(device)
411
+
412
+ batch_phonemes = [
413
+ self.g2p(text) for text in texts
414
+ ] # shape [batch_size, length]
415
+
416
+ phoneme_ignore = [" ", *string.punctuation]
417
+
418
+ # Remove ignored phonemes and cut to max length
419
+ batch_phonemes = [
420
+ [p if p not in phoneme_ignore else "_" for p in phonemes]
421
+ for phonemes in batch_phonemes
422
+ ]
423
+
424
+ # Convert to ids
425
+ phoneme_ids = [
426
+ [self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes]
427
+ for phonemes in batch_phonemes
428
+ ]
429
+
430
+ # Pad to match longest and make a mask tensor for the padding
431
+ longest = max([len(ids) for ids in phoneme_ids])
432
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
433
+
434
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
435
+
436
+ # Convert to embeddings
437
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
438
+
439
+ phoneme_embeds = self.proj_out(phoneme_embeds)
440
+
441
+ return phoneme_embeds, torch.ones(
442
+ phoneme_embeds.shape[0], phoneme_embeds.shape[1]
443
+ ).to(device)
444
+
445
+
446
+ class TokenizerLUTConditioner(Conditioner):
447
+ """
448
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
449
+
450
+ Args:
451
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
452
+ output_dim: the dimension of the output embeddings
453
+ max_length: the maximum length of the text to embed
454
+ project_out: whether to add another linear projection to the output embeddings
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
460
+ output_dim: int,
461
+ max_length: int = 1024,
462
+ project_out: bool = False,
463
+ ):
464
+ super().__init__(output_dim, output_dim, project_out=project_out)
465
+
466
+ from transformers import AutoTokenizer
467
+
468
+ # Suppress logging from transformers
469
+ previous_level = logging.root.manager.disable
470
+ logging.disable(logging.ERROR)
471
+ with warnings.catch_warnings():
472
+ warnings.simplefilter("ignore")
473
+ try:
474
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
475
+ finally:
476
+ logging.disable(previous_level)
477
+
478
+ self.max_length = max_length
479
+
480
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
481
+
482
+ def forward(
483
+ self, texts: tp.List[str], device: tp.Union[torch.device, str]
484
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
485
+ self.proj_out.to(device)
486
+
487
+ encoded = self.tokenizer(
488
+ texts,
489
+ truncation=True,
490
+ max_length=self.max_length,
491
+ padding="max_length",
492
+ return_tensors="pt",
493
+ )
494
+
495
+ input_ids = encoded["input_ids"].to(device)
496
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
497
+
498
+ embeddings = self.token_embedder(input_ids)
499
+
500
+ embeddings = self.proj_out(embeddings)
501
+
502
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
503
+
504
+ return embeddings, attention_mask
505
+
506
+
507
+ class PretransformConditioner(Conditioner):
508
+ """
509
+ A conditioner that uses a pretransform's encoder for conditioning
510
+
511
+ Args:
512
+ pretransform: an instantiated pretransform to use for conditioning
513
+ output_dim: the dimension of the output embeddings
514
+ """
515
+
516
+ def __init__(self, pretransform: Pretransform, output_dim: int):
517
+ super().__init__(pretransform.encoded_channels, output_dim)
518
+
519
+ self.pretransform = pretransform
520
+
521
+ def forward(
522
+ self,
523
+ audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]],
524
+ device: tp.Union[torch.device, str],
525
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
526
+ self.pretransform.to(device)
527
+ self.proj_out.to(device)
528
+
529
+ if isinstance(audio, list) or isinstance(audio, tuple):
530
+ audio = torch.cat(audio, dim=0)
531
+
532
+ # Convert audio to pretransform input channels
533
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
534
+
535
+ latents = self.pretransform.encode(audio)
536
+
537
+ latents = self.proj_out(latents)
538
+
539
+ return [
540
+ latents,
541
+ torch.ones(latents.shape[0], latents.shape[2]).to(latents.device),
542
+ ]
543
+
544
+
545
+ class MultiConditioner(nn.Module):
546
+ """
547
+ A module that applies multiple conditioners to an input dictionary based on the keys
548
+
549
+ Args:
550
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
551
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
552
+ """
553
+
554
+ def __init__(
555
+ self,
556
+ conditioners: tp.Dict[str, Conditioner],
557
+ default_keys: tp.Dict[str, str] = {},
558
+ ):
559
+ super().__init__()
560
+
561
+ self.conditioners = nn.ModuleDict(conditioners)
562
+ self.default_keys = default_keys
563
+
564
+ def forward(
565
+ self,
566
+ batch_metadata: tp.List[tp.Dict[str, tp.Any]],
567
+ device: tp.Union[torch.device, str],
568
+ ) -> tp.Dict[str, tp.Any]:
569
+ output = {}
570
+
571
+ for key, conditioner in self.conditioners.items():
572
+ condition_key = key
573
+
574
+ conditioner_inputs = []
575
+
576
+ for x in batch_metadata:
577
+ if condition_key not in x:
578
+ if condition_key in self.default_keys:
579
+ condition_key = self.default_keys[condition_key]
580
+ else:
581
+ raise ValueError(
582
+ f"Conditioner key {condition_key} not found in batch metadata"
583
+ )
584
+
585
+ # Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
586
+ if (
587
+ isinstance(x[condition_key], list)
588
+ or isinstance(x[condition_key], tuple)
589
+ and len(x[condition_key]) == 1
590
+ ):
591
+ conditioner_input = x[condition_key][0]
592
+
593
+ else:
594
+ conditioner_input = x[condition_key]
595
+
596
+ conditioner_inputs.append(conditioner_input)
597
+
598
+ output[key] = conditioner(conditioner_inputs, device)
599
+
600
+ return output
601
+
602
+
603
+ def create_multi_conditioner_from_conditioning_config(
604
+ config: tp.Dict[str, tp.Any],
605
+ ) -> MultiConditioner:
606
+ """
607
+ Create a MultiConditioner from a conditioning config dictionary
608
+
609
+ Args:
610
+ config: the conditioning config dictionary
611
+ device: the device to put the conditioners on
612
+ """
613
+ conditioners = {}
614
+ cond_dim = config["cond_dim"]
615
+
616
+ default_keys = config.get("default_keys", {})
617
+
618
+ for conditioner_info in config["configs"]:
619
+ id = conditioner_info["id"]
620
+
621
+ conditioner_type = conditioner_info["type"]
622
+
623
+ conditioner_config = {"output_dim": cond_dim}
624
+
625
+ conditioner_config.update(conditioner_info["config"])
626
+
627
+ if conditioner_type == "t5":
628
+ conditioners[id] = T5Conditioner(**conditioner_config)
629
+ elif conditioner_type == "clap_text":
630
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
631
+ elif conditioner_type == "clap_audio":
632
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
633
+ elif conditioner_type == "int":
634
+ conditioners[id] = IntConditioner(**conditioner_config)
635
+ elif conditioner_type == "number":
636
+ conditioners[id] = NumberConditioner(**conditioner_config)
637
+ elif conditioner_type == "phoneme":
638
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
639
+ elif conditioner_type == "lut":
640
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
641
+ elif conditioner_type == "pretransform":
642
+ sample_rate = conditioner_config.pop("sample_rate", None)
643
+ assert sample_rate is not None, (
644
+ "Sample rate must be specified for pretransform conditioners"
645
+ )
646
+
647
+ pretransform = create_pretransform_from_config(
648
+ conditioner_config.pop("pretransform_config"), sample_rate=sample_rate
649
+ )
650
+
651
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
652
+ pretransform.load_state_dict(
653
+ load_ckpt_state_dict(
654
+ conditioner_config.pop("pretransform_ckpt_path")
655
+ )
656
+ )
657
+
658
+ conditioners[id] = PretransformConditioner(
659
+ pretransform, **conditioner_config
660
+ )
661
+ else:
662
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
663
+
664
+ return MultiConditioner(conditioners, default_keys=default_keys)
src/YingMusicSinger/utils/stable_audio_tools/diffusion.py ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+ from functools import partial
3
+ from time import time
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ # from ..inference.generation import generate_diffusion_cond
11
+ from .adp import UNet1d, UNetCFG1d
12
+ from .blocks import (
13
+ Downsample1d,
14
+ Downsample1d_2,
15
+ FourierFeatures,
16
+ ResConvBlock,
17
+ SelfAttention1d,
18
+ SkipBlock,
19
+ Upsample1d,
20
+ Upsample1d_2,
21
+ expand_to_planes,
22
+ )
23
+ from .conditioners import (
24
+ MultiConditioner,
25
+ create_multi_conditioner_from_conditioning_config,
26
+ )
27
+ from .dit import DiffusionTransformer
28
+ from .factory import create_pretransform_from_config
29
+ from .pretransforms import Pretransform
30
+
31
+
32
+ class Profiler:
33
+ def __init__(self):
34
+ self.ticks = [[time(), None]]
35
+
36
+ def tick(self, msg):
37
+ self.ticks.append([time(), msg])
38
+
39
+ def __repr__(self):
40
+ rep = 80 * "=" + "\n"
41
+ for i in range(1, len(self.ticks)):
42
+ msg = self.ticks[i][1]
43
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
44
+ rep += msg + f": {ellapsed * 1000:.2f}ms\n"
45
+ rep += 80 * "=" + "\n\n\n"
46
+ return rep
47
+
48
+
49
+ class DiffusionModel(nn.Module):
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+
53
+ def forward(self, x, t, **kwargs):
54
+ raise NotImplementedError()
55
+
56
+
57
+ class DiffusionModelWrapper(nn.Module):
58
+ def __init__(
59
+ self,
60
+ model: DiffusionModel,
61
+ io_channels,
62
+ sample_size,
63
+ sample_rate,
64
+ min_input_length,
65
+ pretransform: tp.Optional[Pretransform] = None,
66
+ ):
67
+ super().__init__()
68
+ self.io_channels = io_channels
69
+ self.sample_size = sample_size
70
+ self.sample_rate = sample_rate
71
+ self.min_input_length = min_input_length
72
+
73
+ self.model = model
74
+
75
+ if pretransform is not None:
76
+ self.pretransform = pretransform
77
+ else:
78
+ self.pretransform = None
79
+
80
+ def forward(self, x, t, **kwargs):
81
+ return self.model(x, t, **kwargs)
82
+
83
+
84
+ class ConditionedDiffusionModel(nn.Module):
85
+ def __init__(
86
+ self,
87
+ *args,
88
+ supports_cross_attention: bool = False,
89
+ supports_input_concat: bool = False,
90
+ supports_global_cond: bool = False,
91
+ supports_prepend_cond: bool = False,
92
+ **kwargs,
93
+ ):
94
+ super().__init__(*args, **kwargs)
95
+ self.supports_cross_attention = supports_cross_attention
96
+ self.supports_input_concat = supports_input_concat
97
+ self.supports_global_cond = supports_global_cond
98
+ self.supports_prepend_cond = supports_prepend_cond
99
+
100
+ def forward(
101
+ self,
102
+ x: torch.Tensor,
103
+ t: torch.Tensor,
104
+ cross_attn_cond: torch.Tensor = None,
105
+ cross_attn_mask: torch.Tensor = None,
106
+ input_concat_cond: torch.Tensor = None,
107
+ global_embed: torch.Tensor = None,
108
+ prepend_cond: torch.Tensor = None,
109
+ prepend_cond_mask: torch.Tensor = None,
110
+ cfg_scale: float = 1.0,
111
+ cfg_dropout_prob: float = 0.0,
112
+ batch_cfg: bool = False,
113
+ rescale_cfg: bool = False,
114
+ **kwargs,
115
+ ):
116
+ raise NotImplementedError()
117
+
118
+
119
+ class ConditionedDiffusionModelWrapper(nn.Module):
120
+ """
121
+ A diffusion model that takes in conditioning
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ model: ConditionedDiffusionModel,
127
+ conditioner: MultiConditioner,
128
+ io_channels,
129
+ sample_rate,
130
+ min_input_length: int,
131
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
132
+ pretransform: tp.Optional[Pretransform] = None,
133
+ cross_attn_cond_ids: tp.List[str] = [],
134
+ global_cond_ids: tp.List[str] = [],
135
+ input_concat_ids: tp.List[str] = [],
136
+ prepend_cond_ids: tp.List[str] = [],
137
+ ):
138
+ super().__init__()
139
+
140
+ self.model = model
141
+ self.conditioner = conditioner
142
+ self.io_channels = io_channels
143
+ self.sample_rate = sample_rate
144
+ self.diffusion_objective = diffusion_objective
145
+ self.pretransform = pretransform
146
+ self.cross_attn_cond_ids = cross_attn_cond_ids
147
+ self.global_cond_ids = global_cond_ids
148
+ self.input_concat_ids = input_concat_ids
149
+ self.prepend_cond_ids = prepend_cond_ids
150
+ self.min_input_length = min_input_length
151
+
152
+ def get_conditioning_inputs(
153
+ self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False
154
+ ):
155
+ cross_attention_input = None
156
+ cross_attention_masks = None
157
+ global_cond = None
158
+ input_concat_cond = None
159
+ prepend_cond = None
160
+ prepend_cond_mask = None
161
+
162
+ if len(self.cross_attn_cond_ids) > 0:
163
+ # Concatenate all cross-attention inputs over the sequence dimension
164
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
165
+ cross_attention_input = []
166
+ cross_attention_masks = []
167
+
168
+ for key in self.cross_attn_cond_ids:
169
+ cross_attn_in, cross_attn_mask = conditioning_tensors[key]
170
+
171
+ # Add sequence dimension if it's not there
172
+ if len(cross_attn_in.shape) == 2:
173
+ cross_attn_in = cross_attn_in.unsqueeze(1)
174
+ cross_attn_mask = cross_attn_mask.unsqueeze(1)
175
+
176
+ cross_attention_input.append(cross_attn_in)
177
+ cross_attention_masks.append(cross_attn_mask)
178
+
179
+ cross_attention_input = torch.cat(cross_attention_input, dim=1)
180
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
181
+
182
+ if len(self.global_cond_ids) > 0:
183
+ # Concatenate all global conditioning inputs over the channel dimension
184
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
185
+ global_conds = []
186
+ for key in self.global_cond_ids:
187
+ global_cond_input = conditioning_tensors[key][0]
188
+
189
+ global_conds.append(global_cond_input)
190
+
191
+ # Concatenate over the channel dimension
192
+ global_cond = torch.cat(global_conds, dim=-1)
193
+
194
+ if len(global_cond.shape) == 3:
195
+ global_cond = global_cond.squeeze(1)
196
+
197
+ if len(self.input_concat_ids) > 0:
198
+ # Concatenate all input concat conditioning inputs over the channel dimension
199
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
200
+ input_concat_cond = torch.cat(
201
+ [conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1
202
+ )
203
+
204
+ if len(self.prepend_cond_ids) > 0:
205
+ # Concatenate all prepend conditioning inputs over the sequence dimension
206
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
207
+ prepend_conds = []
208
+ prepend_cond_masks = []
209
+
210
+ for key in self.prepend_cond_ids:
211
+ prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
212
+ prepend_conds.append(prepend_cond_input)
213
+ prepend_cond_masks.append(prepend_cond_mask)
214
+
215
+ prepend_cond = torch.cat(prepend_conds, dim=1)
216
+ prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
217
+
218
+ if negative:
219
+ return {
220
+ "negative_cross_attn_cond": cross_attention_input,
221
+ "negative_cross_attn_mask": cross_attention_masks,
222
+ "negative_global_cond": global_cond,
223
+ "negative_input_concat_cond": input_concat_cond,
224
+ }
225
+ else:
226
+ return {
227
+ "cross_attn_cond": cross_attention_input,
228
+ "cross_attn_mask": cross_attention_masks,
229
+ "global_cond": global_cond,
230
+ "input_concat_cond": input_concat_cond,
231
+ "prepend_cond": prepend_cond,
232
+ "prepend_cond_mask": prepend_cond_mask,
233
+ }
234
+
235
+ def forward(
236
+ self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs
237
+ ):
238
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
239
+
240
+ def generate(self, *args, **kwargs):
241
+ return generate_diffusion_cond(self, *args, **kwargs)
242
+
243
+
244
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
245
+ def __init__(self, *args, **kwargs):
246
+ super().__init__(
247
+ supports_cross_attention=True,
248
+ supports_global_cond=True,
249
+ supports_input_concat=True,
250
+ )
251
+
252
+ self.model = UNetCFG1d(*args, **kwargs)
253
+
254
+ with torch.no_grad():
255
+ for param in self.model.parameters():
256
+ param *= 0.5
257
+
258
+ def forward(
259
+ self,
260
+ x,
261
+ t,
262
+ cross_attn_cond=None,
263
+ cross_attn_mask=None,
264
+ input_concat_cond=None,
265
+ global_cond=None,
266
+ cfg_scale=1.0,
267
+ cfg_dropout_prob: float = 0.0,
268
+ batch_cfg: bool = False,
269
+ rescale_cfg: bool = False,
270
+ negative_cross_attn_cond=None,
271
+ negative_cross_attn_mask=None,
272
+ negative_global_cond=None,
273
+ negative_input_concat_cond=None,
274
+ prepend_cond=None,
275
+ prepend_cond_mask=None,
276
+ **kwargs,
277
+ ):
278
+ p = Profiler()
279
+
280
+ p.tick("start")
281
+
282
+ channels_list = None
283
+ if input_concat_cond is not None:
284
+ channels_list = [input_concat_cond]
285
+
286
+ outputs = self.model(
287
+ x,
288
+ t,
289
+ embedding=cross_attn_cond,
290
+ embedding_mask=cross_attn_mask,
291
+ features=global_cond,
292
+ channels_list=channels_list,
293
+ embedding_scale=cfg_scale,
294
+ embedding_mask_proba=cfg_dropout_prob,
295
+ batch_cfg=batch_cfg,
296
+ rescale_cfg=rescale_cfg,
297
+ negative_embedding=negative_cross_attn_cond,
298
+ negative_embedding_mask=negative_cross_attn_mask,
299
+ **kwargs,
300
+ )
301
+
302
+ p.tick("UNetCFG1D forward")
303
+
304
+ # print(f"Profiler: {p}")
305
+ return outputs
306
+
307
+
308
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
309
+ def __init__(self, *args, **kwargs):
310
+ super().__init__(
311
+ supports_cross_attention=False,
312
+ supports_global_cond=True,
313
+ supports_input_concat=True,
314
+ )
315
+
316
+ self.model = UNet1d(*args, **kwargs)
317
+
318
+ with torch.no_grad():
319
+ for param in self.model.parameters():
320
+ param *= 0.5
321
+
322
+ def forward(
323
+ self,
324
+ x,
325
+ t,
326
+ input_concat_cond=None,
327
+ global_cond=None,
328
+ cross_attn_cond=None,
329
+ cross_attn_mask=None,
330
+ prepend_cond=None,
331
+ prepend_cond_mask=None,
332
+ cfg_scale=1.0,
333
+ cfg_dropout_prob: float = 0.0,
334
+ batch_cfg: bool = False,
335
+ rescale_cfg: bool = False,
336
+ negative_cross_attn_cond=None,
337
+ negative_cross_attn_mask=None,
338
+ negative_global_cond=None,
339
+ negative_input_concat_cond=None,
340
+ **kwargs,
341
+ ):
342
+ channels_list = None
343
+ if input_concat_cond is not None:
344
+ # Interpolate input_concat_cond to the same length as x
345
+ if input_concat_cond.shape[2] != x.shape[2]:
346
+ input_concat_cond = F.interpolate(
347
+ input_concat_cond, (x.shape[2],), mode="nearest"
348
+ )
349
+
350
+ channels_list = [input_concat_cond]
351
+
352
+ outputs = self.model(
353
+ x, t, features=global_cond, channels_list=channels_list, **kwargs
354
+ )
355
+
356
+ return outputs
357
+
358
+
359
+ class UNet1DUncondWrapper(DiffusionModel):
360
+ def __init__(self, in_channels, *args, **kwargs):
361
+ super().__init__()
362
+
363
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
364
+
365
+ self.io_channels = in_channels
366
+
367
+ with torch.no_grad():
368
+ for param in self.model.parameters():
369
+ param *= 0.5
370
+
371
+ def forward(self, x, t, **kwargs):
372
+ return self.model(x, t, **kwargs)
373
+
374
+
375
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
376
+ def __init__(self, *args, **kwargs):
377
+ super().__init__(
378
+ supports_cross_attention=False,
379
+ supports_global_cond=False,
380
+ supports_input_concat=True,
381
+ )
382
+
383
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
384
+
385
+ with torch.no_grad():
386
+ for param in self.model.parameters():
387
+ param *= 0.5
388
+
389
+ def forward(
390
+ self,
391
+ x,
392
+ t,
393
+ input_concat_cond=None,
394
+ cross_attn_cond=None,
395
+ cross_attn_mask=None,
396
+ global_cond=None,
397
+ cfg_scale=1.0,
398
+ cfg_dropout_prob: float = 0.0,
399
+ batch_cfg: bool = False,
400
+ rescale_cfg: bool = False,
401
+ negative_cross_attn_cond=None,
402
+ negative_cross_attn_mask=None,
403
+ negative_global_cond=None,
404
+ negative_input_concat_cond=None,
405
+ prepend_cond=None,
406
+ **kwargs,
407
+ ):
408
+ return self.model(x, t, cond=input_concat_cond)
409
+
410
+
411
+ class DiffusionAttnUnet1D(nn.Module):
412
+ def __init__(
413
+ self,
414
+ io_channels=2,
415
+ depth=14,
416
+ n_attn_layers=6,
417
+ channels=[128, 128, 256, 256] + [512] * 10,
418
+ cond_dim=0,
419
+ cond_noise_aug=False,
420
+ kernel_size=5,
421
+ learned_resample=False,
422
+ strides=[2] * 13,
423
+ conv_bias=True,
424
+ use_snake=False,
425
+ ):
426
+ super().__init__()
427
+
428
+ self.cond_noise_aug = cond_noise_aug
429
+
430
+ self.io_channels = io_channels
431
+
432
+ if self.cond_noise_aug:
433
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
434
+
435
+ self.timestep_embed = FourierFeatures(1, 16)
436
+
437
+ attn_layer = depth - n_attn_layers
438
+
439
+ strides = [1] + strides
440
+
441
+ block = nn.Identity()
442
+
443
+ conv_block = partial(
444
+ ResConvBlock,
445
+ kernel_size=kernel_size,
446
+ conv_bias=conv_bias,
447
+ use_snake=use_snake,
448
+ )
449
+
450
+ for i in range(depth, 0, -1):
451
+ c = channels[i - 1]
452
+ stride = strides[i - 1]
453
+ if stride > 2 and not learned_resample:
454
+ raise ValueError("Must have stride 2 without learned resampling")
455
+
456
+ if i > 1:
457
+ c_prev = channels[i - 2]
458
+ add_attn = i >= attn_layer and n_attn_layers > 0
459
+ block = SkipBlock(
460
+ Downsample1d_2(c_prev, c_prev, stride)
461
+ if (learned_resample or stride == 1)
462
+ else Downsample1d("cubic"),
463
+ conv_block(c_prev, c, c),
464
+ SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
465
+ conv_block(c, c, c),
466
+ SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
467
+ conv_block(c, c, c),
468
+ SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
469
+ block,
470
+ conv_block(c * 2 if i != depth else c, c, c),
471
+ SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
472
+ conv_block(c, c, c),
473
+ SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
474
+ conv_block(c, c, c_prev),
475
+ SelfAttention1d(c_prev, c_prev // 32)
476
+ if add_attn
477
+ else nn.Identity(),
478
+ Upsample1d_2(c_prev, c_prev, stride)
479
+ if learned_resample
480
+ else Upsample1d(kernel="cubic"),
481
+ )
482
+ else:
483
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
484
+ block = nn.Sequential(
485
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
486
+ conv_block(c, c, c),
487
+ conv_block(c, c, c),
488
+ block,
489
+ conv_block(c * 2, c, c),
490
+ conv_block(c, c, c),
491
+ conv_block(c, c, io_channels, is_last=True),
492
+ )
493
+ self.net = block
494
+
495
+ with torch.no_grad():
496
+ for param in self.net.parameters():
497
+ param *= 0.5
498
+
499
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
500
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
501
+
502
+ inputs = [x, timestep_embed]
503
+
504
+ if cond is not None:
505
+ if cond.shape[2] != x.shape[2]:
506
+ cond = F.interpolate(
507
+ cond, (x.shape[2],), mode="linear", align_corners=False
508
+ )
509
+
510
+ if self.cond_noise_aug:
511
+ # Get a random number between 0 and 1, uniformly sampled
512
+ if cond_aug_scale is None:
513
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
514
+ else:
515
+ aug_level = (
516
+ torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
517
+ )
518
+
519
+ # Add noise to the conditioning signal
520
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
521
+
522
+ # Get embedding for noise cond level, reusing timestamp_embed
523
+ aug_level_embed = expand_to_planes(
524
+ self.timestep_embed(aug_level[:, None]), x.shape
525
+ )
526
+
527
+ inputs.append(aug_level_embed)
528
+
529
+ inputs.append(cond)
530
+
531
+ outputs = self.net(torch.cat(inputs, dim=1))
532
+
533
+ return outputs
534
+
535
+
536
+ class DiTWrapper(ConditionedDiffusionModel):
537
+ def __init__(self, *args, **kwargs):
538
+ super().__init__(
539
+ supports_cross_attention=True,
540
+ supports_global_cond=False,
541
+ supports_input_concat=False,
542
+ )
543
+
544
+ self.model = DiffusionTransformer(*args, **kwargs)
545
+
546
+ with torch.no_grad():
547
+ for param in self.model.parameters():
548
+ param *= 0.5
549
+
550
+ def forward(
551
+ self,
552
+ x,
553
+ t,
554
+ cross_attn_cond=None,
555
+ cross_attn_mask=None,
556
+ negative_cross_attn_cond=None,
557
+ negative_cross_attn_mask=None,
558
+ input_concat_cond=None,
559
+ negative_input_concat_cond=None,
560
+ global_cond=None,
561
+ negative_global_cond=None,
562
+ prepend_cond=None,
563
+ prepend_cond_mask=None,
564
+ cfg_scale=1.0,
565
+ cfg_dropout_prob: float = 0.0,
566
+ batch_cfg: bool = True,
567
+ rescale_cfg: bool = False,
568
+ scale_phi: float = 0.0,
569
+ **kwargs,
570
+ ):
571
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
572
+ # assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
573
+
574
+ return self.model(
575
+ x,
576
+ t,
577
+ cross_attn_cond=cross_attn_cond,
578
+ cross_attn_cond_mask=cross_attn_mask,
579
+ negative_cross_attn_cond=negative_cross_attn_cond,
580
+ negative_cross_attn_mask=negative_cross_attn_mask,
581
+ input_concat_cond=input_concat_cond,
582
+ prepend_cond=prepend_cond,
583
+ prepend_cond_mask=prepend_cond_mask,
584
+ cfg_scale=cfg_scale,
585
+ cfg_dropout_prob=cfg_dropout_prob,
586
+ scale_phi=scale_phi,
587
+ global_embed=global_cond,
588
+ **kwargs,
589
+ )
590
+
591
+
592
+ class DiTUncondWrapper(DiffusionModel):
593
+ def __init__(self, in_channels, *args, **kwargs):
594
+ super().__init__()
595
+
596
+ self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
597
+
598
+ self.io_channels = in_channels
599
+
600
+ with torch.no_grad():
601
+ for param in self.model.parameters():
602
+ param *= 0.5
603
+
604
+ def forward(self, x, t, **kwargs):
605
+ return self.model(x, t, **kwargs)
606
+
607
+
608
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
609
+ diffusion_uncond_config = config["model"]
610
+
611
+ model_type = diffusion_uncond_config.get("type", None)
612
+
613
+ diffusion_config = diffusion_uncond_config.get("config", {})
614
+
615
+ assert model_type is not None, "Must specify model type in config"
616
+
617
+ pretransform = diffusion_uncond_config.get("pretransform", None)
618
+
619
+ sample_size = config.get("sample_size", None)
620
+ assert sample_size is not None, "Must specify sample size in config"
621
+
622
+ sample_rate = config.get("sample_rate", None)
623
+ assert sample_rate is not None, "Must specify sample rate in config"
624
+
625
+ if pretransform is not None:
626
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
627
+ min_input_length = pretransform.downsampling_ratio
628
+ else:
629
+ min_input_length = 1
630
+
631
+ if model_type == "DAU1d":
632
+ model = DiffusionAttnUnet1D(**diffusion_config)
633
+
634
+ elif model_type == "adp_uncond_1d":
635
+ model = UNet1DUncondWrapper(**diffusion_config)
636
+
637
+ elif model_type == "dit":
638
+ model = DiTUncondWrapper(**diffusion_config)
639
+
640
+ else:
641
+ raise NotImplementedError(f"Unknown model type: {model_type}")
642
+
643
+ return DiffusionModelWrapper(
644
+ model,
645
+ io_channels=model.io_channels,
646
+ sample_size=sample_size,
647
+ sample_rate=sample_rate,
648
+ pretransform=pretransform,
649
+ min_input_length=min_input_length,
650
+ )
651
+
652
+
653
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
654
+ model_config = config["model"]
655
+
656
+ model_type = config["model_type"]
657
+
658
+ diffusion_config = model_config.get("diffusion", None)
659
+ assert diffusion_config is not None, "Must specify diffusion config"
660
+
661
+ diffusion_model_type = diffusion_config.get("type", None)
662
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
663
+
664
+ diffusion_model_config = diffusion_config.get("config", None)
665
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
666
+
667
+ if diffusion_model_type == "adp_cfg_1d":
668
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
669
+ elif diffusion_model_type == "adp_1d":
670
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
671
+ elif diffusion_model_type == "dit":
672
+ diffusion_model = DiTWrapper(**diffusion_model_config)
673
+
674
+ io_channels = model_config.get("io_channels", None)
675
+ assert io_channels is not None, "Must specify io_channels in model config"
676
+
677
+ sample_rate = config.get("sample_rate", None)
678
+ assert sample_rate is not None, "Must specify sample_rate in config"
679
+
680
+ diffusion_objective = diffusion_config.get("diffusion_objective", "v")
681
+
682
+ conditioning_config = model_config.get("conditioning", None)
683
+
684
+ conditioner = None
685
+ if conditioning_config is not None:
686
+ conditioner = create_multi_conditioner_from_conditioning_config(
687
+ conditioning_config
688
+ )
689
+
690
+ cross_attention_ids = diffusion_config.get("cross_attention_cond_ids", [])
691
+ global_cond_ids = diffusion_config.get("global_cond_ids", [])
692
+ input_concat_ids = diffusion_config.get("input_concat_ids", [])
693
+ prepend_cond_ids = diffusion_config.get("prepend_cond_ids", [])
694
+
695
+ pretransform = model_config.get("pretransform", None)
696
+
697
+ if pretransform is not None:
698
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
699
+ min_input_length = pretransform.downsampling_ratio
700
+ else:
701
+ min_input_length = 1
702
+
703
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
704
+ min_input_length *= np.prod(diffusion_model_config["factors"])
705
+ elif diffusion_model_type == "dit":
706
+ min_input_length *= diffusion_model.model.patch_size
707
+
708
+ # Get the proper wrapper class
709
+
710
+ extra_kwargs = {}
711
+
712
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
713
+ wrapper_fn = ConditionedDiffusionModelWrapper
714
+
715
+ extra_kwargs["diffusion_objective"] = diffusion_objective
716
+
717
+ elif model_type == "diffusion_prior":
718
+ prior_type = model_config.get("prior_type", None)
719
+ assert prior_type is not None, (
720
+ "Must specify prior_type in diffusion prior model config"
721
+ )
722
+
723
+ if prior_type == "mono_stereo":
724
+ from .diffusion_prior import MonoToStereoDiffusionPrior
725
+
726
+ wrapper_fn = MonoToStereoDiffusionPrior
727
+
728
+ return wrapper_fn(
729
+ diffusion_model,
730
+ conditioner,
731
+ min_input_length=min_input_length,
732
+ sample_rate=sample_rate,
733
+ cross_attn_cond_ids=cross_attention_ids,
734
+ global_cond_ids=global_cond_ids,
735
+ input_concat_ids=input_concat_ids,
736
+ prepend_cond_ids=prepend_cond_ids,
737
+ pretransform=pretransform,
738
+ io_channels=io_channels,
739
+ **extra_kwargs,
740
+ )
src/YingMusicSinger/utils/stable_audio_tools/dit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from x_transformers import ContinuousTransformerWrapper, Encoder
8
+
9
+ from .blocks import FourierFeatures
10
+ from .transformer import ContinuousTransformer
11
+
12
+
13
+ class DiffusionTransformer(nn.Module):
14
+ def __init__(
15
+ self,
16
+ io_channels=32,
17
+ patch_size=1,
18
+ embed_dim=768,
19
+ cond_token_dim=0,
20
+ project_cond_tokens=True,
21
+ global_cond_dim=0,
22
+ project_global_cond=True,
23
+ input_concat_dim=0,
24
+ prepend_cond_dim=0,
25
+ depth=12,
26
+ num_heads=8,
27
+ transformer_type: tp.Literal[
28
+ "x-transformers", "continuous_transformer"
29
+ ] = "x-transformers",
30
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
31
+ **kwargs,
32
+ ):
33
+ super().__init__()
34
+
35
+ self.cond_token_dim = cond_token_dim
36
+
37
+ # Timestep embeddings
38
+ timestep_features_dim = 256
39
+
40
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
41
+
42
+ self.to_timestep_embed = nn.Sequential(
43
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
44
+ nn.SiLU(),
45
+ nn.Linear(embed_dim, embed_dim, bias=True),
46
+ )
47
+
48
+ if cond_token_dim > 0:
49
+ # Conditioning tokens
50
+
51
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
52
+ self.to_cond_embed = nn.Sequential(
53
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
54
+ nn.SiLU(),
55
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False),
56
+ )
57
+ else:
58
+ cond_embed_dim = 0
59
+
60
+ if global_cond_dim > 0:
61
+ # Global conditioning
62
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
63
+ self.to_global_embed = nn.Sequential(
64
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
65
+ nn.SiLU(),
66
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False),
67
+ )
68
+
69
+ if prepend_cond_dim > 0:
70
+ # Prepend conditioning
71
+ self.to_prepend_embed = nn.Sequential(
72
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
73
+ nn.SiLU(),
74
+ nn.Linear(embed_dim, embed_dim, bias=False),
75
+ )
76
+
77
+ self.input_concat_dim = input_concat_dim
78
+
79
+ dim_in = io_channels + self.input_concat_dim
80
+
81
+ self.patch_size = patch_size
82
+
83
+ # Transformer
84
+
85
+ self.transformer_type = transformer_type
86
+
87
+ self.global_cond_type = global_cond_type
88
+
89
+ if self.transformer_type == "x-transformers":
90
+ self.transformer = ContinuousTransformerWrapper(
91
+ dim_in=dim_in * patch_size,
92
+ dim_out=io_channels * patch_size,
93
+ max_seq_len=0, # Not relevant without absolute positional embeds
94
+ attn_layers=Encoder(
95
+ dim=embed_dim,
96
+ depth=depth,
97
+ heads=num_heads,
98
+ attn_flash=True,
99
+ cross_attend=cond_token_dim > 0,
100
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
101
+ zero_init_branch_output=True,
102
+ use_abs_pos_emb=False,
103
+ rotary_pos_emb=True,
104
+ ff_swish=True,
105
+ ff_glu=True,
106
+ **kwargs,
107
+ ),
108
+ )
109
+
110
+ elif self.transformer_type == "continuous_transformer":
111
+ global_dim = None
112
+
113
+ if self.global_cond_type == "adaLN":
114
+ # The global conditioning is projected to the embed_dim already at this point
115
+ global_dim = embed_dim
116
+
117
+ self.transformer = ContinuousTransformer(
118
+ dim=embed_dim,
119
+ depth=depth,
120
+ dim_heads=embed_dim // num_heads,
121
+ dim_in=dim_in * patch_size,
122
+ dim_out=io_channels * patch_size,
123
+ cross_attend=cond_token_dim > 0,
124
+ cond_token_dim=cond_embed_dim,
125
+ global_cond_dim=global_dim,
126
+ **kwargs,
127
+ )
128
+
129
+ else:
130
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
131
+
132
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
133
+ nn.init.zeros_(self.preprocess_conv.weight)
134
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
135
+ nn.init.zeros_(self.postprocess_conv.weight)
136
+
137
+ def _forward(
138
+ self,
139
+ x,
140
+ t,
141
+ mask=None,
142
+ cross_attn_cond=None,
143
+ cross_attn_cond_mask=None,
144
+ input_concat_cond=None,
145
+ global_embed=None,
146
+ prepend_cond=None,
147
+ prepend_cond_mask=None,
148
+ return_info=False,
149
+ **kwargs,
150
+ ):
151
+ if cross_attn_cond is not None:
152
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
153
+
154
+ if global_embed is not None:
155
+ # Project the global conditioning to the embedding dimension
156
+ global_embed = self.to_global_embed(global_embed)
157
+
158
+ prepend_inputs = None
159
+ prepend_mask = None
160
+ prepend_length = 0
161
+ if prepend_cond is not None:
162
+ # Project the prepend conditioning to the embedding dimension
163
+ prepend_cond = self.to_prepend_embed(prepend_cond)
164
+
165
+ prepend_inputs = prepend_cond
166
+ if prepend_cond_mask is not None:
167
+ prepend_mask = prepend_cond_mask
168
+
169
+ if input_concat_cond is not None:
170
+ # Interpolate input_concat_cond to the same length as x
171
+ if input_concat_cond.shape[2] != x.shape[2]:
172
+ input_concat_cond = F.interpolate(
173
+ input_concat_cond, (x.shape[2],), mode="nearest"
174
+ )
175
+
176
+ x = torch.cat([x, input_concat_cond], dim=1)
177
+
178
+ # Get the batch of timestep embeddings
179
+ timestep_embed = self.to_timestep_embed(
180
+ self.timestep_features(t[:, None])
181
+ ) # (b, embed_dim)
182
+
183
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
184
+ if global_embed is not None:
185
+ global_embed = global_embed + timestep_embed
186
+ else:
187
+ global_embed = timestep_embed
188
+
189
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
190
+ if self.global_cond_type == "prepend":
191
+ if prepend_inputs is None:
192
+ # Prepend inputs are just the global embed, and the mask is all ones
193
+ prepend_inputs = global_embed.unsqueeze(1)
194
+ prepend_mask = torch.ones(
195
+ (x.shape[0], 1), device=x.device, dtype=torch.bool
196
+ )
197
+ else:
198
+ # Prepend inputs are the prepend conditioning + the global embed
199
+ prepend_inputs = torch.cat(
200
+ [prepend_inputs, global_embed.unsqueeze(1)], dim=1
201
+ )
202
+ prepend_mask = torch.cat(
203
+ [
204
+ prepend_mask,
205
+ torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool),
206
+ ],
207
+ dim=1,
208
+ )
209
+
210
+ prepend_length = prepend_inputs.shape[1]
211
+
212
+ x = self.preprocess_conv(x) + x
213
+
214
+ x = rearrange(x, "b c t -> b t c")
215
+
216
+ extra_args = {}
217
+
218
+ if self.global_cond_type == "adaLN":
219
+ extra_args["global_cond"] = global_embed
220
+
221
+ if self.patch_size > 1:
222
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
223
+
224
+ if self.transformer_type == "x-transformers":
225
+ output = self.transformer(
226
+ x,
227
+ prepend_embeds=prepend_inputs,
228
+ context=cross_attn_cond,
229
+ context_mask=cross_attn_cond_mask,
230
+ mask=mask,
231
+ prepend_mask=prepend_mask,
232
+ **extra_args,
233
+ **kwargs,
234
+ )
235
+ elif self.transformer_type == "continuous_transformer":
236
+ output = self.transformer(
237
+ x,
238
+ prepend_embeds=prepend_inputs,
239
+ context=cross_attn_cond,
240
+ context_mask=cross_attn_cond_mask,
241
+ mask=mask,
242
+ prepend_mask=prepend_mask,
243
+ return_info=return_info,
244
+ **extra_args,
245
+ **kwargs,
246
+ )
247
+
248
+ if return_info:
249
+ output, info = output
250
+ elif self.transformer_type == "mm_transformer":
251
+ output = self.transformer(
252
+ x,
253
+ context=cross_attn_cond,
254
+ mask=mask,
255
+ context_mask=cross_attn_cond_mask,
256
+ **extra_args,
257
+ **kwargs,
258
+ )
259
+
260
+ output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
261
+
262
+ if self.patch_size > 1:
263
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
264
+
265
+ output = self.postprocess_conv(output) + output
266
+
267
+ if return_info:
268
+ return output, info
269
+
270
+ return output
271
+
272
+ def forward(
273
+ self,
274
+ x,
275
+ t,
276
+ cross_attn_cond=None,
277
+ cross_attn_cond_mask=None,
278
+ negative_cross_attn_cond=None,
279
+ negative_cross_attn_mask=None,
280
+ input_concat_cond=None,
281
+ global_embed=None,
282
+ negative_global_embed=None,
283
+ prepend_cond=None,
284
+ prepend_cond_mask=None,
285
+ cfg_scale=1.0,
286
+ cfg_dropout_prob=0.0,
287
+ causal=False,
288
+ scale_phi=0.0,
289
+ mask=None,
290
+ return_info=False,
291
+ **kwargs,
292
+ ):
293
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
294
+
295
+ if cross_attn_cond_mask is not None:
296
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
297
+
298
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
299
+
300
+ if prepend_cond_mask is not None:
301
+ prepend_cond_mask = prepend_cond_mask.bool()
302
+
303
+ # CFG dropout
304
+ if cfg_dropout_prob > 0.0:
305
+ if cross_attn_cond is not None:
306
+ null_embed = torch.zeros_like(
307
+ cross_attn_cond, device=cross_attn_cond.device
308
+ )
309
+ dropout_mask = torch.bernoulli(
310
+ torch.full(
311
+ (cross_attn_cond.shape[0], 1, 1),
312
+ cfg_dropout_prob,
313
+ device=cross_attn_cond.device,
314
+ )
315
+ ).to(torch.bool)
316
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
317
+
318
+ if prepend_cond is not None:
319
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
320
+ dropout_mask = torch.bernoulli(
321
+ torch.full(
322
+ (prepend_cond.shape[0], 1, 1),
323
+ cfg_dropout_prob,
324
+ device=prepend_cond.device,
325
+ )
326
+ ).to(torch.bool)
327
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
328
+
329
+ if cfg_scale != 1.0 and (
330
+ cross_attn_cond is not None or prepend_cond is not None
331
+ ):
332
+ # Classifier-free guidance
333
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
334
+ batch_inputs = torch.cat([x, x], dim=0)
335
+ batch_timestep = torch.cat([t, t], dim=0)
336
+
337
+ if global_embed is not None:
338
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
339
+ else:
340
+ batch_global_cond = None
341
+
342
+ if input_concat_cond is not None:
343
+ batch_input_concat_cond = torch.cat(
344
+ [input_concat_cond, input_concat_cond], dim=0
345
+ )
346
+ else:
347
+ batch_input_concat_cond = None
348
+
349
+ batch_cond = None
350
+ batch_cond_masks = None
351
+
352
+ # Handle CFG for cross-attention conditioning
353
+ if cross_attn_cond is not None:
354
+ null_embed = torch.zeros_like(
355
+ cross_attn_cond, device=cross_attn_cond.device
356
+ )
357
+
358
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
359
+ if negative_cross_attn_cond is not None:
360
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
361
+ if negative_cross_attn_mask is not None:
362
+ negative_cross_attn_mask = negative_cross_attn_mask.to(
363
+ torch.bool
364
+ ).unsqueeze(2)
365
+
366
+ negative_cross_attn_cond = torch.where(
367
+ negative_cross_attn_mask,
368
+ negative_cross_attn_cond,
369
+ null_embed,
370
+ )
371
+
372
+ batch_cond = torch.cat(
373
+ [cross_attn_cond, negative_cross_attn_cond], dim=0
374
+ )
375
+
376
+ else:
377
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
378
+
379
+ if cross_attn_cond_mask is not None:
380
+ batch_cond_masks = torch.cat(
381
+ [cross_attn_cond_mask, cross_attn_cond_mask], dim=0
382
+ )
383
+
384
+ batch_prepend_cond = None
385
+ batch_prepend_cond_mask = None
386
+
387
+ if prepend_cond is not None:
388
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
389
+
390
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
391
+
392
+ if prepend_cond_mask is not None:
393
+ batch_prepend_cond_mask = torch.cat(
394
+ [prepend_cond_mask, prepend_cond_mask], dim=0
395
+ )
396
+
397
+ if mask is not None:
398
+ batch_masks = torch.cat([mask, mask], dim=0)
399
+ else:
400
+ batch_masks = None
401
+
402
+ batch_output = self._forward(
403
+ batch_inputs,
404
+ batch_timestep,
405
+ cross_attn_cond=batch_cond,
406
+ cross_attn_cond_mask=batch_cond_masks,
407
+ mask=batch_masks,
408
+ input_concat_cond=batch_input_concat_cond,
409
+ global_embed=batch_global_cond,
410
+ prepend_cond=batch_prepend_cond,
411
+ prepend_cond_mask=batch_prepend_cond_mask,
412
+ return_info=return_info,
413
+ **kwargs,
414
+ )
415
+
416
+ if return_info:
417
+ batch_output, info = batch_output
418
+
419
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
420
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
421
+
422
+ # CFG Rescale
423
+ if scale_phi != 0.0:
424
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
425
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
426
+ output = (
427
+ scale_phi * (cfg_output * (cond_out_std / out_cfg_std))
428
+ + (1 - scale_phi) * cfg_output
429
+ )
430
+ else:
431
+ output = cfg_output
432
+
433
+ if return_info:
434
+ return output, info
435
+
436
+ return output
437
+
438
+ else:
439
+ return self._forward(
440
+ x,
441
+ t,
442
+ cross_attn_cond=cross_attn_cond,
443
+ cross_attn_cond_mask=cross_attn_cond_mask,
444
+ input_concat_cond=input_concat_cond,
445
+ global_embed=global_embed,
446
+ prepend_cond=prepend_cond,
447
+ prepend_cond_mask=prepend_cond_mask,
448
+ mask=mask,
449
+ return_info=return_info,
450
+ **kwargs,
451
+ )
src/YingMusicSinger/utils/stable_audio_tools/factory.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
+ def create_model_from_config(model_config):
5
+ model_type = model_config.get("model_type", None)
6
+
7
+ assert model_type is not None, "model_type must be specified in model config"
8
+
9
+ if model_type == "autoencoder":
10
+ from .autoencoders import create_autoencoder_from_config
11
+
12
+ return create_autoencoder_from_config(model_config)
13
+ elif model_type == "diffusion_uncond":
14
+ from .diffusion import create_diffusion_uncond_from_config
15
+
16
+ return create_diffusion_uncond_from_config(model_config)
17
+ elif (
18
+ model_type == "diffusion_cond"
19
+ or model_type == "diffusion_cond_inpaint"
20
+ or model_type == "diffusion_prior"
21
+ ):
22
+ from .diffusion import create_diffusion_cond_from_config
23
+
24
+ return create_diffusion_cond_from_config(model_config)
25
+ elif model_type == "diffusion_autoencoder":
26
+ from .autoencoders import create_diffAE_from_config
27
+
28
+ return create_diffAE_from_config(model_config)
29
+ elif model_type == "lm":
30
+ from .lm import create_audio_lm_from_config
31
+
32
+ return create_audio_lm_from_config(model_config)
33
+ else:
34
+ raise NotImplementedError(f"Unknown model type: {model_type}")
35
+
36
+
37
+ def create_model_from_config_path(model_config_path):
38
+ with open(model_config_path) as f:
39
+ model_config = json.load(f)
40
+
41
+ return create_model_from_config(model_config)
42
+
43
+
44
+ def create_pretransform_from_config(pretransform_config, sample_rate):
45
+ pretransform_type = pretransform_config.get("type", None)
46
+
47
+ assert pretransform_type is not None, (
48
+ "type must be specified in pretransform config"
49
+ )
50
+
51
+ if pretransform_type == "autoencoder":
52
+ from .autoencoders import create_autoencoder_from_config
53
+ from .pretransforms import AutoencoderPretransform
54
+
55
+ # Create fake top-level config to pass sample rate to autoencoder constructor
56
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
57
+ autoencoder_config = {
58
+ "sample_rate": sample_rate,
59
+ "model": pretransform_config["config"],
60
+ }
61
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
62
+
63
+ scale = pretransform_config.get("scale", 1.0)
64
+ model_half = pretransform_config.get("model_half", False)
65
+ iterate_batch = pretransform_config.get("iterate_batch", False)
66
+ chunked = pretransform_config.get("chunked", False)
67
+
68
+ pretransform = AutoencoderPretransform(
69
+ autoencoder,
70
+ scale=scale,
71
+ model_half=model_half,
72
+ iterate_batch=iterate_batch,
73
+ chunked=chunked,
74
+ )
75
+ elif pretransform_type == "wavelet":
76
+ from .pretransforms import WaveletPretransform
77
+
78
+ wavelet_config = pretransform_config["config"]
79
+ channels = wavelet_config["channels"]
80
+ levels = wavelet_config["levels"]
81
+ wavelet = wavelet_config["wavelet"]
82
+
83
+ pretransform = WaveletPretransform(channels, levels, wavelet)
84
+ elif pretransform_type == "pqmf":
85
+ from .pretransforms import PQMFPretransform
86
+
87
+ pqmf_config = pretransform_config["config"]
88
+ pretransform = PQMFPretransform(**pqmf_config)
89
+ elif pretransform_type == "dac_pretrained":
90
+ from .pretransforms import PretrainedDACPretransform
91
+
92
+ pretrained_dac_config = pretransform_config["config"]
93
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
94
+ elif pretransform_type == "audiocraft_pretrained":
95
+ from .pretransforms import AudiocraftCompressionPretransform
96
+
97
+ audiocraft_config = pretransform_config["config"]
98
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
99
+ else:
100
+ raise NotImplementedError(f"Unknown pretransform type: {pretransform_type}")
101
+
102
+ enable_grad = pretransform_config.get("enable_grad", False)
103
+ pretransform.enable_grad = enable_grad
104
+
105
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
106
+
107
+ return pretransform
108
+
109
+
110
+ def create_bottleneck_from_config(bottleneck_config):
111
+ bottleneck_type = bottleneck_config.get("type", None)
112
+
113
+ assert bottleneck_type is not None, "type must be specified in bottleneck config"
114
+
115
+ if bottleneck_type == "tanh":
116
+ from .bottleneck import TanhBottleneck
117
+
118
+ bottleneck = TanhBottleneck()
119
+ elif bottleneck_type == "vae":
120
+ from .bottleneck import VAEBottleneck
121
+
122
+ bottleneck = VAEBottleneck()
123
+ elif bottleneck_type == "rvq":
124
+ from .bottleneck import RVQBottleneck
125
+
126
+ quantizer_params = {
127
+ "dim": 128,
128
+ "codebook_size": 1024,
129
+ "num_quantizers": 8,
130
+ "decay": 0.99,
131
+ "kmeans_init": True,
132
+ "kmeans_iters": 50,
133
+ "threshold_ema_dead_code": 2,
134
+ }
135
+
136
+ quantizer_params.update(bottleneck_config["config"])
137
+
138
+ bottleneck = RVQBottleneck(**quantizer_params)
139
+ elif bottleneck_type == "dac_rvq":
140
+ from .bottleneck import DACRVQBottleneck
141
+
142
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
143
+
144
+ elif bottleneck_type == "rvq_vae":
145
+ from .bottleneck import RVQVAEBottleneck
146
+
147
+ quantizer_params = {
148
+ "dim": 128,
149
+ "codebook_size": 1024,
150
+ "num_quantizers": 8,
151
+ "decay": 0.99,
152
+ "kmeans_init": True,
153
+ "kmeans_iters": 50,
154
+ "threshold_ema_dead_code": 2,
155
+ }
156
+
157
+ quantizer_params.update(bottleneck_config["config"])
158
+
159
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
160
+
161
+ elif bottleneck_type == "dac_rvq_vae":
162
+ from .bottleneck import DACRVQVAEBottleneck
163
+
164
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
165
+ elif bottleneck_type == "l2_norm":
166
+ from .bottleneck import L2Bottleneck
167
+
168
+ bottleneck = L2Bottleneck()
169
+ elif bottleneck_type == "wasserstein":
170
+ from .bottleneck import WassersteinBottleneck
171
+
172
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
173
+ elif bottleneck_type == "fsq":
174
+ from .bottleneck import FSQBottleneck
175
+
176
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
177
+ else:
178
+ raise NotImplementedError(f"Unknown bottleneck type: {bottleneck_type}")
179
+
180
+ requires_grad = bottleneck_config.get("requires_grad", True)
181
+ if not requires_grad:
182
+ for param in bottleneck.parameters():
183
+ param.requires_grad = False
184
+
185
+ return bottleneck
src/YingMusicSinger/utils/stable_audio_tools/pretransforms.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+
6
+ class Pretransform(nn.Module):
7
+ def __init__(self, enable_grad, io_channels, is_discrete):
8
+ super().__init__()
9
+
10
+ self.is_discrete = is_discrete
11
+ self.io_channels = io_channels
12
+ self.encoded_channels = None
13
+ self.downsampling_ratio = None
14
+
15
+ self.enable_grad = enable_grad
16
+
17
+ def encode(self, x):
18
+ raise NotImplementedError
19
+
20
+ def decode(self, z):
21
+ raise NotImplementedError
22
+
23
+ def tokenize(self, x):
24
+ raise NotImplementedError
25
+
26
+ def decode_tokens(self, tokens):
27
+ raise NotImplementedError
28
+
29
+
30
+ class AutoencoderPretransform(Pretransform):
31
+ def __init__(
32
+ self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False
33
+ ):
34
+ super().__init__(
35
+ enable_grad=False,
36
+ io_channels=model.io_channels,
37
+ is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete,
38
+ )
39
+ self.model = model
40
+ self.model.requires_grad_(False).eval()
41
+ self.scale = scale
42
+ self.downsampling_ratio = model.downsampling_ratio
43
+ self.io_channels = model.io_channels
44
+ self.sample_rate = model.sample_rate
45
+
46
+ self.model_half = model_half
47
+ self.iterate_batch = iterate_batch
48
+
49
+ self.encoded_channels = model.latent_dim
50
+ self.latent_dim = model.latent_dim
51
+
52
+ self.chunked = chunked
53
+ self.num_quantizers = (
54
+ model.bottleneck.num_quantizers
55
+ if model.bottleneck is not None and model.bottleneck.is_discrete
56
+ else None
57
+ )
58
+ self.codebook_size = (
59
+ model.bottleneck.codebook_size
60
+ if model.bottleneck is not None and model.bottleneck.is_discrete
61
+ else None
62
+ )
63
+
64
+ if self.model_half:
65
+ self.model.half()
66
+
67
+ def encode(self, x, **kwargs):
68
+ if self.model_half:
69
+ x = x.half()
70
+ self.model.to(torch.float16)
71
+
72
+ encoded = self.model.encode_audio(
73
+ x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
74
+ )
75
+
76
+ if self.model_half:
77
+ encoded = encoded.float()
78
+
79
+ return encoded / self.scale
80
+
81
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
82
+ """
83
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
84
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
85
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
86
+ # and therefore you likely could use the same values with decode_audio.
87
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
88
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
89
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
90
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
91
+ Smaller chunk_size uses less memory, but more compute.
92
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
93
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
94
+ """
95
+ if not chunked:
96
+ # default behavior. Encode the entire audio in parallel
97
+ return self.encode(audio, **kwargs)
98
+ else:
99
+ # CHUNKED ENCODING
100
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
101
+ samples_per_latent = self.downsampling_ratio
102
+ total_size = audio.shape[2] # in samples
103
+ batch_size = audio.shape[0]
104
+ chunk_size *= samples_per_latent # converting metric in latents to samples
105
+ overlap *= samples_per_latent # converting metric in latents to samples
106
+ hop_size = chunk_size - overlap
107
+ chunks = []
108
+ for i in range(0, total_size - chunk_size + 1, hop_size):
109
+ chunk = audio[:, :, i : i + chunk_size]
110
+ chunks.append(chunk)
111
+ if i + chunk_size != total_size:
112
+ # Final chunk
113
+ chunk = audio[:, :, -chunk_size:]
114
+ chunks.append(chunk)
115
+ chunks = torch.stack(chunks)
116
+ num_chunks = chunks.shape[0]
117
+ # Note: y_size might be a different value from the latent length used in diffusion training
118
+ # because we can encode audio of varying lengths
119
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
120
+ y_size = total_size // samples_per_latent
121
+ # Create an empty latent, we will populate it with chunks as we encode them
122
+ y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(
123
+ audio.device
124
+ )
125
+ for i in range(num_chunks):
126
+ x_chunk = chunks[i, :]
127
+ # encode the chunk
128
+ y_chunk = self.encode(x_chunk)
129
+ # figure out where to put the audio along the time domain
130
+ if i == num_chunks - 1:
131
+ # final chunk always goes at the end
132
+ t_end = y_size
133
+ t_start = t_end - y_chunk.shape[2]
134
+ else:
135
+ t_start = i * hop_size // samples_per_latent
136
+ t_end = t_start + chunk_size // samples_per_latent
137
+ # remove the edges of the overlaps
138
+ ol = overlap // samples_per_latent // 2
139
+ chunk_start = 0
140
+ chunk_end = y_chunk.shape[2]
141
+ if i > 0:
142
+ # no overlap for the start of the first chunk
143
+ t_start += ol
144
+ chunk_start += ol
145
+ if i < num_chunks - 1:
146
+ # no overlap for the end of the last chunk
147
+ t_end -= ol
148
+ chunk_end -= ol
149
+ # paste the chunked audio into our y_final output audio
150
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
151
+ return y_final
152
+
153
+ def decode(self, z, **kwargs):
154
+ z = z * self.scale
155
+
156
+ if self.model_half:
157
+ z = z.half()
158
+ self.model.to(torch.float16)
159
+
160
+ decoded = self.model.decode_audio(
161
+ z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
162
+ )
163
+
164
+ if self.model_half:
165
+ decoded = decoded.float()
166
+
167
+ return decoded
168
+
169
+ def decode_audio(
170
+ self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs
171
+ ):
172
+ if not chunked:
173
+ # default behavior. Decode the entire latent in parallel
174
+ return self.decode(latents, **kwargs)
175
+ else:
176
+ # chunked decoding
177
+ hop_size = chunk_size - overlap
178
+ total_size = latents.shape[2]
179
+ batch_size = latents.shape[0]
180
+ chunks = []
181
+ i = 0
182
+ for i in range(0, total_size - chunk_size + 1, hop_size):
183
+ chunk = latents[:, :, i : i + chunk_size]
184
+ chunks.append(chunk)
185
+ if i + chunk_size != total_size:
186
+ # Final chunk
187
+ chunk = latents[:, :, -chunk_size:]
188
+ chunks.append(chunk)
189
+ chunks = torch.stack(chunks)
190
+ num_chunks = chunks.shape[0]
191
+ # samples_per_latent is just the downsampling ratio
192
+ samples_per_latent = self.downsampling_ratio
193
+ # Create an empty waveform, we will populate it with chunks as decode them
194
+ y_size = total_size * samples_per_latent
195
+ y_final = torch.zeros((batch_size, self.io_channels, y_size)).to(
196
+ latents.device
197
+ )
198
+ for i in range(num_chunks):
199
+ x_chunk = chunks[i, :]
200
+ # decode the chunk
201
+ y_chunk = self.decode(x_chunk)
202
+ # figure out where to put the audio along the time domain
203
+ if i == num_chunks - 1:
204
+ # final chunk always goes at the end
205
+ t_end = y_size
206
+ t_start = t_end - y_chunk.shape[2]
207
+ else:
208
+ t_start = i * hop_size * samples_per_latent
209
+ t_end = t_start + chunk_size * samples_per_latent
210
+ # remove the edges of the overlaps
211
+ ol = (overlap // 2) * samples_per_latent
212
+ chunk_start = 0
213
+ chunk_end = y_chunk.shape[2]
214
+ if i > 0:
215
+ # no overlap for the start of the first chunk
216
+ t_start += ol
217
+ chunk_start += ol
218
+ if i < num_chunks - 1:
219
+ # no overlap for the end of the last chunk
220
+ t_end -= ol
221
+ chunk_end -= ol
222
+ # paste the chunked audio into our y_final output audio
223
+ y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
224
+ return y_final
225
+
226
+ def tokenize(self, x, **kwargs):
227
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
228
+
229
+ _, info = self.model.encode(x, return_info=True, **kwargs)
230
+
231
+ return info[self.model.bottleneck.tokens_id]
232
+
233
+ def decode_tokens(self, tokens, **kwargs):
234
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
235
+
236
+ return self.model.decode_tokens(tokens, **kwargs)
237
+
238
+ def load_state_dict(self, state_dict, strict=True):
239
+ self.model.load_state_dict(state_dict, strict=strict)
240
+
241
+
242
+ class WaveletPretransform(Pretransform):
243
+ def __init__(self, channels, levels, wavelet):
244
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
245
+
246
+ from .wavelets import WaveletDecode1d, WaveletEncode1d
247
+
248
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
249
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
250
+
251
+ self.downsampling_ratio = 2**levels
252
+ self.io_channels = channels
253
+ self.encoded_channels = channels * self.downsampling_ratio
254
+
255
+ def encode(self, x):
256
+ return self.encoder(x)
257
+
258
+ def decode(self, z):
259
+ return self.decoder(z)
260
+
261
+
262
+ class PQMFPretransform(Pretransform):
263
+ def __init__(self, attenuation=100, num_bands=16):
264
+ # TODO: Fix PQMF to take in in-channels
265
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
266
+ from .pqmf import PQMF
267
+
268
+ self.pqmf = PQMF(attenuation, num_bands)
269
+
270
+ def encode(self, x):
271
+ # x is (Batch x Channels x Time)
272
+ x = self.pqmf.forward(x)
273
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
274
+ # but Pretransform needs Batch x Channels x Time
275
+ # so concatenate channels and bands into one axis
276
+ return rearrange(x, "b c n t -> b (c n) t")
277
+
278
+ def decode(self, x):
279
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
280
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
281
+ # returns (Batch x Channels x Time)
282
+ return self.pqmf.inverse(x)
283
+
284
+
285
+ class PretrainedDACPretransform(Pretransform):
286
+ def __init__(
287
+ self,
288
+ model_type="44khz",
289
+ model_bitrate="8kbps",
290
+ scale=1.0,
291
+ quantize_on_decode: bool = True,
292
+ chunked=True,
293
+ ):
294
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
295
+
296
+ import dac
297
+
298
+ model_path = dac.utils.download(
299
+ model_type=model_type, model_bitrate=model_bitrate
300
+ )
301
+
302
+ self.model = dac.DAC.load(model_path)
303
+
304
+ self.quantize_on_decode = quantize_on_decode
305
+
306
+ if model_type == "44khz":
307
+ self.downsampling_ratio = 512
308
+ else:
309
+ self.downsampling_ratio = 320
310
+
311
+ self.io_channels = 1
312
+
313
+ self.scale = scale
314
+
315
+ self.chunked = chunked
316
+
317
+ self.encoded_channels = self.model.latent_dim
318
+
319
+ self.num_quantizers = self.model.n_codebooks
320
+
321
+ self.codebook_size = self.model.codebook_size
322
+
323
+ def encode(self, x):
324
+ latents = self.model.encoder(x)
325
+
326
+ if self.quantize_on_decode:
327
+ output = latents
328
+ else:
329
+ z, _, _, _, _ = self.model.quantizer(
330
+ latents, n_quantizers=self.model.n_codebooks
331
+ )
332
+ output = z
333
+
334
+ if self.scale != 1.0:
335
+ output = output / self.scale
336
+
337
+ return output
338
+
339
+ def decode(self, z):
340
+ if self.scale != 1.0:
341
+ z = z * self.scale
342
+
343
+ if self.quantize_on_decode:
344
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
345
+
346
+ return self.model.decode(z)
347
+
348
+ def tokenize(self, x):
349
+ return self.model.encode(x)[1]
350
+
351
+ def decode_tokens(self, tokens):
352
+ latents = self.model.quantizer.from_codes(tokens)
353
+ return self.model.decode(latents)
354
+
355
+
356
+ class AudiocraftCompressionPretransform(Pretransform):
357
+ def __init__(
358
+ self,
359
+ model_type="facebook/encodec_32khz",
360
+ scale=1.0,
361
+ quantize_on_decode: bool = True,
362
+ ):
363
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
364
+
365
+ try:
366
+ from audiocraft.models import CompressionModel
367
+ except ImportError:
368
+ raise ImportError(
369
+ "Audiocraft is not installed. Please install audiocraft to use Audiocraft models."
370
+ )
371
+
372
+ self.model = CompressionModel.get_pretrained(model_type)
373
+
374
+ self.quantize_on_decode = quantize_on_decode
375
+
376
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
377
+
378
+ self.sample_rate = self.model.sample_rate
379
+
380
+ self.io_channels = self.model.channels
381
+
382
+ self.scale = scale
383
+
384
+ # self.encoded_channels = self.model.latent_dim
385
+
386
+ self.num_quantizers = self.model.num_codebooks
387
+
388
+ self.codebook_size = self.model.cardinality
389
+
390
+ self.model.to(torch.float16).eval().requires_grad_(False)
391
+
392
+ def encode(self, x):
393
+ assert False, "Audiocraft compression models do not support continuous encoding"
394
+
395
+ # latents = self.model.encoder(x)
396
+
397
+ # if self.quantize_on_decode:
398
+ # output = latents
399
+ # else:
400
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
401
+ # output = z
402
+
403
+ # if self.scale != 1.0:
404
+ # output = output / self.scale
405
+
406
+ # return output
407
+
408
+ def decode(self, z):
409
+ assert False, "Audiocraft compression models do not support continuous decoding"
410
+
411
+ # if self.scale != 1.0:
412
+ # z = z * self.scale
413
+
414
+ # if self.quantize_on_decode:
415
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
416
+
417
+ # return self.model.decode(z)
418
+
419
+ def tokenize(self, x):
420
+ with torch.cuda.amp.autocast(enabled=False):
421
+ return self.model.encode(x.to(torch.float16))[0]
422
+
423
+ def decode_tokens(self, tokens):
424
+ with torch.cuda.amp.autocast(enabled=False):
425
+ return self.model.decode(tokens)