WaveTrainerFit / README.md
nehi-h's picture
initial commit
524304c
---
language:
- en
tags:
- speech
license:
- cc-by-sa-3.0
- cc-by-4.0
---
# Wave-Trainer-Fit | Neural vocoder from SSL features
[[Code of Wave-Trainer-Fit](https://github.com/line/WaveTrainerFit)][[audio samples](https://i17oonaka-h.github.io/projects/research_topics/wave_trainer_fit/)]
>**Abstract:**<br>
We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted.
![concept.png](./assets/concept.png)
This repository provides pre-trained models and their optimizers.
The models were pre-trained on [LibriTTS-R](https://www.openslr.org/141/) (train-clean-360).
Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input.
## Pre-trained model list
> [!IMPORTANT]
⚠️ **License Notice**: The model weights provided in this repository are licensed under different terms. The `xlsr*` and `whisper*` models are licensed differently from the `wavlm*` models. Please refer to the [License](#license) section for details.
The list of available models is as follows:
| Model-name | Conditional features | Layer num | #iters of model |
|:------| :---------: | :---: | :---: |
| wavlm2_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 |
| wavlm2_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 |
| wavlm8_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 |
| wavlm8_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 |
| wavlm24_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 |
| wavlm24_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 |
| xlsr8_wavetrainerfit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 |
| xlsr8_wavefit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 |
| whisper8_wavetrainerfit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 |
| whisper8_wavefit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 |
※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds.
During evaluation, our model processed inputs by dividing them into `2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing`.
If you use this model in your application, the upstream feature extraction must also follow this flow.
## Usage
Please refer to [our GitHub repository](https://github.com/line/WaveTrainerFit) for instructions on how to use the models provided here.
The following is reproduced from the GitHub repository:
```python
import torchaudio
import torch
from wavetrainerfit import load_pretrained_vocoder
from transformers import WavLMModel, AutoFeatureExtractor
ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large')
ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large')
layer = 2
ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5')
waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav')
if sr != 16000:
waveform = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=16000
)(waveform)
inputs = ssl_preprocessor(
waveform[0].numpy(),
sampling_rate=16000,
return_tensors="pt"
)
with torch.no_grad():
inputs = ssl_model(**inputs, output_hidden_states=True)
inputs = inputs.hidden_states[layer] # (Batch, Timeframe, Featuredim)
generated_waveform = ssl_vocoder.pred(
conditional_feature=inputs, # (Batch, Timeframe, Featuredim)
T_=5 # num of iteration
)
torchaudio.save(
'./assets/ljspeech-samples/LJ037-0171-reconstructed.wav',
generated_waveform[-1][:, 0].cpu(), 24000
)
```
## License
### Model Weights
**`xlsr*` and `whisper*` models**: Licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)
- Based on XLS-R by facebook (Apache 2.0) - https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Based on Whisper by OpenAI (Apache 2.0) - https://huggingface.co/openai/whisper-medium
**`wavlm*` models**: Licensed under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/)
- Based on WavLM by Microsoft Corporation ([CC BY-SA 3.0](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)) - https://huggingface.co/microsoft/wavlm-large
- ⚠️ Derivative works must also use CC BY-SA 3.0
**Training data**: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/
When using these models, you must comply with both our license and the original upstream model licenses.