| |
| |
| |
| |
|
|
| import random |
| import torch |
| import torchaudio |
| from torch.nn.utils.rnn import pad_sequence |
| from utils.data_utils import * |
| from models.base.base_dataset import ( |
| BaseOfflineCollator, |
| BaseOfflineDataset, |
| BaseTestDataset, |
| BaseTestCollator, |
| ) |
| from text import text_to_sequence |
|
|
|
|
| class JetsDataset(BaseOfflineDataset): |
| def __init__(self, cfg, dataset, is_valid=False): |
| BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid) |
| self.batch_size = cfg.train.batch_size |
| cfg = cfg.preprocess |
| |
| self.utt2duration_path = {} |
| for utt_info in self.metadata: |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| self.utt2duration_path[utt] = os.path.join( |
| cfg.processed_dir, |
| dataset, |
| cfg.duration_dir, |
| uid + ".npy", |
| ) |
| self.utt2dur = self.read_duration() |
|
|
| if cfg.use_frame_energy: |
| self.frame_utt2energy, self.energy_statistic = load_energy( |
| self.metadata, |
| cfg.processed_dir, |
| cfg.energy_dir, |
| use_log_scale=cfg.use_log_scale_energy, |
| utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None, |
| return_norm=True, |
| ) |
| elif cfg.use_phone_energy: |
| self.phone_utt2energy, self.energy_statistic = load_energy( |
| self.metadata, |
| cfg.processed_dir, |
| cfg.phone_energy_dir, |
| use_log_scale=cfg.use_log_scale_energy, |
| utt2spk=self.utt2spk if cfg.use_spkid else None, |
| return_norm=True, |
| ) |
|
|
| if cfg.use_frame_pitch: |
| self.frame_utt2pitch, self.pitch_statistic = load_energy( |
| self.metadata, |
| cfg.processed_dir, |
| cfg.pitch_dir, |
| use_log_scale=cfg.energy_extract_mode, |
| utt2spk=self.utt2spk if cfg.use_spkid else None, |
| return_norm=True, |
| ) |
|
|
| elif cfg.use_phone_pitch: |
| self.phone_utt2pitch, self.pitch_statistic = load_energy( |
| self.metadata, |
| cfg.processed_dir, |
| cfg.phone_pitch_dir, |
| use_log_scale=cfg.use_log_scale_pitch, |
| utt2spk=self.utt2spk if cfg.use_spkid else None, |
| return_norm=True, |
| ) |
|
|
| |
| self.utt2lab_path = {} |
| for utt_info in self.metadata: |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| self.utt2lab_path[utt] = os.path.join( |
| cfg.processed_dir, |
| dataset, |
| cfg.lab_dir, |
| uid + ".txt", |
| ) |
|
|
| self.speaker_map = {} |
| if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): |
| with open( |
| os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) |
| ) as f: |
| self.speaker_map = json.load(f) |
|
|
| self.metadata = self.check_metadata() |
| if cfg.use_audios: |
| self.utt2audio_path = {} |
| for utt_info in self.metadata: |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| if cfg.extract_audio: |
| self.utt2audio_path[utt] = os.path.join( |
| cfg.processed_dir, |
| dataset, |
| cfg.audio_dir, |
| uid + ".wav", |
| ) |
| else: |
| self.utt2audio_path[utt] = utt_info["Path"] |
|
|
| def __getitem__(self, index): |
| single_feature = BaseOfflineDataset.__getitem__(self, index) |
|
|
| utt_info = self.metadata[index] |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| duration = self.utt2dur[utt] |
|
|
| |
| f = open(self.utt2lab_path[utt], "r") |
| phones = f.readlines()[0].strip() |
| f.close() |
| |
| phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"])) |
| text_len = len(phones_ids) |
|
|
| if self.cfg.preprocess.use_frame_pitch: |
| pitch = self.frame_utt2pitch[utt] |
| elif self.cfg.preprocess.use_phone_pitch: |
| pitch = self.phone_utt2pitch[utt] |
|
|
| if self.cfg.preprocess.use_frame_energy: |
| energy = self.frame_utt2energy[utt] |
| elif self.cfg.preprocess.use_phone_energy: |
| energy = self.phone_utt2energy[utt] |
|
|
| |
| if len(self.speaker_map) > 0: |
| speaker_id = self.speaker_map[utt_info["Singer"]] |
| else: |
| speaker_id = 0 |
|
|
| single_feature.update( |
| { |
| "durations": duration, |
| "texts": phones_ids, |
| "spk_id": speaker_id, |
| "text_len": text_len, |
| "pitch": pitch, |
| "energy": energy, |
| "uid": uid, |
| } |
| ) |
|
|
| if self.cfg.preprocess.use_audios: |
| audio, sr = torchaudio.load(self.utt2audio_path[utt]) |
| audio = audio.cpu().numpy().squeeze() |
| single_feature["audio"] = audio |
| single_feature["audio_len"] = audio.shape[0] |
| return self.clip_if_too_long(single_feature) |
|
|
| def read_duration(self): |
| |
| utt2dur = {} |
| for index in range(len(self.metadata)): |
| utt_info = self.metadata[index] |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists( |
| self.utt2duration_path[utt] |
| ): |
| continue |
|
|
| mel = np.load(self.utt2mel_path[utt]).transpose(1, 0) |
| duration = np.load(self.utt2duration_path[utt]) |
| assert mel.shape[0] == sum( |
| duration |
| ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}" |
| utt2dur[utt] = duration |
| return utt2dur |
|
|
| def __len__(self): |
| return len(self.metadata) |
|
|
| def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): |
| """ |
| ending_ts: to avoid invalid whisper features for over 30s audios |
| 2812 = 30 * 24000 // 256 |
| """ |
| ts = max(feature_seq_len - max_seq_len, 0) |
| ts = min(ts, ending_ts - max_seq_len) |
|
|
| start = random.randint(0, ts) |
| end = start + max_seq_len |
| return start, end |
|
|
| def clip_if_too_long(self, sample, max_seq_len=1000): |
| """ |
| sample : |
| { |
| 'spk_id': (1,), |
| 'target_len': int |
| 'mel': (seq_len, dim), |
| 'frame_pitch': (seq_len,) |
| 'frame_energy': (seq_len,) |
| 'content_vector_feat': (seq_len, dim) |
| } |
| """ |
| if sample["target_len"] <= max_seq_len: |
| return sample |
|
|
| start, end = self.random_select(sample["target_len"], max_seq_len) |
| sample["target_len"] = end - start |
|
|
| for k in sample.keys(): |
| if k not in ["spk_id", "target_len"]: |
| sample[k] = sample[k][start:end] |
|
|
| return sample |
|
|
| def check_metadata(self): |
| new_metadata = [] |
| for utt_info in self.metadata: |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
| if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists( |
| self.utt2mel_path[utt] |
| ): |
| continue |
| else: |
| new_metadata.append(utt_info) |
| return new_metadata |
|
|
|
|
| class JetsCollator(BaseOfflineCollator): |
| """Zero-pads model inputs and targets based on number of frames per step""" |
|
|
| def __init__(self, cfg): |
| BaseOfflineCollator.__init__(self, cfg) |
| self.sort = cfg.train.sort_sample |
| self.batch_size = cfg.train.batch_size |
| self.drop_last = cfg.train.drop_last |
|
|
| def __call__(self, batch): |
| |
| |
| |
| |
| |
| packed_batch_features = dict() |
|
|
| for key in batch[0].keys(): |
| if key == "target_len": |
| packed_batch_features["target_len"] = torch.LongTensor( |
| [b["target_len"] for b in batch] |
| ) |
| masks = [ |
| torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch |
| ] |
| packed_batch_features["mask"] = pad_sequence( |
| masks, batch_first=True, padding_value=0 |
| ) |
| elif key == "text_len": |
| packed_batch_features["text_len"] = torch.LongTensor( |
| [b["text_len"] for b in batch] |
| ) |
| masks = [ |
| torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch |
| ] |
| packed_batch_features["text_mask"] = pad_sequence( |
| masks, batch_first=True, padding_value=0 |
| ) |
| elif key == "spk_id": |
| packed_batch_features["spk_id"] = torch.LongTensor( |
| [b["spk_id"] for b in batch] |
| ) |
| elif key == "uid": |
| packed_batch_features[key] = [b["uid"] for b in batch] |
| elif key == "audio_len": |
| packed_batch_features["audio_len"] = torch.LongTensor( |
| [b["audio_len"] for b in batch] |
| ) |
| else: |
| values = [torch.from_numpy(b[key]) for b in batch] |
| packed_batch_features[key] = pad_sequence( |
| values, batch_first=True, padding_value=0 |
| ) |
| return packed_batch_features |
|
|
|
|
| class JetsTestDataset(BaseTestDataset): |
| def __init__(self, args, cfg, infer_type=None): |
| datasets = cfg.dataset |
| cfg = cfg.preprocess |
| is_bigdata = False |
|
|
| assert len(datasets) >= 1 |
| if len(datasets) > 1: |
| datasets.sort() |
| bigdata_version = "_".join(datasets) |
| processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version) |
| is_bigdata = True |
| else: |
| processed_data_dir = os.path.join(cfg.processed_dir, args.dataset) |
|
|
| if args.test_list_file: |
| self.metafile_path = args.test_list_file |
| self.metadata = self.get_metadata() |
| else: |
| assert args.testing_set |
| source_metafile_path = os.path.join( |
| cfg.processed_dir, |
| args.dataset, |
| "{}.json".format(args.testing_set), |
| ) |
| with open(source_metafile_path, "r") as f: |
| self.metadata = json.load(f) |
|
|
| self.cfg = cfg |
| self.datasets = datasets |
| self.data_root = processed_data_dir |
| self.is_bigdata = is_bigdata |
| self.source_dataset = args.dataset |
|
|
| |
| if cfg.use_spkid: |
| spk2id_path = os.path.join(self.data_root, cfg.spk2id) |
| utt2sp_path = os.path.join(self.data_root, cfg.utt2spk) |
| self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets) |
|
|
| |
| self.utt2lab_path = {} |
| for utt_info in self.metadata: |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
| self.utt2lab_path[utt] = os.path.join( |
| cfg.processed_dir, |
| dataset, |
| cfg.lab_dir, |
| uid + ".txt", |
| ) |
|
|
| self.speaker_map = {} |
| if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): |
| with open( |
| os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) |
| ) as f: |
| self.speaker_map = json.load(f) |
|
|
| def __getitem__(self, index): |
| single_feature = {} |
|
|
| utt_info = self.metadata[index] |
| dataset = utt_info["Dataset"] |
| uid = utt_info["Uid"] |
| utt = "{}_{}".format(dataset, uid) |
|
|
| |
| f = open(self.utt2lab_path[utt], "r") |
| phones = f.readlines()[0].strip() |
| f.close() |
|
|
| phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners)) |
| text_len = len(phones_ids) |
|
|
| |
| if len(self.speaker_map) > 0: |
| speaker_id = self.speaker_map[utt_info["Singer"]] |
| else: |
| speaker_id = 0 |
|
|
| single_feature.update( |
| { |
| "texts": phones_ids, |
| "spk_id": speaker_id, |
| "text_len": text_len, |
| } |
| ) |
|
|
| return single_feature |
|
|
| def __len__(self): |
| return len(self.metadata) |
|
|
| def get_metadata(self): |
| with open(self.metafile_path, "r", encoding="utf-8") as f: |
| metadata = json.load(f) |
|
|
| return metadata |
|
|
|
|
| class JetsTestCollator(BaseTestCollator): |
| """Zero-pads model inputs and targets based on number of frames per step""" |
|
|
| def __init__(self, cfg): |
| self.cfg = cfg |
|
|
| def __call__(self, batch): |
| packed_batch_features = dict() |
|
|
| |
| |
| |
| |
| |
|
|
| for key in batch[0].keys(): |
| if key == "target_len": |
| packed_batch_features["target_len"] = torch.LongTensor( |
| [b["target_len"] for b in batch] |
| ) |
| masks = [ |
| torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch |
| ] |
| packed_batch_features["mask"] = pad_sequence( |
| masks, batch_first=True, padding_value=0 |
| ) |
| elif key == "text_len": |
| packed_batch_features["text_len"] = torch.LongTensor( |
| [b["text_len"] for b in batch] |
| ) |
| masks = [ |
| torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch |
| ] |
| packed_batch_features["text_mask"] = pad_sequence( |
| masks, batch_first=True, padding_value=0 |
| ) |
| elif key == "spk_id": |
| packed_batch_features["spk_id"] = torch.LongTensor( |
| [b["spk_id"] for b in batch] |
| ) |
| else: |
| values = [torch.from_numpy(b[key]) for b in batch] |
| packed_batch_features[key] = pad_sequence( |
| values, batch_first=True, padding_value=0 |
| ) |
|
|
| return packed_batch_features |
|
|