Spaces:
Build error
Build error
Upload 44 files
Browse files- configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml +93 -0
- configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +93 -0
- data/demo.txt +4 -0
- decoder/__init__.py +4 -0
- decoder/__pycache__/__init__.cpython-310.pyc +0 -0
- decoder/__pycache__/__init__.cpython-38.pyc +0 -0
- decoder/__pycache__/__init__.cpython-39.pyc +0 -0
- decoder/__pycache__/dataset.cpython-310.pyc +0 -0
- decoder/__pycache__/discriminator_dac.cpython-310.pyc +0 -0
- decoder/__pycache__/discriminators.cpython-310.pyc +0 -0
- decoder/__pycache__/experiment.cpython-310.pyc +0 -0
- decoder/__pycache__/feature_extractors.cpython-310.pyc +0 -0
- decoder/__pycache__/feature_extractors.cpython-38.pyc +0 -0
- decoder/__pycache__/feature_extractors.cpython-39.pyc +0 -0
- decoder/__pycache__/heads.cpython-310.pyc +0 -0
- decoder/__pycache__/heads.cpython-39.pyc +0 -0
- decoder/__pycache__/helpers.cpython-310.pyc +0 -0
- decoder/__pycache__/loss.cpython-310.pyc +0 -0
- decoder/__pycache__/models.cpython-310.pyc +0 -0
- decoder/__pycache__/models.cpython-39.pyc +0 -0
- decoder/__pycache__/modules.cpython-310.pyc +0 -0
- decoder/__pycache__/modules.cpython-38.pyc +0 -0
- decoder/__pycache__/modules.cpython-39.pyc +0 -0
- decoder/__pycache__/pretrained.cpython-310.pyc +0 -0
- decoder/__pycache__/pretrained.cpython-38.pyc +0 -0
- decoder/__pycache__/pretrained.cpython-39.pyc +0 -0
- decoder/__pycache__/pretrained_model.cpython-310.pyc +0 -0
- decoder/__pycache__/spectral_ops.cpython-310.pyc +0 -0
- decoder/__pycache__/spectral_ops.cpython-39.pyc +0 -0
- decoder/dataset.py +84 -0
- decoder/discriminator_dac.py +249 -0
- decoder/discriminators.py +202 -0
- decoder/experiment.py +474 -0
- decoder/feature_extractors.py +141 -0
- decoder/heads.py +157 -0
- decoder/helpers.py +71 -0
- decoder/loss.py +159 -0
- decoder/models.py +264 -0
- decoder/modules.py +213 -0
- decoder/pretrained.py +239 -0
- decoder/pretrained_model.py +192 -0
- decoder/spectral_ops.py +192 -0
- infer.py +73 -0
- train.py +15 -0
configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 3407
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
class_path: decoder.dataset.VocosDataModule
|
| 5 |
+
init_args:
|
| 6 |
+
train_params:
|
| 7 |
+
filelist_path: ./WavTokenizer/data/train/libritts_train
|
| 8 |
+
sampling_rate: 24000
|
| 9 |
+
num_samples: 72000
|
| 10 |
+
batch_size: 40 # 20
|
| 11 |
+
num_workers: 8
|
| 12 |
+
|
| 13 |
+
val_params:
|
| 14 |
+
filelist_path: ./WavTokenizer/data/infer/librttts_val
|
| 15 |
+
sampling_rate: 24000
|
| 16 |
+
num_samples: 72000
|
| 17 |
+
batch_size: 5 # 10
|
| 18 |
+
num_workers: 8
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
class_path: decoder.experiment.WavTokenizer
|
| 22 |
+
init_args:
|
| 23 |
+
sample_rate: 24000
|
| 24 |
+
initial_learning_rate: 2e-4
|
| 25 |
+
mel_loss_coeff: 45
|
| 26 |
+
mrd_loss_coeff: 1.0
|
| 27 |
+
num_warmup_steps: 0 # Optimizers warmup steps
|
| 28 |
+
pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration
|
| 29 |
+
|
| 30 |
+
# automatic evaluation
|
| 31 |
+
evaluate_utmos: true
|
| 32 |
+
evaluate_pesq: true
|
| 33 |
+
evaluate_periodicty: true
|
| 34 |
+
|
| 35 |
+
resume: false
|
| 36 |
+
resume_config: ./WavTokenizer/configs/wavtokenizer_smalldata_frame40_3s_nq1_code16384_dim512_kmeans800_attn.yaml
|
| 37 |
+
resume_model: ./version_3/checkpoints/xxx.ckpt
|
| 38 |
+
|
| 39 |
+
feature_extractor:
|
| 40 |
+
class_path: decoder.feature_extractors.EncodecFeatures
|
| 41 |
+
init_args:
|
| 42 |
+
encodec_model: encodec_24khz
|
| 43 |
+
bandwidths: [6.6, 6.6, 6.6, 6.6]
|
| 44 |
+
train_codebooks: true
|
| 45 |
+
num_quantizers: 1
|
| 46 |
+
dowmsamples: [6, 5, 5, 4]
|
| 47 |
+
vq_bins: 4096
|
| 48 |
+
vq_kmeans: 200
|
| 49 |
+
|
| 50 |
+
backbone:
|
| 51 |
+
class_path: decoder.models.VocosBackbone
|
| 52 |
+
init_args:
|
| 53 |
+
input_channels: 512
|
| 54 |
+
dim: 768
|
| 55 |
+
intermediate_dim: 2304
|
| 56 |
+
num_layers: 12
|
| 57 |
+
adanorm_num_embeddings: 4
|
| 58 |
+
|
| 59 |
+
head:
|
| 60 |
+
class_path: decoder.heads.ISTFTHead
|
| 61 |
+
init_args:
|
| 62 |
+
dim: 768
|
| 63 |
+
n_fft: 2400
|
| 64 |
+
hop_length: 600
|
| 65 |
+
padding: same
|
| 66 |
+
|
| 67 |
+
trainer:
|
| 68 |
+
logger:
|
| 69 |
+
class_path: pytorch_lightning.loggers.TensorBoardLogger
|
| 70 |
+
init_args:
|
| 71 |
+
save_dir: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn/
|
| 72 |
+
callbacks:
|
| 73 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 74 |
+
- class_path: pytorch_lightning.callbacks.ModelSummary
|
| 75 |
+
init_args:
|
| 76 |
+
max_depth: 2
|
| 77 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 78 |
+
init_args:
|
| 79 |
+
monitor: val_loss
|
| 80 |
+
filename: wavtokenizer_checkpoint_{epoch}_{step}_{val_loss:.4f}
|
| 81 |
+
save_top_k: 10
|
| 82 |
+
save_last: true
|
| 83 |
+
- class_path: decoder.helpers.GradNormCallback
|
| 84 |
+
|
| 85 |
+
# Lightning calculates max_steps across all optimizer steps (rather than number of batches)
|
| 86 |
+
# This equals to 1M steps per generator and 1M per discriminator
|
| 87 |
+
max_steps: 20000000
|
| 88 |
+
# You might want to limit val batches when evaluating all the metrics, as they are time-consuming
|
| 89 |
+
limit_val_batches: 200
|
| 90 |
+
accelerator: gpu
|
| 91 |
+
strategy: ddp
|
| 92 |
+
devices: [0,1,2,3,4,5,6,7]
|
| 93 |
+
log_every_n_steps: 1000
|
configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed_everything: 3407
|
| 2 |
+
|
| 3 |
+
data:
|
| 4 |
+
class_path: decoder.dataset.VocosDataModule
|
| 5 |
+
init_args:
|
| 6 |
+
train_params:
|
| 7 |
+
filelist_path: ./WavTokenizer/data/train/libritts_train
|
| 8 |
+
sampling_rate: 24000
|
| 9 |
+
num_samples: 72000
|
| 10 |
+
batch_size: 40 # 20
|
| 11 |
+
num_workers: 8
|
| 12 |
+
|
| 13 |
+
val_params:
|
| 14 |
+
filelist_path: ./WavTokenizer/data/infer/librttts_val
|
| 15 |
+
sampling_rate: 24000
|
| 16 |
+
num_samples: 72000
|
| 17 |
+
batch_size: 5 # 10
|
| 18 |
+
num_workers: 8
|
| 19 |
+
|
| 20 |
+
model:
|
| 21 |
+
class_path: decoder.experiment.WavTokenizer
|
| 22 |
+
init_args:
|
| 23 |
+
sample_rate: 24000
|
| 24 |
+
initial_learning_rate: 2e-4
|
| 25 |
+
mel_loss_coeff: 45
|
| 26 |
+
mrd_loss_coeff: 1.0
|
| 27 |
+
num_warmup_steps: 0 # Optimizers warmup steps
|
| 28 |
+
pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration
|
| 29 |
+
|
| 30 |
+
# automatic evaluation
|
| 31 |
+
evaluate_utmos: true
|
| 32 |
+
evaluate_pesq: true
|
| 33 |
+
evaluate_periodicty: true
|
| 34 |
+
|
| 35 |
+
resume: false
|
| 36 |
+
resume_config: ./WavTokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code16384_dim512_kmeans800_attn.yaml
|
| 37 |
+
resume_model: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame75_3s_nq1_code16384_dim512_kmeans800_attn/xxx.ckpt
|
| 38 |
+
|
| 39 |
+
feature_extractor:
|
| 40 |
+
class_path: decoder.feature_extractors.EncodecFeatures
|
| 41 |
+
init_args:
|
| 42 |
+
encodec_model: encodec_24khz
|
| 43 |
+
bandwidths: [6.6, 6.6, 6.6, 6.6]
|
| 44 |
+
train_codebooks: true
|
| 45 |
+
num_quantizers: 1
|
| 46 |
+
dowmsamples: [8, 5, 4, 2]
|
| 47 |
+
vq_bins: 4096
|
| 48 |
+
vq_kmeans: 200
|
| 49 |
+
|
| 50 |
+
backbone:
|
| 51 |
+
class_path: decoder.models.VocosBackbone
|
| 52 |
+
init_args:
|
| 53 |
+
input_channels: 512
|
| 54 |
+
dim: 768
|
| 55 |
+
intermediate_dim: 2304
|
| 56 |
+
num_layers: 12
|
| 57 |
+
adanorm_num_embeddings: 4
|
| 58 |
+
|
| 59 |
+
head:
|
| 60 |
+
class_path: decoder.heads.ISTFTHead
|
| 61 |
+
init_args:
|
| 62 |
+
dim: 768
|
| 63 |
+
n_fft: 1280
|
| 64 |
+
hop_length: 320
|
| 65 |
+
padding: same
|
| 66 |
+
|
| 67 |
+
trainer:
|
| 68 |
+
logger:
|
| 69 |
+
class_path: pytorch_lightning.loggers.TensorBoardLogger
|
| 70 |
+
init_args:
|
| 71 |
+
save_dir: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn/
|
| 72 |
+
callbacks:
|
| 73 |
+
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
|
| 74 |
+
- class_path: pytorch_lightning.callbacks.ModelSummary
|
| 75 |
+
init_args:
|
| 76 |
+
max_depth: 2
|
| 77 |
+
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
|
| 78 |
+
init_args:
|
| 79 |
+
monitor: val_loss
|
| 80 |
+
filename: wavtokenizer_checkpoint_{epoch}_{step}_{val_loss:.4f}
|
| 81 |
+
save_top_k: 10
|
| 82 |
+
save_last: true
|
| 83 |
+
- class_path: decoder.helpers.GradNormCallback
|
| 84 |
+
|
| 85 |
+
# Lightning calculates max_steps across all optimizer steps (rather than number of batches)
|
| 86 |
+
# This equals to 1M steps per generator and 1M per discriminator
|
| 87 |
+
max_steps: 20000000
|
| 88 |
+
# You might want to limit val batches when evaluating all the metrics, as they are time-consuming
|
| 89 |
+
limit_val_batches: 100
|
| 90 |
+
accelerator: gpu
|
| 91 |
+
strategy: ddp
|
| 92 |
+
devices: [0,1,2,3,4,5,6,7]
|
| 93 |
+
log_every_n_steps: 1000
|
data/demo.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
./example1.wav
|
| 2 |
+
./example2.wav
|
| 3 |
+
./example3.mp3
|
| 4 |
+
./example4.flac
|
decoder/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from decoder.pretrained import WavTokenizer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__version__ = "0.0.3"
|
decoder/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (224 Bytes). View file
|
|
|
decoder/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (234 Bytes). View file
|
|
|
decoder/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (233 Bytes). View file
|
|
|
decoder/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (3.34 kB). View file
|
|
|
decoder/__pycache__/discriminator_dac.cpython-310.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
decoder/__pycache__/discriminators.cpython-310.pyc
ADDED
|
Binary file (6.97 kB). View file
|
|
|
decoder/__pycache__/experiment.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
decoder/__pycache__/feature_extractors.cpython-310.pyc
ADDED
|
Binary file (4.66 kB). View file
|
|
|
decoder/__pycache__/feature_extractors.cpython-38.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
decoder/__pycache__/feature_extractors.cpython-39.pyc
ADDED
|
Binary file (4.43 kB). View file
|
|
|
decoder/__pycache__/heads.cpython-310.pyc
ADDED
|
Binary file (6.69 kB). View file
|
|
|
decoder/__pycache__/heads.cpython-39.pyc
ADDED
|
Binary file (6.63 kB). View file
|
|
|
decoder/__pycache__/helpers.cpython-310.pyc
ADDED
|
Binary file (2.73 kB). View file
|
|
|
decoder/__pycache__/loss.cpython-310.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
decoder/__pycache__/models.cpython-310.pyc
ADDED
|
Binary file (8.09 kB). View file
|
|
|
decoder/__pycache__/models.cpython-39.pyc
ADDED
|
Binary file (8.01 kB). View file
|
|
|
decoder/__pycache__/modules.cpython-310.pyc
ADDED
|
Binary file (6.64 kB). View file
|
|
|
decoder/__pycache__/modules.cpython-38.pyc
ADDED
|
Binary file (6.62 kB). View file
|
|
|
decoder/__pycache__/modules.cpython-39.pyc
ADDED
|
Binary file (6.59 kB). View file
|
|
|
decoder/__pycache__/pretrained.cpython-310.pyc
ADDED
|
Binary file (8.32 kB). View file
|
|
|
decoder/__pycache__/pretrained.cpython-38.pyc
ADDED
|
Binary file (8.08 kB). View file
|
|
|
decoder/__pycache__/pretrained.cpython-39.pyc
ADDED
|
Binary file (8.41 kB). View file
|
|
|
decoder/__pycache__/pretrained_model.cpython-310.pyc
ADDED
|
Binary file (7.12 kB). View file
|
|
|
decoder/__pycache__/spectral_ops.cpython-310.pyc
ADDED
|
Binary file (6.85 kB). View file
|
|
|
decoder/__pycache__/spectral_ops.cpython-39.pyc
ADDED
|
Binary file (6.9 kB). View file
|
|
|
decoder/dataset.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from pytorch_lightning import LightningDataModule
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
|
| 9 |
+
import soundfile
|
| 10 |
+
# import librosa
|
| 11 |
+
|
| 12 |
+
torch.set_num_threads(1)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class DataConfig:
|
| 17 |
+
filelist_path: str
|
| 18 |
+
sampling_rate: int
|
| 19 |
+
num_samples: int
|
| 20 |
+
batch_size: int
|
| 21 |
+
num_workers: int
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class VocosDataModule(LightningDataModule):
|
| 25 |
+
def __init__(self, train_params: DataConfig, val_params: DataConfig):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.train_config = train_params
|
| 28 |
+
self.val_config = val_params
|
| 29 |
+
|
| 30 |
+
def _get_dataloder(self, cfg: DataConfig, train: bool):
|
| 31 |
+
dataset = VocosDataset(cfg, train=train)
|
| 32 |
+
dataloader = DataLoader(
|
| 33 |
+
dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
|
| 34 |
+
)
|
| 35 |
+
return dataloader
|
| 36 |
+
|
| 37 |
+
def train_dataloader(self) -> DataLoader:
|
| 38 |
+
return self._get_dataloder(self.train_config, train=True)
|
| 39 |
+
|
| 40 |
+
def val_dataloader(self) -> DataLoader:
|
| 41 |
+
return self._get_dataloder(self.val_config, train=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class VocosDataset(Dataset):
|
| 45 |
+
def __init__(self, cfg: DataConfig, train: bool):
|
| 46 |
+
with open(cfg.filelist_path) as f:
|
| 47 |
+
self.filelist = f.read().splitlines()
|
| 48 |
+
self.sampling_rate = cfg.sampling_rate
|
| 49 |
+
self.num_samples = cfg.num_samples
|
| 50 |
+
self.train = train
|
| 51 |
+
|
| 52 |
+
def __len__(self) -> int:
|
| 53 |
+
return len(self.filelist)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, index: int) -> torch.Tensor:
|
| 56 |
+
audio_path = self.filelist[index]
|
| 57 |
+
# y, sr = torchaudio.load(audio_path)
|
| 58 |
+
# print(audio_path,"111")
|
| 59 |
+
y1, sr = soundfile.read(audio_path)
|
| 60 |
+
# y1, sr = librosa.load(audio_path,sr=None)
|
| 61 |
+
y = torch.tensor(y1).float().unsqueeze(0)
|
| 62 |
+
# if y.size(0) > 1:
|
| 63 |
+
# # mix to mono
|
| 64 |
+
# y = y.mean(dim=0, keepdim=True)
|
| 65 |
+
if y.ndim > 2:
|
| 66 |
+
# mix to mono
|
| 67 |
+
# print("有问题哈,数据处理部分")
|
| 68 |
+
y = y.mean(dim=-1, keepdim=False)
|
| 69 |
+
gain = np.random.uniform(-1, -6) if self.train else -3
|
| 70 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
|
| 71 |
+
if sr != self.sampling_rate:
|
| 72 |
+
y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
|
| 73 |
+
if y.size(-1) < self.num_samples:
|
| 74 |
+
pad_length = self.num_samples - y.size(-1)
|
| 75 |
+
padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
|
| 76 |
+
y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
|
| 77 |
+
elif self.train:
|
| 78 |
+
start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
|
| 79 |
+
y = y[:, start : start + self.num_samples]
|
| 80 |
+
else:
|
| 81 |
+
# During validation, take always the first segment for determinism
|
| 82 |
+
y = y[:, : self.num_samples]
|
| 83 |
+
|
| 84 |
+
return y[0]
|
decoder/discriminator_dac.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
# from audiotools import AudioSignal
|
| 5 |
+
# from audiotools import ml
|
| 6 |
+
# from audiotools import STFTParams
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch.nn.utils import weight_norm
|
| 9 |
+
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
|
| 12 |
+
STFTParams = namedtuple(
|
| 13 |
+
"STFTParams",
|
| 14 |
+
["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
STFTParams.__new__.__defaults__ = (None, None, None, None, None)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def WNConv1d(*args, **kwargs):
|
| 21 |
+
act = kwargs.pop("act", True)
|
| 22 |
+
conv = weight_norm(nn.Conv1d(*args, **kwargs))
|
| 23 |
+
if not act:
|
| 24 |
+
return conv
|
| 25 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def WNConv2d(*args, **kwargs):
|
| 29 |
+
act = kwargs.pop("act", True)
|
| 30 |
+
conv = weight_norm(nn.Conv2d(*args, **kwargs))
|
| 31 |
+
if not act:
|
| 32 |
+
return conv
|
| 33 |
+
return nn.Sequential(conv, nn.LeakyReLU(0.1))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class MPD(nn.Module):
|
| 37 |
+
def __init__(self, period):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.period = period
|
| 40 |
+
self.convs = nn.ModuleList(
|
| 41 |
+
[
|
| 42 |
+
WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
|
| 43 |
+
WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
|
| 44 |
+
WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
|
| 45 |
+
WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
|
| 46 |
+
WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
self.conv_post = WNConv2d(
|
| 50 |
+
1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
def pad_to_period(self, x):
|
| 54 |
+
t = x.shape[-1]
|
| 55 |
+
x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
fmap = []
|
| 60 |
+
|
| 61 |
+
x = self.pad_to_period(x)
|
| 62 |
+
x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
|
| 63 |
+
|
| 64 |
+
for layer in self.convs:
|
| 65 |
+
x = layer(x)
|
| 66 |
+
fmap.append(x)
|
| 67 |
+
|
| 68 |
+
x = self.conv_post(x)
|
| 69 |
+
fmap.append(x)
|
| 70 |
+
|
| 71 |
+
return fmap
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MSD(nn.Module):
|
| 75 |
+
def __init__(self, rate: int = 1, sample_rate: int = 24000):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.convs = nn.ModuleList(
|
| 78 |
+
[
|
| 79 |
+
WNConv1d(1, 16, 15, 1, padding=7),
|
| 80 |
+
WNConv1d(16, 64, 41, 4, groups=4, padding=20),
|
| 81 |
+
WNConv1d(64, 256, 41, 4, groups=16, padding=20),
|
| 82 |
+
WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
|
| 83 |
+
WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
|
| 84 |
+
WNConv1d(1024, 1024, 5, 1, padding=2),
|
| 85 |
+
]
|
| 86 |
+
)
|
| 87 |
+
self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
|
| 88 |
+
self.sample_rate = sample_rate
|
| 89 |
+
self.rate = rate
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
# x = AudioSignal(x, self.sample_rate)
|
| 93 |
+
# x.resample(self.sample_rate // self.rate)
|
| 94 |
+
# x = x.audio_data
|
| 95 |
+
|
| 96 |
+
fmap = []
|
| 97 |
+
|
| 98 |
+
for l in self.convs:
|
| 99 |
+
x = l(x)
|
| 100 |
+
fmap.append(x)
|
| 101 |
+
x = self.conv_post(x)
|
| 102 |
+
fmap.append(x)
|
| 103 |
+
|
| 104 |
+
return fmap
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class MRD(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
window_length: int,
|
| 114 |
+
hop_factor: float = 0.25,
|
| 115 |
+
sample_rate: int = 24000,
|
| 116 |
+
bands: list = BANDS,
|
| 117 |
+
):
|
| 118 |
+
"""Complex multi-band spectrogram discriminator.
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
window_length : int
|
| 122 |
+
Window length of STFT.
|
| 123 |
+
hop_factor : float, optional
|
| 124 |
+
Hop factor of the STFT, defaults to ``0.25 * window_length``.
|
| 125 |
+
sample_rate : int, optional
|
| 126 |
+
Sampling rate of audio in Hz, by default 24000
|
| 127 |
+
bands : list, optional
|
| 128 |
+
Bands to run discriminator over.
|
| 129 |
+
"""
|
| 130 |
+
super().__init__()
|
| 131 |
+
|
| 132 |
+
self.window_length = window_length
|
| 133 |
+
self.hop_factor = hop_factor
|
| 134 |
+
self.sample_rate = sample_rate
|
| 135 |
+
self.stft_params = STFTParams(
|
| 136 |
+
window_length=window_length,
|
| 137 |
+
hop_length=int(window_length * hop_factor),
|
| 138 |
+
match_stride=True,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
n_fft = window_length // 2 + 1
|
| 142 |
+
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
| 143 |
+
self.bands = bands
|
| 144 |
+
self.n_fft = window_length
|
| 145 |
+
|
| 146 |
+
ch = 32
|
| 147 |
+
convs = lambda: nn.ModuleList(
|
| 148 |
+
[
|
| 149 |
+
WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
|
| 150 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 151 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 152 |
+
WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
|
| 153 |
+
WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
| 157 |
+
self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
|
| 158 |
+
|
| 159 |
+
def spectrogram(self, x):
|
| 160 |
+
# x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
|
| 161 |
+
# x = torch.view_as_real(x.stft())
|
| 162 |
+
|
| 163 |
+
# x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size()
|
| 164 |
+
# breakpoint()
|
| 165 |
+
if x.size(0)==1:
|
| 166 |
+
# x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0))
|
| 167 |
+
x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0))
|
| 168 |
+
else:
|
| 169 |
+
# x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1))
|
| 170 |
+
x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1))
|
| 171 |
+
x = rearrange(x, "b 1 f t c -> (b 1) c t f")
|
| 172 |
+
# Split into bands
|
| 173 |
+
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
| 174 |
+
return x_bands
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
x_bands = self.spectrogram(x)
|
| 178 |
+
fmap = []
|
| 179 |
+
|
| 180 |
+
x = []
|
| 181 |
+
for band, stack in zip(x_bands, self.band_convs):
|
| 182 |
+
for layer in stack:
|
| 183 |
+
band = layer(band)
|
| 184 |
+
fmap.append(band)
|
| 185 |
+
x.append(band)
|
| 186 |
+
|
| 187 |
+
x = torch.cat(x, dim=-1)
|
| 188 |
+
x = self.conv_post(x)
|
| 189 |
+
fmap.append(x)
|
| 190 |
+
|
| 191 |
+
return fmap
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# class DACDiscriminator(ml.BaseModel):
|
| 195 |
+
class DACDiscriminator(nn.Module):
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
rates: list = [],
|
| 199 |
+
periods: list = [2, 3, 5, 7, 11],
|
| 200 |
+
fft_sizes: list = [2048, 1024, 512],
|
| 201 |
+
sample_rate: int = 24000,
|
| 202 |
+
bands: list = BANDS,
|
| 203 |
+
):
|
| 204 |
+
"""Discriminator that combines multiple discriminators.
|
| 205 |
+
|
| 206 |
+
Parameters
|
| 207 |
+
----------
|
| 208 |
+
rates : list, optional
|
| 209 |
+
sampling rates (in Hz) to run MSD at, by default []
|
| 210 |
+
If empty, MSD is not used.
|
| 211 |
+
periods : list, optional
|
| 212 |
+
periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
|
| 213 |
+
fft_sizes : list, optional
|
| 214 |
+
Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
|
| 215 |
+
sample_rate : int, optional
|
| 216 |
+
Sampling rate of audio in Hz, by default 24000
|
| 217 |
+
bands : list, optional
|
| 218 |
+
Bands to run MRD at, by default `BANDS`
|
| 219 |
+
"""
|
| 220 |
+
super().__init__()
|
| 221 |
+
discs = []
|
| 222 |
+
discs += [MPD(p) for p in periods]
|
| 223 |
+
discs += [MSD(r, sample_rate=sample_rate) for r in rates]
|
| 224 |
+
discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
|
| 225 |
+
self.discriminators = nn.ModuleList(discs)
|
| 226 |
+
|
| 227 |
+
def preprocess(self, y):
|
| 228 |
+
# Remove DC offset
|
| 229 |
+
y = y - y.mean(dim=-1, keepdims=True)
|
| 230 |
+
# Peak normalize the volume of input audio
|
| 231 |
+
y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
| 232 |
+
return y
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
x = self.preprocess(x)
|
| 236 |
+
fmaps = [d(x) for d in self.discriminators]
|
| 237 |
+
return fmaps
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
disc = DACDiscriminator()
|
| 242 |
+
x = torch.zeros(1, 1, 24000)
|
| 243 |
+
results = disc(x)
|
| 244 |
+
breakpoint()
|
| 245 |
+
for i, result in enumerate(results):
|
| 246 |
+
print(f"disc{i}")
|
| 247 |
+
for i, r in enumerate(result):
|
| 248 |
+
print(r.shape, r.mean(), r.min(), r.max())
|
| 249 |
+
print("00")
|
decoder/discriminators.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import Conv2d
|
| 6 |
+
from torch.nn.utils import weight_norm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MultiPeriodDiscriminator(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
|
| 12 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
periods (tuple[int]): Tuple of periods for each discriminator.
|
| 16 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 17 |
+
Defaults to None.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
|
| 23 |
+
|
| 24 |
+
def forward(
|
| 25 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 26 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 27 |
+
y_d_rs = []
|
| 28 |
+
y_d_gs = []
|
| 29 |
+
fmap_rs = []
|
| 30 |
+
fmap_gs = []
|
| 31 |
+
for d in self.discriminators:
|
| 32 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 33 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 34 |
+
y_d_rs.append(y_d_r)
|
| 35 |
+
fmap_rs.append(fmap_r)
|
| 36 |
+
y_d_gs.append(y_d_g)
|
| 37 |
+
fmap_gs.append(fmap_g)
|
| 38 |
+
|
| 39 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DiscriminatorP(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
period: int,
|
| 46 |
+
in_channels: int = 1,
|
| 47 |
+
kernel_size: int = 5,
|
| 48 |
+
stride: int = 3,
|
| 49 |
+
lrelu_slope: float = 0.1,
|
| 50 |
+
num_embeddings: int = None,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.period = period
|
| 54 |
+
self.convs = nn.ModuleList(
|
| 55 |
+
[
|
| 56 |
+
weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 57 |
+
weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 58 |
+
weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 59 |
+
weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
|
| 60 |
+
weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
|
| 61 |
+
]
|
| 62 |
+
)
|
| 63 |
+
if num_embeddings is not None:
|
| 64 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
|
| 65 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 66 |
+
|
| 67 |
+
self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 68 |
+
self.lrelu_slope = lrelu_slope
|
| 69 |
+
|
| 70 |
+
def forward(
|
| 71 |
+
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
| 72 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 73 |
+
x = x.unsqueeze(1)
|
| 74 |
+
fmap = []
|
| 75 |
+
# 1d to 2d
|
| 76 |
+
b, c, t = x.shape
|
| 77 |
+
if t % self.period != 0: # pad first
|
| 78 |
+
n_pad = self.period - (t % self.period)
|
| 79 |
+
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
| 80 |
+
t = t + n_pad
|
| 81 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 82 |
+
|
| 83 |
+
for i, l in enumerate(self.convs):
|
| 84 |
+
x = l(x)
|
| 85 |
+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
| 86 |
+
if i > 0:
|
| 87 |
+
fmap.append(x)
|
| 88 |
+
if cond_embedding_id is not None:
|
| 89 |
+
emb = self.emb(cond_embedding_id)
|
| 90 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 91 |
+
else:
|
| 92 |
+
h = 0
|
| 93 |
+
x = self.conv_post(x)
|
| 94 |
+
fmap.append(x)
|
| 95 |
+
x += h
|
| 96 |
+
x = torch.flatten(x, 1, -1)
|
| 97 |
+
|
| 98 |
+
return x, fmap
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
|
| 105 |
+
num_embeddings: int = None,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
|
| 109 |
+
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
|
| 113 |
+
Each resolution should be a tuple of (n_fft, hop_length, win_length).
|
| 114 |
+
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
| 115 |
+
Defaults to None.
|
| 116 |
+
"""
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.discriminators = nn.ModuleList(
|
| 119 |
+
[DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
| 124 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
| 125 |
+
y_d_rs = []
|
| 126 |
+
y_d_gs = []
|
| 127 |
+
fmap_rs = []
|
| 128 |
+
fmap_gs = []
|
| 129 |
+
|
| 130 |
+
for d in self.discriminators:
|
| 131 |
+
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
| 132 |
+
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
| 133 |
+
y_d_rs.append(y_d_r)
|
| 134 |
+
fmap_rs.append(fmap_r)
|
| 135 |
+
y_d_gs.append(y_d_g)
|
| 136 |
+
fmap_gs.append(fmap_g)
|
| 137 |
+
|
| 138 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class DiscriminatorR(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
resolution: Tuple[int, int, int],
|
| 145 |
+
channels: int = 64,
|
| 146 |
+
in_channels: int = 1,
|
| 147 |
+
num_embeddings: int = None,
|
| 148 |
+
lrelu_slope: float = 0.1,
|
| 149 |
+
):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.resolution = resolution
|
| 152 |
+
self.in_channels = in_channels
|
| 153 |
+
self.lrelu_slope = lrelu_slope
|
| 154 |
+
self.convs = nn.ModuleList(
|
| 155 |
+
[
|
| 156 |
+
weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
|
| 157 |
+
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
|
| 158 |
+
weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
|
| 159 |
+
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
|
| 160 |
+
weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
|
| 161 |
+
]
|
| 162 |
+
)
|
| 163 |
+
if num_embeddings is not None:
|
| 164 |
+
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
| 165 |
+
torch.nn.init.zeros_(self.emb.weight)
|
| 166 |
+
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
|
| 170 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 171 |
+
fmap = []
|
| 172 |
+
x = self.spectrogram(x)
|
| 173 |
+
x = x.unsqueeze(1)
|
| 174 |
+
for l in self.convs:
|
| 175 |
+
x = l(x)
|
| 176 |
+
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
| 177 |
+
fmap.append(x)
|
| 178 |
+
if cond_embedding_id is not None:
|
| 179 |
+
emb = self.emb(cond_embedding_id)
|
| 180 |
+
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
| 181 |
+
else:
|
| 182 |
+
h = 0
|
| 183 |
+
x = self.conv_post(x)
|
| 184 |
+
fmap.append(x)
|
| 185 |
+
x += h
|
| 186 |
+
x = torch.flatten(x, 1, -1)
|
| 187 |
+
|
| 188 |
+
return x, fmap
|
| 189 |
+
|
| 190 |
+
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
n_fft, hop_length, win_length = self.resolution
|
| 192 |
+
magnitude_spectrogram = torch.stft(
|
| 193 |
+
x,
|
| 194 |
+
n_fft=n_fft,
|
| 195 |
+
hop_length=hop_length,
|
| 196 |
+
win_length=win_length,
|
| 197 |
+
window=None, # interestingly rectangular window kind of works here
|
| 198 |
+
center=True,
|
| 199 |
+
return_complex=True,
|
| 200 |
+
).abs()
|
| 201 |
+
|
| 202 |
+
return magnitude_spectrogram
|
decoder/experiment.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import transformers
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from decoder.discriminator_dac import DACDiscriminator
|
| 11 |
+
|
| 12 |
+
from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
| 13 |
+
from decoder.feature_extractors import FeatureExtractor
|
| 14 |
+
from decoder.heads import FourierHead
|
| 15 |
+
from decoder.helpers import plot_spectrogram_to_numpy
|
| 16 |
+
from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss
|
| 17 |
+
from decoder.models import Backbone
|
| 18 |
+
from decoder.modules import safe_log
|
| 19 |
+
from decoder.pretrained_model import instantiate_class
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VocosExp(pl.LightningModule):
|
| 23 |
+
# noinspection PyUnusedLocal
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_extractor: FeatureExtractor,
|
| 27 |
+
backbone: Backbone,
|
| 28 |
+
head: FourierHead,
|
| 29 |
+
resume_config: str,
|
| 30 |
+
resume_model: str,
|
| 31 |
+
sample_rate: int = 24000,
|
| 32 |
+
initial_learning_rate: float = 2e-4,
|
| 33 |
+
num_warmup_steps: int = 0,
|
| 34 |
+
mel_loss_coeff: float = 45,
|
| 35 |
+
mrd_loss_coeff: float = 1.0,
|
| 36 |
+
pretrain_mel_steps: int = 0,
|
| 37 |
+
decay_mel_coeff: bool = False,
|
| 38 |
+
evaluate_utmos: bool = False,
|
| 39 |
+
evaluate_pesq: bool = False,
|
| 40 |
+
evaluate_periodicty: bool = False,
|
| 41 |
+
resume: bool = False,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Args:
|
| 45 |
+
feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
|
| 46 |
+
backbone (Backbone): An instance of Backbone model.
|
| 47 |
+
head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
|
| 48 |
+
sample_rate (int): Sampling rate of the audio signals.
|
| 49 |
+
initial_learning_rate (float): Initial learning rate for the optimizer.
|
| 50 |
+
num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
|
| 51 |
+
mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
|
| 52 |
+
mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
|
| 53 |
+
pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
|
| 54 |
+
decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
|
| 55 |
+
evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
|
| 56 |
+
evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
|
| 57 |
+
evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
|
| 58 |
+
"""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
|
| 61 |
+
|
| 62 |
+
self.feature_extractor = feature_extractor
|
| 63 |
+
self.backbone = backbone
|
| 64 |
+
self.head = head
|
| 65 |
+
|
| 66 |
+
self.resume_config = resume_config
|
| 67 |
+
self.resume_model = resume_model
|
| 68 |
+
self.resume = resume
|
| 69 |
+
|
| 70 |
+
self.multiperioddisc = MultiPeriodDiscriminator()
|
| 71 |
+
self.multiresddisc = MultiResolutionDiscriminator()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
self.dac = DACDiscriminator()
|
| 75 |
+
|
| 76 |
+
self.dacdiscriminator = DACGANLoss(self.dac)
|
| 77 |
+
|
| 78 |
+
self.disc_loss = DiscriminatorLoss()
|
| 79 |
+
self.gen_loss = GeneratorLoss()
|
| 80 |
+
self.feat_matching_loss = FeatureMatchingLoss()
|
| 81 |
+
self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
|
| 82 |
+
|
| 83 |
+
self.train_discriminator = False
|
| 84 |
+
self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
|
| 85 |
+
|
| 86 |
+
def configure_optimizers(self):
|
| 87 |
+
disc_params = [
|
| 88 |
+
{"params": self.multiperioddisc.parameters()},
|
| 89 |
+
{"params": self.multiresddisc.parameters()},
|
| 90 |
+
{"params": self.dac.parameters()},
|
| 91 |
+
]
|
| 92 |
+
gen_params = [
|
| 93 |
+
{"params": self.feature_extractor.parameters()},
|
| 94 |
+
{"params": self.backbone.parameters()},
|
| 95 |
+
{"params": self.head.parameters()},
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
|
| 99 |
+
opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
|
| 100 |
+
|
| 101 |
+
max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
|
| 102 |
+
scheduler_disc = transformers.get_cosine_schedule_with_warmup(
|
| 103 |
+
opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
|
| 104 |
+
)
|
| 105 |
+
scheduler_gen = transformers.get_cosine_schedule_with_warmup(
|
| 106 |
+
opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
[opt_disc, opt_gen],
|
| 111 |
+
[{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def forward(self, audio_input, **kwargs):
|
| 115 |
+
features, _, commit_loss = self.feature_extractor(audio_input, **kwargs)
|
| 116 |
+
# print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
|
| 117 |
+
x = self.backbone(features, **kwargs)
|
| 118 |
+
audio_output = self.head(x)
|
| 119 |
+
return audio_output, commit_loss
|
| 120 |
+
|
| 121 |
+
def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
|
| 122 |
+
audio_input = batch
|
| 123 |
+
|
| 124 |
+
# train discriminator
|
| 125 |
+
if optimizer_idx == 0 and self.train_discriminator:
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
audio_hat, _ = self(audio_input, **kwargs)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
|
| 131 |
+
|
| 132 |
+
real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
| 133 |
+
real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
|
| 134 |
+
loss_mp, loss_mp_real, _ = self.disc_loss(
|
| 135 |
+
disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
|
| 136 |
+
)
|
| 137 |
+
loss_mrd, loss_mrd_real, _ = self.disc_loss(
|
| 138 |
+
disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
|
| 139 |
+
)
|
| 140 |
+
loss_mp /= len(loss_mp_real)
|
| 141 |
+
loss_mrd /= len(loss_mrd_real)
|
| 142 |
+
loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac
|
| 143 |
+
|
| 144 |
+
self.log("discriminator/total", loss, prog_bar=True)
|
| 145 |
+
self.log("discriminator/multi_period_loss", loss_mp)
|
| 146 |
+
self.log("discriminator/multi_res_loss", loss_mrd)
|
| 147 |
+
self.log("discriminator/dac", loss_dac)
|
| 148 |
+
return loss
|
| 149 |
+
|
| 150 |
+
# train generator
|
| 151 |
+
if optimizer_idx == 1:
|
| 152 |
+
audio_hat, commit_loss = self(audio_input, **kwargs)
|
| 153 |
+
if self.train_discriminator:
|
| 154 |
+
|
| 155 |
+
loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
|
| 156 |
+
_, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
|
| 157 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
| 158 |
+
)
|
| 159 |
+
_, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
|
| 160 |
+
y=audio_input, y_hat=audio_hat, **kwargs,
|
| 161 |
+
)
|
| 162 |
+
loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
|
| 163 |
+
loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
|
| 164 |
+
loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
|
| 165 |
+
loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
|
| 166 |
+
loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
|
| 167 |
+
loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
|
| 168 |
+
|
| 169 |
+
self.log("generator/multi_period_loss", loss_gen_mp)
|
| 170 |
+
self.log("generator/multi_res_loss", loss_gen_mrd)
|
| 171 |
+
self.log("generator/feature_matching_mp", loss_fm_mp)
|
| 172 |
+
self.log("generator/feature_matching_mrd", loss_fm_mrd)
|
| 173 |
+
self.log("generator/loss_dac_1", loss_dac_1)
|
| 174 |
+
self.log("generator/loss_dac_2", loss_dac_2)
|
| 175 |
+
else:
|
| 176 |
+
loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
|
| 177 |
+
|
| 178 |
+
mel_loss = self.melspec_loss(audio_hat, audio_input)
|
| 179 |
+
loss = (
|
| 180 |
+
loss_gen_mp
|
| 181 |
+
+ self.hparams.mrd_loss_coeff * loss_gen_mrd
|
| 182 |
+
+ loss_fm_mp
|
| 183 |
+
+ self.hparams.mrd_loss_coeff * loss_fm_mrd
|
| 184 |
+
+ self.mel_loss_coeff * mel_loss
|
| 185 |
+
+ 1000 * commit_loss
|
| 186 |
+
+ loss_dac_1
|
| 187 |
+
+ loss_dac_2
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
self.log("generator/total_loss", loss, prog_bar=True)
|
| 191 |
+
self.log("mel_loss_coeff", self.mel_loss_coeff)
|
| 192 |
+
self.log("generator/mel_loss", mel_loss)
|
| 193 |
+
self.log("commit_loss", commit_loss)
|
| 194 |
+
|
| 195 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
| 196 |
+
self.logger.experiment.add_audio(
|
| 197 |
+
"train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
|
| 198 |
+
)
|
| 199 |
+
self.logger.experiment.add_audio(
|
| 200 |
+
"train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
|
| 201 |
+
)
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
|
| 204 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
|
| 205 |
+
self.logger.experiment.add_image(
|
| 206 |
+
"train/mel_target",
|
| 207 |
+
plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
|
| 208 |
+
self.global_step,
|
| 209 |
+
dataformats="HWC",
|
| 210 |
+
)
|
| 211 |
+
self.logger.experiment.add_image(
|
| 212 |
+
"train/mel_pred",
|
| 213 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
| 214 |
+
self.global_step,
|
| 215 |
+
dataformats="HWC",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
return loss
|
| 219 |
+
|
| 220 |
+
def on_validation_epoch_start(self):
|
| 221 |
+
if self.hparams.evaluate_utmos:
|
| 222 |
+
from metrics.UTMOS import UTMOSScore
|
| 223 |
+
|
| 224 |
+
if not hasattr(self, "utmos_model"):
|
| 225 |
+
self.utmos_model = UTMOSScore(device=self.device)
|
| 226 |
+
|
| 227 |
+
def validation_step(self, batch, batch_idx, **kwargs):
|
| 228 |
+
audio_input = batch
|
| 229 |
+
audio_hat, commit_loss = self(audio_input, **kwargs)
|
| 230 |
+
|
| 231 |
+
audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
|
| 232 |
+
audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
|
| 233 |
+
|
| 234 |
+
if self.hparams.evaluate_periodicty:
|
| 235 |
+
from metrics.periodicity import calculate_periodicity_metrics
|
| 236 |
+
|
| 237 |
+
periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
|
| 238 |
+
else:
|
| 239 |
+
periodicity_loss = pitch_loss = f1_score = 0
|
| 240 |
+
|
| 241 |
+
if self.hparams.evaluate_utmos:
|
| 242 |
+
utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
|
| 243 |
+
else:
|
| 244 |
+
utmos_score = torch.zeros(1, device=self.device)
|
| 245 |
+
|
| 246 |
+
if self.hparams.evaluate_pesq:
|
| 247 |
+
from pesq import pesq
|
| 248 |
+
|
| 249 |
+
pesq_score = 0
|
| 250 |
+
for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
|
| 251 |
+
pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
|
| 252 |
+
pesq_score /= len(audio_16_khz)
|
| 253 |
+
pesq_score = torch.tensor(pesq_score)
|
| 254 |
+
else:
|
| 255 |
+
pesq_score = torch.zeros(1, device=self.device)
|
| 256 |
+
|
| 257 |
+
mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
|
| 258 |
+
total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss
|
| 259 |
+
|
| 260 |
+
return {
|
| 261 |
+
"val_loss": total_loss,
|
| 262 |
+
"mel_loss": mel_loss,
|
| 263 |
+
"utmos_score": utmos_score,
|
| 264 |
+
"pesq_score": pesq_score,
|
| 265 |
+
"periodicity_loss": periodicity_loss,
|
| 266 |
+
"pitch_loss": pitch_loss,
|
| 267 |
+
"f1_score": f1_score,
|
| 268 |
+
"audio_input": audio_input[0],
|
| 269 |
+
"audio_pred": audio_hat[0],
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
def validation_epoch_end(self, outputs):
|
| 273 |
+
if self.global_rank == 0:
|
| 274 |
+
*_, audio_in, audio_pred = outputs[0].values()
|
| 275 |
+
self.logger.experiment.add_audio(
|
| 276 |
+
"val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
|
| 277 |
+
)
|
| 278 |
+
self.logger.experiment.add_audio(
|
| 279 |
+
"val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
|
| 280 |
+
)
|
| 281 |
+
mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
|
| 282 |
+
mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
|
| 283 |
+
self.logger.experiment.add_image(
|
| 284 |
+
"val_mel_target",
|
| 285 |
+
plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
|
| 286 |
+
self.global_step,
|
| 287 |
+
dataformats="HWC",
|
| 288 |
+
)
|
| 289 |
+
self.logger.experiment.add_image(
|
| 290 |
+
"val_mel_hat",
|
| 291 |
+
plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
|
| 292 |
+
self.global_step,
|
| 293 |
+
dataformats="HWC",
|
| 294 |
+
)
|
| 295 |
+
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
| 296 |
+
mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
|
| 297 |
+
utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
|
| 298 |
+
pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
|
| 299 |
+
periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
|
| 300 |
+
pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
|
| 301 |
+
f1_score = np.array([x["f1_score"] for x in outputs]).mean()
|
| 302 |
+
|
| 303 |
+
self.log("val_loss", avg_loss, sync_dist=True)
|
| 304 |
+
self.log("val/mel_loss", mel_loss, sync_dist=True)
|
| 305 |
+
self.log("val/utmos_score", utmos_score, sync_dist=True)
|
| 306 |
+
self.log("val/pesq_score", pesq_score, sync_dist=True)
|
| 307 |
+
self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
|
| 308 |
+
self.log("val/pitch_loss", pitch_loss, sync_dist=True)
|
| 309 |
+
self.log("val/f1_score", f1_score, sync_dist=True)
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def global_step(self):
|
| 313 |
+
"""
|
| 314 |
+
Override global_step so that it returns the total number of batches processed
|
| 315 |
+
"""
|
| 316 |
+
return self.trainer.fit_loop.epoch_loop.total_batch_idx
|
| 317 |
+
|
| 318 |
+
def on_train_batch_start(self, *args):
|
| 319 |
+
if self.global_step >= self.hparams.pretrain_mel_steps:
|
| 320 |
+
self.train_discriminator = True
|
| 321 |
+
else:
|
| 322 |
+
self.train_discriminator = False
|
| 323 |
+
|
| 324 |
+
def on_train_batch_end(self, *args):
|
| 325 |
+
def mel_loss_coeff_decay(current_step, num_cycles=0.5):
|
| 326 |
+
max_steps = self.trainer.max_steps // 2
|
| 327 |
+
if current_step < self.hparams.num_warmup_steps:
|
| 328 |
+
return 1.0
|
| 329 |
+
progress = float(current_step - self.hparams.num_warmup_steps) / float(
|
| 330 |
+
max(1, max_steps - self.hparams.num_warmup_steps)
|
| 331 |
+
)
|
| 332 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 333 |
+
|
| 334 |
+
if self.hparams.decay_mel_coeff:
|
| 335 |
+
self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
class WavTokenizer(VocosExp):
|
| 339 |
+
"""
|
| 340 |
+
WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
|
| 341 |
+
It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
|
| 342 |
+
a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
|
| 343 |
+
while during validation, a fixed bandwidth_id is used.
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def __init__(
|
| 347 |
+
self,
|
| 348 |
+
feature_extractor: FeatureExtractor,
|
| 349 |
+
backbone: Backbone,
|
| 350 |
+
head: FourierHead,
|
| 351 |
+
resume_config: str,
|
| 352 |
+
resume_model: str,
|
| 353 |
+
sample_rate: int = 24000,
|
| 354 |
+
initial_learning_rate: float = 2e-4,
|
| 355 |
+
num_warmup_steps: int = 0,
|
| 356 |
+
mel_loss_coeff: float = 45,
|
| 357 |
+
mrd_loss_coeff: float = 1.0,
|
| 358 |
+
pretrain_mel_steps: int = 0,
|
| 359 |
+
decay_mel_coeff: bool = False,
|
| 360 |
+
evaluate_utmos: bool = False,
|
| 361 |
+
evaluate_pesq: bool = False,
|
| 362 |
+
evaluate_periodicty: bool = False,
|
| 363 |
+
resume: bool = False,
|
| 364 |
+
):
|
| 365 |
+
super().__init__(
|
| 366 |
+
feature_extractor,
|
| 367 |
+
backbone,
|
| 368 |
+
head,
|
| 369 |
+
resume_config,
|
| 370 |
+
resume_model,
|
| 371 |
+
sample_rate,
|
| 372 |
+
initial_learning_rate,
|
| 373 |
+
num_warmup_steps,
|
| 374 |
+
mel_loss_coeff,
|
| 375 |
+
mrd_loss_coeff,
|
| 376 |
+
pretrain_mel_steps,
|
| 377 |
+
decay_mel_coeff,
|
| 378 |
+
evaluate_utmos,
|
| 379 |
+
evaluate_pesq,
|
| 380 |
+
evaluate_periodicty,
|
| 381 |
+
resume
|
| 382 |
+
)
|
| 383 |
+
# Override with conditional discriminators
|
| 384 |
+
# VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
|
| 385 |
+
# if self.resume:
|
| 386 |
+
# VocosExp.load_from_checkpoint(self.resume_model)
|
| 387 |
+
self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
| 388 |
+
self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
|
| 389 |
+
self.dac = DACDiscriminator()
|
| 390 |
+
if self.resume:
|
| 391 |
+
print('加载预训练模型:', self.resume_model)
|
| 392 |
+
# with open(self.resume_config, "r") as f:
|
| 393 |
+
# config = yaml.safe_load(f)
|
| 394 |
+
# feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
|
| 395 |
+
# backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
|
| 396 |
+
# head = instantiate_class(args=(), init=config['model']['init_args']["head"])
|
| 397 |
+
|
| 398 |
+
# 不加载量化器部分权重
|
| 399 |
+
state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict']
|
| 400 |
+
state_dict_fa_qa = dict()
|
| 401 |
+
state_dict_fa_en = dict()
|
| 402 |
+
state_dict_fa_de = dict()
|
| 403 |
+
state_dict_bb = dict()
|
| 404 |
+
state_dict_hd = dict()
|
| 405 |
+
state_dict_mp = dict()
|
| 406 |
+
state_dict_mr = dict()
|
| 407 |
+
state_dict_dac = dict()
|
| 408 |
+
for k, v in state_dict_raw.items():
|
| 409 |
+
# breakpoint()
|
| 410 |
+
if k.startswith('feature_extractor.encodec.quantizer'):
|
| 411 |
+
# breakpoint()
|
| 412 |
+
# print("*****",k)
|
| 413 |
+
ss = k[46:48]
|
| 414 |
+
if ss[-1] == '.':
|
| 415 |
+
num = int(ss[0])
|
| 416 |
+
# print("num,k",num,k[36:])
|
| 417 |
+
if num <= 7:
|
| 418 |
+
state_dict_fa_qa[k[36:]] = v
|
| 419 |
+
if k.startswith('feature_extractor.encodec.encoder'):
|
| 420 |
+
state_dict_fa_en[k[34:]] = v
|
| 421 |
+
if k.startswith('feature_extractor.encodec.decoder'):
|
| 422 |
+
state_dict_fa_de[k[34:]] = v
|
| 423 |
+
if k.startswith('backbone.'):
|
| 424 |
+
state_dict_bb[k[9:]] = v
|
| 425 |
+
if k.startswith('head.'):
|
| 426 |
+
state_dict_hd[k[5:]] = v
|
| 427 |
+
if k.startswith('multiperioddisc.'):
|
| 428 |
+
state_dict_mp[k[16:]] = v
|
| 429 |
+
if k.startswith('multiresddisc.'):
|
| 430 |
+
state_dict_mr[k[14:]] = v
|
| 431 |
+
if k.startswith('dac.'):
|
| 432 |
+
state_dict_dac[k[4:]] = v
|
| 433 |
+
# breakpoint()
|
| 434 |
+
# feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
|
| 435 |
+
feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True)
|
| 436 |
+
feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True)
|
| 437 |
+
feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
|
| 438 |
+
backbone.load_state_dict(state_dict_bb, strict=True)
|
| 439 |
+
head.load_state_dict(state_dict_hd, strict=True)
|
| 440 |
+
self.feature_extractor = feature_extractor.to(self.device)
|
| 441 |
+
self.backbone = backbone.to(self.device)
|
| 442 |
+
self.head = head.to(self.device)
|
| 443 |
+
self.multiperioddisc.load_state_dict(state_dict_mp, strict=True)
|
| 444 |
+
self.multiresddisc.load_state_dict(state_dict_mr, strict=True)
|
| 445 |
+
self.dac.load_state_dict(state_dict_dac, strict=True)
|
| 446 |
+
|
| 447 |
+
def training_step(self, *args):
|
| 448 |
+
# print('-------------------train--------------------')
|
| 449 |
+
# if self.global_rank == 0 and self.resume:
|
| 450 |
+
# config_path = self.resume_config
|
| 451 |
+
# model_path = self.resume_model
|
| 452 |
+
# self.pretrained_load(config_path, model_path)
|
| 453 |
+
# print('加载预训练模型:', model_path)
|
| 454 |
+
bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
|
| 455 |
+
output = super().training_step(*args, bandwidth_id=bandwidth_id)
|
| 456 |
+
return output
|
| 457 |
+
|
| 458 |
+
def validation_step(self, *args):
|
| 459 |
+
# print('-------------------valid--------------------')
|
| 460 |
+
bandwidth_id = torch.tensor([0], device=self.device)
|
| 461 |
+
output = super().validation_step(*args, bandwidth_id=bandwidth_id)
|
| 462 |
+
return output
|
| 463 |
+
|
| 464 |
+
def validation_epoch_end(self, outputs):
|
| 465 |
+
if self.global_rank == 0:
|
| 466 |
+
*_, audio_in, _ = outputs[0].values()
|
| 467 |
+
# Resynthesis with encodec for reference
|
| 468 |
+
self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
|
| 469 |
+
encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
|
| 470 |
+
self.logger.experiment.add_audio(
|
| 471 |
+
"encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
super().validation_epoch_end(outputs)
|
decoder/feature_extractors.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torch import nn
|
| 6 |
+
import math
|
| 7 |
+
from decoder.modules import safe_log
|
| 8 |
+
from encoder.modules import SEANetEncoder, SEANetDecoder
|
| 9 |
+
from encoder import EncodecModel
|
| 10 |
+
from encoder.quantization import ResidualVectorQuantizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FeatureExtractor(nn.Module):
|
| 14 |
+
"""Base class for feature extractors."""
|
| 15 |
+
|
| 16 |
+
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 17 |
+
"""
|
| 18 |
+
Extract features from the given audio.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
audio (Tensor): Input audio waveform.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
|
| 25 |
+
C denotes output features, and L is the sequence length.
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MelSpectrogramFeatures(FeatureExtractor):
|
| 31 |
+
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
|
| 32 |
+
super().__init__()
|
| 33 |
+
if padding not in ["center", "same"]:
|
| 34 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 35 |
+
self.padding = padding
|
| 36 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
| 37 |
+
sample_rate=sample_rate,
|
| 38 |
+
n_fft=n_fft,
|
| 39 |
+
hop_length=hop_length,
|
| 40 |
+
n_mels=n_mels,
|
| 41 |
+
center=padding == "center",
|
| 42 |
+
power=1,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, audio, **kwargs):
|
| 46 |
+
if self.padding == "same":
|
| 47 |
+
pad = self.mel_spec.win_length - self.mel_spec.hop_length
|
| 48 |
+
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
|
| 49 |
+
mel = self.mel_spec(audio)
|
| 50 |
+
features = safe_log(mel)
|
| 51 |
+
return features
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class EncodecFeatures(FeatureExtractor):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
encodec_model: str = "encodec_24khz",
|
| 58 |
+
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
|
| 59 |
+
train_codebooks: bool = False,
|
| 60 |
+
num_quantizers: int = 1,
|
| 61 |
+
dowmsamples: List[int] = [6, 5, 5, 4],
|
| 62 |
+
vq_bins: int = 16384,
|
| 63 |
+
vq_kmeans: int = 800,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
# breakpoint()
|
| 68 |
+
self.frame_rate = 25 # not use
|
| 69 |
+
# n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
|
| 70 |
+
n_q = num_quantizers # important
|
| 71 |
+
encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
|
| 72 |
+
dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
|
| 73 |
+
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
|
| 74 |
+
true_skip=False, compress=2)
|
| 75 |
+
decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
|
| 76 |
+
dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
|
| 77 |
+
kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
|
| 78 |
+
true_skip=False, compress=2)
|
| 79 |
+
quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
|
| 80 |
+
decay=0.99, kmeans_init=True)
|
| 81 |
+
|
| 82 |
+
# breakpoint()
|
| 83 |
+
if encodec_model == "encodec_24khz":
|
| 84 |
+
self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
|
| 85 |
+
target_bandwidths=bandwidths, sample_rate=24000, channels=1)
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(
|
| 88 |
+
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
|
| 89 |
+
)
|
| 90 |
+
for param in self.encodec.parameters():
|
| 91 |
+
param.requires_grad = True
|
| 92 |
+
# self.num_q = n_q
|
| 93 |
+
# codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
|
| 94 |
+
# self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
|
| 95 |
+
self.bandwidths = bandwidths
|
| 96 |
+
|
| 97 |
+
# @torch.no_grad()
|
| 98 |
+
# def get_encodec_codes(self, audio):
|
| 99 |
+
# audio = audio.unsqueeze(1)
|
| 100 |
+
# emb = self.encodec.encoder(audio)
|
| 101 |
+
# codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
|
| 102 |
+
# return codes
|
| 103 |
+
|
| 104 |
+
def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
|
| 105 |
+
if self.training:
|
| 106 |
+
self.encodec.train()
|
| 107 |
+
|
| 108 |
+
audio = audio.unsqueeze(1) # audio(16,24000)
|
| 109 |
+
|
| 110 |
+
# breakpoint()
|
| 111 |
+
|
| 112 |
+
emb = self.encodec.encoder(audio)
|
| 113 |
+
q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
|
| 114 |
+
quantized = q_res.quantized
|
| 115 |
+
codes = q_res.codes
|
| 116 |
+
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
|
| 117 |
+
|
| 118 |
+
return quantized, codes, commit_loss
|
| 119 |
+
|
| 120 |
+
# codes = self.get_encodec_codes(audio)
|
| 121 |
+
# # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
|
| 122 |
+
# # with offsets given by the number of bins, and finally summed in a vectorized operation.
|
| 123 |
+
# offsets = torch.arange(
|
| 124 |
+
# 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
|
| 125 |
+
# )
|
| 126 |
+
# embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
| 127 |
+
# features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
|
| 128 |
+
# return features.transpose(1, 2)
|
| 129 |
+
|
| 130 |
+
def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
|
| 131 |
+
if self.training:
|
| 132 |
+
self.encodec.train()
|
| 133 |
+
|
| 134 |
+
audio = audio.unsqueeze(1) # audio(16,24000)
|
| 135 |
+
emb = self.encodec.encoder(audio)
|
| 136 |
+
q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
|
| 137 |
+
quantized = q_res.quantized
|
| 138 |
+
codes = q_res.codes
|
| 139 |
+
commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
|
| 140 |
+
|
| 141 |
+
return quantized, codes, commit_loss
|
decoder/heads.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
|
| 4 |
+
|
| 5 |
+
from decoder.spectral_ops import IMDCT, ISTFT
|
| 6 |
+
from decoder.modules import symexp
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class FourierHead(nn.Module):
|
| 10 |
+
"""Base class for inverse fourier modules."""
|
| 11 |
+
|
| 12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
"""
|
| 14 |
+
Args:
|
| 15 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 16 |
+
L is the sequence length, and H denotes the model dimension.
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 20 |
+
"""
|
| 21 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ISTFTHead(FourierHead):
|
| 25 |
+
"""
|
| 26 |
+
ISTFT Head module for predicting STFT complex coefficients.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dim (int): Hidden dimension of the model.
|
| 30 |
+
n_fft (int): Size of Fourier transform.
|
| 31 |
+
hop_length (int): The distance between neighboring sliding window frames, which should align with
|
| 32 |
+
the resolution of the input features.
|
| 33 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
|
| 37 |
+
super().__init__()
|
| 38 |
+
out_dim = n_fft + 2
|
| 39 |
+
self.out = torch.nn.Linear(dim, out_dim)
|
| 40 |
+
self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
"""
|
| 44 |
+
Forward pass of the ISTFTHead module.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 48 |
+
L is the sequence length, and H denotes the model dimension.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 52 |
+
"""
|
| 53 |
+
x = self.out(x).transpose(1, 2)
|
| 54 |
+
mag, p = x.chunk(2, dim=1)
|
| 55 |
+
mag = torch.exp(mag)
|
| 56 |
+
mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
|
| 57 |
+
# wrapping happens here. These two lines produce real and imaginary value
|
| 58 |
+
x = torch.cos(p)
|
| 59 |
+
y = torch.sin(p)
|
| 60 |
+
# recalculating phase here does not produce anything new
|
| 61 |
+
# only costs time
|
| 62 |
+
# phase = torch.atan2(y, x)
|
| 63 |
+
# S = mag * torch.exp(phase * 1j)
|
| 64 |
+
# better directly produce the complex value
|
| 65 |
+
S = mag * (x + 1j * y)
|
| 66 |
+
audio = self.istft(S)
|
| 67 |
+
return audio
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class IMDCTSymExpHead(FourierHead):
|
| 71 |
+
"""
|
| 72 |
+
IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
dim (int): Hidden dimension of the model.
|
| 76 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 77 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 78 |
+
sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
|
| 79 |
+
based on perceptual scaling. Defaults to None.
|
| 80 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
out_dim = mdct_frame_len // 2
|
| 88 |
+
self.out = nn.Linear(dim, out_dim)
|
| 89 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 90 |
+
self.clip_audio = clip_audio
|
| 91 |
+
|
| 92 |
+
if sample_rate is not None:
|
| 93 |
+
# optionally init the last layer following mel-scale
|
| 94 |
+
m_max = _hz_to_mel(sample_rate // 2)
|
| 95 |
+
m_pts = torch.linspace(0, m_max, out_dim)
|
| 96 |
+
f_pts = _mel_to_hz(m_pts)
|
| 97 |
+
scale = 1 - (f_pts / f_pts.max())
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
self.out.weight.mul_(scale.view(-1, 1))
|
| 101 |
+
|
| 102 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""
|
| 104 |
+
Forward pass of the IMDCTSymExpHead module.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 108 |
+
L is the sequence length, and H denotes the model dimension.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 112 |
+
"""
|
| 113 |
+
x = self.out(x)
|
| 114 |
+
x = symexp(x)
|
| 115 |
+
x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
|
| 116 |
+
audio = self.imdct(x)
|
| 117 |
+
if self.clip_audio:
|
| 118 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 119 |
+
|
| 120 |
+
return audio
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class IMDCTCosHead(FourierHead):
|
| 124 |
+
"""
|
| 125 |
+
IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
dim (int): Hidden dimension of the model.
|
| 129 |
+
mdct_frame_len (int): Length of the MDCT frame.
|
| 130 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 131 |
+
clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.clip_audio = clip_audio
|
| 137 |
+
self.out = nn.Linear(dim, mdct_frame_len)
|
| 138 |
+
self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
|
| 139 |
+
|
| 140 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
"""
|
| 142 |
+
Forward pass of the IMDCTCosHead module.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
|
| 146 |
+
L is the sequence length, and H denotes the model dimension.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
|
| 150 |
+
"""
|
| 151 |
+
x = self.out(x)
|
| 152 |
+
m, p = x.chunk(2, dim=2)
|
| 153 |
+
m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
|
| 154 |
+
audio = self.imdct(m * torch.cos(p))
|
| 155 |
+
if self.clip_audio:
|
| 156 |
+
audio = torch.clip(x, min=-1.0, max=1.0)
|
| 157 |
+
return audio
|
decoder/helpers.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from matplotlib import pyplot as plt
|
| 5 |
+
from pytorch_lightning import Callback
|
| 6 |
+
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
|
| 11 |
+
"""
|
| 12 |
+
Save a matplotlib figure to a numpy array.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
fig (Figure): Matplotlib figure object.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
ndarray: Numpy array representing the figure.
|
| 19 |
+
"""
|
| 20 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
| 21 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 22 |
+
return data
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
|
| 26 |
+
"""
|
| 27 |
+
Plot a spectrogram and convert it to a numpy array.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
spectrogram (ndarray): Spectrogram data.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
ndarray: Numpy array representing the plotted spectrogram.
|
| 34 |
+
"""
|
| 35 |
+
spectrogram = spectrogram.astype(np.float32)
|
| 36 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
| 37 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 38 |
+
plt.colorbar(im, ax=ax)
|
| 39 |
+
plt.xlabel("Frames")
|
| 40 |
+
plt.ylabel("Channels")
|
| 41 |
+
plt.tight_layout()
|
| 42 |
+
|
| 43 |
+
fig.canvas.draw()
|
| 44 |
+
data = save_figure_to_numpy(fig)
|
| 45 |
+
plt.close()
|
| 46 |
+
return data
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class GradNormCallback(Callback):
|
| 50 |
+
"""
|
| 51 |
+
Callback to log the gradient norm.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def on_after_backward(self, trainer, model):
|
| 55 |
+
model.log("grad_norm", gradient_norm(model))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
|
| 59 |
+
"""
|
| 60 |
+
Compute the gradient norm.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model (Module): PyTorch model.
|
| 64 |
+
norm_type (float, optional): Type of the norm. Defaults to 2.0.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Tensor: Gradient norm.
|
| 68 |
+
"""
|
| 69 |
+
grads = [p.grad for p in model.parameters() if p.grad is not None]
|
| 70 |
+
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
|
| 71 |
+
return total_norm
|
decoder/loss.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchaudio
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from decoder.modules import safe_log
|
| 8 |
+
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MelSpecReconstructionLoss(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
|
| 19 |
+
):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.mel_spec = torchaudio.transforms.MelSpectrogram(
|
| 22 |
+
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self, y_hat, y) -> torch.Tensor:
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
y_hat (Tensor): Predicted audio waveform.
|
| 29 |
+
y (Tensor): Ground truth audio waveform.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tensor: L1 loss between the mel-scaled magnitude spectrograms.
|
| 33 |
+
"""
|
| 34 |
+
mel_hat = safe_log(self.mel_spec(y_hat))
|
| 35 |
+
mel = safe_log(self.mel_spec(y))
|
| 36 |
+
|
| 37 |
+
loss = torch.nn.functional.l1_loss(mel, mel_hat)
|
| 38 |
+
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class GeneratorLoss(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 48 |
+
"""
|
| 49 |
+
Args:
|
| 50 |
+
disc_outputs (List[Tensor]): List of discriminator outputs.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
|
| 54 |
+
the sub-discriminators
|
| 55 |
+
"""
|
| 56 |
+
loss = 0
|
| 57 |
+
gen_losses = []
|
| 58 |
+
for dg in disc_outputs:
|
| 59 |
+
l = torch.mean(torch.clamp(1 - dg, min=0))
|
| 60 |
+
gen_losses.append(l)
|
| 61 |
+
loss += l
|
| 62 |
+
|
| 63 |
+
return loss, gen_losses
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DiscriminatorLoss(nn.Module):
|
| 67 |
+
"""
|
| 68 |
+
Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
| 73 |
+
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
|
| 77 |
+
disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
|
| 81 |
+
the sub-discriminators for real outputs, and a list of
|
| 82 |
+
loss values for generated outputs.
|
| 83 |
+
"""
|
| 84 |
+
loss = 0
|
| 85 |
+
r_losses = []
|
| 86 |
+
g_losses = []
|
| 87 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 88 |
+
r_loss = torch.mean(torch.clamp(1 - dr, min=0))
|
| 89 |
+
g_loss = torch.mean(torch.clamp(1 + dg, min=0))
|
| 90 |
+
loss += r_loss + g_loss
|
| 91 |
+
r_losses.append(r_loss.item())
|
| 92 |
+
g_losses.append(g_loss.item())
|
| 93 |
+
|
| 94 |
+
return loss, r_losses, g_losses
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class FeatureMatchingLoss(nn.Module):
|
| 98 |
+
"""
|
| 99 |
+
Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
| 103 |
+
"""
|
| 104 |
+
Args:
|
| 105 |
+
fmap_r (List[List[Tensor]]): List of feature maps from real samples.
|
| 106 |
+
fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tensor: The calculated feature matching loss.
|
| 110 |
+
"""
|
| 111 |
+
loss = 0
|
| 112 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 113 |
+
for rl, gl in zip(dr, dg):
|
| 114 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 115 |
+
|
| 116 |
+
return loss
|
| 117 |
+
|
| 118 |
+
class DACGANLoss(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Computes a discriminator loss, given a discriminator on
|
| 121 |
+
generated waveforms/spectrograms compared to ground truth
|
| 122 |
+
waveforms/spectrograms. Computes the loss for both the
|
| 123 |
+
discriminator and the generator in separate functions.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, discriminator):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.discriminator = discriminator
|
| 129 |
+
|
| 130 |
+
def forward(self, fake, real):
|
| 131 |
+
# d_fake = self.discriminator(fake.audio_data)
|
| 132 |
+
# d_real = self.discriminator(real.audio_data)
|
| 133 |
+
d_fake = self.discriminator(fake)
|
| 134 |
+
d_real = self.discriminator(real)
|
| 135 |
+
return d_fake, d_real
|
| 136 |
+
|
| 137 |
+
def discriminator_loss(self, fake, real):
|
| 138 |
+
d_fake, d_real = self.forward(fake.clone().detach(), real)
|
| 139 |
+
|
| 140 |
+
loss_d = 0
|
| 141 |
+
for x_fake, x_real in zip(d_fake, d_real):
|
| 142 |
+
loss_d += torch.mean(x_fake[-1] ** 2)
|
| 143 |
+
loss_d += torch.mean((1 - x_real[-1]) ** 2)
|
| 144 |
+
return loss_d
|
| 145 |
+
|
| 146 |
+
def generator_loss(self, fake, real):
|
| 147 |
+
d_fake, d_real = self.forward(fake, real)
|
| 148 |
+
|
| 149 |
+
loss_g = 0
|
| 150 |
+
for x_fake in d_fake:
|
| 151 |
+
loss_g += torch.mean((1 - x_fake[-1]) ** 2)
|
| 152 |
+
|
| 153 |
+
loss_feature = 0
|
| 154 |
+
|
| 155 |
+
for i in range(len(d_fake)):
|
| 156 |
+
for j in range(len(d_fake[i]) - 1):
|
| 157 |
+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
|
| 158 |
+
return loss_g, loss_feature
|
| 159 |
+
|
decoder/models.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
|
| 7 |
+
from decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def nonlinearity(x):
|
| 11 |
+
# swish
|
| 12 |
+
return x * torch.sigmoid(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def Normalize(in_channels, num_groups=32):
|
| 16 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ResnetBlock(nn.Module):
|
| 20 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 21 |
+
dropout, temb_channels=512):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.in_channels = in_channels
|
| 24 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 25 |
+
self.out_channels = out_channels
|
| 26 |
+
self.use_conv_shortcut = conv_shortcut
|
| 27 |
+
|
| 28 |
+
self.norm1 = Normalize(in_channels)
|
| 29 |
+
self.conv1 = torch.nn.Conv1d(in_channels,
|
| 30 |
+
out_channels,
|
| 31 |
+
kernel_size=3,
|
| 32 |
+
stride=1,
|
| 33 |
+
padding=1)
|
| 34 |
+
if temb_channels > 0:
|
| 35 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 36 |
+
out_channels)
|
| 37 |
+
self.norm2 = Normalize(out_channels)
|
| 38 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 39 |
+
self.conv2 = torch.nn.Conv1d(out_channels,
|
| 40 |
+
out_channels,
|
| 41 |
+
kernel_size=3,
|
| 42 |
+
stride=1,
|
| 43 |
+
padding=1)
|
| 44 |
+
if self.in_channels != self.out_channels:
|
| 45 |
+
if self.use_conv_shortcut:
|
| 46 |
+
self.conv_shortcut = torch.nn.Conv1d(in_channels,
|
| 47 |
+
out_channels,
|
| 48 |
+
kernel_size=3,
|
| 49 |
+
stride=1,
|
| 50 |
+
padding=1)
|
| 51 |
+
else:
|
| 52 |
+
self.nin_shortcut = torch.nn.Conv1d(in_channels,
|
| 53 |
+
out_channels,
|
| 54 |
+
kernel_size=1,
|
| 55 |
+
stride=1,
|
| 56 |
+
padding=0)
|
| 57 |
+
|
| 58 |
+
def forward(self, x, temb=None):
|
| 59 |
+
h = x
|
| 60 |
+
h = self.norm1(h)
|
| 61 |
+
h = nonlinearity(h)
|
| 62 |
+
h = self.conv1(h)
|
| 63 |
+
|
| 64 |
+
if temb is not None:
|
| 65 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
| 66 |
+
|
| 67 |
+
h = self.norm2(h)
|
| 68 |
+
h = nonlinearity(h)
|
| 69 |
+
h = self.dropout(h)
|
| 70 |
+
h = self.conv2(h)
|
| 71 |
+
|
| 72 |
+
if self.in_channels != self.out_channels:
|
| 73 |
+
if self.use_conv_shortcut:
|
| 74 |
+
x = self.conv_shortcut(x)
|
| 75 |
+
else:
|
| 76 |
+
x = self.nin_shortcut(x)
|
| 77 |
+
|
| 78 |
+
return x + h
|
| 79 |
+
|
| 80 |
+
class AttnBlock(nn.Module):
|
| 81 |
+
def __init__(self, in_channels):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.in_channels = in_channels
|
| 84 |
+
|
| 85 |
+
self.norm = Normalize(in_channels)
|
| 86 |
+
self.q = torch.nn.Conv1d(in_channels,
|
| 87 |
+
in_channels,
|
| 88 |
+
kernel_size=1,
|
| 89 |
+
stride=1,
|
| 90 |
+
padding=0)
|
| 91 |
+
self.k = torch.nn.Conv1d(in_channels,
|
| 92 |
+
in_channels,
|
| 93 |
+
kernel_size=1,
|
| 94 |
+
stride=1,
|
| 95 |
+
padding=0)
|
| 96 |
+
self.v = torch.nn.Conv1d(in_channels,
|
| 97 |
+
in_channels,
|
| 98 |
+
kernel_size=1,
|
| 99 |
+
stride=1,
|
| 100 |
+
padding=0)
|
| 101 |
+
self.proj_out = torch.nn.Conv1d(in_channels,
|
| 102 |
+
in_channels,
|
| 103 |
+
kernel_size=1,
|
| 104 |
+
stride=1,
|
| 105 |
+
padding=0)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
h_ = x
|
| 109 |
+
h_ = self.norm(h_)
|
| 110 |
+
q = self.q(h_)
|
| 111 |
+
k = self.k(h_)
|
| 112 |
+
v = self.v(h_)
|
| 113 |
+
|
| 114 |
+
# compute attention
|
| 115 |
+
b, c, h = q.shape
|
| 116 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
| 117 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 118 |
+
w_ = w_ * (int(c) ** (-0.5))
|
| 119 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 120 |
+
|
| 121 |
+
# attend to values
|
| 122 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
| 123 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 124 |
+
|
| 125 |
+
h_ = self.proj_out(h_)
|
| 126 |
+
|
| 127 |
+
return x + h_
|
| 128 |
+
|
| 129 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
| 130 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
| 131 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
| 132 |
+
if attn_type == "vanilla":
|
| 133 |
+
return AttnBlock(in_channels)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Backbone(nn.Module):
|
| 137 |
+
"""Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
|
| 143 |
+
C denotes output features, and L is the sequence length.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
|
| 147 |
+
and H denotes the model dimension.
|
| 148 |
+
"""
|
| 149 |
+
raise NotImplementedError("Subclasses must implement the forward method.")
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class VocosBackbone(Backbone):
|
| 153 |
+
"""
|
| 154 |
+
Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
input_channels (int): Number of input features channels.
|
| 158 |
+
dim (int): Hidden dimension of the model.
|
| 159 |
+
intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
|
| 160 |
+
num_layers (int): Number of ConvNeXtBlock layers.
|
| 161 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
|
| 162 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 163 |
+
None means non-conditional model. Defaults to None.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
input_channels: int,
|
| 169 |
+
dim: int,
|
| 170 |
+
intermediate_dim: int,
|
| 171 |
+
num_layers: int,
|
| 172 |
+
layer_scale_init_value: Optional[float] = None,
|
| 173 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 174 |
+
):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.input_channels = input_channels
|
| 177 |
+
self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
|
| 178 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 179 |
+
if adanorm_num_embeddings:
|
| 180 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 181 |
+
else:
|
| 182 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 183 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_layers
|
| 184 |
+
self.convnext = nn.ModuleList(
|
| 185 |
+
[
|
| 186 |
+
ConvNeXtBlock(
|
| 187 |
+
dim=dim,
|
| 188 |
+
intermediate_dim=intermediate_dim,
|
| 189 |
+
layer_scale_init_value=layer_scale_init_value,
|
| 190 |
+
adanorm_num_embeddings=adanorm_num_embeddings,
|
| 191 |
+
)
|
| 192 |
+
for _ in range(num_layers)
|
| 193 |
+
]
|
| 194 |
+
)
|
| 195 |
+
self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
|
| 196 |
+
self.apply(self._init_weights)
|
| 197 |
+
|
| 198 |
+
self.temb_ch = 0
|
| 199 |
+
block_in = dim
|
| 200 |
+
dropout = 0.1
|
| 201 |
+
attn_type="vanilla"
|
| 202 |
+
|
| 203 |
+
pos_net : tp.List[nn.Module] = [
|
| 204 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
| 205 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
| 206 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
| 207 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
| 208 |
+
make_attn(block_in, attn_type=attn_type),
|
| 209 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
| 210 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
| 211 |
+
ResnetBlock(in_channels=block_in,out_channels=block_in,
|
| 212 |
+
temb_channels=self.temb_ch,dropout=dropout),
|
| 213 |
+
Normalize(block_in)
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
self.pos_net = nn.Sequential(*pos_net)
|
| 217 |
+
|
| 218 |
+
def _init_weights(self, m):
|
| 219 |
+
if isinstance(m, (nn.Conv1d, nn.Linear)):
|
| 220 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 221 |
+
nn.init.constant_(m.bias, 0)
|
| 222 |
+
|
| 223 |
+
def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 224 |
+
x = self.embed(x)
|
| 225 |
+
x = self.pos_net(x)
|
| 226 |
+
if self.adanorm:
|
| 227 |
+
assert bandwidth_id is not None
|
| 228 |
+
x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
|
| 229 |
+
else:
|
| 230 |
+
x = self.norm(x.transpose(1, 2))
|
| 231 |
+
x = x.transpose(1, 2)
|
| 232 |
+
for conv_block in self.convnext:
|
| 233 |
+
x = conv_block(x, cond_embedding_id=bandwidth_id)
|
| 234 |
+
x = self.final_layer_norm(x.transpose(1, 2))
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class VocosResNetBackbone(Backbone):
|
| 239 |
+
"""
|
| 240 |
+
Vocos backbone module built with ResBlocks.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
input_channels (int): Number of input features channels.
|
| 244 |
+
dim (int): Hidden dimension of the model.
|
| 245 |
+
num_blocks (int): Number of ResBlock1 blocks.
|
| 246 |
+
layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(
|
| 250 |
+
self, input_channels, dim, num_blocks, layer_scale_init_value=None,
|
| 251 |
+
):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.input_channels = input_channels
|
| 254 |
+
self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
|
| 255 |
+
layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
|
| 256 |
+
self.resnet = nn.Sequential(
|
| 257 |
+
*[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 261 |
+
x = self.embed(x)
|
| 262 |
+
x = self.resnet(x)
|
| 263 |
+
x = x.transpose(1, 2)
|
| 264 |
+
return x
|
decoder/modules.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ConvNeXtBlock(nn.Module):
|
| 9 |
+
"""ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
dim (int): Number of input channels.
|
| 13 |
+
intermediate_dim (int): Dimensionality of the intermediate layer.
|
| 14 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 15 |
+
Defaults to None.
|
| 16 |
+
adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
|
| 17 |
+
None means non-conditional LayerNorm. Defaults to None.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
dim: int,
|
| 23 |
+
intermediate_dim: int,
|
| 24 |
+
layer_scale_init_value: Optional[float] = None,
|
| 25 |
+
adanorm_num_embeddings: Optional[int] = None,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
| 29 |
+
self.adanorm = adanorm_num_embeddings is not None
|
| 30 |
+
if adanorm_num_embeddings:
|
| 31 |
+
self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
|
| 32 |
+
else:
|
| 33 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 34 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 35 |
+
self.act = nn.GELU()
|
| 36 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 37 |
+
self.gamma = (
|
| 38 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
|
| 39 |
+
if layer_scale_init_value > 0
|
| 40 |
+
else None
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 44 |
+
residual = x
|
| 45 |
+
x = self.dwconv(x)
|
| 46 |
+
x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
|
| 47 |
+
if self.adanorm:
|
| 48 |
+
assert cond_embedding_id is not None
|
| 49 |
+
x = self.norm(x, cond_embedding_id)
|
| 50 |
+
else:
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
x = self.pwconv1(x)
|
| 53 |
+
x = self.act(x)
|
| 54 |
+
x = self.pwconv2(x)
|
| 55 |
+
if self.gamma is not None:
|
| 56 |
+
x = self.gamma * x
|
| 57 |
+
x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
|
| 58 |
+
|
| 59 |
+
x = residual + x
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AdaLayerNorm(nn.Module):
|
| 64 |
+
"""
|
| 65 |
+
Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
num_embeddings (int): Number of embeddings.
|
| 69 |
+
embedding_dim (int): Dimension of the embeddings.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.dim = embedding_dim
|
| 76 |
+
self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 77 |
+
self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
|
| 78 |
+
torch.nn.init.ones_(self.scale.weight)
|
| 79 |
+
torch.nn.init.zeros_(self.shift.weight)
|
| 80 |
+
|
| 81 |
+
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
|
| 82 |
+
scale = self.scale(cond_embedding_id)
|
| 83 |
+
shift = self.shift(cond_embedding_id)
|
| 84 |
+
x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
|
| 85 |
+
x = x * scale + shift
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ResBlock1(nn.Module):
|
| 90 |
+
"""
|
| 91 |
+
ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
|
| 92 |
+
but without upsampling layers.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
dim (int): Number of input channels.
|
| 96 |
+
kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
|
| 97 |
+
dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
|
| 98 |
+
Defaults to (1, 3, 5).
|
| 99 |
+
lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
|
| 100 |
+
Defaults to 0.1.
|
| 101 |
+
layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
|
| 102 |
+
Defaults to None.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
dim: int,
|
| 108 |
+
kernel_size: int = 3,
|
| 109 |
+
dilation: tuple[int] = (1, 3, 5),
|
| 110 |
+
lrelu_slope: float = 0.1,
|
| 111 |
+
layer_scale_init_value: float = None,
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
self.lrelu_slope = lrelu_slope
|
| 115 |
+
self.convs1 = nn.ModuleList(
|
| 116 |
+
[
|
| 117 |
+
weight_norm(
|
| 118 |
+
nn.Conv1d(
|
| 119 |
+
dim,
|
| 120 |
+
dim,
|
| 121 |
+
kernel_size,
|
| 122 |
+
1,
|
| 123 |
+
dilation=dilation[0],
|
| 124 |
+
padding=self.get_padding(kernel_size, dilation[0]),
|
| 125 |
+
)
|
| 126 |
+
),
|
| 127 |
+
weight_norm(
|
| 128 |
+
nn.Conv1d(
|
| 129 |
+
dim,
|
| 130 |
+
dim,
|
| 131 |
+
kernel_size,
|
| 132 |
+
1,
|
| 133 |
+
dilation=dilation[1],
|
| 134 |
+
padding=self.get_padding(kernel_size, dilation[1]),
|
| 135 |
+
)
|
| 136 |
+
),
|
| 137 |
+
weight_norm(
|
| 138 |
+
nn.Conv1d(
|
| 139 |
+
dim,
|
| 140 |
+
dim,
|
| 141 |
+
kernel_size,
|
| 142 |
+
1,
|
| 143 |
+
dilation=dilation[2],
|
| 144 |
+
padding=self.get_padding(kernel_size, dilation[2]),
|
| 145 |
+
)
|
| 146 |
+
),
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.convs2 = nn.ModuleList(
|
| 151 |
+
[
|
| 152 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 153 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 154 |
+
weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
|
| 155 |
+
]
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
self.gamma = nn.ParameterList(
|
| 159 |
+
[
|
| 160 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 161 |
+
if layer_scale_init_value is not None
|
| 162 |
+
else None,
|
| 163 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 164 |
+
if layer_scale_init_value is not None
|
| 165 |
+
else None,
|
| 166 |
+
nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
|
| 167 |
+
if layer_scale_init_value is not None
|
| 168 |
+
else None,
|
| 169 |
+
]
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
|
| 174 |
+
xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
|
| 175 |
+
xt = c1(xt)
|
| 176 |
+
xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
|
| 177 |
+
xt = c2(xt)
|
| 178 |
+
if gamma is not None:
|
| 179 |
+
xt = gamma * xt
|
| 180 |
+
x = xt + x
|
| 181 |
+
return x
|
| 182 |
+
|
| 183 |
+
def remove_weight_norm(self):
|
| 184 |
+
for l in self.convs1:
|
| 185 |
+
remove_weight_norm(l)
|
| 186 |
+
for l in self.convs2:
|
| 187 |
+
remove_weight_norm(l)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
| 191 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (Tensor): Input tensor.
|
| 200 |
+
clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: Element-wise logarithm of the input tensor with clipping applied.
|
| 204 |
+
"""
|
| 205 |
+
return torch.log(torch.clip(x, min=clip_val))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def symlog(x: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
return torch.sign(x) * torch.log1p(x.abs())
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def symexp(x: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
return torch.sign(x) * (torch.exp(x.abs()) - 1)
|
decoder/pretrained.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Tuple, Any, Union, Dict
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import yaml
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from torch import nn
|
| 8 |
+
from decoder.feature_extractors import FeatureExtractor, EncodecFeatures
|
| 9 |
+
from decoder.heads import FourierHead
|
| 10 |
+
from decoder.models import Backbone
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
| 14 |
+
"""Instantiates a class with the given args and init.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
args: Positional arguments required for instantiation.
|
| 18 |
+
init: Dict of the form {"class_path":...,"init_args":...}.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The instantiated class object.
|
| 22 |
+
"""
|
| 23 |
+
kwargs = init.get("init_args", {})
|
| 24 |
+
if not isinstance(args, tuple):
|
| 25 |
+
args = (args,)
|
| 26 |
+
class_module, class_name = init["class_path"].rsplit(".", 1)
|
| 27 |
+
module = __import__(class_module, fromlist=[class_name])
|
| 28 |
+
args_class = getattr(module, class_name)
|
| 29 |
+
return args_class(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WavTokenizer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
|
| 35 |
+
This class is primarily designed for inference, with support for loading from pretrained
|
| 36 |
+
model checkpoints. It consists of three main components: a feature extractor,
|
| 37 |
+
a backbone, and a head.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.feature_extractor = feature_extractor
|
| 45 |
+
self.backbone = backbone
|
| 46 |
+
self.head = head
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_hparams(cls, config_path: str) -> "Vocos":
|
| 50 |
+
"""
|
| 51 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
| 52 |
+
"""
|
| 53 |
+
with open(config_path, "r") as f:
|
| 54 |
+
config = yaml.safe_load(f)
|
| 55 |
+
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
|
| 56 |
+
backbone = instantiate_class(args=(), init=config["backbone"])
|
| 57 |
+
head = instantiate_class(args=(), init=config["head"])
|
| 58 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
| 59 |
+
return model
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def from_pretrained(self, repo_id: str) -> "Vocos":
|
| 63 |
+
"""
|
| 64 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 65 |
+
"""
|
| 66 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
|
| 67 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
|
| 68 |
+
model = self.from_hparams(config_path)
|
| 69 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 70 |
+
if isinstance(model.feature_extractor, EncodecFeatures):
|
| 71 |
+
encodec_parameters = {
|
| 72 |
+
"feature_extractor.encodec." + key: value
|
| 73 |
+
for key, value in model.feature_extractor.encodec.state_dict().items()
|
| 74 |
+
}
|
| 75 |
+
state_dict.update(encodec_parameters)
|
| 76 |
+
model.load_state_dict(state_dict)
|
| 77 |
+
model.eval()
|
| 78 |
+
return model
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def from_hparams0802(cls, config_path: str) -> "Vocos":
|
| 83 |
+
"""
|
| 84 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
| 85 |
+
"""
|
| 86 |
+
with open(config_path, "r") as f:
|
| 87 |
+
config = yaml.safe_load(f)
|
| 88 |
+
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
|
| 89 |
+
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
|
| 90 |
+
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
|
| 91 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
| 92 |
+
return model
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@classmethod
|
| 96 |
+
def from_pretrained0802(self, config_path, model_path):
|
| 97 |
+
"""
|
| 98 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 99 |
+
"""
|
| 100 |
+
model = self.from_hparams0802(config_path)
|
| 101 |
+
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
|
| 102 |
+
state_dict = dict()
|
| 103 |
+
for k, v in state_dict_raw.items():
|
| 104 |
+
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
|
| 105 |
+
state_dict[k] = v
|
| 106 |
+
# if isinstance(model.feature_extractor, EncodecFeatures):
|
| 107 |
+
# encodec_parameters = {
|
| 108 |
+
# "feature_extractor.encodec." + key: value
|
| 109 |
+
# for key, value in model.feature_extractor.encodec.state_dict().items()
|
| 110 |
+
# }
|
| 111 |
+
# state_dict.update(encodec_parameters)
|
| 112 |
+
model.load_state_dict(state_dict)
|
| 113 |
+
model.eval()
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@classmethod
|
| 118 |
+
def from_pretrained0911(self, config_path, model_folder_path):
|
| 119 |
+
"""
|
| 120 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 121 |
+
"""
|
| 122 |
+
model = self.from_hparams0802(config_path)
|
| 123 |
+
|
| 124 |
+
models = os.listdir(model_folder_path)
|
| 125 |
+
val_loss = []
|
| 126 |
+
for item in models:
|
| 127 |
+
if not item.startswith('vocos_'):
|
| 128 |
+
continue
|
| 129 |
+
val_loss.append(item[-11:-5])
|
| 130 |
+
val_loss.sort()
|
| 131 |
+
val_loss = val_loss[:3] # 取前3性能较好的模型平均
|
| 132 |
+
state_dict = dict()
|
| 133 |
+
state_dicts = []
|
| 134 |
+
for item in models:
|
| 135 |
+
if not item.startswith('vocos_'):
|
| 136 |
+
continue
|
| 137 |
+
ll = item[-11:-5]
|
| 138 |
+
if ll not in val_loss:
|
| 139 |
+
continue
|
| 140 |
+
model_path = model_folder_path + '/' + item
|
| 141 |
+
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
|
| 142 |
+
state_dict_single = dict()
|
| 143 |
+
for k, v in state_dict_raw.items():
|
| 144 |
+
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
|
| 145 |
+
state_dict_single[k] = v
|
| 146 |
+
state_dicts.append(state_dict_single)
|
| 147 |
+
for kk in state_dicts[0].keys():
|
| 148 |
+
vv = state_dicts[0][kk]
|
| 149 |
+
for i in range(1, len(state_dicts)):
|
| 150 |
+
ss = state_dicts[i]
|
| 151 |
+
vv += ss[kk]
|
| 152 |
+
vm = vv/len(state_dicts)
|
| 153 |
+
state_dict[kk] = vm
|
| 154 |
+
model.load_state_dict(state_dict)
|
| 155 |
+
model.eval()
|
| 156 |
+
return model
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@torch.inference_mode()
|
| 160 |
+
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
|
| 163 |
+
which is then passed through the backbone and the head to reconstruct the audio output.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
|
| 167 |
+
where B is the batch size and L is the waveform length.
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 172 |
+
"""
|
| 173 |
+
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
|
| 174 |
+
audio_output = self.decode(features, **kwargs)
|
| 175 |
+
return audio_output
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# 0818
|
| 179 |
+
@torch.inference_mode()
|
| 180 |
+
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 181 |
+
features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
|
| 182 |
+
return features,discrete_codes
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# 0818
|
| 186 |
+
@torch.inference_mode()
|
| 187 |
+
def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 188 |
+
features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
|
| 189 |
+
return features,discrete_codes
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@torch.inference_mode()
|
| 193 |
+
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 194 |
+
"""
|
| 195 |
+
Method to decode audio waveform from already calculated features. The features input is passed through
|
| 196 |
+
the backbone and the head to reconstruct the audio output.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
|
| 200 |
+
C denotes the feature dimension, and L is the sequence length.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 204 |
+
"""
|
| 205 |
+
x = self.backbone(features_input, **kwargs)
|
| 206 |
+
audio_output = self.head(x)
|
| 207 |
+
return audio_output
|
| 208 |
+
|
| 209 |
+
@torch.inference_mode()
|
| 210 |
+
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
"""
|
| 212 |
+
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
|
| 213 |
+
codebook weights.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
|
| 217 |
+
where K is the number of codebooks, B is the batch size and L is the sequence length.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
|
| 221 |
+
and L is the sequence length.
|
| 222 |
+
"""
|
| 223 |
+
assert isinstance(
|
| 224 |
+
self.feature_extractor, EncodecFeatures
|
| 225 |
+
), "Feature extractor should be an instance of EncodecFeatures"
|
| 226 |
+
|
| 227 |
+
if codes.dim() == 2:
|
| 228 |
+
codes = codes.unsqueeze(1)
|
| 229 |
+
|
| 230 |
+
n_bins = self.feature_extractor.encodec.quantizer.bins
|
| 231 |
+
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
|
| 232 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
| 233 |
+
|
| 234 |
+
tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
|
| 235 |
+
# features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
|
| 236 |
+
features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
|
| 237 |
+
features = features.transpose(1, 2)
|
| 238 |
+
|
| 239 |
+
return features
|
decoder/pretrained_model.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Any, Union, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import yaml
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
from torch import nn
|
| 7 |
+
from decoder.feature_extractors import FeatureExtractor, EncodecFeatures
|
| 8 |
+
from decoder.heads import FourierHead
|
| 9 |
+
from decoder.models import Backbone
|
| 10 |
+
from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
|
| 14 |
+
"""Instantiates a class with the given args and init.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
args: Positional arguments required for instantiation.
|
| 18 |
+
init: Dict of the form {"class_path":...,"init_args":...}.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The instantiated class object.
|
| 22 |
+
"""
|
| 23 |
+
kwargs = init.get("init_args", {})
|
| 24 |
+
if not isinstance(args, tuple):
|
| 25 |
+
args = (args,)
|
| 26 |
+
class_module, class_name = init["class_path"].rsplit(".", 1)
|
| 27 |
+
module = __import__(class_module, fromlist=[class_name])
|
| 28 |
+
args_class = getattr(module, class_name)
|
| 29 |
+
return args_class(*args, **kwargs)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class WavTokenizer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
|
| 35 |
+
This class is primarily designed for inference, with support for loading from pretrained
|
| 36 |
+
model checkpoints. It consists of three main components: a feature extractor,
|
| 37 |
+
a backbone, and a head.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
|
| 42 |
+
multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.feature_extractor = feature_extractor
|
| 46 |
+
self.backbone = backbone
|
| 47 |
+
self.head = head
|
| 48 |
+
|
| 49 |
+
self.multiperioddisc = multiperioddisc
|
| 50 |
+
self.multiresddisc = multiresddisc
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_hparams0828(cls, config_path: str) -> "Vocos":
|
| 54 |
+
"""
|
| 55 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
| 56 |
+
"""
|
| 57 |
+
with open(config_path, "r") as f:
|
| 58 |
+
config = yaml.safe_load(f)
|
| 59 |
+
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
|
| 60 |
+
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
|
| 61 |
+
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
|
| 62 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head,
|
| 63 |
+
multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4),
|
| 64 |
+
multiresddisc=MultiResolutionDiscriminator(num_embeddings=4))
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def from_pretrained0828(self, config_path, model_path):
|
| 69 |
+
"""
|
| 70 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 71 |
+
"""
|
| 72 |
+
model = self.from_hparams0828(config_path)
|
| 73 |
+
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
|
| 74 |
+
state_dict = dict()
|
| 75 |
+
for k, v in state_dict_raw.items():
|
| 76 |
+
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \
|
| 77 |
+
or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'):
|
| 78 |
+
state_dict[k] = v
|
| 79 |
+
# if isinstance(model.feature_extractor, EncodecFeatures):
|
| 80 |
+
# encodec_parameters = {
|
| 81 |
+
# "feature_extractor.encodec." + key: value
|
| 82 |
+
# for key, value in model.feature_extractor.encodec.state_dict().items()
|
| 83 |
+
# }
|
| 84 |
+
# state_dict.update(encodec_parameters)
|
| 85 |
+
model.load_state_dict(state_dict)
|
| 86 |
+
return model
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def from_hparams0802(cls, config_path: str) -> "Vocos":
|
| 90 |
+
"""
|
| 91 |
+
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
|
| 92 |
+
"""
|
| 93 |
+
with open(config_path, "r") as f:
|
| 94 |
+
config = yaml.safe_load(f)
|
| 95 |
+
feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
|
| 96 |
+
backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
|
| 97 |
+
head = instantiate_class(args=(), init=config['model']['init_args']["head"])
|
| 98 |
+
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
|
| 99 |
+
return model
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def from_pretrained0802(self, config_path, model_path):
|
| 103 |
+
"""
|
| 104 |
+
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
|
| 105 |
+
"""
|
| 106 |
+
model = self.from_hparams0802(config_path)
|
| 107 |
+
state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
|
| 108 |
+
state_dict = dict()
|
| 109 |
+
for k, v in state_dict_raw.items():
|
| 110 |
+
if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
|
| 111 |
+
state_dict[k] = v
|
| 112 |
+
# if isinstance(model.feature_extractor, EncodecFeatures):
|
| 113 |
+
# encodec_parameters = {
|
| 114 |
+
# "feature_extractor.encodec." + key: value
|
| 115 |
+
# for key, value in model.feature_extractor.encodec.state_dict().items()
|
| 116 |
+
# }
|
| 117 |
+
# state_dict.update(encodec_parameters)
|
| 118 |
+
model.load_state_dict(state_dict)
|
| 119 |
+
model.eval()
|
| 120 |
+
return model
|
| 121 |
+
|
| 122 |
+
@torch.inference_mode()
|
| 123 |
+
def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
|
| 126 |
+
which is then passed through the backbone and the head to reconstruct the audio output.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
|
| 130 |
+
where B is the batch size and L is the waveform length.
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 135 |
+
"""
|
| 136 |
+
features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
|
| 137 |
+
audio_output = self.decode(features, **kwargs)
|
| 138 |
+
return audio_output
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# 0818
|
| 142 |
+
@torch.inference_mode()
|
| 143 |
+
def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 144 |
+
features, _, _ = self.feature_extractor(audio_input, **kwargs)
|
| 145 |
+
return features
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@torch.inference_mode()
|
| 149 |
+
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
Method to decode audio waveform from already calculated features. The features input is passed through
|
| 152 |
+
the backbone and the head to reconstruct the audio output.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
|
| 156 |
+
C denotes the feature dimension, and L is the sequence length.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
|
| 160 |
+
"""
|
| 161 |
+
x = self.backbone(features_input, **kwargs)
|
| 162 |
+
audio_output = self.head(x)
|
| 163 |
+
return audio_output
|
| 164 |
+
|
| 165 |
+
@torch.inference_mode()
|
| 166 |
+
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
"""
|
| 168 |
+
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
|
| 169 |
+
codebook weights.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
|
| 173 |
+
where K is the number of codebooks, B is the batch size and L is the sequence length.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
|
| 177 |
+
and L is the sequence length.
|
| 178 |
+
"""
|
| 179 |
+
assert isinstance(
|
| 180 |
+
self.feature_extractor, EncodecFeatures
|
| 181 |
+
), "Feature extractor should be an instance of EncodecFeatures"
|
| 182 |
+
|
| 183 |
+
if codes.dim() == 2:
|
| 184 |
+
codes = codes.unsqueeze(1)
|
| 185 |
+
|
| 186 |
+
n_bins = self.feature_extractor.encodec.quantizer.bins
|
| 187 |
+
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
|
| 188 |
+
embeddings_idxs = codes + offsets.view(-1, 1, 1)
|
| 189 |
+
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
|
| 190 |
+
features = features.transpose(1, 2)
|
| 191 |
+
|
| 192 |
+
return features
|
decoder/spectral_ops.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import scipy
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn, view_as_real, view_as_complex
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ISTFT(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
|
| 10 |
+
windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
|
| 11 |
+
See issue: https://github.com/pytorch/pytorch/issues/62323
|
| 12 |
+
Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
|
| 13 |
+
The NOLA constraint is met as we trim padded samples anyway.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
n_fft (int): Size of Fourier transform.
|
| 17 |
+
hop_length (int): The distance between neighboring sliding window frames.
|
| 18 |
+
win_length (int): The size of window frame and STFT filter.
|
| 19 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
|
| 23 |
+
super().__init__()
|
| 24 |
+
if padding not in ["center", "same"]:
|
| 25 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 26 |
+
self.padding = padding
|
| 27 |
+
self.n_fft = n_fft
|
| 28 |
+
self.hop_length = hop_length
|
| 29 |
+
self.win_length = win_length
|
| 30 |
+
window = torch.hann_window(win_length)
|
| 31 |
+
self.register_buffer("window", window)
|
| 32 |
+
|
| 33 |
+
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
|
| 39 |
+
N is the number of frequency bins, and T is the number of time frames.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
|
| 43 |
+
"""
|
| 44 |
+
if self.padding == "center":
|
| 45 |
+
# Fallback to pytorch native implementation
|
| 46 |
+
return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
|
| 47 |
+
elif self.padding == "same":
|
| 48 |
+
pad = (self.win_length - self.hop_length) // 2
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 51 |
+
|
| 52 |
+
assert spec.dim() == 3, "Expected a 3D tensor as input"
|
| 53 |
+
B, N, T = spec.shape
|
| 54 |
+
|
| 55 |
+
# Inverse FFT
|
| 56 |
+
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
|
| 57 |
+
ifft = ifft * self.window[None, :, None]
|
| 58 |
+
|
| 59 |
+
# Overlap and Add
|
| 60 |
+
output_size = (T - 1) * self.hop_length + self.win_length
|
| 61 |
+
y = torch.nn.functional.fold(
|
| 62 |
+
ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 63 |
+
)[:, 0, 0, pad:-pad]
|
| 64 |
+
|
| 65 |
+
# Window envelope
|
| 66 |
+
window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
|
| 67 |
+
window_envelope = torch.nn.functional.fold(
|
| 68 |
+
window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
|
| 69 |
+
).squeeze()[pad:-pad]
|
| 70 |
+
|
| 71 |
+
# Normalize
|
| 72 |
+
assert (window_envelope > 1e-11).all()
|
| 73 |
+
y = y / window_envelope
|
| 74 |
+
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MDCT(nn.Module):
|
| 79 |
+
"""
|
| 80 |
+
Modified Discrete Cosine Transform (MDCT) module.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
frame_len (int): Length of the MDCT frame.
|
| 84 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 88 |
+
super().__init__()
|
| 89 |
+
if padding not in ["center", "same"]:
|
| 90 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 91 |
+
self.padding = padding
|
| 92 |
+
self.frame_len = frame_len
|
| 93 |
+
N = frame_len // 2
|
| 94 |
+
n0 = (N + 1) / 2
|
| 95 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 96 |
+
self.register_buffer("window", window)
|
| 97 |
+
|
| 98 |
+
pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
|
| 99 |
+
post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
|
| 100 |
+
# view_as_real: NCCL Backend does not support ComplexFloat data type
|
| 101 |
+
# https://github.com/pytorch/pytorch/issues/71613
|
| 102 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 103 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 104 |
+
|
| 105 |
+
def forward(self, audio: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""
|
| 107 |
+
Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
|
| 111 |
+
and T is the length of the audio.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
|
| 115 |
+
and N is the number of frequency bins.
|
| 116 |
+
"""
|
| 117 |
+
if self.padding == "center":
|
| 118 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
|
| 119 |
+
elif self.padding == "same":
|
| 120 |
+
# hop_length is 1/2 frame_len
|
| 121 |
+
audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
|
| 122 |
+
else:
|
| 123 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 124 |
+
|
| 125 |
+
x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
|
| 126 |
+
N = self.frame_len // 2
|
| 127 |
+
x = x * self.window.expand(x.shape)
|
| 128 |
+
X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
|
| 129 |
+
res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
|
| 130 |
+
return torch.real(res) * np.sqrt(2)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class IMDCT(nn.Module):
|
| 134 |
+
"""
|
| 135 |
+
Inverse Modified Discrete Cosine Transform (IMDCT) module.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
frame_len (int): Length of the MDCT frame.
|
| 139 |
+
padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def __init__(self, frame_len: int, padding: str = "same"):
|
| 143 |
+
super().__init__()
|
| 144 |
+
if padding not in ["center", "same"]:
|
| 145 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 146 |
+
self.padding = padding
|
| 147 |
+
self.frame_len = frame_len
|
| 148 |
+
N = frame_len // 2
|
| 149 |
+
n0 = (N + 1) / 2
|
| 150 |
+
window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
|
| 151 |
+
self.register_buffer("window", window)
|
| 152 |
+
|
| 153 |
+
pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
|
| 154 |
+
post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
|
| 155 |
+
self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
|
| 156 |
+
self.register_buffer("post_twiddle", view_as_real(post_twiddle))
|
| 157 |
+
|
| 158 |
+
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
|
| 164 |
+
L is the number of frames, and N is the number of frequency bins.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
|
| 168 |
+
"""
|
| 169 |
+
B, L, N = X.shape
|
| 170 |
+
Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
|
| 171 |
+
Y[..., :N] = X
|
| 172 |
+
Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
|
| 173 |
+
y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
|
| 174 |
+
y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
|
| 175 |
+
result = y * self.window.expand(y.shape)
|
| 176 |
+
output_size = (1, (L + 1) * N)
|
| 177 |
+
audio = torch.nn.functional.fold(
|
| 178 |
+
result.transpose(1, 2),
|
| 179 |
+
output_size=output_size,
|
| 180 |
+
kernel_size=(1, self.frame_len),
|
| 181 |
+
stride=(1, self.frame_len // 2),
|
| 182 |
+
)[:, 0, 0, :]
|
| 183 |
+
|
| 184 |
+
if self.padding == "center":
|
| 185 |
+
pad = self.frame_len // 2
|
| 186 |
+
elif self.padding == "same":
|
| 187 |
+
pad = self.frame_len // 4
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 190 |
+
|
| 191 |
+
audio = audio[:, pad:-pad]
|
| 192 |
+
return audio
|
infer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --coding:utf-8--
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from encoder.utils import convert_audio
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torch
|
| 7 |
+
from decoder.pretrained import WavTokenizer
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
device1=torch.device('cuda:0')
|
| 14 |
+
device2=torch.device('cpu')
|
| 15 |
+
|
| 16 |
+
input_path = "./WavTokenizer/data/infer/lirbitts_testclean"
|
| 17 |
+
out_folder = './WavTokenizer/result/infer'
|
| 18 |
+
# os.system("rm -r %s"%(out_folder))
|
| 19 |
+
# os.system("mkdir -p %s"%(out_folder))
|
| 20 |
+
# ll="libritts_testclean500_large"
|
| 21 |
+
ll="wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn_testclean_epoch34"
|
| 22 |
+
|
| 23 |
+
tmptmp=out_folder+"/"+ll
|
| 24 |
+
|
| 25 |
+
os.system("rm -r %s"%(tmptmp))
|
| 26 |
+
os.system("mkdir -p %s"%(tmptmp))
|
| 27 |
+
|
| 28 |
+
# 自己数据模型加载
|
| 29 |
+
config_path = "./WavTokenizer/configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
|
| 30 |
+
model_path = "./WavTokenizer/result/train/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn/lightning_logs/version_3/checkpoints/wavtokenizer_checkpoint_epoch=24_step=137150_val_loss=5.6731.ckpt"
|
| 31 |
+
wavtokenizer = WavTokenizer.from_pretrained0802(config_path, model_path)
|
| 32 |
+
wavtokenizer = wavtokenizer.to(device1)
|
| 33 |
+
# wavtokenizer = wavtokenizer.to(device2)
|
| 34 |
+
|
| 35 |
+
with open(input_path,'r') as fin:
|
| 36 |
+
x=fin.readlines()
|
| 37 |
+
|
| 38 |
+
x = [i.strip() for i in x]
|
| 39 |
+
|
| 40 |
+
# 完成一些加速处理
|
| 41 |
+
|
| 42 |
+
features_all=[]
|
| 43 |
+
|
| 44 |
+
for i in range(len(x)):
|
| 45 |
+
|
| 46 |
+
wav, sr = torchaudio.load(x[i])
|
| 47 |
+
# print("***:",x[i])
|
| 48 |
+
# wav = convert_audio(wav, sr, 24000, 1) # (1,131040)
|
| 49 |
+
bandwidth_id = torch.tensor([0])
|
| 50 |
+
wav=wav.to(device1)
|
| 51 |
+
print(i)
|
| 52 |
+
|
| 53 |
+
features,discrete_code= wavtokenizer.encode_infer(wav, bandwidth_id=bandwidth_id)
|
| 54 |
+
features_all.append(features.cpu())
|
| 55 |
+
|
| 56 |
+
wavtokenizer = wavtokenizer.to(device2)
|
| 57 |
+
|
| 58 |
+
for i in range(len(x)):
|
| 59 |
+
|
| 60 |
+
bandwidth_id = torch.tensor([0])
|
| 61 |
+
|
| 62 |
+
print(i)
|
| 63 |
+
audio_out = wavtokenizer.decode(features_all[i], bandwidth_id=bandwidth_id)
|
| 64 |
+
# print(i,time.time())
|
| 65 |
+
# breakpoint() # (1, 131200)
|
| 66 |
+
audio_path = out_folder + '/' + ll + '/' + x[i].split('/')[-1]
|
| 67 |
+
# os.makedirs(out_folder + '/' + ll, exist_ok=True)
|
| 68 |
+
torchaudio.save(audio_path, audio_out, sample_rate=24000, encoding='PCM_S', bits_per_sample=16)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
train.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 3 |
+
|
| 4 |
+
from pytorch_lightning.cli import LightningCLI, ArgsType
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def cli_main(args: ArgsType = None):
|
| 8 |
+
# breakpoint()
|
| 9 |
+
cli = LightningCLI(args=args)
|
| 10 |
+
# breakpoint()
|
| 11 |
+
cli.trainer.fit(model=cli.model, datamodule=cli.datamodule)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if __name__ == "__main__":
|
| 15 |
+
cli_main()
|