Wave-Trainer-Fit | Neural vocoder from SSL features
[Code of Wave-Trainer-Fit][audio samples]
Abstract:
We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted.
This repository provides pre-trained models and their optimizers. The models were pre-trained on LibriTTS-R (train-clean-360). Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input.
Pre-trained model list
⚠️ License Notice: The model weights provided in this repository are licensed under different terms. The
xlsr*andwhisper*models are licensed differently from thewavlm*models. Please refer to the License section for details.
The list of available models is as follows:
| Model-name | Conditional features | Layer num | #iters of model |
|---|---|---|---|
| wavlm2_wavetrainerfit5 | WavLM-large | 2 | 5 |
| wavlm2_wavefit5 | WavLM-large | 2 | 5 |
| wavlm8_wavetrainerfit5 | WavLM-large | 8 | 5 |
| wavlm8_wavefit5 | WavLM-large | 8 | 5 |
| wavlm24_wavetrainerfit5 | WavLM-large | 24 | 5 |
| wavlm24_wavefit5 | WavLM-large | 24 | 5 |
| xlsr8_wavetrainerfit5 | XLS-R-300m | 8 | 5 |
| xlsr8_wavefit5 | XLS-R-300m | 8 | 5 |
| whisper8_wavetrainerfit5 | ※ Whisper-medium | 8 | 5 |
| whisper8_wavefit5 | ※ Whisper-medium | 8 | 5 |
※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds.
During evaluation, our model processed inputs by dividing them into 2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing.
If you use this model in your application, the upstream feature extraction must also follow this flow.
Usage
Please refer to our GitHub repository for instructions on how to use the models provided here. The following is reproduced from the GitHub repository:
import torchaudio
import torch
from wavetrainerfit import load_pretrained_vocoder
from transformers import WavLMModel, AutoFeatureExtractor
ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large')
ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large')
layer = 2
ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5')
waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav')
if sr != 16000:
waveform = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=16000
)(waveform)
inputs = ssl_preprocessor(
waveform[0].numpy(),
sampling_rate=16000,
return_tensors="pt"
)
with torch.no_grad():
inputs = ssl_model(**inputs, output_hidden_states=True)
inputs = inputs.hidden_states[layer] # (Batch, Timeframe, Featuredim)
generated_waveform = ssl_vocoder.pred(
conditional_feature=inputs, # (Batch, Timeframe, Featuredim)
T_=5 # num of iteration
)
torchaudio.save(
'./assets/ljspeech-samples/LJ037-0171-reconstructed.wav',
generated_waveform[-1][:, 0].cpu(), 24000
)
License
Model Weights
xlsr* and whisper* models: Licensed under CC BY 4.0
- Based on XLS-R by facebook (Apache 2.0) - https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Based on Whisper by OpenAI (Apache 2.0) - https://huggingface.co/openai/whisper-medium
wavlm* models: Licensed under CC BY-SA 3.0
- Based on WavLM by Microsoft Corporation (CC BY-SA 3.0) - https://huggingface.co/microsoft/wavlm-large
- ⚠️ Derivative works must also use CC BY-SA 3.0
Training data: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/
When using these models, you must comply with both our license and the original upstream model licenses.
