| | --- |
| | language: |
| | - en |
| | tags: |
| | - speech |
| | license: |
| | - cc-by-sa-3.0 |
| | - cc-by-4.0 |
| | --- |
| | |
| | # Wave-Trainer-Fit | Neural vocoder from SSL features |
| |
|
| | [[Code of Wave-Trainer-Fit](https://github.com/line/WaveTrainerFit)][[audio samples](https://i17oonaka-h.github.io/projects/research_topics/wave_trainer_fit/)] |
| | |
| | >**Abstract:**<br> |
| | We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted. |
| | |
| |  |
| | |
| | This repository provides pre-trained models and their optimizers. |
| | The models were pre-trained on [LibriTTS-R](https://www.openslr.org/141/) (train-clean-360). |
| | Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input. |
| | |
| | ## Pre-trained model list |
| | > [!IMPORTANT] |
| | ⚠️ **License Notice**: The model weights provided in this repository are licensed under different terms. The `xlsr*` and `whisper*` models are licensed differently from the `wavlm*` models. Please refer to the [License](#license) section for details. |
| | |
| | The list of available models is as follows: |
| | |
| | | Model-name | Conditional features | Layer num | #iters of model | |
| | |:------| :---------: | :---: | :---: | |
| | | wavlm2_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 | |
| | | wavlm2_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 | |
| | | wavlm8_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 | |
| | | wavlm8_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 | |
| | | wavlm24_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 | |
| | | wavlm24_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 | |
| | | xlsr8_wavetrainerfit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 | |
| | | xlsr8_wavefit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 | |
| | | whisper8_wavetrainerfit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 | |
| | | whisper8_wavefit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 | |
| | |
| | ※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds. |
| | During evaluation, our model processed inputs by dividing them into `2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing`. |
| | If you use this model in your application, the upstream feature extraction must also follow this flow. |
| | |
| | ## Usage |
| | Please refer to [our GitHub repository](https://github.com/line/WaveTrainerFit) for instructions on how to use the models provided here. |
| | The following is reproduced from the GitHub repository: |
| | ```python |
| | import torchaudio |
| | import torch |
| | from wavetrainerfit import load_pretrained_vocoder |
| | from transformers import WavLMModel, AutoFeatureExtractor |
| | |
| | ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large') |
| | ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large') |
| | |
| | layer = 2 |
| | ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5') |
| | waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav') |
| | if sr != 16000: |
| | waveform = torchaudio.transforms.Resample( |
| | orig_freq=sr, |
| | new_freq=16000 |
| | )(waveform) |
| | inputs = ssl_preprocessor( |
| | waveform[0].numpy(), |
| | sampling_rate=16000, |
| | return_tensors="pt" |
| | ) |
| | |
| | with torch.no_grad(): |
| | inputs = ssl_model(**inputs, output_hidden_states=True) |
| | inputs = inputs.hidden_states[layer] # (Batch, Timeframe, Featuredim) |
| | generated_waveform = ssl_vocoder.pred( |
| | conditional_feature=inputs, # (Batch, Timeframe, Featuredim) |
| | T_=5 # num of iteration |
| | ) |
| | |
| | torchaudio.save( |
| | './assets/ljspeech-samples/LJ037-0171-reconstructed.wav', |
| | generated_waveform[-1][:, 0].cpu(), 24000 |
| | ) |
| | ``` |
| | |
| | ## License |
| | |
| | ### Model Weights |
| | |
| | **`xlsr*` and `whisper*` models**: Licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) |
| | - Based on XLS-R by facebook (Apache 2.0) - https://huggingface.co/facebook/wav2vec2-xls-r-300m |
| | - Based on Whisper by OpenAI (Apache 2.0) - https://huggingface.co/openai/whisper-medium |
| | |
| | **`wavlm*` models**: Licensed under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/) |
| | - Based on WavLM by Microsoft Corporation ([CC BY-SA 3.0](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)) - https://huggingface.co/microsoft/wavlm-large |
| | - ⚠️ Derivative works must also use CC BY-SA 3.0 |
| | |
| | **Training data**: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/ |
| | |
| | When using these models, you must comply with both our license and the original upstream model licenses. |
| | |
| | |