Spaces:
Runtime error
Runtime error
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| import numpy as np | |
| import torch | |
| import torchaudio as ta | |
| from lightning import LightningDataModule | |
| from torch.utils.data.dataloader import DataLoader | |
| from matcha.text import text_to_sequence | |
| from matcha.utils.audio import mel_spectrogram | |
| from matcha.utils.model import fix_len_compatibility, normalize | |
| from matcha.utils.utils import intersperse | |
| def parse_filelist(filelist_path, split_char="|"): | |
| with open(filelist_path, encoding="utf-8") as f: | |
| filepaths_and_text = [line.strip().split(split_char) for line in f] | |
| return filepaths_and_text | |
| class TextMelDataModule(LightningDataModule): | |
| def __init__( # pylint: disable=unused-argument | |
| self, | |
| name, | |
| train_filelist_path, | |
| valid_filelist_path, | |
| batch_size, | |
| num_workers, | |
| pin_memory, | |
| cleaners, | |
| add_blank, | |
| n_spks, | |
| n_fft, | |
| n_feats, | |
| sample_rate, | |
| hop_length, | |
| win_length, | |
| f_min, | |
| f_max, | |
| data_statistics, | |
| seed, | |
| load_durations, | |
| ): | |
| super().__init__() | |
| # this line allows to access init params with 'self.hparams' attribute | |
| # also ensures init params will be stored in ckpt | |
| self.save_hyperparameters(logger=False) | |
| def setup(self, stage: Optional[str] = None): # pylint: disable=unused-argument | |
| """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. | |
| This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be | |
| careful not to execute things like random split twice! | |
| """ | |
| # load and split datasets only if not loaded already | |
| self.trainset = TextMelDataset( # pylint: disable=attribute-defined-outside-init | |
| self.hparams.train_filelist_path, | |
| self.hparams.n_spks, | |
| self.hparams.cleaners, | |
| self.hparams.add_blank, | |
| self.hparams.n_fft, | |
| self.hparams.n_feats, | |
| self.hparams.sample_rate, | |
| self.hparams.hop_length, | |
| self.hparams.win_length, | |
| self.hparams.f_min, | |
| self.hparams.f_max, | |
| self.hparams.data_statistics, | |
| self.hparams.seed, | |
| self.hparams.load_durations, | |
| ) | |
| self.validset = TextMelDataset( # pylint: disable=attribute-defined-outside-init | |
| self.hparams.valid_filelist_path, | |
| self.hparams.n_spks, | |
| self.hparams.cleaners, | |
| self.hparams.add_blank, | |
| self.hparams.n_fft, | |
| self.hparams.n_feats, | |
| self.hparams.sample_rate, | |
| self.hparams.hop_length, | |
| self.hparams.win_length, | |
| self.hparams.f_min, | |
| self.hparams.f_max, | |
| self.hparams.data_statistics, | |
| self.hparams.seed, | |
| self.hparams.load_durations, | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| dataset=self.trainset, | |
| batch_size=self.hparams.batch_size, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=True, | |
| collate_fn=TextMelBatchCollate(self.hparams.n_spks), | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| dataset=self.validset, | |
| batch_size=self.hparams.batch_size, | |
| num_workers=self.hparams.num_workers, | |
| pin_memory=self.hparams.pin_memory, | |
| shuffle=False, | |
| collate_fn=TextMelBatchCollate(self.hparams.n_spks), | |
| ) | |
| def teardown(self, stage: Optional[str] = None): | |
| """Clean up after fit or test.""" | |
| pass # pylint: disable=unnecessary-pass | |
| def state_dict(self): | |
| """Extra things to save to checkpoint.""" | |
| return {} | |
| def load_state_dict(self, state_dict: Dict[str, Any]): | |
| """Things to do when loading checkpoint.""" | |
| pass # pylint: disable=unnecessary-pass | |
| class TextMelDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| filelist_path, | |
| n_spks, | |
| cleaners, | |
| add_blank=True, | |
| n_fft=1024, | |
| n_mels=80, | |
| sample_rate=22050, | |
| hop_length=256, | |
| win_length=1024, | |
| f_min=0.0, | |
| f_max=8000, | |
| data_parameters=None, | |
| seed=None, | |
| load_durations=False, | |
| ): | |
| self.filepaths_and_text = parse_filelist(filelist_path) | |
| self.n_spks = n_spks | |
| self.cleaners = cleaners | |
| self.add_blank = add_blank | |
| self.n_fft = n_fft | |
| self.n_mels = n_mels | |
| self.sample_rate = sample_rate | |
| self.hop_length = hop_length | |
| self.win_length = win_length | |
| self.f_min = f_min | |
| self.f_max = f_max | |
| self.load_durations = load_durations | |
| if data_parameters is not None: | |
| self.data_parameters = data_parameters | |
| else: | |
| self.data_parameters = {"mel_mean": 0, "mel_std": 1} | |
| random.seed(seed) | |
| random.shuffle(self.filepaths_and_text) | |
| def get_datapoint(self, filepath_and_text): | |
| if self.n_spks > 1: | |
| filepath, spk, text = ( | |
| filepath_and_text[0], | |
| int(filepath_and_text[1]), | |
| filepath_and_text[2], | |
| ) | |
| else: | |
| filepath, text = filepath_and_text[0], filepath_and_text[1] | |
| spk = None | |
| text, cleaned_text = self.get_text(text, add_blank=self.add_blank) | |
| mel = self.get_mel(filepath) | |
| durations = self.get_durations(filepath, text) if self.load_durations else None | |
| return {"x": text, "y": mel, "spk": spk, "filepath": filepath, "x_text": cleaned_text, "durations": durations} | |
| def get_durations(self, filepath, text): | |
| filepath = Path(filepath) | |
| data_dir, name = filepath.parent.parent, filepath.stem | |
| try: | |
| dur_loc = data_dir / "durations" / f"{name}.npy" | |
| durs = torch.from_numpy(np.load(dur_loc).astype(int)) | |
| except FileNotFoundError as e: | |
| raise FileNotFoundError( | |
| f"Tried loading the durations but durations didn't exist at {dur_loc}, make sure you've generate the durations first using: python matcha/utils/get_durations_from_trained_model.py \n" | |
| ) from e | |
| assert len(durs) == len(text), f"Length of durations {len(durs)} and text {len(text)} do not match" | |
| return durs | |
| def get_mel(self, filepath): | |
| audio, sr = ta.load(filepath) | |
| assert sr == self.sample_rate | |
| mel = mel_spectrogram( | |
| audio, | |
| self.n_fft, | |
| self.n_mels, | |
| self.sample_rate, | |
| self.hop_length, | |
| self.win_length, | |
| self.f_min, | |
| self.f_max, | |
| center=False, | |
| ).squeeze() | |
| mel = normalize(mel, self.data_parameters["mel_mean"], self.data_parameters["mel_std"]) | |
| return mel | |
| def get_text(self, text, add_blank=True): | |
| text_norm, cleaned_text = text_to_sequence(text, self.cleaners) | |
| if self.add_blank: | |
| text_norm = intersperse(text_norm, 0) | |
| text_norm = torch.IntTensor(text_norm) | |
| return text_norm, cleaned_text | |
| def __getitem__(self, index): | |
| datapoint = self.get_datapoint(self.filepaths_and_text[index]) | |
| return datapoint | |
| def __len__(self): | |
| return len(self.filepaths_and_text) | |
| class TextMelBatchCollate: | |
| def __init__(self, n_spks): | |
| self.n_spks = n_spks | |
| def __call__(self, batch): | |
| B = len(batch) | |
| y_max_length = max([item["y"].shape[-1] for item in batch]) | |
| y_max_length = fix_len_compatibility(y_max_length) | |
| x_max_length = max([item["x"].shape[-1] for item in batch]) | |
| n_feats = batch[0]["y"].shape[-2] | |
| y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32) | |
| x = torch.zeros((B, x_max_length), dtype=torch.long) | |
| durations = torch.zeros((B, x_max_length), dtype=torch.long) | |
| y_lengths, x_lengths = [], [] | |
| spks = [] | |
| filepaths, x_texts = [], [] | |
| for i, item in enumerate(batch): | |
| y_, x_ = item["y"], item["x"] | |
| y_lengths.append(y_.shape[-1]) | |
| x_lengths.append(x_.shape[-1]) | |
| y[i, :, : y_.shape[-1]] = y_ | |
| x[i, : x_.shape[-1]] = x_ | |
| spks.append(item["spk"]) | |
| filepaths.append(item["filepath"]) | |
| x_texts.append(item["x_text"]) | |
| if item["durations"] is not None: | |
| durations[i, : item["durations"].shape[-1]] = item["durations"] | |
| y_lengths = torch.tensor(y_lengths, dtype=torch.long) | |
| x_lengths = torch.tensor(x_lengths, dtype=torch.long) | |
| spks = torch.tensor(spks, dtype=torch.long) if self.n_spks > 1 else None | |
| return { | |
| "x": x, | |
| "x_lengths": x_lengths, | |
| "y": y, | |
| "y_lengths": y_lengths, | |
| "spks": spks, | |
| "filepaths": filepaths, | |
| "x_texts": x_texts, | |
| "durations": durations if not torch.eq(durations, 0).all() else None, | |
| } | |