Spaces:
Sleeping
Sleeping
| r""" | |
| The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it | |
| when needed. | |
| Parameters from hparam.py will be used | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import rootutils | |
| import torch | |
| from hydra import compose, initialize | |
| from omegaconf import open_dict | |
| from tqdm.auto import tqdm | |
| from matcha.data.text_mel_datamodule import TextMelDataModule | |
| from matcha.utils.logging_utils import pylogger | |
| log = pylogger.get_pylogger(__name__) | |
| def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): | |
| """Generate data mean and standard deviation helpful in data normalisation | |
| Args: | |
| data_loader (torch.utils.data.Dataloader): _description_ | |
| out_channels (int): mel spectrogram channels | |
| """ | |
| total_mel_sum = 0 | |
| total_mel_sq_sum = 0 | |
| total_mel_len = 0 | |
| for batch in tqdm(data_loader, leave=False): | |
| mels = batch["y"] | |
| mel_lengths = batch["y_lengths"] | |
| total_mel_len += torch.sum(mel_lengths) | |
| total_mel_sum += torch.sum(mels) | |
| total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) | |
| data_mean = total_mel_sum / (total_mel_len * out_channels) | |
| data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) | |
| return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-i", | |
| "--input-config", | |
| type=str, | |
| default="vctk.yaml", | |
| help="The name of the yaml config file under configs/data", | |
| ) | |
| parser.add_argument( | |
| "-b", | |
| "--batch-size", | |
| type=int, | |
| default="256", | |
| help="Can have increased batch size for faster computation", | |
| ) | |
| parser.add_argument( | |
| "-f", | |
| "--force", | |
| action="store_true", | |
| default=False, | |
| required=False, | |
| help="force overwrite the file", | |
| ) | |
| args = parser.parse_args() | |
| output_file = Path(args.input_config).with_suffix(".json") | |
| if os.path.exists(output_file) and not args.force: | |
| print("File already exists. Use -f to force overwrite") | |
| sys.exit(1) | |
| with initialize(version_base="1.3", config_path="../../configs/data"): | |
| cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) | |
| root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") | |
| with open_dict(cfg): | |
| del cfg["hydra"] | |
| del cfg["_target_"] | |
| cfg["data_statistics"] = None | |
| cfg["seed"] = 1234 | |
| cfg["batch_size"] = args.batch_size | |
| cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) | |
| cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) | |
| text_mel_datamodule = TextMelDataModule(**cfg) | |
| text_mel_datamodule.setup() | |
| data_loader = text_mel_datamodule.train_dataloader() | |
| log.info("Dataloader loaded! Now computing stats...") | |
| params = compute_data_statistics(data_loader, cfg["n_feats"]) | |
| print(params) | |
| json.dump( | |
| params, | |
| open(output_file, "w"), | |
| ) | |
| if __name__ == "__main__": | |
| main() | |