Spaces:
Build error
Build error
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +47 -0
- src/YingMusicSinger/infer/YingMusicSinger.py +263 -0
- src/YingMusicSinger/melody/Gconform.py +298 -0
- src/YingMusicSinger/melody/Gconv.py +60 -0
- src/YingMusicSinger/melody/SmoothMelody.py +144 -0
- src/YingMusicSinger/melody/midi_extractor.py +208 -0
- src/YingMusicSinger/models/__init__.py +1 -0
- src/YingMusicSinger/models/dit.py +472 -0
- src/YingMusicSinger/models/model.py +423 -0
- src/YingMusicSinger/models/modules.py +961 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/__init__.py +91 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/chinese_model_g2p.py +209 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/cleaners.py +28 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/english.py +202 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/french.py +149 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/german.py +94 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/korean.py +81 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/mandarin.py +603 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/text_tokenizers.py +82 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json +372 -0
- src/YingMusicSinger/utils/f5_tts/g2p/g2p_generation.py +129 -0
- src/YingMusicSinger/utils/f5_tts/g2p/infer_dpo.py +277 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt +41 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt +3 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/config.json +819 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/poly_bert_model.onnx +3 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polychar.txt +159 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict.json +393 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict_r.json +393 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/vocab.txt +0 -0
- src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt +429 -0
- src/YingMusicSinger/utils/f5_tts/g2p/utils/front_utils.py +18 -0
- src/YingMusicSinger/utils/f5_tts/g2p/utils/g2p.py +139 -0
- src/YingMusicSinger/utils/f5_tts/g2p/utils/log.py +52 -0
- src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json +335 -0
- src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/LangSegment.py +1251 -0
- src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/__init__.py +24 -0
- src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/__init__.py +0 -0
- src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/num.py +332 -0
- src/YingMusicSinger/utils/stable_audio_tools/__init__.py +0 -0
- src/YingMusicSinger/utils/stable_audio_tools/adp.py +1686 -0
- src/YingMusicSinger/utils/stable_audio_tools/autoencoders.py +975 -0
- src/YingMusicSinger/utils/stable_audio_tools/blocks.py +398 -0
- src/YingMusicSinger/utils/stable_audio_tools/bottleneck copy.py +393 -0
- src/YingMusicSinger/utils/stable_audio_tools/bottleneck.py +393 -0
- src/YingMusicSinger/utils/stable_audio_tools/conditioners.py +664 -0
- src/YingMusicSinger/utils/stable_audio_tools/diffusion.py +740 -0
- src/YingMusicSinger/utils/stable_audio_tools/dit.py +451 -0
- src/YingMusicSinger/utils/stable_audio_tools/factory.py +185 -0
- src/YingMusicSinger/utils/stable_audio_tools/pretransforms.py +425 -0
.gitattributes
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.avi filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.dylib filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.svg filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.m4a filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.mkv filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.docx filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.ppt filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pptx filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.ico filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.xls filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.doc filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.dll filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.mov filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.ttf filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
|
src/YingMusicSinger/infer/YingMusicSinger.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hydra
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchaudio
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from ema_pytorch import EMA
|
| 7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
|
| 11 |
+
from src.YingMusicSinger.models.model import Singer
|
| 12 |
+
from src.YingMusicSinger.utils.cnen_tokenizer import CNENTokenizer
|
| 13 |
+
from src.YingMusicSinger.utils.lrc_align import (
|
| 14 |
+
align_lrc_put_to_front,
|
| 15 |
+
align_lrc_sentence_level,
|
| 16 |
+
)
|
| 17 |
+
from src.YingMusicSinger.utils.mel_spectrogram import MelodySpectrogram
|
| 18 |
+
from src.YingMusicSinger.utils.stable_audio_tools.vae_copysyn import StableAudioInfer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class YingMusicSinger(nn.Module, PyTorchModelHubMixin):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
model_cfg_path,
|
| 25 |
+
ckpt_path=None,
|
| 26 |
+
vae_config_path=None,
|
| 27 |
+
vae_ckpt_path=None,
|
| 28 |
+
midi_teacher_ckpt_path=None,
|
| 29 |
+
is_distilled=False,
|
| 30 |
+
use_ema=True,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.cfg = OmegaConf.load(model_cfg_path)
|
| 34 |
+
model_cls = hydra.utils.get_class(
|
| 35 |
+
f"src.YingMusicSinger.models.{self.cfg.model.backbone}"
|
| 36 |
+
)
|
| 37 |
+
self.melody_input_source = self.cfg.model.melody_input_source
|
| 38 |
+
self.is_tts_pretrain = self.cfg.model.is_tts_pretrain
|
| 39 |
+
|
| 40 |
+
self.model = Singer(
|
| 41 |
+
transformer=model_cls(
|
| 42 |
+
**self.cfg.model.arch,
|
| 43 |
+
text_num_embeds=self.cfg.datasets_cfg.text_num_embeds,
|
| 44 |
+
mel_dim=self.cfg.model.mel_spec.n_mel_channels,
|
| 45 |
+
use_guidance_scale_embed=is_distilled,
|
| 46 |
+
),
|
| 47 |
+
mel_spec_kwargs=self.cfg.model.mel_spec,
|
| 48 |
+
is_tts_pretrain=self.is_tts_pretrain,
|
| 49 |
+
melody_input_source=self.melody_input_source,
|
| 50 |
+
cka_disabled=self.cfg.model.cka_disabled,
|
| 51 |
+
num_channels=None,
|
| 52 |
+
extra_parameters=self.cfg.extra_parameters,
|
| 53 |
+
distill_stage=1,
|
| 54 |
+
use_guidance_scale_embed=is_distilled,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.vae = StableAudioInfer(
|
| 58 |
+
model_config_path=vae_config_path,
|
| 59 |
+
model_ckpt_path=vae_ckpt_path,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
self._need_midi = self.melody_input_source in {
|
| 63 |
+
"some_pretrain",
|
| 64 |
+
"some_pretrain_fuzzdisturb",
|
| 65 |
+
"some_pretrain_postprocess_embedding",
|
| 66 |
+
}
|
| 67 |
+
self.midi_teacher = None
|
| 68 |
+
if self._need_midi:
|
| 69 |
+
self.midi_teacher = MIDIExtractor()
|
| 70 |
+
if midi_teacher_ckpt_path is not None:
|
| 71 |
+
self.midi_teacher._load_form_ckpt(midi_teacher_ckpt_path)
|
| 72 |
+
for p in self.midi_teacher.parameters():
|
| 73 |
+
p.requires_grad = False
|
| 74 |
+
|
| 75 |
+
self.melody_spectrogram_extract = MelodySpectrogram()
|
| 76 |
+
|
| 77 |
+
self.vae_frame_rate = 44100 / 2048
|
| 78 |
+
|
| 79 |
+
if ckpt_path is not None:
|
| 80 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 81 |
+
if use_ema:
|
| 82 |
+
ema_model = EMA(self.model, include_online_model=False)
|
| 83 |
+
ema_model.load_state_dict(ckpt["ema_model_state_dict"])
|
| 84 |
+
|
| 85 |
+
self.model = ema_model.ema_model
|
| 86 |
+
else:
|
| 87 |
+
self.model.load_state_dict(ckpt["model_state_dict"])
|
| 88 |
+
|
| 89 |
+
self.cnen_tokenizer = CNENTokenizer()
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def device(self):
|
| 93 |
+
return next(self.parameters()).device
|
| 94 |
+
|
| 95 |
+
def prepare_input(
|
| 96 |
+
self,
|
| 97 |
+
ref_audio_path,
|
| 98 |
+
melody_audio_path,
|
| 99 |
+
ref_text,
|
| 100 |
+
target_text,
|
| 101 |
+
sil_len_to_end,
|
| 102 |
+
lrc_align_mode,
|
| 103 |
+
):
|
| 104 |
+
ref_audio, ref_audio_sr = torchaudio.load(ref_audio_path)
|
| 105 |
+
silence = torch.zeros(ref_audio.shape[0], int(ref_audio_sr * sil_len_to_end))
|
| 106 |
+
ref_wav = torch.cat([ref_audio, silence], dim=1)
|
| 107 |
+
ref_latent = self.vae.encode_audio(ref_wav, in_sr=ref_audio_sr).transpose(
|
| 108 |
+
1, 2
|
| 109 |
+
) # [B, T, D]
|
| 110 |
+
|
| 111 |
+
melody_wav, melody_sr = torchaudio.load(melody_audio_path)
|
| 112 |
+
melody_latent = self.vae.encode_audio(melody_wav, in_sr=melody_sr).transpose(
|
| 113 |
+
1, 2
|
| 114 |
+
) # [B, T, D]
|
| 115 |
+
|
| 116 |
+
midi_in = torch.cat([ref_latent, melody_latent], dim=1)
|
| 117 |
+
if self.is_tts_pretrain:
|
| 118 |
+
midi_in = torch.zeros_like(midi_in)
|
| 119 |
+
|
| 120 |
+
ref_latent_len = ref_latent.shape[1]
|
| 121 |
+
total_len = int(ref_latent.shape[1] + melody_latent.shape[1])
|
| 122 |
+
|
| 123 |
+
if self._need_midi:
|
| 124 |
+
ref_mel = self.melody_spectrogram_extract(audio=ref_wav, sr=ref_audio_sr)
|
| 125 |
+
melody_mel = self.melody_spectrogram_extract(audio=melody_wav, sr=melody_sr)
|
| 126 |
+
melody_mel_spec = torch.cat([ref_mel, melody_mel], dim=2)
|
| 127 |
+
else:
|
| 128 |
+
raise NotImplementedError()
|
| 129 |
+
|
| 130 |
+
assert isinstance(ref_text, str) and isinstance(target_text, str)
|
| 131 |
+
text_list = [ref_text] + [target_text]
|
| 132 |
+
|
| 133 |
+
if lrc_align_mode == "put_to_front":
|
| 134 |
+
lrc_token, _ = align_lrc_put_to_front(
|
| 135 |
+
tokenizer=self.cnen_tokenizer,
|
| 136 |
+
lrc_start_times=None,
|
| 137 |
+
lrc_lines=text_list,
|
| 138 |
+
total_lens=total_len,
|
| 139 |
+
)
|
| 140 |
+
elif lrc_align_mode == "sentence_level":
|
| 141 |
+
lrc_token, _ = align_lrc_sentence_level(
|
| 142 |
+
tokenizer=self.cnen_tokenizer,
|
| 143 |
+
lrc_start_times=[0.0, ref_latent_len / self.vae_frame_rate],
|
| 144 |
+
lrc_lines=text_list,
|
| 145 |
+
total_lens=total_len,
|
| 146 |
+
vae_frame_rate=self.vae_frame_rate,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
raise ValueError(f"Unsupported lrc_align_mode: {lrc_align_mode}")
|
| 150 |
+
|
| 151 |
+
text_tokens = (
|
| 152 |
+
torch.tensor(lrc_token, dtype=torch.int64).unsqueeze(0).to(self.device)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
midi_p, bound_p = None, None
|
| 156 |
+
if self._need_midi:
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
midi_p, bound_p = self.midi_teacher(melody_mel_spec.transpose(1, 2))
|
| 159 |
+
|
| 160 |
+
return (
|
| 161 |
+
ref_latent,
|
| 162 |
+
ref_latent_len,
|
| 163 |
+
text_tokens,
|
| 164 |
+
total_len,
|
| 165 |
+
midi_in,
|
| 166 |
+
midi_p,
|
| 167 |
+
bound_p,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def forward(
|
| 171 |
+
self,
|
| 172 |
+
ref_audio_path,
|
| 173 |
+
melody_audio_path,
|
| 174 |
+
ref_text,
|
| 175 |
+
target_text,
|
| 176 |
+
lrc_align_mode: str = "sentence_level",
|
| 177 |
+
sil_len_to_end: float = 0.5,
|
| 178 |
+
t_shift: float = 0.5,
|
| 179 |
+
nfe_step: int = 32,
|
| 180 |
+
cfg_strength: float = 3.0,
|
| 181 |
+
seed: int = 666,
|
| 182 |
+
is_tts_pretrain: bool = False,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Args:
|
| 186 |
+
ref_audio_path: Path to the reference audio (for timbre)
|
| 187 |
+
melody_audio_path: Path to the melody reference audio (provides target duration and melody information)
|
| 188 |
+
ref_text: Text corresponding to the reference audio
|
| 189 |
+
target_text: Target text to be synthesized
|
| 190 |
+
lrc_align_mode: Lyric alignment mode "sentence_level" | "put_to_front"
|
| 191 |
+
sil_len_to_end: Duration of silence appended to the end of the reference audio (seconds)
|
| 192 |
+
t_shift: Sampling time offset
|
| 193 |
+
nfe_step: ODE sampling steps
|
| 194 |
+
cfg_strength: CFG strength
|
| 195 |
+
seed: Random seed
|
| 196 |
+
is_tts_pretrain: If True, melody is not provided (TTS mode)
|
| 197 |
+
"""
|
| 198 |
+
ref_latent, ref_latent_len, text_tokens, total_len, midi_in, midi_p, bound_p = (
|
| 199 |
+
self.prepare_input(
|
| 200 |
+
ref_audio_path=ref_audio_path,
|
| 201 |
+
melody_audio_path=melody_audio_path,
|
| 202 |
+
ref_text=ref_text,
|
| 203 |
+
target_text=target_text,
|
| 204 |
+
sil_len_to_end=sil_len_to_end,
|
| 205 |
+
lrc_align_mode=lrc_align_mode,
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
assert midi_p is not None and bound_p is not None
|
| 210 |
+
with torch.inference_mode():
|
| 211 |
+
generated_latent, _ = self.model.sample(
|
| 212 |
+
cond=ref_latent,
|
| 213 |
+
midi_in=midi_in,
|
| 214 |
+
text=text_tokens,
|
| 215 |
+
duration=total_len,
|
| 216 |
+
steps=nfe_step,
|
| 217 |
+
cfg_strength=cfg_strength,
|
| 218 |
+
sway_sampling_coef=None,
|
| 219 |
+
use_epss=False,
|
| 220 |
+
seed=seed,
|
| 221 |
+
midi_p=midi_p,
|
| 222 |
+
t_shift=t_shift,
|
| 223 |
+
bound_p=bound_p,
|
| 224 |
+
guidance_scale=cfg_strength,
|
| 225 |
+
)
|
| 226 |
+
generated_latent = generated_latent.to(torch.float32)
|
| 227 |
+
generated_latent = generated_latent[:, ref_latent_len:, :]
|
| 228 |
+
generated_latent = generated_latent.permute(0, 2, 1) # [B, D, T]
|
| 229 |
+
|
| 230 |
+
generated_audio = self.vae.decode_audio(generated_latent)
|
| 231 |
+
audio = rearrange(generated_audio, "b d n -> d (b n)")
|
| 232 |
+
|
| 233 |
+
audio = audio.to(torch.float32).cpu()
|
| 234 |
+
|
| 235 |
+
return audio, 44100
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
# === Export to HuggingFace safetensors (optional) ===
|
| 240 |
+
# model = YingMusicSinger(
|
| 241 |
+
# model_cfg_path="src/YingMusicSinger/config/YingMusic_Singer.yaml",
|
| 242 |
+
# ckpt_path="ckpts/YingMusicSinger_model.pt",
|
| 243 |
+
# vae_config_path="src/YingMusicSinger/config/stable_audio_2_0_vae_20hz_official.json",
|
| 244 |
+
# vae_ckpt_path="ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
|
| 245 |
+
# midi_teacher_ckpt_path="ckpts/model_ckpt_steps_100000_simplified.ckpt",
|
| 246 |
+
# )
|
| 247 |
+
# model.save_pretrained("path/to/save")
|
| 248 |
+
|
| 249 |
+
# === Inference Example ===
|
| 250 |
+
model = YingMusicSinger.from_pretrained("ASLP-lab/YingMusic-Singer")
|
| 251 |
+
model.to("cuda:0")
|
| 252 |
+
model.eval()
|
| 253 |
+
|
| 254 |
+
waveform, sample_rate = model(
|
| 255 |
+
ref_audio_path="path/to/ref_audio", # Timbre reference audio
|
| 256 |
+
melody_audio_path="path/to/melody_audio", # Melody-providing singing clip
|
| 257 |
+
ref_text="oh the reason i hold on", # Lyrics corresponding to ref_audio
|
| 258 |
+
target_text="oldest book broken watch|bare feet in grassy spot", # Modified target lyrics
|
| 259 |
+
seed=42,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
torchaudio.save("output.wav", waveform, sample_rate=sample_rate)
|
| 263 |
+
print("Saved to output.wav")
|
src/YingMusicSinger/melody/Gconform.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GLU(nn.Module):
|
| 8 |
+
def __init__(self, dim):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.dim = dim
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
out, gate = x.chunk(2, dim=self.dim)
|
| 14 |
+
|
| 15 |
+
return out * gate.sigmoid()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class conform_conv(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.act2 = nn.SiLU()
|
| 24 |
+
self.act1 = GLU(1)
|
| 25 |
+
|
| 26 |
+
self.pointwise_conv1 = nn.Conv1d(
|
| 27 |
+
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
| 31 |
+
# if self.lorder > 0:
|
| 32 |
+
# it's a causal convolution, the input will be padded with
|
| 33 |
+
# `self.lorder` frames on the left in forward (causal conv impl).
|
| 34 |
+
# else: it's a symmetrical convolution
|
| 35 |
+
|
| 36 |
+
assert (kernel_size - 1) % 2 == 0
|
| 37 |
+
padding = (kernel_size - 1) // 2
|
| 38 |
+
|
| 39 |
+
self.depthwise_conv = nn.Conv1d(
|
| 40 |
+
channels,
|
| 41 |
+
channels,
|
| 42 |
+
kernel_size,
|
| 43 |
+
stride=1,
|
| 44 |
+
padding=padding,
|
| 45 |
+
groups=channels,
|
| 46 |
+
bias=bias,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 50 |
+
|
| 51 |
+
self.pointwise_conv2 = nn.Conv1d(
|
| 52 |
+
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
|
| 53 |
+
)
|
| 54 |
+
self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
x = x.transpose(1, 2)
|
| 58 |
+
x = self.act1(self.pointwise_conv1(x))
|
| 59 |
+
x = self.depthwise_conv(x)
|
| 60 |
+
x = self.norm(x)
|
| 61 |
+
x = self.act2(x)
|
| 62 |
+
x = self.pointwise_conv2(x)
|
| 63 |
+
return self.drop(x).transpose(1, 2)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Attention(nn.Module):
|
| 67 |
+
def __init__(self, dim, heads=4, dim_head=32, conditiondim=None):
|
| 68 |
+
super().__init__()
|
| 69 |
+
if conditiondim is None:
|
| 70 |
+
conditiondim = dim
|
| 71 |
+
|
| 72 |
+
self.scale = dim_head**-0.5
|
| 73 |
+
self.heads = heads
|
| 74 |
+
hidden_dim = dim_head * heads
|
| 75 |
+
self.to_q = nn.Linear(dim, hidden_dim, bias=False)
|
| 76 |
+
self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False)
|
| 77 |
+
|
| 78 |
+
self.to_out = nn.Sequential(
|
| 79 |
+
nn.Linear(
|
| 80 |
+
hidden_dim,
|
| 81 |
+
dim,
|
| 82 |
+
),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def forward(self, q, kv=None, mask=None):
|
| 86 |
+
# b, c, h, w = x.shape
|
| 87 |
+
if kv is None:
|
| 88 |
+
kv = q
|
| 89 |
+
# q, kv = map(
|
| 90 |
+
# lambda t: rearrange(t, "b c t -> b t c", ), (q, kv)
|
| 91 |
+
# )
|
| 92 |
+
|
| 93 |
+
q = self.to_q(q)
|
| 94 |
+
k, v = self.to_kv(kv).chunk(2, dim=2)
|
| 95 |
+
|
| 96 |
+
q, k, v = map(
|
| 97 |
+
lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if mask is not None:
|
| 101 |
+
mask = mask.unsqueeze(1).unsqueeze(1)
|
| 102 |
+
|
| 103 |
+
with torch.backends.cuda.sdp_kernel(enable_math=False):
|
| 104 |
+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
| 105 |
+
|
| 106 |
+
out = rearrange(
|
| 107 |
+
out,
|
| 108 |
+
"b h t c -> b t (h c) ",
|
| 109 |
+
h=self.heads,
|
| 110 |
+
)
|
| 111 |
+
return self.to_out(out)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class conform_ffn(nn.Module):
|
| 115 |
+
def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.ln1 = nn.Linear(dim, dim * 4)
|
| 118 |
+
self.ln2 = nn.Linear(dim * 4, dim)
|
| 119 |
+
self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0.0 else nn.Identity()
|
| 120 |
+
self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0.0 else nn.Identity()
|
| 121 |
+
self.act = nn.SiLU()
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = self.ln1(x)
|
| 125 |
+
x = self.act(x)
|
| 126 |
+
x = self.drop1(x)
|
| 127 |
+
x = self.ln2(x)
|
| 128 |
+
return self.drop2(x)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class conform_blocke(nn.Module):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
dim: int,
|
| 135 |
+
kernel_size: int = 31,
|
| 136 |
+
conv_drop: float = 0.1,
|
| 137 |
+
ffn_latent_drop: float = 0.1,
|
| 138 |
+
ffn_out_drop: float = 0.1,
|
| 139 |
+
attention_drop: float = 0.1,
|
| 140 |
+
attention_heads: int = 4,
|
| 141 |
+
attention_heads_dim: int = 64,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
|
| 145 |
+
self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
|
| 146 |
+
self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim)
|
| 147 |
+
self.attdrop = (
|
| 148 |
+
nn.Dropout(attention_drop) if attention_drop > 0.0 else nn.Identity()
|
| 149 |
+
)
|
| 150 |
+
self.conv = conform_conv(
|
| 151 |
+
dim,
|
| 152 |
+
kernel_size=kernel_size,
|
| 153 |
+
DropoutL=conv_drop,
|
| 154 |
+
)
|
| 155 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 156 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 157 |
+
self.norm3 = nn.LayerNorm(dim)
|
| 158 |
+
self.norm4 = nn.LayerNorm(dim)
|
| 159 |
+
self.norm5 = nn.LayerNorm(dim)
|
| 160 |
+
|
| 161 |
+
def forward(
|
| 162 |
+
self,
|
| 163 |
+
x,
|
| 164 |
+
mask=None,
|
| 165 |
+
):
|
| 166 |
+
x = self.ffn1(self.norm1(x)) * 0.5 + x
|
| 167 |
+
|
| 168 |
+
x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x
|
| 169 |
+
x = self.conv(self.norm3(x)) + x
|
| 170 |
+
x = self.ffn2(self.norm4(x)) * 0.5 + x
|
| 171 |
+
return self.norm5(x)
|
| 172 |
+
|
| 173 |
+
# return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Gcf(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
dim: int,
|
| 180 |
+
kernel_size: int = 31,
|
| 181 |
+
conv_drop: float = 0.1,
|
| 182 |
+
ffn_latent_drop: float = 0.1,
|
| 183 |
+
ffn_out_drop: float = 0.1,
|
| 184 |
+
attention_drop: float = 0.1,
|
| 185 |
+
attention_heads: int = 4,
|
| 186 |
+
attention_heads_dim: int = 64,
|
| 187 |
+
):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.att1 = conform_blocke(
|
| 190 |
+
dim=dim,
|
| 191 |
+
kernel_size=kernel_size,
|
| 192 |
+
conv_drop=conv_drop,
|
| 193 |
+
ffn_latent_drop=ffn_latent_drop,
|
| 194 |
+
ffn_out_drop=ffn_out_drop,
|
| 195 |
+
attention_drop=attention_drop,
|
| 196 |
+
attention_heads=attention_heads,
|
| 197 |
+
attention_heads_dim=attention_heads_dim,
|
| 198 |
+
)
|
| 199 |
+
self.att2 = conform_blocke(
|
| 200 |
+
dim=dim,
|
| 201 |
+
kernel_size=kernel_size,
|
| 202 |
+
conv_drop=conv_drop,
|
| 203 |
+
ffn_latent_drop=ffn_latent_drop,
|
| 204 |
+
ffn_out_drop=ffn_out_drop,
|
| 205 |
+
attention_drop=attention_drop,
|
| 206 |
+
attention_heads=attention_heads,
|
| 207 |
+
attention_heads_dim=attention_heads_dim,
|
| 208 |
+
)
|
| 209 |
+
self.glu1 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
|
| 210 |
+
self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
|
| 211 |
+
|
| 212 |
+
def forward(self, midi, bound):
|
| 213 |
+
midi = self.att1(midi)
|
| 214 |
+
bound = self.att2(bound)
|
| 215 |
+
midis = self.glu1(midi)
|
| 216 |
+
bounds = self.glu2(bound)
|
| 217 |
+
return midi + bounds, bound + midis
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class Gmidi_conform(nn.Module):
|
| 221 |
+
def __init__(
|
| 222 |
+
self,
|
| 223 |
+
lay: int,
|
| 224 |
+
dim: int,
|
| 225 |
+
indim: int,
|
| 226 |
+
outdim: int,
|
| 227 |
+
use_lay_skip: bool,
|
| 228 |
+
kernel_size: int = 31,
|
| 229 |
+
conv_drop: float = 0.1,
|
| 230 |
+
ffn_latent_drop: float = 0.1,
|
| 231 |
+
ffn_out_drop: float = 0.1,
|
| 232 |
+
attention_drop: float = 0.1,
|
| 233 |
+
attention_heads: int = 4,
|
| 234 |
+
attention_heads_dim: int = 64,
|
| 235 |
+
):
|
| 236 |
+
super().__init__()
|
| 237 |
+
|
| 238 |
+
self.inln = nn.Linear(indim, dim)
|
| 239 |
+
self.inln1 = nn.Linear(indim, dim)
|
| 240 |
+
self.outln = nn.Linear(dim, outdim)
|
| 241 |
+
self.cutheard = nn.Linear(dim, 1)
|
| 242 |
+
# self.cutheard = nn.Linear(dim, outdim)
|
| 243 |
+
self.lay = lay
|
| 244 |
+
self.use_lay_skip = use_lay_skip
|
| 245 |
+
self.cf_lay = nn.ModuleList(
|
| 246 |
+
[
|
| 247 |
+
Gcf(
|
| 248 |
+
dim=dim,
|
| 249 |
+
kernel_size=kernel_size,
|
| 250 |
+
conv_drop=conv_drop,
|
| 251 |
+
ffn_latent_drop=ffn_latent_drop,
|
| 252 |
+
ffn_out_drop=ffn_out_drop,
|
| 253 |
+
attention_drop=attention_drop,
|
| 254 |
+
attention_heads=attention_heads,
|
| 255 |
+
attention_heads_dim=attention_heads_dim,
|
| 256 |
+
)
|
| 257 |
+
for _ in range(lay)
|
| 258 |
+
]
|
| 259 |
+
)
|
| 260 |
+
self.att1 = conform_blocke(
|
| 261 |
+
dim=dim,
|
| 262 |
+
kernel_size=kernel_size,
|
| 263 |
+
conv_drop=conv_drop,
|
| 264 |
+
ffn_latent_drop=ffn_latent_drop,
|
| 265 |
+
ffn_out_drop=ffn_out_drop,
|
| 266 |
+
attention_drop=attention_drop,
|
| 267 |
+
attention_heads=attention_heads,
|
| 268 |
+
attention_heads_dim=attention_heads_dim,
|
| 269 |
+
)
|
| 270 |
+
self.att2 = conform_blocke(
|
| 271 |
+
dim=dim,
|
| 272 |
+
kernel_size=kernel_size,
|
| 273 |
+
conv_drop=conv_drop,
|
| 274 |
+
ffn_latent_drop=ffn_latent_drop,
|
| 275 |
+
ffn_out_drop=ffn_out_drop,
|
| 276 |
+
attention_drop=attention_drop,
|
| 277 |
+
attention_heads=attention_heads,
|
| 278 |
+
attention_heads_dim=attention_heads_dim,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def forward(self, x, mask=None):
|
| 282 |
+
x1 = x.clone()
|
| 283 |
+
|
| 284 |
+
x = self.inln(x)
|
| 285 |
+
x1 = self.inln1(x1)
|
| 286 |
+
if mask is not None:
|
| 287 |
+
x = x.masked_fill(~mask.unsqueeze(-1), 0)
|
| 288 |
+
for idx, i in enumerate(self.cf_lay):
|
| 289 |
+
x, x1 = i(x, x1)
|
| 290 |
+
|
| 291 |
+
if mask is not None:
|
| 292 |
+
x = x.masked_fill(~mask.unsqueeze(-1), 0)
|
| 293 |
+
x, x1 = self.att1(x), self.att2(x1)
|
| 294 |
+
|
| 295 |
+
cutprp = self.cutheard(x1)
|
| 296 |
+
midiout = self.outln(x)
|
| 297 |
+
|
| 298 |
+
return midiout, cutprp
|
src/YingMusicSinger/melody/Gconv.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class GLU(nn.Module):
|
| 5 |
+
def __init__(self, dim):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.dim = dim
|
| 8 |
+
|
| 9 |
+
def forward(self, x):
|
| 10 |
+
out, gate = x.chunk(2, dim=self.dim)
|
| 11 |
+
|
| 12 |
+
return out * gate.sigmoid()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class conform_conv(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, channels: int, kernel_size: int = 31, DropoutL=0.1, bias: bool = True
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.act2 = nn.SiLU()
|
| 21 |
+
self.act1 = GLU(1)
|
| 22 |
+
|
| 23 |
+
self.pointwise_conv1 = nn.Conv1d(
|
| 24 |
+
channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# self.lorder is used to distinguish if it's a causal convolution,
|
| 28 |
+
# if self.lorder > 0:
|
| 29 |
+
# it's a causal convolution, the input will be padded with
|
| 30 |
+
# `self.lorder` frames on the left in forward (causal conv impl).
|
| 31 |
+
# else: it's a symmetrical convolution
|
| 32 |
+
|
| 33 |
+
assert (kernel_size - 1) % 2 == 0
|
| 34 |
+
padding = (kernel_size - 1) // 2
|
| 35 |
+
|
| 36 |
+
self.depthwise_conv = nn.Conv1d(
|
| 37 |
+
channels,
|
| 38 |
+
channels,
|
| 39 |
+
kernel_size,
|
| 40 |
+
stride=1,
|
| 41 |
+
padding=padding,
|
| 42 |
+
groups=channels,
|
| 43 |
+
bias=bias,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
self.norm = nn.BatchNorm1d(channels)
|
| 47 |
+
|
| 48 |
+
self.pointwise_conv2 = nn.Conv1d(
|
| 49 |
+
channels, channels, kernel_size=1, stride=1, padding=0, bias=bias
|
| 50 |
+
)
|
| 51 |
+
self.drop = nn.Dropout(DropoutL) if DropoutL > 0.0 else nn.Identity()
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
x = x.transpose(1, 2)
|
| 55 |
+
x = self.act1(self.pointwise_conv1(x))
|
| 56 |
+
x = self.depthwise_conv(x)
|
| 57 |
+
x = self.norm(x)
|
| 58 |
+
x = self.act2(x)
|
| 59 |
+
x = self.pointwise_conv2(x)
|
| 60 |
+
return self.drop(x).transpose(1, 2)
|
src/YingMusicSinger/melody/SmoothMelody.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MIDIFuzzDisturb(nn.Module):
|
| 6 |
+
"""Applies fuzzing perturbations to MIDI latent representations.
|
| 7 |
+
|
| 8 |
+
The raw MIDI teacher model output preserves good prosody but causes
|
| 9 |
+
pronunciation interference. This module mitigates that by applying
|
| 10 |
+
blur, temporal dropout, and noise to the melody latent.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
def __init__(
|
| 14 |
+
self, dim=128, drop_prob=0.3, noise_scale=0.1, blur_kernel=3, drop_type="random"
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.blur = None
|
| 18 |
+
self.drop_prob = None
|
| 19 |
+
self.noise_scale = None
|
| 20 |
+
self.dim = dim
|
| 21 |
+
self.drop_type = drop_type
|
| 22 |
+
|
| 23 |
+
assert drop_prob is not None
|
| 24 |
+
assert drop_type is not None
|
| 25 |
+
if drop_type == "random":
|
| 26 |
+
# drop_prob is a float
|
| 27 |
+
if drop_prob != 0:
|
| 28 |
+
self.drop_prob = drop_prob
|
| 29 |
+
elif drop_type == "equal_space":
|
| 30 |
+
# drop_prob is a [drop, keep] list, e.g., [1, 1] means 1 frame drop, 1 frame keep
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown drop_type: {drop_type}")
|
| 34 |
+
|
| 35 |
+
if noise_scale != 0:
|
| 36 |
+
self.noise_scale = noise_scale
|
| 37 |
+
if blur_kernel != 0:
|
| 38 |
+
assert blur_kernel % 2 == 1, f"blur_kernel {blur_kernel} must be odd"
|
| 39 |
+
self.blur = nn.AvgPool1d(
|
| 40 |
+
kernel_size=blur_kernel, stride=1, padding=blur_kernel // 2
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def _create_equal_space_mask(self, batch_size, seq_len, device):
|
| 44 |
+
"""Create an equally-spaced mask cycling [drop, keep] frames."""
|
| 45 |
+
drop_frames, keep_frames = self.drop_prob
|
| 46 |
+
cycle_len = drop_frames + keep_frames
|
| 47 |
+
|
| 48 |
+
# Pattern: first drop_frames are 0 (drop), next keep_frames are 1 (keep)
|
| 49 |
+
pattern = torch.cat(
|
| 50 |
+
[
|
| 51 |
+
torch.zeros(drop_frames, device=device),
|
| 52 |
+
torch.ones(keep_frames, device=device),
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Repeat pattern to cover the full sequence length
|
| 57 |
+
num_repeats = (seq_len + cycle_len - 1) // cycle_len
|
| 58 |
+
mask = pattern.repeat(num_repeats)[:seq_len] # [T]
|
| 59 |
+
|
| 60 |
+
# Expand to [B, T, 1]
|
| 61 |
+
mask = mask.view(1, seq_len, 1).expand(batch_size, -1, -1)
|
| 62 |
+
|
| 63 |
+
return mask
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
# x: [B, T, D=128], pre-sigmoid logits
|
| 67 |
+
x = torch.sigmoid(x)
|
| 68 |
+
|
| 69 |
+
assert x.shape[-1] == self.dim, (
|
| 70 |
+
f"MIDIFuzzDisturb: expected dim={self.dim}, got {x.shape[-1]}"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if self.blur:
|
| 74 |
+
x = self.blur(x.transpose(1, 2)).transpose(1, 2)
|
| 75 |
+
|
| 76 |
+
if self.drop_prob:
|
| 77 |
+
if self.drop_type == "random":
|
| 78 |
+
time_mask = (
|
| 79 |
+
torch.rand(x.shape[0], x.shape[1], 1, device=x.device)
|
| 80 |
+
> self.drop_prob
|
| 81 |
+
)
|
| 82 |
+
x = x * time_mask.float()
|
| 83 |
+
elif self.drop_type == "equal_space":
|
| 84 |
+
time_mask = self._create_equal_space_mask(
|
| 85 |
+
x.shape[0], x.shape[1], x.device
|
| 86 |
+
)
|
| 87 |
+
x = x * time_mask.float()
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Unknown drop_type: {self.drop_type}")
|
| 90 |
+
|
| 91 |
+
if self.noise_scale:
|
| 92 |
+
noise = torch.randn_like(x) * self.noise_scale
|
| 93 |
+
x = x + noise
|
| 94 |
+
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class MIDIDigitalEmbedding(nn.Module):
|
| 99 |
+
"""Embeds continuous MIDI values into discrete token embeddings.
|
| 100 |
+
|
| 101 |
+
Continuous MIDI values in [0, 127] are quantized at a configurable
|
| 102 |
+
resolution (mark_distinguish_scale) and mapped to learned embeddings.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, embed_dim=128, num_classes=128, mark_distinguish_scale=2):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
# num_classes covers the input range [0, 127] plus 2 special tokens
|
| 109 |
+
self.num_classes = num_classes + 2
|
| 110 |
+
self.mark_distinguish_scale = mark_distinguish_scale
|
| 111 |
+
self.embedding_input_num_class = self.num_classes * self.mark_distinguish_scale
|
| 112 |
+
self.embedding = nn.Embedding(self.embedding_input_num_class, embed_dim)
|
| 113 |
+
|
| 114 |
+
def midi_to_class(self, midi_values):
|
| 115 |
+
"""Map continuous MIDI values to discrete class indices.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
midi_values: [B, T] continuous MIDI values, roughly in [0, 127]
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
class_indices: [B, T] discrete class indices
|
| 122 |
+
"""
|
| 123 |
+
# Round to nearest quantization step
|
| 124 |
+
# e.g., with scale=2: 0->0, 0.3->1, 0.5->1, 0.8->2, 1.0->2, ...
|
| 125 |
+
class_indices = torch.round(midi_values * self.mark_distinguish_scale).long()
|
| 126 |
+
|
| 127 |
+
# Clamp to valid range
|
| 128 |
+
class_indices = torch.clamp(
|
| 129 |
+
class_indices, 0, self.embedding_input_num_class - 1
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return class_indices
|
| 133 |
+
|
| 134 |
+
def forward(self, midi_values):
|
| 135 |
+
"""
|
| 136 |
+
Args:
|
| 137 |
+
midi_values: [B, T] continuous MIDI values
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
embeddings: [B, T, embed_dim] embedding vectors
|
| 141 |
+
"""
|
| 142 |
+
class_indices = self.midi_to_class(midi_values)
|
| 143 |
+
embeddings = self.embedding(class_indices)
|
| 144 |
+
return embeddings
|
src/YingMusicSinger/melody/midi_extractor.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 5 |
+
|
| 6 |
+
from src.YingMusicSinger.melody.Gconform import Gmidi_conform
|
| 7 |
+
|
| 8 |
+
# midi decoding utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
|
| 12 |
+
num_bins = int(probs.shape[-1])
|
| 13 |
+
interval = (vmax - vmin) / (num_bins - 1)
|
| 14 |
+
width = int(3 * deviation / interval) # 3 * sigma
|
| 15 |
+
idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N]
|
| 16 |
+
idx_values = idx * interval + vmin
|
| 17 |
+
center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1]
|
| 18 |
+
start = torch.clip(center - width, min=0) # [B, T, 1]
|
| 19 |
+
end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1]
|
| 20 |
+
idx_masks = (idx >= start) & (idx < end) # [B, T, N]
|
| 21 |
+
weights = probs * idx_masks # [B, T, N]
|
| 22 |
+
product_sum = torch.sum(weights * idx_values, dim=2) # [B, T]
|
| 23 |
+
weight_sum = torch.sum(weights, dim=2) # [B, T]
|
| 24 |
+
values = product_sum / (
|
| 25 |
+
weight_sum + (weight_sum == 0)
|
| 26 |
+
) # avoid dividing by zero, [B, T]
|
| 27 |
+
rest = probs.max(dim=-1)[0] < threshold # [B, T]
|
| 28 |
+
return values, rest
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def decode_bounds_to_alignment(bounds, use_diff=True):
|
| 32 |
+
bounds_step = bounds.cumsum(dim=1).round().long()
|
| 33 |
+
if use_diff:
|
| 34 |
+
bounds_inc = (
|
| 35 |
+
torch.diff(
|
| 36 |
+
bounds_step,
|
| 37 |
+
dim=1,
|
| 38 |
+
prepend=torch.full(
|
| 39 |
+
(bounds.shape[0], 1),
|
| 40 |
+
fill_value=-1,
|
| 41 |
+
dtype=bounds_step.dtype,
|
| 42 |
+
device=bounds_step.device,
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
> 0
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
bounds_inc = F.pad(
|
| 49 |
+
(bounds_step[:, 1:] > bounds_step[:, :-1]), [1, 0], value=True
|
| 50 |
+
)
|
| 51 |
+
frame2item = bounds_inc.long().cumsum(dim=1)
|
| 52 |
+
return frame2item
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def decode_note_sequence(frame2item, values, masks, threshold=0.5):
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
:param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3]
|
| 59 |
+
:param values:
|
| 60 |
+
:param masks:
|
| 61 |
+
:param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item
|
| 62 |
+
:return: item_values, item_dur, item_masks
|
| 63 |
+
"""
|
| 64 |
+
b = frame2item.shape[0]
|
| 65 |
+
space = frame2item.max() + 1
|
| 66 |
+
|
| 67 |
+
item_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
|
| 68 |
+
1, frame2item, torch.ones_like(frame2item)
|
| 69 |
+
)[:, 1:]
|
| 70 |
+
item_unmasked_dur = frame2item.new_zeros(
|
| 71 |
+
b, space, dtype=frame2item.dtype
|
| 72 |
+
).scatter_add(1, frame2item, masks.long())[:, 1:]
|
| 73 |
+
item_masks = item_unmasked_dur / item_dur >= threshold
|
| 74 |
+
|
| 75 |
+
values_quant = values.round().long()
|
| 76 |
+
histogram = (
|
| 77 |
+
frame2item.new_zeros(b, space * 128, dtype=frame2item.dtype)
|
| 78 |
+
.scatter_add(
|
| 79 |
+
1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks
|
| 80 |
+
)
|
| 81 |
+
.unflatten(1, [space, 128])[:, 1:, :]
|
| 82 |
+
)
|
| 83 |
+
item_values_center = histogram.float().argmax(dim=2).to(dtype=values.dtype)
|
| 84 |
+
values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item)
|
| 85 |
+
values_near_center = (
|
| 86 |
+
masks & (values >= values_center - 0.5) & (values <= values_center + 0.5)
|
| 87 |
+
)
|
| 88 |
+
item_valid_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
|
| 89 |
+
1, frame2item, values_near_center.long()
|
| 90 |
+
)[:, 1:]
|
| 91 |
+
item_values = values.new_zeros(b, space, dtype=values.dtype).scatter_add(
|
| 92 |
+
1, frame2item, values * values_near_center
|
| 93 |
+
)[:, 1:] / (item_valid_dur + (item_valid_dur == 0))
|
| 94 |
+
|
| 95 |
+
return item_values, item_dur, item_masks
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def expand_batch_padded(feature_tensor, counts_tensor, padding_value=0.0):
|
| 99 |
+
assert feature_tensor.dim() == 2 and counts_tensor.dim() == 2
|
| 100 |
+
|
| 101 |
+
lengths = torch.sum(counts_tensor, dim=1)
|
| 102 |
+
|
| 103 |
+
feature_tensor = feature_tensor.reshape(-1)
|
| 104 |
+
counts_tensor = counts_tensor.reshape(-1)
|
| 105 |
+
expanded_flat = torch.repeat_interleave(feature_tensor, counts_tensor)
|
| 106 |
+
|
| 107 |
+
ragged_list = torch.split(expanded_flat, lengths.tolist())
|
| 108 |
+
|
| 109 |
+
padded_tensor = pad_sequence(
|
| 110 |
+
ragged_list, batch_first=True, padding_value=padding_value
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return padded_tensor, lengths
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class midi_loss(nn.Module):
|
| 117 |
+
def __init__(self):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.loss = nn.BCELoss()
|
| 120 |
+
|
| 121 |
+
def forward(self, x, target):
|
| 122 |
+
midiout, cutp = x
|
| 123 |
+
midi_target, cutp_target = target
|
| 124 |
+
|
| 125 |
+
cutploss = self.loss(cutp, cutp_target)
|
| 126 |
+
midiloss = self.loss(midiout, midi_target)
|
| 127 |
+
return midiloss, cutploss
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class MIDIExtractor(nn.Module):
|
| 131 |
+
def __init__(self, in_dim=None, out_dim=None):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
cfg = {
|
| 135 |
+
"attention_drop": 0.1,
|
| 136 |
+
"attention_heads": 8,
|
| 137 |
+
"attention_heads_dim": 64,
|
| 138 |
+
"conv_drop": 0.1,
|
| 139 |
+
"dim": 512,
|
| 140 |
+
"ffn_latent_drop": 0.1,
|
| 141 |
+
"ffn_out_drop": 0.1,
|
| 142 |
+
"kernel_size": 31,
|
| 143 |
+
"lay": 8,
|
| 144 |
+
"use_lay_skip": True,
|
| 145 |
+
"indim": 80,
|
| 146 |
+
"outdim": 128,
|
| 147 |
+
}
|
| 148 |
+
if in_dim is not None:
|
| 149 |
+
cfg["indim"] = in_dim
|
| 150 |
+
if out_dim is not None:
|
| 151 |
+
cfg["outdim"] = out_dim
|
| 152 |
+
|
| 153 |
+
self.midi_conform = Gmidi_conform(**cfg)
|
| 154 |
+
|
| 155 |
+
self.midi_min = 0
|
| 156 |
+
self.midi_max = 127
|
| 157 |
+
self.midi_deviation = 1.0
|
| 158 |
+
self.rest_threshold = 0.1
|
| 159 |
+
|
| 160 |
+
def _load_form_ckpt(self, ckpt_path, device="cpu"):
|
| 161 |
+
from collections import OrderedDict
|
| 162 |
+
|
| 163 |
+
if ckpt_path is None:
|
| 164 |
+
raise ValueError("midi_extractor_path is required")
|
| 165 |
+
|
| 166 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
| 167 |
+
prefix_in_ckpt = "model.model"
|
| 168 |
+
state_dict = OrderedDict(
|
| 169 |
+
{
|
| 170 |
+
k.replace(f"{prefix_in_ckpt}.", "midi_conform."): v
|
| 171 |
+
for k, v in state_dict.items()
|
| 172 |
+
if k.startswith(f"{prefix_in_ckpt}.")
|
| 173 |
+
}
|
| 174 |
+
)
|
| 175 |
+
self.load_state_dict(state_dict, strict=True)
|
| 176 |
+
# self.to(device)
|
| 177 |
+
|
| 178 |
+
def forward(self, x, mask=None):
|
| 179 |
+
midi, bound = self.midi_conform(x, mask)
|
| 180 |
+
|
| 181 |
+
return midi, bound
|
| 182 |
+
|
| 183 |
+
def postprocess(self, midi, bounds, with_expand=False):
|
| 184 |
+
probs = torch.sigmoid(midi)
|
| 185 |
+
|
| 186 |
+
bound_probs = torch.sigmoid(bounds)
|
| 187 |
+
bound_probs = torch.squeeze(bound_probs, -1)
|
| 188 |
+
|
| 189 |
+
masks = torch.ones_like(bound_probs).bool()
|
| 190 |
+
# Avoid in-place ops on tensors needed for autograd (outputs of SigmoidBackward)
|
| 191 |
+
probs = probs * masks[..., None]
|
| 192 |
+
bound_probs = bound_probs * masks
|
| 193 |
+
unit2note_pred = decode_bounds_to_alignment(bound_probs) * masks
|
| 194 |
+
midi_pred, rest_pred = decode_gaussian_blurred_probs(
|
| 195 |
+
probs,
|
| 196 |
+
vmin=self.midi_min,
|
| 197 |
+
vmax=self.midi_max,
|
| 198 |
+
deviation=self.midi_deviation,
|
| 199 |
+
threshold=self.rest_threshold,
|
| 200 |
+
)
|
| 201 |
+
note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
|
| 202 |
+
unit2note_pred, midi_pred, ~rest_pred & masks
|
| 203 |
+
)
|
| 204 |
+
if not with_expand:
|
| 205 |
+
return note_midi_pred, note_dur_pred
|
| 206 |
+
|
| 207 |
+
note_midi_expand, _ = expand_batch_padded(note_midi_pred, note_dur_pred)
|
| 208 |
+
return note_midi_expand, None
|
src/YingMusicSinger/models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dit import DiT
|
src/YingMusicSinger/models/dit.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch import nn
|
| 15 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 16 |
+
|
| 17 |
+
from src.YingMusicSinger.models.modules import (
|
| 18 |
+
AdaLayerNorm_Final,
|
| 19 |
+
ConvNeXtV2Block,
|
| 20 |
+
ConvPositionEmbedding,
|
| 21 |
+
DiTBlock,
|
| 22 |
+
TimestepGuidanceEmbedding,
|
| 23 |
+
get_pos_embed_indices,
|
| 24 |
+
precompute_freqs_cis,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Text embedding
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TextEmbedding(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
text_num_embeds,
|
| 35 |
+
text_dim,
|
| 36 |
+
mask_padding=False,
|
| 37 |
+
average_upsampling=False,
|
| 38 |
+
conv_layers=0,
|
| 39 |
+
conv_mult=2,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.text_embed = nn.Embedding(
|
| 43 |
+
text_num_embeds + 1, text_dim
|
| 44 |
+
) # index 0 reserved as filler token
|
| 45 |
+
|
| 46 |
+
self.mask_padding = mask_padding
|
| 47 |
+
self.average_upsampling = average_upsampling # ZipVoice-style late average upsampling (after text encoder)
|
| 48 |
+
if average_upsampling:
|
| 49 |
+
assert mask_padding, (
|
| 50 |
+
"text_embedding_average_upsampling requires text_mask_padding to be True"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if conv_layers > 0:
|
| 54 |
+
self.extra_modeling = True
|
| 55 |
+
self.precompute_max_pos = 4096 # ~44s of 24kHz audio
|
| 56 |
+
self.register_buffer(
|
| 57 |
+
"freqs_cis",
|
| 58 |
+
precompute_freqs_cis(text_dim, self.precompute_max_pos),
|
| 59 |
+
persistent=False,
|
| 60 |
+
)
|
| 61 |
+
self.text_blocks = nn.Sequential(
|
| 62 |
+
*[
|
| 63 |
+
ConvNeXtV2Block(text_dim, text_dim * conv_mult)
|
| 64 |
+
for _ in range(conv_layers)
|
| 65 |
+
]
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.extra_modeling = False
|
| 69 |
+
|
| 70 |
+
print(
|
| 71 |
+
f"[info] TextEmbedding: mask_padding={mask_padding}, average_upsampling={average_upsampling}, conv_layers={conv_layers}"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
|
| 75 |
+
batch, text_len, text_dim = text.shape
|
| 76 |
+
|
| 77 |
+
if audio_mask is None:
|
| 78 |
+
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
|
| 79 |
+
valid_mask = audio_mask & text_mask
|
| 80 |
+
audio_lens = audio_mask.sum(dim=1) # [batch]
|
| 81 |
+
valid_lens = valid_mask.sum(dim=1) # [batch]
|
| 82 |
+
|
| 83 |
+
upsampled_text = torch.zeros_like(text)
|
| 84 |
+
|
| 85 |
+
for i in range(batch):
|
| 86 |
+
audio_len = audio_lens[i].item()
|
| 87 |
+
valid_len = valid_lens[i].item()
|
| 88 |
+
|
| 89 |
+
if valid_len == 0:
|
| 90 |
+
continue
|
| 91 |
+
|
| 92 |
+
valid_ind = torch.where(valid_mask[i])[0]
|
| 93 |
+
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
|
| 94 |
+
|
| 95 |
+
base_repeat = audio_len // valid_len
|
| 96 |
+
remainder = audio_len % valid_len
|
| 97 |
+
|
| 98 |
+
indices = []
|
| 99 |
+
for j in range(valid_len):
|
| 100 |
+
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
|
| 101 |
+
indices.extend([j] * repeat_count)
|
| 102 |
+
|
| 103 |
+
indices = torch.tensor(
|
| 104 |
+
indices[:audio_len], device=text.device, dtype=torch.long
|
| 105 |
+
)
|
| 106 |
+
upsampled = valid_data[indices] # [audio_len, text_dim]
|
| 107 |
+
|
| 108 |
+
upsampled_text[i, :audio_len, :] = upsampled
|
| 109 |
+
|
| 110 |
+
return upsampled_text
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
text: int["b nt"],
|
| 115 |
+
seq_len,
|
| 116 |
+
drop_text=False,
|
| 117 |
+
audio_mask: bool["b n"] | None = None,
|
| 118 |
+
): # noqa: F722
|
| 119 |
+
# Text tokens start from 0; shift by 1 so that 0 is never a valid token
|
| 120 |
+
text = text + 1
|
| 121 |
+
# Note: 1 is used as the PAD token
|
| 122 |
+
text = text[
|
| 123 |
+
:, :seq_len
|
| 124 |
+
] # Truncate if text tokens exceed mel spectrogram length
|
| 125 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 126 |
+
text = F.pad(text, (0, seq_len - text_len), value=1)
|
| 127 |
+
|
| 128 |
+
if self.mask_padding:
|
| 129 |
+
text_mask = text == 1
|
| 130 |
+
else:
|
| 131 |
+
text_mask = torch.zeros(
|
| 132 |
+
(batch, seq_len), device=text.device, dtype=torch.bool
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if drop_text: # CFG for text
|
| 136 |
+
text = torch.zeros_like(text)
|
| 137 |
+
|
| 138 |
+
text = self.text_embed(text) # b n -> b n d
|
| 139 |
+
|
| 140 |
+
# Optional extra modeling
|
| 141 |
+
if self.extra_modeling:
|
| 142 |
+
# Sinusoidal positional embedding
|
| 143 |
+
batch_start = torch.zeros((batch,), device=text.device, dtype=torch.long)
|
| 144 |
+
pos_idx = get_pos_embed_indices(
|
| 145 |
+
batch_start, seq_len, max_pos=self.precompute_max_pos
|
| 146 |
+
)
|
| 147 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 148 |
+
text = text + text_pos_embed
|
| 149 |
+
|
| 150 |
+
# ConvNeXtV2 blocks
|
| 151 |
+
if self.mask_padding:
|
| 152 |
+
text = text.masked_fill(
|
| 153 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
| 154 |
+
)
|
| 155 |
+
for block in self.text_blocks:
|
| 156 |
+
text = block(text)
|
| 157 |
+
text = text.masked_fill(
|
| 158 |
+
text_mask.unsqueeze(-1).expand(-1, -1, text.size(-1)), 0.0
|
| 159 |
+
)
|
| 160 |
+
else:
|
| 161 |
+
text = self.text_blocks(text)
|
| 162 |
+
|
| 163 |
+
if self.average_upsampling:
|
| 164 |
+
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
|
| 165 |
+
|
| 166 |
+
return text, text_mask
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# Noised input audio and context mixing embedding
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class InputEmbedding(nn.Module):
|
| 173 |
+
def __init__(self, mel_dim, text_dim, out_dim, midi_dim=128):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim + midi_dim, out_dim)
|
| 176 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
| 177 |
+
self.midi_proj = nn.Linear(128, 128)
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
x: float["b n d"], # noqa: F722
|
| 182 |
+
cond: float["b n d"], # noqa: F722
|
| 183 |
+
text_embed: float["b n d"], # noqa: F722
|
| 184 |
+
midi,
|
| 185 |
+
drop_audio_cond=False,
|
| 186 |
+
drop_midi=False,
|
| 187 |
+
):
|
| 188 |
+
if drop_audio_cond: # CFG for conditioning audio
|
| 189 |
+
cond = torch.zeros_like(cond)
|
| 190 |
+
|
| 191 |
+
midi = self.midi_proj(midi)
|
| 192 |
+
|
| 193 |
+
if drop_midi: # CFG for melody
|
| 194 |
+
midi = torch.zeros_like(midi)
|
| 195 |
+
|
| 196 |
+
x = self.proj(torch.cat((x, cond, text_embed, midi), dim=-1))
|
| 197 |
+
x = self.conv_pos_embed(x) + x
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Transformer backbone using DiT blocks
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class DiT(nn.Module):
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
*,
|
| 208 |
+
dim,
|
| 209 |
+
depth=8,
|
| 210 |
+
heads=8,
|
| 211 |
+
dim_head=64,
|
| 212 |
+
dropout=0.1,
|
| 213 |
+
ff_mult=4,
|
| 214 |
+
mel_dim=100,
|
| 215 |
+
text_num_embeds=256,
|
| 216 |
+
text_dim=None,
|
| 217 |
+
n_f0_bins=512,
|
| 218 |
+
text_mask_padding=True,
|
| 219 |
+
text_embedding_average_upsampling=False,
|
| 220 |
+
qk_norm=None,
|
| 221 |
+
conv_layers=0,
|
| 222 |
+
pe_attn_head=None,
|
| 223 |
+
attn_backend="torch", # "torch" | "flash_attn"
|
| 224 |
+
attn_mask_enabled=False,
|
| 225 |
+
long_skip_connection=False,
|
| 226 |
+
checkpoint_activations=False,
|
| 227 |
+
use_guidance_scale_embed: bool = False,
|
| 228 |
+
guidance_scale_embed_dim: int = 192,
|
| 229 |
+
):
|
| 230 |
+
super().__init__()
|
| 231 |
+
|
| 232 |
+
self.time_embed = TimestepGuidanceEmbedding(
|
| 233 |
+
dim,
|
| 234 |
+
use_guidance_scale_embed=use_guidance_scale_embed,
|
| 235 |
+
guidance_scale_embed_dim=guidance_scale_embed_dim,
|
| 236 |
+
)
|
| 237 |
+
if text_dim is None:
|
| 238 |
+
text_dim = mel_dim
|
| 239 |
+
self.text_embed_p = TextEmbedding(
|
| 240 |
+
text_num_embeds,
|
| 241 |
+
text_dim,
|
| 242 |
+
mask_padding=text_mask_padding,
|
| 243 |
+
average_upsampling=text_embedding_average_upsampling,
|
| 244 |
+
conv_layers=conv_layers,
|
| 245 |
+
)
|
| 246 |
+
self.text_cond, self.text_uncond = None, None # text cache
|
| 247 |
+
self.input_embed_with_midi = InputEmbedding(mel_dim, text_dim, dim)
|
| 248 |
+
|
| 249 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 250 |
+
self.use_guidance_scale_embed = use_guidance_scale_embed
|
| 251 |
+
|
| 252 |
+
self.dim = dim
|
| 253 |
+
self.depth = depth
|
| 254 |
+
|
| 255 |
+
self.transformer_blocks = nn.ModuleList(
|
| 256 |
+
[
|
| 257 |
+
DiTBlock(
|
| 258 |
+
dim=dim,
|
| 259 |
+
heads=heads,
|
| 260 |
+
dim_head=dim_head,
|
| 261 |
+
ff_mult=ff_mult,
|
| 262 |
+
dropout=dropout,
|
| 263 |
+
qk_norm=qk_norm,
|
| 264 |
+
pe_attn_head=pe_attn_head,
|
| 265 |
+
attn_backend=attn_backend,
|
| 266 |
+
attn_mask_enabled=attn_mask_enabled,
|
| 267 |
+
)
|
| 268 |
+
for _ in range(depth)
|
| 269 |
+
]
|
| 270 |
+
)
|
| 271 |
+
self.long_skip_connection = (
|
| 272 |
+
nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self.norm_out = AdaLayerNorm_Final(dim) # Final modulation
|
| 276 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 277 |
+
|
| 278 |
+
self.checkpoint_activations = checkpoint_activations
|
| 279 |
+
|
| 280 |
+
self.initialize_weights()
|
| 281 |
+
|
| 282 |
+
def initialize_weights(self):
|
| 283 |
+
# Zero-out AdaLN layers in DiT blocks
|
| 284 |
+
for block in self.transformer_blocks:
|
| 285 |
+
nn.init.constant_(block.attn_norm.linear.weight, 0)
|
| 286 |
+
nn.init.constant_(block.attn_norm.linear.bias, 0)
|
| 287 |
+
|
| 288 |
+
# Zero-out output layers
|
| 289 |
+
nn.init.constant_(self.norm_out.linear.weight, 0)
|
| 290 |
+
nn.init.constant_(self.norm_out.linear.bias, 0)
|
| 291 |
+
nn.init.constant_(self.proj_out.weight, 0)
|
| 292 |
+
nn.init.constant_(self.proj_out.bias, 0)
|
| 293 |
+
|
| 294 |
+
nn.init.zeros_(self.input_embed_with_midi.midi_proj.weight)
|
| 295 |
+
nn.init.zeros_(self.input_embed_with_midi.midi_proj.bias)
|
| 296 |
+
|
| 297 |
+
def ckpt_wrapper(self, module):
|
| 298 |
+
# Ref: https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
| 299 |
+
def ckpt_forward(*inputs):
|
| 300 |
+
outputs = module(*inputs)
|
| 301 |
+
return outputs
|
| 302 |
+
|
| 303 |
+
return ckpt_forward
|
| 304 |
+
|
| 305 |
+
def get_input_embed(
|
| 306 |
+
self,
|
| 307 |
+
x, # b n d
|
| 308 |
+
cond, # b n d
|
| 309 |
+
text, # b nt
|
| 310 |
+
midi, # b n
|
| 311 |
+
drop_audio_cond: bool = False,
|
| 312 |
+
drop_text: bool = False,
|
| 313 |
+
drop_midi: bool = False,
|
| 314 |
+
cache: bool = True,
|
| 315 |
+
audio_mask: bool["b n"] | None = None, # noqa: F722
|
| 316 |
+
):
|
| 317 |
+
seq_len = x.shape[1]
|
| 318 |
+
|
| 319 |
+
if cache:
|
| 320 |
+
if drop_text:
|
| 321 |
+
if self.text_uncond is None:
|
| 322 |
+
self.text_uncond, _ = self.text_embed_p(
|
| 323 |
+
text, seq_len, drop_text=True, audio_mask=audio_mask
|
| 324 |
+
)
|
| 325 |
+
text_embed = self.text_uncond
|
| 326 |
+
else:
|
| 327 |
+
if self.text_cond is None:
|
| 328 |
+
self.text_cond, _ = self.text_embed_p(
|
| 329 |
+
text, seq_len, drop_text=False, audio_mask=audio_mask
|
| 330 |
+
)
|
| 331 |
+
text_embed = self.text_cond
|
| 332 |
+
else:
|
| 333 |
+
text_embed, text_mask = self.text_embed_p(
|
| 334 |
+
text, seq_len, drop_text=drop_text, audio_mask=audio_mask
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if midi is None:
|
| 338 |
+
midi = torch.zeros(
|
| 339 |
+
(x.size(0), x.size(1)), device=x.device, dtype=torch.long
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
x = self.input_embed_with_midi(
|
| 343 |
+
x,
|
| 344 |
+
cond,
|
| 345 |
+
text_embed,
|
| 346 |
+
midi,
|
| 347 |
+
drop_audio_cond=drop_audio_cond,
|
| 348 |
+
drop_midi=drop_midi,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return x, None
|
| 352 |
+
|
| 353 |
+
def clear_cache(self):
|
| 354 |
+
self.text_cond, self.text_uncond = None, None
|
| 355 |
+
|
| 356 |
+
def forward(
|
| 357 |
+
self,
|
| 358 |
+
x: float["b n d"], # Noised input audio # noqa: F722
|
| 359 |
+
cond: float["b n d"], # Masked conditioning audio # noqa: F722
|
| 360 |
+
text: int["b nt"], # Text tokens # noqa: F722
|
| 361 |
+
time: float["b"] | float[""], # Timestep # noqa: F821 F722
|
| 362 |
+
midi: float["b n"] | None = None, # Melody latent # noqa: F722
|
| 363 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 364 |
+
drop_audio_cond: bool = False, # CFG for conditioning audio
|
| 365 |
+
drop_text: bool = False, # CFG for text
|
| 366 |
+
drop_midi: bool = False, # CFG for melody
|
| 367 |
+
cfg_infer: bool = False, # CFG inference: pack cond & uncond forward
|
| 368 |
+
cache: bool = False,
|
| 369 |
+
guidance_scale=None,
|
| 370 |
+
cfg_infer_ids=None, # tuple(bool): (x_cond, x_uncond, x_uncond_cc, x_drop_all_cond)
|
| 371 |
+
):
|
| 372 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 373 |
+
if time.ndim == 0:
|
| 374 |
+
time = time.repeat(batch)
|
| 375 |
+
|
| 376 |
+
# Timestep embedding (with optional distillation guidance scale)
|
| 377 |
+
t = self.time_embed(time, guidance_scale=guidance_scale)
|
| 378 |
+
|
| 379 |
+
if cfg_infer: # Pack cond & uncond forward: b n d -> Kb n d
|
| 380 |
+
x_cond, x_uncond, x_uncond_cc, x_drop_all_cond = None, None, None, None
|
| 381 |
+
if cfg_infer_ids is None or cfg_infer_ids[0]:
|
| 382 |
+
x_cond, _ = self.get_input_embed(
|
| 383 |
+
x,
|
| 384 |
+
cond,
|
| 385 |
+
text,
|
| 386 |
+
midi,
|
| 387 |
+
drop_audio_cond=False,
|
| 388 |
+
drop_text=False,
|
| 389 |
+
drop_midi=False,
|
| 390 |
+
cache=cache,
|
| 391 |
+
audio_mask=mask,
|
| 392 |
+
)
|
| 393 |
+
if cfg_infer_ids is None or cfg_infer_ids[1]:
|
| 394 |
+
x_uncond, _ = self.get_input_embed(
|
| 395 |
+
x,
|
| 396 |
+
cond,
|
| 397 |
+
text,
|
| 398 |
+
midi,
|
| 399 |
+
drop_audio_cond=True,
|
| 400 |
+
drop_text=False,
|
| 401 |
+
drop_midi=False,
|
| 402 |
+
cache=cache,
|
| 403 |
+
audio_mask=mask,
|
| 404 |
+
)
|
| 405 |
+
if cfg_infer_ids is None or cfg_infer_ids[2]:
|
| 406 |
+
x_uncond_cc, _ = self.get_input_embed(
|
| 407 |
+
x,
|
| 408 |
+
cond,
|
| 409 |
+
text,
|
| 410 |
+
midi,
|
| 411 |
+
drop_audio_cond=False,
|
| 412 |
+
drop_text=True,
|
| 413 |
+
drop_midi=True,
|
| 414 |
+
cache=cache,
|
| 415 |
+
audio_mask=mask,
|
| 416 |
+
)
|
| 417 |
+
if cfg_infer_ids is None or cfg_infer_ids[3]:
|
| 418 |
+
x_drop_all_cond, _ = self.get_input_embed(
|
| 419 |
+
x,
|
| 420 |
+
cond,
|
| 421 |
+
text,
|
| 422 |
+
midi,
|
| 423 |
+
drop_audio_cond=True,
|
| 424 |
+
drop_text=True,
|
| 425 |
+
drop_midi=True,
|
| 426 |
+
cache=cache,
|
| 427 |
+
audio_mask=mask,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Concatenate only non-None tensors
|
| 431 |
+
x_list = [
|
| 432 |
+
xi
|
| 433 |
+
for xi in [x_cond, x_uncond, x_uncond_cc, x_drop_all_cond]
|
| 434 |
+
if xi is not None
|
| 435 |
+
]
|
| 436 |
+
x = torch.cat(x_list, dim=0)
|
| 437 |
+
t = torch.cat([t] * len(x_list), dim=0)
|
| 438 |
+
mask = torch.cat([mask] * len(x_list), dim=0) if mask is not None else None
|
| 439 |
+
else:
|
| 440 |
+
x, text_inner_sim_matrix = self.get_input_embed(
|
| 441 |
+
x,
|
| 442 |
+
cond,
|
| 443 |
+
text,
|
| 444 |
+
midi,
|
| 445 |
+
drop_audio_cond=drop_audio_cond,
|
| 446 |
+
drop_text=drop_text,
|
| 447 |
+
drop_midi=drop_midi,
|
| 448 |
+
cache=cache,
|
| 449 |
+
audio_mask=mask,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 453 |
+
|
| 454 |
+
if self.long_skip_connection is not None:
|
| 455 |
+
residual = x
|
| 456 |
+
|
| 457 |
+
# Mask is all zeros during inference
|
| 458 |
+
for block in self.transformer_blocks:
|
| 459 |
+
if self.checkpoint_activations:
|
| 460 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 461 |
+
self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
x = block(x, t, mask=mask, rope=rope)
|
| 465 |
+
|
| 466 |
+
if self.long_skip_connection is not None:
|
| 467 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
| 468 |
+
|
| 469 |
+
x = self.norm_out(x, t)
|
| 470 |
+
output = self.proj_out(x)
|
| 471 |
+
|
| 472 |
+
return output, text_inner_sim_matrix if not cfg_infer else None
|
src/YingMusicSinger/models/model.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Callable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from torchdiffeq import odeint
|
| 10 |
+
|
| 11 |
+
from src.YingMusicSinger.melody.midi_extractor import MIDIExtractor
|
| 12 |
+
from src.YingMusicSinger.utils.common import (
|
| 13 |
+
default,
|
| 14 |
+
exists,
|
| 15 |
+
get_epss_timesteps,
|
| 16 |
+
lens_to_mask,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def interpolation_midi_continuous(midi_p, bound_p, total_len):
|
| 21 |
+
"""Temporally interpolate 3D melody latent to match target length."""
|
| 22 |
+
if midi_p.shape[1] != total_len:
|
| 23 |
+
midi = (
|
| 24 |
+
F.interpolate(
|
| 25 |
+
midi_p.clone().detach().transpose(1, 2),
|
| 26 |
+
size=total_len,
|
| 27 |
+
mode="linear",
|
| 28 |
+
align_corners=False,
|
| 29 |
+
)
|
| 30 |
+
.transpose(1, 2)
|
| 31 |
+
.clone()
|
| 32 |
+
.detach()
|
| 33 |
+
)
|
| 34 |
+
if bound_p is not None:
|
| 35 |
+
midi_bound = (
|
| 36 |
+
F.interpolate(
|
| 37 |
+
bound_p.clone().detach().transpose(1, 2),
|
| 38 |
+
size=total_len,
|
| 39 |
+
mode="linear",
|
| 40 |
+
align_corners=False,
|
| 41 |
+
)
|
| 42 |
+
.transpose(1, 2)
|
| 43 |
+
.clone()
|
| 44 |
+
.detach()
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
midi = midi_p.clone().detach()
|
| 48 |
+
if bound_p is not None:
|
| 49 |
+
midi_bound = bound_p.clone().detach()
|
| 50 |
+
if bound_p is not None:
|
| 51 |
+
return midi, midi_bound
|
| 52 |
+
else:
|
| 53 |
+
return midi
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def interpolation_midi_continuous_2_dim(midi_p, bound_p, total_len):
|
| 57 |
+
"""Temporally interpolate 2D melody latent to match target length."""
|
| 58 |
+
assert len(midi_p.shape) == 2
|
| 59 |
+
|
| 60 |
+
if midi_p.shape[1] != total_len:
|
| 61 |
+
midi = (
|
| 62 |
+
F.interpolate(
|
| 63 |
+
midi_p.unsqueeze(2).clone().detach().transpose(1, 2),
|
| 64 |
+
size=total_len,
|
| 65 |
+
mode="linear",
|
| 66 |
+
align_corners=False,
|
| 67 |
+
)
|
| 68 |
+
.transpose(1, 2)
|
| 69 |
+
.clone()
|
| 70 |
+
.detach()
|
| 71 |
+
)
|
| 72 |
+
if bound_p:
|
| 73 |
+
midi_bound = (
|
| 74 |
+
F.interpolate(
|
| 75 |
+
bound_p.unsqueeze(2).clone().detach().transpose(1, 2),
|
| 76 |
+
size=total_len,
|
| 77 |
+
mode="linear",
|
| 78 |
+
align_corners=False,
|
| 79 |
+
)
|
| 80 |
+
.transpose(1, 2)
|
| 81 |
+
.clone()
|
| 82 |
+
.detach()
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
midi = midi_p.clone().detach()
|
| 86 |
+
if bound_p:
|
| 87 |
+
midi_bound = bound_p.clone().detach()
|
| 88 |
+
if bound_p:
|
| 89 |
+
return midi.squeeze(2), midi_bound.squeeze(2)
|
| 90 |
+
else:
|
| 91 |
+
return midi.squeeze(2)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Singer(nn.Module):
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
transformer: nn.Module,
|
| 98 |
+
is_tts_pretrain,
|
| 99 |
+
melody_input_source,
|
| 100 |
+
cka_disabled,
|
| 101 |
+
distill_stage,
|
| 102 |
+
use_guidance_scale_embed,
|
| 103 |
+
sigma=0.0,
|
| 104 |
+
odeint_kwargs: dict = dict(method="euler"),
|
| 105 |
+
audio_drop_prob=0.3,
|
| 106 |
+
cond_drop_prob=0.2,
|
| 107 |
+
num_channels=None,
|
| 108 |
+
mel_spec_module: nn.Module | None = None,
|
| 109 |
+
mel_spec_kwargs: dict = dict(),
|
| 110 |
+
frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
|
| 111 |
+
extra_parameters=None,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
self.is_tts_pretrain = is_tts_pretrain
|
| 116 |
+
|
| 117 |
+
if distill_stage is None:
|
| 118 |
+
self.distill_stage = 0
|
| 119 |
+
else:
|
| 120 |
+
self.distill_stage = int(distill_stage)
|
| 121 |
+
|
| 122 |
+
self.use_guidance_scale_embed = use_guidance_scale_embed
|
| 123 |
+
|
| 124 |
+
assert melody_input_source in {
|
| 125 |
+
"student_model",
|
| 126 |
+
"some_pretrain",
|
| 127 |
+
"some_pretrain_fuzzdisturb",
|
| 128 |
+
"some_pretrain_postprocess_embedding",
|
| 129 |
+
"none",
|
| 130 |
+
}
|
| 131 |
+
from src.YingMusicSinger.melody.SmoothMelody import MIDIFuzzDisturb
|
| 132 |
+
|
| 133 |
+
if melody_input_source == "some_pretrain_fuzzdisturb":
|
| 134 |
+
self.smoothMelody_MIDIFuzzDisturb = MIDIFuzzDisturb(
|
| 135 |
+
dim=extra_parameters.some_pretrain_fuzzdisturb.dim,
|
| 136 |
+
drop_prob=extra_parameters.some_pretrain_fuzzdisturb.drop_prob,
|
| 137 |
+
noise_scale=extra_parameters.some_pretrain_fuzzdisturb.noise_scale,
|
| 138 |
+
blur_kernel=extra_parameters.some_pretrain_fuzzdisturb.blur_kernel,
|
| 139 |
+
drop_type=extra_parameters.some_pretrain_fuzzdisturb.drop_type,
|
| 140 |
+
)
|
| 141 |
+
from src.YingMusicSinger.melody.SmoothMelody import MIDIDigitalEmbedding
|
| 142 |
+
|
| 143 |
+
if melody_input_source == "some_pretrain_postprocess_embedding":
|
| 144 |
+
self.smoothMelody_MIDIDigitalEmbedding = MIDIDigitalEmbedding(
|
| 145 |
+
embed_dim=extra_parameters.some_pretrain_postprocess_embedding.embed_dim,
|
| 146 |
+
num_classes=extra_parameters.some_pretrain_postprocess_embedding.num_classes,
|
| 147 |
+
mark_distinguish_scale=extra_parameters.some_pretrain_postprocess_embedding.mark_distinguish_scale,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.melody_input_source = melody_input_source
|
| 151 |
+
self.cka_disabled = cka_disabled
|
| 152 |
+
|
| 153 |
+
self.frac_lengths_mask = frac_lengths_mask
|
| 154 |
+
|
| 155 |
+
num_channels = default(num_channels, mel_spec_kwargs.n_mel_channels)
|
| 156 |
+
self.num_channels = num_channels
|
| 157 |
+
|
| 158 |
+
# Classifier-free guidance drop probabilities
|
| 159 |
+
self.audio_drop_prob = audio_drop_prob
|
| 160 |
+
self.cond_drop_prob = cond_drop_prob
|
| 161 |
+
|
| 162 |
+
# Transformer backbone
|
| 163 |
+
self.transformer = transformer
|
| 164 |
+
dim = transformer.dim
|
| 165 |
+
self.dim = dim
|
| 166 |
+
|
| 167 |
+
# Conditional flow matching
|
| 168 |
+
self.sigma = sigma
|
| 169 |
+
self.odeint_kwargs = odeint_kwargs
|
| 170 |
+
|
| 171 |
+
# Melody extractor
|
| 172 |
+
self.midi_extractor = MIDIExtractor(in_dim=num_channels)
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def device(self):
|
| 176 |
+
return next(self.parameters()).device
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def sample(
|
| 180 |
+
self,
|
| 181 |
+
cond: float["b n d"] | float["b nw"], # noqa: F722
|
| 182 |
+
text: int["b nt"] | list[str], # noqa: F722
|
| 183 |
+
duration: int | int["b"] | None = None, # noqa: F821
|
| 184 |
+
*,
|
| 185 |
+
midi_in: float["b n d"] | None = None,
|
| 186 |
+
lens: int["b"] | None = None, # noqa: F821
|
| 187 |
+
steps=32,
|
| 188 |
+
cfg_strength=1.0,
|
| 189 |
+
sway_sampling_coef=None,
|
| 190 |
+
seed: int | None = None,
|
| 191 |
+
max_duration=4096, # Maximum total length (including ICL prompt), ~190s
|
| 192 |
+
vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722
|
| 193 |
+
use_epss=True,
|
| 194 |
+
no_ref_audio=False,
|
| 195 |
+
duplicate_test=False,
|
| 196 |
+
t_inter=0.1,
|
| 197 |
+
t_shift=1.0, # Sampling timestep shift (ZipVoice-style)
|
| 198 |
+
guidance_scale=None,
|
| 199 |
+
edit_mask=None,
|
| 200 |
+
midi_p=None,
|
| 201 |
+
bound_p=None,
|
| 202 |
+
enable_melody_control=True,
|
| 203 |
+
):
|
| 204 |
+
self.eval()
|
| 205 |
+
|
| 206 |
+
assert isinstance(cond, torch.Tensor)
|
| 207 |
+
assert not edit_mask, "edit_mask is not supported in this mode"
|
| 208 |
+
assert not duplicate_test, "duplicate_test is not supported in this mode"
|
| 209 |
+
|
| 210 |
+
if self.melody_input_source == "student_model":
|
| 211 |
+
assert midi_p is None and bound_p is None
|
| 212 |
+
elif self.melody_input_source in {
|
| 213 |
+
"some_pretrain",
|
| 214 |
+
"some_pretrain_fuzzdisturb",
|
| 215 |
+
"some_pretrain_postprocess_embedding",
|
| 216 |
+
}:
|
| 217 |
+
assert midi_p is not None and bound_p is not None
|
| 218 |
+
elif self.melody_input_source == "none":
|
| 219 |
+
assert midi_p is None and bound_p is None
|
| 220 |
+
else:
|
| 221 |
+
raise ValueError(
|
| 222 |
+
f"Unsupported melody_input_source: {self.melody_input_source}"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# duration is the total latent sequence length
|
| 226 |
+
assert duration
|
| 227 |
+
|
| 228 |
+
cond = cond.to(next(self.parameters()).dtype)
|
| 229 |
+
|
| 230 |
+
# Extract or interpolate melody representation
|
| 231 |
+
if self.melody_input_source == "student_model":
|
| 232 |
+
midi, midi_bound = self.midi_extractor(midi_in)
|
| 233 |
+
|
| 234 |
+
elif self.melody_input_source == "some_pretrain":
|
| 235 |
+
midi, midi_bound = interpolation_midi_continuous(
|
| 236 |
+
midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1]
|
| 237 |
+
)
|
| 238 |
+
elif self.melody_input_source == "some_pretrain_fuzzdisturb":
|
| 239 |
+
midi, midi_bound = interpolation_midi_continuous(
|
| 240 |
+
midi_p=midi_p, bound_p=bound_p, total_len=text.shape[1]
|
| 241 |
+
)
|
| 242 |
+
midi = self.smoothMelody_MIDIFuzzDisturb(midi)
|
| 243 |
+
|
| 244 |
+
elif self.melody_input_source == "some_pretrain_postprocess_embedding":
|
| 245 |
+
midi_after_postprocess, _ = self.midi_extractor.postprocess(
|
| 246 |
+
midi=midi_p, bounds=bound_p, with_expand=True
|
| 247 |
+
)
|
| 248 |
+
midi = interpolation_midi_continuous_2_dim(
|
| 249 |
+
midi_p=midi_after_postprocess, bound_p=None, total_len=text.shape[1]
|
| 250 |
+
)
|
| 251 |
+
midi = self.smoothMelody_MIDIDigitalEmbedding(midi)
|
| 252 |
+
midi_bound = None
|
| 253 |
+
|
| 254 |
+
elif self.melody_input_source == "none":
|
| 255 |
+
midi = torch.zeros(
|
| 256 |
+
text.shape[0], text.shape[1], 128, dtype=cond.dtype, device=text.device
|
| 257 |
+
)
|
| 258 |
+
midi_bound = None
|
| 259 |
+
else:
|
| 260 |
+
raise NotImplementedError()
|
| 261 |
+
|
| 262 |
+
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
| 263 |
+
if not exists(lens):
|
| 264 |
+
lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
|
| 265 |
+
|
| 266 |
+
assert isinstance(text, torch.Tensor)
|
| 267 |
+
|
| 268 |
+
cond_mask = lens_to_mask(lens)
|
| 269 |
+
|
| 270 |
+
if edit_mask is not None:
|
| 271 |
+
cond_mask = cond_mask & edit_mask
|
| 272 |
+
|
| 273 |
+
if isinstance(duration, int):
|
| 274 |
+
duration = torch.full((batch,), duration, device=device, dtype=torch.long)
|
| 275 |
+
|
| 276 |
+
# Duration must be at least max(text_len, audio_prompt_len) + 1
|
| 277 |
+
duration = torch.maximum(
|
| 278 |
+
torch.maximum((text != 0).sum(dim=-1), lens) + 1, duration
|
| 279 |
+
)
|
| 280 |
+
duration = duration.clamp(max=max_duration)
|
| 281 |
+
|
| 282 |
+
max_duration = duration.amax()
|
| 283 |
+
|
| 284 |
+
# Duplicate test: interpolate between noise and conditioning
|
| 285 |
+
if duplicate_test:
|
| 286 |
+
test_cond = F.pad(
|
| 287 |
+
cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Zero-pad conditioning latent to max_duration
|
| 291 |
+
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value=0.0)
|
| 292 |
+
|
| 293 |
+
if no_ref_audio:
|
| 294 |
+
cond = torch.zeros_like(cond)
|
| 295 |
+
|
| 296 |
+
cond_mask = F.pad(
|
| 297 |
+
cond_mask, (0, max_duration - cond_mask.shape[-1]), value=False
|
| 298 |
+
)
|
| 299 |
+
cond_mask = cond_mask.unsqueeze(-1)
|
| 300 |
+
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
| 301 |
+
|
| 302 |
+
assert max_duration == midi.shape[1]
|
| 303 |
+
|
| 304 |
+
# Zero out melody in prompt region; optionally disable melody control entirely
|
| 305 |
+
if enable_melody_control:
|
| 306 |
+
midi = torch.where(cond_mask, torch.zeros_like(midi), midi)
|
| 307 |
+
else:
|
| 308 |
+
midi = torch.zeros_like(midi)
|
| 309 |
+
|
| 310 |
+
if self.is_tts_pretrain:
|
| 311 |
+
midi = torch.zeros_like(midi)
|
| 312 |
+
|
| 313 |
+
# For batched inference, explicit mask prevents causal attention fallback
|
| 314 |
+
if batch > 1:
|
| 315 |
+
mask = lens_to_mask(duration)
|
| 316 |
+
else:
|
| 317 |
+
mask = None
|
| 318 |
+
|
| 319 |
+
# ODE velocity function
|
| 320 |
+
def fn(t, x):
|
| 321 |
+
if cfg_strength < 1e-5:
|
| 322 |
+
# No classifier-free guidance
|
| 323 |
+
pred, _ = self.transformer(
|
| 324 |
+
x=x,
|
| 325 |
+
cond=step_cond,
|
| 326 |
+
text=text,
|
| 327 |
+
midi=midi,
|
| 328 |
+
time=t,
|
| 329 |
+
mask=mask,
|
| 330 |
+
drop_audio_cond=False,
|
| 331 |
+
drop_text=False,
|
| 332 |
+
drop_midi=not enable_melody_control,
|
| 333 |
+
cache=False,
|
| 334 |
+
)
|
| 335 |
+
return pred
|
| 336 |
+
else:
|
| 337 |
+
if self.use_guidance_scale_embed:
|
| 338 |
+
# Distilled model with built-in CFG
|
| 339 |
+
assert enable_melody_control
|
| 340 |
+
pred_cfg, _ = self.transformer(
|
| 341 |
+
x=x,
|
| 342 |
+
cond=step_cond,
|
| 343 |
+
text=text,
|
| 344 |
+
midi=midi,
|
| 345 |
+
time=t,
|
| 346 |
+
mask=mask,
|
| 347 |
+
drop_audio_cond=False,
|
| 348 |
+
drop_text=False,
|
| 349 |
+
drop_midi=not enable_melody_control,
|
| 350 |
+
cache=False,
|
| 351 |
+
guidance_scale=torch.tensor([guidance_scale], device=device),
|
| 352 |
+
)
|
| 353 |
+
print(
|
| 354 |
+
f"CFG 参数调节无作用! 蒸馏之后的,输入CFG为 guidance_scale={guidance_scale}"
|
| 355 |
+
)
|
| 356 |
+
return pred_cfg
|
| 357 |
+
else:
|
| 358 |
+
# Standard CFG: cond + uncond forward
|
| 359 |
+
# BUG If enable_melody_control is False, there might be a slight issue here
|
| 360 |
+
assert guidance_scale is not None
|
| 361 |
+
pred_cfg, _ = self.transformer(
|
| 362 |
+
x=x,
|
| 363 |
+
cond=step_cond,
|
| 364 |
+
text=text,
|
| 365 |
+
midi=midi,
|
| 366 |
+
time=t,
|
| 367 |
+
mask=mask,
|
| 368 |
+
cfg_infer=True,
|
| 369 |
+
cache=False,
|
| 370 |
+
cfg_infer_ids=(True, False, False, True),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
pred, pred_drop_all_cond = torch.chunk(pred_cfg, 2, dim=0)
|
| 374 |
+
return pred + (pred - pred_drop_all_cond) * float(guidance_scale)
|
| 375 |
+
|
| 376 |
+
# Generate initial noise (per-sample seeding for batch consistency)
|
| 377 |
+
y0 = []
|
| 378 |
+
for dur in duration:
|
| 379 |
+
if exists(seed):
|
| 380 |
+
torch.manual_seed(seed)
|
| 381 |
+
y0.append(
|
| 382 |
+
torch.randn(
|
| 383 |
+
dur, self.num_channels, device=self.device, dtype=step_cond.dtype
|
| 384 |
+
)
|
| 385 |
+
)
|
| 386 |
+
y0 = pad_sequence(y0, padding_value=0, batch_first=True)
|
| 387 |
+
|
| 388 |
+
t_start = 0
|
| 389 |
+
|
| 390 |
+
if duplicate_test:
|
| 391 |
+
t_start = t_inter
|
| 392 |
+
y0 = (1 - t_start) * y0 + t_start * test_cond
|
| 393 |
+
steps = int(steps * (1 - t_start))
|
| 394 |
+
|
| 395 |
+
# Build timestep schedule
|
| 396 |
+
assert not use_epss and sway_sampling_coef is None, (
|
| 397 |
+
"Use timestep shift instead of the strategy in F5"
|
| 398 |
+
)
|
| 399 |
+
if t_start == 0 and use_epss:
|
| 400 |
+
# Empirically Pruned Step Sampling for low NFE
|
| 401 |
+
t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype)
|
| 402 |
+
else:
|
| 403 |
+
t = torch.linspace(
|
| 404 |
+
t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if sway_sampling_coef is not None:
|
| 408 |
+
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
| 409 |
+
|
| 410 |
+
# Apply timestep shift
|
| 411 |
+
t = t_shift * t / (1 + (t_shift - 1) * t)
|
| 412 |
+
|
| 413 |
+
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
| 414 |
+
self.transformer.clear_cache()
|
| 415 |
+
|
| 416 |
+
sampled = trajectory[-1]
|
| 417 |
+
out = sampled
|
| 418 |
+
|
| 419 |
+
if exists(vocoder):
|
| 420 |
+
out = out.permute(0, 2, 1)
|
| 421 |
+
out = vocoder(out)
|
| 422 |
+
|
| 423 |
+
return out, trajectory
|
src/YingMusicSinger/models/modules.py
ADDED
|
@@ -0,0 +1,961 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import torchaudio
|
| 19 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 20 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
| 21 |
+
|
| 22 |
+
from src.YingMusicSinger.utils.common import is_package_available
|
| 23 |
+
|
| 24 |
+
# raw wav to mel spec
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
mel_basis_cache = {}
|
| 28 |
+
hann_window_cache = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_bigvgan_mel_spectrogram(
|
| 32 |
+
waveform,
|
| 33 |
+
n_fft=1024,
|
| 34 |
+
n_mel_channels=100,
|
| 35 |
+
target_sample_rate=24000,
|
| 36 |
+
hop_length=256,
|
| 37 |
+
win_length=1024,
|
| 38 |
+
fmin=0,
|
| 39 |
+
fmax=None,
|
| 40 |
+
center=False,
|
| 41 |
+
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
| 42 |
+
device = waveform.device
|
| 43 |
+
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
| 44 |
+
|
| 45 |
+
if key not in mel_basis_cache:
|
| 46 |
+
mel = librosa_mel_fn(
|
| 47 |
+
sr=target_sample_rate,
|
| 48 |
+
n_fft=n_fft,
|
| 49 |
+
n_mels=n_mel_channels,
|
| 50 |
+
fmin=fmin,
|
| 51 |
+
fmax=fmax,
|
| 52 |
+
)
|
| 53 |
+
mel_basis_cache[key] = (
|
| 54 |
+
torch.from_numpy(mel).float().to(device)
|
| 55 |
+
) # TODO: why they need .float()?
|
| 56 |
+
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
| 57 |
+
|
| 58 |
+
mel_basis = mel_basis_cache[key]
|
| 59 |
+
hann_window = hann_window_cache[key]
|
| 60 |
+
|
| 61 |
+
padding = (n_fft - hop_length) // 2
|
| 62 |
+
waveform = torch.nn.functional.pad(
|
| 63 |
+
waveform.unsqueeze(1), (padding, padding), mode="reflect"
|
| 64 |
+
).squeeze(1)
|
| 65 |
+
|
| 66 |
+
spec = torch.stft(
|
| 67 |
+
waveform,
|
| 68 |
+
n_fft,
|
| 69 |
+
hop_length=hop_length,
|
| 70 |
+
win_length=win_length,
|
| 71 |
+
window=hann_window,
|
| 72 |
+
center=center,
|
| 73 |
+
pad_mode="reflect",
|
| 74 |
+
normalized=False,
|
| 75 |
+
onesided=True,
|
| 76 |
+
return_complex=True,
|
| 77 |
+
)
|
| 78 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 79 |
+
|
| 80 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
| 81 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 82 |
+
|
| 83 |
+
return mel_spec
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_vocos_mel_spectrogram(
|
| 87 |
+
waveform,
|
| 88 |
+
n_fft=1024,
|
| 89 |
+
n_mel_channels=100,
|
| 90 |
+
target_sample_rate=24000,
|
| 91 |
+
hop_length=256,
|
| 92 |
+
win_length=1024,
|
| 93 |
+
):
|
| 94 |
+
mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 95 |
+
sample_rate=target_sample_rate,
|
| 96 |
+
n_fft=n_fft,
|
| 97 |
+
win_length=win_length,
|
| 98 |
+
hop_length=hop_length,
|
| 99 |
+
n_mels=n_mel_channels,
|
| 100 |
+
power=1,
|
| 101 |
+
center=True,
|
| 102 |
+
normalized=False,
|
| 103 |
+
norm=None,
|
| 104 |
+
).to(waveform.device)
|
| 105 |
+
if len(waveform.shape) == 3:
|
| 106 |
+
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
| 107 |
+
|
| 108 |
+
assert len(waveform.shape) == 2
|
| 109 |
+
|
| 110 |
+
mel = mel_stft(waveform)
|
| 111 |
+
mel = mel.clamp(min=1e-5).log()
|
| 112 |
+
return mel
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class MelSpec(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
n_fft=1024,
|
| 119 |
+
hop_length=256,
|
| 120 |
+
win_length=1024,
|
| 121 |
+
n_mel_channels=100,
|
| 122 |
+
target_sample_rate=24_000,
|
| 123 |
+
mel_spec_type="vocos",
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print(
|
| 127 |
+
"We only support two extract mel backend: vocos or bigvgan"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.n_fft = n_fft
|
| 131 |
+
self.hop_length = hop_length
|
| 132 |
+
self.win_length = win_length
|
| 133 |
+
self.n_mel_channels = n_mel_channels
|
| 134 |
+
self.target_sample_rate = target_sample_rate
|
| 135 |
+
|
| 136 |
+
if mel_spec_type == "vocos":
|
| 137 |
+
self.extractor = get_vocos_mel_spectrogram
|
| 138 |
+
elif mel_spec_type == "bigvgan":
|
| 139 |
+
self.extractor = get_bigvgan_mel_spectrogram
|
| 140 |
+
|
| 141 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 142 |
+
|
| 143 |
+
def forward(self, wav):
|
| 144 |
+
if self.dummy.device != wav.device:
|
| 145 |
+
self.to(wav.device)
|
| 146 |
+
|
| 147 |
+
mel = self.extractor(
|
| 148 |
+
waveform=wav,
|
| 149 |
+
n_fft=self.n_fft,
|
| 150 |
+
n_mel_channels=self.n_mel_channels,
|
| 151 |
+
target_sample_rate=self.target_sample_rate,
|
| 152 |
+
hop_length=self.hop_length,
|
| 153 |
+
win_length=self.win_length,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
return mel
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# sinusoidal position embedding
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class SinusPositionEmbedding(nn.Module):
|
| 163 |
+
def __init__(self, dim):
|
| 164 |
+
super().__init__()
|
| 165 |
+
self.dim = dim
|
| 166 |
+
|
| 167 |
+
def forward(self, x, scale=1000):
|
| 168 |
+
device = x.device
|
| 169 |
+
half_dim = self.dim // 2
|
| 170 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 171 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 172 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 173 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 174 |
+
return emb
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# convolutional position embedding
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ConvPositionEmbedding(nn.Module):
|
| 181 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
| 182 |
+
super().__init__()
|
| 183 |
+
assert kernel_size % 2 != 0
|
| 184 |
+
self.conv1d = nn.Sequential(
|
| 185 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 186 |
+
nn.Mish(),
|
| 187 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 188 |
+
nn.Mish(),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None):
|
| 192 |
+
if mask is not None:
|
| 193 |
+
mask = mask[..., None]
|
| 194 |
+
x = x.masked_fill(~mask, 0.0)
|
| 195 |
+
|
| 196 |
+
x = x.permute(0, 2, 1)
|
| 197 |
+
x = self.conv1d(x)
|
| 198 |
+
out = x.permute(0, 2, 1)
|
| 199 |
+
|
| 200 |
+
if mask is not None:
|
| 201 |
+
out = out.masked_fill(~mask, 0.0)
|
| 202 |
+
|
| 203 |
+
return out
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# rotary positional embedding related
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def precompute_freqs_cis(
|
| 210 |
+
dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0
|
| 211 |
+
):
|
| 212 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 213 |
+
# has some connection to NTK literature
|
| 214 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 215 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 216 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 217 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 218 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 219 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 220 |
+
freqs_cos = torch.cos(freqs) # real part
|
| 221 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
| 222 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 226 |
+
# length = length if isinstance(length, int) else length.max()
|
| 227 |
+
scale = scale * torch.ones_like(
|
| 228 |
+
start, dtype=torch.float32
|
| 229 |
+
) # in case scale is a scalar
|
| 230 |
+
pos = (
|
| 231 |
+
start.unsqueeze(1)
|
| 232 |
+
+ (
|
| 233 |
+
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0)
|
| 234 |
+
* scale.unsqueeze(1)
|
| 235 |
+
).long()
|
| 236 |
+
)
|
| 237 |
+
# avoid extra long error.
|
| 238 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 239 |
+
return pos
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class GRN(nn.Module):
|
| 246 |
+
def __init__(self, dim):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
| 249 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
| 253 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 254 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
| 258 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class ConvNeXtV2Block(nn.Module):
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
dim: int,
|
| 265 |
+
intermediate_dim: int,
|
| 266 |
+
dilation: int = 1,
|
| 267 |
+
):
|
| 268 |
+
super().__init__()
|
| 269 |
+
padding = (dilation * (7 - 1)) // 2
|
| 270 |
+
self.dwconv = nn.Conv1d(
|
| 271 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
| 272 |
+
) # depthwise conv
|
| 273 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 274 |
+
self.pwconv1 = nn.Linear(
|
| 275 |
+
dim, intermediate_dim
|
| 276 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 277 |
+
self.act = nn.GELU()
|
| 278 |
+
self.grn = GRN(intermediate_dim)
|
| 279 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
residual = x
|
| 283 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
| 284 |
+
x = self.dwconv(x)
|
| 285 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
| 286 |
+
x = self.norm(x)
|
| 287 |
+
x = self.pwconv1(x)
|
| 288 |
+
x = self.act(x)
|
| 289 |
+
x = self.grn(x)
|
| 290 |
+
x = self.pwconv2(x)
|
| 291 |
+
return residual + x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# RMSNorm
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class RMSNorm(nn.Module):
|
| 298 |
+
def __init__(self, dim: int, eps: float):
|
| 299 |
+
super().__init__()
|
| 300 |
+
self.eps = eps
|
| 301 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 302 |
+
self.native_rms_norm = float(torch.__version__[:3]) >= 2.4
|
| 303 |
+
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
if self.native_rms_norm:
|
| 306 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 307 |
+
x = x.to(self.weight.dtype)
|
| 308 |
+
x = F.rms_norm(
|
| 309 |
+
x, normalized_shape=(x.shape[-1],), weight=self.weight, eps=self.eps
|
| 310 |
+
)
|
| 311 |
+
else:
|
| 312 |
+
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 313 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 314 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 315 |
+
x = x.to(self.weight.dtype)
|
| 316 |
+
x = x * self.weight
|
| 317 |
+
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# AdaLayerNorm
|
| 322 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class AdaLayerNorm(nn.Module):
|
| 326 |
+
def __init__(self, dim):
|
| 327 |
+
super().__init__()
|
| 328 |
+
|
| 329 |
+
self.silu = nn.SiLU()
|
| 330 |
+
self.linear = nn.Linear(dim, dim * 6)
|
| 331 |
+
|
| 332 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 333 |
+
|
| 334 |
+
def forward(self, x, emb=None):
|
| 335 |
+
emb = self.linear(self.silu(emb))
|
| 336 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(
|
| 337 |
+
emb, 6, dim=1
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 341 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# AdaLayerNorm for final layer
|
| 345 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class AdaLayerNorm_Final(nn.Module):
|
| 349 |
+
def __init__(self, dim):
|
| 350 |
+
super().__init__()
|
| 351 |
+
|
| 352 |
+
self.silu = nn.SiLU()
|
| 353 |
+
self.linear = nn.Linear(dim, dim * 2)
|
| 354 |
+
|
| 355 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 356 |
+
|
| 357 |
+
def forward(self, x, emb):
|
| 358 |
+
emb = self.linear(self.silu(emb))
|
| 359 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 360 |
+
|
| 361 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 362 |
+
return x
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# FeedForward
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class FeedForward(nn.Module):
|
| 369 |
+
def __init__(
|
| 370 |
+
self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"
|
| 371 |
+
):
|
| 372 |
+
super().__init__()
|
| 373 |
+
inner_dim = int(dim * mult)
|
| 374 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 375 |
+
|
| 376 |
+
activation = nn.GELU(approximate=approximate)
|
| 377 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
| 378 |
+
self.ff = nn.Sequential(
|
| 379 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def forward(self, x):
|
| 383 |
+
return self.ff(x)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# Attention with possible joint part
|
| 387 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class Attention(nn.Module):
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
processor: JointAttnProcessor | AttnProcessor,
|
| 394 |
+
dim: int,
|
| 395 |
+
heads: int = 8,
|
| 396 |
+
dim_head: int = 64,
|
| 397 |
+
dropout: float = 0.0,
|
| 398 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
| 399 |
+
context_pre_only: bool = False,
|
| 400 |
+
qk_norm: Optional[str] = None,
|
| 401 |
+
):
|
| 402 |
+
super().__init__()
|
| 403 |
+
|
| 404 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 405 |
+
raise ImportError(
|
| 406 |
+
"Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
self.processor = processor
|
| 410 |
+
|
| 411 |
+
self.dim = dim
|
| 412 |
+
self.heads = heads
|
| 413 |
+
self.inner_dim = dim_head * heads
|
| 414 |
+
self.dropout = dropout
|
| 415 |
+
|
| 416 |
+
self.context_dim = context_dim
|
| 417 |
+
self.context_pre_only = context_pre_only
|
| 418 |
+
|
| 419 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
| 420 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
| 421 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
| 422 |
+
|
| 423 |
+
if qk_norm is None:
|
| 424 |
+
self.q_norm = None
|
| 425 |
+
self.k_norm = None
|
| 426 |
+
elif qk_norm == "rms_norm":
|
| 427 |
+
self.q_norm = RMSNorm(dim_head, eps=1e-6)
|
| 428 |
+
self.k_norm = RMSNorm(dim_head, eps=1e-6)
|
| 429 |
+
else:
|
| 430 |
+
raise ValueError(f"Unimplemented qk_norm: {qk_norm}")
|
| 431 |
+
|
| 432 |
+
if self.context_dim is not None:
|
| 433 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
| 434 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
| 435 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
| 436 |
+
if qk_norm is None:
|
| 437 |
+
self.c_q_norm = None
|
| 438 |
+
self.c_k_norm = None
|
| 439 |
+
elif qk_norm == "rms_norm":
|
| 440 |
+
self.c_q_norm = RMSNorm(dim_head, eps=1e-6)
|
| 441 |
+
self.c_k_norm = RMSNorm(dim_head, eps=1e-6)
|
| 442 |
+
|
| 443 |
+
self.to_out = nn.ModuleList([])
|
| 444 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
| 445 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 446 |
+
|
| 447 |
+
if self.context_dim is not None and not self.context_pre_only:
|
| 448 |
+
self.to_out_c = nn.Linear(self.inner_dim, context_dim)
|
| 449 |
+
|
| 450 |
+
def forward(
|
| 451 |
+
self,
|
| 452 |
+
x: float["b n d"], # noised input x
|
| 453 |
+
c: float["b n d"] = None, # context c
|
| 454 |
+
mask: bool["b n"] | None = None,
|
| 455 |
+
rope=None, # rotary position embedding for x
|
| 456 |
+
c_rope=None, # rotary position embedding for c
|
| 457 |
+
) -> torch.Tensor:
|
| 458 |
+
if c is not None:
|
| 459 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
| 460 |
+
else:
|
| 461 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# Attention processor
|
| 465 |
+
|
| 466 |
+
if is_package_available("flash_attn"):
|
| 467 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 468 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class AttnProcessor:
|
| 472 |
+
def __init__(
|
| 473 |
+
self,
|
| 474 |
+
pe_attn_head: int
|
| 475 |
+
| None = None, # number of attention head to apply rope, None for all
|
| 476 |
+
attn_backend: str = "torch", # "torch" or "flash_attn"
|
| 477 |
+
attn_mask_enabled: bool = True,
|
| 478 |
+
):
|
| 479 |
+
if attn_backend == "flash_attn":
|
| 480 |
+
assert is_package_available("flash_attn"), (
|
| 481 |
+
"Please install flash-attn first."
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
self.pe_attn_head = pe_attn_head
|
| 485 |
+
self.attn_backend = attn_backend
|
| 486 |
+
self.attn_mask_enabled = attn_mask_enabled
|
| 487 |
+
|
| 488 |
+
def __call__(
|
| 489 |
+
self,
|
| 490 |
+
attn: Attention,
|
| 491 |
+
x: float["b n d"], # noised input x
|
| 492 |
+
mask: bool["b n"] | None = None,
|
| 493 |
+
rope=None, # rotary position embedding
|
| 494 |
+
) -> torch.FloatTensor:
|
| 495 |
+
batch_size = x.shape[0]
|
| 496 |
+
|
| 497 |
+
# `sample` projections
|
| 498 |
+
query = attn.to_q(x)
|
| 499 |
+
key = attn.to_k(x)
|
| 500 |
+
value = attn.to_v(x)
|
| 501 |
+
|
| 502 |
+
# attention
|
| 503 |
+
inner_dim = key.shape[-1]
|
| 504 |
+
head_dim = inner_dim // attn.heads
|
| 505 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 506 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 507 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 508 |
+
|
| 509 |
+
# qk norm
|
| 510 |
+
if attn.q_norm is not None:
|
| 511 |
+
query = attn.q_norm(query)
|
| 512 |
+
if attn.k_norm is not None:
|
| 513 |
+
key = attn.k_norm(key)
|
| 514 |
+
|
| 515 |
+
# apply rotary position embedding
|
| 516 |
+
if rope is not None:
|
| 517 |
+
freqs, xpos_scale = rope
|
| 518 |
+
q_xpos_scale, k_xpos_scale = (
|
| 519 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if self.pe_attn_head is not None:
|
| 523 |
+
pn = self.pe_attn_head
|
| 524 |
+
query[:, :pn, :, :] = apply_rotary_pos_emb(
|
| 525 |
+
query[:, :pn, :, :], freqs, q_xpos_scale
|
| 526 |
+
)
|
| 527 |
+
key[:, :pn, :, :] = apply_rotary_pos_emb(
|
| 528 |
+
key[:, :pn, :, :], freqs, k_xpos_scale
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 532 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 533 |
+
|
| 534 |
+
if self.attn_backend == "torch":
|
| 535 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 536 |
+
if self.attn_mask_enabled and mask is not None:
|
| 537 |
+
attn_mask = mask
|
| 538 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 539 |
+
attn_mask = attn_mask.expand(
|
| 540 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
attn_mask = None
|
| 544 |
+
x = F.scaled_dot_product_attention(
|
| 545 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
| 546 |
+
)
|
| 547 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 548 |
+
|
| 549 |
+
elif self.attn_backend == "flash_attn":
|
| 550 |
+
query = query.transpose(1, 2) # [b, h, n, d] -> [b, n, h, d]
|
| 551 |
+
key = key.transpose(1, 2)
|
| 552 |
+
value = value.transpose(1, 2)
|
| 553 |
+
if self.attn_mask_enabled and mask is not None:
|
| 554 |
+
query, indices, q_cu_seqlens, q_max_seqlen_in_batch, _ = unpad_input(
|
| 555 |
+
query, mask
|
| 556 |
+
)
|
| 557 |
+
key, _, k_cu_seqlens, k_max_seqlen_in_batch, _ = unpad_input(key, mask)
|
| 558 |
+
value, _, _, _, _ = unpad_input(value, mask)
|
| 559 |
+
x = flash_attn_varlen_func(
|
| 560 |
+
query,
|
| 561 |
+
key,
|
| 562 |
+
value,
|
| 563 |
+
q_cu_seqlens,
|
| 564 |
+
k_cu_seqlens,
|
| 565 |
+
q_max_seqlen_in_batch,
|
| 566 |
+
k_max_seqlen_in_batch,
|
| 567 |
+
)
|
| 568 |
+
x = pad_input(x, indices, batch_size, q_max_seqlen_in_batch)
|
| 569 |
+
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
| 570 |
+
else:
|
| 571 |
+
x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
|
| 572 |
+
x = x.reshape(batch_size, -1, attn.heads * head_dim)
|
| 573 |
+
|
| 574 |
+
x = x.to(query.dtype)
|
| 575 |
+
|
| 576 |
+
# linear proj
|
| 577 |
+
x = attn.to_out[0](x)
|
| 578 |
+
# dropout
|
| 579 |
+
x = attn.to_out[1](x)
|
| 580 |
+
|
| 581 |
+
if mask is not None:
|
| 582 |
+
mask = mask.unsqueeze(-1)
|
| 583 |
+
x = x.masked_fill(~mask, 0.0)
|
| 584 |
+
|
| 585 |
+
return x
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
# Joint Attention processor for MM-DiT
|
| 589 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class JointAttnProcessor:
|
| 593 |
+
def __init__(self):
|
| 594 |
+
pass
|
| 595 |
+
|
| 596 |
+
def __call__(
|
| 597 |
+
self,
|
| 598 |
+
attn: Attention,
|
| 599 |
+
x: float["b n d"], # noised input x
|
| 600 |
+
c: float["b nt d"] = None, # context c, here text
|
| 601 |
+
mask: bool["b n"] | None = None,
|
| 602 |
+
rope=None, # rotary position embedding for x
|
| 603 |
+
c_rope=None, # rotary position embedding for c
|
| 604 |
+
) -> torch.FloatTensor:
|
| 605 |
+
residual = x
|
| 606 |
+
|
| 607 |
+
batch_size = c.shape[0]
|
| 608 |
+
|
| 609 |
+
# `sample` projections
|
| 610 |
+
query = attn.to_q(x)
|
| 611 |
+
key = attn.to_k(x)
|
| 612 |
+
value = attn.to_v(x)
|
| 613 |
+
|
| 614 |
+
# `context` projections
|
| 615 |
+
c_query = attn.to_q_c(c)
|
| 616 |
+
c_key = attn.to_k_c(c)
|
| 617 |
+
c_value = attn.to_v_c(c)
|
| 618 |
+
|
| 619 |
+
# attention
|
| 620 |
+
inner_dim = key.shape[-1]
|
| 621 |
+
head_dim = inner_dim // attn.heads
|
| 622 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 623 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 624 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 625 |
+
c_query = c_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 626 |
+
c_key = c_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 627 |
+
c_value = c_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 628 |
+
|
| 629 |
+
# qk norm
|
| 630 |
+
if attn.q_norm is not None:
|
| 631 |
+
query = attn.q_norm(query)
|
| 632 |
+
if attn.k_norm is not None:
|
| 633 |
+
key = attn.k_norm(key)
|
| 634 |
+
if attn.c_q_norm is not None:
|
| 635 |
+
c_query = attn.c_q_norm(c_query)
|
| 636 |
+
if attn.c_k_norm is not None:
|
| 637 |
+
c_key = attn.c_k_norm(c_key)
|
| 638 |
+
|
| 639 |
+
# apply rope for context and noised input independently
|
| 640 |
+
if rope is not None:
|
| 641 |
+
freqs, xpos_scale = rope
|
| 642 |
+
q_xpos_scale, k_xpos_scale = (
|
| 643 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 644 |
+
)
|
| 645 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 646 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 647 |
+
if c_rope is not None:
|
| 648 |
+
freqs, xpos_scale = c_rope
|
| 649 |
+
q_xpos_scale, k_xpos_scale = (
|
| 650 |
+
(xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 651 |
+
)
|
| 652 |
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
| 653 |
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
| 654 |
+
|
| 655 |
+
# joint attention
|
| 656 |
+
query = torch.cat([query, c_query], dim=2)
|
| 657 |
+
key = torch.cat([key, c_key], dim=2)
|
| 658 |
+
value = torch.cat([value, c_value], dim=2)
|
| 659 |
+
|
| 660 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 661 |
+
if mask is not None:
|
| 662 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
| 663 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 664 |
+
attn_mask = attn_mask.expand(
|
| 665 |
+
batch_size, attn.heads, query.shape[-2], key.shape[-2]
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
attn_mask = None
|
| 669 |
+
|
| 670 |
+
x = F.scaled_dot_product_attention(
|
| 671 |
+
query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False
|
| 672 |
+
)
|
| 673 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 674 |
+
x = x.to(query.dtype)
|
| 675 |
+
|
| 676 |
+
# Split the attention outputs.
|
| 677 |
+
x, c = (
|
| 678 |
+
x[:, : residual.shape[1]],
|
| 679 |
+
x[:, residual.shape[1] :],
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# linear proj
|
| 683 |
+
x = attn.to_out[0](x)
|
| 684 |
+
# dropout
|
| 685 |
+
x = attn.to_out[1](x)
|
| 686 |
+
if not attn.context_pre_only:
|
| 687 |
+
c = attn.to_out_c(c)
|
| 688 |
+
|
| 689 |
+
if mask is not None:
|
| 690 |
+
mask = mask.unsqueeze(-1)
|
| 691 |
+
x = x.masked_fill(~mask, 0.0)
|
| 692 |
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
| 693 |
+
|
| 694 |
+
return x, c
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
# DiT Block
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class DiTBlock(nn.Module):
|
| 701 |
+
def __init__(
|
| 702 |
+
self,
|
| 703 |
+
dim,
|
| 704 |
+
heads,
|
| 705 |
+
dim_head,
|
| 706 |
+
ff_mult=4,
|
| 707 |
+
dropout=0.1,
|
| 708 |
+
qk_norm=None,
|
| 709 |
+
pe_attn_head=None,
|
| 710 |
+
attn_backend="torch", # "torch" or "flash_attn"
|
| 711 |
+
attn_mask_enabled=True,
|
| 712 |
+
):
|
| 713 |
+
super().__init__()
|
| 714 |
+
|
| 715 |
+
self.attn_norm = AdaLayerNorm(dim)
|
| 716 |
+
self.attn = Attention(
|
| 717 |
+
processor=AttnProcessor(
|
| 718 |
+
pe_attn_head=pe_attn_head,
|
| 719 |
+
attn_backend=attn_backend,
|
| 720 |
+
attn_mask_enabled=attn_mask_enabled,
|
| 721 |
+
),
|
| 722 |
+
dim=dim,
|
| 723 |
+
heads=heads,
|
| 724 |
+
dim_head=dim_head,
|
| 725 |
+
dropout=dropout,
|
| 726 |
+
qk_norm=qk_norm,
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 730 |
+
self.ff = FeedForward(
|
| 731 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
| 735 |
+
# pre-norm & modulation for attention input
|
| 736 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
| 737 |
+
|
| 738 |
+
# attention
|
| 739 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
| 740 |
+
|
| 741 |
+
# process attention output for input x
|
| 742 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 743 |
+
|
| 744 |
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 745 |
+
ff_output = self.ff(norm)
|
| 746 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
| 747 |
+
|
| 748 |
+
return x
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
class MMDiTBlock(nn.Module):
|
| 755 |
+
r"""
|
| 756 |
+
modified from diffusers/src/diffusers/models/attention.py
|
| 757 |
+
|
| 758 |
+
notes.
|
| 759 |
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
| 760 |
+
_x: noised input related. (right part)
|
| 761 |
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
| 762 |
+
"""
|
| 763 |
+
|
| 764 |
+
def __init__(
|
| 765 |
+
self,
|
| 766 |
+
dim,
|
| 767 |
+
heads,
|
| 768 |
+
dim_head,
|
| 769 |
+
ff_mult=4,
|
| 770 |
+
dropout=0.1,
|
| 771 |
+
context_dim=None,
|
| 772 |
+
context_pre_only=False,
|
| 773 |
+
qk_norm=None,
|
| 774 |
+
):
|
| 775 |
+
super().__init__()
|
| 776 |
+
if context_dim is None:
|
| 777 |
+
context_dim = dim
|
| 778 |
+
self.context_pre_only = context_pre_only
|
| 779 |
+
|
| 780 |
+
self.attn_norm_c = (
|
| 781 |
+
AdaLayerNorm_Final(context_dim)
|
| 782 |
+
if context_pre_only
|
| 783 |
+
else AdaLayerNorm(context_dim)
|
| 784 |
+
)
|
| 785 |
+
self.attn_norm_x = AdaLayerNorm(dim)
|
| 786 |
+
self.attn = Attention(
|
| 787 |
+
processor=JointAttnProcessor(),
|
| 788 |
+
dim=dim,
|
| 789 |
+
heads=heads,
|
| 790 |
+
dim_head=dim_head,
|
| 791 |
+
dropout=dropout,
|
| 792 |
+
context_dim=context_dim,
|
| 793 |
+
context_pre_only=context_pre_only,
|
| 794 |
+
qk_norm=qk_norm,
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if not context_pre_only:
|
| 798 |
+
self.ff_norm_c = nn.LayerNorm(
|
| 799 |
+
context_dim, elementwise_affine=False, eps=1e-6
|
| 800 |
+
)
|
| 801 |
+
self.ff_c = FeedForward(
|
| 802 |
+
dim=context_dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
| 803 |
+
)
|
| 804 |
+
else:
|
| 805 |
+
self.ff_norm_c = None
|
| 806 |
+
self.ff_c = None
|
| 807 |
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 808 |
+
self.ff_x = FeedForward(
|
| 809 |
+
dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh"
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
def forward(
|
| 813 |
+
self, x, c, t, mask=None, rope=None, c_rope=None
|
| 814 |
+
): # x: noised input, c: context, t: time embedding
|
| 815 |
+
# pre-norm & modulation for attention input
|
| 816 |
+
if self.context_pre_only:
|
| 817 |
+
norm_c = self.attn_norm_c(c, t)
|
| 818 |
+
else:
|
| 819 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(
|
| 820 |
+
c, emb=t
|
| 821 |
+
)
|
| 822 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(
|
| 823 |
+
x, emb=t
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
# attention
|
| 827 |
+
x_attn_output, c_attn_output = self.attn(
|
| 828 |
+
x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# process attention output for context c
|
| 832 |
+
if self.context_pre_only:
|
| 833 |
+
c = None
|
| 834 |
+
else: # if not last layer
|
| 835 |
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
| 836 |
+
|
| 837 |
+
norm_c = (
|
| 838 |
+
self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 839 |
+
)
|
| 840 |
+
c_ff_output = self.ff_c(norm_c)
|
| 841 |
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
| 842 |
+
|
| 843 |
+
# process attention output for input x
|
| 844 |
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
| 845 |
+
|
| 846 |
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
| 847 |
+
x_ff_output = self.ff_x(norm_x)
|
| 848 |
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
| 849 |
+
|
| 850 |
+
return c, x
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
# time step conditioning embedding
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
# class TimestepEmbedding(nn.Module):
|
| 857 |
+
# def __init__(self, dim, freq_embed_dim=256):
|
| 858 |
+
# super().__init__()
|
| 859 |
+
# self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 860 |
+
# self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 861 |
+
|
| 862 |
+
# def forward(self, timestep: float["b"]):
|
| 863 |
+
# time_hidden = self.time_embed(timestep)
|
| 864 |
+
# time_hidden = time_hidden.to(timestep.dtype)
|
| 865 |
+
# time = self.time_mlp(time_hidden) # b d
|
| 866 |
+
# return time
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
def zipvoice_timestep_embedding(timesteps, dim, max_period=10000):
|
| 870 |
+
"""Create sinusoidal timestep embeddings.
|
| 871 |
+
|
| 872 |
+
:param timesteps: shape of (N) or (N, T)
|
| 873 |
+
:param dim: the dimension of the output.
|
| 874 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 875 |
+
:return: an Tensor of positional embeddings. shape of (N, dim) or (T, N, dim)
|
| 876 |
+
"""
|
| 877 |
+
half = dim // 2
|
| 878 |
+
freqs = torch.exp(
|
| 879 |
+
-math.log(max_period)
|
| 880 |
+
* torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device)
|
| 881 |
+
/ half
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
if timesteps.dim() == 2:
|
| 885 |
+
timesteps = timesteps.transpose(0, 1) # (N, T) -> (T, N)
|
| 886 |
+
|
| 887 |
+
args = timesteps[..., None].float() * freqs[None]
|
| 888 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 889 |
+
if dim % 2:
|
| 890 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[..., :1])], dim=-1)
|
| 891 |
+
return embedding
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
| 895 |
+
"""
|
| 896 |
+
Behaves like a constructor of a modified version of nn.Linear
|
| 897 |
+
that gives an easy way to set the default initial parameter scale.
|
| 898 |
+
|
| 899 |
+
Args:
|
| 900 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 901 |
+
e.g. in_features, out_features, bias=False.
|
| 902 |
+
|
| 903 |
+
initial_scale: you can override this if you want to increase
|
| 904 |
+
or decrease the initial magnitude of the module's output
|
| 905 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 906 |
+
Another option, if you want to do something like this, is
|
| 907 |
+
to re-initialize the parameters.
|
| 908 |
+
"""
|
| 909 |
+
ans = nn.Linear(*args, **kwargs)
|
| 910 |
+
with torch.no_grad():
|
| 911 |
+
ans.weight[:] *= initial_scale
|
| 912 |
+
if ans.bias is not None:
|
| 913 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 914 |
+
return ans
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
# 在蒸馏的时候使用!
|
| 918 |
+
class TimestepGuidanceEmbedding(nn.Module):
|
| 919 |
+
def __init__(
|
| 920 |
+
self,
|
| 921 |
+
dim,
|
| 922 |
+
freq_embed_dim=256,
|
| 923 |
+
use_guidance_scale_embed=False,
|
| 924 |
+
guidance_scale_embed_dim=192,
|
| 925 |
+
):
|
| 926 |
+
super().__init__()
|
| 927 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 928 |
+
self.time_mlp = nn.Sequential(
|
| 929 |
+
nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
|
| 930 |
+
)
|
| 931 |
+
if use_guidance_scale_embed:
|
| 932 |
+
self.guidance_scale_embed = ScaledLinear(
|
| 933 |
+
guidance_scale_embed_dim,
|
| 934 |
+
freq_embed_dim,
|
| 935 |
+
bias=False,
|
| 936 |
+
initial_scale=0.1,
|
| 937 |
+
)
|
| 938 |
+
self.guidance_scale_embed_dim = guidance_scale_embed_dim
|
| 939 |
+
else:
|
| 940 |
+
self.guidance_scale_embed = None
|
| 941 |
+
|
| 942 |
+
def forward(self, timestep: float["b"], guidance_scale=None):
|
| 943 |
+
# import pdb
|
| 944 |
+
|
| 945 |
+
# pdb.set_trace()
|
| 946 |
+
time_hidden = self.time_embed(timestep)
|
| 947 |
+
|
| 948 |
+
if self.guidance_scale_embed:
|
| 949 |
+
assert guidance_scale is not None
|
| 950 |
+
guidance_scale_emb = self.guidance_scale_embed(
|
| 951 |
+
zipvoice_timestep_embedding(
|
| 952 |
+
guidance_scale, self.guidance_scale_embed_dim
|
| 953 |
+
)
|
| 954 |
+
)
|
| 955 |
+
time_hidden = time_hidden + guidance_scale_emb
|
| 956 |
+
else:
|
| 957 |
+
assert guidance_scale is None
|
| 958 |
+
|
| 959 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 960 |
+
time = self.time_mlp(time_hidden) # b d
|
| 961 |
+
return time
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/__init__.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
from tokenizers import Tokenizer
|
| 10 |
+
|
| 11 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p import cleaners
|
| 12 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.text_tokenizers import TextTokenizer
|
| 13 |
+
|
| 14 |
+
# import LangSegment
|
| 15 |
+
from src.YingMusicSinger.utils.f5_tts.thirdparty.LangSegment import LangSegment
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PhonemeBpeTokenizer:
|
| 19 |
+
def __init__(
|
| 20 |
+
self, vacab_path="./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json"
|
| 21 |
+
):
|
| 22 |
+
self.lang2backend = {
|
| 23 |
+
"zh": "cmn",
|
| 24 |
+
"ja": "ja",
|
| 25 |
+
"en": "en-us",
|
| 26 |
+
"fr": "fr-fr",
|
| 27 |
+
"ko": "ko",
|
| 28 |
+
"de": "de",
|
| 29 |
+
}
|
| 30 |
+
self.text_tokenizers = {}
|
| 31 |
+
self.int_text_tokenizers()
|
| 32 |
+
|
| 33 |
+
with open(vacab_path, "r") as f:
|
| 34 |
+
json_data = f.read()
|
| 35 |
+
data = json.loads(json_data)
|
| 36 |
+
self.vocab = data["vocab"]
|
| 37 |
+
LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
|
| 38 |
+
|
| 39 |
+
def int_text_tokenizers(self):
|
| 40 |
+
for key, value in self.lang2backend.items():
|
| 41 |
+
self.text_tokenizers[key] = TextTokenizer(language=value)
|
| 42 |
+
|
| 43 |
+
def tokenize(self, text, sentence, language):
|
| 44 |
+
# 1. convert text to phoneme
|
| 45 |
+
phonemes = []
|
| 46 |
+
if language == "auto":
|
| 47 |
+
seglist = LangSegment.getTexts(text)
|
| 48 |
+
tmp_ph = []
|
| 49 |
+
for seg in seglist:
|
| 50 |
+
tmp_ph.append(
|
| 51 |
+
self._clean_text(
|
| 52 |
+
seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
phonemes = "|_|".join(tmp_ph)
|
| 56 |
+
else:
|
| 57 |
+
phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
|
| 58 |
+
# print('clean text: ', phonemes)
|
| 59 |
+
|
| 60 |
+
# 2. tokenize phonemes
|
| 61 |
+
phoneme_tokens = self.phoneme2token(phonemes)
|
| 62 |
+
# print('encode: ', phoneme_tokens)
|
| 63 |
+
|
| 64 |
+
# # 3. decode tokens [optional]
|
| 65 |
+
# decoded_text = self.tokenizer.decode(phoneme_tokens)
|
| 66 |
+
# print('decoded: ', decoded_text)
|
| 67 |
+
|
| 68 |
+
return phonemes, phoneme_tokens
|
| 69 |
+
|
| 70 |
+
def _clean_text(self, text, sentence, language, cleaner_names):
|
| 71 |
+
for name in cleaner_names:
|
| 72 |
+
cleaner = getattr(cleaners, name)
|
| 73 |
+
if not cleaner:
|
| 74 |
+
raise Exception("Unknown cleaner: %s" % name)
|
| 75 |
+
text = cleaner(text, sentence, language, self.text_tokenizers)
|
| 76 |
+
return text
|
| 77 |
+
|
| 78 |
+
def phoneme2token(self, phonemes):
|
| 79 |
+
tokens = []
|
| 80 |
+
if isinstance(phonemes, list):
|
| 81 |
+
for phone in phonemes:
|
| 82 |
+
phone = phone.split("\t")[0]
|
| 83 |
+
phonemes_split = phone.split("|")
|
| 84 |
+
tokens.append(
|
| 85 |
+
[self.vocab[p] for p in phonemes_split if p in self.vocab]
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
phonemes = phonemes.split("\t")[0]
|
| 89 |
+
phonemes_split = phonemes.split("|")
|
| 90 |
+
tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
|
| 91 |
+
return tokens
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/chinese_model_g2p.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
| 12 |
+
from torch.utils.data import DataLoader, Dataset
|
| 13 |
+
from transformers import BertTokenizer
|
| 14 |
+
from transformers.models.bert.modeling_bert import *
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PolyDataset(Dataset):
|
| 18 |
+
def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
|
| 19 |
+
self.dataset = self.preprocess(words, labels)
|
| 20 |
+
self.word_pad_idx = word_pad_idx
|
| 21 |
+
self.label_pad_idx = label_pad_idx
|
| 22 |
+
|
| 23 |
+
def preprocess(self, origin_sentences, origin_labels):
|
| 24 |
+
"""
|
| 25 |
+
Maps tokens and tags to their indices and stores them in the dict data.
|
| 26 |
+
examples:
|
| 27 |
+
word:['[CLS]', '浙', '商', '银', '行', '企', '业', '信', '贷', '部']
|
| 28 |
+
sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
|
| 29 |
+
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
|
| 30 |
+
label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
|
| 31 |
+
"""
|
| 32 |
+
data = []
|
| 33 |
+
labels = []
|
| 34 |
+
sentences = []
|
| 35 |
+
# tokenize
|
| 36 |
+
for line in origin_sentences:
|
| 37 |
+
# replace each token by its index
|
| 38 |
+
# we can not use encode_plus because our sentences are aligned to labels in list type
|
| 39 |
+
words = []
|
| 40 |
+
word_lens = []
|
| 41 |
+
for token in line:
|
| 42 |
+
words.append(token)
|
| 43 |
+
word_lens.append(1)
|
| 44 |
+
token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
|
| 45 |
+
sentences.append(((words, token_start_idxs), 0))
|
| 46 |
+
###
|
| 47 |
+
for tag in origin_labels:
|
| 48 |
+
labels.append(tag)
|
| 49 |
+
|
| 50 |
+
for sentence, label in zip(sentences, labels):
|
| 51 |
+
data.append((sentence, label))
|
| 52 |
+
return data
|
| 53 |
+
|
| 54 |
+
def __getitem__(self, idx):
|
| 55 |
+
"""sample data to get batch"""
|
| 56 |
+
word = self.dataset[idx][0]
|
| 57 |
+
label = self.dataset[idx][1]
|
| 58 |
+
return [word, label]
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
"""get dataset size"""
|
| 62 |
+
return len(self.dataset)
|
| 63 |
+
|
| 64 |
+
def collate_fn(self, batch):
|
| 65 |
+
sentences = [x[0][0] for x in batch]
|
| 66 |
+
ori_sents = [x[0][1] for x in batch]
|
| 67 |
+
labels = [x[1] for x in batch]
|
| 68 |
+
batch_len = len(sentences)
|
| 69 |
+
|
| 70 |
+
# compute length of longest sentence in batch
|
| 71 |
+
max_len = max([len(s[0]) for s in sentences])
|
| 72 |
+
max_label_len = 0
|
| 73 |
+
batch_data = np.ones((batch_len, max_len))
|
| 74 |
+
batch_label_starts = []
|
| 75 |
+
|
| 76 |
+
# padding and aligning
|
| 77 |
+
for j in range(batch_len):
|
| 78 |
+
cur_len = len(sentences[j][0])
|
| 79 |
+
batch_data[j][:cur_len] = sentences[j][0]
|
| 80 |
+
label_start_idx = sentences[j][-1]
|
| 81 |
+
label_starts = np.zeros(max_len)
|
| 82 |
+
label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
|
| 83 |
+
batch_label_starts.append(label_starts)
|
| 84 |
+
max_label_len = max(int(sum(label_starts)), max_label_len)
|
| 85 |
+
|
| 86 |
+
# padding label
|
| 87 |
+
batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
|
| 88 |
+
batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
|
| 89 |
+
for j in range(batch_len):
|
| 90 |
+
cur_tags_len = len(labels[j])
|
| 91 |
+
batch_labels[j][:cur_tags_len] = labels[j]
|
| 92 |
+
batch_pmasks[j][:cur_tags_len] = [
|
| 93 |
+
1 if item > 0 else 0 for item in labels[j]
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
# convert data to torch LongTensors
|
| 97 |
+
batch_data = torch.tensor(batch_data, dtype=torch.long)
|
| 98 |
+
batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
|
| 99 |
+
batch_labels = torch.tensor(batch_labels, dtype=torch.long)
|
| 100 |
+
batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
|
| 101 |
+
return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class BertPolyPredict:
|
| 105 |
+
def __init__(self, bert_model, jsonr_file, json_file):
|
| 106 |
+
self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
|
| 107 |
+
with open(jsonr_file, "r", encoding="utf8") as fp:
|
| 108 |
+
self.pron_dict = json.load(fp)
|
| 109 |
+
with open(json_file, "r", encoding="utf8") as fp:
|
| 110 |
+
self.pron_dict_id_2_pinyin = json.load(fp)
|
| 111 |
+
self.num_polyphone = len(self.pron_dict)
|
| 112 |
+
self.device = "cpu"
|
| 113 |
+
self.polydataset = PolyDataset
|
| 114 |
+
options = SessionOptions() # initialize session options
|
| 115 |
+
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 116 |
+
print(os.path.join(bert_model, "poly_bert_model.onnx"))
|
| 117 |
+
self.session = InferenceSession(
|
| 118 |
+
os.path.join(bert_model, "poly_bert_model.onnx"),
|
| 119 |
+
sess_options=options,
|
| 120 |
+
providers=[
|
| 121 |
+
"CPUExecutionProvider",
|
| 122 |
+
"CUDAExecutionProvider",
|
| 123 |
+
], # CPUExecutionProvider #CUDAExecutionProvider
|
| 124 |
+
)
|
| 125 |
+
# self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
|
| 126 |
+
|
| 127 |
+
# disable session.run() fallback mechanism, it prevents for a reset of the execution provider
|
| 128 |
+
self.session.disable_fallback()
|
| 129 |
+
|
| 130 |
+
def predict_process(self, txt_list):
|
| 131 |
+
word_test, label_test, texts_test = self.get_examples_po(txt_list)
|
| 132 |
+
data = self.polydataset(word_test, label_test)
|
| 133 |
+
predict_loader = DataLoader(
|
| 134 |
+
data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
|
| 135 |
+
)
|
| 136 |
+
pred_tags = self.predict_onnx(predict_loader)
|
| 137 |
+
return pred_tags
|
| 138 |
+
|
| 139 |
+
def predict_onnx(self, dev_loader):
|
| 140 |
+
pred_tags = []
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for idx, batch_samples in enumerate(dev_loader):
|
| 143 |
+
# [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
|
| 144 |
+
batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
|
| 145 |
+
batch_samples
|
| 146 |
+
)
|
| 147 |
+
# shift tensors to GPU if available
|
| 148 |
+
batch_data = batch_data.to(self.device)
|
| 149 |
+
batch_label_starts = batch_label_starts.to(self.device)
|
| 150 |
+
batch_labels = batch_labels.to(self.device)
|
| 151 |
+
batch_pmasks = batch_pmasks.to(self.device)
|
| 152 |
+
batch_data = np.asarray(batch_data, dtype=np.int32)
|
| 153 |
+
batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
|
| 154 |
+
# batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
|
| 155 |
+
batch_output = self.session.run(
|
| 156 |
+
output_names=["outputs"], input_feed={"input_ids": batch_data}
|
| 157 |
+
)[0]
|
| 158 |
+
label_masks = batch_pmasks == 1
|
| 159 |
+
batch_labels = batch_labels.to("cpu").numpy()
|
| 160 |
+
for i, indices in enumerate(np.argmax(batch_output, axis=2)):
|
| 161 |
+
for j, idx in enumerate(indices):
|
| 162 |
+
if label_masks[i][j]:
|
| 163 |
+
# pred_tag.append(idx)
|
| 164 |
+
pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
|
| 165 |
+
return pred_tags
|
| 166 |
+
|
| 167 |
+
def get_examples_po(self, text_list):
|
| 168 |
+
word_list = []
|
| 169 |
+
label_list = []
|
| 170 |
+
sentence_list = []
|
| 171 |
+
id = 0
|
| 172 |
+
for line in [text_list]:
|
| 173 |
+
sentence = line[0]
|
| 174 |
+
words = []
|
| 175 |
+
tokens = line[0]
|
| 176 |
+
index = line[-1]
|
| 177 |
+
front = index
|
| 178 |
+
back = len(tokens) - index - 1
|
| 179 |
+
labels = [0] * front + [1] + [0] * back
|
| 180 |
+
words = ["[CLS]"] + [item for item in sentence]
|
| 181 |
+
words = self.tokenizer.convert_tokens_to_ids(words)
|
| 182 |
+
word_list.append(words)
|
| 183 |
+
label_list.append(labels)
|
| 184 |
+
sentence_list.append(sentence)
|
| 185 |
+
|
| 186 |
+
id += 1
|
| 187 |
+
# mask_list.append(masks)
|
| 188 |
+
assert len(labels) + 1 == len(words), print(
|
| 189 |
+
(
|
| 190 |
+
poly,
|
| 191 |
+
sentence,
|
| 192 |
+
words,
|
| 193 |
+
labels,
|
| 194 |
+
sentence,
|
| 195 |
+
len(sentence),
|
| 196 |
+
len(words),
|
| 197 |
+
len(labels),
|
| 198 |
+
)
|
| 199 |
+
)
|
| 200 |
+
assert len(labels) + 1 == len(words), (
|
| 201 |
+
"Number of labels does not match number of words"
|
| 202 |
+
)
|
| 203 |
+
assert len(labels) == len(sentence), (
|
| 204 |
+
"Number of labels does not match number of sentences"
|
| 205 |
+
)
|
| 206 |
+
assert len(word_list) == len(label_list), (
|
| 207 |
+
"Number of label sentences does not match number of word sentences"
|
| 208 |
+
)
|
| 209 |
+
return word_list, label_list, text_list
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/cleaners.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.english import english_to_ipa
|
| 7 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.french import french_to_ipa
|
| 8 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.german import german_to_ipa
|
| 9 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.korean import korean_to_ipa
|
| 10 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.mandarin import chinese_to_ipa
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def cjekfd_cleaners(text, sentence, language, text_tokenizers):
|
| 14 |
+
if language == "zh":
|
| 15 |
+
return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
|
| 16 |
+
elif language == "ja":
|
| 17 |
+
return japanese_to_ipa(text, text_tokenizers["ja"])
|
| 18 |
+
elif language == "en":
|
| 19 |
+
return english_to_ipa(text, text_tokenizers["en"])
|
| 20 |
+
elif language == "fr":
|
| 21 |
+
return french_to_ipa(text, text_tokenizers["fr"])
|
| 22 |
+
elif language == "ko":
|
| 23 |
+
return korean_to_ipa(text, text_tokenizers["ko"])
|
| 24 |
+
elif language == "de":
|
| 25 |
+
return german_to_ipa(text, text_tokenizers["de"])
|
| 26 |
+
else:
|
| 27 |
+
raise Exception("Unknown language: %s" % language)
|
| 28 |
+
return None
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/english.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
import inflect
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Text clean time
|
| 12 |
+
"""
|
| 13 |
+
_inflect = inflect.engine()
|
| 14 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
| 15 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
| 16 |
+
_percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
|
| 17 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
| 18 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
| 19 |
+
_fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
|
| 20 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
| 21 |
+
_number_re = re.compile(r"[0-9]+")
|
| 22 |
+
|
| 23 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 24 |
+
_abbreviations = [
|
| 25 |
+
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
| 26 |
+
for x in [
|
| 27 |
+
("mrs", "misess"),
|
| 28 |
+
("mr", "mister"),
|
| 29 |
+
("dr", "doctor"),
|
| 30 |
+
("st", "saint"),
|
| 31 |
+
("co", "company"),
|
| 32 |
+
("jr", "junior"),
|
| 33 |
+
("maj", "major"),
|
| 34 |
+
("gen", "general"),
|
| 35 |
+
("drs", "doctors"),
|
| 36 |
+
("rev", "reverend"),
|
| 37 |
+
("lt", "lieutenant"),
|
| 38 |
+
("hon", "honorable"),
|
| 39 |
+
("sgt", "sergeant"),
|
| 40 |
+
("capt", "captain"),
|
| 41 |
+
("esq", "esquire"),
|
| 42 |
+
("ltd", "limited"),
|
| 43 |
+
("col", "colonel"),
|
| 44 |
+
("ft", "fort"),
|
| 45 |
+
("etc", "et cetera"),
|
| 46 |
+
("btw", "by the way"),
|
| 47 |
+
]
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
_special_map = [
|
| 51 |
+
("t|ɹ", "tɹ"),
|
| 52 |
+
("d|ɹ", "dɹ"),
|
| 53 |
+
("t|s", "ts"),
|
| 54 |
+
("d|z", "dz"),
|
| 55 |
+
("ɪ|ɹ", "ɪɹ"),
|
| 56 |
+
("ɐ", "ɚ"),
|
| 57 |
+
("ᵻ", "ɪ"),
|
| 58 |
+
("əl", "l"),
|
| 59 |
+
("x", "k"),
|
| 60 |
+
("ɬ", "l"),
|
| 61 |
+
("ʔ", "t"),
|
| 62 |
+
("n̩", "n"),
|
| 63 |
+
("oː|ɹ", "oːɹ"),
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def expand_abbreviations(text):
|
| 68 |
+
for regex, replacement in _abbreviations:
|
| 69 |
+
text = re.sub(regex, replacement, text)
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _remove_commas(m):
|
| 74 |
+
return m.group(1).replace(",", "")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _expand_decimal_point(m):
|
| 78 |
+
return m.group(1).replace(".", " point ")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _expand_percent(m):
|
| 82 |
+
return m.group(1).replace("%", " percent ")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _expand_dollars(m):
|
| 86 |
+
match = m.group(1)
|
| 87 |
+
parts = match.split(".")
|
| 88 |
+
if len(parts) > 2:
|
| 89 |
+
return " " + match + " dollars " # Unexpected format
|
| 90 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 91 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 92 |
+
if dollars and cents:
|
| 93 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 94 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 95 |
+
return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
|
| 96 |
+
elif dollars:
|
| 97 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 98 |
+
return " %s %s " % (dollars, dollar_unit)
|
| 99 |
+
elif cents:
|
| 100 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 101 |
+
return " %s %s " % (cents, cent_unit)
|
| 102 |
+
else:
|
| 103 |
+
return " zero dollars "
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def fraction_to_words(numerator, denominator):
|
| 107 |
+
if numerator == 1 and denominator == 2:
|
| 108 |
+
return " one half "
|
| 109 |
+
if numerator == 1 and denominator == 4:
|
| 110 |
+
return " one quarter "
|
| 111 |
+
if denominator == 2:
|
| 112 |
+
return " " + _inflect.number_to_words(numerator) + " halves "
|
| 113 |
+
if denominator == 4:
|
| 114 |
+
return " " + _inflect.number_to_words(numerator) + " quarters "
|
| 115 |
+
return (
|
| 116 |
+
" "
|
| 117 |
+
+ _inflect.number_to_words(numerator)
|
| 118 |
+
+ " "
|
| 119 |
+
+ _inflect.ordinal(_inflect.number_to_words(denominator))
|
| 120 |
+
+ " "
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _expand_fraction(m):
|
| 125 |
+
numerator = int(m.group(1))
|
| 126 |
+
denominator = int(m.group(2))
|
| 127 |
+
return fraction_to_words(numerator, denominator)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _expand_ordinal(m):
|
| 131 |
+
return " " + _inflect.number_to_words(m.group(0)) + " "
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _expand_number(m):
|
| 135 |
+
num = int(m.group(0))
|
| 136 |
+
if num > 1000 and num < 3000:
|
| 137 |
+
if num == 2000:
|
| 138 |
+
return " two thousand "
|
| 139 |
+
elif num > 2000 and num < 2010:
|
| 140 |
+
return " two thousand " + _inflect.number_to_words(num % 100) + " "
|
| 141 |
+
elif num % 100 == 0:
|
| 142 |
+
return " " + _inflect.number_to_words(num // 100) + " hundred "
|
| 143 |
+
else:
|
| 144 |
+
return (
|
| 145 |
+
" "
|
| 146 |
+
+ _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
|
| 147 |
+
", ", " "
|
| 148 |
+
)
|
| 149 |
+
+ " "
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
return " " + _inflect.number_to_words(num, andword="") + " "
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Normalize numbers pronunciation
|
| 156 |
+
def normalize_numbers(text):
|
| 157 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 158 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
| 159 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
| 160 |
+
text = re.sub(_fraction_re, _expand_fraction, text)
|
| 161 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
| 162 |
+
text = re.sub(_percent_number_re, _expand_percent, text)
|
| 163 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
| 164 |
+
text = re.sub(_number_re, _expand_number, text)
|
| 165 |
+
return text
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _english_to_ipa(text):
|
| 169 |
+
# text = unidecode(text).lower()
|
| 170 |
+
text = expand_abbreviations(text)
|
| 171 |
+
text = normalize_numbers(text)
|
| 172 |
+
return text
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# special map
|
| 176 |
+
def special_map(text):
|
| 177 |
+
for regex, replacement in _special_map:
|
| 178 |
+
regex = regex.replace("|", "\|")
|
| 179 |
+
while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
|
| 180 |
+
text = re.sub(
|
| 181 |
+
r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
|
| 182 |
+
)
|
| 183 |
+
# text = re.sub(r'([,.!?])', r'|\1', text)
|
| 184 |
+
return text
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Add some special operation
|
| 188 |
+
def english_to_ipa(text, text_tokenizer):
|
| 189 |
+
if type(text) == str:
|
| 190 |
+
text = _english_to_ipa(text)
|
| 191 |
+
else:
|
| 192 |
+
text = [_english_to_ipa(t) for t in text]
|
| 193 |
+
phonemes = text_tokenizer(text)
|
| 194 |
+
if phonemes[-1] in "p⁼ʰmftnlkxʃs`ɹaoəɛɪeɑʊŋiuɥwæjː":
|
| 195 |
+
phonemes += "|_"
|
| 196 |
+
if type(text) == str:
|
| 197 |
+
return special_map(phonemes)
|
| 198 |
+
else:
|
| 199 |
+
result_ph = []
|
| 200 |
+
for phone in phonemes:
|
| 201 |
+
result_ph.append(special_map(phone))
|
| 202 |
+
return result_ph
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/french.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Text clean time
|
| 10 |
+
"""
|
| 11 |
+
# List of (regular expression, replacement) pairs for abbreviations in french:
|
| 12 |
+
_abbreviations = [
|
| 13 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 14 |
+
for x in [
|
| 15 |
+
("M", "monsieur"),
|
| 16 |
+
("Mlle", "mademoiselle"),
|
| 17 |
+
("Mlles", "mesdemoiselles"),
|
| 18 |
+
("Mme", "Madame"),
|
| 19 |
+
("Mmes", "Mesdames"),
|
| 20 |
+
("N.B", "nota bene"),
|
| 21 |
+
("M", "monsieur"),
|
| 22 |
+
("p.c.q", "parce que"),
|
| 23 |
+
("Pr", "professeur"),
|
| 24 |
+
("qqch", "quelque chose"),
|
| 25 |
+
("rdv", "rendez-vous"),
|
| 26 |
+
("max", "maximum"),
|
| 27 |
+
("min", "minimum"),
|
| 28 |
+
("no", "numéro"),
|
| 29 |
+
("adr", "adresse"),
|
| 30 |
+
("dr", "docteur"),
|
| 31 |
+
("st", "saint"),
|
| 32 |
+
("co", "companie"),
|
| 33 |
+
("jr", "junior"),
|
| 34 |
+
("sgt", "sergent"),
|
| 35 |
+
("capt", "capitain"),
|
| 36 |
+
("col", "colonel"),
|
| 37 |
+
("av", "avenue"),
|
| 38 |
+
("av. J.-C", "avant Jésus-Christ"),
|
| 39 |
+
("apr. J.-C", "après Jésus-Christ"),
|
| 40 |
+
("art", "article"),
|
| 41 |
+
("boul", "boulevard"),
|
| 42 |
+
("c.-à-d", "c’est-à-dire"),
|
| 43 |
+
("etc", "et cetera"),
|
| 44 |
+
("ex", "exemple"),
|
| 45 |
+
("excl", "exclusivement"),
|
| 46 |
+
("boul", "boulevard"),
|
| 47 |
+
]
|
| 48 |
+
] + [
|
| 49 |
+
(re.compile("\\b%s" % x[0]), x[1])
|
| 50 |
+
for x in [
|
| 51 |
+
("Mlle", "mademoiselle"),
|
| 52 |
+
("Mlles", "mesdemoiselles"),
|
| 53 |
+
("Mme", "Madame"),
|
| 54 |
+
("Mmes", "Mesdames"),
|
| 55 |
+
]
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
rep_map = {
|
| 59 |
+
":": ",",
|
| 60 |
+
";": ",",
|
| 61 |
+
",": ",",
|
| 62 |
+
"。": ".",
|
| 63 |
+
"!": "!",
|
| 64 |
+
"?": "?",
|
| 65 |
+
"\n": ".",
|
| 66 |
+
"·": ",",
|
| 67 |
+
"、": ",",
|
| 68 |
+
"...": ".",
|
| 69 |
+
"…": ".",
|
| 70 |
+
"$": ".",
|
| 71 |
+
"“": "",
|
| 72 |
+
"”": "",
|
| 73 |
+
"‘": "",
|
| 74 |
+
"’": "",
|
| 75 |
+
"(": "",
|
| 76 |
+
")": "",
|
| 77 |
+
"(": "",
|
| 78 |
+
")": "",
|
| 79 |
+
"《": "",
|
| 80 |
+
"》": "",
|
| 81 |
+
"【": "",
|
| 82 |
+
"】": "",
|
| 83 |
+
"[": "",
|
| 84 |
+
"]": "",
|
| 85 |
+
"—": "",
|
| 86 |
+
"~": "-",
|
| 87 |
+
"~": "-",
|
| 88 |
+
"「": "",
|
| 89 |
+
"」": "",
|
| 90 |
+
"¿": "",
|
| 91 |
+
"¡": "",
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def collapse_whitespace(text):
|
| 96 |
+
# Regular expression matching whitespace:
|
| 97 |
+
_whitespace_re = re.compile(r"\s+")
|
| 98 |
+
return re.sub(_whitespace_re, " ", text).strip()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def remove_punctuation_at_begin(text):
|
| 102 |
+
return re.sub(r"^[,.!?]+", "", text)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def remove_aux_symbols(text):
|
| 106 |
+
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
|
| 107 |
+
return text
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def replace_symbols(text):
|
| 111 |
+
text = text.replace(";", ",")
|
| 112 |
+
text = text.replace("-", " ")
|
| 113 |
+
text = text.replace(":", ",")
|
| 114 |
+
text = text.replace("&", " et ")
|
| 115 |
+
return text
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def expand_abbreviations(text):
|
| 119 |
+
for regex, replacement in _abbreviations:
|
| 120 |
+
text = re.sub(regex, replacement, text)
|
| 121 |
+
return text
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def replace_punctuation(text):
|
| 125 |
+
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
| 126 |
+
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
| 127 |
+
return replaced_text
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def text_normalize(text):
|
| 131 |
+
text = expand_abbreviations(text)
|
| 132 |
+
text = replace_punctuation(text)
|
| 133 |
+
text = replace_symbols(text)
|
| 134 |
+
text = remove_aux_symbols(text)
|
| 135 |
+
text = remove_punctuation_at_begin(text)
|
| 136 |
+
text = collapse_whitespace(text)
|
| 137 |
+
text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
|
| 138 |
+
return text
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def french_to_ipa(text, text_tokenizer):
|
| 142 |
+
if type(text) == str:
|
| 143 |
+
text = text_normalize(text)
|
| 144 |
+
phonemes = text_tokenizer(text)
|
| 145 |
+
return phonemes
|
| 146 |
+
else:
|
| 147 |
+
for i, t in enumerate(text):
|
| 148 |
+
text[i] = text_normalize(t)
|
| 149 |
+
return text_tokenizer(text)
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/german.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Text clean time
|
| 10 |
+
"""
|
| 11 |
+
rep_map = {
|
| 12 |
+
":": ",",
|
| 13 |
+
";": ",",
|
| 14 |
+
",": ",",
|
| 15 |
+
"。": ".",
|
| 16 |
+
"!": "!",
|
| 17 |
+
"?": "?",
|
| 18 |
+
"\n": ".",
|
| 19 |
+
"·": ",",
|
| 20 |
+
"、": ",",
|
| 21 |
+
"...": ".",
|
| 22 |
+
"…": ".",
|
| 23 |
+
"$": ".",
|
| 24 |
+
"“": "",
|
| 25 |
+
"”": "",
|
| 26 |
+
"‘": "",
|
| 27 |
+
"’": "",
|
| 28 |
+
"(": "",
|
| 29 |
+
")": "",
|
| 30 |
+
"(": "",
|
| 31 |
+
")": "",
|
| 32 |
+
"《": "",
|
| 33 |
+
"》": "",
|
| 34 |
+
"【": "",
|
| 35 |
+
"】": "",
|
| 36 |
+
"[": "",
|
| 37 |
+
"]": "",
|
| 38 |
+
"—": "",
|
| 39 |
+
"~": "-",
|
| 40 |
+
"~": "-",
|
| 41 |
+
"「": "",
|
| 42 |
+
"」": "",
|
| 43 |
+
"¿": "",
|
| 44 |
+
"¡": "",
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def collapse_whitespace(text):
|
| 49 |
+
# Regular expression matching whitespace:
|
| 50 |
+
_whitespace_re = re.compile(r"\s+")
|
| 51 |
+
return re.sub(_whitespace_re, " ", text).strip()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def remove_punctuation_at_begin(text):
|
| 55 |
+
return re.sub(r"^[,.!?]+", "", text)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def remove_aux_symbols(text):
|
| 59 |
+
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
|
| 60 |
+
return text
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def replace_symbols(text):
|
| 64 |
+
text = text.replace(";", ",")
|
| 65 |
+
text = text.replace("-", " ")
|
| 66 |
+
text = text.replace(":", ",")
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def replace_punctuation(text):
|
| 71 |
+
pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
|
| 72 |
+
replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
|
| 73 |
+
return replaced_text
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def text_normalize(text):
|
| 77 |
+
text = replace_punctuation(text)
|
| 78 |
+
text = replace_symbols(text)
|
| 79 |
+
text = remove_aux_symbols(text)
|
| 80 |
+
text = remove_punctuation_at_begin(text)
|
| 81 |
+
text = collapse_whitespace(text)
|
| 82 |
+
text = re.sub(r"([^\.,!\?\-…])$", r"\1", text)
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def german_to_ipa(text, text_tokenizer):
|
| 87 |
+
if type(text) == str:
|
| 88 |
+
text = text_normalize(text)
|
| 89 |
+
phonemes = text_tokenizer(text)
|
| 90 |
+
return phonemes
|
| 91 |
+
else:
|
| 92 |
+
for i, t in enumerate(text):
|
| 93 |
+
text[i] = text_normalize(t)
|
| 94 |
+
return text_tokenizer(text)
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/korean.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Text clean time
|
| 10 |
+
"""
|
| 11 |
+
english_dictionary = {
|
| 12 |
+
"KOREA": "코리아",
|
| 13 |
+
"IDOL": "아이돌",
|
| 14 |
+
"IT": "아이티",
|
| 15 |
+
"IQ": "아이큐",
|
| 16 |
+
"UP": "업",
|
| 17 |
+
"DOWN": "다운",
|
| 18 |
+
"PC": "피씨",
|
| 19 |
+
"CCTV": "씨씨티비",
|
| 20 |
+
"SNS": "에스엔에스",
|
| 21 |
+
"AI": "에이아이",
|
| 22 |
+
"CEO": "씨이오",
|
| 23 |
+
"A": "에이",
|
| 24 |
+
"B": "비",
|
| 25 |
+
"C": "씨",
|
| 26 |
+
"D": "디",
|
| 27 |
+
"E": "이",
|
| 28 |
+
"F": "에프",
|
| 29 |
+
"G": "지",
|
| 30 |
+
"H": "에이치",
|
| 31 |
+
"I": "아이",
|
| 32 |
+
"J": "제이",
|
| 33 |
+
"K": "케이",
|
| 34 |
+
"L": "엘",
|
| 35 |
+
"M": "엠",
|
| 36 |
+
"N": "엔",
|
| 37 |
+
"O": "오",
|
| 38 |
+
"P": "피",
|
| 39 |
+
"Q": "큐",
|
| 40 |
+
"R": "알",
|
| 41 |
+
"S": "에스",
|
| 42 |
+
"T": "티",
|
| 43 |
+
"U": "유",
|
| 44 |
+
"V": "브이",
|
| 45 |
+
"W": "더블유",
|
| 46 |
+
"X": "엑스",
|
| 47 |
+
"Y": "와이",
|
| 48 |
+
"Z": "제트",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def normalize(text):
|
| 53 |
+
text = text.strip()
|
| 54 |
+
text = re.sub(
|
| 55 |
+
"[⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]", "", text
|
| 56 |
+
)
|
| 57 |
+
text = normalize_english(text)
|
| 58 |
+
text = text.lower()
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def normalize_english(text):
|
| 63 |
+
def fn(m):
|
| 64 |
+
word = m.group()
|
| 65 |
+
if word in english_dictionary:
|
| 66 |
+
return english_dictionary.get(word)
|
| 67 |
+
return word
|
| 68 |
+
|
| 69 |
+
text = re.sub("([A-Za-z]+)", fn, text)
|
| 70 |
+
return text
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def korean_to_ipa(text, text_tokenizer):
|
| 74 |
+
if type(text) == str:
|
| 75 |
+
text = normalize(text)
|
| 76 |
+
phonemes = text_tokenizer(text)
|
| 77 |
+
return phonemes
|
| 78 |
+
else:
|
| 79 |
+
for i, t in enumerate(text):
|
| 80 |
+
text[i] = normalize(t)
|
| 81 |
+
return text_tokenizer(text)
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/mandarin.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import cn2an
|
| 11 |
+
import jieba
|
| 12 |
+
from pypinyin import BOPOMOFO, lazy_pinyin
|
| 13 |
+
|
| 14 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p.chinese_model_g2p import BertPolyPredict
|
| 15 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.utils.front_utils import *
|
| 16 |
+
|
| 17 |
+
# from g2pw import G2PWConverter
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# set blank level, {0:"none",1:"char", 2:"word"}
|
| 21 |
+
BLANK_LEVEL = 0
|
| 22 |
+
|
| 23 |
+
# conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
|
| 24 |
+
resource_path = r"./src/YingMusicSinger/utils/f5_tts/g2p"
|
| 25 |
+
poly_all_class_path = os.path.join(
|
| 26 |
+
resource_path, "sources", "g2p_chinese_model", "polychar.txt"
|
| 27 |
+
)
|
| 28 |
+
if not os.path.exists(poly_all_class_path):
|
| 29 |
+
print(
|
| 30 |
+
"Incorrect path for polyphonic character class dictionary: {}, please check...".format(
|
| 31 |
+
poly_all_class_path
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
exit()
|
| 35 |
+
poly_dict = generate_poly_lexicon(poly_all_class_path)
|
| 36 |
+
|
| 37 |
+
# Set up G2PW model parameters
|
| 38 |
+
g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
|
| 39 |
+
if not os.path.exists(g2pw_poly_model_path):
|
| 40 |
+
print(
|
| 41 |
+
"Incorrect path for g2pw polyphonic character model: {}, please check...".format(
|
| 42 |
+
g2pw_poly_model_path
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
exit()
|
| 46 |
+
|
| 47 |
+
json_file_path = os.path.join(
|
| 48 |
+
resource_path, "sources", "g2p_chinese_model", "polydict.json"
|
| 49 |
+
)
|
| 50 |
+
if not os.path.exists(json_file_path):
|
| 51 |
+
print(
|
| 52 |
+
"Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
|
| 53 |
+
json_file_path
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
exit()
|
| 57 |
+
|
| 58 |
+
jsonr_file_path = os.path.join(
|
| 59 |
+
resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
|
| 60 |
+
)
|
| 61 |
+
if not os.path.exists(jsonr_file_path):
|
| 62 |
+
print(
|
| 63 |
+
"Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
|
| 64 |
+
jsonr_file_path
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
exit()
|
| 68 |
+
|
| 69 |
+
g2pw_poly_predict = BertPolyPredict(
|
| 70 |
+
g2pw_poly_model_path, jsonr_file_path, json_file_path
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
Text clean time
|
| 76 |
+
"""
|
| 77 |
+
# List of (Latin alphabet, bopomofo) pairs:
|
| 78 |
+
_latin_to_bopomofo = [
|
| 79 |
+
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
| 80 |
+
for x in [
|
| 81 |
+
("a", "ㄟˉ"),
|
| 82 |
+
("b", "ㄅㄧˋ"),
|
| 83 |
+
("c", "ㄙㄧˉ"),
|
| 84 |
+
("d", "ㄉㄧˋ"),
|
| 85 |
+
("e", "ㄧˋ"),
|
| 86 |
+
("f", "ㄝˊㄈㄨˋ"),
|
| 87 |
+
("g", "ㄐㄧˋ"),
|
| 88 |
+
("h", "ㄝˇㄑㄩˋ"),
|
| 89 |
+
("i", "ㄞˋ"),
|
| 90 |
+
("j", "ㄐㄟˋ"),
|
| 91 |
+
("k", "ㄎㄟˋ"),
|
| 92 |
+
("l", "ㄝˊㄛˋ"),
|
| 93 |
+
("m", "ㄝˊㄇㄨˋ"),
|
| 94 |
+
("n", "ㄣˉ"),
|
| 95 |
+
("o", "ㄡˉ"),
|
| 96 |
+
("p", "ㄆㄧˉ"),
|
| 97 |
+
("q", "ㄎㄧㄡˉ"),
|
| 98 |
+
("r", "ㄚˋ"),
|
| 99 |
+
("s", "ㄝˊㄙˋ"),
|
| 100 |
+
("t", "ㄊㄧˋ"),
|
| 101 |
+
("u", "ㄧㄡˉ"),
|
| 102 |
+
("v", "ㄨㄧˉ"),
|
| 103 |
+
("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
|
| 104 |
+
("x", "ㄝˉㄎㄨˋㄙˋ"),
|
| 105 |
+
("y", "ㄨㄞˋ"),
|
| 106 |
+
("z", "ㄗㄟˋ"),
|
| 107 |
+
]
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
# List of (bopomofo, ipa) pairs:
|
| 111 |
+
_bopomofo_to_ipa = [
|
| 112 |
+
(re.compile("%s" % x[0]), x[1])
|
| 113 |
+
for x in [
|
| 114 |
+
("ㄅㄛ", "p⁼wo"),
|
| 115 |
+
("ㄆㄛ", "pʰwo"),
|
| 116 |
+
("ㄇㄛ", "mwo"),
|
| 117 |
+
("ㄈㄛ", "fwo"),
|
| 118 |
+
("ㄧㄢ", "|jɛn"),
|
| 119 |
+
("ㄩㄢ", "|ɥæn"),
|
| 120 |
+
("ㄧㄣ", "|in"),
|
| 121 |
+
("ㄩㄣ", "|ɥn"),
|
| 122 |
+
("ㄧㄥ", "|iŋ"),
|
| 123 |
+
("ㄨㄥ", "|ʊŋ"),
|
| 124 |
+
("ㄩㄥ", "|jʊŋ"),
|
| 125 |
+
# Add
|
| 126 |
+
("ㄧㄚ", "|ia"),
|
| 127 |
+
("ㄧㄝ", "|iɛ"),
|
| 128 |
+
("ㄧㄠ", "|iɑʊ"),
|
| 129 |
+
("ㄧㄡ", "|ioʊ"),
|
| 130 |
+
("ㄧㄤ", "|iɑŋ"),
|
| 131 |
+
("ㄨㄚ", "|ua"),
|
| 132 |
+
("ㄨㄛ", "|uo"),
|
| 133 |
+
("ㄨㄞ", "|uaɪ"),
|
| 134 |
+
("ㄨㄟ", "|ueɪ"),
|
| 135 |
+
("ㄨㄢ", "|uan"),
|
| 136 |
+
("ㄨㄣ", "|uən"),
|
| 137 |
+
("ㄨㄤ", "|uɑŋ"),
|
| 138 |
+
("ㄩㄝ", "|ɥɛ"),
|
| 139 |
+
# End
|
| 140 |
+
("ㄅ", "p⁼"),
|
| 141 |
+
("ㄆ", "pʰ"),
|
| 142 |
+
("ㄇ", "m"),
|
| 143 |
+
("ㄈ", "f"),
|
| 144 |
+
("ㄉ", "t⁼"),
|
| 145 |
+
("ㄊ", "tʰ"),
|
| 146 |
+
("ㄋ", "n"),
|
| 147 |
+
("ㄌ", "l"),
|
| 148 |
+
("ㄍ", "k⁼"),
|
| 149 |
+
("ㄎ", "kʰ"),
|
| 150 |
+
("ㄏ", "x"),
|
| 151 |
+
("ㄐ", "tʃ⁼"),
|
| 152 |
+
("ㄑ", "tʃʰ"),
|
| 153 |
+
("ㄒ", "ʃ"),
|
| 154 |
+
("ㄓ", "ts`⁼"),
|
| 155 |
+
("ㄔ", "ts`ʰ"),
|
| 156 |
+
("ㄕ", "s`"),
|
| 157 |
+
("ㄖ", "ɹ`"),
|
| 158 |
+
("ㄗ", "ts⁼"),
|
| 159 |
+
("ㄘ", "tsʰ"),
|
| 160 |
+
("ㄙ", "|s"),
|
| 161 |
+
("ㄚ", "|a"),
|
| 162 |
+
("ㄛ", "|o"),
|
| 163 |
+
("ㄜ", "|ə"),
|
| 164 |
+
("ㄝ", "|ɛ"),
|
| 165 |
+
("ㄞ", "|aɪ"),
|
| 166 |
+
("ㄟ", "|eɪ"),
|
| 167 |
+
("ㄠ", "|ɑʊ"),
|
| 168 |
+
("ㄡ", "|oʊ"),
|
| 169 |
+
("ㄢ", "|an"),
|
| 170 |
+
("ㄣ", "|ən"),
|
| 171 |
+
("ㄤ", "|ɑŋ"),
|
| 172 |
+
("ㄥ", "|əŋ"),
|
| 173 |
+
("ㄦ", "əɹ"),
|
| 174 |
+
("ㄧ", "|i"),
|
| 175 |
+
("ㄨ", "|u"),
|
| 176 |
+
("ㄩ", "|ɥ"),
|
| 177 |
+
("ˉ", "→|"),
|
| 178 |
+
("ˊ", "↑|"),
|
| 179 |
+
("ˇ", "↓↑|"),
|
| 180 |
+
("ˋ", "↓|"),
|
| 181 |
+
("˙", "|"),
|
| 182 |
+
]
|
| 183 |
+
]
|
| 184 |
+
must_not_er_words = {"女儿", "老儿", "男儿", "少儿", "小儿"}
|
| 185 |
+
|
| 186 |
+
word_pinyin_dict = {}
|
| 187 |
+
with open(
|
| 188 |
+
r"src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt",
|
| 189 |
+
"r",
|
| 190 |
+
encoding="utf-8",
|
| 191 |
+
) as fread:
|
| 192 |
+
txt_list = fread.readlines()
|
| 193 |
+
for txt in txt_list:
|
| 194 |
+
word, pinyin = txt.strip().split("\t")
|
| 195 |
+
word_pinyin_dict[word] = pinyin
|
| 196 |
+
fread.close()
|
| 197 |
+
|
| 198 |
+
pinyin_2_bopomofo_dict = {}
|
| 199 |
+
with open(
|
| 200 |
+
r"./src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt",
|
| 201 |
+
"r",
|
| 202 |
+
encoding="utf-8",
|
| 203 |
+
) as fread:
|
| 204 |
+
txt_list = fread.readlines()
|
| 205 |
+
for txt in txt_list:
|
| 206 |
+
pinyin, bopomofo = txt.strip().split("\t")
|
| 207 |
+
pinyin_2_bopomofo_dict[pinyin] = bopomofo
|
| 208 |
+
fread.close()
|
| 209 |
+
|
| 210 |
+
tone_dict = {
|
| 211 |
+
"0": "˙",
|
| 212 |
+
"5": "˙",
|
| 213 |
+
"1": "",
|
| 214 |
+
"2": "ˊ",
|
| 215 |
+
"3": "ˇ",
|
| 216 |
+
"4": "ˋ",
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
bopomofos2pinyin_dict = {}
|
| 220 |
+
with open(
|
| 221 |
+
r"./src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt",
|
| 222 |
+
"r",
|
| 223 |
+
encoding="utf-8",
|
| 224 |
+
) as fread:
|
| 225 |
+
txt_list = fread.readlines()
|
| 226 |
+
for txt in txt_list:
|
| 227 |
+
v, k = txt.strip().split("\t")
|
| 228 |
+
bopomofos2pinyin_dict[k] = v
|
| 229 |
+
fread.close()
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def bpmf_to_pinyin(text):
|
| 233 |
+
bopomofo_list = text.split("|")
|
| 234 |
+
pinyin_list = []
|
| 235 |
+
for info in bopomofo_list:
|
| 236 |
+
pinyin = ""
|
| 237 |
+
for c in info:
|
| 238 |
+
if c in bopomofos2pinyin_dict:
|
| 239 |
+
pinyin += bopomofos2pinyin_dict[c]
|
| 240 |
+
if len(pinyin) == 0:
|
| 241 |
+
continue
|
| 242 |
+
if pinyin[-1] not in "01234":
|
| 243 |
+
pinyin += "1"
|
| 244 |
+
if pinyin[:-1] == "ve":
|
| 245 |
+
pinyin = "y" + pinyin
|
| 246 |
+
if pinyin[:-1] == "sh":
|
| 247 |
+
pinyin = pinyin[:-1] + "i" + pinyin[-1]
|
| 248 |
+
if pinyin == "sh":
|
| 249 |
+
pinyin = pinyin[:-1] + "i"
|
| 250 |
+
if pinyin[:-1] == "s":
|
| 251 |
+
pinyin = "si" + pinyin[-1]
|
| 252 |
+
if pinyin[:-1] == "c":
|
| 253 |
+
pinyin = "ci" + pinyin[-1]
|
| 254 |
+
if pinyin[:-1] == "i":
|
| 255 |
+
pinyin = "yi" + pinyin[-1]
|
| 256 |
+
if pinyin[:-1] == "iou":
|
| 257 |
+
pinyin = "you" + pinyin[-1]
|
| 258 |
+
if pinyin[:-1] == "ien":
|
| 259 |
+
pinyin = "yin" + pinyin[-1]
|
| 260 |
+
if "iou" in pinyin and pinyin[-4:-1] == "iou":
|
| 261 |
+
pinyin = pinyin[:-4] + "iu" + pinyin[-1]
|
| 262 |
+
if "uei" in pinyin:
|
| 263 |
+
if pinyin[:-1] == "uei":
|
| 264 |
+
pinyin = "wei" + pinyin[-1]
|
| 265 |
+
elif pinyin[-4:-1] == "uei":
|
| 266 |
+
pinyin = pinyin[:-4] + "ui" + pinyin[-1]
|
| 267 |
+
if "uen" in pinyin and pinyin[-4:-1] == "uen":
|
| 268 |
+
if pinyin[:-1] == "uen":
|
| 269 |
+
pinyin = "wen" + pinyin[-1]
|
| 270 |
+
elif pinyin[-4:-1] == "uei":
|
| 271 |
+
pinyin = pinyin[:-4] + "un" + pinyin[-1]
|
| 272 |
+
if "van" in pinyin and pinyin[-4:-1] == "van":
|
| 273 |
+
if pinyin[:-1] == "van":
|
| 274 |
+
pinyin = "yuan" + pinyin[-1]
|
| 275 |
+
elif pinyin[-4:-1] == "van":
|
| 276 |
+
pinyin = pinyin[:-4] + "uan" + pinyin[-1]
|
| 277 |
+
if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
|
| 278 |
+
pinyin = pinyin[:-5] + "ong" + pinyin[-1]
|
| 279 |
+
if pinyin[:-1] == "veng":
|
| 280 |
+
pinyin = "yong" + pinyin[-1]
|
| 281 |
+
if "veng" in pinyin and pinyin[-5:-1] == "veng":
|
| 282 |
+
pinyin = pinyin[:-5] + "iong" + pinyin[-1]
|
| 283 |
+
if pinyin[:-1] == "ieng":
|
| 284 |
+
pinyin = "ying" + pinyin[-1]
|
| 285 |
+
if pinyin[:-1] == "u":
|
| 286 |
+
pinyin = "wu" + pinyin[-1]
|
| 287 |
+
if pinyin[:-1] == "v":
|
| 288 |
+
pinyin = "yv" + pinyin[-1]
|
| 289 |
+
if pinyin[:-1] == "ing":
|
| 290 |
+
pinyin = "ying" + pinyin[-1]
|
| 291 |
+
if pinyin[:-1] == "z":
|
| 292 |
+
pinyin = "zi" + pinyin[-1]
|
| 293 |
+
if pinyin[:-1] == "zh":
|
| 294 |
+
pinyin = "zhi" + pinyin[-1]
|
| 295 |
+
if pinyin[0] == "u":
|
| 296 |
+
pinyin = "w" + pinyin[1:]
|
| 297 |
+
if pinyin[0] == "i":
|
| 298 |
+
pinyin = "y" + pinyin[1:]
|
| 299 |
+
pinyin = pinyin.replace("ien", "in")
|
| 300 |
+
|
| 301 |
+
pinyin_list.append(pinyin)
|
| 302 |
+
return " ".join(pinyin_list)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# Convert numbers to Chinese pronunciation
|
| 306 |
+
def number_to_chinese(text):
|
| 307 |
+
# numbers = re.findall(r'\d+(?:\.?\d+)?', text)
|
| 308 |
+
# for number in numbers:
|
| 309 |
+
# text = text.replace(number, cn2an.an2cn(number), 1)
|
| 310 |
+
text = cn2an.transform(text, "an2cn")
|
| 311 |
+
return text
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def normalization(text):
|
| 315 |
+
text = text.replace(",", ",")
|
| 316 |
+
text = text.replace("。", ".")
|
| 317 |
+
text = text.replace("!", "!")
|
| 318 |
+
text = text.replace("?", "?")
|
| 319 |
+
text = text.replace(";", ";")
|
| 320 |
+
text = text.replace(":", ":")
|
| 321 |
+
text = text.replace("、", ",")
|
| 322 |
+
text = text.replace("‘", "'")
|
| 323 |
+
text = text.replace("’", "'")
|
| 324 |
+
text = text.replace("⋯", "…")
|
| 325 |
+
text = text.replace("···", "…")
|
| 326 |
+
text = text.replace("・・・", "…")
|
| 327 |
+
text = text.replace("...", "…")
|
| 328 |
+
text = re.sub(r"\s+", "", text)
|
| 329 |
+
text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'…]", "", text)
|
| 330 |
+
text = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", text)
|
| 331 |
+
return text
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def change_tone(bopomofo: str, tone: str) -> str:
|
| 335 |
+
if bopomofo[-1] not in "˙ˊˇˋ":
|
| 336 |
+
bopomofo = bopomofo + tone
|
| 337 |
+
else:
|
| 338 |
+
bopomofo = bopomofo[:-1] + tone
|
| 339 |
+
return bopomofo
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
|
| 343 |
+
if len(word) > 1 and word[-1] == "儿" and word not in must_not_er_words:
|
| 344 |
+
bopomofos[-1] = change_tone(bopomofos[-1], "˙")
|
| 345 |
+
return bopomofos
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
|
| 349 |
+
valid_char = set(word)
|
| 350 |
+
if len(valid_char) == 1 and "不" in valid_char:
|
| 351 |
+
pass
|
| 352 |
+
elif word in ["不字"]:
|
| 353 |
+
pass
|
| 354 |
+
elif len(word) == 3 and word[1] == "不" and bopomofos[1][:-1] == "ㄅㄨ":
|
| 355 |
+
bopomofos[1] = bopomofos[1][:-1] + "˙"
|
| 356 |
+
else:
|
| 357 |
+
for i, char in enumerate(word):
|
| 358 |
+
if (
|
| 359 |
+
i + 1 < len(bopomofos)
|
| 360 |
+
and char == "不"
|
| 361 |
+
and i + 1 < len(word)
|
| 362 |
+
and 0 < len(bopomofos[i + 1])
|
| 363 |
+
and bopomofos[i + 1][-1] == "ˋ"
|
| 364 |
+
):
|
| 365 |
+
bopomofos[i] = bopomofos[i][:-1] + "ˊ"
|
| 366 |
+
return bopomofos
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
|
| 370 |
+
punc = ":,;。?!“”‘’':,;.?!()(){}【】[]-~`、 "
|
| 371 |
+
if word.find("一") != -1 and any(
|
| 372 |
+
[item.isnumeric() for item in word if item != "一"]
|
| 373 |
+
):
|
| 374 |
+
for i in range(len(word)):
|
| 375 |
+
if (
|
| 376 |
+
i == 0
|
| 377 |
+
and word[0] == "一"
|
| 378 |
+
and len(word) > 1
|
| 379 |
+
and word[1]
|
| 380 |
+
not in [
|
| 381 |
+
"零",
|
| 382 |
+
"一",
|
| 383 |
+
"二",
|
| 384 |
+
"三",
|
| 385 |
+
"四",
|
| 386 |
+
"五",
|
| 387 |
+
"六",
|
| 388 |
+
"七",
|
| 389 |
+
"八",
|
| 390 |
+
"九",
|
| 391 |
+
"十",
|
| 392 |
+
]
|
| 393 |
+
):
|
| 394 |
+
if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ˋ", "˙"]:
|
| 395 |
+
bopomofos[0] = change_tone(bopomofos[0], "ˊ")
|
| 396 |
+
else:
|
| 397 |
+
bopomofos[0] = change_tone(bopomofos[0], "ˋ")
|
| 398 |
+
elif word[i] == "一":
|
| 399 |
+
bopomofos[i] = change_tone(bopomofos[i], "")
|
| 400 |
+
return bopomofos
|
| 401 |
+
elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
|
| 402 |
+
bopomofos[1] = change_tone(bopomofos[1], "˙")
|
| 403 |
+
elif word.startswith("第一"):
|
| 404 |
+
bopomofos[1] = change_tone(bopomofos[1], "")
|
| 405 |
+
elif word.startswith("一月") or word.startswith("一日") or word.startswith("一号"):
|
| 406 |
+
bopomofos[0] = change_tone(bopomofos[0], "")
|
| 407 |
+
else:
|
| 408 |
+
for i, char in enumerate(word):
|
| 409 |
+
if char == "一" and i + 1 < len(word):
|
| 410 |
+
if (
|
| 411 |
+
len(bopomofos) > i + 1
|
| 412 |
+
and len(bopomofos[i + 1]) > 0
|
| 413 |
+
and bopomofos[i + 1][-1] in {"ˋ"}
|
| 414 |
+
):
|
| 415 |
+
bopomofos[i] = change_tone(bopomofos[i], "ˊ")
|
| 416 |
+
else:
|
| 417 |
+
if word[i + 1] not in punc:
|
| 418 |
+
bopomofos[i] = change_tone(bopomofos[i], "ˋ")
|
| 419 |
+
else:
|
| 420 |
+
pass
|
| 421 |
+
return bopomofos
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def merge_bu(seg: List) -> List:
|
| 425 |
+
new_seg = []
|
| 426 |
+
last_word = ""
|
| 427 |
+
for word in seg:
|
| 428 |
+
if word != "不":
|
| 429 |
+
if last_word == "不":
|
| 430 |
+
word = last_word + word
|
| 431 |
+
new_seg.append(word)
|
| 432 |
+
last_word = word
|
| 433 |
+
return new_seg
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def merge_er(seg: List) -> List:
|
| 437 |
+
new_seg = []
|
| 438 |
+
for i, word in enumerate(seg):
|
| 439 |
+
if i - 1 >= 0 and word == "儿":
|
| 440 |
+
new_seg[-1] = new_seg[-1] + seg[i]
|
| 441 |
+
else:
|
| 442 |
+
new_seg.append(word)
|
| 443 |
+
return new_seg
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def merge_yi(seg: List) -> List:
|
| 447 |
+
new_seg = []
|
| 448 |
+
# function 1
|
| 449 |
+
for i, word in enumerate(seg):
|
| 450 |
+
if (
|
| 451 |
+
i - 1 >= 0
|
| 452 |
+
and word == "一"
|
| 453 |
+
and i + 1 < len(seg)
|
| 454 |
+
and seg[i - 1] == seg[i + 1]
|
| 455 |
+
):
|
| 456 |
+
if i - 1 < len(new_seg):
|
| 457 |
+
new_seg[i - 1] = new_seg[i - 1] + "一" + new_seg[i - 1]
|
| 458 |
+
else:
|
| 459 |
+
new_seg.append(word)
|
| 460 |
+
new_seg.append(seg[i + 1])
|
| 461 |
+
else:
|
| 462 |
+
if i - 2 >= 0 and seg[i - 1] == "一" and seg[i - 2] == word:
|
| 463 |
+
continue
|
| 464 |
+
else:
|
| 465 |
+
new_seg.append(word)
|
| 466 |
+
seg = new_seg
|
| 467 |
+
new_seg = []
|
| 468 |
+
isnumeric_flag = False
|
| 469 |
+
for i, word in enumerate(seg):
|
| 470 |
+
if all([item.isnumeric() for item in word]) and not isnumeric_flag:
|
| 471 |
+
isnumeric_flag = True
|
| 472 |
+
new_seg.append(word)
|
| 473 |
+
else:
|
| 474 |
+
new_seg.append(word)
|
| 475 |
+
seg = new_seg
|
| 476 |
+
new_seg = []
|
| 477 |
+
# function 2
|
| 478 |
+
for i, word in enumerate(seg):
|
| 479 |
+
if new_seg and new_seg[-1] == "一":
|
| 480 |
+
new_seg[-1] = new_seg[-1] + word
|
| 481 |
+
else:
|
| 482 |
+
new_seg.append(word)
|
| 483 |
+
return new_seg
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
|
| 487 |
+
def chinese_to_bopomofo(text_short, sentence):
|
| 488 |
+
# bopomofos = conv(text_short)
|
| 489 |
+
words = jieba.lcut(text_short, cut_all=False)
|
| 490 |
+
words = merge_yi(words)
|
| 491 |
+
words = merge_bu(words)
|
| 492 |
+
words = merge_er(words)
|
| 493 |
+
text = ""
|
| 494 |
+
|
| 495 |
+
char_index = 0
|
| 496 |
+
for word in words:
|
| 497 |
+
bopomofos = []
|
| 498 |
+
if word in word_pinyin_dict and word not in poly_dict:
|
| 499 |
+
pinyin = word_pinyin_dict[word]
|
| 500 |
+
for py in pinyin.split(" "):
|
| 501 |
+
if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
|
| 502 |
+
bopomofos.append(
|
| 503 |
+
pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
|
| 504 |
+
)
|
| 505 |
+
if BLANK_LEVEL == 1:
|
| 506 |
+
bopomofos.append("_")
|
| 507 |
+
else:
|
| 508 |
+
bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
|
| 509 |
+
bopomofos += bopomofos_lazy
|
| 510 |
+
if BLANK_LEVEL == 1:
|
| 511 |
+
bopomofos.append("_")
|
| 512 |
+
else:
|
| 513 |
+
for i in range(len(word)):
|
| 514 |
+
c = word[i]
|
| 515 |
+
if c in poly_dict:
|
| 516 |
+
poly_pinyin = g2pw_poly_predict.predict_process(
|
| 517 |
+
[text_short, char_index + i]
|
| 518 |
+
)[0]
|
| 519 |
+
py = poly_pinyin[2:-1]
|
| 520 |
+
bopomofos.append(
|
| 521 |
+
pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
|
| 522 |
+
)
|
| 523 |
+
if BLANK_LEVEL == 1:
|
| 524 |
+
bopomofos.append("_")
|
| 525 |
+
elif c in word_pinyin_dict:
|
| 526 |
+
py = word_pinyin_dict[c]
|
| 527 |
+
bopomofos.append(
|
| 528 |
+
pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
|
| 529 |
+
)
|
| 530 |
+
if BLANK_LEVEL == 1:
|
| 531 |
+
bopomofos.append("_")
|
| 532 |
+
else:
|
| 533 |
+
bopomofos.append(c)
|
| 534 |
+
if BLANK_LEVEL == 1:
|
| 535 |
+
bopomofos.append("_")
|
| 536 |
+
if BLANK_LEVEL == 2:
|
| 537 |
+
bopomofos.append("_")
|
| 538 |
+
char_index += len(word)
|
| 539 |
+
|
| 540 |
+
if (
|
| 541 |
+
len(word) == 3
|
| 542 |
+
and bopomofos[0][-1] == "ˇ"
|
| 543 |
+
and bopomofos[1][-1] == "ˇ"
|
| 544 |
+
and bopomofos[-1][-1] == "ˇ"
|
| 545 |
+
):
|
| 546 |
+
bopomofos[0] = bopomofos[0] + "ˊ"
|
| 547 |
+
bopomofos[1] = bopomofos[1] + "ˊ"
|
| 548 |
+
if len(word) == 2 and bopomofos[0][-1] == "ˇ" and bopomofos[-1][-1] == "ˇ":
|
| 549 |
+
bopomofos[0] = bopomofos[0][:-1] + "ˊ"
|
| 550 |
+
bopomofos = bu_sandhi(word, bopomofos)
|
| 551 |
+
bopomofos = yi_sandhi(word, bopomofos)
|
| 552 |
+
bopomofos = er_sandhi(word, bopomofos)
|
| 553 |
+
if not re.search("[\u4e00-\u9fff]", word):
|
| 554 |
+
text += "|" + word
|
| 555 |
+
continue
|
| 556 |
+
for i in range(len(bopomofos)):
|
| 557 |
+
bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
|
| 558 |
+
if text != "":
|
| 559 |
+
text += "|"
|
| 560 |
+
text += "|".join(bopomofos)
|
| 561 |
+
return text
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
# Convert latin pronunciation to pinyin (bopomofo)
|
| 565 |
+
def latin_to_bopomofo(text):
|
| 566 |
+
for regex, replacement in _latin_to_bopomofo:
|
| 567 |
+
text = re.sub(regex, replacement, text)
|
| 568 |
+
return text
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# Convert pinyin (bopomofo) to IPA
|
| 572 |
+
def bopomofo_to_ipa(text):
|
| 573 |
+
for regex, replacement in _bopomofo_to_ipa:
|
| 574 |
+
text = re.sub(regex, replacement, text)
|
| 575 |
+
return text
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _chinese_to_ipa(text, sentence):
|
| 579 |
+
text = number_to_chinese(text.strip())
|
| 580 |
+
text = normalization(text)
|
| 581 |
+
text = chinese_to_bopomofo(text, sentence)
|
| 582 |
+
# pinyin = bpmf_to_pinyin(text)
|
| 583 |
+
text = latin_to_bopomofo(text)
|
| 584 |
+
text = bopomofo_to_ipa(text)
|
| 585 |
+
text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
|
| 586 |
+
text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
|
| 587 |
+
text = re.sub(r"^\||[^\w\s_,\.\?!;:\'…\|→↓↑⁼ʰ`]", "", text)
|
| 588 |
+
text = re.sub(r"([,\.\?!;:\'…])", r"|\1|", text)
|
| 589 |
+
text = re.sub(r"\|+", "|", text)
|
| 590 |
+
text = text.rstrip("|")
|
| 591 |
+
return text
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
# Convert Chinese to IPA
|
| 595 |
+
def chinese_to_ipa(text, sentence, text_tokenizer):
|
| 596 |
+
# phonemes = text_tokenizer(text.strip())
|
| 597 |
+
if type(text) == str:
|
| 598 |
+
return _chinese_to_ipa(text, sentence)
|
| 599 |
+
else:
|
| 600 |
+
result_ph = []
|
| 601 |
+
for t in text:
|
| 602 |
+
result_ph.append(_chinese_to_ipa(t, sentence))
|
| 603 |
+
return result_ph
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/text_tokenizers.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
from phonemizer.backend import EspeakBackend
|
| 10 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
| 11 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
| 12 |
+
from phonemizer.separator import Separator
|
| 13 |
+
from phonemizer.utils import list2str, str2list
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TextTokenizer:
|
| 17 |
+
"""Phonemize Text."""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
language="en-us",
|
| 22 |
+
backend="espeak",
|
| 23 |
+
separator=Separator(word="|_|", syllable="-", phone="|"),
|
| 24 |
+
preserve_punctuation=True,
|
| 25 |
+
with_stress: bool = False,
|
| 26 |
+
tie: Union[bool, str] = False,
|
| 27 |
+
language_switch: LanguageSwitch = "remove-flags",
|
| 28 |
+
words_mismatch: WordMismatch = "ignore",
|
| 29 |
+
) -> None:
|
| 30 |
+
self.preserve_punctuation_marks = ",.?!;:'…"
|
| 31 |
+
self.backend = EspeakBackend(
|
| 32 |
+
language,
|
| 33 |
+
punctuation_marks=self.preserve_punctuation_marks,
|
| 34 |
+
preserve_punctuation=preserve_punctuation,
|
| 35 |
+
with_stress=with_stress,
|
| 36 |
+
tie=tie,
|
| 37 |
+
language_switch=language_switch,
|
| 38 |
+
words_mismatch=words_mismatch,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.separator = separator
|
| 42 |
+
|
| 43 |
+
# convert chinese punctuation to english punctuation
|
| 44 |
+
def convert_chinese_punctuation(self, text: str) -> str:
|
| 45 |
+
text = text.replace(",", ",")
|
| 46 |
+
text = text.replace("。", ".")
|
| 47 |
+
text = text.replace("!", "!")
|
| 48 |
+
text = text.replace("?", "?")
|
| 49 |
+
text = text.replace(";", ";")
|
| 50 |
+
text = text.replace(":", ":")
|
| 51 |
+
text = text.replace("、", ",")
|
| 52 |
+
text = text.replace("‘", "'")
|
| 53 |
+
text = text.replace("’", "'")
|
| 54 |
+
text = text.replace("⋯", "…")
|
| 55 |
+
text = text.replace("···", "…")
|
| 56 |
+
text = text.replace("・・・", "…")
|
| 57 |
+
text = text.replace("...", "…")
|
| 58 |
+
return text
|
| 59 |
+
|
| 60 |
+
def __call__(self, text, strip=True) -> List[str]:
|
| 61 |
+
text_type = type(text)
|
| 62 |
+
normalized_text = []
|
| 63 |
+
for line in str2list(text):
|
| 64 |
+
line = self.convert_chinese_punctuation(line.strip())
|
| 65 |
+
line = re.sub(r"[^\w\s_,\.\?!;:\'…]", "", line)
|
| 66 |
+
line = re.sub(r"\s*([,\.\?!;:\'…])\s*", r"\1", line)
|
| 67 |
+
line = re.sub(r"\s+", " ", line)
|
| 68 |
+
normalized_text.append(line)
|
| 69 |
+
# print("Normalized test: ", normalized_text[0])
|
| 70 |
+
phonemized = self.backend.phonemize(
|
| 71 |
+
normalized_text, separator=self.separator, strip=strip, njobs=1
|
| 72 |
+
)
|
| 73 |
+
if text_type == str:
|
| 74 |
+
phonemized = re.sub(r"([,\.\?!;:\'…])", r"|\1|", list2str(phonemized))
|
| 75 |
+
phonemized = re.sub(r"\|+", "|", phonemized)
|
| 76 |
+
phonemized = phonemized.rstrip("|")
|
| 77 |
+
else:
|
| 78 |
+
for i in range(len(phonemized)):
|
| 79 |
+
phonemized[i] = re.sub(r"([,\.\?!;:\'…])", r"|\1|", phonemized[i])
|
| 80 |
+
phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
|
| 81 |
+
phonemized[i] = phonemized[i].rstrip("|")
|
| 82 |
+
return phonemized
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"vocab": {
|
| 3 |
+
",": 0,
|
| 4 |
+
".": 1,
|
| 5 |
+
"?": 2,
|
| 6 |
+
"!": 3,
|
| 7 |
+
"_": 4,
|
| 8 |
+
"iː": 5,
|
| 9 |
+
"ɪ": 6,
|
| 10 |
+
"ɜː": 7,
|
| 11 |
+
"ɚ": 8,
|
| 12 |
+
"oːɹ": 9,
|
| 13 |
+
"ɔː": 10,
|
| 14 |
+
"ɔːɹ": 11,
|
| 15 |
+
"ɑː": 12,
|
| 16 |
+
"uː": 13,
|
| 17 |
+
"ʊ": 14,
|
| 18 |
+
"ɑːɹ": 15,
|
| 19 |
+
"ʌ": 16,
|
| 20 |
+
"ɛ": 17,
|
| 21 |
+
"æ": 18,
|
| 22 |
+
"eɪ": 19,
|
| 23 |
+
"aɪ": 20,
|
| 24 |
+
"ɔɪ": 21,
|
| 25 |
+
"aʊ": 22,
|
| 26 |
+
"oʊ": 23,
|
| 27 |
+
"ɪɹ": 24,
|
| 28 |
+
"ɛɹ": 25,
|
| 29 |
+
"ʊɹ": 26,
|
| 30 |
+
"p": 27,
|
| 31 |
+
"b": 28,
|
| 32 |
+
"t": 29,
|
| 33 |
+
"d": 30,
|
| 34 |
+
"k": 31,
|
| 35 |
+
"ɡ": 32,
|
| 36 |
+
"f": 33,
|
| 37 |
+
"v": 34,
|
| 38 |
+
"θ": 35,
|
| 39 |
+
"ð": 36,
|
| 40 |
+
"s": 37,
|
| 41 |
+
"z": 38,
|
| 42 |
+
"ʃ": 39,
|
| 43 |
+
"ʒ": 40,
|
| 44 |
+
"h": 41,
|
| 45 |
+
"tʃ": 42,
|
| 46 |
+
"dʒ": 43,
|
| 47 |
+
"m": 44,
|
| 48 |
+
"n": 45,
|
| 49 |
+
"ŋ": 46,
|
| 50 |
+
"j": 47,
|
| 51 |
+
"w": 48,
|
| 52 |
+
"ɹ": 49,
|
| 53 |
+
"l": 50,
|
| 54 |
+
"tɹ": 51,
|
| 55 |
+
"dɹ": 52,
|
| 56 |
+
"ts": 53,
|
| 57 |
+
"dz": 54,
|
| 58 |
+
"i": 55,
|
| 59 |
+
"ɔ": 56,
|
| 60 |
+
"ə": 57,
|
| 61 |
+
"ɾ": 58,
|
| 62 |
+
"iə": 59,
|
| 63 |
+
"r": 60,
|
| 64 |
+
"u": 61,
|
| 65 |
+
"oː": 62,
|
| 66 |
+
"ɛː": 63,
|
| 67 |
+
"ɪː": 64,
|
| 68 |
+
"aɪə": 65,
|
| 69 |
+
"aɪɚ": 66,
|
| 70 |
+
"ɑ̃": 67,
|
| 71 |
+
"ç": 68,
|
| 72 |
+
"ɔ̃": 69,
|
| 73 |
+
"ææ": 70,
|
| 74 |
+
"ɐɐ": 71,
|
| 75 |
+
"ɡʲ": 72,
|
| 76 |
+
"nʲ": 73,
|
| 77 |
+
"iːː": 74,
|
| 78 |
+
|
| 79 |
+
"p⁼": 75,
|
| 80 |
+
"pʰ": 76,
|
| 81 |
+
"t⁼": 77,
|
| 82 |
+
"tʰ": 78,
|
| 83 |
+
"k⁼": 79,
|
| 84 |
+
"kʰ": 80,
|
| 85 |
+
"x": 81,
|
| 86 |
+
"tʃ⁼": 82,
|
| 87 |
+
"tʃʰ": 83,
|
| 88 |
+
"ts`⁼": 84,
|
| 89 |
+
"ts`ʰ": 85,
|
| 90 |
+
"s`": 86,
|
| 91 |
+
"ɹ`": 87,
|
| 92 |
+
"ts⁼": 88,
|
| 93 |
+
"tsʰ": 89,
|
| 94 |
+
"p⁼wo": 90,
|
| 95 |
+
"p⁼wo→": 91,
|
| 96 |
+
"p⁼wo↑": 92,
|
| 97 |
+
"p⁼wo↓↑": 93,
|
| 98 |
+
"p⁼wo↓": 94,
|
| 99 |
+
"pʰwo": 95,
|
| 100 |
+
"pʰwo→": 96,
|
| 101 |
+
"pʰwo↑": 97,
|
| 102 |
+
"pʰwo↓↑": 98,
|
| 103 |
+
"pʰwo↓": 99,
|
| 104 |
+
"mwo": 100,
|
| 105 |
+
"mwo→": 101,
|
| 106 |
+
"mwo↑": 102,
|
| 107 |
+
"mwo↓↑": 103,
|
| 108 |
+
"mwo↓": 104,
|
| 109 |
+
"fwo": 105,
|
| 110 |
+
"fwo→": 106,
|
| 111 |
+
"fwo↑": 107,
|
| 112 |
+
"fwo↓↑": 108,
|
| 113 |
+
"fwo↓": 109,
|
| 114 |
+
"jɛn": 110,
|
| 115 |
+
"jɛn→": 111,
|
| 116 |
+
"jɛn↑": 112,
|
| 117 |
+
"jɛn↓↑": 113,
|
| 118 |
+
"jɛn↓": 114,
|
| 119 |
+
"ɥæn": 115,
|
| 120 |
+
"ɥæn→": 116,
|
| 121 |
+
"ɥæn↑": 117,
|
| 122 |
+
"ɥæn↓↑": 118,
|
| 123 |
+
"ɥæn↓": 119,
|
| 124 |
+
"in": 120,
|
| 125 |
+
"in→": 121,
|
| 126 |
+
"in↑": 122,
|
| 127 |
+
"in↓↑": 123,
|
| 128 |
+
"in↓": 124,
|
| 129 |
+
"ɥn": 125,
|
| 130 |
+
"ɥn→": 126,
|
| 131 |
+
"ɥn↑": 127,
|
| 132 |
+
"ɥn↓↑": 128,
|
| 133 |
+
"ɥn↓": 129,
|
| 134 |
+
"iŋ": 130,
|
| 135 |
+
"iŋ→": 131,
|
| 136 |
+
"iŋ↑": 132,
|
| 137 |
+
"iŋ↓↑": 133,
|
| 138 |
+
"iŋ↓": 134,
|
| 139 |
+
"ʊŋ": 135,
|
| 140 |
+
"ʊŋ→": 136,
|
| 141 |
+
"ʊŋ↑": 137,
|
| 142 |
+
"ʊŋ↓↑": 138,
|
| 143 |
+
"ʊŋ↓": 139,
|
| 144 |
+
"jʊŋ": 140,
|
| 145 |
+
"jʊŋ→": 141,
|
| 146 |
+
"jʊŋ↑": 142,
|
| 147 |
+
"jʊŋ↓↑": 143,
|
| 148 |
+
"jʊŋ↓": 144,
|
| 149 |
+
"ia": 145,
|
| 150 |
+
"ia→": 146,
|
| 151 |
+
"ia↑": 147,
|
| 152 |
+
"ia↓↑": 148,
|
| 153 |
+
"ia↓": 149,
|
| 154 |
+
"iɛ": 150,
|
| 155 |
+
"iɛ→": 151,
|
| 156 |
+
"iɛ↑": 152,
|
| 157 |
+
"iɛ↓↑": 153,
|
| 158 |
+
"iɛ↓": 154,
|
| 159 |
+
"iɑʊ": 155,
|
| 160 |
+
"iɑʊ→": 156,
|
| 161 |
+
"iɑʊ↑": 157,
|
| 162 |
+
"iɑʊ↓↑": 158,
|
| 163 |
+
"iɑʊ↓": 159,
|
| 164 |
+
"ioʊ": 160,
|
| 165 |
+
"ioʊ→": 161,
|
| 166 |
+
"ioʊ↑": 162,
|
| 167 |
+
"ioʊ↓↑": 163,
|
| 168 |
+
"ioʊ↓": 164,
|
| 169 |
+
"iɑŋ": 165,
|
| 170 |
+
"iɑŋ→": 166,
|
| 171 |
+
"iɑŋ↑": 167,
|
| 172 |
+
"iɑŋ↓↑": 168,
|
| 173 |
+
"iɑŋ↓": 169,
|
| 174 |
+
"ua": 170,
|
| 175 |
+
"ua→": 171,
|
| 176 |
+
"ua↑": 172,
|
| 177 |
+
"ua↓↑": 173,
|
| 178 |
+
"ua↓": 174,
|
| 179 |
+
"uo": 175,
|
| 180 |
+
"uo→": 176,
|
| 181 |
+
"uo↑": 177,
|
| 182 |
+
"uo↓↑": 178,
|
| 183 |
+
"uo↓": 179,
|
| 184 |
+
"uaɪ": 180,
|
| 185 |
+
"uaɪ→": 181,
|
| 186 |
+
"uaɪ↑": 182,
|
| 187 |
+
"uaɪ↓↑": 183,
|
| 188 |
+
"uaɪ↓": 184,
|
| 189 |
+
"ueɪ": 185,
|
| 190 |
+
"ueɪ→": 186,
|
| 191 |
+
"ueɪ↑": 187,
|
| 192 |
+
"ueɪ↓↑": 188,
|
| 193 |
+
"ueɪ↓": 189,
|
| 194 |
+
"uan": 190,
|
| 195 |
+
"uan→": 191,
|
| 196 |
+
"uan↑": 192,
|
| 197 |
+
"uan↓↑": 193,
|
| 198 |
+
"uan↓": 194,
|
| 199 |
+
"uən": 195,
|
| 200 |
+
"uən→": 196,
|
| 201 |
+
"uən↑": 197,
|
| 202 |
+
"uən↓↑": 198,
|
| 203 |
+
"uən↓": 199,
|
| 204 |
+
"uɑŋ": 200,
|
| 205 |
+
"uɑŋ→": 201,
|
| 206 |
+
"uɑŋ↑": 202,
|
| 207 |
+
"uɑŋ↓↑": 203,
|
| 208 |
+
"uɑŋ↓": 204,
|
| 209 |
+
"ɥɛ": 205,
|
| 210 |
+
"ɥɛ→": 206,
|
| 211 |
+
"ɥɛ↑": 207,
|
| 212 |
+
"ɥɛ↓↑": 208,
|
| 213 |
+
"ɥɛ↓": 209,
|
| 214 |
+
"a": 210,
|
| 215 |
+
"a→": 211,
|
| 216 |
+
"a↑": 212,
|
| 217 |
+
"a↓↑": 213,
|
| 218 |
+
"a↓": 214,
|
| 219 |
+
"o": 215,
|
| 220 |
+
"o→": 216,
|
| 221 |
+
"o↑": 217,
|
| 222 |
+
"o↓↑": 218,
|
| 223 |
+
"o↓": 219,
|
| 224 |
+
"ə→": 220,
|
| 225 |
+
"ə↑": 221,
|
| 226 |
+
"ə↓↑": 222,
|
| 227 |
+
"ə↓": 223,
|
| 228 |
+
"ɛ→": 224,
|
| 229 |
+
"ɛ↑": 225,
|
| 230 |
+
"ɛ↓↑": 226,
|
| 231 |
+
"ɛ↓": 227,
|
| 232 |
+
"aɪ→": 228,
|
| 233 |
+
"aɪ↑": 229,
|
| 234 |
+
"aɪ↓↑": 230,
|
| 235 |
+
"aɪ↓": 231,
|
| 236 |
+
"eɪ→": 232,
|
| 237 |
+
"eɪ↑": 233,
|
| 238 |
+
"eɪ↓↑": 234,
|
| 239 |
+
"eɪ↓": 235,
|
| 240 |
+
"ɑʊ": 236,
|
| 241 |
+
"ɑʊ→": 237,
|
| 242 |
+
"ɑʊ↑": 238,
|
| 243 |
+
"ɑʊ↓↑": 239,
|
| 244 |
+
"ɑʊ↓": 240,
|
| 245 |
+
"oʊ→": 241,
|
| 246 |
+
"oʊ↑": 242,
|
| 247 |
+
"oʊ↓↑": 243,
|
| 248 |
+
"oʊ↓": 244,
|
| 249 |
+
"an": 245,
|
| 250 |
+
"an→": 246,
|
| 251 |
+
"an↑": 247,
|
| 252 |
+
"an↓↑": 248,
|
| 253 |
+
"an↓": 249,
|
| 254 |
+
"ən": 250,
|
| 255 |
+
"ən→": 251,
|
| 256 |
+
"ən↑": 252,
|
| 257 |
+
"ən↓↑": 253,
|
| 258 |
+
"ən↓": 254,
|
| 259 |
+
"ɑŋ": 255,
|
| 260 |
+
"ɑŋ→": 256,
|
| 261 |
+
"ɑŋ↑": 257,
|
| 262 |
+
"ɑŋ↓↑": 258,
|
| 263 |
+
"ɑŋ↓": 259,
|
| 264 |
+
"əŋ": 260,
|
| 265 |
+
"əŋ→": 261,
|
| 266 |
+
"əŋ↑": 262,
|
| 267 |
+
"əŋ↓↑": 263,
|
| 268 |
+
"əŋ↓": 264,
|
| 269 |
+
"əɹ": 265,
|
| 270 |
+
"əɹ→": 266,
|
| 271 |
+
"əɹ↑": 267,
|
| 272 |
+
"əɹ↓↑": 268,
|
| 273 |
+
"əɹ↓": 269,
|
| 274 |
+
"i→": 270,
|
| 275 |
+
"i↑": 271,
|
| 276 |
+
"i↓↑": 272,
|
| 277 |
+
"i↓": 273,
|
| 278 |
+
"u→": 274,
|
| 279 |
+
"u↑": 275,
|
| 280 |
+
"u↓↑": 276,
|
| 281 |
+
"u↓": 277,
|
| 282 |
+
"ɥ": 278,
|
| 283 |
+
"ɥ→": 279,
|
| 284 |
+
"ɥ↑": 280,
|
| 285 |
+
"ɥ↓↑": 281,
|
| 286 |
+
"ɥ↓": 282,
|
| 287 |
+
"ts`⁼ɹ": 283,
|
| 288 |
+
"ts`⁼ɹ→": 284,
|
| 289 |
+
"ts`⁼ɹ↑": 285,
|
| 290 |
+
"ts`⁼ɹ↓↑": 286,
|
| 291 |
+
"ts`⁼ɹ↓": 287,
|
| 292 |
+
"ts`ʰɹ": 288,
|
| 293 |
+
"ts`ʰɹ→": 289,
|
| 294 |
+
"ts`ʰɹ↑": 290,
|
| 295 |
+
"ts`ʰɹ↓↑": 291,
|
| 296 |
+
"ts`ʰɹ↓": 292,
|
| 297 |
+
"s`ɹ": 293,
|
| 298 |
+
"s`ɹ→": 294,
|
| 299 |
+
"s`ɹ↑": 295,
|
| 300 |
+
"s`ɹ↓↑": 296,
|
| 301 |
+
"s`ɹ���": 297,
|
| 302 |
+
"ɹ`ɹ": 298,
|
| 303 |
+
"ɹ`ɹ→": 299,
|
| 304 |
+
"ɹ`ɹ↑": 300,
|
| 305 |
+
"ɹ`ɹ↓↑": 301,
|
| 306 |
+
"ɹ`ɹ↓": 302,
|
| 307 |
+
"ts⁼ɹ": 303,
|
| 308 |
+
"ts⁼ɹ→": 304,
|
| 309 |
+
"ts⁼ɹ↑": 305,
|
| 310 |
+
"ts⁼ɹ↓↑": 306,
|
| 311 |
+
"ts⁼ɹ↓": 307,
|
| 312 |
+
"tsʰɹ": 308,
|
| 313 |
+
"tsʰɹ→": 309,
|
| 314 |
+
"tsʰɹ↑": 310,
|
| 315 |
+
"tsʰɹ↓↑": 311,
|
| 316 |
+
"tsʰɹ↓": 312,
|
| 317 |
+
"sɹ": 313,
|
| 318 |
+
"sɹ→": 314,
|
| 319 |
+
"sɹ↑": 315,
|
| 320 |
+
"sɹ↓↑": 316,
|
| 321 |
+
"sɹ↓": 317,
|
| 322 |
+
|
| 323 |
+
"ɯ": 318,
|
| 324 |
+
"e": 319,
|
| 325 |
+
"aː": 320,
|
| 326 |
+
"ɯː": 321,
|
| 327 |
+
"eː": 322,
|
| 328 |
+
"ç": 323,
|
| 329 |
+
"ɸ": 324,
|
| 330 |
+
"ɰᵝ": 325,
|
| 331 |
+
"ɴ": 326,
|
| 332 |
+
"g": 327,
|
| 333 |
+
"dʑ": 328,
|
| 334 |
+
"q": 329,
|
| 335 |
+
"ː": 330,
|
| 336 |
+
"bj": 331,
|
| 337 |
+
"tɕ": 332,
|
| 338 |
+
"dej": 333,
|
| 339 |
+
"tej": 334,
|
| 340 |
+
"gj": 335,
|
| 341 |
+
"gɯ": 336,
|
| 342 |
+
"çj": 337,
|
| 343 |
+
"kj": 338,
|
| 344 |
+
"kɯ": 339,
|
| 345 |
+
"mj": 340,
|
| 346 |
+
"nj": 341,
|
| 347 |
+
"pj": 342,
|
| 348 |
+
"ɾj": 343,
|
| 349 |
+
"ɕ": 344,
|
| 350 |
+
"tsɯ": 345,
|
| 351 |
+
|
| 352 |
+
"ɐ": 346,
|
| 353 |
+
"ɑ": 347,
|
| 354 |
+
"ɒ": 348,
|
| 355 |
+
"ɜ": 349,
|
| 356 |
+
"ɫ": 350,
|
| 357 |
+
"ʑ": 351,
|
| 358 |
+
"ʲ": 352,
|
| 359 |
+
|
| 360 |
+
"y": 353,
|
| 361 |
+
"ø": 354,
|
| 362 |
+
"œ": 355,
|
| 363 |
+
"ʁ": 356,
|
| 364 |
+
"̃": 357,
|
| 365 |
+
"ɲ": 358,
|
| 366 |
+
|
| 367 |
+
":": 359,
|
| 368 |
+
";": 360,
|
| 369 |
+
"'": 361,
|
| 370 |
+
"…": 362
|
| 371 |
+
}
|
| 372 |
+
}
|
src/YingMusicSinger/utils/f5_tts/g2p/g2p_generation.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.g2p import PhonemeBpeTokenizer
|
| 11 |
+
from src.YingMusicSinger.utils.f5_tts.g2p.utils.g2p import phonemizer_g2p
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def ph_g2p(text, language):
|
| 15 |
+
return phonemizer_g2p(text=text, language=language)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def g2p(text, sentence, language):
|
| 19 |
+
return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_chinese(char):
|
| 23 |
+
if char >= "\u4e00" and char <= "\u9fa5":
|
| 24 |
+
return True
|
| 25 |
+
else:
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_alphabet(char):
|
| 30 |
+
if (char >= "\u0041" and char <= "\u005a") or (
|
| 31 |
+
char >= "\u0061" and char <= "\u007a"
|
| 32 |
+
):
|
| 33 |
+
return True
|
| 34 |
+
else:
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def is_other(char):
|
| 39 |
+
if not (is_chinese(char) or is_alphabet(char)):
|
| 40 |
+
return True
|
| 41 |
+
else:
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_segment(text: str) -> List[str]:
|
| 46 |
+
# sentence --> [ch_part, en_part, ch_part, ...]
|
| 47 |
+
segments = []
|
| 48 |
+
types = []
|
| 49 |
+
flag = 0
|
| 50 |
+
temp_seg = ""
|
| 51 |
+
temp_lang = ""
|
| 52 |
+
|
| 53 |
+
# Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
|
| 54 |
+
for i, ch in enumerate(text):
|
| 55 |
+
if is_chinese(ch):
|
| 56 |
+
types.append("zh")
|
| 57 |
+
elif is_alphabet(ch):
|
| 58 |
+
types.append("en")
|
| 59 |
+
else:
|
| 60 |
+
types.append("other")
|
| 61 |
+
|
| 62 |
+
assert len(types) == len(text)
|
| 63 |
+
|
| 64 |
+
for i in range(len(types)):
|
| 65 |
+
# find the first char of the seg
|
| 66 |
+
if flag == 0:
|
| 67 |
+
temp_seg += text[i]
|
| 68 |
+
temp_lang = types[i]
|
| 69 |
+
flag = 1
|
| 70 |
+
else:
|
| 71 |
+
if temp_lang == "other":
|
| 72 |
+
if types[i] == temp_lang:
|
| 73 |
+
temp_seg += text[i]
|
| 74 |
+
else:
|
| 75 |
+
temp_seg += text[i]
|
| 76 |
+
temp_lang = types[i]
|
| 77 |
+
else:
|
| 78 |
+
if types[i] == temp_lang:
|
| 79 |
+
temp_seg += text[i]
|
| 80 |
+
elif types[i] == "other":
|
| 81 |
+
temp_seg += text[i]
|
| 82 |
+
else:
|
| 83 |
+
segments.append((temp_seg, temp_lang))
|
| 84 |
+
temp_seg = text[i]
|
| 85 |
+
temp_lang = types[i]
|
| 86 |
+
flag = 1
|
| 87 |
+
|
| 88 |
+
segments.append((temp_seg, temp_lang))
|
| 89 |
+
return segments
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def chn_eng_g2p(text: str):
|
| 93 |
+
# now only en and ch
|
| 94 |
+
segments = get_segment(text)
|
| 95 |
+
all_phoneme = ""
|
| 96 |
+
all_tokens = []
|
| 97 |
+
|
| 98 |
+
for index in range(len(segments)):
|
| 99 |
+
seg = segments[index]
|
| 100 |
+
phoneme, token = g2p(seg[0], text, seg[1])
|
| 101 |
+
all_phoneme += phoneme + "|"
|
| 102 |
+
all_tokens += token
|
| 103 |
+
|
| 104 |
+
if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
|
| 105 |
+
all_phoneme = all_phoneme[:-2]
|
| 106 |
+
all_tokens = all_tokens[:-1]
|
| 107 |
+
return all_phoneme, all_tokens
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
text_tokenizer = PhonemeBpeTokenizer()
|
| 111 |
+
with open("./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json", "r") as f:
|
| 112 |
+
json_data = f.read()
|
| 113 |
+
data = json.loads(json_data)
|
| 114 |
+
vocab = data["vocab"]
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
phone, token = chn_eng_g2p("你好,hello world")
|
| 118 |
+
phone, token = chn_eng_g2p(
|
| 119 |
+
"你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑"
|
| 120 |
+
)
|
| 121 |
+
print(phone)
|
| 122 |
+
print(token)
|
| 123 |
+
|
| 124 |
+
# phone, token = text_tokenizer.tokenize("你好,hello world, Bonjour, 테스트 해 보겠습니다, 五月雨緑", "", "auto")
|
| 125 |
+
phone, token = text_tokenizer.tokenize("緑", "", "auto")
|
| 126 |
+
# phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "auto")
|
| 127 |
+
# phone, token = text_tokenizer.tokenize("आइए इसका परीक्षण करें", "", "other")
|
| 128 |
+
print(phone)
|
| 129 |
+
print(token)
|
src/YingMusicSinger/utils/f5_tts/g2p/infer_dpo.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from f5_tts.infer.utils_infer import load_checkpoint
|
| 9 |
+
from f5_tts.model import CFM, DiT
|
| 10 |
+
from f5_tts.model.alsp_lance.data.npydata import FloatData
|
| 11 |
+
from f5_tts.model.alsp_lance.tools import LanceReader, LanceWriter
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
filter_keyword_list = [
|
| 15 |
+
"纯音乐",
|
| 16 |
+
"编曲",
|
| 17 |
+
"作词",
|
| 18 |
+
"作曲",
|
| 19 |
+
"调音",
|
| 20 |
+
"制作人",
|
| 21 |
+
"录音师",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
filter_full_list = ["music", "end"]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def check_lyric(time: float, lyric: str):
|
| 28 |
+
if time < 0.1:
|
| 29 |
+
return False
|
| 30 |
+
for filter_keyword in filter_keyword_list:
|
| 31 |
+
if filter_keyword in lyric:
|
| 32 |
+
return False
|
| 33 |
+
for filter_full in filter_full_list:
|
| 34 |
+
if filter_full == lyric.strip().lower():
|
| 35 |
+
return False
|
| 36 |
+
if len(lyric) == 0:
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def parse_lyrics(lyrics: str):
|
| 42 |
+
lyrics_with_time = []
|
| 43 |
+
lyrics = lyrics.strip()
|
| 44 |
+
for line in lyrics.split("\n"):
|
| 45 |
+
try:
|
| 46 |
+
time, lyric = line[1:9], line[10:]
|
| 47 |
+
lyric = lyric.strip()
|
| 48 |
+
mins, secs = time.split(":")
|
| 49 |
+
secs = int(mins) * 60 + float(secs)
|
| 50 |
+
# print(lyric, check_lyric(secs, lyric))
|
| 51 |
+
if not check_lyric(secs, lyric):
|
| 52 |
+
continue
|
| 53 |
+
lyrics_with_time.append((secs, lyric))
|
| 54 |
+
except:
|
| 55 |
+
# traceback.print_exc()
|
| 56 |
+
continue
|
| 57 |
+
# print("error", line)
|
| 58 |
+
return lyrics_with_time
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CNENTokenizer:
|
| 62 |
+
def __init__(self):
|
| 63 |
+
with open("./src/YingMusicSinger/utils/f5_tts/g2p/g2p/vocab.json", "r") as file:
|
| 64 |
+
self.phone2id: dict = json.load(file)["vocab"]
|
| 65 |
+
self.id2phone = {v: k for (k, v) in self.phone2id.items()}
|
| 66 |
+
from f5_tts.g2p.g2p_generation import chn_eng_g2p
|
| 67 |
+
|
| 68 |
+
self.tokenizer = chn_eng_g2p
|
| 69 |
+
|
| 70 |
+
def encode(self, text):
|
| 71 |
+
phone, token = self.tokenizer(text)
|
| 72 |
+
token = [x + 1 for x in token]
|
| 73 |
+
return token
|
| 74 |
+
|
| 75 |
+
def decode(self, token):
|
| 76 |
+
return "|".join([self.id2phone[x - 1] for x in token])
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def inference(
|
| 80 |
+
model,
|
| 81 |
+
cond,
|
| 82 |
+
text,
|
| 83 |
+
duration,
|
| 84 |
+
style_prompt,
|
| 85 |
+
style,
|
| 86 |
+
output_dir,
|
| 87 |
+
song_name,
|
| 88 |
+
ckpt_step,
|
| 89 |
+
start_time,
|
| 90 |
+
latent_pred_start_frame,
|
| 91 |
+
latent_pred_end_frame,
|
| 92 |
+
epoch,
|
| 93 |
+
cfg_strength,
|
| 94 |
+
):
|
| 95 |
+
# import pdb; pdb.set_trace()
|
| 96 |
+
with torch.inference_mode():
|
| 97 |
+
generated, _ = model.sample(
|
| 98 |
+
cond=cond,
|
| 99 |
+
text=text,
|
| 100 |
+
duration=duration,
|
| 101 |
+
style_prompt=style_prompt,
|
| 102 |
+
steps=32,
|
| 103 |
+
cfg_strength=cfg_strength,
|
| 104 |
+
sway_sampling_coef=None,
|
| 105 |
+
start_time=start_time,
|
| 106 |
+
latent_pred_start_frame=latent_pred_start_frame,
|
| 107 |
+
latent_pred_end_frame=latent_pred_end_frame,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
generated = generated.to(torch.float32) # [b t d]
|
| 111 |
+
latent = generated.transpose(1, 2) # [b d t]
|
| 112 |
+
latent = latent.detach().cpu.numpy()
|
| 113 |
+
|
| 114 |
+
return latent
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_style_prompt(device, song_name, song_name2ref_npy):
|
| 118 |
+
mulan_style_path = song_name2ref_npy[song_name]
|
| 119 |
+
mulan_stlye = np.load(mulan_style_path)
|
| 120 |
+
|
| 121 |
+
style_prompt = torch.from_numpy(mulan_stlye).to(device) # [1, 512]
|
| 122 |
+
style_prompt = style_prompt.half()
|
| 123 |
+
|
| 124 |
+
return style_prompt
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_lrc_prompt(text, tokenizer, dit_model, max_secs):
|
| 128 |
+
max_frames = 2048
|
| 129 |
+
lyrics_shift = 2
|
| 130 |
+
sampling_rate = 44100
|
| 131 |
+
downsample_rate = 2048
|
| 132 |
+
|
| 133 |
+
pad_token_id = 0
|
| 134 |
+
comma_token_id = 1
|
| 135 |
+
period_token_id = 2
|
| 136 |
+
|
| 137 |
+
fsmin = -10
|
| 138 |
+
fsmax = 10
|
| 139 |
+
|
| 140 |
+
lrc_with_time = parse_lyrics(text)
|
| 141 |
+
|
| 142 |
+
modified_lrc_with_time = []
|
| 143 |
+
for i in range(len(lrc_with_time)):
|
| 144 |
+
time, line = lrc_with_time[i]
|
| 145 |
+
# line_token = self.tokenizer.encode(line)
|
| 146 |
+
line_token = tokenizer.encode(line)
|
| 147 |
+
modified_lrc_with_time.append((time, line_token))
|
| 148 |
+
|
| 149 |
+
lrc_with_time = modified_lrc_with_time
|
| 150 |
+
|
| 151 |
+
lrc_with_time = [
|
| 152 |
+
(time_start, line)
|
| 153 |
+
for (time_start, line) in lrc_with_time
|
| 154 |
+
if time_start < max_secs
|
| 155 |
+
]
|
| 156 |
+
# latent_end_time = lrc_with_time[-1][0] if len(lrc_with_time) >= 1 else -1
|
| 157 |
+
lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
|
| 158 |
+
|
| 159 |
+
normalized_start_time = 0.0
|
| 160 |
+
|
| 161 |
+
lrc = torch.zeros((max_frames,), dtype=torch.long)
|
| 162 |
+
|
| 163 |
+
tokens_count = 0
|
| 164 |
+
last_end_pos = 0
|
| 165 |
+
for time_start, line in lrc_with_time:
|
| 166 |
+
tokens = [
|
| 167 |
+
token if token != period_token_id else comma_token_id for token in line
|
| 168 |
+
] + [period_token_id]
|
| 169 |
+
tokens = torch.tensor(tokens, dtype=torch.long)
|
| 170 |
+
num_tokens = tokens.shape[0]
|
| 171 |
+
|
| 172 |
+
gt_frame_start = int(time_start * sampling_rate / downsample_rate)
|
| 173 |
+
|
| 174 |
+
frame_shift = random.randint(int(fsmin), int(fsmax))
|
| 175 |
+
|
| 176 |
+
frame_start = max(gt_frame_start - frame_shift, last_end_pos)
|
| 177 |
+
frame_len = min(num_tokens, max_frames - frame_start)
|
| 178 |
+
|
| 179 |
+
# print(gt_frame_start, frame_shift, frame_start, frame_len, tokens_count, last_end_pos, full_pos_emb.shape)
|
| 180 |
+
|
| 181 |
+
lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
|
| 182 |
+
|
| 183 |
+
tokens_count += num_tokens
|
| 184 |
+
last_end_pos = frame_start + frame_len
|
| 185 |
+
|
| 186 |
+
lrc_emb = lrc.unsqueeze(0).to(dit_model.device)
|
| 187 |
+
|
| 188 |
+
normalized_start_time = (
|
| 189 |
+
torch.tensor(normalized_start_time).unsqueeze(0).to(dit_model.device)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
return lrc_emb, normalized_start_time
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
parser = argparse.ArgumentParser()
|
| 197 |
+
|
| 198 |
+
parser.add_argument("--model-config", type=str, default=None)
|
| 199 |
+
parser.add_argument("--ckpt-path", type=str, default=None)
|
| 200 |
+
parser.add_argument("--output-dir", type=str, default=None) # lance
|
| 201 |
+
parser.add_argument("--lrc-path", type=str, default=None)
|
| 202 |
+
parser.add_argument("--mulan-style-path", type=str, default=None) # lance
|
| 203 |
+
parser.add_argument("--cfg-strength", type=float, default=None)
|
| 204 |
+
|
| 205 |
+
args = parser.parse_args()
|
| 206 |
+
|
| 207 |
+
lrc_path = args.lrc_path
|
| 208 |
+
cfg_strength = args.cfg_strength
|
| 209 |
+
style_path = args.mulan_style_path
|
| 210 |
+
|
| 211 |
+
with open(args.model_config) as f:
|
| 212 |
+
model_config = json.load(f)
|
| 213 |
+
|
| 214 |
+
model_cls = DiT
|
| 215 |
+
ckpt_path = args.ckpt_path
|
| 216 |
+
device = "cuda"
|
| 217 |
+
use_style_prompt = True
|
| 218 |
+
dit_model = CFM(
|
| 219 |
+
transformer=model_cls(
|
| 220 |
+
**model_config["model"], use_style_prompt=use_style_prompt
|
| 221 |
+
),
|
| 222 |
+
num_channels=model_config["model"]["mel_dim"],
|
| 223 |
+
use_style_prompt=use_style_prompt,
|
| 224 |
+
)
|
| 225 |
+
dit_model = dit_model.to(device)
|
| 226 |
+
dit_model = load_checkpoint(dit_model, ckpt_path, device=device, use_ema=True)
|
| 227 |
+
|
| 228 |
+
lrc_tokenizer = CNENTokenizer()
|
| 229 |
+
|
| 230 |
+
sampling_rate = 44100
|
| 231 |
+
downsample_rate = 2048
|
| 232 |
+
max_frames = 2048
|
| 233 |
+
max_secs = max_frames / (sampling_rate / downsample_rate)
|
| 234 |
+
|
| 235 |
+
output_dir = args.output_dir
|
| 236 |
+
writer = LanceWriter(output_dir, target_cls=FloatData)
|
| 237 |
+
|
| 238 |
+
reader = LanceReader(style_path, target_cls=FloatData)
|
| 239 |
+
|
| 240 |
+
WRITE_INTERVAL = 500
|
| 241 |
+
|
| 242 |
+
latent_data = []
|
| 243 |
+
for id in tqdm(reader.get_ids()):
|
| 244 |
+
item = reader.get_datas_by_rowids(row_ids=[id._rowid])[0]
|
| 245 |
+
data_id = item.data_id
|
| 246 |
+
style_prompt = torch.from_numpy(item.data).to(device)
|
| 247 |
+
stlye_prompt = style_prompt.half()
|
| 248 |
+
|
| 249 |
+
lrc_path = os.path.join(lrc_path, f"{data_id}.lrc")
|
| 250 |
+
with (open(lrc_path), "r") as f:
|
| 251 |
+
lrc = [line.strip() for line in f.readlines()]
|
| 252 |
+
lrc_prompt, start_time = get_lrc_prompt(lrc, lrc_tokenizer, dit_model, max_secs)
|
| 253 |
+
|
| 254 |
+
latent_prompt = torch.zeros(1, max_frames, 64).to(device)
|
| 255 |
+
sf = 0
|
| 256 |
+
ef = max_frames
|
| 257 |
+
|
| 258 |
+
generated_latent = inference(
|
| 259 |
+
model=dit_model,
|
| 260 |
+
cond=latent_prompt,
|
| 261 |
+
text=lrc_prompt,
|
| 262 |
+
duration=max_frames,
|
| 263 |
+
style_prompt=style_prompt,
|
| 264 |
+
output_dir=output_dir,
|
| 265 |
+
start_time=start_time,
|
| 266 |
+
latent_pred_start_frame=sf,
|
| 267 |
+
latent_pred_end_frame=ef,
|
| 268 |
+
cfg_strength=cfg_strength,
|
| 269 |
+
) # [b d t] numpy
|
| 270 |
+
|
| 271 |
+
latent_data.append(generated_latent)
|
| 272 |
+
|
| 273 |
+
if len(latent_data) > WRITE_INTERVAL:
|
| 274 |
+
writer.write_parallel(latent_data)
|
| 275 |
+
latent_data = []
|
| 276 |
+
|
| 277 |
+
writer.write_parallel(latent_data)
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/bpmf_2_pinyin.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
b ㄅ
|
| 2 |
+
p ㄆ
|
| 3 |
+
m ㄇ
|
| 4 |
+
f ㄈ
|
| 5 |
+
d ㄉ
|
| 6 |
+
t ㄊ
|
| 7 |
+
n ㄋ
|
| 8 |
+
l ㄌ
|
| 9 |
+
g ㄍ
|
| 10 |
+
k ㄎ
|
| 11 |
+
h ㄏ
|
| 12 |
+
j ㄐ
|
| 13 |
+
q ㄑ
|
| 14 |
+
x ㄒ
|
| 15 |
+
zh ㄓ
|
| 16 |
+
ch ㄔ
|
| 17 |
+
sh ㄕ
|
| 18 |
+
r ㄖ
|
| 19 |
+
z ㄗ
|
| 20 |
+
c ㄘ
|
| 21 |
+
s ㄙ
|
| 22 |
+
i ㄧ
|
| 23 |
+
u ㄨ
|
| 24 |
+
v ㄩ
|
| 25 |
+
a ㄚ
|
| 26 |
+
o ㄛ
|
| 27 |
+
e ㄜ
|
| 28 |
+
e ㄝ
|
| 29 |
+
ai ㄞ
|
| 30 |
+
ei ㄟ
|
| 31 |
+
ao ㄠ
|
| 32 |
+
ou ㄡ
|
| 33 |
+
an ㄢ
|
| 34 |
+
en ㄣ
|
| 35 |
+
ang ㄤ
|
| 36 |
+
eng ㄥ
|
| 37 |
+
er ㄦ
|
| 38 |
+
2 ˊ
|
| 39 |
+
3 ˇ
|
| 40 |
+
4 ˋ
|
| 41 |
+
0 ˙
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/chinese_lexicon.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
|
| 3 |
+
size 14779443
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/config.json
ADDED
|
@@ -0,0 +1,819 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"BertPoly"
|
| 5 |
+
],
|
| 6 |
+
"attention_probs_dropout_prob": 0.1,
|
| 7 |
+
"classifier_dropout": null,
|
| 8 |
+
"directionality": "bidi",
|
| 9 |
+
"gradient_checkpointing": false,
|
| 10 |
+
"hidden_act": "gelu",
|
| 11 |
+
"hidden_dropout_prob": 0.1,
|
| 12 |
+
"hidden_size": 384,
|
| 13 |
+
"id2label": {
|
| 14 |
+
"0": "LABEL_0",
|
| 15 |
+
"1": "LABEL_1",
|
| 16 |
+
"2": "LABEL_2",
|
| 17 |
+
"3": "LABEL_3",
|
| 18 |
+
"4": "LABEL_4",
|
| 19 |
+
"5": "LABEL_5",
|
| 20 |
+
"6": "LABEL_6",
|
| 21 |
+
"7": "LABEL_7",
|
| 22 |
+
"8": "LABEL_8",
|
| 23 |
+
"9": "LABEL_9",
|
| 24 |
+
"10": "LABEL_10",
|
| 25 |
+
"11": "LABEL_11",
|
| 26 |
+
"12": "LABEL_12",
|
| 27 |
+
"13": "LABEL_13",
|
| 28 |
+
"14": "LABEL_14",
|
| 29 |
+
"15": "LABEL_15",
|
| 30 |
+
"16": "LABEL_16",
|
| 31 |
+
"17": "LABEL_17",
|
| 32 |
+
"18": "LABEL_18",
|
| 33 |
+
"19": "LABEL_19",
|
| 34 |
+
"20": "LABEL_20",
|
| 35 |
+
"21": "LABEL_21",
|
| 36 |
+
"22": "LABEL_22",
|
| 37 |
+
"23": "LABEL_23",
|
| 38 |
+
"24": "LABEL_24",
|
| 39 |
+
"25": "LABEL_25",
|
| 40 |
+
"26": "LABEL_26",
|
| 41 |
+
"27": "LABEL_27",
|
| 42 |
+
"28": "LABEL_28",
|
| 43 |
+
"29": "LABEL_29",
|
| 44 |
+
"30": "LABEL_30",
|
| 45 |
+
"31": "LABEL_31",
|
| 46 |
+
"32": "LABEL_32",
|
| 47 |
+
"33": "LABEL_33",
|
| 48 |
+
"34": "LABEL_34",
|
| 49 |
+
"35": "LABEL_35",
|
| 50 |
+
"36": "LABEL_36",
|
| 51 |
+
"37": "LABEL_37",
|
| 52 |
+
"38": "LABEL_38",
|
| 53 |
+
"39": "LABEL_39",
|
| 54 |
+
"40": "LABEL_40",
|
| 55 |
+
"41": "LABEL_41",
|
| 56 |
+
"42": "LABEL_42",
|
| 57 |
+
"43": "LABEL_43",
|
| 58 |
+
"44": "LABEL_44",
|
| 59 |
+
"45": "LABEL_45",
|
| 60 |
+
"46": "LABEL_46",
|
| 61 |
+
"47": "LABEL_47",
|
| 62 |
+
"48": "LABEL_48",
|
| 63 |
+
"49": "LABEL_49",
|
| 64 |
+
"50": "LABEL_50",
|
| 65 |
+
"51": "LABEL_51",
|
| 66 |
+
"52": "LABEL_52",
|
| 67 |
+
"53": "LABEL_53",
|
| 68 |
+
"54": "LABEL_54",
|
| 69 |
+
"55": "LABEL_55",
|
| 70 |
+
"56": "LABEL_56",
|
| 71 |
+
"57": "LABEL_57",
|
| 72 |
+
"58": "LABEL_58",
|
| 73 |
+
"59": "LABEL_59",
|
| 74 |
+
"60": "LABEL_60",
|
| 75 |
+
"61": "LABEL_61",
|
| 76 |
+
"62": "LABEL_62",
|
| 77 |
+
"63": "LABEL_63",
|
| 78 |
+
"64": "LABEL_64",
|
| 79 |
+
"65": "LABEL_65",
|
| 80 |
+
"66": "LABEL_66",
|
| 81 |
+
"67": "LABEL_67",
|
| 82 |
+
"68": "LABEL_68",
|
| 83 |
+
"69": "LABEL_69",
|
| 84 |
+
"70": "LABEL_70",
|
| 85 |
+
"71": "LABEL_71",
|
| 86 |
+
"72": "LABEL_72",
|
| 87 |
+
"73": "LABEL_73",
|
| 88 |
+
"74": "LABEL_74",
|
| 89 |
+
"75": "LABEL_75",
|
| 90 |
+
"76": "LABEL_76",
|
| 91 |
+
"77": "LABEL_77",
|
| 92 |
+
"78": "LABEL_78",
|
| 93 |
+
"79": "LABEL_79",
|
| 94 |
+
"80": "LABEL_80",
|
| 95 |
+
"81": "LABEL_81",
|
| 96 |
+
"82": "LABEL_82",
|
| 97 |
+
"83": "LABEL_83",
|
| 98 |
+
"84": "LABEL_84",
|
| 99 |
+
"85": "LABEL_85",
|
| 100 |
+
"86": "LABEL_86",
|
| 101 |
+
"87": "LABEL_87",
|
| 102 |
+
"88": "LABEL_88",
|
| 103 |
+
"89": "LABEL_89",
|
| 104 |
+
"90": "LABEL_90",
|
| 105 |
+
"91": "LABEL_91",
|
| 106 |
+
"92": "LABEL_92",
|
| 107 |
+
"93": "LABEL_93",
|
| 108 |
+
"94": "LABEL_94",
|
| 109 |
+
"95": "LABEL_95",
|
| 110 |
+
"96": "LABEL_96",
|
| 111 |
+
"97": "LABEL_97",
|
| 112 |
+
"98": "LABEL_98",
|
| 113 |
+
"99": "LABEL_99",
|
| 114 |
+
"100": "LABEL_100",
|
| 115 |
+
"101": "LABEL_101",
|
| 116 |
+
"102": "LABEL_102",
|
| 117 |
+
"103": "LABEL_103",
|
| 118 |
+
"104": "LABEL_104",
|
| 119 |
+
"105": "LABEL_105",
|
| 120 |
+
"106": "LABEL_106",
|
| 121 |
+
"107": "LABEL_107",
|
| 122 |
+
"108": "LABEL_108",
|
| 123 |
+
"109": "LABEL_109",
|
| 124 |
+
"110": "LABEL_110",
|
| 125 |
+
"111": "LABEL_111",
|
| 126 |
+
"112": "LABEL_112",
|
| 127 |
+
"113": "LABEL_113",
|
| 128 |
+
"114": "LABEL_114",
|
| 129 |
+
"115": "LABEL_115",
|
| 130 |
+
"116": "LABEL_116",
|
| 131 |
+
"117": "LABEL_117",
|
| 132 |
+
"118": "LABEL_118",
|
| 133 |
+
"119": "LABEL_119",
|
| 134 |
+
"120": "LABEL_120",
|
| 135 |
+
"121": "LABEL_121",
|
| 136 |
+
"122": "LABEL_122",
|
| 137 |
+
"123": "LABEL_123",
|
| 138 |
+
"124": "LABEL_124",
|
| 139 |
+
"125": "LABEL_125",
|
| 140 |
+
"126": "LABEL_126",
|
| 141 |
+
"127": "LABEL_127",
|
| 142 |
+
"128": "LABEL_128",
|
| 143 |
+
"129": "LABEL_129",
|
| 144 |
+
"130": "LABEL_130",
|
| 145 |
+
"131": "LABEL_131",
|
| 146 |
+
"132": "LABEL_132",
|
| 147 |
+
"133": "LABEL_133",
|
| 148 |
+
"134": "LABEL_134",
|
| 149 |
+
"135": "LABEL_135",
|
| 150 |
+
"136": "LABEL_136",
|
| 151 |
+
"137": "LABEL_137",
|
| 152 |
+
"138": "LABEL_138",
|
| 153 |
+
"139": "LABEL_139",
|
| 154 |
+
"140": "LABEL_140",
|
| 155 |
+
"141": "LABEL_141",
|
| 156 |
+
"142": "LABEL_142",
|
| 157 |
+
"143": "LABEL_143",
|
| 158 |
+
"144": "LABEL_144",
|
| 159 |
+
"145": "LABEL_145",
|
| 160 |
+
"146": "LABEL_146",
|
| 161 |
+
"147": "LABEL_147",
|
| 162 |
+
"148": "LABEL_148",
|
| 163 |
+
"149": "LABEL_149",
|
| 164 |
+
"150": "LABEL_150",
|
| 165 |
+
"151": "LABEL_151",
|
| 166 |
+
"152": "LABEL_152",
|
| 167 |
+
"153": "LABEL_153",
|
| 168 |
+
"154": "LABEL_154",
|
| 169 |
+
"155": "LABEL_155",
|
| 170 |
+
"156": "LABEL_156",
|
| 171 |
+
"157": "LABEL_157",
|
| 172 |
+
"158": "LABEL_158",
|
| 173 |
+
"159": "LABEL_159",
|
| 174 |
+
"160": "LABEL_160",
|
| 175 |
+
"161": "LABEL_161",
|
| 176 |
+
"162": "LABEL_162",
|
| 177 |
+
"163": "LABEL_163",
|
| 178 |
+
"164": "LABEL_164",
|
| 179 |
+
"165": "LABEL_165",
|
| 180 |
+
"166": "LABEL_166",
|
| 181 |
+
"167": "LABEL_167",
|
| 182 |
+
"168": "LABEL_168",
|
| 183 |
+
"169": "LABEL_169",
|
| 184 |
+
"170": "LABEL_170",
|
| 185 |
+
"171": "LABEL_171",
|
| 186 |
+
"172": "LABEL_172",
|
| 187 |
+
"173": "LABEL_173",
|
| 188 |
+
"174": "LABEL_174",
|
| 189 |
+
"175": "LABEL_175",
|
| 190 |
+
"176": "LABEL_176",
|
| 191 |
+
"177": "LABEL_177",
|
| 192 |
+
"178": "LABEL_178",
|
| 193 |
+
"179": "LABEL_179",
|
| 194 |
+
"180": "LABEL_180",
|
| 195 |
+
"181": "LABEL_181",
|
| 196 |
+
"182": "LABEL_182",
|
| 197 |
+
"183": "LABEL_183",
|
| 198 |
+
"184": "LABEL_184",
|
| 199 |
+
"185": "LABEL_185",
|
| 200 |
+
"186": "LABEL_186",
|
| 201 |
+
"187": "LABEL_187",
|
| 202 |
+
"188": "LABEL_188",
|
| 203 |
+
"189": "LABEL_189",
|
| 204 |
+
"190": "LABEL_190",
|
| 205 |
+
"191": "LABEL_191",
|
| 206 |
+
"192": "LABEL_192",
|
| 207 |
+
"193": "LABEL_193",
|
| 208 |
+
"194": "LABEL_194",
|
| 209 |
+
"195": "LABEL_195",
|
| 210 |
+
"196": "LABEL_196",
|
| 211 |
+
"197": "LABEL_197",
|
| 212 |
+
"198": "LABEL_198",
|
| 213 |
+
"199": "LABEL_199",
|
| 214 |
+
"200": "LABEL_200",
|
| 215 |
+
"201": "LABEL_201",
|
| 216 |
+
"202": "LABEL_202",
|
| 217 |
+
"203": "LABEL_203",
|
| 218 |
+
"204": "LABEL_204",
|
| 219 |
+
"205": "LABEL_205",
|
| 220 |
+
"206": "LABEL_206",
|
| 221 |
+
"207": "LABEL_207",
|
| 222 |
+
"208": "LABEL_208",
|
| 223 |
+
"209": "LABEL_209",
|
| 224 |
+
"210": "LABEL_210",
|
| 225 |
+
"211": "LABEL_211",
|
| 226 |
+
"212": "LABEL_212",
|
| 227 |
+
"213": "LABEL_213",
|
| 228 |
+
"214": "LABEL_214",
|
| 229 |
+
"215": "LABEL_215",
|
| 230 |
+
"216": "LABEL_216",
|
| 231 |
+
"217": "LABEL_217",
|
| 232 |
+
"218": "LABEL_218",
|
| 233 |
+
"219": "LABEL_219",
|
| 234 |
+
"220": "LABEL_220",
|
| 235 |
+
"221": "LABEL_221",
|
| 236 |
+
"222": "LABEL_222",
|
| 237 |
+
"223": "LABEL_223",
|
| 238 |
+
"224": "LABEL_224",
|
| 239 |
+
"225": "LABEL_225",
|
| 240 |
+
"226": "LABEL_226",
|
| 241 |
+
"227": "LABEL_227",
|
| 242 |
+
"228": "LABEL_228",
|
| 243 |
+
"229": "LABEL_229",
|
| 244 |
+
"230": "LABEL_230",
|
| 245 |
+
"231": "LABEL_231",
|
| 246 |
+
"232": "LABEL_232",
|
| 247 |
+
"233": "LABEL_233",
|
| 248 |
+
"234": "LABEL_234",
|
| 249 |
+
"235": "LABEL_235",
|
| 250 |
+
"236": "LABEL_236",
|
| 251 |
+
"237": "LABEL_237",
|
| 252 |
+
"238": "LABEL_238",
|
| 253 |
+
"239": "LABEL_239",
|
| 254 |
+
"240": "LABEL_240",
|
| 255 |
+
"241": "LABEL_241",
|
| 256 |
+
"242": "LABEL_242",
|
| 257 |
+
"243": "LABEL_243",
|
| 258 |
+
"244": "LABEL_244",
|
| 259 |
+
"245": "LABEL_245",
|
| 260 |
+
"246": "LABEL_246",
|
| 261 |
+
"247": "LABEL_247",
|
| 262 |
+
"248": "LABEL_248",
|
| 263 |
+
"249": "LABEL_249",
|
| 264 |
+
"250": "LABEL_250",
|
| 265 |
+
"251": "LABEL_251",
|
| 266 |
+
"252": "LABEL_252",
|
| 267 |
+
"253": "LABEL_253",
|
| 268 |
+
"254": "LABEL_254",
|
| 269 |
+
"255": "LABEL_255",
|
| 270 |
+
"256": "LABEL_256",
|
| 271 |
+
"257": "LABEL_257",
|
| 272 |
+
"258": "LABEL_258",
|
| 273 |
+
"259": "LABEL_259",
|
| 274 |
+
"260": "LABEL_260",
|
| 275 |
+
"261": "LABEL_261",
|
| 276 |
+
"262": "LABEL_262",
|
| 277 |
+
"263": "LABEL_263",
|
| 278 |
+
"264": "LABEL_264",
|
| 279 |
+
"265": "LABEL_265",
|
| 280 |
+
"266": "LABEL_266",
|
| 281 |
+
"267": "LABEL_267",
|
| 282 |
+
"268": "LABEL_268",
|
| 283 |
+
"269": "LABEL_269",
|
| 284 |
+
"270": "LABEL_270",
|
| 285 |
+
"271": "LABEL_271",
|
| 286 |
+
"272": "LABEL_272",
|
| 287 |
+
"273": "LABEL_273",
|
| 288 |
+
"274": "LABEL_274",
|
| 289 |
+
"275": "LABEL_275",
|
| 290 |
+
"276": "LABEL_276",
|
| 291 |
+
"277": "LABEL_277",
|
| 292 |
+
"278": "LABEL_278",
|
| 293 |
+
"279": "LABEL_279",
|
| 294 |
+
"280": "LABEL_280",
|
| 295 |
+
"281": "LABEL_281",
|
| 296 |
+
"282": "LABEL_282",
|
| 297 |
+
"283": "LABEL_283",
|
| 298 |
+
"284": "LABEL_284",
|
| 299 |
+
"285": "LABEL_285",
|
| 300 |
+
"286": "LABEL_286",
|
| 301 |
+
"287": "LABEL_287",
|
| 302 |
+
"288": "LABEL_288",
|
| 303 |
+
"289": "LABEL_289",
|
| 304 |
+
"290": "LABEL_290",
|
| 305 |
+
"291": "LABEL_291",
|
| 306 |
+
"292": "LABEL_292",
|
| 307 |
+
"293": "LABEL_293",
|
| 308 |
+
"294": "LABEL_294",
|
| 309 |
+
"295": "LABEL_295",
|
| 310 |
+
"296": "LABEL_296",
|
| 311 |
+
"297": "LABEL_297",
|
| 312 |
+
"298": "LABEL_298",
|
| 313 |
+
"299": "LABEL_299",
|
| 314 |
+
"300": "LABEL_300",
|
| 315 |
+
"301": "LABEL_301",
|
| 316 |
+
"302": "LABEL_302",
|
| 317 |
+
"303": "LABEL_303",
|
| 318 |
+
"304": "LABEL_304",
|
| 319 |
+
"305": "LABEL_305",
|
| 320 |
+
"306": "LABEL_306",
|
| 321 |
+
"307": "LABEL_307",
|
| 322 |
+
"308": "LABEL_308",
|
| 323 |
+
"309": "LABEL_309",
|
| 324 |
+
"310": "LABEL_310",
|
| 325 |
+
"311": "LABEL_311",
|
| 326 |
+
"312": "LABEL_312",
|
| 327 |
+
"313": "LABEL_313",
|
| 328 |
+
"314": "LABEL_314",
|
| 329 |
+
"315": "LABEL_315",
|
| 330 |
+
"316": "LABEL_316",
|
| 331 |
+
"317": "LABEL_317",
|
| 332 |
+
"318": "LABEL_318",
|
| 333 |
+
"319": "LABEL_319",
|
| 334 |
+
"320": "LABEL_320",
|
| 335 |
+
"321": "LABEL_321",
|
| 336 |
+
"322": "LABEL_322",
|
| 337 |
+
"323": "LABEL_323",
|
| 338 |
+
"324": "LABEL_324",
|
| 339 |
+
"325": "LABEL_325",
|
| 340 |
+
"326": "LABEL_326",
|
| 341 |
+
"327": "LABEL_327",
|
| 342 |
+
"328": "LABEL_328",
|
| 343 |
+
"329": "LABEL_329",
|
| 344 |
+
"330": "LABEL_330",
|
| 345 |
+
"331": "LABEL_331",
|
| 346 |
+
"332": "LABEL_332",
|
| 347 |
+
"333": "LABEL_333",
|
| 348 |
+
"334": "LABEL_334",
|
| 349 |
+
"335": "LABEL_335",
|
| 350 |
+
"336": "LABEL_336",
|
| 351 |
+
"337": "LABEL_337",
|
| 352 |
+
"338": "LABEL_338",
|
| 353 |
+
"339": "LABEL_339",
|
| 354 |
+
"340": "LABEL_340",
|
| 355 |
+
"341": "LABEL_341",
|
| 356 |
+
"342": "LABEL_342",
|
| 357 |
+
"343": "LABEL_343",
|
| 358 |
+
"344": "LABEL_344",
|
| 359 |
+
"345": "LABEL_345",
|
| 360 |
+
"346": "LABEL_346",
|
| 361 |
+
"347": "LABEL_347",
|
| 362 |
+
"348": "LABEL_348",
|
| 363 |
+
"349": "LABEL_349",
|
| 364 |
+
"350": "LABEL_350",
|
| 365 |
+
"351": "LABEL_351",
|
| 366 |
+
"352": "LABEL_352",
|
| 367 |
+
"353": "LABEL_353",
|
| 368 |
+
"354": "LABEL_354",
|
| 369 |
+
"355": "LABEL_355",
|
| 370 |
+
"356": "LABEL_356",
|
| 371 |
+
"357": "LABEL_357",
|
| 372 |
+
"358": "LABEL_358",
|
| 373 |
+
"359": "LABEL_359",
|
| 374 |
+
"360": "LABEL_360",
|
| 375 |
+
"361": "LABEL_361",
|
| 376 |
+
"362": "LABEL_362",
|
| 377 |
+
"363": "LABEL_363",
|
| 378 |
+
"364": "LABEL_364",
|
| 379 |
+
"365": "LABEL_365",
|
| 380 |
+
"366": "LABEL_366",
|
| 381 |
+
"367": "LABEL_367",
|
| 382 |
+
"368": "LABEL_368",
|
| 383 |
+
"369": "LABEL_369",
|
| 384 |
+
"370": "LABEL_370",
|
| 385 |
+
"371": "LABEL_371",
|
| 386 |
+
"372": "LABEL_372",
|
| 387 |
+
"373": "LABEL_373",
|
| 388 |
+
"374": "LABEL_374",
|
| 389 |
+
"375": "LABEL_375",
|
| 390 |
+
"376": "LABEL_376",
|
| 391 |
+
"377": "LABEL_377",
|
| 392 |
+
"378": "LABEL_378",
|
| 393 |
+
"379": "LABEL_379",
|
| 394 |
+
"380": "LABEL_380",
|
| 395 |
+
"381": "LABEL_381",
|
| 396 |
+
"382": "LABEL_382",
|
| 397 |
+
"383": "LABEL_383",
|
| 398 |
+
"384": "LABEL_384",
|
| 399 |
+
"385": "LABEL_385",
|
| 400 |
+
"386": "LABEL_386",
|
| 401 |
+
"387": "LABEL_387",
|
| 402 |
+
"388": "LABEL_388",
|
| 403 |
+
"389": "LABEL_389",
|
| 404 |
+
"390": "LABEL_390"
|
| 405 |
+
},
|
| 406 |
+
"initializer_range": 0.02,
|
| 407 |
+
"intermediate_size": 1536,
|
| 408 |
+
"label2id": {
|
| 409 |
+
"LABEL_0": 0,
|
| 410 |
+
"LABEL_1": 1,
|
| 411 |
+
"LABEL_10": 10,
|
| 412 |
+
"LABEL_100": 100,
|
| 413 |
+
"LABEL_101": 101,
|
| 414 |
+
"LABEL_102": 102,
|
| 415 |
+
"LABEL_103": 103,
|
| 416 |
+
"LABEL_104": 104,
|
| 417 |
+
"LABEL_105": 105,
|
| 418 |
+
"LABEL_106": 106,
|
| 419 |
+
"LABEL_107": 107,
|
| 420 |
+
"LABEL_108": 108,
|
| 421 |
+
"LABEL_109": 109,
|
| 422 |
+
"LABEL_11": 11,
|
| 423 |
+
"LABEL_110": 110,
|
| 424 |
+
"LABEL_111": 111,
|
| 425 |
+
"LABEL_112": 112,
|
| 426 |
+
"LABEL_113": 113,
|
| 427 |
+
"LABEL_114": 114,
|
| 428 |
+
"LABEL_115": 115,
|
| 429 |
+
"LABEL_116": 116,
|
| 430 |
+
"LABEL_117": 117,
|
| 431 |
+
"LABEL_118": 118,
|
| 432 |
+
"LABEL_119": 119,
|
| 433 |
+
"LABEL_12": 12,
|
| 434 |
+
"LABEL_120": 120,
|
| 435 |
+
"LABEL_121": 121,
|
| 436 |
+
"LABEL_122": 122,
|
| 437 |
+
"LABEL_123": 123,
|
| 438 |
+
"LABEL_124": 124,
|
| 439 |
+
"LABEL_125": 125,
|
| 440 |
+
"LABEL_126": 126,
|
| 441 |
+
"LABEL_127": 127,
|
| 442 |
+
"LABEL_128": 128,
|
| 443 |
+
"LABEL_129": 129,
|
| 444 |
+
"LABEL_13": 13,
|
| 445 |
+
"LABEL_130": 130,
|
| 446 |
+
"LABEL_131": 131,
|
| 447 |
+
"LABEL_132": 132,
|
| 448 |
+
"LABEL_133": 133,
|
| 449 |
+
"LABEL_134": 134,
|
| 450 |
+
"LABEL_135": 135,
|
| 451 |
+
"LABEL_136": 136,
|
| 452 |
+
"LABEL_137": 137,
|
| 453 |
+
"LABEL_138": 138,
|
| 454 |
+
"LABEL_139": 139,
|
| 455 |
+
"LABEL_14": 14,
|
| 456 |
+
"LABEL_140": 140,
|
| 457 |
+
"LABEL_141": 141,
|
| 458 |
+
"LABEL_142": 142,
|
| 459 |
+
"LABEL_143": 143,
|
| 460 |
+
"LABEL_144": 144,
|
| 461 |
+
"LABEL_145": 145,
|
| 462 |
+
"LABEL_146": 146,
|
| 463 |
+
"LABEL_147": 147,
|
| 464 |
+
"LABEL_148": 148,
|
| 465 |
+
"LABEL_149": 149,
|
| 466 |
+
"LABEL_15": 15,
|
| 467 |
+
"LABEL_150": 150,
|
| 468 |
+
"LABEL_151": 151,
|
| 469 |
+
"LABEL_152": 152,
|
| 470 |
+
"LABEL_153": 153,
|
| 471 |
+
"LABEL_154": 154,
|
| 472 |
+
"LABEL_155": 155,
|
| 473 |
+
"LABEL_156": 156,
|
| 474 |
+
"LABEL_157": 157,
|
| 475 |
+
"LABEL_158": 158,
|
| 476 |
+
"LABEL_159": 159,
|
| 477 |
+
"LABEL_16": 16,
|
| 478 |
+
"LABEL_160": 160,
|
| 479 |
+
"LABEL_161": 161,
|
| 480 |
+
"LABEL_162": 162,
|
| 481 |
+
"LABEL_163": 163,
|
| 482 |
+
"LABEL_164": 164,
|
| 483 |
+
"LABEL_165": 165,
|
| 484 |
+
"LABEL_166": 166,
|
| 485 |
+
"LABEL_167": 167,
|
| 486 |
+
"LABEL_168": 168,
|
| 487 |
+
"LABEL_169": 169,
|
| 488 |
+
"LABEL_17": 17,
|
| 489 |
+
"LABEL_170": 170,
|
| 490 |
+
"LABEL_171": 171,
|
| 491 |
+
"LABEL_172": 172,
|
| 492 |
+
"LABEL_173": 173,
|
| 493 |
+
"LABEL_174": 174,
|
| 494 |
+
"LABEL_175": 175,
|
| 495 |
+
"LABEL_176": 176,
|
| 496 |
+
"LABEL_177": 177,
|
| 497 |
+
"LABEL_178": 178,
|
| 498 |
+
"LABEL_179": 179,
|
| 499 |
+
"LABEL_18": 18,
|
| 500 |
+
"LABEL_180": 180,
|
| 501 |
+
"LABEL_181": 181,
|
| 502 |
+
"LABEL_182": 182,
|
| 503 |
+
"LABEL_183": 183,
|
| 504 |
+
"LABEL_184": 184,
|
| 505 |
+
"LABEL_185": 185,
|
| 506 |
+
"LABEL_186": 186,
|
| 507 |
+
"LABEL_187": 187,
|
| 508 |
+
"LABEL_188": 188,
|
| 509 |
+
"LABEL_189": 189,
|
| 510 |
+
"LABEL_19": 19,
|
| 511 |
+
"LABEL_190": 190,
|
| 512 |
+
"LABEL_191": 191,
|
| 513 |
+
"LABEL_192": 192,
|
| 514 |
+
"LABEL_193": 193,
|
| 515 |
+
"LABEL_194": 194,
|
| 516 |
+
"LABEL_195": 195,
|
| 517 |
+
"LABEL_196": 196,
|
| 518 |
+
"LABEL_197": 197,
|
| 519 |
+
"LABEL_198": 198,
|
| 520 |
+
"LABEL_199": 199,
|
| 521 |
+
"LABEL_2": 2,
|
| 522 |
+
"LABEL_20": 20,
|
| 523 |
+
"LABEL_200": 200,
|
| 524 |
+
"LABEL_201": 201,
|
| 525 |
+
"LABEL_202": 202,
|
| 526 |
+
"LABEL_203": 203,
|
| 527 |
+
"LABEL_204": 204,
|
| 528 |
+
"LABEL_205": 205,
|
| 529 |
+
"LABEL_206": 206,
|
| 530 |
+
"LABEL_207": 207,
|
| 531 |
+
"LABEL_208": 208,
|
| 532 |
+
"LABEL_209": 209,
|
| 533 |
+
"LABEL_21": 21,
|
| 534 |
+
"LABEL_210": 210,
|
| 535 |
+
"LABEL_211": 211,
|
| 536 |
+
"LABEL_212": 212,
|
| 537 |
+
"LABEL_213": 213,
|
| 538 |
+
"LABEL_214": 214,
|
| 539 |
+
"LABEL_215": 215,
|
| 540 |
+
"LABEL_216": 216,
|
| 541 |
+
"LABEL_217": 217,
|
| 542 |
+
"LABEL_218": 218,
|
| 543 |
+
"LABEL_219": 219,
|
| 544 |
+
"LABEL_22": 22,
|
| 545 |
+
"LABEL_220": 220,
|
| 546 |
+
"LABEL_221": 221,
|
| 547 |
+
"LABEL_222": 222,
|
| 548 |
+
"LABEL_223": 223,
|
| 549 |
+
"LABEL_224": 224,
|
| 550 |
+
"LABEL_225": 225,
|
| 551 |
+
"LABEL_226": 226,
|
| 552 |
+
"LABEL_227": 227,
|
| 553 |
+
"LABEL_228": 228,
|
| 554 |
+
"LABEL_229": 229,
|
| 555 |
+
"LABEL_23": 23,
|
| 556 |
+
"LABEL_230": 230,
|
| 557 |
+
"LABEL_231": 231,
|
| 558 |
+
"LABEL_232": 232,
|
| 559 |
+
"LABEL_233": 233,
|
| 560 |
+
"LABEL_234": 234,
|
| 561 |
+
"LABEL_235": 235,
|
| 562 |
+
"LABEL_236": 236,
|
| 563 |
+
"LABEL_237": 237,
|
| 564 |
+
"LABEL_238": 238,
|
| 565 |
+
"LABEL_239": 239,
|
| 566 |
+
"LABEL_24": 24,
|
| 567 |
+
"LABEL_240": 240,
|
| 568 |
+
"LABEL_241": 241,
|
| 569 |
+
"LABEL_242": 242,
|
| 570 |
+
"LABEL_243": 243,
|
| 571 |
+
"LABEL_244": 244,
|
| 572 |
+
"LABEL_245": 245,
|
| 573 |
+
"LABEL_246": 246,
|
| 574 |
+
"LABEL_247": 247,
|
| 575 |
+
"LABEL_248": 248,
|
| 576 |
+
"LABEL_249": 249,
|
| 577 |
+
"LABEL_25": 25,
|
| 578 |
+
"LABEL_250": 250,
|
| 579 |
+
"LABEL_251": 251,
|
| 580 |
+
"LABEL_252": 252,
|
| 581 |
+
"LABEL_253": 253,
|
| 582 |
+
"LABEL_254": 254,
|
| 583 |
+
"LABEL_255": 255,
|
| 584 |
+
"LABEL_256": 256,
|
| 585 |
+
"LABEL_257": 257,
|
| 586 |
+
"LABEL_258": 258,
|
| 587 |
+
"LABEL_259": 259,
|
| 588 |
+
"LABEL_26": 26,
|
| 589 |
+
"LABEL_260": 260,
|
| 590 |
+
"LABEL_261": 261,
|
| 591 |
+
"LABEL_262": 262,
|
| 592 |
+
"LABEL_263": 263,
|
| 593 |
+
"LABEL_264": 264,
|
| 594 |
+
"LABEL_265": 265,
|
| 595 |
+
"LABEL_266": 266,
|
| 596 |
+
"LABEL_267": 267,
|
| 597 |
+
"LABEL_268": 268,
|
| 598 |
+
"LABEL_269": 269,
|
| 599 |
+
"LABEL_27": 27,
|
| 600 |
+
"LABEL_270": 270,
|
| 601 |
+
"LABEL_271": 271,
|
| 602 |
+
"LABEL_272": 272,
|
| 603 |
+
"LABEL_273": 273,
|
| 604 |
+
"LABEL_274": 274,
|
| 605 |
+
"LABEL_275": 275,
|
| 606 |
+
"LABEL_276": 276,
|
| 607 |
+
"LABEL_277": 277,
|
| 608 |
+
"LABEL_278": 278,
|
| 609 |
+
"LABEL_279": 279,
|
| 610 |
+
"LABEL_28": 28,
|
| 611 |
+
"LABEL_280": 280,
|
| 612 |
+
"LABEL_281": 281,
|
| 613 |
+
"LABEL_282": 282,
|
| 614 |
+
"LABEL_283": 283,
|
| 615 |
+
"LABEL_284": 284,
|
| 616 |
+
"LABEL_285": 285,
|
| 617 |
+
"LABEL_286": 286,
|
| 618 |
+
"LABEL_287": 287,
|
| 619 |
+
"LABEL_288": 288,
|
| 620 |
+
"LABEL_289": 289,
|
| 621 |
+
"LABEL_29": 29,
|
| 622 |
+
"LABEL_290": 290,
|
| 623 |
+
"LABEL_291": 291,
|
| 624 |
+
"LABEL_292": 292,
|
| 625 |
+
"LABEL_293": 293,
|
| 626 |
+
"LABEL_294": 294,
|
| 627 |
+
"LABEL_295": 295,
|
| 628 |
+
"LABEL_296": 296,
|
| 629 |
+
"LABEL_297": 297,
|
| 630 |
+
"LABEL_298": 298,
|
| 631 |
+
"LABEL_299": 299,
|
| 632 |
+
"LABEL_3": 3,
|
| 633 |
+
"LABEL_30": 30,
|
| 634 |
+
"LABEL_300": 300,
|
| 635 |
+
"LABEL_301": 301,
|
| 636 |
+
"LABEL_302": 302,
|
| 637 |
+
"LABEL_303": 303,
|
| 638 |
+
"LABEL_304": 304,
|
| 639 |
+
"LABEL_305": 305,
|
| 640 |
+
"LABEL_306": 306,
|
| 641 |
+
"LABEL_307": 307,
|
| 642 |
+
"LABEL_308": 308,
|
| 643 |
+
"LABEL_309": 309,
|
| 644 |
+
"LABEL_31": 31,
|
| 645 |
+
"LABEL_310": 310,
|
| 646 |
+
"LABEL_311": 311,
|
| 647 |
+
"LABEL_312": 312,
|
| 648 |
+
"LABEL_313": 313,
|
| 649 |
+
"LABEL_314": 314,
|
| 650 |
+
"LABEL_315": 315,
|
| 651 |
+
"LABEL_316": 316,
|
| 652 |
+
"LABEL_317": 317,
|
| 653 |
+
"LABEL_318": 318,
|
| 654 |
+
"LABEL_319": 319,
|
| 655 |
+
"LABEL_32": 32,
|
| 656 |
+
"LABEL_320": 320,
|
| 657 |
+
"LABEL_321": 321,
|
| 658 |
+
"LABEL_322": 322,
|
| 659 |
+
"LABEL_323": 323,
|
| 660 |
+
"LABEL_324": 324,
|
| 661 |
+
"LABEL_325": 325,
|
| 662 |
+
"LABEL_326": 326,
|
| 663 |
+
"LABEL_327": 327,
|
| 664 |
+
"LABEL_328": 328,
|
| 665 |
+
"LABEL_329": 329,
|
| 666 |
+
"LABEL_33": 33,
|
| 667 |
+
"LABEL_330": 330,
|
| 668 |
+
"LABEL_331": 331,
|
| 669 |
+
"LABEL_332": 332,
|
| 670 |
+
"LABEL_333": 333,
|
| 671 |
+
"LABEL_334": 334,
|
| 672 |
+
"LABEL_335": 335,
|
| 673 |
+
"LABEL_336": 336,
|
| 674 |
+
"LABEL_337": 337,
|
| 675 |
+
"LABEL_338": 338,
|
| 676 |
+
"LABEL_339": 339,
|
| 677 |
+
"LABEL_34": 34,
|
| 678 |
+
"LABEL_340": 340,
|
| 679 |
+
"LABEL_341": 341,
|
| 680 |
+
"LABEL_342": 342,
|
| 681 |
+
"LABEL_343": 343,
|
| 682 |
+
"LABEL_344": 344,
|
| 683 |
+
"LABEL_345": 345,
|
| 684 |
+
"LABEL_346": 346,
|
| 685 |
+
"LABEL_347": 347,
|
| 686 |
+
"LABEL_348": 348,
|
| 687 |
+
"LABEL_349": 349,
|
| 688 |
+
"LABEL_35": 35,
|
| 689 |
+
"LABEL_350": 350,
|
| 690 |
+
"LABEL_351": 351,
|
| 691 |
+
"LABEL_352": 352,
|
| 692 |
+
"LABEL_353": 353,
|
| 693 |
+
"LABEL_354": 354,
|
| 694 |
+
"LABEL_355": 355,
|
| 695 |
+
"LABEL_356": 356,
|
| 696 |
+
"LABEL_357": 357,
|
| 697 |
+
"LABEL_358": 358,
|
| 698 |
+
"LABEL_359": 359,
|
| 699 |
+
"LABEL_36": 36,
|
| 700 |
+
"LABEL_360": 360,
|
| 701 |
+
"LABEL_361": 361,
|
| 702 |
+
"LABEL_362": 362,
|
| 703 |
+
"LABEL_363": 363,
|
| 704 |
+
"LABEL_364": 364,
|
| 705 |
+
"LABEL_365": 365,
|
| 706 |
+
"LABEL_366": 366,
|
| 707 |
+
"LABEL_367": 367,
|
| 708 |
+
"LABEL_368": 368,
|
| 709 |
+
"LABEL_369": 369,
|
| 710 |
+
"LABEL_37": 37,
|
| 711 |
+
"LABEL_370": 370,
|
| 712 |
+
"LABEL_371": 371,
|
| 713 |
+
"LABEL_372": 372,
|
| 714 |
+
"LABEL_373": 373,
|
| 715 |
+
"LABEL_374": 374,
|
| 716 |
+
"LABEL_375": 375,
|
| 717 |
+
"LABEL_376": 376,
|
| 718 |
+
"LABEL_377": 377,
|
| 719 |
+
"LABEL_378": 378,
|
| 720 |
+
"LABEL_379": 379,
|
| 721 |
+
"LABEL_38": 38,
|
| 722 |
+
"LABEL_380": 380,
|
| 723 |
+
"LABEL_381": 381,
|
| 724 |
+
"LABEL_382": 382,
|
| 725 |
+
"LABEL_383": 383,
|
| 726 |
+
"LABEL_384": 384,
|
| 727 |
+
"LABEL_385": 385,
|
| 728 |
+
"LABEL_386": 386,
|
| 729 |
+
"LABEL_387": 387,
|
| 730 |
+
"LABEL_388": 388,
|
| 731 |
+
"LABEL_389": 389,
|
| 732 |
+
"LABEL_39": 39,
|
| 733 |
+
"LABEL_390": 390,
|
| 734 |
+
"LABEL_4": 4,
|
| 735 |
+
"LABEL_40": 40,
|
| 736 |
+
"LABEL_41": 41,
|
| 737 |
+
"LABEL_42": 42,
|
| 738 |
+
"LABEL_43": 43,
|
| 739 |
+
"LABEL_44": 44,
|
| 740 |
+
"LABEL_45": 45,
|
| 741 |
+
"LABEL_46": 46,
|
| 742 |
+
"LABEL_47": 47,
|
| 743 |
+
"LABEL_48": 48,
|
| 744 |
+
"LABEL_49": 49,
|
| 745 |
+
"LABEL_5": 5,
|
| 746 |
+
"LABEL_50": 50,
|
| 747 |
+
"LABEL_51": 51,
|
| 748 |
+
"LABEL_52": 52,
|
| 749 |
+
"LABEL_53": 53,
|
| 750 |
+
"LABEL_54": 54,
|
| 751 |
+
"LABEL_55": 55,
|
| 752 |
+
"LABEL_56": 56,
|
| 753 |
+
"LABEL_57": 57,
|
| 754 |
+
"LABEL_58": 58,
|
| 755 |
+
"LABEL_59": 59,
|
| 756 |
+
"LABEL_6": 6,
|
| 757 |
+
"LABEL_60": 60,
|
| 758 |
+
"LABEL_61": 61,
|
| 759 |
+
"LABEL_62": 62,
|
| 760 |
+
"LABEL_63": 63,
|
| 761 |
+
"LABEL_64": 64,
|
| 762 |
+
"LABEL_65": 65,
|
| 763 |
+
"LABEL_66": 66,
|
| 764 |
+
"LABEL_67": 67,
|
| 765 |
+
"LABEL_68": 68,
|
| 766 |
+
"LABEL_69": 69,
|
| 767 |
+
"LABEL_7": 7,
|
| 768 |
+
"LABEL_70": 70,
|
| 769 |
+
"LABEL_71": 71,
|
| 770 |
+
"LABEL_72": 72,
|
| 771 |
+
"LABEL_73": 73,
|
| 772 |
+
"LABEL_74": 74,
|
| 773 |
+
"LABEL_75": 75,
|
| 774 |
+
"LABEL_76": 76,
|
| 775 |
+
"LABEL_77": 77,
|
| 776 |
+
"LABEL_78": 78,
|
| 777 |
+
"LABEL_79": 79,
|
| 778 |
+
"LABEL_8": 8,
|
| 779 |
+
"LABEL_80": 80,
|
| 780 |
+
"LABEL_81": 81,
|
| 781 |
+
"LABEL_82": 82,
|
| 782 |
+
"LABEL_83": 83,
|
| 783 |
+
"LABEL_84": 84,
|
| 784 |
+
"LABEL_85": 85,
|
| 785 |
+
"LABEL_86": 86,
|
| 786 |
+
"LABEL_87": 87,
|
| 787 |
+
"LABEL_88": 88,
|
| 788 |
+
"LABEL_89": 89,
|
| 789 |
+
"LABEL_9": 9,
|
| 790 |
+
"LABEL_90": 90,
|
| 791 |
+
"LABEL_91": 91,
|
| 792 |
+
"LABEL_92": 92,
|
| 793 |
+
"LABEL_93": 93,
|
| 794 |
+
"LABEL_94": 94,
|
| 795 |
+
"LABEL_95": 95,
|
| 796 |
+
"LABEL_96": 96,
|
| 797 |
+
"LABEL_97": 97,
|
| 798 |
+
"LABEL_98": 98,
|
| 799 |
+
"LABEL_99": 99
|
| 800 |
+
},
|
| 801 |
+
"layer_norm_eps": 1e-12,
|
| 802 |
+
"max_position_embeddings": 512,
|
| 803 |
+
"model_type": "bert",
|
| 804 |
+
"num_attention_heads": 12,
|
| 805 |
+
"num_hidden_layers": 6,
|
| 806 |
+
"num_relation_heads": 32,
|
| 807 |
+
"pad_token_id": 0,
|
| 808 |
+
"pooler_fc_size": 768,
|
| 809 |
+
"pooler_num_attention_heads": 12,
|
| 810 |
+
"pooler_num_fc_layers": 3,
|
| 811 |
+
"pooler_size_per_head": 128,
|
| 812 |
+
"pooler_type": "first_token_transform",
|
| 813 |
+
"position_embedding_type": "absolute",
|
| 814 |
+
"torch_dtype": "float32",
|
| 815 |
+
"transformers_version": "4.44.1",
|
| 816 |
+
"type_vocab_size": 2,
|
| 817 |
+
"use_cache": true,
|
| 818 |
+
"vocab_size": 21128
|
| 819 |
+
}
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/poly_bert_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
|
| 3 |
+
size 76583333
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polychar.txt
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
丧
|
| 2 |
+
中
|
| 3 |
+
为
|
| 4 |
+
乌
|
| 5 |
+
乐
|
| 6 |
+
了
|
| 7 |
+
什
|
| 8 |
+
仔
|
| 9 |
+
令
|
| 10 |
+
任
|
| 11 |
+
会
|
| 12 |
+
传
|
| 13 |
+
佛
|
| 14 |
+
供
|
| 15 |
+
便
|
| 16 |
+
倒
|
| 17 |
+
假
|
| 18 |
+
兴
|
| 19 |
+
冠
|
| 20 |
+
冲
|
| 21 |
+
几
|
| 22 |
+
分
|
| 23 |
+
切
|
| 24 |
+
划
|
| 25 |
+
创
|
| 26 |
+
剥
|
| 27 |
+
勒
|
| 28 |
+
区
|
| 29 |
+
华
|
| 30 |
+
单
|
| 31 |
+
卜
|
| 32 |
+
占
|
| 33 |
+
卡
|
| 34 |
+
卷
|
| 35 |
+
厦
|
| 36 |
+
参
|
| 37 |
+
发
|
| 38 |
+
只
|
| 39 |
+
号
|
| 40 |
+
同
|
| 41 |
+
吐
|
| 42 |
+
和
|
| 43 |
+
喝
|
| 44 |
+
圈
|
| 45 |
+
地
|
| 46 |
+
塞
|
| 47 |
+
壳
|
| 48 |
+
处
|
| 49 |
+
奇
|
| 50 |
+
奔
|
| 51 |
+
好
|
| 52 |
+
宁
|
| 53 |
+
宿
|
| 54 |
+
将
|
| 55 |
+
少
|
| 56 |
+
尽
|
| 57 |
+
岗
|
| 58 |
+
差
|
| 59 |
+
巷
|
| 60 |
+
帖
|
| 61 |
+
干
|
| 62 |
+
应
|
| 63 |
+
度
|
| 64 |
+
弹
|
| 65 |
+
强
|
| 66 |
+
当
|
| 67 |
+
待
|
| 68 |
+
得
|
| 69 |
+
恶
|
| 70 |
+
扁
|
| 71 |
+
扇
|
| 72 |
+
扎
|
| 73 |
+
扫
|
| 74 |
+
担
|
| 75 |
+
挑
|
| 76 |
+
据
|
| 77 |
+
撒
|
| 78 |
+
教
|
| 79 |
+
散
|
| 80 |
+
数
|
| 81 |
+
斗
|
| 82 |
+
晃
|
| 83 |
+
曝
|
| 84 |
+
曲
|
| 85 |
+
更
|
| 86 |
+
曾
|
| 87 |
+
朝
|
| 88 |
+
朴
|
| 89 |
+
杆
|
| 90 |
+
查
|
| 91 |
+
校
|
| 92 |
+
模
|
| 93 |
+
横
|
| 94 |
+
没
|
| 95 |
+
泡
|
| 96 |
+
济
|
| 97 |
+
混
|
| 98 |
+
漂
|
| 99 |
+
炸
|
| 100 |
+
熟
|
| 101 |
+
燕
|
| 102 |
+
片
|
| 103 |
+
率
|
| 104 |
+
畜
|
| 105 |
+
的
|
| 106 |
+
盛
|
| 107 |
+
相
|
| 108 |
+
省
|
| 109 |
+
看
|
| 110 |
+
着
|
| 111 |
+
矫
|
| 112 |
+
禁
|
| 113 |
+
种
|
| 114 |
+
称
|
| 115 |
+
空
|
| 116 |
+
答
|
| 117 |
+
粘
|
| 118 |
+
糊
|
| 119 |
+
系
|
| 120 |
+
累
|
| 121 |
+
纤
|
| 122 |
+
结
|
| 123 |
+
给
|
| 124 |
+
缝
|
| 125 |
+
肖
|
| 126 |
+
背
|
| 127 |
+
脏
|
| 128 |
+
舍
|
| 129 |
+
色
|
| 130 |
+
落
|
| 131 |
+
蒙
|
| 132 |
+
薄
|
| 133 |
+
藏
|
| 134 |
+
血
|
| 135 |
+
行
|
| 136 |
+
要
|
| 137 |
+
观
|
| 138 |
+
觉
|
| 139 |
+
角
|
| 140 |
+
解
|
| 141 |
+
说
|
| 142 |
+
调
|
| 143 |
+
踏
|
| 144 |
+
车
|
| 145 |
+
转
|
| 146 |
+
载
|
| 147 |
+
还
|
| 148 |
+
遂
|
| 149 |
+
都
|
| 150 |
+
重
|
| 151 |
+
量
|
| 152 |
+
钻
|
| 153 |
+
铺
|
| 154 |
+
长
|
| 155 |
+
间
|
| 156 |
+
降
|
| 157 |
+
难
|
| 158 |
+
露
|
| 159 |
+
鲜
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict.json
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"1": "丧{sang1}",
|
| 3 |
+
"2": "丧{sang4}",
|
| 4 |
+
"3": "中{zhong1}",
|
| 5 |
+
"4": "中{zhong4}",
|
| 6 |
+
"5": "为{wei2}",
|
| 7 |
+
"6": "为{wei4}",
|
| 8 |
+
"7": "乌{wu1}",
|
| 9 |
+
"8": "乌{wu4}",
|
| 10 |
+
"9": "乐{lao4}",
|
| 11 |
+
"10": "乐{le4}",
|
| 12 |
+
"11": "乐{le5}",
|
| 13 |
+
"12": "乐{yao4}",
|
| 14 |
+
"13": "乐{yve4}",
|
| 15 |
+
"14": "了{le5}",
|
| 16 |
+
"15": "了{liao3}",
|
| 17 |
+
"16": "了{liao5}",
|
| 18 |
+
"17": "什{shen2}",
|
| 19 |
+
"18": "什{shi2}",
|
| 20 |
+
"19": "仔{zai3}",
|
| 21 |
+
"20": "仔{zai5}",
|
| 22 |
+
"21": "仔{zi3}",
|
| 23 |
+
"22": "仔{zi5}",
|
| 24 |
+
"23": "令{ling2}",
|
| 25 |
+
"24": "令{ling4}",
|
| 26 |
+
"25": "任{ren2}",
|
| 27 |
+
"26": "任{ren4}",
|
| 28 |
+
"27": "会{hui4}",
|
| 29 |
+
"28": "会{hui5}",
|
| 30 |
+
"29": "会{kuai4}",
|
| 31 |
+
"30": "传{chuan2}",
|
| 32 |
+
"31": "传{zhuan4}",
|
| 33 |
+
"32": "佛{fo2}",
|
| 34 |
+
"33": "佛{fu2}",
|
| 35 |
+
"34": "供{gong1}",
|
| 36 |
+
"35": "供{gong4}",
|
| 37 |
+
"36": "便{bian4}",
|
| 38 |
+
"37": "便{pian2}",
|
| 39 |
+
"38": "倒{dao3}",
|
| 40 |
+
"39": "倒{dao4}",
|
| 41 |
+
"40": "假{jia3}",
|
| 42 |
+
"41": "假{jia4}",
|
| 43 |
+
"42": "兴{xing1}",
|
| 44 |
+
"43": "兴{xing4}",
|
| 45 |
+
"44": "冠{guan1}",
|
| 46 |
+
"45": "冠{guan4}",
|
| 47 |
+
"46": "冲{chong1}",
|
| 48 |
+
"47": "冲{chong4}",
|
| 49 |
+
"48": "几{ji1}",
|
| 50 |
+
"49": "几{ji2}",
|
| 51 |
+
"50": "几{ji3}",
|
| 52 |
+
"51": "分{fen1}",
|
| 53 |
+
"52": "分{fen4}",
|
| 54 |
+
"53": "分{fen5}",
|
| 55 |
+
"54": "切{qie1}",
|
| 56 |
+
"55": "切{qie4}",
|
| 57 |
+
"56": "划{hua2}",
|
| 58 |
+
"57": "划{hua4}",
|
| 59 |
+
"58": "划{hua5}",
|
| 60 |
+
"59": "创{chuang1}",
|
| 61 |
+
"60": "创{chuang4}",
|
| 62 |
+
"61": "剥{bao1}",
|
| 63 |
+
"62": "剥{bo1}",
|
| 64 |
+
"63": "勒{le4}",
|
| 65 |
+
"64": "勒{le5}",
|
| 66 |
+
"65": "勒{lei1}",
|
| 67 |
+
"66": "区{ou1}",
|
| 68 |
+
"67": "区{qu1}",
|
| 69 |
+
"68": "华{hua2}",
|
| 70 |
+
"69": "华{hua4}",
|
| 71 |
+
"70": "单{chan2}",
|
| 72 |
+
"71": "单{dan1}",
|
| 73 |
+
"72": "单{shan4}",
|
| 74 |
+
"73": "卜{bo5}",
|
| 75 |
+
"74": "卜{bu3}",
|
| 76 |
+
"75": "占{zhan1}",
|
| 77 |
+
"76": "占{zhan4}",
|
| 78 |
+
"77": "卡{ka2}",
|
| 79 |
+
"78": "卡{ka3}",
|
| 80 |
+
"79": "卡{qia3}",
|
| 81 |
+
"80": "卷{jvan3}",
|
| 82 |
+
"81": "卷{jvan4}",
|
| 83 |
+
"82": "厦{sha4}",
|
| 84 |
+
"83": "厦{xia4}",
|
| 85 |
+
"84": "参{can1}",
|
| 86 |
+
"85": "参{cen1}",
|
| 87 |
+
"86": "参{shen1}",
|
| 88 |
+
"87": "发{fa1}",
|
| 89 |
+
"88": "发{fa4}",
|
| 90 |
+
"89": "发{fa5}",
|
| 91 |
+
"90": "只{zhi1}",
|
| 92 |
+
"91": "只{zhi3}",
|
| 93 |
+
"92": "号{hao2}",
|
| 94 |
+
"93": "号{hao4}",
|
| 95 |
+
"94": "号{hao5}",
|
| 96 |
+
"95": "同{tong2}",
|
| 97 |
+
"96": "同{tong4}",
|
| 98 |
+
"97": "同{tong5}",
|
| 99 |
+
"98": "吐{tu2}",
|
| 100 |
+
"99": "吐{tu3}",
|
| 101 |
+
"100": "吐{tu4}",
|
| 102 |
+
"101": "和{he2}",
|
| 103 |
+
"102": "和{he4}",
|
| 104 |
+
"103": "和{he5}",
|
| 105 |
+
"104": "和{huo2}",
|
| 106 |
+
"105": "和{huo4}",
|
| 107 |
+
"106": "和{huo5}",
|
| 108 |
+
"107": "喝{he1}",
|
| 109 |
+
"108": "喝{he4}",
|
| 110 |
+
"109": "圈{jvan4}",
|
| 111 |
+
"110": "圈{qvan1}",
|
| 112 |
+
"111": "圈{qvan5}",
|
| 113 |
+
"112": "地{de5}",
|
| 114 |
+
"113": "地{di4}",
|
| 115 |
+
"114": "地{di5}",
|
| 116 |
+
"115": "塞{sai1}",
|
| 117 |
+
"116": "塞{sai2}",
|
| 118 |
+
"117": "塞{sai4}",
|
| 119 |
+
"118": "塞{se4}",
|
| 120 |
+
"119": "壳{ke2}",
|
| 121 |
+
"120": "壳{qiao4}",
|
| 122 |
+
"121": "处{chu3}",
|
| 123 |
+
"122": "处{chu4}",
|
| 124 |
+
"123": "奇{ji1}",
|
| 125 |
+
"124": "奇{qi2}",
|
| 126 |
+
"125": "奔{ben1}",
|
| 127 |
+
"126": "奔{ben4}",
|
| 128 |
+
"127": "好{hao3}",
|
| 129 |
+
"128": "好{hao4}",
|
| 130 |
+
"129": "好{hao5}",
|
| 131 |
+
"130": "宁{ning2}",
|
| 132 |
+
"131": "宁{ning4}",
|
| 133 |
+
"132": "宁{ning5}",
|
| 134 |
+
"133": "宿{su4}",
|
| 135 |
+
"134": "宿{xiu3}",
|
| 136 |
+
"135": "宿{xiu4}",
|
| 137 |
+
"136": "将{jiang1}",
|
| 138 |
+
"137": "将{jiang4}",
|
| 139 |
+
"138": "少{shao3}",
|
| 140 |
+
"139": "少{shao4}",
|
| 141 |
+
"140": "尽{jin3}",
|
| 142 |
+
"141": "尽{jin4}",
|
| 143 |
+
"142": "岗{gang1}",
|
| 144 |
+
"143": "岗{gang3}",
|
| 145 |
+
"144": "差{cha1}",
|
| 146 |
+
"145": "差{cha4}",
|
| 147 |
+
"146": "差{chai1}",
|
| 148 |
+
"147": "差{ci1}",
|
| 149 |
+
"148": "巷{hang4}",
|
| 150 |
+
"149": "巷{xiang4}",
|
| 151 |
+
"150": "帖{tie1}",
|
| 152 |
+
"151": "帖{tie3}",
|
| 153 |
+
"152": "帖{tie4}",
|
| 154 |
+
"153": "干{gan1}",
|
| 155 |
+
"154": "干{gan4}",
|
| 156 |
+
"155": "应{ying1}",
|
| 157 |
+
"156": "应{ying4}",
|
| 158 |
+
"157": "应{ying5}",
|
| 159 |
+
"158": "度{du4}",
|
| 160 |
+
"159": "度{du5}",
|
| 161 |
+
"160": "度{duo2}",
|
| 162 |
+
"161": "弹{dan4}",
|
| 163 |
+
"162": "弹{tan2}",
|
| 164 |
+
"163": "弹{tan5}",
|
| 165 |
+
"164": "强{jiang4}",
|
| 166 |
+
"165": "强{qiang2}",
|
| 167 |
+
"166": "强{qiang3}",
|
| 168 |
+
"167": "当{dang1}",
|
| 169 |
+
"168": "当{dang4}",
|
| 170 |
+
"169": "当{dang5}",
|
| 171 |
+
"170": "待{dai1}",
|
| 172 |
+
"171": "待{dai4}",
|
| 173 |
+
"172": "得{de2}",
|
| 174 |
+
"173": "得{de5}",
|
| 175 |
+
"174": "得{dei3}",
|
| 176 |
+
"175": "得{dei5}",
|
| 177 |
+
"176": "恶{e3}",
|
| 178 |
+
"177": "恶{e4}",
|
| 179 |
+
"178": "恶{wu4}",
|
| 180 |
+
"179": "扁{bian3}",
|
| 181 |
+
"180": "扁{pian1}",
|
| 182 |
+
"181": "扇{shan1}",
|
| 183 |
+
"182": "扇{shan4}",
|
| 184 |
+
"183": "扎{za1}",
|
| 185 |
+
"184": "扎{zha1}",
|
| 186 |
+
"185": "扎{zha2}",
|
| 187 |
+
"186": "扫{sao3}",
|
| 188 |
+
"187": "扫{sao4}",
|
| 189 |
+
"188": "担{dan1}",
|
| 190 |
+
"189": "担{dan4}",
|
| 191 |
+
"190": "担{dan5}",
|
| 192 |
+
"191": "挑{tiao1}",
|
| 193 |
+
"192": "挑{tiao3}",
|
| 194 |
+
"193": "据{jv1}",
|
| 195 |
+
"194": "据{jv4}",
|
| 196 |
+
"195": "撒{sa1}",
|
| 197 |
+
"196": "撒{sa3}",
|
| 198 |
+
"197": "撒{sa5}",
|
| 199 |
+
"198": "教{jiao1}",
|
| 200 |
+
"199": "教{jiao4}",
|
| 201 |
+
"200": "散{san3}",
|
| 202 |
+
"201": "散{san4}",
|
| 203 |
+
"202": "散{san5}",
|
| 204 |
+
"203": "数{shu3}",
|
| 205 |
+
"204": "数{shu4}",
|
| 206 |
+
"205": "数{shu5}",
|
| 207 |
+
"206": "斗{dou3}",
|
| 208 |
+
"207": "斗{dou4}",
|
| 209 |
+
"208": "晃{huang3}",
|
| 210 |
+
"209": "曝{bao4}",
|
| 211 |
+
"210": "曲{qu1}",
|
| 212 |
+
"211": "曲{qu3}",
|
| 213 |
+
"212": "更{geng1}",
|
| 214 |
+
"213": "更{geng4}",
|
| 215 |
+
"214": "曾{ceng1}",
|
| 216 |
+
"215": "曾{ceng2}",
|
| 217 |
+
"216": "曾{zeng1}",
|
| 218 |
+
"217": "朝{chao2}",
|
| 219 |
+
"218": "朝{zhao1}",
|
| 220 |
+
"219": "朴{piao2}",
|
| 221 |
+
"220": "朴{pu2}",
|
| 222 |
+
"221": "朴{pu3}",
|
| 223 |
+
"222": "杆{gan1}",
|
| 224 |
+
"223": "杆{gan3}",
|
| 225 |
+
"224": "查{cha2}",
|
| 226 |
+
"225": "查{zha1}",
|
| 227 |
+
"226": "校{jiao4}",
|
| 228 |
+
"227": "校{xiao4}",
|
| 229 |
+
"228": "模{mo2}",
|
| 230 |
+
"229": "模{mu2}",
|
| 231 |
+
"230": "横{heng2}",
|
| 232 |
+
"231": "横{heng4}",
|
| 233 |
+
"232": "没{mei2}",
|
| 234 |
+
"233": "没{mo4}",
|
| 235 |
+
"234": "泡{pao1}",
|
| 236 |
+
"235": "泡{pao4}",
|
| 237 |
+
"236": "泡{pao5}",
|
| 238 |
+
"237": "济{ji3}",
|
| 239 |
+
"238": "济{ji4}",
|
| 240 |
+
"239": "混{hun2}",
|
| 241 |
+
"240": "混{hun3}",
|
| 242 |
+
"241": "混{hun4}",
|
| 243 |
+
"242": "混{hun5}",
|
| 244 |
+
"243": "漂{piao1}",
|
| 245 |
+
"244": "漂{piao3}",
|
| 246 |
+
"245": "漂{piao4}",
|
| 247 |
+
"246": "炸{zha2}",
|
| 248 |
+
"247": "炸{zha4}",
|
| 249 |
+
"248": "熟{shou2}",
|
| 250 |
+
"249": "熟{shu2}",
|
| 251 |
+
"250": "燕{yan1}",
|
| 252 |
+
"251": "燕{yan4}",
|
| 253 |
+
"252": "片{pian1}",
|
| 254 |
+
"253": "片{pian4}",
|
| 255 |
+
"254": "率{lv4}",
|
| 256 |
+
"255": "率{shuai4}",
|
| 257 |
+
"256": "畜{chu4}",
|
| 258 |
+
"257": "畜{xu4}",
|
| 259 |
+
"258": "的{de5}",
|
| 260 |
+
"259": "的{di1}",
|
| 261 |
+
"260": "的{di2}",
|
| 262 |
+
"261": "的{di4}",
|
| 263 |
+
"262": "的{di5}",
|
| 264 |
+
"263": "盛{cheng2}",
|
| 265 |
+
"264": "盛{sheng4}",
|
| 266 |
+
"265": "相{xiang1}",
|
| 267 |
+
"266": "相{xiang4}",
|
| 268 |
+
"267": "相{xiang5}",
|
| 269 |
+
"268": "省{sheng3}",
|
| 270 |
+
"269": "省{xing3}",
|
| 271 |
+
"270": "看{kan1}",
|
| 272 |
+
"271": "看{kan4}",
|
| 273 |
+
"272": "看{kan5}",
|
| 274 |
+
"273": "着{zhao1}",
|
| 275 |
+
"274": "着{zhao2}",
|
| 276 |
+
"275": "着{zhao5}",
|
| 277 |
+
"276": "着{zhe5}",
|
| 278 |
+
"277": "着{zhuo2}",
|
| 279 |
+
"278": "着{zhuo5}",
|
| 280 |
+
"279": "矫{jiao3}",
|
| 281 |
+
"280": "禁{jin1}",
|
| 282 |
+
"281": "禁{jin4}",
|
| 283 |
+
"282": "种{zhong3}",
|
| 284 |
+
"283": "种{zhong4}",
|
| 285 |
+
"284": "称{chen4}",
|
| 286 |
+
"285": "称{cheng1}",
|
| 287 |
+
"286": "空{kong1}",
|
| 288 |
+
"287": "空{kong4}",
|
| 289 |
+
"288": "答{da1}",
|
| 290 |
+
"289": "答{da2}",
|
| 291 |
+
"290": "粘{nian2}",
|
| 292 |
+
"291": "粘{zhan1}",
|
| 293 |
+
"292": "糊{hu2}",
|
| 294 |
+
"293": "糊{hu5}",
|
| 295 |
+
"294": "系{ji4}",
|
| 296 |
+
"295": "系{xi4}",
|
| 297 |
+
"296": "系{xi5}",
|
| 298 |
+
"297": "累{lei2}",
|
| 299 |
+
"298": "累{lei3}",
|
| 300 |
+
"299": "累{lei4}",
|
| 301 |
+
"300": "累{lei5}",
|
| 302 |
+
"301": "纤{qian4}",
|
| 303 |
+
"302": "纤{xian1}",
|
| 304 |
+
"303": "结{jie1}",
|
| 305 |
+
"304": "结{jie2}",
|
| 306 |
+
"305": "结{jie5}",
|
| 307 |
+
"306": "给{gei3}",
|
| 308 |
+
"307": "给{gei5}",
|
| 309 |
+
"308": "给{ji3}",
|
| 310 |
+
"309": "缝{feng2}",
|
| 311 |
+
"310": "缝{feng4}",
|
| 312 |
+
"311": "缝{feng5}",
|
| 313 |
+
"312": "肖{xiao1}",
|
| 314 |
+
"313": "肖{xiao4}",
|
| 315 |
+
"314": "背{bei1}",
|
| 316 |
+
"315": "背{bei4}",
|
| 317 |
+
"316": "脏{zang1}",
|
| 318 |
+
"317": "脏{zang4}",
|
| 319 |
+
"318": "舍{she3}",
|
| 320 |
+
"319": "舍{she4}",
|
| 321 |
+
"320": "色{se4}",
|
| 322 |
+
"321": "色{shai3}",
|
| 323 |
+
"322": "落{lao4}",
|
| 324 |
+
"323": "落{luo4}",
|
| 325 |
+
"324": "蒙{meng1}",
|
| 326 |
+
"325": "蒙{meng2}",
|
| 327 |
+
"326": "蒙{meng3}",
|
| 328 |
+
"327": "薄{bao2}",
|
| 329 |
+
"328": "薄{bo2}",
|
| 330 |
+
"329": "薄{bo4}",
|
| 331 |
+
"330": "藏{cang2}",
|
| 332 |
+
"331": "藏{zang4}",
|
| 333 |
+
"332": "血{xie3}",
|
| 334 |
+
"333": "血{xue4}",
|
| 335 |
+
"334": "行{hang2}",
|
| 336 |
+
"335": "行{hang5}",
|
| 337 |
+
"336": "行{heng5}",
|
| 338 |
+
"337": "行{xing2}",
|
| 339 |
+
"338": "行{xing4}",
|
| 340 |
+
"339": "要{yao1}",
|
| 341 |
+
"340": "要{yao4}",
|
| 342 |
+
"341": "观{guan1}",
|
| 343 |
+
"342": "观{guan4}",
|
| 344 |
+
"343": "觉{jiao4}",
|
| 345 |
+
"344": "觉{jiao5}",
|
| 346 |
+
"345": "觉{jve2}",
|
| 347 |
+
"346": "角{jiao3}",
|
| 348 |
+
"347": "角{jve2}",
|
| 349 |
+
"348": "解{jie3}",
|
| 350 |
+
"349": "解{jie4}",
|
| 351 |
+
"350": "解{xie4}",
|
| 352 |
+
"351": "说{shui4}",
|
| 353 |
+
"352": "说{shuo1}",
|
| 354 |
+
"353": "调{diao4}",
|
| 355 |
+
"354": "调{tiao2}",
|
| 356 |
+
"355": "踏{ta1}",
|
| 357 |
+
"356": "踏{ta4}",
|
| 358 |
+
"357": "车{che1}",
|
| 359 |
+
"358": "车{jv1}",
|
| 360 |
+
"359": "转{zhuan3}",
|
| 361 |
+
"360": "转{zhuan4}",
|
| 362 |
+
"361": "载{zai3}",
|
| 363 |
+
"362": "载{zai4}",
|
| 364 |
+
"363": "还{hai2}",
|
| 365 |
+
"364": "还{huan2}",
|
| 366 |
+
"365": "遂{sui2}",
|
| 367 |
+
"366": "遂{sui4}",
|
| 368 |
+
"367": "都{dou1}",
|
| 369 |
+
"368": "都{du1}",
|
| 370 |
+
"369": "重{chong2}",
|
| 371 |
+
"370": "重{zhong4}",
|
| 372 |
+
"371": "量{liang2}",
|
| 373 |
+
"372": "量{liang4}",
|
| 374 |
+
"373": "量{liang5}",
|
| 375 |
+
"374": "钻{zuan1}",
|
| 376 |
+
"375": "钻{zuan4}",
|
| 377 |
+
"376": "铺{pu1}",
|
| 378 |
+
"377": "铺{pu4}",
|
| 379 |
+
"378": "长{chang2}",
|
| 380 |
+
"379": "长{chang3}",
|
| 381 |
+
"380": "长{zhang3}",
|
| 382 |
+
"381": "间{jian1}",
|
| 383 |
+
"382": "间{jian4}",
|
| 384 |
+
"383": "降{jiang4}",
|
| 385 |
+
"384": "降{xiang2}",
|
| 386 |
+
"385": "难{nan2}",
|
| 387 |
+
"386": "难{nan4}",
|
| 388 |
+
"387": "难{nan5}",
|
| 389 |
+
"388": "露{lou4}",
|
| 390 |
+
"389": "露{lu4}",
|
| 391 |
+
"390": "鲜{xian1}",
|
| 392 |
+
"391": "鲜{xian3}"
|
| 393 |
+
}
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/polydict_r.json
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"丧{sang1}": 1,
|
| 3 |
+
"丧{sang4}": 2,
|
| 4 |
+
"中{zhong1}": 3,
|
| 5 |
+
"中{zhong4}": 4,
|
| 6 |
+
"为{wei2}": 5,
|
| 7 |
+
"为{wei4}": 6,
|
| 8 |
+
"乌{wu1}": 7,
|
| 9 |
+
"乌{wu4}": 8,
|
| 10 |
+
"乐{lao4}": 9,
|
| 11 |
+
"乐{le4}": 10,
|
| 12 |
+
"乐{le5}": 11,
|
| 13 |
+
"乐{yao4}": 12,
|
| 14 |
+
"乐{yve4}": 13,
|
| 15 |
+
"了{le5}": 14,
|
| 16 |
+
"了{liao3}": 15,
|
| 17 |
+
"了{liao5}": 16,
|
| 18 |
+
"什{shen2}": 17,
|
| 19 |
+
"什{shi2}": 18,
|
| 20 |
+
"仔{zai3}": 19,
|
| 21 |
+
"仔{zai5}": 20,
|
| 22 |
+
"仔{zi3}": 21,
|
| 23 |
+
"仔{zi5}": 22,
|
| 24 |
+
"令{ling2}": 23,
|
| 25 |
+
"令{ling4}": 24,
|
| 26 |
+
"任{ren2}": 25,
|
| 27 |
+
"任{ren4}": 26,
|
| 28 |
+
"会{hui4}": 27,
|
| 29 |
+
"会{hui5}": 28,
|
| 30 |
+
"会{kuai4}": 29,
|
| 31 |
+
"传{chuan2}": 30,
|
| 32 |
+
"传{zhuan4}": 31,
|
| 33 |
+
"佛{fo2}": 32,
|
| 34 |
+
"佛{fu2}": 33,
|
| 35 |
+
"供{gong1}": 34,
|
| 36 |
+
"供{gong4}": 35,
|
| 37 |
+
"便{bian4}": 36,
|
| 38 |
+
"便{pian2}": 37,
|
| 39 |
+
"倒{dao3}": 38,
|
| 40 |
+
"倒{dao4}": 39,
|
| 41 |
+
"假{jia3}": 40,
|
| 42 |
+
"假{jia4}": 41,
|
| 43 |
+
"兴{xing1}": 42,
|
| 44 |
+
"兴{xing4}": 43,
|
| 45 |
+
"冠{guan1}": 44,
|
| 46 |
+
"冠{guan4}": 45,
|
| 47 |
+
"冲{chong1}": 46,
|
| 48 |
+
"冲{chong4}": 47,
|
| 49 |
+
"几{ji1}": 48,
|
| 50 |
+
"几{ji2}": 49,
|
| 51 |
+
"几{ji3}": 50,
|
| 52 |
+
"分{fen1}": 51,
|
| 53 |
+
"分{fen4}": 52,
|
| 54 |
+
"分{fen5}": 53,
|
| 55 |
+
"切{qie1}": 54,
|
| 56 |
+
"切{qie4}": 55,
|
| 57 |
+
"划{hua2}": 56,
|
| 58 |
+
"划{hua4}": 57,
|
| 59 |
+
"划{hua5}": 58,
|
| 60 |
+
"创{chuang1}": 59,
|
| 61 |
+
"创{chuang4}": 60,
|
| 62 |
+
"剥{bao1}": 61,
|
| 63 |
+
"剥{bo1}": 62,
|
| 64 |
+
"勒{le4}": 63,
|
| 65 |
+
"勒{le5}": 64,
|
| 66 |
+
"勒{lei1}": 65,
|
| 67 |
+
"区{ou1}": 66,
|
| 68 |
+
"区{qu1}": 67,
|
| 69 |
+
"华{hua2}": 68,
|
| 70 |
+
"华{hua4}": 69,
|
| 71 |
+
"单{chan2}": 70,
|
| 72 |
+
"单{dan1}": 71,
|
| 73 |
+
"单{shan4}": 72,
|
| 74 |
+
"卜{bo5}": 73,
|
| 75 |
+
"卜{bu3}": 74,
|
| 76 |
+
"占{zhan1}": 75,
|
| 77 |
+
"占{zhan4}": 76,
|
| 78 |
+
"卡{ka2}": 77,
|
| 79 |
+
"卡{ka3}": 78,
|
| 80 |
+
"卡{qia3}": 79,
|
| 81 |
+
"卷{jvan3}": 80,
|
| 82 |
+
"卷{jvan4}": 81,
|
| 83 |
+
"厦{sha4}": 82,
|
| 84 |
+
"厦{xia4}": 83,
|
| 85 |
+
"参{can1}": 84,
|
| 86 |
+
"参{cen1}": 85,
|
| 87 |
+
"参{shen1}": 86,
|
| 88 |
+
"发{fa1}": 87,
|
| 89 |
+
"发{fa4}": 88,
|
| 90 |
+
"发{fa5}": 89,
|
| 91 |
+
"只{zhi1}": 90,
|
| 92 |
+
"只{zhi3}": 91,
|
| 93 |
+
"号{hao2}": 92,
|
| 94 |
+
"号{hao4}": 93,
|
| 95 |
+
"号{hao5}": 94,
|
| 96 |
+
"同{tong2}": 95,
|
| 97 |
+
"同{tong4}": 96,
|
| 98 |
+
"同{tong5}": 97,
|
| 99 |
+
"吐{tu2}": 98,
|
| 100 |
+
"吐{tu3}": 99,
|
| 101 |
+
"吐{tu4}": 100,
|
| 102 |
+
"和{he2}": 101,
|
| 103 |
+
"和{he4}": 102,
|
| 104 |
+
"和{he5}": 103,
|
| 105 |
+
"和{huo2}": 104,
|
| 106 |
+
"和{huo4}": 105,
|
| 107 |
+
"和{huo5}": 106,
|
| 108 |
+
"喝{he1}": 107,
|
| 109 |
+
"喝{he4}": 108,
|
| 110 |
+
"圈{jvan4}": 109,
|
| 111 |
+
"圈{qvan1}": 110,
|
| 112 |
+
"圈{qvan5}": 111,
|
| 113 |
+
"地{de5}": 112,
|
| 114 |
+
"地{di4}": 113,
|
| 115 |
+
"地{di5}": 114,
|
| 116 |
+
"塞{sai1}": 115,
|
| 117 |
+
"塞{sai2}": 116,
|
| 118 |
+
"塞{sai4}": 117,
|
| 119 |
+
"塞{se4}": 118,
|
| 120 |
+
"壳{ke2}": 119,
|
| 121 |
+
"壳{qiao4}": 120,
|
| 122 |
+
"处{chu3}": 121,
|
| 123 |
+
"处{chu4}": 122,
|
| 124 |
+
"奇{ji1}": 123,
|
| 125 |
+
"奇{qi2}": 124,
|
| 126 |
+
"奔{ben1}": 125,
|
| 127 |
+
"奔{ben4}": 126,
|
| 128 |
+
"好{hao3}": 127,
|
| 129 |
+
"好{hao4}": 128,
|
| 130 |
+
"好{hao5}": 129,
|
| 131 |
+
"宁{ning2}": 130,
|
| 132 |
+
"宁{ning4}": 131,
|
| 133 |
+
"宁{ning5}": 132,
|
| 134 |
+
"宿{su4}": 133,
|
| 135 |
+
"宿{xiu3}": 134,
|
| 136 |
+
"宿{xiu4}": 135,
|
| 137 |
+
"将{jiang1}": 136,
|
| 138 |
+
"将{jiang4}": 137,
|
| 139 |
+
"少{shao3}": 138,
|
| 140 |
+
"少{shao4}": 139,
|
| 141 |
+
"尽{jin3}": 140,
|
| 142 |
+
"尽{jin4}": 141,
|
| 143 |
+
"岗{gang1}": 142,
|
| 144 |
+
"岗{gang3}": 143,
|
| 145 |
+
"差{cha1}": 144,
|
| 146 |
+
"差{cha4}": 145,
|
| 147 |
+
"差{chai1}": 146,
|
| 148 |
+
"差{ci1}": 147,
|
| 149 |
+
"巷{hang4}": 148,
|
| 150 |
+
"巷{xiang4}": 149,
|
| 151 |
+
"帖{tie1}": 150,
|
| 152 |
+
"帖{tie3}": 151,
|
| 153 |
+
"帖{tie4}": 152,
|
| 154 |
+
"干{gan1}": 153,
|
| 155 |
+
"干{gan4}": 154,
|
| 156 |
+
"应{ying1}": 155,
|
| 157 |
+
"应{ying4}": 156,
|
| 158 |
+
"应{ying5}": 157,
|
| 159 |
+
"度{du4}": 158,
|
| 160 |
+
"度{du5}": 159,
|
| 161 |
+
"度{duo2}": 160,
|
| 162 |
+
"弹{dan4}": 161,
|
| 163 |
+
"弹{tan2}": 162,
|
| 164 |
+
"弹{tan5}": 163,
|
| 165 |
+
"强{jiang4}": 164,
|
| 166 |
+
"强{qiang2}": 165,
|
| 167 |
+
"强{qiang3}": 166,
|
| 168 |
+
"当{dang1}": 167,
|
| 169 |
+
"当{dang4}": 168,
|
| 170 |
+
"当{dang5}": 169,
|
| 171 |
+
"待{dai1}": 170,
|
| 172 |
+
"待{dai4}": 171,
|
| 173 |
+
"得{de2}": 172,
|
| 174 |
+
"得{de5}": 173,
|
| 175 |
+
"得{dei3}": 174,
|
| 176 |
+
"得{dei5}": 175,
|
| 177 |
+
"恶{e3}": 176,
|
| 178 |
+
"恶{e4}": 177,
|
| 179 |
+
"恶{wu4}": 178,
|
| 180 |
+
"扁{bian3}": 179,
|
| 181 |
+
"扁{pian1}": 180,
|
| 182 |
+
"扇{shan1}": 181,
|
| 183 |
+
"扇{shan4}": 182,
|
| 184 |
+
"扎{za1}": 183,
|
| 185 |
+
"扎{zha1}": 184,
|
| 186 |
+
"扎{zha2}": 185,
|
| 187 |
+
"扫{sao3}": 186,
|
| 188 |
+
"扫{sao4}": 187,
|
| 189 |
+
"担{dan1}": 188,
|
| 190 |
+
"担{dan4}": 189,
|
| 191 |
+
"担{dan5}": 190,
|
| 192 |
+
"挑{tiao1}": 191,
|
| 193 |
+
"挑{tiao3}": 192,
|
| 194 |
+
"据{jv1}": 193,
|
| 195 |
+
"据{jv4}": 194,
|
| 196 |
+
"撒{sa1}": 195,
|
| 197 |
+
"撒{sa3}": 196,
|
| 198 |
+
"撒{sa5}": 197,
|
| 199 |
+
"教{jiao1}": 198,
|
| 200 |
+
"教{jiao4}": 199,
|
| 201 |
+
"散{san3}": 200,
|
| 202 |
+
"散{san4}": 201,
|
| 203 |
+
"散{san5}": 202,
|
| 204 |
+
"数{shu3}": 203,
|
| 205 |
+
"数{shu4}": 204,
|
| 206 |
+
"数{shu5}": 205,
|
| 207 |
+
"斗{dou3}": 206,
|
| 208 |
+
"斗{dou4}": 207,
|
| 209 |
+
"晃{huang3}": 208,
|
| 210 |
+
"曝{bao4}": 209,
|
| 211 |
+
"曲{qu1}": 210,
|
| 212 |
+
"曲{qu3}": 211,
|
| 213 |
+
"更{geng1}": 212,
|
| 214 |
+
"更{geng4}": 213,
|
| 215 |
+
"曾{ceng1}": 214,
|
| 216 |
+
"曾{ceng2}": 215,
|
| 217 |
+
"曾{zeng1}": 216,
|
| 218 |
+
"朝{chao2}": 217,
|
| 219 |
+
"朝{zhao1}": 218,
|
| 220 |
+
"朴{piao2}": 219,
|
| 221 |
+
"朴{pu2}": 220,
|
| 222 |
+
"朴{pu3}": 221,
|
| 223 |
+
"杆{gan1}": 222,
|
| 224 |
+
"杆{gan3}": 223,
|
| 225 |
+
"查{cha2}": 224,
|
| 226 |
+
"查{zha1}": 225,
|
| 227 |
+
"校{jiao4}": 226,
|
| 228 |
+
"校{xiao4}": 227,
|
| 229 |
+
"模{mo2}": 228,
|
| 230 |
+
"模{mu2}": 229,
|
| 231 |
+
"横{heng2}": 230,
|
| 232 |
+
"横{heng4}": 231,
|
| 233 |
+
"没{mei2}": 232,
|
| 234 |
+
"没{mo4}": 233,
|
| 235 |
+
"泡{pao1}": 234,
|
| 236 |
+
"泡{pao4}": 235,
|
| 237 |
+
"泡{pao5}": 236,
|
| 238 |
+
"济{ji3}": 237,
|
| 239 |
+
"济{ji4}": 238,
|
| 240 |
+
"混{hun2}": 239,
|
| 241 |
+
"混{hun3}": 240,
|
| 242 |
+
"混{hun4}": 241,
|
| 243 |
+
"混{hun5}": 242,
|
| 244 |
+
"漂{piao1}": 243,
|
| 245 |
+
"漂{piao3}": 244,
|
| 246 |
+
"漂{piao4}": 245,
|
| 247 |
+
"炸{zha2}": 246,
|
| 248 |
+
"炸{zha4}": 247,
|
| 249 |
+
"熟{shou2}": 248,
|
| 250 |
+
"熟{shu2}": 249,
|
| 251 |
+
"燕{yan1}": 250,
|
| 252 |
+
"燕{yan4}": 251,
|
| 253 |
+
"片{pian1}": 252,
|
| 254 |
+
"片{pian4}": 253,
|
| 255 |
+
"率{lv4}": 254,
|
| 256 |
+
"率{shuai4}": 255,
|
| 257 |
+
"畜{chu4}": 256,
|
| 258 |
+
"畜{xu4}": 257,
|
| 259 |
+
"的{de5}": 258,
|
| 260 |
+
"的{di1}": 259,
|
| 261 |
+
"的{di2}": 260,
|
| 262 |
+
"的{di4}": 261,
|
| 263 |
+
"的{di5}": 262,
|
| 264 |
+
"盛{cheng2}": 263,
|
| 265 |
+
"盛{sheng4}": 264,
|
| 266 |
+
"相{xiang1}": 265,
|
| 267 |
+
"相{xiang4}": 266,
|
| 268 |
+
"相{xiang5}": 267,
|
| 269 |
+
"省{sheng3}": 268,
|
| 270 |
+
"省{xing3}": 269,
|
| 271 |
+
"看{kan1}": 270,
|
| 272 |
+
"看{kan4}": 271,
|
| 273 |
+
"看{kan5}": 272,
|
| 274 |
+
"着{zhao1}": 273,
|
| 275 |
+
"着{zhao2}": 274,
|
| 276 |
+
"着{zhao5}": 275,
|
| 277 |
+
"着{zhe5}": 276,
|
| 278 |
+
"着{zhuo2}": 277,
|
| 279 |
+
"着{zhuo5}": 278,
|
| 280 |
+
"矫{jiao3}": 279,
|
| 281 |
+
"禁{jin1}": 280,
|
| 282 |
+
"禁{jin4}": 281,
|
| 283 |
+
"种{zhong3}": 282,
|
| 284 |
+
"种{zhong4}": 283,
|
| 285 |
+
"称{chen4}": 284,
|
| 286 |
+
"称{cheng1}": 285,
|
| 287 |
+
"空{kong1}": 286,
|
| 288 |
+
"空{kong4}": 287,
|
| 289 |
+
"答{da1}": 288,
|
| 290 |
+
"答{da2}": 289,
|
| 291 |
+
"粘{nian2}": 290,
|
| 292 |
+
"粘{zhan1}": 291,
|
| 293 |
+
"糊{hu2}": 292,
|
| 294 |
+
"糊{hu5}": 293,
|
| 295 |
+
"系{ji4}": 294,
|
| 296 |
+
"系{xi4}": 295,
|
| 297 |
+
"系{xi5}": 296,
|
| 298 |
+
"累{lei2}": 297,
|
| 299 |
+
"累{lei3}": 298,
|
| 300 |
+
"累{lei4}": 299,
|
| 301 |
+
"累{lei5}": 300,
|
| 302 |
+
"纤{qian4}": 301,
|
| 303 |
+
"纤{xian1}": 302,
|
| 304 |
+
"结{jie1}": 303,
|
| 305 |
+
"结{jie2}": 304,
|
| 306 |
+
"结{jie5}": 305,
|
| 307 |
+
"给{gei3}": 306,
|
| 308 |
+
"给{gei5}": 307,
|
| 309 |
+
"给{ji3}": 308,
|
| 310 |
+
"缝{feng2}": 309,
|
| 311 |
+
"缝{feng4}": 310,
|
| 312 |
+
"缝{feng5}": 311,
|
| 313 |
+
"肖{xiao1}": 312,
|
| 314 |
+
"肖{xiao4}": 313,
|
| 315 |
+
"背{bei1}": 314,
|
| 316 |
+
"背{bei4}": 315,
|
| 317 |
+
"脏{zang1}": 316,
|
| 318 |
+
"脏{zang4}": 317,
|
| 319 |
+
"舍{she3}": 318,
|
| 320 |
+
"舍{she4}": 319,
|
| 321 |
+
"色{se4}": 320,
|
| 322 |
+
"色{shai3}": 321,
|
| 323 |
+
"落{lao4}": 322,
|
| 324 |
+
"落{luo4}": 323,
|
| 325 |
+
"蒙{meng1}": 324,
|
| 326 |
+
"蒙{meng2}": 325,
|
| 327 |
+
"蒙{meng3}": 326,
|
| 328 |
+
"薄{bao2}": 327,
|
| 329 |
+
"薄{bo2}": 328,
|
| 330 |
+
"薄{bo4}": 329,
|
| 331 |
+
"藏{cang2}": 330,
|
| 332 |
+
"藏{zang4}": 331,
|
| 333 |
+
"血{xie3}": 332,
|
| 334 |
+
"血{xue4}": 333,
|
| 335 |
+
"行{hang2}": 334,
|
| 336 |
+
"行{hang5}": 335,
|
| 337 |
+
"行{heng5}": 336,
|
| 338 |
+
"行{xing2}": 337,
|
| 339 |
+
"行{xing4}": 338,
|
| 340 |
+
"要{yao1}": 339,
|
| 341 |
+
"要{yao4}": 340,
|
| 342 |
+
"观{guan1}": 341,
|
| 343 |
+
"观{guan4}": 342,
|
| 344 |
+
"觉{jiao4}": 343,
|
| 345 |
+
"觉{jiao5}": 344,
|
| 346 |
+
"觉{jve2}": 345,
|
| 347 |
+
"角{jiao3}": 346,
|
| 348 |
+
"角{jve2}": 347,
|
| 349 |
+
"解{jie3}": 348,
|
| 350 |
+
"解{jie4}": 349,
|
| 351 |
+
"解{xie4}": 350,
|
| 352 |
+
"说{shui4}": 351,
|
| 353 |
+
"说{shuo1}": 352,
|
| 354 |
+
"调{diao4}": 353,
|
| 355 |
+
"调{tiao2}": 354,
|
| 356 |
+
"踏{ta1}": 355,
|
| 357 |
+
"踏{ta4}": 356,
|
| 358 |
+
"车{che1}": 357,
|
| 359 |
+
"车{jv1}": 358,
|
| 360 |
+
"转{zhuan3}": 359,
|
| 361 |
+
"转{zhuan4}": 360,
|
| 362 |
+
"载{zai3}": 361,
|
| 363 |
+
"载{zai4}": 362,
|
| 364 |
+
"还{hai2}": 363,
|
| 365 |
+
"还{huan2}": 364,
|
| 366 |
+
"遂{sui2}": 365,
|
| 367 |
+
"遂{sui4}": 366,
|
| 368 |
+
"都{dou1}": 367,
|
| 369 |
+
"都{du1}": 368,
|
| 370 |
+
"重{chong2}": 369,
|
| 371 |
+
"重{zhong4}": 370,
|
| 372 |
+
"量{liang2}": 371,
|
| 373 |
+
"量{liang4}": 372,
|
| 374 |
+
"量{liang5}": 373,
|
| 375 |
+
"钻{zuan1}": 374,
|
| 376 |
+
"钻{zuan4}": 375,
|
| 377 |
+
"铺{pu1}": 376,
|
| 378 |
+
"铺{pu4}": 377,
|
| 379 |
+
"长{chang2}": 378,
|
| 380 |
+
"长{chang3}": 379,
|
| 381 |
+
"长{zhang3}": 380,
|
| 382 |
+
"间{jian1}": 381,
|
| 383 |
+
"间{jian4}": 382,
|
| 384 |
+
"降{jiang4}": 383,
|
| 385 |
+
"降{xiang2}": 384,
|
| 386 |
+
"难{nan2}": 385,
|
| 387 |
+
"难{nan4}": 386,
|
| 388 |
+
"难{nan5}": 387,
|
| 389 |
+
"露{lou4}": 388,
|
| 390 |
+
"露{lu4}": 389,
|
| 391 |
+
"鲜{xian1}": 390,
|
| 392 |
+
"鲜{xian3}": 391
|
| 393 |
+
}
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/g2p_chinese_model/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/YingMusicSinger/utils/f5_tts/g2p/sources/pinyin_2_bpmf.txt
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
a ㄚ
|
| 2 |
+
ai ㄞ
|
| 3 |
+
an ㄢ
|
| 4 |
+
ang ㄤ
|
| 5 |
+
ao ㄠ
|
| 6 |
+
ba ㄅㄚ
|
| 7 |
+
bai ㄅㄞ
|
| 8 |
+
ban ㄅㄢ
|
| 9 |
+
bang ㄅㄤ
|
| 10 |
+
bao ㄅㄠ
|
| 11 |
+
bei ㄅㄟ
|
| 12 |
+
ben ㄅㄣ
|
| 13 |
+
beng ㄅㄥ
|
| 14 |
+
bi ㄅㄧ
|
| 15 |
+
bian ㄅㄧㄢ
|
| 16 |
+
biang ㄅㄧㄤ
|
| 17 |
+
biao ㄅㄧㄠ
|
| 18 |
+
bie ㄅㄧㄝ
|
| 19 |
+
bin ㄅㄧㄣ
|
| 20 |
+
bing ㄅㄧㄥ
|
| 21 |
+
bo ㄅㄛ
|
| 22 |
+
bu ㄅㄨ
|
| 23 |
+
ca ㄘㄚ
|
| 24 |
+
cai ㄘㄞ
|
| 25 |
+
can ㄘㄢ
|
| 26 |
+
cang ㄘㄤ
|
| 27 |
+
cao ㄘㄠ
|
| 28 |
+
ce ㄘㄜ
|
| 29 |
+
cen ㄘㄣ
|
| 30 |
+
ceng ㄘㄥ
|
| 31 |
+
cha ㄔㄚ
|
| 32 |
+
chai ㄔㄞ
|
| 33 |
+
chan ㄔㄢ
|
| 34 |
+
chang ㄔㄤ
|
| 35 |
+
chao ㄔㄠ
|
| 36 |
+
che ㄔㄜ
|
| 37 |
+
chen ㄔㄣ
|
| 38 |
+
cheng ㄔㄥ
|
| 39 |
+
chi ㄔ
|
| 40 |
+
chong ㄔㄨㄥ
|
| 41 |
+
chou ㄔㄡ
|
| 42 |
+
chu ㄔㄨ
|
| 43 |
+
chua ㄔㄨㄚ
|
| 44 |
+
chuai ㄔㄨㄞ
|
| 45 |
+
chuan ㄔㄨㄢ
|
| 46 |
+
chuang ㄔㄨㄤ
|
| 47 |
+
chui ㄔㄨㄟ
|
| 48 |
+
chun ㄔㄨㄣ
|
| 49 |
+
chuo ㄔㄨㄛ
|
| 50 |
+
ci ㄘ
|
| 51 |
+
cong ㄘㄨㄥ
|
| 52 |
+
cou ㄘㄡ
|
| 53 |
+
cu ㄘㄨ
|
| 54 |
+
cuan ㄘㄨㄢ
|
| 55 |
+
cui ㄘㄨㄟ
|
| 56 |
+
cun ㄘㄨㄣ
|
| 57 |
+
cuo ㄘㄨㄛ
|
| 58 |
+
da ㄉㄚ
|
| 59 |
+
dai ㄉㄞ
|
| 60 |
+
dan ㄉㄢ
|
| 61 |
+
dang ㄉㄤ
|
| 62 |
+
dao ㄉㄠ
|
| 63 |
+
de ㄉㄜ
|
| 64 |
+
dei ㄉㄟ
|
| 65 |
+
den ㄉㄣ
|
| 66 |
+
deng ㄉㄥ
|
| 67 |
+
di ㄉㄧ
|
| 68 |
+
dia ㄉㄧㄚ
|
| 69 |
+
dian ㄉㄧㄢ
|
| 70 |
+
diao ㄉㄧㄠ
|
| 71 |
+
die ㄉㄧㄝ
|
| 72 |
+
din ㄉㄧㄣ
|
| 73 |
+
ding ㄉㄧㄥ
|
| 74 |
+
diu ㄉㄧㄡ
|
| 75 |
+
dong ㄉㄨㄥ
|
| 76 |
+
dou ㄉㄡ
|
| 77 |
+
du ㄉㄨ
|
| 78 |
+
duan ㄉㄨㄢ
|
| 79 |
+
dui ㄉㄨㄟ
|
| 80 |
+
dun ㄉㄨㄣ
|
| 81 |
+
duo ㄉㄨㄛ
|
| 82 |
+
e ㄜ
|
| 83 |
+
ei ㄟ
|
| 84 |
+
en ㄣ
|
| 85 |
+
eng ㄥ
|
| 86 |
+
er ㄦ
|
| 87 |
+
fa ㄈㄚ
|
| 88 |
+
fan ㄈㄢ
|
| 89 |
+
fang ㄈㄤ
|
| 90 |
+
fei ㄈㄟ
|
| 91 |
+
fen ㄈㄣ
|
| 92 |
+
feng ㄈㄥ
|
| 93 |
+
fo ㄈㄛ
|
| 94 |
+
fou ㄈㄡ
|
| 95 |
+
fu ㄈㄨ
|
| 96 |
+
ga ㄍㄚ
|
| 97 |
+
gai ㄍㄞ
|
| 98 |
+
gan ㄍㄢ
|
| 99 |
+
gang ㄍㄤ
|
| 100 |
+
gao ㄍㄠ
|
| 101 |
+
ge ㄍㄜ
|
| 102 |
+
gei ㄍㄟ
|
| 103 |
+
gen ㄍㄣ
|
| 104 |
+
geng ㄍㄥ
|
| 105 |
+
gong ㄍㄨㄥ
|
| 106 |
+
gou ㄍㄡ
|
| 107 |
+
gu ㄍㄨ
|
| 108 |
+
gua ㄍㄨㄚ
|
| 109 |
+
guai ㄍㄨㄞ
|
| 110 |
+
guan ㄍㄨㄢ
|
| 111 |
+
guang ㄍㄨㄤ
|
| 112 |
+
gui ㄍㄨㄟ
|
| 113 |
+
gun ㄍㄨㄣ
|
| 114 |
+
guo ㄍㄨㄛ
|
| 115 |
+
ha ㄏㄚ
|
| 116 |
+
hai ㄏㄞ
|
| 117 |
+
han ㄏㄢ
|
| 118 |
+
hang ㄏㄤ
|
| 119 |
+
hao ㄏㄠ
|
| 120 |
+
he ㄏㄜ
|
| 121 |
+
hei ㄏㄟ
|
| 122 |
+
hen ㄏㄣ
|
| 123 |
+
heng ㄏㄥ
|
| 124 |
+
hm ㄏㄇ
|
| 125 |
+
hong ㄏㄨㄥ
|
| 126 |
+
hou ㄏㄡ
|
| 127 |
+
hu ㄏㄨ
|
| 128 |
+
hua ㄏㄨㄚ
|
| 129 |
+
huai ㄏㄨㄞ
|
| 130 |
+
huan ㄏㄨㄢ
|
| 131 |
+
huang ㄏㄨㄤ
|
| 132 |
+
hui ㄏㄨㄟ
|
| 133 |
+
hun ㄏㄨㄣ
|
| 134 |
+
huo ㄏㄨㄛ
|
| 135 |
+
ji ㄐㄧ
|
| 136 |
+
jia ㄐㄧㄚ
|
| 137 |
+
jian ㄐㄧㄢ
|
| 138 |
+
jiang ㄐㄧㄤ
|
| 139 |
+
jiao ㄐㄧㄠ
|
| 140 |
+
jie ㄐㄧㄝ
|
| 141 |
+
jin ㄐㄧㄣ
|
| 142 |
+
jing ㄐㄧㄥ
|
| 143 |
+
jiong ㄐㄩㄥ
|
| 144 |
+
jiu ㄐㄧㄡ
|
| 145 |
+
ju ㄐㄩ
|
| 146 |
+
jv ㄐㄩ
|
| 147 |
+
juan ㄐㄩㄢ
|
| 148 |
+
jvan ㄐㄩㄢ
|
| 149 |
+
jue ㄐㄩㄝ
|
| 150 |
+
jve ㄐㄩㄝ
|
| 151 |
+
jun ㄐㄩㄣ
|
| 152 |
+
ka ㄎㄚ
|
| 153 |
+
kai ㄎㄞ
|
| 154 |
+
kan ㄎㄢ
|
| 155 |
+
kang ㄎㄤ
|
| 156 |
+
kao ㄎㄠ
|
| 157 |
+
ke ㄎㄜ
|
| 158 |
+
kei ㄎㄟ
|
| 159 |
+
ken ㄎㄣ
|
| 160 |
+
keng ㄎㄥ
|
| 161 |
+
kong ㄎㄨㄥ
|
| 162 |
+
kou ㄎㄡ
|
| 163 |
+
ku ㄎㄨ
|
| 164 |
+
kua ㄎㄨㄚ
|
| 165 |
+
kuai ㄎㄨㄞ
|
| 166 |
+
kuan ㄎㄨㄢ
|
| 167 |
+
kuang ㄎㄨㄤ
|
| 168 |
+
kui ㄎㄨㄟ
|
| 169 |
+
kun ㄎㄨㄣ
|
| 170 |
+
kuo ㄎㄨㄛ
|
| 171 |
+
la ㄌㄚ
|
| 172 |
+
lai ㄌㄞ
|
| 173 |
+
lan ㄌㄢ
|
| 174 |
+
lang ㄌㄤ
|
| 175 |
+
lao ㄌㄠ
|
| 176 |
+
le ㄌㄜ
|
| 177 |
+
lei ㄌㄟ
|
| 178 |
+
leng ㄌㄥ
|
| 179 |
+
li ㄌㄧ
|
| 180 |
+
lia ㄌㄧㄚ
|
| 181 |
+
lian ㄌㄧㄢ
|
| 182 |
+
liang ㄌㄧㄤ
|
| 183 |
+
liao ㄌㄧㄠ
|
| 184 |
+
lie ㄌㄧㄝ
|
| 185 |
+
lin ㄌㄧㄣ
|
| 186 |
+
ling ㄌㄧㄥ
|
| 187 |
+
liu ㄌㄧㄡ
|
| 188 |
+
lo ㄌㄛ
|
| 189 |
+
long ㄌㄨㄥ
|
| 190 |
+
lou ㄌㄡ
|
| 191 |
+
lu ㄌㄨ
|
| 192 |
+
luan ㄌㄨㄢ
|
| 193 |
+
lue ㄌㄩㄝ
|
| 194 |
+
lun ㄌㄨㄣ
|
| 195 |
+
luo ㄌㄨㄛ
|
| 196 |
+
lv ㄌㄩ
|
| 197 |
+
lve ㄌㄩㄝ
|
| 198 |
+
m ㄇㄨ
|
| 199 |
+
ma ㄇㄚ
|
| 200 |
+
mai ㄇㄞ
|
| 201 |
+
man ㄇㄢ
|
| 202 |
+
mang ㄇㄤ
|
| 203 |
+
mao ㄇㄠ
|
| 204 |
+
me ㄇㄜ
|
| 205 |
+
mei ㄇㄟ
|
| 206 |
+
men ㄇㄣ
|
| 207 |
+
meng ㄇㄥ
|
| 208 |
+
mi ㄇㄧ
|
| 209 |
+
mian ㄇㄧㄢ
|
| 210 |
+
miao ㄇㄧㄠ
|
| 211 |
+
mie ㄇㄧㄝ
|
| 212 |
+
min ㄇㄧㄣ
|
| 213 |
+
ming ㄇㄧㄥ
|
| 214 |
+
miu ㄇㄧㄡ
|
| 215 |
+
mo ㄇㄛ
|
| 216 |
+
mou ㄇㄡ
|
| 217 |
+
mu ㄇㄨ
|
| 218 |
+
n ㄣ
|
| 219 |
+
na ㄋㄚ
|
| 220 |
+
nai ㄋㄞ
|
| 221 |
+
nan ㄋㄢ
|
| 222 |
+
nang ㄋㄤ
|
| 223 |
+
nao ㄋㄠ
|
| 224 |
+
ne ㄋㄜ
|
| 225 |
+
nei ㄋㄟ
|
| 226 |
+
nen ㄋㄣ
|
| 227 |
+
neng ㄋㄥ
|
| 228 |
+
ng ㄣ
|
| 229 |
+
ni ㄋㄧ
|
| 230 |
+
nian ㄋㄧㄢ
|
| 231 |
+
niang ㄋㄧㄤ
|
| 232 |
+
niao ㄋㄧㄠ
|
| 233 |
+
nie ㄋㄧㄝ
|
| 234 |
+
nin ㄋㄧㄣ
|
| 235 |
+
ning ㄋㄧㄥ
|
| 236 |
+
niu ㄋㄧㄡ
|
| 237 |
+
nong ㄋㄨㄥ
|
| 238 |
+
nou ㄋㄡ
|
| 239 |
+
nu ㄋㄨ
|
| 240 |
+
nuan ㄋㄨㄢ
|
| 241 |
+
nue ㄋㄩㄝ
|
| 242 |
+
nun ㄋㄨㄣ
|
| 243 |
+
nuo ㄋㄨㄛ
|
| 244 |
+
nv ㄋㄩ
|
| 245 |
+
nve ㄋㄩㄝ
|
| 246 |
+
o ㄛ
|
| 247 |
+
ou ㄡ
|
| 248 |
+
pa ㄆㄚ
|
| 249 |
+
pai ㄆㄞ
|
| 250 |
+
pan ㄆㄢ
|
| 251 |
+
pang ㄆㄤ
|
| 252 |
+
pao ㄆㄠ
|
| 253 |
+
pei ㄆㄟ
|
| 254 |
+
pen ㄆㄣ
|
| 255 |
+
peng ㄆㄥ
|
| 256 |
+
pi ㄆㄧ
|
| 257 |
+
pian ㄆㄧㄢ
|
| 258 |
+
piao ㄆㄧㄠ
|
| 259 |
+
pie ㄆㄧㄝ
|
| 260 |
+
pin ㄆㄧㄣ
|
| 261 |
+
ping ㄆㄧㄥ
|
| 262 |
+
po ㄆㄛ
|
| 263 |
+
pou ㄆㄡ
|
| 264 |
+
pu ㄆㄨ
|
| 265 |
+
qi ㄑㄧ
|
| 266 |
+
qia ㄑㄧㄚ
|
| 267 |
+
qian ㄑㄧㄢ
|
| 268 |
+
qiang ㄑㄧㄤ
|
| 269 |
+
qiao ㄑㄧㄠ
|
| 270 |
+
qie ㄑㄧㄝ
|
| 271 |
+
qin ㄑㄧㄣ
|
| 272 |
+
qing ㄑㄧㄥ
|
| 273 |
+
qiong ㄑㄩㄥ
|
| 274 |
+
qiu ㄑㄧㄡ
|
| 275 |
+
qu ㄑㄩ
|
| 276 |
+
quan ㄑㄩㄢ
|
| 277 |
+
qvan ㄑㄩㄢ
|
| 278 |
+
que ㄑㄩㄝ
|
| 279 |
+
qun ㄑㄩㄣ
|
| 280 |
+
ran ㄖㄢ
|
| 281 |
+
rang ㄖㄤ
|
| 282 |
+
rao ㄖㄠ
|
| 283 |
+
re ㄖㄜ
|
| 284 |
+
ren ㄖㄣ
|
| 285 |
+
reng ㄖㄥ
|
| 286 |
+
ri ㄖ
|
| 287 |
+
rong ㄖㄨㄥ
|
| 288 |
+
rou ㄖㄡ
|
| 289 |
+
ru ㄖㄨ
|
| 290 |
+
rua ㄖㄨㄚ
|
| 291 |
+
ruan ㄖㄨㄢ
|
| 292 |
+
rui ㄖㄨㄟ
|
| 293 |
+
run ㄖㄨㄣ
|
| 294 |
+
ruo ㄖㄨㄛ
|
| 295 |
+
sa ㄙㄚ
|
| 296 |
+
sai ㄙㄞ
|
| 297 |
+
san ㄙㄢ
|
| 298 |
+
sang ㄙㄤ
|
| 299 |
+
sao ㄙㄠ
|
| 300 |
+
se ㄙㄜ
|
| 301 |
+
sen ㄙㄣ
|
| 302 |
+
seng ㄙㄥ
|
| 303 |
+
sha ㄕㄚ
|
| 304 |
+
shai ㄕㄞ
|
| 305 |
+
shan ㄕㄢ
|
| 306 |
+
shang ㄕㄤ
|
| 307 |
+
shao ㄕㄠ
|
| 308 |
+
she ㄕㄜ
|
| 309 |
+
shei ㄕㄟ
|
| 310 |
+
shen ㄕㄣ
|
| 311 |
+
sheng ㄕㄥ
|
| 312 |
+
shi ㄕ
|
| 313 |
+
shou ㄕㄡ
|
| 314 |
+
shu ㄕㄨ
|
| 315 |
+
shua ㄕㄨㄚ
|
| 316 |
+
shuai ㄕㄨㄞ
|
| 317 |
+
shuan ㄕㄨㄢ
|
| 318 |
+
shuang ㄕㄨㄤ
|
| 319 |
+
shui ㄕㄨㄟ
|
| 320 |
+
shun ㄕㄨㄣ
|
| 321 |
+
shuo ㄕㄨㄛ
|
| 322 |
+
si ㄙ
|
| 323 |
+
song ㄙㄨㄥ
|
| 324 |
+
sou ㄙㄡ
|
| 325 |
+
su ㄙㄨ
|
| 326 |
+
suan ㄙㄨㄢ
|
| 327 |
+
sui ㄙㄨㄟ
|
| 328 |
+
sun ㄙㄨㄣ
|
| 329 |
+
suo ㄙㄨㄛ
|
| 330 |
+
ta ㄊㄚ
|
| 331 |
+
tai ㄊㄞ
|
| 332 |
+
tan ㄊㄢ
|
| 333 |
+
tang ㄊㄤ
|
| 334 |
+
tao ㄊㄠ
|
| 335 |
+
te ㄊㄜ
|
| 336 |
+
tei ㄊㄟ
|
| 337 |
+
teng ㄊㄥ
|
| 338 |
+
ti ㄊㄧ
|
| 339 |
+
tian ㄊㄧㄢ
|
| 340 |
+
tiao ㄊㄧㄠ
|
| 341 |
+
tie ㄊㄧㄝ
|
| 342 |
+
ting ㄊㄧㄥ
|
| 343 |
+
tong ㄊㄨㄥ
|
| 344 |
+
tou ㄊㄡ
|
| 345 |
+
tsuo ㄘㄨㄛ
|
| 346 |
+
tu ㄊㄨ
|
| 347 |
+
tuan ㄊㄨㄢ
|
| 348 |
+
tui ㄊㄨㄟ
|
| 349 |
+
tun ㄊㄨㄣ
|
| 350 |
+
tuo ㄊㄨㄛ
|
| 351 |
+
tzan ㄗㄢ
|
| 352 |
+
wa ㄨㄚ
|
| 353 |
+
wai ㄨㄞ
|
| 354 |
+
wan ㄨㄢ
|
| 355 |
+
wang ㄨㄤ
|
| 356 |
+
wei ㄨㄟ
|
| 357 |
+
wen ㄨㄣ
|
| 358 |
+
weng ㄨㄥ
|
| 359 |
+
wo ㄨㄛ
|
| 360 |
+
wong ㄨㄥ
|
| 361 |
+
wu ㄨ
|
| 362 |
+
xi ㄒㄧ
|
| 363 |
+
xia ㄒㄧㄚ
|
| 364 |
+
xian ㄒㄧㄢ
|
| 365 |
+
xiang ㄒㄧㄤ
|
| 366 |
+
xiao ㄒㄧㄠ
|
| 367 |
+
xie ㄒㄧㄝ
|
| 368 |
+
xin ㄒㄧㄣ
|
| 369 |
+
xing ㄒㄧㄥ
|
| 370 |
+
xiong ㄒㄩㄥ
|
| 371 |
+
xiu ㄒㄧㄡ
|
| 372 |
+
xu ㄒㄩ
|
| 373 |
+
xuan ㄒㄩㄢ
|
| 374 |
+
xue ㄒㄩㄝ
|
| 375 |
+
xun ㄒㄩㄣ
|
| 376 |
+
ya ㄧㄚ
|
| 377 |
+
yai ㄧㄞ
|
| 378 |
+
yan ㄧㄢ
|
| 379 |
+
yang ㄧㄤ
|
| 380 |
+
yao ㄧㄠ
|
| 381 |
+
ye ㄧㄝ
|
| 382 |
+
yi ㄧ
|
| 383 |
+
yin ㄧㄣ
|
| 384 |
+
ying ㄧㄥ
|
| 385 |
+
yo ㄧㄛ
|
| 386 |
+
yong ㄩㄥ
|
| 387 |
+
you ㄧㄡ
|
| 388 |
+
yu ㄩ
|
| 389 |
+
yuan ㄩㄢ
|
| 390 |
+
yue ㄩㄝ
|
| 391 |
+
yve ㄩㄝ
|
| 392 |
+
yun ㄩㄣ
|
| 393 |
+
za ㄗㄚ
|
| 394 |
+
zai ㄗㄞ
|
| 395 |
+
zan ㄗㄢ
|
| 396 |
+
zang ㄗㄤ
|
| 397 |
+
zao ㄗㄠ
|
| 398 |
+
ze ㄗㄜ
|
| 399 |
+
zei ㄗㄟ
|
| 400 |
+
zen ㄗㄣ
|
| 401 |
+
zeng ㄗㄥ
|
| 402 |
+
zha ㄓㄚ
|
| 403 |
+
zhai ㄓㄞ
|
| 404 |
+
zhan ㄓㄢ
|
| 405 |
+
zhang ㄓㄤ
|
| 406 |
+
zhao ㄓㄠ
|
| 407 |
+
zhe ㄓㄜ
|
| 408 |
+
zhei ㄓㄟ
|
| 409 |
+
zhen ㄓㄣ
|
| 410 |
+
zheng ㄓㄥ
|
| 411 |
+
zhi ㄓ
|
| 412 |
+
zhong ㄓㄨㄥ
|
| 413 |
+
zhou ㄓㄡ
|
| 414 |
+
zhu ㄓㄨ
|
| 415 |
+
zhua ㄓㄨㄚ
|
| 416 |
+
zhuai ㄓㄨㄞ
|
| 417 |
+
zhuan ㄓㄨㄢ
|
| 418 |
+
zhuang ㄓㄨㄤ
|
| 419 |
+
zhui ㄓㄨㄟ
|
| 420 |
+
zhun ㄓㄨㄣ
|
| 421 |
+
zhuo ㄓㄨㄛ
|
| 422 |
+
zi ㄗ
|
| 423 |
+
zong ㄗㄨㄥ
|
| 424 |
+
zou ㄗㄡ
|
| 425 |
+
zu ㄗㄨ
|
| 426 |
+
zuan ㄗㄨㄢ
|
| 427 |
+
zui ㄗㄨㄟ
|
| 428 |
+
zun ㄗㄨㄣ
|
| 429 |
+
zuo ㄗㄨㄛ
|
src/YingMusicSinger/utils/f5_tts/g2p/utils/front_utils.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def generate_poly_lexicon(file_path: str):
|
| 8 |
+
"""Generate poly char lexicon for Mandarin Chinese."""
|
| 9 |
+
poly_dict = {}
|
| 10 |
+
|
| 11 |
+
with open(file_path, "r", encoding="utf-8") as readf:
|
| 12 |
+
txt_list = readf.readlines()
|
| 13 |
+
for txt in txt_list:
|
| 14 |
+
word = txt.strip("\n")
|
| 15 |
+
if word not in poly_dict:
|
| 16 |
+
poly_dict[word] = 1
|
| 17 |
+
readf.close()
|
| 18 |
+
return poly_dict
|
src/YingMusicSinger/utils/f5_tts/g2p/utils/g2p.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from typing import List, Union
|
| 9 |
+
|
| 10 |
+
from phonemizer.backend import EspeakBackend
|
| 11 |
+
from phonemizer.separator import Separator
|
| 12 |
+
from phonemizer.utils import list2str, str2list
|
| 13 |
+
|
| 14 |
+
# separator=Separator(phone=' ', word=' _ ', syllable='|'),
|
| 15 |
+
separator = Separator(word=" _ ", syllable="|", phone=" ")
|
| 16 |
+
|
| 17 |
+
phonemizer_zh = EspeakBackend(
|
| 18 |
+
"cmn", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
|
| 19 |
+
)
|
| 20 |
+
# phonemizer_zh.separator = separator
|
| 21 |
+
|
| 22 |
+
phonemizer_en = EspeakBackend(
|
| 23 |
+
"en-us",
|
| 24 |
+
preserve_punctuation=False,
|
| 25 |
+
with_stress=False,
|
| 26 |
+
language_switch="remove-flags",
|
| 27 |
+
)
|
| 28 |
+
# phonemizer_en.separator = separator
|
| 29 |
+
|
| 30 |
+
phonemizer_ja = EspeakBackend(
|
| 31 |
+
"ja", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
|
| 32 |
+
)
|
| 33 |
+
# phonemizer_ja.separator = separator
|
| 34 |
+
|
| 35 |
+
phonemizer_ko = EspeakBackend(
|
| 36 |
+
"ko", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
|
| 37 |
+
)
|
| 38 |
+
# phonemizer_ko.separator = separator
|
| 39 |
+
|
| 40 |
+
phonemizer_fr = EspeakBackend(
|
| 41 |
+
"fr-fr",
|
| 42 |
+
preserve_punctuation=False,
|
| 43 |
+
with_stress=False,
|
| 44 |
+
language_switch="remove-flags",
|
| 45 |
+
)
|
| 46 |
+
# phonemizer_fr.separator = separator
|
| 47 |
+
|
| 48 |
+
phonemizer_de = EspeakBackend(
|
| 49 |
+
"de", preserve_punctuation=False, with_stress=False, language_switch="remove-flags"
|
| 50 |
+
)
|
| 51 |
+
# phonemizer_de.separator = separator
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
lang2backend = {
|
| 55 |
+
"zh": phonemizer_zh,
|
| 56 |
+
"ja": phonemizer_ja,
|
| 57 |
+
"en": phonemizer_en,
|
| 58 |
+
"fr": phonemizer_fr,
|
| 59 |
+
"ko": phonemizer_ko,
|
| 60 |
+
"de": phonemizer_de,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
with open("./src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json", "r") as f:
|
| 64 |
+
json_data = f.read()
|
| 65 |
+
token = json.loads(json_data)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def phonemizer_g2p(text, language):
|
| 69 |
+
langbackend = lang2backend[language]
|
| 70 |
+
phonemes = _phonemize(
|
| 71 |
+
langbackend,
|
| 72 |
+
text,
|
| 73 |
+
separator,
|
| 74 |
+
strip=True,
|
| 75 |
+
njobs=1,
|
| 76 |
+
prepend_text=False,
|
| 77 |
+
preserve_empty_lines=False,
|
| 78 |
+
)
|
| 79 |
+
token_id = []
|
| 80 |
+
if isinstance(phonemes, list):
|
| 81 |
+
for phone in phonemes:
|
| 82 |
+
phonemes_split = phone.split(" ")
|
| 83 |
+
token_id.append([token[p] for p in phonemes_split if p in token])
|
| 84 |
+
else:
|
| 85 |
+
phonemes_split = phonemes.split(" ")
|
| 86 |
+
token_id = [token[p] for p in phonemes_split if p in token]
|
| 87 |
+
return phonemes, token_id
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _phonemize( # pylint: disable=too-many-arguments
|
| 91 |
+
backend,
|
| 92 |
+
text: Union[str, List[str]],
|
| 93 |
+
separator: Separator,
|
| 94 |
+
strip: bool,
|
| 95 |
+
njobs: int,
|
| 96 |
+
prepend_text: bool,
|
| 97 |
+
preserve_empty_lines: bool,
|
| 98 |
+
):
|
| 99 |
+
"""Auxiliary function to phonemize()
|
| 100 |
+
|
| 101 |
+
Does the phonemization and returns the phonemized text. Raises a
|
| 102 |
+
RuntimeError on error.
|
| 103 |
+
|
| 104 |
+
"""
|
| 105 |
+
# remember the text type for output (either list or string)
|
| 106 |
+
text_type = type(text)
|
| 107 |
+
|
| 108 |
+
# force the text as a list
|
| 109 |
+
text = [line.strip(os.linesep) for line in str2list(text)]
|
| 110 |
+
|
| 111 |
+
# if preserving empty lines, note the index of each empty line
|
| 112 |
+
if preserve_empty_lines:
|
| 113 |
+
empty_lines = [n for n, line in enumerate(text) if not line.strip()]
|
| 114 |
+
|
| 115 |
+
# ignore empty lines
|
| 116 |
+
text = [line for line in text if line.strip()]
|
| 117 |
+
|
| 118 |
+
if text:
|
| 119 |
+
# phonemize the text
|
| 120 |
+
phonemized = backend.phonemize(
|
| 121 |
+
text, separator=separator, strip=strip, njobs=njobs
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
phonemized = []
|
| 125 |
+
|
| 126 |
+
# if preserving empty lines, reinsert them into text and phonemized lists
|
| 127 |
+
if preserve_empty_lines:
|
| 128 |
+
for i in empty_lines: # noqa
|
| 129 |
+
if prepend_text:
|
| 130 |
+
text.insert(i, "")
|
| 131 |
+
phonemized.insert(i, "")
|
| 132 |
+
|
| 133 |
+
# at that point, the phonemized text is a list of str. Format it as
|
| 134 |
+
# expected by the parameters
|
| 135 |
+
if prepend_text:
|
| 136 |
+
return list(zip(text, phonemized))
|
| 137 |
+
if text_type == str:
|
| 138 |
+
return list2str(phonemized)
|
| 139 |
+
return phonemized
|
src/YingMusicSinger/utils/f5_tts/g2p/utils/log.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"logger",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Logger(object):
|
| 16 |
+
def __init__(self, name: str = None):
|
| 17 |
+
name = "PaddleSpeech" if not name else name
|
| 18 |
+
self.logger = logging.getLogger(name)
|
| 19 |
+
|
| 20 |
+
log_config = {
|
| 21 |
+
"DEBUG": 10,
|
| 22 |
+
"INFO": 20,
|
| 23 |
+
"TRAIN": 21,
|
| 24 |
+
"EVAL": 22,
|
| 25 |
+
"WARNING": 30,
|
| 26 |
+
"ERROR": 40,
|
| 27 |
+
"CRITICAL": 50,
|
| 28 |
+
"EXCEPTION": 100,
|
| 29 |
+
}
|
| 30 |
+
for key, level in log_config.items():
|
| 31 |
+
logging.addLevelName(level, key)
|
| 32 |
+
if key == "EXCEPTION":
|
| 33 |
+
self.__dict__[key.lower()] = self.logger.exception
|
| 34 |
+
else:
|
| 35 |
+
self.__dict__[key.lower()] = functools.partial(self.__call__, level)
|
| 36 |
+
|
| 37 |
+
self.format = logging.Formatter(
|
| 38 |
+
fmt="[%(asctime)-15s] [%(levelname)8s] - %(message)s"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
self.handler = logging.StreamHandler()
|
| 42 |
+
self.handler.setFormatter(self.format)
|
| 43 |
+
|
| 44 |
+
self.logger.addHandler(self.handler)
|
| 45 |
+
self.logger.setLevel(logging.INFO)
|
| 46 |
+
self.logger.propagate = False
|
| 47 |
+
|
| 48 |
+
def __call__(self, log_level: str, msg: str):
|
| 49 |
+
self.logger.log(log_level, msg)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
logger = Logger()
|
src/YingMusicSinger/utils/f5_tts/g2p/utils/mls_en.json
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[UNK]": 0,
|
| 3 |
+
"_": 1,
|
| 4 |
+
"b": 2,
|
| 5 |
+
"d": 3,
|
| 6 |
+
"f": 4,
|
| 7 |
+
"h": 5,
|
| 8 |
+
"i": 6,
|
| 9 |
+
"j": 7,
|
| 10 |
+
"k": 8,
|
| 11 |
+
"l": 9,
|
| 12 |
+
"m": 10,
|
| 13 |
+
"n": 11,
|
| 14 |
+
"p": 12,
|
| 15 |
+
"r": 13,
|
| 16 |
+
"s": 14,
|
| 17 |
+
"t": 15,
|
| 18 |
+
"v": 16,
|
| 19 |
+
"w": 17,
|
| 20 |
+
"x": 18,
|
| 21 |
+
"z": 19,
|
| 22 |
+
"æ": 20,
|
| 23 |
+
"ç": 21,
|
| 24 |
+
"ð": 22,
|
| 25 |
+
"ŋ": 23,
|
| 26 |
+
"ɐ": 24,
|
| 27 |
+
"ɔ": 25,
|
| 28 |
+
"ə": 26,
|
| 29 |
+
"ɚ": 27,
|
| 30 |
+
"ɛ": 28,
|
| 31 |
+
"ɡ": 29,
|
| 32 |
+
"ɪ": 30,
|
| 33 |
+
"ɬ": 31,
|
| 34 |
+
"ɹ": 32,
|
| 35 |
+
"ɾ": 33,
|
| 36 |
+
"ʃ": 34,
|
| 37 |
+
"ʊ": 35,
|
| 38 |
+
"ʌ": 36,
|
| 39 |
+
"ʒ": 37,
|
| 40 |
+
"ʔ": 38,
|
| 41 |
+
"θ": 39,
|
| 42 |
+
"ᵻ": 40,
|
| 43 |
+
"aɪ": 41,
|
| 44 |
+
"aʊ": 42,
|
| 45 |
+
"dʒ": 43,
|
| 46 |
+
"eɪ": 44,
|
| 47 |
+
"iə": 45,
|
| 48 |
+
"iː": 46,
|
| 49 |
+
"n̩": 47,
|
| 50 |
+
"oʊ": 48,
|
| 51 |
+
"oː": 49,
|
| 52 |
+
"tʃ": 50,
|
| 53 |
+
"uː": 51,
|
| 54 |
+
"ææ": 52,
|
| 55 |
+
"ɐɐ": 53,
|
| 56 |
+
"ɑː": 54,
|
| 57 |
+
"ɑ̃": 55,
|
| 58 |
+
"ɔɪ": 56,
|
| 59 |
+
"ɔː": 57,
|
| 60 |
+
"ɔ̃": 58,
|
| 61 |
+
"əl": 59,
|
| 62 |
+
"ɛɹ": 60,
|
| 63 |
+
"ɜː": 61,
|
| 64 |
+
"ɡʲ": 62,
|
| 65 |
+
"ɪɹ": 63,
|
| 66 |
+
"ʊɹ": 64,
|
| 67 |
+
"aɪə": 65,
|
| 68 |
+
"aɪɚ": 66,
|
| 69 |
+
"iːː": 67,
|
| 70 |
+
"oːɹ": 68,
|
| 71 |
+
"ɑːɹ": 69,
|
| 72 |
+
"ɔːɹ": 70,
|
| 73 |
+
|
| 74 |
+
"1": 71,
|
| 75 |
+
"a": 72,
|
| 76 |
+
"e": 73,
|
| 77 |
+
"o": 74,
|
| 78 |
+
"q": 75,
|
| 79 |
+
"u": 76,
|
| 80 |
+
"y": 77,
|
| 81 |
+
"ɑ": 78,
|
| 82 |
+
"ɒ": 79,
|
| 83 |
+
"ɕ": 80,
|
| 84 |
+
"ɣ": 81,
|
| 85 |
+
"ɫ": 82,
|
| 86 |
+
"ɯ": 83,
|
| 87 |
+
"ʐ": 84,
|
| 88 |
+
"ʲ": 85,
|
| 89 |
+
"a1": 86,
|
| 90 |
+
"a2": 87,
|
| 91 |
+
"a5": 88,
|
| 92 |
+
"ai": 89,
|
| 93 |
+
"aɜ": 90,
|
| 94 |
+
"aː": 91,
|
| 95 |
+
"ei": 92,
|
| 96 |
+
"eə": 93,
|
| 97 |
+
"i.": 94,
|
| 98 |
+
"i1": 95,
|
| 99 |
+
"i2": 96,
|
| 100 |
+
"i5": 97,
|
| 101 |
+
"io": 98,
|
| 102 |
+
"iɑ": 99,
|
| 103 |
+
"iɛ": 100,
|
| 104 |
+
"iɜ": 101,
|
| 105 |
+
"i̪": 102,
|
| 106 |
+
"kh": 103,
|
| 107 |
+
"nʲ": 104,
|
| 108 |
+
"o1": 105,
|
| 109 |
+
"o2": 106,
|
| 110 |
+
"o5": 107,
|
| 111 |
+
"ou": 108,
|
| 112 |
+
"oɜ": 109,
|
| 113 |
+
"ph": 110,
|
| 114 |
+
"s.": 111,
|
| 115 |
+
"th": 112,
|
| 116 |
+
"ts": 113,
|
| 117 |
+
"tɕ": 114,
|
| 118 |
+
"u1": 115,
|
| 119 |
+
"u2": 116,
|
| 120 |
+
"u5": 117,
|
| 121 |
+
"ua": 118,
|
| 122 |
+
"uo": 119,
|
| 123 |
+
"uə": 120,
|
| 124 |
+
"uɜ": 121,
|
| 125 |
+
"y1": 122,
|
| 126 |
+
"y2": 123,
|
| 127 |
+
"y5": 124,
|
| 128 |
+
"yu": 125,
|
| 129 |
+
"yæ": 126,
|
| 130 |
+
"yə": 127,
|
| 131 |
+
"yɛ": 128,
|
| 132 |
+
"yɜ": 129,
|
| 133 |
+
"ŋɜ": 130,
|
| 134 |
+
"ŋʲ": 131,
|
| 135 |
+
"ɑ1": 132,
|
| 136 |
+
"ɑ2": 133,
|
| 137 |
+
"ɑ5": 134,
|
| 138 |
+
"ɑu": 135,
|
| 139 |
+
"ɑɜ": 136,
|
| 140 |
+
"ɑʲ": 137,
|
| 141 |
+
"ə1": 138,
|
| 142 |
+
"ə2": 139,
|
| 143 |
+
"ə5": 140,
|
| 144 |
+
"ər": 141,
|
| 145 |
+
"əɜ": 142,
|
| 146 |
+
"əʊ": 143,
|
| 147 |
+
"ʊə": 144,
|
| 148 |
+
"ai1": 145,
|
| 149 |
+
"ai2": 146,
|
| 150 |
+
"ai5": 147,
|
| 151 |
+
"aiɜ": 148,
|
| 152 |
+
"ei1": 149,
|
| 153 |
+
"ei2": 150,
|
| 154 |
+
"ei5": 151,
|
| 155 |
+
"eiɜ": 152,
|
| 156 |
+
"i.1": 153,
|
| 157 |
+
"i.2": 154,
|
| 158 |
+
"i.5": 155,
|
| 159 |
+
"i.ɜ": 156,
|
| 160 |
+
"io5": 157,
|
| 161 |
+
"iou": 158,
|
| 162 |
+
"iɑ1": 159,
|
| 163 |
+
"iɑ2": 160,
|
| 164 |
+
"iɑ5": 161,
|
| 165 |
+
"iɑɜ": 162,
|
| 166 |
+
"iɛ1": 163,
|
| 167 |
+
"iɛ2": 164,
|
| 168 |
+
"iɛ5": 165,
|
| 169 |
+
"iɛɜ": 166,
|
| 170 |
+
"i̪1": 167,
|
| 171 |
+
"i̪2": 168,
|
| 172 |
+
"i̪5": 169,
|
| 173 |
+
"i̪ɜ": 170,
|
| 174 |
+
"onɡ": 171,
|
| 175 |
+
"ou1": 172,
|
| 176 |
+
"ou2": 173,
|
| 177 |
+
"ou5": 174,
|
| 178 |
+
"ouɜ": 175,
|
| 179 |
+
"ts.": 176,
|
| 180 |
+
"tsh": 177,
|
| 181 |
+
"tɕh": 178,
|
| 182 |
+
"u5ʲ": 179,
|
| 183 |
+
"ua1": 180,
|
| 184 |
+
"ua2": 181,
|
| 185 |
+
"ua5": 182,
|
| 186 |
+
"uai": 183,
|
| 187 |
+
"uaɜ": 184,
|
| 188 |
+
"uei": 185,
|
| 189 |
+
"uo1": 186,
|
| 190 |
+
"uo2": 187,
|
| 191 |
+
"uo5": 188,
|
| 192 |
+
"uoɜ": 189,
|
| 193 |
+
"uə1": 190,
|
| 194 |
+
"uə2": 191,
|
| 195 |
+
"uə5": 192,
|
| 196 |
+
"uəɜ": 193,
|
| 197 |
+
"yiɜ": 194,
|
| 198 |
+
"yu2": 195,
|
| 199 |
+
"yu5": 196,
|
| 200 |
+
"yæ2": 197,
|
| 201 |
+
"yæ5": 198,
|
| 202 |
+
"yæɜ": 199,
|
| 203 |
+
"yə2": 200,
|
| 204 |
+
"yə5": 201,
|
| 205 |
+
"yəɜ": 202,
|
| 206 |
+
"yɛ1": 203,
|
| 207 |
+
"yɛ2": 204,
|
| 208 |
+
"yɛ5": 205,
|
| 209 |
+
"yɛɜ": 206,
|
| 210 |
+
"ɑu1": 207,
|
| 211 |
+
"ɑu2": 208,
|
| 212 |
+
"ɑu5": 209,
|
| 213 |
+
"ɑuɜ": 210,
|
| 214 |
+
"ər1": 211,
|
| 215 |
+
"ər2": 212,
|
| 216 |
+
"ər5": 213,
|
| 217 |
+
"ərɜ": 214,
|
| 218 |
+
"əː1": 215,
|
| 219 |
+
"iou1": 216,
|
| 220 |
+
"iou2": 217,
|
| 221 |
+
"iou5": 218,
|
| 222 |
+
"iouɜ": 219,
|
| 223 |
+
"onɡ1": 220,
|
| 224 |
+
"onɡ2": 221,
|
| 225 |
+
"onɡ5": 222,
|
| 226 |
+
"onɡɜ": 223,
|
| 227 |
+
"ts.h": 224,
|
| 228 |
+
"uai2": 225,
|
| 229 |
+
"uai5": 226,
|
| 230 |
+
"uaiɜ": 227,
|
| 231 |
+
"uei1": 228,
|
| 232 |
+
"uei2": 229,
|
| 233 |
+
"uei5": 230,
|
| 234 |
+
"ueiɜ": 231,
|
| 235 |
+
"uoɜʲ": 232,
|
| 236 |
+
"yɛ5ʲ": 233,
|
| 237 |
+
"ɑu2ʲ": 234,
|
| 238 |
+
|
| 239 |
+
"2": 235,
|
| 240 |
+
"5": 236,
|
| 241 |
+
"ɜ": 237,
|
| 242 |
+
"ʂ": 238,
|
| 243 |
+
"dʑ": 239,
|
| 244 |
+
"iɪ": 240,
|
| 245 |
+
"uɪ": 241,
|
| 246 |
+
"xʲ": 242,
|
| 247 |
+
"ɑt": 243,
|
| 248 |
+
"ɛɜ": 244,
|
| 249 |
+
"ɛː": 245,
|
| 250 |
+
"ɪː": 246,
|
| 251 |
+
"phʲ": 247,
|
| 252 |
+
"ɑ5ʲ": 248,
|
| 253 |
+
"ɑuʲ": 249,
|
| 254 |
+
"ərə": 250,
|
| 255 |
+
"uozʰ": 251,
|
| 256 |
+
"ər1ʲ": 252,
|
| 257 |
+
"tɕhtɕh": 253,
|
| 258 |
+
|
| 259 |
+
"c": 254,
|
| 260 |
+
"ʋ": 255,
|
| 261 |
+
"ʍ": 256,
|
| 262 |
+
"ʑ": 257,
|
| 263 |
+
"ː": 258,
|
| 264 |
+
"aə": 259,
|
| 265 |
+
"eː": 260,
|
| 266 |
+
"hʲ": 261,
|
| 267 |
+
"iʊ": 262,
|
| 268 |
+
"kʲ": 263,
|
| 269 |
+
"lʲ": 264,
|
| 270 |
+
"oə": 265,
|
| 271 |
+
"oɪ": 266,
|
| 272 |
+
"oʲ": 267,
|
| 273 |
+
"pʲ": 268,
|
| 274 |
+
"sʲ": 269,
|
| 275 |
+
"u4": 270,
|
| 276 |
+
"uʲ": 271,
|
| 277 |
+
"yi": 272,
|
| 278 |
+
"yʲ": 273,
|
| 279 |
+
"ŋ2": 274,
|
| 280 |
+
"ŋ5": 275,
|
| 281 |
+
"ŋ̩": 276,
|
| 282 |
+
"ɑɪ": 277,
|
| 283 |
+
"ɑʊ": 278,
|
| 284 |
+
"ɕʲ": 279,
|
| 285 |
+
"ət": 280,
|
| 286 |
+
"əə": 281,
|
| 287 |
+
"əɪ": 282,
|
| 288 |
+
"əʲ": 283,
|
| 289 |
+
"ɛ1": 284,
|
| 290 |
+
"ɛ5": 285,
|
| 291 |
+
"aiə": 286,
|
| 292 |
+
"aiɪ": 287,
|
| 293 |
+
"azʰ": 288,
|
| 294 |
+
"eiə": 289,
|
| 295 |
+
"eiɪ": 290,
|
| 296 |
+
"eiʊ": 291,
|
| 297 |
+
"i.ə": 292,
|
| 298 |
+
"i.ɪ": 293,
|
| 299 |
+
"i.ʊ": 294,
|
| 300 |
+
"ioɜ": 295,
|
| 301 |
+
"izʰ": 296,
|
| 302 |
+
"iɑə": 297,
|
| 303 |
+
"iɑʊ": 298,
|
| 304 |
+
"iɑʲ": 299,
|
| 305 |
+
"iɛə": 300,
|
| 306 |
+
"iɛɪ": 301,
|
| 307 |
+
"iɛʊ": 302,
|
| 308 |
+
"i̪ə": 303,
|
| 309 |
+
"i̪ʊ": 304,
|
| 310 |
+
"khʲ": 305,
|
| 311 |
+
"ouʲ": 306,
|
| 312 |
+
"tsʲ": 307,
|
| 313 |
+
"u2ʲ": 308,
|
| 314 |
+
"uoɪ": 309,
|
| 315 |
+
"uzʰ": 310,
|
| 316 |
+
"uɜʲ": 311,
|
| 317 |
+
"yæɪ": 312,
|
| 318 |
+
"yəʊ": 313,
|
| 319 |
+
"ərt": 314,
|
| 320 |
+
"ərɪ": 315,
|
| 321 |
+
"ərʲ": 316,
|
| 322 |
+
"əːt": 317,
|
| 323 |
+
"iouə": 318,
|
| 324 |
+
"iouʊ": 319,
|
| 325 |
+
"iouʲ": 320,
|
| 326 |
+
"iɛzʰ": 321,
|
| 327 |
+
"onɡə": 322,
|
| 328 |
+
"onɡɪ": 323,
|
| 329 |
+
"onɡʊ": 324,
|
| 330 |
+
"ouzʰ": 325,
|
| 331 |
+
"uai1": 326,
|
| 332 |
+
"ueiɪ": 327,
|
| 333 |
+
"ɑuzʰ": 328,
|
| 334 |
+
"iouzʰ": 329
|
| 335 |
+
}
|
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/LangSegment.py
ADDED
|
@@ -0,0 +1,1251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file bundles language identification functions.
|
| 3 |
+
|
| 4 |
+
Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
|
| 5 |
+
|
| 6 |
+
Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
|
| 7 |
+
Based on research by Marco Lui and Tim Baldwin.
|
| 8 |
+
|
| 9 |
+
See LICENSE file for more info.
|
| 10 |
+
https://github.com/adbar/py3langid
|
| 11 |
+
|
| 12 |
+
Projects:
|
| 13 |
+
https://github.com/juntaosun/LangSegment
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import re
|
| 17 |
+
from collections import Counter, defaultdict
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
# import langid
|
| 22 |
+
# import py3langid as langid
|
| 23 |
+
# pip install py3langid==0.2.2
|
| 24 |
+
# 启用语言预测概率归一化,概率预测的分数。因此,实现重新规范化 产生 0-1 范围内的输出。
|
| 25 |
+
# langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
|
| 26 |
+
# For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
|
| 27 |
+
from py3langid.langid import MODEL_FILE, LanguageIdentifier
|
| 28 |
+
|
| 29 |
+
langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
|
| 30 |
+
|
| 31 |
+
# Digital processing
|
| 32 |
+
try:
|
| 33 |
+
from src.YingMusicSinger.utils.f5_tts.thirdparty.LangSegment.utils.num import (
|
| 34 |
+
num2str,
|
| 35 |
+
)
|
| 36 |
+
except ImportError:
|
| 37 |
+
try:
|
| 38 |
+
from thirdparty.LangSegment.utils.num import num2str
|
| 39 |
+
except ImportError as e:
|
| 40 |
+
raise e
|
| 41 |
+
|
| 42 |
+
# -----------------------------------
|
| 43 |
+
# 更新日志:新版本分词更加精准。
|
| 44 |
+
# Changelog: The new version of the word segmentation is more accurate.
|
| 45 |
+
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
| 46 |
+
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
| 47 |
+
# -----------------------------------
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# Word segmentation function:
|
| 51 |
+
# automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
|
| 52 |
+
# making it more suitable for TTS processing.
|
| 53 |
+
# This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
|
| 54 |
+
# This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
|
| 55 |
+
|
| 56 |
+
# ===========================================================================================================
|
| 57 |
+
# 分かち書き機能:文章や文章の中の例えば(中国語/英語/日本語/韓国語)を、異なる言語で自動的に認識して分割し、TTS処理により適したものにします。
|
| 58 |
+
# このコードは、さまざまなTTSプロジェクトのフロントエンドテキストの多言語混合注釈区別、多言語混合トレーニング、および推論のために特別に作成されています。
|
| 59 |
+
# ===========================================================================================================
|
| 60 |
+
# (1)自動分詞:「韓国語では何を読むのですかあなたの体育の先生は誰ですか?今回の発表会では、iPhone 15シリーズの4機種が登場しました」
|
| 61 |
+
# (2)手动分词:“あなたの名前は<ja>佐々木ですか?<ja>ですか?”
|
| 62 |
+
# この処理結果は主に(中国語=ja、日本語=ja、英語=en、韓国語=ko)を対象としており、実際には最大97の異なる言語の混合処理をサポートできます。
|
| 63 |
+
# ===========================================================================================================
|
| 64 |
+
|
| 65 |
+
# ===========================================================================================================
|
| 66 |
+
# 단어 분할 기능: 기사 또는 문장에서 단어(중국어/영어/일본어/한국어)를 다른 언어에 따라 자동으로 식별하고 분할하여 TTS 처리에 더 적합합니다.
|
| 67 |
+
# 이 코드는 프런트 엔드 텍스트 다국어 혼합 주석 분화, 다국어 혼합 교육 및 다양한 TTS 프로젝트의 추론을 위해 설계되었습니다.
|
| 68 |
+
# ===========================================================================================================
|
| 69 |
+
# (1) 자동 단어 분할: "한국어로 무엇을 읽습니까? 스포츠 씨? 이 컨퍼런스는 4개의 iPhone 15 시리즈 모델을 제공합니다."
|
| 70 |
+
# (2) 수동 참여: "이름이 <ja>Saki입니까? <ja>?"
|
| 71 |
+
# 이 처리 결과는 주로 (중국어 = zh, 일본어 = ja, 영어 = en, 한국어 = ko)를 위한 것이며 실제로 혼합 처리를 위해 최대 97개의 언어를 지원합니다.
|
| 72 |
+
# ===========================================================================================================
|
| 73 |
+
|
| 74 |
+
# ===========================================================================================================
|
| 75 |
+
# 分词功能:将文章或句子里的例如(中/英/日/韩),按不同语言自动识别并拆分,让它更适合TTS处理。
|
| 76 |
+
# 本代码专为各种 TTS 项目的前端文本多语种混合标注区分,多语言混合训练和推理而编写。
|
| 77 |
+
# ===========================================================================================================
|
| 78 |
+
# (1)自动分词:“韩语中的오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型”
|
| 79 |
+
# (2)手动分词:“你的名字叫<ja>佐々木?<ja>吗?”
|
| 80 |
+
# 本处理结果主要针对(中文=zh , 日文=ja , 英文=en , 韩语=ko), 实际上可支持多达 97 种不同的语言混合处理。
|
| 81 |
+
# ===========================================================================================================
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# 手动分词标签规范:<语言标签>文本内容</语言标签>
|
| 85 |
+
# 수동 단어 분할 태그 사양: <언어 태그> 텍스트 내용</언어 태그>
|
| 86 |
+
# Manual word segmentation tag specification: <language tags> text content </language tags>
|
| 87 |
+
# 手動分詞タグ仕様:<言語タグ>テキスト内容</言語タグ>
|
| 88 |
+
# ===========================================================================================================
|
| 89 |
+
# For manual word segmentation, labels need to appear in pairs, such as:
|
| 90 |
+
# 如需手动分词,标签需要成对出现,例如:“<ja>佐々木<ja>” 或者 “<ja>佐々木</ja>”
|
| 91 |
+
# 错误示范:“你的名字叫<ja>佐々木。” 此句子中出现的单个<ja>标签将被忽略,不会处理。
|
| 92 |
+
# Error demonstration: "Your name is <ja>佐々木。" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
|
| 93 |
+
# ===========================================================================================================
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ===========================================================================================================
|
| 97 |
+
# 语音合成标记语言 SSML , 这里只支持它的标签(非 XML)Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
|
| 98 |
+
# 想支持更多的 SSML 标签?欢迎 PR! Want to support more SSML tags? PRs are welcome!
|
| 99 |
+
# 说明:除了中文以外,它也可改造成支持多语种 SSML ,不仅仅是中文。
|
| 100 |
+
# Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
|
| 101 |
+
# ===========================================================================================================
|
| 102 |
+
# 中文实现:Chinese implementation:
|
| 103 |
+
# 【SSML】<number>=中文大写数字读法(单字)
|
| 104 |
+
# 【SSML】<telephone>=数字转成中文电话号码大写汉字(单字)
|
| 105 |
+
# 【SSML】<currency>=按金额发音。
|
| 106 |
+
# 【SSML】<date>=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
| 107 |
+
# ===========================================================================================================
|
| 108 |
+
class LangSSML:
|
| 109 |
+
# 纯数字
|
| 110 |
+
_zh_numerals_number = {
|
| 111 |
+
"0": "零",
|
| 112 |
+
"1": "一",
|
| 113 |
+
"2": "二",
|
| 114 |
+
"3": "三",
|
| 115 |
+
"4": "四",
|
| 116 |
+
"5": "五",
|
| 117 |
+
"6": "六",
|
| 118 |
+
"7": "七",
|
| 119 |
+
"8": "八",
|
| 120 |
+
"9": "九",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# 将2024/8/24, 2024-08, 08-24, 24 标准化“年月日”
|
| 124 |
+
# Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
|
| 125 |
+
def _format_chinese_data(date_str: str):
|
| 126 |
+
# 处理日期格式
|
| 127 |
+
input_date = date_str
|
| 128 |
+
if date_str is None or date_str.strip() == "":
|
| 129 |
+
return ""
|
| 130 |
+
date_str = re.sub(r"[\/\._|年|月]", "-", date_str)
|
| 131 |
+
date_str = re.sub(r"日", r"", date_str)
|
| 132 |
+
date_arrs = date_str.split(" ")
|
| 133 |
+
if len(date_arrs) == 1 and ":" in date_arrs[0]:
|
| 134 |
+
time_str = date_arrs[0]
|
| 135 |
+
date_arrs = []
|
| 136 |
+
else:
|
| 137 |
+
time_str = date_arrs[1] if len(date_arrs) >= 2 else ""
|
| 138 |
+
|
| 139 |
+
def nonZero(num, cn, func=None):
|
| 140 |
+
if func is not None:
|
| 141 |
+
num = func(num)
|
| 142 |
+
return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
|
| 143 |
+
|
| 144 |
+
f_number = LangSSML.to_chinese_number
|
| 145 |
+
f_currency = LangSSML.to_chinese_currency
|
| 146 |
+
# year, month, day
|
| 147 |
+
year_month_day = ""
|
| 148 |
+
if len(date_arrs) > 0:
|
| 149 |
+
year, month, day = "", "", ""
|
| 150 |
+
parts = date_arrs[0].split("-")
|
| 151 |
+
if len(parts) == 3: # 格式为 YYYY-MM-DD
|
| 152 |
+
year, month, day = parts
|
| 153 |
+
elif len(parts) == 2: # 格式为 MM-DD 或 YYYY-MM
|
| 154 |
+
if len(parts[0]) == 4: # 年-月
|
| 155 |
+
year, month = parts
|
| 156 |
+
else:
|
| 157 |
+
month, day = parts # 月-日
|
| 158 |
+
elif len(parts[0]) > 0: # 仅有月-日或年
|
| 159 |
+
if len(parts[0]) == 4:
|
| 160 |
+
year = parts[0]
|
| 161 |
+
else:
|
| 162 |
+
day = parts[0]
|
| 163 |
+
year, month, day = (
|
| 164 |
+
nonZero(year, "年", f_number),
|
| 165 |
+
nonZero(month, "月", f_currency),
|
| 166 |
+
nonZero(day, "日", f_currency),
|
| 167 |
+
)
|
| 168 |
+
year_month_day = re.sub(r"([年|月|日])+", r"\1", f"{year}{month}{day}")
|
| 169 |
+
# hours, minutes, seconds
|
| 170 |
+
time_str = re.sub(r"[\/\.\-:_]", ":", time_str)
|
| 171 |
+
time_arrs = time_str.split(":")
|
| 172 |
+
hours, minutes, seconds = "", "", ""
|
| 173 |
+
if len(time_arrs) == 3: # H/M/S
|
| 174 |
+
hours, minutes, seconds = time_arrs
|
| 175 |
+
elif len(time_arrs) == 2: # H/M
|
| 176 |
+
hours, minutes = time_arrs
|
| 177 |
+
elif len(time_arrs[0]) > 0:
|
| 178 |
+
hours = f"{time_arrs[0]}点" # H
|
| 179 |
+
if len(time_arrs) > 1:
|
| 180 |
+
hours, minutes, seconds = (
|
| 181 |
+
nonZero(hours, "点", f_currency),
|
| 182 |
+
nonZero(minutes, "分", f_currency),
|
| 183 |
+
nonZero(seconds, "秒", f_currency),
|
| 184 |
+
)
|
| 185 |
+
hours_minutes_seconds = re.sub(
|
| 186 |
+
r"([点|分|秒])+", r"\1", f"{hours}{minutes}{seconds}"
|
| 187 |
+
)
|
| 188 |
+
output_date = f"{year_month_day}{hours_minutes_seconds}"
|
| 189 |
+
return output_date
|
| 190 |
+
|
| 191 |
+
# 【SSML】number=中文大写数字读法(单字)
|
| 192 |
+
# Chinese Numbers(single word)
|
| 193 |
+
def to_chinese_number(num: str):
|
| 194 |
+
pattern = r"(\d+)"
|
| 195 |
+
zh_numerals = LangSSML._zh_numerals_number
|
| 196 |
+
arrs = re.split(pattern, num)
|
| 197 |
+
output = ""
|
| 198 |
+
for item in arrs:
|
| 199 |
+
if re.match(pattern, item):
|
| 200 |
+
output += "".join(
|
| 201 |
+
zh_numerals[digit] if digit in zh_numerals else ""
|
| 202 |
+
for digit in str(item)
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
output += item
|
| 206 |
+
output = output.replace(".", "点")
|
| 207 |
+
return output
|
| 208 |
+
|
| 209 |
+
# 【SSML】telephone=数字转成中文电话号码大写汉字(单字)
|
| 210 |
+
# Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
|
| 211 |
+
def to_chinese_telephone(num: str):
|
| 212 |
+
output = LangSSML.to_chinese_number(num.replace("+86", "")) # zh +86
|
| 213 |
+
output = output.replace("一", "幺")
|
| 214 |
+
return output
|
| 215 |
+
|
| 216 |
+
# 【SSML】currency=按金额发音。
|
| 217 |
+
# Digital processing from GPT_SoVITS num.py (thanks)
|
| 218 |
+
def to_chinese_currency(num: str):
|
| 219 |
+
pattern = r"(\d+)"
|
| 220 |
+
arrs = re.split(pattern, num)
|
| 221 |
+
output = ""
|
| 222 |
+
for item in arrs:
|
| 223 |
+
if re.match(pattern, item):
|
| 224 |
+
output += num2str(item)
|
| 225 |
+
else:
|
| 226 |
+
output += item
|
| 227 |
+
output = output.replace(".", "点")
|
| 228 |
+
return output
|
| 229 |
+
|
| 230 |
+
# 【SSML】date=按日期发音。支持 2024年08月24, 2024/8/24, 2024-08, 08-24, 24 等输入。
|
| 231 |
+
def to_chinese_date(num: str):
|
| 232 |
+
chinese_date = LangSSML._format_chinese_data(num)
|
| 233 |
+
return chinese_date
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class LangSegment:
|
| 237 |
+
_text_cache = None
|
| 238 |
+
_text_lasts = None
|
| 239 |
+
_text_langs = None
|
| 240 |
+
_lang_count = None
|
| 241 |
+
_lang_eos = None
|
| 242 |
+
|
| 243 |
+
# 可自定义语言匹配标签:カスタマイズ可能な言語対応タグ:사용자 지정 가능한 언어 일치 태그:
|
| 244 |
+
# Customizable language matching tags: These are supported,이 표현들은 모두 지지합니다
|
| 245 |
+
# <zh>你好<zh> , <ja>佐々木</ja> , <en>OK<en> , <ko>오빠</ko> 这些写法均支持
|
| 246 |
+
SYMBOLS_PATTERN = r"(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)"
|
| 247 |
+
|
| 248 |
+
# 语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
|
| 249 |
+
# 언어 필터 그룹 기능을 사용하면 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
|
| 250 |
+
# 言語フィルターグループ機能では、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
|
| 251 |
+
# The language filter group function allows you to specify reserved languages.
|
| 252 |
+
# Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
|
| 253 |
+
# 排名越前,优先级越高,The higher the ranking, the higher the priority,ランキングが上位になるほど、優先度が高くなります。
|
| 254 |
+
|
| 255 |
+
# 系统默认过滤器。System default filter。(ISO 639-1 codes given)
|
| 256 |
+
# ----------------------------------------------------------------------------------------------------------------------------------
|
| 257 |
+
# "zh"中文=Chinese ,"en"英语=English ,"ja"日语=Japanese ,"ko"韩语=Korean ,"fr"法语=French ,"vi"越南语=Vietnamese , "ru"俄语=Russian
|
| 258 |
+
# "th"泰语=Thai
|
| 259 |
+
# ----------------------------------------------------------------------------------------------------------------------------------
|
| 260 |
+
DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
|
| 261 |
+
|
| 262 |
+
# 用户可自定义过滤器。User-defined filters
|
| 263 |
+
Langfilters = DEFAULT_FILTERS[:] # 创建副本
|
| 264 |
+
|
| 265 |
+
# 合并文本
|
| 266 |
+
isLangMerge = True
|
| 267 |
+
|
| 268 |
+
# 试验性支持:您可自定义添加:"fr"法语 , "vi"越南语。Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
|
| 269 |
+
# 请使用API启用:LangSegment.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # 您可自定义添加,如:"fr"法语 , "vi"越南语。
|
| 270 |
+
|
| 271 |
+
# 预览版功能,自动启用或禁用,无需设置
|
| 272 |
+
# Preview feature, automatically enabled or disabled, no settings required
|
| 273 |
+
EnablePreview = False
|
| 274 |
+
|
| 275 |
+
# 除此以外,它支持简写过滤器,只需按不同语种任意组合即可。
|
| 276 |
+
# In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
|
| 277 |
+
# 示例:您可以任意指定多种组合,进行过滤
|
| 278 |
+
# Example: You can specify any combination to filter
|
| 279 |
+
|
| 280 |
+
# 中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
|
| 281 |
+
# 중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
|
| 282 |
+
# 中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
|
| 283 |
+
# Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
|
| 284 |
+
# Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
|
| 285 |
+
LangPriorityThreshold = 0.89
|
| 286 |
+
|
| 287 |
+
# Langfilters = ["zh"] # 按中文识别
|
| 288 |
+
# Langfilters = ["en"] # 按英文识别
|
| 289 |
+
# Langfilters = ["ja"] # 按日文识别
|
| 290 |
+
# Langfilters = ["ko"] # 按韩文识别
|
| 291 |
+
# Langfilters = ["zh_ja"] # 中日混合识别
|
| 292 |
+
# Langfilters = ["zh_en"] # 中英混合识别
|
| 293 |
+
# Langfilters = ["ja_en"] # 日英混合识别
|
| 294 |
+
# Langfilters = ["zh_ko"] # 中韩混合识别
|
| 295 |
+
# Langfilters = ["ja_ko"] # 日韩混合识别
|
| 296 |
+
# Langfilters = ["en_ko"] # 英韩混合识别
|
| 297 |
+
# Langfilters = ["zh_ja_en"] # 中日英混合识别
|
| 298 |
+
# Langfilters = ["zh_ja_en_ko"] # 中日英韩混合识别
|
| 299 |
+
|
| 300 |
+
# 更多过滤组合,请您随意。。。For more filter combinations, please feel free to......
|
| 301 |
+
# より多くのフィルターの組み合わせ、お気軽に。。。더 많은 필터 조합을 원하시면 자유롭게 해주세요. .....
|
| 302 |
+
|
| 303 |
+
# 可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。
|
| 304 |
+
# 开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
|
| 305 |
+
keepPinyin = False
|
| 306 |
+
|
| 307 |
+
# DEFINITION
|
| 308 |
+
PARSE_TAG = re.compile(r"(⑥\$*\d+[\d]{6,}⑥)")
|
| 309 |
+
|
| 310 |
+
@staticmethod
|
| 311 |
+
def _clears():
|
| 312 |
+
LangSegment._text_cache = None
|
| 313 |
+
LangSegment._text_lasts = None
|
| 314 |
+
LangSegment._text_langs = None
|
| 315 |
+
LangSegment._text_waits = None
|
| 316 |
+
LangSegment._lang_count = None
|
| 317 |
+
LangSegment._lang_eos = None
|
| 318 |
+
pass
|
| 319 |
+
|
| 320 |
+
@staticmethod
|
| 321 |
+
def _is_english_word(word):
|
| 322 |
+
return bool(re.match(r"^[a-zA-Z]+$", word))
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def _is_chinese(word):
|
| 326 |
+
for char in word:
|
| 327 |
+
if "\u4e00" <= char <= "\u9fff":
|
| 328 |
+
return True
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
@staticmethod
|
| 332 |
+
def _is_japanese_kana(word):
|
| 333 |
+
pattern = re.compile(r"[\u3040-\u309F\u30A0-\u30FF]+")
|
| 334 |
+
matches = pattern.findall(word)
|
| 335 |
+
return len(matches) > 0
|
| 336 |
+
|
| 337 |
+
@staticmethod
|
| 338 |
+
def _insert_english_uppercase(word):
|
| 339 |
+
modified_text = re.sub(r"(?<!\b)([A-Z])", r" \1", word)
|
| 340 |
+
modified_text = modified_text.strip("-")
|
| 341 |
+
return modified_text + " "
|
| 342 |
+
|
| 343 |
+
@staticmethod
|
| 344 |
+
def _split_camel_case(word):
|
| 345 |
+
return re.sub(r"(?<!^)(?=[A-Z])", " ", word)
|
| 346 |
+
|
| 347 |
+
@staticmethod
|
| 348 |
+
def _statistics(language, text):
|
| 349 |
+
# Language word statistics:
|
| 350 |
+
# Chinese characters usually occupy double bytes
|
| 351 |
+
if LangSegment._lang_count is None or not isinstance(
|
| 352 |
+
LangSegment._lang_count, defaultdict
|
| 353 |
+
):
|
| 354 |
+
LangSegment._lang_count = defaultdict(int)
|
| 355 |
+
lang_count = LangSegment._lang_count
|
| 356 |
+
if "|" not in language:
|
| 357 |
+
lang_count[language] += (
|
| 358 |
+
int(len(text) * 2) if language == "zh" else len(text)
|
| 359 |
+
)
|
| 360 |
+
LangSegment._lang_count = lang_count
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
def _clear_text_number(text):
|
| 365 |
+
if text == "\n":
|
| 366 |
+
return text, False # Keep Line Breaks
|
| 367 |
+
clear_text = re.sub(r"([^\w\s]+)", "", re.sub(r"\n+", "", text)).strip()
|
| 368 |
+
is_number = len(re.sub(re.compile(r"(\d+)"), "", clear_text)) == 0
|
| 369 |
+
return clear_text, is_number
|
| 370 |
+
|
| 371 |
+
@staticmethod
|
| 372 |
+
def _saveData(words, language: str, text: str, score: float, symbol=None):
|
| 373 |
+
# Pre-detection
|
| 374 |
+
clear_text, is_number = LangSegment._clear_text_number(text)
|
| 375 |
+
# Merge the same language and save the results
|
| 376 |
+
preData = words[-1] if len(words) > 0 else None
|
| 377 |
+
if symbol is not None:
|
| 378 |
+
pass
|
| 379 |
+
elif preData is not None and preData["symbol"] is None:
|
| 380 |
+
if len(clear_text) == 0:
|
| 381 |
+
language = preData["lang"]
|
| 382 |
+
elif is_number == True:
|
| 383 |
+
language = preData["lang"]
|
| 384 |
+
_, pre_is_number = LangSegment._clear_text_number(preData["text"])
|
| 385 |
+
if preData["lang"] == language:
|
| 386 |
+
LangSegment._statistics(preData["lang"], text)
|
| 387 |
+
text = preData["text"] + text
|
| 388 |
+
preData["text"] = text
|
| 389 |
+
return preData
|
| 390 |
+
elif pre_is_number == True:
|
| 391 |
+
text = f"{preData['text']}{text}"
|
| 392 |
+
words.pop()
|
| 393 |
+
elif is_number == True:
|
| 394 |
+
priority_language = LangSegment._get_filters_string()[:2]
|
| 395 |
+
if priority_language in "ja-zh-en-ko-fr-vi":
|
| 396 |
+
language = priority_language
|
| 397 |
+
data = {"lang": language, "text": text, "score": score, "symbol": symbol}
|
| 398 |
+
filters = LangSegment.Langfilters
|
| 399 |
+
if (
|
| 400 |
+
filters is None
|
| 401 |
+
or len(filters) == 0
|
| 402 |
+
or "?" in language
|
| 403 |
+
or language in filters
|
| 404 |
+
or language in filters[0]
|
| 405 |
+
or filters[0] == "*"
|
| 406 |
+
or filters[0] in "alls-mixs-autos"
|
| 407 |
+
):
|
| 408 |
+
words.append(data)
|
| 409 |
+
LangSegment._statistics(data["lang"], data["text"])
|
| 410 |
+
return data
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def _addwords(words, language, text, score, symbol=None):
|
| 414 |
+
if text == "\n":
|
| 415 |
+
pass # Keep Line Breaks
|
| 416 |
+
elif text is None or len(text.strip()) == 0:
|
| 417 |
+
return True
|
| 418 |
+
if language is None:
|
| 419 |
+
language = ""
|
| 420 |
+
language = language.lower()
|
| 421 |
+
if language == "en":
|
| 422 |
+
text = LangSegment._insert_english_uppercase(text)
|
| 423 |
+
# text = re.sub(r'[(())]', ',' , text) # Keep it.
|
| 424 |
+
text_waits = LangSegment._text_waits
|
| 425 |
+
ispre_waits = len(text_waits) > 0
|
| 426 |
+
preResult = text_waits.pop() if ispre_waits else None
|
| 427 |
+
if preResult is None:
|
| 428 |
+
preResult = words[-1] if len(words) > 0 else None
|
| 429 |
+
if preResult and ("|" in preResult["lang"]):
|
| 430 |
+
pre_lang = preResult["lang"]
|
| 431 |
+
if language in pre_lang:
|
| 432 |
+
preResult["lang"] = language = language.split("|")[0]
|
| 433 |
+
else:
|
| 434 |
+
preResult["lang"] = pre_lang.split("|")[0]
|
| 435 |
+
if ispre_waits:
|
| 436 |
+
preResult = LangSegment._saveData(
|
| 437 |
+
words,
|
| 438 |
+
preResult["lang"],
|
| 439 |
+
preResult["text"],
|
| 440 |
+
preResult["score"],
|
| 441 |
+
preResult["symbol"],
|
| 442 |
+
)
|
| 443 |
+
pre_lang = preResult["lang"] if preResult else None
|
| 444 |
+
if ("|" in language) and (
|
| 445 |
+
pre_lang and pre_lang not in language and "…" not in language
|
| 446 |
+
):
|
| 447 |
+
language = language.split("|")[0]
|
| 448 |
+
if "|" in language:
|
| 449 |
+
LangSegment._text_waits.append(
|
| 450 |
+
{"lang": language, "text": text, "score": score, "symbol": symbol}
|
| 451 |
+
)
|
| 452 |
+
else:
|
| 453 |
+
LangSegment._saveData(words, language, text, score, symbol)
|
| 454 |
+
return False
|
| 455 |
+
|
| 456 |
+
@staticmethod
|
| 457 |
+
def _get_prev_data(words):
|
| 458 |
+
data = words[-1] if words and len(words) > 0 else None
|
| 459 |
+
if data:
|
| 460 |
+
return (data["lang"], data["text"])
|
| 461 |
+
return (None, "")
|
| 462 |
+
|
| 463 |
+
@staticmethod
|
| 464 |
+
def _match_ending(input, index):
|
| 465 |
+
if input is None or len(input) == 0:
|
| 466 |
+
return False, None
|
| 467 |
+
input = re.sub(r"\s+", "", input)
|
| 468 |
+
if len(input) == 0 or abs(index) > len(input):
|
| 469 |
+
return False, None
|
| 470 |
+
ending_pattern = re.compile(r'([「」“”‘’"\'::。.!!?.?])')
|
| 471 |
+
return ending_pattern.match(input[index]), input[index]
|
| 472 |
+
|
| 473 |
+
@staticmethod
|
| 474 |
+
def _cleans_text(cleans_text):
|
| 475 |
+
cleans_text = re.sub(r"(.*?)([^\w]+)", r"\1 ", cleans_text)
|
| 476 |
+
cleans_text = re.sub(r"(.)\1+", r"\1", cleans_text)
|
| 477 |
+
return cleans_text.strip()
|
| 478 |
+
|
| 479 |
+
@staticmethod
|
| 480 |
+
def _mean_processing(text: str):
|
| 481 |
+
if text is None or (text.strip()) == "":
|
| 482 |
+
return None, 0.0
|
| 483 |
+
arrs = LangSegment._split_camel_case(text).split(" ")
|
| 484 |
+
langs = []
|
| 485 |
+
for t in arrs:
|
| 486 |
+
if len(t.strip()) <= 3:
|
| 487 |
+
continue
|
| 488 |
+
language, score = langid.classify(t)
|
| 489 |
+
langs.append({"lang": language})
|
| 490 |
+
if len(langs) == 0:
|
| 491 |
+
return None, 0.0
|
| 492 |
+
return Counter([item["lang"] for item in langs]).most_common(1)[0][0], 1.0
|
| 493 |
+
|
| 494 |
+
@staticmethod
|
| 495 |
+
def _lang_classify(cleans_text):
|
| 496 |
+
language, score = langid.classify(cleans_text)
|
| 497 |
+
# fix: Huggingface is np.float32
|
| 498 |
+
if (
|
| 499 |
+
score is not None
|
| 500 |
+
and isinstance(score, np.generic)
|
| 501 |
+
and hasattr(score, "item")
|
| 502 |
+
):
|
| 503 |
+
score = score.item()
|
| 504 |
+
score = round(score, 3)
|
| 505 |
+
return language, score
|
| 506 |
+
|
| 507 |
+
@staticmethod
|
| 508 |
+
def _get_filters_string():
|
| 509 |
+
filters = LangSegment.Langfilters
|
| 510 |
+
return "-".join(filters).lower().strip() if filters is not None else ""
|
| 511 |
+
|
| 512 |
+
@staticmethod
|
| 513 |
+
def _parse_language(words, segment):
|
| 514 |
+
LANG_JA = "ja"
|
| 515 |
+
LANG_ZH = "zh"
|
| 516 |
+
LANG_ZH_JA = f"{LANG_ZH}|{LANG_JA}"
|
| 517 |
+
LANG_JA_ZH = f"{LANG_JA}|{LANG_ZH}"
|
| 518 |
+
language = LANG_ZH
|
| 519 |
+
regex_pattern = re.compile(r"([^\w\s]+)")
|
| 520 |
+
lines = regex_pattern.split(segment)
|
| 521 |
+
lines_max = len(lines)
|
| 522 |
+
LANG_EOS = LangSegment._lang_eos
|
| 523 |
+
for index, text in enumerate(lines):
|
| 524 |
+
if len(text) == 0:
|
| 525 |
+
continue
|
| 526 |
+
EOS = index >= (lines_max - 1)
|
| 527 |
+
nextId = index + 1
|
| 528 |
+
nextText = lines[nextId] if not EOS else ""
|
| 529 |
+
nextPunc = (
|
| 530 |
+
len(re.sub(regex_pattern, "", re.sub(r"\n+", "", nextText)).strip())
|
| 531 |
+
== 0
|
| 532 |
+
)
|
| 533 |
+
textPunc = (
|
| 534 |
+
len(re.sub(regex_pattern, "", re.sub(r"\n+", "", text)).strip()) == 0
|
| 535 |
+
)
|
| 536 |
+
if not EOS and (
|
| 537 |
+
textPunc == True or (len(nextText.strip()) >= 0 and nextPunc == True)
|
| 538 |
+
):
|
| 539 |
+
lines[nextId] = f"{text}{nextText}"
|
| 540 |
+
continue
|
| 541 |
+
number_tags = re.compile(r"(⑥\d{6,}⑥)")
|
| 542 |
+
cleans_text = re.sub(number_tags, "", text)
|
| 543 |
+
cleans_text = re.sub(r"\d+", "", cleans_text)
|
| 544 |
+
cleans_text = LangSegment._cleans_text(cleans_text)
|
| 545 |
+
# fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
|
| 546 |
+
if not EOS and len(cleans_text) <= 2:
|
| 547 |
+
lines[nextId] = f"{text}{nextText}"
|
| 548 |
+
continue
|
| 549 |
+
language, score = LangSegment._lang_classify(cleans_text)
|
| 550 |
+
prev_language, prev_text = LangSegment._get_prev_data(words)
|
| 551 |
+
if language != LANG_ZH and all(
|
| 552 |
+
"\u4e00" <= c <= "\u9fff" for c in re.sub(r"\s", "", cleans_text)
|
| 553 |
+
):
|
| 554 |
+
language, score = LANG_ZH, 1
|
| 555 |
+
if len(cleans_text) <= 5 and LangSegment._is_chinese(cleans_text):
|
| 556 |
+
filters_string = LangSegment._get_filters_string()
|
| 557 |
+
if (
|
| 558 |
+
score < LangSegment.LangPriorityThreshold
|
| 559 |
+
and len(filters_string) > 0
|
| 560 |
+
):
|
| 561 |
+
index_ja, index_zh = (
|
| 562 |
+
filters_string.find(LANG_JA),
|
| 563 |
+
filters_string.find(LANG_ZH),
|
| 564 |
+
)
|
| 565 |
+
if index_ja != -1 and index_ja < index_zh:
|
| 566 |
+
language = LANG_JA
|
| 567 |
+
elif index_zh != -1 and index_zh < index_ja:
|
| 568 |
+
language = LANG_ZH
|
| 569 |
+
if LangSegment._is_japanese_kana(cleans_text):
|
| 570 |
+
language = LANG_JA
|
| 571 |
+
elif len(cleans_text) > 2 and score > 0.90:
|
| 572 |
+
pass
|
| 573 |
+
elif EOS and LANG_EOS:
|
| 574 |
+
language = LANG_ZH if len(cleans_text) <= 1 else language
|
| 575 |
+
else:
|
| 576 |
+
LANG_UNKNOWN = (
|
| 577 |
+
LANG_ZH_JA
|
| 578 |
+
if language == LANG_ZH
|
| 579 |
+
or (len(cleans_text) <= 2 and prev_language == LANG_ZH)
|
| 580 |
+
else LANG_JA_ZH
|
| 581 |
+
)
|
| 582 |
+
match_end, match_char = LangSegment._match_ending(text, -1)
|
| 583 |
+
referen = (
|
| 584 |
+
prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language
|
| 585 |
+
if prev_language
|
| 586 |
+
else False
|
| 587 |
+
)
|
| 588 |
+
if match_char in "。.":
|
| 589 |
+
language = (
|
| 590 |
+
prev_language if referen and len(words) > 0 else language
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
language = f"{LANG_UNKNOWN}|…"
|
| 594 |
+
text, *_ = re.subn(number_tags, LangSegment._restore_number, text)
|
| 595 |
+
LangSegment._addwords(words, language, text, score)
|
| 596 |
+
pass
|
| 597 |
+
pass
|
| 598 |
+
|
| 599 |
+
# ----------------------------------------------------------
|
| 600 |
+
# 【SSML】中文数字处理:Chinese Number Processing (SSML support)
|
| 601 |
+
# 这里默认都是中文,用于处理 SSML 中文标签。当然可以支持任意语言,例如:
|
| 602 |
+
# The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
|
| 603 |
+
# 中文电话号码:<telephone>1234567</telephone>
|
| 604 |
+
# 中文数字号码:<number>1234567</number>
|
| 605 |
+
@staticmethod
|
| 606 |
+
def _process_symbol_SSML(words, data):
|
| 607 |
+
tag, match = data
|
| 608 |
+
language = SSML = match[1]
|
| 609 |
+
text = match[2]
|
| 610 |
+
score = 1.0
|
| 611 |
+
if SSML == "telephone":
|
| 612 |
+
# 中文-电话号码
|
| 613 |
+
language = "zh"
|
| 614 |
+
text = LangSSML.to_chinese_telephone(text)
|
| 615 |
+
pass
|
| 616 |
+
elif SSML == "number":
|
| 617 |
+
# 中文-数字读法
|
| 618 |
+
language = "zh"
|
| 619 |
+
text = LangSSML.to_chinese_number(text)
|
| 620 |
+
pass
|
| 621 |
+
elif SSML == "currency":
|
| 622 |
+
# 中文-按金额发音
|
| 623 |
+
language = "zh"
|
| 624 |
+
text = LangSSML.to_chinese_currency(text)
|
| 625 |
+
pass
|
| 626 |
+
elif SSML == "date":
|
| 627 |
+
# 中文-按金额发音
|
| 628 |
+
language = "zh"
|
| 629 |
+
text = LangSSML.to_chinese_date(text)
|
| 630 |
+
pass
|
| 631 |
+
LangSegment._addwords(words, language, text, score, SSML)
|
| 632 |
+
pass
|
| 633 |
+
|
| 634 |
+
# ----------------------------------------------------------
|
| 635 |
+
|
| 636 |
+
@staticmethod
|
| 637 |
+
def _restore_number(matche):
|
| 638 |
+
value = matche.group(0)
|
| 639 |
+
text_cache = LangSegment._text_cache
|
| 640 |
+
if value in text_cache:
|
| 641 |
+
process, data = text_cache[value]
|
| 642 |
+
tag, match = data
|
| 643 |
+
value = match
|
| 644 |
+
return value
|
| 645 |
+
|
| 646 |
+
@staticmethod
|
| 647 |
+
def _pattern_symbols(item, text):
|
| 648 |
+
if text is None:
|
| 649 |
+
return text
|
| 650 |
+
tag, pattern, process = item
|
| 651 |
+
matches = pattern.findall(text)
|
| 652 |
+
if len(matches) == 1 and "".join(matches[0]) == text:
|
| 653 |
+
return text
|
| 654 |
+
for i, match in enumerate(matches):
|
| 655 |
+
key = f"⑥{tag}{i:06d}⑥"
|
| 656 |
+
text = re.sub(pattern, key, text, count=1)
|
| 657 |
+
LangSegment._text_cache[key] = (process, (tag, match))
|
| 658 |
+
return text
|
| 659 |
+
|
| 660 |
+
@staticmethod
|
| 661 |
+
def _process_symbol(words, data):
|
| 662 |
+
tag, match = data
|
| 663 |
+
language = match[1]
|
| 664 |
+
text = match[2]
|
| 665 |
+
score = 1.0
|
| 666 |
+
filters = LangSegment._get_filters_string()
|
| 667 |
+
if language not in filters:
|
| 668 |
+
LangSegment._process_symbol_SSML(words, data)
|
| 669 |
+
else:
|
| 670 |
+
LangSegment._addwords(words, language, text, score, True)
|
| 671 |
+
pass
|
| 672 |
+
|
| 673 |
+
@staticmethod
|
| 674 |
+
def _process_english(words, data):
|
| 675 |
+
tag, match = data
|
| 676 |
+
text = match[0]
|
| 677 |
+
filters = LangSegment._get_filters_string()
|
| 678 |
+
priority_language = filters[:2]
|
| 679 |
+
# Preview feature, other language segmentation processing
|
| 680 |
+
enablePreview = LangSegment.EnablePreview
|
| 681 |
+
if enablePreview == True:
|
| 682 |
+
# Experimental: Other language support
|
| 683 |
+
regex_pattern = re.compile(r"(.*?[。.??!!]+[\n]{,1})")
|
| 684 |
+
lines = regex_pattern.split(text)
|
| 685 |
+
for index, text in enumerate(lines):
|
| 686 |
+
if len(text.strip()) == 0:
|
| 687 |
+
continue
|
| 688 |
+
cleans_text = LangSegment._cleans_text(text)
|
| 689 |
+
language, score = LangSegment._lang_classify(cleans_text)
|
| 690 |
+
if language not in filters:
|
| 691 |
+
language, score = LangSegment._mean_processing(cleans_text)
|
| 692 |
+
if language is None or score <= 0.0:
|
| 693 |
+
continue
|
| 694 |
+
elif language in filters:
|
| 695 |
+
pass # pass
|
| 696 |
+
elif score >= 0.95:
|
| 697 |
+
continue # High score, but not in the filter, excluded.
|
| 698 |
+
elif score <= 0.15 and filters[:2] == "fr":
|
| 699 |
+
language = priority_language
|
| 700 |
+
else:
|
| 701 |
+
language = "en"
|
| 702 |
+
LangSegment._addwords(words, language, text, score)
|
| 703 |
+
else:
|
| 704 |
+
# Default is English
|
| 705 |
+
language, score = "en", 1.0
|
| 706 |
+
LangSegment._addwords(words, language, text, score)
|
| 707 |
+
pass
|
| 708 |
+
|
| 709 |
+
@staticmethod
|
| 710 |
+
def _process_Russian(words, data):
|
| 711 |
+
tag, match = data
|
| 712 |
+
text = match[0]
|
| 713 |
+
language = "ru"
|
| 714 |
+
score = 1.0
|
| 715 |
+
LangSegment._addwords(words, language, text, score)
|
| 716 |
+
pass
|
| 717 |
+
|
| 718 |
+
@staticmethod
|
| 719 |
+
def _process_Thai(words, data):
|
| 720 |
+
tag, match = data
|
| 721 |
+
text = match[0]
|
| 722 |
+
language = "th"
|
| 723 |
+
score = 1.0
|
| 724 |
+
LangSegment._addwords(words, language, text, score)
|
| 725 |
+
pass
|
| 726 |
+
|
| 727 |
+
@staticmethod
|
| 728 |
+
def _process_korean(words, data):
|
| 729 |
+
tag, match = data
|
| 730 |
+
text = match[0]
|
| 731 |
+
language = "ko"
|
| 732 |
+
score = 1.0
|
| 733 |
+
LangSegment._addwords(words, language, text, score)
|
| 734 |
+
pass
|
| 735 |
+
|
| 736 |
+
@staticmethod
|
| 737 |
+
def _process_quotes(words, data):
|
| 738 |
+
tag, match = data
|
| 739 |
+
text = "".join(match)
|
| 740 |
+
childs = LangSegment.PARSE_TAG.findall(text)
|
| 741 |
+
if len(childs) > 0:
|
| 742 |
+
LangSegment._process_tags(words, text, False)
|
| 743 |
+
else:
|
| 744 |
+
cleans_text = LangSegment._cleans_text(match[1])
|
| 745 |
+
if len(cleans_text) <= 5:
|
| 746 |
+
LangSegment._parse_language(words, text)
|
| 747 |
+
else:
|
| 748 |
+
language, score = LangSegment._lang_classify(cleans_text)
|
| 749 |
+
LangSegment._addwords(words, language, text, score)
|
| 750 |
+
pass
|
| 751 |
+
|
| 752 |
+
@staticmethod
|
| 753 |
+
def _process_pinyin(words, data):
|
| 754 |
+
tag, match = data
|
| 755 |
+
text = match
|
| 756 |
+
language = "zh"
|
| 757 |
+
score = 1.0
|
| 758 |
+
LangSegment._addwords(words, language, text, score)
|
| 759 |
+
pass
|
| 760 |
+
|
| 761 |
+
@staticmethod
|
| 762 |
+
def _process_number(words, data): # "$0" process only
|
| 763 |
+
"""
|
| 764 |
+
Numbers alone cannot accurately identify language.
|
| 765 |
+
Because numbers are universal in all languages.
|
| 766 |
+
So it won't be executed here, just for testing.
|
| 767 |
+
"""
|
| 768 |
+
tag, match = data
|
| 769 |
+
language = words[0]["lang"] if len(words) > 0 else "zh"
|
| 770 |
+
text = match
|
| 771 |
+
score = 0.0
|
| 772 |
+
LangSegment._addwords(words, language, text, score)
|
| 773 |
+
pass
|
| 774 |
+
|
| 775 |
+
@staticmethod
|
| 776 |
+
def _process_tags(words, text, root_tag):
|
| 777 |
+
text_cache = LangSegment._text_cache
|
| 778 |
+
segments = re.split(LangSegment.PARSE_TAG, text)
|
| 779 |
+
segments_len = len(segments) - 1
|
| 780 |
+
for index, text in enumerate(segments):
|
| 781 |
+
if root_tag:
|
| 782 |
+
LangSegment._lang_eos = index >= segments_len
|
| 783 |
+
if LangSegment.PARSE_TAG.match(text):
|
| 784 |
+
process, data = text_cache[text]
|
| 785 |
+
if process:
|
| 786 |
+
process(words, data)
|
| 787 |
+
else:
|
| 788 |
+
LangSegment._parse_language(words, text)
|
| 789 |
+
pass
|
| 790 |
+
return words
|
| 791 |
+
|
| 792 |
+
@staticmethod
|
| 793 |
+
def _merge_results(words):
|
| 794 |
+
new_word = []
|
| 795 |
+
for index, cur_data in enumerate(words):
|
| 796 |
+
if "symbol" in cur_data:
|
| 797 |
+
del cur_data["symbol"]
|
| 798 |
+
if index == 0:
|
| 799 |
+
new_word.append(cur_data)
|
| 800 |
+
else:
|
| 801 |
+
pre_data = new_word[-1]
|
| 802 |
+
if cur_data["lang"] == pre_data["lang"]:
|
| 803 |
+
pre_data["text"] = f"{pre_data['text']}{cur_data['text']}"
|
| 804 |
+
else:
|
| 805 |
+
new_word.append(cur_data)
|
| 806 |
+
return new_word
|
| 807 |
+
|
| 808 |
+
@staticmethod
|
| 809 |
+
def _parse_symbols(text):
|
| 810 |
+
TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
|
| 811 |
+
TAG_S1, TAG_S2, TAG_P1, TAG_P2, TAG_EN, TAG_KO, TAG_RU, TAG_TH = (
|
| 812 |
+
"$1",
|
| 813 |
+
"$2",
|
| 814 |
+
"$3",
|
| 815 |
+
"$4",
|
| 816 |
+
"$5",
|
| 817 |
+
"$6",
|
| 818 |
+
"$7",
|
| 819 |
+
"$8",
|
| 820 |
+
)
|
| 821 |
+
TAG_BASE = re.compile(r'(([【《((“‘"\']*[LANGUAGE]+[\W\s]*)+)')
|
| 822 |
+
# Get custom language filter
|
| 823 |
+
filters = LangSegment.Langfilters
|
| 824 |
+
filters = filters if filters is not None else ""
|
| 825 |
+
# =======================================================================================================
|
| 826 |
+
# Experimental: Other language support.Thử nghiệm: Hỗ trợ ngôn ngữ khác.Expérimental : prise en charge d’autres langues.
|
| 827 |
+
# 相关语言字符如有缺失,熟悉相关语言的朋友,可以提交把缺失的发音符号补全。
|
| 828 |
+
# If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
|
| 829 |
+
# S'il manque des caractères linguistiques pertinents, les amis qui connaissent les langues concernées peuvent soumettre une soumission pour compléter les symboles de prononciation manquants.
|
| 830 |
+
# Nếu thiếu ký tự ngôn ngữ liên quan, những người bạn quen thuộc với ngôn ngữ liên quan có thể gửi bài để hoàn thành các ký hiệu phát âm còn thiếu.
|
| 831 |
+
# -------------------------------------------------------------------------------------------------------
|
| 832 |
+
# Preview feature, other language support
|
| 833 |
+
enablePreview = LangSegment.EnablePreview
|
| 834 |
+
if "fr" in filters or "vi" in filters:
|
| 835 |
+
enablePreview = True
|
| 836 |
+
LangSegment.EnablePreview = enablePreview
|
| 837 |
+
# 实验性:法语字符支持。Prise en charge des caractères français
|
| 838 |
+
RE_FR = "" if not enablePreview else "àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ"
|
| 839 |
+
# 实验性:越南语字符支持。Hỗ trợ ký tự tiếng Việt
|
| 840 |
+
RE_VI = (
|
| 841 |
+
""
|
| 842 |
+
if not enablePreview
|
| 843 |
+
else "đơưăáàảãạắằẳẵặấầẩẫậéèẻẽẹếềểễệíìỉĩịóòỏõọốồổỗộớờởỡợúùủũụứừửữựôâêơưỷỹ"
|
| 844 |
+
)
|
| 845 |
+
# -------------------------------------------------------------------------------------------------------
|
| 846 |
+
# Basic options:
|
| 847 |
+
process_list = [
|
| 848 |
+
(
|
| 849 |
+
TAG_S1,
|
| 850 |
+
re.compile(LangSegment.SYMBOLS_PATTERN),
|
| 851 |
+
LangSegment._process_symbol,
|
| 852 |
+
), # Symbol Tag
|
| 853 |
+
(
|
| 854 |
+
TAG_KO,
|
| 855 |
+
re.compile(re.sub(r"LANGUAGE", "\uac00-\ud7a3", TAG_BASE.pattern)),
|
| 856 |
+
LangSegment._process_korean,
|
| 857 |
+
), # Korean words
|
| 858 |
+
(
|
| 859 |
+
TAG_TH,
|
| 860 |
+
re.compile(re.sub(r"LANGUAGE", "\u0e00-\u0e7f", TAG_BASE.pattern)),
|
| 861 |
+
LangSegment._process_Thai,
|
| 862 |
+
), # Thai words support.
|
| 863 |
+
(
|
| 864 |
+
TAG_RU,
|
| 865 |
+
re.compile(re.sub(r"LANGUAGE", "А-Яа-яЁё", TAG_BASE.pattern)),
|
| 866 |
+
LangSegment._process_Russian,
|
| 867 |
+
), # Russian words support.
|
| 868 |
+
(
|
| 869 |
+
TAG_NUM,
|
| 870 |
+
re.compile(r"(\W*\d+\W+\d*\W*\d*)"),
|
| 871 |
+
LangSegment._process_number,
|
| 872 |
+
), # Number words, Universal in all languages, Ignore it.
|
| 873 |
+
(
|
| 874 |
+
TAG_EN,
|
| 875 |
+
re.compile(
|
| 876 |
+
re.sub(r"LANGUAGE", f"a-zA-Z{RE_FR}{RE_VI}", TAG_BASE.pattern)
|
| 877 |
+
),
|
| 878 |
+
LangSegment._process_english,
|
| 879 |
+
), # English words + Other language support.
|
| 880 |
+
(
|
| 881 |
+
TAG_P1,
|
| 882 |
+
re.compile(r'(["\'])(.*?)(\1)'),
|
| 883 |
+
LangSegment._process_quotes,
|
| 884 |
+
), # Regular quotes
|
| 885 |
+
(
|
| 886 |
+
TAG_P2,
|
| 887 |
+
re.compile(
|
| 888 |
+
r"([\n]*[【《((“‘])([^【《((“‘’”))》】]{3,})([’”))》】][\W\s]*[\n]{,1})"
|
| 889 |
+
),
|
| 890 |
+
LangSegment._process_quotes,
|
| 891 |
+
), # Special quotes, There are left and right.
|
| 892 |
+
]
|
| 893 |
+
# Extended options: Default False
|
| 894 |
+
if LangSegment.keepPinyin == True:
|
| 895 |
+
process_list.insert(
|
| 896 |
+
1,
|
| 897 |
+
(
|
| 898 |
+
TAG_S2,
|
| 899 |
+
re.compile(r"([\(({](?:\s*\w*\d\w*\s*)+[})\)])"),
|
| 900 |
+
LangSegment._process_pinyin,
|
| 901 |
+
), # Chinese Pinyin Tag.
|
| 902 |
+
)
|
| 903 |
+
# -------------------------------------------------------------------------------------------------------
|
| 904 |
+
words = []
|
| 905 |
+
lines = re.findall(r".*\n*", re.sub(LangSegment.PARSE_TAG, "", text))
|
| 906 |
+
for index, text in enumerate(lines):
|
| 907 |
+
if len(text.strip()) == 0:
|
| 908 |
+
continue
|
| 909 |
+
LangSegment._lang_eos = False
|
| 910 |
+
LangSegment._text_cache = {}
|
| 911 |
+
for item in process_list:
|
| 912 |
+
text = LangSegment._pattern_symbols(item, text)
|
| 913 |
+
cur_word = LangSegment._process_tags([], text, True)
|
| 914 |
+
if len(cur_word) == 0:
|
| 915 |
+
continue
|
| 916 |
+
cur_data = cur_word[0] if len(cur_word) > 0 else None
|
| 917 |
+
pre_data = words[-1] if len(words) > 0 else None
|
| 918 |
+
if (
|
| 919 |
+
cur_data
|
| 920 |
+
and pre_data
|
| 921 |
+
and cur_data["lang"] == pre_data["lang"]
|
| 922 |
+
and cur_data["symbol"] == False
|
| 923 |
+
and pre_data["symbol"]
|
| 924 |
+
):
|
| 925 |
+
cur_data["text"] = f"{pre_data['text']}{cur_data['text']}"
|
| 926 |
+
words.pop()
|
| 927 |
+
words += cur_word
|
| 928 |
+
if LangSegment.isLangMerge == True:
|
| 929 |
+
words = LangSegment._merge_results(words)
|
| 930 |
+
lang_count = LangSegment._lang_count
|
| 931 |
+
if lang_count and len(lang_count) > 0:
|
| 932 |
+
lang_count = dict(
|
| 933 |
+
sorted(lang_count.items(), key=lambda x: x[1], reverse=True)
|
| 934 |
+
)
|
| 935 |
+
lang_count = list(lang_count.items())
|
| 936 |
+
LangSegment._lang_count = lang_count
|
| 937 |
+
return words
|
| 938 |
+
|
| 939 |
+
@staticmethod
|
| 940 |
+
def setfilters(filters):
|
| 941 |
+
# 当过滤器更改时,清除缓存
|
| 942 |
+
# 필터가 변경되면 캐시를 지웁니다.
|
| 943 |
+
# フィルタが変更されると、キャッシュがクリアされます
|
| 944 |
+
# When the filter changes, clear the cache
|
| 945 |
+
if LangSegment.Langfilters != filters:
|
| 946 |
+
LangSegment._clears()
|
| 947 |
+
LangSegment.Langfilters = filters
|
| 948 |
+
pass
|
| 949 |
+
|
| 950 |
+
@staticmethod
|
| 951 |
+
def getfilters():
|
| 952 |
+
return LangSegment.Langfilters
|
| 953 |
+
|
| 954 |
+
@staticmethod
|
| 955 |
+
def setPriorityThreshold(threshold: float):
|
| 956 |
+
LangSegment.LangPriorityThreshold = threshold
|
| 957 |
+
pass
|
| 958 |
+
|
| 959 |
+
@staticmethod
|
| 960 |
+
def getPriorityThreshold():
|
| 961 |
+
return LangSegment.LangPriorityThreshold
|
| 962 |
+
|
| 963 |
+
@staticmethod
|
| 964 |
+
def getCounts():
|
| 965 |
+
lang_count = LangSegment._lang_count
|
| 966 |
+
if lang_count is not None:
|
| 967 |
+
return lang_count
|
| 968 |
+
text_langs = LangSegment._text_langs
|
| 969 |
+
if text_langs is None or len(text_langs) == 0:
|
| 970 |
+
return [("zh", 0)]
|
| 971 |
+
lang_counts = defaultdict(int)
|
| 972 |
+
for d in text_langs:
|
| 973 |
+
lang_counts[d["lang"]] += (
|
| 974 |
+
int(len(d["text"]) * 2) if d["lang"] == "zh" else len(d["text"])
|
| 975 |
+
)
|
| 976 |
+
lang_counts = dict(
|
| 977 |
+
sorted(lang_counts.items(), key=lambda x: x[1], reverse=True)
|
| 978 |
+
)
|
| 979 |
+
lang_counts = list(lang_counts.items())
|
| 980 |
+
LangSegment._lang_count = lang_counts
|
| 981 |
+
return lang_counts
|
| 982 |
+
|
| 983 |
+
@staticmethod
|
| 984 |
+
def getTexts(text: str):
|
| 985 |
+
if text is None or len(text.strip()) == 0:
|
| 986 |
+
LangSegment._clears()
|
| 987 |
+
return []
|
| 988 |
+
# lasts
|
| 989 |
+
text_langs = LangSegment._text_langs
|
| 990 |
+
if LangSegment._text_lasts == text and text_langs is not None:
|
| 991 |
+
return text_langs
|
| 992 |
+
# parse
|
| 993 |
+
LangSegment._text_waits = []
|
| 994 |
+
LangSegment._lang_count = None
|
| 995 |
+
LangSegment._text_lasts = text
|
| 996 |
+
text = LangSegment._parse_symbols(text)
|
| 997 |
+
LangSegment._text_langs = text
|
| 998 |
+
return text
|
| 999 |
+
|
| 1000 |
+
@staticmethod
|
| 1001 |
+
def classify(text: str):
|
| 1002 |
+
return LangSegment.getTexts(text)
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def setLangMerge(value: bool):
|
| 1006 |
+
"""是否优化合并结果"""
|
| 1007 |
+
LangSegment.isLangMerge = value
|
| 1008 |
+
pass
|
| 1009 |
+
|
| 1010 |
+
|
| 1011 |
+
def getLangMerge():
|
| 1012 |
+
"""是否优化合并结果"""
|
| 1013 |
+
return LangSegment.isLangMerge
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def setfilters(filters):
|
| 1017 |
+
"""
|
| 1018 |
+
功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
|
| 1019 |
+
기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
|
| 1020 |
+
機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
|
| 1021 |
+
Function: Language filter group function, you can specify reserved languages. \n
|
| 1022 |
+
Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
|
| 1023 |
+
Args:
|
| 1024 |
+
filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
|
| 1025 |
+
"""
|
| 1026 |
+
LangSegment.setfilters(filters)
|
| 1027 |
+
pass
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
def getfilters():
|
| 1031 |
+
"""
|
| 1032 |
+
功能:语言过滤组功能, 可以指定保留语言。不在过滤组中的语言将被清除。您可随心搭配TTS语音合成所支持的语言。
|
| 1033 |
+
기능: 언어 필터 그룹 기능, 예약된 언어를 지정할 수 있습니다. 필터 그룹에 없는 언어는 지워집니다. TTS 텍스트에서 지원하는 언어를 원하는 대로 일치시킬 수 있습니다.
|
| 1034 |
+
機能:言語フィルターグループ機能で、予約言語を指定できます。フィルターグループに含まれていない言語はクリアされます。TTS音声合成がサポートする言語を自由に組み合わせることができます。
|
| 1035 |
+
Function: Language filter group function, you can specify reserved languages. \n
|
| 1036 |
+
Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.\n
|
| 1037 |
+
Args:
|
| 1038 |
+
filters (list): ["zh", "en", "ja", "ko"] 排名越前,优先级越高
|
| 1039 |
+
"""
|
| 1040 |
+
return LangSegment.getfilters()
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
# # @Deprecated:Use shorter setfilters
|
| 1044 |
+
# def setLangfilters(filters):
|
| 1045 |
+
# """
|
| 1046 |
+
# >0.1.9废除:使用更简短的setfilters
|
| 1047 |
+
# """
|
| 1048 |
+
# setfilters(filters)
|
| 1049 |
+
# # @Deprecated:Use shorter getfilters
|
| 1050 |
+
# def getLangfilters():
|
| 1051 |
+
# """
|
| 1052 |
+
# >0.1.9废除:使用更简短的getfilters
|
| 1053 |
+
# """
|
| 1054 |
+
# return getfilters()
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
+
def setKeepPinyin(value: bool):
|
| 1058 |
+
"""
|
| 1059 |
+
可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
|
| 1060 |
+
开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
|
| 1061 |
+
"""
|
| 1062 |
+
LangSegment.keepPinyin = value
|
| 1063 |
+
pass
|
| 1064 |
+
|
| 1065 |
+
|
| 1066 |
+
def getKeepPinyin():
|
| 1067 |
+
"""
|
| 1068 |
+
可选保留:支持中文数字拼音格式,更方便前端实现拼音音素修改和推理,默认关闭 False 。\n
|
| 1069 |
+
开启后 True ,括号内的数字拼音格式均保留,并识别输出为:"zh"中文。
|
| 1070 |
+
"""
|
| 1071 |
+
return LangSegment.keepPinyin
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
def setEnablePreview(value: bool):
|
| 1075 |
+
"""
|
| 1076 |
+
启用预览版功能(默认关闭)
|
| 1077 |
+
Enable preview functionality (off by default)
|
| 1078 |
+
Args:
|
| 1079 |
+
value (bool): True=开启, False=关闭
|
| 1080 |
+
"""
|
| 1081 |
+
LangSegment.EnablePreview = value == True
|
| 1082 |
+
pass
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
def getEnablePreview():
|
| 1086 |
+
"""
|
| 1087 |
+
启用预览版功能(默认关闭)
|
| 1088 |
+
Enable preview functionality (off by default)
|
| 1089 |
+
Args:
|
| 1090 |
+
value (bool): True=开启, False=关闭
|
| 1091 |
+
"""
|
| 1092 |
+
return LangSegment.EnablePreview == True
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
def setPriorityThreshold(threshold: float):
|
| 1096 |
+
"""
|
| 1097 |
+
中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
|
| 1098 |
+
中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
|
| 1099 |
+
중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
|
| 1100 |
+
Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
|
| 1101 |
+
Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
|
| 1102 |
+
Args:
|
| 1103 |
+
threshold:float (score range is 0 ~ 1)
|
| 1104 |
+
"""
|
| 1105 |
+
LangSegment.setPriorityThreshold(threshold)
|
| 1106 |
+
pass
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
def getPriorityThreshold():
|
| 1110 |
+
"""
|
| 1111 |
+
中/日语言优先级阀值(评分范围为 0 ~ 1):评分低于设定阀值 <0.89 时,启用 filters 中的优先级。\n
|
| 1112 |
+
中国語/日本語の優先度しきい値(スコア範囲0〜1):スコアが設定されたしきい値<0.89未満の場合、フィルターの優先度が有効になります。\n
|
| 1113 |
+
중/일본어 우선 순위 임계값(점수 범위 0-1): 점수가 설정된 임계값 <0.89보다 낮을 때 필터에서 우선 순위를 활성화합니다.
|
| 1114 |
+
Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
|
| 1115 |
+
Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
|
| 1116 |
+
Args:
|
| 1117 |
+
threshold:float (score range is 0 ~ 1)
|
| 1118 |
+
"""
|
| 1119 |
+
return LangSegment.getPriorityThreshold()
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
def getTexts(text: str):
|
| 1123 |
+
"""
|
| 1124 |
+
功能:对输入的文本进行多语种分词\n
|
| 1125 |
+
기능: 입력 텍스트의 다국어 분할 \n
|
| 1126 |
+
機能:入力されたテキストの多言語セグメンテーション\n
|
| 1127 |
+
Feature: Tokenizing multilingual text input.\n
|
| 1128 |
+
参数-Args:
|
| 1129 |
+
text (str): Text content,文本内容\n
|
| 1130 |
+
返回-Returns:
|
| 1131 |
+
list: 示例结果:[{'lang':'zh','text':'?'},...]\n
|
| 1132 |
+
lang=语种 , text=内容\n
|
| 1133 |
+
"""
|
| 1134 |
+
return LangSegment.getTexts(text)
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
def getCounts():
|
| 1138 |
+
"""
|
| 1139 |
+
功能:分词结果统计,按语种字数降序,用于确定其主要语言\n
|
| 1140 |
+
기능: 주요 언어를 결정하는 데 사용되는 언어별 단어 수 내림차순으로 단어 분할 결과의 통계 \n
|
| 1141 |
+
機能:主な言語を決定するために使用される、言語の単語数の降順による単語分割結果の統計\n
|
| 1142 |
+
Function: Tokenizing multilingual text input.\n
|
| 1143 |
+
返回-Returns:
|
| 1144 |
+
list: 示例结果:[('zh', 5), ('ja', 2), ('en', 1)] = [(语种,字数含标点)]\n
|
| 1145 |
+
"""
|
| 1146 |
+
return LangSegment.getCounts()
|
| 1147 |
+
|
| 1148 |
+
|
| 1149 |
+
def classify(text: str):
|
| 1150 |
+
"""
|
| 1151 |
+
功能:兼容接口实现
|
| 1152 |
+
Function: Compatible interface implementation
|
| 1153 |
+
"""
|
| 1154 |
+
return LangSegment.classify(text)
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
def printList(langlist):
|
| 1158 |
+
"""
|
| 1159 |
+
功能:打印数组结果
|
| 1160 |
+
기능: 어레이 결과 인쇄
|
| 1161 |
+
機能:配列結果を印刷
|
| 1162 |
+
Function: Print array results
|
| 1163 |
+
"""
|
| 1164 |
+
print("\n===================【打印结果】===================")
|
| 1165 |
+
if langlist is None or len(langlist) == 0:
|
| 1166 |
+
print("无内容结果,No content result")
|
| 1167 |
+
return
|
| 1168 |
+
for line in langlist:
|
| 1169 |
+
print(line)
|
| 1170 |
+
pass
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
def main():
|
| 1174 |
+
# -----------------------------------
|
| 1175 |
+
# 更新日志:新版本分词更加精准。
|
| 1176 |
+
# Changelog: The new version of the word segmentation is more accurate.
|
| 1177 |
+
# チェンジログ:新しいバージョンの単語セグメンテーションはより正確です。
|
| 1178 |
+
# Changelog: 분할이라는 단어의 새로운 버전이 더 정확합니다.
|
| 1179 |
+
# -----------------------------------
|
| 1180 |
+
|
| 1181 |
+
# 输入示例1:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
| 1182 |
+
# text = "“昨日は雨が降った,音楽、映画。。。”你今天学习日语了吗?春は桜の季節です。语种分词是语音合成必不可少的环节。言語分詞は音声合成に欠かせない環節である!"
|
| 1183 |
+
|
| 1184 |
+
# 输入示例2:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
| 1185 |
+
# text = "欢迎来玩。東京,は日本の首都です。欢迎来玩. 太好了!"
|
| 1186 |
+
|
| 1187 |
+
# 输入示例3:(包含日文,中文)Input Example 1: (including Japanese, Chinese)
|
| 1188 |
+
# text = "明日、私たちは海辺にバカンスに行きます。你会说日语吗:“中国語、話せますか” 你的日语真好啊!"
|
| 1189 |
+
|
| 1190 |
+
# 输入示例4:(包含日文,中文,韩语,英文)Input Example 4: (including Japanese, Chinese, Korean, English)
|
| 1191 |
+
# text = "你的名字叫<ja>佐々木?<ja>吗?韩语中的안녕 오빠读什么呢?あなたの体育の先生は誰ですか? 此次发布会带来了四款iPhone 15系列机型和三款Apple Watch等一系列新品,这次的iPad Air采用了LCD屏幕"
|
| 1192 |
+
|
| 1193 |
+
# 试验性支持:"fr"法语 , "vi"越南语 , "ru"俄语 , "th"泰语。Experimental: Other language support.
|
| 1194 |
+
LangSegment.setfilters(["fr", "vi", "ja", "zh", "ko", "en", "ru", "th"])
|
| 1195 |
+
text = """
|
| 1196 |
+
我喜欢在雨天里听音乐。
|
| 1197 |
+
I enjoy listening to music on rainy days.
|
| 1198 |
+
雨の日に音楽を聴くのが好きです。
|
| 1199 |
+
비 오는 날에 음악을 듣는 것을 즐깁니다。
|
| 1200 |
+
J'aime écouter de la musique les jours de pluie.
|
| 1201 |
+
Tôi thích nghe nhạc vào những ngày mưa.
|
| 1202 |
+
Мне нравится слушать музыку в дождливую погоду.
|
| 1203 |
+
ฉันชอบฟังเพ���งในวันที่ฝนตก
|
| 1204 |
+
"""
|
| 1205 |
+
|
| 1206 |
+
# 进行分词:(接入TTS项目仅需一行代码调用)Segmentation: (Only one line of code is required to access the TTS project)
|
| 1207 |
+
langlist = LangSegment.getTexts(text)
|
| 1208 |
+
printList(langlist)
|
| 1209 |
+
|
| 1210 |
+
# 语种统计:Language statistics:
|
| 1211 |
+
print("\n===================【语种统计】===================")
|
| 1212 |
+
# 获取所有语种数组结果,根据内容字数降序排列
|
| 1213 |
+
# Get the array results in all languages, sorted in descending order according to the number of content words
|
| 1214 |
+
langCounts = LangSegment.getCounts()
|
| 1215 |
+
print(langCounts, "\n")
|
| 1216 |
+
|
| 1217 |
+
# 根据结果获取内容的主要语种 (语言,字数含标点)
|
| 1218 |
+
# Get the main language of content based on the results (language, word count including punctuation)
|
| 1219 |
+
lang, count = langCounts[0]
|
| 1220 |
+
print(f"输入内容的主要语言为 = {lang} ,字数 = {count}")
|
| 1221 |
+
print("==================================================\n")
|
| 1222 |
+
|
| 1223 |
+
# 分词输出:lang=语言,text=内容。Word output: lang = language, text = content
|
| 1224 |
+
# ===================【打印结果】===================
|
| 1225 |
+
# {'lang': 'zh', 'text': '你的名字叫'}
|
| 1226 |
+
# {'lang': 'ja', 'text': '佐々木?'}
|
| 1227 |
+
# {'lang': 'zh', 'text': '吗?韩语中的'}
|
| 1228 |
+
# {'lang': 'ko', 'text': '안녕 오빠'}
|
| 1229 |
+
# {'lang': 'zh', 'text': '读什么呢?'}
|
| 1230 |
+
# {'lang': 'ja', 'text': 'あなたの体育の先生は誰ですか?'}
|
| 1231 |
+
# {'lang': 'zh', 'text': ' 此次发布会带来了四款'}
|
| 1232 |
+
# {'lang': 'en', 'text': 'i Phone '}
|
| 1233 |
+
# {'lang': 'zh', 'text': '15系列机型和三款'}
|
| 1234 |
+
# {'lang': 'en', 'text': 'Apple Watch '}
|
| 1235 |
+
# {'lang': 'zh', 'text': '等一系列新品,这次的'}
|
| 1236 |
+
# {'lang': 'en', 'text': 'i Pad Air '}
|
| 1237 |
+
# {'lang': 'zh', 'text': '采用了'}
|
| 1238 |
+
# {'lang': 'en', 'text': 'L C D '}
|
| 1239 |
+
# {'lang': 'zh', 'text': '屏幕'}
|
| 1240 |
+
# ===================【语种统计】===================
|
| 1241 |
+
|
| 1242 |
+
# ===================【语种统计】===================
|
| 1243 |
+
# [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
|
| 1244 |
+
|
| 1245 |
+
# 输入内容的主要语言为 = zh ,字数 = 51
|
| 1246 |
+
# ==================================================
|
| 1247 |
+
# The main language of the input content is = zh, word count = 51
|
| 1248 |
+
|
| 1249 |
+
|
| 1250 |
+
if __name__ == "__main__":
|
| 1251 |
+
main()
|
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .LangSegment import (
|
| 2 |
+
LangSegment,
|
| 3 |
+
classify,
|
| 4 |
+
getCounts,
|
| 5 |
+
getEnablePreview,
|
| 6 |
+
getfilters,
|
| 7 |
+
getKeepPinyin,
|
| 8 |
+
getLangMerge,
|
| 9 |
+
getPriorityThreshold,
|
| 10 |
+
getTexts,
|
| 11 |
+
printList,
|
| 12 |
+
setEnablePreview,
|
| 13 |
+
setfilters,
|
| 14 |
+
setKeepPinyin,
|
| 15 |
+
setLangMerge,
|
| 16 |
+
setPriorityThreshold,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
# release
|
| 20 |
+
__version__ = "0.3.5"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# develop
|
| 24 |
+
__develop__ = "dev-0.0.1"
|
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/__init__.py
ADDED
|
File without changes
|
src/YingMusicSinger/utils/f5_tts/thirdparty/LangSegment/utils/num.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# Digital processing from GPT_SoVITS num.py (thanks)
|
| 15 |
+
"""
|
| 16 |
+
Rules to verbalize numbers into Chinese characters.
|
| 17 |
+
https://zh.wikipedia.org/wiki/中文数字#現代中文
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
from collections import OrderedDict
|
| 22 |
+
from typing import List
|
| 23 |
+
|
| 24 |
+
DIGITS = {str(i): tran for i, tran in enumerate("零一二三四五六七八九")}
|
| 25 |
+
UNITS = OrderedDict(
|
| 26 |
+
{
|
| 27 |
+
1: "十",
|
| 28 |
+
2: "百",
|
| 29 |
+
3: "千",
|
| 30 |
+
4: "万",
|
| 31 |
+
8: "亿",
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
COM_QUANTIFIERS = "(处|台|架|枚|趟|幅|平|方|堵|间|床|株|批|项|例|列|篇|栋|注|亩|封|艘|把|目|套|段|人|所|朵|匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|毫|厘|(公)分|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|小时|旬|纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|元|(亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|美|)元|(亿|千万|百万|万|千|百|十|)吨|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
| 36 |
+
|
| 37 |
+
# 分数表达式
|
| 38 |
+
RE_FRAC = re.compile(r"(-?)(\d+)/(\d+)")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def replace_frac(match) -> str:
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
match (re.Match)
|
| 45 |
+
Returns:
|
| 46 |
+
str
|
| 47 |
+
"""
|
| 48 |
+
sign = match.group(1)
|
| 49 |
+
nominator = match.group(2)
|
| 50 |
+
denominator = match.group(3)
|
| 51 |
+
sign: str = "负" if sign else ""
|
| 52 |
+
nominator: str = num2str(nominator)
|
| 53 |
+
denominator: str = num2str(denominator)
|
| 54 |
+
result = f"{sign}{denominator}分之{nominator}"
|
| 55 |
+
return result
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# 百分数表达式
|
| 59 |
+
RE_PERCENTAGE = re.compile(r"(-?)(\d+(\.\d+)?)%")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def replace_percentage(match) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
match (re.Match)
|
| 66 |
+
Returns:
|
| 67 |
+
str
|
| 68 |
+
"""
|
| 69 |
+
sign = match.group(1)
|
| 70 |
+
percent = match.group(2)
|
| 71 |
+
sign: str = "负" if sign else ""
|
| 72 |
+
percent: str = num2str(percent)
|
| 73 |
+
result = f"{sign}百分之{percent}"
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# 整数表达式
|
| 78 |
+
# 带负号的整数 -10
|
| 79 |
+
RE_INTEGER = re.compile(r"(-)" r"(\d+)")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def replace_negative_num(match) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Args:
|
| 85 |
+
match (re.Match)
|
| 86 |
+
Returns:
|
| 87 |
+
str
|
| 88 |
+
"""
|
| 89 |
+
sign = match.group(1)
|
| 90 |
+
number = match.group(2)
|
| 91 |
+
sign: str = "负" if sign else ""
|
| 92 |
+
number: str = num2str(number)
|
| 93 |
+
result = f"{sign}{number}"
|
| 94 |
+
return result
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# 编号-无符号整形
|
| 98 |
+
# 00078
|
| 99 |
+
RE_DEFAULT_NUM = re.compile(r"\d{3}\d*")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def replace_default_num(match):
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
match (re.Match)
|
| 106 |
+
Returns:
|
| 107 |
+
str
|
| 108 |
+
"""
|
| 109 |
+
number = match.group(0)
|
| 110 |
+
return verbalize_digit(number, alt_one=True)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# 加减乘除
|
| 114 |
+
# RE_ASMD = re.compile(
|
| 115 |
+
# r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\×÷=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
|
| 116 |
+
RE_ASMD = re.compile(
|
| 117 |
+
r"((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))([\+\-\×÷=])((-?)((\d+)(\.\d+)?[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|(\.\d+[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*)|([A-Za-z][⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]*))"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
asmd_map = {"+": "加", "-": "减", "×": "乘", "÷": "除", "=": "等于"}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def replace_asmd(match) -> str:
|
| 124 |
+
"""
|
| 125 |
+
Args:
|
| 126 |
+
match (re.Match)
|
| 127 |
+
Returns:
|
| 128 |
+
str
|
| 129 |
+
"""
|
| 130 |
+
result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# 次方专项
|
| 135 |
+
RE_POWER = re.compile(r"[⁰¹²³⁴⁵⁶⁷⁸⁹ˣʸⁿ]+")
|
| 136 |
+
|
| 137 |
+
power_map = {
|
| 138 |
+
"⁰": "0",
|
| 139 |
+
"¹": "1",
|
| 140 |
+
"²": "2",
|
| 141 |
+
"³": "3",
|
| 142 |
+
"⁴": "4",
|
| 143 |
+
"⁵": "5",
|
| 144 |
+
"⁶": "6",
|
| 145 |
+
"⁷": "7",
|
| 146 |
+
"⁸": "8",
|
| 147 |
+
"⁹": "9",
|
| 148 |
+
"ˣ": "x",
|
| 149 |
+
"ʸ": "y",
|
| 150 |
+
"ⁿ": "n",
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def replace_power(match) -> str:
|
| 155 |
+
"""
|
| 156 |
+
Args:
|
| 157 |
+
match (re.Match)
|
| 158 |
+
Returns:
|
| 159 |
+
str
|
| 160 |
+
"""
|
| 161 |
+
power_num = ""
|
| 162 |
+
for m in match.group(0):
|
| 163 |
+
power_num += power_map[m]
|
| 164 |
+
result = "的" + power_num + "次方"
|
| 165 |
+
return result
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# 数字表达式
|
| 169 |
+
# 纯小数
|
| 170 |
+
RE_DECIMAL_NUM = re.compile(r"(-?)((\d+)(\.\d+))" r"|(\.(\d+))")
|
| 171 |
+
# 正整数 + 量词
|
| 172 |
+
RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([多余几\+])?" + COM_QUANTIFIERS)
|
| 173 |
+
RE_NUMBER = re.compile(r"(-?)((\d+)(\.\d+)?)" r"|(\.(\d+))")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def replace_positive_quantifier(match) -> str:
|
| 177 |
+
"""
|
| 178 |
+
Args:
|
| 179 |
+
match (re.Match)
|
| 180 |
+
Returns:
|
| 181 |
+
str
|
| 182 |
+
"""
|
| 183 |
+
number = match.group(1)
|
| 184 |
+
match_2 = match.group(2)
|
| 185 |
+
if match_2 == "+":
|
| 186 |
+
match_2 = "多"
|
| 187 |
+
match_2: str = match_2 if match_2 else ""
|
| 188 |
+
quantifiers: str = match.group(3)
|
| 189 |
+
number: str = num2str(number)
|
| 190 |
+
result = f"{number}{match_2}{quantifiers}"
|
| 191 |
+
return result
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def replace_number(match) -> str:
|
| 195 |
+
"""
|
| 196 |
+
Args:
|
| 197 |
+
match (re.Match)
|
| 198 |
+
Returns:
|
| 199 |
+
str
|
| 200 |
+
"""
|
| 201 |
+
sign = match.group(1)
|
| 202 |
+
number = match.group(2)
|
| 203 |
+
pure_decimal = match.group(5)
|
| 204 |
+
if pure_decimal:
|
| 205 |
+
result = num2str(pure_decimal)
|
| 206 |
+
else:
|
| 207 |
+
sign: str = "负" if sign else ""
|
| 208 |
+
number: str = num2str(number)
|
| 209 |
+
result = f"{sign}{number}"
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# 范围表达式
|
| 214 |
+
# match.group(1) and match.group(8) are copy from RE_NUMBER
|
| 215 |
+
|
| 216 |
+
RE_RANGE = re.compile(
|
| 217 |
+
r"""
|
| 218 |
+
(?<![\d\+\-\×÷=]) # 使用反向前瞻以确保数字范围之前没有其他数字和操作符
|
| 219 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围起始的负数或正数(整数或小数)
|
| 220 |
+
[-~] # 匹配范围分隔符
|
| 221 |
+
((-?)((\d+)(\.\d+)?)) # 匹配范围结束的负数或正数(整数或小数)
|
| 222 |
+
(?![\d\+\-\×÷=]) # 使用正向前瞻以确保数字范围之后没有其他数字和操作符
|
| 223 |
+
""",
|
| 224 |
+
re.VERBOSE,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def replace_range(match) -> str:
|
| 229 |
+
"""
|
| 230 |
+
Args:
|
| 231 |
+
match (re.Match)
|
| 232 |
+
Returns:
|
| 233 |
+
str
|
| 234 |
+
"""
|
| 235 |
+
first, second = match.group(1), match.group(6)
|
| 236 |
+
first = RE_NUMBER.sub(replace_number, first)
|
| 237 |
+
second = RE_NUMBER.sub(replace_number, second)
|
| 238 |
+
result = f"{first}到{second}"
|
| 239 |
+
return result
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ~至表达式
|
| 243 |
+
RE_TO_RANGE = re.compile(
|
| 244 |
+
r"((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|°C|℃|度|摄氏度|cm2|cm²|cm3|cm³|cm|db|ds|kg|km|m2|m²|m³|m3|ml|m|mm|s)"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def replace_to_range(match) -> str:
|
| 249 |
+
"""
|
| 250 |
+
Args:
|
| 251 |
+
match (re.Match)
|
| 252 |
+
Returns:
|
| 253 |
+
str
|
| 254 |
+
"""
|
| 255 |
+
result = match.group(0).replace("~", "至")
|
| 256 |
+
return result
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
|
| 260 |
+
stripped = value_string.lstrip("0")
|
| 261 |
+
if len(stripped) == 0:
|
| 262 |
+
return []
|
| 263 |
+
elif len(stripped) == 1:
|
| 264 |
+
if use_zero and len(stripped) < len(value_string):
|
| 265 |
+
return [DIGITS["0"], DIGITS[stripped]]
|
| 266 |
+
else:
|
| 267 |
+
return [DIGITS[stripped]]
|
| 268 |
+
else:
|
| 269 |
+
largest_unit = next(
|
| 270 |
+
power for power in reversed(UNITS.keys()) if power < len(stripped)
|
| 271 |
+
)
|
| 272 |
+
first_part = value_string[:-largest_unit]
|
| 273 |
+
second_part = value_string[-largest_unit:]
|
| 274 |
+
return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(second_part)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def verbalize_cardinal(value_string: str) -> str:
|
| 278 |
+
if not value_string:
|
| 279 |
+
return ""
|
| 280 |
+
|
| 281 |
+
# 000 -> '零' , 0 -> '零'
|
| 282 |
+
value_string = value_string.lstrip("0")
|
| 283 |
+
if len(value_string) == 0:
|
| 284 |
+
return DIGITS["0"]
|
| 285 |
+
|
| 286 |
+
result_symbols = _get_value(value_string)
|
| 287 |
+
# verbalized number starting with '一十*' is abbreviated as `十*`
|
| 288 |
+
if (
|
| 289 |
+
len(result_symbols) >= 2
|
| 290 |
+
and result_symbols[0] == DIGITS["1"]
|
| 291 |
+
and result_symbols[1] == UNITS[1]
|
| 292 |
+
):
|
| 293 |
+
result_symbols = result_symbols[1:]
|
| 294 |
+
return "".join(result_symbols)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def verbalize_digit(value_string: str, alt_one=False) -> str:
|
| 298 |
+
result_symbols = [DIGITS[digit] for digit in value_string]
|
| 299 |
+
result = "".join(result_symbols)
|
| 300 |
+
if alt_one:
|
| 301 |
+
result = result.replace("一", "幺")
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def num2str(value_string: str) -> str:
|
| 306 |
+
integer_decimal = value_string.split(".")
|
| 307 |
+
if len(integer_decimal) == 1:
|
| 308 |
+
integer = integer_decimal[0]
|
| 309 |
+
decimal = ""
|
| 310 |
+
elif len(integer_decimal) == 2:
|
| 311 |
+
integer, decimal = integer_decimal
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(
|
| 314 |
+
f"The value string: '${value_string}' has more than one point in it."
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
result = verbalize_cardinal(integer)
|
| 318 |
+
|
| 319 |
+
decimal = decimal.rstrip("0")
|
| 320 |
+
if decimal:
|
| 321 |
+
# '.22' is verbalized as '零点二二'
|
| 322 |
+
# '3.20' is verbalized as '三点二
|
| 323 |
+
result = result if result else "零"
|
| 324 |
+
result += "点" + verbalize_digit(decimal)
|
| 325 |
+
return result
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
text = ""
|
| 330 |
+
text = num2str(text)
|
| 331 |
+
print(text)
|
| 332 |
+
pass
|
src/YingMusicSinger/utils/stable_audio_tools/__init__.py
ADDED
|
File without changes
|
src/YingMusicSinger/utils/stable_audio_tools/adp.py
ADDED
|
@@ -0,0 +1,1686 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
|
| 2 |
+
# License can be found in LICENSES/LICENSE_ADP.txt
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from inspect import isfunction
|
| 6 |
+
from math import ceil, floor, log, log2, pi
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from dac.nn.layers import Snake1d
|
| 12 |
+
from einops import rearrange, reduce, repeat
|
| 13 |
+
from einops.layers.torch import Rearrange
|
| 14 |
+
from einops_exts import rearrange_many
|
| 15 |
+
from packaging import version
|
| 16 |
+
from torch import Tensor, einsum
|
| 17 |
+
from torch.backends.cuda import sdp_kernel
|
| 18 |
+
from torch.nn import functional as F
|
| 19 |
+
|
| 20 |
+
"""
|
| 21 |
+
Utils
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ConditionedSequential(nn.Module):
|
| 26 |
+
def __init__(self, *modules):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.module_list = nn.ModuleList(*modules)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
|
| 31 |
+
for module in self.module_list:
|
| 32 |
+
x = module(x, mapping)
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
T = TypeVar("T")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
| 40 |
+
if exists(val):
|
| 41 |
+
return val
|
| 42 |
+
return d() if isfunction(d) else d
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def exists(val: Optional[T]) -> T:
|
| 46 |
+
return val is not None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def closest_power_2(x: float) -> int:
|
| 50 |
+
exponent = log2(x)
|
| 51 |
+
distance_fn = lambda z: abs(x - 2**z) # noqa
|
| 52 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
| 53 |
+
return 2 ** int(exponent_closest)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
| 57 |
+
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
| 58 |
+
for key in d.keys():
|
| 59 |
+
no_prefix = int(not key.startswith(prefix))
|
| 60 |
+
return_dicts[no_prefix][key] = d[key]
|
| 61 |
+
return return_dicts
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
| 65 |
+
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
| 66 |
+
if keep_prefix:
|
| 67 |
+
return kwargs_with_prefix, kwargs
|
| 68 |
+
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
| 69 |
+
return kwargs_no_prefix, kwargs
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
"""
|
| 73 |
+
Convolutional Blocks
|
| 74 |
+
"""
|
| 75 |
+
import typing as tp
|
| 76 |
+
|
| 77 |
+
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
|
| 78 |
+
# License available in LICENSES/LICENSE_META.txt
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_extra_padding_for_conv1d(
|
| 82 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 83 |
+
) -> int:
|
| 84 |
+
"""See `pad_for_conv1d`."""
|
| 85 |
+
length = x.shape[-1]
|
| 86 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
| 87 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| 88 |
+
return ideal_length - length
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def pad_for_conv1d(
|
| 92 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
| 93 |
+
):
|
| 94 |
+
"""Pad for a convolution to make sure that the last window is full.
|
| 95 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
| 96 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
| 97 |
+
might get removed.
|
| 98 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
| 99 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
| 100 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
| 101 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
| 102 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
| 103 |
+
"""
|
| 104 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
| 105 |
+
return F.pad(x, (0, extra_padding))
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def pad1d(
|
| 109 |
+
x: torch.Tensor,
|
| 110 |
+
paddings: tp.Tuple[int, int],
|
| 111 |
+
mode: str = "constant",
|
| 112 |
+
value: float = 0.0,
|
| 113 |
+
):
|
| 114 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
| 115 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
| 116 |
+
"""
|
| 117 |
+
length = x.shape[-1]
|
| 118 |
+
padding_left, padding_right = paddings
|
| 119 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 120 |
+
if mode == "reflect":
|
| 121 |
+
max_pad = max(padding_left, padding_right)
|
| 122 |
+
extra_pad = 0
|
| 123 |
+
if length <= max_pad:
|
| 124 |
+
extra_pad = max_pad - length + 1
|
| 125 |
+
x = F.pad(x, (0, extra_pad))
|
| 126 |
+
padded = F.pad(x, paddings, mode, value)
|
| 127 |
+
end = padded.shape[-1] - extra_pad
|
| 128 |
+
return padded[..., :end]
|
| 129 |
+
else:
|
| 130 |
+
return F.pad(x, paddings, mode, value)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| 134 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| 135 |
+
padding_left, padding_right = paddings
|
| 136 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| 137 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
| 138 |
+
end = x.shape[-1] - padding_right
|
| 139 |
+
return x[..., padding_left:end]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Conv1d(nn.Conv1d):
|
| 143 |
+
def __init__(self, *args, **kwargs):
|
| 144 |
+
super().__init__(*args, **kwargs)
|
| 145 |
+
|
| 146 |
+
def forward(self, x: Tensor, causal=False) -> Tensor:
|
| 147 |
+
kernel_size = self.kernel_size[0]
|
| 148 |
+
stride = self.stride[0]
|
| 149 |
+
dilation = self.dilation[0]
|
| 150 |
+
kernel_size = (
|
| 151 |
+
kernel_size - 1
|
| 152 |
+
) * dilation + 1 # effective kernel size with dilations
|
| 153 |
+
padding_total = kernel_size - stride
|
| 154 |
+
extra_padding = get_extra_padding_for_conv1d(
|
| 155 |
+
x, kernel_size, stride, padding_total
|
| 156 |
+
)
|
| 157 |
+
if causal:
|
| 158 |
+
# Left padding for causal
|
| 159 |
+
x = pad1d(x, (padding_total, extra_padding))
|
| 160 |
+
else:
|
| 161 |
+
# Asymmetric padding required for odd strides
|
| 162 |
+
padding_right = padding_total // 2
|
| 163 |
+
padding_left = padding_total - padding_right
|
| 164 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding))
|
| 165 |
+
return super().forward(x)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 169 |
+
def __init__(self, *args, **kwargs):
|
| 170 |
+
super().__init__(*args, **kwargs)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: Tensor, causal=False) -> Tensor:
|
| 173 |
+
kernel_size = self.kernel_size[0]
|
| 174 |
+
stride = self.stride[0]
|
| 175 |
+
padding_total = kernel_size - stride
|
| 176 |
+
|
| 177 |
+
y = super().forward(x)
|
| 178 |
+
|
| 179 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
| 180 |
+
# removed at the very end, when keeping only the right length for the output,
|
| 181 |
+
# as removing it here would require also passing the length at the matching layer
|
| 182 |
+
# in the encoder.
|
| 183 |
+
if causal:
|
| 184 |
+
padding_right = ceil(padding_total)
|
| 185 |
+
padding_left = padding_total - padding_right
|
| 186 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 187 |
+
else:
|
| 188 |
+
# Asymmetric padding required for odd strides
|
| 189 |
+
padding_right = padding_total // 2
|
| 190 |
+
padding_left = padding_total - padding_right
|
| 191 |
+
y = unpad1d(y, (padding_left, padding_right))
|
| 192 |
+
return y
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def Downsample1d(
|
| 196 |
+
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
| 197 |
+
) -> nn.Module:
|
| 198 |
+
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
| 199 |
+
|
| 200 |
+
return Conv1d(
|
| 201 |
+
in_channels=in_channels,
|
| 202 |
+
out_channels=out_channels,
|
| 203 |
+
kernel_size=factor * kernel_multiplier + 1,
|
| 204 |
+
stride=factor,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def Upsample1d(
|
| 209 |
+
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
| 210 |
+
) -> nn.Module:
|
| 211 |
+
if factor == 1:
|
| 212 |
+
return Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3)
|
| 213 |
+
|
| 214 |
+
if use_nearest:
|
| 215 |
+
return nn.Sequential(
|
| 216 |
+
nn.Upsample(scale_factor=factor, mode="nearest"),
|
| 217 |
+
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
return ConvTranspose1d(
|
| 221 |
+
in_channels=in_channels,
|
| 222 |
+
out_channels=out_channels,
|
| 223 |
+
kernel_size=factor * 2,
|
| 224 |
+
stride=factor,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
class ConvBlock1d(nn.Module):
|
| 229 |
+
def __init__(
|
| 230 |
+
self,
|
| 231 |
+
in_channels: int,
|
| 232 |
+
out_channels: int,
|
| 233 |
+
*,
|
| 234 |
+
kernel_size: int = 3,
|
| 235 |
+
stride: int = 1,
|
| 236 |
+
dilation: int = 1,
|
| 237 |
+
num_groups: int = 8,
|
| 238 |
+
use_norm: bool = True,
|
| 239 |
+
use_snake: bool = False,
|
| 240 |
+
) -> None:
|
| 241 |
+
super().__init__()
|
| 242 |
+
|
| 243 |
+
self.groupnorm = (
|
| 244 |
+
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
|
| 245 |
+
if use_norm
|
| 246 |
+
else nn.Identity()
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if use_snake:
|
| 250 |
+
self.activation = Snake1d(in_channels)
|
| 251 |
+
else:
|
| 252 |
+
self.activation = nn.SiLU()
|
| 253 |
+
|
| 254 |
+
self.project = Conv1d(
|
| 255 |
+
in_channels=in_channels,
|
| 256 |
+
out_channels=out_channels,
|
| 257 |
+
kernel_size=kernel_size,
|
| 258 |
+
stride=stride,
|
| 259 |
+
dilation=dilation,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def forward(
|
| 263 |
+
self,
|
| 264 |
+
x: Tensor,
|
| 265 |
+
scale_shift: Optional[Tuple[Tensor, Tensor]] = None,
|
| 266 |
+
causal=False,
|
| 267 |
+
) -> Tensor:
|
| 268 |
+
x = self.groupnorm(x)
|
| 269 |
+
if exists(scale_shift):
|
| 270 |
+
scale, shift = scale_shift
|
| 271 |
+
x = x * (scale + 1) + shift
|
| 272 |
+
x = self.activation(x)
|
| 273 |
+
return self.project(x, causal=causal)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class MappingToScaleShift(nn.Module):
|
| 277 |
+
def __init__(
|
| 278 |
+
self,
|
| 279 |
+
features: int,
|
| 280 |
+
channels: int,
|
| 281 |
+
):
|
| 282 |
+
super().__init__()
|
| 283 |
+
|
| 284 |
+
self.to_scale_shift = nn.Sequential(
|
| 285 |
+
nn.SiLU(),
|
| 286 |
+
nn.Linear(in_features=features, out_features=channels * 2),
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
|
| 290 |
+
scale_shift = self.to_scale_shift(mapping)
|
| 291 |
+
scale_shift = rearrange(scale_shift, "b c -> b c 1")
|
| 292 |
+
scale, shift = scale_shift.chunk(2, dim=1)
|
| 293 |
+
return scale, shift
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class ResnetBlock1d(nn.Module):
|
| 297 |
+
def __init__(
|
| 298 |
+
self,
|
| 299 |
+
in_channels: int,
|
| 300 |
+
out_channels: int,
|
| 301 |
+
*,
|
| 302 |
+
kernel_size: int = 3,
|
| 303 |
+
stride: int = 1,
|
| 304 |
+
dilation: int = 1,
|
| 305 |
+
use_norm: bool = True,
|
| 306 |
+
use_snake: bool = False,
|
| 307 |
+
num_groups: int = 8,
|
| 308 |
+
context_mapping_features: Optional[int] = None,
|
| 309 |
+
) -> None:
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
self.use_mapping = exists(context_mapping_features)
|
| 313 |
+
|
| 314 |
+
self.block1 = ConvBlock1d(
|
| 315 |
+
in_channels=in_channels,
|
| 316 |
+
out_channels=out_channels,
|
| 317 |
+
kernel_size=kernel_size,
|
| 318 |
+
stride=stride,
|
| 319 |
+
dilation=dilation,
|
| 320 |
+
use_norm=use_norm,
|
| 321 |
+
num_groups=num_groups,
|
| 322 |
+
use_snake=use_snake,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if self.use_mapping:
|
| 326 |
+
assert exists(context_mapping_features)
|
| 327 |
+
self.to_scale_shift = MappingToScaleShift(
|
| 328 |
+
features=context_mapping_features, channels=out_channels
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.block2 = ConvBlock1d(
|
| 332 |
+
in_channels=out_channels,
|
| 333 |
+
out_channels=out_channels,
|
| 334 |
+
use_norm=use_norm,
|
| 335 |
+
num_groups=num_groups,
|
| 336 |
+
use_snake=use_snake,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
self.to_out = (
|
| 340 |
+
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
|
| 341 |
+
if in_channels != out_channels
|
| 342 |
+
else nn.Identity()
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
|
| 347 |
+
) -> Tensor:
|
| 348 |
+
assert_message = "context mapping required if context_mapping_features > 0"
|
| 349 |
+
assert not (self.use_mapping ^ exists(mapping)), assert_message
|
| 350 |
+
|
| 351 |
+
h = self.block1(x, causal=causal)
|
| 352 |
+
|
| 353 |
+
scale_shift = None
|
| 354 |
+
if self.use_mapping:
|
| 355 |
+
scale_shift = self.to_scale_shift(mapping)
|
| 356 |
+
|
| 357 |
+
h = self.block2(h, scale_shift=scale_shift, causal=causal)
|
| 358 |
+
|
| 359 |
+
return h + self.to_out(x)
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class Patcher(nn.Module):
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
in_channels: int,
|
| 366 |
+
out_channels: int,
|
| 367 |
+
patch_size: int,
|
| 368 |
+
context_mapping_features: Optional[int] = None,
|
| 369 |
+
use_snake: bool = False,
|
| 370 |
+
):
|
| 371 |
+
super().__init__()
|
| 372 |
+
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
|
| 373 |
+
assert out_channels % patch_size == 0, assert_message
|
| 374 |
+
self.patch_size = patch_size
|
| 375 |
+
|
| 376 |
+
self.block = ResnetBlock1d(
|
| 377 |
+
in_channels=in_channels,
|
| 378 |
+
out_channels=out_channels // patch_size,
|
| 379 |
+
num_groups=1,
|
| 380 |
+
context_mapping_features=context_mapping_features,
|
| 381 |
+
use_snake=use_snake,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def forward(
|
| 385 |
+
self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
|
| 386 |
+
) -> Tensor:
|
| 387 |
+
x = self.block(x, mapping, causal=causal)
|
| 388 |
+
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
|
| 389 |
+
return x
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Unpatcher(nn.Module):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
in_channels: int,
|
| 396 |
+
out_channels: int,
|
| 397 |
+
patch_size: int,
|
| 398 |
+
context_mapping_features: Optional[int] = None,
|
| 399 |
+
use_snake: bool = False,
|
| 400 |
+
):
|
| 401 |
+
super().__init__()
|
| 402 |
+
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
|
| 403 |
+
assert in_channels % patch_size == 0, assert_message
|
| 404 |
+
self.patch_size = patch_size
|
| 405 |
+
|
| 406 |
+
self.block = ResnetBlock1d(
|
| 407 |
+
in_channels=in_channels // patch_size,
|
| 408 |
+
out_channels=out_channels,
|
| 409 |
+
num_groups=1,
|
| 410 |
+
context_mapping_features=context_mapping_features,
|
| 411 |
+
use_snake=use_snake,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def forward(
|
| 415 |
+
self, x: Tensor, mapping: Optional[Tensor] = None, causal=False
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
|
| 418 |
+
x = self.block(x, mapping, causal=causal)
|
| 419 |
+
return x
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
"""
|
| 423 |
+
Attention Components
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
| 428 |
+
mid_features = features * multiplier
|
| 429 |
+
return nn.Sequential(
|
| 430 |
+
nn.Linear(in_features=features, out_features=mid_features),
|
| 431 |
+
nn.GELU(),
|
| 432 |
+
nn.Linear(in_features=mid_features, out_features=features),
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
|
| 437 |
+
b, ndim = sim.shape[0], mask.ndim
|
| 438 |
+
if ndim == 3:
|
| 439 |
+
mask = rearrange(mask, "b n m -> b 1 n m")
|
| 440 |
+
if ndim == 2:
|
| 441 |
+
mask = repeat(mask, "n m -> b 1 n m", b=b)
|
| 442 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 443 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
| 444 |
+
return sim
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
|
| 448 |
+
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
|
| 449 |
+
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
|
| 450 |
+
mask = repeat(mask, "n m -> b n m", b=b)
|
| 451 |
+
return mask
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class AttentionBase(nn.Module):
|
| 455 |
+
def __init__(
|
| 456 |
+
self,
|
| 457 |
+
features: int,
|
| 458 |
+
*,
|
| 459 |
+
head_features: int,
|
| 460 |
+
num_heads: int,
|
| 461 |
+
out_features: Optional[int] = None,
|
| 462 |
+
):
|
| 463 |
+
super().__init__()
|
| 464 |
+
self.scale = head_features**-0.5
|
| 465 |
+
self.num_heads = num_heads
|
| 466 |
+
mid_features = head_features * num_heads
|
| 467 |
+
out_features = default(out_features, features)
|
| 468 |
+
|
| 469 |
+
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
|
| 470 |
+
|
| 471 |
+
self.use_flash = torch.cuda.is_available() and version.parse(
|
| 472 |
+
torch.__version__
|
| 473 |
+
) >= version.parse("2.0.0")
|
| 474 |
+
|
| 475 |
+
if not self.use_flash:
|
| 476 |
+
return
|
| 477 |
+
|
| 478 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 479 |
+
|
| 480 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
| 481 |
+
# Use flash attention for A100 GPUs
|
| 482 |
+
self.sdp_kernel_config = (True, False, False)
|
| 483 |
+
else:
|
| 484 |
+
# Don't use flash attention for other GPUs
|
| 485 |
+
self.sdp_kernel_config = (False, True, True)
|
| 486 |
+
|
| 487 |
+
def forward(
|
| 488 |
+
self,
|
| 489 |
+
q: Tensor,
|
| 490 |
+
k: Tensor,
|
| 491 |
+
v: Tensor,
|
| 492 |
+
mask: Optional[Tensor] = None,
|
| 493 |
+
is_causal: bool = False,
|
| 494 |
+
) -> Tensor:
|
| 495 |
+
# Split heads
|
| 496 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
| 497 |
+
|
| 498 |
+
if not self.use_flash:
|
| 499 |
+
if is_causal and not mask:
|
| 500 |
+
# Mask out future tokens for causal attention
|
| 501 |
+
mask = causal_mask(q, k)
|
| 502 |
+
|
| 503 |
+
# Compute similarity matrix and add eventual mask
|
| 504 |
+
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
|
| 505 |
+
sim = add_mask(sim, mask) if exists(mask) else sim
|
| 506 |
+
|
| 507 |
+
# Get attention matrix with softmax
|
| 508 |
+
attn = sim.softmax(dim=-1, dtype=torch.float32)
|
| 509 |
+
|
| 510 |
+
# Compute values
|
| 511 |
+
out = einsum("... n m, ... m d -> ... n d", attn, v)
|
| 512 |
+
else:
|
| 513 |
+
with sdp_kernel(*self.sdp_kernel_config):
|
| 514 |
+
out = F.scaled_dot_product_attention(
|
| 515 |
+
q, k, v, attn_mask=mask, is_causal=is_causal
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 519 |
+
return self.to_out(out)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
class Attention(nn.Module):
|
| 523 |
+
def __init__(
|
| 524 |
+
self,
|
| 525 |
+
features: int,
|
| 526 |
+
*,
|
| 527 |
+
head_features: int,
|
| 528 |
+
num_heads: int,
|
| 529 |
+
out_features: Optional[int] = None,
|
| 530 |
+
context_features: Optional[int] = None,
|
| 531 |
+
causal: bool = False,
|
| 532 |
+
):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.context_features = context_features
|
| 535 |
+
self.causal = causal
|
| 536 |
+
mid_features = head_features * num_heads
|
| 537 |
+
context_features = default(context_features, features)
|
| 538 |
+
|
| 539 |
+
self.norm = nn.LayerNorm(features)
|
| 540 |
+
self.norm_context = nn.LayerNorm(context_features)
|
| 541 |
+
self.to_q = nn.Linear(
|
| 542 |
+
in_features=features, out_features=mid_features, bias=False
|
| 543 |
+
)
|
| 544 |
+
self.to_kv = nn.Linear(
|
| 545 |
+
in_features=context_features, out_features=mid_features * 2, bias=False
|
| 546 |
+
)
|
| 547 |
+
self.attention = AttentionBase(
|
| 548 |
+
features,
|
| 549 |
+
num_heads=num_heads,
|
| 550 |
+
head_features=head_features,
|
| 551 |
+
out_features=out_features,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def forward(
|
| 555 |
+
self,
|
| 556 |
+
x: Tensor, # [b, n, c]
|
| 557 |
+
context: Optional[Tensor] = None, # [b, m, d]
|
| 558 |
+
context_mask: Optional[Tensor] = None, # [b, m], false is masked,
|
| 559 |
+
causal: Optional[bool] = False,
|
| 560 |
+
) -> Tensor:
|
| 561 |
+
assert_message = "You must provide a context when using context_features"
|
| 562 |
+
assert not self.context_features or exists(context), assert_message
|
| 563 |
+
# Use context if provided
|
| 564 |
+
context = default(context, x)
|
| 565 |
+
# Normalize then compute q from input and k,v from context
|
| 566 |
+
x, context = self.norm(x), self.norm_context(context)
|
| 567 |
+
|
| 568 |
+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
| 569 |
+
|
| 570 |
+
if exists(context_mask):
|
| 571 |
+
# Mask out cross-attention for padding tokens
|
| 572 |
+
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
|
| 573 |
+
k, v = k * mask, v * mask
|
| 574 |
+
|
| 575 |
+
# Compute and return attention
|
| 576 |
+
return self.attention(q, k, v, is_causal=self.causal or causal)
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def FeedForward(features: int, multiplier: int) -> nn.Module:
|
| 580 |
+
mid_features = features * multiplier
|
| 581 |
+
return nn.Sequential(
|
| 582 |
+
nn.Linear(in_features=features, out_features=mid_features),
|
| 583 |
+
nn.GELU(),
|
| 584 |
+
nn.Linear(in_features=mid_features, out_features=features),
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
"""
|
| 589 |
+
Transformer Blocks
|
| 590 |
+
"""
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
class TransformerBlock(nn.Module):
|
| 594 |
+
def __init__(
|
| 595 |
+
self,
|
| 596 |
+
features: int,
|
| 597 |
+
num_heads: int,
|
| 598 |
+
head_features: int,
|
| 599 |
+
multiplier: int,
|
| 600 |
+
context_features: Optional[int] = None,
|
| 601 |
+
):
|
| 602 |
+
super().__init__()
|
| 603 |
+
|
| 604 |
+
self.use_cross_attention = exists(context_features) and context_features > 0
|
| 605 |
+
|
| 606 |
+
self.attention = Attention(
|
| 607 |
+
features=features, num_heads=num_heads, head_features=head_features
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if self.use_cross_attention:
|
| 611 |
+
self.cross_attention = Attention(
|
| 612 |
+
features=features,
|
| 613 |
+
num_heads=num_heads,
|
| 614 |
+
head_features=head_features,
|
| 615 |
+
context_features=context_features,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
| 619 |
+
|
| 620 |
+
def forward(
|
| 621 |
+
self,
|
| 622 |
+
x: Tensor,
|
| 623 |
+
*,
|
| 624 |
+
context: Optional[Tensor] = None,
|
| 625 |
+
context_mask: Optional[Tensor] = None,
|
| 626 |
+
causal: Optional[bool] = False,
|
| 627 |
+
) -> Tensor:
|
| 628 |
+
x = self.attention(x, causal=causal) + x
|
| 629 |
+
if self.use_cross_attention:
|
| 630 |
+
x = self.cross_attention(x, context=context, context_mask=context_mask) + x
|
| 631 |
+
x = self.feed_forward(x) + x
|
| 632 |
+
return x
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
"""
|
| 636 |
+
Transformers
|
| 637 |
+
"""
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
class Transformer1d(nn.Module):
|
| 641 |
+
def __init__(
|
| 642 |
+
self,
|
| 643 |
+
num_layers: int,
|
| 644 |
+
channels: int,
|
| 645 |
+
num_heads: int,
|
| 646 |
+
head_features: int,
|
| 647 |
+
multiplier: int,
|
| 648 |
+
context_features: Optional[int] = None,
|
| 649 |
+
):
|
| 650 |
+
super().__init__()
|
| 651 |
+
|
| 652 |
+
self.to_in = nn.Sequential(
|
| 653 |
+
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
|
| 654 |
+
Conv1d(
|
| 655 |
+
in_channels=channels,
|
| 656 |
+
out_channels=channels,
|
| 657 |
+
kernel_size=1,
|
| 658 |
+
),
|
| 659 |
+
Rearrange("b c t -> b t c"),
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
self.blocks = nn.ModuleList(
|
| 663 |
+
[
|
| 664 |
+
TransformerBlock(
|
| 665 |
+
features=channels,
|
| 666 |
+
head_features=head_features,
|
| 667 |
+
num_heads=num_heads,
|
| 668 |
+
multiplier=multiplier,
|
| 669 |
+
context_features=context_features,
|
| 670 |
+
)
|
| 671 |
+
for i in range(num_layers)
|
| 672 |
+
]
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
self.to_out = nn.Sequential(
|
| 676 |
+
Rearrange("b t c -> b c t"),
|
| 677 |
+
Conv1d(
|
| 678 |
+
in_channels=channels,
|
| 679 |
+
out_channels=channels,
|
| 680 |
+
kernel_size=1,
|
| 681 |
+
),
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
def forward(
|
| 685 |
+
self,
|
| 686 |
+
x: Tensor,
|
| 687 |
+
*,
|
| 688 |
+
context: Optional[Tensor] = None,
|
| 689 |
+
context_mask: Optional[Tensor] = None,
|
| 690 |
+
causal=False,
|
| 691 |
+
) -> Tensor:
|
| 692 |
+
x = self.to_in(x)
|
| 693 |
+
for block in self.blocks:
|
| 694 |
+
x = block(x, context=context, context_mask=context_mask, causal=causal)
|
| 695 |
+
x = self.to_out(x)
|
| 696 |
+
return x
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
"""
|
| 700 |
+
Time Embeddings
|
| 701 |
+
"""
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
class SinusoidalEmbedding(nn.Module):
|
| 705 |
+
def __init__(self, dim: int):
|
| 706 |
+
super().__init__()
|
| 707 |
+
self.dim = dim
|
| 708 |
+
|
| 709 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 710 |
+
device, half_dim = x.device, self.dim // 2
|
| 711 |
+
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
|
| 712 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 713 |
+
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
|
| 714 |
+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
class LearnedPositionalEmbedding(nn.Module):
|
| 718 |
+
"""Used for continuous time"""
|
| 719 |
+
|
| 720 |
+
def __init__(self, dim: int):
|
| 721 |
+
super().__init__()
|
| 722 |
+
assert (dim % 2) == 0
|
| 723 |
+
half_dim = dim // 2
|
| 724 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
| 725 |
+
|
| 726 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 727 |
+
x = rearrange(x, "b -> b 1")
|
| 728 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
| 729 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
| 730 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
| 731 |
+
return fouriered
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
|
| 735 |
+
return nn.Sequential(
|
| 736 |
+
LearnedPositionalEmbedding(dim),
|
| 737 |
+
nn.Linear(in_features=dim + 1, out_features=out_features),
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
"""
|
| 742 |
+
Encoder/Decoder Components
|
| 743 |
+
"""
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
class DownsampleBlock1d(nn.Module):
|
| 747 |
+
def __init__(
|
| 748 |
+
self,
|
| 749 |
+
in_channels: int,
|
| 750 |
+
out_channels: int,
|
| 751 |
+
*,
|
| 752 |
+
factor: int,
|
| 753 |
+
num_groups: int,
|
| 754 |
+
num_layers: int,
|
| 755 |
+
kernel_multiplier: int = 2,
|
| 756 |
+
use_pre_downsample: bool = True,
|
| 757 |
+
use_skip: bool = False,
|
| 758 |
+
use_snake: bool = False,
|
| 759 |
+
extract_channels: int = 0,
|
| 760 |
+
context_channels: int = 0,
|
| 761 |
+
num_transformer_blocks: int = 0,
|
| 762 |
+
attention_heads: Optional[int] = None,
|
| 763 |
+
attention_features: Optional[int] = None,
|
| 764 |
+
attention_multiplier: Optional[int] = None,
|
| 765 |
+
context_mapping_features: Optional[int] = None,
|
| 766 |
+
context_embedding_features: Optional[int] = None,
|
| 767 |
+
):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.use_pre_downsample = use_pre_downsample
|
| 770 |
+
self.use_skip = use_skip
|
| 771 |
+
self.use_transformer = num_transformer_blocks > 0
|
| 772 |
+
self.use_extract = extract_channels > 0
|
| 773 |
+
self.use_context = context_channels > 0
|
| 774 |
+
|
| 775 |
+
channels = out_channels if use_pre_downsample else in_channels
|
| 776 |
+
|
| 777 |
+
self.downsample = Downsample1d(
|
| 778 |
+
in_channels=in_channels,
|
| 779 |
+
out_channels=out_channels,
|
| 780 |
+
factor=factor,
|
| 781 |
+
kernel_multiplier=kernel_multiplier,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
self.blocks = nn.ModuleList(
|
| 785 |
+
[
|
| 786 |
+
ResnetBlock1d(
|
| 787 |
+
in_channels=channels + context_channels if i == 0 else channels,
|
| 788 |
+
out_channels=channels,
|
| 789 |
+
num_groups=num_groups,
|
| 790 |
+
context_mapping_features=context_mapping_features,
|
| 791 |
+
use_snake=use_snake,
|
| 792 |
+
)
|
| 793 |
+
for i in range(num_layers)
|
| 794 |
+
]
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if self.use_transformer:
|
| 798 |
+
assert (exists(attention_heads) or exists(attention_features)) and exists(
|
| 799 |
+
attention_multiplier
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if attention_features is None and attention_heads is not None:
|
| 803 |
+
attention_features = channels // attention_heads
|
| 804 |
+
|
| 805 |
+
if attention_heads is None and attention_features is not None:
|
| 806 |
+
attention_heads = channels // attention_features
|
| 807 |
+
|
| 808 |
+
self.transformer = Transformer1d(
|
| 809 |
+
num_layers=num_transformer_blocks,
|
| 810 |
+
channels=channels,
|
| 811 |
+
num_heads=attention_heads,
|
| 812 |
+
head_features=attention_features,
|
| 813 |
+
multiplier=attention_multiplier,
|
| 814 |
+
context_features=context_embedding_features,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
if self.use_extract:
|
| 818 |
+
num_extract_groups = min(num_groups, extract_channels)
|
| 819 |
+
self.to_extracted = ResnetBlock1d(
|
| 820 |
+
in_channels=out_channels,
|
| 821 |
+
out_channels=extract_channels,
|
| 822 |
+
num_groups=num_extract_groups,
|
| 823 |
+
use_snake=use_snake,
|
| 824 |
+
)
|
| 825 |
+
|
| 826 |
+
def forward(
|
| 827 |
+
self,
|
| 828 |
+
x: Tensor,
|
| 829 |
+
*,
|
| 830 |
+
mapping: Optional[Tensor] = None,
|
| 831 |
+
channels: Optional[Tensor] = None,
|
| 832 |
+
embedding: Optional[Tensor] = None,
|
| 833 |
+
embedding_mask: Optional[Tensor] = None,
|
| 834 |
+
causal: Optional[bool] = False,
|
| 835 |
+
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
|
| 836 |
+
if self.use_pre_downsample:
|
| 837 |
+
x = self.downsample(x)
|
| 838 |
+
|
| 839 |
+
if self.use_context and exists(channels):
|
| 840 |
+
x = torch.cat([x, channels], dim=1)
|
| 841 |
+
|
| 842 |
+
skips = []
|
| 843 |
+
for block in self.blocks:
|
| 844 |
+
x = block(x, mapping=mapping, causal=causal)
|
| 845 |
+
skips += [x] if self.use_skip else []
|
| 846 |
+
|
| 847 |
+
if self.use_transformer:
|
| 848 |
+
x = self.transformer(
|
| 849 |
+
x, context=embedding, context_mask=embedding_mask, causal=causal
|
| 850 |
+
)
|
| 851 |
+
skips += [x] if self.use_skip else []
|
| 852 |
+
|
| 853 |
+
if not self.use_pre_downsample:
|
| 854 |
+
x = self.downsample(x)
|
| 855 |
+
|
| 856 |
+
if self.use_extract:
|
| 857 |
+
extracted = self.to_extracted(x)
|
| 858 |
+
return x, extracted
|
| 859 |
+
|
| 860 |
+
return (x, skips) if self.use_skip else x
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
class UpsampleBlock1d(nn.Module):
|
| 864 |
+
def __init__(
|
| 865 |
+
self,
|
| 866 |
+
in_channels: int,
|
| 867 |
+
out_channels: int,
|
| 868 |
+
*,
|
| 869 |
+
factor: int,
|
| 870 |
+
num_layers: int,
|
| 871 |
+
num_groups: int,
|
| 872 |
+
use_nearest: bool = False,
|
| 873 |
+
use_pre_upsample: bool = False,
|
| 874 |
+
use_skip: bool = False,
|
| 875 |
+
use_snake: bool = False,
|
| 876 |
+
skip_channels: int = 0,
|
| 877 |
+
use_skip_scale: bool = False,
|
| 878 |
+
extract_channels: int = 0,
|
| 879 |
+
num_transformer_blocks: int = 0,
|
| 880 |
+
attention_heads: Optional[int] = None,
|
| 881 |
+
attention_features: Optional[int] = None,
|
| 882 |
+
attention_multiplier: Optional[int] = None,
|
| 883 |
+
context_mapping_features: Optional[int] = None,
|
| 884 |
+
context_embedding_features: Optional[int] = None,
|
| 885 |
+
):
|
| 886 |
+
super().__init__()
|
| 887 |
+
|
| 888 |
+
self.use_extract = extract_channels > 0
|
| 889 |
+
self.use_pre_upsample = use_pre_upsample
|
| 890 |
+
self.use_transformer = num_transformer_blocks > 0
|
| 891 |
+
self.use_skip = use_skip
|
| 892 |
+
self.skip_scale = 2**-0.5 if use_skip_scale else 1.0
|
| 893 |
+
|
| 894 |
+
channels = out_channels if use_pre_upsample else in_channels
|
| 895 |
+
|
| 896 |
+
self.blocks = nn.ModuleList(
|
| 897 |
+
[
|
| 898 |
+
ResnetBlock1d(
|
| 899 |
+
in_channels=channels + skip_channels,
|
| 900 |
+
out_channels=channels,
|
| 901 |
+
num_groups=num_groups,
|
| 902 |
+
context_mapping_features=context_mapping_features,
|
| 903 |
+
use_snake=use_snake,
|
| 904 |
+
)
|
| 905 |
+
for _ in range(num_layers)
|
| 906 |
+
]
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
if self.use_transformer:
|
| 910 |
+
assert (exists(attention_heads) or exists(attention_features)) and exists(
|
| 911 |
+
attention_multiplier
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
if attention_features is None and attention_heads is not None:
|
| 915 |
+
attention_features = channels // attention_heads
|
| 916 |
+
|
| 917 |
+
if attention_heads is None and attention_features is not None:
|
| 918 |
+
attention_heads = channels // attention_features
|
| 919 |
+
|
| 920 |
+
self.transformer = Transformer1d(
|
| 921 |
+
num_layers=num_transformer_blocks,
|
| 922 |
+
channels=channels,
|
| 923 |
+
num_heads=attention_heads,
|
| 924 |
+
head_features=attention_features,
|
| 925 |
+
multiplier=attention_multiplier,
|
| 926 |
+
context_features=context_embedding_features,
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
self.upsample = Upsample1d(
|
| 930 |
+
in_channels=in_channels,
|
| 931 |
+
out_channels=out_channels,
|
| 932 |
+
factor=factor,
|
| 933 |
+
use_nearest=use_nearest,
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
if self.use_extract:
|
| 937 |
+
num_extract_groups = min(num_groups, extract_channels)
|
| 938 |
+
self.to_extracted = ResnetBlock1d(
|
| 939 |
+
in_channels=out_channels,
|
| 940 |
+
out_channels=extract_channels,
|
| 941 |
+
num_groups=num_extract_groups,
|
| 942 |
+
use_snake=use_snake,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
|
| 946 |
+
return torch.cat([x, skip * self.skip_scale], dim=1)
|
| 947 |
+
|
| 948 |
+
def forward(
|
| 949 |
+
self,
|
| 950 |
+
x: Tensor,
|
| 951 |
+
*,
|
| 952 |
+
skips: Optional[List[Tensor]] = None,
|
| 953 |
+
mapping: Optional[Tensor] = None,
|
| 954 |
+
embedding: Optional[Tensor] = None,
|
| 955 |
+
embedding_mask: Optional[Tensor] = None,
|
| 956 |
+
causal: Optional[bool] = False,
|
| 957 |
+
) -> Union[Tuple[Tensor, Tensor], Tensor]:
|
| 958 |
+
if self.use_pre_upsample:
|
| 959 |
+
x = self.upsample(x)
|
| 960 |
+
|
| 961 |
+
for block in self.blocks:
|
| 962 |
+
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
|
| 963 |
+
x = block(x, mapping=mapping, causal=causal)
|
| 964 |
+
|
| 965 |
+
if self.use_transformer:
|
| 966 |
+
x = self.transformer(
|
| 967 |
+
x, context=embedding, context_mask=embedding_mask, causal=causal
|
| 968 |
+
)
|
| 969 |
+
|
| 970 |
+
if not self.use_pre_upsample:
|
| 971 |
+
x = self.upsample(x)
|
| 972 |
+
|
| 973 |
+
if self.use_extract:
|
| 974 |
+
extracted = self.to_extracted(x)
|
| 975 |
+
return x, extracted
|
| 976 |
+
|
| 977 |
+
return x
|
| 978 |
+
|
| 979 |
+
|
| 980 |
+
class BottleneckBlock1d(nn.Module):
|
| 981 |
+
def __init__(
|
| 982 |
+
self,
|
| 983 |
+
channels: int,
|
| 984 |
+
*,
|
| 985 |
+
num_groups: int,
|
| 986 |
+
num_transformer_blocks: int = 0,
|
| 987 |
+
attention_heads: Optional[int] = None,
|
| 988 |
+
attention_features: Optional[int] = None,
|
| 989 |
+
attention_multiplier: Optional[int] = None,
|
| 990 |
+
context_mapping_features: Optional[int] = None,
|
| 991 |
+
context_embedding_features: Optional[int] = None,
|
| 992 |
+
use_snake: bool = False,
|
| 993 |
+
):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.use_transformer = num_transformer_blocks > 0
|
| 996 |
+
|
| 997 |
+
self.pre_block = ResnetBlock1d(
|
| 998 |
+
in_channels=channels,
|
| 999 |
+
out_channels=channels,
|
| 1000 |
+
num_groups=num_groups,
|
| 1001 |
+
context_mapping_features=context_mapping_features,
|
| 1002 |
+
use_snake=use_snake,
|
| 1003 |
+
)
|
| 1004 |
+
|
| 1005 |
+
if self.use_transformer:
|
| 1006 |
+
assert (exists(attention_heads) or exists(attention_features)) and exists(
|
| 1007 |
+
attention_multiplier
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
if attention_features is None and attention_heads is not None:
|
| 1011 |
+
attention_features = channels // attention_heads
|
| 1012 |
+
|
| 1013 |
+
if attention_heads is None and attention_features is not None:
|
| 1014 |
+
attention_heads = channels // attention_features
|
| 1015 |
+
|
| 1016 |
+
self.transformer = Transformer1d(
|
| 1017 |
+
num_layers=num_transformer_blocks,
|
| 1018 |
+
channels=channels,
|
| 1019 |
+
num_heads=attention_heads,
|
| 1020 |
+
head_features=attention_features,
|
| 1021 |
+
multiplier=attention_multiplier,
|
| 1022 |
+
context_features=context_embedding_features,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
self.post_block = ResnetBlock1d(
|
| 1026 |
+
in_channels=channels,
|
| 1027 |
+
out_channels=channels,
|
| 1028 |
+
num_groups=num_groups,
|
| 1029 |
+
context_mapping_features=context_mapping_features,
|
| 1030 |
+
use_snake=use_snake,
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
def forward(
|
| 1034 |
+
self,
|
| 1035 |
+
x: Tensor,
|
| 1036 |
+
*,
|
| 1037 |
+
mapping: Optional[Tensor] = None,
|
| 1038 |
+
embedding: Optional[Tensor] = None,
|
| 1039 |
+
embedding_mask: Optional[Tensor] = None,
|
| 1040 |
+
causal: Optional[bool] = False,
|
| 1041 |
+
) -> Tensor:
|
| 1042 |
+
x = self.pre_block(x, mapping=mapping, causal=causal)
|
| 1043 |
+
if self.use_transformer:
|
| 1044 |
+
x = self.transformer(
|
| 1045 |
+
x, context=embedding, context_mask=embedding_mask, causal=causal
|
| 1046 |
+
)
|
| 1047 |
+
x = self.post_block(x, mapping=mapping, causal=causal)
|
| 1048 |
+
return x
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
"""
|
| 1052 |
+
UNet
|
| 1053 |
+
"""
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
class UNet1d(nn.Module):
|
| 1057 |
+
def __init__(
|
| 1058 |
+
self,
|
| 1059 |
+
in_channels: int,
|
| 1060 |
+
channels: int,
|
| 1061 |
+
multipliers: Sequence[int],
|
| 1062 |
+
factors: Sequence[int],
|
| 1063 |
+
num_blocks: Sequence[int],
|
| 1064 |
+
attentions: Sequence[int],
|
| 1065 |
+
patch_size: int = 1,
|
| 1066 |
+
resnet_groups: int = 8,
|
| 1067 |
+
use_context_time: bool = True,
|
| 1068 |
+
kernel_multiplier_downsample: int = 2,
|
| 1069 |
+
use_nearest_upsample: bool = False,
|
| 1070 |
+
use_skip_scale: bool = True,
|
| 1071 |
+
use_snake: bool = False,
|
| 1072 |
+
use_stft: bool = False,
|
| 1073 |
+
use_stft_context: bool = False,
|
| 1074 |
+
out_channels: Optional[int] = None,
|
| 1075 |
+
context_features: Optional[int] = None,
|
| 1076 |
+
context_features_multiplier: int = 4,
|
| 1077 |
+
context_channels: Optional[Sequence[int]] = None,
|
| 1078 |
+
context_embedding_features: Optional[int] = None,
|
| 1079 |
+
**kwargs,
|
| 1080 |
+
):
|
| 1081 |
+
super().__init__()
|
| 1082 |
+
out_channels = default(out_channels, in_channels)
|
| 1083 |
+
context_channels = list(default(context_channels, []))
|
| 1084 |
+
num_layers = len(multipliers) - 1
|
| 1085 |
+
use_context_features = exists(context_features)
|
| 1086 |
+
use_context_channels = len(context_channels) > 0
|
| 1087 |
+
context_mapping_features = None
|
| 1088 |
+
|
| 1089 |
+
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
|
| 1090 |
+
|
| 1091 |
+
self.num_layers = num_layers
|
| 1092 |
+
self.use_context_time = use_context_time
|
| 1093 |
+
self.use_context_features = use_context_features
|
| 1094 |
+
self.use_context_channels = use_context_channels
|
| 1095 |
+
self.use_stft = use_stft
|
| 1096 |
+
self.use_stft_context = use_stft_context
|
| 1097 |
+
|
| 1098 |
+
self.context_features = context_features
|
| 1099 |
+
context_channels_pad_length = num_layers + 1 - len(context_channels)
|
| 1100 |
+
context_channels = context_channels + [0] * context_channels_pad_length
|
| 1101 |
+
self.context_channels = context_channels
|
| 1102 |
+
self.context_embedding_features = context_embedding_features
|
| 1103 |
+
|
| 1104 |
+
if use_context_channels:
|
| 1105 |
+
has_context = [c > 0 for c in context_channels]
|
| 1106 |
+
self.has_context = has_context
|
| 1107 |
+
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
|
| 1108 |
+
|
| 1109 |
+
assert (
|
| 1110 |
+
len(factors) == num_layers
|
| 1111 |
+
and len(attentions) >= num_layers
|
| 1112 |
+
and len(num_blocks) == num_layers
|
| 1113 |
+
)
|
| 1114 |
+
|
| 1115 |
+
if use_context_time or use_context_features:
|
| 1116 |
+
context_mapping_features = channels * context_features_multiplier
|
| 1117 |
+
|
| 1118 |
+
self.to_mapping = nn.Sequential(
|
| 1119 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
| 1120 |
+
nn.GELU(),
|
| 1121 |
+
nn.Linear(context_mapping_features, context_mapping_features),
|
| 1122 |
+
nn.GELU(),
|
| 1123 |
+
)
|
| 1124 |
+
|
| 1125 |
+
if use_context_time:
|
| 1126 |
+
assert exists(context_mapping_features)
|
| 1127 |
+
self.to_time = nn.Sequential(
|
| 1128 |
+
TimePositionalEmbedding(
|
| 1129 |
+
dim=channels, out_features=context_mapping_features
|
| 1130 |
+
),
|
| 1131 |
+
nn.GELU(),
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
if use_context_features:
|
| 1135 |
+
assert exists(context_features) and exists(context_mapping_features)
|
| 1136 |
+
self.to_features = nn.Sequential(
|
| 1137 |
+
nn.Linear(
|
| 1138 |
+
in_features=context_features, out_features=context_mapping_features
|
| 1139 |
+
),
|
| 1140 |
+
nn.GELU(),
|
| 1141 |
+
)
|
| 1142 |
+
|
| 1143 |
+
if use_stft:
|
| 1144 |
+
stft_kwargs, kwargs = groupby("stft_", kwargs)
|
| 1145 |
+
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
|
| 1146 |
+
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
|
| 1147 |
+
in_channels *= stft_channels
|
| 1148 |
+
out_channels *= stft_channels
|
| 1149 |
+
context_channels[0] *= stft_channels if use_stft_context else 1
|
| 1150 |
+
assert exists(in_channels) and exists(out_channels)
|
| 1151 |
+
self.stft = STFT(**stft_kwargs)
|
| 1152 |
+
|
| 1153 |
+
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
|
| 1154 |
+
|
| 1155 |
+
self.to_in = Patcher(
|
| 1156 |
+
in_channels=in_channels + context_channels[0],
|
| 1157 |
+
out_channels=channels * multipliers[0],
|
| 1158 |
+
patch_size=patch_size,
|
| 1159 |
+
context_mapping_features=context_mapping_features,
|
| 1160 |
+
use_snake=use_snake,
|
| 1161 |
+
)
|
| 1162 |
+
|
| 1163 |
+
self.downsamples = nn.ModuleList(
|
| 1164 |
+
[
|
| 1165 |
+
DownsampleBlock1d(
|
| 1166 |
+
in_channels=channels * multipliers[i],
|
| 1167 |
+
out_channels=channels * multipliers[i + 1],
|
| 1168 |
+
context_mapping_features=context_mapping_features,
|
| 1169 |
+
context_channels=context_channels[i + 1],
|
| 1170 |
+
context_embedding_features=context_embedding_features,
|
| 1171 |
+
num_layers=num_blocks[i],
|
| 1172 |
+
factor=factors[i],
|
| 1173 |
+
kernel_multiplier=kernel_multiplier_downsample,
|
| 1174 |
+
num_groups=resnet_groups,
|
| 1175 |
+
use_pre_downsample=True,
|
| 1176 |
+
use_skip=True,
|
| 1177 |
+
use_snake=use_snake,
|
| 1178 |
+
num_transformer_blocks=attentions[i],
|
| 1179 |
+
**attention_kwargs,
|
| 1180 |
+
)
|
| 1181 |
+
for i in range(num_layers)
|
| 1182 |
+
]
|
| 1183 |
+
)
|
| 1184 |
+
|
| 1185 |
+
self.bottleneck = BottleneckBlock1d(
|
| 1186 |
+
channels=channels * multipliers[-1],
|
| 1187 |
+
context_mapping_features=context_mapping_features,
|
| 1188 |
+
context_embedding_features=context_embedding_features,
|
| 1189 |
+
num_groups=resnet_groups,
|
| 1190 |
+
num_transformer_blocks=attentions[-1],
|
| 1191 |
+
use_snake=use_snake,
|
| 1192 |
+
**attention_kwargs,
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
self.upsamples = nn.ModuleList(
|
| 1196 |
+
[
|
| 1197 |
+
UpsampleBlock1d(
|
| 1198 |
+
in_channels=channels * multipliers[i + 1],
|
| 1199 |
+
out_channels=channels * multipliers[i],
|
| 1200 |
+
context_mapping_features=context_mapping_features,
|
| 1201 |
+
context_embedding_features=context_embedding_features,
|
| 1202 |
+
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
|
| 1203 |
+
factor=factors[i],
|
| 1204 |
+
use_nearest=use_nearest_upsample,
|
| 1205 |
+
num_groups=resnet_groups,
|
| 1206 |
+
use_skip_scale=use_skip_scale,
|
| 1207 |
+
use_pre_upsample=False,
|
| 1208 |
+
use_skip=True,
|
| 1209 |
+
use_snake=use_snake,
|
| 1210 |
+
skip_channels=channels * multipliers[i + 1],
|
| 1211 |
+
num_transformer_blocks=attentions[i],
|
| 1212 |
+
**attention_kwargs,
|
| 1213 |
+
)
|
| 1214 |
+
for i in reversed(range(num_layers))
|
| 1215 |
+
]
|
| 1216 |
+
)
|
| 1217 |
+
|
| 1218 |
+
self.to_out = Unpatcher(
|
| 1219 |
+
in_channels=channels * multipliers[0],
|
| 1220 |
+
out_channels=out_channels,
|
| 1221 |
+
patch_size=patch_size,
|
| 1222 |
+
context_mapping_features=context_mapping_features,
|
| 1223 |
+
use_snake=use_snake,
|
| 1224 |
+
)
|
| 1225 |
+
|
| 1226 |
+
def get_channels(
|
| 1227 |
+
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
|
| 1228 |
+
) -> Optional[Tensor]:
|
| 1229 |
+
"""Gets context channels at `layer` and checks that shape is correct"""
|
| 1230 |
+
use_context_channels = self.use_context_channels and self.has_context[layer]
|
| 1231 |
+
if not use_context_channels:
|
| 1232 |
+
return None
|
| 1233 |
+
assert exists(channels_list), "Missing context"
|
| 1234 |
+
# Get channels index (skipping zero channel contexts)
|
| 1235 |
+
channels_id = self.channels_ids[layer]
|
| 1236 |
+
# Get channels
|
| 1237 |
+
channels = channels_list[channels_id]
|
| 1238 |
+
message = f"Missing context for layer {layer} at index {channels_id}"
|
| 1239 |
+
assert exists(channels), message
|
| 1240 |
+
# Check channels
|
| 1241 |
+
num_channels = self.context_channels[layer]
|
| 1242 |
+
message = f"Expected context with {num_channels} channels at idx {channels_id}"
|
| 1243 |
+
assert channels.shape[1] == num_channels, message
|
| 1244 |
+
# STFT channels if requested
|
| 1245 |
+
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
|
| 1246 |
+
return channels
|
| 1247 |
+
|
| 1248 |
+
def get_mapping(
|
| 1249 |
+
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
|
| 1250 |
+
) -> Optional[Tensor]:
|
| 1251 |
+
"""Combines context time features and features into mapping"""
|
| 1252 |
+
items, mapping = [], None
|
| 1253 |
+
# Compute time features
|
| 1254 |
+
if self.use_context_time:
|
| 1255 |
+
assert_message = "use_context_time=True but no time features provided"
|
| 1256 |
+
assert exists(time), assert_message
|
| 1257 |
+
items += [self.to_time(time)]
|
| 1258 |
+
# Compute features
|
| 1259 |
+
if self.use_context_features:
|
| 1260 |
+
assert_message = "context_features exists but no features provided"
|
| 1261 |
+
assert exists(features), assert_message
|
| 1262 |
+
items += [self.to_features(features)]
|
| 1263 |
+
# Compute joint mapping
|
| 1264 |
+
if self.use_context_time or self.use_context_features:
|
| 1265 |
+
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
| 1266 |
+
mapping = self.to_mapping(mapping)
|
| 1267 |
+
return mapping
|
| 1268 |
+
|
| 1269 |
+
def forward(
|
| 1270 |
+
self,
|
| 1271 |
+
x: Tensor,
|
| 1272 |
+
time: Optional[Tensor] = None,
|
| 1273 |
+
*,
|
| 1274 |
+
features: Optional[Tensor] = None,
|
| 1275 |
+
channels_list: Optional[Sequence[Tensor]] = None,
|
| 1276 |
+
embedding: Optional[Tensor] = None,
|
| 1277 |
+
embedding_mask: Optional[Tensor] = None,
|
| 1278 |
+
causal: Optional[bool] = False,
|
| 1279 |
+
) -> Tensor:
|
| 1280 |
+
channels = self.get_channels(channels_list, layer=0)
|
| 1281 |
+
# Apply stft if required
|
| 1282 |
+
x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
|
| 1283 |
+
# Concat context channels at layer 0 if provided
|
| 1284 |
+
x = torch.cat([x, channels], dim=1) if exists(channels) else x
|
| 1285 |
+
# Compute mapping from time and features
|
| 1286 |
+
mapping = self.get_mapping(time, features)
|
| 1287 |
+
x = self.to_in(x, mapping, causal=causal)
|
| 1288 |
+
skips_list = [x]
|
| 1289 |
+
|
| 1290 |
+
for i, downsample in enumerate(self.downsamples):
|
| 1291 |
+
channels = self.get_channels(channels_list, layer=i + 1)
|
| 1292 |
+
x, skips = downsample(
|
| 1293 |
+
x,
|
| 1294 |
+
mapping=mapping,
|
| 1295 |
+
channels=channels,
|
| 1296 |
+
embedding=embedding,
|
| 1297 |
+
embedding_mask=embedding_mask,
|
| 1298 |
+
causal=causal,
|
| 1299 |
+
)
|
| 1300 |
+
skips_list += [skips]
|
| 1301 |
+
|
| 1302 |
+
x = self.bottleneck(
|
| 1303 |
+
x,
|
| 1304 |
+
mapping=mapping,
|
| 1305 |
+
embedding=embedding,
|
| 1306 |
+
embedding_mask=embedding_mask,
|
| 1307 |
+
causal=causal,
|
| 1308 |
+
)
|
| 1309 |
+
|
| 1310 |
+
for i, upsample in enumerate(self.upsamples):
|
| 1311 |
+
skips = skips_list.pop()
|
| 1312 |
+
x = upsample(
|
| 1313 |
+
x,
|
| 1314 |
+
skips=skips,
|
| 1315 |
+
mapping=mapping,
|
| 1316 |
+
embedding=embedding,
|
| 1317 |
+
embedding_mask=embedding_mask,
|
| 1318 |
+
causal=causal,
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
x += skips_list.pop()
|
| 1322 |
+
x = self.to_out(x, mapping, causal=causal)
|
| 1323 |
+
x = self.stft.decode1d(x) if self.use_stft else x
|
| 1324 |
+
|
| 1325 |
+
return x
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
""" Conditioning Modules """
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
class FixedEmbedding(nn.Module):
|
| 1332 |
+
def __init__(self, max_length: int, features: int):
|
| 1333 |
+
super().__init__()
|
| 1334 |
+
self.max_length = max_length
|
| 1335 |
+
self.embedding = nn.Embedding(max_length, features)
|
| 1336 |
+
|
| 1337 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1338 |
+
batch_size, length, device = *x.shape[0:2], x.device
|
| 1339 |
+
assert_message = "Input sequence length must be <= max_length"
|
| 1340 |
+
assert length <= self.max_length, assert_message
|
| 1341 |
+
position = torch.arange(length, device=device)
|
| 1342 |
+
fixed_embedding = self.embedding(position)
|
| 1343 |
+
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
|
| 1344 |
+
return fixed_embedding
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
|
| 1348 |
+
if proba == 1:
|
| 1349 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 1350 |
+
elif proba == 0:
|
| 1351 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 1352 |
+
else:
|
| 1353 |
+
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
| 1354 |
+
|
| 1355 |
+
|
| 1356 |
+
class UNetCFG1d(UNet1d):
|
| 1357 |
+
"""UNet1d with Classifier-Free Guidance"""
|
| 1358 |
+
|
| 1359 |
+
def __init__(
|
| 1360 |
+
self,
|
| 1361 |
+
context_embedding_max_length: int,
|
| 1362 |
+
context_embedding_features: int,
|
| 1363 |
+
use_xattn_time: bool = False,
|
| 1364 |
+
**kwargs,
|
| 1365 |
+
):
|
| 1366 |
+
super().__init__(
|
| 1367 |
+
context_embedding_features=context_embedding_features, **kwargs
|
| 1368 |
+
)
|
| 1369 |
+
|
| 1370 |
+
self.use_xattn_time = use_xattn_time
|
| 1371 |
+
|
| 1372 |
+
if use_xattn_time:
|
| 1373 |
+
assert exists(context_embedding_features)
|
| 1374 |
+
self.to_time_embedding = nn.Sequential(
|
| 1375 |
+
TimePositionalEmbedding(
|
| 1376 |
+
dim=kwargs["channels"], out_features=context_embedding_features
|
| 1377 |
+
),
|
| 1378 |
+
nn.GELU(),
|
| 1379 |
+
)
|
| 1380 |
+
|
| 1381 |
+
context_embedding_max_length += 1 # Add one for time embedding
|
| 1382 |
+
|
| 1383 |
+
self.fixed_embedding = FixedEmbedding(
|
| 1384 |
+
max_length=context_embedding_max_length, features=context_embedding_features
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
def forward( # type: ignore
|
| 1388 |
+
self,
|
| 1389 |
+
x: Tensor,
|
| 1390 |
+
time: Tensor,
|
| 1391 |
+
*,
|
| 1392 |
+
embedding: Tensor,
|
| 1393 |
+
embedding_mask: Optional[Tensor] = None,
|
| 1394 |
+
embedding_scale: float = 1.0,
|
| 1395 |
+
embedding_mask_proba: float = 0.0,
|
| 1396 |
+
batch_cfg: bool = False,
|
| 1397 |
+
rescale_cfg: bool = False,
|
| 1398 |
+
scale_phi: float = 0.4,
|
| 1399 |
+
negative_embedding: Optional[Tensor] = None,
|
| 1400 |
+
negative_embedding_mask: Optional[Tensor] = None,
|
| 1401 |
+
**kwargs,
|
| 1402 |
+
) -> Tensor:
|
| 1403 |
+
b, device = embedding.shape[0], embedding.device
|
| 1404 |
+
|
| 1405 |
+
if self.use_xattn_time:
|
| 1406 |
+
embedding = torch.cat(
|
| 1407 |
+
[embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1
|
| 1408 |
+
)
|
| 1409 |
+
|
| 1410 |
+
if embedding_mask is not None:
|
| 1411 |
+
embedding_mask = torch.cat(
|
| 1412 |
+
[embedding_mask, torch.ones((b, 1), device=device)], dim=1
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
fixed_embedding = self.fixed_embedding(embedding)
|
| 1416 |
+
|
| 1417 |
+
if embedding_mask_proba > 0.0:
|
| 1418 |
+
# Randomly mask embedding
|
| 1419 |
+
batch_mask = rand_bool(
|
| 1420 |
+
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
| 1421 |
+
)
|
| 1422 |
+
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
| 1423 |
+
|
| 1424 |
+
if embedding_scale != 1.0:
|
| 1425 |
+
if batch_cfg:
|
| 1426 |
+
batch_x = torch.cat([x, x], dim=0)
|
| 1427 |
+
batch_time = torch.cat([time, time], dim=0)
|
| 1428 |
+
|
| 1429 |
+
if negative_embedding is not None:
|
| 1430 |
+
if negative_embedding_mask is not None:
|
| 1431 |
+
negative_embedding_mask = negative_embedding_mask.to(
|
| 1432 |
+
torch.bool
|
| 1433 |
+
).unsqueeze(2)
|
| 1434 |
+
|
| 1435 |
+
negative_embedding = torch.where(
|
| 1436 |
+
negative_embedding_mask, negative_embedding, fixed_embedding
|
| 1437 |
+
)
|
| 1438 |
+
|
| 1439 |
+
batch_embed = torch.cat([embedding, negative_embedding], dim=0)
|
| 1440 |
+
|
| 1441 |
+
else:
|
| 1442 |
+
batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
|
| 1443 |
+
|
| 1444 |
+
batch_mask = None
|
| 1445 |
+
if embedding_mask is not None:
|
| 1446 |
+
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
|
| 1447 |
+
|
| 1448 |
+
batch_features = None
|
| 1449 |
+
features = kwargs.pop("features", None)
|
| 1450 |
+
if self.use_context_features:
|
| 1451 |
+
batch_features = torch.cat([features, features], dim=0)
|
| 1452 |
+
|
| 1453 |
+
batch_channels = None
|
| 1454 |
+
channels_list = kwargs.pop("channels_list", None)
|
| 1455 |
+
if self.use_context_channels:
|
| 1456 |
+
batch_channels = []
|
| 1457 |
+
for channels in channels_list:
|
| 1458 |
+
batch_channels += [torch.cat([channels, channels], dim=0)]
|
| 1459 |
+
|
| 1460 |
+
# Compute both normal and fixed embedding outputs
|
| 1461 |
+
batch_out = super().forward(
|
| 1462 |
+
batch_x,
|
| 1463 |
+
batch_time,
|
| 1464 |
+
embedding=batch_embed,
|
| 1465 |
+
embedding_mask=batch_mask,
|
| 1466 |
+
features=batch_features,
|
| 1467 |
+
channels_list=batch_channels,
|
| 1468 |
+
**kwargs,
|
| 1469 |
+
)
|
| 1470 |
+
out, out_masked = batch_out.chunk(2, dim=0)
|
| 1471 |
+
|
| 1472 |
+
else:
|
| 1473 |
+
# Compute both normal and fixed embedding outputs
|
| 1474 |
+
out = super().forward(
|
| 1475 |
+
x,
|
| 1476 |
+
time,
|
| 1477 |
+
embedding=embedding,
|
| 1478 |
+
embedding_mask=embedding_mask,
|
| 1479 |
+
**kwargs,
|
| 1480 |
+
)
|
| 1481 |
+
out_masked = super().forward(
|
| 1482 |
+
x,
|
| 1483 |
+
time,
|
| 1484 |
+
embedding=fixed_embedding,
|
| 1485 |
+
embedding_mask=embedding_mask,
|
| 1486 |
+
**kwargs,
|
| 1487 |
+
)
|
| 1488 |
+
|
| 1489 |
+
out_cfg = out_masked + (out - out_masked) * embedding_scale
|
| 1490 |
+
|
| 1491 |
+
if rescale_cfg:
|
| 1492 |
+
out_std = out.std(dim=1, keepdim=True)
|
| 1493 |
+
out_cfg_std = out_cfg.std(dim=1, keepdim=True)
|
| 1494 |
+
|
| 1495 |
+
return (
|
| 1496 |
+
scale_phi * (out_cfg * (out_std / out_cfg_std))
|
| 1497 |
+
+ (1 - scale_phi) * out_cfg
|
| 1498 |
+
)
|
| 1499 |
+
|
| 1500 |
+
else:
|
| 1501 |
+
return out_cfg
|
| 1502 |
+
|
| 1503 |
+
else:
|
| 1504 |
+
return super().forward(
|
| 1505 |
+
x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs
|
| 1506 |
+
)
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
class UNetNCCA1d(UNet1d):
|
| 1510 |
+
"""UNet1d with Noise Channel Conditioning Augmentation"""
|
| 1511 |
+
|
| 1512 |
+
def __init__(self, context_features: int, **kwargs):
|
| 1513 |
+
super().__init__(context_features=context_features, **kwargs)
|
| 1514 |
+
self.embedder = NumberEmbedder(features=context_features)
|
| 1515 |
+
|
| 1516 |
+
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
|
| 1517 |
+
x = x if torch.is_tensor(x) else torch.tensor(x)
|
| 1518 |
+
return x.expand(shape)
|
| 1519 |
+
|
| 1520 |
+
def forward( # type: ignore
|
| 1521 |
+
self,
|
| 1522 |
+
x: Tensor,
|
| 1523 |
+
time: Tensor,
|
| 1524 |
+
*,
|
| 1525 |
+
channels_list: Sequence[Tensor],
|
| 1526 |
+
channels_augmentation: Union[
|
| 1527 |
+
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
|
| 1528 |
+
] = False,
|
| 1529 |
+
channels_scale: Union[
|
| 1530 |
+
float, Sequence[float], Sequence[Sequence[float]], Tensor
|
| 1531 |
+
] = 0,
|
| 1532 |
+
**kwargs,
|
| 1533 |
+
) -> Tensor:
|
| 1534 |
+
b, n = x.shape[0], len(channels_list)
|
| 1535 |
+
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
|
| 1536 |
+
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
|
| 1537 |
+
|
| 1538 |
+
# Augmentation (for each channel list item)
|
| 1539 |
+
for i in range(n):
|
| 1540 |
+
scale = channels_scale[:, i] * channels_augmentation[:, i]
|
| 1541 |
+
scale = rearrange(scale, "b -> b 1 1")
|
| 1542 |
+
item = channels_list[i]
|
| 1543 |
+
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
|
| 1544 |
+
|
| 1545 |
+
# Scale embedding (sum reduction if more than one channel list item)
|
| 1546 |
+
channels_scale_emb = self.embedder(channels_scale)
|
| 1547 |
+
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
|
| 1548 |
+
|
| 1549 |
+
return super().forward(
|
| 1550 |
+
x=x,
|
| 1551 |
+
time=time,
|
| 1552 |
+
channels_list=channels_list,
|
| 1553 |
+
features=channels_scale_emb,
|
| 1554 |
+
**kwargs,
|
| 1555 |
+
)
|
| 1556 |
+
|
| 1557 |
+
|
| 1558 |
+
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
|
| 1559 |
+
def __init__(self, *args, **kwargs):
|
| 1560 |
+
super().__init__(*args, **kwargs)
|
| 1561 |
+
|
| 1562 |
+
def forward(self, *args, **kwargs): # type: ignore
|
| 1563 |
+
return UNetCFG1d.forward(self, *args, **kwargs)
|
| 1564 |
+
|
| 1565 |
+
|
| 1566 |
+
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
|
| 1567 |
+
if type == "base":
|
| 1568 |
+
return UNet1d(**kwargs)
|
| 1569 |
+
elif type == "all":
|
| 1570 |
+
return UNetAll1d(**kwargs)
|
| 1571 |
+
elif type == "cfg":
|
| 1572 |
+
return UNetCFG1d(**kwargs)
|
| 1573 |
+
elif type == "ncca":
|
| 1574 |
+
return UNetNCCA1d(**kwargs)
|
| 1575 |
+
else:
|
| 1576 |
+
raise ValueError(f"Unknown XUNet1d type: {type}")
|
| 1577 |
+
|
| 1578 |
+
|
| 1579 |
+
class NumberEmbedder(nn.Module):
|
| 1580 |
+
def __init__(
|
| 1581 |
+
self,
|
| 1582 |
+
features: int,
|
| 1583 |
+
dim: int = 256,
|
| 1584 |
+
):
|
| 1585 |
+
super().__init__()
|
| 1586 |
+
self.features = features
|
| 1587 |
+
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
|
| 1588 |
+
|
| 1589 |
+
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
|
| 1590 |
+
if not torch.is_tensor(x):
|
| 1591 |
+
device = next(self.embedding.parameters()).device
|
| 1592 |
+
x = torch.tensor(x, device=device)
|
| 1593 |
+
assert isinstance(x, Tensor)
|
| 1594 |
+
shape = x.shape
|
| 1595 |
+
x = rearrange(x, "... -> (...)")
|
| 1596 |
+
embedding = self.embedding(x)
|
| 1597 |
+
x = embedding.view(*shape, self.features)
|
| 1598 |
+
return x # type: ignore
|
| 1599 |
+
|
| 1600 |
+
|
| 1601 |
+
"""
|
| 1602 |
+
Audio Transforms
|
| 1603 |
+
"""
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
class STFT(nn.Module):
|
| 1607 |
+
"""Helper for torch stft and istft"""
|
| 1608 |
+
|
| 1609 |
+
def __init__(
|
| 1610 |
+
self,
|
| 1611 |
+
num_fft: int = 1023,
|
| 1612 |
+
hop_length: int = 256,
|
| 1613 |
+
window_length: Optional[int] = None,
|
| 1614 |
+
length: Optional[int] = None,
|
| 1615 |
+
use_complex: bool = False,
|
| 1616 |
+
):
|
| 1617 |
+
super().__init__()
|
| 1618 |
+
self.num_fft = num_fft
|
| 1619 |
+
self.hop_length = default(hop_length, floor(num_fft // 4))
|
| 1620 |
+
self.window_length = default(window_length, num_fft)
|
| 1621 |
+
self.length = length
|
| 1622 |
+
self.register_buffer("window", torch.hann_window(self.window_length))
|
| 1623 |
+
self.use_complex = use_complex
|
| 1624 |
+
|
| 1625 |
+
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
|
| 1626 |
+
b = wave.shape[0]
|
| 1627 |
+
wave = rearrange(wave, "b c t -> (b c) t")
|
| 1628 |
+
|
| 1629 |
+
stft = torch.stft(
|
| 1630 |
+
wave,
|
| 1631 |
+
n_fft=self.num_fft,
|
| 1632 |
+
hop_length=self.hop_length,
|
| 1633 |
+
win_length=self.window_length,
|
| 1634 |
+
window=self.window, # type: ignore
|
| 1635 |
+
return_complex=True,
|
| 1636 |
+
normalized=True,
|
| 1637 |
+
)
|
| 1638 |
+
|
| 1639 |
+
if self.use_complex:
|
| 1640 |
+
# Returns real and imaginary
|
| 1641 |
+
stft_a, stft_b = stft.real, stft.imag
|
| 1642 |
+
else:
|
| 1643 |
+
# Returns magnitude and phase matrices
|
| 1644 |
+
magnitude, phase = torch.abs(stft), torch.angle(stft)
|
| 1645 |
+
stft_a, stft_b = magnitude, phase
|
| 1646 |
+
|
| 1647 |
+
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
|
| 1648 |
+
|
| 1649 |
+
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
|
| 1650 |
+
b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
|
| 1651 |
+
length = closest_power_2(l * self.hop_length)
|
| 1652 |
+
|
| 1653 |
+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
|
| 1654 |
+
|
| 1655 |
+
if self.use_complex:
|
| 1656 |
+
real, imag = stft_a, stft_b
|
| 1657 |
+
else:
|
| 1658 |
+
magnitude, phase = stft_a, stft_b
|
| 1659 |
+
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
|
| 1660 |
+
|
| 1661 |
+
stft = torch.stack([real, imag], dim=-1)
|
| 1662 |
+
|
| 1663 |
+
wave = torch.istft(
|
| 1664 |
+
stft,
|
| 1665 |
+
n_fft=self.num_fft,
|
| 1666 |
+
hop_length=self.hop_length,
|
| 1667 |
+
win_length=self.window_length,
|
| 1668 |
+
window=self.window, # type: ignore
|
| 1669 |
+
length=default(self.length, length),
|
| 1670 |
+
normalized=True,
|
| 1671 |
+
)
|
| 1672 |
+
|
| 1673 |
+
return rearrange(wave, "(b c) t -> b c t", b=b)
|
| 1674 |
+
|
| 1675 |
+
def encode1d(
|
| 1676 |
+
self, wave: Tensor, stacked: bool = True
|
| 1677 |
+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
|
| 1678 |
+
stft_a, stft_b = self.encode(wave)
|
| 1679 |
+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
|
| 1680 |
+
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
|
| 1681 |
+
|
| 1682 |
+
def decode1d(self, stft_pair: Tensor) -> Tensor:
|
| 1683 |
+
f = self.num_fft // 2 + 1
|
| 1684 |
+
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
|
| 1685 |
+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
|
| 1686 |
+
return self.decode(stft_a, stft_b)
|
src/YingMusicSinger/utils/stable_audio_tools/autoencoders.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any, Dict, Literal
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from alias_free_torch import Activation1d
|
| 7 |
+
from dac.nn.layers import WNConv1d, WNConvTranspose1d
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torchaudio import transforms as T
|
| 11 |
+
|
| 12 |
+
# from ..inference.sampling import sample
|
| 13 |
+
# from ..inference.utils import prepare_audio
|
| 14 |
+
from .blocks import SnakeBeta
|
| 15 |
+
from .bottleneck import Bottleneck, DiscreteBottleneck
|
| 16 |
+
from .diffusion import (
|
| 17 |
+
ConditionedDiffusionModel,
|
| 18 |
+
DAU1DCondWrapper,
|
| 19 |
+
DiTWrapper,
|
| 20 |
+
UNet1DCondWrapper,
|
| 21 |
+
)
|
| 22 |
+
from .factory import create_bottleneck_from_config, create_pretransform_from_config
|
| 23 |
+
from .pretransforms import Pretransform
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def checkpoint(function, *args, **kwargs):
|
| 27 |
+
kwargs.setdefault("use_reentrant", False)
|
| 28 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_activation(
|
| 32 |
+
activation: Literal["elu", "snake", "none"], antialias=False, channels=None
|
| 33 |
+
) -> nn.Module:
|
| 34 |
+
if activation == "elu":
|
| 35 |
+
act = nn.ELU()
|
| 36 |
+
elif activation == "snake":
|
| 37 |
+
act = SnakeBeta(channels)
|
| 38 |
+
elif activation == "none":
|
| 39 |
+
act = nn.Identity()
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Unknown activation {activation}")
|
| 42 |
+
|
| 43 |
+
if antialias:
|
| 44 |
+
act = Activation1d(act)
|
| 45 |
+
|
| 46 |
+
return act
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ResidualUnit(nn.Module):
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
in_channels,
|
| 53 |
+
out_channels,
|
| 54 |
+
dilation,
|
| 55 |
+
use_snake=False,
|
| 56 |
+
antialias_activation=False,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
|
| 60 |
+
self.dilation = dilation
|
| 61 |
+
|
| 62 |
+
padding = (dilation * (7 - 1)) // 2
|
| 63 |
+
|
| 64 |
+
self.layers = nn.Sequential(
|
| 65 |
+
get_activation(
|
| 66 |
+
"snake" if use_snake else "elu",
|
| 67 |
+
antialias=antialias_activation,
|
| 68 |
+
channels=out_channels,
|
| 69 |
+
),
|
| 70 |
+
WNConv1d(
|
| 71 |
+
in_channels=in_channels,
|
| 72 |
+
out_channels=out_channels,
|
| 73 |
+
kernel_size=7,
|
| 74 |
+
dilation=dilation,
|
| 75 |
+
padding=padding,
|
| 76 |
+
),
|
| 77 |
+
get_activation(
|
| 78 |
+
"snake" if use_snake else "elu",
|
| 79 |
+
antialias=antialias_activation,
|
| 80 |
+
channels=out_channels,
|
| 81 |
+
),
|
| 82 |
+
WNConv1d(
|
| 83 |
+
in_channels=out_channels, out_channels=out_channels, kernel_size=1
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
res = x
|
| 89 |
+
|
| 90 |
+
# x = checkpoint(self.layers, x)
|
| 91 |
+
x = self.layers(x)
|
| 92 |
+
|
| 93 |
+
return x + res
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class EncoderBlock(nn.Module):
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
in_channels,
|
| 100 |
+
out_channels,
|
| 101 |
+
stride,
|
| 102 |
+
use_snake=False,
|
| 103 |
+
antialias_activation=False,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
|
| 107 |
+
self.layers = nn.Sequential(
|
| 108 |
+
ResidualUnit(
|
| 109 |
+
in_channels=in_channels,
|
| 110 |
+
out_channels=in_channels,
|
| 111 |
+
dilation=1,
|
| 112 |
+
use_snake=use_snake,
|
| 113 |
+
),
|
| 114 |
+
ResidualUnit(
|
| 115 |
+
in_channels=in_channels,
|
| 116 |
+
out_channels=in_channels,
|
| 117 |
+
dilation=3,
|
| 118 |
+
use_snake=use_snake,
|
| 119 |
+
),
|
| 120 |
+
ResidualUnit(
|
| 121 |
+
in_channels=in_channels,
|
| 122 |
+
out_channels=in_channels,
|
| 123 |
+
dilation=9,
|
| 124 |
+
use_snake=use_snake,
|
| 125 |
+
),
|
| 126 |
+
get_activation(
|
| 127 |
+
"snake" if use_snake else "elu",
|
| 128 |
+
antialias=antialias_activation,
|
| 129 |
+
channels=in_channels,
|
| 130 |
+
),
|
| 131 |
+
WNConv1d(
|
| 132 |
+
in_channels=in_channels,
|
| 133 |
+
out_channels=out_channels,
|
| 134 |
+
kernel_size=2 * stride,
|
| 135 |
+
stride=stride,
|
| 136 |
+
padding=math.ceil(stride / 2),
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
return self.layers(x)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class DecoderBlock(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channels,
|
| 148 |
+
out_channels,
|
| 149 |
+
stride,
|
| 150 |
+
use_snake=False,
|
| 151 |
+
antialias_activation=False,
|
| 152 |
+
use_nearest_upsample=False,
|
| 153 |
+
):
|
| 154 |
+
super().__init__()
|
| 155 |
+
|
| 156 |
+
if use_nearest_upsample:
|
| 157 |
+
upsample_layer = nn.Sequential(
|
| 158 |
+
nn.Upsample(scale_factor=stride, mode="nearest"),
|
| 159 |
+
WNConv1d(
|
| 160 |
+
in_channels=in_channels,
|
| 161 |
+
out_channels=out_channels,
|
| 162 |
+
kernel_size=2 * stride,
|
| 163 |
+
stride=1,
|
| 164 |
+
bias=False,
|
| 165 |
+
padding="same",
|
| 166 |
+
),
|
| 167 |
+
)
|
| 168 |
+
else:
|
| 169 |
+
upsample_layer = WNConvTranspose1d(
|
| 170 |
+
in_channels=in_channels,
|
| 171 |
+
out_channels=out_channels,
|
| 172 |
+
kernel_size=2 * stride,
|
| 173 |
+
stride=stride,
|
| 174 |
+
padding=math.ceil(stride / 2),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.layers = nn.Sequential(
|
| 178 |
+
get_activation(
|
| 179 |
+
"snake" if use_snake else "elu",
|
| 180 |
+
antialias=antialias_activation,
|
| 181 |
+
channels=in_channels,
|
| 182 |
+
),
|
| 183 |
+
upsample_layer,
|
| 184 |
+
ResidualUnit(
|
| 185 |
+
in_channels=out_channels,
|
| 186 |
+
out_channels=out_channels,
|
| 187 |
+
dilation=1,
|
| 188 |
+
use_snake=use_snake,
|
| 189 |
+
),
|
| 190 |
+
ResidualUnit(
|
| 191 |
+
in_channels=out_channels,
|
| 192 |
+
out_channels=out_channels,
|
| 193 |
+
dilation=3,
|
| 194 |
+
use_snake=use_snake,
|
| 195 |
+
),
|
| 196 |
+
ResidualUnit(
|
| 197 |
+
in_channels=out_channels,
|
| 198 |
+
out_channels=out_channels,
|
| 199 |
+
dilation=9,
|
| 200 |
+
use_snake=use_snake,
|
| 201 |
+
),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
return self.layers(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class OobleckEncoder(nn.Module):
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
in_channels=2,
|
| 212 |
+
channels=128,
|
| 213 |
+
latent_dim=32,
|
| 214 |
+
c_mults=[1, 2, 4, 8],
|
| 215 |
+
strides=[2, 4, 8, 8],
|
| 216 |
+
use_snake=False,
|
| 217 |
+
antialias_activation=False,
|
| 218 |
+
):
|
| 219 |
+
super().__init__()
|
| 220 |
+
|
| 221 |
+
c_mults = [1] + c_mults
|
| 222 |
+
|
| 223 |
+
self.depth = len(c_mults)
|
| 224 |
+
|
| 225 |
+
layers = [
|
| 226 |
+
WNConv1d(
|
| 227 |
+
in_channels=in_channels,
|
| 228 |
+
out_channels=c_mults[0] * channels,
|
| 229 |
+
kernel_size=7,
|
| 230 |
+
padding=3,
|
| 231 |
+
)
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
for i in range(self.depth - 1):
|
| 235 |
+
layers += [
|
| 236 |
+
EncoderBlock(
|
| 237 |
+
in_channels=c_mults[i] * channels,
|
| 238 |
+
out_channels=c_mults[i + 1] * channels,
|
| 239 |
+
stride=strides[i],
|
| 240 |
+
use_snake=use_snake,
|
| 241 |
+
)
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
layers += [
|
| 245 |
+
get_activation(
|
| 246 |
+
"snake" if use_snake else "elu",
|
| 247 |
+
antialias=antialias_activation,
|
| 248 |
+
channels=c_mults[-1] * channels,
|
| 249 |
+
),
|
| 250 |
+
WNConv1d(
|
| 251 |
+
in_channels=c_mults[-1] * channels,
|
| 252 |
+
out_channels=latent_dim,
|
| 253 |
+
kernel_size=3,
|
| 254 |
+
padding=1,
|
| 255 |
+
),
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
self.layers = nn.Sequential(*layers)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
return self.layers(x)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class OobleckDecoder(nn.Module):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
out_channels=2,
|
| 268 |
+
channels=128,
|
| 269 |
+
latent_dim=32,
|
| 270 |
+
c_mults=[1, 2, 4, 8],
|
| 271 |
+
strides=[2, 4, 8, 8],
|
| 272 |
+
use_snake=False,
|
| 273 |
+
antialias_activation=False,
|
| 274 |
+
use_nearest_upsample=False,
|
| 275 |
+
final_tanh=True,
|
| 276 |
+
):
|
| 277 |
+
super().__init__()
|
| 278 |
+
|
| 279 |
+
c_mults = [1] + c_mults
|
| 280 |
+
|
| 281 |
+
self.depth = len(c_mults)
|
| 282 |
+
|
| 283 |
+
layers = [
|
| 284 |
+
WNConv1d(
|
| 285 |
+
in_channels=latent_dim,
|
| 286 |
+
out_channels=c_mults[-1] * channels,
|
| 287 |
+
kernel_size=7,
|
| 288 |
+
padding=3,
|
| 289 |
+
),
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
for i in range(self.depth - 1, 0, -1):
|
| 293 |
+
layers += [
|
| 294 |
+
DecoderBlock(
|
| 295 |
+
in_channels=c_mults[i] * channels,
|
| 296 |
+
out_channels=c_mults[i - 1] * channels,
|
| 297 |
+
stride=strides[i - 1],
|
| 298 |
+
use_snake=use_snake,
|
| 299 |
+
antialias_activation=antialias_activation,
|
| 300 |
+
use_nearest_upsample=use_nearest_upsample,
|
| 301 |
+
)
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
layers += [
|
| 305 |
+
get_activation(
|
| 306 |
+
"snake" if use_snake else "elu",
|
| 307 |
+
antialias=antialias_activation,
|
| 308 |
+
channels=c_mults[0] * channels,
|
| 309 |
+
),
|
| 310 |
+
WNConv1d(
|
| 311 |
+
in_channels=c_mults[0] * channels,
|
| 312 |
+
out_channels=out_channels,
|
| 313 |
+
kernel_size=7,
|
| 314 |
+
padding=3,
|
| 315 |
+
bias=False,
|
| 316 |
+
),
|
| 317 |
+
nn.Tanh() if final_tanh else nn.Identity(),
|
| 318 |
+
]
|
| 319 |
+
|
| 320 |
+
self.layers = nn.Sequential(*layers)
|
| 321 |
+
|
| 322 |
+
def forward(self, x):
|
| 323 |
+
return self.layers(x)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class DACEncoderWrapper(nn.Module):
|
| 327 |
+
def __init__(self, in_channels=1, **kwargs):
|
| 328 |
+
super().__init__()
|
| 329 |
+
|
| 330 |
+
from dac.model.dac import Encoder as DACEncoder
|
| 331 |
+
|
| 332 |
+
latent_dim = kwargs.pop("latent_dim", None)
|
| 333 |
+
|
| 334 |
+
encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
|
| 335 |
+
self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
|
| 336 |
+
self.latent_dim = latent_dim
|
| 337 |
+
|
| 338 |
+
# Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
|
| 339 |
+
self.proj_out = (
|
| 340 |
+
nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1)
|
| 341 |
+
if latent_dim is not None
|
| 342 |
+
else nn.Identity()
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
if in_channels != 1:
|
| 346 |
+
self.encoder.block[0] = WNConv1d(
|
| 347 |
+
in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
def forward(self, x):
|
| 351 |
+
x = self.encoder(x)
|
| 352 |
+
x = self.proj_out(x)
|
| 353 |
+
return x
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class DACDecoderWrapper(nn.Module):
|
| 357 |
+
def __init__(self, latent_dim, out_channels=1, **kwargs):
|
| 358 |
+
super().__init__()
|
| 359 |
+
|
| 360 |
+
from dac.model.dac import Decoder as DACDecoder
|
| 361 |
+
|
| 362 |
+
self.decoder = DACDecoder(
|
| 363 |
+
**kwargs, input_channel=latent_dim, d_out=out_channels
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
self.latent_dim = latent_dim
|
| 367 |
+
|
| 368 |
+
def forward(self, x):
|
| 369 |
+
return self.decoder(x)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class AudioAutoencoder(nn.Module):
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
encoder,
|
| 376 |
+
decoder,
|
| 377 |
+
latent_dim,
|
| 378 |
+
downsampling_ratio,
|
| 379 |
+
sample_rate,
|
| 380 |
+
io_channels=2,
|
| 381 |
+
bottleneck: Bottleneck = None,
|
| 382 |
+
pretransform: Pretransform = None,
|
| 383 |
+
in_channels=None,
|
| 384 |
+
out_channels=None,
|
| 385 |
+
soft_clip=False,
|
| 386 |
+
):
|
| 387 |
+
super().__init__()
|
| 388 |
+
|
| 389 |
+
self.downsampling_ratio = downsampling_ratio
|
| 390 |
+
self.sample_rate = sample_rate
|
| 391 |
+
|
| 392 |
+
self.latent_dim = latent_dim
|
| 393 |
+
self.io_channels = io_channels
|
| 394 |
+
self.in_channels = io_channels
|
| 395 |
+
self.out_channels = io_channels
|
| 396 |
+
|
| 397 |
+
self.min_length = self.downsampling_ratio
|
| 398 |
+
|
| 399 |
+
if in_channels is not None:
|
| 400 |
+
self.in_channels = in_channels
|
| 401 |
+
|
| 402 |
+
if out_channels is not None:
|
| 403 |
+
self.out_channels = out_channels
|
| 404 |
+
|
| 405 |
+
self.bottleneck = bottleneck
|
| 406 |
+
|
| 407 |
+
self.encoder = encoder
|
| 408 |
+
|
| 409 |
+
self.decoder = decoder
|
| 410 |
+
|
| 411 |
+
self.pretransform = pretransform
|
| 412 |
+
|
| 413 |
+
self.soft_clip = soft_clip
|
| 414 |
+
|
| 415 |
+
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
|
| 416 |
+
|
| 417 |
+
def encode(
|
| 418 |
+
self,
|
| 419 |
+
audio,
|
| 420 |
+
return_info=False,
|
| 421 |
+
skip_pretransform=False,
|
| 422 |
+
iterate_batch=False,
|
| 423 |
+
**kwargs,
|
| 424 |
+
):
|
| 425 |
+
info = {}
|
| 426 |
+
|
| 427 |
+
if self.pretransform is not None and not skip_pretransform:
|
| 428 |
+
if self.pretransform.enable_grad:
|
| 429 |
+
if iterate_batch:
|
| 430 |
+
audios = []
|
| 431 |
+
for i in range(audio.shape[0]):
|
| 432 |
+
audios.append(self.pretransform.encode(audio[i : i + 1]))
|
| 433 |
+
audio = torch.cat(audios, dim=0)
|
| 434 |
+
else:
|
| 435 |
+
audio = self.pretransform.encode(audio)
|
| 436 |
+
else:
|
| 437 |
+
with torch.no_grad():
|
| 438 |
+
if iterate_batch:
|
| 439 |
+
audios = []
|
| 440 |
+
for i in range(audio.shape[0]):
|
| 441 |
+
audios.append(self.pretransform.encode(audio[i : i + 1]))
|
| 442 |
+
audio = torch.cat(audios, dim=0)
|
| 443 |
+
else:
|
| 444 |
+
audio = self.pretransform.encode(audio)
|
| 445 |
+
|
| 446 |
+
if self.encoder is not None:
|
| 447 |
+
if iterate_batch:
|
| 448 |
+
latents = []
|
| 449 |
+
for i in range(audio.shape[0]):
|
| 450 |
+
latents.append(self.encoder(audio[i : i + 1]))
|
| 451 |
+
latents = torch.cat(latents, dim=0)
|
| 452 |
+
else:
|
| 453 |
+
latents = self.encoder(audio)
|
| 454 |
+
else:
|
| 455 |
+
latents = audio
|
| 456 |
+
|
| 457 |
+
if self.bottleneck is not None:
|
| 458 |
+
# TODO: Add iterate batch logic, needs to merge the info dicts
|
| 459 |
+
latents, bottleneck_info = self.bottleneck.encode(
|
| 460 |
+
latents, return_info=True, **kwargs
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
info.update(bottleneck_info)
|
| 464 |
+
|
| 465 |
+
if return_info:
|
| 466 |
+
return latents, info
|
| 467 |
+
|
| 468 |
+
return latents
|
| 469 |
+
|
| 470 |
+
def decode(self, latents, iterate_batch=False, **kwargs):
|
| 471 |
+
if self.bottleneck is not None:
|
| 472 |
+
if iterate_batch:
|
| 473 |
+
decoded = []
|
| 474 |
+
for i in range(latents.shape[0]):
|
| 475 |
+
decoded.append(self.bottleneck.decode(latents[i : i + 1]))
|
| 476 |
+
latents = torch.cat(decoded, dim=0)
|
| 477 |
+
else:
|
| 478 |
+
latents = self.bottleneck.decode(latents)
|
| 479 |
+
|
| 480 |
+
if iterate_batch:
|
| 481 |
+
decoded = []
|
| 482 |
+
for i in range(latents.shape[0]):
|
| 483 |
+
decoded.append(self.decoder(latents[i : i + 1]))
|
| 484 |
+
decoded = torch.cat(decoded, dim=0)
|
| 485 |
+
else:
|
| 486 |
+
decoded = self.decoder(latents, **kwargs)
|
| 487 |
+
|
| 488 |
+
if self.pretransform is not None:
|
| 489 |
+
if self.pretransform.enable_grad:
|
| 490 |
+
if iterate_batch:
|
| 491 |
+
decodeds = []
|
| 492 |
+
for i in range(decoded.shape[0]):
|
| 493 |
+
decodeds.append(self.pretransform.decode(decoded[i : i + 1]))
|
| 494 |
+
decoded = torch.cat(decodeds, dim=0)
|
| 495 |
+
else:
|
| 496 |
+
decoded = self.pretransform.decode(decoded)
|
| 497 |
+
else:
|
| 498 |
+
with torch.no_grad():
|
| 499 |
+
if iterate_batch:
|
| 500 |
+
decodeds = []
|
| 501 |
+
for i in range(latents.shape[0]):
|
| 502 |
+
decodeds.append(
|
| 503 |
+
self.pretransform.decode(decoded[i : i + 1])
|
| 504 |
+
)
|
| 505 |
+
decoded = torch.cat(decodeds, dim=0)
|
| 506 |
+
else:
|
| 507 |
+
decoded = self.pretransform.decode(decoded)
|
| 508 |
+
|
| 509 |
+
if self.soft_clip:
|
| 510 |
+
decoded = torch.tanh(decoded)
|
| 511 |
+
|
| 512 |
+
return decoded
|
| 513 |
+
|
| 514 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 515 |
+
"""
|
| 516 |
+
Decode discrete tokens to audio
|
| 517 |
+
Only works with discrete autoencoders
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
assert isinstance(self.bottleneck, DiscreteBottleneck), (
|
| 521 |
+
"decode_tokens only works with discrete autoencoders"
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
latents = self.bottleneck.decode_tokens(tokens, **kwargs)
|
| 525 |
+
|
| 526 |
+
return self.decode(latents, **kwargs)
|
| 527 |
+
|
| 528 |
+
def preprocess_audio_for_encoder(self, audio, in_sr):
|
| 529 |
+
"""
|
| 530 |
+
Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
|
| 531 |
+
If the model is mono, stereo audio will be converted to mono.
|
| 532 |
+
Audio will be silence-padded to be a multiple of the model's downsampling ratio.
|
| 533 |
+
Audio will be resampled to the model's sample rate.
|
| 534 |
+
The output will have batch size 1 and be shape (1 x Channels x Length)
|
| 535 |
+
"""
|
| 536 |
+
return self.preprocess_audio_list_for_encoder([audio], [in_sr])
|
| 537 |
+
|
| 538 |
+
def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
|
| 539 |
+
"""
|
| 540 |
+
Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
|
| 541 |
+
The audio in that list can be of different lengths and channels.
|
| 542 |
+
in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
|
| 543 |
+
All audio will be resampled to the model's sample rate.
|
| 544 |
+
Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
|
| 545 |
+
If the model is mono, all audio will be converted to mono.
|
| 546 |
+
The output will be a tensor of shape (Batch x Channels x Length)
|
| 547 |
+
"""
|
| 548 |
+
batch_size = len(audio_list)
|
| 549 |
+
if isinstance(in_sr_list, int):
|
| 550 |
+
in_sr_list = [in_sr_list] * batch_size
|
| 551 |
+
assert len(in_sr_list) == batch_size, (
|
| 552 |
+
"list of sample rates must be the same length of audio_list"
|
| 553 |
+
)
|
| 554 |
+
new_audio = []
|
| 555 |
+
max_length = 0
|
| 556 |
+
# resample & find the max length
|
| 557 |
+
for i in range(batch_size):
|
| 558 |
+
audio = audio_list[i]
|
| 559 |
+
in_sr = in_sr_list[i]
|
| 560 |
+
if len(audio.shape) == 3 and audio.shape[0] == 1:
|
| 561 |
+
# batchsize 1 was given by accident. Just squeeze it.
|
| 562 |
+
audio = audio.squeeze(0)
|
| 563 |
+
elif len(audio.shape) == 1:
|
| 564 |
+
# Mono signal, channel dimension is missing, unsqueeze it in
|
| 565 |
+
audio = audio.unsqueeze(0)
|
| 566 |
+
assert len(audio.shape) == 2, (
|
| 567 |
+
"Audio should be shape (Channels x Length) with no batch dimension"
|
| 568 |
+
)
|
| 569 |
+
# Resample audio
|
| 570 |
+
if in_sr != self.sample_rate:
|
| 571 |
+
resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
|
| 572 |
+
audio = resample_tf(audio)
|
| 573 |
+
new_audio.append(audio)
|
| 574 |
+
if audio.shape[-1] > max_length:
|
| 575 |
+
max_length = audio.shape[-1]
|
| 576 |
+
# Pad every audio to the same length, multiple of model's downsampling ratio
|
| 577 |
+
padded_audio_length = (
|
| 578 |
+
max_length
|
| 579 |
+
+ (self.min_length - (max_length % self.min_length)) % self.min_length
|
| 580 |
+
)
|
| 581 |
+
for i in range(batch_size):
|
| 582 |
+
# Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
|
| 583 |
+
new_audio[i] = prepare_audio(
|
| 584 |
+
new_audio[i],
|
| 585 |
+
in_sr=in_sr,
|
| 586 |
+
target_sr=in_sr,
|
| 587 |
+
target_length=padded_audio_length,
|
| 588 |
+
target_channels=self.in_channels,
|
| 589 |
+
device=new_audio[i].device,
|
| 590 |
+
).squeeze(0)
|
| 591 |
+
# convert to tensor
|
| 592 |
+
return torch.stack(new_audio)
|
| 593 |
+
|
| 594 |
+
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
| 595 |
+
"""
|
| 596 |
+
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
| 597 |
+
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
| 598 |
+
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
| 599 |
+
# and therefore you likely could use the same values with decode_audio.
|
| 600 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
| 601 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
| 602 |
+
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
| 603 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
| 604 |
+
Smaller chunk_size uses less memory, but more compute.
|
| 605 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
| 606 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
| 607 |
+
"""
|
| 608 |
+
if not chunked:
|
| 609 |
+
# default behavior. Encode the entire audio in parallel
|
| 610 |
+
return self.encode(audio, **kwargs)
|
| 611 |
+
else:
|
| 612 |
+
# CHUNKED ENCODING
|
| 613 |
+
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
| 614 |
+
samples_per_latent = self.downsampling_ratio
|
| 615 |
+
total_size = audio.shape[2] # in samples
|
| 616 |
+
batch_size = audio.shape[0]
|
| 617 |
+
chunk_size *= samples_per_latent # converting metric in latents to samples
|
| 618 |
+
overlap *= samples_per_latent # converting metric in latents to samples
|
| 619 |
+
hop_size = chunk_size - overlap
|
| 620 |
+
chunks = []
|
| 621 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 622 |
+
chunk = audio[:, :, i : i + chunk_size]
|
| 623 |
+
chunks.append(chunk)
|
| 624 |
+
if i + chunk_size != total_size:
|
| 625 |
+
# Final chunk
|
| 626 |
+
chunk = audio[:, :, -chunk_size:]
|
| 627 |
+
chunks.append(chunk)
|
| 628 |
+
chunks = torch.stack(chunks)
|
| 629 |
+
num_chunks = chunks.shape[0]
|
| 630 |
+
# Note: y_size might be a different value from the latent length used in diffusion training
|
| 631 |
+
# because we can encode audio of varying lengths
|
| 632 |
+
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
| 633 |
+
y_size = total_size // samples_per_latent
|
| 634 |
+
# Create an empty latent, we will populate it with chunks as we encode them
|
| 635 |
+
y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(
|
| 636 |
+
audio.device
|
| 637 |
+
)
|
| 638 |
+
for i in range(num_chunks):
|
| 639 |
+
x_chunk = chunks[i, :]
|
| 640 |
+
# encode the chunk
|
| 641 |
+
y_chunk = self.encode(x_chunk)
|
| 642 |
+
# figure out where to put the audio along the time domain
|
| 643 |
+
if i == num_chunks - 1:
|
| 644 |
+
# final chunk always goes at the end
|
| 645 |
+
t_end = y_size
|
| 646 |
+
t_start = t_end - y_chunk.shape[2]
|
| 647 |
+
else:
|
| 648 |
+
t_start = i * hop_size // samples_per_latent
|
| 649 |
+
t_end = t_start + chunk_size // samples_per_latent
|
| 650 |
+
# remove the edges of the overlaps
|
| 651 |
+
ol = overlap // samples_per_latent // 2
|
| 652 |
+
chunk_start = 0
|
| 653 |
+
chunk_end = y_chunk.shape[2]
|
| 654 |
+
if i > 0:
|
| 655 |
+
# no overlap for the start of the first chunk
|
| 656 |
+
t_start += ol
|
| 657 |
+
chunk_start += ol
|
| 658 |
+
if i < num_chunks - 1:
|
| 659 |
+
# no overlap for the end of the last chunk
|
| 660 |
+
t_end -= ol
|
| 661 |
+
chunk_end -= ol
|
| 662 |
+
# paste the chunked audio into our y_final output audio
|
| 663 |
+
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
|
| 664 |
+
return y_final
|
| 665 |
+
|
| 666 |
+
def decode_audio(
|
| 667 |
+
self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs
|
| 668 |
+
):
|
| 669 |
+
"""
|
| 670 |
+
Decode latents to audio.
|
| 671 |
+
If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
|
| 672 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
| 673 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
| 674 |
+
You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
|
| 675 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
| 676 |
+
Smaller chunk_size uses less memory, but more compute.
|
| 677 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
| 678 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
| 679 |
+
"""
|
| 680 |
+
if not chunked:
|
| 681 |
+
# default behavior. Decode the entire latent in parallel
|
| 682 |
+
return self.decode(latents, **kwargs)
|
| 683 |
+
else:
|
| 684 |
+
# chunked decoding
|
| 685 |
+
hop_size = chunk_size - overlap
|
| 686 |
+
total_size = latents.shape[2]
|
| 687 |
+
batch_size = latents.shape[0]
|
| 688 |
+
chunks = []
|
| 689 |
+
if total_size < chunk_size:
|
| 690 |
+
# pad the latents to be at least chunk_size
|
| 691 |
+
# 如果在这里pad之后,那么之后的生成歌曲就变噪音了
|
| 692 |
+
pad_size = chunk_size - total_size + 1
|
| 693 |
+
latents = F.pad(latents, (0, pad_size), mode="replicate")
|
| 694 |
+
total_size = latents.shape[2]
|
| 695 |
+
# import pdb; pdb.set_trace()
|
| 696 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 697 |
+
chunk = latents[:, :, i : i + chunk_size]
|
| 698 |
+
chunks.append(chunk)
|
| 699 |
+
if i + chunk_size != total_size:
|
| 700 |
+
# Final chunk
|
| 701 |
+
chunk = latents[:, :, -chunk_size:]
|
| 702 |
+
chunks.append(chunk)
|
| 703 |
+
chunks = torch.stack(chunks)
|
| 704 |
+
num_chunks = chunks.shape[0]
|
| 705 |
+
# samples_per_latent is just the downsampling ratio
|
| 706 |
+
samples_per_latent = self.downsampling_ratio
|
| 707 |
+
# Create an empty waveform, we will populate it with chunks as decode them
|
| 708 |
+
y_size = total_size * samples_per_latent
|
| 709 |
+
y_final = torch.zeros((batch_size, self.out_channels, y_size)).to(
|
| 710 |
+
latents.device
|
| 711 |
+
)
|
| 712 |
+
for i in range(num_chunks):
|
| 713 |
+
x_chunk = chunks[i, :]
|
| 714 |
+
# decode the chunk
|
| 715 |
+
y_chunk = self.decode(x_chunk)
|
| 716 |
+
# figure out where to put the audio along the time domain
|
| 717 |
+
if i == num_chunks - 1:
|
| 718 |
+
# final chunk always goes at the end
|
| 719 |
+
t_end = y_size
|
| 720 |
+
t_start = t_end - y_chunk.shape[2]
|
| 721 |
+
else:
|
| 722 |
+
t_start = i * hop_size * samples_per_latent
|
| 723 |
+
t_end = t_start + chunk_size * samples_per_latent
|
| 724 |
+
# remove the edges of the overlaps
|
| 725 |
+
ol = (overlap // 2) * samples_per_latent
|
| 726 |
+
chunk_start = 0
|
| 727 |
+
chunk_end = y_chunk.shape[2]
|
| 728 |
+
if i > 0:
|
| 729 |
+
# no overlap for the start of the first chunk
|
| 730 |
+
t_start += ol
|
| 731 |
+
chunk_start += ol
|
| 732 |
+
if i < num_chunks - 1:
|
| 733 |
+
# no overlap for the end of the last chunk
|
| 734 |
+
t_end -= ol
|
| 735 |
+
chunk_end -= ol
|
| 736 |
+
# paste the chunked audio into our y_final output audio
|
| 737 |
+
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
|
| 738 |
+
return y_final
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
class DiffusionAutoencoder(AudioAutoencoder):
|
| 742 |
+
def __init__(
|
| 743 |
+
self,
|
| 744 |
+
diffusion: ConditionedDiffusionModel,
|
| 745 |
+
diffusion_downsampling_ratio,
|
| 746 |
+
*args,
|
| 747 |
+
**kwargs,
|
| 748 |
+
):
|
| 749 |
+
super().__init__(*args, **kwargs)
|
| 750 |
+
|
| 751 |
+
self.diffusion = diffusion
|
| 752 |
+
|
| 753 |
+
self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
|
| 754 |
+
|
| 755 |
+
if self.encoder is not None:
|
| 756 |
+
# Shrink the initial encoder parameters to avoid saturated latents
|
| 757 |
+
with torch.no_grad():
|
| 758 |
+
for param in self.encoder.parameters():
|
| 759 |
+
param *= 0.5
|
| 760 |
+
|
| 761 |
+
def decode(self, latents, steps=100):
|
| 762 |
+
upsampled_length = latents.shape[2] * self.downsampling_ratio
|
| 763 |
+
|
| 764 |
+
if self.bottleneck is not None:
|
| 765 |
+
latents = self.bottleneck.decode(latents)
|
| 766 |
+
|
| 767 |
+
if self.decoder is not None:
|
| 768 |
+
latents = self.decode(latents)
|
| 769 |
+
|
| 770 |
+
# Upsample latents to match diffusion length
|
| 771 |
+
if latents.shape[2] != upsampled_length:
|
| 772 |
+
latents = F.interpolate(latents, size=upsampled_length, mode="nearest")
|
| 773 |
+
|
| 774 |
+
noise = torch.randn(
|
| 775 |
+
latents.shape[0], self.io_channels, upsampled_length, device=latents.device
|
| 776 |
+
)
|
| 777 |
+
decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
|
| 778 |
+
|
| 779 |
+
if self.pretransform is not None:
|
| 780 |
+
if self.pretransform.enable_grad:
|
| 781 |
+
decoded = self.pretransform.decode(decoded)
|
| 782 |
+
else:
|
| 783 |
+
with torch.no_grad():
|
| 784 |
+
decoded = self.pretransform.decode(decoded)
|
| 785 |
+
|
| 786 |
+
return decoded
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
# AE factories
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def create_encoder_from_config(encoder_config: Dict[str, Any]):
|
| 793 |
+
encoder_type = encoder_config.get("type", None)
|
| 794 |
+
assert encoder_type is not None, "Encoder type must be specified"
|
| 795 |
+
|
| 796 |
+
if encoder_type == "oobleck":
|
| 797 |
+
encoder = OobleckEncoder(**encoder_config["config"])
|
| 798 |
+
|
| 799 |
+
elif encoder_type == "seanet":
|
| 800 |
+
from encodec.modules import SEANetEncoder
|
| 801 |
+
|
| 802 |
+
seanet_encoder_config = encoder_config["config"]
|
| 803 |
+
|
| 804 |
+
# SEANet encoder expects strides in reverse order
|
| 805 |
+
seanet_encoder_config["ratios"] = list(
|
| 806 |
+
reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))
|
| 807 |
+
)
|
| 808 |
+
encoder = SEANetEncoder(**seanet_encoder_config)
|
| 809 |
+
elif encoder_type == "dac":
|
| 810 |
+
dac_config = encoder_config["config"]
|
| 811 |
+
|
| 812 |
+
encoder = DACEncoderWrapper(**dac_config)
|
| 813 |
+
elif encoder_type == "local_attn":
|
| 814 |
+
from .local_attention import TransformerEncoder1D
|
| 815 |
+
|
| 816 |
+
local_attn_config = encoder_config["config"]
|
| 817 |
+
|
| 818 |
+
encoder = TransformerEncoder1D(**local_attn_config)
|
| 819 |
+
else:
|
| 820 |
+
raise ValueError(f"Unknown encoder type {encoder_type}")
|
| 821 |
+
|
| 822 |
+
requires_grad = encoder_config.get("requires_grad", True)
|
| 823 |
+
if not requires_grad:
|
| 824 |
+
for param in encoder.parameters():
|
| 825 |
+
param.requires_grad = False
|
| 826 |
+
|
| 827 |
+
return encoder
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def create_decoder_from_config(decoder_config: Dict[str, Any]):
|
| 831 |
+
decoder_type = decoder_config.get("type", None)
|
| 832 |
+
assert decoder_type is not None, "Decoder type must be specified"
|
| 833 |
+
|
| 834 |
+
if decoder_type == "oobleck":
|
| 835 |
+
decoder = OobleckDecoder(**decoder_config["config"])
|
| 836 |
+
elif decoder_type == "seanet":
|
| 837 |
+
from encodec.modules import SEANetDecoder
|
| 838 |
+
|
| 839 |
+
decoder = SEANetDecoder(**decoder_config["config"])
|
| 840 |
+
elif decoder_type == "dac":
|
| 841 |
+
dac_config = decoder_config["config"]
|
| 842 |
+
|
| 843 |
+
decoder = DACDecoderWrapper(**dac_config)
|
| 844 |
+
elif decoder_type == "local_attn":
|
| 845 |
+
from .local_attention import TransformerDecoder1D
|
| 846 |
+
|
| 847 |
+
local_attn_config = decoder_config["config"]
|
| 848 |
+
|
| 849 |
+
decoder = TransformerDecoder1D(**local_attn_config)
|
| 850 |
+
else:
|
| 851 |
+
raise ValueError(f"Unknown decoder type {decoder_type}")
|
| 852 |
+
|
| 853 |
+
requires_grad = decoder_config.get("requires_grad", True)
|
| 854 |
+
if not requires_grad:
|
| 855 |
+
for param in decoder.parameters():
|
| 856 |
+
param.requires_grad = False
|
| 857 |
+
|
| 858 |
+
return decoder
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def create_autoencoder_from_config(config: Dict[str, Any]):
|
| 862 |
+
ae_config = config["model"]
|
| 863 |
+
|
| 864 |
+
encoder = create_encoder_from_config(ae_config["encoder"])
|
| 865 |
+
decoder = create_decoder_from_config(ae_config["decoder"])
|
| 866 |
+
|
| 867 |
+
bottleneck = ae_config.get("bottleneck", None)
|
| 868 |
+
|
| 869 |
+
latent_dim = ae_config.get("latent_dim", None)
|
| 870 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
| 871 |
+
downsampling_ratio = ae_config.get("downsampling_ratio", None)
|
| 872 |
+
assert downsampling_ratio is not None, (
|
| 873 |
+
"downsampling_ratio must be specified in model config"
|
| 874 |
+
)
|
| 875 |
+
io_channels = ae_config.get("io_channels", None)
|
| 876 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
| 877 |
+
sample_rate = config.get("sample_rate", None)
|
| 878 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
| 879 |
+
|
| 880 |
+
in_channels = ae_config.get("in_channels", None)
|
| 881 |
+
out_channels = ae_config.get("out_channels", None)
|
| 882 |
+
|
| 883 |
+
pretransform = ae_config.get("pretransform", None)
|
| 884 |
+
|
| 885 |
+
if pretransform is not None:
|
| 886 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 887 |
+
|
| 888 |
+
if bottleneck is not None:
|
| 889 |
+
bottleneck = create_bottleneck_from_config(bottleneck)
|
| 890 |
+
|
| 891 |
+
soft_clip = ae_config["decoder"].get("soft_clip", False)
|
| 892 |
+
|
| 893 |
+
return AudioAutoencoder(
|
| 894 |
+
encoder,
|
| 895 |
+
decoder,
|
| 896 |
+
io_channels=io_channels,
|
| 897 |
+
latent_dim=latent_dim,
|
| 898 |
+
downsampling_ratio=downsampling_ratio,
|
| 899 |
+
sample_rate=sample_rate,
|
| 900 |
+
bottleneck=bottleneck,
|
| 901 |
+
pretransform=pretransform,
|
| 902 |
+
in_channels=in_channels,
|
| 903 |
+
out_channels=out_channels,
|
| 904 |
+
soft_clip=soft_clip,
|
| 905 |
+
)
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def create_diffAE_from_config(config: Dict[str, Any]):
|
| 909 |
+
diffae_config = config["model"]
|
| 910 |
+
|
| 911 |
+
if "encoder" in diffae_config:
|
| 912 |
+
encoder = create_encoder_from_config(diffae_config["encoder"])
|
| 913 |
+
else:
|
| 914 |
+
encoder = None
|
| 915 |
+
|
| 916 |
+
if "decoder" in diffae_config:
|
| 917 |
+
decoder = create_decoder_from_config(diffae_config["decoder"])
|
| 918 |
+
else:
|
| 919 |
+
decoder = None
|
| 920 |
+
|
| 921 |
+
diffusion_model_type = diffae_config["diffusion"]["type"]
|
| 922 |
+
|
| 923 |
+
if diffusion_model_type == "DAU1d":
|
| 924 |
+
diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
|
| 925 |
+
elif diffusion_model_type == "adp_1d":
|
| 926 |
+
diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
|
| 927 |
+
elif diffusion_model_type == "dit":
|
| 928 |
+
diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
|
| 929 |
+
|
| 930 |
+
latent_dim = diffae_config.get("latent_dim", None)
|
| 931 |
+
assert latent_dim is not None, "latent_dim must be specified in model config"
|
| 932 |
+
downsampling_ratio = diffae_config.get("downsampling_ratio", None)
|
| 933 |
+
assert downsampling_ratio is not None, (
|
| 934 |
+
"downsampling_ratio must be specified in model config"
|
| 935 |
+
)
|
| 936 |
+
io_channels = diffae_config.get("io_channels", None)
|
| 937 |
+
assert io_channels is not None, "io_channels must be specified in model config"
|
| 938 |
+
sample_rate = config.get("sample_rate", None)
|
| 939 |
+
assert sample_rate is not None, "sample_rate must be specified in model config"
|
| 940 |
+
|
| 941 |
+
bottleneck = diffae_config.get("bottleneck", None)
|
| 942 |
+
|
| 943 |
+
pretransform = diffae_config.get("pretransform", None)
|
| 944 |
+
|
| 945 |
+
if pretransform is not None:
|
| 946 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 947 |
+
|
| 948 |
+
if bottleneck is not None:
|
| 949 |
+
bottleneck = create_bottleneck_from_config(bottleneck)
|
| 950 |
+
|
| 951 |
+
diffusion_downsampling_ratio = (None,)
|
| 952 |
+
|
| 953 |
+
if diffusion_model_type == "DAU1d":
|
| 954 |
+
diffusion_downsampling_ratio = np.prod(
|
| 955 |
+
diffae_config["diffusion"]["config"]["strides"]
|
| 956 |
+
)
|
| 957 |
+
elif diffusion_model_type == "adp_1d":
|
| 958 |
+
diffusion_downsampling_ratio = np.prod(
|
| 959 |
+
diffae_config["diffusion"]["config"]["factors"]
|
| 960 |
+
)
|
| 961 |
+
elif diffusion_model_type == "dit":
|
| 962 |
+
diffusion_downsampling_ratio = 1
|
| 963 |
+
|
| 964 |
+
return DiffusionAutoencoder(
|
| 965 |
+
encoder=encoder,
|
| 966 |
+
decoder=decoder,
|
| 967 |
+
diffusion=diffusion,
|
| 968 |
+
io_channels=io_channels,
|
| 969 |
+
sample_rate=sample_rate,
|
| 970 |
+
latent_dim=latent_dim,
|
| 971 |
+
downsampling_ratio=downsampling_ratio,
|
| 972 |
+
diffusion_downsampling_ratio=diffusion_downsampling_ratio,
|
| 973 |
+
bottleneck=bottleneck,
|
| 974 |
+
pretransform=pretransform,
|
| 975 |
+
)
|
src/YingMusicSinger/utils/stable_audio_tools/blocks.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from functools import reduce
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from dac.nn.layers import Snake1d
|
| 7 |
+
from packaging import version
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.backends.cuda import sdp_kernel
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualBlock(nn.Module):
|
| 14 |
+
def __init__(self, main, skip=None):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.main = nn.Sequential(*main)
|
| 17 |
+
self.skip = skip if skip else nn.Identity()
|
| 18 |
+
|
| 19 |
+
def forward(self, input):
|
| 20 |
+
return self.main(input) + self.skip(input)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResConvBlock(ResidualBlock):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
c_in,
|
| 27 |
+
c_mid,
|
| 28 |
+
c_out,
|
| 29 |
+
is_last=False,
|
| 30 |
+
kernel_size=5,
|
| 31 |
+
conv_bias=True,
|
| 32 |
+
use_snake=False,
|
| 33 |
+
):
|
| 34 |
+
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
|
| 35 |
+
super().__init__(
|
| 36 |
+
[
|
| 37 |
+
nn.Conv1d(
|
| 38 |
+
c_in, c_mid, kernel_size, padding=kernel_size // 2, bias=conv_bias
|
| 39 |
+
),
|
| 40 |
+
nn.GroupNorm(1, c_mid),
|
| 41 |
+
Snake1d(c_mid) if use_snake else nn.GELU(),
|
| 42 |
+
nn.Conv1d(
|
| 43 |
+
c_mid, c_out, kernel_size, padding=kernel_size // 2, bias=conv_bias
|
| 44 |
+
),
|
| 45 |
+
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
|
| 46 |
+
(Snake1d(c_out) if use_snake else nn.GELU())
|
| 47 |
+
if not is_last
|
| 48 |
+
else nn.Identity(),
|
| 49 |
+
],
|
| 50 |
+
skip,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SelfAttention1d(nn.Module):
|
| 55 |
+
def __init__(self, c_in, n_head=1, dropout_rate=0.0):
|
| 56 |
+
super().__init__()
|
| 57 |
+
assert c_in % n_head == 0
|
| 58 |
+
self.norm = nn.GroupNorm(1, c_in)
|
| 59 |
+
self.n_head = n_head
|
| 60 |
+
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
|
| 61 |
+
self.out_proj = nn.Conv1d(c_in, c_in, 1)
|
| 62 |
+
self.dropout = nn.Dropout(dropout_rate, inplace=True)
|
| 63 |
+
|
| 64 |
+
self.use_flash = torch.cuda.is_available() and version.parse(
|
| 65 |
+
torch.__version__
|
| 66 |
+
) >= version.parse("2.0.0")
|
| 67 |
+
|
| 68 |
+
if not self.use_flash:
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
|
| 72 |
+
|
| 73 |
+
if device_properties.major == 8 and device_properties.minor == 0:
|
| 74 |
+
# Use flash attention for A100 GPUs
|
| 75 |
+
self.sdp_kernel_config = (True, False, False)
|
| 76 |
+
else:
|
| 77 |
+
# Don't use flash attention for other GPUs
|
| 78 |
+
self.sdp_kernel_config = (False, True, True)
|
| 79 |
+
|
| 80 |
+
def forward(self, input):
|
| 81 |
+
n, c, s = input.shape
|
| 82 |
+
qkv = self.qkv_proj(self.norm(input))
|
| 83 |
+
qkv = qkv.view([n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
|
| 84 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 85 |
+
scale = k.shape[3] ** -0.25
|
| 86 |
+
|
| 87 |
+
if self.use_flash:
|
| 88 |
+
with sdp_kernel(*self.sdp_kernel_config):
|
| 89 |
+
y = (
|
| 90 |
+
F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
| 91 |
+
.contiguous()
|
| 92 |
+
.view([n, c, s])
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
|
| 96 |
+
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
|
| 97 |
+
|
| 98 |
+
return input + self.dropout(self.out_proj(y))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class SkipBlock(nn.Module):
|
| 102 |
+
def __init__(self, *main):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.main = nn.Sequential(*main)
|
| 105 |
+
|
| 106 |
+
def forward(self, input):
|
| 107 |
+
return torch.cat([self.main(input), input], dim=1)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class FourierFeatures(nn.Module):
|
| 111 |
+
def __init__(self, in_features, out_features, std=1.0):
|
| 112 |
+
super().__init__()
|
| 113 |
+
assert out_features % 2 == 0
|
| 114 |
+
self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
|
| 115 |
+
|
| 116 |
+
def forward(self, input):
|
| 117 |
+
f = 2 * math.pi * input @ self.weight.T
|
| 118 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def expand_to_planes(input, shape):
|
| 122 |
+
return input[..., None].repeat([1, 1, shape[2]])
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
_kernels = {
|
| 126 |
+
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
| 127 |
+
"cubic": [
|
| 128 |
+
-0.01171875,
|
| 129 |
+
-0.03515625,
|
| 130 |
+
0.11328125,
|
| 131 |
+
0.43359375,
|
| 132 |
+
0.43359375,
|
| 133 |
+
0.11328125,
|
| 134 |
+
-0.03515625,
|
| 135 |
+
-0.01171875,
|
| 136 |
+
],
|
| 137 |
+
"lanczos3": [
|
| 138 |
+
0.003689131001010537,
|
| 139 |
+
0.015056144446134567,
|
| 140 |
+
-0.03399861603975296,
|
| 141 |
+
-0.066637322306633,
|
| 142 |
+
0.13550527393817902,
|
| 143 |
+
0.44638532400131226,
|
| 144 |
+
0.44638532400131226,
|
| 145 |
+
0.13550527393817902,
|
| 146 |
+
-0.066637322306633,
|
| 147 |
+
-0.03399861603975296,
|
| 148 |
+
0.015056144446134567,
|
| 149 |
+
0.003689131001010537,
|
| 150 |
+
],
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class Downsample1d(nn.Module):
|
| 155 |
+
def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.pad_mode = pad_mode
|
| 158 |
+
kernel_1d = torch.tensor(_kernels[kernel])
|
| 159 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
| 160 |
+
self.register_buffer("kernel", kernel_1d)
|
| 161 |
+
self.channels_last = channels_last
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
if self.channels_last:
|
| 165 |
+
x = x.permute(0, 2, 1)
|
| 166 |
+
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
|
| 167 |
+
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
| 168 |
+
indices = torch.arange(x.shape[1], device=x.device)
|
| 169 |
+
weight[indices, indices] = self.kernel.to(weight)
|
| 170 |
+
x = F.conv1d(x, weight, stride=2)
|
| 171 |
+
if self.channels_last:
|
| 172 |
+
x = x.permute(0, 2, 1)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class Upsample1d(nn.Module):
|
| 177 |
+
def __init__(self, kernel="linear", pad_mode="reflect", channels_last=False):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.pad_mode = pad_mode
|
| 180 |
+
kernel_1d = torch.tensor(_kernels[kernel]) * 2
|
| 181 |
+
self.pad = kernel_1d.shape[0] // 2 - 1
|
| 182 |
+
self.register_buffer("kernel", kernel_1d)
|
| 183 |
+
self.channels_last = channels_last
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
if self.channels_last:
|
| 187 |
+
x = x.permute(0, 2, 1)
|
| 188 |
+
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
| 189 |
+
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
|
| 190 |
+
indices = torch.arange(x.shape[1], device=x.device)
|
| 191 |
+
weight[indices, indices] = self.kernel.to(weight)
|
| 192 |
+
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
| 193 |
+
if self.channels_last:
|
| 194 |
+
x = x.permute(0, 2, 1)
|
| 195 |
+
return x
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def Downsample1d_2(
|
| 199 |
+
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
|
| 200 |
+
) -> nn.Module:
|
| 201 |
+
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
|
| 202 |
+
|
| 203 |
+
return nn.Conv1d(
|
| 204 |
+
in_channels=in_channels,
|
| 205 |
+
out_channels=out_channels,
|
| 206 |
+
kernel_size=factor * kernel_multiplier + 1,
|
| 207 |
+
stride=factor,
|
| 208 |
+
padding=factor * (kernel_multiplier // 2),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def Upsample1d_2(
|
| 213 |
+
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
|
| 214 |
+
) -> nn.Module:
|
| 215 |
+
if factor == 1:
|
| 216 |
+
return nn.Conv1d(
|
| 217 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if use_nearest:
|
| 221 |
+
return nn.Sequential(
|
| 222 |
+
nn.Upsample(scale_factor=factor, mode="nearest"),
|
| 223 |
+
nn.Conv1d(
|
| 224 |
+
in_channels=in_channels,
|
| 225 |
+
out_channels=out_channels,
|
| 226 |
+
kernel_size=3,
|
| 227 |
+
padding=1,
|
| 228 |
+
),
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
return nn.ConvTranspose1d(
|
| 232 |
+
in_channels=in_channels,
|
| 233 |
+
out_channels=out_channels,
|
| 234 |
+
kernel_size=factor * 2,
|
| 235 |
+
stride=factor,
|
| 236 |
+
padding=factor // 2 + factor % 2,
|
| 237 |
+
output_padding=factor % 2,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def zero_init(layer):
|
| 242 |
+
nn.init.zeros_(layer.weight)
|
| 243 |
+
if layer.bias is not None:
|
| 244 |
+
nn.init.zeros_(layer.bias)
|
| 245 |
+
return layer
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def rms_norm(x, scale, eps):
|
| 249 |
+
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
| 250 |
+
mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True)
|
| 251 |
+
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
| 252 |
+
return x * scale.to(x.dtype)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# rms_norm = torch.compile(rms_norm)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class AdaRMSNorm(nn.Module):
|
| 259 |
+
def __init__(self, features, cond_features, eps=1e-6):
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.eps = eps
|
| 262 |
+
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
|
| 263 |
+
|
| 264 |
+
def extra_repr(self):
|
| 265 |
+
return f"eps={self.eps},"
|
| 266 |
+
|
| 267 |
+
def forward(self, x, cond):
|
| 268 |
+
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def normalize(x, eps=1e-4):
|
| 272 |
+
dim = list(range(1, x.ndim))
|
| 273 |
+
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
|
| 274 |
+
alpha = np.sqrt(n.numel() / x.numel())
|
| 275 |
+
return x / torch.add(eps, n, alpha=alpha)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class ForcedWNConv1d(nn.Module):
|
| 279 |
+
def __init__(self, in_channels, out_channels, kernel_size=1):
|
| 280 |
+
super().__init__()
|
| 281 |
+
self.weight = nn.Parameter(
|
| 282 |
+
torch.randn([out_channels, in_channels, kernel_size])
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
def forward(self, x):
|
| 286 |
+
if self.training:
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
self.weight.copy_(normalize(self.weight))
|
| 289 |
+
|
| 290 |
+
fan_in = self.weight[0].numel()
|
| 291 |
+
|
| 292 |
+
w = normalize(self.weight) / math.sqrt(fan_in)
|
| 293 |
+
|
| 294 |
+
return F.conv1d(x, w, padding="same")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# Kernels
|
| 298 |
+
|
| 299 |
+
use_compile = True
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def compile(function, *args, **kwargs):
|
| 303 |
+
if not use_compile:
|
| 304 |
+
return function
|
| 305 |
+
try:
|
| 306 |
+
return torch.compile(function, *args, **kwargs)
|
| 307 |
+
except RuntimeError:
|
| 308 |
+
return function
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@compile
|
| 312 |
+
def linear_geglu(x, weight, bias=None):
|
| 313 |
+
x = x @ weight.mT
|
| 314 |
+
if bias is not None:
|
| 315 |
+
x = x + bias
|
| 316 |
+
x, gate = x.chunk(2, dim=-1)
|
| 317 |
+
return x * F.gelu(gate)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@compile
|
| 321 |
+
def rms_norm(x, scale, eps):
|
| 322 |
+
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
| 323 |
+
mean_sq = torch.mean(x.to(dtype) ** 2, dim=-1, keepdim=True)
|
| 324 |
+
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
| 325 |
+
return x * scale.to(x.dtype)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
# Layers
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
class LinearGEGLU(nn.Linear):
|
| 332 |
+
def __init__(self, in_features, out_features, bias=True):
|
| 333 |
+
super().__init__(in_features, out_features * 2, bias=bias)
|
| 334 |
+
self.out_features = out_features
|
| 335 |
+
|
| 336 |
+
def forward(self, x):
|
| 337 |
+
return linear_geglu(x, self.weight, self.bias)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class RMSNorm(nn.Module):
|
| 341 |
+
def __init__(self, shape, fix_scale=False, eps=1e-6):
|
| 342 |
+
super().__init__()
|
| 343 |
+
self.eps = eps
|
| 344 |
+
|
| 345 |
+
if fix_scale:
|
| 346 |
+
self.register_buffer("scale", torch.ones(shape))
|
| 347 |
+
else:
|
| 348 |
+
self.scale = nn.Parameter(torch.ones(shape))
|
| 349 |
+
|
| 350 |
+
def extra_repr(self):
|
| 351 |
+
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
return rms_norm(x, self.scale, self.eps)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def snake_beta(x, alpha, beta):
|
| 358 |
+
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# try:
|
| 362 |
+
# snake_beta = torch.compile(snake_beta)
|
| 363 |
+
# except RuntimeError:
|
| 364 |
+
# pass
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
|
| 368 |
+
# License available in LICENSES/LICENSE_NVIDIA.txt
|
| 369 |
+
class SnakeBeta(nn.Module):
|
| 370 |
+
def __init__(
|
| 371 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
| 372 |
+
):
|
| 373 |
+
super(SnakeBeta, self).__init__()
|
| 374 |
+
self.in_features = in_features
|
| 375 |
+
|
| 376 |
+
# initialize alpha
|
| 377 |
+
self.alpha_logscale = alpha_logscale
|
| 378 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 379 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 380 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
| 381 |
+
else: # linear scale alphas initialized to ones
|
| 382 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
| 383 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
| 384 |
+
|
| 385 |
+
self.alpha.requires_grad = alpha_trainable
|
| 386 |
+
self.beta.requires_grad = alpha_trainable
|
| 387 |
+
|
| 388 |
+
self.no_div_by_zero = 0.000000001
|
| 389 |
+
|
| 390 |
+
def forward(self, x):
|
| 391 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 392 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 393 |
+
if self.alpha_logscale:
|
| 394 |
+
alpha = torch.exp(alpha)
|
| 395 |
+
beta = torch.exp(beta)
|
| 396 |
+
x = snake_beta(x, alpha, beta)
|
| 397 |
+
|
| 398 |
+
return x
|
src/YingMusicSinger/utils/stable_audio_tools/bottleneck copy.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from vector_quantize_pytorch import FSQ, ResidualVQ
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
def __init__(self, is_discrete: bool = False):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.is_discrete = is_discrete
|
| 15 |
+
|
| 16 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def decode(self, x):
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiscreteBottleneck(Bottleneck):
|
| 24 |
+
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
| 25 |
+
super().__init__(is_discrete=True)
|
| 26 |
+
|
| 27 |
+
self.num_quantizers = num_quantizers
|
| 28 |
+
self.codebook_size = codebook_size
|
| 29 |
+
self.tokens_id = tokens_id
|
| 30 |
+
|
| 31 |
+
def decode_tokens(self, codes, **kwargs):
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TanhBottleneck(Bottleneck):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
super().__init__(is_discrete=False)
|
| 38 |
+
self.tanh = nn.Tanh()
|
| 39 |
+
|
| 40 |
+
def encode(self, x, return_info=False):
|
| 41 |
+
info = {}
|
| 42 |
+
|
| 43 |
+
x = torch.tanh(x)
|
| 44 |
+
|
| 45 |
+
if return_info:
|
| 46 |
+
return x, info
|
| 47 |
+
else:
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
def decode(self, x):
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def vae_sample(mean, scale):
|
| 55 |
+
stdev = nn.functional.softplus(scale) + 1e-4
|
| 56 |
+
var = stdev * stdev
|
| 57 |
+
logvar = torch.log(var)
|
| 58 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 59 |
+
|
| 60 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 61 |
+
|
| 62 |
+
return latents, kl
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class VAEBottleneck(Bottleneck):
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super().__init__(is_discrete=False)
|
| 68 |
+
|
| 69 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 70 |
+
info = {}
|
| 71 |
+
|
| 72 |
+
mean, scale = x.chunk(2, dim=1)
|
| 73 |
+
|
| 74 |
+
x, kl = vae_sample(mean, scale)
|
| 75 |
+
|
| 76 |
+
info["kl"] = kl
|
| 77 |
+
|
| 78 |
+
if return_info:
|
| 79 |
+
return x, info
|
| 80 |
+
else:
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def decode(self, x):
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_mean_kernel(x, y):
|
| 88 |
+
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
| 89 |
+
return torch.exp(-kernel_input).mean()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_mmd(latents):
|
| 93 |
+
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
| 94 |
+
noise = torch.randn_like(latents_reshaped)
|
| 95 |
+
|
| 96 |
+
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
| 97 |
+
noise_kernel = compute_mean_kernel(noise, noise)
|
| 98 |
+
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
| 99 |
+
|
| 100 |
+
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
| 101 |
+
return mmd.mean()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class WassersteinBottleneck(Bottleneck):
|
| 105 |
+
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
| 106 |
+
super().__init__(is_discrete=False)
|
| 107 |
+
|
| 108 |
+
self.noise_augment_dim = noise_augment_dim
|
| 109 |
+
self.bypass_mmd = bypass_mmd
|
| 110 |
+
|
| 111 |
+
def encode(self, x, return_info=False):
|
| 112 |
+
info = {}
|
| 113 |
+
|
| 114 |
+
if self.training and return_info:
|
| 115 |
+
if self.bypass_mmd:
|
| 116 |
+
mmd = torch.tensor(0.0)
|
| 117 |
+
else:
|
| 118 |
+
mmd = compute_mmd(x)
|
| 119 |
+
|
| 120 |
+
info["mmd"] = mmd
|
| 121 |
+
|
| 122 |
+
if return_info:
|
| 123 |
+
return x, info
|
| 124 |
+
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def decode(self, x):
|
| 128 |
+
if self.noise_augment_dim > 0:
|
| 129 |
+
noise = torch.randn(
|
| 130 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 131 |
+
).type_as(x)
|
| 132 |
+
x = torch.cat([x, noise], dim=1)
|
| 133 |
+
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class L2Bottleneck(Bottleneck):
|
| 138 |
+
def __init__(self):
|
| 139 |
+
super().__init__(is_discrete=False)
|
| 140 |
+
|
| 141 |
+
def encode(self, x, return_info=False):
|
| 142 |
+
info = {}
|
| 143 |
+
|
| 144 |
+
x = F.normalize(x, dim=1)
|
| 145 |
+
|
| 146 |
+
if return_info:
|
| 147 |
+
return x, info
|
| 148 |
+
else:
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
def decode(self, x):
|
| 152 |
+
return F.normalize(x, dim=1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class RVQBottleneck(DiscreteBottleneck):
|
| 156 |
+
def __init__(self, **quantizer_kwargs):
|
| 157 |
+
super().__init__(
|
| 158 |
+
num_quantizers=quantizer_kwargs["num_quantizers"],
|
| 159 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 160 |
+
tokens_id="quantizer_indices",
|
| 161 |
+
)
|
| 162 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 163 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 164 |
+
|
| 165 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 166 |
+
info = {}
|
| 167 |
+
|
| 168 |
+
x = rearrange(x, "b c n -> b n c")
|
| 169 |
+
x, indices, loss = self.quantizer(x)
|
| 170 |
+
x = rearrange(x, "b n c -> b c n")
|
| 171 |
+
|
| 172 |
+
info["quantizer_indices"] = indices
|
| 173 |
+
info["quantizer_loss"] = loss.mean()
|
| 174 |
+
|
| 175 |
+
if return_info:
|
| 176 |
+
return x, info
|
| 177 |
+
else:
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
def decode(self, x):
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def decode_tokens(self, codes, **kwargs):
|
| 184 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 185 |
+
|
| 186 |
+
return self.decode(latents, **kwargs)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class RVQVAEBottleneck(DiscreteBottleneck):
|
| 190 |
+
def __init__(self, **quantizer_kwargs):
|
| 191 |
+
super().__init__(
|
| 192 |
+
num_quantizers=quantizer_kwargs["num_quantizers"],
|
| 193 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 194 |
+
tokens_id="quantizer_indices",
|
| 195 |
+
)
|
| 196 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 197 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 198 |
+
|
| 199 |
+
def encode(self, x, return_info=False):
|
| 200 |
+
info = {}
|
| 201 |
+
|
| 202 |
+
x, kl = vae_sample(*x.chunk(2, dim=1))
|
| 203 |
+
|
| 204 |
+
info["kl"] = kl
|
| 205 |
+
|
| 206 |
+
x = rearrange(x, "b c n -> b n c")
|
| 207 |
+
x, indices, loss = self.quantizer(x)
|
| 208 |
+
x = rearrange(x, "b n c -> b c n")
|
| 209 |
+
|
| 210 |
+
info["quantizer_indices"] = indices
|
| 211 |
+
info["quantizer_loss"] = loss.mean()
|
| 212 |
+
|
| 213 |
+
if return_info:
|
| 214 |
+
return x, info
|
| 215 |
+
else:
|
| 216 |
+
return x
|
| 217 |
+
|
| 218 |
+
def decode(self, x):
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def decode_tokens(self, codes, **kwargs):
|
| 222 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 223 |
+
|
| 224 |
+
return self.decode(latents, **kwargs)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DACRVQBottleneck(DiscreteBottleneck):
|
| 228 |
+
def __init__(
|
| 229 |
+
self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs
|
| 230 |
+
):
|
| 231 |
+
super().__init__(
|
| 232 |
+
num_quantizers=quantizer_kwargs["n_codebooks"],
|
| 233 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 234 |
+
tokens_id="codes",
|
| 235 |
+
)
|
| 236 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 237 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 238 |
+
self.quantize_on_decode = quantize_on_decode
|
| 239 |
+
self.noise_augment_dim = noise_augment_dim
|
| 240 |
+
|
| 241 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 242 |
+
info = {}
|
| 243 |
+
|
| 244 |
+
info["pre_quantizer"] = x
|
| 245 |
+
|
| 246 |
+
if self.quantize_on_decode:
|
| 247 |
+
return x, info if return_info else x
|
| 248 |
+
|
| 249 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
| 250 |
+
|
| 251 |
+
output = {
|
| 252 |
+
"z": z,
|
| 253 |
+
"codes": codes,
|
| 254 |
+
"latents": latents,
|
| 255 |
+
"vq/commitment_loss": commitment_loss,
|
| 256 |
+
"vq/codebook_loss": codebook_loss,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 260 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 261 |
+
|
| 262 |
+
info.update(output)
|
| 263 |
+
|
| 264 |
+
if return_info:
|
| 265 |
+
return output["z"], info
|
| 266 |
+
|
| 267 |
+
return output["z"]
|
| 268 |
+
|
| 269 |
+
def decode(self, x):
|
| 270 |
+
if self.quantize_on_decode:
|
| 271 |
+
x = self.quantizer(x)[0]
|
| 272 |
+
|
| 273 |
+
if self.noise_augment_dim > 0:
|
| 274 |
+
noise = torch.randn(
|
| 275 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 276 |
+
).type_as(x)
|
| 277 |
+
x = torch.cat([x, noise], dim=1)
|
| 278 |
+
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
def decode_tokens(self, codes, **kwargs):
|
| 282 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 283 |
+
|
| 284 |
+
return self.decode(latents, **kwargs)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
| 288 |
+
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
| 289 |
+
super().__init__(
|
| 290 |
+
num_quantizers=quantizer_kwargs["n_codebooks"],
|
| 291 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 292 |
+
tokens_id="codes",
|
| 293 |
+
)
|
| 294 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 295 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 296 |
+
self.quantize_on_decode = quantize_on_decode
|
| 297 |
+
|
| 298 |
+
def encode(self, x, return_info=False, n_quantizers: int = None):
|
| 299 |
+
info = {}
|
| 300 |
+
|
| 301 |
+
mean, scale = x.chunk(2, dim=1)
|
| 302 |
+
|
| 303 |
+
x, kl = vae_sample(mean, scale)
|
| 304 |
+
|
| 305 |
+
info["pre_quantizer"] = x
|
| 306 |
+
info["kl"] = kl
|
| 307 |
+
|
| 308 |
+
if self.quantize_on_decode:
|
| 309 |
+
return x, info if return_info else x
|
| 310 |
+
|
| 311 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
|
| 312 |
+
x, n_quantizers=n_quantizers
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
output = {
|
| 316 |
+
"z": z,
|
| 317 |
+
"codes": codes,
|
| 318 |
+
"latents": latents,
|
| 319 |
+
"vq/commitment_loss": commitment_loss,
|
| 320 |
+
"vq/codebook_loss": codebook_loss,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 324 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 325 |
+
|
| 326 |
+
info.update(output)
|
| 327 |
+
|
| 328 |
+
if return_info:
|
| 329 |
+
return output["z"], info
|
| 330 |
+
|
| 331 |
+
return output["z"]
|
| 332 |
+
|
| 333 |
+
def decode(self, x):
|
| 334 |
+
if self.quantize_on_decode:
|
| 335 |
+
x = self.quantizer(x)[0]
|
| 336 |
+
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
def decode_tokens(self, codes, **kwargs):
|
| 340 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 341 |
+
|
| 342 |
+
return self.decode(latents, **kwargs)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class FSQBottleneck(DiscreteBottleneck):
|
| 346 |
+
def __init__(self, noise_augment_dim=0, **kwargs):
|
| 347 |
+
super().__init__(
|
| 348 |
+
num_quantizers=kwargs.get("num_codebooks", 1),
|
| 349 |
+
codebook_size=np.prod(kwargs["levels"]),
|
| 350 |
+
tokens_id="quantizer_indices",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.noise_augment_dim = noise_augment_dim
|
| 354 |
+
|
| 355 |
+
self.quantizer = FSQ(
|
| 356 |
+
**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def encode(self, x, return_info=False):
|
| 360 |
+
info = {}
|
| 361 |
+
|
| 362 |
+
orig_dtype = x.dtype
|
| 363 |
+
x = x.float()
|
| 364 |
+
|
| 365 |
+
x = rearrange(x, "b c n -> b n c")
|
| 366 |
+
x, indices = self.quantizer(x)
|
| 367 |
+
x = rearrange(x, "b n c -> b c n")
|
| 368 |
+
|
| 369 |
+
x = x.to(orig_dtype)
|
| 370 |
+
|
| 371 |
+
# Reorder indices to match the expected format
|
| 372 |
+
indices = rearrange(indices, "b n q -> b q n")
|
| 373 |
+
|
| 374 |
+
info["quantizer_indices"] = indices
|
| 375 |
+
|
| 376 |
+
if return_info:
|
| 377 |
+
return x, info
|
| 378 |
+
else:
|
| 379 |
+
return x
|
| 380 |
+
|
| 381 |
+
def decode(self, x):
|
| 382 |
+
if self.noise_augment_dim > 0:
|
| 383 |
+
noise = torch.randn(
|
| 384 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 385 |
+
).type_as(x)
|
| 386 |
+
x = torch.cat([x, noise], dim=1)
|
| 387 |
+
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 391 |
+
latents = self.quantizer.indices_to_codes(tokens)
|
| 392 |
+
|
| 393 |
+
return self.decode(latents, **kwargs)
|
src/YingMusicSinger/utils/stable_audio_tools/bottleneck.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from vector_quantize_pytorch import FSQ, ResidualVQ
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Bottleneck(nn.Module):
|
| 11 |
+
def __init__(self, is_discrete: bool = False):
|
| 12 |
+
super().__init__()
|
| 13 |
+
|
| 14 |
+
self.is_discrete = is_discrete
|
| 15 |
+
|
| 16 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 17 |
+
raise NotImplementedError
|
| 18 |
+
|
| 19 |
+
def decode(self, x):
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiscreteBottleneck(Bottleneck):
|
| 24 |
+
def __init__(self, num_quantizers, codebook_size, tokens_id):
|
| 25 |
+
super().__init__(is_discrete=True)
|
| 26 |
+
|
| 27 |
+
self.num_quantizers = num_quantizers
|
| 28 |
+
self.codebook_size = codebook_size
|
| 29 |
+
self.tokens_id = tokens_id
|
| 30 |
+
|
| 31 |
+
def decode_tokens(self, codes, **kwargs):
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TanhBottleneck(Bottleneck):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
super().__init__(is_discrete=False)
|
| 38 |
+
self.tanh = nn.Tanh()
|
| 39 |
+
|
| 40 |
+
def encode(self, x, return_info=False):
|
| 41 |
+
info = {}
|
| 42 |
+
|
| 43 |
+
x = torch.tanh(x)
|
| 44 |
+
|
| 45 |
+
if return_info:
|
| 46 |
+
return x, info
|
| 47 |
+
else:
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
def decode(self, x):
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def vae_sample(mean, scale):
|
| 55 |
+
stdev = nn.functional.softplus(scale) + 1e-4
|
| 56 |
+
var = stdev * stdev
|
| 57 |
+
logvar = torch.log(var)
|
| 58 |
+
latents = torch.randn_like(mean) * stdev + mean
|
| 59 |
+
|
| 60 |
+
kl = (mean * mean + var - logvar - 1).sum(1).mean()
|
| 61 |
+
|
| 62 |
+
return latents, kl
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class VAEBottleneck(Bottleneck):
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super().__init__(is_discrete=False)
|
| 68 |
+
|
| 69 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 70 |
+
info = {}
|
| 71 |
+
|
| 72 |
+
mean, scale = x.chunk(2, dim=1)
|
| 73 |
+
|
| 74 |
+
x, kl = vae_sample(mean, scale)
|
| 75 |
+
|
| 76 |
+
info["kl"] = kl
|
| 77 |
+
|
| 78 |
+
if return_info:
|
| 79 |
+
return x, info
|
| 80 |
+
else:
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def decode(self, x):
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_mean_kernel(x, y):
|
| 88 |
+
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
|
| 89 |
+
return torch.exp(-kernel_input).mean()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def compute_mmd(latents):
|
| 93 |
+
latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
|
| 94 |
+
noise = torch.randn_like(latents_reshaped)
|
| 95 |
+
|
| 96 |
+
latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
|
| 97 |
+
noise_kernel = compute_mean_kernel(noise, noise)
|
| 98 |
+
latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
|
| 99 |
+
|
| 100 |
+
mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
|
| 101 |
+
return mmd.mean()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class WassersteinBottleneck(Bottleneck):
|
| 105 |
+
def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
|
| 106 |
+
super().__init__(is_discrete=False)
|
| 107 |
+
|
| 108 |
+
self.noise_augment_dim = noise_augment_dim
|
| 109 |
+
self.bypass_mmd = bypass_mmd
|
| 110 |
+
|
| 111 |
+
def encode(self, x, return_info=False):
|
| 112 |
+
info = {}
|
| 113 |
+
|
| 114 |
+
if self.training and return_info:
|
| 115 |
+
if self.bypass_mmd:
|
| 116 |
+
mmd = torch.tensor(0.0)
|
| 117 |
+
else:
|
| 118 |
+
mmd = compute_mmd(x)
|
| 119 |
+
|
| 120 |
+
info["mmd"] = mmd
|
| 121 |
+
|
| 122 |
+
if return_info:
|
| 123 |
+
return x, info
|
| 124 |
+
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def decode(self, x):
|
| 128 |
+
if self.noise_augment_dim > 0:
|
| 129 |
+
noise = torch.randn(
|
| 130 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 131 |
+
).type_as(x)
|
| 132 |
+
x = torch.cat([x, noise], dim=1)
|
| 133 |
+
|
| 134 |
+
return x
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class L2Bottleneck(Bottleneck):
|
| 138 |
+
def __init__(self):
|
| 139 |
+
super().__init__(is_discrete=False)
|
| 140 |
+
|
| 141 |
+
def encode(self, x, return_info=False):
|
| 142 |
+
info = {}
|
| 143 |
+
|
| 144 |
+
x = F.normalize(x, dim=1)
|
| 145 |
+
|
| 146 |
+
if return_info:
|
| 147 |
+
return x, info
|
| 148 |
+
else:
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
def decode(self, x):
|
| 152 |
+
return F.normalize(x, dim=1)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class RVQBottleneck(DiscreteBottleneck):
|
| 156 |
+
def __init__(self, **quantizer_kwargs):
|
| 157 |
+
super().__init__(
|
| 158 |
+
num_quantizers=quantizer_kwargs["num_quantizers"],
|
| 159 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 160 |
+
tokens_id="quantizer_indices",
|
| 161 |
+
)
|
| 162 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 163 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 164 |
+
|
| 165 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 166 |
+
info = {}
|
| 167 |
+
|
| 168 |
+
x = rearrange(x, "b c n -> b n c")
|
| 169 |
+
x, indices, loss = self.quantizer(x)
|
| 170 |
+
x = rearrange(x, "b n c -> b c n")
|
| 171 |
+
|
| 172 |
+
info["quantizer_indices"] = indices
|
| 173 |
+
info["quantizer_loss"] = loss.mean()
|
| 174 |
+
|
| 175 |
+
if return_info:
|
| 176 |
+
return x, info
|
| 177 |
+
else:
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
def decode(self, x):
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def decode_tokens(self, codes, **kwargs):
|
| 184 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 185 |
+
|
| 186 |
+
return self.decode(latents, **kwargs)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class RVQVAEBottleneck(DiscreteBottleneck):
|
| 190 |
+
def __init__(self, **quantizer_kwargs):
|
| 191 |
+
super().__init__(
|
| 192 |
+
num_quantizers=quantizer_kwargs["num_quantizers"],
|
| 193 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 194 |
+
tokens_id="quantizer_indices",
|
| 195 |
+
)
|
| 196 |
+
self.quantizer = ResidualVQ(**quantizer_kwargs)
|
| 197 |
+
self.num_quantizers = quantizer_kwargs["num_quantizers"]
|
| 198 |
+
|
| 199 |
+
def encode(self, x, return_info=False):
|
| 200 |
+
info = {}
|
| 201 |
+
|
| 202 |
+
x, kl = vae_sample(*x.chunk(2, dim=1))
|
| 203 |
+
|
| 204 |
+
info["kl"] = kl
|
| 205 |
+
|
| 206 |
+
x = rearrange(x, "b c n -> b n c")
|
| 207 |
+
x, indices, loss = self.quantizer(x)
|
| 208 |
+
x = rearrange(x, "b n c -> b c n")
|
| 209 |
+
|
| 210 |
+
info["quantizer_indices"] = indices
|
| 211 |
+
info["quantizer_loss"] = loss.mean()
|
| 212 |
+
|
| 213 |
+
if return_info:
|
| 214 |
+
return x, info
|
| 215 |
+
else:
|
| 216 |
+
return x
|
| 217 |
+
|
| 218 |
+
def decode(self, x):
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def decode_tokens(self, codes, **kwargs):
|
| 222 |
+
latents = self.quantizer.get_outputs_from_indices(codes)
|
| 223 |
+
|
| 224 |
+
return self.decode(latents, **kwargs)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class DACRVQBottleneck(DiscreteBottleneck):
|
| 228 |
+
def __init__(
|
| 229 |
+
self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs
|
| 230 |
+
):
|
| 231 |
+
super().__init__(
|
| 232 |
+
num_quantizers=quantizer_kwargs["n_codebooks"],
|
| 233 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 234 |
+
tokens_id="codes",
|
| 235 |
+
)
|
| 236 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 237 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 238 |
+
self.quantize_on_decode = quantize_on_decode
|
| 239 |
+
self.noise_augment_dim = noise_augment_dim
|
| 240 |
+
|
| 241 |
+
def encode(self, x, return_info=False, **kwargs):
|
| 242 |
+
info = {}
|
| 243 |
+
|
| 244 |
+
info["pre_quantizer"] = x
|
| 245 |
+
|
| 246 |
+
if self.quantize_on_decode:
|
| 247 |
+
return x, info if return_info else x
|
| 248 |
+
|
| 249 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
|
| 250 |
+
|
| 251 |
+
output = {
|
| 252 |
+
"z": z,
|
| 253 |
+
"codes": codes,
|
| 254 |
+
"latents": latents,
|
| 255 |
+
"vq/commitment_loss": commitment_loss,
|
| 256 |
+
"vq/codebook_loss": codebook_loss,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 260 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 261 |
+
|
| 262 |
+
info.update(output)
|
| 263 |
+
|
| 264 |
+
if return_info:
|
| 265 |
+
return output["z"], info
|
| 266 |
+
|
| 267 |
+
return output["z"]
|
| 268 |
+
|
| 269 |
+
def decode(self, x):
|
| 270 |
+
if self.quantize_on_decode:
|
| 271 |
+
x = self.quantizer(x)[0]
|
| 272 |
+
|
| 273 |
+
if self.noise_augment_dim > 0:
|
| 274 |
+
noise = torch.randn(
|
| 275 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 276 |
+
).type_as(x)
|
| 277 |
+
x = torch.cat([x, noise], dim=1)
|
| 278 |
+
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
def decode_tokens(self, codes, **kwargs):
|
| 282 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 283 |
+
|
| 284 |
+
return self.decode(latents, **kwargs)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class DACRVQVAEBottleneck(DiscreteBottleneck):
|
| 288 |
+
def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
|
| 289 |
+
super().__init__(
|
| 290 |
+
num_quantizers=quantizer_kwargs["n_codebooks"],
|
| 291 |
+
codebook_size=quantizer_kwargs["codebook_size"],
|
| 292 |
+
tokens_id="codes",
|
| 293 |
+
)
|
| 294 |
+
self.quantizer = DACResidualVQ(**quantizer_kwargs)
|
| 295 |
+
self.num_quantizers = quantizer_kwargs["n_codebooks"]
|
| 296 |
+
self.quantize_on_decode = quantize_on_decode
|
| 297 |
+
|
| 298 |
+
def encode(self, x, return_info=False, n_quantizers: int = None):
|
| 299 |
+
info = {}
|
| 300 |
+
|
| 301 |
+
mean, scale = x.chunk(2, dim=1)
|
| 302 |
+
|
| 303 |
+
x, kl = vae_sample(mean, scale)
|
| 304 |
+
|
| 305 |
+
info["pre_quantizer"] = x
|
| 306 |
+
info["kl"] = kl
|
| 307 |
+
|
| 308 |
+
if self.quantize_on_decode:
|
| 309 |
+
return x, info if return_info else x
|
| 310 |
+
|
| 311 |
+
z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
|
| 312 |
+
x, n_quantizers=n_quantizers
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
output = {
|
| 316 |
+
"z": z,
|
| 317 |
+
"codes": codes,
|
| 318 |
+
"latents": latents,
|
| 319 |
+
"vq/commitment_loss": commitment_loss,
|
| 320 |
+
"vq/codebook_loss": codebook_loss,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
output["vq/commitment_loss"] /= self.num_quantizers
|
| 324 |
+
output["vq/codebook_loss"] /= self.num_quantizers
|
| 325 |
+
|
| 326 |
+
info.update(output)
|
| 327 |
+
|
| 328 |
+
if return_info:
|
| 329 |
+
return output["z"], info
|
| 330 |
+
|
| 331 |
+
return output["z"]
|
| 332 |
+
|
| 333 |
+
def decode(self, x):
|
| 334 |
+
if self.quantize_on_decode:
|
| 335 |
+
x = self.quantizer(x)[0]
|
| 336 |
+
|
| 337 |
+
return x
|
| 338 |
+
|
| 339 |
+
def decode_tokens(self, codes, **kwargs):
|
| 340 |
+
latents, _, _ = self.quantizer.from_codes(codes)
|
| 341 |
+
|
| 342 |
+
return self.decode(latents, **kwargs)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class FSQBottleneck(DiscreteBottleneck):
|
| 346 |
+
def __init__(self, noise_augment_dim=0, **kwargs):
|
| 347 |
+
super().__init__(
|
| 348 |
+
num_quantizers=kwargs.get("num_codebooks", 1),
|
| 349 |
+
codebook_size=np.prod(kwargs["levels"]),
|
| 350 |
+
tokens_id="quantizer_indices",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
self.noise_augment_dim = noise_augment_dim
|
| 354 |
+
|
| 355 |
+
self.quantizer = FSQ(
|
| 356 |
+
**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def encode(self, x, return_info=False):
|
| 360 |
+
info = {}
|
| 361 |
+
|
| 362 |
+
orig_dtype = x.dtype
|
| 363 |
+
x = x.float()
|
| 364 |
+
|
| 365 |
+
x = rearrange(x, "b c n -> b n c")
|
| 366 |
+
x, indices = self.quantizer(x)
|
| 367 |
+
x = rearrange(x, "b n c -> b c n")
|
| 368 |
+
|
| 369 |
+
x = x.to(orig_dtype)
|
| 370 |
+
|
| 371 |
+
# Reorder indices to match the expected format
|
| 372 |
+
indices = rearrange(indices, "b n q -> b q n")
|
| 373 |
+
|
| 374 |
+
info["quantizer_indices"] = indices
|
| 375 |
+
|
| 376 |
+
if return_info:
|
| 377 |
+
return x, info
|
| 378 |
+
else:
|
| 379 |
+
return x
|
| 380 |
+
|
| 381 |
+
def decode(self, x):
|
| 382 |
+
if self.noise_augment_dim > 0:
|
| 383 |
+
noise = torch.randn(
|
| 384 |
+
x.shape[0], self.noise_augment_dim, x.shape[-1]
|
| 385 |
+
).type_as(x)
|
| 386 |
+
x = torch.cat([x, noise], dim=1)
|
| 387 |
+
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 391 |
+
latents = self.quantizer.indices_to_codes(tokens)
|
| 392 |
+
|
| 393 |
+
return self.decode(latents, **kwargs)
|
src/YingMusicSinger/utils/stable_audio_tools/conditioners.py
ADDED
|
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
|
| 2 |
+
|
| 3 |
+
import gc
|
| 4 |
+
import logging
|
| 5 |
+
import string
|
| 6 |
+
import typing as tp
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from .adp import NumberEmbedder
|
| 13 |
+
|
| 14 |
+
# from ..inference.utils import set_audio_channels
|
| 15 |
+
from .factory import create_pretransform_from_config
|
| 16 |
+
from .pretransforms import Pretransform
|
| 17 |
+
|
| 18 |
+
# from ..training.utils import copy_state_dict
|
| 19 |
+
from .utils import load_ckpt_state_dict
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Conditioner(nn.Module):
|
| 23 |
+
def __init__(self, dim: int, output_dim: int, project_out: bool = False):
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
self.dim = dim
|
| 27 |
+
self.output_dim = output_dim
|
| 28 |
+
self.proj_out = (
|
| 29 |
+
nn.Linear(dim, output_dim)
|
| 30 |
+
if (dim != output_dim or project_out)
|
| 31 |
+
else nn.Identity()
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: tp.Any) -> tp.Any:
|
| 35 |
+
raise NotImplementedError()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class IntConditioner(Conditioner):
|
| 39 |
+
def __init__(self, output_dim: int, min_val: int = 0, max_val: int = 512):
|
| 40 |
+
super().__init__(output_dim, output_dim)
|
| 41 |
+
|
| 42 |
+
self.min_val = min_val
|
| 43 |
+
self.max_val = max_val
|
| 44 |
+
self.int_embedder = nn.Embedding(
|
| 45 |
+
max_val - min_val + 1, output_dim
|
| 46 |
+
).requires_grad_(True)
|
| 47 |
+
|
| 48 |
+
def forward(self, ints: tp.List[int], device=None) -> tp.Any:
|
| 49 |
+
# self.int_embedder.to(device)
|
| 50 |
+
|
| 51 |
+
ints = torch.tensor(ints).to(device)
|
| 52 |
+
ints = ints.clamp(self.min_val, self.max_val)
|
| 53 |
+
|
| 54 |
+
int_embeds = self.int_embedder(ints).unsqueeze(1)
|
| 55 |
+
|
| 56 |
+
return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class NumberConditioner(Conditioner):
|
| 60 |
+
"""
|
| 61 |
+
Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, output_dim: int, min_val: float = 0, max_val: float = 1):
|
| 65 |
+
super().__init__(output_dim, output_dim)
|
| 66 |
+
|
| 67 |
+
self.min_val = min_val
|
| 68 |
+
self.max_val = max_val
|
| 69 |
+
|
| 70 |
+
self.embedder = NumberEmbedder(features=output_dim)
|
| 71 |
+
|
| 72 |
+
def forward(self, floats: tp.List[float], device=None) -> tp.Any:
|
| 73 |
+
# Cast the inputs to floats
|
| 74 |
+
floats = [float(x) for x in floats]
|
| 75 |
+
|
| 76 |
+
floats = torch.tensor(floats).to(device)
|
| 77 |
+
|
| 78 |
+
floats = floats.clamp(self.min_val, self.max_val)
|
| 79 |
+
|
| 80 |
+
normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
|
| 81 |
+
|
| 82 |
+
# Cast floats to same type as embedder
|
| 83 |
+
embedder_dtype = next(self.embedder.parameters()).dtype
|
| 84 |
+
normalized_floats = normalized_floats.to(embedder_dtype)
|
| 85 |
+
|
| 86 |
+
float_embeds = self.embedder(normalized_floats).unsqueeze(1)
|
| 87 |
+
|
| 88 |
+
return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CLAPTextConditioner(Conditioner):
|
| 92 |
+
def __init__(
|
| 93 |
+
self,
|
| 94 |
+
output_dim: int,
|
| 95 |
+
clap_ckpt_path,
|
| 96 |
+
use_text_features=False,
|
| 97 |
+
feature_layer_ix: int = -1,
|
| 98 |
+
audio_model_type="HTSAT-base",
|
| 99 |
+
enable_fusion=True,
|
| 100 |
+
project_out: bool = False,
|
| 101 |
+
finetune: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__(
|
| 104 |
+
768 if use_text_features else 512, output_dim, project_out=project_out
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.use_text_features = use_text_features
|
| 108 |
+
self.feature_layer_ix = feature_layer_ix
|
| 109 |
+
self.finetune = finetune
|
| 110 |
+
|
| 111 |
+
# Suppress logging from transformers
|
| 112 |
+
previous_level = logging.root.manager.disable
|
| 113 |
+
logging.disable(logging.ERROR)
|
| 114 |
+
with warnings.catch_warnings():
|
| 115 |
+
warnings.simplefilter("ignore")
|
| 116 |
+
try:
|
| 117 |
+
import laion_clap
|
| 118 |
+
from laion_clap.clap_module.factory import (
|
| 119 |
+
load_state_dict as clap_load_state_dict,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
model = laion_clap.CLAP_Module(
|
| 123 |
+
enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if self.finetune:
|
| 127 |
+
self.model = model
|
| 128 |
+
else:
|
| 129 |
+
self.__dict__["model"] = model
|
| 130 |
+
|
| 131 |
+
state_dict = clap_load_state_dict(clap_ckpt_path)
|
| 132 |
+
self.model.model.load_state_dict(state_dict, strict=False)
|
| 133 |
+
|
| 134 |
+
if self.finetune:
|
| 135 |
+
self.model.model.text_branch.requires_grad_(True)
|
| 136 |
+
self.model.model.text_branch.train()
|
| 137 |
+
else:
|
| 138 |
+
self.model.model.text_branch.requires_grad_(False)
|
| 139 |
+
self.model.model.text_branch.eval()
|
| 140 |
+
|
| 141 |
+
finally:
|
| 142 |
+
logging.disable(previous_level)
|
| 143 |
+
|
| 144 |
+
del self.model.model.audio_branch
|
| 145 |
+
|
| 146 |
+
gc.collect()
|
| 147 |
+
torch.cuda.empty_cache()
|
| 148 |
+
|
| 149 |
+
def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
|
| 150 |
+
prompt_tokens = self.model.tokenizer(prompts)
|
| 151 |
+
attention_mask = prompt_tokens["attention_mask"].to(
|
| 152 |
+
device=device, non_blocking=True
|
| 153 |
+
)
|
| 154 |
+
prompt_features = self.model.model.text_branch(
|
| 155 |
+
input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
|
| 156 |
+
attention_mask=attention_mask,
|
| 157 |
+
output_hidden_states=True,
|
| 158 |
+
)["hidden_states"][layer_ix]
|
| 159 |
+
|
| 160 |
+
return prompt_features, attention_mask
|
| 161 |
+
|
| 162 |
+
def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
|
| 163 |
+
self.model.to(device)
|
| 164 |
+
|
| 165 |
+
if self.use_text_features:
|
| 166 |
+
if len(texts) == 1:
|
| 167 |
+
text_features, text_attention_mask = self.get_clap_features(
|
| 168 |
+
[texts[0], ""], layer_ix=self.feature_layer_ix, device=device
|
| 169 |
+
)
|
| 170 |
+
text_features = text_features[:1, ...]
|
| 171 |
+
text_attention_mask = text_attention_mask[:1, ...]
|
| 172 |
+
else:
|
| 173 |
+
text_features, text_attention_mask = self.get_clap_features(
|
| 174 |
+
texts, layer_ix=self.feature_layer_ix, device=device
|
| 175 |
+
)
|
| 176 |
+
return [self.proj_out(text_features), text_attention_mask]
|
| 177 |
+
|
| 178 |
+
# Fix for CLAP bug when only one text is passed
|
| 179 |
+
if len(texts) == 1:
|
| 180 |
+
text_embedding = self.model.get_text_embedding(
|
| 181 |
+
[texts[0], ""], use_tensor=True
|
| 182 |
+
)[:1, ...]
|
| 183 |
+
else:
|
| 184 |
+
text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
|
| 185 |
+
|
| 186 |
+
text_embedding = text_embedding.unsqueeze(1).to(device)
|
| 187 |
+
|
| 188 |
+
return [
|
| 189 |
+
self.proj_out(text_embedding),
|
| 190 |
+
torch.ones(text_embedding.shape[0], 1).to(device),
|
| 191 |
+
]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class CLAPAudioConditioner(Conditioner):
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
output_dim: int,
|
| 198 |
+
clap_ckpt_path,
|
| 199 |
+
audio_model_type="HTSAT-base",
|
| 200 |
+
enable_fusion=True,
|
| 201 |
+
project_out: bool = False,
|
| 202 |
+
):
|
| 203 |
+
super().__init__(512, output_dim, project_out=project_out)
|
| 204 |
+
|
| 205 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
+
|
| 207 |
+
# Suppress logging from transformers
|
| 208 |
+
previous_level = logging.root.manager.disable
|
| 209 |
+
logging.disable(logging.ERROR)
|
| 210 |
+
with warnings.catch_warnings():
|
| 211 |
+
warnings.simplefilter("ignore")
|
| 212 |
+
try:
|
| 213 |
+
import laion_clap
|
| 214 |
+
from laion_clap.clap_module.factory import (
|
| 215 |
+
load_state_dict as clap_load_state_dict,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
model = laion_clap.CLAP_Module(
|
| 219 |
+
enable_fusion=enable_fusion, amodel=audio_model_type, device="cpu"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if self.finetune:
|
| 223 |
+
self.model = model
|
| 224 |
+
else:
|
| 225 |
+
self.__dict__["model"] = model
|
| 226 |
+
|
| 227 |
+
state_dict = clap_load_state_dict(clap_ckpt_path)
|
| 228 |
+
self.model.model.load_state_dict(state_dict, strict=False)
|
| 229 |
+
|
| 230 |
+
if self.finetune:
|
| 231 |
+
self.model.model.audio_branch.requires_grad_(True)
|
| 232 |
+
self.model.model.audio_branch.train()
|
| 233 |
+
else:
|
| 234 |
+
self.model.model.audio_branch.requires_grad_(False)
|
| 235 |
+
self.model.model.audio_branch.eval()
|
| 236 |
+
|
| 237 |
+
finally:
|
| 238 |
+
logging.disable(previous_level)
|
| 239 |
+
|
| 240 |
+
del self.model.model.text_branch
|
| 241 |
+
|
| 242 |
+
gc.collect()
|
| 243 |
+
torch.cuda.empty_cache()
|
| 244 |
+
|
| 245 |
+
def forward(
|
| 246 |
+
self,
|
| 247 |
+
audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]],
|
| 248 |
+
device: tp.Any = "cuda",
|
| 249 |
+
) -> tp.Any:
|
| 250 |
+
self.model.to(device)
|
| 251 |
+
|
| 252 |
+
if isinstance(audios, list) or isinstance(audios, tuple):
|
| 253 |
+
audios = torch.cat(audios, dim=0)
|
| 254 |
+
|
| 255 |
+
# Convert to mono
|
| 256 |
+
mono_audios = audios.mean(dim=1)
|
| 257 |
+
|
| 258 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 259 |
+
audio_embedding = self.model.get_audio_embedding_from_data(
|
| 260 |
+
mono_audios.float(), use_tensor=True
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
audio_embedding = audio_embedding.unsqueeze(1).to(device)
|
| 264 |
+
|
| 265 |
+
return [
|
| 266 |
+
self.proj_out(audio_embedding),
|
| 267 |
+
torch.ones(audio_embedding.shape[0], 1).to(device),
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class T5Conditioner(Conditioner):
|
| 272 |
+
T5_MODELS = [
|
| 273 |
+
"t5-small",
|
| 274 |
+
"t5-base",
|
| 275 |
+
"t5-large",
|
| 276 |
+
"t5-3b",
|
| 277 |
+
"t5-11b",
|
| 278 |
+
"google/flan-t5-small",
|
| 279 |
+
"google/flan-t5-base",
|
| 280 |
+
"google/flan-t5-large",
|
| 281 |
+
"google/flan-t5-xl",
|
| 282 |
+
"google/flan-t5-xxl",
|
| 283 |
+
]
|
| 284 |
+
|
| 285 |
+
T5_MODEL_DIMS = {
|
| 286 |
+
"t5-small": 512,
|
| 287 |
+
"t5-base": 768,
|
| 288 |
+
"t5-large": 1024,
|
| 289 |
+
"t5-3b": 1024,
|
| 290 |
+
"t5-11b": 1024,
|
| 291 |
+
"t5-xl": 2048,
|
| 292 |
+
"t5-xxl": 4096,
|
| 293 |
+
"google/flan-t5-small": 512,
|
| 294 |
+
"google/flan-t5-base": 768,
|
| 295 |
+
"google/flan-t5-large": 1024,
|
| 296 |
+
"google/flan-t5-3b": 1024,
|
| 297 |
+
"google/flan-t5-11b": 1024,
|
| 298 |
+
"google/flan-t5-xl": 2048,
|
| 299 |
+
"google/flan-t5-xxl": 4096,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
output_dim: int,
|
| 305 |
+
t5_model_name: str = "t5-base",
|
| 306 |
+
max_length: str = 128,
|
| 307 |
+
enable_grad: bool = False,
|
| 308 |
+
project_out: bool = False,
|
| 309 |
+
):
|
| 310 |
+
assert t5_model_name in self.T5_MODELS, (
|
| 311 |
+
f"Unknown T5 model name: {t5_model_name}"
|
| 312 |
+
)
|
| 313 |
+
super().__init__(
|
| 314 |
+
self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 318 |
+
|
| 319 |
+
self.max_length = max_length
|
| 320 |
+
self.enable_grad = enable_grad
|
| 321 |
+
|
| 322 |
+
# Suppress logging from transformers
|
| 323 |
+
previous_level = logging.root.manager.disable
|
| 324 |
+
logging.disable(logging.ERROR)
|
| 325 |
+
with warnings.catch_warnings():
|
| 326 |
+
warnings.simplefilter("ignore")
|
| 327 |
+
try:
|
| 328 |
+
# self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
|
| 329 |
+
# model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
|
| 330 |
+
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name)
|
| 331 |
+
model = (
|
| 332 |
+
T5EncoderModel.from_pretrained(t5_model_name)
|
| 333 |
+
.train(enable_grad)
|
| 334 |
+
.requires_grad_(enable_grad)
|
| 335 |
+
.to(torch.float16)
|
| 336 |
+
)
|
| 337 |
+
finally:
|
| 338 |
+
logging.disable(previous_level)
|
| 339 |
+
|
| 340 |
+
if self.enable_grad:
|
| 341 |
+
self.model = model
|
| 342 |
+
else:
|
| 343 |
+
self.__dict__["model"] = model
|
| 344 |
+
|
| 345 |
+
def forward(
|
| 346 |
+
self, texts: tp.List[str], device: tp.Union[torch.device, str]
|
| 347 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 348 |
+
self.model.to(device)
|
| 349 |
+
self.proj_out.to(device)
|
| 350 |
+
|
| 351 |
+
encoded = self.tokenizer(
|
| 352 |
+
texts,
|
| 353 |
+
truncation=True,
|
| 354 |
+
max_length=self.max_length,
|
| 355 |
+
padding="max_length",
|
| 356 |
+
return_tensors="pt",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
input_ids = encoded["input_ids"].to(device)
|
| 360 |
+
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
| 361 |
+
|
| 362 |
+
self.model.eval()
|
| 363 |
+
|
| 364 |
+
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(
|
| 365 |
+
self.enable_grad
|
| 366 |
+
):
|
| 367 |
+
embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)[
|
| 368 |
+
"last_hidden_state"
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
embeddings = self.proj_out(embeddings.float())
|
| 372 |
+
|
| 373 |
+
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
| 374 |
+
|
| 375 |
+
return embeddings, attention_mask
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class PhonemeConditioner(Conditioner):
|
| 379 |
+
"""
|
| 380 |
+
A conditioner that turns text into phonemes and embeds them using a lookup table
|
| 381 |
+
Only works for English text
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
output_dim: the dimension of the output embeddings
|
| 385 |
+
max_length: the maximum number of phonemes to embed
|
| 386 |
+
project_out: whether to add another linear projection to the output embeddings
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(
|
| 390 |
+
self,
|
| 391 |
+
output_dim: int,
|
| 392 |
+
max_length: int = 1024,
|
| 393 |
+
project_out: bool = False,
|
| 394 |
+
):
|
| 395 |
+
super().__init__(output_dim, output_dim, project_out=project_out)
|
| 396 |
+
|
| 397 |
+
from g2p_en import G2p
|
| 398 |
+
|
| 399 |
+
self.max_length = max_length
|
| 400 |
+
|
| 401 |
+
self.g2p = G2p()
|
| 402 |
+
|
| 403 |
+
# Reserving 0 for padding, 1 for ignored
|
| 404 |
+
self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
|
| 405 |
+
|
| 406 |
+
def forward(
|
| 407 |
+
self, texts: tp.List[str], device: tp.Union[torch.device, str]
|
| 408 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 409 |
+
self.phoneme_embedder.to(device)
|
| 410 |
+
self.proj_out.to(device)
|
| 411 |
+
|
| 412 |
+
batch_phonemes = [
|
| 413 |
+
self.g2p(text) for text in texts
|
| 414 |
+
] # shape [batch_size, length]
|
| 415 |
+
|
| 416 |
+
phoneme_ignore = [" ", *string.punctuation]
|
| 417 |
+
|
| 418 |
+
# Remove ignored phonemes and cut to max length
|
| 419 |
+
batch_phonemes = [
|
| 420 |
+
[p if p not in phoneme_ignore else "_" for p in phonemes]
|
| 421 |
+
for phonemes in batch_phonemes
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
# Convert to ids
|
| 425 |
+
phoneme_ids = [
|
| 426 |
+
[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes]
|
| 427 |
+
for phonemes in batch_phonemes
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# Pad to match longest and make a mask tensor for the padding
|
| 431 |
+
longest = max([len(ids) for ids in phoneme_ids])
|
| 432 |
+
phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
|
| 433 |
+
|
| 434 |
+
phoneme_ids = torch.tensor(phoneme_ids).to(device)
|
| 435 |
+
|
| 436 |
+
# Convert to embeddings
|
| 437 |
+
phoneme_embeds = self.phoneme_embedder(phoneme_ids)
|
| 438 |
+
|
| 439 |
+
phoneme_embeds = self.proj_out(phoneme_embeds)
|
| 440 |
+
|
| 441 |
+
return phoneme_embeds, torch.ones(
|
| 442 |
+
phoneme_embeds.shape[0], phoneme_embeds.shape[1]
|
| 443 |
+
).to(device)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class TokenizerLUTConditioner(Conditioner):
|
| 447 |
+
"""
|
| 448 |
+
A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
|
| 452 |
+
output_dim: the dimension of the output embeddings
|
| 453 |
+
max_length: the maximum length of the text to embed
|
| 454 |
+
project_out: whether to add another linear projection to the output embeddings
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
def __init__(
|
| 458 |
+
self,
|
| 459 |
+
tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
|
| 460 |
+
output_dim: int,
|
| 461 |
+
max_length: int = 1024,
|
| 462 |
+
project_out: bool = False,
|
| 463 |
+
):
|
| 464 |
+
super().__init__(output_dim, output_dim, project_out=project_out)
|
| 465 |
+
|
| 466 |
+
from transformers import AutoTokenizer
|
| 467 |
+
|
| 468 |
+
# Suppress logging from transformers
|
| 469 |
+
previous_level = logging.root.manager.disable
|
| 470 |
+
logging.disable(logging.ERROR)
|
| 471 |
+
with warnings.catch_warnings():
|
| 472 |
+
warnings.simplefilter("ignore")
|
| 473 |
+
try:
|
| 474 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 475 |
+
finally:
|
| 476 |
+
logging.disable(previous_level)
|
| 477 |
+
|
| 478 |
+
self.max_length = max_length
|
| 479 |
+
|
| 480 |
+
self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self, texts: tp.List[str], device: tp.Union[torch.device, str]
|
| 484 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 485 |
+
self.proj_out.to(device)
|
| 486 |
+
|
| 487 |
+
encoded = self.tokenizer(
|
| 488 |
+
texts,
|
| 489 |
+
truncation=True,
|
| 490 |
+
max_length=self.max_length,
|
| 491 |
+
padding="max_length",
|
| 492 |
+
return_tensors="pt",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
input_ids = encoded["input_ids"].to(device)
|
| 496 |
+
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
|
| 497 |
+
|
| 498 |
+
embeddings = self.token_embedder(input_ids)
|
| 499 |
+
|
| 500 |
+
embeddings = self.proj_out(embeddings)
|
| 501 |
+
|
| 502 |
+
embeddings = embeddings * attention_mask.unsqueeze(-1).float()
|
| 503 |
+
|
| 504 |
+
return embeddings, attention_mask
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class PretransformConditioner(Conditioner):
|
| 508 |
+
"""
|
| 509 |
+
A conditioner that uses a pretransform's encoder for conditioning
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
pretransform: an instantiated pretransform to use for conditioning
|
| 513 |
+
output_dim: the dimension of the output embeddings
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
def __init__(self, pretransform: Pretransform, output_dim: int):
|
| 517 |
+
super().__init__(pretransform.encoded_channels, output_dim)
|
| 518 |
+
|
| 519 |
+
self.pretransform = pretransform
|
| 520 |
+
|
| 521 |
+
def forward(
|
| 522 |
+
self,
|
| 523 |
+
audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]],
|
| 524 |
+
device: tp.Union[torch.device, str],
|
| 525 |
+
) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
| 526 |
+
self.pretransform.to(device)
|
| 527 |
+
self.proj_out.to(device)
|
| 528 |
+
|
| 529 |
+
if isinstance(audio, list) or isinstance(audio, tuple):
|
| 530 |
+
audio = torch.cat(audio, dim=0)
|
| 531 |
+
|
| 532 |
+
# Convert audio to pretransform input channels
|
| 533 |
+
audio = set_audio_channels(audio, self.pretransform.io_channels)
|
| 534 |
+
|
| 535 |
+
latents = self.pretransform.encode(audio)
|
| 536 |
+
|
| 537 |
+
latents = self.proj_out(latents)
|
| 538 |
+
|
| 539 |
+
return [
|
| 540 |
+
latents,
|
| 541 |
+
torch.ones(latents.shape[0], latents.shape[2]).to(latents.device),
|
| 542 |
+
]
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class MultiConditioner(nn.Module):
|
| 546 |
+
"""
|
| 547 |
+
A module that applies multiple conditioners to an input dictionary based on the keys
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
|
| 551 |
+
default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(
|
| 555 |
+
self,
|
| 556 |
+
conditioners: tp.Dict[str, Conditioner],
|
| 557 |
+
default_keys: tp.Dict[str, str] = {},
|
| 558 |
+
):
|
| 559 |
+
super().__init__()
|
| 560 |
+
|
| 561 |
+
self.conditioners = nn.ModuleDict(conditioners)
|
| 562 |
+
self.default_keys = default_keys
|
| 563 |
+
|
| 564 |
+
def forward(
|
| 565 |
+
self,
|
| 566 |
+
batch_metadata: tp.List[tp.Dict[str, tp.Any]],
|
| 567 |
+
device: tp.Union[torch.device, str],
|
| 568 |
+
) -> tp.Dict[str, tp.Any]:
|
| 569 |
+
output = {}
|
| 570 |
+
|
| 571 |
+
for key, conditioner in self.conditioners.items():
|
| 572 |
+
condition_key = key
|
| 573 |
+
|
| 574 |
+
conditioner_inputs = []
|
| 575 |
+
|
| 576 |
+
for x in batch_metadata:
|
| 577 |
+
if condition_key not in x:
|
| 578 |
+
if condition_key in self.default_keys:
|
| 579 |
+
condition_key = self.default_keys[condition_key]
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError(
|
| 582 |
+
f"Conditioner key {condition_key} not found in batch metadata"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
|
| 586 |
+
if (
|
| 587 |
+
isinstance(x[condition_key], list)
|
| 588 |
+
or isinstance(x[condition_key], tuple)
|
| 589 |
+
and len(x[condition_key]) == 1
|
| 590 |
+
):
|
| 591 |
+
conditioner_input = x[condition_key][0]
|
| 592 |
+
|
| 593 |
+
else:
|
| 594 |
+
conditioner_input = x[condition_key]
|
| 595 |
+
|
| 596 |
+
conditioner_inputs.append(conditioner_input)
|
| 597 |
+
|
| 598 |
+
output[key] = conditioner(conditioner_inputs, device)
|
| 599 |
+
|
| 600 |
+
return output
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def create_multi_conditioner_from_conditioning_config(
|
| 604 |
+
config: tp.Dict[str, tp.Any],
|
| 605 |
+
) -> MultiConditioner:
|
| 606 |
+
"""
|
| 607 |
+
Create a MultiConditioner from a conditioning config dictionary
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
config: the conditioning config dictionary
|
| 611 |
+
device: the device to put the conditioners on
|
| 612 |
+
"""
|
| 613 |
+
conditioners = {}
|
| 614 |
+
cond_dim = config["cond_dim"]
|
| 615 |
+
|
| 616 |
+
default_keys = config.get("default_keys", {})
|
| 617 |
+
|
| 618 |
+
for conditioner_info in config["configs"]:
|
| 619 |
+
id = conditioner_info["id"]
|
| 620 |
+
|
| 621 |
+
conditioner_type = conditioner_info["type"]
|
| 622 |
+
|
| 623 |
+
conditioner_config = {"output_dim": cond_dim}
|
| 624 |
+
|
| 625 |
+
conditioner_config.update(conditioner_info["config"])
|
| 626 |
+
|
| 627 |
+
if conditioner_type == "t5":
|
| 628 |
+
conditioners[id] = T5Conditioner(**conditioner_config)
|
| 629 |
+
elif conditioner_type == "clap_text":
|
| 630 |
+
conditioners[id] = CLAPTextConditioner(**conditioner_config)
|
| 631 |
+
elif conditioner_type == "clap_audio":
|
| 632 |
+
conditioners[id] = CLAPAudioConditioner(**conditioner_config)
|
| 633 |
+
elif conditioner_type == "int":
|
| 634 |
+
conditioners[id] = IntConditioner(**conditioner_config)
|
| 635 |
+
elif conditioner_type == "number":
|
| 636 |
+
conditioners[id] = NumberConditioner(**conditioner_config)
|
| 637 |
+
elif conditioner_type == "phoneme":
|
| 638 |
+
conditioners[id] = PhonemeConditioner(**conditioner_config)
|
| 639 |
+
elif conditioner_type == "lut":
|
| 640 |
+
conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
|
| 641 |
+
elif conditioner_type == "pretransform":
|
| 642 |
+
sample_rate = conditioner_config.pop("sample_rate", None)
|
| 643 |
+
assert sample_rate is not None, (
|
| 644 |
+
"Sample rate must be specified for pretransform conditioners"
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
pretransform = create_pretransform_from_config(
|
| 648 |
+
conditioner_config.pop("pretransform_config"), sample_rate=sample_rate
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
if conditioner_config.get("pretransform_ckpt_path", None) is not None:
|
| 652 |
+
pretransform.load_state_dict(
|
| 653 |
+
load_ckpt_state_dict(
|
| 654 |
+
conditioner_config.pop("pretransform_ckpt_path")
|
| 655 |
+
)
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
conditioners[id] = PretransformConditioner(
|
| 659 |
+
pretransform, **conditioner_config
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
raise ValueError(f"Unknown conditioner type: {conditioner_type}")
|
| 663 |
+
|
| 664 |
+
return MultiConditioner(conditioners, default_keys=default_keys)
|
src/YingMusicSinger/utils/stable_audio_tools/diffusion.py
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as tp
|
| 2 |
+
from functools import partial
|
| 3 |
+
from time import time
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
# from ..inference.generation import generate_diffusion_cond
|
| 11 |
+
from .adp import UNet1d, UNetCFG1d
|
| 12 |
+
from .blocks import (
|
| 13 |
+
Downsample1d,
|
| 14 |
+
Downsample1d_2,
|
| 15 |
+
FourierFeatures,
|
| 16 |
+
ResConvBlock,
|
| 17 |
+
SelfAttention1d,
|
| 18 |
+
SkipBlock,
|
| 19 |
+
Upsample1d,
|
| 20 |
+
Upsample1d_2,
|
| 21 |
+
expand_to_planes,
|
| 22 |
+
)
|
| 23 |
+
from .conditioners import (
|
| 24 |
+
MultiConditioner,
|
| 25 |
+
create_multi_conditioner_from_conditioning_config,
|
| 26 |
+
)
|
| 27 |
+
from .dit import DiffusionTransformer
|
| 28 |
+
from .factory import create_pretransform_from_config
|
| 29 |
+
from .pretransforms import Pretransform
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Profiler:
|
| 33 |
+
def __init__(self):
|
| 34 |
+
self.ticks = [[time(), None]]
|
| 35 |
+
|
| 36 |
+
def tick(self, msg):
|
| 37 |
+
self.ticks.append([time(), msg])
|
| 38 |
+
|
| 39 |
+
def __repr__(self):
|
| 40 |
+
rep = 80 * "=" + "\n"
|
| 41 |
+
for i in range(1, len(self.ticks)):
|
| 42 |
+
msg = self.ticks[i][1]
|
| 43 |
+
ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
|
| 44 |
+
rep += msg + f": {ellapsed * 1000:.2f}ms\n"
|
| 45 |
+
rep += 80 * "=" + "\n\n\n"
|
| 46 |
+
return rep
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DiffusionModel(nn.Module):
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__(*args, **kwargs)
|
| 52 |
+
|
| 53 |
+
def forward(self, x, t, **kwargs):
|
| 54 |
+
raise NotImplementedError()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class DiffusionModelWrapper(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
model: DiffusionModel,
|
| 61 |
+
io_channels,
|
| 62 |
+
sample_size,
|
| 63 |
+
sample_rate,
|
| 64 |
+
min_input_length,
|
| 65 |
+
pretransform: tp.Optional[Pretransform] = None,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.io_channels = io_channels
|
| 69 |
+
self.sample_size = sample_size
|
| 70 |
+
self.sample_rate = sample_rate
|
| 71 |
+
self.min_input_length = min_input_length
|
| 72 |
+
|
| 73 |
+
self.model = model
|
| 74 |
+
|
| 75 |
+
if pretransform is not None:
|
| 76 |
+
self.pretransform = pretransform
|
| 77 |
+
else:
|
| 78 |
+
self.pretransform = None
|
| 79 |
+
|
| 80 |
+
def forward(self, x, t, **kwargs):
|
| 81 |
+
return self.model(x, t, **kwargs)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class ConditionedDiffusionModel(nn.Module):
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
*args,
|
| 88 |
+
supports_cross_attention: bool = False,
|
| 89 |
+
supports_input_concat: bool = False,
|
| 90 |
+
supports_global_cond: bool = False,
|
| 91 |
+
supports_prepend_cond: bool = False,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
super().__init__(*args, **kwargs)
|
| 95 |
+
self.supports_cross_attention = supports_cross_attention
|
| 96 |
+
self.supports_input_concat = supports_input_concat
|
| 97 |
+
self.supports_global_cond = supports_global_cond
|
| 98 |
+
self.supports_prepend_cond = supports_prepend_cond
|
| 99 |
+
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
x: torch.Tensor,
|
| 103 |
+
t: torch.Tensor,
|
| 104 |
+
cross_attn_cond: torch.Tensor = None,
|
| 105 |
+
cross_attn_mask: torch.Tensor = None,
|
| 106 |
+
input_concat_cond: torch.Tensor = None,
|
| 107 |
+
global_embed: torch.Tensor = None,
|
| 108 |
+
prepend_cond: torch.Tensor = None,
|
| 109 |
+
prepend_cond_mask: torch.Tensor = None,
|
| 110 |
+
cfg_scale: float = 1.0,
|
| 111 |
+
cfg_dropout_prob: float = 0.0,
|
| 112 |
+
batch_cfg: bool = False,
|
| 113 |
+
rescale_cfg: bool = False,
|
| 114 |
+
**kwargs,
|
| 115 |
+
):
|
| 116 |
+
raise NotImplementedError()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ConditionedDiffusionModelWrapper(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
A diffusion model that takes in conditioning
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
model: ConditionedDiffusionModel,
|
| 127 |
+
conditioner: MultiConditioner,
|
| 128 |
+
io_channels,
|
| 129 |
+
sample_rate,
|
| 130 |
+
min_input_length: int,
|
| 131 |
+
diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
|
| 132 |
+
pretransform: tp.Optional[Pretransform] = None,
|
| 133 |
+
cross_attn_cond_ids: tp.List[str] = [],
|
| 134 |
+
global_cond_ids: tp.List[str] = [],
|
| 135 |
+
input_concat_ids: tp.List[str] = [],
|
| 136 |
+
prepend_cond_ids: tp.List[str] = [],
|
| 137 |
+
):
|
| 138 |
+
super().__init__()
|
| 139 |
+
|
| 140 |
+
self.model = model
|
| 141 |
+
self.conditioner = conditioner
|
| 142 |
+
self.io_channels = io_channels
|
| 143 |
+
self.sample_rate = sample_rate
|
| 144 |
+
self.diffusion_objective = diffusion_objective
|
| 145 |
+
self.pretransform = pretransform
|
| 146 |
+
self.cross_attn_cond_ids = cross_attn_cond_ids
|
| 147 |
+
self.global_cond_ids = global_cond_ids
|
| 148 |
+
self.input_concat_ids = input_concat_ids
|
| 149 |
+
self.prepend_cond_ids = prepend_cond_ids
|
| 150 |
+
self.min_input_length = min_input_length
|
| 151 |
+
|
| 152 |
+
def get_conditioning_inputs(
|
| 153 |
+
self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False
|
| 154 |
+
):
|
| 155 |
+
cross_attention_input = None
|
| 156 |
+
cross_attention_masks = None
|
| 157 |
+
global_cond = None
|
| 158 |
+
input_concat_cond = None
|
| 159 |
+
prepend_cond = None
|
| 160 |
+
prepend_cond_mask = None
|
| 161 |
+
|
| 162 |
+
if len(self.cross_attn_cond_ids) > 0:
|
| 163 |
+
# Concatenate all cross-attention inputs over the sequence dimension
|
| 164 |
+
# Assumes that the cross-attention inputs are of shape (batch, seq, channels)
|
| 165 |
+
cross_attention_input = []
|
| 166 |
+
cross_attention_masks = []
|
| 167 |
+
|
| 168 |
+
for key in self.cross_attn_cond_ids:
|
| 169 |
+
cross_attn_in, cross_attn_mask = conditioning_tensors[key]
|
| 170 |
+
|
| 171 |
+
# Add sequence dimension if it's not there
|
| 172 |
+
if len(cross_attn_in.shape) == 2:
|
| 173 |
+
cross_attn_in = cross_attn_in.unsqueeze(1)
|
| 174 |
+
cross_attn_mask = cross_attn_mask.unsqueeze(1)
|
| 175 |
+
|
| 176 |
+
cross_attention_input.append(cross_attn_in)
|
| 177 |
+
cross_attention_masks.append(cross_attn_mask)
|
| 178 |
+
|
| 179 |
+
cross_attention_input = torch.cat(cross_attention_input, dim=1)
|
| 180 |
+
cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
|
| 181 |
+
|
| 182 |
+
if len(self.global_cond_ids) > 0:
|
| 183 |
+
# Concatenate all global conditioning inputs over the channel dimension
|
| 184 |
+
# Assumes that the global conditioning inputs are of shape (batch, channels)
|
| 185 |
+
global_conds = []
|
| 186 |
+
for key in self.global_cond_ids:
|
| 187 |
+
global_cond_input = conditioning_tensors[key][0]
|
| 188 |
+
|
| 189 |
+
global_conds.append(global_cond_input)
|
| 190 |
+
|
| 191 |
+
# Concatenate over the channel dimension
|
| 192 |
+
global_cond = torch.cat(global_conds, dim=-1)
|
| 193 |
+
|
| 194 |
+
if len(global_cond.shape) == 3:
|
| 195 |
+
global_cond = global_cond.squeeze(1)
|
| 196 |
+
|
| 197 |
+
if len(self.input_concat_ids) > 0:
|
| 198 |
+
# Concatenate all input concat conditioning inputs over the channel dimension
|
| 199 |
+
# Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
|
| 200 |
+
input_concat_cond = torch.cat(
|
| 201 |
+
[conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
if len(self.prepend_cond_ids) > 0:
|
| 205 |
+
# Concatenate all prepend conditioning inputs over the sequence dimension
|
| 206 |
+
# Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
|
| 207 |
+
prepend_conds = []
|
| 208 |
+
prepend_cond_masks = []
|
| 209 |
+
|
| 210 |
+
for key in self.prepend_cond_ids:
|
| 211 |
+
prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
|
| 212 |
+
prepend_conds.append(prepend_cond_input)
|
| 213 |
+
prepend_cond_masks.append(prepend_cond_mask)
|
| 214 |
+
|
| 215 |
+
prepend_cond = torch.cat(prepend_conds, dim=1)
|
| 216 |
+
prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
|
| 217 |
+
|
| 218 |
+
if negative:
|
| 219 |
+
return {
|
| 220 |
+
"negative_cross_attn_cond": cross_attention_input,
|
| 221 |
+
"negative_cross_attn_mask": cross_attention_masks,
|
| 222 |
+
"negative_global_cond": global_cond,
|
| 223 |
+
"negative_input_concat_cond": input_concat_cond,
|
| 224 |
+
}
|
| 225 |
+
else:
|
| 226 |
+
return {
|
| 227 |
+
"cross_attn_cond": cross_attention_input,
|
| 228 |
+
"cross_attn_mask": cross_attention_masks,
|
| 229 |
+
"global_cond": global_cond,
|
| 230 |
+
"input_concat_cond": input_concat_cond,
|
| 231 |
+
"prepend_cond": prepend_cond,
|
| 232 |
+
"prepend_cond_mask": prepend_cond_mask,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs
|
| 237 |
+
):
|
| 238 |
+
return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
|
| 239 |
+
|
| 240 |
+
def generate(self, *args, **kwargs):
|
| 241 |
+
return generate_diffusion_cond(self, *args, **kwargs)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class UNetCFG1DWrapper(ConditionedDiffusionModel):
|
| 245 |
+
def __init__(self, *args, **kwargs):
|
| 246 |
+
super().__init__(
|
| 247 |
+
supports_cross_attention=True,
|
| 248 |
+
supports_global_cond=True,
|
| 249 |
+
supports_input_concat=True,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
self.model = UNetCFG1d(*args, **kwargs)
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
for param in self.model.parameters():
|
| 256 |
+
param *= 0.5
|
| 257 |
+
|
| 258 |
+
def forward(
|
| 259 |
+
self,
|
| 260 |
+
x,
|
| 261 |
+
t,
|
| 262 |
+
cross_attn_cond=None,
|
| 263 |
+
cross_attn_mask=None,
|
| 264 |
+
input_concat_cond=None,
|
| 265 |
+
global_cond=None,
|
| 266 |
+
cfg_scale=1.0,
|
| 267 |
+
cfg_dropout_prob: float = 0.0,
|
| 268 |
+
batch_cfg: bool = False,
|
| 269 |
+
rescale_cfg: bool = False,
|
| 270 |
+
negative_cross_attn_cond=None,
|
| 271 |
+
negative_cross_attn_mask=None,
|
| 272 |
+
negative_global_cond=None,
|
| 273 |
+
negative_input_concat_cond=None,
|
| 274 |
+
prepend_cond=None,
|
| 275 |
+
prepend_cond_mask=None,
|
| 276 |
+
**kwargs,
|
| 277 |
+
):
|
| 278 |
+
p = Profiler()
|
| 279 |
+
|
| 280 |
+
p.tick("start")
|
| 281 |
+
|
| 282 |
+
channels_list = None
|
| 283 |
+
if input_concat_cond is not None:
|
| 284 |
+
channels_list = [input_concat_cond]
|
| 285 |
+
|
| 286 |
+
outputs = self.model(
|
| 287 |
+
x,
|
| 288 |
+
t,
|
| 289 |
+
embedding=cross_attn_cond,
|
| 290 |
+
embedding_mask=cross_attn_mask,
|
| 291 |
+
features=global_cond,
|
| 292 |
+
channels_list=channels_list,
|
| 293 |
+
embedding_scale=cfg_scale,
|
| 294 |
+
embedding_mask_proba=cfg_dropout_prob,
|
| 295 |
+
batch_cfg=batch_cfg,
|
| 296 |
+
rescale_cfg=rescale_cfg,
|
| 297 |
+
negative_embedding=negative_cross_attn_cond,
|
| 298 |
+
negative_embedding_mask=negative_cross_attn_mask,
|
| 299 |
+
**kwargs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
p.tick("UNetCFG1D forward")
|
| 303 |
+
|
| 304 |
+
# print(f"Profiler: {p}")
|
| 305 |
+
return outputs
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class UNet1DCondWrapper(ConditionedDiffusionModel):
|
| 309 |
+
def __init__(self, *args, **kwargs):
|
| 310 |
+
super().__init__(
|
| 311 |
+
supports_cross_attention=False,
|
| 312 |
+
supports_global_cond=True,
|
| 313 |
+
supports_input_concat=True,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
self.model = UNet1d(*args, **kwargs)
|
| 317 |
+
|
| 318 |
+
with torch.no_grad():
|
| 319 |
+
for param in self.model.parameters():
|
| 320 |
+
param *= 0.5
|
| 321 |
+
|
| 322 |
+
def forward(
|
| 323 |
+
self,
|
| 324 |
+
x,
|
| 325 |
+
t,
|
| 326 |
+
input_concat_cond=None,
|
| 327 |
+
global_cond=None,
|
| 328 |
+
cross_attn_cond=None,
|
| 329 |
+
cross_attn_mask=None,
|
| 330 |
+
prepend_cond=None,
|
| 331 |
+
prepend_cond_mask=None,
|
| 332 |
+
cfg_scale=1.0,
|
| 333 |
+
cfg_dropout_prob: float = 0.0,
|
| 334 |
+
batch_cfg: bool = False,
|
| 335 |
+
rescale_cfg: bool = False,
|
| 336 |
+
negative_cross_attn_cond=None,
|
| 337 |
+
negative_cross_attn_mask=None,
|
| 338 |
+
negative_global_cond=None,
|
| 339 |
+
negative_input_concat_cond=None,
|
| 340 |
+
**kwargs,
|
| 341 |
+
):
|
| 342 |
+
channels_list = None
|
| 343 |
+
if input_concat_cond is not None:
|
| 344 |
+
# Interpolate input_concat_cond to the same length as x
|
| 345 |
+
if input_concat_cond.shape[2] != x.shape[2]:
|
| 346 |
+
input_concat_cond = F.interpolate(
|
| 347 |
+
input_concat_cond, (x.shape[2],), mode="nearest"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
channels_list = [input_concat_cond]
|
| 351 |
+
|
| 352 |
+
outputs = self.model(
|
| 353 |
+
x, t, features=global_cond, channels_list=channels_list, **kwargs
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return outputs
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class UNet1DUncondWrapper(DiffusionModel):
|
| 360 |
+
def __init__(self, in_channels, *args, **kwargs):
|
| 361 |
+
super().__init__()
|
| 362 |
+
|
| 363 |
+
self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
|
| 364 |
+
|
| 365 |
+
self.io_channels = in_channels
|
| 366 |
+
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
for param in self.model.parameters():
|
| 369 |
+
param *= 0.5
|
| 370 |
+
|
| 371 |
+
def forward(self, x, t, **kwargs):
|
| 372 |
+
return self.model(x, t, **kwargs)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class DAU1DCondWrapper(ConditionedDiffusionModel):
|
| 376 |
+
def __init__(self, *args, **kwargs):
|
| 377 |
+
super().__init__(
|
| 378 |
+
supports_cross_attention=False,
|
| 379 |
+
supports_global_cond=False,
|
| 380 |
+
supports_input_concat=True,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
self.model = DiffusionAttnUnet1D(*args, **kwargs)
|
| 384 |
+
|
| 385 |
+
with torch.no_grad():
|
| 386 |
+
for param in self.model.parameters():
|
| 387 |
+
param *= 0.5
|
| 388 |
+
|
| 389 |
+
def forward(
|
| 390 |
+
self,
|
| 391 |
+
x,
|
| 392 |
+
t,
|
| 393 |
+
input_concat_cond=None,
|
| 394 |
+
cross_attn_cond=None,
|
| 395 |
+
cross_attn_mask=None,
|
| 396 |
+
global_cond=None,
|
| 397 |
+
cfg_scale=1.0,
|
| 398 |
+
cfg_dropout_prob: float = 0.0,
|
| 399 |
+
batch_cfg: bool = False,
|
| 400 |
+
rescale_cfg: bool = False,
|
| 401 |
+
negative_cross_attn_cond=None,
|
| 402 |
+
negative_cross_attn_mask=None,
|
| 403 |
+
negative_global_cond=None,
|
| 404 |
+
negative_input_concat_cond=None,
|
| 405 |
+
prepend_cond=None,
|
| 406 |
+
**kwargs,
|
| 407 |
+
):
|
| 408 |
+
return self.model(x, t, cond=input_concat_cond)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class DiffusionAttnUnet1D(nn.Module):
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
io_channels=2,
|
| 415 |
+
depth=14,
|
| 416 |
+
n_attn_layers=6,
|
| 417 |
+
channels=[128, 128, 256, 256] + [512] * 10,
|
| 418 |
+
cond_dim=0,
|
| 419 |
+
cond_noise_aug=False,
|
| 420 |
+
kernel_size=5,
|
| 421 |
+
learned_resample=False,
|
| 422 |
+
strides=[2] * 13,
|
| 423 |
+
conv_bias=True,
|
| 424 |
+
use_snake=False,
|
| 425 |
+
):
|
| 426 |
+
super().__init__()
|
| 427 |
+
|
| 428 |
+
self.cond_noise_aug = cond_noise_aug
|
| 429 |
+
|
| 430 |
+
self.io_channels = io_channels
|
| 431 |
+
|
| 432 |
+
if self.cond_noise_aug:
|
| 433 |
+
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
|
| 434 |
+
|
| 435 |
+
self.timestep_embed = FourierFeatures(1, 16)
|
| 436 |
+
|
| 437 |
+
attn_layer = depth - n_attn_layers
|
| 438 |
+
|
| 439 |
+
strides = [1] + strides
|
| 440 |
+
|
| 441 |
+
block = nn.Identity()
|
| 442 |
+
|
| 443 |
+
conv_block = partial(
|
| 444 |
+
ResConvBlock,
|
| 445 |
+
kernel_size=kernel_size,
|
| 446 |
+
conv_bias=conv_bias,
|
| 447 |
+
use_snake=use_snake,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
for i in range(depth, 0, -1):
|
| 451 |
+
c = channels[i - 1]
|
| 452 |
+
stride = strides[i - 1]
|
| 453 |
+
if stride > 2 and not learned_resample:
|
| 454 |
+
raise ValueError("Must have stride 2 without learned resampling")
|
| 455 |
+
|
| 456 |
+
if i > 1:
|
| 457 |
+
c_prev = channels[i - 2]
|
| 458 |
+
add_attn = i >= attn_layer and n_attn_layers > 0
|
| 459 |
+
block = SkipBlock(
|
| 460 |
+
Downsample1d_2(c_prev, c_prev, stride)
|
| 461 |
+
if (learned_resample or stride == 1)
|
| 462 |
+
else Downsample1d("cubic"),
|
| 463 |
+
conv_block(c_prev, c, c),
|
| 464 |
+
SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
|
| 465 |
+
conv_block(c, c, c),
|
| 466 |
+
SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
|
| 467 |
+
conv_block(c, c, c),
|
| 468 |
+
SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
|
| 469 |
+
block,
|
| 470 |
+
conv_block(c * 2 if i != depth else c, c, c),
|
| 471 |
+
SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
|
| 472 |
+
conv_block(c, c, c),
|
| 473 |
+
SelfAttention1d(c, c // 32) if add_attn else nn.Identity(),
|
| 474 |
+
conv_block(c, c, c_prev),
|
| 475 |
+
SelfAttention1d(c_prev, c_prev // 32)
|
| 476 |
+
if add_attn
|
| 477 |
+
else nn.Identity(),
|
| 478 |
+
Upsample1d_2(c_prev, c_prev, stride)
|
| 479 |
+
if learned_resample
|
| 480 |
+
else Upsample1d(kernel="cubic"),
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
cond_embed_dim = 16 if not self.cond_noise_aug else 32
|
| 484 |
+
block = nn.Sequential(
|
| 485 |
+
conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
|
| 486 |
+
conv_block(c, c, c),
|
| 487 |
+
conv_block(c, c, c),
|
| 488 |
+
block,
|
| 489 |
+
conv_block(c * 2, c, c),
|
| 490 |
+
conv_block(c, c, c),
|
| 491 |
+
conv_block(c, c, io_channels, is_last=True),
|
| 492 |
+
)
|
| 493 |
+
self.net = block
|
| 494 |
+
|
| 495 |
+
with torch.no_grad():
|
| 496 |
+
for param in self.net.parameters():
|
| 497 |
+
param *= 0.5
|
| 498 |
+
|
| 499 |
+
def forward(self, x, t, cond=None, cond_aug_scale=None):
|
| 500 |
+
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
|
| 501 |
+
|
| 502 |
+
inputs = [x, timestep_embed]
|
| 503 |
+
|
| 504 |
+
if cond is not None:
|
| 505 |
+
if cond.shape[2] != x.shape[2]:
|
| 506 |
+
cond = F.interpolate(
|
| 507 |
+
cond, (x.shape[2],), mode="linear", align_corners=False
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if self.cond_noise_aug:
|
| 511 |
+
# Get a random number between 0 and 1, uniformly sampled
|
| 512 |
+
if cond_aug_scale is None:
|
| 513 |
+
aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
|
| 514 |
+
else:
|
| 515 |
+
aug_level = (
|
| 516 |
+
torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
# Add noise to the conditioning signal
|
| 520 |
+
cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
|
| 521 |
+
|
| 522 |
+
# Get embedding for noise cond level, reusing timestamp_embed
|
| 523 |
+
aug_level_embed = expand_to_planes(
|
| 524 |
+
self.timestep_embed(aug_level[:, None]), x.shape
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
inputs.append(aug_level_embed)
|
| 528 |
+
|
| 529 |
+
inputs.append(cond)
|
| 530 |
+
|
| 531 |
+
outputs = self.net(torch.cat(inputs, dim=1))
|
| 532 |
+
|
| 533 |
+
return outputs
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class DiTWrapper(ConditionedDiffusionModel):
|
| 537 |
+
def __init__(self, *args, **kwargs):
|
| 538 |
+
super().__init__(
|
| 539 |
+
supports_cross_attention=True,
|
| 540 |
+
supports_global_cond=False,
|
| 541 |
+
supports_input_concat=False,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
self.model = DiffusionTransformer(*args, **kwargs)
|
| 545 |
+
|
| 546 |
+
with torch.no_grad():
|
| 547 |
+
for param in self.model.parameters():
|
| 548 |
+
param *= 0.5
|
| 549 |
+
|
| 550 |
+
def forward(
|
| 551 |
+
self,
|
| 552 |
+
x,
|
| 553 |
+
t,
|
| 554 |
+
cross_attn_cond=None,
|
| 555 |
+
cross_attn_mask=None,
|
| 556 |
+
negative_cross_attn_cond=None,
|
| 557 |
+
negative_cross_attn_mask=None,
|
| 558 |
+
input_concat_cond=None,
|
| 559 |
+
negative_input_concat_cond=None,
|
| 560 |
+
global_cond=None,
|
| 561 |
+
negative_global_cond=None,
|
| 562 |
+
prepend_cond=None,
|
| 563 |
+
prepend_cond_mask=None,
|
| 564 |
+
cfg_scale=1.0,
|
| 565 |
+
cfg_dropout_prob: float = 0.0,
|
| 566 |
+
batch_cfg: bool = True,
|
| 567 |
+
rescale_cfg: bool = False,
|
| 568 |
+
scale_phi: float = 0.0,
|
| 569 |
+
**kwargs,
|
| 570 |
+
):
|
| 571 |
+
assert batch_cfg, "batch_cfg must be True for DiTWrapper"
|
| 572 |
+
# assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
|
| 573 |
+
|
| 574 |
+
return self.model(
|
| 575 |
+
x,
|
| 576 |
+
t,
|
| 577 |
+
cross_attn_cond=cross_attn_cond,
|
| 578 |
+
cross_attn_cond_mask=cross_attn_mask,
|
| 579 |
+
negative_cross_attn_cond=negative_cross_attn_cond,
|
| 580 |
+
negative_cross_attn_mask=negative_cross_attn_mask,
|
| 581 |
+
input_concat_cond=input_concat_cond,
|
| 582 |
+
prepend_cond=prepend_cond,
|
| 583 |
+
prepend_cond_mask=prepend_cond_mask,
|
| 584 |
+
cfg_scale=cfg_scale,
|
| 585 |
+
cfg_dropout_prob=cfg_dropout_prob,
|
| 586 |
+
scale_phi=scale_phi,
|
| 587 |
+
global_embed=global_cond,
|
| 588 |
+
**kwargs,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
class DiTUncondWrapper(DiffusionModel):
|
| 593 |
+
def __init__(self, in_channels, *args, **kwargs):
|
| 594 |
+
super().__init__()
|
| 595 |
+
|
| 596 |
+
self.model = DiffusionTransformer(io_channels=in_channels, *args, **kwargs)
|
| 597 |
+
|
| 598 |
+
self.io_channels = in_channels
|
| 599 |
+
|
| 600 |
+
with torch.no_grad():
|
| 601 |
+
for param in self.model.parameters():
|
| 602 |
+
param *= 0.5
|
| 603 |
+
|
| 604 |
+
def forward(self, x, t, **kwargs):
|
| 605 |
+
return self.model(x, t, **kwargs)
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
|
| 609 |
+
diffusion_uncond_config = config["model"]
|
| 610 |
+
|
| 611 |
+
model_type = diffusion_uncond_config.get("type", None)
|
| 612 |
+
|
| 613 |
+
diffusion_config = diffusion_uncond_config.get("config", {})
|
| 614 |
+
|
| 615 |
+
assert model_type is not None, "Must specify model type in config"
|
| 616 |
+
|
| 617 |
+
pretransform = diffusion_uncond_config.get("pretransform", None)
|
| 618 |
+
|
| 619 |
+
sample_size = config.get("sample_size", None)
|
| 620 |
+
assert sample_size is not None, "Must specify sample size in config"
|
| 621 |
+
|
| 622 |
+
sample_rate = config.get("sample_rate", None)
|
| 623 |
+
assert sample_rate is not None, "Must specify sample rate in config"
|
| 624 |
+
|
| 625 |
+
if pretransform is not None:
|
| 626 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 627 |
+
min_input_length = pretransform.downsampling_ratio
|
| 628 |
+
else:
|
| 629 |
+
min_input_length = 1
|
| 630 |
+
|
| 631 |
+
if model_type == "DAU1d":
|
| 632 |
+
model = DiffusionAttnUnet1D(**diffusion_config)
|
| 633 |
+
|
| 634 |
+
elif model_type == "adp_uncond_1d":
|
| 635 |
+
model = UNet1DUncondWrapper(**diffusion_config)
|
| 636 |
+
|
| 637 |
+
elif model_type == "dit":
|
| 638 |
+
model = DiTUncondWrapper(**diffusion_config)
|
| 639 |
+
|
| 640 |
+
else:
|
| 641 |
+
raise NotImplementedError(f"Unknown model type: {model_type}")
|
| 642 |
+
|
| 643 |
+
return DiffusionModelWrapper(
|
| 644 |
+
model,
|
| 645 |
+
io_channels=model.io_channels,
|
| 646 |
+
sample_size=sample_size,
|
| 647 |
+
sample_rate=sample_rate,
|
| 648 |
+
pretransform=pretransform,
|
| 649 |
+
min_input_length=min_input_length,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
|
| 654 |
+
model_config = config["model"]
|
| 655 |
+
|
| 656 |
+
model_type = config["model_type"]
|
| 657 |
+
|
| 658 |
+
diffusion_config = model_config.get("diffusion", None)
|
| 659 |
+
assert diffusion_config is not None, "Must specify diffusion config"
|
| 660 |
+
|
| 661 |
+
diffusion_model_type = diffusion_config.get("type", None)
|
| 662 |
+
assert diffusion_model_type is not None, "Must specify diffusion model type"
|
| 663 |
+
|
| 664 |
+
diffusion_model_config = diffusion_config.get("config", None)
|
| 665 |
+
assert diffusion_model_config is not None, "Must specify diffusion model config"
|
| 666 |
+
|
| 667 |
+
if diffusion_model_type == "adp_cfg_1d":
|
| 668 |
+
diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
|
| 669 |
+
elif diffusion_model_type == "adp_1d":
|
| 670 |
+
diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
|
| 671 |
+
elif diffusion_model_type == "dit":
|
| 672 |
+
diffusion_model = DiTWrapper(**diffusion_model_config)
|
| 673 |
+
|
| 674 |
+
io_channels = model_config.get("io_channels", None)
|
| 675 |
+
assert io_channels is not None, "Must specify io_channels in model config"
|
| 676 |
+
|
| 677 |
+
sample_rate = config.get("sample_rate", None)
|
| 678 |
+
assert sample_rate is not None, "Must specify sample_rate in config"
|
| 679 |
+
|
| 680 |
+
diffusion_objective = diffusion_config.get("diffusion_objective", "v")
|
| 681 |
+
|
| 682 |
+
conditioning_config = model_config.get("conditioning", None)
|
| 683 |
+
|
| 684 |
+
conditioner = None
|
| 685 |
+
if conditioning_config is not None:
|
| 686 |
+
conditioner = create_multi_conditioner_from_conditioning_config(
|
| 687 |
+
conditioning_config
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
cross_attention_ids = diffusion_config.get("cross_attention_cond_ids", [])
|
| 691 |
+
global_cond_ids = diffusion_config.get("global_cond_ids", [])
|
| 692 |
+
input_concat_ids = diffusion_config.get("input_concat_ids", [])
|
| 693 |
+
prepend_cond_ids = diffusion_config.get("prepend_cond_ids", [])
|
| 694 |
+
|
| 695 |
+
pretransform = model_config.get("pretransform", None)
|
| 696 |
+
|
| 697 |
+
if pretransform is not None:
|
| 698 |
+
pretransform = create_pretransform_from_config(pretransform, sample_rate)
|
| 699 |
+
min_input_length = pretransform.downsampling_ratio
|
| 700 |
+
else:
|
| 701 |
+
min_input_length = 1
|
| 702 |
+
|
| 703 |
+
if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
|
| 704 |
+
min_input_length *= np.prod(diffusion_model_config["factors"])
|
| 705 |
+
elif diffusion_model_type == "dit":
|
| 706 |
+
min_input_length *= diffusion_model.model.patch_size
|
| 707 |
+
|
| 708 |
+
# Get the proper wrapper class
|
| 709 |
+
|
| 710 |
+
extra_kwargs = {}
|
| 711 |
+
|
| 712 |
+
if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint":
|
| 713 |
+
wrapper_fn = ConditionedDiffusionModelWrapper
|
| 714 |
+
|
| 715 |
+
extra_kwargs["diffusion_objective"] = diffusion_objective
|
| 716 |
+
|
| 717 |
+
elif model_type == "diffusion_prior":
|
| 718 |
+
prior_type = model_config.get("prior_type", None)
|
| 719 |
+
assert prior_type is not None, (
|
| 720 |
+
"Must specify prior_type in diffusion prior model config"
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
if prior_type == "mono_stereo":
|
| 724 |
+
from .diffusion_prior import MonoToStereoDiffusionPrior
|
| 725 |
+
|
| 726 |
+
wrapper_fn = MonoToStereoDiffusionPrior
|
| 727 |
+
|
| 728 |
+
return wrapper_fn(
|
| 729 |
+
diffusion_model,
|
| 730 |
+
conditioner,
|
| 731 |
+
min_input_length=min_input_length,
|
| 732 |
+
sample_rate=sample_rate,
|
| 733 |
+
cross_attn_cond_ids=cross_attention_ids,
|
| 734 |
+
global_cond_ids=global_cond_ids,
|
| 735 |
+
input_concat_ids=input_concat_ids,
|
| 736 |
+
prepend_cond_ids=prepend_cond_ids,
|
| 737 |
+
pretransform=pretransform,
|
| 738 |
+
io_channels=io_channels,
|
| 739 |
+
**extra_kwargs,
|
| 740 |
+
)
|
src/YingMusicSinger/utils/stable_audio_tools/dit.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing as tp
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from x_transformers import ContinuousTransformerWrapper, Encoder
|
| 8 |
+
|
| 9 |
+
from .blocks import FourierFeatures
|
| 10 |
+
from .transformer import ContinuousTransformer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiffusionTransformer(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
io_channels=32,
|
| 17 |
+
patch_size=1,
|
| 18 |
+
embed_dim=768,
|
| 19 |
+
cond_token_dim=0,
|
| 20 |
+
project_cond_tokens=True,
|
| 21 |
+
global_cond_dim=0,
|
| 22 |
+
project_global_cond=True,
|
| 23 |
+
input_concat_dim=0,
|
| 24 |
+
prepend_cond_dim=0,
|
| 25 |
+
depth=12,
|
| 26 |
+
num_heads=8,
|
| 27 |
+
transformer_type: tp.Literal[
|
| 28 |
+
"x-transformers", "continuous_transformer"
|
| 29 |
+
] = "x-transformers",
|
| 30 |
+
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.cond_token_dim = cond_token_dim
|
| 36 |
+
|
| 37 |
+
# Timestep embeddings
|
| 38 |
+
timestep_features_dim = 256
|
| 39 |
+
|
| 40 |
+
self.timestep_features = FourierFeatures(1, timestep_features_dim)
|
| 41 |
+
|
| 42 |
+
self.to_timestep_embed = nn.Sequential(
|
| 43 |
+
nn.Linear(timestep_features_dim, embed_dim, bias=True),
|
| 44 |
+
nn.SiLU(),
|
| 45 |
+
nn.Linear(embed_dim, embed_dim, bias=True),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if cond_token_dim > 0:
|
| 49 |
+
# Conditioning tokens
|
| 50 |
+
|
| 51 |
+
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
|
| 52 |
+
self.to_cond_embed = nn.Sequential(
|
| 53 |
+
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
|
| 54 |
+
nn.SiLU(),
|
| 55 |
+
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False),
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
cond_embed_dim = 0
|
| 59 |
+
|
| 60 |
+
if global_cond_dim > 0:
|
| 61 |
+
# Global conditioning
|
| 62 |
+
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
|
| 63 |
+
self.to_global_embed = nn.Sequential(
|
| 64 |
+
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
|
| 65 |
+
nn.SiLU(),
|
| 66 |
+
nn.Linear(global_embed_dim, global_embed_dim, bias=False),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if prepend_cond_dim > 0:
|
| 70 |
+
# Prepend conditioning
|
| 71 |
+
self.to_prepend_embed = nn.Sequential(
|
| 72 |
+
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
|
| 73 |
+
nn.SiLU(),
|
| 74 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.input_concat_dim = input_concat_dim
|
| 78 |
+
|
| 79 |
+
dim_in = io_channels + self.input_concat_dim
|
| 80 |
+
|
| 81 |
+
self.patch_size = patch_size
|
| 82 |
+
|
| 83 |
+
# Transformer
|
| 84 |
+
|
| 85 |
+
self.transformer_type = transformer_type
|
| 86 |
+
|
| 87 |
+
self.global_cond_type = global_cond_type
|
| 88 |
+
|
| 89 |
+
if self.transformer_type == "x-transformers":
|
| 90 |
+
self.transformer = ContinuousTransformerWrapper(
|
| 91 |
+
dim_in=dim_in * patch_size,
|
| 92 |
+
dim_out=io_channels * patch_size,
|
| 93 |
+
max_seq_len=0, # Not relevant without absolute positional embeds
|
| 94 |
+
attn_layers=Encoder(
|
| 95 |
+
dim=embed_dim,
|
| 96 |
+
depth=depth,
|
| 97 |
+
heads=num_heads,
|
| 98 |
+
attn_flash=True,
|
| 99 |
+
cross_attend=cond_token_dim > 0,
|
| 100 |
+
dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
|
| 101 |
+
zero_init_branch_output=True,
|
| 102 |
+
use_abs_pos_emb=False,
|
| 103 |
+
rotary_pos_emb=True,
|
| 104 |
+
ff_swish=True,
|
| 105 |
+
ff_glu=True,
|
| 106 |
+
**kwargs,
|
| 107 |
+
),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
elif self.transformer_type == "continuous_transformer":
|
| 111 |
+
global_dim = None
|
| 112 |
+
|
| 113 |
+
if self.global_cond_type == "adaLN":
|
| 114 |
+
# The global conditioning is projected to the embed_dim already at this point
|
| 115 |
+
global_dim = embed_dim
|
| 116 |
+
|
| 117 |
+
self.transformer = ContinuousTransformer(
|
| 118 |
+
dim=embed_dim,
|
| 119 |
+
depth=depth,
|
| 120 |
+
dim_heads=embed_dim // num_heads,
|
| 121 |
+
dim_in=dim_in * patch_size,
|
| 122 |
+
dim_out=io_channels * patch_size,
|
| 123 |
+
cross_attend=cond_token_dim > 0,
|
| 124 |
+
cond_token_dim=cond_embed_dim,
|
| 125 |
+
global_cond_dim=global_dim,
|
| 126 |
+
**kwargs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
|
| 131 |
+
|
| 132 |
+
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
|
| 133 |
+
nn.init.zeros_(self.preprocess_conv.weight)
|
| 134 |
+
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
|
| 135 |
+
nn.init.zeros_(self.postprocess_conv.weight)
|
| 136 |
+
|
| 137 |
+
def _forward(
|
| 138 |
+
self,
|
| 139 |
+
x,
|
| 140 |
+
t,
|
| 141 |
+
mask=None,
|
| 142 |
+
cross_attn_cond=None,
|
| 143 |
+
cross_attn_cond_mask=None,
|
| 144 |
+
input_concat_cond=None,
|
| 145 |
+
global_embed=None,
|
| 146 |
+
prepend_cond=None,
|
| 147 |
+
prepend_cond_mask=None,
|
| 148 |
+
return_info=False,
|
| 149 |
+
**kwargs,
|
| 150 |
+
):
|
| 151 |
+
if cross_attn_cond is not None:
|
| 152 |
+
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
|
| 153 |
+
|
| 154 |
+
if global_embed is not None:
|
| 155 |
+
# Project the global conditioning to the embedding dimension
|
| 156 |
+
global_embed = self.to_global_embed(global_embed)
|
| 157 |
+
|
| 158 |
+
prepend_inputs = None
|
| 159 |
+
prepend_mask = None
|
| 160 |
+
prepend_length = 0
|
| 161 |
+
if prepend_cond is not None:
|
| 162 |
+
# Project the prepend conditioning to the embedding dimension
|
| 163 |
+
prepend_cond = self.to_prepend_embed(prepend_cond)
|
| 164 |
+
|
| 165 |
+
prepend_inputs = prepend_cond
|
| 166 |
+
if prepend_cond_mask is not None:
|
| 167 |
+
prepend_mask = prepend_cond_mask
|
| 168 |
+
|
| 169 |
+
if input_concat_cond is not None:
|
| 170 |
+
# Interpolate input_concat_cond to the same length as x
|
| 171 |
+
if input_concat_cond.shape[2] != x.shape[2]:
|
| 172 |
+
input_concat_cond = F.interpolate(
|
| 173 |
+
input_concat_cond, (x.shape[2],), mode="nearest"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
x = torch.cat([x, input_concat_cond], dim=1)
|
| 177 |
+
|
| 178 |
+
# Get the batch of timestep embeddings
|
| 179 |
+
timestep_embed = self.to_timestep_embed(
|
| 180 |
+
self.timestep_features(t[:, None])
|
| 181 |
+
) # (b, embed_dim)
|
| 182 |
+
|
| 183 |
+
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
|
| 184 |
+
if global_embed is not None:
|
| 185 |
+
global_embed = global_embed + timestep_embed
|
| 186 |
+
else:
|
| 187 |
+
global_embed = timestep_embed
|
| 188 |
+
|
| 189 |
+
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
|
| 190 |
+
if self.global_cond_type == "prepend":
|
| 191 |
+
if prepend_inputs is None:
|
| 192 |
+
# Prepend inputs are just the global embed, and the mask is all ones
|
| 193 |
+
prepend_inputs = global_embed.unsqueeze(1)
|
| 194 |
+
prepend_mask = torch.ones(
|
| 195 |
+
(x.shape[0], 1), device=x.device, dtype=torch.bool
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
# Prepend inputs are the prepend conditioning + the global embed
|
| 199 |
+
prepend_inputs = torch.cat(
|
| 200 |
+
[prepend_inputs, global_embed.unsqueeze(1)], dim=1
|
| 201 |
+
)
|
| 202 |
+
prepend_mask = torch.cat(
|
| 203 |
+
[
|
| 204 |
+
prepend_mask,
|
| 205 |
+
torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool),
|
| 206 |
+
],
|
| 207 |
+
dim=1,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
prepend_length = prepend_inputs.shape[1]
|
| 211 |
+
|
| 212 |
+
x = self.preprocess_conv(x) + x
|
| 213 |
+
|
| 214 |
+
x = rearrange(x, "b c t -> b t c")
|
| 215 |
+
|
| 216 |
+
extra_args = {}
|
| 217 |
+
|
| 218 |
+
if self.global_cond_type == "adaLN":
|
| 219 |
+
extra_args["global_cond"] = global_embed
|
| 220 |
+
|
| 221 |
+
if self.patch_size > 1:
|
| 222 |
+
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
|
| 223 |
+
|
| 224 |
+
if self.transformer_type == "x-transformers":
|
| 225 |
+
output = self.transformer(
|
| 226 |
+
x,
|
| 227 |
+
prepend_embeds=prepend_inputs,
|
| 228 |
+
context=cross_attn_cond,
|
| 229 |
+
context_mask=cross_attn_cond_mask,
|
| 230 |
+
mask=mask,
|
| 231 |
+
prepend_mask=prepend_mask,
|
| 232 |
+
**extra_args,
|
| 233 |
+
**kwargs,
|
| 234 |
+
)
|
| 235 |
+
elif self.transformer_type == "continuous_transformer":
|
| 236 |
+
output = self.transformer(
|
| 237 |
+
x,
|
| 238 |
+
prepend_embeds=prepend_inputs,
|
| 239 |
+
context=cross_attn_cond,
|
| 240 |
+
context_mask=cross_attn_cond_mask,
|
| 241 |
+
mask=mask,
|
| 242 |
+
prepend_mask=prepend_mask,
|
| 243 |
+
return_info=return_info,
|
| 244 |
+
**extra_args,
|
| 245 |
+
**kwargs,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if return_info:
|
| 249 |
+
output, info = output
|
| 250 |
+
elif self.transformer_type == "mm_transformer":
|
| 251 |
+
output = self.transformer(
|
| 252 |
+
x,
|
| 253 |
+
context=cross_attn_cond,
|
| 254 |
+
mask=mask,
|
| 255 |
+
context_mask=cross_attn_cond_mask,
|
| 256 |
+
**extra_args,
|
| 257 |
+
**kwargs,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
|
| 261 |
+
|
| 262 |
+
if self.patch_size > 1:
|
| 263 |
+
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
|
| 264 |
+
|
| 265 |
+
output = self.postprocess_conv(output) + output
|
| 266 |
+
|
| 267 |
+
if return_info:
|
| 268 |
+
return output, info
|
| 269 |
+
|
| 270 |
+
return output
|
| 271 |
+
|
| 272 |
+
def forward(
|
| 273 |
+
self,
|
| 274 |
+
x,
|
| 275 |
+
t,
|
| 276 |
+
cross_attn_cond=None,
|
| 277 |
+
cross_attn_cond_mask=None,
|
| 278 |
+
negative_cross_attn_cond=None,
|
| 279 |
+
negative_cross_attn_mask=None,
|
| 280 |
+
input_concat_cond=None,
|
| 281 |
+
global_embed=None,
|
| 282 |
+
negative_global_embed=None,
|
| 283 |
+
prepend_cond=None,
|
| 284 |
+
prepend_cond_mask=None,
|
| 285 |
+
cfg_scale=1.0,
|
| 286 |
+
cfg_dropout_prob=0.0,
|
| 287 |
+
causal=False,
|
| 288 |
+
scale_phi=0.0,
|
| 289 |
+
mask=None,
|
| 290 |
+
return_info=False,
|
| 291 |
+
**kwargs,
|
| 292 |
+
):
|
| 293 |
+
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
|
| 294 |
+
|
| 295 |
+
if cross_attn_cond_mask is not None:
|
| 296 |
+
cross_attn_cond_mask = cross_attn_cond_mask.bool()
|
| 297 |
+
|
| 298 |
+
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
|
| 299 |
+
|
| 300 |
+
if prepend_cond_mask is not None:
|
| 301 |
+
prepend_cond_mask = prepend_cond_mask.bool()
|
| 302 |
+
|
| 303 |
+
# CFG dropout
|
| 304 |
+
if cfg_dropout_prob > 0.0:
|
| 305 |
+
if cross_attn_cond is not None:
|
| 306 |
+
null_embed = torch.zeros_like(
|
| 307 |
+
cross_attn_cond, device=cross_attn_cond.device
|
| 308 |
+
)
|
| 309 |
+
dropout_mask = torch.bernoulli(
|
| 310 |
+
torch.full(
|
| 311 |
+
(cross_attn_cond.shape[0], 1, 1),
|
| 312 |
+
cfg_dropout_prob,
|
| 313 |
+
device=cross_attn_cond.device,
|
| 314 |
+
)
|
| 315 |
+
).to(torch.bool)
|
| 316 |
+
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
|
| 317 |
+
|
| 318 |
+
if prepend_cond is not None:
|
| 319 |
+
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
| 320 |
+
dropout_mask = torch.bernoulli(
|
| 321 |
+
torch.full(
|
| 322 |
+
(prepend_cond.shape[0], 1, 1),
|
| 323 |
+
cfg_dropout_prob,
|
| 324 |
+
device=prepend_cond.device,
|
| 325 |
+
)
|
| 326 |
+
).to(torch.bool)
|
| 327 |
+
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
|
| 328 |
+
|
| 329 |
+
if cfg_scale != 1.0 and (
|
| 330 |
+
cross_attn_cond is not None or prepend_cond is not None
|
| 331 |
+
):
|
| 332 |
+
# Classifier-free guidance
|
| 333 |
+
# Concatenate conditioned and unconditioned inputs on the batch dimension
|
| 334 |
+
batch_inputs = torch.cat([x, x], dim=0)
|
| 335 |
+
batch_timestep = torch.cat([t, t], dim=0)
|
| 336 |
+
|
| 337 |
+
if global_embed is not None:
|
| 338 |
+
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
|
| 339 |
+
else:
|
| 340 |
+
batch_global_cond = None
|
| 341 |
+
|
| 342 |
+
if input_concat_cond is not None:
|
| 343 |
+
batch_input_concat_cond = torch.cat(
|
| 344 |
+
[input_concat_cond, input_concat_cond], dim=0
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
batch_input_concat_cond = None
|
| 348 |
+
|
| 349 |
+
batch_cond = None
|
| 350 |
+
batch_cond_masks = None
|
| 351 |
+
|
| 352 |
+
# Handle CFG for cross-attention conditioning
|
| 353 |
+
if cross_attn_cond is not None:
|
| 354 |
+
null_embed = torch.zeros_like(
|
| 355 |
+
cross_attn_cond, device=cross_attn_cond.device
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
|
| 359 |
+
if negative_cross_attn_cond is not None:
|
| 360 |
+
# If there's a negative cross-attention mask, set the masked tokens to the null embed
|
| 361 |
+
if negative_cross_attn_mask is not None:
|
| 362 |
+
negative_cross_attn_mask = negative_cross_attn_mask.to(
|
| 363 |
+
torch.bool
|
| 364 |
+
).unsqueeze(2)
|
| 365 |
+
|
| 366 |
+
negative_cross_attn_cond = torch.where(
|
| 367 |
+
negative_cross_attn_mask,
|
| 368 |
+
negative_cross_attn_cond,
|
| 369 |
+
null_embed,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
batch_cond = torch.cat(
|
| 373 |
+
[cross_attn_cond, negative_cross_attn_cond], dim=0
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
else:
|
| 377 |
+
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
|
| 378 |
+
|
| 379 |
+
if cross_attn_cond_mask is not None:
|
| 380 |
+
batch_cond_masks = torch.cat(
|
| 381 |
+
[cross_attn_cond_mask, cross_attn_cond_mask], dim=0
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
batch_prepend_cond = None
|
| 385 |
+
batch_prepend_cond_mask = None
|
| 386 |
+
|
| 387 |
+
if prepend_cond is not None:
|
| 388 |
+
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
|
| 389 |
+
|
| 390 |
+
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
|
| 391 |
+
|
| 392 |
+
if prepend_cond_mask is not None:
|
| 393 |
+
batch_prepend_cond_mask = torch.cat(
|
| 394 |
+
[prepend_cond_mask, prepend_cond_mask], dim=0
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if mask is not None:
|
| 398 |
+
batch_masks = torch.cat([mask, mask], dim=0)
|
| 399 |
+
else:
|
| 400 |
+
batch_masks = None
|
| 401 |
+
|
| 402 |
+
batch_output = self._forward(
|
| 403 |
+
batch_inputs,
|
| 404 |
+
batch_timestep,
|
| 405 |
+
cross_attn_cond=batch_cond,
|
| 406 |
+
cross_attn_cond_mask=batch_cond_masks,
|
| 407 |
+
mask=batch_masks,
|
| 408 |
+
input_concat_cond=batch_input_concat_cond,
|
| 409 |
+
global_embed=batch_global_cond,
|
| 410 |
+
prepend_cond=batch_prepend_cond,
|
| 411 |
+
prepend_cond_mask=batch_prepend_cond_mask,
|
| 412 |
+
return_info=return_info,
|
| 413 |
+
**kwargs,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
if return_info:
|
| 417 |
+
batch_output, info = batch_output
|
| 418 |
+
|
| 419 |
+
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
|
| 420 |
+
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
|
| 421 |
+
|
| 422 |
+
# CFG Rescale
|
| 423 |
+
if scale_phi != 0.0:
|
| 424 |
+
cond_out_std = cond_output.std(dim=1, keepdim=True)
|
| 425 |
+
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
|
| 426 |
+
output = (
|
| 427 |
+
scale_phi * (cfg_output * (cond_out_std / out_cfg_std))
|
| 428 |
+
+ (1 - scale_phi) * cfg_output
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
output = cfg_output
|
| 432 |
+
|
| 433 |
+
if return_info:
|
| 434 |
+
return output, info
|
| 435 |
+
|
| 436 |
+
return output
|
| 437 |
+
|
| 438 |
+
else:
|
| 439 |
+
return self._forward(
|
| 440 |
+
x,
|
| 441 |
+
t,
|
| 442 |
+
cross_attn_cond=cross_attn_cond,
|
| 443 |
+
cross_attn_cond_mask=cross_attn_cond_mask,
|
| 444 |
+
input_concat_cond=input_concat_cond,
|
| 445 |
+
global_embed=global_embed,
|
| 446 |
+
prepend_cond=prepend_cond,
|
| 447 |
+
prepend_cond_mask=prepend_cond_mask,
|
| 448 |
+
mask=mask,
|
| 449 |
+
return_info=return_info,
|
| 450 |
+
**kwargs,
|
| 451 |
+
)
|
src/YingMusicSinger/utils/stable_audio_tools/factory.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_model_from_config(model_config):
|
| 5 |
+
model_type = model_config.get("model_type", None)
|
| 6 |
+
|
| 7 |
+
assert model_type is not None, "model_type must be specified in model config"
|
| 8 |
+
|
| 9 |
+
if model_type == "autoencoder":
|
| 10 |
+
from .autoencoders import create_autoencoder_from_config
|
| 11 |
+
|
| 12 |
+
return create_autoencoder_from_config(model_config)
|
| 13 |
+
elif model_type == "diffusion_uncond":
|
| 14 |
+
from .diffusion import create_diffusion_uncond_from_config
|
| 15 |
+
|
| 16 |
+
return create_diffusion_uncond_from_config(model_config)
|
| 17 |
+
elif (
|
| 18 |
+
model_type == "diffusion_cond"
|
| 19 |
+
or model_type == "diffusion_cond_inpaint"
|
| 20 |
+
or model_type == "diffusion_prior"
|
| 21 |
+
):
|
| 22 |
+
from .diffusion import create_diffusion_cond_from_config
|
| 23 |
+
|
| 24 |
+
return create_diffusion_cond_from_config(model_config)
|
| 25 |
+
elif model_type == "diffusion_autoencoder":
|
| 26 |
+
from .autoencoders import create_diffAE_from_config
|
| 27 |
+
|
| 28 |
+
return create_diffAE_from_config(model_config)
|
| 29 |
+
elif model_type == "lm":
|
| 30 |
+
from .lm import create_audio_lm_from_config
|
| 31 |
+
|
| 32 |
+
return create_audio_lm_from_config(model_config)
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError(f"Unknown model type: {model_type}")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_model_from_config_path(model_config_path):
|
| 38 |
+
with open(model_config_path) as f:
|
| 39 |
+
model_config = json.load(f)
|
| 40 |
+
|
| 41 |
+
return create_model_from_config(model_config)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def create_pretransform_from_config(pretransform_config, sample_rate):
|
| 45 |
+
pretransform_type = pretransform_config.get("type", None)
|
| 46 |
+
|
| 47 |
+
assert pretransform_type is not None, (
|
| 48 |
+
"type must be specified in pretransform config"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if pretransform_type == "autoencoder":
|
| 52 |
+
from .autoencoders import create_autoencoder_from_config
|
| 53 |
+
from .pretransforms import AutoencoderPretransform
|
| 54 |
+
|
| 55 |
+
# Create fake top-level config to pass sample rate to autoencoder constructor
|
| 56 |
+
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
|
| 57 |
+
autoencoder_config = {
|
| 58 |
+
"sample_rate": sample_rate,
|
| 59 |
+
"model": pretransform_config["config"],
|
| 60 |
+
}
|
| 61 |
+
autoencoder = create_autoencoder_from_config(autoencoder_config)
|
| 62 |
+
|
| 63 |
+
scale = pretransform_config.get("scale", 1.0)
|
| 64 |
+
model_half = pretransform_config.get("model_half", False)
|
| 65 |
+
iterate_batch = pretransform_config.get("iterate_batch", False)
|
| 66 |
+
chunked = pretransform_config.get("chunked", False)
|
| 67 |
+
|
| 68 |
+
pretransform = AutoencoderPretransform(
|
| 69 |
+
autoencoder,
|
| 70 |
+
scale=scale,
|
| 71 |
+
model_half=model_half,
|
| 72 |
+
iterate_batch=iterate_batch,
|
| 73 |
+
chunked=chunked,
|
| 74 |
+
)
|
| 75 |
+
elif pretransform_type == "wavelet":
|
| 76 |
+
from .pretransforms import WaveletPretransform
|
| 77 |
+
|
| 78 |
+
wavelet_config = pretransform_config["config"]
|
| 79 |
+
channels = wavelet_config["channels"]
|
| 80 |
+
levels = wavelet_config["levels"]
|
| 81 |
+
wavelet = wavelet_config["wavelet"]
|
| 82 |
+
|
| 83 |
+
pretransform = WaveletPretransform(channels, levels, wavelet)
|
| 84 |
+
elif pretransform_type == "pqmf":
|
| 85 |
+
from .pretransforms import PQMFPretransform
|
| 86 |
+
|
| 87 |
+
pqmf_config = pretransform_config["config"]
|
| 88 |
+
pretransform = PQMFPretransform(**pqmf_config)
|
| 89 |
+
elif pretransform_type == "dac_pretrained":
|
| 90 |
+
from .pretransforms import PretrainedDACPretransform
|
| 91 |
+
|
| 92 |
+
pretrained_dac_config = pretransform_config["config"]
|
| 93 |
+
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
|
| 94 |
+
elif pretransform_type == "audiocraft_pretrained":
|
| 95 |
+
from .pretransforms import AudiocraftCompressionPretransform
|
| 96 |
+
|
| 97 |
+
audiocraft_config = pretransform_config["config"]
|
| 98 |
+
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError(f"Unknown pretransform type: {pretransform_type}")
|
| 101 |
+
|
| 102 |
+
enable_grad = pretransform_config.get("enable_grad", False)
|
| 103 |
+
pretransform.enable_grad = enable_grad
|
| 104 |
+
|
| 105 |
+
pretransform.eval().requires_grad_(pretransform.enable_grad)
|
| 106 |
+
|
| 107 |
+
return pretransform
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def create_bottleneck_from_config(bottleneck_config):
|
| 111 |
+
bottleneck_type = bottleneck_config.get("type", None)
|
| 112 |
+
|
| 113 |
+
assert bottleneck_type is not None, "type must be specified in bottleneck config"
|
| 114 |
+
|
| 115 |
+
if bottleneck_type == "tanh":
|
| 116 |
+
from .bottleneck import TanhBottleneck
|
| 117 |
+
|
| 118 |
+
bottleneck = TanhBottleneck()
|
| 119 |
+
elif bottleneck_type == "vae":
|
| 120 |
+
from .bottleneck import VAEBottleneck
|
| 121 |
+
|
| 122 |
+
bottleneck = VAEBottleneck()
|
| 123 |
+
elif bottleneck_type == "rvq":
|
| 124 |
+
from .bottleneck import RVQBottleneck
|
| 125 |
+
|
| 126 |
+
quantizer_params = {
|
| 127 |
+
"dim": 128,
|
| 128 |
+
"codebook_size": 1024,
|
| 129 |
+
"num_quantizers": 8,
|
| 130 |
+
"decay": 0.99,
|
| 131 |
+
"kmeans_init": True,
|
| 132 |
+
"kmeans_iters": 50,
|
| 133 |
+
"threshold_ema_dead_code": 2,
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
quantizer_params.update(bottleneck_config["config"])
|
| 137 |
+
|
| 138 |
+
bottleneck = RVQBottleneck(**quantizer_params)
|
| 139 |
+
elif bottleneck_type == "dac_rvq":
|
| 140 |
+
from .bottleneck import DACRVQBottleneck
|
| 141 |
+
|
| 142 |
+
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
|
| 143 |
+
|
| 144 |
+
elif bottleneck_type == "rvq_vae":
|
| 145 |
+
from .bottleneck import RVQVAEBottleneck
|
| 146 |
+
|
| 147 |
+
quantizer_params = {
|
| 148 |
+
"dim": 128,
|
| 149 |
+
"codebook_size": 1024,
|
| 150 |
+
"num_quantizers": 8,
|
| 151 |
+
"decay": 0.99,
|
| 152 |
+
"kmeans_init": True,
|
| 153 |
+
"kmeans_iters": 50,
|
| 154 |
+
"threshold_ema_dead_code": 2,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
quantizer_params.update(bottleneck_config["config"])
|
| 158 |
+
|
| 159 |
+
bottleneck = RVQVAEBottleneck(**quantizer_params)
|
| 160 |
+
|
| 161 |
+
elif bottleneck_type == "dac_rvq_vae":
|
| 162 |
+
from .bottleneck import DACRVQVAEBottleneck
|
| 163 |
+
|
| 164 |
+
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
|
| 165 |
+
elif bottleneck_type == "l2_norm":
|
| 166 |
+
from .bottleneck import L2Bottleneck
|
| 167 |
+
|
| 168 |
+
bottleneck = L2Bottleneck()
|
| 169 |
+
elif bottleneck_type == "wasserstein":
|
| 170 |
+
from .bottleneck import WassersteinBottleneck
|
| 171 |
+
|
| 172 |
+
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
|
| 173 |
+
elif bottleneck_type == "fsq":
|
| 174 |
+
from .bottleneck import FSQBottleneck
|
| 175 |
+
|
| 176 |
+
bottleneck = FSQBottleneck(**bottleneck_config["config"])
|
| 177 |
+
else:
|
| 178 |
+
raise NotImplementedError(f"Unknown bottleneck type: {bottleneck_type}")
|
| 179 |
+
|
| 180 |
+
requires_grad = bottleneck_config.get("requires_grad", True)
|
| 181 |
+
if not requires_grad:
|
| 182 |
+
for param in bottleneck.parameters():
|
| 183 |
+
param.requires_grad = False
|
| 184 |
+
|
| 185 |
+
return bottleneck
|
src/YingMusicSinger/utils/stable_audio_tools/pretransforms.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Pretransform(nn.Module):
|
| 7 |
+
def __init__(self, enable_grad, io_channels, is_discrete):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
self.is_discrete = is_discrete
|
| 11 |
+
self.io_channels = io_channels
|
| 12 |
+
self.encoded_channels = None
|
| 13 |
+
self.downsampling_ratio = None
|
| 14 |
+
|
| 15 |
+
self.enable_grad = enable_grad
|
| 16 |
+
|
| 17 |
+
def encode(self, x):
|
| 18 |
+
raise NotImplementedError
|
| 19 |
+
|
| 20 |
+
def decode(self, z):
|
| 21 |
+
raise NotImplementedError
|
| 22 |
+
|
| 23 |
+
def tokenize(self, x):
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
def decode_tokens(self, tokens):
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class AutoencoderPretransform(Pretransform):
|
| 31 |
+
def __init__(
|
| 32 |
+
self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False
|
| 33 |
+
):
|
| 34 |
+
super().__init__(
|
| 35 |
+
enable_grad=False,
|
| 36 |
+
io_channels=model.io_channels,
|
| 37 |
+
is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete,
|
| 38 |
+
)
|
| 39 |
+
self.model = model
|
| 40 |
+
self.model.requires_grad_(False).eval()
|
| 41 |
+
self.scale = scale
|
| 42 |
+
self.downsampling_ratio = model.downsampling_ratio
|
| 43 |
+
self.io_channels = model.io_channels
|
| 44 |
+
self.sample_rate = model.sample_rate
|
| 45 |
+
|
| 46 |
+
self.model_half = model_half
|
| 47 |
+
self.iterate_batch = iterate_batch
|
| 48 |
+
|
| 49 |
+
self.encoded_channels = model.latent_dim
|
| 50 |
+
self.latent_dim = model.latent_dim
|
| 51 |
+
|
| 52 |
+
self.chunked = chunked
|
| 53 |
+
self.num_quantizers = (
|
| 54 |
+
model.bottleneck.num_quantizers
|
| 55 |
+
if model.bottleneck is not None and model.bottleneck.is_discrete
|
| 56 |
+
else None
|
| 57 |
+
)
|
| 58 |
+
self.codebook_size = (
|
| 59 |
+
model.bottleneck.codebook_size
|
| 60 |
+
if model.bottleneck is not None and model.bottleneck.is_discrete
|
| 61 |
+
else None
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if self.model_half:
|
| 65 |
+
self.model.half()
|
| 66 |
+
|
| 67 |
+
def encode(self, x, **kwargs):
|
| 68 |
+
if self.model_half:
|
| 69 |
+
x = x.half()
|
| 70 |
+
self.model.to(torch.float16)
|
| 71 |
+
|
| 72 |
+
encoded = self.model.encode_audio(
|
| 73 |
+
x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if self.model_half:
|
| 77 |
+
encoded = encoded.float()
|
| 78 |
+
|
| 79 |
+
return encoded / self.scale
|
| 80 |
+
|
| 81 |
+
def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
|
| 82 |
+
"""
|
| 83 |
+
Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
|
| 84 |
+
If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
|
| 85 |
+
Overlap and chunk_size params are both measured in number of latents (not audio samples)
|
| 86 |
+
# and therefore you likely could use the same values with decode_audio.
|
| 87 |
+
A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
|
| 88 |
+
Every autoencoder will have a different receptive field size, and thus ideal overlap.
|
| 89 |
+
You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
|
| 90 |
+
The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
|
| 91 |
+
Smaller chunk_size uses less memory, but more compute.
|
| 92 |
+
The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
|
| 93 |
+
For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
|
| 94 |
+
"""
|
| 95 |
+
if not chunked:
|
| 96 |
+
# default behavior. Encode the entire audio in parallel
|
| 97 |
+
return self.encode(audio, **kwargs)
|
| 98 |
+
else:
|
| 99 |
+
# CHUNKED ENCODING
|
| 100 |
+
# samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
|
| 101 |
+
samples_per_latent = self.downsampling_ratio
|
| 102 |
+
total_size = audio.shape[2] # in samples
|
| 103 |
+
batch_size = audio.shape[0]
|
| 104 |
+
chunk_size *= samples_per_latent # converting metric in latents to samples
|
| 105 |
+
overlap *= samples_per_latent # converting metric in latents to samples
|
| 106 |
+
hop_size = chunk_size - overlap
|
| 107 |
+
chunks = []
|
| 108 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 109 |
+
chunk = audio[:, :, i : i + chunk_size]
|
| 110 |
+
chunks.append(chunk)
|
| 111 |
+
if i + chunk_size != total_size:
|
| 112 |
+
# Final chunk
|
| 113 |
+
chunk = audio[:, :, -chunk_size:]
|
| 114 |
+
chunks.append(chunk)
|
| 115 |
+
chunks = torch.stack(chunks)
|
| 116 |
+
num_chunks = chunks.shape[0]
|
| 117 |
+
# Note: y_size might be a different value from the latent length used in diffusion training
|
| 118 |
+
# because we can encode audio of varying lengths
|
| 119 |
+
# However, the audio should've been padded to a multiple of samples_per_latent by now.
|
| 120 |
+
y_size = total_size // samples_per_latent
|
| 121 |
+
# Create an empty latent, we will populate it with chunks as we encode them
|
| 122 |
+
y_final = torch.zeros((batch_size, self.latent_dim, y_size)).to(
|
| 123 |
+
audio.device
|
| 124 |
+
)
|
| 125 |
+
for i in range(num_chunks):
|
| 126 |
+
x_chunk = chunks[i, :]
|
| 127 |
+
# encode the chunk
|
| 128 |
+
y_chunk = self.encode(x_chunk)
|
| 129 |
+
# figure out where to put the audio along the time domain
|
| 130 |
+
if i == num_chunks - 1:
|
| 131 |
+
# final chunk always goes at the end
|
| 132 |
+
t_end = y_size
|
| 133 |
+
t_start = t_end - y_chunk.shape[2]
|
| 134 |
+
else:
|
| 135 |
+
t_start = i * hop_size // samples_per_latent
|
| 136 |
+
t_end = t_start + chunk_size // samples_per_latent
|
| 137 |
+
# remove the edges of the overlaps
|
| 138 |
+
ol = overlap // samples_per_latent // 2
|
| 139 |
+
chunk_start = 0
|
| 140 |
+
chunk_end = y_chunk.shape[2]
|
| 141 |
+
if i > 0:
|
| 142 |
+
# no overlap for the start of the first chunk
|
| 143 |
+
t_start += ol
|
| 144 |
+
chunk_start += ol
|
| 145 |
+
if i < num_chunks - 1:
|
| 146 |
+
# no overlap for the end of the last chunk
|
| 147 |
+
t_end -= ol
|
| 148 |
+
chunk_end -= ol
|
| 149 |
+
# paste the chunked audio into our y_final output audio
|
| 150 |
+
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
|
| 151 |
+
return y_final
|
| 152 |
+
|
| 153 |
+
def decode(self, z, **kwargs):
|
| 154 |
+
z = z * self.scale
|
| 155 |
+
|
| 156 |
+
if self.model_half:
|
| 157 |
+
z = z.half()
|
| 158 |
+
self.model.to(torch.float16)
|
| 159 |
+
|
| 160 |
+
decoded = self.model.decode_audio(
|
| 161 |
+
z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if self.model_half:
|
| 165 |
+
decoded = decoded.float()
|
| 166 |
+
|
| 167 |
+
return decoded
|
| 168 |
+
|
| 169 |
+
def decode_audio(
|
| 170 |
+
self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs
|
| 171 |
+
):
|
| 172 |
+
if not chunked:
|
| 173 |
+
# default behavior. Decode the entire latent in parallel
|
| 174 |
+
return self.decode(latents, **kwargs)
|
| 175 |
+
else:
|
| 176 |
+
# chunked decoding
|
| 177 |
+
hop_size = chunk_size - overlap
|
| 178 |
+
total_size = latents.shape[2]
|
| 179 |
+
batch_size = latents.shape[0]
|
| 180 |
+
chunks = []
|
| 181 |
+
i = 0
|
| 182 |
+
for i in range(0, total_size - chunk_size + 1, hop_size):
|
| 183 |
+
chunk = latents[:, :, i : i + chunk_size]
|
| 184 |
+
chunks.append(chunk)
|
| 185 |
+
if i + chunk_size != total_size:
|
| 186 |
+
# Final chunk
|
| 187 |
+
chunk = latents[:, :, -chunk_size:]
|
| 188 |
+
chunks.append(chunk)
|
| 189 |
+
chunks = torch.stack(chunks)
|
| 190 |
+
num_chunks = chunks.shape[0]
|
| 191 |
+
# samples_per_latent is just the downsampling ratio
|
| 192 |
+
samples_per_latent = self.downsampling_ratio
|
| 193 |
+
# Create an empty waveform, we will populate it with chunks as decode them
|
| 194 |
+
y_size = total_size * samples_per_latent
|
| 195 |
+
y_final = torch.zeros((batch_size, self.io_channels, y_size)).to(
|
| 196 |
+
latents.device
|
| 197 |
+
)
|
| 198 |
+
for i in range(num_chunks):
|
| 199 |
+
x_chunk = chunks[i, :]
|
| 200 |
+
# decode the chunk
|
| 201 |
+
y_chunk = self.decode(x_chunk)
|
| 202 |
+
# figure out where to put the audio along the time domain
|
| 203 |
+
if i == num_chunks - 1:
|
| 204 |
+
# final chunk always goes at the end
|
| 205 |
+
t_end = y_size
|
| 206 |
+
t_start = t_end - y_chunk.shape[2]
|
| 207 |
+
else:
|
| 208 |
+
t_start = i * hop_size * samples_per_latent
|
| 209 |
+
t_end = t_start + chunk_size * samples_per_latent
|
| 210 |
+
# remove the edges of the overlaps
|
| 211 |
+
ol = (overlap // 2) * samples_per_latent
|
| 212 |
+
chunk_start = 0
|
| 213 |
+
chunk_end = y_chunk.shape[2]
|
| 214 |
+
if i > 0:
|
| 215 |
+
# no overlap for the start of the first chunk
|
| 216 |
+
t_start += ol
|
| 217 |
+
chunk_start += ol
|
| 218 |
+
if i < num_chunks - 1:
|
| 219 |
+
# no overlap for the end of the last chunk
|
| 220 |
+
t_end -= ol
|
| 221 |
+
chunk_end -= ol
|
| 222 |
+
# paste the chunked audio into our y_final output audio
|
| 223 |
+
y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
|
| 224 |
+
return y_final
|
| 225 |
+
|
| 226 |
+
def tokenize(self, x, **kwargs):
|
| 227 |
+
assert self.model.is_discrete, "Cannot tokenize with a continuous model"
|
| 228 |
+
|
| 229 |
+
_, info = self.model.encode(x, return_info=True, **kwargs)
|
| 230 |
+
|
| 231 |
+
return info[self.model.bottleneck.tokens_id]
|
| 232 |
+
|
| 233 |
+
def decode_tokens(self, tokens, **kwargs):
|
| 234 |
+
assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
|
| 235 |
+
|
| 236 |
+
return self.model.decode_tokens(tokens, **kwargs)
|
| 237 |
+
|
| 238 |
+
def load_state_dict(self, state_dict, strict=True):
|
| 239 |
+
self.model.load_state_dict(state_dict, strict=strict)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class WaveletPretransform(Pretransform):
|
| 243 |
+
def __init__(self, channels, levels, wavelet):
|
| 244 |
+
super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
|
| 245 |
+
|
| 246 |
+
from .wavelets import WaveletDecode1d, WaveletEncode1d
|
| 247 |
+
|
| 248 |
+
self.encoder = WaveletEncode1d(channels, levels, wavelet)
|
| 249 |
+
self.decoder = WaveletDecode1d(channels, levels, wavelet)
|
| 250 |
+
|
| 251 |
+
self.downsampling_ratio = 2**levels
|
| 252 |
+
self.io_channels = channels
|
| 253 |
+
self.encoded_channels = channels * self.downsampling_ratio
|
| 254 |
+
|
| 255 |
+
def encode(self, x):
|
| 256 |
+
return self.encoder(x)
|
| 257 |
+
|
| 258 |
+
def decode(self, z):
|
| 259 |
+
return self.decoder(z)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class PQMFPretransform(Pretransform):
|
| 263 |
+
def __init__(self, attenuation=100, num_bands=16):
|
| 264 |
+
# TODO: Fix PQMF to take in in-channels
|
| 265 |
+
super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
|
| 266 |
+
from .pqmf import PQMF
|
| 267 |
+
|
| 268 |
+
self.pqmf = PQMF(attenuation, num_bands)
|
| 269 |
+
|
| 270 |
+
def encode(self, x):
|
| 271 |
+
# x is (Batch x Channels x Time)
|
| 272 |
+
x = self.pqmf.forward(x)
|
| 273 |
+
# pqmf.forward returns (Batch x Channels x Bands x Time)
|
| 274 |
+
# but Pretransform needs Batch x Channels x Time
|
| 275 |
+
# so concatenate channels and bands into one axis
|
| 276 |
+
return rearrange(x, "b c n t -> b (c n) t")
|
| 277 |
+
|
| 278 |
+
def decode(self, x):
|
| 279 |
+
# x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
|
| 280 |
+
x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
|
| 281 |
+
# returns (Batch x Channels x Time)
|
| 282 |
+
return self.pqmf.inverse(x)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class PretrainedDACPretransform(Pretransform):
|
| 286 |
+
def __init__(
|
| 287 |
+
self,
|
| 288 |
+
model_type="44khz",
|
| 289 |
+
model_bitrate="8kbps",
|
| 290 |
+
scale=1.0,
|
| 291 |
+
quantize_on_decode: bool = True,
|
| 292 |
+
chunked=True,
|
| 293 |
+
):
|
| 294 |
+
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
| 295 |
+
|
| 296 |
+
import dac
|
| 297 |
+
|
| 298 |
+
model_path = dac.utils.download(
|
| 299 |
+
model_type=model_type, model_bitrate=model_bitrate
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.model = dac.DAC.load(model_path)
|
| 303 |
+
|
| 304 |
+
self.quantize_on_decode = quantize_on_decode
|
| 305 |
+
|
| 306 |
+
if model_type == "44khz":
|
| 307 |
+
self.downsampling_ratio = 512
|
| 308 |
+
else:
|
| 309 |
+
self.downsampling_ratio = 320
|
| 310 |
+
|
| 311 |
+
self.io_channels = 1
|
| 312 |
+
|
| 313 |
+
self.scale = scale
|
| 314 |
+
|
| 315 |
+
self.chunked = chunked
|
| 316 |
+
|
| 317 |
+
self.encoded_channels = self.model.latent_dim
|
| 318 |
+
|
| 319 |
+
self.num_quantizers = self.model.n_codebooks
|
| 320 |
+
|
| 321 |
+
self.codebook_size = self.model.codebook_size
|
| 322 |
+
|
| 323 |
+
def encode(self, x):
|
| 324 |
+
latents = self.model.encoder(x)
|
| 325 |
+
|
| 326 |
+
if self.quantize_on_decode:
|
| 327 |
+
output = latents
|
| 328 |
+
else:
|
| 329 |
+
z, _, _, _, _ = self.model.quantizer(
|
| 330 |
+
latents, n_quantizers=self.model.n_codebooks
|
| 331 |
+
)
|
| 332 |
+
output = z
|
| 333 |
+
|
| 334 |
+
if self.scale != 1.0:
|
| 335 |
+
output = output / self.scale
|
| 336 |
+
|
| 337 |
+
return output
|
| 338 |
+
|
| 339 |
+
def decode(self, z):
|
| 340 |
+
if self.scale != 1.0:
|
| 341 |
+
z = z * self.scale
|
| 342 |
+
|
| 343 |
+
if self.quantize_on_decode:
|
| 344 |
+
z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
| 345 |
+
|
| 346 |
+
return self.model.decode(z)
|
| 347 |
+
|
| 348 |
+
def tokenize(self, x):
|
| 349 |
+
return self.model.encode(x)[1]
|
| 350 |
+
|
| 351 |
+
def decode_tokens(self, tokens):
|
| 352 |
+
latents = self.model.quantizer.from_codes(tokens)
|
| 353 |
+
return self.model.decode(latents)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class AudiocraftCompressionPretransform(Pretransform):
|
| 357 |
+
def __init__(
|
| 358 |
+
self,
|
| 359 |
+
model_type="facebook/encodec_32khz",
|
| 360 |
+
scale=1.0,
|
| 361 |
+
quantize_on_decode: bool = True,
|
| 362 |
+
):
|
| 363 |
+
super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
|
| 364 |
+
|
| 365 |
+
try:
|
| 366 |
+
from audiocraft.models import CompressionModel
|
| 367 |
+
except ImportError:
|
| 368 |
+
raise ImportError(
|
| 369 |
+
"Audiocraft is not installed. Please install audiocraft to use Audiocraft models."
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self.model = CompressionModel.get_pretrained(model_type)
|
| 373 |
+
|
| 374 |
+
self.quantize_on_decode = quantize_on_decode
|
| 375 |
+
|
| 376 |
+
self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
|
| 377 |
+
|
| 378 |
+
self.sample_rate = self.model.sample_rate
|
| 379 |
+
|
| 380 |
+
self.io_channels = self.model.channels
|
| 381 |
+
|
| 382 |
+
self.scale = scale
|
| 383 |
+
|
| 384 |
+
# self.encoded_channels = self.model.latent_dim
|
| 385 |
+
|
| 386 |
+
self.num_quantizers = self.model.num_codebooks
|
| 387 |
+
|
| 388 |
+
self.codebook_size = self.model.cardinality
|
| 389 |
+
|
| 390 |
+
self.model.to(torch.float16).eval().requires_grad_(False)
|
| 391 |
+
|
| 392 |
+
def encode(self, x):
|
| 393 |
+
assert False, "Audiocraft compression models do not support continuous encoding"
|
| 394 |
+
|
| 395 |
+
# latents = self.model.encoder(x)
|
| 396 |
+
|
| 397 |
+
# if self.quantize_on_decode:
|
| 398 |
+
# output = latents
|
| 399 |
+
# else:
|
| 400 |
+
# z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
|
| 401 |
+
# output = z
|
| 402 |
+
|
| 403 |
+
# if self.scale != 1.0:
|
| 404 |
+
# output = output / self.scale
|
| 405 |
+
|
| 406 |
+
# return output
|
| 407 |
+
|
| 408 |
+
def decode(self, z):
|
| 409 |
+
assert False, "Audiocraft compression models do not support continuous decoding"
|
| 410 |
+
|
| 411 |
+
# if self.scale != 1.0:
|
| 412 |
+
# z = z * self.scale
|
| 413 |
+
|
| 414 |
+
# if self.quantize_on_decode:
|
| 415 |
+
# z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
|
| 416 |
+
|
| 417 |
+
# return self.model.decode(z)
|
| 418 |
+
|
| 419 |
+
def tokenize(self, x):
|
| 420 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 421 |
+
return self.model.encode(x.to(torch.float16))[0]
|
| 422 |
+
|
| 423 |
+
def decode_tokens(self, tokens):
|
| 424 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 425 |
+
return self.model.decode(tokens)
|