Spaces:
Runtime error
Runtime error
Commit
·
41de683
1
Parent(s):
8bce6ea
generation additions
Browse files- MusicModel/decode.py +21 -0
- MusicModel/encode.py +24 -0
- MusicModel/parse/parse_decode.py +210 -0
- MusicModel/parse/parse_encode.py +217 -0
- MusicModel/parse/parse_generate.py +217 -0
- MusicModel/parse/parse_test.py +196 -0
- MusicModel/utils_encode.py +214 -0
- requirements.txt +10 -2
MusicModel/decode.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 4 |
+
|
| 5 |
+
from parse.parse_decode import parse_args
|
| 6 |
+
from models import Models_functions
|
| 7 |
+
from utils import Utils_functions
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
|
| 11 |
+
# parse args
|
| 12 |
+
args = parse_args()
|
| 13 |
+
|
| 14 |
+
# initialize networks
|
| 15 |
+
M = Models_functions(args)
|
| 16 |
+
M.download_networks()
|
| 17 |
+
models_ls = M.get_networks()
|
| 18 |
+
|
| 19 |
+
# encode samples
|
| 20 |
+
U = Utils_functions(args)
|
| 21 |
+
U.decode_path(models_ls)
|
MusicModel/encode.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 4 |
+
|
| 5 |
+
from parse.parse_encode import parse_args
|
| 6 |
+
from models import Models_functions
|
| 7 |
+
from utils_encode import UtilsEncode_functions
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
|
| 11 |
+
# parse args
|
| 12 |
+
args = parse_args()
|
| 13 |
+
|
| 14 |
+
# initialize networks
|
| 15 |
+
M = Models_functions(args)
|
| 16 |
+
M.download_networks()
|
| 17 |
+
models_ls = M.get_networks()
|
| 18 |
+
|
| 19 |
+
# encode samples
|
| 20 |
+
U = UtilsEncode_functions(args)
|
| 21 |
+
if args.whole:
|
| 22 |
+
U.compress_whole_files(models_ls)
|
| 23 |
+
else:
|
| 24 |
+
U.compress_files(models_ls)
|
MusicModel/parse/parse_decode.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EasyDict(dict):
|
| 7 |
+
def __getattr__(self, name: str) -> Any:
|
| 8 |
+
try:
|
| 9 |
+
return self[name]
|
| 10 |
+
except KeyError:
|
| 11 |
+
raise AttributeError(name)
|
| 12 |
+
|
| 13 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 14 |
+
self[name] = value
|
| 15 |
+
|
| 16 |
+
def __delattr__(self, name: str) -> None:
|
| 17 |
+
del self[name]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def str2bool(v):
|
| 21 |
+
if isinstance(v, bool):
|
| 22 |
+
return v
|
| 23 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 24 |
+
return True
|
| 25 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 26 |
+
return False
|
| 27 |
+
else:
|
| 28 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def params_args(args):
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--hop",
|
| 36 |
+
type=int,
|
| 37 |
+
default=256,
|
| 38 |
+
help="Hop size (window size = 4*hop)",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--mel_bins",
|
| 42 |
+
type=int,
|
| 43 |
+
default=256,
|
| 44 |
+
help="Mel bins in mel-spectrograms",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--sr",
|
| 48 |
+
type=int,
|
| 49 |
+
default=44100,
|
| 50 |
+
help="Sampling Rate",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--small",
|
| 54 |
+
type=str2bool,
|
| 55 |
+
default=False,
|
| 56 |
+
help="If True, use model with shorter available context, useful for small datasets",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--latdepth",
|
| 60 |
+
type=int,
|
| 61 |
+
default=64,
|
| 62 |
+
help="Depth of generated latent vectors",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--coorddepth",
|
| 66 |
+
type=int,
|
| 67 |
+
default=64,
|
| 68 |
+
help="Dimension of latent coordinate and style random vectors",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--max_lat_len",
|
| 72 |
+
type=int,
|
| 73 |
+
default=512,
|
| 74 |
+
help="Length of latent sequences: a random on-the-fly crop will be used for training",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--base_channels",
|
| 78 |
+
type=int,
|
| 79 |
+
default=128,
|
| 80 |
+
help="Base channels for generator and discriminator architectures",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--shape",
|
| 84 |
+
type=int,
|
| 85 |
+
default=128,
|
| 86 |
+
help="Length of spectrograms time axis",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--window",
|
| 90 |
+
type=int,
|
| 91 |
+
default=64,
|
| 92 |
+
help="Generator spectrogram window (must divide shape)",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--mu_rescale",
|
| 96 |
+
type=float,
|
| 97 |
+
default=-25.0,
|
| 98 |
+
help="Spectrogram mu used to normalize",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--sigma_rescale",
|
| 102 |
+
type=float,
|
| 103 |
+
default=75.0,
|
| 104 |
+
help="Spectrogram sigma used to normalize",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--files_path",
|
| 108 |
+
type=str,
|
| 109 |
+
default="audio_samples/",
|
| 110 |
+
help="Path of compressed latent samples to decode",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--save_path",
|
| 114 |
+
type=str,
|
| 115 |
+
default="decoded_samples/",
|
| 116 |
+
help="Path where decoded audio files will be saved",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--dec_path",
|
| 120 |
+
type=str,
|
| 121 |
+
default="checkpoints/ae",
|
| 122 |
+
help="Path of pretrained decoders weights",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--load_path",
|
| 126 |
+
type=str,
|
| 127 |
+
default="None",
|
| 128 |
+
help="If not None, load models weights from this path",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--base_path",
|
| 132 |
+
type=str,
|
| 133 |
+
default="checkpoints",
|
| 134 |
+
help="Path where pretrained models are downloaded",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--testing",
|
| 138 |
+
type=str2bool,
|
| 139 |
+
default=True,
|
| 140 |
+
help="True if optimizers weight do not need to be loaded",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--cpu",
|
| 144 |
+
type=str2bool,
|
| 145 |
+
default=False,
|
| 146 |
+
help="True if you wish to use cpu",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--mixed_precision",
|
| 150 |
+
type=str2bool,
|
| 151 |
+
default=True,
|
| 152 |
+
help="True if your GPU supports mixed precision",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
tmp_args = parser.parse_args()
|
| 156 |
+
|
| 157 |
+
args.hop = tmp_args.hop
|
| 158 |
+
args.mel_bins = tmp_args.mel_bins
|
| 159 |
+
args.sr = tmp_args.sr
|
| 160 |
+
args.small = tmp_args.small
|
| 161 |
+
args.latdepth = tmp_args.latdepth
|
| 162 |
+
args.coorddepth = tmp_args.coorddepth
|
| 163 |
+
args.max_lat_len = tmp_args.max_lat_len
|
| 164 |
+
args.base_channels = tmp_args.base_channels
|
| 165 |
+
args.shape = tmp_args.shape
|
| 166 |
+
args.window = tmp_args.window
|
| 167 |
+
args.mu_rescale = tmp_args.mu_rescale
|
| 168 |
+
args.sigma_rescale = tmp_args.sigma_rescale
|
| 169 |
+
args.save_path = tmp_args.save_path
|
| 170 |
+
args.files_path = tmp_args.files_path
|
| 171 |
+
args.dec_path = tmp_args.dec_path
|
| 172 |
+
args.load_path = tmp_args.load_path
|
| 173 |
+
args.base_path = tmp_args.base_path
|
| 174 |
+
args.testing = tmp_args.testing
|
| 175 |
+
args.cpu = tmp_args.cpu
|
| 176 |
+
args.mixed_precision = tmp_args.mixed_precision
|
| 177 |
+
|
| 178 |
+
if args.small:
|
| 179 |
+
args.latlen = 128
|
| 180 |
+
else:
|
| 181 |
+
args.latlen = 256
|
| 182 |
+
args.coordlen = (args.latlen // 2) * 3
|
| 183 |
+
|
| 184 |
+
print()
|
| 185 |
+
|
| 186 |
+
args.datatype = tf.float32
|
| 187 |
+
gpuls = tf.config.list_physical_devices("GPU")
|
| 188 |
+
if len(gpuls) == 0 or args.cpu:
|
| 189 |
+
args.cpu = True
|
| 190 |
+
args.mixed_precision = False
|
| 191 |
+
tf.config.set_visible_devices([], "GPU")
|
| 192 |
+
print()
|
| 193 |
+
print("Using CPU...")
|
| 194 |
+
print()
|
| 195 |
+
if args.mixed_precision:
|
| 196 |
+
args.datatype = tf.float16
|
| 197 |
+
print()
|
| 198 |
+
print("Using GPU with mixed precision enabled...")
|
| 199 |
+
print()
|
| 200 |
+
if not args.mixed_precision and not args.cpu:
|
| 201 |
+
print()
|
| 202 |
+
print("Using GPU without mixed precision...")
|
| 203 |
+
print()
|
| 204 |
+
|
| 205 |
+
return args
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def parse_args():
|
| 209 |
+
args = EasyDict()
|
| 210 |
+
return params_args(args)
|
MusicModel/parse/parse_encode.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EasyDict(dict):
|
| 7 |
+
def __getattr__(self, name: str) -> Any:
|
| 8 |
+
try:
|
| 9 |
+
return self[name]
|
| 10 |
+
except KeyError:
|
| 11 |
+
raise AttributeError(name)
|
| 12 |
+
|
| 13 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 14 |
+
self[name] = value
|
| 15 |
+
|
| 16 |
+
def __delattr__(self, name: str) -> None:
|
| 17 |
+
del self[name]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def str2bool(v):
|
| 21 |
+
if isinstance(v, bool):
|
| 22 |
+
return v
|
| 23 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 24 |
+
return True
|
| 25 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 26 |
+
return False
|
| 27 |
+
else:
|
| 28 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def params_args(args):
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--whole",
|
| 36 |
+
type=str2bool,
|
| 37 |
+
default=False,
|
| 38 |
+
help="If True, encode a single audio file to a single compressed encoding file of variable length",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--hop",
|
| 42 |
+
type=int,
|
| 43 |
+
default=256,
|
| 44 |
+
help="Hop size (window size = 4*hop)",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--mel_bins",
|
| 48 |
+
type=int,
|
| 49 |
+
default=256,
|
| 50 |
+
help="Mel bins in mel-spectrograms",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--sr",
|
| 54 |
+
type=int,
|
| 55 |
+
default=44100,
|
| 56 |
+
help="Sampling Rate",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--small",
|
| 60 |
+
type=str2bool,
|
| 61 |
+
default=False,
|
| 62 |
+
help="If True, use model with shorter available context, useful for small datasets",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--latdepth",
|
| 66 |
+
type=int,
|
| 67 |
+
default=64,
|
| 68 |
+
help="Depth of generated latent vectors",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--coorddepth",
|
| 72 |
+
type=int,
|
| 73 |
+
default=64,
|
| 74 |
+
help="Dimension of latent coordinate and style random vectors",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--max_lat_len",
|
| 78 |
+
type=int,
|
| 79 |
+
default=512,
|
| 80 |
+
help="Length of latent sequences: a random on-the-fly crop will be used for training",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--base_channels",
|
| 84 |
+
type=int,
|
| 85 |
+
default=128,
|
| 86 |
+
help="Base channels for generator and discriminator architectures",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--shape",
|
| 90 |
+
type=int,
|
| 91 |
+
default=128,
|
| 92 |
+
help="Length of spectrograms time axis",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--window",
|
| 96 |
+
type=int,
|
| 97 |
+
default=64,
|
| 98 |
+
help="Generator spectrogram window (must divide shape)",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--mu_rescale",
|
| 102 |
+
type=float,
|
| 103 |
+
default=-25.0,
|
| 104 |
+
help="Spectrogram mu used to normalize",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--sigma_rescale",
|
| 108 |
+
type=float,
|
| 109 |
+
default=75.0,
|
| 110 |
+
help="Spectrogram sigma used to normalize",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--files_path",
|
| 114 |
+
type=str,
|
| 115 |
+
default="audio_samples/",
|
| 116 |
+
help="Path of samples to encode",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--save_path",
|
| 120 |
+
type=str,
|
| 121 |
+
default="encoded_samples/",
|
| 122 |
+
help="Path where compressed representations will be saved",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--dec_path",
|
| 126 |
+
type=str,
|
| 127 |
+
default="checkpoints/ae",
|
| 128 |
+
help="Path of pretrained decoders weights",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--load_path",
|
| 132 |
+
type=str,
|
| 133 |
+
default="None",
|
| 134 |
+
help="If not None, load models weights from this path",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--base_path",
|
| 138 |
+
type=str,
|
| 139 |
+
default="checkpoints",
|
| 140 |
+
help="Path where pretrained models are downloaded",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--testing",
|
| 144 |
+
type=str2bool,
|
| 145 |
+
default=True,
|
| 146 |
+
help="True if optimizers weight do not need to be loaded",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--cpu",
|
| 150 |
+
type=str2bool,
|
| 151 |
+
default=False,
|
| 152 |
+
help="True if you wish to use cpu",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--mixed_precision",
|
| 156 |
+
type=str2bool,
|
| 157 |
+
default=True,
|
| 158 |
+
help="True if your GPU supports mixed precision",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
tmp_args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
args.whole = tmp_args.whole
|
| 164 |
+
args.hop = tmp_args.hop
|
| 165 |
+
args.mel_bins = tmp_args.mel_bins
|
| 166 |
+
args.sr = tmp_args.sr
|
| 167 |
+
args.small = tmp_args.small
|
| 168 |
+
args.latdepth = tmp_args.latdepth
|
| 169 |
+
args.coorddepth = tmp_args.coorddepth
|
| 170 |
+
args.max_lat_len = tmp_args.max_lat_len
|
| 171 |
+
args.base_channels = tmp_args.base_channels
|
| 172 |
+
args.shape = tmp_args.shape
|
| 173 |
+
args.window = tmp_args.window
|
| 174 |
+
args.mu_rescale = tmp_args.mu_rescale
|
| 175 |
+
args.sigma_rescale = tmp_args.sigma_rescale
|
| 176 |
+
args.save_path = tmp_args.save_path
|
| 177 |
+
args.files_path = tmp_args.files_path
|
| 178 |
+
args.dec_path = tmp_args.dec_path
|
| 179 |
+
args.load_path = tmp_args.load_path
|
| 180 |
+
args.base_path = tmp_args.base_path
|
| 181 |
+
args.testing = tmp_args.testing
|
| 182 |
+
args.cpu = tmp_args.cpu
|
| 183 |
+
args.mixed_precision = tmp_args.mixed_precision
|
| 184 |
+
|
| 185 |
+
if args.small:
|
| 186 |
+
args.latlen = 128
|
| 187 |
+
else:
|
| 188 |
+
args.latlen = 256
|
| 189 |
+
args.coordlen = (args.latlen // 2) * 3
|
| 190 |
+
|
| 191 |
+
print()
|
| 192 |
+
|
| 193 |
+
args.datatype = tf.float32
|
| 194 |
+
gpuls = tf.config.list_physical_devices("GPU")
|
| 195 |
+
if len(gpuls) == 0 or args.cpu:
|
| 196 |
+
args.cpu = True
|
| 197 |
+
args.mixed_precision = False
|
| 198 |
+
tf.config.set_visible_devices([], "GPU")
|
| 199 |
+
print()
|
| 200 |
+
print("Using CPU...")
|
| 201 |
+
print()
|
| 202 |
+
if args.mixed_precision:
|
| 203 |
+
args.datatype = tf.float16
|
| 204 |
+
print()
|
| 205 |
+
print("Using GPU with mixed precision enabled...")
|
| 206 |
+
print()
|
| 207 |
+
if not args.mixed_precision and not args.cpu:
|
| 208 |
+
print()
|
| 209 |
+
print("Using GPU without mixed precision...")
|
| 210 |
+
print()
|
| 211 |
+
|
| 212 |
+
return args
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def parse_args():
|
| 216 |
+
args = EasyDict()
|
| 217 |
+
return params_args(args)
|
MusicModel/parse/parse_generate.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EasyDict(dict):
|
| 7 |
+
def __getattr__(self, name: str) -> Any:
|
| 8 |
+
try:
|
| 9 |
+
return self[name]
|
| 10 |
+
except KeyError:
|
| 11 |
+
raise AttributeError(name)
|
| 12 |
+
|
| 13 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 14 |
+
self[name] = value
|
| 15 |
+
|
| 16 |
+
def __delattr__(self, name: str) -> None:
|
| 17 |
+
del self[name]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def str2bool(v):
|
| 21 |
+
if isinstance(v, bool):
|
| 22 |
+
return v
|
| 23 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 24 |
+
return True
|
| 25 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 26 |
+
return False
|
| 27 |
+
else:
|
| 28 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def params_args(args):
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--num_samples",
|
| 36 |
+
type=int,
|
| 37 |
+
default=1,
|
| 38 |
+
help="Number of desired generated samples",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--seconds",
|
| 42 |
+
type=int,
|
| 43 |
+
default=120,
|
| 44 |
+
help="Length in seconds of generated samples",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--save_path",
|
| 48 |
+
type=str,
|
| 49 |
+
default="generations",
|
| 50 |
+
help="Path where to save generated samples",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--truncation",
|
| 54 |
+
type=float,
|
| 55 |
+
default=2.0,
|
| 56 |
+
help="Standard deviation of random vectors (truncation trick)",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--hop",
|
| 60 |
+
type=int,
|
| 61 |
+
default=256,
|
| 62 |
+
help="Hop size (window size = 4*hop)",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--mel_bins",
|
| 66 |
+
type=int,
|
| 67 |
+
default=256,
|
| 68 |
+
help="Mel bins in mel-spectrograms",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--sr",
|
| 72 |
+
type=int,
|
| 73 |
+
default=44100,
|
| 74 |
+
help="Sampling Rate",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--small",
|
| 78 |
+
type=str2bool,
|
| 79 |
+
default=False,
|
| 80 |
+
help="If True, use model with shorter available context, useful for small datasets",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--latdepth",
|
| 84 |
+
type=int,
|
| 85 |
+
default=64,
|
| 86 |
+
help="Depth of generated latent vectors",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--coorddepth",
|
| 90 |
+
type=int,
|
| 91 |
+
default=64,
|
| 92 |
+
help="Dimension of latent coordinate and style random vectors",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--base_channels",
|
| 96 |
+
type=int,
|
| 97 |
+
default=128,
|
| 98 |
+
help="Base channels for generator and discriminator architectures",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--shape",
|
| 102 |
+
type=int,
|
| 103 |
+
default=128,
|
| 104 |
+
help="Length of spectrograms time axis",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--window",
|
| 108 |
+
type=int,
|
| 109 |
+
default=64,
|
| 110 |
+
help="Generator spectrogram window (must divide shape)",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--mu_rescale",
|
| 114 |
+
type=float,
|
| 115 |
+
default=-25.0,
|
| 116 |
+
help="Spectrogram mu used to normalize",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--sigma_rescale",
|
| 120 |
+
type=float,
|
| 121 |
+
default=75.0,
|
| 122 |
+
help="Spectrogram sigma used to normalize",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--load_path",
|
| 126 |
+
type=str,
|
| 127 |
+
default="checkpoints/techno/",
|
| 128 |
+
help="Path of pretrained networks weights",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--dec_path",
|
| 132 |
+
type=str,
|
| 133 |
+
default="checkpoints/ae/",
|
| 134 |
+
help="Path of pretrained decoders weights",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--base_path",
|
| 138 |
+
type=str,
|
| 139 |
+
default="checkpoints",
|
| 140 |
+
help="Path where pretrained models are downloaded",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--testing",
|
| 144 |
+
type=str2bool,
|
| 145 |
+
default=True,
|
| 146 |
+
help="True if optimizers weight do not need to be loaded",
|
| 147 |
+
)
|
| 148 |
+
parser.add_argument(
|
| 149 |
+
"--cpu",
|
| 150 |
+
type=str2bool,
|
| 151 |
+
default=False,
|
| 152 |
+
help="True if you wish to use cpu",
|
| 153 |
+
)
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--mixed_precision",
|
| 156 |
+
type=str2bool,
|
| 157 |
+
default=True,
|
| 158 |
+
help="True if your GPU supports mixed precision",
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
tmp_args = parser.parse_args()
|
| 162 |
+
|
| 163 |
+
args.num_samples = tmp_args.num_samples
|
| 164 |
+
args.seconds = tmp_args.seconds
|
| 165 |
+
args.save_path = tmp_args.save_path
|
| 166 |
+
args.truncation = tmp_args.truncation
|
| 167 |
+
args.hop = tmp_args.hop
|
| 168 |
+
args.mel_bins = tmp_args.mel_bins
|
| 169 |
+
args.sr = tmp_args.sr
|
| 170 |
+
args.small = tmp_args.small
|
| 171 |
+
args.latdepth = tmp_args.latdepth
|
| 172 |
+
args.coorddepth = tmp_args.coorddepth
|
| 173 |
+
args.base_channels = tmp_args.base_channels
|
| 174 |
+
args.shape = tmp_args.shape
|
| 175 |
+
args.window = tmp_args.window
|
| 176 |
+
args.mu_rescale = tmp_args.mu_rescale
|
| 177 |
+
args.sigma_rescale = tmp_args.sigma_rescale
|
| 178 |
+
args.load_path = tmp_args.load_path
|
| 179 |
+
args.base_path = tmp_args.base_path
|
| 180 |
+
args.dec_path = tmp_args.dec_path
|
| 181 |
+
args.testing = tmp_args.testing
|
| 182 |
+
args.cpu = tmp_args.cpu
|
| 183 |
+
args.mixed_precision = tmp_args.mixed_precision
|
| 184 |
+
|
| 185 |
+
if args.small:
|
| 186 |
+
args.latlen = 128
|
| 187 |
+
else:
|
| 188 |
+
args.latlen = 256
|
| 189 |
+
args.coordlen = (args.latlen // 2) * 3
|
| 190 |
+
|
| 191 |
+
print()
|
| 192 |
+
|
| 193 |
+
args.datatype = tf.float32
|
| 194 |
+
gpuls = tf.config.list_physical_devices("GPU")
|
| 195 |
+
if len(gpuls) == 0 or args.cpu:
|
| 196 |
+
args.cpu = True
|
| 197 |
+
args.mixed_precision = False
|
| 198 |
+
tf.config.set_visible_devices([], "GPU")
|
| 199 |
+
print()
|
| 200 |
+
print("Using CPU...")
|
| 201 |
+
print()
|
| 202 |
+
if args.mixed_precision:
|
| 203 |
+
args.datatype = tf.float16
|
| 204 |
+
print()
|
| 205 |
+
print("Using GPU with mixed precision enabled...")
|
| 206 |
+
print()
|
| 207 |
+
if not args.mixed_precision and not args.cpu:
|
| 208 |
+
print()
|
| 209 |
+
print("Using GPU without mixed precision...")
|
| 210 |
+
print()
|
| 211 |
+
|
| 212 |
+
return args
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def parse_args():
|
| 216 |
+
args = EasyDict()
|
| 217 |
+
return params_args(args)
|
MusicModel/parse/parse_test.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from typing import Any
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EasyDict(dict):
|
| 7 |
+
def __getattr__(self, name: str) -> Any:
|
| 8 |
+
try:
|
| 9 |
+
return self[name]
|
| 10 |
+
except KeyError:
|
| 11 |
+
raise AttributeError(name)
|
| 12 |
+
|
| 13 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
| 14 |
+
self[name] = value
|
| 15 |
+
|
| 16 |
+
def __delattr__(self, name: str) -> None:
|
| 17 |
+
del self[name]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def str2bool(v):
|
| 21 |
+
if isinstance(v, bool):
|
| 22 |
+
return v
|
| 23 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 24 |
+
return True
|
| 25 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 26 |
+
return False
|
| 27 |
+
else:
|
| 28 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def params_args(args):
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--hop",
|
| 36 |
+
type=int,
|
| 37 |
+
default=256,
|
| 38 |
+
help="Hop size (window size = 4*hop)",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--mel_bins",
|
| 42 |
+
type=int,
|
| 43 |
+
default=256,
|
| 44 |
+
help="Mel bins in mel-spectrograms",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--sr",
|
| 48 |
+
type=int,
|
| 49 |
+
default=44100,
|
| 50 |
+
help="Sampling Rate",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--small",
|
| 54 |
+
type=str2bool,
|
| 55 |
+
default=False,
|
| 56 |
+
help="If True, use model with shorter available context, useful for small datasets",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--latdepth",
|
| 60 |
+
type=int,
|
| 61 |
+
default=64,
|
| 62 |
+
help="Depth of generated latent vectors",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--coorddepth",
|
| 66 |
+
type=int,
|
| 67 |
+
default=64,
|
| 68 |
+
help="Dimension of latent coordinate and style random vectors",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--base_channels",
|
| 72 |
+
type=int,
|
| 73 |
+
default=128,
|
| 74 |
+
help="Base channels for generator and discriminator architectures",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--shape",
|
| 78 |
+
type=int,
|
| 79 |
+
default=128,
|
| 80 |
+
help="Length of spectrograms time axis",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--window",
|
| 84 |
+
type=int,
|
| 85 |
+
default=64,
|
| 86 |
+
help="Generator spectrogram window (must divide shape)",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--mu_rescale",
|
| 90 |
+
type=float,
|
| 91 |
+
default=-25.0,
|
| 92 |
+
help="Spectrogram mu used to normalize",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--sigma_rescale",
|
| 96 |
+
type=float,
|
| 97 |
+
default=75.0,
|
| 98 |
+
help="Spectrogram sigma used to normalize",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--load_path",
|
| 102 |
+
type=str,
|
| 103 |
+
default="checkpoints/techno/",
|
| 104 |
+
help="Path of pretrained networks weights",
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--dec_path",
|
| 108 |
+
type=str,
|
| 109 |
+
default="checkpoints/ae/",
|
| 110 |
+
help="Path of pretrained decoders weights",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--base_path",
|
| 114 |
+
type=str,
|
| 115 |
+
default="checkpoints",
|
| 116 |
+
help="Path where pretrained models are downloaded",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--testing",
|
| 120 |
+
type=str2bool,
|
| 121 |
+
default=True,
|
| 122 |
+
help="True if optimizers weight do not need to be loaded",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--cpu",
|
| 126 |
+
type=str2bool,
|
| 127 |
+
default=False,
|
| 128 |
+
help="True if you wish to use cpu",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--mixed_precision",
|
| 132 |
+
type=str2bool,
|
| 133 |
+
default=True,
|
| 134 |
+
help="True if your GPU supports mixed precision",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--share_gradio",
|
| 138 |
+
type=str2bool,
|
| 139 |
+
default=False,
|
| 140 |
+
help="True if you wish to create a public URL for the Gradio interface",
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
tmp_args = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
args.hop = tmp_args.hop
|
| 146 |
+
args.mel_bins = tmp_args.mel_bins
|
| 147 |
+
args.sr = tmp_args.sr
|
| 148 |
+
args.small = tmp_args.small
|
| 149 |
+
args.latdepth = tmp_args.latdepth
|
| 150 |
+
args.coorddepth = tmp_args.coorddepth
|
| 151 |
+
args.base_channels = tmp_args.base_channels
|
| 152 |
+
args.shape = tmp_args.shape
|
| 153 |
+
args.window = tmp_args.window
|
| 154 |
+
args.mu_rescale = tmp_args.mu_rescale
|
| 155 |
+
args.sigma_rescale = tmp_args.sigma_rescale
|
| 156 |
+
args.load_path = tmp_args.load_path
|
| 157 |
+
args.base_path = tmp_args.base_path
|
| 158 |
+
args.dec_path = tmp_args.dec_path
|
| 159 |
+
args.testing = tmp_args.testing
|
| 160 |
+
args.cpu = tmp_args.cpu
|
| 161 |
+
args.mixed_precision = tmp_args.mixed_precision
|
| 162 |
+
args.share_gradio = tmp_args.share_gradio
|
| 163 |
+
|
| 164 |
+
if args.small:
|
| 165 |
+
args.latlen = 128
|
| 166 |
+
else:
|
| 167 |
+
args.latlen = 256
|
| 168 |
+
args.coordlen = (args.latlen // 2) * 3
|
| 169 |
+
|
| 170 |
+
print()
|
| 171 |
+
|
| 172 |
+
args.datatype = tf.float32
|
| 173 |
+
gpuls = tf.config.list_physical_devices("GPU")
|
| 174 |
+
if len(gpuls) == 0 or args.cpu:
|
| 175 |
+
args.cpu = True
|
| 176 |
+
args.mixed_precision = False
|
| 177 |
+
tf.config.set_visible_devices([], "GPU")
|
| 178 |
+
print()
|
| 179 |
+
print("Using CPU...")
|
| 180 |
+
print()
|
| 181 |
+
if args.mixed_precision:
|
| 182 |
+
args.datatype = tf.float16
|
| 183 |
+
print()
|
| 184 |
+
print("Using GPU with mixed precision enabled...")
|
| 185 |
+
print()
|
| 186 |
+
if not args.mixed_precision and not args.cpu:
|
| 187 |
+
print()
|
| 188 |
+
print("Using GPU without mixed precision...")
|
| 189 |
+
print()
|
| 190 |
+
|
| 191 |
+
return args
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def parse_args():
|
| 195 |
+
args = EasyDict()
|
| 196 |
+
return params_args(args)
|
MusicModel/utils_encode.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from pydub import AudioSegment
|
| 6 |
+
from glob import glob
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from utils import Utils_functions
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class UtilsEncode_functions:
|
| 13 |
+
def __init__(self, args):
|
| 14 |
+
|
| 15 |
+
self.args = args
|
| 16 |
+
self.U = Utils_functions(args)
|
| 17 |
+
self.paths = sorted(glob(self.args.files_path + "/*"))
|
| 18 |
+
|
| 19 |
+
def audio_generator(self):
|
| 20 |
+
for p in self.paths:
|
| 21 |
+
try:
|
| 22 |
+
tp, ext = os.path.splitext(p)
|
| 23 |
+
bname = os.path.basename(tp)
|
| 24 |
+
wvo = AudioSegment.from_file(p, format=ext[1:])
|
| 25 |
+
wvo = wvo.set_frame_rate(self.args.sr)
|
| 26 |
+
wvls = wvo.split_to_mono()
|
| 27 |
+
wvls = [s.get_array_of_samples() for s in wvls]
|
| 28 |
+
wv = np.array(wvls).T.astype(np.float32)
|
| 29 |
+
wv /= np.iinfo(wvls[0].typecode).max
|
| 30 |
+
yield np.squeeze(wv), bname
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(e)
|
| 33 |
+
print("Exception ignored! Continuing...")
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
# def create_dataset(self):
|
| 37 |
+
# self.ds = (
|
| 38 |
+
# tf.data.Dataset.from_generator(
|
| 39 |
+
# self.audio_generator, output_signature=(tf.TensorSpec(shape=(None, 2), dtype=tf.float32))
|
| 40 |
+
# )
|
| 41 |
+
# .prefetch(tf.data.experimental.AUTOTUNE)
|
| 42 |
+
# .apply(tf.data.experimental.ignore_errors())
|
| 43 |
+
# )
|
| 44 |
+
|
| 45 |
+
def compress_files(self, models_ls=None):
|
| 46 |
+
critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
|
| 47 |
+
# self.create_dataset()
|
| 48 |
+
os.makedirs(self.args.save_path, exist_ok=True)
|
| 49 |
+
c = 0
|
| 50 |
+
time_compression_ratio = 16 # TODO: infer time compression ratio
|
| 51 |
+
shape2 = self.args.shape
|
| 52 |
+
pbar = tqdm(self.audio_generator(), position=0, leave=True, total=len(self.paths))
|
| 53 |
+
|
| 54 |
+
for (wv,bname) in pbar:
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
|
| 58 |
+
if wv.shape[0] > self.args.hop * self.args.shape * 2 + 3 * self.args.hop:
|
| 59 |
+
|
| 60 |
+
split_limit = (
|
| 61 |
+
5 * 60 * self.args.sr
|
| 62 |
+
) # split very long waveforms (> 5 minutes) and process separately to avoid out of memory errors
|
| 63 |
+
|
| 64 |
+
nsplits = (wv.shape[0] // split_limit) + 1
|
| 65 |
+
wvsplits = []
|
| 66 |
+
for ns in range(nsplits):
|
| 67 |
+
if wv.shape[0] - (ns * split_limit) > self.args.hop * self.args.shape * 2 + 3 * self.args.hop:
|
| 68 |
+
wvsplits.append(wv[ns * split_limit : (ns + 1) * split_limit, :])
|
| 69 |
+
|
| 70 |
+
for wv in wvsplits:
|
| 71 |
+
|
| 72 |
+
wv = tf.image.random_crop(
|
| 73 |
+
wv,
|
| 74 |
+
size=[
|
| 75 |
+
(((wv.shape[0] - (3 * self.args.hop)) // (self.args.shape * self.args.hop)))
|
| 76 |
+
* self.args.shape
|
| 77 |
+
* self.args.hop
|
| 78 |
+
+ 3 * self.args.hop,
|
| 79 |
+
2,
|
| 80 |
+
],
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
chls = []
|
| 84 |
+
for channel in range(2):
|
| 85 |
+
|
| 86 |
+
x = wv[:, channel]
|
| 87 |
+
x = tf.expand_dims(tf.transpose(self.U.wv2spec(x, hop_size=self.args.hop), (1, 0)), -1)
|
| 88 |
+
ds = []
|
| 89 |
+
num = x.shape[1] // self.args.shape
|
| 90 |
+
rn = 0
|
| 91 |
+
for i in range(num):
|
| 92 |
+
ds.append(
|
| 93 |
+
x[:, rn + (i * self.args.shape) : rn + (i * self.args.shape) + self.args.shape, :]
|
| 94 |
+
)
|
| 95 |
+
del x
|
| 96 |
+
ds = tf.convert_to_tensor(ds, dtype=tf.float32)
|
| 97 |
+
lat = self.U.distribute_enc(ds, enc)
|
| 98 |
+
del ds
|
| 99 |
+
lat = tf.split(lat, lat.shape[0], 0)
|
| 100 |
+
lat = tf.concat(lat, -2)
|
| 101 |
+
lat = tf.squeeze(lat)
|
| 102 |
+
|
| 103 |
+
switch = False
|
| 104 |
+
if lat.shape[0] > (self.args.max_lat_len * time_compression_ratio):
|
| 105 |
+
switch = True
|
| 106 |
+
ds2 = []
|
| 107 |
+
num2 = lat.shape[-2] // shape2
|
| 108 |
+
rn2 = 0
|
| 109 |
+
for j in range(num2):
|
| 110 |
+
ds2.append(lat[rn2 + (j * shape2) : rn2 + (j * shape2) + shape2, :])
|
| 111 |
+
ds2 = tf.convert_to_tensor(ds2, dtype=tf.float32)
|
| 112 |
+
lat = self.U.distribute_enc(tf.expand_dims(ds2, -3), enc2)
|
| 113 |
+
del ds2
|
| 114 |
+
lat = tf.split(lat, lat.shape[0], 0)
|
| 115 |
+
lat = tf.concat(lat, -2)
|
| 116 |
+
lat = tf.squeeze(lat)
|
| 117 |
+
chls.append(lat)
|
| 118 |
+
|
| 119 |
+
if lat.shape[0] > self.args.max_lat_len and switch:
|
| 120 |
+
|
| 121 |
+
lat = tf.concat(chls, -1)
|
| 122 |
+
|
| 123 |
+
del chls
|
| 124 |
+
|
| 125 |
+
latc = lat[: (lat.shape[0] // self.args.max_lat_len) * self.args.max_lat_len, :]
|
| 126 |
+
latc = tf.split(latc, latc.shape[0] // self.args.max_lat_len, 0)
|
| 127 |
+
for el in latc:
|
| 128 |
+
np.save(self.args.save_path + f"/{bname}_{c}.npy", el)
|
| 129 |
+
c += 1
|
| 130 |
+
pbar.set_postfix({"Saved Files": c})
|
| 131 |
+
np.save(self.args.save_path + f"/{bname}_{c}.npy", lat[-self.args.max_lat_len :, :])
|
| 132 |
+
c += 1
|
| 133 |
+
pbar.set_postfix({"Saved Files": c})
|
| 134 |
+
|
| 135 |
+
del lat
|
| 136 |
+
del latc
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(e)
|
| 140 |
+
print("Exception ignored! Continuing...")
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def compress_whole_files(self, models_ls=None):
|
| 145 |
+
critic, gen, enc, dec, enc2, dec2, gen_ema, [opt_dec, opt_disc], switch = models_ls
|
| 146 |
+
# self.create_dataset()
|
| 147 |
+
os.makedirs(self.args.save_path, exist_ok=True)
|
| 148 |
+
c = 0
|
| 149 |
+
time_compression_ratio = 16 # TODO: infer time compression ratio
|
| 150 |
+
shape2 = self.args.shape
|
| 151 |
+
pbar = tqdm(self.audio_generator(), position=0, leave=True, total=len(self.paths))
|
| 152 |
+
|
| 153 |
+
for (wv,bname) in pbar:
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
|
| 157 |
+
# wv_len_orig = wv.shape[0]
|
| 158 |
+
|
| 159 |
+
if wv.shape[0] > self.args.hop * self.args.shape * 2 + 3 * self.args.hop:
|
| 160 |
+
|
| 161 |
+
rem = (wv.shape[0] - (3 * self.args.hop)) % (self.args.shape * self.args.hop)
|
| 162 |
+
|
| 163 |
+
if rem != 0:
|
| 164 |
+
wv = tf.concat([wv, tf.zeros([rem,2], dtype=tf.float32)], 0)
|
| 165 |
+
|
| 166 |
+
chls = []
|
| 167 |
+
for channel in range(2):
|
| 168 |
+
|
| 169 |
+
x = wv[:, channel]
|
| 170 |
+
x = tf.expand_dims(tf.transpose(self.U.wv2spec(x, hop_size=self.args.hop), (1, 0)), -1)
|
| 171 |
+
ds = []
|
| 172 |
+
num = x.shape[1] // self.args.shape
|
| 173 |
+
rn = 0
|
| 174 |
+
for i in range(num):
|
| 175 |
+
ds.append(
|
| 176 |
+
x[:, rn + (i * self.args.shape) : rn + (i * self.args.shape) + self.args.shape, :]
|
| 177 |
+
)
|
| 178 |
+
del x
|
| 179 |
+
ds = tf.convert_to_tensor(ds, dtype=tf.float32)
|
| 180 |
+
lat = self.U.distribute_enc(ds, enc)
|
| 181 |
+
del ds
|
| 182 |
+
lat = tf.split(lat, lat.shape[0], 0)
|
| 183 |
+
lat = tf.concat(lat, -2)
|
| 184 |
+
lat = tf.squeeze(lat)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
ds2 = []
|
| 189 |
+
num2 = lat.shape[-2] // shape2
|
| 190 |
+
rn2 = 0
|
| 191 |
+
for j in range(num2):
|
| 192 |
+
ds2.append(lat[rn2 + (j * shape2) : rn2 + (j * shape2) + shape2, :])
|
| 193 |
+
ds2 = tf.convert_to_tensor(ds2, dtype=tf.float32)
|
| 194 |
+
lat = self.U.distribute_enc(tf.expand_dims(ds2, -3), enc2)
|
| 195 |
+
del ds2
|
| 196 |
+
lat = tf.split(lat, lat.shape[0], 0)
|
| 197 |
+
lat = tf.concat(lat, -2)
|
| 198 |
+
lat = tf.squeeze(lat)
|
| 199 |
+
chls.append(lat)
|
| 200 |
+
|
| 201 |
+
lat = tf.concat(chls, -1)
|
| 202 |
+
|
| 203 |
+
del chls
|
| 204 |
+
|
| 205 |
+
np.save(self.args.save_path + f"/{bname}.npy", lat)
|
| 206 |
+
c += 1
|
| 207 |
+
pbar.set_postfix({"Saved Files": c})
|
| 208 |
+
|
| 209 |
+
del lat
|
| 210 |
+
|
| 211 |
+
except Exception as e:
|
| 212 |
+
print(e)
|
| 213 |
+
print("Exception ignored! Continuing...")
|
| 214 |
+
pass
|
requirements.txt
CHANGED
|
@@ -1,7 +1,15 @@
|
|
| 1 |
-
tensorflow==2.9.1
|
| 2 |
gdown==4.4.0
|
| 3 |
streamlit==1.10.0
|
| 4 |
streamlit-tags==1.2.7
|
| 5 |
torch==1.12.1
|
| 6 |
scikit-learn==0.22
|
| 7 |
-
transformers==4.24.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
gdown==4.4.0
|
| 2 |
streamlit==1.10.0
|
| 3 |
streamlit-tags==1.2.7
|
| 4 |
torch==1.12.1
|
| 5 |
scikit-learn==0.22
|
| 6 |
+
transformers==4.24.0
|
| 7 |
+
|
| 8 |
+
librosa==0.8.1
|
| 9 |
+
matplotlib==3.4.3
|
| 10 |
+
numpy==1.23.5
|
| 11 |
+
scipy==1.7.1
|
| 12 |
+
tensorboard==2.10.0
|
| 13 |
+
tensorflow==2.10.0
|
| 14 |
+
tqdm==4.62.3
|
| 15 |
+
pydub==0.25.1
|