| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
|
|
| import numpy as np |
| import torch |
| import torchaudio as ta |
| from functools import lru_cache |
| from typing import Optional |
|
|
| from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer |
| from .const import S3GEN_SR |
| from .flow import CausalMaskedDiffWithXvec |
| from .xvector import CAMPPlus |
| from .utils.mel import mel_spectrogram |
| from .f0_predictor import ConvRNNF0Predictor |
| from .hifigan import HiFTGenerator |
| from .transformer.upsample_encoder import UpsampleConformerEncoder |
| from .flow_matching import CausalConditionalCFM |
| from .decoder import ConditionalDecoder |
| from .configs import CFM_PARAMS |
|
|
|
|
| def drop_invalid_tokens(x): |
| assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" |
| return x[x < SPEECH_VOCAB_SIZE] |
|
|
|
|
| |
| @lru_cache(100) |
| def get_resampler(src_sr, dst_sr, device): |
| return ta.transforms.Resample(src_sr, dst_sr).to(device) |
|
|
|
|
| class S3Token2Mel(torch.nn.Module): |
| """ |
| CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. |
| |
| TODO: make these modules configurable? |
| """ |
| def __init__(self): |
| super().__init__() |
| self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") |
| self.mel_extractor = mel_spectrogram |
| self.speaker_encoder = CAMPPlus() |
|
|
| encoder = UpsampleConformerEncoder( |
| output_size=512, |
| attention_heads=8, |
| linear_units=2048, |
| num_blocks=6, |
| dropout_rate=0.1, |
| positional_dropout_rate=0.1, |
| attention_dropout_rate=0.1, |
| normalize_before=True, |
| input_layer='linear', |
| pos_enc_layer_type='rel_pos_espnet', |
| selfattention_layer_type='rel_selfattn', |
| input_size=512, |
| use_cnn_module=False, |
| macaron_style=False, |
| ) |
|
|
| estimator = ConditionalDecoder( |
| in_channels=320, |
| out_channels=80, |
| causal=True, |
| channels=[256], |
| dropout=0.0, |
| attention_head_dim=64, |
| n_blocks=4, |
| num_mid_blocks=12, |
| num_heads=8, |
| act_fn='gelu', |
| ) |
| cfm_params = CFM_PARAMS |
| decoder = CausalConditionalCFM( |
| spk_emb_dim=80, |
| cfm_params=cfm_params, |
| estimator=estimator, |
| ) |
|
|
| self.flow = CausalMaskedDiffWithXvec( |
| encoder=encoder, |
| decoder=decoder |
| ) |
|
|
| self.resamplers = {} |
|
|
| @property |
| def device(self): |
| params = self.tokenizer.parameters() |
| return next(params).device |
|
|
| def embed_ref( |
| self, |
| ref_wav: torch.Tensor, |
| ref_sr: int, |
| device="auto", |
| ref_fade_out=True, |
| ): |
| device = self.device if device == "auto" else device |
| if isinstance(ref_wav, np.ndarray): |
| ref_wav = torch.from_numpy(ref_wav).float() |
|
|
| if ref_wav.device != device: |
| ref_wav = ref_wav.to(device) |
|
|
| if len(ref_wav.shape) == 1: |
| ref_wav = ref_wav.unsqueeze(0) |
|
|
| if ref_wav.size(1) > 10 * ref_sr: |
| print("WARNING: cosydec received ref longer than 10s") |
|
|
| ref_wav_24 = ref_wav |
| if ref_sr != S3GEN_SR: |
| ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) |
|
|
| ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) |
| ref_mels_24_len = None |
|
|
| |
| ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) |
|
|
| |
| ref_x_vector = self.speaker_encoder.inference(ref_wav_16) |
|
|
| |
| ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) |
|
|
| |
| if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: |
| logging.warning( |
| "Reference mel length is not equal to 2 * reference token length.\n" |
| ) |
| ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] |
| ref_speech_token_lens[0] = ref_speech_tokens.shape[1] |
|
|
| return dict( |
| prompt_token=ref_speech_tokens.to(device), |
| prompt_token_len=ref_speech_token_lens, |
| prompt_feat=ref_mels_24, |
| prompt_feat_len=ref_mels_24_len, |
| embedding=ref_x_vector, |
| ) |
|
|
| def forward( |
| self, |
| speech_tokens: torch.LongTensor, |
| |
| ref_wav: Optional[torch.Tensor], |
| ref_sr: Optional[int], |
| |
| ref_dict: Optional[dict] = None, |
| finalize: bool = False, |
| ): |
| """ |
| Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. |
| |
| NOTE: |
| - The speaker encoder accepts 16 kHz waveform. |
| - S3TokenizerV2 accepts 16 kHz waveform. |
| - The mel-spectrogram for the reference assumes 24 kHz input signal. |
| - This function is designed for batch_size=1 only. |
| |
| Args |
| ---- |
| - `speech_tokens`: S3 speech tokens [B=1, T] |
| - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) |
| - `ref_sr`: reference sample rate |
| - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. |
| """ |
| assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" |
|
|
| if ref_dict is None: |
| ref_dict = self.embed_ref(ref_wav, ref_sr) |
| else: |
| |
| for rk in list(ref_dict): |
| if isinstance(ref_dict[rk], np.ndarray): |
| ref_dict[rk] = torch.from_numpy(ref_dict[rk]) |
| if torch.is_tensor(ref_dict[rk]): |
| ref_dict[rk] = ref_dict[rk].to(self.device) |
|
|
| if len(speech_tokens.shape) == 1: |
| speech_tokens = speech_tokens.unsqueeze(0) |
|
|
| |
| speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) |
|
|
| output_mels, _ = self.flow.inference( |
| token=speech_tokens, |
| token_len=speech_token_lens, |
| finalize=finalize, |
| **ref_dict, |
| ) |
| return output_mels |
|
|
|
|
| class S3Token2Wav(S3Token2Mel): |
| """ |
| The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. |
| |
| TODO: make these modules configurable? |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| f0_predictor = ConvRNNF0Predictor() |
| self.mel2wav = HiFTGenerator( |
| sampling_rate=S3GEN_SR, |
| upsample_rates=[8, 5, 3], |
| upsample_kernel_sizes=[16, 11, 7], |
| source_resblock_kernel_sizes=[7, 7, 11], |
| source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], |
| f0_predictor=f0_predictor, |
| ) |
|
|
| |
| n_trim = S3GEN_SR // 50 |
| trim_fade = torch.zeros(2 * n_trim) |
| trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 |
| self.register_buffer("trim_fade", trim_fade, persistent=False) |
|
|
| def forward( |
| self, |
| speech_tokens, |
| |
| ref_wav: Optional[torch.Tensor], |
| ref_sr: Optional[int], |
| |
| ref_dict: Optional[dict] = None, |
| finalize: bool = False |
| ): |
| output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) |
|
|
| |
| hift_cache_source = torch.zeros(1, 1, 0).to(self.device) |
|
|
| output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) |
|
|
| if not self.training: |
| |
| output_wavs[:, :len(self.trim_fade)] *= self.trim_fade |
|
|
| return output_wavs |
|
|
| @torch.inference_mode() |
| def flow_inference( |
| self, |
| speech_tokens, |
| |
| ref_wav: Optional[torch.Tensor] = None, |
| ref_sr: Optional[int] = None, |
| |
| ref_dict: Optional[dict] = None, |
| finalize: bool = False, |
| ): |
| return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) |
|
|
| @torch.inference_mode() |
| def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): |
| if cache_source is None: |
| cache_source = torch.zeros(1, 1, 0).to(self.device) |
| return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) |
|
|
| @torch.inference_mode() |
| def inference( |
| self, |
| speech_tokens, |
| |
| ref_wav: Optional[torch.Tensor] = None, |
| ref_sr: Optional[int] = None, |
| |
| ref_dict: Optional[dict] = None, |
| cache_source: torch.Tensor = None, |
| finalize: bool = True, |
| ): |
| output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) |
| output_wavs, output_sources = self.hift_inference(output_mels, cache_source) |
|
|
| |
| output_wavs[:, :len(self.trim_fade)] *= self.trim_fade |
|
|
| return output_wavs, output_sources |
|
|