Spaces:
Runtime error
Runtime error
| import argparse | |
| from typing import Any | |
| import tensorflow as tf | |
| class EasyDict(dict): | |
| def __getattr__(self, name: str) -> Any: | |
| try: | |
| return self[name] | |
| except KeyError: | |
| raise AttributeError(name) | |
| def __setattr__(self, name: str, value: Any) -> None: | |
| self[name] = value | |
| def __delattr__(self, name: str) -> None: | |
| del self[name] | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") | |
| def params_args(args): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--hop", | |
| type=int, | |
| default=256, | |
| help="Hop size (window size = 4*hop)", | |
| ) | |
| parser.add_argument( | |
| "--mel_bins", | |
| type=int, | |
| default=256, | |
| help="Mel bins in mel-spectrograms", | |
| ) | |
| parser.add_argument( | |
| "--sr", | |
| type=int, | |
| default=44100, | |
| help="Sampling Rate", | |
| ) | |
| parser.add_argument( | |
| "--small", | |
| type=str2bool, | |
| default=False, | |
| help="If True, use model with shorter available context, useful for small datasets", | |
| ) | |
| parser.add_argument( | |
| "--latdepth", | |
| type=int, | |
| default=64, | |
| help="Depth of generated latent vectors", | |
| ) | |
| parser.add_argument( | |
| "--coorddepth", | |
| type=int, | |
| default=64, | |
| help="Dimension of latent coordinate and style random vectors", | |
| ) | |
| parser.add_argument( | |
| "--max_lat_len", | |
| type=int, | |
| default=512, | |
| help="Length of .npy arrays used for training", | |
| ) | |
| parser.add_argument( | |
| "--base_channels", | |
| type=int, | |
| default=128, | |
| help="Base channels for generator and discriminator architectures", | |
| ) | |
| parser.add_argument( | |
| "--shape", | |
| type=int, | |
| default=128, | |
| help="Length of spectrograms time axis", | |
| ) | |
| parser.add_argument( | |
| "--window", | |
| type=int, | |
| default=64, | |
| help="Generator spectrogram window (must divide shape)", | |
| ) | |
| parser.add_argument( | |
| "--bs", | |
| type=int, | |
| default=32, | |
| help="Batch size", | |
| ) | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| default=0.0001, | |
| help="Learning Rate", | |
| ) | |
| parser.add_argument( | |
| "--gp_max_weight", | |
| type=float, | |
| default=10.0, | |
| help="Maximum allowed R1 gradient penalty loss weight. The weight will self-adapt if high values are not needed for stable training.", | |
| ) | |
| parser.add_argument( | |
| "--totsamples", | |
| type=int, | |
| default=300000, | |
| help="Max samples chosen per epoch", | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=250, | |
| help="Number of epochs", | |
| ) | |
| parser.add_argument( | |
| "--save_every", | |
| type=int, | |
| default=1, | |
| help="Save after x epochs", | |
| ) | |
| parser.add_argument( | |
| "--mu_rescale", | |
| type=float, | |
| default=-25.0, | |
| help="Spectrogram mu used to normalize", | |
| ) | |
| parser.add_argument( | |
| "--sigma_rescale", | |
| type=float, | |
| default=75.0, | |
| help="Spectrogram sigma used to normalize", | |
| ) | |
| parser.add_argument( | |
| "--save_path", | |
| type=str, | |
| default="checkpoints", | |
| help="Path where to save checkpoints", | |
| ) | |
| parser.add_argument( | |
| "--train_path", | |
| type=str, | |
| default="training_samples", | |
| help="Path of training samples", | |
| ) | |
| parser.add_argument( | |
| "--dec_path", | |
| type=str, | |
| default="checkpoints/ae", | |
| help="Path of pretrained decoders weights", | |
| ) | |
| parser.add_argument( | |
| "--load_path", | |
| type=str, | |
| default="None", | |
| help="If not None, load models weights from this path", | |
| ) | |
| parser.add_argument( | |
| "--base_path", | |
| type=str, | |
| default="checkpoints", | |
| help="Path where pretrained models are downloaded", | |
| ) | |
| parser.add_argument( | |
| "--log_path", | |
| type=str, | |
| default="logs", | |
| help="Path where to save tensorboard logs", | |
| ) | |
| parser.add_argument( | |
| "--testing", | |
| type=str2bool, | |
| default=False, | |
| help="True if optimizers weight do not need to be loaded", | |
| ) | |
| parser.add_argument( | |
| "--cpu", | |
| type=str2bool, | |
| default=False, | |
| help="True if you wish to use cpu", | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str2bool, | |
| default=True, | |
| help="True if your GPU supports mixed precision", | |
| ) | |
| parser.add_argument( | |
| "--xla", | |
| type=str2bool, | |
| default=True, | |
| help="True if you wish to improve training speed with XLA", | |
| ) | |
| parser.add_argument( | |
| "--share_gradio", | |
| type=str2bool, | |
| default=False, | |
| help="True if you wish to create a public URL for the Gradio interface", | |
| ) | |
| tmp_args = parser.parse_args() | |
| args.hop = tmp_args.hop | |
| args.mel_bins = tmp_args.mel_bins | |
| args.sr = tmp_args.sr | |
| args.small = tmp_args.small | |
| args.latdepth = tmp_args.latdepth | |
| args.coorddepth = tmp_args.coorddepth | |
| args.max_lat_len = tmp_args.max_lat_len | |
| args.base_channels = tmp_args.base_channels | |
| args.shape = tmp_args.shape | |
| args.window = tmp_args.window | |
| args.bs = tmp_args.bs | |
| args.lr = tmp_args.lr | |
| args.gp_max_weight = tmp_args.gp_max_weight | |
| args.totsamples = tmp_args.totsamples | |
| args.epochs = tmp_args.epochs | |
| args.save_every = tmp_args.save_every | |
| args.mu_rescale = tmp_args.mu_rescale | |
| args.sigma_rescale = tmp_args.sigma_rescale | |
| args.save_path = tmp_args.save_path | |
| args.train_path = tmp_args.train_path | |
| args.dec_path = tmp_args.dec_path | |
| args.load_path = tmp_args.load_path | |
| args.base_path = tmp_args.base_path | |
| args.log_path = tmp_args.log_path | |
| args.testing = tmp_args.testing | |
| args.cpu = tmp_args.cpu | |
| args.mixed_precision = tmp_args.mixed_precision | |
| args.xla = tmp_args.xla | |
| args.share_gradio = tmp_args.share_gradio | |
| if args.small: | |
| args.latlen = 128 | |
| else: | |
| args.latlen = 256 | |
| args.coordlen = (args.latlen // 2) * 3 | |
| print() | |
| args.datatype = tf.float32 | |
| gpuls = tf.config.list_physical_devices("GPU") | |
| if len(gpuls) == 0 or args.cpu: | |
| args.cpu = True | |
| args.mixed_precision = False | |
| tf.config.set_visible_devices([], "GPU") | |
| print() | |
| print("Using CPU...") | |
| print() | |
| if args.mixed_precision: | |
| args.datatype = tf.float16 | |
| print() | |
| print("Using GPU with mixed precision enabled...") | |
| print() | |
| if not args.mixed_precision and not args.cpu: | |
| print() | |
| print("Using GPU without mixed precision...") | |
| print() | |
| return args | |
| def parse_args(): | |
| args = EasyDict() | |
| return params_args(args) |