| | import random |
| | from typing import Any, Dict, Optional |
| |
|
| | import torch |
| | import torchaudio as ta |
| | from lightning import LightningDataModule |
| | from torch.utils.data.dataloader import DataLoader |
| |
|
| | from matcha.text import text_to_sequence |
| | from matcha.utils.audio import mel_spectrogram |
| | from matcha.utils.model import fix_len_compatibility, normalize |
| | from matcha.utils.utils import intersperse |
| |
|
| |
|
| | def parse_filelist(filelist_path, split_char="|"): |
| | with open(filelist_path, encoding="utf-8") as f: |
| | filepaths_and_text = [line.strip().split(split_char) for line in f] |
| | return filepaths_and_text |
| |
|
| |
|
| | class TextMelDataModule(LightningDataModule): |
| | def __init__( |
| | self, |
| | name, |
| | train_filelist_path, |
| | valid_filelist_path, |
| | batch_size, |
| | num_workers, |
| | pin_memory, |
| | cleaners, |
| | add_blank, |
| | n_spks, |
| | n_fft, |
| | n_feats, |
| | sample_rate, |
| | hop_length, |
| | win_length, |
| | f_min, |
| | f_max, |
| | data_statistics, |
| | seed, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.save_hyperparameters(logger=False) |
| |
|
| | def setup(self, stage: Optional[str] = None): |
| | """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. |
| | |
| | This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be |
| | careful not to execute things like random split twice! |
| | """ |
| | |
| |
|
| | self.trainset = TextMelDataset( |
| | self.hparams.train_filelist_path, |
| | self.hparams.n_spks, |
| | self.hparams.cleaners, |
| | self.hparams.add_blank, |
| | self.hparams.n_fft, |
| | self.hparams.n_feats, |
| | self.hparams.sample_rate, |
| | self.hparams.hop_length, |
| | self.hparams.win_length, |
| | self.hparams.f_min, |
| | self.hparams.f_max, |
| | self.hparams.data_statistics, |
| | self.hparams.seed, |
| | ) |
| | self.validset = TextMelDataset( |
| | self.hparams.valid_filelist_path, |
| | self.hparams.n_spks, |
| | self.hparams.cleaners, |
| | self.hparams.add_blank, |
| | self.hparams.n_fft, |
| | self.hparams.n_feats, |
| | self.hparams.sample_rate, |
| | self.hparams.hop_length, |
| | self.hparams.win_length, |
| | self.hparams.f_min, |
| | self.hparams.f_max, |
| | self.hparams.data_statistics, |
| | self.hparams.seed, |
| | ) |
| |
|
| | def train_dataloader(self): |
| | return DataLoader( |
| | dataset=self.trainset, |
| | batch_size=self.hparams.batch_size, |
| | num_workers=self.hparams.num_workers, |
| | pin_memory=self.hparams.pin_memory, |
| | shuffle=True, |
| | collate_fn=TextMelBatchCollate(self.hparams.n_spks), |
| | ) |
| |
|
| | def val_dataloader(self): |
| | return DataLoader( |
| | dataset=self.validset, |
| | batch_size=self.hparams.batch_size, |
| | num_workers=self.hparams.num_workers, |
| | pin_memory=self.hparams.pin_memory, |
| | shuffle=False, |
| | collate_fn=TextMelBatchCollate(self.hparams.n_spks), |
| | ) |
| |
|
| | def teardown(self, stage: Optional[str] = None): |
| | """Clean up after fit or test.""" |
| | pass |
| |
|
| | def state_dict(self): |
| | """Extra things to save to checkpoint.""" |
| | return {} |
| |
|
| | def load_state_dict(self, state_dict: Dict[str, Any]): |
| | """Things to do when loading checkpoint.""" |
| | pass |
| |
|
| |
|
| | class TextMelDataset(torch.utils.data.Dataset): |
| | def __init__( |
| | self, |
| | filelist_path, |
| | n_spks, |
| | cleaners, |
| | add_blank=True, |
| | n_fft=1024, |
| | n_mels=80, |
| | sample_rate=22050, |
| | hop_length=256, |
| | win_length=1024, |
| | f_min=0.0, |
| | f_max=8000, |
| | data_parameters=None, |
| | seed=None, |
| | ): |
| | self.filepaths_and_text = parse_filelist(filelist_path) |
| | self.n_spks = n_spks |
| | self.cleaners = cleaners |
| | self.add_blank = add_blank |
| | self.n_fft = n_fft |
| | self.n_mels = n_mels |
| | self.sample_rate = sample_rate |
| | self.hop_length = hop_length |
| | self.win_length = win_length |
| | self.f_min = f_min |
| | self.f_max = f_max |
| | if data_parameters is not None: |
| | self.data_parameters = data_parameters |
| | else: |
| | self.data_parameters = {"mel_mean": 0, "mel_std": 1} |
| | random.seed(seed) |
| | random.shuffle(self.filepaths_and_text) |
| |
|
| | def get_datapoint(self, filepath_and_text): |
| | if self.n_spks > 1: |
| | filepath, spk, text = ( |
| | filepath_and_text[0], |
| | int(filepath_and_text[1]), |
| | filepath_and_text[2], |
| | ) |
| | else: |
| | filepath, text = filepath_and_text[0], filepath_and_text[1] |
| | spk = None |
| |
|
| | text = self.get_text(text, add_blank=self.add_blank) |
| | mel = self.get_mel(filepath) |
| |
|
| | return {"x": text, "y": mel, "spk": spk} |
| |
|
| | def get_mel(self, filepath): |
| | audio, sr = ta.load(filepath) |
| | assert sr == self.sample_rate |
| | mel = mel_spectrogram( |
| | audio, |
| | self.n_fft, |
| | self.n_mels, |
| | self.sample_rate, |
| | self.hop_length, |
| | self.win_length, |
| | self.f_min, |
| | self.f_max, |
| | center=False, |
| | ).squeeze() |
| | mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) |
| | return mel |
| |
|
| | def get_text(self, text, add_blank=True): |
| | text_norm = text_to_sequence(text, self.cleaners) |
| | if self.add_blank: |
| | text_norm = intersperse(text_norm, 0) |
| | text_norm = torch.IntTensor(text_norm) |
| | return text_norm |
| |
|
| | def __getitem__(self, index): |
| | datapoint = self.get_datapoint(self.filepaths_and_text[index]) |
| | return datapoint |
| |
|
| | def __len__(self): |
| | return len(self.filepaths_and_text) |
| |
|
| |
|
| | class TextMelBatchCollate: |
| | def __init__(self, n_spks): |
| | self.n_spks = n_spks |
| |
|
| | def __call__(self, batch): |
| | B = len(batch) |
| | y_max_length = max([item["y"].shape[-1] for item in batch]) |
| | y_max_length = fix_len_compatibility(y_max_length) |
| | x_max_length = max([item["x"].shape[-1] for item in batch]) |
| | n_feats = batch[0]["y"].shape[-2] |
| |
|
| | y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) |
| | x = torch.zeros((B, x_max_length), dtype=torch.long) |
| | y_lengths, x_lengths = [], [] |
| | spks = [] |
| | for i, item in enumerate(batch): |
| | y_, x_ = item["y"], item["x"] |
| | y_lengths.append(y_.shape[-1]) |
| | x_lengths.append(x_.shape[-1]) |
| | y[i, :, : y_.shape[-1]] = y_ |
| | x[i, : x_.shape[-1]] = x_ |
| | spks.append(item["spk"]) |
| |
|
| | y_lengths = torch.tensor(y_lengths, dtype=torch.long) |
| | x_lengths = torch.tensor(x_lengths, dtype=torch.long) |
| | spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None |
| |
|
| | return {"x": x, "x_lengths": x_lengths, "y": y, "y_lengths": y_lengths, "spks": spks} |
| |
|