File size: 5,676 Bytes
524304c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
---
language:
- en
tags:
- speech
license:
- cc-by-sa-3.0
- cc-by-4.0
---

# Wave-Trainer-Fit | Neural vocoder from SSL features

[[Code of Wave-Trainer-Fit](https://github.com/line/WaveTrainerFit)][[audio samples](https://i17oonaka-h.github.io/projects/research_topics/wave_trainer_fit/)]

>**Abstract:**<br>
We propose WaveTrainerFit, a neural vocoder that performs high-quality waveform generation from data-driven features such as SSL features. WaveTrainerFit builds upon the WaveFit vocoder, which integrates diffusion model and generative adversarial network. Furthermore, the proposed method incorporates the following key improvements: 1. By introducing trainable priors, the inference process starts from noise close to the target speech instead of Gaussian noise. 2. Reference-aware gain adjustment is performed by imposing constraints on the trainable prior to matching the speech energy. These improvements are expected to reduce the complexity of waveform modeling from data-driven features, enabling high-quality waveform generation with fewer inference steps. Through experiments, we showed that WaveTrainerFit can generate highly natural waveforms with improved speaker similarity from data-driven features, while requiring fewer iterations than WaveFit. Moreover, we showed that the proposed method works robustly with respect to the depth at which SSL features are extracted.

![concept.png](./assets/concept.png)

This repository provides pre-trained models and their optimizers. 
The models were pre-trained on [LibriTTS-R](https://www.openslr.org/141/) (train-clean-360).
Also, these models operate to reconstruct 24kHz audio by taking SSL features from 16kHz audio as input.

## Pre-trained model list
> [!IMPORTANT]
⚠️ **License Notice**: The model weights provided in this repository are licensed under different terms. The `xlsr*` and `whisper*` models are licensed differently from the `wavlm*` models. Please refer to the [License](#license) section for details.

The list of available models is as follows:

| Model-name | Conditional features | Layer num | #iters of model |
|:------| :---------: | :---: | :---: | 
| wavlm2_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 |
| wavlm2_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 2 | 5 |
| wavlm8_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 |
| wavlm8_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 8| 5 |
| wavlm24_wavetrainerfit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 |
| wavlm24_wavefit5 | [WavLM-large](https://huggingface.co/microsoft/wavlm-large) | 24| 5 |
| xlsr8_wavetrainerfit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 |
| xlsr8_wavefit5 | [XLS-R-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 8| 5 |
| whisper8_wavetrainerfit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 |
| whisper8_wavefit5 | ※ [Whisper-medium](https://huggingface.co/openai/whisper-medium) | 8| 5 |

※ As a result of our verification, we found that amplitude decay occurs in Whisper features after about 2.0 seconds. 
During evaluation, our model processed inputs by dividing them into `2.0-second segments → extracting features with the Whisper encoder → recombining → resynthesizing`. 
If you use this model in your application, the upstream feature extraction must also follow this flow.

## Usage
Please refer to [our GitHub repository](https://github.com/line/WaveTrainerFit) for instructions on how to use the models provided here.
The following is reproduced from the GitHub repository:
```python
import torchaudio
import torch
from wavetrainerfit import load_pretrained_vocoder
from transformers import WavLMModel, AutoFeatureExtractor

ssl_preprocessor = AutoFeatureExtractor.from_pretrained('microsoft/wavlm-large')
ssl_model: WavLMModel = WavLMModel.from_pretrained('microsoft/wavlm-large')

layer = 2
ssl_vocoder, cfg = load_pretrained_vocoder(f'wavlm{layer}_wavetrainerfit5')
waveform, sr = torchaudio.load('./assets/ljspeech-samples/LJ037-0171.wav')
if sr != 16000:
    waveform = torchaudio.transforms.Resample(
        orig_freq=sr, 
        new_freq=16000
    )(waveform)
inputs = ssl_preprocessor(
    waveform[0].numpy(), 
    sampling_rate=16000, 
    return_tensors="pt"
)

with torch.no_grad():
    inputs = ssl_model(**inputs, output_hidden_states=True)
    inputs = inputs.hidden_states[layer]  # (Batch, Timeframe, Featuredim)
    generated_waveform = ssl_vocoder.pred(
        conditional_feature=inputs, # (Batch, Timeframe, Featuredim)
        T_=5 # num of iteration
    )

torchaudio.save(
    './assets/ljspeech-samples/LJ037-0171-reconstructed.wav', 
    generated_waveform[-1][:, 0].cpu(), 24000
)
```

## License

### Model Weights

**`xlsr*` and `whisper*` models**: Licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)
- Based on XLS-R by facebook (Apache 2.0) - https://huggingface.co/facebook/wav2vec2-xls-r-300m
- Based on Whisper by OpenAI (Apache 2.0) - https://huggingface.co/openai/whisper-medium

**`wavlm*` models**: Licensed under [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/)
- Based on WavLM by Microsoft Corporation ([CC BY-SA 3.0](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)) - https://huggingface.co/microsoft/wavlm-large
- ⚠️ Derivative works must also use CC BY-SA 3.0

**Training data**: LibriTTS-R (CC BY 4.0) - https://www.openslr.org/141/

When using these models, you must comply with both our license and the original upstream model licenses.