Spaces:
Build error
Build error
| import argparse | |
| from typing import Any | |
| import tensorflow as tf | |
| class EasyDict(dict): | |
| def __getattr__(self, name: str) -> Any: | |
| try: | |
| return self[name] | |
| except KeyError: | |
| raise AttributeError(name) | |
| def __setattr__(self, name: str, value: Any) -> None: | |
| self[name] = value | |
| def __delattr__(self, name: str) -> None: | |
| del self[name] | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ("yes", "true", "t", "y", "1"): | |
| return True | |
| elif v.lower() in ("no", "false", "f", "n", "0"): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError("Boolean value expected.") | |
| def params_args(args): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--hop", | |
| type=int, | |
| default=256, | |
| help="Hop size (window size = 4*hop)", | |
| ) | |
| parser.add_argument( | |
| "--mel_bins", | |
| type=int, | |
| default=256, | |
| help="Mel bins in mel-spectrograms", | |
| ) | |
| parser.add_argument( | |
| "--sr", | |
| type=int, | |
| default=44100, | |
| help="Sampling Rate", | |
| ) | |
| parser.add_argument( | |
| "--small", | |
| type=str2bool, | |
| default=False, | |
| help="If True, use model with shorter available context, useful for small datasets", | |
| ) | |
| parser.add_argument( | |
| "--latdepth", | |
| type=int, | |
| default=64, | |
| help="Depth of generated latent vectors", | |
| ) | |
| parser.add_argument( | |
| "--coorddepth", | |
| type=int, | |
| default=64, | |
| help="Dimension of latent coordinate and style random vectors", | |
| ) | |
| parser.add_argument( | |
| "--base_channels", | |
| type=int, | |
| default=128, | |
| help="Base channels for generator and discriminator architectures", | |
| ) | |
| parser.add_argument( | |
| "--shape", | |
| type=int, | |
| default=128, | |
| help="Length of spectrograms time axis", | |
| ) | |
| parser.add_argument( | |
| "--window", | |
| type=int, | |
| default=64, | |
| help="Generator spectrogram window (must divide shape)", | |
| ) | |
| parser.add_argument( | |
| "--mu_rescale", | |
| type=float, | |
| default=-25.0, | |
| help="Spectrogram mu used to normalize", | |
| ) | |
| parser.add_argument( | |
| "--sigma_rescale", | |
| type=float, | |
| default=75.0, | |
| help="Spectrogram sigma used to normalize", | |
| ) | |
| parser.add_argument( | |
| "--load_path_1", | |
| type=str, | |
| default="checkpoints/techno/", | |
| help="Path of pretrained networks weights 1", | |
| ) | |
| parser.add_argument( | |
| "--load_path_2", | |
| type=str, | |
| default="checkpoints/metal/", | |
| help="Path of pretrained networks weights 2", | |
| ) | |
| parser.add_argument( | |
| "--load_path_3", | |
| type=str, | |
| default="checkpoints/misc/", | |
| help="Path of pretrained networks weights 3", | |
| ) | |
| parser.add_argument( | |
| "--dec_path", | |
| type=str, | |
| default="checkpoints/ae/", | |
| help="Path of pretrained decoders weights", | |
| ) | |
| parser.add_argument( | |
| "--testing", | |
| type=str2bool, | |
| default=True, | |
| help="True if optimizers weight do not need to be loaded", | |
| ) | |
| parser.add_argument( | |
| "--cpu", | |
| type=str2bool, | |
| default=False, | |
| help="True if you wish to use cpu", | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str2bool, | |
| default=True, | |
| help="True if your GPU supports mixed precision", | |
| ) | |
| tmp_args = parser.parse_args() | |
| args.hop = tmp_args.hop | |
| args.mel_bins = tmp_args.mel_bins | |
| args.sr = tmp_args.sr | |
| args.small = tmp_args.small | |
| args.latdepth = tmp_args.latdepth | |
| args.coorddepth = tmp_args.coorddepth | |
| args.base_channels = tmp_args.base_channels | |
| args.shape = tmp_args.shape | |
| args.window = tmp_args.window | |
| args.mu_rescale = tmp_args.mu_rescale | |
| args.sigma_rescale = tmp_args.sigma_rescale | |
| args.load_path_1 = tmp_args.load_path_1 | |
| args.load_path_2 = tmp_args.load_path_2 | |
| args.load_path_3 = tmp_args.load_path_3 | |
| args.dec_path = tmp_args.dec_path | |
| args.testing = tmp_args.testing | |
| args.cpu = tmp_args.cpu | |
| args.mixed_precision = tmp_args.mixed_precision | |
| if args.small: | |
| args.latlen = 128 | |
| else: | |
| args.latlen = 256 | |
| args.coordlen = (args.latlen // 2) * 3 | |
| print() | |
| args.datatype = tf.float32 | |
| gpuls = tf.config.list_physical_devices("GPU") | |
| if len(gpuls) == 0 or args.cpu: | |
| args.cpu = True | |
| args.mixed_precision = False | |
| tf.config.set_visible_devices([], "GPU") | |
| print() | |
| print("Using CPU...") | |
| print() | |
| if args.mixed_precision: | |
| args.datatype = tf.float16 | |
| print() | |
| print("Using GPU with mixed precision enabled...") | |
| print() | |
| if not args.mixed_precision and not args.cpu: | |
| print() | |
| print("Using GPU without mixed precision...") | |
| print() | |
| return args | |
| def parse_args(): | |
| args = EasyDict() | |
| return params_args(args) | |