haoxiangsnr commited on
Commit
fe777b2
·
verified ·
1 Parent(s): 7024957

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Amphion/models/base/base_trainer.py +348 -0
  2. Amphion/models/codec/ns3_codec/__pycache__/melspec.cpython-310.pyc +0 -0
  3. Amphion/models/codec/ns3_codec/__pycache__/transformer.cpython-310.pyc +0 -0
  4. Amphion/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
  5. Amphion/models/svc/diffusion/diffusion_trainer.py +102 -0
  6. Amphion/models/tta/autoencoder/__init__.py +0 -0
  7. Amphion/models/tta/autoencoder/autoencoder_loss.py +305 -0
  8. Amphion/models/tta/ldm/audioldm_inference.py +193 -0
  9. Amphion/models/tts/base/tts_inferece.py +278 -0
  10. Amphion/models/tts/fastspeech2/fs2_dataset.py +424 -0
  11. Amphion/models/tts/naturalspeech2/diffusion.py +124 -0
  12. Amphion/models/tts/vits/vits_inference.py +163 -0
  13. Amphion/models/vocoders/flow/flow_vocoder_trainer.py +0 -0
  14. Amphion/models/vocoders/gan/gan_vocoder_dataset.py +205 -0
  15. Amphion/modules/anti_aliasing/__init__.py +8 -0
  16. Amphion/modules/encoder/condition_encoder.py +244 -0
  17. Amphion/modules/general/__init__.py +3 -0
  18. Amphion/modules/monotonic_align/__init__.py +21 -0
  19. Amphion/modules/neural_source_filter/__init__.py +6 -0
  20. Amphion/modules/transformer/Layers.py +137 -0
  21. Amphion/modules/wenet_extractor/cif/predictor.py +274 -0
  22. Amphion/modules/wenet_extractor/paraformer/search/beam_search.py +479 -0
  23. Amphion/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py +377 -0
  24. Amphion/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py +88 -0
  25. Amphion/modules/wenet_extractor/transformer/decoder_layer.py +140 -0
  26. Amphion/modules/wenet_extractor/transformer/subsampling.py +257 -0
  27. Amphion/modules/wenet_extractor/utils/__init__.py +0 -0
  28. Amphion/preprocessors/cdmusiceval.py +174 -0
  29. Amphion/utils/data_utils.py +588 -0
  30. Amphion/utils/distribution.py +270 -0
  31. Amphion/utils/mel.py +280 -0
  32. Amphion/utils/prompt_preparer.py +68 -0
  33. __pycache__/model.cpython-310.pyc +0 -0
  34. conf/default.yaml +70 -0
  35. exp/bmi__fa-codec/2024_05_20--16_21_26.log +4 -0
  36. exp/bmi__fa-codec/2024_05_20--16_22_35.log +110 -0
  37. exp/bmi__fa-codec/2024_05_20--16_24_01.log +4 -0
  38. exp/bmi__fa-codec/amplified_signals/S06021_L0014_HA-output.wav +0 -0
  39. exp/bmi__fa-codec/amplified_signals/S06026_L0088_HA-output.wav +0 -0
  40. exp/bmi__fa-codec/amplified_signals/S06031_L0096_HA-output.wav +0 -0
  41. exp/bmi__fa-codec/amplified_signals/S06036_L0036_HA-output.wav +0 -0
  42. exp/bmi__fa-codec/amplified_signals/S06066_L0042_HA-output.wav +0 -0
  43. exp/bmi__fa-codec/amplified_signals/S06071_L0089_HA-output.wav +0 -0
  44. exp/bmi__fa-codec/amplified_signals/S06086_L0072_HA-output.wav +0 -0
  45. exp/bmi__fa-codec/amplified_signals/S06091_L0099_HA-output.wav +0 -0
  46. exp/bmi__fa-codec/amplified_signals/S06101_L0042_HA-output.wav +0 -0
  47. exp/bmi__fa-codec/amplified_signals/S06111_L0002_HA-output.wav +0 -0
  48. exp/bmi__fa-codec/amplified_signals/S06116_L0092_HA-output.wav +0 -0
  49. exp/bmi__fa-codec/amplified_signals/S06126_L0069_HA-output.wav +0 -0
  50. exp/bmi__fa-codec/amplified_signals/S06146_L0017_HA-output.wav +0 -0
Amphion/models/base/base_trainer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from models.base.base_sampler import BatchSampler
19
+ from utils.util import (
20
+ Logger,
21
+ remove_older_ckpt,
22
+ save_config,
23
+ set_all_random_seed,
24
+ ValueWindow,
25
+ )
26
+
27
+
28
+ class BaseTrainer(object):
29
+ def __init__(self, args, cfg):
30
+ self.args = args
31
+ self.log_dir = args.log_dir
32
+ self.cfg = cfg
33
+
34
+ self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
35
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
36
+ if not cfg.train.ddp or args.local_rank == 0:
37
+ self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
38
+ self.logger = self.build_logger()
39
+ self.time_window = ValueWindow(50)
40
+
41
+ self.step = 0
42
+ self.epoch = -1
43
+ self.max_epochs = self.cfg.train.epochs
44
+ self.max_steps = self.cfg.train.max_steps
45
+
46
+ # set random seed & init distributed training
47
+ set_all_random_seed(self.cfg.train.random_seed)
48
+ if cfg.train.ddp:
49
+ dist.init_process_group(backend="nccl")
50
+
51
+ if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
52
+ self.singers = self.build_singers_lut()
53
+
54
+ # setup data_loader
55
+ self.data_loader = self.build_data_loader()
56
+
57
+ # setup model & enable distributed training
58
+ self.model = self.build_model()
59
+ print(self.model)
60
+
61
+ if isinstance(self.model, dict):
62
+ for key, value in self.model.items():
63
+ value.cuda(self.args.local_rank)
64
+ if key == "PQMF":
65
+ continue
66
+ if cfg.train.ddp:
67
+ self.model[key] = DistributedDataParallel(
68
+ value, device_ids=[self.args.local_rank]
69
+ )
70
+ else:
71
+ self.model.cuda(self.args.local_rank)
72
+ if cfg.train.ddp:
73
+ self.model = DistributedDataParallel(
74
+ self.model, device_ids=[self.args.local_rank]
75
+ )
76
+
77
+ # create criterion
78
+ self.criterion = self.build_criterion()
79
+ if isinstance(self.criterion, dict):
80
+ for key, value in self.criterion.items():
81
+ self.criterion[key].cuda(args.local_rank)
82
+ else:
83
+ self.criterion.cuda(self.args.local_rank)
84
+
85
+ # optimizer
86
+ self.optimizer = self.build_optimizer()
87
+ self.scheduler = self.build_scheduler()
88
+
89
+ # save config file
90
+ self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
91
+
92
+ def build_logger(self):
93
+ log_file = os.path.join(self.checkpoint_dir, "train.log")
94
+ logger = Logger(log_file, level=self.args.log_level).logger
95
+
96
+ return logger
97
+
98
+ def build_dataset(self):
99
+ raise NotImplementedError
100
+
101
+ def build_data_loader(self):
102
+ Dataset, Collator = self.build_dataset()
103
+ # build dataset instance for each dataset and combine them by ConcatDataset
104
+ datasets_list = []
105
+ for dataset in self.cfg.dataset:
106
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
107
+ datasets_list.append(subdataset)
108
+ train_dataset = ConcatDataset(datasets_list)
109
+
110
+ train_collate = Collator(self.cfg)
111
+ # TODO: multi-GPU training
112
+ if self.cfg.train.ddp:
113
+ raise NotImplementedError("DDP is not supported yet.")
114
+
115
+ # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
116
+ batch_sampler = BatchSampler(
117
+ cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
118
+ )
119
+
120
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
121
+ train_loader = DataLoader(
122
+ train_dataset,
123
+ collate_fn=train_collate,
124
+ num_workers=self.args.num_workers,
125
+ batch_sampler=batch_sampler,
126
+ pin_memory=False,
127
+ )
128
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
129
+ datasets_list = []
130
+ for dataset in self.cfg.dataset:
131
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
132
+ datasets_list.append(subdataset)
133
+ valid_dataset = ConcatDataset(datasets_list)
134
+ valid_collate = Collator(self.cfg)
135
+ batch_sampler = BatchSampler(
136
+ cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
137
+ )
138
+ valid_loader = DataLoader(
139
+ valid_dataset,
140
+ collate_fn=valid_collate,
141
+ num_workers=1,
142
+ batch_sampler=batch_sampler,
143
+ )
144
+ else:
145
+ raise NotImplementedError("DDP is not supported yet.")
146
+ # valid_loader = None
147
+ data_loader = {"train": train_loader, "valid": valid_loader}
148
+ return data_loader
149
+
150
+ def build_singers_lut(self):
151
+ # combine singers
152
+ if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
153
+ singers = collections.OrderedDict()
154
+ else:
155
+ with open(
156
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
157
+ ) as singer_file:
158
+ singers = json.load(singer_file)
159
+ singer_count = len(singers)
160
+ for dataset in self.cfg.dataset:
161
+ singer_lut_path = os.path.join(
162
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
163
+ )
164
+ with open(singer_lut_path, "r") as singer_lut_path:
165
+ singer_lut = json.load(singer_lut_path)
166
+ for singer in singer_lut.keys():
167
+ if singer not in singers:
168
+ singers[singer] = singer_count
169
+ singer_count += 1
170
+ with open(
171
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
172
+ ) as singer_file:
173
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
174
+ print(
175
+ "singers have been dumped to {}".format(
176
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
177
+ )
178
+ )
179
+ return singers
180
+
181
+ def build_model(self):
182
+ raise NotImplementedError()
183
+
184
+ def build_optimizer(self):
185
+ raise NotImplementedError
186
+
187
+ def build_scheduler(self):
188
+ raise NotImplementedError()
189
+
190
+ def build_criterion(self):
191
+ raise NotImplementedError
192
+
193
+ def get_state_dict(self):
194
+ raise NotImplementedError
195
+
196
+ def save_config_file(self):
197
+ save_config(self.config_save_path, self.cfg)
198
+
199
+ # TODO, save without module.
200
+ def save_checkpoint(self, state_dict, saved_model_path):
201
+ torch.save(state_dict, saved_model_path)
202
+
203
+ def load_checkpoint(self):
204
+ checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
205
+ assert os.path.exists(checkpoint_path)
206
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
207
+ model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
208
+ assert os.path.exists(model_path)
209
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
210
+ self.logger.info(f"Re(store) from {model_path}")
211
+ checkpoint = torch.load(model_path, map_location="cpu")
212
+ return checkpoint
213
+
214
+ def load_model(self, checkpoint):
215
+ raise NotImplementedError
216
+
217
+ def restore(self):
218
+ checkpoint = self.load_checkpoint()
219
+ self.load_model(checkpoint)
220
+
221
+ def train_step(self, data):
222
+ raise NotImplementedError(
223
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
224
+ f"your sub-class of {self.__class__.__name__}. "
225
+ )
226
+
227
+ @torch.no_grad()
228
+ def eval_step(self):
229
+ raise NotImplementedError(
230
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
231
+ f"your sub-class of {self.__class__.__name__}. "
232
+ )
233
+
234
+ def write_summary(self, losses, stats):
235
+ raise NotImplementedError(
236
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
237
+ f"your sub-class of {self.__class__.__name__}. "
238
+ )
239
+
240
+ def write_valid_summary(self, losses, stats):
241
+ raise NotImplementedError(
242
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
243
+ f"your sub-class of {self.__class__.__name__}. "
244
+ )
245
+
246
+ def echo_log(self, losses, mode="Training"):
247
+ message = [
248
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
249
+ mode, self.epoch + 1, self.step, self.time_window.average
250
+ )
251
+ ]
252
+
253
+ for key in sorted(losses.keys()):
254
+ if isinstance(losses[key], dict):
255
+ for k, v in losses[key].items():
256
+ message.append(
257
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
258
+ )
259
+ else:
260
+ message.append(
261
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
262
+ )
263
+ self.logger.info(", ".join(message))
264
+
265
+ def eval_epoch(self):
266
+ self.logger.info("Validation...")
267
+ valid_losses = {}
268
+ for i, batch_data in enumerate(self.data_loader["valid"]):
269
+ for k, v in batch_data.items():
270
+ if isinstance(v, torch.Tensor):
271
+ batch_data[k] = v.cuda()
272
+ valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
273
+ for key in valid_loss:
274
+ if key not in valid_losses:
275
+ valid_losses[key] = 0
276
+ valid_losses[key] += valid_loss[key]
277
+
278
+ # Add mel and audio to the Tensorboard
279
+ # Average loss
280
+ for key in valid_losses:
281
+ valid_losses[key] /= i + 1
282
+ self.echo_log(valid_losses, "Valid")
283
+ return valid_losses, valid_stats
284
+
285
+ def train_epoch(self):
286
+ for i, batch_data in enumerate(self.data_loader["train"]):
287
+ start_time = time.time()
288
+ # Put the data to cuda device
289
+ for k, v in batch_data.items():
290
+ if isinstance(v, torch.Tensor):
291
+ batch_data[k] = v.cuda(self.args.local_rank)
292
+
293
+ # Training step
294
+ train_losses, train_stats, total_loss = self.train_step(batch_data)
295
+ self.time_window.append(time.time() - start_time)
296
+
297
+ if self.args.local_rank == 0 or not self.cfg.train.ddp:
298
+ if self.step % self.args.stdout_interval == 0:
299
+ self.echo_log(train_losses, "Training")
300
+
301
+ if self.step % self.cfg.train.save_summary_steps == 0:
302
+ self.logger.info(f"Save summary as step {self.step}")
303
+ self.write_summary(train_losses, train_stats)
304
+
305
+ if (
306
+ self.step % self.cfg.train.save_checkpoints_steps == 0
307
+ and self.step != 0
308
+ ):
309
+ saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
310
+ self.step, total_loss
311
+ )
312
+ saved_model_path = os.path.join(
313
+ self.checkpoint_dir, saved_model_name
314
+ )
315
+ saved_state_dict = self.get_state_dict()
316
+ self.save_checkpoint(saved_state_dict, saved_model_path)
317
+ self.save_config_file()
318
+ # keep max n models
319
+ remove_older_ckpt(
320
+ saved_model_name,
321
+ self.checkpoint_dir,
322
+ max_to_keep=self.cfg.train.keep_checkpoint_max,
323
+ )
324
+
325
+ if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
326
+ if isinstance(self.model, dict):
327
+ for key in self.model.keys():
328
+ self.model[key].eval()
329
+ else:
330
+ self.model.eval()
331
+ # Evaluate one epoch and get average loss
332
+ valid_losses, valid_stats = self.eval_epoch()
333
+ if isinstance(self.model, dict):
334
+ for key in self.model.keys():
335
+ self.model[key].train()
336
+ else:
337
+ self.model.train()
338
+ # Write validation losses to summary.
339
+ self.write_valid_summary(valid_losses, valid_stats)
340
+ self.step += 1
341
+
342
+ def train(self):
343
+ for epoch in range(max(0, self.epoch), self.max_epochs):
344
+ self.train_epoch()
345
+ self.epoch += 1
346
+ if self.step > self.max_steps:
347
+ self.logger.info("Training finished!")
348
+ break
Amphion/models/codec/ns3_codec/__pycache__/melspec.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
Amphion/models/codec/ns3_codec/__pycache__/transformer.cpython-310.pyc ADDED
Binary file (5.42 kB). View file
 
Amphion/models/codec/ns3_codec/alias_free_torch/filter.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ if "sinc" in dir(torch):
9
+ sinc = torch.sinc
10
+ else:
11
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
12
+ # https://adefossez.github.io/julius/julius/core.html
13
+ def sinc(x: torch.Tensor):
14
+ """
15
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17
+ """
18
+ return torch.where(
19
+ x == 0,
20
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
21
+ torch.sin(math.pi * x) / math.pi / x,
22
+ )
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ def kaiser_sinc_filter1d(
28
+ cutoff, half_width, kernel_size
29
+ ): # return filter [1,1,kernel_size]
30
+ even = kernel_size % 2 == 0
31
+ half_size = kernel_size // 2
32
+
33
+ # For kaiser window
34
+ delta_f = 4 * half_width
35
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36
+ if A > 50.0:
37
+ beta = 0.1102 * (A - 8.7)
38
+ elif A >= 21.0:
39
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40
+ else:
41
+ beta = 0.0
42
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43
+
44
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45
+ if even:
46
+ time = torch.arange(-half_size, half_size) + 0.5
47
+ else:
48
+ time = torch.arange(kernel_size) - half_size
49
+ if cutoff == 0:
50
+ filter_ = torch.zeros_like(time)
51
+ else:
52
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
54
+ # of the constant component in the input signal.
55
+ filter_ /= filter_.sum()
56
+ filter = filter_.view(1, 1, kernel_size)
57
+
58
+ return filter
59
+
60
+
61
+ class LowPassFilter1d(nn.Module):
62
+ def __init__(
63
+ self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = "replicate",
69
+ kernel_size: int = 12,
70
+ ):
71
+ # kernel_size should be even number for stylegan3 setup,
72
+ # in this implementation, odd number is also possible.
73
+ super().__init__()
74
+ if cutoff < -0.0:
75
+ raise ValueError("Minimum cutoff must be larger than zero.")
76
+ if cutoff > 0.5:
77
+ raise ValueError("A cutoff above 0.5 does not make sense.")
78
+ self.kernel_size = kernel_size
79
+ self.even = kernel_size % 2 == 0
80
+ self.pad_left = kernel_size // 2 - int(self.even)
81
+ self.pad_right = kernel_size // 2
82
+ self.stride = stride
83
+ self.padding = padding
84
+ self.padding_mode = padding_mode
85
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86
+ self.register_buffer("filter", filter)
87
+
88
+ # input [B, C, T]
89
+ def forward(self, x):
90
+ _, C, _ = x.shape
91
+
92
+ if self.padding:
93
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95
+
96
+ return out
Amphion/models/svc/diffusion/diffusion_trainer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DDPMScheduler
8
+
9
+ from models.svc.base import SVCTrainer
10
+ from modules.encoder.condition_encoder import ConditionEncoder
11
+ from .diffusion_wrapper import DiffusionWrapper
12
+
13
+
14
+ class DiffusionTrainer(SVCTrainer):
15
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
16
+ implements ``_build_model`` and ``_forward_step`` methods.
17
+ """
18
+
19
+ def __init__(self, args=None, cfg=None):
20
+ SVCTrainer.__init__(self, args, cfg)
21
+
22
+ # Only for SVC tasks using diffusion
23
+ self.noise_scheduler = DDPMScheduler(
24
+ **self.cfg.model.diffusion.scheduler_settings,
25
+ )
26
+ self.diffusion_timesteps = (
27
+ self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
28
+ )
29
+
30
+ ### Following are methods only for diffusion models ###
31
+ def _build_model(self):
32
+ r"""Build the model for training. This function is called in ``__init__`` function."""
33
+
34
+ # TODO: sort out the config
35
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
36
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
37
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
38
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
39
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
40
+
41
+ num_of_params_encoder = self.count_parameters(self.condition_encoder)
42
+ num_of_params_am = self.count_parameters(self.acoustic_mapper)
43
+ num_of_params = num_of_params_encoder + num_of_params_am
44
+ log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
45
+ num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
46
+ )
47
+ self.logger.info(log)
48
+
49
+ return model
50
+
51
+ def count_parameters(self, model):
52
+ model_param = 0.0
53
+ if isinstance(model, dict):
54
+ for key, value in model.items():
55
+ model_param += sum(p.numel() for p in model[key].parameters())
56
+ else:
57
+ model_param = sum(p.numel() for p in model.parameters())
58
+ return model_param
59
+
60
+ def _check_nan(self, batch, loss, y_pred, y_gt):
61
+ if torch.any(torch.isnan(loss)):
62
+ for k, v in batch.items():
63
+ self.logger.info(k)
64
+ self.logger.info(v)
65
+
66
+ super()._check_nan(loss, y_pred, y_gt)
67
+
68
+ def _forward_step(self, batch):
69
+ r"""Forward step for training and inference. This function is called
70
+ in ``_train_step`` & ``_test_step`` function.
71
+ """
72
+ device = self.accelerator.device
73
+
74
+ if self.online_features_extraction:
75
+ # On-the-fly features extraction
76
+ batch = self._extract_svc_features(batch)
77
+
78
+ # To debug
79
+ # for k, v in batch.items():
80
+ # print(k, v.shape, v)
81
+ # exit()
82
+
83
+ mel_input = batch["mel"]
84
+ noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
85
+ batch_size = mel_input.size(0)
86
+ timesteps = torch.randint(
87
+ 0,
88
+ self.diffusion_timesteps,
89
+ (batch_size,),
90
+ device=device,
91
+ dtype=torch.long,
92
+ )
93
+
94
+ noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
95
+ conditioner = self.condition_encoder(batch)
96
+
97
+ y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
98
+
99
+ loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
100
+ self._check_nan(batch, loss, y_pred, noise)
101
+
102
+ return loss
Amphion/models/tta/autoencoder/__init__.py ADDED
File without changes
Amphion/models/tta/autoencoder/autoencoder_loss.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import functools
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def hinge_d_loss(logits_real, logits_fake):
13
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
14
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
15
+ d_loss = 0.5 * (loss_real + loss_fake)
16
+ return d_loss
17
+
18
+
19
+ def vanilla_d_loss(logits_real, logits_fake):
20
+ d_loss = 0.5 * (
21
+ torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake))
22
+ )
23
+ return d_loss
24
+
25
+
26
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
27
+ if global_step < threshold:
28
+ weight = value
29
+ return weight
30
+
31
+
32
+ class ActNorm(nn.Module):
33
+ def __init__(
34
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
35
+ ):
36
+ assert affine
37
+ super().__init__()
38
+ self.logdet = logdet
39
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
40
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
41
+ self.allow_reverse_init = allow_reverse_init
42
+
43
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
44
+
45
+ def initialize(self, input):
46
+ with torch.no_grad():
47
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
48
+ mean = (
49
+ flatten.mean(1)
50
+ .unsqueeze(1)
51
+ .unsqueeze(2)
52
+ .unsqueeze(3)
53
+ .permute(1, 0, 2, 3)
54
+ )
55
+ std = (
56
+ flatten.std(1)
57
+ .unsqueeze(1)
58
+ .unsqueeze(2)
59
+ .unsqueeze(3)
60
+ .permute(1, 0, 2, 3)
61
+ )
62
+
63
+ self.loc.data.copy_(-mean)
64
+ self.scale.data.copy_(1 / (std + 1e-6))
65
+
66
+ def forward(self, input, reverse=False):
67
+ if reverse:
68
+ return self.reverse(input)
69
+ if len(input.shape) == 2:
70
+ input = input[:, :, None, None]
71
+ squeeze = True
72
+ else:
73
+ squeeze = False
74
+
75
+ _, _, height, width = input.shape
76
+
77
+ if self.training and self.initialized.item() == 0:
78
+ self.initialize(input)
79
+ self.initialized.fill_(1)
80
+
81
+ h = self.scale * (input + self.loc)
82
+
83
+ if squeeze:
84
+ h = h.squeeze(-1).squeeze(-1)
85
+
86
+ if self.logdet:
87
+ log_abs = torch.log(torch.abs(self.scale))
88
+ logdet = height * width * torch.sum(log_abs)
89
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
90
+ return h, logdet
91
+
92
+ return h
93
+
94
+ def reverse(self, output):
95
+ if self.training and self.initialized.item() == 0:
96
+ if not self.allow_reverse_init:
97
+ raise RuntimeError(
98
+ "Initializing ActNorm in reverse direction is "
99
+ "disabled by default. Use allow_reverse_init=True to enable."
100
+ )
101
+ else:
102
+ self.initialize(output)
103
+ self.initialized.fill_(1)
104
+
105
+ if len(output.shape) == 2:
106
+ output = output[:, :, None, None]
107
+ squeeze = True
108
+ else:
109
+ squeeze = False
110
+
111
+ h = output / self.scale - self.loc
112
+
113
+ if squeeze:
114
+ h = h.squeeze(-1).squeeze(-1)
115
+ return h
116
+
117
+
118
+ def weights_init(m):
119
+ classname = m.__class__.__name__
120
+ if classname.find("Conv") != -1:
121
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
122
+ elif classname.find("BatchNorm") != -1:
123
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
124
+ nn.init.constant_(m.bias.data, 0)
125
+
126
+
127
+ class NLayerDiscriminator(nn.Module):
128
+ """Defines a PatchGAN discriminator as in Pix2Pix
129
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
130
+ """
131
+
132
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
133
+ """Construct a PatchGAN discriminator
134
+ Parameters:
135
+ input_nc (int) -- the number of channels in input images
136
+ ndf (int) -- the number of filters in the last conv layer
137
+ n_layers (int) -- the number of conv layers in the discriminator
138
+ norm_layer -- normalization layer
139
+ """
140
+ super(NLayerDiscriminator, self).__init__()
141
+ if not use_actnorm:
142
+ norm_layer = nn.BatchNorm2d
143
+ else:
144
+ norm_layer = ActNorm
145
+ if (
146
+ type(norm_layer) == functools.partial
147
+ ): # no need to use bias as BatchNorm2d has affine parameters
148
+ use_bias = norm_layer.func != nn.BatchNorm2d
149
+ else:
150
+ use_bias = norm_layer != nn.BatchNorm2d
151
+
152
+ kw = 4
153
+ padw = 1
154
+ sequence = [
155
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
156
+ nn.LeakyReLU(0.2, True),
157
+ ]
158
+ nf_mult = 1
159
+ nf_mult_prev = 1
160
+ for n in range(1, n_layers): # gradually increase the number of filters
161
+ nf_mult_prev = nf_mult
162
+ nf_mult = min(2**n, 8)
163
+ sequence += [
164
+ nn.Conv2d(
165
+ ndf * nf_mult_prev,
166
+ ndf * nf_mult,
167
+ kernel_size=kw,
168
+ stride=2,
169
+ padding=padw,
170
+ bias=use_bias,
171
+ ),
172
+ norm_layer(ndf * nf_mult),
173
+ nn.LeakyReLU(0.2, True),
174
+ ]
175
+
176
+ nf_mult_prev = nf_mult
177
+ nf_mult = min(2**n_layers, 8)
178
+ sequence += [
179
+ nn.Conv2d(
180
+ ndf * nf_mult_prev,
181
+ ndf * nf_mult,
182
+ kernel_size=kw,
183
+ stride=1,
184
+ padding=padw,
185
+ bias=use_bias,
186
+ ),
187
+ norm_layer(ndf * nf_mult),
188
+ nn.LeakyReLU(0.2, True),
189
+ ]
190
+
191
+ sequence += [
192
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
193
+ ] # output 1 channel prediction map
194
+ self.main = nn.Sequential(*sequence)
195
+
196
+ def forward(self, input):
197
+ """Standard forward."""
198
+ return self.main(input)
199
+
200
+
201
+ class AutoencoderLossWithDiscriminator(nn.Module):
202
+ def __init__(self, cfg):
203
+ super().__init__()
204
+ self.cfg = cfg
205
+ self.kl_weight = cfg.kl_weight
206
+ self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init)
207
+
208
+ self.discriminator = NLayerDiscriminator(
209
+ input_nc=cfg.disc_in_channels,
210
+ n_layers=cfg.disc_num_layers,
211
+ use_actnorm=cfg.use_actnorm,
212
+ ).apply(weights_init)
213
+
214
+ self.discriminator_iter_start = cfg.disc_start
215
+ self.discriminator_weight = cfg.disc_weight
216
+ self.disc_factor = cfg.disc_factor
217
+ self.disc_loss = hinge_d_loss
218
+
219
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
220
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
221
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
222
+
223
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
224
+ d_weight = torch.clamp(
225
+ d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight
226
+ ).detach()
227
+ d_weight = d_weight * self.discriminator_weight
228
+ return d_weight
229
+
230
+ def forward(
231
+ self,
232
+ inputs,
233
+ reconstructions,
234
+ posteriors,
235
+ optimizer_idx,
236
+ global_step,
237
+ last_layer,
238
+ split="train",
239
+ weights=None,
240
+ ):
241
+ rec_loss = torch.abs(
242
+ inputs.contiguous() - reconstructions.contiguous()
243
+ ) # l1 loss
244
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
245
+ weighted_nll_loss = nll_loss
246
+ if weights is not None:
247
+ weighted_nll_loss = weights * nll_loss
248
+ # weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
249
+ weighted_nll_loss = torch.mean(weighted_nll_loss)
250
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
251
+ nll_loss = torch.mean(nll_loss)
252
+ kl_loss = posteriors.kl()
253
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
254
+ # ? kl_loss = torch.mean(kl_loss)
255
+
256
+ # now the GAN part
257
+ if optimizer_idx == 0:
258
+ logits_fake = self.discriminator(reconstructions.contiguous())
259
+ g_loss = -torch.mean(logits_fake)
260
+
261
+ if self.disc_factor > 0.0:
262
+ try:
263
+ d_weight = self.calculate_adaptive_weight(
264
+ nll_loss, g_loss, last_layer=last_layer
265
+ )
266
+ except RuntimeError:
267
+ assert not self.training
268
+ d_weight = torch.tensor(0.0)
269
+ else:
270
+ d_weight = torch.tensor(0.0)
271
+
272
+ disc_factor = adopt_weight(
273
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
274
+ )
275
+
276
+ total_loss = (
277
+ weighted_nll_loss
278
+ + self.kl_weight * kl_loss
279
+ + d_weight * disc_factor * g_loss
280
+ )
281
+
282
+ return {
283
+ "loss": total_loss,
284
+ "kl_loss": kl_loss,
285
+ "rec_loss": rec_loss.mean(),
286
+ "nll_loss": nll_loss,
287
+ "g_loss": g_loss,
288
+ "d_weight": d_weight,
289
+ "disc_factor": torch.tensor(disc_factor),
290
+ }
291
+
292
+ if optimizer_idx == 1:
293
+ logits_real = self.discriminator(inputs.contiguous().detach())
294
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
295
+
296
+ disc_factor = adopt_weight(
297
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
298
+ )
299
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
300
+
301
+ return {
302
+ "d_loss": d_loss,
303
+ "logits_real": logits_real.mean(),
304
+ "logits_fake": logits_fake.mean(),
305
+ }
Amphion/models/tta/ldm/audioldm_inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import torch.nn as nn
12
+ from collections import OrderedDict
13
+ import json
14
+
15
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
16
+ from models.tta.ldm.inference_utils.vocoder import Generator
17
+ from models.tta.ldm.audioldm import AudioLDM
18
+ from transformers import T5EncoderModel, AutoTokenizer
19
+ from diffusers import PNDMScheduler
20
+
21
+ import matplotlib.pyplot as plt
22
+ from scipy.io.wavfile import write
23
+
24
+
25
+ class AttrDict(dict):
26
+ def __init__(self, *args, **kwargs):
27
+ super(AttrDict, self).__init__(*args, **kwargs)
28
+ self.__dict__ = self
29
+
30
+
31
+ class AudioLDMInference:
32
+ def __init__(self, args, cfg):
33
+ self.cfg = cfg
34
+ self.args = args
35
+
36
+ self.build_autoencoderkl()
37
+ self.build_textencoder()
38
+
39
+ self.model = self.build_model()
40
+ self.load_state_dict()
41
+
42
+ self.build_vocoder()
43
+
44
+ self.out_path = self.args.output_dir
45
+ self.out_mel_path = os.path.join(self.out_path, "mel")
46
+ self.out_wav_path = os.path.join(self.out_path, "wav")
47
+ os.makedirs(self.out_mel_path, exist_ok=True)
48
+ os.makedirs(self.out_wav_path, exist_ok=True)
49
+
50
+ def build_autoencoderkl(self):
51
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
52
+ self.autoencoder_path = self.cfg.model.autoencoder_path
53
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
54
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
55
+ self.autoencoderkl.cuda(self.args.local_rank)
56
+ self.autoencoderkl.requires_grad_(requires_grad=False)
57
+ self.autoencoderkl.eval()
58
+
59
+ def build_textencoder(self):
60
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
61
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
62
+ self.text_encoder.cuda(self.args.local_rank)
63
+ self.text_encoder.requires_grad_(requires_grad=False)
64
+ self.text_encoder.eval()
65
+
66
+ def build_vocoder(self):
67
+ config_file = os.path.join(self.args.vocoder_config_path)
68
+ with open(config_file) as f:
69
+ data = f.read()
70
+ json_config = json.loads(data)
71
+ h = AttrDict(json_config)
72
+ self.vocoder = Generator(h).to(self.args.local_rank)
73
+ checkpoint_dict = torch.load(
74
+ self.args.vocoder_path, map_location=self.args.local_rank
75
+ )
76
+ self.vocoder.load_state_dict(checkpoint_dict["generator"])
77
+
78
+ def build_model(self):
79
+ self.model = AudioLDM(self.cfg.model.audioldm)
80
+ return self.model
81
+
82
+ def load_state_dict(self):
83
+ self.checkpoint_path = self.args.checkpoint_path
84
+ checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
85
+ self.model.load_state_dict(checkpoint["model"])
86
+ self.model.cuda(self.args.local_rank)
87
+
88
+ def get_text_embedding(self):
89
+ text = self.args.text
90
+
91
+ prompt = [text]
92
+
93
+ text_input = self.tokenizer(
94
+ prompt,
95
+ max_length=self.tokenizer.model_max_length,
96
+ truncation=True,
97
+ padding="do_not_pad",
98
+ return_tensors="pt",
99
+ )
100
+ text_embeddings = self.text_encoder(
101
+ text_input.input_ids.to(self.args.local_rank)
102
+ )[0]
103
+
104
+ max_length = text_input.input_ids.shape[-1]
105
+ uncond_input = self.tokenizer(
106
+ [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
107
+ )
108
+ uncond_embeddings = self.text_encoder(
109
+ uncond_input.input_ids.to(self.args.local_rank)
110
+ )[0]
111
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
112
+
113
+ return text_embeddings
114
+
115
+ def inference(self):
116
+ text_embeddings = self.get_text_embedding()
117
+ print(text_embeddings.shape)
118
+
119
+ num_steps = self.args.num_steps
120
+ guidance_scale = self.args.guidance_scale
121
+
122
+ noise_scheduler = PNDMScheduler(
123
+ num_train_timesteps=1000,
124
+ beta_start=0.00085,
125
+ beta_end=0.012,
126
+ beta_schedule="scaled_linear",
127
+ skip_prk_steps=True,
128
+ set_alpha_to_one=False,
129
+ steps_offset=1,
130
+ prediction_type="epsilon",
131
+ )
132
+
133
+ noise_scheduler.set_timesteps(num_steps)
134
+
135
+ latents = torch.randn(
136
+ (
137
+ 1,
138
+ self.cfg.model.autoencoderkl.z_channels,
139
+ 80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
140
+ 624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
141
+ )
142
+ ).to(self.args.local_rank)
143
+
144
+ self.model.eval()
145
+ for t in tqdm(noise_scheduler.timesteps):
146
+ t = t.to(self.args.local_rank)
147
+
148
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
149
+ latent_model_input = torch.cat([latents] * 2)
150
+
151
+ latent_model_input = noise_scheduler.scale_model_input(
152
+ latent_model_input, timestep=t
153
+ )
154
+ # print(latent_model_input.shape)
155
+
156
+ # predict the noise residual
157
+ with torch.no_grad():
158
+ noise_pred = self.model(
159
+ latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
160
+ )
161
+
162
+ # perform guidance
163
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
164
+ noise_pred = noise_pred_uncond + guidance_scale * (
165
+ noise_pred_text - noise_pred_uncond
166
+ )
167
+
168
+ # compute the previous noisy sample x_t -> x_t-1
169
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
170
+ # print(latents.shape)
171
+
172
+ latents_out = latents
173
+ print(latents_out.shape)
174
+
175
+ with torch.no_grad():
176
+ mel_out = self.autoencoderkl.decode(latents_out)
177
+ print(mel_out.shape)
178
+
179
+ melspec = mel_out[0, 0].cpu().detach().numpy()
180
+ plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec)
181
+
182
+ self.vocoder.eval()
183
+ self.vocoder.remove_weight_norm()
184
+ with torch.no_grad():
185
+ melspec = np.expand_dims(melspec, 0)
186
+ melspec = torch.FloatTensor(melspec).to(self.args.local_rank)
187
+
188
+ y = self.vocoder(melspec)
189
+ audio = y.squeeze()
190
+ audio = audio * 32768.0
191
+ audio = audio.cpu().numpy().astype("int16")
192
+
193
+ write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio)
Amphion/models/tts/base/tts_inferece.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ import time
9
+ import accelerate
10
+ import random
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from accelerate.logging import get_logger
14
+ from torch.utils.data import DataLoader
15
+ from safetensors.torch import load_file
16
+
17
+
18
+ from abc import abstractmethod
19
+ from pathlib import Path
20
+ from utils.io import save_audio
21
+ from utils.util import load_config
22
+ from models.vocoders.vocoder_inference import synthesis
23
+
24
+
25
+ class TTSInference(object):
26
+ def __init__(self, args=None, cfg=None):
27
+ super().__init__()
28
+
29
+ start = time.monotonic_ns()
30
+ self.args = args
31
+ self.cfg = cfg
32
+ self.infer_type = args.mode
33
+
34
+ # get exp_dir
35
+ if self.args.acoustics_dir is not None:
36
+ self.exp_dir = self.args.acoustics_dir
37
+ elif self.args.checkpoint_path is not None:
38
+ self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path))
39
+
40
+ # Init accelerator
41
+ self.accelerator = accelerate.Accelerator()
42
+ self.accelerator.wait_for_everyone()
43
+ self.device = self.accelerator.device
44
+
45
+ # Get logger
46
+ with self.accelerator.main_process_first():
47
+ self.logger = get_logger("inference", log_level=args.log_level)
48
+
49
+ # Log some info
50
+ self.logger.info("=" * 56)
51
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
52
+ self.logger.info("=" * 56)
53
+ self.logger.info("\n")
54
+
55
+ self.acoustic_model_dir = args.acoustics_dir
56
+ self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}")
57
+
58
+ if args.vocoder_dir is not None:
59
+ self.vocoder_dir = args.vocoder_dir
60
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
61
+
62
+ os.makedirs(args.output_dir, exist_ok=True)
63
+
64
+ # Set random seed
65
+ with self.accelerator.main_process_first():
66
+ start = time.monotonic_ns()
67
+ self._set_random_seed(self.cfg.train.random_seed)
68
+ end = time.monotonic_ns()
69
+ self.logger.debug(
70
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
71
+ )
72
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
73
+
74
+ # Setup data loader
75
+ if self.infer_type == "batch":
76
+ with self.accelerator.main_process_first():
77
+ self.logger.info("Building dataset...")
78
+ start = time.monotonic_ns()
79
+ self.test_dataloader = self._build_test_dataloader()
80
+ end = time.monotonic_ns()
81
+ self.logger.info(
82
+ f"Building dataset done in {(end - start) / 1e6:.2f}ms"
83
+ )
84
+
85
+ # Build model
86
+ with self.accelerator.main_process_first():
87
+ self.logger.info("Building model...")
88
+ start = time.monotonic_ns()
89
+ self.model = self._build_model()
90
+ end = time.monotonic_ns()
91
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
92
+
93
+ # Init with accelerate
94
+ self.logger.info("Initializing accelerate...")
95
+ start = time.monotonic_ns()
96
+ self.accelerator = accelerate.Accelerator()
97
+ self.model = self.accelerator.prepare(self.model)
98
+ if self.infer_type == "batch":
99
+ self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
100
+ end = time.monotonic_ns()
101
+ self.accelerator.wait_for_everyone()
102
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
103
+
104
+ with self.accelerator.main_process_first():
105
+ self.logger.info("Loading checkpoint...")
106
+ start = time.monotonic_ns()
107
+ if args.acoustics_dir is not None:
108
+ self._load_model(
109
+ checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint")
110
+ )
111
+ elif args.checkpoint_path is not None:
112
+ self._load_model(checkpoint_path=args.checkpoint_path)
113
+ else:
114
+ print("Either checkpoint dir or checkpoint path should be provided.")
115
+
116
+ end = time.monotonic_ns()
117
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
118
+
119
+ self.model.eval()
120
+ self.accelerator.wait_for_everyone()
121
+
122
+ def _build_test_dataset(self):
123
+ pass
124
+
125
+ def _build_model(self):
126
+ pass
127
+
128
+ # TODO: LEGACY CODE
129
+ def _build_test_dataloader(self):
130
+ datasets, collate = self._build_test_dataset()
131
+ self.test_dataset = datasets(self.args, self.cfg)
132
+ self.test_collate = collate(self.cfg)
133
+ self.test_batch_size = min(
134
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
135
+ )
136
+ test_dataloader = DataLoader(
137
+ self.test_dataset,
138
+ collate_fn=self.test_collate,
139
+ num_workers=1,
140
+ batch_size=self.test_batch_size,
141
+ shuffle=False,
142
+ )
143
+ return test_dataloader
144
+
145
+ def _load_model(
146
+ self,
147
+ checkpoint_dir: str = None,
148
+ checkpoint_path: str = None,
149
+ old_mode: bool = False,
150
+ ):
151
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
152
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
153
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
154
+ method after** ``accelerator.prepare()``.
155
+ """
156
+
157
+ if checkpoint_path is None:
158
+ assert checkpoint_dir is not None
159
+ # Load the latest accelerator state dicts
160
+ ls = [
161
+ str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i)
162
+ ]
163
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
164
+ checkpoint_path = ls[0]
165
+
166
+ if (
167
+ Path(os.path.join(checkpoint_path, "model.safetensors")).exists()
168
+ and accelerate.__version__ < "0.25"
169
+ ):
170
+ self.model.load_state_dict(
171
+ load_file(os.path.join(checkpoint_path, "model.safetensors")),
172
+ strict=False,
173
+ )
174
+ else:
175
+ self.accelerator.load_state(str(checkpoint_path))
176
+ return str(checkpoint_path)
177
+
178
+ def inference(self):
179
+ if self.infer_type == "single":
180
+ out_dir = os.path.join(self.args.output_dir, "single")
181
+ os.makedirs(out_dir, exist_ok=True)
182
+
183
+ pred_audio = self.inference_for_single_utterance()
184
+ save_path = os.path.join(out_dir, "test_pred.wav")
185
+ save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate)
186
+
187
+ elif self.infer_type == "batch":
188
+ out_dir = os.path.join(self.args.output_dir, "batch")
189
+ os.makedirs(out_dir, exist_ok=True)
190
+
191
+ pred_audio_list = self.inference_for_batches()
192
+ for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
193
+ uid = it["Uid"]
194
+ save_audio(
195
+ os.path.join(out_dir, f"{uid}.wav"),
196
+ wav.numpy(),
197
+ self.cfg.preprocess.sample_rate,
198
+ add_silence=True,
199
+ turn_up=True,
200
+ )
201
+ tmp_file = os.path.join(out_dir, f"{uid}.pt")
202
+ if os.path.exists(tmp_file):
203
+ os.remove(tmp_file)
204
+ print("Saved to: ", out_dir)
205
+
206
+ @torch.inference_mode()
207
+ def inference_for_batches(self):
208
+ y_pred = []
209
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
210
+ y_pred, mel_lens, _ = self._inference_each_batch(batch)
211
+ y_ls = y_pred.chunk(self.test_batch_size)
212
+ tgt_ls = mel_lens.chunk(self.test_batch_size)
213
+ j = 0
214
+ for it, l in zip(y_ls, tgt_ls):
215
+ l = l.item()
216
+ it = it.squeeze(0)[:l].detach().cpu()
217
+
218
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
219
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
220
+ j += 1
221
+
222
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
223
+ res = synthesis(
224
+ cfg=vocoder_cfg,
225
+ vocoder_weight_file=vocoder_ckpt,
226
+ n_samples=None,
227
+ pred=[
228
+ torch.load(
229
+ os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
230
+ ).numpy()
231
+ for item in self.test_dataset.metadata
232
+ ],
233
+ )
234
+ for it, wav in zip(self.test_dataset.metadata, res):
235
+ uid = it["Uid"]
236
+ save_audio(
237
+ os.path.join(self.args.output_dir, f"{uid}.wav"),
238
+ wav.numpy(),
239
+ 22050,
240
+ add_silence=True,
241
+ turn_up=True,
242
+ )
243
+
244
+ @abstractmethod
245
+ @torch.inference_mode()
246
+ def _inference_each_batch(self, batch_data):
247
+ pass
248
+
249
+ def inference_for_single_utterance(self, text):
250
+ pass
251
+
252
+ def synthesis_by_vocoder(self, pred):
253
+ audios_pred = synthesis(
254
+ self.vocoder_cfg,
255
+ self.checkpoint_dir_vocoder,
256
+ len(pred),
257
+ pred,
258
+ )
259
+
260
+ return audios_pred
261
+
262
+ @staticmethod
263
+ def _parse_vocoder(vocoder_dir):
264
+ r"""Parse vocoder config"""
265
+ vocoder_dir = os.path.abspath(vocoder_dir)
266
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
267
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
268
+ ckpt_path = str(ckpt_list[0])
269
+ vocoder_cfg = load_config(
270
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
271
+ )
272
+ return vocoder_cfg, ckpt_path
273
+
274
+ def _set_random_seed(self, seed):
275
+ """Set random seed for all possible random modules."""
276
+ random.seed(seed)
277
+ np.random.seed(seed)
278
+ torch.random.manual_seed(seed)
Amphion/models/tts/fastspeech2/fs2_dataset.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+ from models.base.base_dataset import (
11
+ BaseOfflineCollator,
12
+ BaseOfflineDataset,
13
+ BaseTestDataset,
14
+ BaseTestCollator,
15
+ )
16
+ from text import text_to_sequence
17
+
18
+
19
+ class FS2Dataset(BaseOfflineDataset):
20
+ def __init__(self, cfg, dataset, is_valid=False):
21
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
22
+ self.batch_size = cfg.train.batch_size
23
+ cfg = cfg.preprocess
24
+ # utt2duration
25
+ self.utt2duration_path = {}
26
+ for utt_info in self.metadata:
27
+ dataset = utt_info["Dataset"]
28
+ uid = utt_info["Uid"]
29
+ utt = "{}_{}".format(dataset, uid)
30
+
31
+ self.utt2duration_path[utt] = os.path.join(
32
+ cfg.processed_dir,
33
+ dataset,
34
+ cfg.duration_dir,
35
+ uid + ".npy",
36
+ )
37
+ self.utt2dur = self.read_duration()
38
+
39
+ if cfg.use_frame_energy:
40
+ self.frame_utt2energy, self.energy_statistic = load_energy(
41
+ self.metadata,
42
+ cfg.processed_dir,
43
+ cfg.energy_dir,
44
+ use_log_scale=cfg.use_log_scale_energy,
45
+ utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None,
46
+ return_norm=True,
47
+ )
48
+ elif cfg.use_phone_energy:
49
+ self.phone_utt2energy, self.energy_statistic = load_energy(
50
+ self.metadata,
51
+ cfg.processed_dir,
52
+ cfg.phone_energy_dir,
53
+ use_log_scale=cfg.use_log_scale_energy,
54
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
55
+ return_norm=True,
56
+ )
57
+
58
+ if cfg.use_frame_pitch:
59
+ self.frame_utt2pitch, self.pitch_statistic = load_energy(
60
+ self.metadata,
61
+ cfg.processed_dir,
62
+ cfg.pitch_dir,
63
+ use_log_scale=cfg.energy_extract_mode,
64
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
65
+ return_norm=True,
66
+ )
67
+
68
+ elif cfg.use_phone_pitch:
69
+ self.phone_utt2pitch, self.pitch_statistic = load_energy(
70
+ self.metadata,
71
+ cfg.processed_dir,
72
+ cfg.phone_pitch_dir,
73
+ use_log_scale=cfg.use_log_scale_pitch,
74
+ utt2spk=self.utt2spk if cfg.use_spkid else None,
75
+ return_norm=True,
76
+ )
77
+
78
+ # utt2lab
79
+ self.utt2lab_path = {}
80
+ for utt_info in self.metadata:
81
+ dataset = utt_info["Dataset"]
82
+ uid = utt_info["Uid"]
83
+ utt = "{}_{}".format(dataset, uid)
84
+
85
+ self.utt2lab_path[utt] = os.path.join(
86
+ cfg.processed_dir,
87
+ dataset,
88
+ cfg.lab_dir,
89
+ uid + ".txt",
90
+ )
91
+
92
+ self.speaker_map = {}
93
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
94
+ with open(
95
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
96
+ ) as f:
97
+ self.speaker_map = json.load(f)
98
+
99
+ self.metadata = self.check_metadata()
100
+
101
+ def __getitem__(self, index):
102
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
103
+
104
+ utt_info = self.metadata[index]
105
+ dataset = utt_info["Dataset"]
106
+ uid = utt_info["Uid"]
107
+ utt = "{}_{}".format(dataset, uid)
108
+
109
+ duration = self.utt2dur[utt]
110
+
111
+ # text
112
+ f = open(self.utt2lab_path[utt], "r")
113
+ phones = f.readlines()[0].strip()
114
+ f.close()
115
+ # todo: add cleaner(chenxi)
116
+ phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"]))
117
+ text_len = len(phones_ids)
118
+
119
+ if self.cfg.preprocess.use_frame_pitch:
120
+ pitch = self.frame_utt2pitch[utt]
121
+ elif self.cfg.preprocess.use_phone_pitch:
122
+ pitch = self.phone_utt2pitch[utt]
123
+
124
+ if self.cfg.preprocess.use_frame_energy:
125
+ energy = self.frame_utt2energy[utt]
126
+ elif self.cfg.preprocess.use_phone_energy:
127
+ energy = self.phone_utt2energy[utt]
128
+
129
+ # speaker
130
+ if len(self.speaker_map) > 0:
131
+ speaker_id = self.speaker_map[utt_info["Singer"]]
132
+ else:
133
+ speaker_id = 0
134
+
135
+ single_feature.update(
136
+ {
137
+ "durations": duration,
138
+ "texts": phones_ids,
139
+ "spk_id": speaker_id,
140
+ "text_len": text_len,
141
+ "pitch": pitch,
142
+ "energy": energy,
143
+ "uid": uid,
144
+ }
145
+ )
146
+ return self.clip_if_too_long(single_feature)
147
+
148
+ def read_duration(self):
149
+ # read duration
150
+ utt2dur = {}
151
+ for index in range(len(self.metadata)):
152
+ utt_info = self.metadata[index]
153
+ dataset = utt_info["Dataset"]
154
+ uid = utt_info["Uid"]
155
+ utt = "{}_{}".format(dataset, uid)
156
+
157
+ if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists(
158
+ self.utt2duration_path[utt]
159
+ ):
160
+ continue
161
+
162
+ mel = np.load(self.utt2mel_path[utt]).transpose(1, 0)
163
+ duration = np.load(self.utt2duration_path[utt])
164
+ assert mel.shape[0] == sum(
165
+ duration
166
+ ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
167
+ utt2dur[utt] = duration
168
+ return utt2dur
169
+
170
+ def __len__(self):
171
+ return len(self.metadata)
172
+
173
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
174
+ """
175
+ ending_ts: to avoid invalid whisper features for over 30s audios
176
+ 2812 = 30 * 24000 // 256
177
+ """
178
+ ts = max(feature_seq_len - max_seq_len, 0)
179
+ ts = min(ts, ending_ts - max_seq_len)
180
+
181
+ start = random.randint(0, ts)
182
+ end = start + max_seq_len
183
+ return start, end
184
+
185
+ def clip_if_too_long(self, sample, max_seq_len=1000):
186
+ """
187
+ sample :
188
+ {
189
+ 'spk_id': (1,),
190
+ 'target_len': int
191
+ 'mel': (seq_len, dim),
192
+ 'frame_pitch': (seq_len,)
193
+ 'frame_energy': (seq_len,)
194
+ 'content_vector_feat': (seq_len, dim)
195
+ }
196
+ """
197
+ if sample["target_len"] <= max_seq_len:
198
+ return sample
199
+
200
+ start, end = self.random_select(sample["target_len"], max_seq_len)
201
+ sample["target_len"] = end - start
202
+
203
+ for k in sample.keys():
204
+ if k not in ["spk_id", "target_len"]:
205
+ sample[k] = sample[k][start:end]
206
+
207
+ return sample
208
+
209
+ def check_metadata(self):
210
+ new_metadata = []
211
+ for utt_info in self.metadata:
212
+ dataset = utt_info["Dataset"]
213
+ uid = utt_info["Uid"]
214
+ utt = "{}_{}".format(dataset, uid)
215
+ if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists(
216
+ self.utt2mel_path[utt]
217
+ ):
218
+ continue
219
+ else:
220
+ new_metadata.append(utt_info)
221
+ return new_metadata
222
+
223
+
224
+ class FS2Collator(BaseOfflineCollator):
225
+ """Zero-pads model inputs and targets based on number of frames per step"""
226
+
227
+ def __init__(self, cfg):
228
+ BaseOfflineCollator.__init__(self, cfg)
229
+ self.sort = cfg.train.sort_sample
230
+ self.batch_size = cfg.train.batch_size
231
+ self.drop_last = cfg.train.drop_last
232
+
233
+ def __call__(self, batch):
234
+ # mel: [b, T, n_mels]
235
+ # frame_pitch, frame_energy: [1, T]
236
+ # target_len: [1]
237
+ # spk_id: [b, 1]
238
+ # mask: [b, T, 1]
239
+ packed_batch_features = dict()
240
+
241
+ for key in batch[0].keys():
242
+ if key == "target_len":
243
+ packed_batch_features["target_len"] = torch.LongTensor(
244
+ [b["target_len"] for b in batch]
245
+ )
246
+ masks = [
247
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
248
+ ]
249
+ packed_batch_features["mask"] = pad_sequence(
250
+ masks, batch_first=True, padding_value=0
251
+ )
252
+ elif key == "text_len":
253
+ packed_batch_features["text_len"] = torch.LongTensor(
254
+ [b["text_len"] for b in batch]
255
+ )
256
+ masks = [
257
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
258
+ ]
259
+ packed_batch_features["text_mask"] = pad_sequence(
260
+ masks, batch_first=True, padding_value=0
261
+ )
262
+ elif key == "spk_id":
263
+ packed_batch_features["spk_id"] = torch.LongTensor(
264
+ [b["spk_id"] for b in batch]
265
+ )
266
+ elif key == "uid":
267
+ packed_batch_features[key] = [b["uid"] for b in batch]
268
+ else:
269
+ values = [torch.from_numpy(b[key]) for b in batch]
270
+ packed_batch_features[key] = pad_sequence(
271
+ values, batch_first=True, padding_value=0
272
+ )
273
+ return packed_batch_features
274
+
275
+
276
+ class FS2TestDataset(BaseTestDataset):
277
+ def __init__(self, args, cfg, infer_type=None):
278
+ datasets = cfg.dataset
279
+ cfg = cfg.preprocess
280
+ is_bigdata = False
281
+
282
+ assert len(datasets) >= 1
283
+ if len(datasets) > 1:
284
+ datasets.sort()
285
+ bigdata_version = "_".join(datasets)
286
+ processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version)
287
+ is_bigdata = True
288
+ else:
289
+ processed_data_dir = os.path.join(cfg.processed_dir, args.dataset)
290
+
291
+ if args.test_list_file:
292
+ self.metafile_path = args.test_list_file
293
+ self.metadata = self.get_metadata()
294
+ else:
295
+ assert args.testing_set
296
+ source_metafile_path = os.path.join(
297
+ cfg.processed_dir,
298
+ args.dataset,
299
+ "{}.json".format(args.testing_set),
300
+ )
301
+ with open(source_metafile_path, "r") as f:
302
+ self.metadata = json.load(f)
303
+
304
+ self.cfg = cfg
305
+ self.datasets = datasets
306
+ self.data_root = processed_data_dir
307
+ self.is_bigdata = is_bigdata
308
+ self.source_dataset = args.dataset
309
+
310
+ ######### Load source acoustic features #########
311
+ if cfg.use_spkid:
312
+ spk2id_path = os.path.join(self.data_root, cfg.spk2id)
313
+ utt2sp_path = os.path.join(self.data_root, cfg.utt2spk)
314
+ self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets)
315
+
316
+ # utt2lab
317
+ self.utt2lab_path = {}
318
+ for utt_info in self.metadata:
319
+ dataset = utt_info["Dataset"]
320
+ uid = utt_info["Uid"]
321
+ utt = "{}_{}".format(dataset, uid)
322
+ self.utt2lab_path[utt] = os.path.join(
323
+ cfg.processed_dir,
324
+ dataset,
325
+ cfg.lab_dir,
326
+ uid + ".txt",
327
+ )
328
+
329
+ self.speaker_map = {}
330
+ if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
331
+ with open(
332
+ os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
333
+ ) as f:
334
+ self.speaker_map = json.load(f)
335
+
336
+ def __getitem__(self, index):
337
+ single_feature = {}
338
+
339
+ utt_info = self.metadata[index]
340
+ dataset = utt_info["Dataset"]
341
+ uid = utt_info["Uid"]
342
+ utt = "{}_{}".format(dataset, uid)
343
+
344
+ # text
345
+ f = open(self.utt2lab_path[utt], "r")
346
+ phones = f.readlines()[0].strip()
347
+ f.close()
348
+
349
+ phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners))
350
+ text_len = len(phones_ids)
351
+
352
+ # speaker
353
+ if len(self.speaker_map) > 0:
354
+ speaker_id = self.speaker_map[utt_info["Singer"]]
355
+ else:
356
+ speaker_id = 0
357
+
358
+ single_feature.update(
359
+ {
360
+ "texts": phones_ids,
361
+ "spk_id": speaker_id,
362
+ "text_len": text_len,
363
+ }
364
+ )
365
+
366
+ return single_feature
367
+
368
+ def __len__(self):
369
+ return len(self.metadata)
370
+
371
+ def get_metadata(self):
372
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
373
+ metadata = json.load(f)
374
+
375
+ return metadata
376
+
377
+
378
+ class FS2TestCollator(BaseTestCollator):
379
+ """Zero-pads model inputs and targets based on number of frames per step"""
380
+
381
+ def __init__(self, cfg):
382
+ self.cfg = cfg
383
+
384
+ def __call__(self, batch):
385
+ packed_batch_features = dict()
386
+
387
+ # mel: [b, T, n_mels]
388
+ # frame_pitch, frame_energy: [1, T]
389
+ # target_len: [1]
390
+ # spk_id: [b, 1]
391
+ # mask: [b, T, 1]
392
+
393
+ for key in batch[0].keys():
394
+ if key == "target_len":
395
+ packed_batch_features["target_len"] = torch.LongTensor(
396
+ [b["target_len"] for b in batch]
397
+ )
398
+ masks = [
399
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
400
+ ]
401
+ packed_batch_features["mask"] = pad_sequence(
402
+ masks, batch_first=True, padding_value=0
403
+ )
404
+ elif key == "text_len":
405
+ packed_batch_features["text_len"] = torch.LongTensor(
406
+ [b["text_len"] for b in batch]
407
+ )
408
+ masks = [
409
+ torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
410
+ ]
411
+ packed_batch_features["text_mask"] = pad_sequence(
412
+ masks, batch_first=True, padding_value=0
413
+ )
414
+ elif key == "spk_id":
415
+ packed_batch_features["spk_id"] = torch.LongTensor(
416
+ [b["spk_id"] for b in batch]
417
+ )
418
+ else:
419
+ values = [torch.from_numpy(b[key]) for b in batch]
420
+ packed_batch_features[key] = pad_sequence(
421
+ values, batch_first=True, padding_value=0
422
+ )
423
+
424
+ return packed_batch_features
Amphion/models/tts/naturalspeech2/diffusion.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from models.tts.naturalspeech2.wavenet import WaveNet
11
+
12
+
13
+ class Diffusion(nn.Module):
14
+ def __init__(self, cfg):
15
+ super().__init__()
16
+
17
+ self.cfg = cfg
18
+
19
+ self.diff_estimator = WaveNet(cfg.wavenet)
20
+ self.beta_min = cfg.beta_min
21
+ self.beta_max = cfg.beta_max
22
+ self.sigma = cfg.sigma
23
+ self.noise_factor = cfg.noise_factor
24
+
25
+ def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5):
26
+ """
27
+ x: (B, 128, T)
28
+ x_mask: (B, T), mask is 0
29
+ cond: (B, T, 512)
30
+ spk_query_emb: (B, 32, 512)
31
+ """
32
+ diffusion_step = torch.rand(
33
+ x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
34
+ )
35
+ diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
36
+ xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
37
+
38
+ cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1))
39
+ x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
40
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
41
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
42
+ noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor)
43
+ noise = z
44
+ diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise}
45
+ return diff_out
46
+
47
+ @torch.no_grad()
48
+ def get_cum_beta(self, time_step):
49
+ return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * (
50
+ time_step**2
51
+ )
52
+
53
+ @torch.no_grad()
54
+ def get_beta_t(self, time_step):
55
+ return self.beta_min + (self.beta_max - self.beta_min) * time_step
56
+
57
+ @torch.no_grad()
58
+ def forward_diffusion(self, x0, diffusion_step):
59
+ """
60
+ x0: (B, 128, T)
61
+ time_step: (B,)
62
+ """
63
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
64
+ cum_beta = self.get_cum_beta(time_step)
65
+ mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2))
66
+ variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2)))
67
+ z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
68
+ xt = mean + z * torch.sqrt(variance) * self.noise_factor
69
+ return xt, z
70
+
71
+ @torch.no_grad()
72
+ def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h):
73
+ time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
74
+ cum_beta = self.get_cum_beta(time_step=time_step)
75
+ beta_t = self.get_beta_t(time_step=time_step)
76
+ x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
77
+ mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
78
+ noise_pred = xt - mean_pred
79
+ variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
80
+ logp = -noise_pred / (variance + 1e-8)
81
+ dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2))
82
+ return dxt
83
+
84
+ @torch.no_grad()
85
+ def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb):
86
+ h = 1.0 / max(n_timesteps, 1)
87
+ xt = z
88
+ for i in range(n_timesteps):
89
+ t = (1.0 - (i + 0.5) * h) * torch.ones(
90
+ z.shape[0], dtype=z.dtype, device=z.device
91
+ )
92
+ dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
93
+ xt_ = xt - dxt
94
+ if self.cfg.ode_solver == "midpoint":
95
+ x_mid = 0.5 * (xt_ + xt)
96
+ dxt = self.cal_dxt(
97
+ x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
98
+ )
99
+ xt = xt - dxt
100
+ elif self.cfg.ode_solver == "euler":
101
+ xt = xt_
102
+ return xt
103
+
104
+ @torch.no_grad()
105
+ def reverse_diffusion_from_t(
106
+ self, z, x_mask, cond, n_timesteps, spk_query_emb, t_start
107
+ ):
108
+ h = t_start / max(n_timesteps, 1)
109
+ xt = z
110
+ for i in range(n_timesteps):
111
+ t = (t_start - (i + 0.5) * h) * torch.ones(
112
+ z.shape[0], dtype=z.dtype, device=z.device
113
+ )
114
+ dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
115
+ xt_ = xt - dxt
116
+ if self.cfg.ode_solver == "midpoint":
117
+ x_mid = 0.5 * (xt_ + xt)
118
+ dxt = self.cal_dxt(
119
+ x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
120
+ )
121
+ xt = xt - dxt
122
+ elif self.cfg.ode_solver == "euler":
123
+ xt = xt_
124
+ return xt
Amphion/models/tts/vits/vits_inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+ import torch
11
+ import json
12
+ from models.tts.base.tts_inferece import TTSInference
13
+ from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator
14
+ from models.tts.vits.vits import SynthesizerTrn
15
+ from processors.phone_extractor import phoneExtractor
16
+ from text.text_token_collation import phoneIDCollation
17
+ from utils.data_utils import *
18
+
19
+
20
+ class VitsInference(TTSInference):
21
+ def __init__(self, args=None, cfg=None):
22
+ TTSInference.__init__(self, args, cfg)
23
+
24
+ def _build_model(self):
25
+ net_g = SynthesizerTrn(
26
+ self.cfg.model.text_token_num,
27
+ self.cfg.preprocess.n_fft // 2 + 1,
28
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
29
+ **self.cfg.model,
30
+ )
31
+
32
+ return net_g
33
+
34
+ def _build_test_dataset(sefl):
35
+ return VITSTestDataset, VITSTestCollator
36
+
37
+ def build_save_dir(self, dataset, speaker):
38
+ save_dir = os.path.join(
39
+ self.args.output_dir,
40
+ "tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
41
+ )
42
+ if dataset is not None:
43
+ save_dir = os.path.join(save_dir, "data_{}".format(dataset))
44
+ if speaker != -1:
45
+ save_dir = os.path.join(
46
+ save_dir,
47
+ "spk_{}".format(speaker),
48
+ )
49
+ os.makedirs(save_dir, exist_ok=True)
50
+ print("Saving to ", save_dir)
51
+ return save_dir
52
+
53
+ def inference_for_batches(
54
+ self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
55
+ ):
56
+ ###### Construct test_batch ######
57
+ n_batch = len(self.test_dataloader)
58
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
59
+ print(
60
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
61
+ now, self.test_batch_size, n_batch
62
+ )
63
+ )
64
+ self.model.eval()
65
+
66
+ ###### Inference for each batch ######
67
+ pred_res = []
68
+ with torch.no_grad():
69
+ for i, batch_data in enumerate(
70
+ self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
71
+ ):
72
+ spk_id = None
73
+ if (
74
+ self.cfg.preprocess.use_spkid
75
+ and self.cfg.train.multi_speaker_training
76
+ ):
77
+ spk_id = batch_data["spk_id"]
78
+
79
+ outputs = self.model.infer(
80
+ batch_data["phone_seq"],
81
+ batch_data["phone_len"],
82
+ spk_id,
83
+ noise_scale=noise_scale,
84
+ noise_scale_w=noise_scale_w,
85
+ length_scale=length_scale,
86
+ )
87
+
88
+ audios = outputs["y_hat"]
89
+ masks = outputs["mask"]
90
+
91
+ for idx in range(audios.size(0)):
92
+ audio = audios[idx, 0, :].data.cpu().float()
93
+ mask = masks[idx, :, :]
94
+ audio_length = (
95
+ mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size
96
+ )
97
+ audio_length = audio_length.cpu().numpy()
98
+ audio = audio[:audio_length]
99
+ pred_res.append(audio)
100
+
101
+ return pred_res
102
+
103
+ def inference_for_single_utterance(
104
+ self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
105
+ ):
106
+ text = self.args.text
107
+
108
+ # get phone symbol file
109
+ phone_symbol_file = None
110
+ if self.cfg.preprocess.phone_extractor != "lexicon":
111
+ phone_symbol_file = os.path.join(
112
+ self.exp_dir, self.cfg.preprocess.symbols_dict
113
+ )
114
+ assert os.path.exists(phone_symbol_file)
115
+ # convert text to phone sequence
116
+ phone_extractor = phoneExtractor(self.cfg)
117
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
118
+ # convert phone sequence to phone id sequence
119
+ phon_id_collator = phoneIDCollation(
120
+ self.cfg, symbols_dict_file=phone_symbol_file
121
+ )
122
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
123
+
124
+ if self.cfg.preprocess.add_blank:
125
+ phone_id_seq = intersperse(phone_id_seq, 0)
126
+
127
+ # convert phone sequence to phone id sequence
128
+ phone_id_seq = np.array(phone_id_seq)
129
+ phone_id_seq = torch.from_numpy(phone_id_seq)
130
+
131
+ # get speaker id if multi-speaker training and use speaker id
132
+ speaker_id = None
133
+ if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
134
+ spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
135
+ with open(spk2id_file, "r") as f:
136
+ spk2id = json.load(f)
137
+ speaker_name = self.args.speaker_name
138
+ assert (
139
+ speaker_name in spk2id
140
+ ), f"Speaker {speaker_name} not found in the spk2id keys. \
141
+ Please make sure you've specified the correct speaker name in infer_speaker_name."
142
+ speaker_id = spk2id[speaker_name]
143
+ speaker_id = torch.from_numpy(
144
+ np.array([speaker_id], dtype=np.int32)
145
+ ).unsqueeze(0)
146
+
147
+ with torch.no_grad():
148
+ x_tst = phone_id_seq.to(self.device).unsqueeze(0)
149
+ x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
150
+ if speaker_id is not None:
151
+ speaker_id = speaker_id.to(self.device)
152
+ outputs = self.model.infer(
153
+ x_tst,
154
+ x_tst_lengths,
155
+ sid=speaker_id,
156
+ noise_scale=noise_scale,
157
+ noise_scale_w=noise_scale_w,
158
+ length_scale=length_scale,
159
+ )
160
+
161
+ audio = outputs["y_hat"][0, 0].data.cpu().float().numpy()
162
+
163
+ return audio
Amphion/models/vocoders/flow/flow_vocoder_trainer.py ADDED
File without changes
Amphion/models/vocoders/gan/gan_vocoder_dataset.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+ import numpy as np
10
+
11
+ from torch.nn import functional as F
12
+
13
+ from torch.nn.utils.rnn import pad_sequence
14
+ from utils.data_utils import *
15
+ from models.vocoders.vocoder_dataset import VocoderDataset
16
+
17
+
18
+ class GANVocoderDataset(VocoderDataset):
19
+ def __init__(self, cfg, dataset, is_valid=False):
20
+ """
21
+ Args:
22
+ cfg: config
23
+ dataset: dataset name
24
+ is_valid: whether to use train or valid dataset
25
+ """
26
+ super().__init__(cfg, dataset, is_valid)
27
+
28
+ eval_index = random.randint(0, len(self.metadata) - 1)
29
+ eval_utt_info = self.metadata[eval_index]
30
+ eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"])
31
+ self.eval_audio = np.load(self.utt2audio_path[eval_utt])
32
+ if cfg.preprocess.use_mel:
33
+ self.eval_mel = np.load(self.utt2mel_path[eval_utt])
34
+ if cfg.preprocess.use_frame_pitch:
35
+ self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt])
36
+
37
+ def __getitem__(self, index):
38
+ utt_info = self.metadata[index]
39
+
40
+ dataset = utt_info["Dataset"]
41
+ uid = utt_info["Uid"]
42
+ utt = "{}_{}".format(dataset, uid)
43
+
44
+ single_feature = dict()
45
+
46
+ if self.cfg.preprocess.use_mel:
47
+ mel = np.load(self.utt2mel_path[utt])
48
+ assert mel.shape[0] == self.cfg.preprocess.n_mel
49
+
50
+ if "target_len" not in single_feature.keys():
51
+ single_feature["target_len"] = mel.shape[1]
52
+
53
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
54
+ mel = np.pad(
55
+ mel,
56
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
57
+ mode="constant",
58
+ )
59
+ else:
60
+ if "start" not in single_feature.keys():
61
+ start = random.randint(
62
+ 0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame
63
+ )
64
+ end = start + self.cfg.preprocess.cut_mel_frame
65
+ single_feature["start"] = start
66
+ single_feature["end"] = end
67
+ mel = mel[:, single_feature["start"] : single_feature["end"]]
68
+ single_feature["mel"] = mel
69
+
70
+ if self.cfg.preprocess.use_frame_pitch:
71
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
72
+ if "target_len" not in single_feature.keys():
73
+ single_feature["target_len"] = len(frame_pitch)
74
+ aligned_frame_pitch = align_length(
75
+ frame_pitch, single_feature["target_len"]
76
+ )
77
+
78
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
79
+ aligned_frame_pitch = np.pad(
80
+ aligned_frame_pitch,
81
+ (
82
+ (
83
+ 0,
84
+ self.cfg.preprocess.cut_mel_frame
85
+ * self.cfg.preprocess.hop_size
86
+ - audio.shape[-1],
87
+ )
88
+ ),
89
+ mode="constant",
90
+ )
91
+ else:
92
+ if "start" not in single_feature.keys():
93
+ start = random.randint(
94
+ 0,
95
+ aligned_frame_pitch.shape[-1]
96
+ - self.cfg.preprocess.cut_mel_frame,
97
+ )
98
+ end = start + self.cfg.preprocess.cut_mel_frame
99
+ single_feature["start"] = start
100
+ single_feature["end"] = end
101
+ aligned_frame_pitch = aligned_frame_pitch[
102
+ single_feature["start"] : single_feature["end"]
103
+ ]
104
+ single_feature["frame_pitch"] = aligned_frame_pitch
105
+
106
+ if self.cfg.preprocess.use_audio:
107
+ audio = np.load(self.utt2audio_path[utt])
108
+
109
+ assert "target_len" in single_feature.keys()
110
+
111
+ if (
112
+ audio.shape[-1]
113
+ <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size
114
+ ):
115
+ audio = np.pad(
116
+ audio,
117
+ (
118
+ (
119
+ 0,
120
+ self.cfg.preprocess.cut_mel_frame
121
+ * self.cfg.preprocess.hop_size
122
+ - audio.shape[-1],
123
+ )
124
+ ),
125
+ mode="constant",
126
+ )
127
+ else:
128
+ if "start" not in single_feature.keys():
129
+ audio = audio[
130
+ 0 : self.cfg.preprocess.cut_mel_frame
131
+ * self.cfg.preprocess.hop_size
132
+ ]
133
+ else:
134
+ audio = audio[
135
+ single_feature["start"]
136
+ * self.cfg.preprocess.hop_size : single_feature["end"]
137
+ * self.cfg.preprocess.hop_size,
138
+ ]
139
+ single_feature["audio"] = audio
140
+
141
+ if self.cfg.preprocess.use_amplitude_phase:
142
+ logamp = np.load(self.utt2logamp_path[utt])
143
+ pha = np.load(self.utt2pha_path[utt])
144
+ rea = np.load(self.utt2rea_path[utt])
145
+ imag = np.load(self.utt2imag_path[utt])
146
+
147
+ assert "target_len" in single_feature.keys()
148
+
149
+ if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
150
+ logamp = np.pad(
151
+ logamp,
152
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
153
+ mode="constant",
154
+ )
155
+ pha = np.pad(
156
+ pha,
157
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
158
+ mode="constant",
159
+ )
160
+ rea = np.pad(
161
+ rea,
162
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
163
+ mode="constant",
164
+ )
165
+ imag = np.pad(
166
+ imag,
167
+ ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
168
+ mode="constant",
169
+ )
170
+ else:
171
+ logamp = logamp[:, single_feature["start"] : single_feature["end"]]
172
+ pha = pha[:, single_feature["start"] : single_feature["end"]]
173
+ rea = rea[:, single_feature["start"] : single_feature["end"]]
174
+ imag = imag[:, single_feature["start"] : single_feature["end"]]
175
+ single_feature["logamp"] = logamp
176
+ single_feature["pha"] = pha
177
+ single_feature["rea"] = rea
178
+ single_feature["imag"] = imag
179
+
180
+ return single_feature
181
+
182
+
183
+ class GANVocoderCollator(object):
184
+ """Zero-pads model inputs and targets based on number of frames per step"""
185
+
186
+ def __init__(self, cfg):
187
+ self.cfg = cfg
188
+
189
+ def __call__(self, batch):
190
+ packed_batch_features = dict()
191
+
192
+ # mel: [b, n_mels, frame]
193
+ # frame_pitch: [b, frame]
194
+ # audios: [b, frame * hop_size]
195
+
196
+ for key in batch[0].keys():
197
+ if key in ["target_len", "start", "end"]:
198
+ continue
199
+ else:
200
+ values = [torch.from_numpy(b[key]) for b in batch]
201
+ packed_batch_features[key] = pad_sequence(
202
+ values, batch_first=True, padding_value=0
203
+ )
204
+
205
+ return packed_batch_features
Amphion/modules/anti_aliasing/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .act import *
7
+ from .filter import *
8
+ from .resample import *
Amphion/modules/encoder/condition_encoder.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchaudio.models import Conformer
10
+ from models.svc.transformer.transformer import PositionalEncoding
11
+
12
+ from utils.f0 import f0_to_coarse
13
+
14
+
15
+ class ContentEncoder(nn.Module):
16
+ def __init__(self, cfg, input_dim, output_dim):
17
+ super().__init__()
18
+ self.cfg = cfg
19
+
20
+ assert input_dim != 0
21
+ self.nn = nn.Linear(input_dim, output_dim)
22
+
23
+ # Introduce conformer or not
24
+ if (
25
+ "use_conformer_for_content_features" in cfg
26
+ and cfg.use_conformer_for_content_features
27
+ ):
28
+ self.pos_encoder = PositionalEncoding(input_dim)
29
+ self.conformer = Conformer(
30
+ input_dim=input_dim,
31
+ num_heads=2,
32
+ ffn_dim=256,
33
+ num_layers=6,
34
+ depthwise_conv_kernel_size=3,
35
+ )
36
+ else:
37
+ self.conformer = None
38
+
39
+ def forward(self, x, length=None):
40
+ # x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
41
+ if self.conformer:
42
+ x = self.pos_encoder(x)
43
+ x, _ = self.conformer(x, length)
44
+ return self.nn(x)
45
+
46
+
47
+ class MelodyEncoder(nn.Module):
48
+ def __init__(self, cfg):
49
+ super().__init__()
50
+ self.cfg = cfg
51
+
52
+ self.input_dim = self.cfg.input_melody_dim
53
+ self.output_dim = self.cfg.output_melody_dim
54
+ self.n_bins = self.cfg.n_bins_melody
55
+
56
+ if self.input_dim != 0:
57
+ if self.n_bins == 0:
58
+ # Not use quantization
59
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
60
+ else:
61
+ self.f0_min = cfg.f0_min
62
+ self.f0_max = cfg.f0_max
63
+
64
+ self.nn = nn.Embedding(
65
+ num_embeddings=self.n_bins,
66
+ embedding_dim=self.output_dim,
67
+ padding_idx=None,
68
+ )
69
+ self.uv_embedding = nn.Embedding(2, self.output_dim)
70
+
71
+ def forward(self, x, uv=None, length=None):
72
+ # x: (B, frame_len)
73
+ if self.n_bins == 0:
74
+ x = x.unsqueeze(-1)
75
+ else:
76
+ x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
77
+ x = self.nn(x)
78
+
79
+ if self.cfg.use_uv:
80
+ uv = self.uv_embedding(uv)
81
+ x = x + uv
82
+ return x
83
+
84
+
85
+ class LoudnessEncoder(nn.Module):
86
+ def __init__(self, cfg):
87
+ super().__init__()
88
+ self.cfg = cfg
89
+
90
+ self.input_dim = self.cfg.input_loudness_dim
91
+ self.output_dim = self.cfg.output_loudness_dim
92
+ self.n_bins = self.cfg.n_bins_loudness
93
+
94
+ if self.input_dim != 0:
95
+ if self.n_bins == 0:
96
+ # Not use quantization
97
+ self.nn = nn.Linear(self.input_dim, self.output_dim)
98
+ else:
99
+ # TODO: set empirically now
100
+ self.loudness_min = 1e-30
101
+ self.loudness_max = 1.5
102
+ self.energy_bins = nn.Parameter(
103
+ torch.exp(
104
+ torch.linspace(
105
+ np.log(self.loudness_min),
106
+ np.log(self.loudness_max),
107
+ self.n_bins - 1,
108
+ )
109
+ ),
110
+ requires_grad=False,
111
+ )
112
+
113
+ self.nn = nn.Embedding(
114
+ num_embeddings=self.n_bins,
115
+ embedding_dim=self.output_dim,
116
+ padding_idx=None,
117
+ )
118
+
119
+ def forward(self, x):
120
+ # x: (N, frame_len)
121
+ if self.n_bins == 0:
122
+ x = x.unsqueeze(-1)
123
+ else:
124
+ x = torch.bucketize(x, self.energy_bins)
125
+ return self.nn(x)
126
+
127
+
128
+ class SingerEncoder(nn.Module):
129
+ def __init__(self, cfg):
130
+ super().__init__()
131
+ self.cfg = cfg
132
+
133
+ self.input_dim = 1
134
+ self.output_dim = self.cfg.output_singer_dim
135
+
136
+ self.nn = nn.Embedding(
137
+ num_embeddings=cfg.singer_table_size,
138
+ embedding_dim=self.output_dim,
139
+ padding_idx=None,
140
+ )
141
+
142
+ def forward(self, x):
143
+ # x: (N, 1) -> (N, 1, output_dim)
144
+ return self.nn(x)
145
+
146
+
147
+ class ConditionEncoder(nn.Module):
148
+ def __init__(self, cfg):
149
+ super().__init__()
150
+ self.cfg = cfg
151
+ self.merge_mode = cfg.merge_mode
152
+
153
+ ### Semantic Features ###
154
+ if cfg.use_whisper:
155
+ self.whisper_encoder = ContentEncoder(
156
+ self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
157
+ )
158
+ if cfg.use_contentvec:
159
+ self.contentvec_encoder = ContentEncoder(
160
+ self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
161
+ )
162
+ if cfg.use_mert:
163
+ self.mert_encoder = ContentEncoder(
164
+ self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
165
+ )
166
+ if cfg.use_wenet:
167
+ self.wenet_encoder = ContentEncoder(
168
+ self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
169
+ )
170
+
171
+ ### Prosody Features ###
172
+ if cfg.use_f0:
173
+ self.melody_encoder = MelodyEncoder(self.cfg)
174
+ if cfg.use_energy:
175
+ self.loudness_encoder = LoudnessEncoder(self.cfg)
176
+
177
+ ### Speaker Features ###
178
+ if cfg.use_spkid:
179
+ self.singer_encoder = SingerEncoder(self.cfg)
180
+
181
+ def forward(self, x):
182
+ outputs = []
183
+
184
+ if self.cfg.use_f0:
185
+ if self.cfg.use_uv:
186
+ pitch_enc_out = self.melody_encoder(
187
+ x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
188
+ )
189
+ else:
190
+ pitch_enc_out = self.melody_encoder(
191
+ x["frame_pitch"], uv=None, length=x["target_len"]
192
+ )
193
+ outputs.append(pitch_enc_out)
194
+
195
+ if self.cfg.use_energy:
196
+ loudness_enc_out = self.loudness_encoder(x["frame_energy"])
197
+ outputs.append(loudness_enc_out)
198
+
199
+ if self.cfg.use_whisper:
200
+ # whisper_feat: [b, T, 1024]
201
+ whiser_enc_out = self.whisper_encoder(
202
+ x["whisper_feat"], length=x["target_len"]
203
+ )
204
+ outputs.append(whiser_enc_out)
205
+ seq_len = whiser_enc_out.shape[1]
206
+
207
+ if self.cfg.use_contentvec:
208
+ contentvec_enc_out = self.contentvec_encoder(
209
+ x["contentvec_feat"], length=x["target_len"]
210
+ )
211
+ outputs.append(contentvec_enc_out)
212
+ seq_len = contentvec_enc_out.shape[1]
213
+
214
+ if self.cfg.use_mert:
215
+ mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
216
+ outputs.append(mert_enc_out)
217
+ seq_len = mert_enc_out.shape[1]
218
+
219
+ if self.cfg.use_wenet:
220
+ wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
221
+ outputs.append(wenet_enc_out)
222
+ seq_len = wenet_enc_out.shape[1]
223
+
224
+ if self.cfg.use_spkid:
225
+ speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
226
+ assert (
227
+ "whisper_feat" in x.keys()
228
+ or "contentvec_feat" in x.keys()
229
+ or "mert_feat" in x.keys()
230
+ or "wenet_feat" in x.keys()
231
+ )
232
+ singer_info = speaker_enc_out.expand(-1, seq_len, -1)
233
+ outputs.append(singer_info)
234
+
235
+ encoder_output = None
236
+ if self.merge_mode == "concat":
237
+ encoder_output = torch.cat(outputs, dim=-1)
238
+ if self.merge_mode == "add":
239
+ # (#modules, N, seq_len, output_dim)
240
+ outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
241
+ # (N, seq_len, output_dim)
242
+ encoder_output = torch.sum(outputs, dim=0)
243
+
244
+ return encoder_output
Amphion/modules/general/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .input_strategies import PromptedFeatures, PromptedPrecomputedFeatures
2
+ from .scaling import BalancedDoubleSwish
3
+ from .utils import Transpose
Amphion/modules/monotonic_align/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code from https://github.com/jaywalnut310/vits/
2
+
3
+ import numpy as np
4
+ import torch
5
+ from .monotonic_align.core import maximum_path_c
6
+
7
+
8
+ def maximum_path(neg_cent, mask):
9
+ """Cython optimized version.
10
+ neg_cent: [b, t_t, t_s]
11
+ mask: [b, t_t, t_s]
12
+ """
13
+ device = neg_cent.device
14
+ dtype = neg_cent.dtype
15
+ neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
16
+ path = np.zeros(neg_cent.shape, dtype=np.int32)
17
+
18
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
19
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
20
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
21
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
Amphion/modules/neural_source_filter/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .sine_excitation import *
Amphion/modules/transformer/Layers.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn import functional as F
9
+ from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
10
+
11
+
12
+ class FFTBlock(torch.nn.Module):
13
+ """FFT Block"""
14
+
15
+ def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
16
+ super(FFTBlock, self).__init__()
17
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
18
+ self.pos_ffn = PositionwiseFeedForward(
19
+ d_model, d_inner, kernel_size, dropout=dropout
20
+ )
21
+
22
+ def forward(self, enc_input, mask=None, slf_attn_mask=None):
23
+ enc_output, enc_slf_attn = self.slf_attn(
24
+ enc_input, enc_input, enc_input, mask=slf_attn_mask
25
+ )
26
+ enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
27
+
28
+ enc_output = self.pos_ffn(enc_output)
29
+ enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
30
+
31
+ return enc_output, enc_slf_attn
32
+
33
+
34
+ class ConvNorm(torch.nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ out_channels,
39
+ kernel_size=1,
40
+ stride=1,
41
+ padding=None,
42
+ dilation=1,
43
+ bias=True,
44
+ w_init_gain="linear",
45
+ ):
46
+ super(ConvNorm, self).__init__()
47
+
48
+ if padding is None:
49
+ assert kernel_size % 2 == 1
50
+ padding = int(dilation * (kernel_size - 1) / 2)
51
+
52
+ self.conv = torch.nn.Conv1d(
53
+ in_channels,
54
+ out_channels,
55
+ kernel_size=kernel_size,
56
+ stride=stride,
57
+ padding=padding,
58
+ dilation=dilation,
59
+ bias=bias,
60
+ )
61
+
62
+ def forward(self, signal):
63
+ conv_signal = self.conv(signal)
64
+
65
+ return conv_signal
66
+
67
+
68
+ class PostNet(nn.Module):
69
+ """
70
+ PostNet: Five 1-d convolution with 512 channels and kernel size 5
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ n_mel_channels=80,
76
+ postnet_embedding_dim=512,
77
+ postnet_kernel_size=5,
78
+ postnet_n_convolutions=5,
79
+ ):
80
+ super(PostNet, self).__init__()
81
+ self.convolutions = nn.ModuleList()
82
+
83
+ self.convolutions.append(
84
+ nn.Sequential(
85
+ ConvNorm(
86
+ n_mel_channels,
87
+ postnet_embedding_dim,
88
+ kernel_size=postnet_kernel_size,
89
+ stride=1,
90
+ padding=int((postnet_kernel_size - 1) / 2),
91
+ dilation=1,
92
+ w_init_gain="tanh",
93
+ ),
94
+ nn.BatchNorm1d(postnet_embedding_dim),
95
+ )
96
+ )
97
+
98
+ for i in range(1, postnet_n_convolutions - 1):
99
+ self.convolutions.append(
100
+ nn.Sequential(
101
+ ConvNorm(
102
+ postnet_embedding_dim,
103
+ postnet_embedding_dim,
104
+ kernel_size=postnet_kernel_size,
105
+ stride=1,
106
+ padding=int((postnet_kernel_size - 1) / 2),
107
+ dilation=1,
108
+ w_init_gain="tanh",
109
+ ),
110
+ nn.BatchNorm1d(postnet_embedding_dim),
111
+ )
112
+ )
113
+
114
+ self.convolutions.append(
115
+ nn.Sequential(
116
+ ConvNorm(
117
+ postnet_embedding_dim,
118
+ n_mel_channels,
119
+ kernel_size=postnet_kernel_size,
120
+ stride=1,
121
+ padding=int((postnet_kernel_size - 1) / 2),
122
+ dilation=1,
123
+ w_init_gain="linear",
124
+ ),
125
+ nn.BatchNorm1d(n_mel_channels),
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ x = x.contiguous().transpose(1, 2)
131
+
132
+ for i in range(len(self.convolutions) - 1):
133
+ x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
134
+ x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
135
+
136
+ x = x.contiguous().transpose(1, 2)
137
+ return x
Amphion/modules/wenet_extractor/cif/predictor.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+ from typing import Optional
24
+
25
+ import torch
26
+ from torch import nn
27
+ from modules.wenet_extractor.utils.mask import make_pad_mask
28
+
29
+
30
+ class Predictor(nn.Module):
31
+ def __init__(
32
+ self,
33
+ idim,
34
+ l_order,
35
+ r_order,
36
+ threshold=1.0,
37
+ dropout=0.1,
38
+ smooth_factor=1.0,
39
+ noise_threshold=0,
40
+ tail_threshold=0.45,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
45
+ self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
46
+ self.cif_output = nn.Linear(idim, 1)
47
+ self.dropout = torch.nn.Dropout(p=dropout)
48
+ self.threshold = threshold
49
+ self.smooth_factor = smooth_factor
50
+ self.noise_threshold = noise_threshold
51
+ self.tail_threshold = tail_threshold
52
+
53
+ def forward(
54
+ self,
55
+ hidden,
56
+ target_label: Optional[torch.Tensor] = None,
57
+ mask: torch.Tensor = torch.tensor(0),
58
+ ignore_id: int = -1,
59
+ mask_chunk_predictor: Optional[torch.Tensor] = None,
60
+ target_label_length: Optional[torch.Tensor] = None,
61
+ ):
62
+ h = hidden
63
+ context = h.transpose(1, 2)
64
+ queries = self.pad(context)
65
+ memory = self.cif_conv1d(queries)
66
+ output = memory + context
67
+ output = self.dropout(output)
68
+ output = output.transpose(1, 2)
69
+ output = torch.relu(output)
70
+ output = self.cif_output(output)
71
+ alphas = torch.sigmoid(output)
72
+ alphas = torch.nn.functional.relu(
73
+ alphas * self.smooth_factor - self.noise_threshold
74
+ )
75
+ if mask is not None:
76
+ mask = mask.transpose(-1, -2).float()
77
+ alphas = alphas * mask
78
+ if mask_chunk_predictor is not None:
79
+ alphas = alphas * mask_chunk_predictor
80
+ alphas = alphas.squeeze(-1)
81
+ mask = mask.squeeze(-1)
82
+ if target_label_length is not None:
83
+ target_length = target_label_length
84
+ elif target_label is not None:
85
+ target_length = (target_label != ignore_id).float().sum(-1)
86
+ else:
87
+ target_length = None
88
+ token_num = alphas.sum(-1)
89
+ if target_length is not None:
90
+ alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
91
+ elif self.tail_threshold > 0.0:
92
+ hidden, alphas, token_num = self.tail_process_fn(
93
+ hidden, alphas, token_num, mask=mask
94
+ )
95
+
96
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
97
+
98
+ if target_length is None and self.tail_threshold > 0.0:
99
+ token_num_int = torch.max(token_num).type(torch.int32).item()
100
+ acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
101
+
102
+ return acoustic_embeds, token_num, alphas, cif_peak
103
+
104
+ def tail_process_fn(
105
+ self,
106
+ hidden,
107
+ alphas,
108
+ token_num: Optional[torch.Tensor] = None,
109
+ mask: Optional[torch.Tensor] = None,
110
+ ):
111
+ b, t, d = hidden.size()
112
+ tail_threshold = self.tail_threshold
113
+ if mask is not None:
114
+ zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
115
+ ones_t = torch.ones_like(zeros_t)
116
+ mask_1 = torch.cat([mask, zeros_t], dim=1)
117
+ mask_2 = torch.cat([ones_t, mask], dim=1)
118
+ mask = mask_2 - mask_1
119
+ tail_threshold = mask * tail_threshold
120
+ alphas = torch.cat([alphas, zeros_t], dim=1)
121
+ alphas = torch.add(alphas, tail_threshold)
122
+ else:
123
+ tail_threshold_tensor = torch.tensor(
124
+ [tail_threshold], dtype=alphas.dtype
125
+ ).to(alphas.device)
126
+ tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
127
+ alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
128
+ zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
129
+ hidden = torch.cat([hidden, zeros], dim=1)
130
+ token_num = alphas.sum(dim=-1)
131
+ token_num_floor = torch.floor(token_num)
132
+
133
+ return hidden, alphas, token_num_floor
134
+
135
+ def gen_frame_alignments(
136
+ self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
137
+ ):
138
+ batch_size, maximum_length = alphas.size()
139
+ int_type = torch.int32
140
+
141
+ is_training = self.training
142
+ if is_training:
143
+ token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
144
+ else:
145
+ token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
146
+
147
+ max_token_num = torch.max(token_num).item()
148
+
149
+ alphas_cumsum = torch.cumsum(alphas, dim=1)
150
+ alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
151
+ alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
152
+
153
+ index = torch.ones([batch_size, max_token_num], dtype=int_type)
154
+ index = torch.cumsum(index, dim=1)
155
+ index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
156
+
157
+ index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
158
+ index_div_bool_zeros = index_div.eq(0)
159
+ index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
160
+ index_div_bool_zeros_count = torch.clamp(
161
+ index_div_bool_zeros_count, 0, encoder_sequence_length.max()
162
+ )
163
+ token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
164
+ token_num.device
165
+ )
166
+ index_div_bool_zeros_count *= token_num_mask
167
+
168
+ index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
169
+ 1, 1, maximum_length
170
+ )
171
+ ones = torch.ones_like(index_div_bool_zeros_count_tile)
172
+ zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
173
+ ones = torch.cumsum(ones, dim=2)
174
+ cond = index_div_bool_zeros_count_tile == ones
175
+ index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
176
+
177
+ index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
178
+ torch.bool
179
+ )
180
+ index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
181
+ int_type
182
+ )
183
+ index_div_bool_zeros_count_tile_out = torch.sum(
184
+ index_div_bool_zeros_count_tile, dim=1
185
+ )
186
+ index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
187
+ int_type
188
+ )
189
+ predictor_mask = (
190
+ (
191
+ ~make_pad_mask(
192
+ encoder_sequence_length, max_len=encoder_sequence_length.max()
193
+ )
194
+ )
195
+ .type(int_type)
196
+ .to(encoder_sequence_length.device)
197
+ )
198
+ index_div_bool_zeros_count_tile_out = (
199
+ index_div_bool_zeros_count_tile_out * predictor_mask
200
+ )
201
+
202
+ predictor_alignments = index_div_bool_zeros_count_tile_out
203
+ predictor_alignments_length = predictor_alignments.sum(-1).type(
204
+ encoder_sequence_length.dtype
205
+ )
206
+ return predictor_alignments.detach(), predictor_alignments_length.detach()
207
+
208
+
209
+ class MAELoss(nn.Module):
210
+ def __init__(self, normalize_length=False):
211
+ super(MAELoss, self).__init__()
212
+ self.normalize_length = normalize_length
213
+ self.criterion = torch.nn.L1Loss(reduction="sum")
214
+
215
+ def forward(self, token_length, pre_token_length):
216
+ loss_token_normalizer = token_length.size(0)
217
+ if self.normalize_length:
218
+ loss_token_normalizer = token_length.sum().type(torch.float32)
219
+ loss = self.criterion(token_length, pre_token_length)
220
+ loss = loss / loss_token_normalizer
221
+ return loss
222
+
223
+
224
+ def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
225
+ batch_size, len_time, hidden_size = hidden.size()
226
+
227
+ # loop varss
228
+ integrate = torch.zeros([batch_size], device=hidden.device)
229
+ frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
230
+ # intermediate vars along time
231
+ list_fires = []
232
+ list_frames = []
233
+
234
+ for t in range(len_time):
235
+ alpha = alphas[:, t]
236
+ distribution_completion = (
237
+ torch.ones([batch_size], device=hidden.device) - integrate
238
+ )
239
+
240
+ integrate += alpha
241
+ list_fires.append(integrate)
242
+
243
+ fire_place = integrate >= threshold
244
+ integrate = torch.where(
245
+ fire_place,
246
+ integrate - torch.ones([batch_size], device=hidden.device),
247
+ integrate,
248
+ )
249
+ cur = torch.where(fire_place, distribution_completion, alpha)
250
+ remainds = alpha - cur
251
+
252
+ frame += cur[:, None] * hidden[:, t, :]
253
+ list_frames.append(frame)
254
+ frame = torch.where(
255
+ fire_place[:, None].repeat(1, hidden_size),
256
+ remainds[:, None] * hidden[:, t, :],
257
+ frame,
258
+ )
259
+
260
+ fires = torch.stack(list_fires, 1)
261
+ frames = torch.stack(list_frames, 1)
262
+ list_ls = []
263
+ len_labels = torch.round(alphas.sum(-1)).int()
264
+ max_label_len = len_labels.max()
265
+ for b in range(batch_size):
266
+ fire = fires[b, :]
267
+ l = torch.index_select(
268
+ frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()
269
+ )
270
+ pad_l = torch.zeros(
271
+ [int(max_label_len - l.size(0)), hidden_size], device=hidden.device
272
+ )
273
+ list_ls.append(torch.cat([l, pad_l], 0))
274
+ return torch.stack(list_ls, 0), fires
Amphion/modules/wenet_extractor/paraformer/search/beam_search.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+ from itertools import chain
24
+ from typing import Any
25
+ from typing import Dict
26
+ from typing import List
27
+ from typing import Tuple
28
+ from typing import Union
29
+ from typing import NamedTuple
30
+
31
+ import torch
32
+
33
+ from modules.wenet_extractor.paraformer.utils import end_detect
34
+ from modules.wenet_extractor.paraformer.search.ctc import CTCPrefixScorer
35
+ from modules.wenet_extractor.paraformer.search.scorer_interface import (
36
+ ScorerInterface,
37
+ PartialScorerInterface,
38
+ )
39
+
40
+
41
+ class Hypothesis(NamedTuple):
42
+ """Hypothesis data type."""
43
+
44
+ yseq: torch.Tensor
45
+ score: Union[float, torch.Tensor] = 0
46
+ scores: Dict[str, Union[float, torch.Tensor]] = dict()
47
+ states: Dict[str, Any] = dict()
48
+
49
+ def asdict(self) -> dict:
50
+ """Convert data to JSON-friendly dict."""
51
+ return self._replace(
52
+ yseq=self.yseq.tolist(),
53
+ score=float(self.score),
54
+ scores={k: float(v) for k, v in self.scores.items()},
55
+ )._asdict()
56
+
57
+
58
+ class BeamSearchCIF(torch.nn.Module):
59
+ """Beam search implementation."""
60
+
61
+ def __init__(
62
+ self,
63
+ scorers: Dict[str, ScorerInterface],
64
+ weights: Dict[str, float],
65
+ beam_size: int,
66
+ vocab_size: int,
67
+ sos: int,
68
+ eos: int,
69
+ pre_beam_ratio: float = 1.5,
70
+ pre_beam_score_key: str = None,
71
+ ):
72
+ """Initialize beam search.
73
+
74
+ Args:
75
+ scorers (dict[str, ScorerInterface]): Dict of decoder modules
76
+ e.g., Decoder, CTCPrefixScorer, LM
77
+ The scorer will be ignored if it is `None`
78
+ weights (dict[str, float]): Dict of weights for each scorers
79
+ The scorer will be ignored if its weight is 0
80
+ beam_size (int): The number of hypotheses kept during search
81
+ vocab_size (int): The number of vocabulary
82
+ sos (int): Start of sequence id
83
+ eos (int): End of sequence id
84
+ pre_beam_score_key (str): key of scores to perform pre-beam search
85
+ pre_beam_ratio (float): beam size in the pre-beam search
86
+ will be `int(pre_beam_ratio * beam_size)`
87
+
88
+ """
89
+ super().__init__()
90
+ # set scorers
91
+ self.weights = weights
92
+ self.scorers = dict()
93
+ self.full_scorers = dict()
94
+ self.part_scorers = dict()
95
+ # this module dict is required for recursive cast
96
+ # `self.to(device, dtype)` in `recog.py`
97
+ self.nn_dict = torch.nn.ModuleDict()
98
+ for k, v in scorers.items():
99
+ w = weights.get(k, 0)
100
+ if w == 0 or v is None:
101
+ continue
102
+ assert isinstance(
103
+ v, ScorerInterface
104
+ ), f"{k} ({type(v)}) does not implement ScorerInterface"
105
+ self.scorers[k] = v
106
+ if isinstance(v, PartialScorerInterface):
107
+ self.part_scorers[k] = v
108
+ else:
109
+ self.full_scorers[k] = v
110
+ if isinstance(v, torch.nn.Module):
111
+ self.nn_dict[k] = v
112
+
113
+ # set configurations
114
+ self.sos = sos
115
+ self.eos = eos
116
+ self.pre_beam_size = int(pre_beam_ratio * beam_size)
117
+ self.beam_size = beam_size
118
+ self.n_vocab = vocab_size
119
+ if (
120
+ pre_beam_score_key is not None
121
+ and pre_beam_score_key != "full"
122
+ and pre_beam_score_key not in self.full_scorers
123
+ ):
124
+ raise KeyError(
125
+ f"{pre_beam_score_key} is not found in " f"{self.full_scorers}"
126
+ )
127
+ self.pre_beam_score_key = pre_beam_score_key
128
+ self.do_pre_beam = (
129
+ self.pre_beam_score_key is not None
130
+ and self.pre_beam_size < self.n_vocab
131
+ and len(self.part_scorers) > 0
132
+ )
133
+
134
+ def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
135
+ """Get an initial hypothesis data.
136
+
137
+ Args:
138
+ x (torch.Tensor): The encoder output feature
139
+
140
+ Returns:
141
+ Hypothesis: The initial hypothesis.
142
+
143
+ """
144
+ init_states = dict()
145
+ init_scores = dict()
146
+ for k, d in self.scorers.items():
147
+ init_states[k] = d.init_state(x)
148
+ init_scores[k] = 0.0
149
+ return [
150
+ Hypothesis(
151
+ score=0.0,
152
+ scores=init_scores,
153
+ states=init_states,
154
+ yseq=torch.tensor([self.sos], device=x.device),
155
+ )
156
+ ]
157
+
158
+ @staticmethod
159
+ def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
160
+ """Append new token to prefix tokens.
161
+
162
+ Args:
163
+ xs (torch.Tensor): The prefix token
164
+ x (int): The new token to append
165
+
166
+ Returns:
167
+ torch.Tensor: New tensor contains: xs + [x] with xs.dtype and
168
+ xs.device
169
+
170
+ """
171
+ x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
172
+ return torch.cat((xs, x))
173
+
174
+ def score_full(
175
+ self, hyp: Hypothesis, x: torch.Tensor
176
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
177
+ """Score new hypothesis by `self.full_scorers`.
178
+
179
+ Args:
180
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
181
+ x (torch.Tensor): Corresponding input feature
182
+
183
+ Returns:
184
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
185
+ score dict of `hyp` that has string keys of `self.full_scorers`
186
+ and tensor score values of shape: `(self.n_vocab,)`,
187
+ and state dict that has string keys
188
+ and state values of `self.full_scorers`
189
+
190
+ """
191
+ scores = dict()
192
+ states = dict()
193
+ for k, d in self.full_scorers.items():
194
+ scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
195
+ return scores, states
196
+
197
+ def score_partial(
198
+ self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
199
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
200
+ """Score new hypothesis by `self.part_scorers`.
201
+
202
+ Args:
203
+ hyp (Hypothesis): Hypothesis with prefix tokens to score
204
+ ids (torch.Tensor): 1D tensor of new partial tokens to score
205
+ x (torch.Tensor): Corresponding input feature
206
+
207
+ Returns:
208
+ Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
209
+ score dict of `hyp` that has string keys of `self.part_scorers`
210
+ and tensor score values of shape: `(len(ids),)`,
211
+ and state dict that has string keys
212
+ and state values of `self.part_scorers`
213
+
214
+ """
215
+ scores = dict()
216
+ states = dict()
217
+ for k, d in self.part_scorers.items():
218
+ scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
219
+ return scores, states
220
+
221
+ def beam(
222
+ self, weighted_scores: torch.Tensor, ids: torch.Tensor
223
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
224
+ """Compute topk full token ids and partial token ids.
225
+
226
+ Args:
227
+ weighted_scores (torch.Tensor): The weighted sum scores for each
228
+ tokens.
229
+ Its shape is `(self.n_vocab,)`.
230
+ ids (torch.Tensor): The partial token ids to compute topk
231
+
232
+ Returns:
233
+ Tuple[torch.Tensor, torch.Tensor]:
234
+ The topk full token ids and partial token ids.
235
+ Their shapes are `(self.beam_size,)`
236
+
237
+ """
238
+ # no pre beam performed
239
+ if weighted_scores.size(0) == ids.size(0):
240
+ top_ids = weighted_scores.topk(self.beam_size)[1]
241
+ return top_ids, top_ids
242
+
243
+ # mask pruned in pre-beam not to select in topk
244
+ tmp = weighted_scores[ids]
245
+ weighted_scores[:] = -float("inf")
246
+ weighted_scores[ids] = tmp
247
+ top_ids = weighted_scores.topk(self.beam_size)[1]
248
+ local_ids = weighted_scores[ids].topk(self.beam_size)[1]
249
+ return top_ids, local_ids
250
+
251
+ @staticmethod
252
+ def merge_scores(
253
+ prev_scores: Dict[str, float],
254
+ next_full_scores: Dict[str, torch.Tensor],
255
+ full_idx: int,
256
+ next_part_scores: Dict[str, torch.Tensor],
257
+ part_idx: int,
258
+ ) -> Dict[str, torch.Tensor]:
259
+ """Merge scores for new hypothesis.
260
+
261
+ Args:
262
+ prev_scores (Dict[str, float]):
263
+ The previous hypothesis scores by `self.scorers`
264
+ next_full_scores (Dict[str, torch.Tensor]): scores by
265
+ `self.full_scorers`
266
+ full_idx (int): The next token id for `next_full_scores`
267
+ next_part_scores (Dict[str, torch.Tensor]):
268
+ scores of partial tokens by `self.part_scorers`
269
+ part_idx (int): The new token id for `next_part_scores`
270
+
271
+ Returns:
272
+ Dict[str, torch.Tensor]: The new score dict.
273
+ Its keys are names of `self.full_scorers` and
274
+ `self.part_scorers`.
275
+ Its values are scalar tensors by the scorers.
276
+
277
+ """
278
+ new_scores = dict()
279
+ for k, v in next_full_scores.items():
280
+ new_scores[k] = prev_scores[k] + v[full_idx]
281
+ for k, v in next_part_scores.items():
282
+ new_scores[k] = prev_scores[k] + v[part_idx]
283
+ return new_scores
284
+
285
+ def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
286
+ """Merge states for new hypothesis.
287
+
288
+ Args:
289
+ states: states of `self.full_scorers`
290
+ part_states: states of `self.part_scorers`
291
+ part_idx (int): The new token id for `part_scores`
292
+
293
+ Returns:
294
+ Dict[str, torch.Tensor]: The new score dict.
295
+ Its keys are names of `self.full_scorers` and
296
+ `self.part_scorers`.
297
+ Its values are states of the scorers.
298
+
299
+ """
300
+ new_states = dict()
301
+ for k, v in states.items():
302
+ new_states[k] = v
303
+ for k, d in self.part_scorers.items():
304
+ new_states[k] = d.select_state(part_states[k], part_idx)
305
+ return new_states
306
+
307
+ def search(
308
+ self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
309
+ ) -> List[Hypothesis]:
310
+ """Search new tokens for running hypotheses and encoded speech x.
311
+
312
+ Args:
313
+ running_hyps (List[Hypothesis]): Running hypotheses on beam
314
+ x (torch.Tensor): Encoded speech feature (T, D)
315
+
316
+ Returns:
317
+ List[Hypotheses]: Best sorted hypotheses
318
+
319
+ """
320
+ best_hyps = []
321
+ part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
322
+ for hyp in running_hyps:
323
+ # scoring
324
+ weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
325
+ weighted_scores += am_score
326
+ scores, states = self.score_full(hyp, x)
327
+ for k in self.full_scorers:
328
+ weighted_scores += self.weights[k] * scores[k]
329
+ # partial scoring
330
+ if self.do_pre_beam:
331
+ pre_beam_scores = (
332
+ weighted_scores
333
+ if self.pre_beam_score_key == "full"
334
+ else scores[self.pre_beam_score_key]
335
+ )
336
+ part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
337
+ part_scores, part_states = self.score_partial(hyp, part_ids, x)
338
+ for k in self.part_scorers:
339
+ weighted_scores[part_ids] += self.weights[k] * part_scores[k]
340
+ # add previous hyp score
341
+ weighted_scores += hyp.score
342
+
343
+ # update hyps
344
+ for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
345
+ # will be (2 x beam at most)
346
+ best_hyps.append(
347
+ Hypothesis(
348
+ score=weighted_scores[j],
349
+ yseq=self.append_token(hyp.yseq, j),
350
+ scores=self.merge_scores(
351
+ hyp.scores, scores, j, part_scores, part_j
352
+ ),
353
+ states=self.merge_states(states, part_states, part_j),
354
+ )
355
+ )
356
+
357
+ # sort and prune 2 x beam -> beam
358
+ best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
359
+ : min(len(best_hyps), self.beam_size)
360
+ ]
361
+ return best_hyps
362
+
363
+ def forward(
364
+ self,
365
+ x: torch.Tensor,
366
+ am_scores: torch.Tensor,
367
+ maxlenratio: float = 0.0,
368
+ minlenratio: float = 0.0,
369
+ ) -> List[Hypothesis]:
370
+ """Perform beam search.
371
+
372
+ Args:
373
+ x (torch.Tensor): Encoded speech feature (T, D)
374
+ maxlenratio (float): Input length ratio to obtain max output length.
375
+ If maxlenratio=0.0 (default), it uses a end-detect function
376
+ to automatically find maximum hypothesis lengths
377
+ If maxlenratio<0.0, its absolute value is interpreted
378
+ as a constant max output length.
379
+ minlenratio (float): Input length ratio to obtain min output length.
380
+
381
+ Returns:
382
+ list[Hypothesis]: N-best decoding results
383
+
384
+ """
385
+ # set length bounds
386
+ maxlen = am_scores.shape[0]
387
+
388
+ # main loop of prefix search
389
+ running_hyps = self.init_hyp(x)
390
+ ended_hyps = []
391
+ for i in range(maxlen):
392
+ best = self.search(running_hyps, x, am_scores[i])
393
+ # post process of one iteration
394
+ running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
395
+ # end detection
396
+ if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
397
+ break
398
+
399
+ nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
400
+ # check the number of hypotheses reaching to eos
401
+ if len(nbest_hyps) == 0:
402
+ return (
403
+ []
404
+ if minlenratio < 0.1
405
+ else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
406
+ )
407
+
408
+ best = nbest_hyps[0]
409
+ return nbest_hyps
410
+
411
+ def post_process(
412
+ self,
413
+ i: int,
414
+ maxlen: int,
415
+ maxlenratio: float,
416
+ running_hyps: List[Hypothesis],
417
+ ended_hyps: List[Hypothesis],
418
+ ) -> List[Hypothesis]:
419
+ """Perform post-processing of beam search iterations.
420
+
421
+ Args:
422
+ i (int): The length of hypothesis tokens.
423
+ maxlen (int): The maximum length of tokens in beam search.
424
+ maxlenratio (int): The maximum length ratio in beam search.
425
+ running_hyps (List[Hypothesis]): The running hypotheses in beam
426
+ search.
427
+ ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
428
+
429
+ Returns:
430
+ List[Hypothesis]: The new running hypotheses.
431
+
432
+ """
433
+
434
+ # add eos in the final loop to avoid that there are no ended hyps
435
+ if i == maxlen - 1:
436
+ # logging.info("adding <eos> in the last position in the loop")
437
+ running_hyps = [
438
+ h._replace(yseq=self.append_token(h.yseq, self.eos))
439
+ for h in running_hyps
440
+ ]
441
+
442
+ # add ended hypotheses to a final list, and removed them from current
443
+ # hypotheses
444
+ # (this will be a problem, number of hyps < beam)
445
+ remained_hyps = []
446
+ for hyp in running_hyps:
447
+ if hyp.yseq[-1] == self.eos:
448
+ # e.g., Word LM needs to add final <eos> score
449
+ for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
450
+ s = d.final_score(hyp.states[k])
451
+ hyp.scores[k] += s
452
+ hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
453
+ ended_hyps.append(hyp)
454
+ else:
455
+ remained_hyps.append(hyp)
456
+ return remained_hyps
457
+
458
+
459
+ def build_beam_search(model, args, device):
460
+ scorers = {}
461
+ if model.ctc is not None:
462
+ ctc = CTCPrefixScorer(ctc=model.ctc, eos=model.eos)
463
+ scorers.update(ctc=ctc)
464
+ weights = dict(
465
+ decoder=1.0 - args.ctc_weight,
466
+ ctc=args.ctc_weight,
467
+ length_bonus=args.penalty,
468
+ )
469
+ beam_search = BeamSearchCIF(
470
+ beam_size=args.beam_size,
471
+ weights=weights,
472
+ scorers=scorers,
473
+ sos=model.sos,
474
+ eos=model.eos,
475
+ vocab_size=model.vocab_size,
476
+ pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
477
+ )
478
+ beam_search.to(device=device, dtype=torch.float32).eval()
479
+ return beam_search
Amphion/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+ import torch
24
+ import numpy as np
25
+
26
+ import six
27
+
28
+
29
+ class CTCPrefixScore(object):
30
+ """Compute CTC label sequence scores
31
+
32
+ which is based on Algorithm 2 in WATANABE et al.
33
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
34
+ but extended to efficiently compute the probablities of multiple labels
35
+ simultaneously
36
+ """
37
+
38
+ def __init__(self, x, blank, eos, xp):
39
+ self.xp = xp
40
+ self.logzero = -10000000000.0
41
+ self.blank = blank
42
+ self.eos = eos
43
+ self.input_length = len(x)
44
+ self.x = x
45
+
46
+ def initial_state(self):
47
+ """Obtain an initial CTC state
48
+
49
+ :return: CTC state
50
+ """
51
+ # initial CTC state is made of a frame x 2 tensor that corresponds to
52
+ # r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
53
+ # superscripts n and b (non-blank and blank), respectively.
54
+ r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
55
+ r[0, 1] = self.x[0, self.blank]
56
+ for i in six.moves.range(1, self.input_length):
57
+ r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
58
+ return r
59
+
60
+ def __call__(self, y, cs, r_prev):
61
+ """Compute CTC prefix scores for next labels
62
+
63
+ :param y : prefix label sequence
64
+ :param cs : array of next labels
65
+ :param r_prev: previous CTC state
66
+ :return ctc_scores, ctc_states
67
+ """
68
+ # initialize CTC states
69
+ output_length = len(y) - 1 # ignore sos
70
+ # new CTC states are prepared as a frame x (n or b) x n_labels tensor
71
+ # that corresponds to r_t^n(h) and r_t^b(h).
72
+ r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
73
+ xs = self.x[:, cs]
74
+ if output_length == 0:
75
+ r[0, 0] = xs[0]
76
+ r[0, 1] = self.logzero
77
+ else:
78
+ r[output_length - 1] = self.logzero
79
+
80
+ # prepare forward probabilities for the last label
81
+ r_sum = self.xp.logaddexp(
82
+ r_prev[:, 0], r_prev[:, 1]
83
+ ) # log(r_t^n(g) + r_t^b(g))
84
+ last = y[-1]
85
+ if output_length > 0 and last in cs:
86
+ log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
87
+ for i in six.moves.range(len(cs)):
88
+ log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
89
+ else:
90
+ log_phi = r_sum
91
+
92
+ # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
93
+ # and log prefix probabilities log(psi)
94
+ start = max(output_length, 1)
95
+ log_psi = r[start - 1, 0]
96
+ for t in six.moves.range(start, self.input_length):
97
+ r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
98
+ r[t, 1] = (
99
+ self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
100
+ )
101
+ log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
102
+
103
+ # get P(...eos|X) that ends with the prefix itself
104
+ eos_pos = self.xp.where(cs == self.eos)[0]
105
+ if len(eos_pos) > 0:
106
+ log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
107
+
108
+ # exclude blank probs
109
+ blank_pos = self.xp.where(cs == self.blank)[0]
110
+ if len(blank_pos) > 0:
111
+ log_psi[blank_pos] = self.logzero
112
+
113
+ # return the log prefix probability and CTC states, where the label axis
114
+ # of the CTC states is moved to the first axis to slice it easily
115
+ return log_psi, self.xp.rollaxis(r, 2)
116
+
117
+
118
+ class CTCPrefixScoreTH(object):
119
+ """Batch processing of CTCPrefixScore
120
+
121
+ which is based on Algorithm 2 in WATANABE et al.
122
+ "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
123
+ but extended to efficiently compute the label probablities for multiple
124
+ hypotheses simultaneously
125
+ See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
126
+ Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
127
+ """
128
+
129
+ def __init__(self, x, xlens, blank, eos, margin=0):
130
+ """Construct CTC prefix scorer
131
+
132
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
133
+ :param torch.Tensor xlens: input lengths (B,)
134
+ :param int blank: blank label id
135
+ :param int eos: end-of-sequence id
136
+ :param int margin: margin parameter for windowing (0 means no windowing)
137
+ """
138
+ # In the comment lines,
139
+ # we assume T: input_length, B: batch size, W: beam width, O: output dim
140
+ self.logzero = -10000000000.0
141
+ self.blank = blank
142
+ self.eos = eos
143
+ self.batch = x.size(0)
144
+ self.input_length = x.size(1)
145
+ self.odim = x.size(2)
146
+ self.dtype = x.dtype
147
+ self.device = (
148
+ torch.device("cuda:%d" % x.get_device())
149
+ if x.is_cuda
150
+ else torch.device("cpu")
151
+ )
152
+ # Pad the rest of posteriors in the batch
153
+ # TODO(takaaki-hori): need a better way without for-loops
154
+ for i, l in enumerate(xlens):
155
+ if l < self.input_length:
156
+ x[i, l:, :] = self.logzero
157
+ x[i, l:, blank] = 0
158
+ # Reshape input x
159
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
160
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
161
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
162
+ self.end_frames = torch.as_tensor(xlens) - 1
163
+
164
+ # Setup CTC windowing
165
+ self.margin = margin
166
+ if margin > 0:
167
+ self.frame_ids = torch.arange(
168
+ self.input_length, dtype=self.dtype, device=self.device
169
+ )
170
+ # Base indices for index conversion
171
+ self.idx_bh = None
172
+ self.idx_b = torch.arange(self.batch, device=self.device)
173
+ self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
174
+
175
+ def __call__(self, y, state, scoring_ids=None, att_w=None):
176
+ """Compute CTC prefix scores for next labels
177
+
178
+ :param list y: prefix label sequences
179
+ :param tuple state: previous CTC state
180
+ :param torch.Tensor pre_scores: scores for pre-selection of hypotheses
181
+ (BW, O)
182
+ :param torch.Tensor att_w: attention weights to decide CTC window
183
+ :return new_state, ctc_local_scores (BW, O)
184
+ """
185
+ output_length = len(y[0]) - 1 # ignore sos
186
+ last_ids = [yi[-1] for yi in y] # last output label ids
187
+ n_bh = len(last_ids) # batch * hyps
188
+ n_hyps = n_bh // self.batch # assuming each utterance has the same
189
+ self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
190
+ # prepare state info
191
+ if state is None:
192
+ r_prev = torch.full(
193
+ (self.input_length, 2, self.batch, n_hyps),
194
+ self.logzero,
195
+ dtype=self.dtype,
196
+ device=self.device,
197
+ )
198
+ r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
199
+ r_prev = r_prev.view(-1, 2, n_bh)
200
+ s_prev = 0.0
201
+ f_min_prev = 0
202
+ f_max_prev = 1
203
+ else:
204
+ r_prev, s_prev, f_min_prev, f_max_prev = state
205
+
206
+ # select input dimensions for scoring
207
+ if self.scoring_num > 0:
208
+ scoring_idmap = torch.full(
209
+ (n_bh, self.odim), -1, dtype=torch.long, device=self.device
210
+ )
211
+ snum = self.scoring_num
212
+ if self.idx_bh is None or n_bh > len(self.idx_bh):
213
+ self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
214
+ scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
215
+ snum, device=self.device
216
+ )
217
+ scoring_idx = (
218
+ scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
219
+ ).view(-1)
220
+ x_ = torch.index_select(
221
+ self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
222
+ ).view(2, -1, n_bh, snum)
223
+ else:
224
+ scoring_ids = None
225
+ scoring_idmap = None
226
+ snum = self.odim
227
+ x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
228
+
229
+ # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
230
+ # that corresponds to r_t^n(h) and r_t^b(h) in a batch.
231
+ r = torch.full(
232
+ (self.input_length, 2, n_bh, snum),
233
+ self.logzero,
234
+ dtype=self.dtype,
235
+ device=self.device,
236
+ )
237
+ if output_length == 0:
238
+ r[0, 0] = x_[0, 0]
239
+
240
+ r_sum = torch.logsumexp(r_prev, 1)
241
+ log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
242
+ if scoring_ids is not None:
243
+ for idx in range(n_bh):
244
+ pos = scoring_idmap[idx, last_ids[idx]]
245
+ if pos >= 0:
246
+ log_phi[:, idx, pos] = r_prev[:, 1, idx]
247
+ else:
248
+ for idx in range(n_bh):
249
+ log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
250
+
251
+ # decide start and end frames based on attention weights
252
+ if att_w is not None and self.margin > 0:
253
+ f_arg = torch.matmul(att_w, self.frame_ids)
254
+ f_min = max(int(f_arg.min().cpu()), f_min_prev)
255
+ f_max = max(int(f_arg.max().cpu()), f_max_prev)
256
+ start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
257
+ end = min(f_max + self.margin, self.input_length)
258
+ else:
259
+ f_min = f_max = 0
260
+ start = max(output_length, 1)
261
+ end = self.input_length
262
+
263
+ # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
264
+ for t in range(start, end):
265
+ rp = r[t - 1]
266
+ rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
267
+ 2, 2, n_bh, snum
268
+ )
269
+ r[t] = torch.logsumexp(rr, 1) + x_[:, t]
270
+
271
+ # compute log prefix probabilities log(psi)
272
+ log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
273
+ if scoring_ids is not None:
274
+ log_psi = torch.full(
275
+ (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
276
+ )
277
+ log_psi_ = torch.logsumexp(
278
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
279
+ dim=0,
280
+ )
281
+ for si in range(n_bh):
282
+ log_psi[si, scoring_ids[si]] = log_psi_[si]
283
+ else:
284
+ log_psi = torch.logsumexp(
285
+ torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
286
+ dim=0,
287
+ )
288
+
289
+ for si in range(n_bh):
290
+ log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
291
+
292
+ # exclude blank probs
293
+ log_psi[:, self.blank] = self.logzero
294
+
295
+ return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
296
+
297
+ def index_select_state(self, state, best_ids):
298
+ """Select CTC states according to best ids
299
+
300
+ :param state : CTC state
301
+ :param best_ids : index numbers selected by beam pruning (B, W)
302
+ :return selected_state
303
+ """
304
+ r, s, f_min, f_max, scoring_idmap = state
305
+ # convert ids to BHO space
306
+ n_bh = len(s)
307
+ n_hyps = n_bh // self.batch
308
+ vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
309
+ # select hypothesis scores
310
+ s_new = torch.index_select(s.view(-1), 0, vidx)
311
+ s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
312
+ # convert ids to BHS space (S: scoring_num)
313
+ if scoring_idmap is not None:
314
+ snum = self.scoring_num
315
+ hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
316
+ -1
317
+ )
318
+ label_ids = torch.fmod(best_ids, self.odim).view(-1)
319
+ score_idx = scoring_idmap[hyp_idx, label_ids]
320
+ score_idx[score_idx == -1] = 0
321
+ vidx = score_idx + hyp_idx * snum
322
+ else:
323
+ snum = self.odim
324
+ # select forward probabilities
325
+ r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
326
+ -1, 2, n_bh
327
+ )
328
+ return r_new, s_new, f_min, f_max
329
+
330
+ def extend_prob(self, x):
331
+ """Extend CTC prob.
332
+
333
+ :param torch.Tensor x: input label posterior sequences (B, T, O)
334
+ """
335
+
336
+ if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
337
+ # Pad the rest of posteriors in the batch
338
+ # TODO(takaaki-hori): need a better way without for-loops
339
+ xlens = [x.size(1)]
340
+ for i, l in enumerate(xlens):
341
+ if l < self.input_length:
342
+ x[i, l:, :] = self.logzero
343
+ x[i, l:, self.blank] = 0
344
+ tmp_x = self.x
345
+ xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
346
+ xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
347
+ self.x = torch.stack([xn, xb]) # (2, T, B, O)
348
+ self.x[:, : tmp_x.shape[1], :, :] = tmp_x
349
+ self.input_length = x.size(1)
350
+ self.end_frames = torch.as_tensor(xlens) - 1
351
+
352
+ def extend_state(self, state):
353
+ """Compute CTC prefix state.
354
+
355
+
356
+ :param state : CTC state
357
+ :return ctc_state
358
+ """
359
+
360
+ if state is None:
361
+ # nothing to do
362
+ return state
363
+ else:
364
+ r_prev, s_prev, f_min_prev, f_max_prev = state
365
+
366
+ r_prev_new = torch.full(
367
+ (self.input_length, 2),
368
+ self.logzero,
369
+ dtype=self.dtype,
370
+ device=self.device,
371
+ )
372
+ start = max(r_prev.shape[0], 1)
373
+ r_prev_new[0:start] = r_prev
374
+ for t in six.moves.range(start, self.input_length):
375
+ r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
376
+
377
+ return r_prev_new, s_prev, f_min_prev, f_max_prev
Amphion/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+ """Positionwise feed forward layer definition."""
24
+
25
+ import torch
26
+
27
+
28
+ class PositionwiseFeedForward(torch.nn.Module):
29
+ """Positionwise feed forward layer.
30
+
31
+ FeedForward are appied on each position of the sequence.
32
+ The output dim is same with the input dim.
33
+
34
+ Args:
35
+ idim (int): Input dimenstion.
36
+ hidden_units (int): The number of hidden units.
37
+ dropout_rate (float): Dropout rate.
38
+ activation (torch.nn.Module): Activation function
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ idim: int,
44
+ hidden_units: int,
45
+ dropout_rate: float,
46
+ activation: torch.nn.Module = torch.nn.ReLU(),
47
+ adaptive_scale: bool = False,
48
+ init_weights: bool = False,
49
+ ):
50
+ """Construct a PositionwiseFeedForward object."""
51
+ super(PositionwiseFeedForward, self).__init__()
52
+ self.idim = idim
53
+ self.hidden_units = hidden_units
54
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
55
+ self.activation = activation
56
+ self.dropout = torch.nn.Dropout(dropout_rate)
57
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
58
+ self.ada_scale = None
59
+ self.ada_bias = None
60
+ self.adaptive_scale = adaptive_scale
61
+ self.ada_scale = torch.nn.Parameter(
62
+ torch.ones([1, 1, idim]), requires_grad=adaptive_scale
63
+ )
64
+ self.ada_bias = torch.nn.Parameter(
65
+ torch.zeros([1, 1, idim]), requires_grad=adaptive_scale
66
+ )
67
+ if init_weights:
68
+ self.init_weights()
69
+
70
+ def init_weights(self):
71
+ ffn1_max = self.idim**-0.5
72
+ ffn2_max = self.hidden_units**-0.5
73
+ torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max)
74
+ torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max)
75
+ torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max)
76
+ torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max)
77
+
78
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
79
+ """Forward function.
80
+
81
+ Args:
82
+ xs: input tensor (B, L, D)
83
+ Returns:
84
+ output tensor, (B, L, D)
85
+ """
86
+ if self.adaptive_scale:
87
+ xs = self.ada_scale * xs + self.ada_bias
88
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
Amphion/modules/wenet_extractor/transformer/decoder_layer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+ """Decoder self-attention layer definition."""
24
+ from typing import Optional, Tuple
25
+
26
+ import torch
27
+ from torch import nn
28
+
29
+
30
+ class DecoderLayer(nn.Module):
31
+ """Single decoder layer module.
32
+
33
+ Args:
34
+ size (int): Input dimension.
35
+ self_attn (torch.nn.Module): Self-attention module instance.
36
+ `MultiHeadedAttention` instance can be used as the argument.
37
+ src_attn (torch.nn.Module): Inter-attention module instance.
38
+ `MultiHeadedAttention` instance can be used as the argument.
39
+ If `None` is passed, Inter-attention is not used, such as
40
+ CIF, GPT, and other decoder only model.
41
+ feed_forward (torch.nn.Module): Feed-forward module instance.
42
+ `PositionwiseFeedForward` instance can be used as the argument.
43
+ dropout_rate (float): Dropout rate.
44
+ normalize_before (bool):
45
+ True: use layer_norm before each sub-block.
46
+ False: to use layer_norm after each sub-block.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ size: int,
52
+ self_attn: nn.Module,
53
+ src_attn: Optional[nn.Module],
54
+ feed_forward: nn.Module,
55
+ dropout_rate: float,
56
+ normalize_before: bool = True,
57
+ ):
58
+ """Construct an DecoderLayer object."""
59
+ super().__init__()
60
+ self.size = size
61
+ self.self_attn = self_attn
62
+ self.src_attn = src_attn
63
+ self.feed_forward = feed_forward
64
+ self.norm1 = nn.LayerNorm(size, eps=1e-5)
65
+ self.norm2 = nn.LayerNorm(size, eps=1e-5)
66
+ self.norm3 = nn.LayerNorm(size, eps=1e-5)
67
+ self.dropout = nn.Dropout(dropout_rate)
68
+ self.normalize_before = normalize_before
69
+
70
+ def forward(
71
+ self,
72
+ tgt: torch.Tensor,
73
+ tgt_mask: torch.Tensor,
74
+ memory: torch.Tensor,
75
+ memory_mask: torch.Tensor,
76
+ cache: Optional[torch.Tensor] = None,
77
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
78
+ """Compute decoded features.
79
+
80
+ Args:
81
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
82
+ tgt_mask (torch.Tensor): Mask for input tensor
83
+ (#batch, maxlen_out).
84
+ memory (torch.Tensor): Encoded memory
85
+ (#batch, maxlen_in, size).
86
+ memory_mask (torch.Tensor): Encoded memory mask
87
+ (#batch, maxlen_in).
88
+ cache (torch.Tensor): cached tensors.
89
+ (#batch, maxlen_out - 1, size).
90
+
91
+ Returns:
92
+ torch.Tensor: Output tensor (#batch, maxlen_out, size).
93
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
94
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
95
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
96
+
97
+ """
98
+ residual = tgt
99
+ if self.normalize_before:
100
+ tgt = self.norm1(tgt)
101
+
102
+ if cache is None:
103
+ tgt_q = tgt
104
+ tgt_q_mask = tgt_mask
105
+ else:
106
+ # compute only the last frame query keeping dim: max_time_out -> 1
107
+ assert cache.shape == (
108
+ tgt.shape[0],
109
+ tgt.shape[1] - 1,
110
+ self.size,
111
+ ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
112
+ tgt_q = tgt[:, -1:, :]
113
+ residual = residual[:, -1:, :]
114
+ tgt_q_mask = tgt_mask[:, -1:, :]
115
+
116
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
117
+ if not self.normalize_before:
118
+ x = self.norm1(x)
119
+
120
+ if self.src_attn is not None:
121
+ residual = x
122
+ if self.normalize_before:
123
+ x = self.norm2(x)
124
+ x = residual + self.dropout(
125
+ self.src_attn(x, memory, memory, memory_mask)[0]
126
+ )
127
+ if not self.normalize_before:
128
+ x = self.norm2(x)
129
+
130
+ residual = x
131
+ if self.normalize_before:
132
+ x = self.norm3(x)
133
+ x = residual + self.dropout(self.feed_forward(x))
134
+ if not self.normalize_before:
135
+ x = self.norm3(x)
136
+
137
+ if cache is not None:
138
+ x = torch.cat([cache, x], dim=1)
139
+
140
+ return x, tgt_mask, memory, memory_mask
Amphion/modules/wenet_extractor/transformer/subsampling.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is from [WeNet](https://github.com/wenet-e2e/wenet).
2
+
3
+ # ## Citations
4
+
5
+ # ```bibtex
6
+ # @inproceedings{yao2021wenet,
7
+ # title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
8
+ # author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
9
+ # booktitle={Proc. Interspeech},
10
+ # year={2021},
11
+ # address={Brno, Czech Republic },
12
+ # organization={IEEE}
13
+ # }
14
+
15
+ # @article{zhang2022wenet,
16
+ # title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
17
+ # author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
18
+ # journal={arXiv preprint arXiv:2203.15455},
19
+ # year={2022}
20
+ # }
21
+ #
22
+
23
+
24
+ """Subsampling layer definition."""
25
+
26
+ from typing import Tuple, Union
27
+
28
+ import torch
29
+
30
+
31
+ class BaseSubsampling(torch.nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.right_context = 0
35
+ self.subsampling_rate = 1
36
+
37
+ def position_encoding(
38
+ self, offset: Union[int, torch.Tensor], size: int
39
+ ) -> torch.Tensor:
40
+ return self.pos_enc.position_encoding(offset, size)
41
+
42
+
43
+ class LinearNoSubsampling(BaseSubsampling):
44
+ """Linear transform the input without subsampling
45
+
46
+ Args:
47
+ idim (int): Input dimension.
48
+ odim (int): Output dimension.
49
+ dropout_rate (float): Dropout rate.
50
+
51
+ """
52
+
53
+ def __init__(
54
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
55
+ ):
56
+ """Construct an linear object."""
57
+ super().__init__()
58
+ self.out = torch.nn.Sequential(
59
+ torch.nn.Linear(idim, odim),
60
+ torch.nn.LayerNorm(odim, eps=1e-5),
61
+ torch.nn.Dropout(dropout_rate),
62
+ )
63
+ self.pos_enc = pos_enc_class
64
+ self.right_context = 0
65
+ self.subsampling_rate = 1
66
+
67
+ def forward(
68
+ self,
69
+ x: torch.Tensor,
70
+ x_mask: torch.Tensor,
71
+ offset: Union[int, torch.Tensor] = 0,
72
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
73
+ """Input x.
74
+
75
+ Args:
76
+ x (torch.Tensor): Input tensor (#batch, time, idim).
77
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
78
+
79
+ Returns:
80
+ torch.Tensor: linear input tensor (#batch, time', odim),
81
+ where time' = time .
82
+ torch.Tensor: linear input mask (#batch, 1, time'),
83
+ where time' = time .
84
+
85
+ """
86
+ x = self.out(x)
87
+ x, pos_emb = self.pos_enc(x, offset)
88
+ return x, pos_emb, x_mask
89
+
90
+
91
+ class Conv2dSubsampling4(BaseSubsampling):
92
+ """Convolutional 2D subsampling (to 1/4 length).
93
+
94
+ Args:
95
+ idim (int): Input dimension.
96
+ odim (int): Output dimension.
97
+ dropout_rate (float): Dropout rate.
98
+
99
+ """
100
+
101
+ def __init__(
102
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
103
+ ):
104
+ """Construct an Conv2dSubsampling4 object."""
105
+ super().__init__()
106
+ self.conv = torch.nn.Sequential(
107
+ torch.nn.Conv2d(1, odim, 3, 2),
108
+ torch.nn.ReLU(),
109
+ torch.nn.Conv2d(odim, odim, 3, 2),
110
+ torch.nn.ReLU(),
111
+ )
112
+ self.out = torch.nn.Sequential(
113
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
114
+ )
115
+ self.pos_enc = pos_enc_class
116
+ # The right context for every conv layer is computed by:
117
+ # (kernel_size - 1) * frame_rate_of_this_layer
118
+ self.subsampling_rate = 4
119
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
120
+ self.right_context = 6
121
+
122
+ def forward(
123
+ self,
124
+ x: torch.Tensor,
125
+ x_mask: torch.Tensor,
126
+ offset: Union[int, torch.Tensor] = 0,
127
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ """Subsample x.
129
+
130
+ Args:
131
+ x (torch.Tensor): Input tensor (#batch, time, idim).
132
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
133
+
134
+ Returns:
135
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
136
+ where time' = time // 4.
137
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
138
+ where time' = time // 4.
139
+ torch.Tensor: positional encoding
140
+
141
+ """
142
+ x = x.unsqueeze(1) # (b, c=1, t, f)
143
+ x = self.conv(x)
144
+ b, c, t, f = x.size()
145
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
146
+ x, pos_emb = self.pos_enc(x, offset)
147
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
148
+
149
+
150
+ class Conv2dSubsampling6(BaseSubsampling):
151
+ """Convolutional 2D subsampling (to 1/6 length).
152
+ Args:
153
+ idim (int): Input dimension.
154
+ odim (int): Output dimension.
155
+ dropout_rate (float): Dropout rate.
156
+ pos_enc (torch.nn.Module): Custom position encoding layer.
157
+ """
158
+
159
+ def __init__(
160
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
161
+ ):
162
+ """Construct an Conv2dSubsampling6 object."""
163
+ super().__init__()
164
+ self.conv = torch.nn.Sequential(
165
+ torch.nn.Conv2d(1, odim, 3, 2),
166
+ torch.nn.ReLU(),
167
+ torch.nn.Conv2d(odim, odim, 5, 3),
168
+ torch.nn.ReLU(),
169
+ )
170
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
171
+ self.pos_enc = pos_enc_class
172
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
173
+ self.subsampling_rate = 6
174
+ self.right_context = 10
175
+
176
+ def forward(
177
+ self,
178
+ x: torch.Tensor,
179
+ x_mask: torch.Tensor,
180
+ offset: Union[int, torch.Tensor] = 0,
181
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
182
+ """Subsample x.
183
+ Args:
184
+ x (torch.Tensor): Input tensor (#batch, time, idim).
185
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
186
+
187
+ Returns:
188
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
189
+ where time' = time // 6.
190
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
191
+ where time' = time // 6.
192
+ torch.Tensor: positional encoding
193
+ """
194
+ x = x.unsqueeze(1) # (b, c, t, f)
195
+ x = self.conv(x)
196
+ b, c, t, f = x.size()
197
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
198
+ x, pos_emb = self.pos_enc(x, offset)
199
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
200
+
201
+
202
+ class Conv2dSubsampling8(BaseSubsampling):
203
+ """Convolutional 2D subsampling (to 1/8 length).
204
+
205
+ Args:
206
+ idim (int): Input dimension.
207
+ odim (int): Output dimension.
208
+ dropout_rate (float): Dropout rate.
209
+
210
+ """
211
+
212
+ def __init__(
213
+ self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
214
+ ):
215
+ """Construct an Conv2dSubsampling8 object."""
216
+ super().__init__()
217
+ self.conv = torch.nn.Sequential(
218
+ torch.nn.Conv2d(1, odim, 3, 2),
219
+ torch.nn.ReLU(),
220
+ torch.nn.Conv2d(odim, odim, 3, 2),
221
+ torch.nn.ReLU(),
222
+ torch.nn.Conv2d(odim, odim, 3, 2),
223
+ torch.nn.ReLU(),
224
+ )
225
+ self.linear = torch.nn.Linear(
226
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
227
+ )
228
+ self.pos_enc = pos_enc_class
229
+ self.subsampling_rate = 8
230
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
231
+ self.right_context = 14
232
+
233
+ def forward(
234
+ self,
235
+ x: torch.Tensor,
236
+ x_mask: torch.Tensor,
237
+ offset: Union[int, torch.Tensor] = 0,
238
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
239
+ """Subsample x.
240
+
241
+ Args:
242
+ x (torch.Tensor): Input tensor (#batch, time, idim).
243
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
244
+
245
+ Returns:
246
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
247
+ where time' = time // 8.
248
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
249
+ where time' = time // 8.
250
+ torch.Tensor: positional encoding
251
+ """
252
+ x = x.unsqueeze(1) # (b, c, t, f)
253
+ x = self.conv(x)
254
+ b, c, t, f = x.size()
255
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
256
+ x, pos_emb = self.pos_enc(x, offset)
257
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
Amphion/modules/wenet_extractor/utils/__init__.py ADDED
File without changes
Amphion/preprocessors/cdmusiceval.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from glob import glob
7
+ import os
8
+ import json
9
+ import torchaudio
10
+ from tqdm import tqdm
11
+ from collections import defaultdict
12
+
13
+ from utils.util import has_existed, remove_and_create
14
+ from utils.audio_slicer import split_utterances_from_audio
15
+
16
+
17
+ def split_to_utterances(input_dir, output_dir):
18
+ print("Splitting to utterances for {}...".format(input_dir))
19
+
20
+ files_list = glob("*", root_dir=input_dir)
21
+ files_list.sort()
22
+ for wav_file in tqdm(files_list):
23
+ # # Load waveform
24
+ # waveform, fs = torchaudio.load(os.path.join(input_dir, wav_file))
25
+
26
+ # Singer name, Song name
27
+ song_name, singer_name = wav_file.split("_")[2].split("-")
28
+ save_dir = os.path.join(output_dir, singer_name, song_name)
29
+
30
+ split_utterances_from_audio(
31
+ os.path.join(input_dir, wav_file), save_dir, max_duration_of_utterance=10
32
+ )
33
+
34
+ # # Split
35
+ # slicer = Slicer(sr=fs, threshold=-30.0, max_sil_kept=3000, min_interval=1000)
36
+ # chunks = slicer.slice(waveform)
37
+
38
+ # for i, chunk in enumerate(chunks):
39
+ # save_dir = os.path.join(output_dir, singer_name, song_name)
40
+ # os.makedirs(save_dir, exist_ok=True)
41
+
42
+ # output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
43
+ # save_audio(output_file, chunk, fs)
44
+
45
+
46
+ def _main(dataset_path):
47
+ """
48
+ Split to utterances
49
+ """
50
+ utterance_dir = os.path.join(dataset_path, "utterances")
51
+ remove_and_create(utterance_dir)
52
+ split_to_utterances(os.path.join(dataset_path, "vocal"), utterance_dir)
53
+
54
+
55
+ def statistics(utterance_dir):
56
+ singers = []
57
+ songs = []
58
+ singers2songs = defaultdict(lambda: defaultdict(list))
59
+
60
+ singer_infos = glob(utterance_dir + "/*")
61
+
62
+ for singer_info in singer_infos:
63
+ singer = singer_info.split("/")[-1]
64
+
65
+ song_infos = glob(singer_info + "/*")
66
+
67
+ for song_info in song_infos:
68
+ song = song_info.split("/")[-1]
69
+
70
+ singers.append(singer)
71
+ songs.append(song)
72
+
73
+ utts = glob(song_info + "/*.wav")
74
+
75
+ for utt in utts:
76
+ uid = utt.split("/")[-1].split(".")[0]
77
+ singers2songs[singer][song].append(uid)
78
+
79
+ unique_singers = list(set(singers))
80
+ unique_songs = list(set(songs))
81
+ unique_singers.sort()
82
+ unique_songs.sort()
83
+
84
+ print(
85
+ "Statistics: {} singers, {} utterances ({} unique songs)".format(
86
+ len(unique_singers), len(songs), len(unique_songs)
87
+ )
88
+ )
89
+ print("Singers: \n{}".format("\t".join(unique_singers)))
90
+ return singers2songs, unique_singers
91
+
92
+
93
+ def main(output_path, dataset_path):
94
+ print("-" * 10)
95
+ print("Preparing samples for CD Music Eval...\n")
96
+
97
+ if not os.path.exists(os.path.join(dataset_path, "utterances")):
98
+ print("Spliting into utterances...\n")
99
+ _main(dataset_path)
100
+
101
+ save_dir = os.path.join(output_path, "cdmusiceval")
102
+ os.makedirs(save_dir, exist_ok=True)
103
+ train_output_file = os.path.join(save_dir, "train.json")
104
+ test_output_file = os.path.join(save_dir, "test.json")
105
+ singer_dict_file = os.path.join(save_dir, "singers.json")
106
+ utt2singer_file = os.path.join(save_dir, "utt2singer")
107
+ if (
108
+ has_existed(train_output_file)
109
+ and has_existed(test_output_file)
110
+ and has_existed(singer_dict_file)
111
+ and has_existed(utt2singer_file)
112
+ ):
113
+ return
114
+ utt2singer = open(utt2singer_file, "w")
115
+
116
+ # Load
117
+ utt_path = os.path.join(dataset_path, "utterances")
118
+ singers2songs, unique_singers = statistics(utt_path)
119
+
120
+ # We select songs of standard samples as test songs
121
+ train = []
122
+ test = []
123
+
124
+ train_index_count = 0
125
+ test_index_count = 0
126
+
127
+ train_total_duration = 0
128
+ test_total_duration = 0
129
+
130
+ for singer, songs in tqdm(singers2songs.items()):
131
+ song_names = list(songs.keys())
132
+
133
+ for chosen_song in song_names:
134
+ for chosen_uid in songs[chosen_song]:
135
+ res = {
136
+ "Dataset": "cdmusiceval",
137
+ "Singer": singer,
138
+ "Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
139
+ }
140
+ res["Path"] = "{}/{}/{}.wav".format(singer, chosen_song, chosen_uid)
141
+ res["Path"] = os.path.join(utt_path, res["Path"])
142
+ assert os.path.exists(res["Path"])
143
+
144
+ waveform, sample_rate = torchaudio.load(res["Path"])
145
+ duration = waveform.size(-1) / sample_rate
146
+ res["Duration"] = duration
147
+
148
+ if duration <= 1e-8:
149
+ continue
150
+
151
+ res["index"] = test_index_count
152
+ test_total_duration += duration
153
+ test.append(res)
154
+ test_index_count += 1
155
+
156
+ utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
157
+
158
+ print("#Train = {}, #Test = {}".format(len(train), len(test)))
159
+ print(
160
+ "#Train hours= {}, #Test hours= {}".format(
161
+ train_total_duration / 3600, test_total_duration / 3600
162
+ )
163
+ )
164
+
165
+ # Save train.json and test.json
166
+ with open(train_output_file, "w") as f:
167
+ json.dump(train, f, indent=4, ensure_ascii=False)
168
+ with open(test_output_file, "w") as f:
169
+ json.dump(test, f, indent=4, ensure_ascii=False)
170
+
171
+ # Save singers.json
172
+ singer_lut = {name: i for i, name in enumerate(unique_singers)}
173
+ with open(singer_dict_file, "w") as f:
174
+ json.dump(singer_lut, f, indent=4, ensure_ascii=False)
Amphion/utils/data_utils.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+
9
+ import numpy as np
10
+ from scipy.interpolate import interp1d
11
+ from tqdm import tqdm
12
+ from sklearn.preprocessing import StandardScaler
13
+
14
+
15
+ def intersperse(lst, item):
16
+ """
17
+ Insert an item in between any two consecutive elements of the given list, including beginning and end of list
18
+
19
+ Example:
20
+ >>> intersperse(0, [1, 74, 5, 31])
21
+ [0, 1, 0, 74, 0, 5, 0, 31, 0]
22
+ """
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def load_content_feature_path(meta_data, processed_dir, feat_dir):
29
+ utt2feat_path = {}
30
+ for utt_info in meta_data:
31
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
32
+ feat_path = os.path.join(
33
+ processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy'
34
+ )
35
+ utt2feat_path[utt] = feat_path
36
+
37
+ return utt2feat_path
38
+
39
+
40
+ def load_source_content_feature_path(meta_data, feat_dir):
41
+ utt2feat_path = {}
42
+ for utt in meta_data:
43
+ feat_path = os.path.join(feat_dir, f"{utt}.npy")
44
+ utt2feat_path[utt] = feat_path
45
+
46
+ return utt2feat_path
47
+
48
+
49
+ def get_spk_map(spk2id_path, utt2spk_path):
50
+ utt2spk = {}
51
+ with open(spk2id_path, "r") as spk2id_file:
52
+ spk2id = json.load(spk2id_file)
53
+ with open(utt2spk_path, encoding="utf-8") as f:
54
+ for line in f.readlines():
55
+ utt, spk = line.strip().split("\t")
56
+ utt2spk[utt] = spk
57
+ return spk2id, utt2spk
58
+
59
+
60
+ def get_target_f0_median(f0_dir):
61
+ total_f0 = []
62
+ for utt in os.listdir(f0_dir):
63
+ if not utt.endswith(".npy"):
64
+ continue
65
+ f0_feat_path = os.path.join(f0_dir, utt)
66
+ f0 = np.load(f0_feat_path)
67
+ total_f0 += f0.tolist()
68
+
69
+ total_f0 = np.array(total_f0)
70
+ voiced_position = np.where(total_f0 != 0)
71
+ return np.median(total_f0[voiced_position])
72
+
73
+
74
+ def get_conversion_f0_factor(source_f0, target_median, source_median=None):
75
+ """Align the median between source f0 and target f0
76
+
77
+ Note: Here we use multiplication, whose factor is target_median/source_median
78
+
79
+ Reference: Frequency and pitch interval
80
+ http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/
81
+ """
82
+ if source_median is None:
83
+ voiced_position = np.where(source_f0 != 0)
84
+ source_median = np.median(source_f0[voiced_position])
85
+ factor = target_median / source_median
86
+ return source_median, factor
87
+
88
+
89
+ def transpose_key(frame_pitch, trans_key):
90
+ # Transpose by user's argument
91
+ print("Transpose key = {} ...\n".format(trans_key))
92
+
93
+ transed_pitch = frame_pitch * 2 ** (trans_key / 12)
94
+ return transed_pitch
95
+
96
+
97
+ def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None):
98
+ # Loading F0 Base (median) and shift
99
+ source_pitch_median, factor = get_conversion_f0_factor(
100
+ frame_pitch, target_pitch_median, source_pitch_median
101
+ )
102
+ print(
103
+ "Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format(
104
+ source_pitch_median, target_pitch_median, factor
105
+ )
106
+ )
107
+ transed_pitch = frame_pitch * factor
108
+ return transed_pitch
109
+
110
+
111
+ def load_frame_pitch(
112
+ meta_data,
113
+ processed_dir,
114
+ pitch_dir,
115
+ use_log_scale=False,
116
+ return_norm=False,
117
+ interoperate=False,
118
+ utt2spk=None,
119
+ ):
120
+ utt2pitch = {}
121
+ utt2uv = {}
122
+ if utt2spk is None:
123
+ pitch_scaler = StandardScaler()
124
+ for utt_info in meta_data:
125
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
126
+ pitch_path = os.path.join(
127
+ processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
128
+ )
129
+ pitch = np.load(pitch_path)
130
+ assert len(pitch) > 0
131
+ uv = pitch != 0
132
+ utt2uv[utt] = uv
133
+ if use_log_scale:
134
+ nonzero_idxes = np.where(pitch != 0)[0]
135
+ pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
136
+ utt2pitch[utt] = pitch
137
+ pitch_scaler.partial_fit(pitch.reshape(-1, 1))
138
+
139
+ mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
140
+ if return_norm:
141
+ for utt_info in meta_data:
142
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
143
+ pitch = utt2pitch[utt]
144
+ normalized_pitch = (pitch - mean) / std
145
+ utt2pitch[utt] = normalized_pitch
146
+ pitch_statistic = {"mean": mean, "std": std}
147
+ else:
148
+ spk2utt = {}
149
+ pitch_statistic = []
150
+ for utt_info in meta_data:
151
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
152
+ if not utt2spk[utt] in spk2utt:
153
+ spk2utt[utt2spk[utt]] = []
154
+ spk2utt[utt2spk[utt]].append(utt)
155
+
156
+ for spk in spk2utt:
157
+ pitch_scaler = StandardScaler()
158
+ for utt in spk2utt[spk]:
159
+ dataset = utt.split("_")[0]
160
+ uid = "_".join(utt.split("_")[1:])
161
+ pitch_path = os.path.join(
162
+ processed_dir, dataset, pitch_dir, f"{uid}.npy"
163
+ )
164
+ pitch = np.load(pitch_path)
165
+ assert len(pitch) > 0
166
+ uv = pitch != 0
167
+ utt2uv[utt] = uv
168
+ if use_log_scale:
169
+ nonzero_idxes = np.where(pitch != 0)[0]
170
+ pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
171
+ utt2pitch[utt] = pitch
172
+ pitch_scaler.partial_fit(pitch.reshape(-1, 1))
173
+
174
+ mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
175
+ if return_norm:
176
+ for utt in spk2utt[spk]:
177
+ pitch = utt2pitch[utt]
178
+ normalized_pitch = (pitch - mean) / std
179
+ utt2pitch[utt] = normalized_pitch
180
+ pitch_statistic.append({"spk": spk, "mean": mean, "std": std})
181
+
182
+ return utt2pitch, utt2uv, pitch_statistic
183
+
184
+
185
+ # discard
186
+ def load_phone_pitch(
187
+ meta_data,
188
+ processed_dir,
189
+ pitch_dir,
190
+ utt2dur,
191
+ use_log_scale=False,
192
+ return_norm=False,
193
+ interoperate=True,
194
+ utt2spk=None,
195
+ ):
196
+ print("Load Phone Pitch")
197
+ utt2pitch = {}
198
+ utt2uv = {}
199
+ if utt2spk is None:
200
+ pitch_scaler = StandardScaler()
201
+ for utt_info in tqdm(meta_data):
202
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
203
+ pitch_path = os.path.join(
204
+ processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
205
+ )
206
+ frame_pitch = np.load(pitch_path)
207
+ assert len(frame_pitch) > 0
208
+ uv = frame_pitch != 0
209
+ utt2uv[utt] = uv
210
+ phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate)
211
+ if use_log_scale:
212
+ nonzero_idxes = np.where(phone_pitch != 0)[0]
213
+ phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
214
+ utt2pitch[utt] = phone_pitch
215
+ pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
216
+
217
+ mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
218
+ max_value = np.finfo(np.float64).min
219
+ min_value = np.finfo(np.float64).max
220
+ if return_norm:
221
+ for utt_info in meta_data:
222
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
223
+ pitch = utt2pitch[utt]
224
+ normalized_pitch = (pitch - mean) / std
225
+ max_value = max(max_value, max(normalized_pitch))
226
+ min_value = min(min_value, min(normalized_pitch))
227
+ utt2pitch[utt] = normalized_pitch
228
+ phone_normalized_pitch_path = os.path.join(
229
+ processed_dir,
230
+ utt_info["Dataset"],
231
+ "phone_level_" + pitch_dir,
232
+ f'{utt_info["Uid"]}.npy',
233
+ )
234
+ pitch_statistic = {
235
+ "mean": mean,
236
+ "std": std,
237
+ "min_value": min_value,
238
+ "max_value": max_value,
239
+ }
240
+ else:
241
+ spk2utt = {}
242
+ pitch_statistic = []
243
+ for utt_info in tqdm(meta_data):
244
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
245
+ if not utt2spk[utt] in spk2utt:
246
+ spk2utt[utt2spk[utt]] = []
247
+ spk2utt[utt2spk[utt]].append(utt)
248
+
249
+ for spk in spk2utt:
250
+ pitch_scaler = StandardScaler()
251
+ for utt in spk2utt[spk]:
252
+ dataset = utt.split("_")[0]
253
+ uid = "_".join(utt.split("_")[1:])
254
+ pitch_path = os.path.join(
255
+ processed_dir, dataset, pitch_dir, f"{uid}.npy"
256
+ )
257
+ frame_pitch = np.load(pitch_path)
258
+ assert len(frame_pitch) > 0
259
+ uv = frame_pitch != 0
260
+ utt2uv[utt] = uv
261
+ phone_pitch = phone_average_pitch(
262
+ frame_pitch, utt2dur[utt], interoperate
263
+ )
264
+ if use_log_scale:
265
+ nonzero_idxes = np.where(phone_pitch != 0)[0]
266
+ phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
267
+ utt2pitch[utt] = phone_pitch
268
+ pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
269
+
270
+ mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
271
+ max_value = np.finfo(np.float64).min
272
+ min_value = np.finfo(np.float64).max
273
+
274
+ if return_norm:
275
+ for utt in spk2utt[spk]:
276
+ pitch = utt2pitch[utt]
277
+ normalized_pitch = (pitch - mean) / std
278
+ max_value = max(max_value, max(normalized_pitch))
279
+ min_value = min(min_value, min(normalized_pitch))
280
+ utt2pitch[utt] = normalized_pitch
281
+ pitch_statistic.append(
282
+ {
283
+ "spk": spk,
284
+ "mean": mean,
285
+ "std": std,
286
+ "min_value": min_value,
287
+ "max_value": max_value,
288
+ }
289
+ )
290
+
291
+ return utt2pitch, utt2uv, pitch_statistic
292
+
293
+
294
+ def phone_average_pitch(pitch, dur, interoperate=False):
295
+ pos = 0
296
+
297
+ if interoperate:
298
+ nonzero_ids = np.where(pitch != 0)[0]
299
+ interp_fn = interp1d(
300
+ nonzero_ids,
301
+ pitch[nonzero_ids],
302
+ fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
303
+ bounds_error=False,
304
+ )
305
+ pitch = interp_fn(np.arange(0, len(pitch)))
306
+ phone_pitch = np.zeros(len(dur))
307
+
308
+ for i, d in enumerate(dur):
309
+ d = int(d)
310
+ if d > 0 and pos < len(pitch):
311
+ phone_pitch[i] = np.mean(pitch[pos : pos + d])
312
+ else:
313
+ phone_pitch[i] = 0
314
+ pos += d
315
+ return phone_pitch
316
+
317
+
318
+ def load_energy(
319
+ meta_data,
320
+ processed_dir,
321
+ energy_dir,
322
+ use_log_scale=False,
323
+ return_norm=False,
324
+ utt2spk=None,
325
+ ):
326
+ utt2energy = {}
327
+ if utt2spk is None:
328
+ for utt_info in meta_data:
329
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
330
+ energy_path = os.path.join(
331
+ processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
332
+ )
333
+ if not os.path.exists(energy_path):
334
+ continue
335
+ energy = np.load(energy_path)
336
+ assert len(energy) > 0
337
+
338
+ if use_log_scale:
339
+ nonzero_idxes = np.where(energy != 0)[0]
340
+ energy[nonzero_idxes] = np.log(energy[nonzero_idxes])
341
+ utt2energy[utt] = energy
342
+
343
+ if return_norm:
344
+ with open(
345
+ os.path.join(
346
+ processed_dir, utt_info["Dataset"], energy_dir, "statistics.json"
347
+ )
348
+ ) as f:
349
+ stats = json.load(f)
350
+ mean, std = (
351
+ stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][
352
+ "voiced_positions"
353
+ ]["mean"],
354
+ stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"],
355
+ )
356
+ for utt in utt2energy.keys():
357
+ energy = utt2energy[utt]
358
+ normalized_energy = (energy - mean) / std
359
+ utt2energy[utt] = normalized_energy
360
+
361
+ energy_statistic = {"mean": mean, "std": std}
362
+ else:
363
+ spk2utt = {}
364
+ energy_statistic = []
365
+ for utt_info in meta_data:
366
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
367
+ if not utt2spk[utt] in spk2utt:
368
+ spk2utt[utt2spk[utt]] = []
369
+ spk2utt[utt2spk[utt]].append(utt)
370
+
371
+ for spk in spk2utt:
372
+ energy_scaler = StandardScaler()
373
+ for utt in spk2utt[spk]:
374
+ dataset = utt.split("_")[0]
375
+ uid = "_".join(utt.split("_")[1:])
376
+ energy_path = os.path.join(
377
+ processed_dir, dataset, energy_dir, f"{uid}.npy"
378
+ )
379
+ if not os.path.exists(energy_path):
380
+ continue
381
+ frame_energy = np.load(energy_path)
382
+ assert len(frame_energy) > 0
383
+
384
+ if use_log_scale:
385
+ nonzero_idxes = np.where(frame_energy != 0)[0]
386
+ frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
387
+ utt2energy[utt] = frame_energy
388
+ energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
389
+
390
+ mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
391
+ if return_norm:
392
+ for utt in spk2utt[spk]:
393
+ energy = utt2energy[utt]
394
+ normalized_energy = (energy - mean) / std
395
+ utt2energy[utt] = normalized_energy
396
+ energy_statistic.append({"spk": spk, "mean": mean, "std": std})
397
+
398
+ return utt2energy, energy_statistic
399
+
400
+
401
+ def load_frame_energy(
402
+ meta_data,
403
+ processed_dir,
404
+ energy_dir,
405
+ use_log_scale=False,
406
+ return_norm=False,
407
+ interoperate=False,
408
+ utt2spk=None,
409
+ ):
410
+ utt2energy = {}
411
+ if utt2spk is None:
412
+ energy_scaler = StandardScaler()
413
+ for utt_info in meta_data:
414
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
415
+ energy_path = os.path.join(
416
+ processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
417
+ )
418
+ frame_energy = np.load(energy_path)
419
+ assert len(frame_energy) > 0
420
+
421
+ if use_log_scale:
422
+ nonzero_idxes = np.where(frame_energy != 0)[0]
423
+ frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
424
+ utt2energy[utt] = frame_energy
425
+ energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
426
+
427
+ mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
428
+ if return_norm:
429
+ for utt_info in meta_data:
430
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
431
+ energy = utt2energy[utt]
432
+ normalized_energy = (energy - mean) / std
433
+ utt2energy[utt] = normalized_energy
434
+ energy_statistic = {"mean": mean, "std": std}
435
+
436
+ else:
437
+ spk2utt = {}
438
+ energy_statistic = []
439
+ for utt_info in meta_data:
440
+ utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
441
+ if not utt2spk[utt] in spk2utt:
442
+ spk2utt[utt2spk[utt]] = []
443
+ spk2utt[utt2spk[utt]].append(utt)
444
+
445
+ for spk in spk2utt:
446
+ energy_scaler = StandardScaler()
447
+ for utt in spk2utt[spk]:
448
+ dataset = utt.split("_")[0]
449
+ uid = "_".join(utt.split("_")[1:])
450
+ energy_path = os.path.join(
451
+ processed_dir, dataset, energy_dir, f"{uid}.npy"
452
+ )
453
+ frame_energy = np.load(energy_path)
454
+ assert len(frame_energy) > 0
455
+
456
+ if use_log_scale:
457
+ nonzero_idxes = np.where(frame_energy != 0)[0]
458
+ frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
459
+ utt2energy[utt] = frame_energy
460
+ energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
461
+
462
+ mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
463
+ if return_norm:
464
+ for utt in spk2utt[spk]:
465
+ energy = utt2energy[utt]
466
+ normalized_energy = (energy - mean) / std
467
+ utt2energy[utt] = normalized_energy
468
+ energy_statistic.append({"spk": spk, "mean": mean, "std": std})
469
+
470
+ return utt2energy, energy_statistic
471
+
472
+
473
+ def align_length(feature, target_len, pad_value=0.0):
474
+ feature_len = feature.shape[-1]
475
+ dim = len(feature.shape)
476
+ # align 1-D data
477
+ if dim == 2:
478
+ if target_len > feature_len:
479
+ feature = np.pad(
480
+ feature,
481
+ ((0, 0), (0, target_len - feature_len)),
482
+ constant_values=pad_value,
483
+ )
484
+ else:
485
+ feature = feature[:, :target_len]
486
+ # align 2-D data
487
+ elif dim == 1:
488
+ if target_len > feature_len:
489
+ feature = np.pad(
490
+ feature, (0, target_len - feature_len), constant_values=pad_value
491
+ )
492
+ else:
493
+ feature = feature[:target_len]
494
+ else:
495
+ raise NotImplementedError
496
+ return feature
497
+
498
+
499
+ def align_whisper_feauture_length(
500
+ feature, target_len, fast_mapping=True, source_hop=320, target_hop=256
501
+ ):
502
+ factor = np.gcd(source_hop, target_hop)
503
+ source_hop //= factor
504
+ target_hop //= factor
505
+ # print(
506
+ # "Mapping source's {} frames => target's {} frames".format(
507
+ # target_hop, source_hop
508
+ # )
509
+ # )
510
+
511
+ max_source_len = 1500
512
+ target_len = min(target_len, max_source_len * source_hop // target_hop)
513
+
514
+ width = feature.shape[-1]
515
+
516
+ if fast_mapping:
517
+ source_len = target_len * target_hop // source_hop + 1
518
+ feature = feature[:source_len]
519
+
520
+ else:
521
+ source_len = max_source_len
522
+
523
+ # const ~= target_len * target_hop
524
+ const = source_len * source_hop // target_hop * target_hop
525
+
526
+ # (source_len * source_hop, dim)
527
+ up_sampling_feats = np.repeat(feature, source_hop, axis=0)
528
+ # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
529
+ down_sampling_feats = np.average(
530
+ up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
531
+ )
532
+ assert len(down_sampling_feats) >= target_len
533
+
534
+ # (target_len, dim)
535
+ feat = down_sampling_feats[:target_len]
536
+
537
+ return feat
538
+
539
+
540
+ def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256):
541
+ factor = np.gcd(source_hop, target_hop)
542
+ source_hop //= factor
543
+ target_hop //= factor
544
+ # print(
545
+ # "Mapping source's {} frames => target's {} frames".format(
546
+ # target_hop, source_hop
547
+ # )
548
+ # )
549
+
550
+ # (source_len, 256)
551
+ source_len, width = feature.shape
552
+
553
+ # const ~= target_len * target_hop
554
+ const = source_len * source_hop // target_hop * target_hop
555
+
556
+ # (source_len * source_hop, dim)
557
+ up_sampling_feats = np.repeat(feature, source_hop, axis=0)
558
+ # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
559
+ down_sampling_feats = np.average(
560
+ up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
561
+ )
562
+
563
+ err = abs(target_len - len(down_sampling_feats))
564
+ if err > 4: ## why 4 not 3?
565
+ print("target_len:", target_len)
566
+ print("raw feature:", feature.shape)
567
+ print("up_sampling:", up_sampling_feats.shape)
568
+ print("down_sampling_feats:", down_sampling_feats.shape)
569
+ exit()
570
+ if len(down_sampling_feats) < target_len:
571
+ # (1, dim) -> (err, dim)
572
+ end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
573
+ down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
574
+
575
+ # (target_len, dim)
576
+ feat = down_sampling_feats[:target_len]
577
+
578
+ return feat
579
+
580
+
581
+ def remove_outlier(values):
582
+ values = np.array(values)
583
+ p25 = np.percentile(values, 25)
584
+ p75 = np.percentile(values, 75)
585
+ lower = p25 - 1.5 * (p75 - p25)
586
+ upper = p75 + 1.5 * (p75 - p25)
587
+ normal_indices = np.logical_and(values > lower, values < upper)
588
+ return values[normal_indices]
Amphion/utils/distribution.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from torch.distributions import Normal
11
+
12
+
13
+ def log_sum_exp(x):
14
+ """numerically stable log_sum_exp implementation that prevents overflow"""
15
+ # TF ordering
16
+ axis = len(x.size()) - 1
17
+ m, _ = torch.max(x, dim=axis)
18
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
19
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
20
+
21
+
22
+ def discretized_mix_logistic_loss(
23
+ y_hat, y, num_classes=256, log_scale_min=-7.0, reduce=True
24
+ ):
25
+ """Discretized mixture of logistic distributions loss
26
+
27
+ Note that it is assumed that input is scaled to [-1, 1].
28
+
29
+ Args:
30
+ y_hat (Tensor): Predicted output (B x C x T)
31
+ y (Tensor): Target (B x T x 1).
32
+ num_classes (int): Number of classes
33
+ log_scale_min (float): Log scale minimum value
34
+ reduce (bool): If True, the losses are averaged or summed for each
35
+ minibatch.
36
+
37
+ Returns
38
+ Tensor: loss
39
+ """
40
+ assert y_hat.dim() == 3
41
+ assert y_hat.size(1) % 3 == 0
42
+ nr_mix = y_hat.size(1) // 3
43
+
44
+ # (B x T x C)
45
+ y_hat = y_hat.transpose(1, 2)
46
+
47
+ # unpack parameters. (B, T, num_mixtures) x 3
48
+ logit_probs = y_hat[:, :, :nr_mix]
49
+ means = y_hat[:, :, nr_mix : 2 * nr_mix]
50
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
51
+
52
+ # B x T x 1 -> B x T x num_mixtures
53
+ y = y.expand_as(means)
54
+
55
+ centered_y = y - means
56
+ inv_stdv = torch.exp(-log_scales)
57
+ plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
58
+ cdf_plus = torch.sigmoid(plus_in)
59
+ min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
60
+ cdf_min = torch.sigmoid(min_in)
61
+
62
+ # log probability for edge case of 0 (before scaling)
63
+ # equivalent: torch.log(torch.sigmoid(plus_in))
64
+ log_cdf_plus = plus_in - F.softplus(plus_in)
65
+
66
+ # log probability for edge case of 255 (before scaling)
67
+ # equivalent: (1 - torch.sigmoid(min_in)).log()
68
+ log_one_minus_cdf_min = -F.softplus(min_in)
69
+
70
+ # probability for all other cases
71
+ cdf_delta = cdf_plus - cdf_min
72
+
73
+ mid_in = inv_stdv * centered_y
74
+ # log probability in the center of the bin, to be used in extreme cases
75
+ # (not actually used in our code)
76
+ log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
77
+
78
+ # tf equivalent
79
+ """
80
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
81
+ tf.where(x > 0.999, log_one_minus_cdf_min,
82
+ tf.where(cdf_delta > 1e-5,
83
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
84
+ log_pdf_mid - np.log(127.5))))
85
+ """
86
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
87
+ # for num_classes=65536 case? 1e-7? not sure..
88
+ inner_inner_cond = (cdf_delta > 1e-5).float()
89
+
90
+ inner_inner_out = inner_inner_cond * torch.log(
91
+ torch.clamp(cdf_delta, min=1e-12)
92
+ ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
93
+ inner_cond = (y > 0.999).float()
94
+ inner_out = (
95
+ inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
96
+ )
97
+ cond = (y < -0.999).float()
98
+ log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
99
+
100
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
101
+
102
+ if reduce:
103
+ return -torch.sum(log_sum_exp(log_probs))
104
+ else:
105
+ return -log_sum_exp(log_probs).unsqueeze(-1)
106
+
107
+
108
+ def to_one_hot(tensor, n, fill_with=1.0):
109
+ # we perform one hot encore with respect to the last axis
110
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
111
+ if tensor.is_cuda:
112
+ one_hot = one_hot.cuda()
113
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
114
+ return one_hot
115
+
116
+
117
+ def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, clamp_log_scale=False):
118
+ """
119
+ Sample from discretized mixture of logistic distributions
120
+
121
+ Args:
122
+ y (Tensor): B x C x T
123
+ log_scale_min (float): Log scale minimum value
124
+
125
+ Returns:
126
+ Tensor: sample in range of [-1, 1].
127
+ """
128
+ assert y.size(1) % 3 == 0
129
+ nr_mix = y.size(1) // 3
130
+
131
+ # B x T x C
132
+ y = y.transpose(1, 2)
133
+ logit_probs = y[:, :, :nr_mix]
134
+
135
+ # sample mixture indicator from softmax
136
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
137
+ temp = logit_probs.data - torch.log(-torch.log(temp))
138
+ _, argmax = temp.max(dim=-1)
139
+
140
+ # (B, T) -> (B, T, nr_mix)
141
+ one_hot = to_one_hot(argmax, nr_mix)
142
+ # select logistic parameters
143
+ means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
144
+ log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
145
+ if clamp_log_scale:
146
+ log_scales = torch.clamp(log_scales, min=log_scale_min)
147
+ # sample from logistic & clip to interval
148
+ # we don't actually round to the nearest 8bit value when sampling
149
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
150
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
151
+
152
+ x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
153
+
154
+ return x
155
+
156
+
157
+ # we can easily define discretized version of the gaussian loss, however,
158
+ # use continuous version as same as the https://clarinet-demo.github.io/
159
+ def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True):
160
+ """Mixture of continuous gaussian distributions loss
161
+
162
+ Note that it is assumed that input is scaled to [-1, 1].
163
+
164
+ Args:
165
+ y_hat (Tensor): Predicted output (B x C x T)
166
+ y (Tensor): Target (B x T x 1).
167
+ log_scale_min (float): Log scale minimum value
168
+ reduce (bool): If True, the losses are averaged or summed for each
169
+ minibatch.
170
+ Returns
171
+ Tensor: loss
172
+ """
173
+ assert y_hat.dim() == 3
174
+ C = y_hat.size(1)
175
+ if C == 2:
176
+ nr_mix = 1
177
+ else:
178
+ assert y_hat.size(1) % 3 == 0
179
+ nr_mix = y_hat.size(1) // 3
180
+
181
+ # (B x T x C)
182
+ y_hat = y_hat.transpose(1, 2)
183
+
184
+ # unpack parameters.
185
+ if C == 2:
186
+ # special case for C == 2, just for compatibility
187
+ logit_probs = None
188
+ means = y_hat[:, :, 0:1]
189
+ log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min)
190
+ else:
191
+ # (B, T, num_mixtures) x 3
192
+ logit_probs = y_hat[:, :, :nr_mix]
193
+ means = y_hat[:, :, nr_mix : 2 * nr_mix]
194
+ log_scales = torch.clamp(
195
+ y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min
196
+ )
197
+
198
+ # B x T x 1 -> B x T x num_mixtures
199
+ y = y.expand_as(means)
200
+
201
+ centered_y = y - means
202
+ dist = Normal(loc=0.0, scale=torch.exp(log_scales))
203
+ # do we need to add a trick to avoid log(0)?
204
+ log_probs = dist.log_prob(centered_y)
205
+
206
+ if nr_mix > 1:
207
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
208
+
209
+ if reduce:
210
+ if nr_mix == 1:
211
+ return -torch.sum(log_probs)
212
+ else:
213
+ return -torch.sum(log_sum_exp(log_probs))
214
+ else:
215
+ if nr_mix == 1:
216
+ return -log_probs
217
+ else:
218
+ return -log_sum_exp(log_probs).unsqueeze(-1)
219
+
220
+
221
+ def sample_from_mix_gaussian(y, log_scale_min=-7.0):
222
+ """
223
+ Sample from (discretized) mixture of gaussian distributions
224
+ Args:
225
+ y (Tensor): B x C x T
226
+ log_scale_min (float): Log scale minimum value
227
+ Returns:
228
+ Tensor: sample in range of [-1, 1].
229
+ """
230
+ C = y.size(1)
231
+ if C == 2:
232
+ nr_mix = 1
233
+ else:
234
+ assert y.size(1) % 3 == 0
235
+ nr_mix = y.size(1) // 3
236
+
237
+ # B x T x C
238
+ y = y.transpose(1, 2)
239
+
240
+ if C == 2:
241
+ logit_probs = None
242
+ else:
243
+ logit_probs = y[:, :, :nr_mix]
244
+
245
+ if nr_mix > 1:
246
+ # sample mixture indicator from softmax
247
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
248
+ temp = logit_probs.data - torch.log(-torch.log(temp))
249
+ _, argmax = temp.max(dim=-1)
250
+
251
+ # (B, T) -> (B, T, nr_mix)
252
+ one_hot = to_one_hot(argmax, nr_mix)
253
+
254
+ # Select means and log scales
255
+ means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
256
+ log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
257
+ else:
258
+ if C == 2:
259
+ means, log_scales = y[:, :, 0], y[:, :, 1]
260
+ elif C == 3:
261
+ means, log_scales = y[:, :, 1], y[:, :, 2]
262
+ else:
263
+ assert False, "shouldn't happen"
264
+
265
+ scales = torch.exp(log_scales)
266
+ dist = Normal(loc=means, scale=scales)
267
+ x = dist.sample()
268
+
269
+ x = torch.clamp(x, min=-1.0, max=1.0)
270
+ return x
Amphion/utils/mel.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from librosa.filters import mel as librosa_mel_fn
8
+
9
+
10
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
11
+ # Min value: ln(1e-5) = -11.5129
12
+ return torch.log(torch.clamp(x, min=clip_val) * C)
13
+
14
+
15
+ def spectral_normalize_torch(magnitudes):
16
+ output = dynamic_range_compression_torch(magnitudes)
17
+ return output
18
+
19
+
20
+ def extract_linear_features(y, cfg, center=False):
21
+ if torch.min(y) < -1.0:
22
+ print("min value is ", torch.min(y))
23
+ if torch.max(y) > 1.0:
24
+ print("max value is ", torch.max(y))
25
+
26
+ global hann_window
27
+ hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
28
+
29
+ y = torch.nn.functional.pad(
30
+ y.unsqueeze(1),
31
+ (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
32
+ mode="reflect",
33
+ )
34
+ y = y.squeeze(1)
35
+
36
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
37
+ spec = torch.stft(
38
+ y,
39
+ cfg.n_fft,
40
+ hop_length=cfg.hop_size,
41
+ win_length=cfg.win_size,
42
+ window=hann_window[str(y.device)],
43
+ center=center,
44
+ pad_mode="reflect",
45
+ normalized=False,
46
+ onesided=True,
47
+ return_complex=True,
48
+ )
49
+ spec = torch.view_as_real(spec)
50
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
51
+ spec = torch.squeeze(spec, 0)
52
+ return spec
53
+
54
+
55
+ def mel_spectrogram_torch(y, cfg, center=False):
56
+ """
57
+ TODO: to merge this funtion with the extract_mel_features below
58
+ """
59
+ if torch.min(y) < -1.0:
60
+ print("min value is ", torch.min(y))
61
+ if torch.max(y) > 1.0:
62
+ print("max value is ", torch.max(y))
63
+
64
+ global mel_basis, hann_window
65
+ if cfg.fmax not in mel_basis:
66
+ mel = librosa_mel_fn(
67
+ sr=cfg.sample_rate,
68
+ n_fft=cfg.n_fft,
69
+ n_mels=cfg.n_mel,
70
+ fmin=cfg.fmin,
71
+ fmax=cfg.fmax,
72
+ )
73
+ mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
74
+ torch.from_numpy(mel).float().to(y.device)
75
+ )
76
+ hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
77
+
78
+ y = torch.nn.functional.pad(
79
+ y.unsqueeze(1),
80
+ (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
81
+ mode="reflect",
82
+ )
83
+ y = y.squeeze(1)
84
+
85
+ spec = torch.stft(
86
+ y,
87
+ cfg.n_fft,
88
+ hop_length=cfg.hop_size,
89
+ win_length=cfg.win_size,
90
+ window=hann_window[str(y.device)],
91
+ center=center,
92
+ pad_mode="reflect",
93
+ normalized=False,
94
+ onesided=True,
95
+ return_complex=True,
96
+ )
97
+
98
+ spec = torch.view_as_real(spec)
99
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
100
+
101
+ spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
102
+ spec = spectral_normalize_torch(spec)
103
+
104
+ return spec
105
+
106
+
107
+ mel_basis = {}
108
+ hann_window = {}
109
+
110
+
111
+ def extract_mel_features(
112
+ y,
113
+ cfg,
114
+ center=False,
115
+ ):
116
+ """Extract mel features
117
+
118
+ Args:
119
+ y (tensor): audio data in tensor
120
+ cfg (dict): configuration in cfg.preprocess
121
+ center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
122
+
123
+ Returns:
124
+ tensor: a tensor containing the mel feature calculated based on STFT result
125
+ """
126
+ if torch.min(y) < -1.0:
127
+ print("min value is ", torch.min(y))
128
+ if torch.max(y) > 1.0:
129
+ print("max value is ", torch.max(y))
130
+
131
+ global mel_basis, hann_window
132
+ if cfg.fmax not in mel_basis:
133
+ mel = librosa_mel_fn(
134
+ sr=cfg.sample_rate,
135
+ n_fft=cfg.n_fft,
136
+ n_mels=cfg.n_mel,
137
+ fmin=cfg.fmin,
138
+ fmax=cfg.fmax,
139
+ )
140
+ mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
141
+ torch.from_numpy(mel).float().to(y.device)
142
+ )
143
+ hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
144
+
145
+ y = torch.nn.functional.pad(
146
+ y.unsqueeze(1),
147
+ (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
148
+ mode="reflect",
149
+ )
150
+ y = y.squeeze(1)
151
+
152
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
153
+ spec = torch.stft(
154
+ y,
155
+ cfg.n_fft,
156
+ hop_length=cfg.hop_size,
157
+ win_length=cfg.win_size,
158
+ window=hann_window[str(y.device)],
159
+ center=center,
160
+ pad_mode="reflect",
161
+ normalized=False,
162
+ onesided=True,
163
+ return_complex=True,
164
+ )
165
+ spec = torch.view_as_real(spec)
166
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
167
+
168
+ spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
169
+ spec = spectral_normalize_torch(spec)
170
+ return spec.squeeze(0)
171
+
172
+
173
+ def extract_mel_features_tts(
174
+ y,
175
+ cfg,
176
+ center=False,
177
+ taco=False,
178
+ _stft=None,
179
+ ):
180
+ """Extract mel features
181
+
182
+ Args:
183
+ y (tensor): audio data in tensor
184
+ cfg (dict): configuration in cfg.preprocess
185
+ center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
186
+ taco: use tacotron mel
187
+
188
+ Returns:
189
+ tensor: a tensor containing the mel feature calculated based on STFT result
190
+ """
191
+ if not taco:
192
+ if torch.min(y) < -1.0:
193
+ print("min value is ", torch.min(y))
194
+ if torch.max(y) > 1.0:
195
+ print("max value is ", torch.max(y))
196
+
197
+ global mel_basis, hann_window
198
+ if cfg.fmax not in mel_basis:
199
+ mel = librosa_mel_fn(
200
+ sr=cfg.sample_rate,
201
+ n_fft=cfg.n_fft,
202
+ n_mels=cfg.n_mel,
203
+ fmin=cfg.fmin,
204
+ fmax=cfg.fmax,
205
+ )
206
+ mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
207
+ torch.from_numpy(mel).float().to(y.device)
208
+ )
209
+ hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
210
+
211
+ y = torch.nn.functional.pad(
212
+ y.unsqueeze(1),
213
+ (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
214
+ mode="reflect",
215
+ )
216
+ y = y.squeeze(1)
217
+
218
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
219
+ spec = torch.stft(
220
+ y,
221
+ cfg.n_fft,
222
+ hop_length=cfg.hop_size,
223
+ win_length=cfg.win_size,
224
+ window=hann_window[str(y.device)],
225
+ center=center,
226
+ pad_mode="reflect",
227
+ normalized=False,
228
+ onesided=True,
229
+ return_complex=True,
230
+ )
231
+ spec = torch.view_as_real(spec)
232
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
233
+
234
+ spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
235
+ spec = spectral_normalize_torch(spec)
236
+ else:
237
+ audio = torch.clip(y, -1, 1)
238
+ audio = torch.autograd.Variable(audio, requires_grad=False)
239
+ spec, energy = _stft.mel_spectrogram(audio)
240
+
241
+ return spec.squeeze(0)
242
+
243
+
244
+ def amplitude_phase_spectrum(y, cfg):
245
+ hann_window = torch.hann_window(cfg.win_size).to(y.device)
246
+
247
+ y = torch.nn.functional.pad(
248
+ y.unsqueeze(1),
249
+ (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
250
+ mode="reflect",
251
+ )
252
+ y = y.squeeze(1)
253
+
254
+ stft_spec = torch.stft(
255
+ y,
256
+ cfg.n_fft,
257
+ hop_length=cfg.hop_size,
258
+ win_length=cfg.win_size,
259
+ window=hann_window,
260
+ center=False,
261
+ return_complex=True,
262
+ )
263
+
264
+ stft_spec = torch.view_as_real(stft_spec)
265
+ if stft_spec.size()[0] == 1:
266
+ stft_spec = stft_spec.squeeze(0)
267
+
268
+ if len(list(stft_spec.size())) == 4:
269
+ rea = stft_spec[:, :, :, 0] # [batch_size, n_fft//2+1, frames]
270
+ imag = stft_spec[:, :, :, 1] # [batch_size, n_fft//2+1, frames]
271
+ else:
272
+ rea = stft_spec[:, :, 0] # [n_fft//2+1, frames]
273
+ imag = stft_spec[:, :, 1] # [n_fft//2+1, frames]
274
+
275
+ log_amplitude = torch.log(
276
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
277
+ ) # [n_fft//2+1, frames]
278
+ phase = torch.atan2(imag, rea) # [n_fft//2+1, frames]
279
+
280
+ return log_amplitude, phase, rea, imag
Amphion/utils/prompt_preparer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class PromptPreparer:
10
+ def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
11
+ if self.prefix_mode == 0:
12
+ y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage)
13
+ elif self.prefix_mode == 1:
14
+ y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage)
15
+ elif self.prefix_mode in [2, 4]:
16
+ y_emb, prefix_len = self._handle_prefix_mode_2_4(
17
+ y, y_lens, codes, nar_stage, y_prompts_codes
18
+ )
19
+ else:
20
+ raise ValueError("Invalid prefix mode")
21
+
22
+ return y_emb, prefix_len
23
+
24
+ def _handle_prefix_mode_0(self, y, codes, nar_stage):
25
+ prefix_len = 0
26
+ y_emb = self.nar_audio_embeddings[0](y)
27
+ for j in range(1, nar_stage):
28
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
29
+ return y_emb, 0
30
+
31
+ def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage):
32
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
33
+ prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
34
+ prefix_len = min(prefix_len, 225)
35
+
36
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
37
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
38
+ for j in range(1, self.num_quantizers):
39
+ y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
40
+ if j < nar_stage:
41
+ y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
42
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
43
+ return y_emb, prefix_len
44
+
45
+ def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes):
46
+ if self.prefix_mode == 2:
47
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
48
+
49
+ y_prompts_codes = []
50
+ for b in range(codes.shape[0]):
51
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
52
+ y_prompts_codes.append(
53
+ torch.clone(codes[b, start : start + prefix_len])
54
+ )
55
+ codes[b, start : start + prefix_len, nar_stage] = self.audio_token_num
56
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
57
+ else:
58
+ prefix_len = y_prompts_codes.shape[1]
59
+
60
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
61
+ y_emb = self.nar_audio_embeddings[0](y)
62
+ for j in range(1, self.num_quantizers):
63
+ y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
64
+ if j < nar_stage:
65
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
66
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
67
+
68
+ return y_emb, prefix_len
__pycache__/model.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
conf/default.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ is_left_ear: true
2
+ loss_function: MultiResolutionL1SpecLoss
3
+
4
+ trainer:
5
+ acoustic_hop_length: 256
6
+ acoustic_n_fft: 512
7
+ acoustic_sr: 16000
8
+ acoustic_win_length: 512
9
+ adam_beta1: 0.9
10
+ adam_beta2: 0.999
11
+ adam_epsilon: 1.0e-08
12
+ dataloader_drop_last: true
13
+ dataloader_num_workers: 8
14
+ dataloader_persistent_workers: false
15
+ dataloader_pin_memory: true
16
+ dataloader_prefetch_factor: 2
17
+ ddp_find_unused_parameters: true
18
+ debug: false
19
+ do_eval: false
20
+ do_predict: false
21
+ do_train: true
22
+ early_stopping_patience: 20
23
+ eval_batch_size: 1
24
+ eval_epoch_interval: 20
25
+ gradient_accumulation_steps: 1
26
+ greater_is_better: true
27
+ learning_rate: 0.001
28
+ lr_scheduler_type: constant_schedule_with_warmup
29
+ max_grad_norm: 5
30
+ max_steps: 0
31
+ metric_for_best_model: OVRL
32
+ num_train_epochs: 2000
33
+ optim: adamw
34
+ output_dir: exp/bmi__bmi__fix-ddp-sampler__warmup-2000__fixed-length__rerun
35
+ per_device_train_batch_size: 8
36
+ plot_lr: false
37
+ resume_from_checkpoint: "no"
38
+ save_epoch_interval: 20
39
+ save_total_limit: 100
40
+ seed: 20220815
41
+ warmup_ratio: 0.0
42
+ warmup_steps: 2000
43
+
44
+ predict_dataset:
45
+ enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/speaker_adapt
46
+ normalize: true
47
+ scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.dev.json
48
+ scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/scenes/
49
+ scenes_listeners_file: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes_listeners.dev.json
50
+ small_test: false
51
+
52
+ train_dataset:
53
+ enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/train/targets
54
+ enroll_len_limit: 10
55
+ limit: -1
56
+ normalize: false
57
+ sample_len_limit: 4
58
+ scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.train.json
59
+ scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/train/scenes/
60
+ sr: 44100
61
+ use_all_enroll: false
62
+ use_additional_data: true
63
+
64
+ eval_dev_dataset:
65
+ scenes_listeners_file: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes_listeners.dev.json
66
+ scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/scenes/
67
+ enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/speaker_adapt
68
+ limit: 500
69
+ normalize: false
70
+ scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.dev.json
exp/bmi__fa-codec/2024_05_20--16_21_26.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
2
+ 05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
3
+ 05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
4
+ 05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
exp/bmi__fa-codec/2024_05_20--16_22_35.log ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 05-20 16:22:35 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
2
+ 05-20 16:22:35 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
3
+ 05-20 16:22:39 INFO [logging.py:61]:
4
+ Environment information:
5
+ - `Accelerate` version: 0.28.0
6
+ - Platform: Linux-6.1.0-18-amd64-x86_64-with-glibc2.36
7
+ - Python version: 3.10.13
8
+ - Numpy version: 1.26.4
9
+ - PyTorch version (GPU?): 2.2.2 (True)
10
+ - System RAM: 503.49 GB
11
+ - GPU Available: True
12
+ - GPU IDs: 8
13
+ - GPU type: NVIDIA RTX A6000
14
+ 05-20 16:22:39 INFO [logging.py:61]:
15
+ ===============================================================================================
16
+ Layer (type:depth-idx) Param #
17
+ ===============================================================================================
18
+ Model --
19
+ ├─Linear: 1-1 8,224
20
+ ├─FACodecEncoder: 1-2 --
21
+ │ └─Sequential: 2-1 --
22
+ │ │ └─Conv1d: 3-1 (288)
23
+ │ │ └─EncoderBlock: 3-2 (33,728)
24
+ │ │ └─EncoderBlock: 3-3 (165,760)
25
+ │ │ └─EncoderBlock: 3-4 (724,736)
26
+ │ │ └─EncoderBlock: 3-5 (2,891,264)
27
+ │ │ └─Activation1d: 3-6 (1,024)
28
+ │ │ └─Conv1d: 3-7 (393,728)
29
+ ├─FACodecDecoder: 1-3 --
30
+ │ └─ModuleList: 2-2 --
31
+ │ │ └─ResidualVQ: 3-8 (12,816)
32
+ │ │ └─ResidualVQ: 3-9 (25,632)
33
+ │ │ └─ResidualVQ: 3-10 (38,448)
34
+ │ └─Sequential: 2-3 --
35
+ │ │ └─Conv1d: 3-11 (1,837,056)
36
+ │ │ └─DecoderBlock: 3-12 (11,550,208)
37
+ │ │ └─DecoderBlock: 3-13 (2,891,520)
38
+ │ │ └─DecoderBlock: 3-14 (659,328)
39
+ │ │ └─DecoderBlock: 3-15 (133,056)
40
+ │ │ └─Activation1d: 3-16 (128)
41
+ │ │ └─Conv1d: 3-17 (450)
42
+ │ │ └─Tanh: 3-18 --
43
+ │ └─TransformerEncoder: 2-4 --
44
+ │ │ └─PositionalEncoding: 3-19 --
45
+ │ │ └─ModuleList: 3-20 (7,353,344)
46
+ │ │ └─LayerNorm: 3-21 (512)
47
+ │ └─Linear: 2-5 (131,584)
48
+ │ └─LayerNorm: 2-6 --
49
+ │ └─CNNLSTM: 2-7 --
50
+ │ │ └─Sequential: 3-22 (1,579,520)
51
+ │ │ └─ModuleList: 3-23 (514)
52
+ │ └─CNNLSTM: 2-8 --
53
+ │ │ └─Sequential: 3-24 (1,579,520)
54
+ │ │ └─ModuleList: 3-25 (1,285,771)
55
+ │ └─Sequential: 2-9 --
56
+ │ │ └─GradientReversal: 3-26 --
57
+ │ │ └─CNNLSTM: 3-27 (1,580,034)
58
+ │ └─Sequential: 2-10 --
59
+ │ │ └─GradientReversal: 3-28 --
60
+ │ │ └─CNNLSTM: 3-29 (2,865,291)
61
+ │ └─Sequential: 2-11 --
62
+ │ │ └─GradientReversal: 3-30 --
63
+ │ │ └─CNNLSTM: 3-31 (64,595,920)
64
+ ├─ERB: 1-4 --
65
+ │ └─Linear: 2-12 (3,728)
66
+ │ └─Linear: 2-13 (3,728)
67
+ ├─SubbandFeatureExtractor: 1-5 --
68
+ │ └─Unfold: 2-14 --
69
+ ├─Sequential: 1-6 --
70
+ │ └─GroupNorm: 2-15 36
71
+ │ └─Conv1d: 2-16 608
72
+ ├─ModuleList: 1-7 --
73
+ │ └─TriplePathRNN: 2-17 --
74
+ │ │ └─ResRNN: 3-32 54,368
75
+ │ │ └─ResRNN: 3-33 54,368
76
+ │ │ └─Linear: 3-34 2,080
77
+ │ └─TriplePathRNN: 2-18 --
78
+ │ │ └─ResRNN: 3-35 54,368
79
+ │ │ └─ResRNN: 3-36 54,368
80
+ │ │ └─Linear: 3-37 2,080
81
+ │ └─TriplePathRNN: 2-19 --
82
+ │ │ └─ResRNN: 3-38 54,368
83
+ │ │ └─ResRNN: 3-39 54,368
84
+ │ │ └─Linear: 3-40 2,080
85
+ ├─Sequential: 1-8 --
86
+ │ └─GroupNorm: 2-20 64
87
+ │ └─Conv1d: 2-21 2,112
88
+ │ └─Tanh: 2-22 --
89
+ │ └─Conv1d: 2-23 4,160
90
+ │ └─Tanh: 2-24 --
91
+ │ └─Conv1d: 2-25 260
92
+ ===============================================================================================
93
+ Total params: 102,686,548
94
+ Trainable params: 347,912
95
+ Non-trainable params: 102,338,636
96
+ ===============================================================================================
97
+ 05-20 16:22:40 INFO [logging.py:61]: warmup_steps=2000. warmup_ratio will be ignored.
98
+ 05-20 16:22:41 INFO [logging.py:61]: Will start from scratch (no checkpoint will be loaded).
99
+ 05-20 16:22:41 INFO [logging.py:61]: ***** Running training *****
100
+ 05-20 16:22:41 INFO [logging.py:61]: Num Epochs = 2,000
101
+ 05-20 16:22:41 INFO [logging.py:61]: `steps_per_epoch` = 125
102
+ 05-20 16:22:41 INFO [logging.py:61]: Instantaneous batch size per device = 16
103
+ 05-20 16:22:41 INFO [logging.py:61]: Gradient Accumulation steps = 1
104
+ 05-20 16:22:41 INFO [logging.py:61]: Total optimization steps = 250,000
105
+ 05-20 16:22:41 INFO [logging.py:61]: ========= Epoch 1 out of 2000 =========
106
+ 05-20 16:22:41 INFO [logging.py:61]: Begin training...
107
+ 05-20 16:23:34 INFO [logging.py:61]: Loss 'loss' on epoch 1: 0.28071990609169006
108
+ 05-20 16:23:34 INFO [logging.py:61]: Loss 'norm_before' on epoch 1: 0.033199019730091095
109
+ 05-20 16:23:34 INFO [logging.py:61]: ========= Epoch 2 out of 2000 =========
110
+ 05-20 16:23:34 INFO [logging.py:61]: Begin training...
exp/bmi__fa-codec/2024_05_20--16_24_01.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
2
+ 05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
3
+ 05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
4
+ 05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
exp/bmi__fa-codec/amplified_signals/S06021_L0014_HA-output.wav ADDED
Binary file (554 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06026_L0088_HA-output.wav ADDED
Binary file (449 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06031_L0096_HA-output.wav ADDED
Binary file (478 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06036_L0036_HA-output.wav ADDED
Binary file (416 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06066_L0042_HA-output.wav ADDED
Binary file (535 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06071_L0089_HA-output.wav ADDED
Binary file (528 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06086_L0072_HA-output.wav ADDED
Binary file (442 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06091_L0099_HA-output.wav ADDED
Binary file (623 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06101_L0042_HA-output.wav ADDED
Binary file (484 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06111_L0002_HA-output.wav ADDED
Binary file (481 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06116_L0092_HA-output.wav ADDED
Binary file (643 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06126_L0069_HA-output.wav ADDED
Binary file (522 kB). View file
 
exp/bmi__fa-codec/amplified_signals/S06146_L0017_HA-output.wav ADDED
Binary file (522 kB). View file