Wave-Trainer-Fit | Neural vocoder from SSL features

[Code of Wave-Trainer-Fit][audio samples]

Abstract:
We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted.

concept.png

This repository provides pre-trained models and their optimizers. The models were pre-trained on LibriTTS-R (train-clean-360). Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input.

Pre-trained model list

⚠️ License Notice: The model weights provided in this repository are licensed under different terms. The xlsr* and whisper* models are licensed differently from the wavlm* models. Please refer to the License section for details.

The list of available models is as follows:

Model-name Conditional features Layer num #iters of model
wavlm2_wavetrainerfit5 WavLM-large 2 5
wavlm2_wavefit5 WavLM-large 2 5
wavlm8_wavetrainerfit5 WavLM-large 8 5
wavlm8_wavefit5 WavLM-large 8 5
wavlm24_wavetrainerfit5 WavLM-large 24 5
wavlm24_wavefit5 WavLM-large 24 5
xlsr8_wavetrainerfit5 XLS-R-300m 8 5
xlsr8_wavefit5 XLS-R-300m 8 5
whisper8_wavetrainerfit5 Whisper-medium 8 5
whisper8_wavefit5 Whisper-medium 8 5

※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds. During evaluation, our model processed inputs by dividing them into 2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing. If you use this model in your application, the upstream feature extraction must also follow this flow.

Usage

Please refer to our GitHub repository for instructions on how to use the models provided here. The following is reproduced from the GitHub repository:

import torchaudio
import torch
from wavetrainerfit import load_pretrained_vocoder
from transformers import WavLMModel, AutoFeatureExtractor

ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large')
ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large')

layer = 2
ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5')
waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav')
if sr != 16000:
    waveform = torchaudio.transforms.Resample(
        orig_freq=sr, 
        new_freq=16000
    )(waveform)
inputs = ssl_preprocessor(
    waveform[0].numpy(), 
    sampling_rate=16000, 
    return_tensors="pt"
)

with torch.no_grad():
    inputs = ssl_model(**inputs, output_hidden_states=True)
    inputs = inputs.hidden_states[layer]  # (Batch, Timeframe, Featuredim)
    generated_waveform = ssl_vocoder.pred(
        conditional_feature=inputs, # (Batch, Timeframe, Featuredim)
        T_=5 # num of iteration
    )

torchaudio.save(
    './assets/ljspeech-samples/LJ037-0171-reconstructed.wav', 
    generated_waveform[-1][:, 0].cpu(), 24000
)

License

Model Weights

xlsr* and whisper* models: Licensed under CC BY 4.0

wavlm* models: Licensed under CC BY-SA 3.0

Training data: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/

When using these models, you must comply with both our license and the original upstream model licenses.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support