| r""" |
| The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it |
| when needed. |
| |
| Parameters from hparam.py will be used |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import rootutils |
| import torch |
| from hydra import compose, initialize |
| from omegaconf import open_dict |
| from tqdm.auto import tqdm |
|
|
| from matcha.data.text_mel_datamodule import TextMelDataModule |
| from matcha.utils.logging_utils import pylogger |
|
|
| log = pylogger.get_pylogger(__name__) |
|
|
|
|
| def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): |
| """Generate data mean and standard deviation helpful in data normalisation |
| |
| Args: |
| data_loader (torch.utils.data.Dataloader): _description_ |
| out_channels (int): mel spectrogram channels |
| """ |
| total_mel_sum = 0 |
| total_mel_sq_sum = 0 |
| total_mel_len = 0 |
|
|
| for batch in tqdm(data_loader, leave=False): |
| mels = batch["y"] |
| mel_lengths = batch["y_lengths"] |
|
|
| total_mel_len += torch.sum(mel_lengths) |
| total_mel_sum += torch.sum(mels) |
| total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) |
|
|
| data_mean = total_mel_sum / (total_mel_len * out_channels) |
| data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) |
|
|
| return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument( |
| "-i", |
| "--input-config", |
| type=str, |
| default="vctk.yaml", |
| help="The name of the yaml config file under configs/data", |
| ) |
|
|
| parser.add_argument( |
| "-b", |
| "--batch-size", |
| type=int, |
| default="256", |
| help="Can have increased batch size for faster computation", |
| ) |
|
|
| parser.add_argument( |
| "-f", |
| "--force", |
| action="store_true", |
| default=False, |
| required=False, |
| help="force overwrite the file", |
| ) |
| args = parser.parse_args() |
| output_file = Path(args.input_config).with_suffix(".json") |
|
|
| if os.path.exists(output_file) and not args.force: |
| print("File already exists. Use -f to force overwrite") |
| sys.exit(1) |
|
|
| with initialize(version_base="1.3", config_path="../../configs/data"): |
| cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) |
|
|
| root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") |
|
|
| with open_dict(cfg): |
| del cfg["hydra"] |
| del cfg["_target_"] |
| cfg["data_statistics"] = None |
| cfg["seed"] = 1234 |
| cfg["batch_size"] = args.batch_size |
| cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) |
| cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) |
|
|
| text_mel_datamodule = TextMelDataModule(**cfg) |
| text_mel_datamodule.setup() |
| data_loader = text_mel_datamodule.train_dataloader() |
| log.info("Dataloader loaded! Now computing stats...") |
| params = compute_data_statistics(data_loader, cfg["n_feats"]) |
| print(params) |
| json.dump( |
| params, |
| open(output_file, "w"), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|