Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Amphion/models/base/base_trainer.py +348 -0
- Amphion/models/codec/ns3_codec/__pycache__/melspec.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/__pycache__/transformer.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/filter.py +96 -0
- Amphion/models/svc/diffusion/diffusion_trainer.py +102 -0
- Amphion/models/tta/autoencoder/__init__.py +0 -0
- Amphion/models/tta/autoencoder/autoencoder_loss.py +305 -0
- Amphion/models/tta/ldm/audioldm_inference.py +193 -0
- Amphion/models/tts/base/tts_inferece.py +278 -0
- Amphion/models/tts/fastspeech2/fs2_dataset.py +424 -0
- Amphion/models/tts/naturalspeech2/diffusion.py +124 -0
- Amphion/models/tts/vits/vits_inference.py +163 -0
- Amphion/models/vocoders/flow/flow_vocoder_trainer.py +0 -0
- Amphion/models/vocoders/gan/gan_vocoder_dataset.py +205 -0
- Amphion/modules/anti_aliasing/__init__.py +8 -0
- Amphion/modules/encoder/condition_encoder.py +244 -0
- Amphion/modules/general/__init__.py +3 -0
- Amphion/modules/monotonic_align/__init__.py +21 -0
- Amphion/modules/neural_source_filter/__init__.py +6 -0
- Amphion/modules/transformer/Layers.py +137 -0
- Amphion/modules/wenet_extractor/cif/predictor.py +274 -0
- Amphion/modules/wenet_extractor/paraformer/search/beam_search.py +479 -0
- Amphion/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py +377 -0
- Amphion/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py +88 -0
- Amphion/modules/wenet_extractor/transformer/decoder_layer.py +140 -0
- Amphion/modules/wenet_extractor/transformer/subsampling.py +257 -0
- Amphion/modules/wenet_extractor/utils/__init__.py +0 -0
- Amphion/preprocessors/cdmusiceval.py +174 -0
- Amphion/utils/data_utils.py +588 -0
- Amphion/utils/distribution.py +270 -0
- Amphion/utils/mel.py +280 -0
- Amphion/utils/prompt_preparer.py +68 -0
- __pycache__/model.cpython-310.pyc +0 -0
- conf/default.yaml +70 -0
- exp/bmi__fa-codec/2024_05_20--16_21_26.log +4 -0
- exp/bmi__fa-codec/2024_05_20--16_22_35.log +110 -0
- exp/bmi__fa-codec/2024_05_20--16_24_01.log +4 -0
- exp/bmi__fa-codec/amplified_signals/S06021_L0014_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06026_L0088_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06031_L0096_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06036_L0036_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06066_L0042_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06071_L0089_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06086_L0072_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06091_L0099_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06101_L0042_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06111_L0002_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06116_L0092_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06126_L0069_HA-output.wav +0 -0
- exp/bmi__fa-codec/amplified_signals/S06146_L0017_HA-output.wav +0 -0
Amphion/models/base/base_trainer.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import collections
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
| 16 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 17 |
+
|
| 18 |
+
from models.base.base_sampler import BatchSampler
|
| 19 |
+
from utils.util import (
|
| 20 |
+
Logger,
|
| 21 |
+
remove_older_ckpt,
|
| 22 |
+
save_config,
|
| 23 |
+
set_all_random_seed,
|
| 24 |
+
ValueWindow,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseTrainer(object):
|
| 29 |
+
def __init__(self, args, cfg):
|
| 30 |
+
self.args = args
|
| 31 |
+
self.log_dir = args.log_dir
|
| 32 |
+
self.cfg = cfg
|
| 33 |
+
|
| 34 |
+
self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
|
| 35 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 36 |
+
if not cfg.train.ddp or args.local_rank == 0:
|
| 37 |
+
self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
|
| 38 |
+
self.logger = self.build_logger()
|
| 39 |
+
self.time_window = ValueWindow(50)
|
| 40 |
+
|
| 41 |
+
self.step = 0
|
| 42 |
+
self.epoch = -1
|
| 43 |
+
self.max_epochs = self.cfg.train.epochs
|
| 44 |
+
self.max_steps = self.cfg.train.max_steps
|
| 45 |
+
|
| 46 |
+
# set random seed & init distributed training
|
| 47 |
+
set_all_random_seed(self.cfg.train.random_seed)
|
| 48 |
+
if cfg.train.ddp:
|
| 49 |
+
dist.init_process_group(backend="nccl")
|
| 50 |
+
|
| 51 |
+
if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
|
| 52 |
+
self.singers = self.build_singers_lut()
|
| 53 |
+
|
| 54 |
+
# setup data_loader
|
| 55 |
+
self.data_loader = self.build_data_loader()
|
| 56 |
+
|
| 57 |
+
# setup model & enable distributed training
|
| 58 |
+
self.model = self.build_model()
|
| 59 |
+
print(self.model)
|
| 60 |
+
|
| 61 |
+
if isinstance(self.model, dict):
|
| 62 |
+
for key, value in self.model.items():
|
| 63 |
+
value.cuda(self.args.local_rank)
|
| 64 |
+
if key == "PQMF":
|
| 65 |
+
continue
|
| 66 |
+
if cfg.train.ddp:
|
| 67 |
+
self.model[key] = DistributedDataParallel(
|
| 68 |
+
value, device_ids=[self.args.local_rank]
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
self.model.cuda(self.args.local_rank)
|
| 72 |
+
if cfg.train.ddp:
|
| 73 |
+
self.model = DistributedDataParallel(
|
| 74 |
+
self.model, device_ids=[self.args.local_rank]
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# create criterion
|
| 78 |
+
self.criterion = self.build_criterion()
|
| 79 |
+
if isinstance(self.criterion, dict):
|
| 80 |
+
for key, value in self.criterion.items():
|
| 81 |
+
self.criterion[key].cuda(args.local_rank)
|
| 82 |
+
else:
|
| 83 |
+
self.criterion.cuda(self.args.local_rank)
|
| 84 |
+
|
| 85 |
+
# optimizer
|
| 86 |
+
self.optimizer = self.build_optimizer()
|
| 87 |
+
self.scheduler = self.build_scheduler()
|
| 88 |
+
|
| 89 |
+
# save config file
|
| 90 |
+
self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
|
| 91 |
+
|
| 92 |
+
def build_logger(self):
|
| 93 |
+
log_file = os.path.join(self.checkpoint_dir, "train.log")
|
| 94 |
+
logger = Logger(log_file, level=self.args.log_level).logger
|
| 95 |
+
|
| 96 |
+
return logger
|
| 97 |
+
|
| 98 |
+
def build_dataset(self):
|
| 99 |
+
raise NotImplementedError
|
| 100 |
+
|
| 101 |
+
def build_data_loader(self):
|
| 102 |
+
Dataset, Collator = self.build_dataset()
|
| 103 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
| 104 |
+
datasets_list = []
|
| 105 |
+
for dataset in self.cfg.dataset:
|
| 106 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
| 107 |
+
datasets_list.append(subdataset)
|
| 108 |
+
train_dataset = ConcatDataset(datasets_list)
|
| 109 |
+
|
| 110 |
+
train_collate = Collator(self.cfg)
|
| 111 |
+
# TODO: multi-GPU training
|
| 112 |
+
if self.cfg.train.ddp:
|
| 113 |
+
raise NotImplementedError("DDP is not supported yet.")
|
| 114 |
+
|
| 115 |
+
# sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
|
| 116 |
+
batch_sampler = BatchSampler(
|
| 117 |
+
cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
| 121 |
+
train_loader = DataLoader(
|
| 122 |
+
train_dataset,
|
| 123 |
+
collate_fn=train_collate,
|
| 124 |
+
num_workers=self.args.num_workers,
|
| 125 |
+
batch_sampler=batch_sampler,
|
| 126 |
+
pin_memory=False,
|
| 127 |
+
)
|
| 128 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
| 129 |
+
datasets_list = []
|
| 130 |
+
for dataset in self.cfg.dataset:
|
| 131 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
| 132 |
+
datasets_list.append(subdataset)
|
| 133 |
+
valid_dataset = ConcatDataset(datasets_list)
|
| 134 |
+
valid_collate = Collator(self.cfg)
|
| 135 |
+
batch_sampler = BatchSampler(
|
| 136 |
+
cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
|
| 137 |
+
)
|
| 138 |
+
valid_loader = DataLoader(
|
| 139 |
+
valid_dataset,
|
| 140 |
+
collate_fn=valid_collate,
|
| 141 |
+
num_workers=1,
|
| 142 |
+
batch_sampler=batch_sampler,
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
raise NotImplementedError("DDP is not supported yet.")
|
| 146 |
+
# valid_loader = None
|
| 147 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
| 148 |
+
return data_loader
|
| 149 |
+
|
| 150 |
+
def build_singers_lut(self):
|
| 151 |
+
# combine singers
|
| 152 |
+
if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
|
| 153 |
+
singers = collections.OrderedDict()
|
| 154 |
+
else:
|
| 155 |
+
with open(
|
| 156 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
|
| 157 |
+
) as singer_file:
|
| 158 |
+
singers = json.load(singer_file)
|
| 159 |
+
singer_count = len(singers)
|
| 160 |
+
for dataset in self.cfg.dataset:
|
| 161 |
+
singer_lut_path = os.path.join(
|
| 162 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
| 163 |
+
)
|
| 164 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
| 165 |
+
singer_lut = json.load(singer_lut_path)
|
| 166 |
+
for singer in singer_lut.keys():
|
| 167 |
+
if singer not in singers:
|
| 168 |
+
singers[singer] = singer_count
|
| 169 |
+
singer_count += 1
|
| 170 |
+
with open(
|
| 171 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
|
| 172 |
+
) as singer_file:
|
| 173 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
| 174 |
+
print(
|
| 175 |
+
"singers have been dumped to {}".format(
|
| 176 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
|
| 177 |
+
)
|
| 178 |
+
)
|
| 179 |
+
return singers
|
| 180 |
+
|
| 181 |
+
def build_model(self):
|
| 182 |
+
raise NotImplementedError()
|
| 183 |
+
|
| 184 |
+
def build_optimizer(self):
|
| 185 |
+
raise NotImplementedError
|
| 186 |
+
|
| 187 |
+
def build_scheduler(self):
|
| 188 |
+
raise NotImplementedError()
|
| 189 |
+
|
| 190 |
+
def build_criterion(self):
|
| 191 |
+
raise NotImplementedError
|
| 192 |
+
|
| 193 |
+
def get_state_dict(self):
|
| 194 |
+
raise NotImplementedError
|
| 195 |
+
|
| 196 |
+
def save_config_file(self):
|
| 197 |
+
save_config(self.config_save_path, self.cfg)
|
| 198 |
+
|
| 199 |
+
# TODO, save without module.
|
| 200 |
+
def save_checkpoint(self, state_dict, saved_model_path):
|
| 201 |
+
torch.save(state_dict, saved_model_path)
|
| 202 |
+
|
| 203 |
+
def load_checkpoint(self):
|
| 204 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
|
| 205 |
+
assert os.path.exists(checkpoint_path)
|
| 206 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
| 207 |
+
model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
|
| 208 |
+
assert os.path.exists(model_path)
|
| 209 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
| 210 |
+
self.logger.info(f"Re(store) from {model_path}")
|
| 211 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 212 |
+
return checkpoint
|
| 213 |
+
|
| 214 |
+
def load_model(self, checkpoint):
|
| 215 |
+
raise NotImplementedError
|
| 216 |
+
|
| 217 |
+
def restore(self):
|
| 218 |
+
checkpoint = self.load_checkpoint()
|
| 219 |
+
self.load_model(checkpoint)
|
| 220 |
+
|
| 221 |
+
def train_step(self, data):
|
| 222 |
+
raise NotImplementedError(
|
| 223 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
| 224 |
+
f"your sub-class of {self.__class__.__name__}. "
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def eval_step(self):
|
| 229 |
+
raise NotImplementedError(
|
| 230 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
| 231 |
+
f"your sub-class of {self.__class__.__name__}. "
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def write_summary(self, losses, stats):
|
| 235 |
+
raise NotImplementedError(
|
| 236 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
| 237 |
+
f"your sub-class of {self.__class__.__name__}. "
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def write_valid_summary(self, losses, stats):
|
| 241 |
+
raise NotImplementedError(
|
| 242 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
| 243 |
+
f"your sub-class of {self.__class__.__name__}. "
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def echo_log(self, losses, mode="Training"):
|
| 247 |
+
message = [
|
| 248 |
+
"{} - Epoch {} Step {}: [{:.3f} s/step]".format(
|
| 249 |
+
mode, self.epoch + 1, self.step, self.time_window.average
|
| 250 |
+
)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
for key in sorted(losses.keys()):
|
| 254 |
+
if isinstance(losses[key], dict):
|
| 255 |
+
for k, v in losses[key].items():
|
| 256 |
+
message.append(
|
| 257 |
+
str(k).split("/")[-1] + "=" + str(round(float(v), 5))
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
message.append(
|
| 261 |
+
str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
|
| 262 |
+
)
|
| 263 |
+
self.logger.info(", ".join(message))
|
| 264 |
+
|
| 265 |
+
def eval_epoch(self):
|
| 266 |
+
self.logger.info("Validation...")
|
| 267 |
+
valid_losses = {}
|
| 268 |
+
for i, batch_data in enumerate(self.data_loader["valid"]):
|
| 269 |
+
for k, v in batch_data.items():
|
| 270 |
+
if isinstance(v, torch.Tensor):
|
| 271 |
+
batch_data[k] = v.cuda()
|
| 272 |
+
valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
|
| 273 |
+
for key in valid_loss:
|
| 274 |
+
if key not in valid_losses:
|
| 275 |
+
valid_losses[key] = 0
|
| 276 |
+
valid_losses[key] += valid_loss[key]
|
| 277 |
+
|
| 278 |
+
# Add mel and audio to the Tensorboard
|
| 279 |
+
# Average loss
|
| 280 |
+
for key in valid_losses:
|
| 281 |
+
valid_losses[key] /= i + 1
|
| 282 |
+
self.echo_log(valid_losses, "Valid")
|
| 283 |
+
return valid_losses, valid_stats
|
| 284 |
+
|
| 285 |
+
def train_epoch(self):
|
| 286 |
+
for i, batch_data in enumerate(self.data_loader["train"]):
|
| 287 |
+
start_time = time.time()
|
| 288 |
+
# Put the data to cuda device
|
| 289 |
+
for k, v in batch_data.items():
|
| 290 |
+
if isinstance(v, torch.Tensor):
|
| 291 |
+
batch_data[k] = v.cuda(self.args.local_rank)
|
| 292 |
+
|
| 293 |
+
# Training step
|
| 294 |
+
train_losses, train_stats, total_loss = self.train_step(batch_data)
|
| 295 |
+
self.time_window.append(time.time() - start_time)
|
| 296 |
+
|
| 297 |
+
if self.args.local_rank == 0 or not self.cfg.train.ddp:
|
| 298 |
+
if self.step % self.args.stdout_interval == 0:
|
| 299 |
+
self.echo_log(train_losses, "Training")
|
| 300 |
+
|
| 301 |
+
if self.step % self.cfg.train.save_summary_steps == 0:
|
| 302 |
+
self.logger.info(f"Save summary as step {self.step}")
|
| 303 |
+
self.write_summary(train_losses, train_stats)
|
| 304 |
+
|
| 305 |
+
if (
|
| 306 |
+
self.step % self.cfg.train.save_checkpoints_steps == 0
|
| 307 |
+
and self.step != 0
|
| 308 |
+
):
|
| 309 |
+
saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
|
| 310 |
+
self.step, total_loss
|
| 311 |
+
)
|
| 312 |
+
saved_model_path = os.path.join(
|
| 313 |
+
self.checkpoint_dir, saved_model_name
|
| 314 |
+
)
|
| 315 |
+
saved_state_dict = self.get_state_dict()
|
| 316 |
+
self.save_checkpoint(saved_state_dict, saved_model_path)
|
| 317 |
+
self.save_config_file()
|
| 318 |
+
# keep max n models
|
| 319 |
+
remove_older_ckpt(
|
| 320 |
+
saved_model_name,
|
| 321 |
+
self.checkpoint_dir,
|
| 322 |
+
max_to_keep=self.cfg.train.keep_checkpoint_max,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
|
| 326 |
+
if isinstance(self.model, dict):
|
| 327 |
+
for key in self.model.keys():
|
| 328 |
+
self.model[key].eval()
|
| 329 |
+
else:
|
| 330 |
+
self.model.eval()
|
| 331 |
+
# Evaluate one epoch and get average loss
|
| 332 |
+
valid_losses, valid_stats = self.eval_epoch()
|
| 333 |
+
if isinstance(self.model, dict):
|
| 334 |
+
for key in self.model.keys():
|
| 335 |
+
self.model[key].train()
|
| 336 |
+
else:
|
| 337 |
+
self.model.train()
|
| 338 |
+
# Write validation losses to summary.
|
| 339 |
+
self.write_valid_summary(valid_losses, valid_stats)
|
| 340 |
+
self.step += 1
|
| 341 |
+
|
| 342 |
+
def train(self):
|
| 343 |
+
for epoch in range(max(0, self.epoch), self.max_epochs):
|
| 344 |
+
self.train_epoch()
|
| 345 |
+
self.epoch += 1
|
| 346 |
+
if self.step > self.max_steps:
|
| 347 |
+
self.logger.info("Training finished!")
|
| 348 |
+
break
|
Amphion/models/codec/ns3_codec/__pycache__/melspec.cpython-310.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/__pycache__/transformer.cpython-310.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/alias_free_torch/filter.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
if "sinc" in dir(torch):
|
| 9 |
+
sinc = torch.sinc
|
| 10 |
+
else:
|
| 11 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 12 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 13 |
+
def sinc(x: torch.Tensor):
|
| 14 |
+
"""
|
| 15 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 16 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 17 |
+
"""
|
| 18 |
+
return torch.where(
|
| 19 |
+
x == 0,
|
| 20 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 21 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 27 |
+
def kaiser_sinc_filter1d(
|
| 28 |
+
cutoff, half_width, kernel_size
|
| 29 |
+
): # return filter [1,1,kernel_size]
|
| 30 |
+
even = kernel_size % 2 == 0
|
| 31 |
+
half_size = kernel_size // 2
|
| 32 |
+
|
| 33 |
+
# For kaiser window
|
| 34 |
+
delta_f = 4 * half_width
|
| 35 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 36 |
+
if A > 50.0:
|
| 37 |
+
beta = 0.1102 * (A - 8.7)
|
| 38 |
+
elif A >= 21.0:
|
| 39 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
| 40 |
+
else:
|
| 41 |
+
beta = 0.0
|
| 42 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 43 |
+
|
| 44 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 45 |
+
if even:
|
| 46 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
| 47 |
+
else:
|
| 48 |
+
time = torch.arange(kernel_size) - half_size
|
| 49 |
+
if cutoff == 0:
|
| 50 |
+
filter_ = torch.zeros_like(time)
|
| 51 |
+
else:
|
| 52 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 53 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
| 54 |
+
# of the constant component in the input signal.
|
| 55 |
+
filter_ /= filter_.sum()
|
| 56 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 57 |
+
|
| 58 |
+
return filter
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LowPassFilter1d(nn.Module):
|
| 62 |
+
def __init__(
|
| 63 |
+
self,
|
| 64 |
+
cutoff=0.5,
|
| 65 |
+
half_width=0.6,
|
| 66 |
+
stride: int = 1,
|
| 67 |
+
padding: bool = True,
|
| 68 |
+
padding_mode: str = "replicate",
|
| 69 |
+
kernel_size: int = 12,
|
| 70 |
+
):
|
| 71 |
+
# kernel_size should be even number for stylegan3 setup,
|
| 72 |
+
# in this implementation, odd number is also possible.
|
| 73 |
+
super().__init__()
|
| 74 |
+
if cutoff < -0.0:
|
| 75 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 76 |
+
if cutoff > 0.5:
|
| 77 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 78 |
+
self.kernel_size = kernel_size
|
| 79 |
+
self.even = kernel_size % 2 == 0
|
| 80 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 81 |
+
self.pad_right = kernel_size // 2
|
| 82 |
+
self.stride = stride
|
| 83 |
+
self.padding = padding
|
| 84 |
+
self.padding_mode = padding_mode
|
| 85 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 86 |
+
self.register_buffer("filter", filter)
|
| 87 |
+
|
| 88 |
+
# input [B, C, T]
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
_, C, _ = x.shape
|
| 91 |
+
|
| 92 |
+
if self.padding:
|
| 93 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 94 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 95 |
+
|
| 96 |
+
return out
|
Amphion/models/svc/diffusion/diffusion_trainer.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from diffusers import DDPMScheduler
|
| 8 |
+
|
| 9 |
+
from models.svc.base import SVCTrainer
|
| 10 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
| 11 |
+
from .diffusion_wrapper import DiffusionWrapper
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DiffusionTrainer(SVCTrainer):
|
| 15 |
+
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
|
| 16 |
+
implements ``_build_model`` and ``_forward_step`` methods.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, args=None, cfg=None):
|
| 20 |
+
SVCTrainer.__init__(self, args, cfg)
|
| 21 |
+
|
| 22 |
+
# Only for SVC tasks using diffusion
|
| 23 |
+
self.noise_scheduler = DDPMScheduler(
|
| 24 |
+
**self.cfg.model.diffusion.scheduler_settings,
|
| 25 |
+
)
|
| 26 |
+
self.diffusion_timesteps = (
|
| 27 |
+
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
### Following are methods only for diffusion models ###
|
| 31 |
+
def _build_model(self):
|
| 32 |
+
r"""Build the model for training. This function is called in ``__init__`` function."""
|
| 33 |
+
|
| 34 |
+
# TODO: sort out the config
|
| 35 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
| 36 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
| 37 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
| 38 |
+
self.acoustic_mapper = DiffusionWrapper(self.cfg)
|
| 39 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
| 40 |
+
|
| 41 |
+
num_of_params_encoder = self.count_parameters(self.condition_encoder)
|
| 42 |
+
num_of_params_am = self.count_parameters(self.acoustic_mapper)
|
| 43 |
+
num_of_params = num_of_params_encoder + num_of_params_am
|
| 44 |
+
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
|
| 45 |
+
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
|
| 46 |
+
)
|
| 47 |
+
self.logger.info(log)
|
| 48 |
+
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
def count_parameters(self, model):
|
| 52 |
+
model_param = 0.0
|
| 53 |
+
if isinstance(model, dict):
|
| 54 |
+
for key, value in model.items():
|
| 55 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
| 56 |
+
else:
|
| 57 |
+
model_param = sum(p.numel() for p in model.parameters())
|
| 58 |
+
return model_param
|
| 59 |
+
|
| 60 |
+
def _check_nan(self, batch, loss, y_pred, y_gt):
|
| 61 |
+
if torch.any(torch.isnan(loss)):
|
| 62 |
+
for k, v in batch.items():
|
| 63 |
+
self.logger.info(k)
|
| 64 |
+
self.logger.info(v)
|
| 65 |
+
|
| 66 |
+
super()._check_nan(loss, y_pred, y_gt)
|
| 67 |
+
|
| 68 |
+
def _forward_step(self, batch):
|
| 69 |
+
r"""Forward step for training and inference. This function is called
|
| 70 |
+
in ``_train_step`` & ``_test_step`` function.
|
| 71 |
+
"""
|
| 72 |
+
device = self.accelerator.device
|
| 73 |
+
|
| 74 |
+
if self.online_features_extraction:
|
| 75 |
+
# On-the-fly features extraction
|
| 76 |
+
batch = self._extract_svc_features(batch)
|
| 77 |
+
|
| 78 |
+
# To debug
|
| 79 |
+
# for k, v in batch.items():
|
| 80 |
+
# print(k, v.shape, v)
|
| 81 |
+
# exit()
|
| 82 |
+
|
| 83 |
+
mel_input = batch["mel"]
|
| 84 |
+
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
|
| 85 |
+
batch_size = mel_input.size(0)
|
| 86 |
+
timesteps = torch.randint(
|
| 87 |
+
0,
|
| 88 |
+
self.diffusion_timesteps,
|
| 89 |
+
(batch_size,),
|
| 90 |
+
device=device,
|
| 91 |
+
dtype=torch.long,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
|
| 95 |
+
conditioner = self.condition_encoder(batch)
|
| 96 |
+
|
| 97 |
+
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
|
| 98 |
+
|
| 99 |
+
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
|
| 100 |
+
self._check_nan(batch, loss, y_pred, noise)
|
| 101 |
+
|
| 102 |
+
return loss
|
Amphion/models/tta/autoencoder/__init__.py
ADDED
|
File without changes
|
Amphion/models/tta/autoencoder/autoencoder_loss.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import functools
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def hinge_d_loss(logits_real, logits_fake):
|
| 13 |
+
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
| 14 |
+
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
| 15 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
| 16 |
+
return d_loss
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
| 20 |
+
d_loss = 0.5 * (
|
| 21 |
+
torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake))
|
| 22 |
+
)
|
| 23 |
+
return d_loss
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
| 27 |
+
if global_step < threshold:
|
| 28 |
+
weight = value
|
| 29 |
+
return weight
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ActNorm(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
| 35 |
+
):
|
| 36 |
+
assert affine
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.logdet = logdet
|
| 39 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
| 40 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
| 41 |
+
self.allow_reverse_init = allow_reverse_init
|
| 42 |
+
|
| 43 |
+
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
| 44 |
+
|
| 45 |
+
def initialize(self, input):
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
| 48 |
+
mean = (
|
| 49 |
+
flatten.mean(1)
|
| 50 |
+
.unsqueeze(1)
|
| 51 |
+
.unsqueeze(2)
|
| 52 |
+
.unsqueeze(3)
|
| 53 |
+
.permute(1, 0, 2, 3)
|
| 54 |
+
)
|
| 55 |
+
std = (
|
| 56 |
+
flatten.std(1)
|
| 57 |
+
.unsqueeze(1)
|
| 58 |
+
.unsqueeze(2)
|
| 59 |
+
.unsqueeze(3)
|
| 60 |
+
.permute(1, 0, 2, 3)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
self.loc.data.copy_(-mean)
|
| 64 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
| 65 |
+
|
| 66 |
+
def forward(self, input, reverse=False):
|
| 67 |
+
if reverse:
|
| 68 |
+
return self.reverse(input)
|
| 69 |
+
if len(input.shape) == 2:
|
| 70 |
+
input = input[:, :, None, None]
|
| 71 |
+
squeeze = True
|
| 72 |
+
else:
|
| 73 |
+
squeeze = False
|
| 74 |
+
|
| 75 |
+
_, _, height, width = input.shape
|
| 76 |
+
|
| 77 |
+
if self.training and self.initialized.item() == 0:
|
| 78 |
+
self.initialize(input)
|
| 79 |
+
self.initialized.fill_(1)
|
| 80 |
+
|
| 81 |
+
h = self.scale * (input + self.loc)
|
| 82 |
+
|
| 83 |
+
if squeeze:
|
| 84 |
+
h = h.squeeze(-1).squeeze(-1)
|
| 85 |
+
|
| 86 |
+
if self.logdet:
|
| 87 |
+
log_abs = torch.log(torch.abs(self.scale))
|
| 88 |
+
logdet = height * width * torch.sum(log_abs)
|
| 89 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
| 90 |
+
return h, logdet
|
| 91 |
+
|
| 92 |
+
return h
|
| 93 |
+
|
| 94 |
+
def reverse(self, output):
|
| 95 |
+
if self.training and self.initialized.item() == 0:
|
| 96 |
+
if not self.allow_reverse_init:
|
| 97 |
+
raise RuntimeError(
|
| 98 |
+
"Initializing ActNorm in reverse direction is "
|
| 99 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
self.initialize(output)
|
| 103 |
+
self.initialized.fill_(1)
|
| 104 |
+
|
| 105 |
+
if len(output.shape) == 2:
|
| 106 |
+
output = output[:, :, None, None]
|
| 107 |
+
squeeze = True
|
| 108 |
+
else:
|
| 109 |
+
squeeze = False
|
| 110 |
+
|
| 111 |
+
h = output / self.scale - self.loc
|
| 112 |
+
|
| 113 |
+
if squeeze:
|
| 114 |
+
h = h.squeeze(-1).squeeze(-1)
|
| 115 |
+
return h
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def weights_init(m):
|
| 119 |
+
classname = m.__class__.__name__
|
| 120 |
+
if classname.find("Conv") != -1:
|
| 121 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
| 122 |
+
elif classname.find("BatchNorm") != -1:
|
| 123 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 124 |
+
nn.init.constant_(m.bias.data, 0)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class NLayerDiscriminator(nn.Module):
|
| 128 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
| 129 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
| 133 |
+
"""Construct a PatchGAN discriminator
|
| 134 |
+
Parameters:
|
| 135 |
+
input_nc (int) -- the number of channels in input images
|
| 136 |
+
ndf (int) -- the number of filters in the last conv layer
|
| 137 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
| 138 |
+
norm_layer -- normalization layer
|
| 139 |
+
"""
|
| 140 |
+
super(NLayerDiscriminator, self).__init__()
|
| 141 |
+
if not use_actnorm:
|
| 142 |
+
norm_layer = nn.BatchNorm2d
|
| 143 |
+
else:
|
| 144 |
+
norm_layer = ActNorm
|
| 145 |
+
if (
|
| 146 |
+
type(norm_layer) == functools.partial
|
| 147 |
+
): # no need to use bias as BatchNorm2d has affine parameters
|
| 148 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
| 149 |
+
else:
|
| 150 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
| 151 |
+
|
| 152 |
+
kw = 4
|
| 153 |
+
padw = 1
|
| 154 |
+
sequence = [
|
| 155 |
+
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
| 156 |
+
nn.LeakyReLU(0.2, True),
|
| 157 |
+
]
|
| 158 |
+
nf_mult = 1
|
| 159 |
+
nf_mult_prev = 1
|
| 160 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
| 161 |
+
nf_mult_prev = nf_mult
|
| 162 |
+
nf_mult = min(2**n, 8)
|
| 163 |
+
sequence += [
|
| 164 |
+
nn.Conv2d(
|
| 165 |
+
ndf * nf_mult_prev,
|
| 166 |
+
ndf * nf_mult,
|
| 167 |
+
kernel_size=kw,
|
| 168 |
+
stride=2,
|
| 169 |
+
padding=padw,
|
| 170 |
+
bias=use_bias,
|
| 171 |
+
),
|
| 172 |
+
norm_layer(ndf * nf_mult),
|
| 173 |
+
nn.LeakyReLU(0.2, True),
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
nf_mult_prev = nf_mult
|
| 177 |
+
nf_mult = min(2**n_layers, 8)
|
| 178 |
+
sequence += [
|
| 179 |
+
nn.Conv2d(
|
| 180 |
+
ndf * nf_mult_prev,
|
| 181 |
+
ndf * nf_mult,
|
| 182 |
+
kernel_size=kw,
|
| 183 |
+
stride=1,
|
| 184 |
+
padding=padw,
|
| 185 |
+
bias=use_bias,
|
| 186 |
+
),
|
| 187 |
+
norm_layer(ndf * nf_mult),
|
| 188 |
+
nn.LeakyReLU(0.2, True),
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
sequence += [
|
| 192 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
| 193 |
+
] # output 1 channel prediction map
|
| 194 |
+
self.main = nn.Sequential(*sequence)
|
| 195 |
+
|
| 196 |
+
def forward(self, input):
|
| 197 |
+
"""Standard forward."""
|
| 198 |
+
return self.main(input)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class AutoencoderLossWithDiscriminator(nn.Module):
|
| 202 |
+
def __init__(self, cfg):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.cfg = cfg
|
| 205 |
+
self.kl_weight = cfg.kl_weight
|
| 206 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init)
|
| 207 |
+
|
| 208 |
+
self.discriminator = NLayerDiscriminator(
|
| 209 |
+
input_nc=cfg.disc_in_channels,
|
| 210 |
+
n_layers=cfg.disc_num_layers,
|
| 211 |
+
use_actnorm=cfg.use_actnorm,
|
| 212 |
+
).apply(weights_init)
|
| 213 |
+
|
| 214 |
+
self.discriminator_iter_start = cfg.disc_start
|
| 215 |
+
self.discriminator_weight = cfg.disc_weight
|
| 216 |
+
self.disc_factor = cfg.disc_factor
|
| 217 |
+
self.disc_loss = hinge_d_loss
|
| 218 |
+
|
| 219 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
|
| 220 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
| 221 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
| 222 |
+
|
| 223 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
| 224 |
+
d_weight = torch.clamp(
|
| 225 |
+
d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight
|
| 226 |
+
).detach()
|
| 227 |
+
d_weight = d_weight * self.discriminator_weight
|
| 228 |
+
return d_weight
|
| 229 |
+
|
| 230 |
+
def forward(
|
| 231 |
+
self,
|
| 232 |
+
inputs,
|
| 233 |
+
reconstructions,
|
| 234 |
+
posteriors,
|
| 235 |
+
optimizer_idx,
|
| 236 |
+
global_step,
|
| 237 |
+
last_layer,
|
| 238 |
+
split="train",
|
| 239 |
+
weights=None,
|
| 240 |
+
):
|
| 241 |
+
rec_loss = torch.abs(
|
| 242 |
+
inputs.contiguous() - reconstructions.contiguous()
|
| 243 |
+
) # l1 loss
|
| 244 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
| 245 |
+
weighted_nll_loss = nll_loss
|
| 246 |
+
if weights is not None:
|
| 247 |
+
weighted_nll_loss = weights * nll_loss
|
| 248 |
+
# weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
| 249 |
+
weighted_nll_loss = torch.mean(weighted_nll_loss)
|
| 250 |
+
# nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
| 251 |
+
nll_loss = torch.mean(nll_loss)
|
| 252 |
+
kl_loss = posteriors.kl()
|
| 253 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
| 254 |
+
# ? kl_loss = torch.mean(kl_loss)
|
| 255 |
+
|
| 256 |
+
# now the GAN part
|
| 257 |
+
if optimizer_idx == 0:
|
| 258 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
| 259 |
+
g_loss = -torch.mean(logits_fake)
|
| 260 |
+
|
| 261 |
+
if self.disc_factor > 0.0:
|
| 262 |
+
try:
|
| 263 |
+
d_weight = self.calculate_adaptive_weight(
|
| 264 |
+
nll_loss, g_loss, last_layer=last_layer
|
| 265 |
+
)
|
| 266 |
+
except RuntimeError:
|
| 267 |
+
assert not self.training
|
| 268 |
+
d_weight = torch.tensor(0.0)
|
| 269 |
+
else:
|
| 270 |
+
d_weight = torch.tensor(0.0)
|
| 271 |
+
|
| 272 |
+
disc_factor = adopt_weight(
|
| 273 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
total_loss = (
|
| 277 |
+
weighted_nll_loss
|
| 278 |
+
+ self.kl_weight * kl_loss
|
| 279 |
+
+ d_weight * disc_factor * g_loss
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"loss": total_loss,
|
| 284 |
+
"kl_loss": kl_loss,
|
| 285 |
+
"rec_loss": rec_loss.mean(),
|
| 286 |
+
"nll_loss": nll_loss,
|
| 287 |
+
"g_loss": g_loss,
|
| 288 |
+
"d_weight": d_weight,
|
| 289 |
+
"disc_factor": torch.tensor(disc_factor),
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if optimizer_idx == 1:
|
| 293 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
| 294 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
| 295 |
+
|
| 296 |
+
disc_factor = adopt_weight(
|
| 297 |
+
self.disc_factor, global_step, threshold=self.discriminator_iter_start
|
| 298 |
+
)
|
| 299 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
"d_loss": d_loss,
|
| 303 |
+
"logits_real": logits_real.mean(),
|
| 304 |
+
"logits_fake": logits_fake.mean(),
|
| 305 |
+
}
|
Amphion/models/tta/ldm/audioldm_inference.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
| 16 |
+
from models.tta.ldm.inference_utils.vocoder import Generator
|
| 17 |
+
from models.tta.ldm.audioldm import AudioLDM
|
| 18 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
| 19 |
+
from diffusers import PNDMScheduler
|
| 20 |
+
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
from scipy.io.wavfile import write
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AttrDict(dict):
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 28 |
+
self.__dict__ = self
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class AudioLDMInference:
|
| 32 |
+
def __init__(self, args, cfg):
|
| 33 |
+
self.cfg = cfg
|
| 34 |
+
self.args = args
|
| 35 |
+
|
| 36 |
+
self.build_autoencoderkl()
|
| 37 |
+
self.build_textencoder()
|
| 38 |
+
|
| 39 |
+
self.model = self.build_model()
|
| 40 |
+
self.load_state_dict()
|
| 41 |
+
|
| 42 |
+
self.build_vocoder()
|
| 43 |
+
|
| 44 |
+
self.out_path = self.args.output_dir
|
| 45 |
+
self.out_mel_path = os.path.join(self.out_path, "mel")
|
| 46 |
+
self.out_wav_path = os.path.join(self.out_path, "wav")
|
| 47 |
+
os.makedirs(self.out_mel_path, exist_ok=True)
|
| 48 |
+
os.makedirs(self.out_wav_path, exist_ok=True)
|
| 49 |
+
|
| 50 |
+
def build_autoencoderkl(self):
|
| 51 |
+
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
|
| 52 |
+
self.autoencoder_path = self.cfg.model.autoencoder_path
|
| 53 |
+
checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
|
| 54 |
+
self.autoencoderkl.load_state_dict(checkpoint["model"])
|
| 55 |
+
self.autoencoderkl.cuda(self.args.local_rank)
|
| 56 |
+
self.autoencoderkl.requires_grad_(requires_grad=False)
|
| 57 |
+
self.autoencoderkl.eval()
|
| 58 |
+
|
| 59 |
+
def build_textencoder(self):
|
| 60 |
+
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
| 61 |
+
self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
| 62 |
+
self.text_encoder.cuda(self.args.local_rank)
|
| 63 |
+
self.text_encoder.requires_grad_(requires_grad=False)
|
| 64 |
+
self.text_encoder.eval()
|
| 65 |
+
|
| 66 |
+
def build_vocoder(self):
|
| 67 |
+
config_file = os.path.join(self.args.vocoder_config_path)
|
| 68 |
+
with open(config_file) as f:
|
| 69 |
+
data = f.read()
|
| 70 |
+
json_config = json.loads(data)
|
| 71 |
+
h = AttrDict(json_config)
|
| 72 |
+
self.vocoder = Generator(h).to(self.args.local_rank)
|
| 73 |
+
checkpoint_dict = torch.load(
|
| 74 |
+
self.args.vocoder_path, map_location=self.args.local_rank
|
| 75 |
+
)
|
| 76 |
+
self.vocoder.load_state_dict(checkpoint_dict["generator"])
|
| 77 |
+
|
| 78 |
+
def build_model(self):
|
| 79 |
+
self.model = AudioLDM(self.cfg.model.audioldm)
|
| 80 |
+
return self.model
|
| 81 |
+
|
| 82 |
+
def load_state_dict(self):
|
| 83 |
+
self.checkpoint_path = self.args.checkpoint_path
|
| 84 |
+
checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
|
| 85 |
+
self.model.load_state_dict(checkpoint["model"])
|
| 86 |
+
self.model.cuda(self.args.local_rank)
|
| 87 |
+
|
| 88 |
+
def get_text_embedding(self):
|
| 89 |
+
text = self.args.text
|
| 90 |
+
|
| 91 |
+
prompt = [text]
|
| 92 |
+
|
| 93 |
+
text_input = self.tokenizer(
|
| 94 |
+
prompt,
|
| 95 |
+
max_length=self.tokenizer.model_max_length,
|
| 96 |
+
truncation=True,
|
| 97 |
+
padding="do_not_pad",
|
| 98 |
+
return_tensors="pt",
|
| 99 |
+
)
|
| 100 |
+
text_embeddings = self.text_encoder(
|
| 101 |
+
text_input.input_ids.to(self.args.local_rank)
|
| 102 |
+
)[0]
|
| 103 |
+
|
| 104 |
+
max_length = text_input.input_ids.shape[-1]
|
| 105 |
+
uncond_input = self.tokenizer(
|
| 106 |
+
[""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 107 |
+
)
|
| 108 |
+
uncond_embeddings = self.text_encoder(
|
| 109 |
+
uncond_input.input_ids.to(self.args.local_rank)
|
| 110 |
+
)[0]
|
| 111 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 112 |
+
|
| 113 |
+
return text_embeddings
|
| 114 |
+
|
| 115 |
+
def inference(self):
|
| 116 |
+
text_embeddings = self.get_text_embedding()
|
| 117 |
+
print(text_embeddings.shape)
|
| 118 |
+
|
| 119 |
+
num_steps = self.args.num_steps
|
| 120 |
+
guidance_scale = self.args.guidance_scale
|
| 121 |
+
|
| 122 |
+
noise_scheduler = PNDMScheduler(
|
| 123 |
+
num_train_timesteps=1000,
|
| 124 |
+
beta_start=0.00085,
|
| 125 |
+
beta_end=0.012,
|
| 126 |
+
beta_schedule="scaled_linear",
|
| 127 |
+
skip_prk_steps=True,
|
| 128 |
+
set_alpha_to_one=False,
|
| 129 |
+
steps_offset=1,
|
| 130 |
+
prediction_type="epsilon",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
noise_scheduler.set_timesteps(num_steps)
|
| 134 |
+
|
| 135 |
+
latents = torch.randn(
|
| 136 |
+
(
|
| 137 |
+
1,
|
| 138 |
+
self.cfg.model.autoencoderkl.z_channels,
|
| 139 |
+
80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
|
| 140 |
+
624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
|
| 141 |
+
)
|
| 142 |
+
).to(self.args.local_rank)
|
| 143 |
+
|
| 144 |
+
self.model.eval()
|
| 145 |
+
for t in tqdm(noise_scheduler.timesteps):
|
| 146 |
+
t = t.to(self.args.local_rank)
|
| 147 |
+
|
| 148 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
| 149 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 150 |
+
|
| 151 |
+
latent_model_input = noise_scheduler.scale_model_input(
|
| 152 |
+
latent_model_input, timestep=t
|
| 153 |
+
)
|
| 154 |
+
# print(latent_model_input.shape)
|
| 155 |
+
|
| 156 |
+
# predict the noise residual
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
noise_pred = self.model(
|
| 159 |
+
latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# perform guidance
|
| 163 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 164 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 165 |
+
noise_pred_text - noise_pred_uncond
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 169 |
+
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
| 170 |
+
# print(latents.shape)
|
| 171 |
+
|
| 172 |
+
latents_out = latents
|
| 173 |
+
print(latents_out.shape)
|
| 174 |
+
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
mel_out = self.autoencoderkl.decode(latents_out)
|
| 177 |
+
print(mel_out.shape)
|
| 178 |
+
|
| 179 |
+
melspec = mel_out[0, 0].cpu().detach().numpy()
|
| 180 |
+
plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec)
|
| 181 |
+
|
| 182 |
+
self.vocoder.eval()
|
| 183 |
+
self.vocoder.remove_weight_norm()
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
melspec = np.expand_dims(melspec, 0)
|
| 186 |
+
melspec = torch.FloatTensor(melspec).to(self.args.local_rank)
|
| 187 |
+
|
| 188 |
+
y = self.vocoder(melspec)
|
| 189 |
+
audio = y.squeeze()
|
| 190 |
+
audio = audio * 32768.0
|
| 191 |
+
audio = audio.cpu().numpy().astype("int16")
|
| 192 |
+
|
| 193 |
+
write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio)
|
Amphion/models/tts/base/tts_inferece.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
import time
|
| 9 |
+
import accelerate
|
| 10 |
+
import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from accelerate.logging import get_logger
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
from safetensors.torch import load_file
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from abc import abstractmethod
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from utils.io import save_audio
|
| 21 |
+
from utils.util import load_config
|
| 22 |
+
from models.vocoders.vocoder_inference import synthesis
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TTSInference(object):
|
| 26 |
+
def __init__(self, args=None, cfg=None):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
start = time.monotonic_ns()
|
| 30 |
+
self.args = args
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
self.infer_type = args.mode
|
| 33 |
+
|
| 34 |
+
# get exp_dir
|
| 35 |
+
if self.args.acoustics_dir is not None:
|
| 36 |
+
self.exp_dir = self.args.acoustics_dir
|
| 37 |
+
elif self.args.checkpoint_path is not None:
|
| 38 |
+
self.exp_dir = os.path.dirname(os.path.dirname(self.args.checkpoint_path))
|
| 39 |
+
|
| 40 |
+
# Init accelerator
|
| 41 |
+
self.accelerator = accelerate.Accelerator()
|
| 42 |
+
self.accelerator.wait_for_everyone()
|
| 43 |
+
self.device = self.accelerator.device
|
| 44 |
+
|
| 45 |
+
# Get logger
|
| 46 |
+
with self.accelerator.main_process_first():
|
| 47 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
| 48 |
+
|
| 49 |
+
# Log some info
|
| 50 |
+
self.logger.info("=" * 56)
|
| 51 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
| 52 |
+
self.logger.info("=" * 56)
|
| 53 |
+
self.logger.info("\n")
|
| 54 |
+
|
| 55 |
+
self.acoustic_model_dir = args.acoustics_dir
|
| 56 |
+
self.logger.debug(f"Acoustic model dir: {args.acoustics_dir}")
|
| 57 |
+
|
| 58 |
+
if args.vocoder_dir is not None:
|
| 59 |
+
self.vocoder_dir = args.vocoder_dir
|
| 60 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
| 61 |
+
|
| 62 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
# Set random seed
|
| 65 |
+
with self.accelerator.main_process_first():
|
| 66 |
+
start = time.monotonic_ns()
|
| 67 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
| 68 |
+
end = time.monotonic_ns()
|
| 69 |
+
self.logger.debug(
|
| 70 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
| 71 |
+
)
|
| 72 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
| 73 |
+
|
| 74 |
+
# Setup data loader
|
| 75 |
+
if self.infer_type == "batch":
|
| 76 |
+
with self.accelerator.main_process_first():
|
| 77 |
+
self.logger.info("Building dataset...")
|
| 78 |
+
start = time.monotonic_ns()
|
| 79 |
+
self.test_dataloader = self._build_test_dataloader()
|
| 80 |
+
end = time.monotonic_ns()
|
| 81 |
+
self.logger.info(
|
| 82 |
+
f"Building dataset done in {(end - start) / 1e6:.2f}ms"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Build model
|
| 86 |
+
with self.accelerator.main_process_first():
|
| 87 |
+
self.logger.info("Building model...")
|
| 88 |
+
start = time.monotonic_ns()
|
| 89 |
+
self.model = self._build_model()
|
| 90 |
+
end = time.monotonic_ns()
|
| 91 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
| 92 |
+
|
| 93 |
+
# Init with accelerate
|
| 94 |
+
self.logger.info("Initializing accelerate...")
|
| 95 |
+
start = time.monotonic_ns()
|
| 96 |
+
self.accelerator = accelerate.Accelerator()
|
| 97 |
+
self.model = self.accelerator.prepare(self.model)
|
| 98 |
+
if self.infer_type == "batch":
|
| 99 |
+
self.test_dataloader = self.accelerator.prepare(self.test_dataloader)
|
| 100 |
+
end = time.monotonic_ns()
|
| 101 |
+
self.accelerator.wait_for_everyone()
|
| 102 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
| 103 |
+
|
| 104 |
+
with self.accelerator.main_process_first():
|
| 105 |
+
self.logger.info("Loading checkpoint...")
|
| 106 |
+
start = time.monotonic_ns()
|
| 107 |
+
if args.acoustics_dir is not None:
|
| 108 |
+
self._load_model(
|
| 109 |
+
checkpoint_dir=os.path.join(args.acoustics_dir, "checkpoint")
|
| 110 |
+
)
|
| 111 |
+
elif args.checkpoint_path is not None:
|
| 112 |
+
self._load_model(checkpoint_path=args.checkpoint_path)
|
| 113 |
+
else:
|
| 114 |
+
print("Either checkpoint dir or checkpoint path should be provided.")
|
| 115 |
+
|
| 116 |
+
end = time.monotonic_ns()
|
| 117 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
| 118 |
+
|
| 119 |
+
self.model.eval()
|
| 120 |
+
self.accelerator.wait_for_everyone()
|
| 121 |
+
|
| 122 |
+
def _build_test_dataset(self):
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
def _build_model(self):
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
# TODO: LEGACY CODE
|
| 129 |
+
def _build_test_dataloader(self):
|
| 130 |
+
datasets, collate = self._build_test_dataset()
|
| 131 |
+
self.test_dataset = datasets(self.args, self.cfg)
|
| 132 |
+
self.test_collate = collate(self.cfg)
|
| 133 |
+
self.test_batch_size = min(
|
| 134 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
| 135 |
+
)
|
| 136 |
+
test_dataloader = DataLoader(
|
| 137 |
+
self.test_dataset,
|
| 138 |
+
collate_fn=self.test_collate,
|
| 139 |
+
num_workers=1,
|
| 140 |
+
batch_size=self.test_batch_size,
|
| 141 |
+
shuffle=False,
|
| 142 |
+
)
|
| 143 |
+
return test_dataloader
|
| 144 |
+
|
| 145 |
+
def _load_model(
|
| 146 |
+
self,
|
| 147 |
+
checkpoint_dir: str = None,
|
| 148 |
+
checkpoint_path: str = None,
|
| 149 |
+
old_mode: bool = False,
|
| 150 |
+
):
|
| 151 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
| 152 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
| 153 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
| 154 |
+
method after** ``accelerator.prepare()``.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
if checkpoint_path is None:
|
| 158 |
+
assert checkpoint_dir is not None
|
| 159 |
+
# Load the latest accelerator state dicts
|
| 160 |
+
ls = [
|
| 161 |
+
str(i) for i in Path(checkpoint_dir).glob("*") if not "audio" in str(i)
|
| 162 |
+
]
|
| 163 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
| 164 |
+
checkpoint_path = ls[0]
|
| 165 |
+
|
| 166 |
+
if (
|
| 167 |
+
Path(os.path.join(checkpoint_path, "model.safetensors")).exists()
|
| 168 |
+
and accelerate.__version__ < "0.25"
|
| 169 |
+
):
|
| 170 |
+
self.model.load_state_dict(
|
| 171 |
+
load_file(os.path.join(checkpoint_path, "model.safetensors")),
|
| 172 |
+
strict=False,
|
| 173 |
+
)
|
| 174 |
+
else:
|
| 175 |
+
self.accelerator.load_state(str(checkpoint_path))
|
| 176 |
+
return str(checkpoint_path)
|
| 177 |
+
|
| 178 |
+
def inference(self):
|
| 179 |
+
if self.infer_type == "single":
|
| 180 |
+
out_dir = os.path.join(self.args.output_dir, "single")
|
| 181 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 182 |
+
|
| 183 |
+
pred_audio = self.inference_for_single_utterance()
|
| 184 |
+
save_path = os.path.join(out_dir, "test_pred.wav")
|
| 185 |
+
save_audio(save_path, pred_audio, self.cfg.preprocess.sample_rate)
|
| 186 |
+
|
| 187 |
+
elif self.infer_type == "batch":
|
| 188 |
+
out_dir = os.path.join(self.args.output_dir, "batch")
|
| 189 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 190 |
+
|
| 191 |
+
pred_audio_list = self.inference_for_batches()
|
| 192 |
+
for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
|
| 193 |
+
uid = it["Uid"]
|
| 194 |
+
save_audio(
|
| 195 |
+
os.path.join(out_dir, f"{uid}.wav"),
|
| 196 |
+
wav.numpy(),
|
| 197 |
+
self.cfg.preprocess.sample_rate,
|
| 198 |
+
add_silence=True,
|
| 199 |
+
turn_up=True,
|
| 200 |
+
)
|
| 201 |
+
tmp_file = os.path.join(out_dir, f"{uid}.pt")
|
| 202 |
+
if os.path.exists(tmp_file):
|
| 203 |
+
os.remove(tmp_file)
|
| 204 |
+
print("Saved to: ", out_dir)
|
| 205 |
+
|
| 206 |
+
@torch.inference_mode()
|
| 207 |
+
def inference_for_batches(self):
|
| 208 |
+
y_pred = []
|
| 209 |
+
for i, batch in tqdm(enumerate(self.test_dataloader)):
|
| 210 |
+
y_pred, mel_lens, _ = self._inference_each_batch(batch)
|
| 211 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
| 212 |
+
tgt_ls = mel_lens.chunk(self.test_batch_size)
|
| 213 |
+
j = 0
|
| 214 |
+
for it, l in zip(y_ls, tgt_ls):
|
| 215 |
+
l = l.item()
|
| 216 |
+
it = it.squeeze(0)[:l].detach().cpu()
|
| 217 |
+
|
| 218 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
| 219 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
| 220 |
+
j += 1
|
| 221 |
+
|
| 222 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
| 223 |
+
res = synthesis(
|
| 224 |
+
cfg=vocoder_cfg,
|
| 225 |
+
vocoder_weight_file=vocoder_ckpt,
|
| 226 |
+
n_samples=None,
|
| 227 |
+
pred=[
|
| 228 |
+
torch.load(
|
| 229 |
+
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
|
| 230 |
+
).numpy()
|
| 231 |
+
for item in self.test_dataset.metadata
|
| 232 |
+
],
|
| 233 |
+
)
|
| 234 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
| 235 |
+
uid = it["Uid"]
|
| 236 |
+
save_audio(
|
| 237 |
+
os.path.join(self.args.output_dir, f"{uid}.wav"),
|
| 238 |
+
wav.numpy(),
|
| 239 |
+
22050,
|
| 240 |
+
add_silence=True,
|
| 241 |
+
turn_up=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
@abstractmethod
|
| 245 |
+
@torch.inference_mode()
|
| 246 |
+
def _inference_each_batch(self, batch_data):
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
def inference_for_single_utterance(self, text):
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
def synthesis_by_vocoder(self, pred):
|
| 253 |
+
audios_pred = synthesis(
|
| 254 |
+
self.vocoder_cfg,
|
| 255 |
+
self.checkpoint_dir_vocoder,
|
| 256 |
+
len(pred),
|
| 257 |
+
pred,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return audios_pred
|
| 261 |
+
|
| 262 |
+
@staticmethod
|
| 263 |
+
def _parse_vocoder(vocoder_dir):
|
| 264 |
+
r"""Parse vocoder config"""
|
| 265 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
| 266 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
| 267 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
| 268 |
+
ckpt_path = str(ckpt_list[0])
|
| 269 |
+
vocoder_cfg = load_config(
|
| 270 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
| 271 |
+
)
|
| 272 |
+
return vocoder_cfg, ckpt_path
|
| 273 |
+
|
| 274 |
+
def _set_random_seed(self, seed):
|
| 275 |
+
"""Set random seed for all possible random modules."""
|
| 276 |
+
random.seed(seed)
|
| 277 |
+
np.random.seed(seed)
|
| 278 |
+
torch.random.manual_seed(seed)
|
Amphion/models/tts/fastspeech2/fs2_dataset.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from utils.data_utils import *
|
| 10 |
+
from models.base.base_dataset import (
|
| 11 |
+
BaseOfflineCollator,
|
| 12 |
+
BaseOfflineDataset,
|
| 13 |
+
BaseTestDataset,
|
| 14 |
+
BaseTestCollator,
|
| 15 |
+
)
|
| 16 |
+
from text import text_to_sequence
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FS2Dataset(BaseOfflineDataset):
|
| 20 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 21 |
+
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
| 22 |
+
self.batch_size = cfg.train.batch_size
|
| 23 |
+
cfg = cfg.preprocess
|
| 24 |
+
# utt2duration
|
| 25 |
+
self.utt2duration_path = {}
|
| 26 |
+
for utt_info in self.metadata:
|
| 27 |
+
dataset = utt_info["Dataset"]
|
| 28 |
+
uid = utt_info["Uid"]
|
| 29 |
+
utt = "{}_{}".format(dataset, uid)
|
| 30 |
+
|
| 31 |
+
self.utt2duration_path[utt] = os.path.join(
|
| 32 |
+
cfg.processed_dir,
|
| 33 |
+
dataset,
|
| 34 |
+
cfg.duration_dir,
|
| 35 |
+
uid + ".npy",
|
| 36 |
+
)
|
| 37 |
+
self.utt2dur = self.read_duration()
|
| 38 |
+
|
| 39 |
+
if cfg.use_frame_energy:
|
| 40 |
+
self.frame_utt2energy, self.energy_statistic = load_energy(
|
| 41 |
+
self.metadata,
|
| 42 |
+
cfg.processed_dir,
|
| 43 |
+
cfg.energy_dir,
|
| 44 |
+
use_log_scale=cfg.use_log_scale_energy,
|
| 45 |
+
utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None,
|
| 46 |
+
return_norm=True,
|
| 47 |
+
)
|
| 48 |
+
elif cfg.use_phone_energy:
|
| 49 |
+
self.phone_utt2energy, self.energy_statistic = load_energy(
|
| 50 |
+
self.metadata,
|
| 51 |
+
cfg.processed_dir,
|
| 52 |
+
cfg.phone_energy_dir,
|
| 53 |
+
use_log_scale=cfg.use_log_scale_energy,
|
| 54 |
+
utt2spk=self.utt2spk if cfg.use_spkid else None,
|
| 55 |
+
return_norm=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if cfg.use_frame_pitch:
|
| 59 |
+
self.frame_utt2pitch, self.pitch_statistic = load_energy(
|
| 60 |
+
self.metadata,
|
| 61 |
+
cfg.processed_dir,
|
| 62 |
+
cfg.pitch_dir,
|
| 63 |
+
use_log_scale=cfg.energy_extract_mode,
|
| 64 |
+
utt2spk=self.utt2spk if cfg.use_spkid else None,
|
| 65 |
+
return_norm=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
elif cfg.use_phone_pitch:
|
| 69 |
+
self.phone_utt2pitch, self.pitch_statistic = load_energy(
|
| 70 |
+
self.metadata,
|
| 71 |
+
cfg.processed_dir,
|
| 72 |
+
cfg.phone_pitch_dir,
|
| 73 |
+
use_log_scale=cfg.use_log_scale_pitch,
|
| 74 |
+
utt2spk=self.utt2spk if cfg.use_spkid else None,
|
| 75 |
+
return_norm=True,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# utt2lab
|
| 79 |
+
self.utt2lab_path = {}
|
| 80 |
+
for utt_info in self.metadata:
|
| 81 |
+
dataset = utt_info["Dataset"]
|
| 82 |
+
uid = utt_info["Uid"]
|
| 83 |
+
utt = "{}_{}".format(dataset, uid)
|
| 84 |
+
|
| 85 |
+
self.utt2lab_path[utt] = os.path.join(
|
| 86 |
+
cfg.processed_dir,
|
| 87 |
+
dataset,
|
| 88 |
+
cfg.lab_dir,
|
| 89 |
+
uid + ".txt",
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.speaker_map = {}
|
| 93 |
+
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
|
| 94 |
+
with open(
|
| 95 |
+
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
|
| 96 |
+
) as f:
|
| 97 |
+
self.speaker_map = json.load(f)
|
| 98 |
+
|
| 99 |
+
self.metadata = self.check_metadata()
|
| 100 |
+
|
| 101 |
+
def __getitem__(self, index):
|
| 102 |
+
single_feature = BaseOfflineDataset.__getitem__(self, index)
|
| 103 |
+
|
| 104 |
+
utt_info = self.metadata[index]
|
| 105 |
+
dataset = utt_info["Dataset"]
|
| 106 |
+
uid = utt_info["Uid"]
|
| 107 |
+
utt = "{}_{}".format(dataset, uid)
|
| 108 |
+
|
| 109 |
+
duration = self.utt2dur[utt]
|
| 110 |
+
|
| 111 |
+
# text
|
| 112 |
+
f = open(self.utt2lab_path[utt], "r")
|
| 113 |
+
phones = f.readlines()[0].strip()
|
| 114 |
+
f.close()
|
| 115 |
+
# todo: add cleaner(chenxi)
|
| 116 |
+
phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"]))
|
| 117 |
+
text_len = len(phones_ids)
|
| 118 |
+
|
| 119 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 120 |
+
pitch = self.frame_utt2pitch[utt]
|
| 121 |
+
elif self.cfg.preprocess.use_phone_pitch:
|
| 122 |
+
pitch = self.phone_utt2pitch[utt]
|
| 123 |
+
|
| 124 |
+
if self.cfg.preprocess.use_frame_energy:
|
| 125 |
+
energy = self.frame_utt2energy[utt]
|
| 126 |
+
elif self.cfg.preprocess.use_phone_energy:
|
| 127 |
+
energy = self.phone_utt2energy[utt]
|
| 128 |
+
|
| 129 |
+
# speaker
|
| 130 |
+
if len(self.speaker_map) > 0:
|
| 131 |
+
speaker_id = self.speaker_map[utt_info["Singer"]]
|
| 132 |
+
else:
|
| 133 |
+
speaker_id = 0
|
| 134 |
+
|
| 135 |
+
single_feature.update(
|
| 136 |
+
{
|
| 137 |
+
"durations": duration,
|
| 138 |
+
"texts": phones_ids,
|
| 139 |
+
"spk_id": speaker_id,
|
| 140 |
+
"text_len": text_len,
|
| 141 |
+
"pitch": pitch,
|
| 142 |
+
"energy": energy,
|
| 143 |
+
"uid": uid,
|
| 144 |
+
}
|
| 145 |
+
)
|
| 146 |
+
return self.clip_if_too_long(single_feature)
|
| 147 |
+
|
| 148 |
+
def read_duration(self):
|
| 149 |
+
# read duration
|
| 150 |
+
utt2dur = {}
|
| 151 |
+
for index in range(len(self.metadata)):
|
| 152 |
+
utt_info = self.metadata[index]
|
| 153 |
+
dataset = utt_info["Dataset"]
|
| 154 |
+
uid = utt_info["Uid"]
|
| 155 |
+
utt = "{}_{}".format(dataset, uid)
|
| 156 |
+
|
| 157 |
+
if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists(
|
| 158 |
+
self.utt2duration_path[utt]
|
| 159 |
+
):
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
mel = np.load(self.utt2mel_path[utt]).transpose(1, 0)
|
| 163 |
+
duration = np.load(self.utt2duration_path[utt])
|
| 164 |
+
assert mel.shape[0] == sum(
|
| 165 |
+
duration
|
| 166 |
+
), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}"
|
| 167 |
+
utt2dur[utt] = duration
|
| 168 |
+
return utt2dur
|
| 169 |
+
|
| 170 |
+
def __len__(self):
|
| 171 |
+
return len(self.metadata)
|
| 172 |
+
|
| 173 |
+
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
|
| 174 |
+
"""
|
| 175 |
+
ending_ts: to avoid invalid whisper features for over 30s audios
|
| 176 |
+
2812 = 30 * 24000 // 256
|
| 177 |
+
"""
|
| 178 |
+
ts = max(feature_seq_len - max_seq_len, 0)
|
| 179 |
+
ts = min(ts, ending_ts - max_seq_len)
|
| 180 |
+
|
| 181 |
+
start = random.randint(0, ts)
|
| 182 |
+
end = start + max_seq_len
|
| 183 |
+
return start, end
|
| 184 |
+
|
| 185 |
+
def clip_if_too_long(self, sample, max_seq_len=1000):
|
| 186 |
+
"""
|
| 187 |
+
sample :
|
| 188 |
+
{
|
| 189 |
+
'spk_id': (1,),
|
| 190 |
+
'target_len': int
|
| 191 |
+
'mel': (seq_len, dim),
|
| 192 |
+
'frame_pitch': (seq_len,)
|
| 193 |
+
'frame_energy': (seq_len,)
|
| 194 |
+
'content_vector_feat': (seq_len, dim)
|
| 195 |
+
}
|
| 196 |
+
"""
|
| 197 |
+
if sample["target_len"] <= max_seq_len:
|
| 198 |
+
return sample
|
| 199 |
+
|
| 200 |
+
start, end = self.random_select(sample["target_len"], max_seq_len)
|
| 201 |
+
sample["target_len"] = end - start
|
| 202 |
+
|
| 203 |
+
for k in sample.keys():
|
| 204 |
+
if k not in ["spk_id", "target_len"]:
|
| 205 |
+
sample[k] = sample[k][start:end]
|
| 206 |
+
|
| 207 |
+
return sample
|
| 208 |
+
|
| 209 |
+
def check_metadata(self):
|
| 210 |
+
new_metadata = []
|
| 211 |
+
for utt_info in self.metadata:
|
| 212 |
+
dataset = utt_info["Dataset"]
|
| 213 |
+
uid = utt_info["Uid"]
|
| 214 |
+
utt = "{}_{}".format(dataset, uid)
|
| 215 |
+
if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists(
|
| 216 |
+
self.utt2mel_path[utt]
|
| 217 |
+
):
|
| 218 |
+
continue
|
| 219 |
+
else:
|
| 220 |
+
new_metadata.append(utt_info)
|
| 221 |
+
return new_metadata
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class FS2Collator(BaseOfflineCollator):
|
| 225 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, cfg):
|
| 228 |
+
BaseOfflineCollator.__init__(self, cfg)
|
| 229 |
+
self.sort = cfg.train.sort_sample
|
| 230 |
+
self.batch_size = cfg.train.batch_size
|
| 231 |
+
self.drop_last = cfg.train.drop_last
|
| 232 |
+
|
| 233 |
+
def __call__(self, batch):
|
| 234 |
+
# mel: [b, T, n_mels]
|
| 235 |
+
# frame_pitch, frame_energy: [1, T]
|
| 236 |
+
# target_len: [1]
|
| 237 |
+
# spk_id: [b, 1]
|
| 238 |
+
# mask: [b, T, 1]
|
| 239 |
+
packed_batch_features = dict()
|
| 240 |
+
|
| 241 |
+
for key in batch[0].keys():
|
| 242 |
+
if key == "target_len":
|
| 243 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
| 244 |
+
[b["target_len"] for b in batch]
|
| 245 |
+
)
|
| 246 |
+
masks = [
|
| 247 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
| 248 |
+
]
|
| 249 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 250 |
+
masks, batch_first=True, padding_value=0
|
| 251 |
+
)
|
| 252 |
+
elif key == "text_len":
|
| 253 |
+
packed_batch_features["text_len"] = torch.LongTensor(
|
| 254 |
+
[b["text_len"] for b in batch]
|
| 255 |
+
)
|
| 256 |
+
masks = [
|
| 257 |
+
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
|
| 258 |
+
]
|
| 259 |
+
packed_batch_features["text_mask"] = pad_sequence(
|
| 260 |
+
masks, batch_first=True, padding_value=0
|
| 261 |
+
)
|
| 262 |
+
elif key == "spk_id":
|
| 263 |
+
packed_batch_features["spk_id"] = torch.LongTensor(
|
| 264 |
+
[b["spk_id"] for b in batch]
|
| 265 |
+
)
|
| 266 |
+
elif key == "uid":
|
| 267 |
+
packed_batch_features[key] = [b["uid"] for b in batch]
|
| 268 |
+
else:
|
| 269 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 270 |
+
packed_batch_features[key] = pad_sequence(
|
| 271 |
+
values, batch_first=True, padding_value=0
|
| 272 |
+
)
|
| 273 |
+
return packed_batch_features
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class FS2TestDataset(BaseTestDataset):
|
| 277 |
+
def __init__(self, args, cfg, infer_type=None):
|
| 278 |
+
datasets = cfg.dataset
|
| 279 |
+
cfg = cfg.preprocess
|
| 280 |
+
is_bigdata = False
|
| 281 |
+
|
| 282 |
+
assert len(datasets) >= 1
|
| 283 |
+
if len(datasets) > 1:
|
| 284 |
+
datasets.sort()
|
| 285 |
+
bigdata_version = "_".join(datasets)
|
| 286 |
+
processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version)
|
| 287 |
+
is_bigdata = True
|
| 288 |
+
else:
|
| 289 |
+
processed_data_dir = os.path.join(cfg.processed_dir, args.dataset)
|
| 290 |
+
|
| 291 |
+
if args.test_list_file:
|
| 292 |
+
self.metafile_path = args.test_list_file
|
| 293 |
+
self.metadata = self.get_metadata()
|
| 294 |
+
else:
|
| 295 |
+
assert args.testing_set
|
| 296 |
+
source_metafile_path = os.path.join(
|
| 297 |
+
cfg.processed_dir,
|
| 298 |
+
args.dataset,
|
| 299 |
+
"{}.json".format(args.testing_set),
|
| 300 |
+
)
|
| 301 |
+
with open(source_metafile_path, "r") as f:
|
| 302 |
+
self.metadata = json.load(f)
|
| 303 |
+
|
| 304 |
+
self.cfg = cfg
|
| 305 |
+
self.datasets = datasets
|
| 306 |
+
self.data_root = processed_data_dir
|
| 307 |
+
self.is_bigdata = is_bigdata
|
| 308 |
+
self.source_dataset = args.dataset
|
| 309 |
+
|
| 310 |
+
######### Load source acoustic features #########
|
| 311 |
+
if cfg.use_spkid:
|
| 312 |
+
spk2id_path = os.path.join(self.data_root, cfg.spk2id)
|
| 313 |
+
utt2sp_path = os.path.join(self.data_root, cfg.utt2spk)
|
| 314 |
+
self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets)
|
| 315 |
+
|
| 316 |
+
# utt2lab
|
| 317 |
+
self.utt2lab_path = {}
|
| 318 |
+
for utt_info in self.metadata:
|
| 319 |
+
dataset = utt_info["Dataset"]
|
| 320 |
+
uid = utt_info["Uid"]
|
| 321 |
+
utt = "{}_{}".format(dataset, uid)
|
| 322 |
+
self.utt2lab_path[utt] = os.path.join(
|
| 323 |
+
cfg.processed_dir,
|
| 324 |
+
dataset,
|
| 325 |
+
cfg.lab_dir,
|
| 326 |
+
uid + ".txt",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
self.speaker_map = {}
|
| 330 |
+
if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")):
|
| 331 |
+
with open(
|
| 332 |
+
os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json"))
|
| 333 |
+
) as f:
|
| 334 |
+
self.speaker_map = json.load(f)
|
| 335 |
+
|
| 336 |
+
def __getitem__(self, index):
|
| 337 |
+
single_feature = {}
|
| 338 |
+
|
| 339 |
+
utt_info = self.metadata[index]
|
| 340 |
+
dataset = utt_info["Dataset"]
|
| 341 |
+
uid = utt_info["Uid"]
|
| 342 |
+
utt = "{}_{}".format(dataset, uid)
|
| 343 |
+
|
| 344 |
+
# text
|
| 345 |
+
f = open(self.utt2lab_path[utt], "r")
|
| 346 |
+
phones = f.readlines()[0].strip()
|
| 347 |
+
f.close()
|
| 348 |
+
|
| 349 |
+
phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners))
|
| 350 |
+
text_len = len(phones_ids)
|
| 351 |
+
|
| 352 |
+
# speaker
|
| 353 |
+
if len(self.speaker_map) > 0:
|
| 354 |
+
speaker_id = self.speaker_map[utt_info["Singer"]]
|
| 355 |
+
else:
|
| 356 |
+
speaker_id = 0
|
| 357 |
+
|
| 358 |
+
single_feature.update(
|
| 359 |
+
{
|
| 360 |
+
"texts": phones_ids,
|
| 361 |
+
"spk_id": speaker_id,
|
| 362 |
+
"text_len": text_len,
|
| 363 |
+
}
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
return single_feature
|
| 367 |
+
|
| 368 |
+
def __len__(self):
|
| 369 |
+
return len(self.metadata)
|
| 370 |
+
|
| 371 |
+
def get_metadata(self):
|
| 372 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 373 |
+
metadata = json.load(f)
|
| 374 |
+
|
| 375 |
+
return metadata
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class FS2TestCollator(BaseTestCollator):
|
| 379 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, cfg):
|
| 382 |
+
self.cfg = cfg
|
| 383 |
+
|
| 384 |
+
def __call__(self, batch):
|
| 385 |
+
packed_batch_features = dict()
|
| 386 |
+
|
| 387 |
+
# mel: [b, T, n_mels]
|
| 388 |
+
# frame_pitch, frame_energy: [1, T]
|
| 389 |
+
# target_len: [1]
|
| 390 |
+
# spk_id: [b, 1]
|
| 391 |
+
# mask: [b, T, 1]
|
| 392 |
+
|
| 393 |
+
for key in batch[0].keys():
|
| 394 |
+
if key == "target_len":
|
| 395 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
| 396 |
+
[b["target_len"] for b in batch]
|
| 397 |
+
)
|
| 398 |
+
masks = [
|
| 399 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
| 400 |
+
]
|
| 401 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 402 |
+
masks, batch_first=True, padding_value=0
|
| 403 |
+
)
|
| 404 |
+
elif key == "text_len":
|
| 405 |
+
packed_batch_features["text_len"] = torch.LongTensor(
|
| 406 |
+
[b["text_len"] for b in batch]
|
| 407 |
+
)
|
| 408 |
+
masks = [
|
| 409 |
+
torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch
|
| 410 |
+
]
|
| 411 |
+
packed_batch_features["text_mask"] = pad_sequence(
|
| 412 |
+
masks, batch_first=True, padding_value=0
|
| 413 |
+
)
|
| 414 |
+
elif key == "spk_id":
|
| 415 |
+
packed_batch_features["spk_id"] = torch.LongTensor(
|
| 416 |
+
[b["spk_id"] for b in batch]
|
| 417 |
+
)
|
| 418 |
+
else:
|
| 419 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 420 |
+
packed_batch_features[key] = pad_sequence(
|
| 421 |
+
values, batch_first=True, padding_value=0
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return packed_batch_features
|
Amphion/models/tts/naturalspeech2/diffusion.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from models.tts.naturalspeech2.wavenet import WaveNet
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Diffusion(nn.Module):
|
| 14 |
+
def __init__(self, cfg):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.cfg = cfg
|
| 18 |
+
|
| 19 |
+
self.diff_estimator = WaveNet(cfg.wavenet)
|
| 20 |
+
self.beta_min = cfg.beta_min
|
| 21 |
+
self.beta_max = cfg.beta_max
|
| 22 |
+
self.sigma = cfg.sigma
|
| 23 |
+
self.noise_factor = cfg.noise_factor
|
| 24 |
+
|
| 25 |
+
def forward(self, x, x_mask, cond, spk_query_emb, offset=1e-5):
|
| 26 |
+
"""
|
| 27 |
+
x: (B, 128, T)
|
| 28 |
+
x_mask: (B, T), mask is 0
|
| 29 |
+
cond: (B, T, 512)
|
| 30 |
+
spk_query_emb: (B, 32, 512)
|
| 31 |
+
"""
|
| 32 |
+
diffusion_step = torch.rand(
|
| 33 |
+
x.shape[0], dtype=x.dtype, device=x.device, requires_grad=False
|
| 34 |
+
)
|
| 35 |
+
diffusion_step = torch.clamp(diffusion_step, offset, 1.0 - offset)
|
| 36 |
+
xt, z = self.forward_diffusion(x0=x, diffusion_step=diffusion_step)
|
| 37 |
+
|
| 38 |
+
cum_beta = self.get_cum_beta(diffusion_step.unsqueeze(-1).unsqueeze(-1))
|
| 39 |
+
x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
|
| 40 |
+
mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
|
| 41 |
+
variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
|
| 42 |
+
noise_pred = (xt - mean_pred) / (torch.sqrt(variance) * self.noise_factor)
|
| 43 |
+
noise = z
|
| 44 |
+
diff_out = {"x0_pred": x0_pred, "noise_pred": noise_pred, "noise": noise}
|
| 45 |
+
return diff_out
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def get_cum_beta(self, time_step):
|
| 49 |
+
return self.beta_min * time_step + 0.5 * (self.beta_max - self.beta_min) * (
|
| 50 |
+
time_step**2
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
@torch.no_grad()
|
| 54 |
+
def get_beta_t(self, time_step):
|
| 55 |
+
return self.beta_min + (self.beta_max - self.beta_min) * time_step
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def forward_diffusion(self, x0, diffusion_step):
|
| 59 |
+
"""
|
| 60 |
+
x0: (B, 128, T)
|
| 61 |
+
time_step: (B,)
|
| 62 |
+
"""
|
| 63 |
+
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
|
| 64 |
+
cum_beta = self.get_cum_beta(time_step)
|
| 65 |
+
mean = x0 * torch.exp(-0.5 * cum_beta / (self.sigma**2))
|
| 66 |
+
variance = (self.sigma**2) * (1 - torch.exp(-cum_beta / (self.sigma**2)))
|
| 67 |
+
z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device, requires_grad=False)
|
| 68 |
+
xt = mean + z * torch.sqrt(variance) * self.noise_factor
|
| 69 |
+
return xt, z
|
| 70 |
+
|
| 71 |
+
@torch.no_grad()
|
| 72 |
+
def cal_dxt(self, xt, x_mask, cond, spk_query_emb, diffusion_step, h):
|
| 73 |
+
time_step = diffusion_step.unsqueeze(-1).unsqueeze(-1)
|
| 74 |
+
cum_beta = self.get_cum_beta(time_step=time_step)
|
| 75 |
+
beta_t = self.get_beta_t(time_step=time_step)
|
| 76 |
+
x0_pred = self.diff_estimator(xt, x_mask, cond, diffusion_step, spk_query_emb)
|
| 77 |
+
mean_pred = x0_pred * torch.exp(-0.5 * cum_beta / (self.sigma**2))
|
| 78 |
+
noise_pred = xt - mean_pred
|
| 79 |
+
variance = (self.sigma**2) * (1.0 - torch.exp(-cum_beta / (self.sigma**2)))
|
| 80 |
+
logp = -noise_pred / (variance + 1e-8)
|
| 81 |
+
dxt = -0.5 * h * beta_t * (logp + xt / (self.sigma**2))
|
| 82 |
+
return dxt
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def reverse_diffusion(self, z, x_mask, cond, n_timesteps, spk_query_emb):
|
| 86 |
+
h = 1.0 / max(n_timesteps, 1)
|
| 87 |
+
xt = z
|
| 88 |
+
for i in range(n_timesteps):
|
| 89 |
+
t = (1.0 - (i + 0.5) * h) * torch.ones(
|
| 90 |
+
z.shape[0], dtype=z.dtype, device=z.device
|
| 91 |
+
)
|
| 92 |
+
dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
|
| 93 |
+
xt_ = xt - dxt
|
| 94 |
+
if self.cfg.ode_solver == "midpoint":
|
| 95 |
+
x_mid = 0.5 * (xt_ + xt)
|
| 96 |
+
dxt = self.cal_dxt(
|
| 97 |
+
x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
|
| 98 |
+
)
|
| 99 |
+
xt = xt - dxt
|
| 100 |
+
elif self.cfg.ode_solver == "euler":
|
| 101 |
+
xt = xt_
|
| 102 |
+
return xt
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def reverse_diffusion_from_t(
|
| 106 |
+
self, z, x_mask, cond, n_timesteps, spk_query_emb, t_start
|
| 107 |
+
):
|
| 108 |
+
h = t_start / max(n_timesteps, 1)
|
| 109 |
+
xt = z
|
| 110 |
+
for i in range(n_timesteps):
|
| 111 |
+
t = (t_start - (i + 0.5) * h) * torch.ones(
|
| 112 |
+
z.shape[0], dtype=z.dtype, device=z.device
|
| 113 |
+
)
|
| 114 |
+
dxt = self.cal_dxt(xt, x_mask, cond, spk_query_emb, diffusion_step=t, h=h)
|
| 115 |
+
xt_ = xt - dxt
|
| 116 |
+
if self.cfg.ode_solver == "midpoint":
|
| 117 |
+
x_mid = 0.5 * (xt_ + xt)
|
| 118 |
+
dxt = self.cal_dxt(
|
| 119 |
+
x_mid, x_mask, cond, spk_query_emb, diffusion_step=t + 0.5 * h, h=h
|
| 120 |
+
)
|
| 121 |
+
xt = xt - dxt
|
| 122 |
+
elif self.cfg.ode_solver == "euler":
|
| 123 |
+
xt = xt_
|
| 124 |
+
return xt
|
Amphion/models/tts/vits/vits_inference.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import torch
|
| 11 |
+
import json
|
| 12 |
+
from models.tts.base.tts_inferece import TTSInference
|
| 13 |
+
from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator
|
| 14 |
+
from models.tts.vits.vits import SynthesizerTrn
|
| 15 |
+
from processors.phone_extractor import phoneExtractor
|
| 16 |
+
from text.text_token_collation import phoneIDCollation
|
| 17 |
+
from utils.data_utils import *
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VitsInference(TTSInference):
|
| 21 |
+
def __init__(self, args=None, cfg=None):
|
| 22 |
+
TTSInference.__init__(self, args, cfg)
|
| 23 |
+
|
| 24 |
+
def _build_model(self):
|
| 25 |
+
net_g = SynthesizerTrn(
|
| 26 |
+
self.cfg.model.text_token_num,
|
| 27 |
+
self.cfg.preprocess.n_fft // 2 + 1,
|
| 28 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 29 |
+
**self.cfg.model,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return net_g
|
| 33 |
+
|
| 34 |
+
def _build_test_dataset(sefl):
|
| 35 |
+
return VITSTestDataset, VITSTestCollator
|
| 36 |
+
|
| 37 |
+
def build_save_dir(self, dataset, speaker):
|
| 38 |
+
save_dir = os.path.join(
|
| 39 |
+
self.args.output_dir,
|
| 40 |
+
"tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
|
| 41 |
+
)
|
| 42 |
+
if dataset is not None:
|
| 43 |
+
save_dir = os.path.join(save_dir, "data_{}".format(dataset))
|
| 44 |
+
if speaker != -1:
|
| 45 |
+
save_dir = os.path.join(
|
| 46 |
+
save_dir,
|
| 47 |
+
"spk_{}".format(speaker),
|
| 48 |
+
)
|
| 49 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 50 |
+
print("Saving to ", save_dir)
|
| 51 |
+
return save_dir
|
| 52 |
+
|
| 53 |
+
def inference_for_batches(
|
| 54 |
+
self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
|
| 55 |
+
):
|
| 56 |
+
###### Construct test_batch ######
|
| 57 |
+
n_batch = len(self.test_dataloader)
|
| 58 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
| 59 |
+
print(
|
| 60 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
| 61 |
+
now, self.test_batch_size, n_batch
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
self.model.eval()
|
| 65 |
+
|
| 66 |
+
###### Inference for each batch ######
|
| 67 |
+
pred_res = []
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
for i, batch_data in enumerate(
|
| 70 |
+
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
|
| 71 |
+
):
|
| 72 |
+
spk_id = None
|
| 73 |
+
if (
|
| 74 |
+
self.cfg.preprocess.use_spkid
|
| 75 |
+
and self.cfg.train.multi_speaker_training
|
| 76 |
+
):
|
| 77 |
+
spk_id = batch_data["spk_id"]
|
| 78 |
+
|
| 79 |
+
outputs = self.model.infer(
|
| 80 |
+
batch_data["phone_seq"],
|
| 81 |
+
batch_data["phone_len"],
|
| 82 |
+
spk_id,
|
| 83 |
+
noise_scale=noise_scale,
|
| 84 |
+
noise_scale_w=noise_scale_w,
|
| 85 |
+
length_scale=length_scale,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
audios = outputs["y_hat"]
|
| 89 |
+
masks = outputs["mask"]
|
| 90 |
+
|
| 91 |
+
for idx in range(audios.size(0)):
|
| 92 |
+
audio = audios[idx, 0, :].data.cpu().float()
|
| 93 |
+
mask = masks[idx, :, :]
|
| 94 |
+
audio_length = (
|
| 95 |
+
mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size
|
| 96 |
+
)
|
| 97 |
+
audio_length = audio_length.cpu().numpy()
|
| 98 |
+
audio = audio[:audio_length]
|
| 99 |
+
pred_res.append(audio)
|
| 100 |
+
|
| 101 |
+
return pred_res
|
| 102 |
+
|
| 103 |
+
def inference_for_single_utterance(
|
| 104 |
+
self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
|
| 105 |
+
):
|
| 106 |
+
text = self.args.text
|
| 107 |
+
|
| 108 |
+
# get phone symbol file
|
| 109 |
+
phone_symbol_file = None
|
| 110 |
+
if self.cfg.preprocess.phone_extractor != "lexicon":
|
| 111 |
+
phone_symbol_file = os.path.join(
|
| 112 |
+
self.exp_dir, self.cfg.preprocess.symbols_dict
|
| 113 |
+
)
|
| 114 |
+
assert os.path.exists(phone_symbol_file)
|
| 115 |
+
# convert text to phone sequence
|
| 116 |
+
phone_extractor = phoneExtractor(self.cfg)
|
| 117 |
+
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
|
| 118 |
+
# convert phone sequence to phone id sequence
|
| 119 |
+
phon_id_collator = phoneIDCollation(
|
| 120 |
+
self.cfg, symbols_dict_file=phone_symbol_file
|
| 121 |
+
)
|
| 122 |
+
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
|
| 123 |
+
|
| 124 |
+
if self.cfg.preprocess.add_blank:
|
| 125 |
+
phone_id_seq = intersperse(phone_id_seq, 0)
|
| 126 |
+
|
| 127 |
+
# convert phone sequence to phone id sequence
|
| 128 |
+
phone_id_seq = np.array(phone_id_seq)
|
| 129 |
+
phone_id_seq = torch.from_numpy(phone_id_seq)
|
| 130 |
+
|
| 131 |
+
# get speaker id if multi-speaker training and use speaker id
|
| 132 |
+
speaker_id = None
|
| 133 |
+
if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
|
| 134 |
+
spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
| 135 |
+
with open(spk2id_file, "r") as f:
|
| 136 |
+
spk2id = json.load(f)
|
| 137 |
+
speaker_name = self.args.speaker_name
|
| 138 |
+
assert (
|
| 139 |
+
speaker_name in spk2id
|
| 140 |
+
), f"Speaker {speaker_name} not found in the spk2id keys. \
|
| 141 |
+
Please make sure you've specified the correct speaker name in infer_speaker_name."
|
| 142 |
+
speaker_id = spk2id[speaker_name]
|
| 143 |
+
speaker_id = torch.from_numpy(
|
| 144 |
+
np.array([speaker_id], dtype=np.int32)
|
| 145 |
+
).unsqueeze(0)
|
| 146 |
+
|
| 147 |
+
with torch.no_grad():
|
| 148 |
+
x_tst = phone_id_seq.to(self.device).unsqueeze(0)
|
| 149 |
+
x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
|
| 150 |
+
if speaker_id is not None:
|
| 151 |
+
speaker_id = speaker_id.to(self.device)
|
| 152 |
+
outputs = self.model.infer(
|
| 153 |
+
x_tst,
|
| 154 |
+
x_tst_lengths,
|
| 155 |
+
sid=speaker_id,
|
| 156 |
+
noise_scale=noise_scale,
|
| 157 |
+
noise_scale_w=noise_scale_w,
|
| 158 |
+
length_scale=length_scale,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
audio = outputs["y_hat"][0, 0].data.cpu().float().numpy()
|
| 162 |
+
|
| 163 |
+
return audio
|
Amphion/models/vocoders/flow/flow_vocoder_trainer.py
ADDED
|
File without changes
|
Amphion/models/vocoders/gan/gan_vocoder_dataset.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 14 |
+
from utils.data_utils import *
|
| 15 |
+
from models.vocoders.vocoder_dataset import VocoderDataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class GANVocoderDataset(VocoderDataset):
|
| 19 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
cfg: config
|
| 23 |
+
dataset: dataset name
|
| 24 |
+
is_valid: whether to use train or valid dataset
|
| 25 |
+
"""
|
| 26 |
+
super().__init__(cfg, dataset, is_valid)
|
| 27 |
+
|
| 28 |
+
eval_index = random.randint(0, len(self.metadata) - 1)
|
| 29 |
+
eval_utt_info = self.metadata[eval_index]
|
| 30 |
+
eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"])
|
| 31 |
+
self.eval_audio = np.load(self.utt2audio_path[eval_utt])
|
| 32 |
+
if cfg.preprocess.use_mel:
|
| 33 |
+
self.eval_mel = np.load(self.utt2mel_path[eval_utt])
|
| 34 |
+
if cfg.preprocess.use_frame_pitch:
|
| 35 |
+
self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt])
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, index):
|
| 38 |
+
utt_info = self.metadata[index]
|
| 39 |
+
|
| 40 |
+
dataset = utt_info["Dataset"]
|
| 41 |
+
uid = utt_info["Uid"]
|
| 42 |
+
utt = "{}_{}".format(dataset, uid)
|
| 43 |
+
|
| 44 |
+
single_feature = dict()
|
| 45 |
+
|
| 46 |
+
if self.cfg.preprocess.use_mel:
|
| 47 |
+
mel = np.load(self.utt2mel_path[utt])
|
| 48 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel
|
| 49 |
+
|
| 50 |
+
if "target_len" not in single_feature.keys():
|
| 51 |
+
single_feature["target_len"] = mel.shape[1]
|
| 52 |
+
|
| 53 |
+
if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
|
| 54 |
+
mel = np.pad(
|
| 55 |
+
mel,
|
| 56 |
+
((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
|
| 57 |
+
mode="constant",
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
if "start" not in single_feature.keys():
|
| 61 |
+
start = random.randint(
|
| 62 |
+
0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame
|
| 63 |
+
)
|
| 64 |
+
end = start + self.cfg.preprocess.cut_mel_frame
|
| 65 |
+
single_feature["start"] = start
|
| 66 |
+
single_feature["end"] = end
|
| 67 |
+
mel = mel[:, single_feature["start"] : single_feature["end"]]
|
| 68 |
+
single_feature["mel"] = mel
|
| 69 |
+
|
| 70 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 71 |
+
frame_pitch = np.load(self.utt2frame_pitch_path[utt])
|
| 72 |
+
if "target_len" not in single_feature.keys():
|
| 73 |
+
single_feature["target_len"] = len(frame_pitch)
|
| 74 |
+
aligned_frame_pitch = align_length(
|
| 75 |
+
frame_pitch, single_feature["target_len"]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
|
| 79 |
+
aligned_frame_pitch = np.pad(
|
| 80 |
+
aligned_frame_pitch,
|
| 81 |
+
(
|
| 82 |
+
(
|
| 83 |
+
0,
|
| 84 |
+
self.cfg.preprocess.cut_mel_frame
|
| 85 |
+
* self.cfg.preprocess.hop_size
|
| 86 |
+
- audio.shape[-1],
|
| 87 |
+
)
|
| 88 |
+
),
|
| 89 |
+
mode="constant",
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
if "start" not in single_feature.keys():
|
| 93 |
+
start = random.randint(
|
| 94 |
+
0,
|
| 95 |
+
aligned_frame_pitch.shape[-1]
|
| 96 |
+
- self.cfg.preprocess.cut_mel_frame,
|
| 97 |
+
)
|
| 98 |
+
end = start + self.cfg.preprocess.cut_mel_frame
|
| 99 |
+
single_feature["start"] = start
|
| 100 |
+
single_feature["end"] = end
|
| 101 |
+
aligned_frame_pitch = aligned_frame_pitch[
|
| 102 |
+
single_feature["start"] : single_feature["end"]
|
| 103 |
+
]
|
| 104 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
| 105 |
+
|
| 106 |
+
if self.cfg.preprocess.use_audio:
|
| 107 |
+
audio = np.load(self.utt2audio_path[utt])
|
| 108 |
+
|
| 109 |
+
assert "target_len" in single_feature.keys()
|
| 110 |
+
|
| 111 |
+
if (
|
| 112 |
+
audio.shape[-1]
|
| 113 |
+
<= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size
|
| 114 |
+
):
|
| 115 |
+
audio = np.pad(
|
| 116 |
+
audio,
|
| 117 |
+
(
|
| 118 |
+
(
|
| 119 |
+
0,
|
| 120 |
+
self.cfg.preprocess.cut_mel_frame
|
| 121 |
+
* self.cfg.preprocess.hop_size
|
| 122 |
+
- audio.shape[-1],
|
| 123 |
+
)
|
| 124 |
+
),
|
| 125 |
+
mode="constant",
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
if "start" not in single_feature.keys():
|
| 129 |
+
audio = audio[
|
| 130 |
+
0 : self.cfg.preprocess.cut_mel_frame
|
| 131 |
+
* self.cfg.preprocess.hop_size
|
| 132 |
+
]
|
| 133 |
+
else:
|
| 134 |
+
audio = audio[
|
| 135 |
+
single_feature["start"]
|
| 136 |
+
* self.cfg.preprocess.hop_size : single_feature["end"]
|
| 137 |
+
* self.cfg.preprocess.hop_size,
|
| 138 |
+
]
|
| 139 |
+
single_feature["audio"] = audio
|
| 140 |
+
|
| 141 |
+
if self.cfg.preprocess.use_amplitude_phase:
|
| 142 |
+
logamp = np.load(self.utt2logamp_path[utt])
|
| 143 |
+
pha = np.load(self.utt2pha_path[utt])
|
| 144 |
+
rea = np.load(self.utt2rea_path[utt])
|
| 145 |
+
imag = np.load(self.utt2imag_path[utt])
|
| 146 |
+
|
| 147 |
+
assert "target_len" in single_feature.keys()
|
| 148 |
+
|
| 149 |
+
if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
|
| 150 |
+
logamp = np.pad(
|
| 151 |
+
logamp,
|
| 152 |
+
((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
|
| 153 |
+
mode="constant",
|
| 154 |
+
)
|
| 155 |
+
pha = np.pad(
|
| 156 |
+
pha,
|
| 157 |
+
((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
|
| 158 |
+
mode="constant",
|
| 159 |
+
)
|
| 160 |
+
rea = np.pad(
|
| 161 |
+
rea,
|
| 162 |
+
((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
|
| 163 |
+
mode="constant",
|
| 164 |
+
)
|
| 165 |
+
imag = np.pad(
|
| 166 |
+
imag,
|
| 167 |
+
((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
|
| 168 |
+
mode="constant",
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
logamp = logamp[:, single_feature["start"] : single_feature["end"]]
|
| 172 |
+
pha = pha[:, single_feature["start"] : single_feature["end"]]
|
| 173 |
+
rea = rea[:, single_feature["start"] : single_feature["end"]]
|
| 174 |
+
imag = imag[:, single_feature["start"] : single_feature["end"]]
|
| 175 |
+
single_feature["logamp"] = logamp
|
| 176 |
+
single_feature["pha"] = pha
|
| 177 |
+
single_feature["rea"] = rea
|
| 178 |
+
single_feature["imag"] = imag
|
| 179 |
+
|
| 180 |
+
return single_feature
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class GANVocoderCollator(object):
|
| 184 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, cfg):
|
| 187 |
+
self.cfg = cfg
|
| 188 |
+
|
| 189 |
+
def __call__(self, batch):
|
| 190 |
+
packed_batch_features = dict()
|
| 191 |
+
|
| 192 |
+
# mel: [b, n_mels, frame]
|
| 193 |
+
# frame_pitch: [b, frame]
|
| 194 |
+
# audios: [b, frame * hop_size]
|
| 195 |
+
|
| 196 |
+
for key in batch[0].keys():
|
| 197 |
+
if key in ["target_len", "start", "end"]:
|
| 198 |
+
continue
|
| 199 |
+
else:
|
| 200 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 201 |
+
packed_batch_features[key] = pad_sequence(
|
| 202 |
+
values, batch_first=True, padding_value=0
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return packed_batch_features
|
Amphion/modules/anti_aliasing/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .act import *
|
| 7 |
+
from .filter import *
|
| 8 |
+
from .resample import *
|
Amphion/modules/encoder/condition_encoder.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torchaudio.models import Conformer
|
| 10 |
+
from models.svc.transformer.transformer import PositionalEncoding
|
| 11 |
+
|
| 12 |
+
from utils.f0 import f0_to_coarse
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ContentEncoder(nn.Module):
|
| 16 |
+
def __init__(self, cfg, input_dim, output_dim):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.cfg = cfg
|
| 19 |
+
|
| 20 |
+
assert input_dim != 0
|
| 21 |
+
self.nn = nn.Linear(input_dim, output_dim)
|
| 22 |
+
|
| 23 |
+
# Introduce conformer or not
|
| 24 |
+
if (
|
| 25 |
+
"use_conformer_for_content_features" in cfg
|
| 26 |
+
and cfg.use_conformer_for_content_features
|
| 27 |
+
):
|
| 28 |
+
self.pos_encoder = PositionalEncoding(input_dim)
|
| 29 |
+
self.conformer = Conformer(
|
| 30 |
+
input_dim=input_dim,
|
| 31 |
+
num_heads=2,
|
| 32 |
+
ffn_dim=256,
|
| 33 |
+
num_layers=6,
|
| 34 |
+
depthwise_conv_kernel_size=3,
|
| 35 |
+
)
|
| 36 |
+
else:
|
| 37 |
+
self.conformer = None
|
| 38 |
+
|
| 39 |
+
def forward(self, x, length=None):
|
| 40 |
+
# x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
|
| 41 |
+
if self.conformer:
|
| 42 |
+
x = self.pos_encoder(x)
|
| 43 |
+
x, _ = self.conformer(x, length)
|
| 44 |
+
return self.nn(x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class MelodyEncoder(nn.Module):
|
| 48 |
+
def __init__(self, cfg):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.cfg = cfg
|
| 51 |
+
|
| 52 |
+
self.input_dim = self.cfg.input_melody_dim
|
| 53 |
+
self.output_dim = self.cfg.output_melody_dim
|
| 54 |
+
self.n_bins = self.cfg.n_bins_melody
|
| 55 |
+
|
| 56 |
+
if self.input_dim != 0:
|
| 57 |
+
if self.n_bins == 0:
|
| 58 |
+
# Not use quantization
|
| 59 |
+
self.nn = nn.Linear(self.input_dim, self.output_dim)
|
| 60 |
+
else:
|
| 61 |
+
self.f0_min = cfg.f0_min
|
| 62 |
+
self.f0_max = cfg.f0_max
|
| 63 |
+
|
| 64 |
+
self.nn = nn.Embedding(
|
| 65 |
+
num_embeddings=self.n_bins,
|
| 66 |
+
embedding_dim=self.output_dim,
|
| 67 |
+
padding_idx=None,
|
| 68 |
+
)
|
| 69 |
+
self.uv_embedding = nn.Embedding(2, self.output_dim)
|
| 70 |
+
|
| 71 |
+
def forward(self, x, uv=None, length=None):
|
| 72 |
+
# x: (B, frame_len)
|
| 73 |
+
if self.n_bins == 0:
|
| 74 |
+
x = x.unsqueeze(-1)
|
| 75 |
+
else:
|
| 76 |
+
x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
|
| 77 |
+
x = self.nn(x)
|
| 78 |
+
|
| 79 |
+
if self.cfg.use_uv:
|
| 80 |
+
uv = self.uv_embedding(uv)
|
| 81 |
+
x = x + uv
|
| 82 |
+
return x
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class LoudnessEncoder(nn.Module):
|
| 86 |
+
def __init__(self, cfg):
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.cfg = cfg
|
| 89 |
+
|
| 90 |
+
self.input_dim = self.cfg.input_loudness_dim
|
| 91 |
+
self.output_dim = self.cfg.output_loudness_dim
|
| 92 |
+
self.n_bins = self.cfg.n_bins_loudness
|
| 93 |
+
|
| 94 |
+
if self.input_dim != 0:
|
| 95 |
+
if self.n_bins == 0:
|
| 96 |
+
# Not use quantization
|
| 97 |
+
self.nn = nn.Linear(self.input_dim, self.output_dim)
|
| 98 |
+
else:
|
| 99 |
+
# TODO: set empirically now
|
| 100 |
+
self.loudness_min = 1e-30
|
| 101 |
+
self.loudness_max = 1.5
|
| 102 |
+
self.energy_bins = nn.Parameter(
|
| 103 |
+
torch.exp(
|
| 104 |
+
torch.linspace(
|
| 105 |
+
np.log(self.loudness_min),
|
| 106 |
+
np.log(self.loudness_max),
|
| 107 |
+
self.n_bins - 1,
|
| 108 |
+
)
|
| 109 |
+
),
|
| 110 |
+
requires_grad=False,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
self.nn = nn.Embedding(
|
| 114 |
+
num_embeddings=self.n_bins,
|
| 115 |
+
embedding_dim=self.output_dim,
|
| 116 |
+
padding_idx=None,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
# x: (N, frame_len)
|
| 121 |
+
if self.n_bins == 0:
|
| 122 |
+
x = x.unsqueeze(-1)
|
| 123 |
+
else:
|
| 124 |
+
x = torch.bucketize(x, self.energy_bins)
|
| 125 |
+
return self.nn(x)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SingerEncoder(nn.Module):
|
| 129 |
+
def __init__(self, cfg):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.cfg = cfg
|
| 132 |
+
|
| 133 |
+
self.input_dim = 1
|
| 134 |
+
self.output_dim = self.cfg.output_singer_dim
|
| 135 |
+
|
| 136 |
+
self.nn = nn.Embedding(
|
| 137 |
+
num_embeddings=cfg.singer_table_size,
|
| 138 |
+
embedding_dim=self.output_dim,
|
| 139 |
+
padding_idx=None,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
# x: (N, 1) -> (N, 1, output_dim)
|
| 144 |
+
return self.nn(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ConditionEncoder(nn.Module):
|
| 148 |
+
def __init__(self, cfg):
|
| 149 |
+
super().__init__()
|
| 150 |
+
self.cfg = cfg
|
| 151 |
+
self.merge_mode = cfg.merge_mode
|
| 152 |
+
|
| 153 |
+
### Semantic Features ###
|
| 154 |
+
if cfg.use_whisper:
|
| 155 |
+
self.whisper_encoder = ContentEncoder(
|
| 156 |
+
self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
|
| 157 |
+
)
|
| 158 |
+
if cfg.use_contentvec:
|
| 159 |
+
self.contentvec_encoder = ContentEncoder(
|
| 160 |
+
self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
|
| 161 |
+
)
|
| 162 |
+
if cfg.use_mert:
|
| 163 |
+
self.mert_encoder = ContentEncoder(
|
| 164 |
+
self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
|
| 165 |
+
)
|
| 166 |
+
if cfg.use_wenet:
|
| 167 |
+
self.wenet_encoder = ContentEncoder(
|
| 168 |
+
self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
### Prosody Features ###
|
| 172 |
+
if cfg.use_f0:
|
| 173 |
+
self.melody_encoder = MelodyEncoder(self.cfg)
|
| 174 |
+
if cfg.use_energy:
|
| 175 |
+
self.loudness_encoder = LoudnessEncoder(self.cfg)
|
| 176 |
+
|
| 177 |
+
### Speaker Features ###
|
| 178 |
+
if cfg.use_spkid:
|
| 179 |
+
self.singer_encoder = SingerEncoder(self.cfg)
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
outputs = []
|
| 183 |
+
|
| 184 |
+
if self.cfg.use_f0:
|
| 185 |
+
if self.cfg.use_uv:
|
| 186 |
+
pitch_enc_out = self.melody_encoder(
|
| 187 |
+
x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
pitch_enc_out = self.melody_encoder(
|
| 191 |
+
x["frame_pitch"], uv=None, length=x["target_len"]
|
| 192 |
+
)
|
| 193 |
+
outputs.append(pitch_enc_out)
|
| 194 |
+
|
| 195 |
+
if self.cfg.use_energy:
|
| 196 |
+
loudness_enc_out = self.loudness_encoder(x["frame_energy"])
|
| 197 |
+
outputs.append(loudness_enc_out)
|
| 198 |
+
|
| 199 |
+
if self.cfg.use_whisper:
|
| 200 |
+
# whisper_feat: [b, T, 1024]
|
| 201 |
+
whiser_enc_out = self.whisper_encoder(
|
| 202 |
+
x["whisper_feat"], length=x["target_len"]
|
| 203 |
+
)
|
| 204 |
+
outputs.append(whiser_enc_out)
|
| 205 |
+
seq_len = whiser_enc_out.shape[1]
|
| 206 |
+
|
| 207 |
+
if self.cfg.use_contentvec:
|
| 208 |
+
contentvec_enc_out = self.contentvec_encoder(
|
| 209 |
+
x["contentvec_feat"], length=x["target_len"]
|
| 210 |
+
)
|
| 211 |
+
outputs.append(contentvec_enc_out)
|
| 212 |
+
seq_len = contentvec_enc_out.shape[1]
|
| 213 |
+
|
| 214 |
+
if self.cfg.use_mert:
|
| 215 |
+
mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
|
| 216 |
+
outputs.append(mert_enc_out)
|
| 217 |
+
seq_len = mert_enc_out.shape[1]
|
| 218 |
+
|
| 219 |
+
if self.cfg.use_wenet:
|
| 220 |
+
wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
|
| 221 |
+
outputs.append(wenet_enc_out)
|
| 222 |
+
seq_len = wenet_enc_out.shape[1]
|
| 223 |
+
|
| 224 |
+
if self.cfg.use_spkid:
|
| 225 |
+
speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
|
| 226 |
+
assert (
|
| 227 |
+
"whisper_feat" in x.keys()
|
| 228 |
+
or "contentvec_feat" in x.keys()
|
| 229 |
+
or "mert_feat" in x.keys()
|
| 230 |
+
or "wenet_feat" in x.keys()
|
| 231 |
+
)
|
| 232 |
+
singer_info = speaker_enc_out.expand(-1, seq_len, -1)
|
| 233 |
+
outputs.append(singer_info)
|
| 234 |
+
|
| 235 |
+
encoder_output = None
|
| 236 |
+
if self.merge_mode == "concat":
|
| 237 |
+
encoder_output = torch.cat(outputs, dim=-1)
|
| 238 |
+
if self.merge_mode == "add":
|
| 239 |
+
# (#modules, N, seq_len, output_dim)
|
| 240 |
+
outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
|
| 241 |
+
# (N, seq_len, output_dim)
|
| 242 |
+
encoder_output = torch.sum(outputs, dim=0)
|
| 243 |
+
|
| 244 |
+
return encoder_output
|
Amphion/modules/general/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .input_strategies import PromptedFeatures, PromptedPrecomputedFeatures
|
| 2 |
+
from .scaling import BalancedDoubleSwish
|
| 3 |
+
from .utils import Transpose
|
Amphion/modules/monotonic_align/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code from https://github.com/jaywalnut310/vits/
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from .monotonic_align.core import maximum_path_c
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def maximum_path(neg_cent, mask):
|
| 9 |
+
"""Cython optimized version.
|
| 10 |
+
neg_cent: [b, t_t, t_s]
|
| 11 |
+
mask: [b, t_t, t_s]
|
| 12 |
+
"""
|
| 13 |
+
device = neg_cent.device
|
| 14 |
+
dtype = neg_cent.dtype
|
| 15 |
+
neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
|
| 16 |
+
path = np.zeros(neg_cent.shape, dtype=np.int32)
|
| 17 |
+
|
| 18 |
+
t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
|
| 19 |
+
t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
|
| 20 |
+
maximum_path_c(path, neg_cent, t_t_max, t_s_max)
|
| 21 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
Amphion/modules/neural_source_filter/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .sine_excitation import *
|
Amphion/modules/transformer/Layers.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from .SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FFTBlock(torch.nn.Module):
|
| 13 |
+
"""FFT Block"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=0.1):
|
| 16 |
+
super(FFTBlock, self).__init__()
|
| 17 |
+
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
| 18 |
+
self.pos_ffn = PositionwiseFeedForward(
|
| 19 |
+
d_model, d_inner, kernel_size, dropout=dropout
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def forward(self, enc_input, mask=None, slf_attn_mask=None):
|
| 23 |
+
enc_output, enc_slf_attn = self.slf_attn(
|
| 24 |
+
enc_input, enc_input, enc_input, mask=slf_attn_mask
|
| 25 |
+
)
|
| 26 |
+
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
|
| 27 |
+
|
| 28 |
+
enc_output = self.pos_ffn(enc_output)
|
| 29 |
+
enc_output = enc_output.masked_fill(mask.unsqueeze(-1), 0)
|
| 30 |
+
|
| 31 |
+
return enc_output, enc_slf_attn
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ConvNorm(torch.nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
in_channels,
|
| 38 |
+
out_channels,
|
| 39 |
+
kernel_size=1,
|
| 40 |
+
stride=1,
|
| 41 |
+
padding=None,
|
| 42 |
+
dilation=1,
|
| 43 |
+
bias=True,
|
| 44 |
+
w_init_gain="linear",
|
| 45 |
+
):
|
| 46 |
+
super(ConvNorm, self).__init__()
|
| 47 |
+
|
| 48 |
+
if padding is None:
|
| 49 |
+
assert kernel_size % 2 == 1
|
| 50 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 51 |
+
|
| 52 |
+
self.conv = torch.nn.Conv1d(
|
| 53 |
+
in_channels,
|
| 54 |
+
out_channels,
|
| 55 |
+
kernel_size=kernel_size,
|
| 56 |
+
stride=stride,
|
| 57 |
+
padding=padding,
|
| 58 |
+
dilation=dilation,
|
| 59 |
+
bias=bias,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, signal):
|
| 63 |
+
conv_signal = self.conv(signal)
|
| 64 |
+
|
| 65 |
+
return conv_signal
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PostNet(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
PostNet: Five 1-d convolution with 512 channels and kernel size 5
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
n_mel_channels=80,
|
| 76 |
+
postnet_embedding_dim=512,
|
| 77 |
+
postnet_kernel_size=5,
|
| 78 |
+
postnet_n_convolutions=5,
|
| 79 |
+
):
|
| 80 |
+
super(PostNet, self).__init__()
|
| 81 |
+
self.convolutions = nn.ModuleList()
|
| 82 |
+
|
| 83 |
+
self.convolutions.append(
|
| 84 |
+
nn.Sequential(
|
| 85 |
+
ConvNorm(
|
| 86 |
+
n_mel_channels,
|
| 87 |
+
postnet_embedding_dim,
|
| 88 |
+
kernel_size=postnet_kernel_size,
|
| 89 |
+
stride=1,
|
| 90 |
+
padding=int((postnet_kernel_size - 1) / 2),
|
| 91 |
+
dilation=1,
|
| 92 |
+
w_init_gain="tanh",
|
| 93 |
+
),
|
| 94 |
+
nn.BatchNorm1d(postnet_embedding_dim),
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
for i in range(1, postnet_n_convolutions - 1):
|
| 99 |
+
self.convolutions.append(
|
| 100 |
+
nn.Sequential(
|
| 101 |
+
ConvNorm(
|
| 102 |
+
postnet_embedding_dim,
|
| 103 |
+
postnet_embedding_dim,
|
| 104 |
+
kernel_size=postnet_kernel_size,
|
| 105 |
+
stride=1,
|
| 106 |
+
padding=int((postnet_kernel_size - 1) / 2),
|
| 107 |
+
dilation=1,
|
| 108 |
+
w_init_gain="tanh",
|
| 109 |
+
),
|
| 110 |
+
nn.BatchNorm1d(postnet_embedding_dim),
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.convolutions.append(
|
| 115 |
+
nn.Sequential(
|
| 116 |
+
ConvNorm(
|
| 117 |
+
postnet_embedding_dim,
|
| 118 |
+
n_mel_channels,
|
| 119 |
+
kernel_size=postnet_kernel_size,
|
| 120 |
+
stride=1,
|
| 121 |
+
padding=int((postnet_kernel_size - 1) / 2),
|
| 122 |
+
dilation=1,
|
| 123 |
+
w_init_gain="linear",
|
| 124 |
+
),
|
| 125 |
+
nn.BatchNorm1d(n_mel_channels),
|
| 126 |
+
)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
x = x.contiguous().transpose(1, 2)
|
| 131 |
+
|
| 132 |
+
for i in range(len(self.convolutions) - 1):
|
| 133 |
+
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
| 134 |
+
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
| 135 |
+
|
| 136 |
+
x = x.contiguous().transpose(1, 2)
|
| 137 |
+
return x
|
Amphion/modules/wenet_extractor/cif/predictor.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn
|
| 27 |
+
from modules.wenet_extractor.utils.mask import make_pad_mask
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Predictor(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
idim,
|
| 34 |
+
l_order,
|
| 35 |
+
r_order,
|
| 36 |
+
threshold=1.0,
|
| 37 |
+
dropout=0.1,
|
| 38 |
+
smooth_factor=1.0,
|
| 39 |
+
noise_threshold=0,
|
| 40 |
+
tail_threshold=0.45,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
|
| 45 |
+
self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
| 46 |
+
self.cif_output = nn.Linear(idim, 1)
|
| 47 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 48 |
+
self.threshold = threshold
|
| 49 |
+
self.smooth_factor = smooth_factor
|
| 50 |
+
self.noise_threshold = noise_threshold
|
| 51 |
+
self.tail_threshold = tail_threshold
|
| 52 |
+
|
| 53 |
+
def forward(
|
| 54 |
+
self,
|
| 55 |
+
hidden,
|
| 56 |
+
target_label: Optional[torch.Tensor] = None,
|
| 57 |
+
mask: torch.Tensor = torch.tensor(0),
|
| 58 |
+
ignore_id: int = -1,
|
| 59 |
+
mask_chunk_predictor: Optional[torch.Tensor] = None,
|
| 60 |
+
target_label_length: Optional[torch.Tensor] = None,
|
| 61 |
+
):
|
| 62 |
+
h = hidden
|
| 63 |
+
context = h.transpose(1, 2)
|
| 64 |
+
queries = self.pad(context)
|
| 65 |
+
memory = self.cif_conv1d(queries)
|
| 66 |
+
output = memory + context
|
| 67 |
+
output = self.dropout(output)
|
| 68 |
+
output = output.transpose(1, 2)
|
| 69 |
+
output = torch.relu(output)
|
| 70 |
+
output = self.cif_output(output)
|
| 71 |
+
alphas = torch.sigmoid(output)
|
| 72 |
+
alphas = torch.nn.functional.relu(
|
| 73 |
+
alphas * self.smooth_factor - self.noise_threshold
|
| 74 |
+
)
|
| 75 |
+
if mask is not None:
|
| 76 |
+
mask = mask.transpose(-1, -2).float()
|
| 77 |
+
alphas = alphas * mask
|
| 78 |
+
if mask_chunk_predictor is not None:
|
| 79 |
+
alphas = alphas * mask_chunk_predictor
|
| 80 |
+
alphas = alphas.squeeze(-1)
|
| 81 |
+
mask = mask.squeeze(-1)
|
| 82 |
+
if target_label_length is not None:
|
| 83 |
+
target_length = target_label_length
|
| 84 |
+
elif target_label is not None:
|
| 85 |
+
target_length = (target_label != ignore_id).float().sum(-1)
|
| 86 |
+
else:
|
| 87 |
+
target_length = None
|
| 88 |
+
token_num = alphas.sum(-1)
|
| 89 |
+
if target_length is not None:
|
| 90 |
+
alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| 91 |
+
elif self.tail_threshold > 0.0:
|
| 92 |
+
hidden, alphas, token_num = self.tail_process_fn(
|
| 93 |
+
hidden, alphas, token_num, mask=mask
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
| 97 |
+
|
| 98 |
+
if target_length is None and self.tail_threshold > 0.0:
|
| 99 |
+
token_num_int = torch.max(token_num).type(torch.int32).item()
|
| 100 |
+
acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
|
| 101 |
+
|
| 102 |
+
return acoustic_embeds, token_num, alphas, cif_peak
|
| 103 |
+
|
| 104 |
+
def tail_process_fn(
|
| 105 |
+
self,
|
| 106 |
+
hidden,
|
| 107 |
+
alphas,
|
| 108 |
+
token_num: Optional[torch.Tensor] = None,
|
| 109 |
+
mask: Optional[torch.Tensor] = None,
|
| 110 |
+
):
|
| 111 |
+
b, t, d = hidden.size()
|
| 112 |
+
tail_threshold = self.tail_threshold
|
| 113 |
+
if mask is not None:
|
| 114 |
+
zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
|
| 115 |
+
ones_t = torch.ones_like(zeros_t)
|
| 116 |
+
mask_1 = torch.cat([mask, zeros_t], dim=1)
|
| 117 |
+
mask_2 = torch.cat([ones_t, mask], dim=1)
|
| 118 |
+
mask = mask_2 - mask_1
|
| 119 |
+
tail_threshold = mask * tail_threshold
|
| 120 |
+
alphas = torch.cat([alphas, zeros_t], dim=1)
|
| 121 |
+
alphas = torch.add(alphas, tail_threshold)
|
| 122 |
+
else:
|
| 123 |
+
tail_threshold_tensor = torch.tensor(
|
| 124 |
+
[tail_threshold], dtype=alphas.dtype
|
| 125 |
+
).to(alphas.device)
|
| 126 |
+
tail_threshold_tensor = torch.reshape(tail_threshold_tensor, (1, 1))
|
| 127 |
+
alphas = torch.cat([alphas, tail_threshold_tensor], dim=1)
|
| 128 |
+
zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
|
| 129 |
+
hidden = torch.cat([hidden, zeros], dim=1)
|
| 130 |
+
token_num = alphas.sum(dim=-1)
|
| 131 |
+
token_num_floor = torch.floor(token_num)
|
| 132 |
+
|
| 133 |
+
return hidden, alphas, token_num_floor
|
| 134 |
+
|
| 135 |
+
def gen_frame_alignments(
|
| 136 |
+
self, alphas: torch.Tensor = None, encoder_sequence_length: torch.Tensor = None
|
| 137 |
+
):
|
| 138 |
+
batch_size, maximum_length = alphas.size()
|
| 139 |
+
int_type = torch.int32
|
| 140 |
+
|
| 141 |
+
is_training = self.training
|
| 142 |
+
if is_training:
|
| 143 |
+
token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
|
| 144 |
+
else:
|
| 145 |
+
token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
|
| 146 |
+
|
| 147 |
+
max_token_num = torch.max(token_num).item()
|
| 148 |
+
|
| 149 |
+
alphas_cumsum = torch.cumsum(alphas, dim=1)
|
| 150 |
+
alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
|
| 151 |
+
alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
|
| 152 |
+
|
| 153 |
+
index = torch.ones([batch_size, max_token_num], dtype=int_type)
|
| 154 |
+
index = torch.cumsum(index, dim=1)
|
| 155 |
+
index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
|
| 156 |
+
|
| 157 |
+
index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
|
| 158 |
+
index_div_bool_zeros = index_div.eq(0)
|
| 159 |
+
index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
|
| 160 |
+
index_div_bool_zeros_count = torch.clamp(
|
| 161 |
+
index_div_bool_zeros_count, 0, encoder_sequence_length.max()
|
| 162 |
+
)
|
| 163 |
+
token_num_mask = (~make_pad_mask(token_num, max_len=max_token_num)).to(
|
| 164 |
+
token_num.device
|
| 165 |
+
)
|
| 166 |
+
index_div_bool_zeros_count *= token_num_mask
|
| 167 |
+
|
| 168 |
+
index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(
|
| 169 |
+
1, 1, maximum_length
|
| 170 |
+
)
|
| 171 |
+
ones = torch.ones_like(index_div_bool_zeros_count_tile)
|
| 172 |
+
zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
|
| 173 |
+
ones = torch.cumsum(ones, dim=2)
|
| 174 |
+
cond = index_div_bool_zeros_count_tile == ones
|
| 175 |
+
index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
|
| 176 |
+
|
| 177 |
+
index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(
|
| 178 |
+
torch.bool
|
| 179 |
+
)
|
| 180 |
+
index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(
|
| 181 |
+
int_type
|
| 182 |
+
)
|
| 183 |
+
index_div_bool_zeros_count_tile_out = torch.sum(
|
| 184 |
+
index_div_bool_zeros_count_tile, dim=1
|
| 185 |
+
)
|
| 186 |
+
index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(
|
| 187 |
+
int_type
|
| 188 |
+
)
|
| 189 |
+
predictor_mask = (
|
| 190 |
+
(
|
| 191 |
+
~make_pad_mask(
|
| 192 |
+
encoder_sequence_length, max_len=encoder_sequence_length.max()
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
.type(int_type)
|
| 196 |
+
.to(encoder_sequence_length.device)
|
| 197 |
+
)
|
| 198 |
+
index_div_bool_zeros_count_tile_out = (
|
| 199 |
+
index_div_bool_zeros_count_tile_out * predictor_mask
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
predictor_alignments = index_div_bool_zeros_count_tile_out
|
| 203 |
+
predictor_alignments_length = predictor_alignments.sum(-1).type(
|
| 204 |
+
encoder_sequence_length.dtype
|
| 205 |
+
)
|
| 206 |
+
return predictor_alignments.detach(), predictor_alignments_length.detach()
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class MAELoss(nn.Module):
|
| 210 |
+
def __init__(self, normalize_length=False):
|
| 211 |
+
super(MAELoss, self).__init__()
|
| 212 |
+
self.normalize_length = normalize_length
|
| 213 |
+
self.criterion = torch.nn.L1Loss(reduction="sum")
|
| 214 |
+
|
| 215 |
+
def forward(self, token_length, pre_token_length):
|
| 216 |
+
loss_token_normalizer = token_length.size(0)
|
| 217 |
+
if self.normalize_length:
|
| 218 |
+
loss_token_normalizer = token_length.sum().type(torch.float32)
|
| 219 |
+
loss = self.criterion(token_length, pre_token_length)
|
| 220 |
+
loss = loss / loss_token_normalizer
|
| 221 |
+
return loss
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
|
| 225 |
+
batch_size, len_time, hidden_size = hidden.size()
|
| 226 |
+
|
| 227 |
+
# loop varss
|
| 228 |
+
integrate = torch.zeros([batch_size], device=hidden.device)
|
| 229 |
+
frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
|
| 230 |
+
# intermediate vars along time
|
| 231 |
+
list_fires = []
|
| 232 |
+
list_frames = []
|
| 233 |
+
|
| 234 |
+
for t in range(len_time):
|
| 235 |
+
alpha = alphas[:, t]
|
| 236 |
+
distribution_completion = (
|
| 237 |
+
torch.ones([batch_size], device=hidden.device) - integrate
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
integrate += alpha
|
| 241 |
+
list_fires.append(integrate)
|
| 242 |
+
|
| 243 |
+
fire_place = integrate >= threshold
|
| 244 |
+
integrate = torch.where(
|
| 245 |
+
fire_place,
|
| 246 |
+
integrate - torch.ones([batch_size], device=hidden.device),
|
| 247 |
+
integrate,
|
| 248 |
+
)
|
| 249 |
+
cur = torch.where(fire_place, distribution_completion, alpha)
|
| 250 |
+
remainds = alpha - cur
|
| 251 |
+
|
| 252 |
+
frame += cur[:, None] * hidden[:, t, :]
|
| 253 |
+
list_frames.append(frame)
|
| 254 |
+
frame = torch.where(
|
| 255 |
+
fire_place[:, None].repeat(1, hidden_size),
|
| 256 |
+
remainds[:, None] * hidden[:, t, :],
|
| 257 |
+
frame,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
fires = torch.stack(list_fires, 1)
|
| 261 |
+
frames = torch.stack(list_frames, 1)
|
| 262 |
+
list_ls = []
|
| 263 |
+
len_labels = torch.round(alphas.sum(-1)).int()
|
| 264 |
+
max_label_len = len_labels.max()
|
| 265 |
+
for b in range(batch_size):
|
| 266 |
+
fire = fires[b, :]
|
| 267 |
+
l = torch.index_select(
|
| 268 |
+
frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()
|
| 269 |
+
)
|
| 270 |
+
pad_l = torch.zeros(
|
| 271 |
+
[int(max_label_len - l.size(0)), hidden_size], device=hidden.device
|
| 272 |
+
)
|
| 273 |
+
list_ls.append(torch.cat([l, pad_l], 0))
|
| 274 |
+
return torch.stack(list_ls, 0), fires
|
Amphion/modules/wenet_extractor/paraformer/search/beam_search.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
from itertools import chain
|
| 24 |
+
from typing import Any
|
| 25 |
+
from typing import Dict
|
| 26 |
+
from typing import List
|
| 27 |
+
from typing import Tuple
|
| 28 |
+
from typing import Union
|
| 29 |
+
from typing import NamedTuple
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
|
| 33 |
+
from modules.wenet_extractor.paraformer.utils import end_detect
|
| 34 |
+
from modules.wenet_extractor.paraformer.search.ctc import CTCPrefixScorer
|
| 35 |
+
from modules.wenet_extractor.paraformer.search.scorer_interface import (
|
| 36 |
+
ScorerInterface,
|
| 37 |
+
PartialScorerInterface,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Hypothesis(NamedTuple):
|
| 42 |
+
"""Hypothesis data type."""
|
| 43 |
+
|
| 44 |
+
yseq: torch.Tensor
|
| 45 |
+
score: Union[float, torch.Tensor] = 0
|
| 46 |
+
scores: Dict[str, Union[float, torch.Tensor]] = dict()
|
| 47 |
+
states: Dict[str, Any] = dict()
|
| 48 |
+
|
| 49 |
+
def asdict(self) -> dict:
|
| 50 |
+
"""Convert data to JSON-friendly dict."""
|
| 51 |
+
return self._replace(
|
| 52 |
+
yseq=self.yseq.tolist(),
|
| 53 |
+
score=float(self.score),
|
| 54 |
+
scores={k: float(v) for k, v in self.scores.items()},
|
| 55 |
+
)._asdict()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class BeamSearchCIF(torch.nn.Module):
|
| 59 |
+
"""Beam search implementation."""
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
scorers: Dict[str, ScorerInterface],
|
| 64 |
+
weights: Dict[str, float],
|
| 65 |
+
beam_size: int,
|
| 66 |
+
vocab_size: int,
|
| 67 |
+
sos: int,
|
| 68 |
+
eos: int,
|
| 69 |
+
pre_beam_ratio: float = 1.5,
|
| 70 |
+
pre_beam_score_key: str = None,
|
| 71 |
+
):
|
| 72 |
+
"""Initialize beam search.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
scorers (dict[str, ScorerInterface]): Dict of decoder modules
|
| 76 |
+
e.g., Decoder, CTCPrefixScorer, LM
|
| 77 |
+
The scorer will be ignored if it is `None`
|
| 78 |
+
weights (dict[str, float]): Dict of weights for each scorers
|
| 79 |
+
The scorer will be ignored if its weight is 0
|
| 80 |
+
beam_size (int): The number of hypotheses kept during search
|
| 81 |
+
vocab_size (int): The number of vocabulary
|
| 82 |
+
sos (int): Start of sequence id
|
| 83 |
+
eos (int): End of sequence id
|
| 84 |
+
pre_beam_score_key (str): key of scores to perform pre-beam search
|
| 85 |
+
pre_beam_ratio (float): beam size in the pre-beam search
|
| 86 |
+
will be `int(pre_beam_ratio * beam_size)`
|
| 87 |
+
|
| 88 |
+
"""
|
| 89 |
+
super().__init__()
|
| 90 |
+
# set scorers
|
| 91 |
+
self.weights = weights
|
| 92 |
+
self.scorers = dict()
|
| 93 |
+
self.full_scorers = dict()
|
| 94 |
+
self.part_scorers = dict()
|
| 95 |
+
# this module dict is required for recursive cast
|
| 96 |
+
# `self.to(device, dtype)` in `recog.py`
|
| 97 |
+
self.nn_dict = torch.nn.ModuleDict()
|
| 98 |
+
for k, v in scorers.items():
|
| 99 |
+
w = weights.get(k, 0)
|
| 100 |
+
if w == 0 or v is None:
|
| 101 |
+
continue
|
| 102 |
+
assert isinstance(
|
| 103 |
+
v, ScorerInterface
|
| 104 |
+
), f"{k} ({type(v)}) does not implement ScorerInterface"
|
| 105 |
+
self.scorers[k] = v
|
| 106 |
+
if isinstance(v, PartialScorerInterface):
|
| 107 |
+
self.part_scorers[k] = v
|
| 108 |
+
else:
|
| 109 |
+
self.full_scorers[k] = v
|
| 110 |
+
if isinstance(v, torch.nn.Module):
|
| 111 |
+
self.nn_dict[k] = v
|
| 112 |
+
|
| 113 |
+
# set configurations
|
| 114 |
+
self.sos = sos
|
| 115 |
+
self.eos = eos
|
| 116 |
+
self.pre_beam_size = int(pre_beam_ratio * beam_size)
|
| 117 |
+
self.beam_size = beam_size
|
| 118 |
+
self.n_vocab = vocab_size
|
| 119 |
+
if (
|
| 120 |
+
pre_beam_score_key is not None
|
| 121 |
+
and pre_beam_score_key != "full"
|
| 122 |
+
and pre_beam_score_key not in self.full_scorers
|
| 123 |
+
):
|
| 124 |
+
raise KeyError(
|
| 125 |
+
f"{pre_beam_score_key} is not found in " f"{self.full_scorers}"
|
| 126 |
+
)
|
| 127 |
+
self.pre_beam_score_key = pre_beam_score_key
|
| 128 |
+
self.do_pre_beam = (
|
| 129 |
+
self.pre_beam_score_key is not None
|
| 130 |
+
and self.pre_beam_size < self.n_vocab
|
| 131 |
+
and len(self.part_scorers) > 0
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
|
| 135 |
+
"""Get an initial hypothesis data.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
x (torch.Tensor): The encoder output feature
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Hypothesis: The initial hypothesis.
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
init_states = dict()
|
| 145 |
+
init_scores = dict()
|
| 146 |
+
for k, d in self.scorers.items():
|
| 147 |
+
init_states[k] = d.init_state(x)
|
| 148 |
+
init_scores[k] = 0.0
|
| 149 |
+
return [
|
| 150 |
+
Hypothesis(
|
| 151 |
+
score=0.0,
|
| 152 |
+
scores=init_scores,
|
| 153 |
+
states=init_states,
|
| 154 |
+
yseq=torch.tensor([self.sos], device=x.device),
|
| 155 |
+
)
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
|
| 160 |
+
"""Append new token to prefix tokens.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
xs (torch.Tensor): The prefix token
|
| 164 |
+
x (int): The new token to append
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and
|
| 168 |
+
xs.device
|
| 169 |
+
|
| 170 |
+
"""
|
| 171 |
+
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
|
| 172 |
+
return torch.cat((xs, x))
|
| 173 |
+
|
| 174 |
+
def score_full(
|
| 175 |
+
self, hyp: Hypothesis, x: torch.Tensor
|
| 176 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 177 |
+
"""Score new hypothesis by `self.full_scorers`.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 181 |
+
x (torch.Tensor): Corresponding input feature
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 185 |
+
score dict of `hyp` that has string keys of `self.full_scorers`
|
| 186 |
+
and tensor score values of shape: `(self.n_vocab,)`,
|
| 187 |
+
and state dict that has string keys
|
| 188 |
+
and state values of `self.full_scorers`
|
| 189 |
+
|
| 190 |
+
"""
|
| 191 |
+
scores = dict()
|
| 192 |
+
states = dict()
|
| 193 |
+
for k, d in self.full_scorers.items():
|
| 194 |
+
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
|
| 195 |
+
return scores, states
|
| 196 |
+
|
| 197 |
+
def score_partial(
|
| 198 |
+
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
|
| 199 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
|
| 200 |
+
"""Score new hypothesis by `self.part_scorers`.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
hyp (Hypothesis): Hypothesis with prefix tokens to score
|
| 204 |
+
ids (torch.Tensor): 1D tensor of new partial tokens to score
|
| 205 |
+
x (torch.Tensor): Corresponding input feature
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
|
| 209 |
+
score dict of `hyp` that has string keys of `self.part_scorers`
|
| 210 |
+
and tensor score values of shape: `(len(ids),)`,
|
| 211 |
+
and state dict that has string keys
|
| 212 |
+
and state values of `self.part_scorers`
|
| 213 |
+
|
| 214 |
+
"""
|
| 215 |
+
scores = dict()
|
| 216 |
+
states = dict()
|
| 217 |
+
for k, d in self.part_scorers.items():
|
| 218 |
+
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
|
| 219 |
+
return scores, states
|
| 220 |
+
|
| 221 |
+
def beam(
|
| 222 |
+
self, weighted_scores: torch.Tensor, ids: torch.Tensor
|
| 223 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 224 |
+
"""Compute topk full token ids and partial token ids.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
weighted_scores (torch.Tensor): The weighted sum scores for each
|
| 228 |
+
tokens.
|
| 229 |
+
Its shape is `(self.n_vocab,)`.
|
| 230 |
+
ids (torch.Tensor): The partial token ids to compute topk
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
| 234 |
+
The topk full token ids and partial token ids.
|
| 235 |
+
Their shapes are `(self.beam_size,)`
|
| 236 |
+
|
| 237 |
+
"""
|
| 238 |
+
# no pre beam performed
|
| 239 |
+
if weighted_scores.size(0) == ids.size(0):
|
| 240 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
| 241 |
+
return top_ids, top_ids
|
| 242 |
+
|
| 243 |
+
# mask pruned in pre-beam not to select in topk
|
| 244 |
+
tmp = weighted_scores[ids]
|
| 245 |
+
weighted_scores[:] = -float("inf")
|
| 246 |
+
weighted_scores[ids] = tmp
|
| 247 |
+
top_ids = weighted_scores.topk(self.beam_size)[1]
|
| 248 |
+
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
|
| 249 |
+
return top_ids, local_ids
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def merge_scores(
|
| 253 |
+
prev_scores: Dict[str, float],
|
| 254 |
+
next_full_scores: Dict[str, torch.Tensor],
|
| 255 |
+
full_idx: int,
|
| 256 |
+
next_part_scores: Dict[str, torch.Tensor],
|
| 257 |
+
part_idx: int,
|
| 258 |
+
) -> Dict[str, torch.Tensor]:
|
| 259 |
+
"""Merge scores for new hypothesis.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
prev_scores (Dict[str, float]):
|
| 263 |
+
The previous hypothesis scores by `self.scorers`
|
| 264 |
+
next_full_scores (Dict[str, torch.Tensor]): scores by
|
| 265 |
+
`self.full_scorers`
|
| 266 |
+
full_idx (int): The next token id for `next_full_scores`
|
| 267 |
+
next_part_scores (Dict[str, torch.Tensor]):
|
| 268 |
+
scores of partial tokens by `self.part_scorers`
|
| 269 |
+
part_idx (int): The new token id for `next_part_scores`
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
Dict[str, torch.Tensor]: The new score dict.
|
| 273 |
+
Its keys are names of `self.full_scorers` and
|
| 274 |
+
`self.part_scorers`.
|
| 275 |
+
Its values are scalar tensors by the scorers.
|
| 276 |
+
|
| 277 |
+
"""
|
| 278 |
+
new_scores = dict()
|
| 279 |
+
for k, v in next_full_scores.items():
|
| 280 |
+
new_scores[k] = prev_scores[k] + v[full_idx]
|
| 281 |
+
for k, v in next_part_scores.items():
|
| 282 |
+
new_scores[k] = prev_scores[k] + v[part_idx]
|
| 283 |
+
return new_scores
|
| 284 |
+
|
| 285 |
+
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
|
| 286 |
+
"""Merge states for new hypothesis.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
states: states of `self.full_scorers`
|
| 290 |
+
part_states: states of `self.part_scorers`
|
| 291 |
+
part_idx (int): The new token id for `part_scores`
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Dict[str, torch.Tensor]: The new score dict.
|
| 295 |
+
Its keys are names of `self.full_scorers` and
|
| 296 |
+
`self.part_scorers`.
|
| 297 |
+
Its values are states of the scorers.
|
| 298 |
+
|
| 299 |
+
"""
|
| 300 |
+
new_states = dict()
|
| 301 |
+
for k, v in states.items():
|
| 302 |
+
new_states[k] = v
|
| 303 |
+
for k, d in self.part_scorers.items():
|
| 304 |
+
new_states[k] = d.select_state(part_states[k], part_idx)
|
| 305 |
+
return new_states
|
| 306 |
+
|
| 307 |
+
def search(
|
| 308 |
+
self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
|
| 309 |
+
) -> List[Hypothesis]:
|
| 310 |
+
"""Search new tokens for running hypotheses and encoded speech x.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
running_hyps (List[Hypothesis]): Running hypotheses on beam
|
| 314 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
List[Hypotheses]: Best sorted hypotheses
|
| 318 |
+
|
| 319 |
+
"""
|
| 320 |
+
best_hyps = []
|
| 321 |
+
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
|
| 322 |
+
for hyp in running_hyps:
|
| 323 |
+
# scoring
|
| 324 |
+
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
|
| 325 |
+
weighted_scores += am_score
|
| 326 |
+
scores, states = self.score_full(hyp, x)
|
| 327 |
+
for k in self.full_scorers:
|
| 328 |
+
weighted_scores += self.weights[k] * scores[k]
|
| 329 |
+
# partial scoring
|
| 330 |
+
if self.do_pre_beam:
|
| 331 |
+
pre_beam_scores = (
|
| 332 |
+
weighted_scores
|
| 333 |
+
if self.pre_beam_score_key == "full"
|
| 334 |
+
else scores[self.pre_beam_score_key]
|
| 335 |
+
)
|
| 336 |
+
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
|
| 337 |
+
part_scores, part_states = self.score_partial(hyp, part_ids, x)
|
| 338 |
+
for k in self.part_scorers:
|
| 339 |
+
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
|
| 340 |
+
# add previous hyp score
|
| 341 |
+
weighted_scores += hyp.score
|
| 342 |
+
|
| 343 |
+
# update hyps
|
| 344 |
+
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
|
| 345 |
+
# will be (2 x beam at most)
|
| 346 |
+
best_hyps.append(
|
| 347 |
+
Hypothesis(
|
| 348 |
+
score=weighted_scores[j],
|
| 349 |
+
yseq=self.append_token(hyp.yseq, j),
|
| 350 |
+
scores=self.merge_scores(
|
| 351 |
+
hyp.scores, scores, j, part_scores, part_j
|
| 352 |
+
),
|
| 353 |
+
states=self.merge_states(states, part_states, part_j),
|
| 354 |
+
)
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
# sort and prune 2 x beam -> beam
|
| 358 |
+
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
|
| 359 |
+
: min(len(best_hyps), self.beam_size)
|
| 360 |
+
]
|
| 361 |
+
return best_hyps
|
| 362 |
+
|
| 363 |
+
def forward(
|
| 364 |
+
self,
|
| 365 |
+
x: torch.Tensor,
|
| 366 |
+
am_scores: torch.Tensor,
|
| 367 |
+
maxlenratio: float = 0.0,
|
| 368 |
+
minlenratio: float = 0.0,
|
| 369 |
+
) -> List[Hypothesis]:
|
| 370 |
+
"""Perform beam search.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
x (torch.Tensor): Encoded speech feature (T, D)
|
| 374 |
+
maxlenratio (float): Input length ratio to obtain max output length.
|
| 375 |
+
If maxlenratio=0.0 (default), it uses a end-detect function
|
| 376 |
+
to automatically find maximum hypothesis lengths
|
| 377 |
+
If maxlenratio<0.0, its absolute value is interpreted
|
| 378 |
+
as a constant max output length.
|
| 379 |
+
minlenratio (float): Input length ratio to obtain min output length.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
list[Hypothesis]: N-best decoding results
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
# set length bounds
|
| 386 |
+
maxlen = am_scores.shape[0]
|
| 387 |
+
|
| 388 |
+
# main loop of prefix search
|
| 389 |
+
running_hyps = self.init_hyp(x)
|
| 390 |
+
ended_hyps = []
|
| 391 |
+
for i in range(maxlen):
|
| 392 |
+
best = self.search(running_hyps, x, am_scores[i])
|
| 393 |
+
# post process of one iteration
|
| 394 |
+
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
|
| 395 |
+
# end detection
|
| 396 |
+
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
|
| 397 |
+
break
|
| 398 |
+
|
| 399 |
+
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
|
| 400 |
+
# check the number of hypotheses reaching to eos
|
| 401 |
+
if len(nbest_hyps) == 0:
|
| 402 |
+
return (
|
| 403 |
+
[]
|
| 404 |
+
if minlenratio < 0.1
|
| 405 |
+
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
best = nbest_hyps[0]
|
| 409 |
+
return nbest_hyps
|
| 410 |
+
|
| 411 |
+
def post_process(
|
| 412 |
+
self,
|
| 413 |
+
i: int,
|
| 414 |
+
maxlen: int,
|
| 415 |
+
maxlenratio: float,
|
| 416 |
+
running_hyps: List[Hypothesis],
|
| 417 |
+
ended_hyps: List[Hypothesis],
|
| 418 |
+
) -> List[Hypothesis]:
|
| 419 |
+
"""Perform post-processing of beam search iterations.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
i (int): The length of hypothesis tokens.
|
| 423 |
+
maxlen (int): The maximum length of tokens in beam search.
|
| 424 |
+
maxlenratio (int): The maximum length ratio in beam search.
|
| 425 |
+
running_hyps (List[Hypothesis]): The running hypotheses in beam
|
| 426 |
+
search.
|
| 427 |
+
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
|
| 428 |
+
|
| 429 |
+
Returns:
|
| 430 |
+
List[Hypothesis]: The new running hypotheses.
|
| 431 |
+
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
# add eos in the final loop to avoid that there are no ended hyps
|
| 435 |
+
if i == maxlen - 1:
|
| 436 |
+
# logging.info("adding <eos> in the last position in the loop")
|
| 437 |
+
running_hyps = [
|
| 438 |
+
h._replace(yseq=self.append_token(h.yseq, self.eos))
|
| 439 |
+
for h in running_hyps
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
# add ended hypotheses to a final list, and removed them from current
|
| 443 |
+
# hypotheses
|
| 444 |
+
# (this will be a problem, number of hyps < beam)
|
| 445 |
+
remained_hyps = []
|
| 446 |
+
for hyp in running_hyps:
|
| 447 |
+
if hyp.yseq[-1] == self.eos:
|
| 448 |
+
# e.g., Word LM needs to add final <eos> score
|
| 449 |
+
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
|
| 450 |
+
s = d.final_score(hyp.states[k])
|
| 451 |
+
hyp.scores[k] += s
|
| 452 |
+
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
|
| 453 |
+
ended_hyps.append(hyp)
|
| 454 |
+
else:
|
| 455 |
+
remained_hyps.append(hyp)
|
| 456 |
+
return remained_hyps
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def build_beam_search(model, args, device):
|
| 460 |
+
scorers = {}
|
| 461 |
+
if model.ctc is not None:
|
| 462 |
+
ctc = CTCPrefixScorer(ctc=model.ctc, eos=model.eos)
|
| 463 |
+
scorers.update(ctc=ctc)
|
| 464 |
+
weights = dict(
|
| 465 |
+
decoder=1.0 - args.ctc_weight,
|
| 466 |
+
ctc=args.ctc_weight,
|
| 467 |
+
length_bonus=args.penalty,
|
| 468 |
+
)
|
| 469 |
+
beam_search = BeamSearchCIF(
|
| 470 |
+
beam_size=args.beam_size,
|
| 471 |
+
weights=weights,
|
| 472 |
+
scorers=scorers,
|
| 473 |
+
sos=model.sos,
|
| 474 |
+
eos=model.eos,
|
| 475 |
+
vocab_size=model.vocab_size,
|
| 476 |
+
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
|
| 477 |
+
)
|
| 478 |
+
beam_search.to(device=device, dtype=torch.float32).eval()
|
| 479 |
+
return beam_search
|
Amphion/modules/wenet_extractor/paraformer/search/ctc_prefix_score.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
|
| 26 |
+
import six
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CTCPrefixScore(object):
|
| 30 |
+
"""Compute CTC label sequence scores
|
| 31 |
+
|
| 32 |
+
which is based on Algorithm 2 in WATANABE et al.
|
| 33 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
| 34 |
+
but extended to efficiently compute the probablities of multiple labels
|
| 35 |
+
simultaneously
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, x, blank, eos, xp):
|
| 39 |
+
self.xp = xp
|
| 40 |
+
self.logzero = -10000000000.0
|
| 41 |
+
self.blank = blank
|
| 42 |
+
self.eos = eos
|
| 43 |
+
self.input_length = len(x)
|
| 44 |
+
self.x = x
|
| 45 |
+
|
| 46 |
+
def initial_state(self):
|
| 47 |
+
"""Obtain an initial CTC state
|
| 48 |
+
|
| 49 |
+
:return: CTC state
|
| 50 |
+
"""
|
| 51 |
+
# initial CTC state is made of a frame x 2 tensor that corresponds to
|
| 52 |
+
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
|
| 53 |
+
# superscripts n and b (non-blank and blank), respectively.
|
| 54 |
+
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
|
| 55 |
+
r[0, 1] = self.x[0, self.blank]
|
| 56 |
+
for i in six.moves.range(1, self.input_length):
|
| 57 |
+
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
|
| 58 |
+
return r
|
| 59 |
+
|
| 60 |
+
def __call__(self, y, cs, r_prev):
|
| 61 |
+
"""Compute CTC prefix scores for next labels
|
| 62 |
+
|
| 63 |
+
:param y : prefix label sequence
|
| 64 |
+
:param cs : array of next labels
|
| 65 |
+
:param r_prev: previous CTC state
|
| 66 |
+
:return ctc_scores, ctc_states
|
| 67 |
+
"""
|
| 68 |
+
# initialize CTC states
|
| 69 |
+
output_length = len(y) - 1 # ignore sos
|
| 70 |
+
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
|
| 71 |
+
# that corresponds to r_t^n(h) and r_t^b(h).
|
| 72 |
+
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
|
| 73 |
+
xs = self.x[:, cs]
|
| 74 |
+
if output_length == 0:
|
| 75 |
+
r[0, 0] = xs[0]
|
| 76 |
+
r[0, 1] = self.logzero
|
| 77 |
+
else:
|
| 78 |
+
r[output_length - 1] = self.logzero
|
| 79 |
+
|
| 80 |
+
# prepare forward probabilities for the last label
|
| 81 |
+
r_sum = self.xp.logaddexp(
|
| 82 |
+
r_prev[:, 0], r_prev[:, 1]
|
| 83 |
+
) # log(r_t^n(g) + r_t^b(g))
|
| 84 |
+
last = y[-1]
|
| 85 |
+
if output_length > 0 and last in cs:
|
| 86 |
+
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
|
| 87 |
+
for i in six.moves.range(len(cs)):
|
| 88 |
+
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
|
| 89 |
+
else:
|
| 90 |
+
log_phi = r_sum
|
| 91 |
+
|
| 92 |
+
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
|
| 93 |
+
# and log prefix probabilities log(psi)
|
| 94 |
+
start = max(output_length, 1)
|
| 95 |
+
log_psi = r[start - 1, 0]
|
| 96 |
+
for t in six.moves.range(start, self.input_length):
|
| 97 |
+
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
|
| 98 |
+
r[t, 1] = (
|
| 99 |
+
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
|
| 100 |
+
)
|
| 101 |
+
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
|
| 102 |
+
|
| 103 |
+
# get P(...eos|X) that ends with the prefix itself
|
| 104 |
+
eos_pos = self.xp.where(cs == self.eos)[0]
|
| 105 |
+
if len(eos_pos) > 0:
|
| 106 |
+
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
|
| 107 |
+
|
| 108 |
+
# exclude blank probs
|
| 109 |
+
blank_pos = self.xp.where(cs == self.blank)[0]
|
| 110 |
+
if len(blank_pos) > 0:
|
| 111 |
+
log_psi[blank_pos] = self.logzero
|
| 112 |
+
|
| 113 |
+
# return the log prefix probability and CTC states, where the label axis
|
| 114 |
+
# of the CTC states is moved to the first axis to slice it easily
|
| 115 |
+
return log_psi, self.xp.rollaxis(r, 2)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class CTCPrefixScoreTH(object):
|
| 119 |
+
"""Batch processing of CTCPrefixScore
|
| 120 |
+
|
| 121 |
+
which is based on Algorithm 2 in WATANABE et al.
|
| 122 |
+
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
|
| 123 |
+
but extended to efficiently compute the label probablities for multiple
|
| 124 |
+
hypotheses simultaneously
|
| 125 |
+
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
|
| 126 |
+
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, x, xlens, blank, eos, margin=0):
|
| 130 |
+
"""Construct CTC prefix scorer
|
| 131 |
+
|
| 132 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
| 133 |
+
:param torch.Tensor xlens: input lengths (B,)
|
| 134 |
+
:param int blank: blank label id
|
| 135 |
+
:param int eos: end-of-sequence id
|
| 136 |
+
:param int margin: margin parameter for windowing (0 means no windowing)
|
| 137 |
+
"""
|
| 138 |
+
# In the comment lines,
|
| 139 |
+
# we assume T: input_length, B: batch size, W: beam width, O: output dim
|
| 140 |
+
self.logzero = -10000000000.0
|
| 141 |
+
self.blank = blank
|
| 142 |
+
self.eos = eos
|
| 143 |
+
self.batch = x.size(0)
|
| 144 |
+
self.input_length = x.size(1)
|
| 145 |
+
self.odim = x.size(2)
|
| 146 |
+
self.dtype = x.dtype
|
| 147 |
+
self.device = (
|
| 148 |
+
torch.device("cuda:%d" % x.get_device())
|
| 149 |
+
if x.is_cuda
|
| 150 |
+
else torch.device("cpu")
|
| 151 |
+
)
|
| 152 |
+
# Pad the rest of posteriors in the batch
|
| 153 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
| 154 |
+
for i, l in enumerate(xlens):
|
| 155 |
+
if l < self.input_length:
|
| 156 |
+
x[i, l:, :] = self.logzero
|
| 157 |
+
x[i, l:, blank] = 0
|
| 158 |
+
# Reshape input x
|
| 159 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
| 160 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
| 161 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
| 162 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
| 163 |
+
|
| 164 |
+
# Setup CTC windowing
|
| 165 |
+
self.margin = margin
|
| 166 |
+
if margin > 0:
|
| 167 |
+
self.frame_ids = torch.arange(
|
| 168 |
+
self.input_length, dtype=self.dtype, device=self.device
|
| 169 |
+
)
|
| 170 |
+
# Base indices for index conversion
|
| 171 |
+
self.idx_bh = None
|
| 172 |
+
self.idx_b = torch.arange(self.batch, device=self.device)
|
| 173 |
+
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
|
| 174 |
+
|
| 175 |
+
def __call__(self, y, state, scoring_ids=None, att_w=None):
|
| 176 |
+
"""Compute CTC prefix scores for next labels
|
| 177 |
+
|
| 178 |
+
:param list y: prefix label sequences
|
| 179 |
+
:param tuple state: previous CTC state
|
| 180 |
+
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses
|
| 181 |
+
(BW, O)
|
| 182 |
+
:param torch.Tensor att_w: attention weights to decide CTC window
|
| 183 |
+
:return new_state, ctc_local_scores (BW, O)
|
| 184 |
+
"""
|
| 185 |
+
output_length = len(y[0]) - 1 # ignore sos
|
| 186 |
+
last_ids = [yi[-1] for yi in y] # last output label ids
|
| 187 |
+
n_bh = len(last_ids) # batch * hyps
|
| 188 |
+
n_hyps = n_bh // self.batch # assuming each utterance has the same
|
| 189 |
+
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
|
| 190 |
+
# prepare state info
|
| 191 |
+
if state is None:
|
| 192 |
+
r_prev = torch.full(
|
| 193 |
+
(self.input_length, 2, self.batch, n_hyps),
|
| 194 |
+
self.logzero,
|
| 195 |
+
dtype=self.dtype,
|
| 196 |
+
device=self.device,
|
| 197 |
+
)
|
| 198 |
+
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
|
| 199 |
+
r_prev = r_prev.view(-1, 2, n_bh)
|
| 200 |
+
s_prev = 0.0
|
| 201 |
+
f_min_prev = 0
|
| 202 |
+
f_max_prev = 1
|
| 203 |
+
else:
|
| 204 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
| 205 |
+
|
| 206 |
+
# select input dimensions for scoring
|
| 207 |
+
if self.scoring_num > 0:
|
| 208 |
+
scoring_idmap = torch.full(
|
| 209 |
+
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
|
| 210 |
+
)
|
| 211 |
+
snum = self.scoring_num
|
| 212 |
+
if self.idx_bh is None or n_bh > len(self.idx_bh):
|
| 213 |
+
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
|
| 214 |
+
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
|
| 215 |
+
snum, device=self.device
|
| 216 |
+
)
|
| 217 |
+
scoring_idx = (
|
| 218 |
+
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
|
| 219 |
+
).view(-1)
|
| 220 |
+
x_ = torch.index_select(
|
| 221 |
+
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
|
| 222 |
+
).view(2, -1, n_bh, snum)
|
| 223 |
+
else:
|
| 224 |
+
scoring_ids = None
|
| 225 |
+
scoring_idmap = None
|
| 226 |
+
snum = self.odim
|
| 227 |
+
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
|
| 228 |
+
|
| 229 |
+
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
|
| 230 |
+
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
|
| 231 |
+
r = torch.full(
|
| 232 |
+
(self.input_length, 2, n_bh, snum),
|
| 233 |
+
self.logzero,
|
| 234 |
+
dtype=self.dtype,
|
| 235 |
+
device=self.device,
|
| 236 |
+
)
|
| 237 |
+
if output_length == 0:
|
| 238 |
+
r[0, 0] = x_[0, 0]
|
| 239 |
+
|
| 240 |
+
r_sum = torch.logsumexp(r_prev, 1)
|
| 241 |
+
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
|
| 242 |
+
if scoring_ids is not None:
|
| 243 |
+
for idx in range(n_bh):
|
| 244 |
+
pos = scoring_idmap[idx, last_ids[idx]]
|
| 245 |
+
if pos >= 0:
|
| 246 |
+
log_phi[:, idx, pos] = r_prev[:, 1, idx]
|
| 247 |
+
else:
|
| 248 |
+
for idx in range(n_bh):
|
| 249 |
+
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
|
| 250 |
+
|
| 251 |
+
# decide start and end frames based on attention weights
|
| 252 |
+
if att_w is not None and self.margin > 0:
|
| 253 |
+
f_arg = torch.matmul(att_w, self.frame_ids)
|
| 254 |
+
f_min = max(int(f_arg.min().cpu()), f_min_prev)
|
| 255 |
+
f_max = max(int(f_arg.max().cpu()), f_max_prev)
|
| 256 |
+
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
|
| 257 |
+
end = min(f_max + self.margin, self.input_length)
|
| 258 |
+
else:
|
| 259 |
+
f_min = f_max = 0
|
| 260 |
+
start = max(output_length, 1)
|
| 261 |
+
end = self.input_length
|
| 262 |
+
|
| 263 |
+
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
|
| 264 |
+
for t in range(start, end):
|
| 265 |
+
rp = r[t - 1]
|
| 266 |
+
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
|
| 267 |
+
2, 2, n_bh, snum
|
| 268 |
+
)
|
| 269 |
+
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
|
| 270 |
+
|
| 271 |
+
# compute log prefix probabilities log(psi)
|
| 272 |
+
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
|
| 273 |
+
if scoring_ids is not None:
|
| 274 |
+
log_psi = torch.full(
|
| 275 |
+
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
|
| 276 |
+
)
|
| 277 |
+
log_psi_ = torch.logsumexp(
|
| 278 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
| 279 |
+
dim=0,
|
| 280 |
+
)
|
| 281 |
+
for si in range(n_bh):
|
| 282 |
+
log_psi[si, scoring_ids[si]] = log_psi_[si]
|
| 283 |
+
else:
|
| 284 |
+
log_psi = torch.logsumexp(
|
| 285 |
+
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
|
| 286 |
+
dim=0,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
for si in range(n_bh):
|
| 290 |
+
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
|
| 291 |
+
|
| 292 |
+
# exclude blank probs
|
| 293 |
+
log_psi[:, self.blank] = self.logzero
|
| 294 |
+
|
| 295 |
+
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
|
| 296 |
+
|
| 297 |
+
def index_select_state(self, state, best_ids):
|
| 298 |
+
"""Select CTC states according to best ids
|
| 299 |
+
|
| 300 |
+
:param state : CTC state
|
| 301 |
+
:param best_ids : index numbers selected by beam pruning (B, W)
|
| 302 |
+
:return selected_state
|
| 303 |
+
"""
|
| 304 |
+
r, s, f_min, f_max, scoring_idmap = state
|
| 305 |
+
# convert ids to BHO space
|
| 306 |
+
n_bh = len(s)
|
| 307 |
+
n_hyps = n_bh // self.batch
|
| 308 |
+
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
|
| 309 |
+
# select hypothesis scores
|
| 310 |
+
s_new = torch.index_select(s.view(-1), 0, vidx)
|
| 311 |
+
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
|
| 312 |
+
# convert ids to BHS space (S: scoring_num)
|
| 313 |
+
if scoring_idmap is not None:
|
| 314 |
+
snum = self.scoring_num
|
| 315 |
+
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
|
| 316 |
+
-1
|
| 317 |
+
)
|
| 318 |
+
label_ids = torch.fmod(best_ids, self.odim).view(-1)
|
| 319 |
+
score_idx = scoring_idmap[hyp_idx, label_ids]
|
| 320 |
+
score_idx[score_idx == -1] = 0
|
| 321 |
+
vidx = score_idx + hyp_idx * snum
|
| 322 |
+
else:
|
| 323 |
+
snum = self.odim
|
| 324 |
+
# select forward probabilities
|
| 325 |
+
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
|
| 326 |
+
-1, 2, n_bh
|
| 327 |
+
)
|
| 328 |
+
return r_new, s_new, f_min, f_max
|
| 329 |
+
|
| 330 |
+
def extend_prob(self, x):
|
| 331 |
+
"""Extend CTC prob.
|
| 332 |
+
|
| 333 |
+
:param torch.Tensor x: input label posterior sequences (B, T, O)
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
|
| 337 |
+
# Pad the rest of posteriors in the batch
|
| 338 |
+
# TODO(takaaki-hori): need a better way without for-loops
|
| 339 |
+
xlens = [x.size(1)]
|
| 340 |
+
for i, l in enumerate(xlens):
|
| 341 |
+
if l < self.input_length:
|
| 342 |
+
x[i, l:, :] = self.logzero
|
| 343 |
+
x[i, l:, self.blank] = 0
|
| 344 |
+
tmp_x = self.x
|
| 345 |
+
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
|
| 346 |
+
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
|
| 347 |
+
self.x = torch.stack([xn, xb]) # (2, T, B, O)
|
| 348 |
+
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
|
| 349 |
+
self.input_length = x.size(1)
|
| 350 |
+
self.end_frames = torch.as_tensor(xlens) - 1
|
| 351 |
+
|
| 352 |
+
def extend_state(self, state):
|
| 353 |
+
"""Compute CTC prefix state.
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
:param state : CTC state
|
| 357 |
+
:return ctc_state
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
if state is None:
|
| 361 |
+
# nothing to do
|
| 362 |
+
return state
|
| 363 |
+
else:
|
| 364 |
+
r_prev, s_prev, f_min_prev, f_max_prev = state
|
| 365 |
+
|
| 366 |
+
r_prev_new = torch.full(
|
| 367 |
+
(self.input_length, 2),
|
| 368 |
+
self.logzero,
|
| 369 |
+
dtype=self.dtype,
|
| 370 |
+
device=self.device,
|
| 371 |
+
)
|
| 372 |
+
start = max(r_prev.shape[0], 1)
|
| 373 |
+
r_prev_new[0:start] = r_prev
|
| 374 |
+
for t in six.moves.range(start, self.input_length):
|
| 375 |
+
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
|
| 376 |
+
|
| 377 |
+
return r_prev_new, s_prev, f_min_prev, f_max_prev
|
Amphion/modules/wenet_extractor/squeezeformer/positionwise_feed_forward.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
"""Positionwise feed forward layer definition."""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
| 29 |
+
"""Positionwise feed forward layer.
|
| 30 |
+
|
| 31 |
+
FeedForward are appied on each position of the sequence.
|
| 32 |
+
The output dim is same with the input dim.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
idim (int): Input dimenstion.
|
| 36 |
+
hidden_units (int): The number of hidden units.
|
| 37 |
+
dropout_rate (float): Dropout rate.
|
| 38 |
+
activation (torch.nn.Module): Activation function
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
idim: int,
|
| 44 |
+
hidden_units: int,
|
| 45 |
+
dropout_rate: float,
|
| 46 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
| 47 |
+
adaptive_scale: bool = False,
|
| 48 |
+
init_weights: bool = False,
|
| 49 |
+
):
|
| 50 |
+
"""Construct a PositionwiseFeedForward object."""
|
| 51 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 52 |
+
self.idim = idim
|
| 53 |
+
self.hidden_units = hidden_units
|
| 54 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
| 55 |
+
self.activation = activation
|
| 56 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
| 57 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
| 58 |
+
self.ada_scale = None
|
| 59 |
+
self.ada_bias = None
|
| 60 |
+
self.adaptive_scale = adaptive_scale
|
| 61 |
+
self.ada_scale = torch.nn.Parameter(
|
| 62 |
+
torch.ones([1, 1, idim]), requires_grad=adaptive_scale
|
| 63 |
+
)
|
| 64 |
+
self.ada_bias = torch.nn.Parameter(
|
| 65 |
+
torch.zeros([1, 1, idim]), requires_grad=adaptive_scale
|
| 66 |
+
)
|
| 67 |
+
if init_weights:
|
| 68 |
+
self.init_weights()
|
| 69 |
+
|
| 70 |
+
def init_weights(self):
|
| 71 |
+
ffn1_max = self.idim**-0.5
|
| 72 |
+
ffn2_max = self.hidden_units**-0.5
|
| 73 |
+
torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max)
|
| 74 |
+
torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max)
|
| 75 |
+
torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max)
|
| 76 |
+
torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max)
|
| 77 |
+
|
| 78 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Forward function.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
xs: input tensor (B, L, D)
|
| 83 |
+
Returns:
|
| 84 |
+
output tensor, (B, L, D)
|
| 85 |
+
"""
|
| 86 |
+
if self.adaptive_scale:
|
| 87 |
+
xs = self.ada_scale * xs + self.ada_bias
|
| 88 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
Amphion/modules/wenet_extractor/transformer/decoder_layer.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
"""Decoder self-attention layer definition."""
|
| 24 |
+
from typing import Optional, Tuple
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
from torch import nn
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DecoderLayer(nn.Module):
|
| 31 |
+
"""Single decoder layer module.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
size (int): Input dimension.
|
| 35 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
| 36 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
| 37 |
+
src_attn (torch.nn.Module): Inter-attention module instance.
|
| 38 |
+
`MultiHeadedAttention` instance can be used as the argument.
|
| 39 |
+
If `None` is passed, Inter-attention is not used, such as
|
| 40 |
+
CIF, GPT, and other decoder only model.
|
| 41 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
| 42 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
| 43 |
+
dropout_rate (float): Dropout rate.
|
| 44 |
+
normalize_before (bool):
|
| 45 |
+
True: use layer_norm before each sub-block.
|
| 46 |
+
False: to use layer_norm after each sub-block.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
size: int,
|
| 52 |
+
self_attn: nn.Module,
|
| 53 |
+
src_attn: Optional[nn.Module],
|
| 54 |
+
feed_forward: nn.Module,
|
| 55 |
+
dropout_rate: float,
|
| 56 |
+
normalize_before: bool = True,
|
| 57 |
+
):
|
| 58 |
+
"""Construct an DecoderLayer object."""
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.size = size
|
| 61 |
+
self.self_attn = self_attn
|
| 62 |
+
self.src_attn = src_attn
|
| 63 |
+
self.feed_forward = feed_forward
|
| 64 |
+
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
| 65 |
+
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
| 66 |
+
self.norm3 = nn.LayerNorm(size, eps=1e-5)
|
| 67 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 68 |
+
self.normalize_before = normalize_before
|
| 69 |
+
|
| 70 |
+
def forward(
|
| 71 |
+
self,
|
| 72 |
+
tgt: torch.Tensor,
|
| 73 |
+
tgt_mask: torch.Tensor,
|
| 74 |
+
memory: torch.Tensor,
|
| 75 |
+
memory_mask: torch.Tensor,
|
| 76 |
+
cache: Optional[torch.Tensor] = None,
|
| 77 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 78 |
+
"""Compute decoded features.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
| 82 |
+
tgt_mask (torch.Tensor): Mask for input tensor
|
| 83 |
+
(#batch, maxlen_out).
|
| 84 |
+
memory (torch.Tensor): Encoded memory
|
| 85 |
+
(#batch, maxlen_in, size).
|
| 86 |
+
memory_mask (torch.Tensor): Encoded memory mask
|
| 87 |
+
(#batch, maxlen_in).
|
| 88 |
+
cache (torch.Tensor): cached tensors.
|
| 89 |
+
(#batch, maxlen_out - 1, size).
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
torch.Tensor: Output tensor (#batch, maxlen_out, size).
|
| 93 |
+
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
| 94 |
+
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
| 95 |
+
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
| 96 |
+
|
| 97 |
+
"""
|
| 98 |
+
residual = tgt
|
| 99 |
+
if self.normalize_before:
|
| 100 |
+
tgt = self.norm1(tgt)
|
| 101 |
+
|
| 102 |
+
if cache is None:
|
| 103 |
+
tgt_q = tgt
|
| 104 |
+
tgt_q_mask = tgt_mask
|
| 105 |
+
else:
|
| 106 |
+
# compute only the last frame query keeping dim: max_time_out -> 1
|
| 107 |
+
assert cache.shape == (
|
| 108 |
+
tgt.shape[0],
|
| 109 |
+
tgt.shape[1] - 1,
|
| 110 |
+
self.size,
|
| 111 |
+
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
| 112 |
+
tgt_q = tgt[:, -1:, :]
|
| 113 |
+
residual = residual[:, -1:, :]
|
| 114 |
+
tgt_q_mask = tgt_mask[:, -1:, :]
|
| 115 |
+
|
| 116 |
+
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
|
| 117 |
+
if not self.normalize_before:
|
| 118 |
+
x = self.norm1(x)
|
| 119 |
+
|
| 120 |
+
if self.src_attn is not None:
|
| 121 |
+
residual = x
|
| 122 |
+
if self.normalize_before:
|
| 123 |
+
x = self.norm2(x)
|
| 124 |
+
x = residual + self.dropout(
|
| 125 |
+
self.src_attn(x, memory, memory, memory_mask)[0]
|
| 126 |
+
)
|
| 127 |
+
if not self.normalize_before:
|
| 128 |
+
x = self.norm2(x)
|
| 129 |
+
|
| 130 |
+
residual = x
|
| 131 |
+
if self.normalize_before:
|
| 132 |
+
x = self.norm3(x)
|
| 133 |
+
x = residual + self.dropout(self.feed_forward(x))
|
| 134 |
+
if not self.normalize_before:
|
| 135 |
+
x = self.norm3(x)
|
| 136 |
+
|
| 137 |
+
if cache is not None:
|
| 138 |
+
x = torch.cat([cache, x], dim=1)
|
| 139 |
+
|
| 140 |
+
return x, tgt_mask, memory, memory_mask
|
Amphion/modules/wenet_extractor/transformer/subsampling.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
|
| 2 |
+
|
| 3 |
+
# ## Citations
|
| 4 |
+
|
| 5 |
+
# ```bibtex
|
| 6 |
+
# @inproceedings{yao2021wenet,
|
| 7 |
+
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
|
| 8 |
+
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
|
| 9 |
+
# booktitle={Proc. Interspeech},
|
| 10 |
+
# year={2021},
|
| 11 |
+
# address={Brno, Czech Republic },
|
| 12 |
+
# organization={IEEE}
|
| 13 |
+
# }
|
| 14 |
+
|
| 15 |
+
# @article{zhang2022wenet,
|
| 16 |
+
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
|
| 17 |
+
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
|
| 18 |
+
# journal={arXiv preprint arXiv:2203.15455},
|
| 19 |
+
# year={2022}
|
| 20 |
+
# }
|
| 21 |
+
#
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
"""Subsampling layer definition."""
|
| 25 |
+
|
| 26 |
+
from typing import Tuple, Union
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseSubsampling(torch.nn.Module):
|
| 32 |
+
def __init__(self):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.right_context = 0
|
| 35 |
+
self.subsampling_rate = 1
|
| 36 |
+
|
| 37 |
+
def position_encoding(
|
| 38 |
+
self, offset: Union[int, torch.Tensor], size: int
|
| 39 |
+
) -> torch.Tensor:
|
| 40 |
+
return self.pos_enc.position_encoding(offset, size)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LinearNoSubsampling(BaseSubsampling):
|
| 44 |
+
"""Linear transform the input without subsampling
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
idim (int): Input dimension.
|
| 48 |
+
odim (int): Output dimension.
|
| 49 |
+
dropout_rate (float): Dropout rate.
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
|
| 55 |
+
):
|
| 56 |
+
"""Construct an linear object."""
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.out = torch.nn.Sequential(
|
| 59 |
+
torch.nn.Linear(idim, odim),
|
| 60 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
| 61 |
+
torch.nn.Dropout(dropout_rate),
|
| 62 |
+
)
|
| 63 |
+
self.pos_enc = pos_enc_class
|
| 64 |
+
self.right_context = 0
|
| 65 |
+
self.subsampling_rate = 1
|
| 66 |
+
|
| 67 |
+
def forward(
|
| 68 |
+
self,
|
| 69 |
+
x: torch.Tensor,
|
| 70 |
+
x_mask: torch.Tensor,
|
| 71 |
+
offset: Union[int, torch.Tensor] = 0,
|
| 72 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 73 |
+
"""Input x.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 77 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
| 81 |
+
where time' = time .
|
| 82 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
| 83 |
+
where time' = time .
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
x = self.out(x)
|
| 87 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 88 |
+
return x, pos_emb, x_mask
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class Conv2dSubsampling4(BaseSubsampling):
|
| 92 |
+
"""Convolutional 2D subsampling (to 1/4 length).
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
idim (int): Input dimension.
|
| 96 |
+
odim (int): Output dimension.
|
| 97 |
+
dropout_rate (float): Dropout rate.
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
|
| 103 |
+
):
|
| 104 |
+
"""Construct an Conv2dSubsampling4 object."""
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.conv = torch.nn.Sequential(
|
| 107 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 108 |
+
torch.nn.ReLU(),
|
| 109 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 110 |
+
torch.nn.ReLU(),
|
| 111 |
+
)
|
| 112 |
+
self.out = torch.nn.Sequential(
|
| 113 |
+
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
| 114 |
+
)
|
| 115 |
+
self.pos_enc = pos_enc_class
|
| 116 |
+
# The right context for every conv layer is computed by:
|
| 117 |
+
# (kernel_size - 1) * frame_rate_of_this_layer
|
| 118 |
+
self.subsampling_rate = 4
|
| 119 |
+
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
| 120 |
+
self.right_context = 6
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
x: torch.Tensor,
|
| 125 |
+
x_mask: torch.Tensor,
|
| 126 |
+
offset: Union[int, torch.Tensor] = 0,
|
| 127 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 128 |
+
"""Subsample x.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 132 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 136 |
+
where time' = time // 4.
|
| 137 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 138 |
+
where time' = time // 4.
|
| 139 |
+
torch.Tensor: positional encoding
|
| 140 |
+
|
| 141 |
+
"""
|
| 142 |
+
x = x.unsqueeze(1) # (b, c=1, t, f)
|
| 143 |
+
x = self.conv(x)
|
| 144 |
+
b, c, t, f = x.size()
|
| 145 |
+
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 146 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 147 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Conv2dSubsampling6(BaseSubsampling):
|
| 151 |
+
"""Convolutional 2D subsampling (to 1/6 length).
|
| 152 |
+
Args:
|
| 153 |
+
idim (int): Input dimension.
|
| 154 |
+
odim (int): Output dimension.
|
| 155 |
+
dropout_rate (float): Dropout rate.
|
| 156 |
+
pos_enc (torch.nn.Module): Custom position encoding layer.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
|
| 161 |
+
):
|
| 162 |
+
"""Construct an Conv2dSubsampling6 object."""
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.conv = torch.nn.Sequential(
|
| 165 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 166 |
+
torch.nn.ReLU(),
|
| 167 |
+
torch.nn.Conv2d(odim, odim, 5, 3),
|
| 168 |
+
torch.nn.ReLU(),
|
| 169 |
+
)
|
| 170 |
+
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim)
|
| 171 |
+
self.pos_enc = pos_enc_class
|
| 172 |
+
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
| 173 |
+
self.subsampling_rate = 6
|
| 174 |
+
self.right_context = 10
|
| 175 |
+
|
| 176 |
+
def forward(
|
| 177 |
+
self,
|
| 178 |
+
x: torch.Tensor,
|
| 179 |
+
x_mask: torch.Tensor,
|
| 180 |
+
offset: Union[int, torch.Tensor] = 0,
|
| 181 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 182 |
+
"""Subsample x.
|
| 183 |
+
Args:
|
| 184 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 185 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 189 |
+
where time' = time // 6.
|
| 190 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 191 |
+
where time' = time // 6.
|
| 192 |
+
torch.Tensor: positional encoding
|
| 193 |
+
"""
|
| 194 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 195 |
+
x = self.conv(x)
|
| 196 |
+
b, c, t, f = x.size()
|
| 197 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 198 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 199 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class Conv2dSubsampling8(BaseSubsampling):
|
| 203 |
+
"""Convolutional 2D subsampling (to 1/8 length).
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
idim (int): Input dimension.
|
| 207 |
+
odim (int): Output dimension.
|
| 208 |
+
dropout_rate (float): Dropout rate.
|
| 209 |
+
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(
|
| 213 |
+
self, idim: int, odim: int, dropout_rate: float, pos_enc_class: torch.nn.Module
|
| 214 |
+
):
|
| 215 |
+
"""Construct an Conv2dSubsampling8 object."""
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.conv = torch.nn.Sequential(
|
| 218 |
+
torch.nn.Conv2d(1, odim, 3, 2),
|
| 219 |
+
torch.nn.ReLU(),
|
| 220 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 221 |
+
torch.nn.ReLU(),
|
| 222 |
+
torch.nn.Conv2d(odim, odim, 3, 2),
|
| 223 |
+
torch.nn.ReLU(),
|
| 224 |
+
)
|
| 225 |
+
self.linear = torch.nn.Linear(
|
| 226 |
+
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim
|
| 227 |
+
)
|
| 228 |
+
self.pos_enc = pos_enc_class
|
| 229 |
+
self.subsampling_rate = 8
|
| 230 |
+
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
| 231 |
+
self.right_context = 14
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self,
|
| 235 |
+
x: torch.Tensor,
|
| 236 |
+
x_mask: torch.Tensor,
|
| 237 |
+
offset: Union[int, torch.Tensor] = 0,
|
| 238 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 239 |
+
"""Subsample x.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
| 243 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
| 247 |
+
where time' = time // 8.
|
| 248 |
+
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
| 249 |
+
where time' = time // 8.
|
| 250 |
+
torch.Tensor: positional encoding
|
| 251 |
+
"""
|
| 252 |
+
x = x.unsqueeze(1) # (b, c, t, f)
|
| 253 |
+
x = self.conv(x)
|
| 254 |
+
b, c, t, f = x.size()
|
| 255 |
+
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
| 256 |
+
x, pos_emb = self.pos_enc(x, offset)
|
| 257 |
+
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
Amphion/modules/wenet_extractor/utils/__init__.py
ADDED
|
File without changes
|
Amphion/preprocessors/cdmusiceval.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from glob import glob
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import torchaudio
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
|
| 13 |
+
from utils.util import has_existed, remove_and_create
|
| 14 |
+
from utils.audio_slicer import split_utterances_from_audio
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def split_to_utterances(input_dir, output_dir):
|
| 18 |
+
print("Splitting to utterances for {}...".format(input_dir))
|
| 19 |
+
|
| 20 |
+
files_list = glob("*", root_dir=input_dir)
|
| 21 |
+
files_list.sort()
|
| 22 |
+
for wav_file in tqdm(files_list):
|
| 23 |
+
# # Load waveform
|
| 24 |
+
# waveform, fs = torchaudio.load(os.path.join(input_dir, wav_file))
|
| 25 |
+
|
| 26 |
+
# Singer name, Song name
|
| 27 |
+
song_name, singer_name = wav_file.split("_")[2].split("-")
|
| 28 |
+
save_dir = os.path.join(output_dir, singer_name, song_name)
|
| 29 |
+
|
| 30 |
+
split_utterances_from_audio(
|
| 31 |
+
os.path.join(input_dir, wav_file), save_dir, max_duration_of_utterance=10
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# # Split
|
| 35 |
+
# slicer = Slicer(sr=fs, threshold=-30.0, max_sil_kept=3000, min_interval=1000)
|
| 36 |
+
# chunks = slicer.slice(waveform)
|
| 37 |
+
|
| 38 |
+
# for i, chunk in enumerate(chunks):
|
| 39 |
+
# save_dir = os.path.join(output_dir, singer_name, song_name)
|
| 40 |
+
# os.makedirs(save_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# output_file = os.path.join(save_dir, "{:04d}.wav".format(i))
|
| 43 |
+
# save_audio(output_file, chunk, fs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _main(dataset_path):
|
| 47 |
+
"""
|
| 48 |
+
Split to utterances
|
| 49 |
+
"""
|
| 50 |
+
utterance_dir = os.path.join(dataset_path, "utterances")
|
| 51 |
+
remove_and_create(utterance_dir)
|
| 52 |
+
split_to_utterances(os.path.join(dataset_path, "vocal"), utterance_dir)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def statistics(utterance_dir):
|
| 56 |
+
singers = []
|
| 57 |
+
songs = []
|
| 58 |
+
singers2songs = defaultdict(lambda: defaultdict(list))
|
| 59 |
+
|
| 60 |
+
singer_infos = glob(utterance_dir + "/*")
|
| 61 |
+
|
| 62 |
+
for singer_info in singer_infos:
|
| 63 |
+
singer = singer_info.split("/")[-1]
|
| 64 |
+
|
| 65 |
+
song_infos = glob(singer_info + "/*")
|
| 66 |
+
|
| 67 |
+
for song_info in song_infos:
|
| 68 |
+
song = song_info.split("/")[-1]
|
| 69 |
+
|
| 70 |
+
singers.append(singer)
|
| 71 |
+
songs.append(song)
|
| 72 |
+
|
| 73 |
+
utts = glob(song_info + "/*.wav")
|
| 74 |
+
|
| 75 |
+
for utt in utts:
|
| 76 |
+
uid = utt.split("/")[-1].split(".")[0]
|
| 77 |
+
singers2songs[singer][song].append(uid)
|
| 78 |
+
|
| 79 |
+
unique_singers = list(set(singers))
|
| 80 |
+
unique_songs = list(set(songs))
|
| 81 |
+
unique_singers.sort()
|
| 82 |
+
unique_songs.sort()
|
| 83 |
+
|
| 84 |
+
print(
|
| 85 |
+
"Statistics: {} singers, {} utterances ({} unique songs)".format(
|
| 86 |
+
len(unique_singers), len(songs), len(unique_songs)
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
print("Singers: \n{}".format("\t".join(unique_singers)))
|
| 90 |
+
return singers2songs, unique_singers
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def main(output_path, dataset_path):
|
| 94 |
+
print("-" * 10)
|
| 95 |
+
print("Preparing samples for CD Music Eval...\n")
|
| 96 |
+
|
| 97 |
+
if not os.path.exists(os.path.join(dataset_path, "utterances")):
|
| 98 |
+
print("Spliting into utterances...\n")
|
| 99 |
+
_main(dataset_path)
|
| 100 |
+
|
| 101 |
+
save_dir = os.path.join(output_path, "cdmusiceval")
|
| 102 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 103 |
+
train_output_file = os.path.join(save_dir, "train.json")
|
| 104 |
+
test_output_file = os.path.join(save_dir, "test.json")
|
| 105 |
+
singer_dict_file = os.path.join(save_dir, "singers.json")
|
| 106 |
+
utt2singer_file = os.path.join(save_dir, "utt2singer")
|
| 107 |
+
if (
|
| 108 |
+
has_existed(train_output_file)
|
| 109 |
+
and has_existed(test_output_file)
|
| 110 |
+
and has_existed(singer_dict_file)
|
| 111 |
+
and has_existed(utt2singer_file)
|
| 112 |
+
):
|
| 113 |
+
return
|
| 114 |
+
utt2singer = open(utt2singer_file, "w")
|
| 115 |
+
|
| 116 |
+
# Load
|
| 117 |
+
utt_path = os.path.join(dataset_path, "utterances")
|
| 118 |
+
singers2songs, unique_singers = statistics(utt_path)
|
| 119 |
+
|
| 120 |
+
# We select songs of standard samples as test songs
|
| 121 |
+
train = []
|
| 122 |
+
test = []
|
| 123 |
+
|
| 124 |
+
train_index_count = 0
|
| 125 |
+
test_index_count = 0
|
| 126 |
+
|
| 127 |
+
train_total_duration = 0
|
| 128 |
+
test_total_duration = 0
|
| 129 |
+
|
| 130 |
+
for singer, songs in tqdm(singers2songs.items()):
|
| 131 |
+
song_names = list(songs.keys())
|
| 132 |
+
|
| 133 |
+
for chosen_song in song_names:
|
| 134 |
+
for chosen_uid in songs[chosen_song]:
|
| 135 |
+
res = {
|
| 136 |
+
"Dataset": "cdmusiceval",
|
| 137 |
+
"Singer": singer,
|
| 138 |
+
"Uid": "{}_{}_{}".format(singer, chosen_song, chosen_uid),
|
| 139 |
+
}
|
| 140 |
+
res["Path"] = "{}/{}/{}.wav".format(singer, chosen_song, chosen_uid)
|
| 141 |
+
res["Path"] = os.path.join(utt_path, res["Path"])
|
| 142 |
+
assert os.path.exists(res["Path"])
|
| 143 |
+
|
| 144 |
+
waveform, sample_rate = torchaudio.load(res["Path"])
|
| 145 |
+
duration = waveform.size(-1) / sample_rate
|
| 146 |
+
res["Duration"] = duration
|
| 147 |
+
|
| 148 |
+
if duration <= 1e-8:
|
| 149 |
+
continue
|
| 150 |
+
|
| 151 |
+
res["index"] = test_index_count
|
| 152 |
+
test_total_duration += duration
|
| 153 |
+
test.append(res)
|
| 154 |
+
test_index_count += 1
|
| 155 |
+
|
| 156 |
+
utt2singer.write("{}\t{}\n".format(res["Uid"], res["Singer"]))
|
| 157 |
+
|
| 158 |
+
print("#Train = {}, #Test = {}".format(len(train), len(test)))
|
| 159 |
+
print(
|
| 160 |
+
"#Train hours= {}, #Test hours= {}".format(
|
| 161 |
+
train_total_duration / 3600, test_total_duration / 3600
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Save train.json and test.json
|
| 166 |
+
with open(train_output_file, "w") as f:
|
| 167 |
+
json.dump(train, f, indent=4, ensure_ascii=False)
|
| 168 |
+
with open(test_output_file, "w") as f:
|
| 169 |
+
json.dump(test, f, indent=4, ensure_ascii=False)
|
| 170 |
+
|
| 171 |
+
# Save singers.json
|
| 172 |
+
singer_lut = {name: i for i, name in enumerate(unique_singers)}
|
| 173 |
+
with open(singer_dict_file, "w") as f:
|
| 174 |
+
json.dump(singer_lut, f, indent=4, ensure_ascii=False)
|
Amphion/utils/data_utils.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.interpolate import interp1d
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from sklearn.preprocessing import StandardScaler
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def intersperse(lst, item):
|
| 16 |
+
"""
|
| 17 |
+
Insert an item in between any two consecutive elements of the given list, including beginning and end of list
|
| 18 |
+
|
| 19 |
+
Example:
|
| 20 |
+
>>> intersperse(0, [1, 74, 5, 31])
|
| 21 |
+
[0, 1, 0, 74, 0, 5, 0, 31, 0]
|
| 22 |
+
"""
|
| 23 |
+
result = [item] * (len(lst) * 2 + 1)
|
| 24 |
+
result[1::2] = lst
|
| 25 |
+
return result
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_content_feature_path(meta_data, processed_dir, feat_dir):
|
| 29 |
+
utt2feat_path = {}
|
| 30 |
+
for utt_info in meta_data:
|
| 31 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 32 |
+
feat_path = os.path.join(
|
| 33 |
+
processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy'
|
| 34 |
+
)
|
| 35 |
+
utt2feat_path[utt] = feat_path
|
| 36 |
+
|
| 37 |
+
return utt2feat_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_source_content_feature_path(meta_data, feat_dir):
|
| 41 |
+
utt2feat_path = {}
|
| 42 |
+
for utt in meta_data:
|
| 43 |
+
feat_path = os.path.join(feat_dir, f"{utt}.npy")
|
| 44 |
+
utt2feat_path[utt] = feat_path
|
| 45 |
+
|
| 46 |
+
return utt2feat_path
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_spk_map(spk2id_path, utt2spk_path):
|
| 50 |
+
utt2spk = {}
|
| 51 |
+
with open(spk2id_path, "r") as spk2id_file:
|
| 52 |
+
spk2id = json.load(spk2id_file)
|
| 53 |
+
with open(utt2spk_path, encoding="utf-8") as f:
|
| 54 |
+
for line in f.readlines():
|
| 55 |
+
utt, spk = line.strip().split("\t")
|
| 56 |
+
utt2spk[utt] = spk
|
| 57 |
+
return spk2id, utt2spk
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_target_f0_median(f0_dir):
|
| 61 |
+
total_f0 = []
|
| 62 |
+
for utt in os.listdir(f0_dir):
|
| 63 |
+
if not utt.endswith(".npy"):
|
| 64 |
+
continue
|
| 65 |
+
f0_feat_path = os.path.join(f0_dir, utt)
|
| 66 |
+
f0 = np.load(f0_feat_path)
|
| 67 |
+
total_f0 += f0.tolist()
|
| 68 |
+
|
| 69 |
+
total_f0 = np.array(total_f0)
|
| 70 |
+
voiced_position = np.where(total_f0 != 0)
|
| 71 |
+
return np.median(total_f0[voiced_position])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_conversion_f0_factor(source_f0, target_median, source_median=None):
|
| 75 |
+
"""Align the median between source f0 and target f0
|
| 76 |
+
|
| 77 |
+
Note: Here we use multiplication, whose factor is target_median/source_median
|
| 78 |
+
|
| 79 |
+
Reference: Frequency and pitch interval
|
| 80 |
+
http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/
|
| 81 |
+
"""
|
| 82 |
+
if source_median is None:
|
| 83 |
+
voiced_position = np.where(source_f0 != 0)
|
| 84 |
+
source_median = np.median(source_f0[voiced_position])
|
| 85 |
+
factor = target_median / source_median
|
| 86 |
+
return source_median, factor
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def transpose_key(frame_pitch, trans_key):
|
| 90 |
+
# Transpose by user's argument
|
| 91 |
+
print("Transpose key = {} ...\n".format(trans_key))
|
| 92 |
+
|
| 93 |
+
transed_pitch = frame_pitch * 2 ** (trans_key / 12)
|
| 94 |
+
return transed_pitch
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None):
|
| 98 |
+
# Loading F0 Base (median) and shift
|
| 99 |
+
source_pitch_median, factor = get_conversion_f0_factor(
|
| 100 |
+
frame_pitch, target_pitch_median, source_pitch_median
|
| 101 |
+
)
|
| 102 |
+
print(
|
| 103 |
+
"Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format(
|
| 104 |
+
source_pitch_median, target_pitch_median, factor
|
| 105 |
+
)
|
| 106 |
+
)
|
| 107 |
+
transed_pitch = frame_pitch * factor
|
| 108 |
+
return transed_pitch
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def load_frame_pitch(
|
| 112 |
+
meta_data,
|
| 113 |
+
processed_dir,
|
| 114 |
+
pitch_dir,
|
| 115 |
+
use_log_scale=False,
|
| 116 |
+
return_norm=False,
|
| 117 |
+
interoperate=False,
|
| 118 |
+
utt2spk=None,
|
| 119 |
+
):
|
| 120 |
+
utt2pitch = {}
|
| 121 |
+
utt2uv = {}
|
| 122 |
+
if utt2spk is None:
|
| 123 |
+
pitch_scaler = StandardScaler()
|
| 124 |
+
for utt_info in meta_data:
|
| 125 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 126 |
+
pitch_path = os.path.join(
|
| 127 |
+
processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
|
| 128 |
+
)
|
| 129 |
+
pitch = np.load(pitch_path)
|
| 130 |
+
assert len(pitch) > 0
|
| 131 |
+
uv = pitch != 0
|
| 132 |
+
utt2uv[utt] = uv
|
| 133 |
+
if use_log_scale:
|
| 134 |
+
nonzero_idxes = np.where(pitch != 0)[0]
|
| 135 |
+
pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
|
| 136 |
+
utt2pitch[utt] = pitch
|
| 137 |
+
pitch_scaler.partial_fit(pitch.reshape(-1, 1))
|
| 138 |
+
|
| 139 |
+
mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
|
| 140 |
+
if return_norm:
|
| 141 |
+
for utt_info in meta_data:
|
| 142 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 143 |
+
pitch = utt2pitch[utt]
|
| 144 |
+
normalized_pitch = (pitch - mean) / std
|
| 145 |
+
utt2pitch[utt] = normalized_pitch
|
| 146 |
+
pitch_statistic = {"mean": mean, "std": std}
|
| 147 |
+
else:
|
| 148 |
+
spk2utt = {}
|
| 149 |
+
pitch_statistic = []
|
| 150 |
+
for utt_info in meta_data:
|
| 151 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 152 |
+
if not utt2spk[utt] in spk2utt:
|
| 153 |
+
spk2utt[utt2spk[utt]] = []
|
| 154 |
+
spk2utt[utt2spk[utt]].append(utt)
|
| 155 |
+
|
| 156 |
+
for spk in spk2utt:
|
| 157 |
+
pitch_scaler = StandardScaler()
|
| 158 |
+
for utt in spk2utt[spk]:
|
| 159 |
+
dataset = utt.split("_")[0]
|
| 160 |
+
uid = "_".join(utt.split("_")[1:])
|
| 161 |
+
pitch_path = os.path.join(
|
| 162 |
+
processed_dir, dataset, pitch_dir, f"{uid}.npy"
|
| 163 |
+
)
|
| 164 |
+
pitch = np.load(pitch_path)
|
| 165 |
+
assert len(pitch) > 0
|
| 166 |
+
uv = pitch != 0
|
| 167 |
+
utt2uv[utt] = uv
|
| 168 |
+
if use_log_scale:
|
| 169 |
+
nonzero_idxes = np.where(pitch != 0)[0]
|
| 170 |
+
pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
|
| 171 |
+
utt2pitch[utt] = pitch
|
| 172 |
+
pitch_scaler.partial_fit(pitch.reshape(-1, 1))
|
| 173 |
+
|
| 174 |
+
mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
|
| 175 |
+
if return_norm:
|
| 176 |
+
for utt in spk2utt[spk]:
|
| 177 |
+
pitch = utt2pitch[utt]
|
| 178 |
+
normalized_pitch = (pitch - mean) / std
|
| 179 |
+
utt2pitch[utt] = normalized_pitch
|
| 180 |
+
pitch_statistic.append({"spk": spk, "mean": mean, "std": std})
|
| 181 |
+
|
| 182 |
+
return utt2pitch, utt2uv, pitch_statistic
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# discard
|
| 186 |
+
def load_phone_pitch(
|
| 187 |
+
meta_data,
|
| 188 |
+
processed_dir,
|
| 189 |
+
pitch_dir,
|
| 190 |
+
utt2dur,
|
| 191 |
+
use_log_scale=False,
|
| 192 |
+
return_norm=False,
|
| 193 |
+
interoperate=True,
|
| 194 |
+
utt2spk=None,
|
| 195 |
+
):
|
| 196 |
+
print("Load Phone Pitch")
|
| 197 |
+
utt2pitch = {}
|
| 198 |
+
utt2uv = {}
|
| 199 |
+
if utt2spk is None:
|
| 200 |
+
pitch_scaler = StandardScaler()
|
| 201 |
+
for utt_info in tqdm(meta_data):
|
| 202 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 203 |
+
pitch_path = os.path.join(
|
| 204 |
+
processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
|
| 205 |
+
)
|
| 206 |
+
frame_pitch = np.load(pitch_path)
|
| 207 |
+
assert len(frame_pitch) > 0
|
| 208 |
+
uv = frame_pitch != 0
|
| 209 |
+
utt2uv[utt] = uv
|
| 210 |
+
phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate)
|
| 211 |
+
if use_log_scale:
|
| 212 |
+
nonzero_idxes = np.where(phone_pitch != 0)[0]
|
| 213 |
+
phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
|
| 214 |
+
utt2pitch[utt] = phone_pitch
|
| 215 |
+
pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
|
| 216 |
+
|
| 217 |
+
mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
|
| 218 |
+
max_value = np.finfo(np.float64).min
|
| 219 |
+
min_value = np.finfo(np.float64).max
|
| 220 |
+
if return_norm:
|
| 221 |
+
for utt_info in meta_data:
|
| 222 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 223 |
+
pitch = utt2pitch[utt]
|
| 224 |
+
normalized_pitch = (pitch - mean) / std
|
| 225 |
+
max_value = max(max_value, max(normalized_pitch))
|
| 226 |
+
min_value = min(min_value, min(normalized_pitch))
|
| 227 |
+
utt2pitch[utt] = normalized_pitch
|
| 228 |
+
phone_normalized_pitch_path = os.path.join(
|
| 229 |
+
processed_dir,
|
| 230 |
+
utt_info["Dataset"],
|
| 231 |
+
"phone_level_" + pitch_dir,
|
| 232 |
+
f'{utt_info["Uid"]}.npy',
|
| 233 |
+
)
|
| 234 |
+
pitch_statistic = {
|
| 235 |
+
"mean": mean,
|
| 236 |
+
"std": std,
|
| 237 |
+
"min_value": min_value,
|
| 238 |
+
"max_value": max_value,
|
| 239 |
+
}
|
| 240 |
+
else:
|
| 241 |
+
spk2utt = {}
|
| 242 |
+
pitch_statistic = []
|
| 243 |
+
for utt_info in tqdm(meta_data):
|
| 244 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 245 |
+
if not utt2spk[utt] in spk2utt:
|
| 246 |
+
spk2utt[utt2spk[utt]] = []
|
| 247 |
+
spk2utt[utt2spk[utt]].append(utt)
|
| 248 |
+
|
| 249 |
+
for spk in spk2utt:
|
| 250 |
+
pitch_scaler = StandardScaler()
|
| 251 |
+
for utt in spk2utt[spk]:
|
| 252 |
+
dataset = utt.split("_")[0]
|
| 253 |
+
uid = "_".join(utt.split("_")[1:])
|
| 254 |
+
pitch_path = os.path.join(
|
| 255 |
+
processed_dir, dataset, pitch_dir, f"{uid}.npy"
|
| 256 |
+
)
|
| 257 |
+
frame_pitch = np.load(pitch_path)
|
| 258 |
+
assert len(frame_pitch) > 0
|
| 259 |
+
uv = frame_pitch != 0
|
| 260 |
+
utt2uv[utt] = uv
|
| 261 |
+
phone_pitch = phone_average_pitch(
|
| 262 |
+
frame_pitch, utt2dur[utt], interoperate
|
| 263 |
+
)
|
| 264 |
+
if use_log_scale:
|
| 265 |
+
nonzero_idxes = np.where(phone_pitch != 0)[0]
|
| 266 |
+
phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
|
| 267 |
+
utt2pitch[utt] = phone_pitch
|
| 268 |
+
pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
|
| 269 |
+
|
| 270 |
+
mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
|
| 271 |
+
max_value = np.finfo(np.float64).min
|
| 272 |
+
min_value = np.finfo(np.float64).max
|
| 273 |
+
|
| 274 |
+
if return_norm:
|
| 275 |
+
for utt in spk2utt[spk]:
|
| 276 |
+
pitch = utt2pitch[utt]
|
| 277 |
+
normalized_pitch = (pitch - mean) / std
|
| 278 |
+
max_value = max(max_value, max(normalized_pitch))
|
| 279 |
+
min_value = min(min_value, min(normalized_pitch))
|
| 280 |
+
utt2pitch[utt] = normalized_pitch
|
| 281 |
+
pitch_statistic.append(
|
| 282 |
+
{
|
| 283 |
+
"spk": spk,
|
| 284 |
+
"mean": mean,
|
| 285 |
+
"std": std,
|
| 286 |
+
"min_value": min_value,
|
| 287 |
+
"max_value": max_value,
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return utt2pitch, utt2uv, pitch_statistic
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def phone_average_pitch(pitch, dur, interoperate=False):
|
| 295 |
+
pos = 0
|
| 296 |
+
|
| 297 |
+
if interoperate:
|
| 298 |
+
nonzero_ids = np.where(pitch != 0)[0]
|
| 299 |
+
interp_fn = interp1d(
|
| 300 |
+
nonzero_ids,
|
| 301 |
+
pitch[nonzero_ids],
|
| 302 |
+
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
|
| 303 |
+
bounds_error=False,
|
| 304 |
+
)
|
| 305 |
+
pitch = interp_fn(np.arange(0, len(pitch)))
|
| 306 |
+
phone_pitch = np.zeros(len(dur))
|
| 307 |
+
|
| 308 |
+
for i, d in enumerate(dur):
|
| 309 |
+
d = int(d)
|
| 310 |
+
if d > 0 and pos < len(pitch):
|
| 311 |
+
phone_pitch[i] = np.mean(pitch[pos : pos + d])
|
| 312 |
+
else:
|
| 313 |
+
phone_pitch[i] = 0
|
| 314 |
+
pos += d
|
| 315 |
+
return phone_pitch
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def load_energy(
|
| 319 |
+
meta_data,
|
| 320 |
+
processed_dir,
|
| 321 |
+
energy_dir,
|
| 322 |
+
use_log_scale=False,
|
| 323 |
+
return_norm=False,
|
| 324 |
+
utt2spk=None,
|
| 325 |
+
):
|
| 326 |
+
utt2energy = {}
|
| 327 |
+
if utt2spk is None:
|
| 328 |
+
for utt_info in meta_data:
|
| 329 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 330 |
+
energy_path = os.path.join(
|
| 331 |
+
processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
|
| 332 |
+
)
|
| 333 |
+
if not os.path.exists(energy_path):
|
| 334 |
+
continue
|
| 335 |
+
energy = np.load(energy_path)
|
| 336 |
+
assert len(energy) > 0
|
| 337 |
+
|
| 338 |
+
if use_log_scale:
|
| 339 |
+
nonzero_idxes = np.where(energy != 0)[0]
|
| 340 |
+
energy[nonzero_idxes] = np.log(energy[nonzero_idxes])
|
| 341 |
+
utt2energy[utt] = energy
|
| 342 |
+
|
| 343 |
+
if return_norm:
|
| 344 |
+
with open(
|
| 345 |
+
os.path.join(
|
| 346 |
+
processed_dir, utt_info["Dataset"], energy_dir, "statistics.json"
|
| 347 |
+
)
|
| 348 |
+
) as f:
|
| 349 |
+
stats = json.load(f)
|
| 350 |
+
mean, std = (
|
| 351 |
+
stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][
|
| 352 |
+
"voiced_positions"
|
| 353 |
+
]["mean"],
|
| 354 |
+
stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"],
|
| 355 |
+
)
|
| 356 |
+
for utt in utt2energy.keys():
|
| 357 |
+
energy = utt2energy[utt]
|
| 358 |
+
normalized_energy = (energy - mean) / std
|
| 359 |
+
utt2energy[utt] = normalized_energy
|
| 360 |
+
|
| 361 |
+
energy_statistic = {"mean": mean, "std": std}
|
| 362 |
+
else:
|
| 363 |
+
spk2utt = {}
|
| 364 |
+
energy_statistic = []
|
| 365 |
+
for utt_info in meta_data:
|
| 366 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 367 |
+
if not utt2spk[utt] in spk2utt:
|
| 368 |
+
spk2utt[utt2spk[utt]] = []
|
| 369 |
+
spk2utt[utt2spk[utt]].append(utt)
|
| 370 |
+
|
| 371 |
+
for spk in spk2utt:
|
| 372 |
+
energy_scaler = StandardScaler()
|
| 373 |
+
for utt in spk2utt[spk]:
|
| 374 |
+
dataset = utt.split("_")[0]
|
| 375 |
+
uid = "_".join(utt.split("_")[1:])
|
| 376 |
+
energy_path = os.path.join(
|
| 377 |
+
processed_dir, dataset, energy_dir, f"{uid}.npy"
|
| 378 |
+
)
|
| 379 |
+
if not os.path.exists(energy_path):
|
| 380 |
+
continue
|
| 381 |
+
frame_energy = np.load(energy_path)
|
| 382 |
+
assert len(frame_energy) > 0
|
| 383 |
+
|
| 384 |
+
if use_log_scale:
|
| 385 |
+
nonzero_idxes = np.where(frame_energy != 0)[0]
|
| 386 |
+
frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
|
| 387 |
+
utt2energy[utt] = frame_energy
|
| 388 |
+
energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
|
| 389 |
+
|
| 390 |
+
mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
|
| 391 |
+
if return_norm:
|
| 392 |
+
for utt in spk2utt[spk]:
|
| 393 |
+
energy = utt2energy[utt]
|
| 394 |
+
normalized_energy = (energy - mean) / std
|
| 395 |
+
utt2energy[utt] = normalized_energy
|
| 396 |
+
energy_statistic.append({"spk": spk, "mean": mean, "std": std})
|
| 397 |
+
|
| 398 |
+
return utt2energy, energy_statistic
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def load_frame_energy(
|
| 402 |
+
meta_data,
|
| 403 |
+
processed_dir,
|
| 404 |
+
energy_dir,
|
| 405 |
+
use_log_scale=False,
|
| 406 |
+
return_norm=False,
|
| 407 |
+
interoperate=False,
|
| 408 |
+
utt2spk=None,
|
| 409 |
+
):
|
| 410 |
+
utt2energy = {}
|
| 411 |
+
if utt2spk is None:
|
| 412 |
+
energy_scaler = StandardScaler()
|
| 413 |
+
for utt_info in meta_data:
|
| 414 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 415 |
+
energy_path = os.path.join(
|
| 416 |
+
processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
|
| 417 |
+
)
|
| 418 |
+
frame_energy = np.load(energy_path)
|
| 419 |
+
assert len(frame_energy) > 0
|
| 420 |
+
|
| 421 |
+
if use_log_scale:
|
| 422 |
+
nonzero_idxes = np.where(frame_energy != 0)[0]
|
| 423 |
+
frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
|
| 424 |
+
utt2energy[utt] = frame_energy
|
| 425 |
+
energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
|
| 426 |
+
|
| 427 |
+
mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
|
| 428 |
+
if return_norm:
|
| 429 |
+
for utt_info in meta_data:
|
| 430 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 431 |
+
energy = utt2energy[utt]
|
| 432 |
+
normalized_energy = (energy - mean) / std
|
| 433 |
+
utt2energy[utt] = normalized_energy
|
| 434 |
+
energy_statistic = {"mean": mean, "std": std}
|
| 435 |
+
|
| 436 |
+
else:
|
| 437 |
+
spk2utt = {}
|
| 438 |
+
energy_statistic = []
|
| 439 |
+
for utt_info in meta_data:
|
| 440 |
+
utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
|
| 441 |
+
if not utt2spk[utt] in spk2utt:
|
| 442 |
+
spk2utt[utt2spk[utt]] = []
|
| 443 |
+
spk2utt[utt2spk[utt]].append(utt)
|
| 444 |
+
|
| 445 |
+
for spk in spk2utt:
|
| 446 |
+
energy_scaler = StandardScaler()
|
| 447 |
+
for utt in spk2utt[spk]:
|
| 448 |
+
dataset = utt.split("_")[0]
|
| 449 |
+
uid = "_".join(utt.split("_")[1:])
|
| 450 |
+
energy_path = os.path.join(
|
| 451 |
+
processed_dir, dataset, energy_dir, f"{uid}.npy"
|
| 452 |
+
)
|
| 453 |
+
frame_energy = np.load(energy_path)
|
| 454 |
+
assert len(frame_energy) > 0
|
| 455 |
+
|
| 456 |
+
if use_log_scale:
|
| 457 |
+
nonzero_idxes = np.where(frame_energy != 0)[0]
|
| 458 |
+
frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
|
| 459 |
+
utt2energy[utt] = frame_energy
|
| 460 |
+
energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
|
| 461 |
+
|
| 462 |
+
mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
|
| 463 |
+
if return_norm:
|
| 464 |
+
for utt in spk2utt[spk]:
|
| 465 |
+
energy = utt2energy[utt]
|
| 466 |
+
normalized_energy = (energy - mean) / std
|
| 467 |
+
utt2energy[utt] = normalized_energy
|
| 468 |
+
energy_statistic.append({"spk": spk, "mean": mean, "std": std})
|
| 469 |
+
|
| 470 |
+
return utt2energy, energy_statistic
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def align_length(feature, target_len, pad_value=0.0):
|
| 474 |
+
feature_len = feature.shape[-1]
|
| 475 |
+
dim = len(feature.shape)
|
| 476 |
+
# align 1-D data
|
| 477 |
+
if dim == 2:
|
| 478 |
+
if target_len > feature_len:
|
| 479 |
+
feature = np.pad(
|
| 480 |
+
feature,
|
| 481 |
+
((0, 0), (0, target_len - feature_len)),
|
| 482 |
+
constant_values=pad_value,
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
feature = feature[:, :target_len]
|
| 486 |
+
# align 2-D data
|
| 487 |
+
elif dim == 1:
|
| 488 |
+
if target_len > feature_len:
|
| 489 |
+
feature = np.pad(
|
| 490 |
+
feature, (0, target_len - feature_len), constant_values=pad_value
|
| 491 |
+
)
|
| 492 |
+
else:
|
| 493 |
+
feature = feature[:target_len]
|
| 494 |
+
else:
|
| 495 |
+
raise NotImplementedError
|
| 496 |
+
return feature
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def align_whisper_feauture_length(
|
| 500 |
+
feature, target_len, fast_mapping=True, source_hop=320, target_hop=256
|
| 501 |
+
):
|
| 502 |
+
factor = np.gcd(source_hop, target_hop)
|
| 503 |
+
source_hop //= factor
|
| 504 |
+
target_hop //= factor
|
| 505 |
+
# print(
|
| 506 |
+
# "Mapping source's {} frames => target's {} frames".format(
|
| 507 |
+
# target_hop, source_hop
|
| 508 |
+
# )
|
| 509 |
+
# )
|
| 510 |
+
|
| 511 |
+
max_source_len = 1500
|
| 512 |
+
target_len = min(target_len, max_source_len * source_hop // target_hop)
|
| 513 |
+
|
| 514 |
+
width = feature.shape[-1]
|
| 515 |
+
|
| 516 |
+
if fast_mapping:
|
| 517 |
+
source_len = target_len * target_hop // source_hop + 1
|
| 518 |
+
feature = feature[:source_len]
|
| 519 |
+
|
| 520 |
+
else:
|
| 521 |
+
source_len = max_source_len
|
| 522 |
+
|
| 523 |
+
# const ~= target_len * target_hop
|
| 524 |
+
const = source_len * source_hop // target_hop * target_hop
|
| 525 |
+
|
| 526 |
+
# (source_len * source_hop, dim)
|
| 527 |
+
up_sampling_feats = np.repeat(feature, source_hop, axis=0)
|
| 528 |
+
# (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
|
| 529 |
+
down_sampling_feats = np.average(
|
| 530 |
+
up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
|
| 531 |
+
)
|
| 532 |
+
assert len(down_sampling_feats) >= target_len
|
| 533 |
+
|
| 534 |
+
# (target_len, dim)
|
| 535 |
+
feat = down_sampling_feats[:target_len]
|
| 536 |
+
|
| 537 |
+
return feat
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256):
|
| 541 |
+
factor = np.gcd(source_hop, target_hop)
|
| 542 |
+
source_hop //= factor
|
| 543 |
+
target_hop //= factor
|
| 544 |
+
# print(
|
| 545 |
+
# "Mapping source's {} frames => target's {} frames".format(
|
| 546 |
+
# target_hop, source_hop
|
| 547 |
+
# )
|
| 548 |
+
# )
|
| 549 |
+
|
| 550 |
+
# (source_len, 256)
|
| 551 |
+
source_len, width = feature.shape
|
| 552 |
+
|
| 553 |
+
# const ~= target_len * target_hop
|
| 554 |
+
const = source_len * source_hop // target_hop * target_hop
|
| 555 |
+
|
| 556 |
+
# (source_len * source_hop, dim)
|
| 557 |
+
up_sampling_feats = np.repeat(feature, source_hop, axis=0)
|
| 558 |
+
# (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
|
| 559 |
+
down_sampling_feats = np.average(
|
| 560 |
+
up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
err = abs(target_len - len(down_sampling_feats))
|
| 564 |
+
if err > 4: ## why 4 not 3?
|
| 565 |
+
print("target_len:", target_len)
|
| 566 |
+
print("raw feature:", feature.shape)
|
| 567 |
+
print("up_sampling:", up_sampling_feats.shape)
|
| 568 |
+
print("down_sampling_feats:", down_sampling_feats.shape)
|
| 569 |
+
exit()
|
| 570 |
+
if len(down_sampling_feats) < target_len:
|
| 571 |
+
# (1, dim) -> (err, dim)
|
| 572 |
+
end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
|
| 573 |
+
down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
|
| 574 |
+
|
| 575 |
+
# (target_len, dim)
|
| 576 |
+
feat = down_sampling_feats[:target_len]
|
| 577 |
+
|
| 578 |
+
return feat
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def remove_outlier(values):
|
| 582 |
+
values = np.array(values)
|
| 583 |
+
p25 = np.percentile(values, 25)
|
| 584 |
+
p75 = np.percentile(values, 75)
|
| 585 |
+
lower = p25 - 1.5 * (p75 - p25)
|
| 586 |
+
upper = p75 + 1.5 * (p75 - p25)
|
| 587 |
+
normal_indices = np.logical_and(values > lower, values < upper)
|
| 588 |
+
return values[normal_indices]
|
Amphion/utils/distribution.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch.distributions import Normal
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def log_sum_exp(x):
|
| 14 |
+
"""numerically stable log_sum_exp implementation that prevents overflow"""
|
| 15 |
+
# TF ordering
|
| 16 |
+
axis = len(x.size()) - 1
|
| 17 |
+
m, _ = torch.max(x, dim=axis)
|
| 18 |
+
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
| 19 |
+
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def discretized_mix_logistic_loss(
|
| 23 |
+
y_hat, y, num_classes=256, log_scale_min=-7.0, reduce=True
|
| 24 |
+
):
|
| 25 |
+
"""Discretized mixture of logistic distributions loss
|
| 26 |
+
|
| 27 |
+
Note that it is assumed that input is scaled to [-1, 1].
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
y_hat (Tensor): Predicted output (B x C x T)
|
| 31 |
+
y (Tensor): Target (B x T x 1).
|
| 32 |
+
num_classes (int): Number of classes
|
| 33 |
+
log_scale_min (float): Log scale minimum value
|
| 34 |
+
reduce (bool): If True, the losses are averaged or summed for each
|
| 35 |
+
minibatch.
|
| 36 |
+
|
| 37 |
+
Returns
|
| 38 |
+
Tensor: loss
|
| 39 |
+
"""
|
| 40 |
+
assert y_hat.dim() == 3
|
| 41 |
+
assert y_hat.size(1) % 3 == 0
|
| 42 |
+
nr_mix = y_hat.size(1) // 3
|
| 43 |
+
|
| 44 |
+
# (B x T x C)
|
| 45 |
+
y_hat = y_hat.transpose(1, 2)
|
| 46 |
+
|
| 47 |
+
# unpack parameters. (B, T, num_mixtures) x 3
|
| 48 |
+
logit_probs = y_hat[:, :, :nr_mix]
|
| 49 |
+
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
| 50 |
+
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
| 51 |
+
|
| 52 |
+
# B x T x 1 -> B x T x num_mixtures
|
| 53 |
+
y = y.expand_as(means)
|
| 54 |
+
|
| 55 |
+
centered_y = y - means
|
| 56 |
+
inv_stdv = torch.exp(-log_scales)
|
| 57 |
+
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
| 58 |
+
cdf_plus = torch.sigmoid(plus_in)
|
| 59 |
+
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
| 60 |
+
cdf_min = torch.sigmoid(min_in)
|
| 61 |
+
|
| 62 |
+
# log probability for edge case of 0 (before scaling)
|
| 63 |
+
# equivalent: torch.log(torch.sigmoid(plus_in))
|
| 64 |
+
log_cdf_plus = plus_in - F.softplus(plus_in)
|
| 65 |
+
|
| 66 |
+
# log probability for edge case of 255 (before scaling)
|
| 67 |
+
# equivalent: (1 - torch.sigmoid(min_in)).log()
|
| 68 |
+
log_one_minus_cdf_min = -F.softplus(min_in)
|
| 69 |
+
|
| 70 |
+
# probability for all other cases
|
| 71 |
+
cdf_delta = cdf_plus - cdf_min
|
| 72 |
+
|
| 73 |
+
mid_in = inv_stdv * centered_y
|
| 74 |
+
# log probability in the center of the bin, to be used in extreme cases
|
| 75 |
+
# (not actually used in our code)
|
| 76 |
+
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
| 77 |
+
|
| 78 |
+
# tf equivalent
|
| 79 |
+
"""
|
| 80 |
+
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
| 81 |
+
tf.where(x > 0.999, log_one_minus_cdf_min,
|
| 82 |
+
tf.where(cdf_delta > 1e-5,
|
| 83 |
+
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
| 84 |
+
log_pdf_mid - np.log(127.5))))
|
| 85 |
+
"""
|
| 86 |
+
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
| 87 |
+
# for num_classes=65536 case? 1e-7? not sure..
|
| 88 |
+
inner_inner_cond = (cdf_delta > 1e-5).float()
|
| 89 |
+
|
| 90 |
+
inner_inner_out = inner_inner_cond * torch.log(
|
| 91 |
+
torch.clamp(cdf_delta, min=1e-12)
|
| 92 |
+
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
| 93 |
+
inner_cond = (y > 0.999).float()
|
| 94 |
+
inner_out = (
|
| 95 |
+
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
| 96 |
+
)
|
| 97 |
+
cond = (y < -0.999).float()
|
| 98 |
+
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
| 99 |
+
|
| 100 |
+
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
| 101 |
+
|
| 102 |
+
if reduce:
|
| 103 |
+
return -torch.sum(log_sum_exp(log_probs))
|
| 104 |
+
else:
|
| 105 |
+
return -log_sum_exp(log_probs).unsqueeze(-1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def to_one_hot(tensor, n, fill_with=1.0):
|
| 109 |
+
# we perform one hot encore with respect to the last axis
|
| 110 |
+
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
| 111 |
+
if tensor.is_cuda:
|
| 112 |
+
one_hot = one_hot.cuda()
|
| 113 |
+
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
| 114 |
+
return one_hot
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, clamp_log_scale=False):
|
| 118 |
+
"""
|
| 119 |
+
Sample from discretized mixture of logistic distributions
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
y (Tensor): B x C x T
|
| 123 |
+
log_scale_min (float): Log scale minimum value
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Tensor: sample in range of [-1, 1].
|
| 127 |
+
"""
|
| 128 |
+
assert y.size(1) % 3 == 0
|
| 129 |
+
nr_mix = y.size(1) // 3
|
| 130 |
+
|
| 131 |
+
# B x T x C
|
| 132 |
+
y = y.transpose(1, 2)
|
| 133 |
+
logit_probs = y[:, :, :nr_mix]
|
| 134 |
+
|
| 135 |
+
# sample mixture indicator from softmax
|
| 136 |
+
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
| 137 |
+
temp = logit_probs.data - torch.log(-torch.log(temp))
|
| 138 |
+
_, argmax = temp.max(dim=-1)
|
| 139 |
+
|
| 140 |
+
# (B, T) -> (B, T, nr_mix)
|
| 141 |
+
one_hot = to_one_hot(argmax, nr_mix)
|
| 142 |
+
# select logistic parameters
|
| 143 |
+
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
| 144 |
+
log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
|
| 145 |
+
if clamp_log_scale:
|
| 146 |
+
log_scales = torch.clamp(log_scales, min=log_scale_min)
|
| 147 |
+
# sample from logistic & clip to interval
|
| 148 |
+
# we don't actually round to the nearest 8bit value when sampling
|
| 149 |
+
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
| 150 |
+
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
| 151 |
+
|
| 152 |
+
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# we can easily define discretized version of the gaussian loss, however,
|
| 158 |
+
# use continuous version as same as the https://clarinet-demo.github.io/
|
| 159 |
+
def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True):
|
| 160 |
+
"""Mixture of continuous gaussian distributions loss
|
| 161 |
+
|
| 162 |
+
Note that it is assumed that input is scaled to [-1, 1].
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
y_hat (Tensor): Predicted output (B x C x T)
|
| 166 |
+
y (Tensor): Target (B x T x 1).
|
| 167 |
+
log_scale_min (float): Log scale minimum value
|
| 168 |
+
reduce (bool): If True, the losses are averaged or summed for each
|
| 169 |
+
minibatch.
|
| 170 |
+
Returns
|
| 171 |
+
Tensor: loss
|
| 172 |
+
"""
|
| 173 |
+
assert y_hat.dim() == 3
|
| 174 |
+
C = y_hat.size(1)
|
| 175 |
+
if C == 2:
|
| 176 |
+
nr_mix = 1
|
| 177 |
+
else:
|
| 178 |
+
assert y_hat.size(1) % 3 == 0
|
| 179 |
+
nr_mix = y_hat.size(1) // 3
|
| 180 |
+
|
| 181 |
+
# (B x T x C)
|
| 182 |
+
y_hat = y_hat.transpose(1, 2)
|
| 183 |
+
|
| 184 |
+
# unpack parameters.
|
| 185 |
+
if C == 2:
|
| 186 |
+
# special case for C == 2, just for compatibility
|
| 187 |
+
logit_probs = None
|
| 188 |
+
means = y_hat[:, :, 0:1]
|
| 189 |
+
log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min)
|
| 190 |
+
else:
|
| 191 |
+
# (B, T, num_mixtures) x 3
|
| 192 |
+
logit_probs = y_hat[:, :, :nr_mix]
|
| 193 |
+
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
| 194 |
+
log_scales = torch.clamp(
|
| 195 |
+
y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# B x T x 1 -> B x T x num_mixtures
|
| 199 |
+
y = y.expand_as(means)
|
| 200 |
+
|
| 201 |
+
centered_y = y - means
|
| 202 |
+
dist = Normal(loc=0.0, scale=torch.exp(log_scales))
|
| 203 |
+
# do we need to add a trick to avoid log(0)?
|
| 204 |
+
log_probs = dist.log_prob(centered_y)
|
| 205 |
+
|
| 206 |
+
if nr_mix > 1:
|
| 207 |
+
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
| 208 |
+
|
| 209 |
+
if reduce:
|
| 210 |
+
if nr_mix == 1:
|
| 211 |
+
return -torch.sum(log_probs)
|
| 212 |
+
else:
|
| 213 |
+
return -torch.sum(log_sum_exp(log_probs))
|
| 214 |
+
else:
|
| 215 |
+
if nr_mix == 1:
|
| 216 |
+
return -log_probs
|
| 217 |
+
else:
|
| 218 |
+
return -log_sum_exp(log_probs).unsqueeze(-1)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def sample_from_mix_gaussian(y, log_scale_min=-7.0):
|
| 222 |
+
"""
|
| 223 |
+
Sample from (discretized) mixture of gaussian distributions
|
| 224 |
+
Args:
|
| 225 |
+
y (Tensor): B x C x T
|
| 226 |
+
log_scale_min (float): Log scale minimum value
|
| 227 |
+
Returns:
|
| 228 |
+
Tensor: sample in range of [-1, 1].
|
| 229 |
+
"""
|
| 230 |
+
C = y.size(1)
|
| 231 |
+
if C == 2:
|
| 232 |
+
nr_mix = 1
|
| 233 |
+
else:
|
| 234 |
+
assert y.size(1) % 3 == 0
|
| 235 |
+
nr_mix = y.size(1) // 3
|
| 236 |
+
|
| 237 |
+
# B x T x C
|
| 238 |
+
y = y.transpose(1, 2)
|
| 239 |
+
|
| 240 |
+
if C == 2:
|
| 241 |
+
logit_probs = None
|
| 242 |
+
else:
|
| 243 |
+
logit_probs = y[:, :, :nr_mix]
|
| 244 |
+
|
| 245 |
+
if nr_mix > 1:
|
| 246 |
+
# sample mixture indicator from softmax
|
| 247 |
+
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
| 248 |
+
temp = logit_probs.data - torch.log(-torch.log(temp))
|
| 249 |
+
_, argmax = temp.max(dim=-1)
|
| 250 |
+
|
| 251 |
+
# (B, T) -> (B, T, nr_mix)
|
| 252 |
+
one_hot = to_one_hot(argmax, nr_mix)
|
| 253 |
+
|
| 254 |
+
# Select means and log scales
|
| 255 |
+
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
| 256 |
+
log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
|
| 257 |
+
else:
|
| 258 |
+
if C == 2:
|
| 259 |
+
means, log_scales = y[:, :, 0], y[:, :, 1]
|
| 260 |
+
elif C == 3:
|
| 261 |
+
means, log_scales = y[:, :, 1], y[:, :, 2]
|
| 262 |
+
else:
|
| 263 |
+
assert False, "shouldn't happen"
|
| 264 |
+
|
| 265 |
+
scales = torch.exp(log_scales)
|
| 266 |
+
dist = Normal(loc=means, scale=scales)
|
| 267 |
+
x = dist.sample()
|
| 268 |
+
|
| 269 |
+
x = torch.clamp(x, min=-1.0, max=1.0)
|
| 270 |
+
return x
|
Amphion/utils/mel.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 11 |
+
# Min value: ln(1e-5) = -11.5129
|
| 12 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def spectral_normalize_torch(magnitudes):
|
| 16 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 17 |
+
return output
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def extract_linear_features(y, cfg, center=False):
|
| 21 |
+
if torch.min(y) < -1.0:
|
| 22 |
+
print("min value is ", torch.min(y))
|
| 23 |
+
if torch.max(y) > 1.0:
|
| 24 |
+
print("max value is ", torch.max(y))
|
| 25 |
+
|
| 26 |
+
global hann_window
|
| 27 |
+
hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
|
| 28 |
+
|
| 29 |
+
y = torch.nn.functional.pad(
|
| 30 |
+
y.unsqueeze(1),
|
| 31 |
+
(int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
|
| 32 |
+
mode="reflect",
|
| 33 |
+
)
|
| 34 |
+
y = y.squeeze(1)
|
| 35 |
+
|
| 36 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
| 37 |
+
spec = torch.stft(
|
| 38 |
+
y,
|
| 39 |
+
cfg.n_fft,
|
| 40 |
+
hop_length=cfg.hop_size,
|
| 41 |
+
win_length=cfg.win_size,
|
| 42 |
+
window=hann_window[str(y.device)],
|
| 43 |
+
center=center,
|
| 44 |
+
pad_mode="reflect",
|
| 45 |
+
normalized=False,
|
| 46 |
+
onesided=True,
|
| 47 |
+
return_complex=True,
|
| 48 |
+
)
|
| 49 |
+
spec = torch.view_as_real(spec)
|
| 50 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 51 |
+
spec = torch.squeeze(spec, 0)
|
| 52 |
+
return spec
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def mel_spectrogram_torch(y, cfg, center=False):
|
| 56 |
+
"""
|
| 57 |
+
TODO: to merge this funtion with the extract_mel_features below
|
| 58 |
+
"""
|
| 59 |
+
if torch.min(y) < -1.0:
|
| 60 |
+
print("min value is ", torch.min(y))
|
| 61 |
+
if torch.max(y) > 1.0:
|
| 62 |
+
print("max value is ", torch.max(y))
|
| 63 |
+
|
| 64 |
+
global mel_basis, hann_window
|
| 65 |
+
if cfg.fmax not in mel_basis:
|
| 66 |
+
mel = librosa_mel_fn(
|
| 67 |
+
sr=cfg.sample_rate,
|
| 68 |
+
n_fft=cfg.n_fft,
|
| 69 |
+
n_mels=cfg.n_mel,
|
| 70 |
+
fmin=cfg.fmin,
|
| 71 |
+
fmax=cfg.fmax,
|
| 72 |
+
)
|
| 73 |
+
mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
|
| 74 |
+
torch.from_numpy(mel).float().to(y.device)
|
| 75 |
+
)
|
| 76 |
+
hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
|
| 77 |
+
|
| 78 |
+
y = torch.nn.functional.pad(
|
| 79 |
+
y.unsqueeze(1),
|
| 80 |
+
(int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
|
| 81 |
+
mode="reflect",
|
| 82 |
+
)
|
| 83 |
+
y = y.squeeze(1)
|
| 84 |
+
|
| 85 |
+
spec = torch.stft(
|
| 86 |
+
y,
|
| 87 |
+
cfg.n_fft,
|
| 88 |
+
hop_length=cfg.hop_size,
|
| 89 |
+
win_length=cfg.win_size,
|
| 90 |
+
window=hann_window[str(y.device)],
|
| 91 |
+
center=center,
|
| 92 |
+
pad_mode="reflect",
|
| 93 |
+
normalized=False,
|
| 94 |
+
onesided=True,
|
| 95 |
+
return_complex=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
spec = torch.view_as_real(spec)
|
| 99 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
| 100 |
+
|
| 101 |
+
spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
|
| 102 |
+
spec = spectral_normalize_torch(spec)
|
| 103 |
+
|
| 104 |
+
return spec
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
mel_basis = {}
|
| 108 |
+
hann_window = {}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def extract_mel_features(
|
| 112 |
+
y,
|
| 113 |
+
cfg,
|
| 114 |
+
center=False,
|
| 115 |
+
):
|
| 116 |
+
"""Extract mel features
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
y (tensor): audio data in tensor
|
| 120 |
+
cfg (dict): configuration in cfg.preprocess
|
| 121 |
+
center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
tensor: a tensor containing the mel feature calculated based on STFT result
|
| 125 |
+
"""
|
| 126 |
+
if torch.min(y) < -1.0:
|
| 127 |
+
print("min value is ", torch.min(y))
|
| 128 |
+
if torch.max(y) > 1.0:
|
| 129 |
+
print("max value is ", torch.max(y))
|
| 130 |
+
|
| 131 |
+
global mel_basis, hann_window
|
| 132 |
+
if cfg.fmax not in mel_basis:
|
| 133 |
+
mel = librosa_mel_fn(
|
| 134 |
+
sr=cfg.sample_rate,
|
| 135 |
+
n_fft=cfg.n_fft,
|
| 136 |
+
n_mels=cfg.n_mel,
|
| 137 |
+
fmin=cfg.fmin,
|
| 138 |
+
fmax=cfg.fmax,
|
| 139 |
+
)
|
| 140 |
+
mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
|
| 141 |
+
torch.from_numpy(mel).float().to(y.device)
|
| 142 |
+
)
|
| 143 |
+
hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
|
| 144 |
+
|
| 145 |
+
y = torch.nn.functional.pad(
|
| 146 |
+
y.unsqueeze(1),
|
| 147 |
+
(int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
|
| 148 |
+
mode="reflect",
|
| 149 |
+
)
|
| 150 |
+
y = y.squeeze(1)
|
| 151 |
+
|
| 152 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
| 153 |
+
spec = torch.stft(
|
| 154 |
+
y,
|
| 155 |
+
cfg.n_fft,
|
| 156 |
+
hop_length=cfg.hop_size,
|
| 157 |
+
win_length=cfg.win_size,
|
| 158 |
+
window=hann_window[str(y.device)],
|
| 159 |
+
center=center,
|
| 160 |
+
pad_mode="reflect",
|
| 161 |
+
normalized=False,
|
| 162 |
+
onesided=True,
|
| 163 |
+
return_complex=True,
|
| 164 |
+
)
|
| 165 |
+
spec = torch.view_as_real(spec)
|
| 166 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 167 |
+
|
| 168 |
+
spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
|
| 169 |
+
spec = spectral_normalize_torch(spec)
|
| 170 |
+
return spec.squeeze(0)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def extract_mel_features_tts(
|
| 174 |
+
y,
|
| 175 |
+
cfg,
|
| 176 |
+
center=False,
|
| 177 |
+
taco=False,
|
| 178 |
+
_stft=None,
|
| 179 |
+
):
|
| 180 |
+
"""Extract mel features
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
y (tensor): audio data in tensor
|
| 184 |
+
cfg (dict): configuration in cfg.preprocess
|
| 185 |
+
center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
|
| 186 |
+
taco: use tacotron mel
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
tensor: a tensor containing the mel feature calculated based on STFT result
|
| 190 |
+
"""
|
| 191 |
+
if not taco:
|
| 192 |
+
if torch.min(y) < -1.0:
|
| 193 |
+
print("min value is ", torch.min(y))
|
| 194 |
+
if torch.max(y) > 1.0:
|
| 195 |
+
print("max value is ", torch.max(y))
|
| 196 |
+
|
| 197 |
+
global mel_basis, hann_window
|
| 198 |
+
if cfg.fmax not in mel_basis:
|
| 199 |
+
mel = librosa_mel_fn(
|
| 200 |
+
sr=cfg.sample_rate,
|
| 201 |
+
n_fft=cfg.n_fft,
|
| 202 |
+
n_mels=cfg.n_mel,
|
| 203 |
+
fmin=cfg.fmin,
|
| 204 |
+
fmax=cfg.fmax,
|
| 205 |
+
)
|
| 206 |
+
mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
|
| 207 |
+
torch.from_numpy(mel).float().to(y.device)
|
| 208 |
+
)
|
| 209 |
+
hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
|
| 210 |
+
|
| 211 |
+
y = torch.nn.functional.pad(
|
| 212 |
+
y.unsqueeze(1),
|
| 213 |
+
(int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
|
| 214 |
+
mode="reflect",
|
| 215 |
+
)
|
| 216 |
+
y = y.squeeze(1)
|
| 217 |
+
|
| 218 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
| 219 |
+
spec = torch.stft(
|
| 220 |
+
y,
|
| 221 |
+
cfg.n_fft,
|
| 222 |
+
hop_length=cfg.hop_size,
|
| 223 |
+
win_length=cfg.win_size,
|
| 224 |
+
window=hann_window[str(y.device)],
|
| 225 |
+
center=center,
|
| 226 |
+
pad_mode="reflect",
|
| 227 |
+
normalized=False,
|
| 228 |
+
onesided=True,
|
| 229 |
+
return_complex=True,
|
| 230 |
+
)
|
| 231 |
+
spec = torch.view_as_real(spec)
|
| 232 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 233 |
+
|
| 234 |
+
spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
|
| 235 |
+
spec = spectral_normalize_torch(spec)
|
| 236 |
+
else:
|
| 237 |
+
audio = torch.clip(y, -1, 1)
|
| 238 |
+
audio = torch.autograd.Variable(audio, requires_grad=False)
|
| 239 |
+
spec, energy = _stft.mel_spectrogram(audio)
|
| 240 |
+
|
| 241 |
+
return spec.squeeze(0)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def amplitude_phase_spectrum(y, cfg):
|
| 245 |
+
hann_window = torch.hann_window(cfg.win_size).to(y.device)
|
| 246 |
+
|
| 247 |
+
y = torch.nn.functional.pad(
|
| 248 |
+
y.unsqueeze(1),
|
| 249 |
+
(int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
|
| 250 |
+
mode="reflect",
|
| 251 |
+
)
|
| 252 |
+
y = y.squeeze(1)
|
| 253 |
+
|
| 254 |
+
stft_spec = torch.stft(
|
| 255 |
+
y,
|
| 256 |
+
cfg.n_fft,
|
| 257 |
+
hop_length=cfg.hop_size,
|
| 258 |
+
win_length=cfg.win_size,
|
| 259 |
+
window=hann_window,
|
| 260 |
+
center=False,
|
| 261 |
+
return_complex=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
stft_spec = torch.view_as_real(stft_spec)
|
| 265 |
+
if stft_spec.size()[0] == 1:
|
| 266 |
+
stft_spec = stft_spec.squeeze(0)
|
| 267 |
+
|
| 268 |
+
if len(list(stft_spec.size())) == 4:
|
| 269 |
+
rea = stft_spec[:, :, :, 0] # [batch_size, n_fft//2+1, frames]
|
| 270 |
+
imag = stft_spec[:, :, :, 1] # [batch_size, n_fft//2+1, frames]
|
| 271 |
+
else:
|
| 272 |
+
rea = stft_spec[:, :, 0] # [n_fft//2+1, frames]
|
| 273 |
+
imag = stft_spec[:, :, 1] # [n_fft//2+1, frames]
|
| 274 |
+
|
| 275 |
+
log_amplitude = torch.log(
|
| 276 |
+
torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
|
| 277 |
+
) # [n_fft//2+1, frames]
|
| 278 |
+
phase = torch.atan2(imag, rea) # [n_fft//2+1, frames]
|
| 279 |
+
|
| 280 |
+
return log_amplitude, phase, rea, imag
|
Amphion/utils/prompt_preparer.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PromptPreparer:
|
| 10 |
+
def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
|
| 11 |
+
if self.prefix_mode == 0:
|
| 12 |
+
y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage)
|
| 13 |
+
elif self.prefix_mode == 1:
|
| 14 |
+
y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage)
|
| 15 |
+
elif self.prefix_mode in [2, 4]:
|
| 16 |
+
y_emb, prefix_len = self._handle_prefix_mode_2_4(
|
| 17 |
+
y, y_lens, codes, nar_stage, y_prompts_codes
|
| 18 |
+
)
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError("Invalid prefix mode")
|
| 21 |
+
|
| 22 |
+
return y_emb, prefix_len
|
| 23 |
+
|
| 24 |
+
def _handle_prefix_mode_0(self, y, codes, nar_stage):
|
| 25 |
+
prefix_len = 0
|
| 26 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 27 |
+
for j in range(1, nar_stage):
|
| 28 |
+
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
| 29 |
+
return y_emb, 0
|
| 30 |
+
|
| 31 |
+
def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage):
|
| 32 |
+
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
| 33 |
+
prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
|
| 34 |
+
prefix_len = min(prefix_len, 225)
|
| 35 |
+
|
| 36 |
+
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
| 37 |
+
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
| 38 |
+
for j in range(1, self.num_quantizers):
|
| 39 |
+
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
|
| 40 |
+
if j < nar_stage:
|
| 41 |
+
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
|
| 42 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 43 |
+
return y_emb, prefix_len
|
| 44 |
+
|
| 45 |
+
def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes):
|
| 46 |
+
if self.prefix_mode == 2:
|
| 47 |
+
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
| 48 |
+
|
| 49 |
+
y_prompts_codes = []
|
| 50 |
+
for b in range(codes.shape[0]):
|
| 51 |
+
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
| 52 |
+
y_prompts_codes.append(
|
| 53 |
+
torch.clone(codes[b, start : start + prefix_len])
|
| 54 |
+
)
|
| 55 |
+
codes[b, start : start + prefix_len, nar_stage] = self.audio_token_num
|
| 56 |
+
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
| 57 |
+
else:
|
| 58 |
+
prefix_len = y_prompts_codes.shape[1]
|
| 59 |
+
|
| 60 |
+
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
| 61 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 62 |
+
for j in range(1, self.num_quantizers):
|
| 63 |
+
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
|
| 64 |
+
if j < nar_stage:
|
| 65 |
+
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
| 66 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 67 |
+
|
| 68 |
+
return y_emb, prefix_len
|
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
conf/default.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
is_left_ear: true
|
| 2 |
+
loss_function: MultiResolutionL1SpecLoss
|
| 3 |
+
|
| 4 |
+
trainer:
|
| 5 |
+
acoustic_hop_length: 256
|
| 6 |
+
acoustic_n_fft: 512
|
| 7 |
+
acoustic_sr: 16000
|
| 8 |
+
acoustic_win_length: 512
|
| 9 |
+
adam_beta1: 0.9
|
| 10 |
+
adam_beta2: 0.999
|
| 11 |
+
adam_epsilon: 1.0e-08
|
| 12 |
+
dataloader_drop_last: true
|
| 13 |
+
dataloader_num_workers: 8
|
| 14 |
+
dataloader_persistent_workers: false
|
| 15 |
+
dataloader_pin_memory: true
|
| 16 |
+
dataloader_prefetch_factor: 2
|
| 17 |
+
ddp_find_unused_parameters: true
|
| 18 |
+
debug: false
|
| 19 |
+
do_eval: false
|
| 20 |
+
do_predict: false
|
| 21 |
+
do_train: true
|
| 22 |
+
early_stopping_patience: 20
|
| 23 |
+
eval_batch_size: 1
|
| 24 |
+
eval_epoch_interval: 20
|
| 25 |
+
gradient_accumulation_steps: 1
|
| 26 |
+
greater_is_better: true
|
| 27 |
+
learning_rate: 0.001
|
| 28 |
+
lr_scheduler_type: constant_schedule_with_warmup
|
| 29 |
+
max_grad_norm: 5
|
| 30 |
+
max_steps: 0
|
| 31 |
+
metric_for_best_model: OVRL
|
| 32 |
+
num_train_epochs: 2000
|
| 33 |
+
optim: adamw
|
| 34 |
+
output_dir: exp/bmi__bmi__fix-ddp-sampler__warmup-2000__fixed-length__rerun
|
| 35 |
+
per_device_train_batch_size: 8
|
| 36 |
+
plot_lr: false
|
| 37 |
+
resume_from_checkpoint: "no"
|
| 38 |
+
save_epoch_interval: 20
|
| 39 |
+
save_total_limit: 100
|
| 40 |
+
seed: 20220815
|
| 41 |
+
warmup_ratio: 0.0
|
| 42 |
+
warmup_steps: 2000
|
| 43 |
+
|
| 44 |
+
predict_dataset:
|
| 45 |
+
enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/speaker_adapt
|
| 46 |
+
normalize: true
|
| 47 |
+
scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.dev.json
|
| 48 |
+
scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/scenes/
|
| 49 |
+
scenes_listeners_file: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes_listeners.dev.json
|
| 50 |
+
small_test: false
|
| 51 |
+
|
| 52 |
+
train_dataset:
|
| 53 |
+
enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/train/targets
|
| 54 |
+
enroll_len_limit: 10
|
| 55 |
+
limit: -1
|
| 56 |
+
normalize: false
|
| 57 |
+
sample_len_limit: 4
|
| 58 |
+
scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.train.json
|
| 59 |
+
scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/train/scenes/
|
| 60 |
+
sr: 44100
|
| 61 |
+
use_all_enroll: false
|
| 62 |
+
use_additional_data: true
|
| 63 |
+
|
| 64 |
+
eval_dev_dataset:
|
| 65 |
+
scenes_listeners_file: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes_listeners.dev.json
|
| 66 |
+
scenes_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/scenes/
|
| 67 |
+
enroll_folder: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/dev/speaker_adapt
|
| 68 |
+
limit: 500
|
| 69 |
+
normalize: false
|
| 70 |
+
scenes_file_fpath: /data/xhao/clarity-ICASSP2023/clarity_CEC2_data/clarity_data/metadata/scenes.dev.json
|
exp/bmi__fa-codec/2024_05_20--16_21_26.log
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 2 |
+
05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 3 |
+
05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 4 |
+
05-20 16:21:26 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
exp/bmi__fa-codec/2024_05_20--16_22_35.log
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
05-20 16:22:35 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 2 |
+
05-20 16:22:35 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 3 |
+
05-20 16:22:39 INFO [logging.py:61]:
|
| 4 |
+
Environment information:
|
| 5 |
+
- `Accelerate` version: 0.28.0
|
| 6 |
+
- Platform: Linux-6.1.0-18-amd64-x86_64-with-glibc2.36
|
| 7 |
+
- Python version: 3.10.13
|
| 8 |
+
- Numpy version: 1.26.4
|
| 9 |
+
- PyTorch version (GPU?): 2.2.2 (True)
|
| 10 |
+
- System RAM: 503.49 GB
|
| 11 |
+
- GPU Available: True
|
| 12 |
+
- GPU IDs: 8
|
| 13 |
+
- GPU type: NVIDIA RTX A6000
|
| 14 |
+
05-20 16:22:39 INFO [logging.py:61]:
|
| 15 |
+
===============================================================================================
|
| 16 |
+
Layer (type:depth-idx) Param #
|
| 17 |
+
===============================================================================================
|
| 18 |
+
Model --
|
| 19 |
+
├─Linear: 1-1 8,224
|
| 20 |
+
├─FACodecEncoder: 1-2 --
|
| 21 |
+
│ └─Sequential: 2-1 --
|
| 22 |
+
│ │ └─Conv1d: 3-1 (288)
|
| 23 |
+
│ │ └─EncoderBlock: 3-2 (33,728)
|
| 24 |
+
│ │ └─EncoderBlock: 3-3 (165,760)
|
| 25 |
+
│ │ └─EncoderBlock: 3-4 (724,736)
|
| 26 |
+
│ │ └─EncoderBlock: 3-5 (2,891,264)
|
| 27 |
+
│ │ └─Activation1d: 3-6 (1,024)
|
| 28 |
+
│ │ └─Conv1d: 3-7 (393,728)
|
| 29 |
+
├─FACodecDecoder: 1-3 --
|
| 30 |
+
│ └─ModuleList: 2-2 --
|
| 31 |
+
│ │ └─ResidualVQ: 3-8 (12,816)
|
| 32 |
+
│ │ └─ResidualVQ: 3-9 (25,632)
|
| 33 |
+
│ │ └─ResidualVQ: 3-10 (38,448)
|
| 34 |
+
│ └─Sequential: 2-3 --
|
| 35 |
+
│ │ └─Conv1d: 3-11 (1,837,056)
|
| 36 |
+
│ │ └─DecoderBlock: 3-12 (11,550,208)
|
| 37 |
+
│ │ └─DecoderBlock: 3-13 (2,891,520)
|
| 38 |
+
│ │ └─DecoderBlock: 3-14 (659,328)
|
| 39 |
+
│ │ └─DecoderBlock: 3-15 (133,056)
|
| 40 |
+
│ │ └─Activation1d: 3-16 (128)
|
| 41 |
+
│ │ └─Conv1d: 3-17 (450)
|
| 42 |
+
│ │ └─Tanh: 3-18 --
|
| 43 |
+
│ └─TransformerEncoder: 2-4 --
|
| 44 |
+
│ │ └─PositionalEncoding: 3-19 --
|
| 45 |
+
│ │ └─ModuleList: 3-20 (7,353,344)
|
| 46 |
+
│ │ └─LayerNorm: 3-21 (512)
|
| 47 |
+
│ └─Linear: 2-5 (131,584)
|
| 48 |
+
│ └─LayerNorm: 2-6 --
|
| 49 |
+
│ └─CNNLSTM: 2-7 --
|
| 50 |
+
│ │ └─Sequential: 3-22 (1,579,520)
|
| 51 |
+
│ │ └─ModuleList: 3-23 (514)
|
| 52 |
+
│ └─CNNLSTM: 2-8 --
|
| 53 |
+
│ │ └─Sequential: 3-24 (1,579,520)
|
| 54 |
+
│ │ └─ModuleList: 3-25 (1,285,771)
|
| 55 |
+
│ └─Sequential: 2-9 --
|
| 56 |
+
│ │ └─GradientReversal: 3-26 --
|
| 57 |
+
│ │ └─CNNLSTM: 3-27 (1,580,034)
|
| 58 |
+
│ └─Sequential: 2-10 --
|
| 59 |
+
│ │ └─GradientReversal: 3-28 --
|
| 60 |
+
│ │ └─CNNLSTM: 3-29 (2,865,291)
|
| 61 |
+
│ └─Sequential: 2-11 --
|
| 62 |
+
│ │ └─GradientReversal: 3-30 --
|
| 63 |
+
│ │ └─CNNLSTM: 3-31 (64,595,920)
|
| 64 |
+
├─ERB: 1-4 --
|
| 65 |
+
│ └─Linear: 2-12 (3,728)
|
| 66 |
+
│ └─Linear: 2-13 (3,728)
|
| 67 |
+
├─SubbandFeatureExtractor: 1-5 --
|
| 68 |
+
│ └─Unfold: 2-14 --
|
| 69 |
+
├─Sequential: 1-6 --
|
| 70 |
+
│ └─GroupNorm: 2-15 36
|
| 71 |
+
│ └─Conv1d: 2-16 608
|
| 72 |
+
├─ModuleList: 1-7 --
|
| 73 |
+
│ └─TriplePathRNN: 2-17 --
|
| 74 |
+
│ │ └─ResRNN: 3-32 54,368
|
| 75 |
+
│ │ └─ResRNN: 3-33 54,368
|
| 76 |
+
│ │ └─Linear: 3-34 2,080
|
| 77 |
+
│ └─TriplePathRNN: 2-18 --
|
| 78 |
+
│ │ └─ResRNN: 3-35 54,368
|
| 79 |
+
│ │ └─ResRNN: 3-36 54,368
|
| 80 |
+
│ │ └─Linear: 3-37 2,080
|
| 81 |
+
│ └─TriplePathRNN: 2-19 --
|
| 82 |
+
│ │ └─ResRNN: 3-38 54,368
|
| 83 |
+
│ │ └─ResRNN: 3-39 54,368
|
| 84 |
+
│ │ └─Linear: 3-40 2,080
|
| 85 |
+
├─Sequential: 1-8 --
|
| 86 |
+
│ └─GroupNorm: 2-20 64
|
| 87 |
+
│ └─Conv1d: 2-21 2,112
|
| 88 |
+
│ └─Tanh: 2-22 --
|
| 89 |
+
│ └─Conv1d: 2-23 4,160
|
| 90 |
+
│ └─Tanh: 2-24 --
|
| 91 |
+
│ └─Conv1d: 2-25 260
|
| 92 |
+
===============================================================================================
|
| 93 |
+
Total params: 102,686,548
|
| 94 |
+
Trainable params: 347,912
|
| 95 |
+
Non-trainable params: 102,338,636
|
| 96 |
+
===============================================================================================
|
| 97 |
+
05-20 16:22:40 INFO [logging.py:61]: warmup_steps=2000. warmup_ratio will be ignored.
|
| 98 |
+
05-20 16:22:41 INFO [logging.py:61]: Will start from scratch (no checkpoint will be loaded).
|
| 99 |
+
05-20 16:22:41 INFO [logging.py:61]: ***** Running training *****
|
| 100 |
+
05-20 16:22:41 INFO [logging.py:61]: Num Epochs = 2,000
|
| 101 |
+
05-20 16:22:41 INFO [logging.py:61]: `steps_per_epoch` = 125
|
| 102 |
+
05-20 16:22:41 INFO [logging.py:61]: Instantaneous batch size per device = 16
|
| 103 |
+
05-20 16:22:41 INFO [logging.py:61]: Gradient Accumulation steps = 1
|
| 104 |
+
05-20 16:22:41 INFO [logging.py:61]: Total optimization steps = 250,000
|
| 105 |
+
05-20 16:22:41 INFO [logging.py:61]: ========= Epoch 1 out of 2000 =========
|
| 106 |
+
05-20 16:22:41 INFO [logging.py:61]: Begin training...
|
| 107 |
+
05-20 16:23:34 INFO [logging.py:61]: Loss 'loss' on epoch 1: 0.28071990609169006
|
| 108 |
+
05-20 16:23:34 INFO [logging.py:61]: Loss 'norm_before' on epoch 1: 0.033199019730091095
|
| 109 |
+
05-20 16:23:34 INFO [logging.py:61]: ========= Epoch 2 out of 2000 =========
|
| 110 |
+
05-20 16:23:34 INFO [logging.py:61]: Begin training...
|
exp/bmi__fa-codec/2024_05_20--16_24_01.log
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 2 |
+
05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 3 |
+
05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
| 4 |
+
05-20 16:24:01 INFO [logger.py:80]: Initialized logger with log file in exp/bmi__fa-codec.
|
exp/bmi__fa-codec/amplified_signals/S06021_L0014_HA-output.wav
ADDED
|
Binary file (554 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06026_L0088_HA-output.wav
ADDED
|
Binary file (449 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06031_L0096_HA-output.wav
ADDED
|
Binary file (478 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06036_L0036_HA-output.wav
ADDED
|
Binary file (416 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06066_L0042_HA-output.wav
ADDED
|
Binary file (535 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06071_L0089_HA-output.wav
ADDED
|
Binary file (528 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06086_L0072_HA-output.wav
ADDED
|
Binary file (442 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06091_L0099_HA-output.wav
ADDED
|
Binary file (623 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06101_L0042_HA-output.wav
ADDED
|
Binary file (484 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06111_L0002_HA-output.wav
ADDED
|
Binary file (481 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06116_L0092_HA-output.wav
ADDED
|
Binary file (643 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06126_L0069_HA-output.wav
ADDED
|
Binary file (522 kB). View file
|
|
|
exp/bmi__fa-codec/amplified_signals/S06146_L0017_HA-output.wav
ADDED
|
Binary file (522 kB). View file
|
|
|