breadlicker45 commited on
Commit
f12fa11
·
verified ·
1 Parent(s): c34b897

Upload 44 files

Browse files
Files changed (44) hide show
  1. configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml +93 -0
  2. configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml +93 -0
  3. data/demo.txt +4 -0
  4. decoder/__init__.py +4 -0
  5. decoder/__pycache__/__init__.cpython-310.pyc +0 -0
  6. decoder/__pycache__/__init__.cpython-38.pyc +0 -0
  7. decoder/__pycache__/__init__.cpython-39.pyc +0 -0
  8. decoder/__pycache__/dataset.cpython-310.pyc +0 -0
  9. decoder/__pycache__/discriminator_dac.cpython-310.pyc +0 -0
  10. decoder/__pycache__/discriminators.cpython-310.pyc +0 -0
  11. decoder/__pycache__/experiment.cpython-310.pyc +0 -0
  12. decoder/__pycache__/feature_extractors.cpython-310.pyc +0 -0
  13. decoder/__pycache__/feature_extractors.cpython-38.pyc +0 -0
  14. decoder/__pycache__/feature_extractors.cpython-39.pyc +0 -0
  15. decoder/__pycache__/heads.cpython-310.pyc +0 -0
  16. decoder/__pycache__/heads.cpython-39.pyc +0 -0
  17. decoder/__pycache__/helpers.cpython-310.pyc +0 -0
  18. decoder/__pycache__/loss.cpython-310.pyc +0 -0
  19. decoder/__pycache__/models.cpython-310.pyc +0 -0
  20. decoder/__pycache__/models.cpython-39.pyc +0 -0
  21. decoder/__pycache__/modules.cpython-310.pyc +0 -0
  22. decoder/__pycache__/modules.cpython-38.pyc +0 -0
  23. decoder/__pycache__/modules.cpython-39.pyc +0 -0
  24. decoder/__pycache__/pretrained.cpython-310.pyc +0 -0
  25. decoder/__pycache__/pretrained.cpython-38.pyc +0 -0
  26. decoder/__pycache__/pretrained.cpython-39.pyc +0 -0
  27. decoder/__pycache__/pretrained_model.cpython-310.pyc +0 -0
  28. decoder/__pycache__/spectral_ops.cpython-310.pyc +0 -0
  29. decoder/__pycache__/spectral_ops.cpython-39.pyc +0 -0
  30. decoder/dataset.py +84 -0
  31. decoder/discriminator_dac.py +249 -0
  32. decoder/discriminators.py +202 -0
  33. decoder/experiment.py +474 -0
  34. decoder/feature_extractors.py +141 -0
  35. decoder/heads.py +157 -0
  36. decoder/helpers.py +71 -0
  37. decoder/loss.py +159 -0
  38. decoder/models.py +264 -0
  39. decoder/modules.py +213 -0
  40. decoder/pretrained.py +239 -0
  41. decoder/pretrained_model.py +192 -0
  42. decoder/spectral_ops.py +192 -0
  43. infer.py +73 -0
  44. train.py +15 -0
configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 3407
2
+
3
+ data:
4
+ class_path: decoder.dataset.VocosDataModule
5
+ init_args:
6
+ train_params:
7
+ filelist_path: ./WavTokenizer/data/train/libritts_train
8
+ sampling_rate: 24000
9
+ num_samples: 72000
10
+ batch_size: 40 # 20
11
+ num_workers: 8
12
+
13
+ val_params:
14
+ filelist_path: ./WavTokenizer/data/infer/librttts_val
15
+ sampling_rate: 24000
16
+ num_samples: 72000
17
+ batch_size: 5 # 10
18
+ num_workers: 8
19
+
20
+ model:
21
+ class_path: decoder.experiment.WavTokenizer
22
+ init_args:
23
+ sample_rate: 24000
24
+ initial_learning_rate: 2e-4
25
+ mel_loss_coeff: 45
26
+ mrd_loss_coeff: 1.0
27
+ num_warmup_steps: 0 # Optimizers warmup steps
28
+ pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration
29
+
30
+ # automatic evaluation
31
+ evaluate_utmos: true
32
+ evaluate_pesq: true
33
+ evaluate_periodicty: true
34
+
35
+ resume: false
36
+ resume_config: ./WavTokenizer/configs/wavtokenizer_smalldata_frame40_3s_nq1_code16384_dim512_kmeans800_attn.yaml
37
+ resume_model: ./version_3/checkpoints/xxx.ckpt
38
+
39
+ feature_extractor:
40
+ class_path: decoder.feature_extractors.EncodecFeatures
41
+ init_args:
42
+ encodec_model: encodec_24khz
43
+ bandwidths: [6.6, 6.6, 6.6, 6.6]
44
+ train_codebooks: true
45
+ num_quantizers: 1
46
+ dowmsamples: [6, 5, 5, 4]
47
+ vq_bins: 4096
48
+ vq_kmeans: 200
49
+
50
+ backbone:
51
+ class_path: decoder.models.VocosBackbone
52
+ init_args:
53
+ input_channels: 512
54
+ dim: 768
55
+ intermediate_dim: 2304
56
+ num_layers: 12
57
+ adanorm_num_embeddings: 4
58
+
59
+ head:
60
+ class_path: decoder.heads.ISTFTHead
61
+ init_args:
62
+ dim: 768
63
+ n_fft: 2400
64
+ hop_length: 600
65
+ padding: same
66
+
67
+ trainer:
68
+ logger:
69
+ class_path: pytorch_lightning.loggers.TensorBoardLogger
70
+ init_args:
71
+ save_dir: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn/
72
+ callbacks:
73
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
74
+ - class_path: pytorch_lightning.callbacks.ModelSummary
75
+ init_args:
76
+ max_depth: 2
77
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
78
+ init_args:
79
+ monitor: val_loss
80
+ filename: wavtokenizer_checkpoint_{epoch}_{step}_{val_loss:.4f}
81
+ save_top_k: 10
82
+ save_last: true
83
+ - class_path: decoder.helpers.GradNormCallback
84
+
85
+ # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
86
+ # This equals to 1M steps per generator and 1M per discriminator
87
+ max_steps: 20000000
88
+ # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
89
+ limit_val_batches: 200
90
+ accelerator: gpu
91
+ strategy: ddp
92
+ devices: [0,1,2,3,4,5,6,7]
93
+ log_every_n_steps: 1000
configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed_everything: 3407
2
+
3
+ data:
4
+ class_path: decoder.dataset.VocosDataModule
5
+ init_args:
6
+ train_params:
7
+ filelist_path: ./WavTokenizer/data/train/libritts_train
8
+ sampling_rate: 24000
9
+ num_samples: 72000
10
+ batch_size: 40 # 20
11
+ num_workers: 8
12
+
13
+ val_params:
14
+ filelist_path: ./WavTokenizer/data/infer/librttts_val
15
+ sampling_rate: 24000
16
+ num_samples: 72000
17
+ batch_size: 5 # 10
18
+ num_workers: 8
19
+
20
+ model:
21
+ class_path: decoder.experiment.WavTokenizer
22
+ init_args:
23
+ sample_rate: 24000
24
+ initial_learning_rate: 2e-4
25
+ mel_loss_coeff: 45
26
+ mrd_loss_coeff: 1.0
27
+ num_warmup_steps: 0 # Optimizers warmup steps
28
+ pretrain_mel_steps: 0 # 0 means GAN objective from the first iteration
29
+
30
+ # automatic evaluation
31
+ evaluate_utmos: true
32
+ evaluate_pesq: true
33
+ evaluate_periodicty: true
34
+
35
+ resume: false
36
+ resume_config: ./WavTokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code16384_dim512_kmeans800_attn.yaml
37
+ resume_model: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame75_3s_nq1_code16384_dim512_kmeans800_attn/xxx.ckpt
38
+
39
+ feature_extractor:
40
+ class_path: decoder.feature_extractors.EncodecFeatures
41
+ init_args:
42
+ encodec_model: encodec_24khz
43
+ bandwidths: [6.6, 6.6, 6.6, 6.6]
44
+ train_codebooks: true
45
+ num_quantizers: 1
46
+ dowmsamples: [8, 5, 4, 2]
47
+ vq_bins: 4096
48
+ vq_kmeans: 200
49
+
50
+ backbone:
51
+ class_path: decoder.models.VocosBackbone
52
+ init_args:
53
+ input_channels: 512
54
+ dim: 768
55
+ intermediate_dim: 2304
56
+ num_layers: 12
57
+ adanorm_num_embeddings: 4
58
+
59
+ head:
60
+ class_path: decoder.heads.ISTFTHead
61
+ init_args:
62
+ dim: 768
63
+ n_fft: 1280
64
+ hop_length: 320
65
+ padding: same
66
+
67
+ trainer:
68
+ logger:
69
+ class_path: pytorch_lightning.loggers.TensorBoardLogger
70
+ init_args:
71
+ save_dir: ./WavTokenizer/result/train/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn/
72
+ callbacks:
73
+ - class_path: pytorch_lightning.callbacks.LearningRateMonitor
74
+ - class_path: pytorch_lightning.callbacks.ModelSummary
75
+ init_args:
76
+ max_depth: 2
77
+ - class_path: pytorch_lightning.callbacks.ModelCheckpoint
78
+ init_args:
79
+ monitor: val_loss
80
+ filename: wavtokenizer_checkpoint_{epoch}_{step}_{val_loss:.4f}
81
+ save_top_k: 10
82
+ save_last: true
83
+ - class_path: decoder.helpers.GradNormCallback
84
+
85
+ # Lightning calculates max_steps across all optimizer steps (rather than number of batches)
86
+ # This equals to 1M steps per generator and 1M per discriminator
87
+ max_steps: 20000000
88
+ # You might want to limit val batches when evaluating all the metrics, as they are time-consuming
89
+ limit_val_batches: 100
90
+ accelerator: gpu
91
+ strategy: ddp
92
+ devices: [0,1,2,3,4,5,6,7]
93
+ log_every_n_steps: 1000
data/demo.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ./example1.wav
2
+ ./example2.wav
3
+ ./example3.mp3
4
+ ./example4.flac
decoder/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from decoder.pretrained import WavTokenizer
2
+
3
+
4
+ __version__ = "0.0.3"
decoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (224 Bytes). View file
 
decoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (234 Bytes). View file
 
decoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (233 Bytes). View file
 
decoder/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (3.34 kB). View file
 
decoder/__pycache__/discriminator_dac.cpython-310.pyc ADDED
Binary file (8.14 kB). View file
 
decoder/__pycache__/discriminators.cpython-310.pyc ADDED
Binary file (6.97 kB). View file
 
decoder/__pycache__/experiment.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
decoder/__pycache__/feature_extractors.cpython-310.pyc ADDED
Binary file (4.66 kB). View file
 
decoder/__pycache__/feature_extractors.cpython-38.pyc ADDED
Binary file (4.12 kB). View file
 
decoder/__pycache__/feature_extractors.cpython-39.pyc ADDED
Binary file (4.43 kB). View file
 
decoder/__pycache__/heads.cpython-310.pyc ADDED
Binary file (6.69 kB). View file
 
decoder/__pycache__/heads.cpython-39.pyc ADDED
Binary file (6.63 kB). View file
 
decoder/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
decoder/__pycache__/loss.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
decoder/__pycache__/models.cpython-310.pyc ADDED
Binary file (8.09 kB). View file
 
decoder/__pycache__/models.cpython-39.pyc ADDED
Binary file (8.01 kB). View file
 
decoder/__pycache__/modules.cpython-310.pyc ADDED
Binary file (6.64 kB). View file
 
decoder/__pycache__/modules.cpython-38.pyc ADDED
Binary file (6.62 kB). View file
 
decoder/__pycache__/modules.cpython-39.pyc ADDED
Binary file (6.59 kB). View file
 
decoder/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (8.32 kB). View file
 
decoder/__pycache__/pretrained.cpython-38.pyc ADDED
Binary file (8.08 kB). View file
 
decoder/__pycache__/pretrained.cpython-39.pyc ADDED
Binary file (8.41 kB). View file
 
decoder/__pycache__/pretrained_model.cpython-310.pyc ADDED
Binary file (7.12 kB). View file
 
decoder/__pycache__/spectral_ops.cpython-310.pyc ADDED
Binary file (6.85 kB). View file
 
decoder/__pycache__/spectral_ops.cpython-39.pyc ADDED
Binary file (6.9 kB). View file
 
decoder/dataset.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from pytorch_lightning import LightningDataModule
7
+ from torch.utils.data import Dataset, DataLoader
8
+
9
+ import soundfile
10
+ # import librosa
11
+
12
+ torch.set_num_threads(1)
13
+
14
+
15
+ @dataclass
16
+ class DataConfig:
17
+ filelist_path: str
18
+ sampling_rate: int
19
+ num_samples: int
20
+ batch_size: int
21
+ num_workers: int
22
+
23
+
24
+ class VocosDataModule(LightningDataModule):
25
+ def __init__(self, train_params: DataConfig, val_params: DataConfig):
26
+ super().__init__()
27
+ self.train_config = train_params
28
+ self.val_config = val_params
29
+
30
+ def _get_dataloder(self, cfg: DataConfig, train: bool):
31
+ dataset = VocosDataset(cfg, train=train)
32
+ dataloader = DataLoader(
33
+ dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True,
34
+ )
35
+ return dataloader
36
+
37
+ def train_dataloader(self) -> DataLoader:
38
+ return self._get_dataloder(self.train_config, train=True)
39
+
40
+ def val_dataloader(self) -> DataLoader:
41
+ return self._get_dataloder(self.val_config, train=False)
42
+
43
+
44
+ class VocosDataset(Dataset):
45
+ def __init__(self, cfg: DataConfig, train: bool):
46
+ with open(cfg.filelist_path) as f:
47
+ self.filelist = f.read().splitlines()
48
+ self.sampling_rate = cfg.sampling_rate
49
+ self.num_samples = cfg.num_samples
50
+ self.train = train
51
+
52
+ def __len__(self) -> int:
53
+ return len(self.filelist)
54
+
55
+ def __getitem__(self, index: int) -> torch.Tensor:
56
+ audio_path = self.filelist[index]
57
+ # y, sr = torchaudio.load(audio_path)
58
+ # print(audio_path,"111")
59
+ y1, sr = soundfile.read(audio_path)
60
+ # y1, sr = librosa.load(audio_path,sr=None)
61
+ y = torch.tensor(y1).float().unsqueeze(0)
62
+ # if y.size(0) > 1:
63
+ # # mix to mono
64
+ # y = y.mean(dim=0, keepdim=True)
65
+ if y.ndim > 2:
66
+ # mix to mono
67
+ # print("有问题哈,数据处理部分")
68
+ y = y.mean(dim=-1, keepdim=False)
69
+ gain = np.random.uniform(-1, -6) if self.train else -3
70
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]])
71
+ if sr != self.sampling_rate:
72
+ y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate)
73
+ if y.size(-1) < self.num_samples:
74
+ pad_length = self.num_samples - y.size(-1)
75
+ padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1))
76
+ y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1)
77
+ elif self.train:
78
+ start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1)
79
+ y = y[:, start : start + self.num_samples]
80
+ else:
81
+ # During validation, take always the first segment for determinism
82
+ y = y[:, : self.num_samples]
83
+
84
+ return y[0]
decoder/discriminator_dac.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ # from audiotools import AudioSignal
5
+ # from audiotools import ml
6
+ # from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from collections import namedtuple
11
+
12
+ STFTParams = namedtuple(
13
+ "STFTParams",
14
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
15
+ )
16
+
17
+ STFTParams.__new__.__defaults__ = (None, None, None, None, None)
18
+
19
+
20
+ def WNConv1d(*args, **kwargs):
21
+ act = kwargs.pop("act", True)
22
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
23
+ if not act:
24
+ return conv
25
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
26
+
27
+
28
+ def WNConv2d(*args, **kwargs):
29
+ act = kwargs.pop("act", True)
30
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
31
+ if not act:
32
+ return conv
33
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
34
+
35
+
36
+ class MPD(nn.Module):
37
+ def __init__(self, period):
38
+ super().__init__()
39
+ self.period = period
40
+ self.convs = nn.ModuleList(
41
+ [
42
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
43
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
44
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
45
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
46
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
47
+ ]
48
+ )
49
+ self.conv_post = WNConv2d(
50
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
51
+ )
52
+
53
+ def pad_to_period(self, x):
54
+ t = x.shape[-1]
55
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
56
+ return x
57
+
58
+ def forward(self, x):
59
+ fmap = []
60
+
61
+ x = self.pad_to_period(x)
62
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
63
+
64
+ for layer in self.convs:
65
+ x = layer(x)
66
+ fmap.append(x)
67
+
68
+ x = self.conv_post(x)
69
+ fmap.append(x)
70
+
71
+ return fmap
72
+
73
+
74
+ class MSD(nn.Module):
75
+ def __init__(self, rate: int = 1, sample_rate: int = 24000):
76
+ super().__init__()
77
+ self.convs = nn.ModuleList(
78
+ [
79
+ WNConv1d(1, 16, 15, 1, padding=7),
80
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
81
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
82
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
83
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
84
+ WNConv1d(1024, 1024, 5, 1, padding=2),
85
+ ]
86
+ )
87
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
88
+ self.sample_rate = sample_rate
89
+ self.rate = rate
90
+
91
+ def forward(self, x):
92
+ # x = AudioSignal(x, self.sample_rate)
93
+ # x.resample(self.sample_rate // self.rate)
94
+ # x = x.audio_data
95
+
96
+ fmap = []
97
+
98
+ for l in self.convs:
99
+ x = l(x)
100
+ fmap.append(x)
101
+ x = self.conv_post(x)
102
+ fmap.append(x)
103
+
104
+ return fmap
105
+
106
+
107
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
108
+
109
+
110
+ class MRD(nn.Module):
111
+ def __init__(
112
+ self,
113
+ window_length: int,
114
+ hop_factor: float = 0.25,
115
+ sample_rate: int = 24000,
116
+ bands: list = BANDS,
117
+ ):
118
+ """Complex multi-band spectrogram discriminator.
119
+ Parameters
120
+ ----------
121
+ window_length : int
122
+ Window length of STFT.
123
+ hop_factor : float, optional
124
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
125
+ sample_rate : int, optional
126
+ Sampling rate of audio in Hz, by default 24000
127
+ bands : list, optional
128
+ Bands to run discriminator over.
129
+ """
130
+ super().__init__()
131
+
132
+ self.window_length = window_length
133
+ self.hop_factor = hop_factor
134
+ self.sample_rate = sample_rate
135
+ self.stft_params = STFTParams(
136
+ window_length=window_length,
137
+ hop_length=int(window_length * hop_factor),
138
+ match_stride=True,
139
+ )
140
+
141
+ n_fft = window_length // 2 + 1
142
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
143
+ self.bands = bands
144
+ self.n_fft = window_length
145
+
146
+ ch = 32
147
+ convs = lambda: nn.ModuleList(
148
+ [
149
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
150
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
151
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
152
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
153
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
154
+ ]
155
+ )
156
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
157
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
158
+
159
+ def spectrogram(self, x):
160
+ # x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
161
+ # x = torch.view_as_real(x.stft())
162
+
163
+ # x.squeeze(0).stft(n_fft=1024,win_length=1024,return_complex=True).size()
164
+ # breakpoint()
165
+ if x.size(0)==1:
166
+ # x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.window_length,return_complex=True).unsqueeze(0))
167
+ x = torch.view_as_real(x.squeeze(0).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(0))
168
+ else:
169
+ # x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.window_length,return_complex=True).unsqueeze(1))
170
+ x = torch.view_as_real(x.squeeze(1).stft(n_fft=self.n_fft,return_complex=True).unsqueeze(1))
171
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
172
+ # Split into bands
173
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
174
+ return x_bands
175
+
176
+ def forward(self, x):
177
+ x_bands = self.spectrogram(x)
178
+ fmap = []
179
+
180
+ x = []
181
+ for band, stack in zip(x_bands, self.band_convs):
182
+ for layer in stack:
183
+ band = layer(band)
184
+ fmap.append(band)
185
+ x.append(band)
186
+
187
+ x = torch.cat(x, dim=-1)
188
+ x = self.conv_post(x)
189
+ fmap.append(x)
190
+
191
+ return fmap
192
+
193
+
194
+ # class DACDiscriminator(ml.BaseModel):
195
+ class DACDiscriminator(nn.Module):
196
+ def __init__(
197
+ self,
198
+ rates: list = [],
199
+ periods: list = [2, 3, 5, 7, 11],
200
+ fft_sizes: list = [2048, 1024, 512],
201
+ sample_rate: int = 24000,
202
+ bands: list = BANDS,
203
+ ):
204
+ """Discriminator that combines multiple discriminators.
205
+
206
+ Parameters
207
+ ----------
208
+ rates : list, optional
209
+ sampling rates (in Hz) to run MSD at, by default []
210
+ If empty, MSD is not used.
211
+ periods : list, optional
212
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
213
+ fft_sizes : list, optional
214
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
215
+ sample_rate : int, optional
216
+ Sampling rate of audio in Hz, by default 24000
217
+ bands : list, optional
218
+ Bands to run MRD at, by default `BANDS`
219
+ """
220
+ super().__init__()
221
+ discs = []
222
+ discs += [MPD(p) for p in periods]
223
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
224
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
225
+ self.discriminators = nn.ModuleList(discs)
226
+
227
+ def preprocess(self, y):
228
+ # Remove DC offset
229
+ y = y - y.mean(dim=-1, keepdims=True)
230
+ # Peak normalize the volume of input audio
231
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
232
+ return y
233
+
234
+ def forward(self, x):
235
+ x = self.preprocess(x)
236
+ fmaps = [d(x) for d in self.discriminators]
237
+ return fmaps
238
+
239
+
240
+ if __name__ == "__main__":
241
+ disc = DACDiscriminator()
242
+ x = torch.zeros(1, 1, 24000)
243
+ results = disc(x)
244
+ breakpoint()
245
+ for i, result in enumerate(results):
246
+ print(f"disc{i}")
247
+ for i, r in enumerate(result):
248
+ print(r.shape, r.mean(), r.min(), r.max())
249
+ print("00")
decoder/discriminators.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Conv2d
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ class MultiPeriodDiscriminator(nn.Module):
10
+ """
11
+ Multi-Period Discriminator module adapted from https://github.com/jik876/hifi-gan.
12
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
13
+
14
+ Args:
15
+ periods (tuple[int]): Tuple of periods for each discriminator.
16
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
17
+ Defaults to None.
18
+ """
19
+
20
+ def __init__(self, periods: Tuple[int] = (2, 3, 5, 7, 11), num_embeddings: int = None):
21
+ super().__init__()
22
+ self.discriminators = nn.ModuleList([DiscriminatorP(period=p, num_embeddings=num_embeddings) for p in periods])
23
+
24
+ def forward(
25
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
26
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
27
+ y_d_rs = []
28
+ y_d_gs = []
29
+ fmap_rs = []
30
+ fmap_gs = []
31
+ for d in self.discriminators:
32
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
33
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
34
+ y_d_rs.append(y_d_r)
35
+ fmap_rs.append(fmap_r)
36
+ y_d_gs.append(y_d_g)
37
+ fmap_gs.append(fmap_g)
38
+
39
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
40
+
41
+
42
+ class DiscriminatorP(nn.Module):
43
+ def __init__(
44
+ self,
45
+ period: int,
46
+ in_channels: int = 1,
47
+ kernel_size: int = 5,
48
+ stride: int = 3,
49
+ lrelu_slope: float = 0.1,
50
+ num_embeddings: int = None,
51
+ ):
52
+ super().__init__()
53
+ self.period = period
54
+ self.convs = nn.ModuleList(
55
+ [
56
+ weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
57
+ weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
58
+ weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
59
+ weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
60
+ weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
61
+ ]
62
+ )
63
+ if num_embeddings is not None:
64
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=1024)
65
+ torch.nn.init.zeros_(self.emb.weight)
66
+
67
+ self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
68
+ self.lrelu_slope = lrelu_slope
69
+
70
+ def forward(
71
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
72
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
73
+ x = x.unsqueeze(1)
74
+ fmap = []
75
+ # 1d to 2d
76
+ b, c, t = x.shape
77
+ if t % self.period != 0: # pad first
78
+ n_pad = self.period - (t % self.period)
79
+ x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
80
+ t = t + n_pad
81
+ x = x.view(b, c, t // self.period, self.period)
82
+
83
+ for i, l in enumerate(self.convs):
84
+ x = l(x)
85
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
86
+ if i > 0:
87
+ fmap.append(x)
88
+ if cond_embedding_id is not None:
89
+ emb = self.emb(cond_embedding_id)
90
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
91
+ else:
92
+ h = 0
93
+ x = self.conv_post(x)
94
+ fmap.append(x)
95
+ x += h
96
+ x = torch.flatten(x, 1, -1)
97
+
98
+ return x, fmap
99
+
100
+
101
+ class MultiResolutionDiscriminator(nn.Module):
102
+ def __init__(
103
+ self,
104
+ resolutions: Tuple[Tuple[int, int, int]] = ((1024, 256, 1024), (2048, 512, 2048), (512, 128, 512)),
105
+ num_embeddings: int = None,
106
+ ):
107
+ """
108
+ Multi-Resolution Discriminator module adapted from https://github.com/mindslab-ai/univnet.
109
+ Additionally, it allows incorporating conditional information with a learned embeddings table.
110
+
111
+ Args:
112
+ resolutions (tuple[tuple[int, int, int]]): Tuple of resolutions for each discriminator.
113
+ Each resolution should be a tuple of (n_fft, hop_length, win_length).
114
+ num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
115
+ Defaults to None.
116
+ """
117
+ super().__init__()
118
+ self.discriminators = nn.ModuleList(
119
+ [DiscriminatorR(resolution=r, num_embeddings=num_embeddings) for r in resolutions]
120
+ )
121
+
122
+ def forward(
123
+ self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
124
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
125
+ y_d_rs = []
126
+ y_d_gs = []
127
+ fmap_rs = []
128
+ fmap_gs = []
129
+
130
+ for d in self.discriminators:
131
+ y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
132
+ y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
133
+ y_d_rs.append(y_d_r)
134
+ fmap_rs.append(fmap_r)
135
+ y_d_gs.append(y_d_g)
136
+ fmap_gs.append(fmap_g)
137
+
138
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
139
+
140
+
141
+ class DiscriminatorR(nn.Module):
142
+ def __init__(
143
+ self,
144
+ resolution: Tuple[int, int, int],
145
+ channels: int = 64,
146
+ in_channels: int = 1,
147
+ num_embeddings: int = None,
148
+ lrelu_slope: float = 0.1,
149
+ ):
150
+ super().__init__()
151
+ self.resolution = resolution
152
+ self.in_channels = in_channels
153
+ self.lrelu_slope = lrelu_slope
154
+ self.convs = nn.ModuleList(
155
+ [
156
+ weight_norm(nn.Conv2d(in_channels, channels, kernel_size=(7, 5), stride=(2, 2), padding=(3, 2))),
157
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 1), padding=(2, 1))),
158
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=(5, 3), stride=(2, 2), padding=(2, 1))),
159
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 1), padding=1)),
160
+ weight_norm(nn.Conv2d(channels, channels, kernel_size=3, stride=(2, 2), padding=1)),
161
+ ]
162
+ )
163
+ if num_embeddings is not None:
164
+ self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
165
+ torch.nn.init.zeros_(self.emb.weight)
166
+ self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), padding=(1, 1)))
167
+
168
+ def forward(
169
+ self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None
170
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
171
+ fmap = []
172
+ x = self.spectrogram(x)
173
+ x = x.unsqueeze(1)
174
+ for l in self.convs:
175
+ x = l(x)
176
+ x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
177
+ fmap.append(x)
178
+ if cond_embedding_id is not None:
179
+ emb = self.emb(cond_embedding_id)
180
+ h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
181
+ else:
182
+ h = 0
183
+ x = self.conv_post(x)
184
+ fmap.append(x)
185
+ x += h
186
+ x = torch.flatten(x, 1, -1)
187
+
188
+ return x, fmap
189
+
190
+ def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
191
+ n_fft, hop_length, win_length = self.resolution
192
+ magnitude_spectrogram = torch.stft(
193
+ x,
194
+ n_fft=n_fft,
195
+ hop_length=hop_length,
196
+ win_length=win_length,
197
+ window=None, # interestingly rectangular window kind of works here
198
+ center=True,
199
+ return_complex=True,
200
+ ).abs()
201
+
202
+ return magnitude_spectrogram
decoder/experiment.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torchaudio
7
+ import transformers
8
+ import yaml
9
+
10
+ from decoder.discriminator_dac import DACDiscriminator
11
+
12
+ from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
13
+ from decoder.feature_extractors import FeatureExtractor
14
+ from decoder.heads import FourierHead
15
+ from decoder.helpers import plot_spectrogram_to_numpy
16
+ from decoder.loss import DiscriminatorLoss, GeneratorLoss, FeatureMatchingLoss, MelSpecReconstructionLoss, DACGANLoss
17
+ from decoder.models import Backbone
18
+ from decoder.modules import safe_log
19
+ from decoder.pretrained_model import instantiate_class
20
+
21
+
22
+ class VocosExp(pl.LightningModule):
23
+ # noinspection PyUnusedLocal
24
+ def __init__(
25
+ self,
26
+ feature_extractor: FeatureExtractor,
27
+ backbone: Backbone,
28
+ head: FourierHead,
29
+ resume_config: str,
30
+ resume_model: str,
31
+ sample_rate: int = 24000,
32
+ initial_learning_rate: float = 2e-4,
33
+ num_warmup_steps: int = 0,
34
+ mel_loss_coeff: float = 45,
35
+ mrd_loss_coeff: float = 1.0,
36
+ pretrain_mel_steps: int = 0,
37
+ decay_mel_coeff: bool = False,
38
+ evaluate_utmos: bool = False,
39
+ evaluate_pesq: bool = False,
40
+ evaluate_periodicty: bool = False,
41
+ resume: bool = False,
42
+ ):
43
+ """
44
+ Args:
45
+ feature_extractor (FeatureExtractor): An instance of FeatureExtractor to extract features from audio signals.
46
+ backbone (Backbone): An instance of Backbone model.
47
+ head (FourierHead): An instance of Fourier head to generate spectral coefficients and reconstruct a waveform.
48
+ sample_rate (int): Sampling rate of the audio signals.
49
+ initial_learning_rate (float): Initial learning rate for the optimizer.
50
+ num_warmup_steps (int): Number of steps for the warmup phase of learning rate scheduler. Default is 0.
51
+ mel_loss_coeff (float, optional): Coefficient for Mel-spectrogram loss in the loss function. Default is 45.
52
+ mrd_loss_coeff (float, optional): Coefficient for Multi Resolution Discriminator loss. Default is 1.0.
53
+ pretrain_mel_steps (int, optional): Number of steps to pre-train the model without the GAN objective. Default is 0.
54
+ decay_mel_coeff (bool, optional): If True, the Mel-spectrogram loss coefficient is decayed during training. Default is False.
55
+ evaluate_utmos (bool, optional): If True, UTMOS scores are computed for each validation run.
56
+ evaluate_pesq (bool, optional): If True, PESQ scores are computed for each validation run.
57
+ evaluate_periodicty (bool, optional): If True, periodicity scores are computed for each validation run.
58
+ """
59
+ super().__init__()
60
+ self.save_hyperparameters(ignore=["feature_extractor", "backbone", "head"])
61
+
62
+ self.feature_extractor = feature_extractor
63
+ self.backbone = backbone
64
+ self.head = head
65
+
66
+ self.resume_config = resume_config
67
+ self.resume_model = resume_model
68
+ self.resume = resume
69
+
70
+ self.multiperioddisc = MultiPeriodDiscriminator()
71
+ self.multiresddisc = MultiResolutionDiscriminator()
72
+
73
+
74
+ self.dac = DACDiscriminator()
75
+
76
+ self.dacdiscriminator = DACGANLoss(self.dac)
77
+
78
+ self.disc_loss = DiscriminatorLoss()
79
+ self.gen_loss = GeneratorLoss()
80
+ self.feat_matching_loss = FeatureMatchingLoss()
81
+ self.melspec_loss = MelSpecReconstructionLoss(sample_rate=sample_rate)
82
+
83
+ self.train_discriminator = False
84
+ self.base_mel_coeff = self.mel_loss_coeff = mel_loss_coeff
85
+
86
+ def configure_optimizers(self):
87
+ disc_params = [
88
+ {"params": self.multiperioddisc.parameters()},
89
+ {"params": self.multiresddisc.parameters()},
90
+ {"params": self.dac.parameters()},
91
+ ]
92
+ gen_params = [
93
+ {"params": self.feature_extractor.parameters()},
94
+ {"params": self.backbone.parameters()},
95
+ {"params": self.head.parameters()},
96
+ ]
97
+
98
+ opt_disc = torch.optim.AdamW(disc_params, lr=self.hparams.initial_learning_rate)
99
+ opt_gen = torch.optim.AdamW(gen_params, lr=self.hparams.initial_learning_rate)
100
+
101
+ max_steps = self.trainer.max_steps // 2 # Max steps per optimizer
102
+ scheduler_disc = transformers.get_cosine_schedule_with_warmup(
103
+ opt_disc, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
104
+ )
105
+ scheduler_gen = transformers.get_cosine_schedule_with_warmup(
106
+ opt_gen, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=max_steps,
107
+ )
108
+
109
+ return (
110
+ [opt_disc, opt_gen],
111
+ [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}],
112
+ )
113
+
114
+ def forward(self, audio_input, **kwargs):
115
+ features, _, commit_loss = self.feature_extractor(audio_input, **kwargs)
116
+ # print('1111', self.feature_extractor.state_dict()['encodec.decoder.model.3.convtr.convtr.weight_g'])
117
+ x = self.backbone(features, **kwargs)
118
+ audio_output = self.head(x)
119
+ return audio_output, commit_loss
120
+
121
+ def training_step(self, batch, batch_idx, optimizer_idx, **kwargs):
122
+ audio_input = batch
123
+
124
+ # train discriminator
125
+ if optimizer_idx == 0 and self.train_discriminator:
126
+ with torch.no_grad():
127
+ audio_hat, _ = self(audio_input, **kwargs)
128
+
129
+
130
+ loss_dac=self.dacdiscriminator.discriminator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
131
+
132
+ real_score_mp, gen_score_mp, _, _ = self.multiperioddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
133
+ real_score_mrd, gen_score_mrd, _, _ = self.multiresddisc(y=audio_input, y_hat=audio_hat, **kwargs,)
134
+ loss_mp, loss_mp_real, _ = self.disc_loss(
135
+ disc_real_outputs=real_score_mp, disc_generated_outputs=gen_score_mp
136
+ )
137
+ loss_mrd, loss_mrd_real, _ = self.disc_loss(
138
+ disc_real_outputs=real_score_mrd, disc_generated_outputs=gen_score_mrd
139
+ )
140
+ loss_mp /= len(loss_mp_real)
141
+ loss_mrd /= len(loss_mrd_real)
142
+ loss = loss_mp + self.hparams.mrd_loss_coeff * loss_mrd + loss_dac
143
+
144
+ self.log("discriminator/total", loss, prog_bar=True)
145
+ self.log("discriminator/multi_period_loss", loss_mp)
146
+ self.log("discriminator/multi_res_loss", loss_mrd)
147
+ self.log("discriminator/dac", loss_dac)
148
+ return loss
149
+
150
+ # train generator
151
+ if optimizer_idx == 1:
152
+ audio_hat, commit_loss = self(audio_input, **kwargs)
153
+ if self.train_discriminator:
154
+
155
+ loss_dac_1,loss_dac_2 = self.dacdiscriminator.generator_loss(audio_hat.unsqueeze(1),audio_input.unsqueeze(1))
156
+ _, gen_score_mp, fmap_rs_mp, fmap_gs_mp = self.multiperioddisc(
157
+ y=audio_input, y_hat=audio_hat, **kwargs,
158
+ )
159
+ _, gen_score_mrd, fmap_rs_mrd, fmap_gs_mrd = self.multiresddisc(
160
+ y=audio_input, y_hat=audio_hat, **kwargs,
161
+ )
162
+ loss_gen_mp, list_loss_gen_mp = self.gen_loss(disc_outputs=gen_score_mp)
163
+ loss_gen_mrd, list_loss_gen_mrd = self.gen_loss(disc_outputs=gen_score_mrd)
164
+ loss_gen_mp = loss_gen_mp / len(list_loss_gen_mp)
165
+ loss_gen_mrd = loss_gen_mrd / len(list_loss_gen_mrd)
166
+ loss_fm_mp = self.feat_matching_loss(fmap_r=fmap_rs_mp, fmap_g=fmap_gs_mp) / len(fmap_rs_mp)
167
+ loss_fm_mrd = self.feat_matching_loss(fmap_r=fmap_rs_mrd, fmap_g=fmap_gs_mrd) / len(fmap_rs_mrd)
168
+
169
+ self.log("generator/multi_period_loss", loss_gen_mp)
170
+ self.log("generator/multi_res_loss", loss_gen_mrd)
171
+ self.log("generator/feature_matching_mp", loss_fm_mp)
172
+ self.log("generator/feature_matching_mrd", loss_fm_mrd)
173
+ self.log("generator/loss_dac_1", loss_dac_1)
174
+ self.log("generator/loss_dac_2", loss_dac_2)
175
+ else:
176
+ loss_gen_mp = loss_gen_mrd = loss_fm_mp = loss_fm_mrd = 0
177
+
178
+ mel_loss = self.melspec_loss(audio_hat, audio_input)
179
+ loss = (
180
+ loss_gen_mp
181
+ + self.hparams.mrd_loss_coeff * loss_gen_mrd
182
+ + loss_fm_mp
183
+ + self.hparams.mrd_loss_coeff * loss_fm_mrd
184
+ + self.mel_loss_coeff * mel_loss
185
+ + 1000 * commit_loss
186
+ + loss_dac_1
187
+ + loss_dac_2
188
+ )
189
+
190
+ self.log("generator/total_loss", loss, prog_bar=True)
191
+ self.log("mel_loss_coeff", self.mel_loss_coeff)
192
+ self.log("generator/mel_loss", mel_loss)
193
+ self.log("commit_loss", commit_loss)
194
+
195
+ if self.global_step % 1000 == 0 and self.global_rank == 0:
196
+ self.logger.experiment.add_audio(
197
+ "train/audio_in", audio_input[0].data.cpu(), self.global_step, self.hparams.sample_rate
198
+ )
199
+ self.logger.experiment.add_audio(
200
+ "train/audio_pred", audio_hat[0].data.cpu(), self.global_step, self.hparams.sample_rate
201
+ )
202
+ with torch.no_grad():
203
+ mel = safe_log(self.melspec_loss.mel_spec(audio_input[0]))
204
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_hat[0]))
205
+ self.logger.experiment.add_image(
206
+ "train/mel_target",
207
+ plot_spectrogram_to_numpy(mel.data.cpu().numpy()),
208
+ self.global_step,
209
+ dataformats="HWC",
210
+ )
211
+ self.logger.experiment.add_image(
212
+ "train/mel_pred",
213
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
214
+ self.global_step,
215
+ dataformats="HWC",
216
+ )
217
+
218
+ return loss
219
+
220
+ def on_validation_epoch_start(self):
221
+ if self.hparams.evaluate_utmos:
222
+ from metrics.UTMOS import UTMOSScore
223
+
224
+ if not hasattr(self, "utmos_model"):
225
+ self.utmos_model = UTMOSScore(device=self.device)
226
+
227
+ def validation_step(self, batch, batch_idx, **kwargs):
228
+ audio_input = batch
229
+ audio_hat, commit_loss = self(audio_input, **kwargs)
230
+
231
+ audio_16_khz = torchaudio.functional.resample(audio_input, orig_freq=self.hparams.sample_rate, new_freq=16000)
232
+ audio_hat_16khz = torchaudio.functional.resample(audio_hat, orig_freq=self.hparams.sample_rate, new_freq=16000)
233
+
234
+ if self.hparams.evaluate_periodicty:
235
+ from metrics.periodicity import calculate_periodicity_metrics
236
+
237
+ periodicity_loss, pitch_loss, f1_score = calculate_periodicity_metrics(audio_16_khz, audio_hat_16khz)
238
+ else:
239
+ periodicity_loss = pitch_loss = f1_score = 0
240
+
241
+ if self.hparams.evaluate_utmos:
242
+ utmos_score = self.utmos_model.score(audio_hat_16khz.unsqueeze(1)).mean()
243
+ else:
244
+ utmos_score = torch.zeros(1, device=self.device)
245
+
246
+ if self.hparams.evaluate_pesq:
247
+ from pesq import pesq
248
+
249
+ pesq_score = 0
250
+ for ref, deg in zip(audio_16_khz.cpu().numpy(), audio_hat_16khz.cpu().numpy()):
251
+ pesq_score += pesq(16000, ref, deg, "wb", on_error=1)
252
+ pesq_score /= len(audio_16_khz)
253
+ pesq_score = torch.tensor(pesq_score)
254
+ else:
255
+ pesq_score = torch.zeros(1, device=self.device)
256
+
257
+ mel_loss = self.melspec_loss(audio_hat.unsqueeze(1), audio_input.unsqueeze(1))
258
+ total_loss = mel_loss + (5 - utmos_score) + (5 - pesq_score) + 1000 * commit_loss
259
+
260
+ return {
261
+ "val_loss": total_loss,
262
+ "mel_loss": mel_loss,
263
+ "utmos_score": utmos_score,
264
+ "pesq_score": pesq_score,
265
+ "periodicity_loss": periodicity_loss,
266
+ "pitch_loss": pitch_loss,
267
+ "f1_score": f1_score,
268
+ "audio_input": audio_input[0],
269
+ "audio_pred": audio_hat[0],
270
+ }
271
+
272
+ def validation_epoch_end(self, outputs):
273
+ if self.global_rank == 0:
274
+ *_, audio_in, audio_pred = outputs[0].values()
275
+ self.logger.experiment.add_audio(
276
+ "val_in", audio_in.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
277
+ )
278
+ self.logger.experiment.add_audio(
279
+ "val_pred", audio_pred.data.cpu().numpy(), self.global_step, self.hparams.sample_rate
280
+ )
281
+ mel_target = safe_log(self.melspec_loss.mel_spec(audio_in))
282
+ mel_hat = safe_log(self.melspec_loss.mel_spec(audio_pred))
283
+ self.logger.experiment.add_image(
284
+ "val_mel_target",
285
+ plot_spectrogram_to_numpy(mel_target.data.cpu().numpy()),
286
+ self.global_step,
287
+ dataformats="HWC",
288
+ )
289
+ self.logger.experiment.add_image(
290
+ "val_mel_hat",
291
+ plot_spectrogram_to_numpy(mel_hat.data.cpu().numpy()),
292
+ self.global_step,
293
+ dataformats="HWC",
294
+ )
295
+ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
296
+ mel_loss = torch.stack([x["mel_loss"] for x in outputs]).mean()
297
+ utmos_score = torch.stack([x["utmos_score"] for x in outputs]).mean()
298
+ pesq_score = torch.stack([x["pesq_score"] for x in outputs]).mean()
299
+ periodicity_loss = np.array([x["periodicity_loss"] for x in outputs]).mean()
300
+ pitch_loss = np.array([x["pitch_loss"] for x in outputs]).mean()
301
+ f1_score = np.array([x["f1_score"] for x in outputs]).mean()
302
+
303
+ self.log("val_loss", avg_loss, sync_dist=True)
304
+ self.log("val/mel_loss", mel_loss, sync_dist=True)
305
+ self.log("val/utmos_score", utmos_score, sync_dist=True)
306
+ self.log("val/pesq_score", pesq_score, sync_dist=True)
307
+ self.log("val/periodicity_loss", periodicity_loss, sync_dist=True)
308
+ self.log("val/pitch_loss", pitch_loss, sync_dist=True)
309
+ self.log("val/f1_score", f1_score, sync_dist=True)
310
+
311
+ @property
312
+ def global_step(self):
313
+ """
314
+ Override global_step so that it returns the total number of batches processed
315
+ """
316
+ return self.trainer.fit_loop.epoch_loop.total_batch_idx
317
+
318
+ def on_train_batch_start(self, *args):
319
+ if self.global_step >= self.hparams.pretrain_mel_steps:
320
+ self.train_discriminator = True
321
+ else:
322
+ self.train_discriminator = False
323
+
324
+ def on_train_batch_end(self, *args):
325
+ def mel_loss_coeff_decay(current_step, num_cycles=0.5):
326
+ max_steps = self.trainer.max_steps // 2
327
+ if current_step < self.hparams.num_warmup_steps:
328
+ return 1.0
329
+ progress = float(current_step - self.hparams.num_warmup_steps) / float(
330
+ max(1, max_steps - self.hparams.num_warmup_steps)
331
+ )
332
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
333
+
334
+ if self.hparams.decay_mel_coeff:
335
+ self.mel_loss_coeff = self.base_mel_coeff * mel_loss_coeff_decay(self.global_step + 1)
336
+
337
+
338
+ class WavTokenizer(VocosExp):
339
+ """
340
+ WavTokenizer is a subclass of VocosExp that overrides the parent experiment to function as a conditional GAN.
341
+ It manages an additional `bandwidth_id` attribute, which denotes a learnable embedding corresponding to
342
+ a specific bandwidth value of EnCodec. During training, a random bandwidth_id is generated for each step,
343
+ while during validation, a fixed bandwidth_id is used.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ feature_extractor: FeatureExtractor,
349
+ backbone: Backbone,
350
+ head: FourierHead,
351
+ resume_config: str,
352
+ resume_model: str,
353
+ sample_rate: int = 24000,
354
+ initial_learning_rate: float = 2e-4,
355
+ num_warmup_steps: int = 0,
356
+ mel_loss_coeff: float = 45,
357
+ mrd_loss_coeff: float = 1.0,
358
+ pretrain_mel_steps: int = 0,
359
+ decay_mel_coeff: bool = False,
360
+ evaluate_utmos: bool = False,
361
+ evaluate_pesq: bool = False,
362
+ evaluate_periodicty: bool = False,
363
+ resume: bool = False,
364
+ ):
365
+ super().__init__(
366
+ feature_extractor,
367
+ backbone,
368
+ head,
369
+ resume_config,
370
+ resume_model,
371
+ sample_rate,
372
+ initial_learning_rate,
373
+ num_warmup_steps,
374
+ mel_loss_coeff,
375
+ mrd_loss_coeff,
376
+ pretrain_mel_steps,
377
+ decay_mel_coeff,
378
+ evaluate_utmos,
379
+ evaluate_pesq,
380
+ evaluate_periodicty,
381
+ resume
382
+ )
383
+ # Override with conditional discriminators
384
+ # VocosExp.__init__(self, feature_extractor, backbone, head, resume_config, resume_model)
385
+ # if self.resume:
386
+ # VocosExp.load_from_checkpoint(self.resume_model)
387
+ self.multiperioddisc = MultiPeriodDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
388
+ self.multiresddisc = MultiResolutionDiscriminator(num_embeddings=len(self.feature_extractor.bandwidths))
389
+ self.dac = DACDiscriminator()
390
+ if self.resume:
391
+ print('加载预训练模型:', self.resume_model)
392
+ # with open(self.resume_config, "r") as f:
393
+ # config = yaml.safe_load(f)
394
+ # feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
395
+ # backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
396
+ # head = instantiate_class(args=(), init=config['model']['init_args']["head"])
397
+
398
+ # 不加载量化器部分权重
399
+ state_dict_raw = torch.load(self.resume_model, map_location=self.device)['state_dict']
400
+ state_dict_fa_qa = dict()
401
+ state_dict_fa_en = dict()
402
+ state_dict_fa_de = dict()
403
+ state_dict_bb = dict()
404
+ state_dict_hd = dict()
405
+ state_dict_mp = dict()
406
+ state_dict_mr = dict()
407
+ state_dict_dac = dict()
408
+ for k, v in state_dict_raw.items():
409
+ # breakpoint()
410
+ if k.startswith('feature_extractor.encodec.quantizer'):
411
+ # breakpoint()
412
+ # print("*****",k)
413
+ ss = k[46:48]
414
+ if ss[-1] == '.':
415
+ num = int(ss[0])
416
+ # print("num,k",num,k[36:])
417
+ if num <= 7:
418
+ state_dict_fa_qa[k[36:]] = v
419
+ if k.startswith('feature_extractor.encodec.encoder'):
420
+ state_dict_fa_en[k[34:]] = v
421
+ if k.startswith('feature_extractor.encodec.decoder'):
422
+ state_dict_fa_de[k[34:]] = v
423
+ if k.startswith('backbone.'):
424
+ state_dict_bb[k[9:]] = v
425
+ if k.startswith('head.'):
426
+ state_dict_hd[k[5:]] = v
427
+ if k.startswith('multiperioddisc.'):
428
+ state_dict_mp[k[16:]] = v
429
+ if k.startswith('multiresddisc.'):
430
+ state_dict_mr[k[14:]] = v
431
+ if k.startswith('dac.'):
432
+ state_dict_dac[k[4:]] = v
433
+ # breakpoint()
434
+ # feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
435
+ feature_extractor.encodec.encoder.load_state_dict(state_dict_fa_en, strict=True)
436
+ feature_extractor.encodec.decoder.load_state_dict(state_dict_fa_de, strict=True)
437
+ feature_extractor.encodec.quantizer.load_state_dict(state_dict_fa_qa, strict=True)
438
+ backbone.load_state_dict(state_dict_bb, strict=True)
439
+ head.load_state_dict(state_dict_hd, strict=True)
440
+ self.feature_extractor = feature_extractor.to(self.device)
441
+ self.backbone = backbone.to(self.device)
442
+ self.head = head.to(self.device)
443
+ self.multiperioddisc.load_state_dict(state_dict_mp, strict=True)
444
+ self.multiresddisc.load_state_dict(state_dict_mr, strict=True)
445
+ self.dac.load_state_dict(state_dict_dac, strict=True)
446
+
447
+ def training_step(self, *args):
448
+ # print('-------------------train--------------------')
449
+ # if self.global_rank == 0 and self.resume:
450
+ # config_path = self.resume_config
451
+ # model_path = self.resume_model
452
+ # self.pretrained_load(config_path, model_path)
453
+ # print('加载预训练模型:', model_path)
454
+ bandwidth_id = torch.randint(low=0, high=len(self.feature_extractor.bandwidths), size=(1,), device=self.device,)
455
+ output = super().training_step(*args, bandwidth_id=bandwidth_id)
456
+ return output
457
+
458
+ def validation_step(self, *args):
459
+ # print('-------------------valid--------------------')
460
+ bandwidth_id = torch.tensor([0], device=self.device)
461
+ output = super().validation_step(*args, bandwidth_id=bandwidth_id)
462
+ return output
463
+
464
+ def validation_epoch_end(self, outputs):
465
+ if self.global_rank == 0:
466
+ *_, audio_in, _ = outputs[0].values()
467
+ # Resynthesis with encodec for reference
468
+ self.feature_extractor.encodec.set_target_bandwidth(self.feature_extractor.bandwidths[0])
469
+ encodec_audio = self.feature_extractor.encodec(audio_in[None, None, :])
470
+ self.logger.experiment.add_audio(
471
+ "encodec", encodec_audio[0, 0].data.cpu().numpy(), self.global_step, self.hparams.sample_rate,
472
+ )
473
+
474
+ super().validation_epoch_end(outputs)
decoder/feature_extractors.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+ import math
7
+ from decoder.modules import safe_log
8
+ from encoder.modules import SEANetEncoder, SEANetDecoder
9
+ from encoder import EncodecModel
10
+ from encoder.quantization import ResidualVectorQuantizer
11
+
12
+
13
+ class FeatureExtractor(nn.Module):
14
+ """Base class for feature extractors."""
15
+
16
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
17
+ """
18
+ Extract features from the given audio.
19
+
20
+ Args:
21
+ audio (Tensor): Input audio waveform.
22
+
23
+ Returns:
24
+ Tensor: Extracted features of shape (B, C, L), where B is the batch size,
25
+ C denotes output features, and L is the sequence length.
26
+ """
27
+ raise NotImplementedError("Subclasses must implement the forward method.")
28
+
29
+
30
+ class MelSpectrogramFeatures(FeatureExtractor):
31
+ def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
32
+ super().__init__()
33
+ if padding not in ["center", "same"]:
34
+ raise ValueError("Padding must be 'center' or 'same'.")
35
+ self.padding = padding
36
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
37
+ sample_rate=sample_rate,
38
+ n_fft=n_fft,
39
+ hop_length=hop_length,
40
+ n_mels=n_mels,
41
+ center=padding == "center",
42
+ power=1,
43
+ )
44
+
45
+ def forward(self, audio, **kwargs):
46
+ if self.padding == "same":
47
+ pad = self.mel_spec.win_length - self.mel_spec.hop_length
48
+ audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
49
+ mel = self.mel_spec(audio)
50
+ features = safe_log(mel)
51
+ return features
52
+
53
+
54
+ class EncodecFeatures(FeatureExtractor):
55
+ def __init__(
56
+ self,
57
+ encodec_model: str = "encodec_24khz",
58
+ bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
59
+ train_codebooks: bool = False,
60
+ num_quantizers: int = 1,
61
+ dowmsamples: List[int] = [6, 5, 5, 4],
62
+ vq_bins: int = 16384,
63
+ vq_kmeans: int = 800,
64
+ ):
65
+ super().__init__()
66
+
67
+ # breakpoint()
68
+ self.frame_rate = 25 # not use
69
+ # n_q = int(bandwidths[-1]*1000/(math.log2(2048) * self.frame_rate))
70
+ n_q = num_quantizers # important
71
+ encoder = SEANetEncoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
72
+ dimension=512, channels=1, n_filters=32, ratios=dowmsamples, activation='ELU',
73
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
74
+ true_skip=False, compress=2)
75
+ decoder = SEANetDecoder(causal=False, n_residual_layers=1, norm='weight_norm', pad_mode='reflect', lstm=2,
76
+ dimension=512, channels=1, n_filters=32, ratios=[8, 5, 4, 2], activation='ELU',
77
+ kernel_size=7, residual_kernel_size=3, last_kernel_size=7, dilation_base=2,
78
+ true_skip=False, compress=2)
79
+ quantizer = ResidualVectorQuantizer(dimension=512, n_q=n_q, bins=vq_bins, kmeans_iters=vq_kmeans,
80
+ decay=0.99, kmeans_init=True)
81
+
82
+ # breakpoint()
83
+ if encodec_model == "encodec_24khz":
84
+ self.encodec = EncodecModel(encoder=encoder, decoder=decoder, quantizer=quantizer,
85
+ target_bandwidths=bandwidths, sample_rate=24000, channels=1)
86
+ else:
87
+ raise ValueError(
88
+ f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz'."
89
+ )
90
+ for param in self.encodec.parameters():
91
+ param.requires_grad = True
92
+ # self.num_q = n_q
93
+ # codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
94
+ # self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
95
+ self.bandwidths = bandwidths
96
+
97
+ # @torch.no_grad()
98
+ # def get_encodec_codes(self, audio):
99
+ # audio = audio.unsqueeze(1)
100
+ # emb = self.encodec.encoder(audio)
101
+ # codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
102
+ # return codes
103
+
104
+ def forward(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
105
+ if self.training:
106
+ self.encodec.train()
107
+
108
+ audio = audio.unsqueeze(1) # audio(16,24000)
109
+
110
+ # breakpoint()
111
+
112
+ emb = self.encodec.encoder(audio)
113
+ q_res = self.encodec.quantizer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
114
+ quantized = q_res.quantized
115
+ codes = q_res.codes
116
+ commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
117
+
118
+ return quantized, codes, commit_loss
119
+
120
+ # codes = self.get_encodec_codes(audio)
121
+ # # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
122
+ # # with offsets given by the number of bins, and finally summed in a vectorized operation.
123
+ # offsets = torch.arange(
124
+ # 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
125
+ # )
126
+ # embeddings_idxs = codes + offsets.view(-1, 1, 1)
127
+ # features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
128
+ # return features.transpose(1, 2)
129
+
130
+ def infer(self, audio: torch.Tensor, bandwidth_id: torch.Tensor):
131
+ if self.training:
132
+ self.encodec.train()
133
+
134
+ audio = audio.unsqueeze(1) # audio(16,24000)
135
+ emb = self.encodec.encoder(audio)
136
+ q_res = self.encodec.quantizer.infer(emb, self.frame_rate, bandwidth=self.bandwidths[bandwidth_id])
137
+ quantized = q_res.quantized
138
+ codes = q_res.codes
139
+ commit_loss = q_res.penalty # codes(8,16,75),features(16,128,75)
140
+
141
+ return quantized, codes, commit_loss
decoder/heads.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
4
+
5
+ from decoder.spectral_ops import IMDCT, ISTFT
6
+ from decoder.modules import symexp
7
+
8
+
9
+ class FourierHead(nn.Module):
10
+ """Base class for inverse fourier modules."""
11
+
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Args:
15
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
16
+ L is the sequence length, and H denotes the model dimension.
17
+
18
+ Returns:
19
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
20
+ """
21
+ raise NotImplementedError("Subclasses must implement the forward method.")
22
+
23
+
24
+ class ISTFTHead(FourierHead):
25
+ """
26
+ ISTFT Head module for predicting STFT complex coefficients.
27
+
28
+ Args:
29
+ dim (int): Hidden dimension of the model.
30
+ n_fft (int): Size of Fourier transform.
31
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
32
+ the resolution of the input features.
33
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
34
+ """
35
+
36
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
37
+ super().__init__()
38
+ out_dim = n_fft + 2
39
+ self.out = torch.nn.Linear(dim, out_dim)
40
+ self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ """
44
+ Forward pass of the ISTFTHead module.
45
+
46
+ Args:
47
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
48
+ L is the sequence length, and H denotes the model dimension.
49
+
50
+ Returns:
51
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
52
+ """
53
+ x = self.out(x).transpose(1, 2)
54
+ mag, p = x.chunk(2, dim=1)
55
+ mag = torch.exp(mag)
56
+ mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
57
+ # wrapping happens here. These two lines produce real and imaginary value
58
+ x = torch.cos(p)
59
+ y = torch.sin(p)
60
+ # recalculating phase here does not produce anything new
61
+ # only costs time
62
+ # phase = torch.atan2(y, x)
63
+ # S = mag * torch.exp(phase * 1j)
64
+ # better directly produce the complex value
65
+ S = mag * (x + 1j * y)
66
+ audio = self.istft(S)
67
+ return audio
68
+
69
+
70
+ class IMDCTSymExpHead(FourierHead):
71
+ """
72
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
73
+
74
+ Args:
75
+ dim (int): Hidden dimension of the model.
76
+ mdct_frame_len (int): Length of the MDCT frame.
77
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
78
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
79
+ based on perceptual scaling. Defaults to None.
80
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
81
+ """
82
+
83
+ def __init__(
84
+ self, dim: int, mdct_frame_len: int, padding: str = "same", sample_rate: int = None, clip_audio: bool = False,
85
+ ):
86
+ super().__init__()
87
+ out_dim = mdct_frame_len // 2
88
+ self.out = nn.Linear(dim, out_dim)
89
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
90
+ self.clip_audio = clip_audio
91
+
92
+ if sample_rate is not None:
93
+ # optionally init the last layer following mel-scale
94
+ m_max = _hz_to_mel(sample_rate // 2)
95
+ m_pts = torch.linspace(0, m_max, out_dim)
96
+ f_pts = _mel_to_hz(m_pts)
97
+ scale = 1 - (f_pts / f_pts.max())
98
+
99
+ with torch.no_grad():
100
+ self.out.weight.mul_(scale.view(-1, 1))
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ """
104
+ Forward pass of the IMDCTSymExpHead module.
105
+
106
+ Args:
107
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
108
+ L is the sequence length, and H denotes the model dimension.
109
+
110
+ Returns:
111
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
112
+ """
113
+ x = self.out(x)
114
+ x = symexp(x)
115
+ x = torch.clip(x, min=-1e2, max=1e2) # safeguard to prevent excessively large magnitudes
116
+ audio = self.imdct(x)
117
+ if self.clip_audio:
118
+ audio = torch.clip(x, min=-1.0, max=1.0)
119
+
120
+ return audio
121
+
122
+
123
+ class IMDCTCosHead(FourierHead):
124
+ """
125
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
126
+
127
+ Args:
128
+ dim (int): Hidden dimension of the model.
129
+ mdct_frame_len (int): Length of the MDCT frame.
130
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
131
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
132
+ """
133
+
134
+ def __init__(self, dim: int, mdct_frame_len: int, padding: str = "same", clip_audio: bool = False):
135
+ super().__init__()
136
+ self.clip_audio = clip_audio
137
+ self.out = nn.Linear(dim, mdct_frame_len)
138
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ """
142
+ Forward pass of the IMDCTCosHead module.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
146
+ L is the sequence length, and H denotes the model dimension.
147
+
148
+ Returns:
149
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
150
+ """
151
+ x = self.out(x)
152
+ m, p = x.chunk(2, dim=2)
153
+ m = torch.exp(m).clip(max=1e2) # safeguard to prevent excessively large magnitudes
154
+ audio = self.imdct(m * torch.cos(p))
155
+ if self.clip_audio:
156
+ audio = torch.clip(x, min=-1.0, max=1.0)
157
+ return audio
decoder/helpers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_lightning import Callback
6
+
7
+ matplotlib.use("Agg")
8
+
9
+
10
+ def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray:
11
+ """
12
+ Save a matplotlib figure to a numpy array.
13
+
14
+ Args:
15
+ fig (Figure): Matplotlib figure object.
16
+
17
+ Returns:
18
+ ndarray: Numpy array representing the figure.
19
+ """
20
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
21
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
22
+ return data
23
+
24
+
25
+ def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray:
26
+ """
27
+ Plot a spectrogram and convert it to a numpy array.
28
+
29
+ Args:
30
+ spectrogram (ndarray): Spectrogram data.
31
+
32
+ Returns:
33
+ ndarray: Numpy array representing the plotted spectrogram.
34
+ """
35
+ spectrogram = spectrogram.astype(np.float32)
36
+ fig, ax = plt.subplots(figsize=(12, 3))
37
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
38
+ plt.colorbar(im, ax=ax)
39
+ plt.xlabel("Frames")
40
+ plt.ylabel("Channels")
41
+ plt.tight_layout()
42
+
43
+ fig.canvas.draw()
44
+ data = save_figure_to_numpy(fig)
45
+ plt.close()
46
+ return data
47
+
48
+
49
+ class GradNormCallback(Callback):
50
+ """
51
+ Callback to log the gradient norm.
52
+ """
53
+
54
+ def on_after_backward(self, trainer, model):
55
+ model.log("grad_norm", gradient_norm(model))
56
+
57
+
58
+ def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor:
59
+ """
60
+ Compute the gradient norm.
61
+
62
+ Args:
63
+ model (Module): PyTorch model.
64
+ norm_type (float, optional): Type of the norm. Defaults to 2.0.
65
+
66
+ Returns:
67
+ Tensor: Gradient norm.
68
+ """
69
+ grads = [p.grad for p in model.parameters() if p.grad is not None]
70
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type)
71
+ return total_norm
decoder/loss.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from decoder.modules import safe_log
8
+
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class MelSpecReconstructionLoss(nn.Module):
13
+ """
14
+ L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample
15
+ """
16
+
17
+ def __init__(
18
+ self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100,
19
+ ):
20
+ super().__init__()
21
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
22
+ sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1,
23
+ )
24
+
25
+ def forward(self, y_hat, y) -> torch.Tensor:
26
+ """
27
+ Args:
28
+ y_hat (Tensor): Predicted audio waveform.
29
+ y (Tensor): Ground truth audio waveform.
30
+
31
+ Returns:
32
+ Tensor: L1 loss between the mel-scaled magnitude spectrograms.
33
+ """
34
+ mel_hat = safe_log(self.mel_spec(y_hat))
35
+ mel = safe_log(self.mel_spec(y))
36
+
37
+ loss = torch.nn.functional.l1_loss(mel, mel_hat)
38
+
39
+ return loss
40
+
41
+
42
+ class GeneratorLoss(nn.Module):
43
+ """
44
+ Generator Loss module. Calculates the loss for the generator based on discriminator outputs.
45
+ """
46
+
47
+ def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
48
+ """
49
+ Args:
50
+ disc_outputs (List[Tensor]): List of discriminator outputs.
51
+
52
+ Returns:
53
+ Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from
54
+ the sub-discriminators
55
+ """
56
+ loss = 0
57
+ gen_losses = []
58
+ for dg in disc_outputs:
59
+ l = torch.mean(torch.clamp(1 - dg, min=0))
60
+ gen_losses.append(l)
61
+ loss += l
62
+
63
+ return loss, gen_losses
64
+
65
+
66
+ class DiscriminatorLoss(nn.Module):
67
+ """
68
+ Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs.
69
+ """
70
+
71
+ def forward(
72
+ self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
73
+ ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
74
+ """
75
+ Args:
76
+ disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples.
77
+ disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples.
78
+
79
+ Returns:
80
+ Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from
81
+ the sub-discriminators for real outputs, and a list of
82
+ loss values for generated outputs.
83
+ """
84
+ loss = 0
85
+ r_losses = []
86
+ g_losses = []
87
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
88
+ r_loss = torch.mean(torch.clamp(1 - dr, min=0))
89
+ g_loss = torch.mean(torch.clamp(1 + dg, min=0))
90
+ loss += r_loss + g_loss
91
+ r_losses.append(r_loss.item())
92
+ g_losses.append(g_loss.item())
93
+
94
+ return loss, r_losses, g_losses
95
+
96
+
97
+ class FeatureMatchingLoss(nn.Module):
98
+ """
99
+ Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators.
100
+ """
101
+
102
+ def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
103
+ """
104
+ Args:
105
+ fmap_r (List[List[Tensor]]): List of feature maps from real samples.
106
+ fmap_g (List[List[Tensor]]): List of feature maps from generated samples.
107
+
108
+ Returns:
109
+ Tensor: The calculated feature matching loss.
110
+ """
111
+ loss = 0
112
+ for dr, dg in zip(fmap_r, fmap_g):
113
+ for rl, gl in zip(dr, dg):
114
+ loss += torch.mean(torch.abs(rl - gl))
115
+
116
+ return loss
117
+
118
+ class DACGANLoss(nn.Module):
119
+ """
120
+ Computes a discriminator loss, given a discriminator on
121
+ generated waveforms/spectrograms compared to ground truth
122
+ waveforms/spectrograms. Computes the loss for both the
123
+ discriminator and the generator in separate functions.
124
+ """
125
+
126
+ def __init__(self, discriminator):
127
+ super().__init__()
128
+ self.discriminator = discriminator
129
+
130
+ def forward(self, fake, real):
131
+ # d_fake = self.discriminator(fake.audio_data)
132
+ # d_real = self.discriminator(real.audio_data)
133
+ d_fake = self.discriminator(fake)
134
+ d_real = self.discriminator(real)
135
+ return d_fake, d_real
136
+
137
+ def discriminator_loss(self, fake, real):
138
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
139
+
140
+ loss_d = 0
141
+ for x_fake, x_real in zip(d_fake, d_real):
142
+ loss_d += torch.mean(x_fake[-1] ** 2)
143
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
144
+ return loss_d
145
+
146
+ def generator_loss(self, fake, real):
147
+ d_fake, d_real = self.forward(fake, real)
148
+
149
+ loss_g = 0
150
+ for x_fake in d_fake:
151
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
152
+
153
+ loss_feature = 0
154
+
155
+ for i in range(len(d_fake)):
156
+ for j in range(len(d_fake[i]) - 1):
157
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
158
+ return loss_g, loss_feature
159
+
decoder/models.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm
6
+
7
+ from decoder.modules import ConvNeXtBlock, ResBlock1, AdaLayerNorm
8
+
9
+
10
+ def nonlinearity(x):
11
+ # swish
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ def Normalize(in_channels, num_groups=32):
16
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
17
+
18
+
19
+ class ResnetBlock(nn.Module):
20
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
21
+ dropout, temb_channels=512):
22
+ super().__init__()
23
+ self.in_channels = in_channels
24
+ out_channels = in_channels if out_channels is None else out_channels
25
+ self.out_channels = out_channels
26
+ self.use_conv_shortcut = conv_shortcut
27
+
28
+ self.norm1 = Normalize(in_channels)
29
+ self.conv1 = torch.nn.Conv1d(in_channels,
30
+ out_channels,
31
+ kernel_size=3,
32
+ stride=1,
33
+ padding=1)
34
+ if temb_channels > 0:
35
+ self.temb_proj = torch.nn.Linear(temb_channels,
36
+ out_channels)
37
+ self.norm2 = Normalize(out_channels)
38
+ self.dropout = torch.nn.Dropout(dropout)
39
+ self.conv2 = torch.nn.Conv1d(out_channels,
40
+ out_channels,
41
+ kernel_size=3,
42
+ stride=1,
43
+ padding=1)
44
+ if self.in_channels != self.out_channels:
45
+ if self.use_conv_shortcut:
46
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
47
+ out_channels,
48
+ kernel_size=3,
49
+ stride=1,
50
+ padding=1)
51
+ else:
52
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
53
+ out_channels,
54
+ kernel_size=1,
55
+ stride=1,
56
+ padding=0)
57
+
58
+ def forward(self, x, temb=None):
59
+ h = x
60
+ h = self.norm1(h)
61
+ h = nonlinearity(h)
62
+ h = self.conv1(h)
63
+
64
+ if temb is not None:
65
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
66
+
67
+ h = self.norm2(h)
68
+ h = nonlinearity(h)
69
+ h = self.dropout(h)
70
+ h = self.conv2(h)
71
+
72
+ if self.in_channels != self.out_channels:
73
+ if self.use_conv_shortcut:
74
+ x = self.conv_shortcut(x)
75
+ else:
76
+ x = self.nin_shortcut(x)
77
+
78
+ return x + h
79
+
80
+ class AttnBlock(nn.Module):
81
+ def __init__(self, in_channels):
82
+ super().__init__()
83
+ self.in_channels = in_channels
84
+
85
+ self.norm = Normalize(in_channels)
86
+ self.q = torch.nn.Conv1d(in_channels,
87
+ in_channels,
88
+ kernel_size=1,
89
+ stride=1,
90
+ padding=0)
91
+ self.k = torch.nn.Conv1d(in_channels,
92
+ in_channels,
93
+ kernel_size=1,
94
+ stride=1,
95
+ padding=0)
96
+ self.v = torch.nn.Conv1d(in_channels,
97
+ in_channels,
98
+ kernel_size=1,
99
+ stride=1,
100
+ padding=0)
101
+ self.proj_out = torch.nn.Conv1d(in_channels,
102
+ in_channels,
103
+ kernel_size=1,
104
+ stride=1,
105
+ padding=0)
106
+
107
+ def forward(self, x):
108
+ h_ = x
109
+ h_ = self.norm(h_)
110
+ q = self.q(h_)
111
+ k = self.k(h_)
112
+ v = self.v(h_)
113
+
114
+ # compute attention
115
+ b, c, h = q.shape
116
+ q = q.permute(0, 2, 1) # b,hw,c
117
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
118
+ w_ = w_ * (int(c) ** (-0.5))
119
+ w_ = torch.nn.functional.softmax(w_, dim=2)
120
+
121
+ # attend to values
122
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
123
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
124
+
125
+ h_ = self.proj_out(h_)
126
+
127
+ return x + h_
128
+
129
+ def make_attn(in_channels, attn_type="vanilla"):
130
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
131
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
132
+ if attn_type == "vanilla":
133
+ return AttnBlock(in_channels)
134
+
135
+
136
+ class Backbone(nn.Module):
137
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
138
+
139
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
140
+ """
141
+ Args:
142
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
143
+ C denotes output features, and L is the sequence length.
144
+
145
+ Returns:
146
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
147
+ and H denotes the model dimension.
148
+ """
149
+ raise NotImplementedError("Subclasses must implement the forward method.")
150
+
151
+
152
+ class VocosBackbone(Backbone):
153
+ """
154
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
155
+
156
+ Args:
157
+ input_channels (int): Number of input features channels.
158
+ dim (int): Hidden dimension of the model.
159
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
160
+ num_layers (int): Number of ConvNeXtBlock layers.
161
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
162
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
163
+ None means non-conditional model. Defaults to None.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ input_channels: int,
169
+ dim: int,
170
+ intermediate_dim: int,
171
+ num_layers: int,
172
+ layer_scale_init_value: Optional[float] = None,
173
+ adanorm_num_embeddings: Optional[int] = None,
174
+ ):
175
+ super().__init__()
176
+ self.input_channels = input_channels
177
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
178
+ self.adanorm = adanorm_num_embeddings is not None
179
+ if adanorm_num_embeddings:
180
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
181
+ else:
182
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
183
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
184
+ self.convnext = nn.ModuleList(
185
+ [
186
+ ConvNeXtBlock(
187
+ dim=dim,
188
+ intermediate_dim=intermediate_dim,
189
+ layer_scale_init_value=layer_scale_init_value,
190
+ adanorm_num_embeddings=adanorm_num_embeddings,
191
+ )
192
+ for _ in range(num_layers)
193
+ ]
194
+ )
195
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
196
+ self.apply(self._init_weights)
197
+
198
+ self.temb_ch = 0
199
+ block_in = dim
200
+ dropout = 0.1
201
+ attn_type="vanilla"
202
+
203
+ pos_net : tp.List[nn.Module] = [
204
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
205
+ temb_channels=self.temb_ch,dropout=dropout),
206
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
207
+ temb_channels=self.temb_ch,dropout=dropout),
208
+ make_attn(block_in, attn_type=attn_type),
209
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
210
+ temb_channels=self.temb_ch,dropout=dropout),
211
+ ResnetBlock(in_channels=block_in,out_channels=block_in,
212
+ temb_channels=self.temb_ch,dropout=dropout),
213
+ Normalize(block_in)
214
+ ]
215
+
216
+ self.pos_net = nn.Sequential(*pos_net)
217
+
218
+ def _init_weights(self, m):
219
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
220
+ nn.init.trunc_normal_(m.weight, std=0.02)
221
+ nn.init.constant_(m.bias, 0)
222
+
223
+ def forward(self, x: torch.Tensor, bandwidth_id: Optional[torch.Tensor] = None) -> torch.Tensor:
224
+ x = self.embed(x)
225
+ x = self.pos_net(x)
226
+ if self.adanorm:
227
+ assert bandwidth_id is not None
228
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
229
+ else:
230
+ x = self.norm(x.transpose(1, 2))
231
+ x = x.transpose(1, 2)
232
+ for conv_block in self.convnext:
233
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
234
+ x = self.final_layer_norm(x.transpose(1, 2))
235
+ return x
236
+
237
+
238
+ class VocosResNetBackbone(Backbone):
239
+ """
240
+ Vocos backbone module built with ResBlocks.
241
+
242
+ Args:
243
+ input_channels (int): Number of input features channels.
244
+ dim (int): Hidden dimension of the model.
245
+ num_blocks (int): Number of ResBlock1 blocks.
246
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
247
+ """
248
+
249
+ def __init__(
250
+ self, input_channels, dim, num_blocks, layer_scale_init_value=None,
251
+ ):
252
+ super().__init__()
253
+ self.input_channels = input_channels
254
+ self.embed = weight_norm(nn.Conv1d(input_channels, dim, kernel_size=3, padding=1))
255
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
256
+ self.resnet = nn.Sequential(
257
+ *[ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) for _ in range(num_blocks)]
258
+ )
259
+
260
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
261
+ x = self.embed(x)
262
+ x = self.resnet(x)
263
+ x = x.transpose(1, 2)
264
+ return x
decoder/modules.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+
8
+ class ConvNeXtBlock(nn.Module):
9
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
10
+
11
+ Args:
12
+ dim (int): Number of input channels.
13
+ intermediate_dim (int): Dimensionality of the intermediate layer.
14
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
15
+ Defaults to None.
16
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
17
+ None means non-conditional LayerNorm. Defaults to None.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ dim: int,
23
+ intermediate_dim: int,
24
+ layer_scale_init_value: Optional[float] = None,
25
+ adanorm_num_embeddings: Optional[int] = None,
26
+ ):
27
+ super().__init__()
28
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
29
+ self.adanorm = adanorm_num_embeddings is not None
30
+ if adanorm_num_embeddings:
31
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
32
+ else:
33
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
34
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
35
+ self.act = nn.GELU()
36
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
37
+ self.gamma = (
38
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
39
+ if layer_scale_init_value > 0
40
+ else None
41
+ )
42
+
43
+ def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor:
44
+ residual = x
45
+ x = self.dwconv(x)
46
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
47
+ if self.adanorm:
48
+ assert cond_embedding_id is not None
49
+ x = self.norm(x, cond_embedding_id)
50
+ else:
51
+ x = self.norm(x)
52
+ x = self.pwconv1(x)
53
+ x = self.act(x)
54
+ x = self.pwconv2(x)
55
+ if self.gamma is not None:
56
+ x = self.gamma * x
57
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
58
+
59
+ x = residual + x
60
+ return x
61
+
62
+
63
+ class AdaLayerNorm(nn.Module):
64
+ """
65
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
66
+
67
+ Args:
68
+ num_embeddings (int): Number of embeddings.
69
+ embedding_dim (int): Dimension of the embeddings.
70
+ """
71
+
72
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
73
+ super().__init__()
74
+ self.eps = eps
75
+ self.dim = embedding_dim
76
+ self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
77
+ self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
78
+ torch.nn.init.ones_(self.scale.weight)
79
+ torch.nn.init.zeros_(self.shift.weight)
80
+
81
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
82
+ scale = self.scale(cond_embedding_id)
83
+ shift = self.shift(cond_embedding_id)
84
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
85
+ x = x * scale + shift
86
+ return x
87
+
88
+
89
+ class ResBlock1(nn.Module):
90
+ """
91
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
92
+ but without upsampling layers.
93
+
94
+ Args:
95
+ dim (int): Number of input channels.
96
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
97
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
98
+ Defaults to (1, 3, 5).
99
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
100
+ Defaults to 0.1.
101
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
102
+ Defaults to None.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ dim: int,
108
+ kernel_size: int = 3,
109
+ dilation: tuple[int] = (1, 3, 5),
110
+ lrelu_slope: float = 0.1,
111
+ layer_scale_init_value: float = None,
112
+ ):
113
+ super().__init__()
114
+ self.lrelu_slope = lrelu_slope
115
+ self.convs1 = nn.ModuleList(
116
+ [
117
+ weight_norm(
118
+ nn.Conv1d(
119
+ dim,
120
+ dim,
121
+ kernel_size,
122
+ 1,
123
+ dilation=dilation[0],
124
+ padding=self.get_padding(kernel_size, dilation[0]),
125
+ )
126
+ ),
127
+ weight_norm(
128
+ nn.Conv1d(
129
+ dim,
130
+ dim,
131
+ kernel_size,
132
+ 1,
133
+ dilation=dilation[1],
134
+ padding=self.get_padding(kernel_size, dilation[1]),
135
+ )
136
+ ),
137
+ weight_norm(
138
+ nn.Conv1d(
139
+ dim,
140
+ dim,
141
+ kernel_size,
142
+ 1,
143
+ dilation=dilation[2],
144
+ padding=self.get_padding(kernel_size, dilation[2]),
145
+ )
146
+ ),
147
+ ]
148
+ )
149
+
150
+ self.convs2 = nn.ModuleList(
151
+ [
152
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
153
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
154
+ weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))),
155
+ ]
156
+ )
157
+
158
+ self.gamma = nn.ParameterList(
159
+ [
160
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
161
+ if layer_scale_init_value is not None
162
+ else None,
163
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
164
+ if layer_scale_init_value is not None
165
+ else None,
166
+ nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True)
167
+ if layer_scale_init_value is not None
168
+ else None,
169
+ ]
170
+ )
171
+
172
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
173
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
174
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
175
+ xt = c1(xt)
176
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
177
+ xt = c2(xt)
178
+ if gamma is not None:
179
+ xt = gamma * xt
180
+ x = xt + x
181
+ return x
182
+
183
+ def remove_weight_norm(self):
184
+ for l in self.convs1:
185
+ remove_weight_norm(l)
186
+ for l in self.convs2:
187
+ remove_weight_norm(l)
188
+
189
+ @staticmethod
190
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
191
+ return int((kernel_size * dilation - dilation) / 2)
192
+
193
+
194
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
195
+ """
196
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
197
+
198
+ Args:
199
+ x (Tensor): Input tensor.
200
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
201
+
202
+ Returns:
203
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
204
+ """
205
+ return torch.log(torch.clip(x, min=clip_val))
206
+
207
+
208
+ def symlog(x: torch.Tensor) -> torch.Tensor:
209
+ return torch.sign(x) * torch.log1p(x.abs())
210
+
211
+
212
+ def symexp(x: torch.Tensor) -> torch.Tensor:
213
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
decoder/pretrained.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Tuple, Any, Union, Dict
3
+
4
+ import torch
5
+ import yaml
6
+ from huggingface_hub import hf_hub_download
7
+ from torch import nn
8
+ from decoder.feature_extractors import FeatureExtractor, EncodecFeatures
9
+ from decoder.heads import FourierHead
10
+ from decoder.models import Backbone
11
+
12
+
13
+ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
14
+ """Instantiates a class with the given args and init.
15
+
16
+ Args:
17
+ args: Positional arguments required for instantiation.
18
+ init: Dict of the form {"class_path":...,"init_args":...}.
19
+
20
+ Returns:
21
+ The instantiated class object.
22
+ """
23
+ kwargs = init.get("init_args", {})
24
+ if not isinstance(args, tuple):
25
+ args = (args,)
26
+ class_module, class_name = init["class_path"].rsplit(".", 1)
27
+ module = __import__(class_module, fromlist=[class_name])
28
+ args_class = getattr(module, class_name)
29
+ return args_class(*args, **kwargs)
30
+
31
+
32
+ class WavTokenizer(nn.Module):
33
+ """
34
+ The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
35
+ This class is primarily designed for inference, with support for loading from pretrained
36
+ model checkpoints. It consists of three main components: a feature extractor,
37
+ a backbone, and a head.
38
+ """
39
+
40
+ def __init__(
41
+ self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
42
+ ):
43
+ super().__init__()
44
+ self.feature_extractor = feature_extractor
45
+ self.backbone = backbone
46
+ self.head = head
47
+
48
+ @classmethod
49
+ def from_hparams(cls, config_path: str) -> "Vocos":
50
+ """
51
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
52
+ """
53
+ with open(config_path, "r") as f:
54
+ config = yaml.safe_load(f)
55
+ feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
56
+ backbone = instantiate_class(args=(), init=config["backbone"])
57
+ head = instantiate_class(args=(), init=config["head"])
58
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
59
+ return model
60
+
61
+ @classmethod
62
+ def from_pretrained(self, repo_id: str) -> "Vocos":
63
+ """
64
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
65
+ """
66
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml")
67
+ model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
68
+ model = self.from_hparams(config_path)
69
+ state_dict = torch.load(model_path, map_location="cpu")
70
+ if isinstance(model.feature_extractor, EncodecFeatures):
71
+ encodec_parameters = {
72
+ "feature_extractor.encodec." + key: value
73
+ for key, value in model.feature_extractor.encodec.state_dict().items()
74
+ }
75
+ state_dict.update(encodec_parameters)
76
+ model.load_state_dict(state_dict)
77
+ model.eval()
78
+ return model
79
+
80
+
81
+ @classmethod
82
+ def from_hparams0802(cls, config_path: str) -> "Vocos":
83
+ """
84
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
85
+ """
86
+ with open(config_path, "r") as f:
87
+ config = yaml.safe_load(f)
88
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
89
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
90
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
91
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
92
+ return model
93
+
94
+
95
+ @classmethod
96
+ def from_pretrained0802(self, config_path, model_path):
97
+ """
98
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
99
+ """
100
+ model = self.from_hparams0802(config_path)
101
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
102
+ state_dict = dict()
103
+ for k, v in state_dict_raw.items():
104
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
105
+ state_dict[k] = v
106
+ # if isinstance(model.feature_extractor, EncodecFeatures):
107
+ # encodec_parameters = {
108
+ # "feature_extractor.encodec." + key: value
109
+ # for key, value in model.feature_extractor.encodec.state_dict().items()
110
+ # }
111
+ # state_dict.update(encodec_parameters)
112
+ model.load_state_dict(state_dict)
113
+ model.eval()
114
+ return model
115
+
116
+
117
+ @classmethod
118
+ def from_pretrained0911(self, config_path, model_folder_path):
119
+ """
120
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
121
+ """
122
+ model = self.from_hparams0802(config_path)
123
+
124
+ models = os.listdir(model_folder_path)
125
+ val_loss = []
126
+ for item in models:
127
+ if not item.startswith('vocos_'):
128
+ continue
129
+ val_loss.append(item[-11:-5])
130
+ val_loss.sort()
131
+ val_loss = val_loss[:3] # 取前3性能较好的模型平均
132
+ state_dict = dict()
133
+ state_dicts = []
134
+ for item in models:
135
+ if not item.startswith('vocos_'):
136
+ continue
137
+ ll = item[-11:-5]
138
+ if ll not in val_loss:
139
+ continue
140
+ model_path = model_folder_path + '/' + item
141
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
142
+ state_dict_single = dict()
143
+ for k, v in state_dict_raw.items():
144
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
145
+ state_dict_single[k] = v
146
+ state_dicts.append(state_dict_single)
147
+ for kk in state_dicts[0].keys():
148
+ vv = state_dicts[0][kk]
149
+ for i in range(1, len(state_dicts)):
150
+ ss = state_dicts[i]
151
+ vv += ss[kk]
152
+ vm = vv/len(state_dicts)
153
+ state_dict[kk] = vm
154
+ model.load_state_dict(state_dict)
155
+ model.eval()
156
+ return model
157
+
158
+
159
+ @torch.inference_mode()
160
+ def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
161
+ """
162
+ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
163
+ which is then passed through the backbone and the head to reconstruct the audio output.
164
+
165
+ Args:
166
+ audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
167
+ where B is the batch size and L is the waveform length.
168
+
169
+
170
+ Returns:
171
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
172
+ """
173
+ features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
174
+ audio_output = self.decode(features, **kwargs)
175
+ return audio_output
176
+
177
+
178
+ # 0818
179
+ @torch.inference_mode()
180
+ def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
181
+ features, discrete_codes, _ = self.feature_extractor(audio_input, **kwargs)
182
+ return features,discrete_codes
183
+
184
+
185
+ # 0818
186
+ @torch.inference_mode()
187
+ def encode_infer(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
188
+ features, discrete_codes, _ = self.feature_extractor.infer(audio_input, **kwargs)
189
+ return features,discrete_codes
190
+
191
+
192
+ @torch.inference_mode()
193
+ def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
194
+ """
195
+ Method to decode audio waveform from already calculated features. The features input is passed through
196
+ the backbone and the head to reconstruct the audio output.
197
+
198
+ Args:
199
+ features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
200
+ C denotes the feature dimension, and L is the sequence length.
201
+
202
+ Returns:
203
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
204
+ """
205
+ x = self.backbone(features_input, **kwargs)
206
+ audio_output = self.head(x)
207
+ return audio_output
208
+
209
+ @torch.inference_mode()
210
+ def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
211
+ """
212
+ Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
213
+ codebook weights.
214
+
215
+ Args:
216
+ codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
217
+ where K is the number of codebooks, B is the batch size and L is the sequence length.
218
+
219
+ Returns:
220
+ Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
221
+ and L is the sequence length.
222
+ """
223
+ assert isinstance(
224
+ self.feature_extractor, EncodecFeatures
225
+ ), "Feature extractor should be an instance of EncodecFeatures"
226
+
227
+ if codes.dim() == 2:
228
+ codes = codes.unsqueeze(1)
229
+
230
+ n_bins = self.feature_extractor.encodec.quantizer.bins
231
+ offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
232
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
233
+
234
+ tmp=torch.cat([vq.codebook for vq in self.feature_extractor.encodec.quantizer.vq.layers],dim=0)
235
+ # features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
236
+ features = torch.nn.functional.embedding(embeddings_idxs, tmp).sum(dim=0)
237
+ features = features.transpose(1, 2)
238
+
239
+ return features
decoder/pretrained_model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Any, Union, Dict
2
+
3
+ import torch
4
+ import yaml
5
+ from huggingface_hub import hf_hub_download
6
+ from torch import nn
7
+ from decoder.feature_extractors import FeatureExtractor, EncodecFeatures
8
+ from decoder.heads import FourierHead
9
+ from decoder.models import Backbone
10
+ from decoder.discriminators import MultiPeriodDiscriminator, MultiResolutionDiscriminator
11
+
12
+
13
+ def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
14
+ """Instantiates a class with the given args and init.
15
+
16
+ Args:
17
+ args: Positional arguments required for instantiation.
18
+ init: Dict of the form {"class_path":...,"init_args":...}.
19
+
20
+ Returns:
21
+ The instantiated class object.
22
+ """
23
+ kwargs = init.get("init_args", {})
24
+ if not isinstance(args, tuple):
25
+ args = (args,)
26
+ class_module, class_name = init["class_path"].rsplit(".", 1)
27
+ module = __import__(class_module, fromlist=[class_name])
28
+ args_class = getattr(module, class_name)
29
+ return args_class(*args, **kwargs)
30
+
31
+
32
+ class WavTokenizer(nn.Module):
33
+ """
34
+ The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
35
+ This class is primarily designed for inference, with support for loading from pretrained
36
+ model checkpoints. It consists of three main components: a feature extractor,
37
+ a backbone, and a head.
38
+ """
39
+
40
+ def __init__(
41
+ self, feature_extractor: FeatureExtractor, backbone: Backbone, head: FourierHead,
42
+ multiperioddisc: MultiPeriodDiscriminator, multiresddisc: MultiResolutionDiscriminator,
43
+ ):
44
+ super().__init__()
45
+ self.feature_extractor = feature_extractor
46
+ self.backbone = backbone
47
+ self.head = head
48
+
49
+ self.multiperioddisc = multiperioddisc
50
+ self.multiresddisc = multiresddisc
51
+
52
+ @classmethod
53
+ def from_hparams0828(cls, config_path: str) -> "Vocos":
54
+ """
55
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
56
+ """
57
+ with open(config_path, "r") as f:
58
+ config = yaml.safe_load(f)
59
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
60
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
61
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
62
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head,
63
+ multiperioddisc=MultiPeriodDiscriminator(num_embeddings=4),
64
+ multiresddisc=MultiResolutionDiscriminator(num_embeddings=4))
65
+ return model
66
+
67
+ @classmethod
68
+ def from_pretrained0828(self, config_path, model_path):
69
+ """
70
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
71
+ """
72
+ model = self.from_hparams0828(config_path)
73
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
74
+ state_dict = dict()
75
+ for k, v in state_dict_raw.items():
76
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.') \
77
+ or k.startswith('multiperioddisc.') or k.startswith('multiresddisc.'):
78
+ state_dict[k] = v
79
+ # if isinstance(model.feature_extractor, EncodecFeatures):
80
+ # encodec_parameters = {
81
+ # "feature_extractor.encodec." + key: value
82
+ # for key, value in model.feature_extractor.encodec.state_dict().items()
83
+ # }
84
+ # state_dict.update(encodec_parameters)
85
+ model.load_state_dict(state_dict)
86
+ return model
87
+
88
+ @classmethod
89
+ def from_hparams0802(cls, config_path: str) -> "Vocos":
90
+ """
91
+ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
92
+ """
93
+ with open(config_path, "r") as f:
94
+ config = yaml.safe_load(f)
95
+ feature_extractor = instantiate_class(args=(), init=config['model']['init_args']["feature_extractor"])
96
+ backbone = instantiate_class(args=(), init=config['model']['init_args']["backbone"])
97
+ head = instantiate_class(args=(), init=config['model']['init_args']["head"])
98
+ model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
99
+ return model
100
+
101
+ @classmethod
102
+ def from_pretrained0802(self, config_path, model_path):
103
+ """
104
+ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
105
+ """
106
+ model = self.from_hparams0802(config_path)
107
+ state_dict_raw = torch.load(model_path, map_location="cpu")['state_dict']
108
+ state_dict = dict()
109
+ for k, v in state_dict_raw.items():
110
+ if k.startswith('backbone.') or k.startswith('head.') or k.startswith('feature_extractor.'):
111
+ state_dict[k] = v
112
+ # if isinstance(model.feature_extractor, EncodecFeatures):
113
+ # encodec_parameters = {
114
+ # "feature_extractor.encodec." + key: value
115
+ # for key, value in model.feature_extractor.encodec.state_dict().items()
116
+ # }
117
+ # state_dict.update(encodec_parameters)
118
+ model.load_state_dict(state_dict)
119
+ model.eval()
120
+ return model
121
+
122
+ @torch.inference_mode()
123
+ def forward(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
124
+ """
125
+ Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
126
+ which is then passed through the backbone and the head to reconstruct the audio output.
127
+
128
+ Args:
129
+ audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
130
+ where B is the batch size and L is the waveform length.
131
+
132
+
133
+ Returns:
134
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
135
+ """
136
+ features, _, _ = self.feature_extractor(audio_input, **kwargs) # 0818
137
+ audio_output = self.decode(features, **kwargs)
138
+ return audio_output
139
+
140
+
141
+ # 0818
142
+ @torch.inference_mode()
143
+ def encode(self, audio_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
144
+ features, _, _ = self.feature_extractor(audio_input, **kwargs)
145
+ return features
146
+
147
+
148
+ @torch.inference_mode()
149
+ def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
150
+ """
151
+ Method to decode audio waveform from already calculated features. The features input is passed through
152
+ the backbone and the head to reconstruct the audio output.
153
+
154
+ Args:
155
+ features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
156
+ C denotes the feature dimension, and L is the sequence length.
157
+
158
+ Returns:
159
+ Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
160
+ """
161
+ x = self.backbone(features_input, **kwargs)
162
+ audio_output = self.head(x)
163
+ return audio_output
164
+
165
+ @torch.inference_mode()
166
+ def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
169
+ codebook weights.
170
+
171
+ Args:
172
+ codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
173
+ where K is the number of codebooks, B is the batch size and L is the sequence length.
174
+
175
+ Returns:
176
+ Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
177
+ and L is the sequence length.
178
+ """
179
+ assert isinstance(
180
+ self.feature_extractor, EncodecFeatures
181
+ ), "Feature extractor should be an instance of EncodecFeatures"
182
+
183
+ if codes.dim() == 2:
184
+ codes = codes.unsqueeze(1)
185
+
186
+ n_bins = self.feature_extractor.encodec.quantizer.bins
187
+ offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
188
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
189
+ features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
190
+ features = features.transpose(1, 2)
191
+
192
+ return features
decoder/spectral_ops.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+ import torch
4
+ from torch import nn, view_as_real, view_as_complex
5
+
6
+
7
+ class ISTFT(nn.Module):
8
+ """
9
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
10
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
11
+ See issue: https://github.com/pytorch/pytorch/issues/62323
12
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
13
+ The NOLA constraint is met as we trim padded samples anyway.
14
+
15
+ Args:
16
+ n_fft (int): Size of Fourier transform.
17
+ hop_length (int): The distance between neighboring sliding window frames.
18
+ win_length (int): The size of window frame and STFT filter.
19
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
20
+ """
21
+
22
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
23
+ super().__init__()
24
+ if padding not in ["center", "same"]:
25
+ raise ValueError("Padding must be 'center' or 'same'.")
26
+ self.padding = padding
27
+ self.n_fft = n_fft
28
+ self.hop_length = hop_length
29
+ self.win_length = win_length
30
+ window = torch.hann_window(win_length)
31
+ self.register_buffer("window", window)
32
+
33
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
34
+ """
35
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
36
+
37
+ Args:
38
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
39
+ N is the number of frequency bins, and T is the number of time frames.
40
+
41
+ Returns:
42
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
43
+ """
44
+ if self.padding == "center":
45
+ # Fallback to pytorch native implementation
46
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
47
+ elif self.padding == "same":
48
+ pad = (self.win_length - self.hop_length) // 2
49
+ else:
50
+ raise ValueError("Padding must be 'center' or 'same'.")
51
+
52
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
53
+ B, N, T = spec.shape
54
+
55
+ # Inverse FFT
56
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
57
+ ifft = ifft * self.window[None, :, None]
58
+
59
+ # Overlap and Add
60
+ output_size = (T - 1) * self.hop_length + self.win_length
61
+ y = torch.nn.functional.fold(
62
+ ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
63
+ )[:, 0, 0, pad:-pad]
64
+
65
+ # Window envelope
66
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
67
+ window_envelope = torch.nn.functional.fold(
68
+ window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
69
+ ).squeeze()[pad:-pad]
70
+
71
+ # Normalize
72
+ assert (window_envelope > 1e-11).all()
73
+ y = y / window_envelope
74
+
75
+ return y
76
+
77
+
78
+ class MDCT(nn.Module):
79
+ """
80
+ Modified Discrete Cosine Transform (MDCT) module.
81
+
82
+ Args:
83
+ frame_len (int): Length of the MDCT frame.
84
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
85
+ """
86
+
87
+ def __init__(self, frame_len: int, padding: str = "same"):
88
+ super().__init__()
89
+ if padding not in ["center", "same"]:
90
+ raise ValueError("Padding must be 'center' or 'same'.")
91
+ self.padding = padding
92
+ self.frame_len = frame_len
93
+ N = frame_len // 2
94
+ n0 = (N + 1) / 2
95
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
96
+ self.register_buffer("window", window)
97
+
98
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
99
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
100
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
101
+ # https://github.com/pytorch/pytorch/issues/71613
102
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
103
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
104
+
105
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
108
+
109
+ Args:
110
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
111
+ and T is the length of the audio.
112
+
113
+ Returns:
114
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
115
+ and N is the number of frequency bins.
116
+ """
117
+ if self.padding == "center":
118
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 2, self.frame_len // 2))
119
+ elif self.padding == "same":
120
+ # hop_length is 1/2 frame_len
121
+ audio = torch.nn.functional.pad(audio, (self.frame_len // 4, self.frame_len // 4))
122
+ else:
123
+ raise ValueError("Padding must be 'center' or 'same'.")
124
+
125
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
126
+ N = self.frame_len // 2
127
+ x = x * self.window.expand(x.shape)
128
+ X = torch.fft.fft(x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1)[..., :N]
129
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
130
+ return torch.real(res) * np.sqrt(2)
131
+
132
+
133
+ class IMDCT(nn.Module):
134
+ """
135
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
136
+
137
+ Args:
138
+ frame_len (int): Length of the MDCT frame.
139
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
140
+ """
141
+
142
+ def __init__(self, frame_len: int, padding: str = "same"):
143
+ super().__init__()
144
+ if padding not in ["center", "same"]:
145
+ raise ValueError("Padding must be 'center' or 'same'.")
146
+ self.padding = padding
147
+ self.frame_len = frame_len
148
+ N = frame_len // 2
149
+ n0 = (N + 1) / 2
150
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
151
+ self.register_buffer("window", window)
152
+
153
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
154
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
155
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
156
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
157
+
158
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
161
+
162
+ Args:
163
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
164
+ L is the number of frames, and N is the number of frequency bins.
165
+
166
+ Returns:
167
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
168
+ """
169
+ B, L, N = X.shape
170
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
171
+ Y[..., :N] = X
172
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
173
+ y = torch.fft.ifft(Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1)
174
+ y = torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape)) * np.sqrt(N) * np.sqrt(2)
175
+ result = y * self.window.expand(y.shape)
176
+ output_size = (1, (L + 1) * N)
177
+ audio = torch.nn.functional.fold(
178
+ result.transpose(1, 2),
179
+ output_size=output_size,
180
+ kernel_size=(1, self.frame_len),
181
+ stride=(1, self.frame_len // 2),
182
+ )[:, 0, 0, :]
183
+
184
+ if self.padding == "center":
185
+ pad = self.frame_len // 2
186
+ elif self.padding == "same":
187
+ pad = self.frame_len // 4
188
+ else:
189
+ raise ValueError("Padding must be 'center' or 'same'.")
190
+
191
+ audio = audio[:, pad:-pad]
192
+ return audio
infer.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --coding:utf-8--
2
+ import os
3
+
4
+ from encoder.utils import convert_audio
5
+ import torchaudio
6
+ import torch
7
+ from decoder.pretrained import WavTokenizer
8
+
9
+ import time
10
+
11
+ import logging
12
+
13
+ device1=torch.device('cuda:0')
14
+ device2=torch.device('cpu')
15
+
16
+ input_path = "./WavTokenizer/data/infer/lirbitts_testclean"
17
+ out_folder = './WavTokenizer/result/infer'
18
+ # os.system("rm -r %s"%(out_folder))
19
+ # os.system("mkdir -p %s"%(out_folder))
20
+ # ll="libritts_testclean500_large"
21
+ ll="wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn_testclean_epoch34"
22
+
23
+ tmptmp=out_folder+"/"+ll
24
+
25
+ os.system("rm -r %s"%(tmptmp))
26
+ os.system("mkdir -p %s"%(tmptmp))
27
+
28
+ # 自己数据模型加载
29
+ config_path = "./WavTokenizer/configs/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
30
+ model_path = "./WavTokenizer/result/train/wavtokenizer_smalldata_frame40_3s_nq1_code4096_dim512_kmeans200_attn/lightning_logs/version_3/checkpoints/wavtokenizer_checkpoint_epoch=24_step=137150_val_loss=5.6731.ckpt"
31
+ wavtokenizer = WavTokenizer.from_pretrained0802(config_path, model_path)
32
+ wavtokenizer = wavtokenizer.to(device1)
33
+ # wavtokenizer = wavtokenizer.to(device2)
34
+
35
+ with open(input_path,'r') as fin:
36
+ x=fin.readlines()
37
+
38
+ x = [i.strip() for i in x]
39
+
40
+ # 完成一些加速处理
41
+
42
+ features_all=[]
43
+
44
+ for i in range(len(x)):
45
+
46
+ wav, sr = torchaudio.load(x[i])
47
+ # print("***:",x[i])
48
+ # wav = convert_audio(wav, sr, 24000, 1) # (1,131040)
49
+ bandwidth_id = torch.tensor([0])
50
+ wav=wav.to(device1)
51
+ print(i)
52
+
53
+ features,discrete_code= wavtokenizer.encode_infer(wav, bandwidth_id=bandwidth_id)
54
+ features_all.append(features.cpu())
55
+
56
+ wavtokenizer = wavtokenizer.to(device2)
57
+
58
+ for i in range(len(x)):
59
+
60
+ bandwidth_id = torch.tensor([0])
61
+
62
+ print(i)
63
+ audio_out = wavtokenizer.decode(features_all[i], bandwidth_id=bandwidth_id)
64
+ # print(i,time.time())
65
+ # breakpoint() # (1, 131200)
66
+ audio_path = out_folder + '/' + ll + '/' + x[i].split('/')[-1]
67
+ # os.makedirs(out_folder + '/' + ll, exist_ok=True)
68
+ torchaudio.save(audio_path, audio_out, sample_rate=24000, encoding='PCM_S', bits_per_sample=16)
69
+
70
+
71
+
72
+
73
+
train.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
3
+
4
+ from pytorch_lightning.cli import LightningCLI, ArgsType
5
+
6
+
7
+ def cli_main(args: ArgsType = None):
8
+ # breakpoint()
9
+ cli = LightningCLI(args=args)
10
+ # breakpoint()
11
+ cli.trainer.fit(model=cli.model, datamodule=cli.datamodule)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ cli_main()