--- language: - en tags: - speech license: - cc-by-sa-3.0 - cc-by-4.0 --- # Wave-Trainer-Fit | Neural vocoder from SSL features [[Code of Wave-Trainer-Fit](https://github.com/line/WaveTrainerFit)][[audio samples](https://i17oonaka-h.github.io/projects/research_topics/wave_trainer_fit/)] >**Abstract:**
We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted. ![concept.png](./assets/concept.png) This repository provides pre-trained models and their optimizers. The models were pre-trained on [LibriTTS-R](https://www.openslr.org/141/) (train-clean-360). Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input. ## Pre-trained model list > [!IMPORTANT] ⚠️ **License Notice**: The model weights provided in this repository are licensed under different terms. The `xlsr*` and `whisper*` models are licensed differently from the `wavlm*` models. Please refer to the [License](#license) section for details. The list of available models is as follows: | Model-name | Conditional features | Layer num | #iters of model | |:------| :---------: | :---: | :---: | | wavlm2_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 | | wavlm2_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 | | wavlm8_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 | | wavlm8_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 | | wavlm24_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 | | wavlm24_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 | | xlsr8_wavetrainerfit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 | | xlsr8_wavefit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 | | whisper8_wavetrainerfit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 | | whisper8_wavefit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 | ※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds. During evaluation, our model processed inputs by dividing them into `2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing`. If you use this model in your application, the upstream feature extraction must also follow this flow. ## Usage Please refer to [our GitHub repository](https://github.com/line/WaveTrainerFit) for instructions on how to use the models provided here. The following is reproduced from the GitHub repository: ```python import torchaudio import torch from wavetrainerfit import load_pretrained_vocoder from transformers import WavLMModel, AutoFeatureExtractor ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large') ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large') layer = 2 ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5') waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav') if sr != 16000: waveform = torchaudio.transforms.Resample( orig_freq=sr, new_freq=16000 )(waveform) inputs = ssl_preprocessor( waveform[0].numpy(), sampling_rate=16000, return_tensors="pt" ) with torch.no_grad(): inputs = ssl_model(**inputs, output_hidden_states=True) inputs = inputs.hidden_states[layer] # (Batch, Timeframe, Featuredim) generated_waveform = ssl_vocoder.pred( conditional_feature=inputs, # (Batch, Timeframe, Featuredim) T_=5 # num of iteration ) torchaudio.save( './assets/ljspeech-samples/LJ037-0171-reconstructed.wav', generated_waveform[-1][:, 0].cpu(), 24000 ) ``` ## License ### Model Weights **`xlsr*` and `whisper*` models**: Licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) - Based on XLS-R by facebook (Apache 2.0) - https://huggingface.co/facebook/wav2vec2-xls-r-300m - Based on Whisper by OpenAI (Apache 2.0) - https://huggingface.co/openai/whisper-medium **`wavlm*` models**: Licensed under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/) - Based on WavLM by Microsoft Corporation ([CC BY-SA 3.0](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)) - https://huggingface.co/microsoft/wavlm-large - ⚠️ Derivative works must also use CC BY-SA 3.0 **Training data**: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/ When using these models, you must comply with both our license and the original upstream model licenses.