Spaces:
Running
Running
Commit ·
a00b67a
1
Parent(s): da27cbe
first commit
Browse files- .gitattributes +0 -34
- LICENSE +21 -0
- README.md +2 -13
- add.py +293 -0
- configs/delimit_6_s.yaml +92 -0
- dataloader/__init__.py +8 -0
- dataloader/dataset.py +579 -0
- dataloader/delimit_dataset.py +573 -0
- dataloader/singleset.py +95 -0
- eval_delimit/calc_flops.py +44 -0
- eval_delimit/score_calc_delimit.py +145 -0
- eval_delimit/score_diff_dyn_complexity.py +87 -0
- eval_delimit/score_fad.py +75 -0
- eval_delimit/score_features.py +233 -0
- eval_delimit/score_peaq.py +77 -0
- eval_delimit/score_peaq_aggregate.py +88 -0
- inference.py +165 -0
- main_ddp.py +49 -0
- models/__init__.py +1 -0
- models/base_models.py +239 -0
- models/load_models.py +87 -0
- prepro/delimit_save_delimiter_stems.py +93 -0
- prepro/delimit_save_musdb_loudnorm.py +118 -0
- prepro/delimit_train_ozone_prepro.py +293 -0
- prepro/delimit_valid_L_prepro.py +41 -0
- prepro/delimit_valid_custom_limiter_prepro.py +59 -0
- prepro/delimit_valid_prepro.py +41 -0
- requirements.txt +13 -0
- separate_func/__init__.py +1 -0
- separate_func/conv_tasnet_separate.py +89 -0
- solver_ddp.py +643 -0
- test_ddp.py +245 -0
- train_ddp.py +56 -0
- utils/__init__.py +19 -0
- utils/logging.py +79 -0
- utils/loudness_utils.py +71 -0
- utils/lr_scheduler.py +80 -0
- utils/read_wave_utils.py +109 -0
- utils/train_utils.py +27 -0
- weight/all.json +957 -0
- weight/all.pth +3 -0
.gitattributes
CHANGED
|
@@ -1,35 +1 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 jeonchangbin49
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,2 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
emoji: 🏃
|
| 4 |
-
colorFrom: pink
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 3.39.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
# De-limiter
|
| 2 |
+
An official demo of "Music De-limiter Networks via Sample-wise Gain Inversion", which will be presented in WASPAA 2023.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
import librosa
|
| 10 |
+
import librosa.display
|
| 11 |
+
import soundfile as sf
|
| 12 |
+
import pyloudnorm as pyln
|
| 13 |
+
from dotmap import DotMap
|
| 14 |
+
import gradio as gr
|
| 15 |
+
|
| 16 |
+
from models import load_model_with_args
|
| 17 |
+
from separate_func import (
|
| 18 |
+
conv_tasnet_separate,
|
| 19 |
+
)
|
| 20 |
+
from utils import db2linear
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
tqdm.monitor_interval = 0
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def separate_track_with_model(
|
| 27 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
| 28 |
+
):
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
if (
|
| 31 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
| 32 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
| 33 |
+
):
|
| 34 |
+
estimates = conv_tasnet_separate(
|
| 35 |
+
args,
|
| 36 |
+
model,
|
| 37 |
+
device,
|
| 38 |
+
track_audio,
|
| 39 |
+
track_name,
|
| 40 |
+
meter=meter,
|
| 41 |
+
augmented_gain=augmented_gain,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
return estimates
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main(input, mix_coefficient):
|
| 48 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 49 |
+
parser.add_argument("--target", type=str, default="all")
|
| 50 |
+
parser.add_argument("--weight_directory", type=str, default="weight")
|
| 51 |
+
parser.add_argument("--output_directory", type=str, default="output")
|
| 52 |
+
parser.add_argument("--use_gpu", type=bool, default=True)
|
| 53 |
+
parser.add_argument("--save_name_as_target", type=bool, default=False)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--loudnorm_input_lufs",
|
| 56 |
+
type=float,
|
| 57 |
+
default=None,
|
| 58 |
+
help="If you want to use loudnorm for input",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--save_output_loudnorm",
|
| 62 |
+
type=float,
|
| 63 |
+
default=-14.0,
|
| 64 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--save_mixed_output",
|
| 68 |
+
type=float,
|
| 69 |
+
default=None,
|
| 70 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--save_16k_mono",
|
| 74 |
+
type=bool,
|
| 75 |
+
default=False,
|
| 76 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
| 77 |
+
)
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--save_histogram",
|
| 80 |
+
type=bool,
|
| 81 |
+
default=False,
|
| 82 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--use_singletrackset",
|
| 86 |
+
type=bool,
|
| 87 |
+
default=False,
|
| 88 |
+
help="Use SingleTrackSet if input data is too long.",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
args, _ = parser.parse_known_args()
|
| 92 |
+
|
| 93 |
+
with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
|
| 94 |
+
args_dict = json.load(f)
|
| 95 |
+
args_dict = DotMap(args_dict)
|
| 96 |
+
|
| 97 |
+
for key, value in args_dict["args"].items():
|
| 98 |
+
if key in list(vars(args).keys()):
|
| 99 |
+
pass
|
| 100 |
+
else:
|
| 101 |
+
setattr(args, key, value)
|
| 102 |
+
|
| 103 |
+
args.test_output_dir = f"{args.output_directory}"
|
| 104 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
| 105 |
+
|
| 106 |
+
device = torch.device(
|
| 107 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
###################### Define Models ######################
|
| 111 |
+
our_model = load_model_with_args(args)
|
| 112 |
+
our_model = our_model.to(device)
|
| 113 |
+
|
| 114 |
+
target_model_path = f"{args.weight_directory}/{args.target}.pth"
|
| 115 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
| 116 |
+
our_model.load_state_dict(checkpoint)
|
| 117 |
+
|
| 118 |
+
our_model.eval()
|
| 119 |
+
|
| 120 |
+
meter = pyln.Meter(44100)
|
| 121 |
+
|
| 122 |
+
sr, track_audio = input
|
| 123 |
+
track_audio = track_audio.T
|
| 124 |
+
track_name = "gradio_demo"
|
| 125 |
+
|
| 126 |
+
orig_audio = track_audio.copy()
|
| 127 |
+
|
| 128 |
+
if sr != 44100:
|
| 129 |
+
raise ValueError("Sample rate should be 44100")
|
| 130 |
+
augmented_gain = None
|
| 131 |
+
|
| 132 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
| 133 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
| 134 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
| 135 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 136 |
+
|
| 137 |
+
track_audio = (
|
| 138 |
+
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
estimates = separate_track_with_model(
|
| 142 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if args.save_mixed_output:
|
| 146 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
| 147 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
| 148 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
| 149 |
+
|
| 150 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
| 151 |
+
1 - args.save_mixed_output
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
sf.write(
|
| 155 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
| 156 |
+
mixed_output.T,
|
| 157 |
+
args.data_params.sample_rate,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
return (
|
| 161 |
+
(sr, estimates.T),
|
| 162 |
+
(sr, orig_audio.T),
|
| 163 |
+
(sr, orig_audio.T * mix_coefficient + estimates.T * (1 - mix_coefficient)),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def parallel_mix(input, output, mix_coefficient):
|
| 168 |
+
sr = 44100
|
| 169 |
+
return sr, input[1] * mix_coefficient + output[1] * (1 - mix_coefficient)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def int16_to_float32(wav):
|
| 173 |
+
wav = np.frombuffer(wav, dtype=np.int16)
|
| 174 |
+
X = wav / 32768
|
| 175 |
+
return X
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def waveform_plot(input, output, prl_mix_ouptut, figsize_x=20, figsize_y=9):
|
| 179 |
+
sr = 44100
|
| 180 |
+
fig, ax = plt.subplots(
|
| 181 |
+
nrows=3, sharex=True, sharey=True, figsize=(figsize_x, figsize_y)
|
| 182 |
+
)
|
| 183 |
+
librosa.display.waveshow(int16_to_float32(input[1]).T, sr=sr, ax=ax[0])
|
| 184 |
+
ax[0].set(title="Loudness Normalized Input")
|
| 185 |
+
ax[0].label_outer()
|
| 186 |
+
librosa.display.waveshow(int16_to_float32(output[1]).T, sr=sr, ax=ax[1])
|
| 187 |
+
ax[1].set(title="De-limiter Output")
|
| 188 |
+
ax[1].label_outer()
|
| 189 |
+
librosa.display.waveshow(int16_to_float32(prl_mix_ouptut[1]).T, sr=sr, ax=ax[2])
|
| 190 |
+
ax[2].set(title="Parallel Mix of the Input and its De-limiter Output")
|
| 191 |
+
ax[2].label_outer()
|
| 192 |
+
return fig
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
with gr.Blocks() as demo:
|
| 196 |
+
gr.HTML(
|
| 197 |
+
"""
|
| 198 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
| 199 |
+
<div
|
| 200 |
+
style="
|
| 201 |
+
display: inline-flex;
|
| 202 |
+
align-items: center;
|
| 203 |
+
gap: 0.8rem;
|
| 204 |
+
font-size: 1.75rem;
|
| 205 |
+
"
|
| 206 |
+
>
|
| 207 |
+
<h1 style="font-weight: 900; margin-bottom: 7px;">
|
| 208 |
+
Music De-limiter
|
| 209 |
+
</h1>
|
| 210 |
+
</div>
|
| 211 |
+
<p style="margin-bottom: 10px; font-size: 94%">
|
| 212 |
+
A demo for "Music De-limiter via Sample-wise Gain Inversion" to appear in WASPAA 2023.
|
| 213 |
+
You can first upload a music (.wav or .mp3) file and then press "De-limit" button to apply the De-limiter. Since we use a CPU instead of a GPU, it may require a few minute.
|
| 214 |
+
Then, you can apply a Parallel Mix technique, which is a simple linear mixing technique of "loudness normalized input" and the "de-limiter output".
|
| 215 |
+
You can modify the mixing coefficient by yourself.
|
| 216 |
+
If the coefficient is 0.3 then the output will be the "loudness_normalized_input * 0.3 + de-limiter_output * 0.7"
|
| 217 |
+
</div>
|
| 218 |
+
"""
|
| 219 |
+
)
|
| 220 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 221 |
+
with gr.Column():
|
| 222 |
+
with gr.Box():
|
| 223 |
+
input_audio = gr.Audio(source="upload", label="De-limiter Input")
|
| 224 |
+
btn = gr.Button("De-limit")
|
| 225 |
+
with gr.Column():
|
| 226 |
+
with gr.Box():
|
| 227 |
+
loud_norm_input = gr.Audio(label="Loudness Normalized Input (-14LUFS)")
|
| 228 |
+
with gr.Box():
|
| 229 |
+
output_audio = gr.Audio(label="De-limiter Output")
|
| 230 |
+
with gr.Box():
|
| 231 |
+
output_audio_parallel = gr.Audio(
|
| 232 |
+
label="Parallel Mix of the Input and its De-limiter Output"
|
| 233 |
+
)
|
| 234 |
+
slider = gr.Slider(
|
| 235 |
+
minimum=0,
|
| 236 |
+
maximum=1,
|
| 237 |
+
step=0.1,
|
| 238 |
+
value=0.5,
|
| 239 |
+
label="Parallel Mix Coefficient",
|
| 240 |
+
)
|
| 241 |
+
btn.click(
|
| 242 |
+
main,
|
| 243 |
+
inputs=[input_audio, slider],
|
| 244 |
+
outputs=[output_audio, loud_norm_input, output_audio_parallel],
|
| 245 |
+
)
|
| 246 |
+
slider.release(
|
| 247 |
+
parallel_mix,
|
| 248 |
+
inputs=[input_audio, output_audio, slider],
|
| 249 |
+
outputs=output_audio_parallel,
|
| 250 |
+
)
|
| 251 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
| 252 |
+
with gr.Column():
|
| 253 |
+
with gr.Box():
|
| 254 |
+
plot = gr.Plot(label="Plots")
|
| 255 |
+
btn2 = gr.Button("Show Plots")
|
| 256 |
+
slider_plot_x = gr.Slider(
|
| 257 |
+
minimum=1,
|
| 258 |
+
maximum=100,
|
| 259 |
+
step=1,
|
| 260 |
+
value=20,
|
| 261 |
+
label="Plot X-axis size",
|
| 262 |
+
)
|
| 263 |
+
slider_plot_y = gr.Slider(
|
| 264 |
+
minimum=1,
|
| 265 |
+
maximum=30,
|
| 266 |
+
step=1,
|
| 267 |
+
value=9,
|
| 268 |
+
label="Plot Y-axis size",
|
| 269 |
+
)
|
| 270 |
+
btn2.click(
|
| 271 |
+
waveform_plot,
|
| 272 |
+
inputs=[
|
| 273 |
+
loud_norm_input,
|
| 274 |
+
output_audio,
|
| 275 |
+
output_audio_parallel,
|
| 276 |
+
slider_plot_x,
|
| 277 |
+
slider_plot_y,
|
| 278 |
+
],
|
| 279 |
+
outputs=plot,
|
| 280 |
+
)
|
| 281 |
+
slider.release(
|
| 282 |
+
waveform_plot,
|
| 283 |
+
inputs=[
|
| 284 |
+
loud_norm_input,
|
| 285 |
+
output_audio,
|
| 286 |
+
output_audio_parallel,
|
| 287 |
+
slider_plot_x,
|
| 288 |
+
slider_plot_y,
|
| 289 |
+
],
|
| 290 |
+
outputs=plot,
|
| 291 |
+
)
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
demo.launch(debug=True)
|
configs/delimit_6_s.yaml
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# For De-limit task, Conv-TasNet.
|
| 2 |
+
# si_sdr loss
|
| 3 |
+
#
|
| 4 |
+
# ozone_train_fixed is about 6.36 hours
|
| 5 |
+
# 300,000 segments is about 333.33 hours
|
| 6 |
+
# ratio should be about 0.019
|
| 7 |
+
|
| 8 |
+
wandb_params:
|
| 9 |
+
use_wandb: true
|
| 10 |
+
entity: null # your wandb id
|
| 11 |
+
project: delimit # your wandb project
|
| 12 |
+
rerun_id: null # use when you rerun wandb.
|
| 13 |
+
sweep: false
|
| 14 |
+
|
| 15 |
+
sys_params:
|
| 16 |
+
nb_workers: 4
|
| 17 |
+
seed: 777
|
| 18 |
+
n_nodes: 1
|
| 19 |
+
port: null
|
| 20 |
+
rank: 0
|
| 21 |
+
|
| 22 |
+
task_params:
|
| 23 |
+
target: all # choices=["all"]
|
| 24 |
+
train: true
|
| 25 |
+
dataset: delimit # choices=["musdb", "delimit"]
|
| 26 |
+
|
| 27 |
+
dir_params:
|
| 28 |
+
root: /path/to/musdb18hq
|
| 29 |
+
output_directory: /path/to/results
|
| 30 |
+
exp_name: convtasnet_6_s # you MUST specify this
|
| 31 |
+
resume: null # "path of checkpoint folder"
|
| 32 |
+
continual_train: false # when we want to use a pre-trained model but not want to use lr_scheduler history.
|
| 33 |
+
delimit_valid_root: null
|
| 34 |
+
delimit_valid_L_root: null
|
| 35 |
+
ozone_root: /path/to/musdb-XL-train # you have to specify data_params.use_fixed
|
| 36 |
+
|
| 37 |
+
hyperparams:
|
| 38 |
+
batch_size: 8 # with 1 gpus (we used 2080ti 11GB)
|
| 39 |
+
epochs: 200
|
| 40 |
+
optimizer: adamw
|
| 41 |
+
weight_decay: 0.01
|
| 42 |
+
lr: 0.00003
|
| 43 |
+
lr_decay_gamma: 0.5
|
| 44 |
+
lr_decay_patience: 15
|
| 45 |
+
patience: 50
|
| 46 |
+
lr_scheduler: step_lr
|
| 47 |
+
gradient_clip: 5.0
|
| 48 |
+
ema: false
|
| 49 |
+
|
| 50 |
+
data_params:
|
| 51 |
+
nfft: 4096
|
| 52 |
+
nhop: 1024
|
| 53 |
+
nb_channels: 2
|
| 54 |
+
sample_rate: 44100
|
| 55 |
+
seq_dur: 4.0
|
| 56 |
+
singleset_num_frames: null
|
| 57 |
+
samples_per_track: 128 # "Number of samples per track to use for training."
|
| 58 |
+
limitaug_method: ozone
|
| 59 |
+
limitaug_mode: null
|
| 60 |
+
limitaug_custom_target_lufs: null
|
| 61 |
+
limitaug_custom_target_lufs_std: null
|
| 62 |
+
target_loudnorm_lufs: -14.0
|
| 63 |
+
random_mix: true
|
| 64 |
+
target_limitaug_mode: null
|
| 65 |
+
target_limitaug_custom_target_lufs: null
|
| 66 |
+
target_limitaug_custom_target_lufs_std: null
|
| 67 |
+
custom_limiter_attack_range: null
|
| 68 |
+
custom_limiter_release_range: null
|
| 69 |
+
use_fixed: 0.019 # range 0.0 ~ 1.0 => 1.0 will use fixed Ozoned_mixture training examples only.
|
| 70 |
+
|
| 71 |
+
model_loss_params:
|
| 72 |
+
architecture: conv_tasnet_mask_on_output # Sample-wise Gain Inversion (SGI)
|
| 73 |
+
train_loss_func: [si_sdr]
|
| 74 |
+
train_loss_scales: [1.]
|
| 75 |
+
valid_loss_func: [si_sdr]
|
| 76 |
+
valid_loss_scales: [1.]
|
| 77 |
+
|
| 78 |
+
conv_tasnet_params:
|
| 79 |
+
encoder_activation: relu
|
| 80 |
+
n_filters: 512
|
| 81 |
+
kernel_size: 128 # about 3ms in 44100Hz
|
| 82 |
+
stride: 64
|
| 83 |
+
n_blocks: 5
|
| 84 |
+
n_repeats: 2
|
| 85 |
+
bn_chan: 128
|
| 86 |
+
hid_chan: 512
|
| 87 |
+
skip_chan: 128
|
| 88 |
+
# conv_kernel_size:
|
| 89 |
+
# norm_type:
|
| 90 |
+
mask_act: relu
|
| 91 |
+
# causal:
|
| 92 |
+
decoder_activation: sigmoid
|
dataloader/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset import aug_from_str, MusdbTrainDataset, MusdbValidDataset
|
| 2 |
+
from .singleset import SingleTrackSet
|
| 3 |
+
from .delimit_dataset import (
|
| 4 |
+
DelimitTrainDataset,
|
| 5 |
+
DelimitValidDataset,
|
| 6 |
+
OzoneTrainDataset,
|
| 7 |
+
OzoneValidDataset,
|
| 8 |
+
)
|
dataloader/dataset.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataloader based on https://github.com/jeonchangbin49/LimitAug
|
| 2 |
+
import os
|
| 3 |
+
from glob import glob
|
| 4 |
+
import random
|
| 5 |
+
from typing import Optional, Callable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import librosa
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
import pyloudnorm as pyln
|
| 12 |
+
from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
|
| 13 |
+
|
| 14 |
+
from utils import load_wav_arbitrary_position_stereo, db2linear
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# based on https://github.com/sigsep/open-unmix-pytorch
|
| 18 |
+
def aug_from_str(list_of_function_names: list):
|
| 19 |
+
if list_of_function_names:
|
| 20 |
+
return Compose([globals()["_augment_" + aug] for aug in list_of_function_names])
|
| 21 |
+
else:
|
| 22 |
+
return lambda audio: audio
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Compose(object):
|
| 26 |
+
"""Composes several augmentation transforms.
|
| 27 |
+
Args:
|
| 28 |
+
augmentations: list of augmentations to compose.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, transforms):
|
| 32 |
+
self.transforms = transforms
|
| 33 |
+
|
| 34 |
+
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
for t in self.transforms:
|
| 36 |
+
audio = t(audio)
|
| 37 |
+
return audio
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# numpy based augmentation
|
| 41 |
+
# based on https://github.com/sigsep/open-unmix-pytorch
|
| 42 |
+
def _augment_gain(audio, low=0.25, high=1.25):
|
| 43 |
+
"""Applies a random gain between `low` and `high`"""
|
| 44 |
+
g = low + random.random() * (high - low)
|
| 45 |
+
return audio * g
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _augment_channelswap(audio):
|
| 49 |
+
"""Swap channels of stereo signals with a probability of p=0.5"""
|
| 50 |
+
if audio.shape[0] == 2 and random.random() < 0.5:
|
| 51 |
+
return np.flip(audio, axis=0) # axis=0 must be given
|
| 52 |
+
else:
|
| 53 |
+
return audio
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Linear gain increasing implementation for Method (1)
|
| 57 |
+
def apply_linear_gain_increase(mixture, target, board, meter, samplerate, target_lufs):
|
| 58 |
+
mixture, target = mixture.T, target.T
|
| 59 |
+
loudness = meter.integrated_loudness(mixture)
|
| 60 |
+
|
| 61 |
+
if np.isinf(loudness):
|
| 62 |
+
augmented_gain = 0.0
|
| 63 |
+
board[0].gain_db = augmented_gain
|
| 64 |
+
else:
|
| 65 |
+
augmented_gain = target_lufs - loudness
|
| 66 |
+
board[0].gain_db = augmented_gain
|
| 67 |
+
mixture = board(mixture.T, samplerate)
|
| 68 |
+
target = board(target.T, samplerate)
|
| 69 |
+
return mixture, target
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# LimitAug implementation for Method (2) and
|
| 73 |
+
# implementation of LimitAug then Loudness normalization for Method (4)
|
| 74 |
+
def apply_limitaug(
|
| 75 |
+
audio,
|
| 76 |
+
board,
|
| 77 |
+
meter,
|
| 78 |
+
samplerate,
|
| 79 |
+
target_lufs,
|
| 80 |
+
target_loudnorm_lufs=None,
|
| 81 |
+
loudness=None,
|
| 82 |
+
):
|
| 83 |
+
audio = audio.T
|
| 84 |
+
if loudness is None:
|
| 85 |
+
loudness = meter.integrated_loudness(audio)
|
| 86 |
+
|
| 87 |
+
if np.isinf(loudness):
|
| 88 |
+
augmented_gain = 0.0
|
| 89 |
+
board[0].gain_db = augmented_gain
|
| 90 |
+
else:
|
| 91 |
+
augmented_gain = target_lufs - loudness
|
| 92 |
+
board[0].gain_db = augmented_gain
|
| 93 |
+
audio = board(audio.T, samplerate)
|
| 94 |
+
|
| 95 |
+
if target_loudnorm_lufs:
|
| 96 |
+
after_loudness = meter.integrated_loudness(audio.T)
|
| 97 |
+
|
| 98 |
+
if np.isinf(after_loudness):
|
| 99 |
+
pass
|
| 100 |
+
else:
|
| 101 |
+
target_gain = target_loudnorm_lufs - after_loudness
|
| 102 |
+
audio = audio * db2linear(target_gain)
|
| 103 |
+
return audio, loudness
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
This dataloader implementation is based on https://github.com/sigsep/open-unmix-pytorch
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class MusdbTrainDataset(Dataset):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
target: str = "vocals",
|
| 115 |
+
root: str = None,
|
| 116 |
+
seq_duration: Optional[float] = 6.0,
|
| 117 |
+
samples_per_track: int = 64,
|
| 118 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
| 119 |
+
sample_rate: int = 44100,
|
| 120 |
+
seed: int = 42,
|
| 121 |
+
limitaug_method: str = "limitaug_then_loudnorm",
|
| 122 |
+
limitaug_mode: str = "normal_L",
|
| 123 |
+
limitaug_custom_target_lufs: float = None,
|
| 124 |
+
limitaug_custom_target_lufs_std: float = None,
|
| 125 |
+
target_loudnorm_lufs: float = -14.0,
|
| 126 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
| 127 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
| 128 |
+
*args,
|
| 129 |
+
**kwargs,
|
| 130 |
+
) -> None:
|
| 131 |
+
"""
|
| 132 |
+
Parameters
|
| 133 |
+
----------
|
| 134 |
+
limitaug_method : str
|
| 135 |
+
choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
|
| 136 |
+
limitaug_mode : str
|
| 137 |
+
choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
|
| 138 |
+
limitaug_custom_target_lufs : float
|
| 139 |
+
valid only when
|
| 140 |
+
limitaug_mode == "custom"
|
| 141 |
+
limitaug_custom_target_lufs_std : float
|
| 142 |
+
also valid only when
|
| 143 |
+
limitaug_mode == "custom
|
| 144 |
+
target_loudnorm_lufs : float
|
| 145 |
+
valid only when
|
| 146 |
+
limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
|
| 147 |
+
default is -14.
|
| 148 |
+
To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
|
| 149 |
+
No special reason for the choice of -14 as target_loudnorm_lufs.
|
| 150 |
+
target : str
|
| 151 |
+
target name of the source to be separated, defaults to ``vocals``.
|
| 152 |
+
root : str
|
| 153 |
+
root path of MUSDB
|
| 154 |
+
seq_duration : float
|
| 155 |
+
training is performed in chunks of ``seq_duration`` (in seconds,
|
| 156 |
+
defaults to ``None`` which loads the full audio track
|
| 157 |
+
samples_per_track : int
|
| 158 |
+
sets the number of samples, yielded from each track per epoch.
|
| 159 |
+
Defaults to 64
|
| 160 |
+
source_augmentations : list[callables]
|
| 161 |
+
provide list of augmentation function that take a multi-channel
|
| 162 |
+
audio file of shape (src, samples) as input and output. Defaults to
|
| 163 |
+
no-augmentations (input = output)
|
| 164 |
+
seed : int
|
| 165 |
+
control randomness of dataset iterations
|
| 166 |
+
args, kwargs : additional keyword arguments
|
| 167 |
+
used to add further control for the musdb dataset
|
| 168 |
+
initialization function.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
self.seed = seed
|
| 172 |
+
random.seed(seed)
|
| 173 |
+
self.seq_duration = seq_duration
|
| 174 |
+
self.target = target
|
| 175 |
+
self.samples_per_track = samples_per_track
|
| 176 |
+
self.source_augmentations = source_augmentations
|
| 177 |
+
self.sample_rate = sample_rate
|
| 178 |
+
|
| 179 |
+
self.root = root
|
| 180 |
+
self.sources = ["vocals", "bass", "drums", "other"]
|
| 181 |
+
self.train_list = glob(f"{self.root}/train/*")
|
| 182 |
+
self.valid_list = [
|
| 183 |
+
"ANiMAL - Rockshow",
|
| 184 |
+
"Actions - One Minute Smile",
|
| 185 |
+
"Alexander Ross - Goodbye Bolero",
|
| 186 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 187 |
+
"Fergessen - Nos Palpitants",
|
| 188 |
+
"James May - On The Line",
|
| 189 |
+
"Johnny Lokke - Promises & Lies",
|
| 190 |
+
"Leaf - Summerghost",
|
| 191 |
+
"Meaxic - Take A Step",
|
| 192 |
+
"Patrick Talbot - A Reason To Leave",
|
| 193 |
+
"Skelpolu - Human Mistakes",
|
| 194 |
+
"Traffic Experiment - Sirens",
|
| 195 |
+
"Triviul - Angelsaint",
|
| 196 |
+
"Young Griffo - Pennies",
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
self.train_list = [
|
| 200 |
+
x for x in self.train_list if os.path.basename(x) not in self.valid_list
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
# limitaug related
|
| 204 |
+
self.limitaug_method = limitaug_method
|
| 205 |
+
self.limitaug_mode = limitaug_mode
|
| 206 |
+
self.limitaug_custom_target_lufs = limitaug_custom_target_lufs
|
| 207 |
+
self.limitaug_custom_target_lufs_std = limitaug_custom_target_lufs_std
|
| 208 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
| 209 |
+
self.meter = pyln.Meter(self.sample_rate)
|
| 210 |
+
|
| 211 |
+
# Method (1) in our paper's Results section and Table 5
|
| 212 |
+
if self.limitaug_method == "linear_gain_increase":
|
| 213 |
+
print("using linear gain increasing!")
|
| 214 |
+
self.board = Pedalboard([Gain(gain_db=0.0)])
|
| 215 |
+
|
| 216 |
+
# Method (2) in our paper's Results section and Table 5
|
| 217 |
+
elif self.limitaug_method == "limitaug":
|
| 218 |
+
print("using limitaug!")
|
| 219 |
+
self.board = Pedalboard(
|
| 220 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Method (3) in our paper's Results section and Table 5
|
| 224 |
+
elif self.limitaug_method == "only_loudnorm":
|
| 225 |
+
print("using only loudness normalized inputs")
|
| 226 |
+
|
| 227 |
+
# Method (4) in our paper's Results section and Table 5
|
| 228 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
| 229 |
+
print("using limitaug then loudness normalize!")
|
| 230 |
+
self.board = Pedalboard(
|
| 231 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
| 235 |
+
print("using Custom limiter limitaug!")
|
| 236 |
+
self.custom_limiter_attack_range = custom_limiter_attack_range
|
| 237 |
+
self.custom_limiter_release_range = custom_limiter_release_range
|
| 238 |
+
self.board = Pedalboard(
|
| 239 |
+
[
|
| 240 |
+
Gain(gain_db=0.0),
|
| 241 |
+
Compressor(
|
| 242 |
+
threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
|
| 243 |
+
), # attack_ms and release_ms will be changed later.
|
| 244 |
+
Compressor(
|
| 245 |
+
threshold_db=0.0,
|
| 246 |
+
ratio=1000.0,
|
| 247 |
+
attack_ms=0.001,
|
| 248 |
+
release_ms=100.0,
|
| 249 |
+
),
|
| 250 |
+
Gain(gain_db=3.75),
|
| 251 |
+
Clipping(threshold_db=0.0),
|
| 252 |
+
]
|
| 253 |
+
) # This implementation is the same as JUCE Limiter.
|
| 254 |
+
# However, we want the first compressor to have a variable attack and release time.
|
| 255 |
+
# Therefore, we use the Custom Limiter instead of the JUCE Limiter.
|
| 256 |
+
|
| 257 |
+
self.limitaug_mode_statistics = {
|
| 258 |
+
"normal": [
|
| 259 |
+
-15.954,
|
| 260 |
+
1.264,
|
| 261 |
+
], # -15.954 is mean LUFS of musdb-hq and 1.264 is standard deviation
|
| 262 |
+
"normal_L": [
|
| 263 |
+
-10.887,
|
| 264 |
+
1.191,
|
| 265 |
+
], # -10.887 is mean LUFS of musdb-L and 1.191 is standard deviation
|
| 266 |
+
"normal_XL": [
|
| 267 |
+
-8.608,
|
| 268 |
+
1.165,
|
| 269 |
+
], # -8.608 is mean LUFS of musdb-L and 1.165 is standard deviation
|
| 270 |
+
"normal_short_term": [
|
| 271 |
+
-17.317,
|
| 272 |
+
5.036,
|
| 273 |
+
], # In our experiments, short-term statistics were not helpful.
|
| 274 |
+
"normal_L_short_term": [-12.303, 5.233],
|
| 275 |
+
"normal_XL_short_term": [-9.988, 5.518],
|
| 276 |
+
"custom": [limitaug_custom_target_lufs, limitaug_custom_target_lufs_std],
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
def sample_target_lufs(self):
|
| 280 |
+
if (
|
| 281 |
+
self.limitaug_mode == "uniform"
|
| 282 |
+
): # if limitaug_mode is uniform, then choose target_lufs from uniform distribution
|
| 283 |
+
target_lufs = random.uniform(-20, -5)
|
| 284 |
+
else: # else, choose target_lufs from gaussian distribution
|
| 285 |
+
target_lufs = random.gauss(
|
| 286 |
+
self.limitaug_mode_statistics[self.limitaug_mode][0],
|
| 287 |
+
self.limitaug_mode_statistics[self.limitaug_mode][1],
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
return target_lufs
|
| 291 |
+
|
| 292 |
+
def get_limitaug_results(self, mixture, target):
|
| 293 |
+
# Apply linear gain increasing (Method (1))
|
| 294 |
+
if self.limitaug_method == "linear_gain_increase":
|
| 295 |
+
target_lufs = self.sample_target_lufs()
|
| 296 |
+
mixture, target = apply_linear_gain_increase(
|
| 297 |
+
mixture,
|
| 298 |
+
target,
|
| 299 |
+
self.board,
|
| 300 |
+
self.meter,
|
| 301 |
+
self.sample_rate,
|
| 302 |
+
target_lufs=target_lufs,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Apply LimitAug (Method (2))
|
| 306 |
+
elif self.limitaug_method == "limitaug":
|
| 307 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
| 308 |
+
mixture_orig = mixture.copy()
|
| 309 |
+
target_lufs = self.sample_target_lufs()
|
| 310 |
+
mixture, _ = apply_limitaug(
|
| 311 |
+
mixture,
|
| 312 |
+
self.board,
|
| 313 |
+
self.meter,
|
| 314 |
+
self.sample_rate,
|
| 315 |
+
target_lufs=target_lufs,
|
| 316 |
+
)
|
| 317 |
+
print("mixture shape:", mixture.shape)
|
| 318 |
+
print("target shape:", target.shape)
|
| 319 |
+
target *= mixture / (mixture_orig + 1e-8)
|
| 320 |
+
|
| 321 |
+
# Apply only loudness normalization (Method(3))
|
| 322 |
+
elif self.limitaug_method == "only_loudnorm":
|
| 323 |
+
mixture_loudness = self.meter.integrated_loudness(mixture.T)
|
| 324 |
+
if np.isinf(
|
| 325 |
+
mixture_loudness
|
| 326 |
+
): # if the source is silence, then mixture_loudness is -inf.
|
| 327 |
+
pass
|
| 328 |
+
else:
|
| 329 |
+
augmented_gain = (
|
| 330 |
+
self.target_loudnorm_lufs - mixture_loudness
|
| 331 |
+
) # default target_loudnorm_lufs is -14.
|
| 332 |
+
mixture = mixture * db2linear(augmented_gain)
|
| 333 |
+
target = target * db2linear(augmented_gain)
|
| 334 |
+
|
| 335 |
+
# Apply LimitAug then loudness normalization (Method (4))
|
| 336 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
| 337 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
| 338 |
+
mixture_orig = mixture.copy()
|
| 339 |
+
target_lufs = self.sample_target_lufs()
|
| 340 |
+
mixture, _ = apply_limitaug(
|
| 341 |
+
mixture,
|
| 342 |
+
self.board,
|
| 343 |
+
self.meter,
|
| 344 |
+
self.sample_rate,
|
| 345 |
+
target_lufs=target_lufs,
|
| 346 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
| 347 |
+
)
|
| 348 |
+
target *= mixture / (mixture_orig + 1e-8)
|
| 349 |
+
|
| 350 |
+
# Apply LimitAug using Custom Limiter
|
| 351 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
| 352 |
+
# Change attack time of First compressor of the Limiter
|
| 353 |
+
self.board[1].attack_ms = random.uniform(
|
| 354 |
+
self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
|
| 355 |
+
)
|
| 356 |
+
# Change release time of First compressor of the Limiter
|
| 357 |
+
self.board[1].release_ms = random.uniform(
|
| 358 |
+
self.custom_limiter_release_range[0],
|
| 359 |
+
self.custom_limiter_release_range[1],
|
| 360 |
+
)
|
| 361 |
+
# Change release time of Second compressor of the Limiter
|
| 362 |
+
self.board[2].release_ms = random.uniform(30.0, 200.0)
|
| 363 |
+
mixture_orig = mixture.copy()
|
| 364 |
+
target_lufs = self.sample_target_lufs()
|
| 365 |
+
mixture, _ = apply_limitaug(
|
| 366 |
+
mixture,
|
| 367 |
+
self.board,
|
| 368 |
+
self.meter,
|
| 369 |
+
self.sample_rate,
|
| 370 |
+
target_lufs=target_lufs,
|
| 371 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
| 372 |
+
)
|
| 373 |
+
target *= mixture / (mixture_orig + 1e-8)
|
| 374 |
+
|
| 375 |
+
return mixture, target
|
| 376 |
+
|
| 377 |
+
def __getitem__(self, index):
|
| 378 |
+
audio_sources = []
|
| 379 |
+
target_ind = None
|
| 380 |
+
|
| 381 |
+
for k, source in enumerate(self.sources):
|
| 382 |
+
# memorize index of target source
|
| 383 |
+
if source == self.target: # if source is 'vocals'
|
| 384 |
+
target_ind = k
|
| 385 |
+
track_path = self.train_list[
|
| 386 |
+
index // self.samples_per_track
|
| 387 |
+
] # we want to use # training samples per each track.
|
| 388 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 389 |
+
audio = load_wav_arbitrary_position_stereo(
|
| 390 |
+
audio_path, self.sample_rate, self.seq_duration
|
| 391 |
+
)
|
| 392 |
+
else:
|
| 393 |
+
track_path = random.choice(self.train_list)
|
| 394 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 395 |
+
audio = load_wav_arbitrary_position_stereo(
|
| 396 |
+
audio_path, self.sample_rate, self.seq_duration
|
| 397 |
+
)
|
| 398 |
+
audio = self.source_augmentations(audio)
|
| 399 |
+
audio_sources.append(audio)
|
| 400 |
+
|
| 401 |
+
stems = np.stack(audio_sources, axis=0)
|
| 402 |
+
|
| 403 |
+
# # apply linear mix over source index=0
|
| 404 |
+
x = stems.sum(0)
|
| 405 |
+
# get the target stem
|
| 406 |
+
y = stems[target_ind]
|
| 407 |
+
|
| 408 |
+
# Apply the limitaug,
|
| 409 |
+
x, y = self.get_limitaug_results(x, y)
|
| 410 |
+
|
| 411 |
+
x = torch.as_tensor(x, dtype=torch.float32)
|
| 412 |
+
y = torch.as_tensor(y, dtype=torch.float32)
|
| 413 |
+
|
| 414 |
+
return x, y
|
| 415 |
+
|
| 416 |
+
def __len__(self):
|
| 417 |
+
return len(self.train_list) * self.samples_per_track
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class MusdbValidDataset(Dataset):
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
target: str = "vocals",
|
| 424 |
+
root: str = None,
|
| 425 |
+
*args,
|
| 426 |
+
**kwargs,
|
| 427 |
+
) -> None:
|
| 428 |
+
"""MUSDB18 torch.data.Dataset that samples from the MUSDB tracks
|
| 429 |
+
using track and excerpts with replacement.
|
| 430 |
+
Parameters
|
| 431 |
+
----------
|
| 432 |
+
target : str
|
| 433 |
+
target name of the source to be separated, defaults to ``vocals``.
|
| 434 |
+
root : str
|
| 435 |
+
root path of MUSDB18HQ dataset, defaults to ``None``.
|
| 436 |
+
args, kwargs : additional keyword arguments
|
| 437 |
+
used to add further control for the musdb dataset
|
| 438 |
+
initialization function.
|
| 439 |
+
"""
|
| 440 |
+
self.target = target
|
| 441 |
+
self.sample_rate = 44100.0 # musdb is fixed sample rate
|
| 442 |
+
|
| 443 |
+
self.root = root
|
| 444 |
+
self.sources = ["vocals", "bass", "drums", "other"]
|
| 445 |
+
self.train_list = glob(f"{self.root}/train/*")
|
| 446 |
+
|
| 447 |
+
self.valid_list = [
|
| 448 |
+
"ANiMAL - Rockshow",
|
| 449 |
+
"Actions - One Minute Smile",
|
| 450 |
+
"Alexander Ross - Goodbye Bolero",
|
| 451 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 452 |
+
"Fergessen - Nos Palpitants",
|
| 453 |
+
"James May - On The Line",
|
| 454 |
+
"Johnny Lokke - Promises & Lies",
|
| 455 |
+
"Leaf - Summerghost",
|
| 456 |
+
"Meaxic - Take A Step",
|
| 457 |
+
"Patrick Talbot - A Reason To Leave",
|
| 458 |
+
"Skelpolu - Human Mistakes",
|
| 459 |
+
"Traffic Experiment - Sirens",
|
| 460 |
+
"Triviul - Angelsaint",
|
| 461 |
+
"Young Griffo - Pennies",
|
| 462 |
+
]
|
| 463 |
+
self.valid_list = [
|
| 464 |
+
x for x in self.train_list if os.path.basename(x) in self.valid_list
|
| 465 |
+
]
|
| 466 |
+
|
| 467 |
+
def __getitem__(self, index):
|
| 468 |
+
audio_sources = []
|
| 469 |
+
target_ind = None
|
| 470 |
+
|
| 471 |
+
for k, source in enumerate(self.sources):
|
| 472 |
+
# memorize index of target source
|
| 473 |
+
if source == self.target: # if source is 'vocals'
|
| 474 |
+
target_ind = k
|
| 475 |
+
track_path = self.valid_list[index]
|
| 476 |
+
song_name = os.path.basename(track_path)
|
| 477 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 478 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
| 479 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 480 |
+
else:
|
| 481 |
+
track_path = self.valid_list[index]
|
| 482 |
+
song_name = os.path.basename(track_path)
|
| 483 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 484 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
| 485 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 486 |
+
|
| 487 |
+
audio = torch.as_tensor(audio, dtype=torch.float32)
|
| 488 |
+
audio_sources.append(audio)
|
| 489 |
+
|
| 490 |
+
stems = torch.stack(audio_sources, dim=0)
|
| 491 |
+
# # apply linear mix over source index=0
|
| 492 |
+
x = stems.sum(0)
|
| 493 |
+
# get the target stem
|
| 494 |
+
y = stems[target_ind]
|
| 495 |
+
|
| 496 |
+
return x, y, song_name
|
| 497 |
+
|
| 498 |
+
def __len__(self):
|
| 499 |
+
return len(self.valid_list)
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# If you want to check the LUFS values of training examples, run this.
|
| 503 |
+
if __name__ == "__main__":
|
| 504 |
+
import argparse
|
| 505 |
+
|
| 506 |
+
parser = argparse.ArgumentParser(
|
| 507 |
+
description="Make musdb-L and musdb-XL dataset from its ratio data"
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
parser.add_argument(
|
| 511 |
+
"--musdb_root",
|
| 512 |
+
type=str,
|
| 513 |
+
default="/path/to/musdb",
|
| 514 |
+
help="root path of musdb-hq dataset",
|
| 515 |
+
)
|
| 516 |
+
parser.add_argument(
|
| 517 |
+
"--limitaug_method",
|
| 518 |
+
type=str,
|
| 519 |
+
default="limitaug",
|
| 520 |
+
choices=[
|
| 521 |
+
"linear_gain_increase",
|
| 522 |
+
"limitaug",
|
| 523 |
+
"limitaug_then_loudnorm",
|
| 524 |
+
"only_loudnorm",
|
| 525 |
+
None,
|
| 526 |
+
],
|
| 527 |
+
help="choose limitaug method",
|
| 528 |
+
)
|
| 529 |
+
parser.add_argument(
|
| 530 |
+
"--limitaug_mode",
|
| 531 |
+
type=str,
|
| 532 |
+
default="normal_L",
|
| 533 |
+
choices=[
|
| 534 |
+
"uniform",
|
| 535 |
+
"normal",
|
| 536 |
+
"normal_L",
|
| 537 |
+
"normal_XL",
|
| 538 |
+
"normal_short_term",
|
| 539 |
+
"normal_L_short_term",
|
| 540 |
+
"normal_XL_short_term",
|
| 541 |
+
"custom",
|
| 542 |
+
],
|
| 543 |
+
help="if you use LimitAug, what lufs distribution to target",
|
| 544 |
+
)
|
| 545 |
+
parser.add_argument(
|
| 546 |
+
"--limitaug_custom_target_lufs",
|
| 547 |
+
type=float,
|
| 548 |
+
default=None,
|
| 549 |
+
help="if limitaug_mode is custom, set custom target lufs for LimitAug",
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
args, _ = parser.parse_known_args()
|
| 553 |
+
|
| 554 |
+
source_augmentations_ = aug_from_str(["gain", "channelswap"])
|
| 555 |
+
|
| 556 |
+
train_dataset = MusdbTrainDataset(
|
| 557 |
+
target="vocals",
|
| 558 |
+
root=args.musdb_root,
|
| 559 |
+
seq_duration=6.0,
|
| 560 |
+
source_augmentations=source_augmentations_,
|
| 561 |
+
limitaug_method=args.limitaug_method,
|
| 562 |
+
limitaug_mode=args.limitaug_mode,
|
| 563 |
+
limitaug_custom_target_lufs=args.limitaug_custom_target_lufs,
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
dataloader = torch.utils.data.DataLoader(
|
| 567 |
+
train_dataset,
|
| 568 |
+
batch_size=1,
|
| 569 |
+
shuffle=True,
|
| 570 |
+
num_workers=4,
|
| 571 |
+
pin_memory=True,
|
| 572 |
+
drop_last=False,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
meter = pyln.Meter(44100)
|
| 576 |
+
for i in range(5):
|
| 577 |
+
for x, y in dataloader:
|
| 578 |
+
loudness = meter.integrated_loudness(x[0].numpy().T)
|
| 579 |
+
print(f"mixture loudness : {loudness} LUFS")
|
dataloader/delimit_dataset.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from typing import Optional, Callable
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
import csv
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import librosa
|
| 11 |
+
import pyloudnorm as pyln
|
| 12 |
+
from pedalboard import Pedalboard, Limiter, Gain, Compressor, Clipping
|
| 13 |
+
|
| 14 |
+
from .dataset import (
|
| 15 |
+
MusdbTrainDataset,
|
| 16 |
+
MusdbValidDataset,
|
| 17 |
+
apply_limitaug,
|
| 18 |
+
)
|
| 19 |
+
from utils import (
|
| 20 |
+
load_wav_arbitrary_position_stereo,
|
| 21 |
+
load_wav_specific_position_stereo,
|
| 22 |
+
db2linear,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DelimitTrainDataset(MusdbTrainDataset):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
target: str = "all",
|
| 30 |
+
root: str = None,
|
| 31 |
+
seq_duration: Optional[float] = 6.0,
|
| 32 |
+
samples_per_track: int = 64,
|
| 33 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
| 34 |
+
sample_rate: int = 44100,
|
| 35 |
+
seed: int = 42,
|
| 36 |
+
limitaug_method: str = "limitaug",
|
| 37 |
+
limitaug_mode: str = "normal_L",
|
| 38 |
+
limitaug_custom_target_lufs: float = None,
|
| 39 |
+
limitaug_custom_target_lufs_std: float = None,
|
| 40 |
+
target_loudnorm_lufs: float = -14.0,
|
| 41 |
+
target_limitaug_mode: str = None,
|
| 42 |
+
target_limitaug_custom_target_lufs: float = None,
|
| 43 |
+
target_limitaug_custom_target_lufs_std: float = None,
|
| 44 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
| 45 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
| 46 |
+
*args,
|
| 47 |
+
**kwargs,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__(
|
| 50 |
+
target=target,
|
| 51 |
+
root=root,
|
| 52 |
+
seq_duration=seq_duration,
|
| 53 |
+
samples_per_track=samples_per_track,
|
| 54 |
+
source_augmentations=source_augmentations,
|
| 55 |
+
sample_rate=sample_rate,
|
| 56 |
+
seed=seed,
|
| 57 |
+
limitaug_method=limitaug_method,
|
| 58 |
+
limitaug_mode=limitaug_mode,
|
| 59 |
+
limitaug_custom_target_lufs=limitaug_custom_target_lufs,
|
| 60 |
+
limitaug_custom_target_lufs_std=limitaug_custom_target_lufs_std,
|
| 61 |
+
target_loudnorm_lufs=target_loudnorm_lufs,
|
| 62 |
+
custom_limiter_attack_range=custom_limiter_attack_range,
|
| 63 |
+
custom_limiter_release_range=custom_limiter_release_range,
|
| 64 |
+
*args,
|
| 65 |
+
**kwargs,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.target_limitaug_mode = target_limitaug_mode
|
| 69 |
+
|
| 70 |
+
self.target_limitaug_custom_target_lufs = (target_limitaug_custom_target_lufs,)
|
| 71 |
+
self.target_limitaug_custom_target_lufs_std = (
|
| 72 |
+
target_limitaug_custom_target_lufs_std,
|
| 73 |
+
)
|
| 74 |
+
self.limitaug_mode_statistics["target_custom"] = [
|
| 75 |
+
target_limitaug_custom_target_lufs,
|
| 76 |
+
target_limitaug_custom_target_lufs_std,
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
"""
|
| 80 |
+
Parameters
|
| 81 |
+
----------
|
| 82 |
+
limitaug_method : str
|
| 83 |
+
choose from ["linear_gain_increase", "limitaug", "limitaug_then_loudnorm", "only_loudnorm"]
|
| 84 |
+
limitaug_mode : str
|
| 85 |
+
choose from ["uniform", "normal", "normal_L", "normal_XL", "normal_short_term", "normal_L_short_term", "normal_XL_short_term", "custom"]
|
| 86 |
+
limitaug_custom_target_lufs : float
|
| 87 |
+
valid only when
|
| 88 |
+
limitaug_mode == "custom"
|
| 89 |
+
target_loudnorm_lufs : float
|
| 90 |
+
valid only when
|
| 91 |
+
limitaug_method == 'limitaug_then_loudnorm' or 'only_loudnorm'
|
| 92 |
+
default is -14.
|
| 93 |
+
To the best of my knowledge, Spotify and Youtube music is using -14 as a reference loudness normalization level.
|
| 94 |
+
No special reason for the choice of -14 as target_loudnorm_lufs.
|
| 95 |
+
target : str
|
| 96 |
+
target name of the source to be separated, defaults to ``vocals``.
|
| 97 |
+
root : str
|
| 98 |
+
root path of MUSDB
|
| 99 |
+
seq_duration : float
|
| 100 |
+
training is performed in chunks of ``seq_duration`` (in seconds,
|
| 101 |
+
defaults to ``None`` which loads the full audio track
|
| 102 |
+
samples_per_track : int
|
| 103 |
+
sets the number of samples, yielded from each track per epoch.
|
| 104 |
+
Defaults to 64
|
| 105 |
+
source_augmentations : list[callables]
|
| 106 |
+
provide list of augmentation function that take a multi-channel
|
| 107 |
+
audio file of shape (src, samples) as input and output. Defaults to
|
| 108 |
+
no-augmentations (input = output)
|
| 109 |
+
seed : int
|
| 110 |
+
control randomness of dataset iterations
|
| 111 |
+
args, kwargs : additional keyword arguments
|
| 112 |
+
used to add further control for the musdb dataset
|
| 113 |
+
initialization function.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
# Get a limitaug result without target (individual stem source)
|
| 117 |
+
def get_limitaug_mixture(self, mixture):
|
| 118 |
+
if self.limitaug_method == "limitaug":
|
| 119 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
| 120 |
+
target_lufs = self.sample_target_lufs()
|
| 121 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
| 122 |
+
mixture,
|
| 123 |
+
self.board,
|
| 124 |
+
self.meter,
|
| 125 |
+
self.sample_rate,
|
| 126 |
+
target_lufs=target_lufs,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
elif self.limitaug_method == "limitaug_then_loudnorm":
|
| 130 |
+
self.board[1].release_ms = random.uniform(30.0, 200.0)
|
| 131 |
+
target_lufs = self.sample_target_lufs()
|
| 132 |
+
mixture_limited, mixture_lufs = (
|
| 133 |
+
apply_limitaug(
|
| 134 |
+
mixture,
|
| 135 |
+
self.board,
|
| 136 |
+
self.meter,
|
| 137 |
+
self.sample_rate,
|
| 138 |
+
target_lufs=target_lufs,
|
| 139 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
| 140 |
+
),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Apply LimitAug using Custom Limiter
|
| 144 |
+
elif self.limitaug_method == "custom_limiter_limitaug":
|
| 145 |
+
# Change attack time of First compressor of the Limiter
|
| 146 |
+
self.board[1].attack_ms = random.uniform(
|
| 147 |
+
self.custom_limiter_attack_range[0], self.custom_limiter_attack_range[1]
|
| 148 |
+
)
|
| 149 |
+
# Change release time of First compressor of the Limiter
|
| 150 |
+
self.board[1].release_ms = random.uniform(
|
| 151 |
+
self.custom_limiter_release_range[0],
|
| 152 |
+
self.custom_limiter_release_range[1],
|
| 153 |
+
)
|
| 154 |
+
# Change release time of Second compressor of the Limiter
|
| 155 |
+
self.board[2].release_ms = random.uniform(30.0, 200.0)
|
| 156 |
+
target_lufs = self.sample_target_lufs()
|
| 157 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
| 158 |
+
mixture,
|
| 159 |
+
self.board,
|
| 160 |
+
self.meter,
|
| 161 |
+
self.sample_rate,
|
| 162 |
+
target_lufs=target_lufs,
|
| 163 |
+
target_loudnorm_lufs=self.target_loudnorm_lufs,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# When we want to force NN to output an appropriately compressed target output
|
| 167 |
+
if self.target_limitaug_mode:
|
| 168 |
+
mixture_target_lufs = random.gauss(
|
| 169 |
+
self.limitaug_mode_statistics[self.target_limitaug_mode][0],
|
| 170 |
+
self.limitaug_mode_statistics[self.target_limitaug_mode][1],
|
| 171 |
+
)
|
| 172 |
+
mixture, target_lufs = apply_limitaug(
|
| 173 |
+
mixture,
|
| 174 |
+
self.board,
|
| 175 |
+
self.meter,
|
| 176 |
+
self.sample_rate,
|
| 177 |
+
target_lufs=mixture_target_lufs,
|
| 178 |
+
loudness=mixture_lufs,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if np.isinf(mixture_lufs):
|
| 182 |
+
mixture_loudnorm = mixture
|
| 183 |
+
else:
|
| 184 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
| 185 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
|
| 186 |
+
|
| 187 |
+
return mixture_limited, mixture_loudnorm
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, index):
|
| 190 |
+
audio_sources = []
|
| 191 |
+
|
| 192 |
+
for k, source in enumerate(self.sources):
|
| 193 |
+
# memorize index of target source
|
| 194 |
+
if source == self.target: # if source is 'vocals'
|
| 195 |
+
track_path = self.train_list[
|
| 196 |
+
index // self.samples_per_track
|
| 197 |
+
] # we want to use # training samples per each track.
|
| 198 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 199 |
+
audio = load_wav_arbitrary_position_stereo(
|
| 200 |
+
audio_path, self.sample_rate, self.seq_duration
|
| 201 |
+
)
|
| 202 |
+
else:
|
| 203 |
+
track_path = random.choice(self.train_list)
|
| 204 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 205 |
+
audio = load_wav_arbitrary_position_stereo(
|
| 206 |
+
audio_path, self.sample_rate, self.seq_duration
|
| 207 |
+
)
|
| 208 |
+
audio = self.source_augmentations(audio)
|
| 209 |
+
audio_sources.append(audio)
|
| 210 |
+
|
| 211 |
+
stems = np.stack(audio_sources, axis=0)
|
| 212 |
+
|
| 213 |
+
# apply linear mix over source index=0
|
| 214 |
+
# and here, linear mixture is a target unlike in MusdbTrainDataset
|
| 215 |
+
mixture = stems.sum(0)
|
| 216 |
+
mixture_limited, mixture_loudnorm = self.get_limitaug_mixture(mixture)
|
| 217 |
+
# We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
|
| 218 |
+
|
| 219 |
+
mixture_limited = np.clip(mixture_limited, -1.0, 1.0)
|
| 220 |
+
mixture_limited = torch.as_tensor(mixture_limited, dtype=torch.float32)
|
| 221 |
+
mixture_loudnorm = torch.as_tensor(mixture_loudnorm, dtype=torch.float32)
|
| 222 |
+
|
| 223 |
+
return mixture_limited, mixture_loudnorm
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class OzoneTrainDataset(DelimitTrainDataset):
|
| 227 |
+
def __init__(
|
| 228 |
+
self,
|
| 229 |
+
target: str = "all",
|
| 230 |
+
root: str = None,
|
| 231 |
+
ozone_root: str = None,
|
| 232 |
+
use_fixed: float = 0.1, # ratio of fixed samples
|
| 233 |
+
seq_duration: Optional[float] = 6.0,
|
| 234 |
+
samples_per_track: int = 64,
|
| 235 |
+
source_augmentations: Optional[Callable] = lambda audio: audio,
|
| 236 |
+
sample_rate: int = 44100,
|
| 237 |
+
seed: int = 42,
|
| 238 |
+
limitaug_method: str = "limitaug",
|
| 239 |
+
limitaug_mode: str = "normal_L",
|
| 240 |
+
limitaug_custom_target_lufs: float = None,
|
| 241 |
+
limitaug_custom_target_lufs_std: float = None,
|
| 242 |
+
target_loudnorm_lufs: float = -14.0,
|
| 243 |
+
target_limitaug_mode: str = None,
|
| 244 |
+
target_limitaug_custom_target_lufs: float = None,
|
| 245 |
+
target_limitaug_custom_target_lufs_std: float = None,
|
| 246 |
+
custom_limiter_attack_range: list = [2.0, 2.0],
|
| 247 |
+
custom_limiter_release_range: list = [200.0, 200.0],
|
| 248 |
+
*args,
|
| 249 |
+
**kwargs,
|
| 250 |
+
) -> None:
|
| 251 |
+
super().__init__(
|
| 252 |
+
target,
|
| 253 |
+
root,
|
| 254 |
+
seq_duration,
|
| 255 |
+
samples_per_track,
|
| 256 |
+
source_augmentations,
|
| 257 |
+
sample_rate,
|
| 258 |
+
seed,
|
| 259 |
+
limitaug_method,
|
| 260 |
+
limitaug_mode,
|
| 261 |
+
limitaug_custom_target_lufs,
|
| 262 |
+
limitaug_custom_target_lufs_std,
|
| 263 |
+
target_loudnorm_lufs,
|
| 264 |
+
target_limitaug_mode,
|
| 265 |
+
target_limitaug_custom_target_lufs,
|
| 266 |
+
target_limitaug_custom_target_lufs_std,
|
| 267 |
+
custom_limiter_attack_range,
|
| 268 |
+
custom_limiter_release_range,
|
| 269 |
+
*args,
|
| 270 |
+
**kwargs,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
self.ozone_root = ozone_root
|
| 274 |
+
self.use_fixed = use_fixed
|
| 275 |
+
self.list_train_fixed = glob.glob(f"{self.ozone_root}/ozone_train_fixed/*.wav")
|
| 276 |
+
self.list_train_random = glob.glob(
|
| 277 |
+
f"{self.ozone_root}/ozone_train_random/*.wav"
|
| 278 |
+
)
|
| 279 |
+
self.dict_train_random = {}
|
| 280 |
+
|
| 281 |
+
# Load information of pre-generated random training examples
|
| 282 |
+
list_csv_files = glob.glob(f"{self.ozone_root}/ozone_train_random_*.csv")
|
| 283 |
+
list_csv_files.sort()
|
| 284 |
+
for csv_file in list_csv_files:
|
| 285 |
+
with open(csv_file, "r") as f:
|
| 286 |
+
reader = csv.reader(f)
|
| 287 |
+
next(reader)
|
| 288 |
+
for row in reader:
|
| 289 |
+
self.dict_train_random[row[0]] = {
|
| 290 |
+
"max_threshold": float(row[1]),
|
| 291 |
+
"max_character": float(row[2]),
|
| 292 |
+
"vocals": {
|
| 293 |
+
"name": row[3],
|
| 294 |
+
"start_sec": float(row[4]),
|
| 295 |
+
"gain": float(row[5]),
|
| 296 |
+
"channelswap": bool(row[6]),
|
| 297 |
+
},
|
| 298 |
+
"bass": {
|
| 299 |
+
"name": row[7],
|
| 300 |
+
"start_sec": float(row[8]),
|
| 301 |
+
"gain": float(row[9]),
|
| 302 |
+
"channelswap": bool(row[10]),
|
| 303 |
+
},
|
| 304 |
+
"drums": {
|
| 305 |
+
"name": row[11],
|
| 306 |
+
"start_sec": float(row[12]),
|
| 307 |
+
"gain": float(row[13]),
|
| 308 |
+
"channelswap": bool(row[14]),
|
| 309 |
+
},
|
| 310 |
+
"other": {
|
| 311 |
+
"name": row[15],
|
| 312 |
+
"start_sec": float(row[16]),
|
| 313 |
+
"gain": float(row[17]),
|
| 314 |
+
"channelswap": bool(row[18]),
|
| 315 |
+
},
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, idx):
|
| 319 |
+
use_fixed_prob = random.random()
|
| 320 |
+
|
| 321 |
+
if use_fixed_prob <= self.use_fixed:
|
| 322 |
+
# Fixed examples
|
| 323 |
+
audio_path = random.choice(self.list_train_fixed)
|
| 324 |
+
song_name = os.path.basename(audio_path).replace(".wav", "")
|
| 325 |
+
mixture_limited, start_pos_sec = load_wav_arbitrary_position_stereo(
|
| 326 |
+
audio_path, self.sample_rate, self.seq_duration, return_pos=True
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
audio_sources = []
|
| 330 |
+
track_path = f"{self.root}/train/{song_name}"
|
| 331 |
+
for source in self.sources:
|
| 332 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 333 |
+
audio = load_wav_specific_position_stereo(
|
| 334 |
+
audio_path,
|
| 335 |
+
self.sample_rate,
|
| 336 |
+
self.seq_duration,
|
| 337 |
+
start_position=start_pos_sec,
|
| 338 |
+
)
|
| 339 |
+
audio_sources.append(audio)
|
| 340 |
+
|
| 341 |
+
else:
|
| 342 |
+
# Random examples
|
| 343 |
+
# Load mixture_limited (pre-generated)
|
| 344 |
+
audio_path = random.choice(self.list_train_random)
|
| 345 |
+
seg_name = os.path.basename(audio_path).replace(".wav", "")
|
| 346 |
+
mixture_limited, sr = librosa.load(
|
| 347 |
+
audio_path, sr=self.sample_rate, mono=False
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Load mixture_unlimited (from the original musdb18, using metadata)
|
| 351 |
+
audio_sources = []
|
| 352 |
+
for source in self.sources:
|
| 353 |
+
dict_seg_info = self.dict_train_random[seg_name]
|
| 354 |
+
dict_seg_source_info = dict_seg_info[source]
|
| 355 |
+
audio_path = (
|
| 356 |
+
f"{self.root}/train/{dict_seg_source_info['name']}/{source}.wav"
|
| 357 |
+
)
|
| 358 |
+
audio = load_wav_specific_position_stereo(
|
| 359 |
+
audio_path,
|
| 360 |
+
self.sample_rate,
|
| 361 |
+
self.seq_duration,
|
| 362 |
+
start_position=dict_seg_source_info["start_sec"],
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# apply augmentations
|
| 366 |
+
audio = audio * dict_seg_source_info["gain"]
|
| 367 |
+
if dict_seg_source_info["channelswap"]:
|
| 368 |
+
audio = np.flip(audio, axis=0)
|
| 369 |
+
|
| 370 |
+
audio_sources.append(audio)
|
| 371 |
+
|
| 372 |
+
stems = np.stack(audio_sources, axis=0)
|
| 373 |
+
mixture = stems.sum(axis=0)
|
| 374 |
+
mixture_lufs = self.meter.integrated_loudness(mixture.T)
|
| 375 |
+
if np.isinf(mixture_lufs):
|
| 376 |
+
mixture_loudnorm = mixture
|
| 377 |
+
else:
|
| 378 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
| 379 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain, eps=0.0)
|
| 380 |
+
|
| 381 |
+
return mixture_limited, mixture_loudnorm
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class DelimitValidDataset(MusdbValidDataset):
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
target: str = "vocals",
|
| 388 |
+
root: str = None,
|
| 389 |
+
delimit_valid_root: str = None,
|
| 390 |
+
valid_target_lufs: float = -8.05, # From the Table 1 of the "Towards robust music source separation on loud commercial music" paper, the average loudness of commerical music.
|
| 391 |
+
target_loudnorm_lufs: float = -14.0,
|
| 392 |
+
delimit_valid_L_root: str = None, # This will be used when using the target as compressed (normal_L) mixture.
|
| 393 |
+
use_custom_limiter: bool = False,
|
| 394 |
+
custom_limiter_attack_range: list = [0.1, 10.0],
|
| 395 |
+
custom_limiter_release_range: list = [30.0, 200.0],
|
| 396 |
+
*args,
|
| 397 |
+
**kwargs,
|
| 398 |
+
) -> None:
|
| 399 |
+
super().__init__(target=target, root=root, *args, **kwargs)
|
| 400 |
+
self.delimit_valid_root = delimit_valid_root
|
| 401 |
+
if self.delimit_valid_root:
|
| 402 |
+
with open(f"{self.delimit_valid_root}/valid_loudness.json", "r") as f:
|
| 403 |
+
self.dict_valid_loudness = json.load(f)
|
| 404 |
+
self.delimit_valid_L_root = delimit_valid_L_root
|
| 405 |
+
if self.delimit_valid_L_root:
|
| 406 |
+
with open(f"{self.delimit_valid_L_root}/valid_loudness.json", "r") as f:
|
| 407 |
+
self.dict_valid_L_loudness = json.load(f)
|
| 408 |
+
|
| 409 |
+
self.valid_target_lufs = valid_target_lufs
|
| 410 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
| 411 |
+
self.meter = pyln.Meter(self.sample_rate)
|
| 412 |
+
self.use_custom_limiter = use_custom_limiter
|
| 413 |
+
|
| 414 |
+
if self.use_custom_limiter:
|
| 415 |
+
print("using Custom limiter limitaug for validation!!")
|
| 416 |
+
self.custom_limiter_attack_range = custom_limiter_attack_range
|
| 417 |
+
self.custom_limiter_release_range = custom_limiter_release_range
|
| 418 |
+
self.board = Pedalboard(
|
| 419 |
+
[
|
| 420 |
+
Gain(gain_db=0.0),
|
| 421 |
+
Compressor(
|
| 422 |
+
threshold_db=-10.0, ratio=4.0, attack_ms=2.0, release_ms=200.0
|
| 423 |
+
), # attack_ms and release_ms will be changed later.
|
| 424 |
+
Compressor(
|
| 425 |
+
threshold_db=0.0,
|
| 426 |
+
ratio=1000.0,
|
| 427 |
+
attack_ms=0.001,
|
| 428 |
+
release_ms=100.0,
|
| 429 |
+
),
|
| 430 |
+
Gain(gain_db=3.75),
|
| 431 |
+
Clipping(threshold_db=0.0),
|
| 432 |
+
]
|
| 433 |
+
) # This implementation is the same as JUCE Limiter.
|
| 434 |
+
# However, we want the first compressor to have a variable attack and release time.
|
| 435 |
+
# Therefore, we use the Custom Limiter instead of the JUCE Limiter.
|
| 436 |
+
else:
|
| 437 |
+
self.board = Pedalboard(
|
| 438 |
+
[Gain(gain_db=0.0), Limiter(threshold_db=0.0, release_ms=100.0)]
|
| 439 |
+
) # Currently, we are using a limiter with a release time of 100ms.
|
| 440 |
+
|
| 441 |
+
def __getitem__(self, index):
|
| 442 |
+
audio_sources = []
|
| 443 |
+
target_ind = None
|
| 444 |
+
|
| 445 |
+
for k, source in enumerate(self.sources):
|
| 446 |
+
# memorize index of target source
|
| 447 |
+
if source == self.target: # if source is 'vocals'
|
| 448 |
+
target_ind = k
|
| 449 |
+
track_path = self.valid_list[index]
|
| 450 |
+
song_name = os.path.basename(track_path)
|
| 451 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 452 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
| 453 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 454 |
+
else:
|
| 455 |
+
track_path = self.valid_list[index]
|
| 456 |
+
song_name = os.path.basename(track_path)
|
| 457 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 458 |
+
# audio = utils.load_wav_stereo(audio_path, self.sample_rate)
|
| 459 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 460 |
+
|
| 461 |
+
audio = torch.as_tensor(audio, dtype=torch.float32)
|
| 462 |
+
audio_sources.append(audio)
|
| 463 |
+
|
| 464 |
+
stems = np.stack(audio_sources, axis=0)
|
| 465 |
+
|
| 466 |
+
# apply linear mix over source index=0
|
| 467 |
+
# and here, linear mixture is a target unlike in MusdbTrainDataset
|
| 468 |
+
mixture = stems.sum(0)
|
| 469 |
+
if (
|
| 470 |
+
self.delimit_valid_root
|
| 471 |
+
): # If there exists a pre-processed delimit valid dataset
|
| 472 |
+
audio_path = f"{self.delimit_valid_root}/valid/{song_name}.wav"
|
| 473 |
+
mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[
|
| 474 |
+
0
|
| 475 |
+
]
|
| 476 |
+
mixture_lufs = self.dict_valid_loudness[song_name]
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
if self.use_custom_limiter:
|
| 480 |
+
custom_limiter_attack = random.uniform(
|
| 481 |
+
self.custom_limiter_attack_range[0],
|
| 482 |
+
self.custom_limiter_attack_range[1],
|
| 483 |
+
)
|
| 484 |
+
self.board[1].attack_ms = custom_limiter_attack
|
| 485 |
+
|
| 486 |
+
custom_limiter_release = random.uniform(
|
| 487 |
+
self.custom_limiter_release_range[0],
|
| 488 |
+
self.custom_limiter_release_range[1],
|
| 489 |
+
)
|
| 490 |
+
self.board[1].release_ms = custom_limiter_release
|
| 491 |
+
|
| 492 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
| 493 |
+
mixture,
|
| 494 |
+
self.board,
|
| 495 |
+
self.meter,
|
| 496 |
+
self.sample_rate,
|
| 497 |
+
target_lufs=self.valid_target_lufs,
|
| 498 |
+
)
|
| 499 |
+
else:
|
| 500 |
+
mixture_limited, mixture_lufs = apply_limitaug(
|
| 501 |
+
mixture,
|
| 502 |
+
self.board,
|
| 503 |
+
self.meter,
|
| 504 |
+
self.sample_rate,
|
| 505 |
+
target_lufs=self.valid_target_lufs,
|
| 506 |
+
# target_loudnorm_lufs=self.target_loudnorm_lufs,
|
| 507 |
+
) # mixture_limited is a limiter applied mixture
|
| 508 |
+
# We will give mixture_limited as an input and mixture_loudnorm as a target to the model.
|
| 509 |
+
|
| 510 |
+
if self.delimit_valid_L_root:
|
| 511 |
+
audio_L_path = f"{self.delimit_valid_L_root}/valid/{song_name}.wav"
|
| 512 |
+
mixture_loudnorm = librosa.load(
|
| 513 |
+
audio_L_path, mono=False, sr=self.sample_rate
|
| 514 |
+
)[0]
|
| 515 |
+
mixture_lufs = self.dict_valid_L_loudness[song_name]
|
| 516 |
+
mixture = mixture_loudnorm
|
| 517 |
+
|
| 518 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
| 519 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain)
|
| 520 |
+
|
| 521 |
+
if self.use_custom_limiter:
|
| 522 |
+
return (
|
| 523 |
+
mixture_limited,
|
| 524 |
+
mixture_loudnorm,
|
| 525 |
+
song_name,
|
| 526 |
+
mixture_lufs,
|
| 527 |
+
custom_limiter_attack,
|
| 528 |
+
custom_limiter_release,
|
| 529 |
+
)
|
| 530 |
+
else:
|
| 531 |
+
return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class OzoneValidDataset(MusdbValidDataset):
|
| 535 |
+
def __init__(
|
| 536 |
+
self,
|
| 537 |
+
target: str = "all",
|
| 538 |
+
root: str = None,
|
| 539 |
+
ozone_root: str = None,
|
| 540 |
+
target_loudnorm_lufs: float = -14.0,
|
| 541 |
+
*args,
|
| 542 |
+
**kwargs,
|
| 543 |
+
) -> None:
|
| 544 |
+
super().__init__(target=target, root=root, *args, **kwargs)
|
| 545 |
+
|
| 546 |
+
self.ozone_root = ozone_root
|
| 547 |
+
self.target_loudnorm_lufs = target_loudnorm_lufs
|
| 548 |
+
|
| 549 |
+
with open(f"{self.ozone_root}/valid_loudness.json", "r") as f:
|
| 550 |
+
self.dict_valid_loudness = json.load(f)
|
| 551 |
+
|
| 552 |
+
def __getitem__(self, index):
|
| 553 |
+
audio_sources = []
|
| 554 |
+
|
| 555 |
+
track_path = self.valid_list[index]
|
| 556 |
+
song_name = os.path.basename(track_path)
|
| 557 |
+
for k, source in enumerate(self.sources):
|
| 558 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 559 |
+
audio = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 560 |
+
audio_sources.append(audio)
|
| 561 |
+
|
| 562 |
+
stems = np.stack(audio_sources, axis=0)
|
| 563 |
+
|
| 564 |
+
mixture = stems.sum(0)
|
| 565 |
+
|
| 566 |
+
audio_path = f"{self.ozone_root}/ozone_train_fixed/{song_name}.wav"
|
| 567 |
+
mixture_limited = librosa.load(audio_path, mono=False, sr=self.sample_rate)[0]
|
| 568 |
+
|
| 569 |
+
mixture_lufs = self.dict_valid_loudness[song_name]
|
| 570 |
+
augmented_gain = self.target_loudnorm_lufs - mixture_lufs
|
| 571 |
+
mixture_loudnorm = mixture * db2linear(augmented_gain)
|
| 572 |
+
|
| 573 |
+
return mixture_limited, mixture_loudnorm, song_name, mixture_lufs
|
dataloader/singleset.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
# Modified version from woosungchoi's original implementation
|
| 8 |
+
class SingleTrackSet(Dataset):
|
| 9 |
+
def __init__(self, track, hop_length, num_frame=128, target_name="vocals"):
|
| 10 |
+
|
| 11 |
+
assert len(track.shape) == 2
|
| 12 |
+
assert track.shape[0] == 2 # check stereo audio
|
| 13 |
+
|
| 14 |
+
self.hop_length = hop_length
|
| 15 |
+
self.window_length = hop_length * (num_frame - 1) # 130048
|
| 16 |
+
self.trim_length = self.get_trim_length(self.hop_length) # 5120
|
| 17 |
+
|
| 18 |
+
self.true_samples = self.window_length - 2 * self.trim_length # 119808
|
| 19 |
+
|
| 20 |
+
self.lengths = [track.shape[1]] # track lengths (in sample level)
|
| 21 |
+
self.source_names = [
|
| 22 |
+
"vocals",
|
| 23 |
+
"drums",
|
| 24 |
+
"bass",
|
| 25 |
+
"other",
|
| 26 |
+
] # == self.musdb_train.targets_names[:-2]
|
| 27 |
+
|
| 28 |
+
self.target_names = [target_name]
|
| 29 |
+
|
| 30 |
+
self.num_tracks = 1
|
| 31 |
+
|
| 32 |
+
import math
|
| 33 |
+
|
| 34 |
+
num_chunks = [
|
| 35 |
+
math.ceil(length / self.true_samples) for length in self.lengths
|
| 36 |
+
] # example : 44.1khz 180sec audio, => [67]
|
| 37 |
+
self.acc_chunk_final_ids = [
|
| 38 |
+
sum(num_chunks[: i + 1]) for i in range(self.num_tracks)
|
| 39 |
+
] # [67]
|
| 40 |
+
|
| 41 |
+
self.cache_mode = True
|
| 42 |
+
self.cache = {}
|
| 43 |
+
self.cache[0] = {}
|
| 44 |
+
self.cache[0]["linear_mixture"] = track
|
| 45 |
+
|
| 46 |
+
def __len__(self):
|
| 47 |
+
return self.acc_chunk_final_ids[-1] * len(self.target_names) # 67
|
| 48 |
+
|
| 49 |
+
def __getitem__(self, idx):
|
| 50 |
+
|
| 51 |
+
target_offset = idx % len(self.target_names) # 0
|
| 52 |
+
idx = idx // len(self.target_names) # idx
|
| 53 |
+
|
| 54 |
+
target_name = self.target_names[target_offset] # 'vocals'
|
| 55 |
+
mixture_idx, start_pos = self.idx_to_track_offset(
|
| 56 |
+
idx
|
| 57 |
+
) # idx * self.true_samples
|
| 58 |
+
|
| 59 |
+
length = self.true_samples
|
| 60 |
+
left_padding_num = right_padding_num = self.trim_length # 5120
|
| 61 |
+
if mixture_idx is None:
|
| 62 |
+
raise StopIteration
|
| 63 |
+
mixture_length = self.lengths[mixture_idx]
|
| 64 |
+
if start_pos + length > mixture_length: # last
|
| 65 |
+
right_padding_num += self.true_samples - (mixture_length - start_pos)
|
| 66 |
+
length = None
|
| 67 |
+
|
| 68 |
+
mixture = self.get_audio(mixture_idx, "linear_mixture", start_pos, length)
|
| 69 |
+
mixture = F.pad(mixture, (left_padding_num, right_padding_num), "constant", 0)
|
| 70 |
+
|
| 71 |
+
return mixture
|
| 72 |
+
|
| 73 |
+
def idx_to_track_offset(self, idx):
|
| 74 |
+
|
| 75 |
+
for i, last_chunk in enumerate(self.acc_chunk_final_ids):
|
| 76 |
+
if idx < last_chunk:
|
| 77 |
+
if i != 0:
|
| 78 |
+
offset = (idx - self.acc_chunk_final_ids[i - 1]) * self.true_samples
|
| 79 |
+
else:
|
| 80 |
+
offset = idx * self.true_samples
|
| 81 |
+
return i, offset
|
| 82 |
+
|
| 83 |
+
return None, None
|
| 84 |
+
|
| 85 |
+
def get_audio(self, idx, target_name, pos=0, length=None):
|
| 86 |
+
track = self.cache[idx][target_name]
|
| 87 |
+
return track[:, pos : pos + length] if length is not None else track[:, pos:]
|
| 88 |
+
|
| 89 |
+
def get_trim_length(self, hop_length, min_trim=5000):
|
| 90 |
+
trim_per_hop = math.ceil(min_trim / hop_length)
|
| 91 |
+
|
| 92 |
+
trim_length = trim_per_hop * hop_length
|
| 93 |
+
assert trim_per_hop > 1
|
| 94 |
+
return trim_length
|
| 95 |
+
|
eval_delimit/calc_flops.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from deepspeed.profiling.flops_profiler import get_model_profile
|
| 7 |
+
|
| 8 |
+
from utils import get_config
|
| 9 |
+
from models import load_model_with_args
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# def main():
|
| 13 |
+
parser = argparse.ArgumentParser(description="FLOPs calculation")
|
| 14 |
+
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"-c", "--config", default="delimit_6_s", type=str, help="Name of the setting file."
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
config_args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
args = get_config(config_args.config)
|
| 22 |
+
print(args)
|
| 23 |
+
|
| 24 |
+
with torch.cuda.device(0):
|
| 25 |
+
model = load_model_with_args(args)
|
| 26 |
+
batch_size = 1
|
| 27 |
+
flops, macs, params = get_model_profile(
|
| 28 |
+
model=model, # model
|
| 29 |
+
input_shape=(batch_size, 2, 44100 * 60), # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
|
| 30 |
+
args=[], # list of positional arguments to the model.
|
| 31 |
+
kwargs={}, # dictionary of keyword arguments to the model.
|
| 32 |
+
print_profile=True, # prints the model graph with the measured profile attached to each module
|
| 33 |
+
detailed=True, # print the detailed profile
|
| 34 |
+
module_depth=-1, # depth into the nested modules, with -1 being the inner most modules
|
| 35 |
+
top_modules=1, # the number of top modules to print aggregated profile
|
| 36 |
+
warm_up=1, # the number of warm-ups before measuring the time of each module
|
| 37 |
+
as_string=True, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
|
| 38 |
+
output_file=None, # path to the output file. If None, the profiler prints to stdout.
|
| 39 |
+
ignore_modules=None,
|
| 40 |
+
) # the list of modules to ignore in the profiling
|
| 41 |
+
print(args.dir_params.exp_name)
|
| 42 |
+
print('flops: ', flops)
|
| 43 |
+
print('macs: ', macs)
|
| 44 |
+
print('params: ', params)
|
eval_delimit/score_calc_delimit.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Calculate SI-SDR, Multi-resolution spectrogram mse score of the pre-inferenced sources
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import json
|
| 6 |
+
import glob
|
| 7 |
+
|
| 8 |
+
import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
import librosa
|
| 11 |
+
import pyloudnorm as pyln
|
| 12 |
+
from asteroid.metrics import get_metrics
|
| 13 |
+
|
| 14 |
+
from utils import str2bool
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def multi_resolution_spectrogram_mse(
|
| 18 |
+
gt, est, n_fft=[2048, 1024, 512], n_hop=[512, 256, 128]
|
| 19 |
+
):
|
| 20 |
+
assert gt.shape == est.shape
|
| 21 |
+
assert len(n_fft) == len(n_hop)
|
| 22 |
+
|
| 23 |
+
score = 0.0
|
| 24 |
+
for i in range(len(n_fft)):
|
| 25 |
+
gt_spec = librosa.magphase(
|
| 26 |
+
librosa.stft(gt, n_fft=n_fft[i], hop_length=n_hop[i])
|
| 27 |
+
)[0]
|
| 28 |
+
est_spec = librosa.magphase(
|
| 29 |
+
librosa.stft(est, n_fft=n_fft[i], hop_length=n_hop[i])
|
| 30 |
+
)[0]
|
| 31 |
+
score = score + np.mean((gt_spec - est_spec) ** 2)
|
| 32 |
+
|
| 33 |
+
return score
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--target",
|
| 40 |
+
type=str,
|
| 41 |
+
default="all",
|
| 42 |
+
help="target source. all, vocals, drums, bass, other, 0.5_mixed",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--output_directory",
|
| 50 |
+
type=str,
|
| 51 |
+
default="/path/to/results",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument("--loudnorm_lufs", type=float, default=-14.0)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--calc_mse",
|
| 56 |
+
type=str2bool,
|
| 57 |
+
default=True,
|
| 58 |
+
help="calculate multi-resolution spectrogram mse",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--calc_results",
|
| 63 |
+
type=str2bool,
|
| 64 |
+
default=True,
|
| 65 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
args, _ = parser.parse_known_args()
|
| 69 |
+
|
| 70 |
+
args.sample_rate = 44100
|
| 71 |
+
|
| 72 |
+
meter = pyln.Meter(args.sample_rate)
|
| 73 |
+
|
| 74 |
+
if args.calc_results:
|
| 75 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 76 |
+
else:
|
| 77 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
| 78 |
+
|
| 79 |
+
if args.target == "all" or args.target == "0.5_mixed":
|
| 80 |
+
test_tracks = glob.glob(f"{args.root}/*/mixture.wav")
|
| 81 |
+
else:
|
| 82 |
+
test_tracks = glob.glob(f"{args.root}/*/{args.target}.wav")
|
| 83 |
+
i = 0
|
| 84 |
+
|
| 85 |
+
dict_song_score = {}
|
| 86 |
+
list_si_sdr = []
|
| 87 |
+
list_multi_mse = []
|
| 88 |
+
for track in tqdm.tqdm(test_tracks):
|
| 89 |
+
if args.target == "all": # for standard de-limiter estimation
|
| 90 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
| 91 |
+
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
|
| 92 |
+
|
| 93 |
+
est_delimiter = librosa.load(
|
| 94 |
+
f"{args.test_output_dir}/{audio_name}/all.wav",
|
| 95 |
+
sr=args.sample_rate,
|
| 96 |
+
mono=False,
|
| 97 |
+
)[0]
|
| 98 |
+
|
| 99 |
+
else: # for source-separated de-limiter estimation
|
| 100 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
| 101 |
+
gt_source = librosa.load(track, sr=args.sample_rate, mono=False)[0]
|
| 102 |
+
est_delimiter = librosa.load(
|
| 103 |
+
f"{args.test_output_dir}/{audio_name}/{args.target}.wav",
|
| 104 |
+
sr=args.sample_rate,
|
| 105 |
+
mono=False,
|
| 106 |
+
)[0]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
metrics_dict = get_metrics(
|
| 110 |
+
gt_source + est_delimiter,
|
| 111 |
+
gt_source,
|
| 112 |
+
est_delimiter,
|
| 113 |
+
sample_rate=args.sample_rate,
|
| 114 |
+
metrics_list=["si_sdr"],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if args.calc_mse:
|
| 118 |
+
multi_resolution_spectrogram_mse_score = multi_resolution_spectrogram_mse(
|
| 119 |
+
gt_source, est_delimiter
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
multi_resolution_spectrogram_mse_score = None
|
| 123 |
+
|
| 124 |
+
dict_song_score[audio_name] = {
|
| 125 |
+
"si_sdr": metrics_dict["si_sdr"],
|
| 126 |
+
"multi_mse": multi_resolution_spectrogram_mse_score,
|
| 127 |
+
}
|
| 128 |
+
list_si_sdr.append(metrics_dict["si_sdr"])
|
| 129 |
+
list_multi_mse.append(multi_resolution_spectrogram_mse_score)
|
| 130 |
+
|
| 131 |
+
i += 1
|
| 132 |
+
|
| 133 |
+
print(f"{args.exp_name} on {args.target}")
|
| 134 |
+
print(f"SI-SDR score: {sum(list_si_sdr) / len(list_si_sdr)}")
|
| 135 |
+
if args.calc_mse:
|
| 136 |
+
print(f"multi-mse score: {sum(list_multi_mse) / len(list_multi_mse)}")
|
| 137 |
+
|
| 138 |
+
if args.target != "all":
|
| 139 |
+
# save dict_song_score to json file
|
| 140 |
+
with open(f"{args.test_output_dir}/score_{args.target}.json", "w") as f:
|
| 141 |
+
json.dump(dict_song_score, f, indent=4)
|
| 142 |
+
else:
|
| 143 |
+
# save dict_song_score to json file
|
| 144 |
+
with open(f"{args.test_output_dir}/score.json", "w") as f:
|
| 145 |
+
json.dump(dict_song_score, f, indent=4)
|
eval_delimit/score_diff_dyn_complexity.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
import tqdm
|
| 8 |
+
import numpy as np
|
| 9 |
+
import librosa
|
| 10 |
+
import musdb
|
| 11 |
+
import pyloudnorm as pyln
|
| 12 |
+
|
| 13 |
+
from utils import str2bool, db2linear
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 16 |
+
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--target",
|
| 19 |
+
type=str,
|
| 20 |
+
default="all",
|
| 21 |
+
help="target source. all, vocals, bass, drums, other.",
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--root",
|
| 25 |
+
type=str,
|
| 26 |
+
default="/path/to/musdb18hq_loudnorm",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--output_directory",
|
| 30 |
+
type=str,
|
| 31 |
+
default="/path/to/results",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument("--exp_name", type=str, default="convtasnet_6_s")
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--calc_results",
|
| 36 |
+
type=str2bool,
|
| 37 |
+
default=True,
|
| 38 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
args, _ = parser.parse_known_args()
|
| 42 |
+
|
| 43 |
+
args.sample_rate = 44100
|
| 44 |
+
meter = pyln.Meter(args.sample_rate)
|
| 45 |
+
|
| 46 |
+
if args.calc_results:
|
| 47 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 48 |
+
else:
|
| 49 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
est_track_list = glob.glob(f"{args.test_output_dir}/*/{args.target}.wav")
|
| 53 |
+
f = open(
|
| 54 |
+
f"{args.test_output_dir}/score_feature_{args.target}.json",
|
| 55 |
+
encoding="UTF-8",
|
| 56 |
+
)
|
| 57 |
+
dict_song_score_est = json.loads(f.read())
|
| 58 |
+
|
| 59 |
+
if args.target == "all":
|
| 60 |
+
ref_track_list = glob.glob(f"{args.root}/*/mixture.wav")
|
| 61 |
+
f = open(f"{args.root}/score_feature.json", encoding="UTF-8")
|
| 62 |
+
dict_song_score_ref = json.loads(f.read())
|
| 63 |
+
else:
|
| 64 |
+
ref_track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
|
| 65 |
+
f = open(f"{args.root}/score_feature_{args.target}.json", encoding="UTF-8")
|
| 66 |
+
dict_song_score_ref = json.loads(f.read())
|
| 67 |
+
|
| 68 |
+
i = 0
|
| 69 |
+
|
| 70 |
+
dict_song_score = {}
|
| 71 |
+
list_diff_dynamic_complexity = []
|
| 72 |
+
|
| 73 |
+
for track in tqdm.tqdm(ref_track_list):
|
| 74 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
| 75 |
+
ref_dyn_complexity = dict_song_score_ref[audio_name]["dynamic_complexity_score"]
|
| 76 |
+
est_dyn_complexity = dict_song_score_est[audio_name]["dynamic_complexity_score"]
|
| 77 |
+
|
| 78 |
+
list_diff_dynamic_complexity.append(est_dyn_complexity - ref_dyn_complexity)
|
| 79 |
+
|
| 80 |
+
i += 1
|
| 81 |
+
|
| 82 |
+
print(
|
| 83 |
+
f"Dynamic complexity difference {args.exp_name} vs {os.path.basename(args.root)} on {args.target}"
|
| 84 |
+
)
|
| 85 |
+
print("mean: ", np.mean(list_diff_dynamic_complexity))
|
| 86 |
+
print("median: ", np.median(list_diff_dynamic_complexity))
|
| 87 |
+
print("std: ", np.std(list_diff_dynamic_complexity))
|
eval_delimit/score_fad.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We are going to use FAD based on https://github.com/gudgud96/frechet-audio-distance
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import glob
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from frechet_audio_distance import FrechetAudioDistance
|
| 8 |
+
|
| 9 |
+
from utils import str2bool
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 13 |
+
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
"--target",
|
| 16 |
+
type=str,
|
| 17 |
+
default="all",
|
| 18 |
+
help="target source. all, vocals, drums, bass, other",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--root",
|
| 22 |
+
type=str,
|
| 23 |
+
default="/path/to/musdb18hq_loudnorm",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--output_directory",
|
| 27 |
+
type=str,
|
| 28 |
+
default="/path/to/results",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--calc_results",
|
| 33 |
+
type=str2bool,
|
| 34 |
+
default=True,
|
| 35 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
args, _ = parser.parse_known_args()
|
| 39 |
+
|
| 40 |
+
os.makedirs(f"{args.root}/musdb_hq_loudnorm_16k_mono_link", exist_ok=True)
|
| 41 |
+
|
| 42 |
+
song_list = glob.glob(f"{args.root}/musdb_hq_loudnorm_16k_mono/*/mixture.wav")
|
| 43 |
+
for song in song_list:
|
| 44 |
+
song_name = os.path.basename(os.path.dirname(song))
|
| 45 |
+
subprocess.run(
|
| 46 |
+
f'ln --symbolic "{song}" "{args.root}/musdb_hq_loudnorm_16k_mono_link/{song_name}.wav"',
|
| 47 |
+
shell=True,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if args.calc_results:
|
| 52 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 53 |
+
else:
|
| 54 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
| 55 |
+
|
| 56 |
+
os.makedirs(f"{args.test_output_dir}_16k_mono_link", exist_ok=True)
|
| 57 |
+
|
| 58 |
+
song_list = glob.glob(f"{args.test_output_dir}_16k_mono/*/{args.target}.wav")
|
| 59 |
+
for song in song_list:
|
| 60 |
+
song_name = os.path.basename(os.path.dirname(song))
|
| 61 |
+
subprocess.run(
|
| 62 |
+
f'ln --symbolic "{song}" "{args.test_output_dir}_16k_mono_link/{song_name}.wav"',
|
| 63 |
+
shell=True,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
frechet = FrechetAudioDistance()
|
| 68 |
+
|
| 69 |
+
fad_score = frechet.score(
|
| 70 |
+
f"{args.root}/musdb_hq_loudnorm_16k_mono_link",
|
| 71 |
+
f"{args.test_output_dir}_16k_mono_link",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
print(f"{args.exp_name}")
|
| 75 |
+
print(f"FAD score: {fad_score}")
|
eval_delimit/score_features.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import csv
|
| 4 |
+
import json
|
| 5 |
+
import glob
|
| 6 |
+
from typing import Any, Optional, Union, Collection
|
| 7 |
+
|
| 8 |
+
import tqdm
|
| 9 |
+
import numpy as np
|
| 10 |
+
import librosa
|
| 11 |
+
from librosa.core.spectrum import _spectrogram
|
| 12 |
+
import musdb
|
| 13 |
+
import essentia
|
| 14 |
+
import essentia.standard
|
| 15 |
+
import pyloudnorm as pyln
|
| 16 |
+
|
| 17 |
+
from utils import str2bool, db2linear
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def spectral_crest(
|
| 21 |
+
*,
|
| 22 |
+
y: Optional[np.ndarray] = None,
|
| 23 |
+
S: Optional[np.ndarray] = None,
|
| 24 |
+
n_fft: int = 2048,
|
| 25 |
+
hop_length: int = 512,
|
| 26 |
+
win_length: Optional[int] = None,
|
| 27 |
+
window: str = "hann",
|
| 28 |
+
center: bool = True,
|
| 29 |
+
pad_mode: str = "constant",
|
| 30 |
+
amin: float = 1e-10,
|
| 31 |
+
power: float = 2.0,
|
| 32 |
+
) -> np.ndarray:
|
| 33 |
+
"""Compute spectral crest
|
| 34 |
+
|
| 35 |
+
Spectral crest (or tonality coefficient) is a measure of
|
| 36 |
+
the ratio of the maximum of the spectrum to the arithmetic mean of the spectrum
|
| 37 |
+
|
| 38 |
+
A higher spectral crest => more tonality,
|
| 39 |
+
A lower spectral crest => more noisy.
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
Parameters
|
| 43 |
+
----------
|
| 44 |
+
y : np.ndarray [shape=(..., n)] or None
|
| 45 |
+
audio time series. Multi-channel is supported.
|
| 46 |
+
S : np.ndarray [shape=(..., d, t)] or None
|
| 47 |
+
(optional) pre-computed spectrogram magnitude
|
| 48 |
+
n_fft : int > 0 [scalar]
|
| 49 |
+
FFT window size
|
| 50 |
+
hop_length : int > 0 [scalar]
|
| 51 |
+
hop length for STFT. See `librosa.stft` for details.
|
| 52 |
+
win_length : int <= n_fft [scalar]
|
| 53 |
+
Each frame of audio is windowed by `window()`.
|
| 54 |
+
The window will be of length `win_length` and then padded
|
| 55 |
+
with zeros to match ``n_fft``.
|
| 56 |
+
If unspecified, defaults to ``win_length = n_fft``.
|
| 57 |
+
window : string, tuple, number, function, or np.ndarray [shape=(n_fft,)]
|
| 58 |
+
- a window specification (string, tuple, or number);
|
| 59 |
+
see `scipy.signal.get_window`
|
| 60 |
+
- a window function, such as `scipy.signal.windows.hann`
|
| 61 |
+
- a vector or array of length ``n_fft``
|
| 62 |
+
.. see also:: `librosa.filters.get_window`
|
| 63 |
+
center : boolean
|
| 64 |
+
- If `True`, the signal ``y`` is padded so that frame
|
| 65 |
+
``t`` is centered at ``y[t * hop_length]``.
|
| 66 |
+
- If `False`, then frame `t` begins at ``y[t * hop_length]``
|
| 67 |
+
pad_mode : string
|
| 68 |
+
If ``center=True``, the padding mode to use at the edges of the signal.
|
| 69 |
+
By default, STFT uses zero padding.
|
| 70 |
+
amin : float > 0 [scalar]
|
| 71 |
+
minimum threshold for ``S`` (=added noise floor for numerical stability)
|
| 72 |
+
power : float > 0 [scalar]
|
| 73 |
+
Exponent for the magnitude spectrogram.
|
| 74 |
+
e.g., 1 for energy, 2 for power, etc.
|
| 75 |
+
Power spectrogram is usually used for computing spectral flatness.
|
| 76 |
+
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
crest : np.ndarray [shape=(..., 1, t)]
|
| 80 |
+
spectral crest for each frame.
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
S, n_fft = _spectrogram(
|
| 86 |
+
y=y,
|
| 87 |
+
S=S,
|
| 88 |
+
n_fft=n_fft,
|
| 89 |
+
hop_length=hop_length,
|
| 90 |
+
power=1.0,
|
| 91 |
+
win_length=win_length,
|
| 92 |
+
window=window,
|
| 93 |
+
center=center,
|
| 94 |
+
pad_mode=pad_mode,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
S_thresh = np.maximum(amin, S**power)
|
| 98 |
+
# gmean = np.exp(np.mean(np.log(S_thresh), axis=-2, keepdims=True))
|
| 99 |
+
gmax = np.max(S_thresh, axis=-2, keepdims=True)
|
| 100 |
+
amean = np.mean(S_thresh, axis=-2, keepdims=True)
|
| 101 |
+
crest: np.ndarray = gmax / amean
|
| 102 |
+
return crest
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 106 |
+
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--target",
|
| 109 |
+
type=str,
|
| 110 |
+
default="all",
|
| 111 |
+
help="target source. all, vocals, drums, bass, other",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--root", type=str, default="/path/to/musdb18hq_loudnorm"
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--output_directory",
|
| 119 |
+
type=str,
|
| 120 |
+
default="/path/to/results",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--calc_results",
|
| 124 |
+
type=str2bool,
|
| 125 |
+
default=True,
|
| 126 |
+
help="calculate results or musdb-hq or musdb-XL test dataset",
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
args, _ = parser.parse_known_args()
|
| 131 |
+
|
| 132 |
+
args.sample_rate = 44100
|
| 133 |
+
|
| 134 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 135 |
+
|
| 136 |
+
if args.calc_results:
|
| 137 |
+
track_list = glob.glob(
|
| 138 |
+
f"{args.output_directory}/test/{args.exp_name}/*/{args.target}.wav"
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
if args.target == "all":
|
| 142 |
+
track_list = glob.glob(f"{args.root}/*/mixture.wav")
|
| 143 |
+
else:
|
| 144 |
+
track_list = glob.glob(f"{args.root}/*/{args.target}.wav")
|
| 145 |
+
|
| 146 |
+
i = 0
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
dynamic_complexity = essentia.standard.DynamicComplexity()
|
| 150 |
+
loudness_range = essentia.standard.LoudnessEBUR128()
|
| 151 |
+
spectral_centroid = essentia.standard.SpectralCentroidTime()
|
| 152 |
+
crest = essentia.standard.Crest()
|
| 153 |
+
dynamic_spread = essentia.standard.DistributionShape()
|
| 154 |
+
central_moments = essentia.standard.CentralMoments()
|
| 155 |
+
|
| 156 |
+
dict_song_score = {}
|
| 157 |
+
list_rms = []
|
| 158 |
+
list_crest_factor = []
|
| 159 |
+
list_dc_score = []
|
| 160 |
+
list_lra_score = []
|
| 161 |
+
list_sc_hertz = []
|
| 162 |
+
list_sf_score = []
|
| 163 |
+
list_spectral_crest_score = []
|
| 164 |
+
|
| 165 |
+
for track in tqdm.tqdm(track_list):
|
| 166 |
+
audio_name = os.path.basename(os.path.dirname(track))
|
| 167 |
+
gt_source_librosa = librosa.load(f"{track}", sr=args.sample_rate, mono=False)[
|
| 168 |
+
0
|
| 169 |
+
] # (nb_channels, nb_samples)
|
| 170 |
+
gt_source_librosa_mono = librosa.to_mono(gt_source_librosa) # (nb_samples)
|
| 171 |
+
|
| 172 |
+
gt_source_essentia = essentia.standard.AudioLoader(filename=f"{track}")()[
|
| 173 |
+
0
|
| 174 |
+
] # (nb_samples, nb_channels)
|
| 175 |
+
gt_source_essentia_cat = np.concatenate(
|
| 176 |
+
[gt_source_essentia[:, 0], gt_source_essentia[:, 1]]
|
| 177 |
+
) # (nb_samples * nb_channels)
|
| 178 |
+
gt_source_essentia_mono = np.mean(gt_source_essentia, axis=1) # (nb_samples)
|
| 179 |
+
|
| 180 |
+
rms = np.sqrt(np.mean(gt_source_essentia_cat**2))
|
| 181 |
+
crest_factor = np.max(np.abs(gt_source_essentia_cat)) / rms
|
| 182 |
+
|
| 183 |
+
dc_score, _ = dynamic_complexity(gt_source_essentia_mono)
|
| 184 |
+
_, _, _, lra_score = loudness_range(gt_source_essentia)
|
| 185 |
+
sc_hertz = spectral_centroid(gt_source_essentia_mono)
|
| 186 |
+
sf_score = np.mean(librosa.feature.spectral_flatness(gt_source_librosa_mono))
|
| 187 |
+
spectral_crest_score = np.mean(spectral_crest(y=gt_source_librosa_mono))
|
| 188 |
+
|
| 189 |
+
dict_song_score[audio_name] = {
|
| 190 |
+
"rms": float(rms),
|
| 191 |
+
"crest_factor": float(crest_factor),
|
| 192 |
+
"dynamic_complexity_score": float(dc_score),
|
| 193 |
+
"lra_score": float(lra_score),
|
| 194 |
+
"spectral_centroid_hertz": float(sc_hertz),
|
| 195 |
+
"spectral_flatness_score": float(sf_score),
|
| 196 |
+
"spectral_crest_score": float(spectral_crest_score),
|
| 197 |
+
}
|
| 198 |
+
list_rms.append(rms)
|
| 199 |
+
list_crest_factor.append(crest_factor)
|
| 200 |
+
list_dc_score.append(dc_score)
|
| 201 |
+
list_lra_score.append(lra_score)
|
| 202 |
+
list_sc_hertz.append(sc_hertz)
|
| 203 |
+
list_sf_score.append(sf_score)
|
| 204 |
+
list_spectral_crest_score.append(spectral_crest_score)
|
| 205 |
+
|
| 206 |
+
i += 1
|
| 207 |
+
|
| 208 |
+
if args.calc_results:
|
| 209 |
+
print(f"{args.exp_name} on {args.target}")
|
| 210 |
+
else:
|
| 211 |
+
print(f"{os.path.basename(args.root)} on {args.target}")
|
| 212 |
+
print(f"rms: {np.mean(list_rms)}")
|
| 213 |
+
print(f"crest_factor: {np.mean(list_crest_factor)}")
|
| 214 |
+
print(f"dynamic_complexity_score: {np.mean(list_dc_score)}")
|
| 215 |
+
print(f"lra_score: {np.mean(list_lra_score)}")
|
| 216 |
+
print(f"sc_hertz: {np.mean(list_sc_hertz)}")
|
| 217 |
+
print(f"sf_score: {np.mean(list_sf_score)}")
|
| 218 |
+
print(f"spectral_crest_score: {np.mean(list_spectral_crest_score)}")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# save dict_song_score to json file
|
| 222 |
+
if args.target == "all":
|
| 223 |
+
file_name = "score_features"
|
| 224 |
+
else:
|
| 225 |
+
file_name = f"score_feature_{args.target}"
|
| 226 |
+
if args.calc_results:
|
| 227 |
+
with open(
|
| 228 |
+
f"{args.output_directory}/test/{args.exp_name}/{file_name}.json", "w"
|
| 229 |
+
) as f:
|
| 230 |
+
json.dump(dict_song_score, f, indent=4)
|
| 231 |
+
else:
|
| 232 |
+
with open(f"{args.root}/{file_name}.json", "w") as f:
|
| 233 |
+
json.dump(dict_song_score, f, indent=4)
|
eval_delimit/score_peaq.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# We are going to use PEAQ based on https://github.com/HSU-ANT/gstpeaq
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
python3 score_peaq.py --exp_name=delimit_6_s | tee /path/to/results/delimit_6_s/score_peaq.txt
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
import glob
|
| 12 |
+
import argparse
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def str2bool(v):
|
| 16 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 17 |
+
return True
|
| 18 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 19 |
+
return False
|
| 20 |
+
else:
|
| 21 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 25 |
+
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--target",
|
| 28 |
+
type=str,
|
| 29 |
+
default="all",
|
| 30 |
+
help="target source. all, vocals, drums, bass, other",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--root",
|
| 34 |
+
type=str,
|
| 35 |
+
default="/path/to/musdb_XL_loudnorm",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--output_directory",
|
| 39 |
+
type=str,
|
| 40 |
+
default="/path/to/results/",
|
| 41 |
+
)
|
| 42 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--calc_results",
|
| 45 |
+
type=str2bool,
|
| 46 |
+
default=True,
|
| 47 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
args, _ = parser.parse_known_args()
|
| 51 |
+
|
| 52 |
+
if args.calc_results:
|
| 53 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 54 |
+
else:
|
| 55 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
| 56 |
+
|
| 57 |
+
if args.target == "all":
|
| 58 |
+
song_list = sorted(glob.glob(f"{args.root}/*/mixture.wav"))
|
| 59 |
+
|
| 60 |
+
for song in song_list:
|
| 61 |
+
song_name = os.path.basename(os.path.dirname(song))
|
| 62 |
+
est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
|
| 63 |
+
subprocess.run(
|
| 64 |
+
f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
|
| 65 |
+
shell=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
else:
|
| 69 |
+
song_list = sorted(glob.glob(f"{args.root}/*/{args.target}.wav"))
|
| 70 |
+
|
| 71 |
+
for song in song_list:
|
| 72 |
+
song_name = os.path.basename(os.path.dirname(song))
|
| 73 |
+
est_path = f"{args.test_output_dir}/{song_name}/{args.target}.wav"
|
| 74 |
+
subprocess.run(
|
| 75 |
+
f'peaq --gst-plugin-load=/usr/local/lib/gstreamer-1.0/libgstpeaq.so "{song}" "{est_path}"',
|
| 76 |
+
shell=True,
|
| 77 |
+
)
|
eval_delimit/score_peaq_aggregate.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PEAQ aggregate score
|
| 2 |
+
"""
|
| 3 |
+
/path/to/results/delimit_6_s/score_peaq.txt
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import glob
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def str2bool(v):
|
| 13 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 14 |
+
return True
|
| 15 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 16 |
+
return False
|
| 17 |
+
else:
|
| 18 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 22 |
+
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--target",
|
| 25 |
+
type=str,
|
| 26 |
+
default="all",
|
| 27 |
+
help="target source. all, vocals, drums, bass, other",
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--root",
|
| 31 |
+
type=str,
|
| 32 |
+
default="/path/to/musdb18hq_loudnorm",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--output_directory",
|
| 36 |
+
type=str,
|
| 37 |
+
default="/path/to/results",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--calc_results",
|
| 42 |
+
type=str2bool,
|
| 43 |
+
default=True,
|
| 44 |
+
help="Set this True when you want to calculate the results of the test set. Set this False when calculating musdb-hq vs musdb-XL. (top row in Table 1.)",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
args, _ = parser.parse_known_args()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if args.calc_results:
|
| 51 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 52 |
+
else:
|
| 53 |
+
args.test_output_dir = f"{args.output_directory}/{args.exp_name}"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if args.target == "all":
|
| 57 |
+
score_path = f"{args.test_output_dir}/score_peaq.txt"
|
| 58 |
+
else:
|
| 59 |
+
score_path = f"{args.test_output_dir}/score_peaq_{args.target}.txt"
|
| 60 |
+
|
| 61 |
+
# write the code to load score_peaq.txt
|
| 62 |
+
with open(score_path, "r") as f:
|
| 63 |
+
score_txt = f.readlines()
|
| 64 |
+
|
| 65 |
+
song_list = glob.glob(f"{args.root}/*")
|
| 66 |
+
|
| 67 |
+
dict_song_peaq = {}
|
| 68 |
+
list_peaq = []
|
| 69 |
+
for idx, song in enumerate(song_list):
|
| 70 |
+
song_name = os.path.basename(song)
|
| 71 |
+
peaq = float(score_txt[idx * 2].replace("Objective Difference Grade: ", ""))
|
| 72 |
+
dict_song_peaq[song_name] = peaq
|
| 73 |
+
list_peaq.append(peaq)
|
| 74 |
+
|
| 75 |
+
print(f"{args.exp_name} on {args.target}")
|
| 76 |
+
print(f"PEAQ score: {sum(list_peaq) / len(list_peaq)}")
|
| 77 |
+
|
| 78 |
+
if args.target == "all":
|
| 79 |
+
# save dict_song_peaq to json file
|
| 80 |
+
with open(f"{args.test_output_dir}/score_peaq.json", "w") as f:
|
| 81 |
+
json.dump(dict_song_peaq, f, indent=4)
|
| 82 |
+
else:
|
| 83 |
+
# save dict_song_peaq to json file
|
| 84 |
+
with open(
|
| 85 |
+
f"{args.test_output_dir}/score_peaq_{args.target}.json",
|
| 86 |
+
"w",
|
| 87 |
+
) as f:
|
| 88 |
+
json.dump(dict_song_peaq, f, indent=4)
|
inference.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import glob
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import tqdm
|
| 8 |
+
import librosa
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import pyloudnorm as pyln
|
| 11 |
+
from dotmap import DotMap
|
| 12 |
+
|
| 13 |
+
from models import load_model_with_args
|
| 14 |
+
from separate_func import (
|
| 15 |
+
conv_tasnet_separate,
|
| 16 |
+
)
|
| 17 |
+
from utils import str2bool, db2linear
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
tqdm.monitor_interval = 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def separate_track_with_model(
|
| 24 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
| 25 |
+
):
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
if (
|
| 28 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
| 29 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
| 30 |
+
):
|
| 31 |
+
estimates = conv_tasnet_separate(
|
| 32 |
+
args,
|
| 33 |
+
model,
|
| 34 |
+
device,
|
| 35 |
+
track_audio,
|
| 36 |
+
track_name,
|
| 37 |
+
meter=meter,
|
| 38 |
+
augmented_gain=augmented_gain,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return estimates
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def main():
|
| 45 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 46 |
+
parser.add_argument("--target", type=str, default="all")
|
| 47 |
+
parser.add_argument("--data_root", type=str, default="./input_data")
|
| 48 |
+
parser.add_argument("--weight_directory", type=str, default="./weight")
|
| 49 |
+
parser.add_argument("--output_directory", type=str, default="./output")
|
| 50 |
+
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
| 51 |
+
parser.add_argument("--save_name_as_target", type=str2bool, default=False)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--loudnorm_input_lufs",
|
| 54 |
+
type=float,
|
| 55 |
+
default=None,
|
| 56 |
+
help="If you want to use loudnorm for input",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--save_output_loudnorm",
|
| 60 |
+
type=float,
|
| 61 |
+
default=-14.0,
|
| 62 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--save_mixed_output",
|
| 66 |
+
type=float,
|
| 67 |
+
default=None,
|
| 68 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--save_16k_mono",
|
| 72 |
+
type=str2bool,
|
| 73 |
+
default=False,
|
| 74 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--save_histogram",
|
| 78 |
+
type=str2bool,
|
| 79 |
+
default=False,
|
| 80 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--use_singletrackset",
|
| 84 |
+
type=str2bool,
|
| 85 |
+
default=False,
|
| 86 |
+
help="Use SingleTrackSet if input data is too long.",
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
args, _ = parser.parse_known_args()
|
| 90 |
+
|
| 91 |
+
with open(f"{args.weight_directory}/{args.target}.json", "r") as f:
|
| 92 |
+
args_dict = json.load(f)
|
| 93 |
+
args_dict = DotMap(args_dict)
|
| 94 |
+
|
| 95 |
+
for key, value in args_dict["args"].items():
|
| 96 |
+
if key in list(vars(args).keys()):
|
| 97 |
+
pass
|
| 98 |
+
else:
|
| 99 |
+
setattr(args, key, value)
|
| 100 |
+
|
| 101 |
+
args.test_output_dir = f"{args.output_directory}"
|
| 102 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
device = torch.device(
|
| 105 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
###################### Define Models ######################
|
| 109 |
+
our_model = load_model_with_args(args)
|
| 110 |
+
our_model = our_model.to(device)
|
| 111 |
+
|
| 112 |
+
target_model_path = f"{args.weight_directory}/{args.target}.pth"
|
| 113 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
| 114 |
+
our_model.load_state_dict(checkpoint)
|
| 115 |
+
|
| 116 |
+
our_model.eval()
|
| 117 |
+
|
| 118 |
+
meter = pyln.Meter(44100)
|
| 119 |
+
|
| 120 |
+
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
|
| 121 |
+
f"{args.data_root}/*.mp3"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
for track in tqdm.tqdm(test_tracks):
|
| 125 |
+
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
|
| 126 |
+
track_audio, sr = librosa.load(track, sr=None, mono=False) # sr should be 44100
|
| 127 |
+
|
| 128 |
+
orig_audio = track_audio.copy()
|
| 129 |
+
|
| 130 |
+
if sr != 44100:
|
| 131 |
+
raise ValueError("Sample rate should be 44100")
|
| 132 |
+
augmented_gain = None
|
| 133 |
+
print("Now De-limiting : ", track_name)
|
| 134 |
+
|
| 135 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
| 136 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
| 137 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
| 138 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 139 |
+
|
| 140 |
+
track_audio = (
|
| 141 |
+
torch.as_tensor(track_audio, dtype=torch.float32).unsqueeze(0).to(device)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
estimates = separate_track_with_model(
|
| 145 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if args.save_mixed_output:
|
| 149 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
| 150 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
| 151 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
| 152 |
+
|
| 153 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
| 154 |
+
1 - args.save_mixed_output
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
sf.write(
|
| 158 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
| 159 |
+
mixed_output.T,
|
| 160 |
+
args.data_params.sample_rate,
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|
main_ddp.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from train_ddp import train
|
| 8 |
+
from utils import get_config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser(description="Trainer")
|
| 13 |
+
|
| 14 |
+
# Put every argumnet in './configs/yymmdd_architecture_number.yaml' and load it.
|
| 15 |
+
parser.add_argument(
|
| 16 |
+
"-c",
|
| 17 |
+
"--config",
|
| 18 |
+
default="delimit_6_s",
|
| 19 |
+
type=str,
|
| 20 |
+
help="Name of the setting file.",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
config_args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
args = get_config(config_args.config)
|
| 26 |
+
|
| 27 |
+
args.img_check = (
|
| 28 |
+
f"{args.dir_params.output_directory}/img_check/{args.dir_params.exp_name}"
|
| 29 |
+
)
|
| 30 |
+
args.output = (
|
| 31 |
+
f"{args.dir_params.output_directory}/checkpoint/{args.dir_params.exp_name}"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Set which devices to use
|
| 35 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
| 36 |
+
os.environ["MASTER_PORT"] = str(random.randint(0, 1800))
|
| 37 |
+
|
| 38 |
+
os.makedirs(args.img_check, exist_ok=True)
|
| 39 |
+
os.makedirs(args.output, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
torch.manual_seed(args.sys_params.seed)
|
| 42 |
+
random.seed(args.sys_params.seed)
|
| 43 |
+
|
| 44 |
+
print(args)
|
| 45 |
+
train(args)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .load_models import load_model_with_args
|
models/base_models.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from asteroid.models.base_models import (
|
| 4 |
+
BaseEncoderMaskerDecoder,
|
| 5 |
+
_unsqueeze_to_3d,
|
| 6 |
+
_shape_reconstructed,
|
| 7 |
+
)
|
| 8 |
+
from asteroid.utils.torch_utils import pad_x_to_y, jitable_shape
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseEncoderMaskerDecoderWithConfigs(BaseEncoderMaskerDecoder):
|
| 13 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
| 14 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
| 15 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
| 16 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
| 17 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
| 18 |
+
|
| 19 |
+
def forward(self, wav):
|
| 20 |
+
"""
|
| 21 |
+
Enc/Mask/Dec model forward with some additional options.
|
| 22 |
+
Some of the models we use, like TFC-TDF-UNet, have no masker.
|
| 23 |
+
In UMX or X-UMX, they already use masking in their model implementation.
|
| 24 |
+
Since we do not want to manipulate the model codes, we use this wrapper.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| 31 |
+
"""
|
| 32 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
| 33 |
+
shape = jitable_shape(wav)
|
| 34 |
+
# Reshape to (batch, n_mix, time)
|
| 35 |
+
wav = _unsqueeze_to_3d(wav)
|
| 36 |
+
|
| 37 |
+
# Real forward
|
| 38 |
+
if self.use_encoder:
|
| 39 |
+
tf_rep = self.forward_encoder(wav)
|
| 40 |
+
else:
|
| 41 |
+
tf_rep = wav
|
| 42 |
+
|
| 43 |
+
est_masks = self.forward_masker(tf_rep)
|
| 44 |
+
|
| 45 |
+
if self.apply_mask:
|
| 46 |
+
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
|
| 47 |
+
else: # model already used masking
|
| 48 |
+
masked_tf_rep = est_masks
|
| 49 |
+
|
| 50 |
+
if self.use_decoder:
|
| 51 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
| 52 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
| 53 |
+
|
| 54 |
+
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
|
| 55 |
+
|
| 56 |
+
else: # In UMX or X-UMX, decoder is not used
|
| 57 |
+
decoded = masked_tf_rep
|
| 58 |
+
|
| 59 |
+
return decoded
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class BaseEncoderMaskerDecoder_mixture_consistency(BaseEncoderMaskerDecoder):
|
| 63 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None):
|
| 64 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
| 65 |
+
|
| 66 |
+
def forward(self, wav):
|
| 67 |
+
"""Enc/Mask/Dec model forward with mixture consistent output
|
| 68 |
+
|
| 69 |
+
References:
|
| 70 |
+
[1] : Wisdom, Scott, et al. "Differentiable consistency constraints for improved deep speech enhancement." ICASSP 2019.
|
| 71 |
+
[2] : Wisdom, Scott, et al. "Unsupervised sound separation using mixture invariant training." NeurIPS 2020.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| 78 |
+
"""
|
| 79 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
| 80 |
+
shape = jitable_shape(wav)
|
| 81 |
+
# Reshape to (batch, n_mix, time)
|
| 82 |
+
wav = _unsqueeze_to_3d(wav)
|
| 83 |
+
|
| 84 |
+
# Real forward
|
| 85 |
+
tf_rep = self.forward_encoder(wav)
|
| 86 |
+
est_masks = self.forward_masker(tf_rep)
|
| 87 |
+
masked_tf_rep = self.apply_masks(tf_rep, est_masks)
|
| 88 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
| 89 |
+
|
| 90 |
+
reconstructed = _shape_reconstructed(pad_x_to_y(decoded, wav), shape)
|
| 91 |
+
|
| 92 |
+
reconstructed = reconstructed + 1 / reconstructed.shape[1] * (
|
| 93 |
+
wav - reconstructed.sum(dim=1, keepdim=True)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
return reconstructed
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(BaseEncoderMaskerDecoder):
|
| 100 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
| 101 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
| 102 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
| 103 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
| 104 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
| 105 |
+
self.nb_channels = kwargs.get("nb_channels", 2)
|
| 106 |
+
self.decoder_activation = kwargs.get("decoder_activation", "sigmoid")
|
| 107 |
+
if self.decoder_activation == "sigmoid":
|
| 108 |
+
self.act_after_dec = nn.Sigmoid()
|
| 109 |
+
elif self.decoder_activation == "relu":
|
| 110 |
+
self.act_after_dec = nn.ReLU()
|
| 111 |
+
elif self.decoder_activation == "relu6":
|
| 112 |
+
self.act_after_dec = nn.ReLU6()
|
| 113 |
+
elif self.decoder_activation == "tanh":
|
| 114 |
+
self.act_after_dec = nn.Tanh()
|
| 115 |
+
elif self.decoder_activation == "none":
|
| 116 |
+
self.act_after_dec = nn.Identity()
|
| 117 |
+
else:
|
| 118 |
+
self.act_after_dec = nn.Sigmoid()
|
| 119 |
+
|
| 120 |
+
def forward(self, wav):
|
| 121 |
+
"""
|
| 122 |
+
For the De-limit task, we will apply the mask on the output of the decoder.
|
| 123 |
+
We want decoder to learn the sample-wise ratio of the sources.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| 130 |
+
"""
|
| 131 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
| 132 |
+
shape = jitable_shape(wav)
|
| 133 |
+
# Reshape to (batch, n_mix, time)
|
| 134 |
+
wav = _unsqueeze_to_3d(wav) # (batch, n_channels, time)
|
| 135 |
+
|
| 136 |
+
# Real forward
|
| 137 |
+
if self.use_encoder:
|
| 138 |
+
tf_rep = self.forward_encoder(wav) # (batch, n_channels, freq, time)
|
| 139 |
+
else:
|
| 140 |
+
tf_rep = wav
|
| 141 |
+
|
| 142 |
+
if self.nb_channels == 2:
|
| 143 |
+
tf_rep = rearrange(
|
| 144 |
+
tf_rep, "b c f t -> b (c f) t"
|
| 145 |
+
) # c == 2 when stereo input.
|
| 146 |
+
est_masks = self.forward_masker(tf_rep) # (batch, 1, freq, time)
|
| 147 |
+
|
| 148 |
+
# we are going to apply the mask on the output of the decoder
|
| 149 |
+
if self.use_decoder:
|
| 150 |
+
if self.nb_channels == 2:
|
| 151 |
+
est_masks = rearrange(est_masks, "b 1 f t -> b f t")
|
| 152 |
+
est_masks_decoded = self.forward_decoder(est_masks)
|
| 153 |
+
est_masks_decoded = pad_x_to_y(est_masks_decoded, wav) # (batch, 1, time)
|
| 154 |
+
est_masks_decoded = self.act_after_dec(
|
| 155 |
+
est_masks_decoded
|
| 156 |
+
) # (batch, 1, time)
|
| 157 |
+
decoded = wav * est_masks_decoded # (batch, n_channels, time)
|
| 158 |
+
|
| 159 |
+
return (
|
| 160 |
+
est_masks_decoded,
|
| 161 |
+
decoded,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
else:
|
| 165 |
+
decoded = est_masks
|
| 166 |
+
|
| 167 |
+
return (decoded,)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(BaseEncoderMaskerDecoder):
|
| 171 |
+
def __init__(self, encoder, masker, decoder, encoder_activation=None, **kwargs):
|
| 172 |
+
super().__init__(encoder, masker, decoder, encoder_activation)
|
| 173 |
+
self.use_encoder = kwargs.get("use_encoder", True)
|
| 174 |
+
self.apply_mask = kwargs.get("apply_mask", True)
|
| 175 |
+
self.use_decoder = kwargs.get("use_decoder", True)
|
| 176 |
+
self.nb_channels = kwargs.get("nb_channels", 2)
|
| 177 |
+
self.decoder_activation = kwargs.get("decoder_activation", "none")
|
| 178 |
+
if self.decoder_activation == "sigmoid":
|
| 179 |
+
self.act_after_dec = nn.Sigmoid()
|
| 180 |
+
elif self.decoder_activation == "relu":
|
| 181 |
+
self.act_after_dec = nn.ReLU()
|
| 182 |
+
elif self.decoder_activation == "relu6":
|
| 183 |
+
self.act_after_dec = nn.ReLU6()
|
| 184 |
+
elif self.decoder_activation == "tanh":
|
| 185 |
+
self.act_after_dec = nn.Tanh()
|
| 186 |
+
elif self.decoder_activation == "none":
|
| 187 |
+
self.act_after_dec = nn.Identity()
|
| 188 |
+
else:
|
| 189 |
+
self.act_after_dec = nn.Sigmoid()
|
| 190 |
+
|
| 191 |
+
def forward(self, wav):
|
| 192 |
+
"""
|
| 193 |
+
Enc/Mask/Dec model forward with some additional options.
|
| 194 |
+
For MultiChannel usage of asteroid-based models. (e.g. ConvTasNet)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
wav (torch.Tensor): waveform tensor. 1D, 2D or 3D tensor, time last.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
torch.Tensor, of shape (batch, n_src, time) or (n_src, time).
|
| 202 |
+
"""
|
| 203 |
+
# Remember shape to shape reconstruction, cast to Tensor for torchscript
|
| 204 |
+
shape = jitable_shape(wav)
|
| 205 |
+
# Reshape to (batch, n_mix, time)
|
| 206 |
+
wav = _unsqueeze_to_3d(wav)
|
| 207 |
+
|
| 208 |
+
# Real forward
|
| 209 |
+
if self.use_encoder:
|
| 210 |
+
tf_rep = self.forward_encoder(wav)
|
| 211 |
+
else:
|
| 212 |
+
tf_rep = wav
|
| 213 |
+
|
| 214 |
+
if self.nb_channels == 2:
|
| 215 |
+
tf_rep = rearrange(
|
| 216 |
+
tf_rep, "b c f t -> b (c f) t"
|
| 217 |
+
) # c == 2 when stereo input.
|
| 218 |
+
est_masks = self.forward_masker(tf_rep)
|
| 219 |
+
|
| 220 |
+
if self.nb_channels == 2:
|
| 221 |
+
tf_rep = rearrange(tf_rep, "b (c f) t -> b c f t", c=self.nb_channels)
|
| 222 |
+
|
| 223 |
+
if self.apply_mask:
|
| 224 |
+
# Since original asteroid implementation of masking includes unnecessary unsqueeze operation, we will do it manually.
|
| 225 |
+
masked_tf_rep = est_masks * tf_rep
|
| 226 |
+
else:
|
| 227 |
+
masked_tf_rep = est_masks
|
| 228 |
+
|
| 229 |
+
if self.use_decoder:
|
| 230 |
+
decoded = self.forward_decoder(masked_tf_rep)
|
| 231 |
+
reconstructed = pad_x_to_y(decoded, wav)
|
| 232 |
+
reconstructed = self.act_after_dec(reconstructed)
|
| 233 |
+
|
| 234 |
+
return masked_tf_rep, _shape_reconstructed(reconstructed, shape)
|
| 235 |
+
|
| 236 |
+
else:
|
| 237 |
+
decoded = masked_tf_rep
|
| 238 |
+
|
| 239 |
+
return decoded
|
models/load_models.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from asteroid_filterbanks import make_enc_dec
|
| 5 |
+
|
| 6 |
+
from asteroid.masknn import TDConvNet
|
| 7 |
+
|
| 8 |
+
import utils
|
| 9 |
+
from .base_models import (
|
| 10 |
+
BaseEncoderMaskerDecoderWithConfigs,
|
| 11 |
+
BaseEncoderMaskerDecoderWithConfigsMaskOnOutput,
|
| 12 |
+
BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model_with_args(args):
|
| 17 |
+
if args.model_loss_params.architecture == "conv_tasnet_mask_on_output":
|
| 18 |
+
encoder, decoder = make_enc_dec(
|
| 19 |
+
"free",
|
| 20 |
+
n_filters=args.conv_tasnet_params.n_filters,
|
| 21 |
+
kernel_size=args.conv_tasnet_params.kernel_size,
|
| 22 |
+
stride=args.conv_tasnet_params.stride,
|
| 23 |
+
sample_rate=args.sample_rate,
|
| 24 |
+
)
|
| 25 |
+
masker = TDConvNet(
|
| 26 |
+
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
|
| 27 |
+
n_src=1, # for de-limit task.
|
| 28 |
+
out_chan=encoder.n_feats_out,
|
| 29 |
+
n_blocks=args.conv_tasnet_params.n_blocks,
|
| 30 |
+
n_repeats=args.conv_tasnet_params.n_repeats,
|
| 31 |
+
bn_chan=args.conv_tasnet_params.bn_chan,
|
| 32 |
+
hid_chan=args.conv_tasnet_params.hid_chan,
|
| 33 |
+
skip_chan=args.conv_tasnet_params.skip_chan,
|
| 34 |
+
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
|
| 35 |
+
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
|
| 36 |
+
mask_act=args.conv_tasnet_params.mask_act,
|
| 37 |
+
# causal=args.conv_tasnet_params.causal,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput(
|
| 41 |
+
encoder,
|
| 42 |
+
masker,
|
| 43 |
+
decoder,
|
| 44 |
+
encoder_activation=args.conv_tasnet_params.encoder_activation,
|
| 45 |
+
use_encoder=True,
|
| 46 |
+
apply_mask=True,
|
| 47 |
+
use_decoder=True,
|
| 48 |
+
decoder_activation=args.conv_tasnet_params.decoder_activation,
|
| 49 |
+
)
|
| 50 |
+
model.use_encoder_to_target = False
|
| 51 |
+
|
| 52 |
+
elif args.model_loss_params.architecture == "conv_tasnet":
|
| 53 |
+
encoder, decoder = make_enc_dec(
|
| 54 |
+
"free",
|
| 55 |
+
n_filters=args.conv_tasnet_params.n_filters,
|
| 56 |
+
kernel_size=args.conv_tasnet_params.kernel_size,
|
| 57 |
+
stride=args.conv_tasnet_params.stride,
|
| 58 |
+
sample_rate=args.sample_rate,
|
| 59 |
+
)
|
| 60 |
+
masker = TDConvNet(
|
| 61 |
+
in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo
|
| 62 |
+
n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting.
|
| 63 |
+
out_chan=encoder.n_feats_out,
|
| 64 |
+
n_blocks=args.conv_tasnet_params.n_blocks,
|
| 65 |
+
n_repeats=args.conv_tasnet_params.n_repeats,
|
| 66 |
+
bn_chan=args.conv_tasnet_params.bn_chan,
|
| 67 |
+
hid_chan=args.conv_tasnet_params.hid_chan,
|
| 68 |
+
skip_chan=args.conv_tasnet_params.skip_chan,
|
| 69 |
+
# conv_kernel_size=args.conv_tasnet_params.conv_kernel_size,
|
| 70 |
+
norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN',
|
| 71 |
+
mask_act=args.conv_tasnet_params.mask_act,
|
| 72 |
+
# causal=args.conv_tasnet_params.causal,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid(
|
| 76 |
+
encoder,
|
| 77 |
+
masker,
|
| 78 |
+
decoder,
|
| 79 |
+
encoder_activation=args.conv_tasnet_params.encoder_activation,
|
| 80 |
+
use_encoder=True,
|
| 81 |
+
apply_mask=False if args.conv_tasnet_params.synthesis else True,
|
| 82 |
+
use_decoder=True,
|
| 83 |
+
decoder_activation=args.conv_tasnet_params.decoder_activation,
|
| 84 |
+
)
|
| 85 |
+
model.use_encoder_to_target = False
|
| 86 |
+
|
| 87 |
+
return model
|
prepro/delimit_save_delimiter_stems.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Save loudness normalized (-14 LUFS) musdb-XL audio files for delimiter evaluation
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import tqdm
|
| 7 |
+
import musdb
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import librosa
|
| 10 |
+
import pyloudnorm as pyln
|
| 11 |
+
|
| 12 |
+
from utils import db2linear, str2bool
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
tqdm.monitor_interval = 0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 20 |
+
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--target",
|
| 23 |
+
type=str,
|
| 24 |
+
default="vocals",
|
| 25 |
+
help="target source. all, vocals, drums, bass, other",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--data_root_hq",
|
| 30 |
+
type=str,
|
| 31 |
+
default="/data1/Music/musdb18hq",
|
| 32 |
+
help="this is used when saving loud-norm stem of musdb-XL")
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--output_directory",
|
| 35 |
+
type=str,
|
| 36 |
+
default="/path/to/results",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s")
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--save_16k_mono",
|
| 41 |
+
type=str2bool,
|
| 42 |
+
default=False,
|
| 43 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
args, _ = parser.parse_known_args()
|
| 48 |
+
|
| 49 |
+
os.makedirs(args.output_directory, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
meter = pyln.Meter(44100)
|
| 52 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 53 |
+
|
| 54 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
| 55 |
+
if args.target != "mixture": # In this file, args.target should not be "mixture"
|
| 56 |
+
hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
|
| 57 |
+
|
| 58 |
+
for idx, track in tqdm.tqdm(enumerate(test_tracks)):
|
| 59 |
+
track_name = track.name
|
| 60 |
+
if (
|
| 61 |
+
os.path.basename(args.data_root) == "musdb18hq"
|
| 62 |
+
and track_name == "PR - Oh No"
|
| 63 |
+
): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
|
| 64 |
+
# Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
|
| 65 |
+
track_audio = (
|
| 66 |
+
track.targets["vocals"].audio
|
| 67 |
+
+ track.targets["drums"].audio
|
| 68 |
+
+ track.targets["bass"].audio
|
| 69 |
+
+ track.targets["other"].audio
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
track_audio = track.audio
|
| 73 |
+
|
| 74 |
+
delimiter_track = librosa.load(f"{args.test_output_dir}/{track_name}/all.wav", sr=44100, mono=False)[0].T
|
| 75 |
+
|
| 76 |
+
print(track_name)
|
| 77 |
+
|
| 78 |
+
if args.target != "mixture":
|
| 79 |
+
hq_track = hq_tracks[idx]
|
| 80 |
+
hq_audio = hq_track.audio
|
| 81 |
+
hq_stem = hq_track.targets[args.target].audio
|
| 82 |
+
hq_samplewise_gain = track_audio / (hq_audio + 1e-8)
|
| 83 |
+
XL_stem = hq_samplewise_gain * hq_stem
|
| 84 |
+
XL_samplewise_gain = delimiter_track / (track_audio + 1e-8)
|
| 85 |
+
delimiter_stem = XL_samplewise_gain * XL_stem
|
| 86 |
+
|
| 87 |
+
sf.write(
|
| 88 |
+
f"{args.test_output_dir}/{track_name}/{args.target}.wav", delimiter_stem, 44100
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
prepro/delimit_save_musdb_loudnorm.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Save loudness normalized (-14 LUFS) musdb-XL audio files for evaluations of de-limiter
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
import tqdm
|
| 7 |
+
import musdb
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import librosa
|
| 10 |
+
import pyloudnorm as pyln
|
| 11 |
+
|
| 12 |
+
from utils import db2linear, str2bool
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
tqdm.monitor_interval = 0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 20 |
+
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--target",
|
| 23 |
+
type=str,
|
| 24 |
+
default="mixture",
|
| 25 |
+
help="target source. all, vocals, drums, bass, other",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--data_root_hq",
|
| 30 |
+
type=str,
|
| 31 |
+
default="/path/to/musdb18hq",
|
| 32 |
+
help="this is used when saving loud-norm stem of musdb-XL")
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--output_directory",
|
| 35 |
+
type=str,
|
| 36 |
+
default="/path/to/musdb_XL_loudnorm",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--loudnorm_input_lufs",
|
| 40 |
+
type=float,
|
| 41 |
+
default=-14.0,
|
| 42 |
+
help="If you want to use loudnorm, input target lufs",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--save_16k_mono",
|
| 46 |
+
type=str2bool,
|
| 47 |
+
default=True,
|
| 48 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
args, _ = parser.parse_known_args()
|
| 53 |
+
|
| 54 |
+
os.makedirs(args.output_directory, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
meter = pyln.Meter(44100)
|
| 57 |
+
|
| 58 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
| 59 |
+
if args.target != "mixture":
|
| 60 |
+
hq_tracks = musdb.DB(root=args.data_root_hq, subsets='test', is_wav=True)
|
| 61 |
+
|
| 62 |
+
for idx, track in tqdm.tqdm(enumerate(test_tracks)):
|
| 63 |
+
track_name = track.name
|
| 64 |
+
if (
|
| 65 |
+
os.path.basename(args.data_root) == "musdb18hq"
|
| 66 |
+
and track_name == "PR - Oh No"
|
| 67 |
+
): # We have to consider this exception because 'PR - Oh No' mixture.wav is left-panned. We will use the linear mixture instead.
|
| 68 |
+
# Please refer https://github.com/jeonchangbin49/musdb-XL/blob/main/make_L_and_XL.py
|
| 69 |
+
track_audio = (
|
| 70 |
+
track.targets["vocals"].audio
|
| 71 |
+
+ track.targets["drums"].audio
|
| 72 |
+
+ track.targets["bass"].audio
|
| 73 |
+
+ track.targets["other"].audio
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
track_audio = track.audio
|
| 77 |
+
|
| 78 |
+
print(track_name)
|
| 79 |
+
|
| 80 |
+
augmented_gain = None
|
| 81 |
+
|
| 82 |
+
track_lufs = meter.integrated_loudness(track_audio)
|
| 83 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
| 84 |
+
if os.path.basename(args.data_root) == "musdb18hq":
|
| 85 |
+
if args.target != "mixture":
|
| 86 |
+
track_audio = track.targets[args.target].audio
|
| 87 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 88 |
+
elif os.path.basename(args.data_root) == "musdb_XL":
|
| 89 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 90 |
+
if args.target != "mixture":
|
| 91 |
+
hq_track = hq_tracks[idx]
|
| 92 |
+
hq_audio = hq_track.audio
|
| 93 |
+
hq_stem = hq_track.targets[args.target].audio
|
| 94 |
+
samplewise_gain = track_audio / (hq_audio + 1e-8)
|
| 95 |
+
track_audio = samplewise_gain * hq_stem
|
| 96 |
+
|
| 97 |
+
os.makedirs(f"{args.output_directory}/{track_name}", exist_ok=True)
|
| 98 |
+
sf.write(
|
| 99 |
+
f"{args.output_directory}/{track_name}/{args.target}.wav", track_audio, 44100
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if args.save_16k_mono:
|
| 103 |
+
track_audio_16k_mono = librosa.to_mono(track_audio.T)
|
| 104 |
+
track_audio_16k_mono = librosa.resample(
|
| 105 |
+
track_audio_16k_mono,
|
| 106 |
+
orig_sr=44100,
|
| 107 |
+
target_sr=16000,
|
| 108 |
+
)
|
| 109 |
+
os.makedirs(f"{args.output_directory}_16k_mono/{track_name}", exist_ok=True)
|
| 110 |
+
sf.write(
|
| 111 |
+
f"{args.output_directory}_16k_mono/{track_name}/{args.target}.wav",
|
| 112 |
+
track_audio_16k_mono,
|
| 113 |
+
samplerate=16000,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
main()
|
prepro/delimit_train_ozone_prepro.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import csv
|
| 4 |
+
import glob
|
| 5 |
+
import argparse
|
| 6 |
+
import random
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import librosa
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
import pedalboard
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pyloudnorm as pyln
|
| 14 |
+
from scipy.stats import gamma
|
| 15 |
+
import torchaudio
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def str2bool(v):
|
| 19 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 20 |
+
return True
|
| 21 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 22 |
+
return False
|
| 23 |
+
else:
|
| 24 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _augment_gain_ozone(audio, low=0.25, high=1.25):
|
| 28 |
+
"""Applies a random gain between `low` and `high`"""
|
| 29 |
+
g = low + random.random() * (high - low)
|
| 30 |
+
return audio * g, g
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _augment_channelswap_ozone(audio):
|
| 34 |
+
"""Swap channels of stereo signals with a probability of p=0.5"""
|
| 35 |
+
if audio.shape[0] == 2 and random.random() < 0.5:
|
| 36 |
+
return np.flip(audio, axis=0), True # axis=0 must be given
|
| 37 |
+
else:
|
| 38 |
+
return audio, False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# load wav file from arbitrary positions of 16bit stereo wav file
|
| 42 |
+
def load_wav_arbitrary_position_stereo(
|
| 43 |
+
filename, sample_rate, seq_duration, return_pos=False
|
| 44 |
+
):
|
| 45 |
+
# stereo
|
| 46 |
+
# seq_duration[second]
|
| 47 |
+
length = torchaudio.info(filename).num_frames
|
| 48 |
+
|
| 49 |
+
random_start = random.randint(
|
| 50 |
+
0, int(length - math.ceil(seq_duration * sample_rate) - 1)
|
| 51 |
+
)
|
| 52 |
+
random_start_sec = librosa.samples_to_time(random_start, sr=sample_rate)
|
| 53 |
+
X, sr = librosa.load(
|
| 54 |
+
filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if return_pos:
|
| 58 |
+
return X, random_start_sec
|
| 59 |
+
else:
|
| 60 |
+
return X
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# def main():
|
| 64 |
+
parser = argparse.ArgumentParser(description="Preprocess audio files for training")
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--root",
|
| 67 |
+
type=str,
|
| 68 |
+
default="/path/to/musdb18hq",
|
| 69 |
+
help="Root directory",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--output",
|
| 73 |
+
type=str,
|
| 74 |
+
default="/path/to/musdb-XL-train",
|
| 75 |
+
help="Where to save output files",
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"--n_samples", type=int, default=300000, help="Number of samples to save"
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument("--seq_duration", type=float, default=4.0, help="Sequence duration")
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--save_fixed", type=str2bool, default=False, help="Save fixed mixture audio"
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--target_lufs_mean", type=float, default=-8.0, help="Target LUFS mean"
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--target_lufs_std", type=float, default=-1.0, help="Target LUFS std"
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument("--sample_rate", type=int, default=44100, help="Sample rate")
|
| 91 |
+
parser.add_argument("--seed", type=int, default=46, help="Random seed")
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
random.seed(args.seed)
|
| 94 |
+
|
| 95 |
+
valid_list = [
|
| 96 |
+
"ANiMAL - Rockshow",
|
| 97 |
+
"Actions - One Minute Smile",
|
| 98 |
+
"Alexander Ross - Goodbye Bolero",
|
| 99 |
+
"Clara Berry And Wooldog - Waltz For My Victims",
|
| 100 |
+
"Fergessen - Nos Palpitants",
|
| 101 |
+
"James May - On The Line",
|
| 102 |
+
"Johnny Lokke - Promises & Lies",
|
| 103 |
+
"Leaf - Summerghost",
|
| 104 |
+
"Meaxic - Take A Step",
|
| 105 |
+
"Patrick Talbot - A Reason To Leave",
|
| 106 |
+
"Skelpolu - Human Mistakes",
|
| 107 |
+
"Traffic Experiment - Sirens",
|
| 108 |
+
"Triviul - Angelsaint",
|
| 109 |
+
"Young Griffo - Pennies",
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
meter = pyln.Meter(args.sample_rate)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
sources = ["vocals", "bass", "drums", "other"]
|
| 116 |
+
song_list = glob.glob(f"{args.root}/train/*")
|
| 117 |
+
|
| 118 |
+
vst = pedalboard.load_plugin(
|
| 119 |
+
"/Library/Audio/Plug-Ins/Components/iZOzone9ElementsAUHook.component"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if args.save_fixed:
|
| 123 |
+
vst_params = []
|
| 124 |
+
|
| 125 |
+
os.makedirs(f"{args.output}/ozone_train_fixed", exist_ok=True)
|
| 126 |
+
|
| 127 |
+
for song in song_list:
|
| 128 |
+
print(f"Processing {song}...")
|
| 129 |
+
song_name = os.path.basename(song)
|
| 130 |
+
audio_sources = []
|
| 131 |
+
for source in sources:
|
| 132 |
+
audio_path = f"{song}/{source}.wav"
|
| 133 |
+
audio, sr = librosa.load(audio_path, sr=args.sample_rate, mono=False)
|
| 134 |
+
audio_sources.append(audio)
|
| 135 |
+
stems = np.stack(audio_sources, axis=0)
|
| 136 |
+
mixture = stems.sum(0)
|
| 137 |
+
lufs = meter.integrated_loudness(mixture.T)
|
| 138 |
+
target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
|
| 139 |
+
adjusted_loudness = target_lufs - lufs
|
| 140 |
+
|
| 141 |
+
vst.reset()
|
| 142 |
+
vst.eq_bypass = True
|
| 143 |
+
vst.img_bypass = True
|
| 144 |
+
vst.max_mode = 1.0 # Set IRC2 mode
|
| 145 |
+
vst.max_threshold = min(-adjusted_loudness, 0.0)
|
| 146 |
+
vst.max_character = min(gamma.rvs(2), 10.0)
|
| 147 |
+
|
| 148 |
+
print(
|
| 149 |
+
f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
|
| 150 |
+
)
|
| 151 |
+
limited_mixture = vst(mixture, args.sample_rate)
|
| 152 |
+
|
| 153 |
+
sf.write(
|
| 154 |
+
f"{args.output}/ozone_train_fixed/{song_name}.wav",
|
| 155 |
+
limited_mixture.T,
|
| 156 |
+
args.sample_rate,
|
| 157 |
+
)
|
| 158 |
+
vst_params.append([song_name, vst.max_threshold, vst.max_character])
|
| 159 |
+
# Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
|
| 160 |
+
with open(f"{args.output}/ozone_train_fixed.csv", "w") as f:
|
| 161 |
+
writer = csv.writer(f)
|
| 162 |
+
writer.writerow(["song_name", "max_threshold", "max_character"])
|
| 163 |
+
for idx, list_vst_param in enumerate(vst_params):
|
| 164 |
+
writer.writerow(list_vst_param)
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
if os.path.exists(f"{args.output}/ozone_train_random_0.csv"):
|
| 168 |
+
vst_params = []
|
| 169 |
+
list_csv_files = glob.glob(f"{args.output}/ozone_train_random_*.csv")
|
| 170 |
+
list_csv_files.sort()
|
| 171 |
+
for csv_file in list_csv_files:
|
| 172 |
+
with open(csv_file, "r") as f:
|
| 173 |
+
reader = csv.reader(f)
|
| 174 |
+
next(reader)
|
| 175 |
+
vst_params.extend([row for row in reader])
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
vst_params = []
|
| 179 |
+
|
| 180 |
+
song_list = [x for x in song_list if os.path.basename(x) not in valid_list]
|
| 181 |
+
|
| 182 |
+
os.makedirs(f"{args.output}/ozone_train_random", exist_ok=True)
|
| 183 |
+
|
| 184 |
+
for n in range(len(vst_params), args.n_samples):
|
| 185 |
+
print(f"Processing {n} / {args.n_samples}...")
|
| 186 |
+
seg_name = f"ozone_seg_{n}"
|
| 187 |
+
|
| 188 |
+
lufs_not_inf = True
|
| 189 |
+
while lufs_not_inf:
|
| 190 |
+
audio_sources = []
|
| 191 |
+
source_song_names = {}
|
| 192 |
+
source_start_secs = {}
|
| 193 |
+
source_gains = {}
|
| 194 |
+
source_channelswaps = {}
|
| 195 |
+
for source in sources:
|
| 196 |
+
track_path = random.choice(song_list)
|
| 197 |
+
song_name = os.path.basename(track_path)
|
| 198 |
+
audio_path = f"{track_path}/{source}.wav"
|
| 199 |
+
audio, start_sec = load_wav_arbitrary_position_stereo(
|
| 200 |
+
audio_path, args.sample_rate, args.seq_duration, return_pos=True
|
| 201 |
+
)
|
| 202 |
+
audio, gain = _augment_gain_ozone(audio)
|
| 203 |
+
audio, channelswap = _augment_channelswap_ozone(audio)
|
| 204 |
+
audio_sources.append(audio)
|
| 205 |
+
source_song_names[source] = song_name
|
| 206 |
+
source_start_secs[source] = start_sec
|
| 207 |
+
source_gains[source] = gain
|
| 208 |
+
source_channelswaps[source] = channelswap
|
| 209 |
+
|
| 210 |
+
stems = np.stack(audio_sources, axis=0)
|
| 211 |
+
mixture = stems.sum(0)
|
| 212 |
+
lufs = meter.integrated_loudness(mixture.T)
|
| 213 |
+
|
| 214 |
+
# if lufs is inf, then the mixture is silent, so we need to generate a new mixture
|
| 215 |
+
lufs_not_inf = np.isinf(lufs)
|
| 216 |
+
|
| 217 |
+
target_lufs = random.gauss(args.target_lufs_mean, args.target_lufs_std)
|
| 218 |
+
adjusted_loudness = target_lufs - lufs
|
| 219 |
+
|
| 220 |
+
vst.reset()
|
| 221 |
+
vst.eq_bypass = True
|
| 222 |
+
vst.img_bypass = True
|
| 223 |
+
vst.max_mode = 1.0 # Set IRC2 mode
|
| 224 |
+
vst.max_threshold = min(max(-20, -adjusted_loudness), 0.0)
|
| 225 |
+
vst.max_character = min(gamma.rvs(2), 10.0)
|
| 226 |
+
|
| 227 |
+
print(
|
| 228 |
+
f"Applying Ozone 9 Elements IRC2 with threshold {vst.max_threshold} and character {vst.max_character}..."
|
| 229 |
+
)
|
| 230 |
+
limited_mixture = vst(mixture, args.sample_rate)
|
| 231 |
+
|
| 232 |
+
sf.write(
|
| 233 |
+
f"{args.output}/ozone_train_random_0/{seg_name}.wav",
|
| 234 |
+
limited_mixture.T,
|
| 235 |
+
args.sample_rate,
|
| 236 |
+
)
|
| 237 |
+
vst_params.append(
|
| 238 |
+
[
|
| 239 |
+
seg_name,
|
| 240 |
+
vst.max_threshold,
|
| 241 |
+
vst.max_character,
|
| 242 |
+
source_song_names["vocals"],
|
| 243 |
+
source_start_secs["vocals"],
|
| 244 |
+
source_gains["vocals"],
|
| 245 |
+
source_channelswaps["vocals"],
|
| 246 |
+
source_song_names["bass"],
|
| 247 |
+
source_start_secs["bass"],
|
| 248 |
+
source_gains["bass"],
|
| 249 |
+
source_channelswaps["bass"],
|
| 250 |
+
source_song_names["drums"],
|
| 251 |
+
source_start_secs["drums"],
|
| 252 |
+
source_gains["drums"],
|
| 253 |
+
source_channelswaps["drums"],
|
| 254 |
+
source_song_names["other"],
|
| 255 |
+
source_start_secs["other"],
|
| 256 |
+
source_gains["other"],
|
| 257 |
+
source_channelswaps["other"],
|
| 258 |
+
]
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if (n + 1) % 20000 == 0 or n == args.n_samples - 1:
|
| 262 |
+
# We will separate the csv file into multiple files to avoid memory error
|
| 263 |
+
# Save the song name and vst parameters (vst.max_threshold and vst.max_character) to a csv file
|
| 264 |
+
number = int(n // 20000)
|
| 265 |
+
with open(f"{args.output}/ozone_train_random_{number}.csv", "w") as f:
|
| 266 |
+
writer = csv.writer(f)
|
| 267 |
+
writer.writerow(
|
| 268 |
+
[
|
| 269 |
+
"song_name",
|
| 270 |
+
"max_threshold",
|
| 271 |
+
"max_character",
|
| 272 |
+
"vocals_name",
|
| 273 |
+
"vocals_start_sec",
|
| 274 |
+
"vocals_gain",
|
| 275 |
+
"vocals_channelswap",
|
| 276 |
+
"bass_name",
|
| 277 |
+
"bass_start_sec",
|
| 278 |
+
"bass_gain",
|
| 279 |
+
"bass_channelswap",
|
| 280 |
+
"drums_name",
|
| 281 |
+
"drums_start_sec",
|
| 282 |
+
"drums_gain",
|
| 283 |
+
"drums_channelswap",
|
| 284 |
+
"other_name",
|
| 285 |
+
"other_start_sec",
|
| 286 |
+
"other_gain",
|
| 287 |
+
"other_channelswap",
|
| 288 |
+
]
|
| 289 |
+
)
|
| 290 |
+
for idx, list_vst_param in enumerate(
|
| 291 |
+
vst_params[number * 20000 : (number + 1) * 20000]
|
| 292 |
+
):
|
| 293 |
+
writer.writerow(list_vst_param)
|
prepro/delimit_valid_L_prepro.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import tqdm
|
| 7 |
+
|
| 8 |
+
from dataloader import DelimitValidDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
# Parameters
|
| 13 |
+
data_path = "/path/to/musdb18hq"
|
| 14 |
+
save_path = "/path/to/musdb18hq_limited_L"
|
| 15 |
+
batch_size = 1
|
| 16 |
+
num_workers = 1
|
| 17 |
+
sr = 44100
|
| 18 |
+
|
| 19 |
+
# Dataset
|
| 20 |
+
dataset = DelimitValidDataset(root=data_path, valid_target_lufs=-14.39)
|
| 21 |
+
data_loader = DataLoader(
|
| 22 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
| 23 |
+
)
|
| 24 |
+
dict_valid_loudness = {}
|
| 25 |
+
# Preprocessing
|
| 26 |
+
for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
|
| 27 |
+
audio_name = audio_name[0]
|
| 28 |
+
limited_audio = limited_audio[0].numpy()
|
| 29 |
+
loudness = float(loudness[0].numpy())
|
| 30 |
+
dict_valid_loudness[audio_name] = loudness
|
| 31 |
+
# Save audio
|
| 32 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
| 33 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
| 34 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
| 35 |
+
# write json write code
|
| 36 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
| 37 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
prepro/delimit_valid_custom_limiter_prepro.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import tqdm
|
| 7 |
+
|
| 8 |
+
from dataloader import DelimitValidDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
# Parameters
|
| 13 |
+
data_path = "/path/to/musdb18hq"
|
| 14 |
+
save_path = (
|
| 15 |
+
"/path/to/musdb18hq_custom_limiter_fixed_attack"
|
| 16 |
+
)
|
| 17 |
+
batch_size = 1
|
| 18 |
+
num_workers = 1
|
| 19 |
+
sr = 44100
|
| 20 |
+
|
| 21 |
+
# Dataset
|
| 22 |
+
dataset = DelimitValidDataset(
|
| 23 |
+
root=data_path, use_custom_limiter=True, custom_limiter_attack_range=[2.0, 2.0]
|
| 24 |
+
)
|
| 25 |
+
data_loader = DataLoader(
|
| 26 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
| 27 |
+
)
|
| 28 |
+
dict_valid_loudness = {}
|
| 29 |
+
dict_limiter_params = {}
|
| 30 |
+
# Preprocessing
|
| 31 |
+
for (
|
| 32 |
+
limited_audio,
|
| 33 |
+
orig_audio,
|
| 34 |
+
audio_name,
|
| 35 |
+
loudness,
|
| 36 |
+
custom_attack,
|
| 37 |
+
custom_release,
|
| 38 |
+
) in tqdm.tqdm(data_loader):
|
| 39 |
+
audio_name = audio_name[0]
|
| 40 |
+
limited_audio = limited_audio[0].numpy()
|
| 41 |
+
loudness = float(loudness[0].numpy())
|
| 42 |
+
dict_valid_loudness[audio_name] = loudness
|
| 43 |
+
dict_limiter_params[audio_name] = {
|
| 44 |
+
"attack_ms": float(custom_attack[0].numpy()),
|
| 45 |
+
"release_ms": float(custom_release[0].numpy()),
|
| 46 |
+
}
|
| 47 |
+
# Save audio
|
| 48 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
| 49 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
| 50 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
| 51 |
+
# write json write code
|
| 52 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
| 53 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
| 54 |
+
with open(os.path.join(save_path, "valid_limiter_params.json"), "w") as f:
|
| 55 |
+
json.dump(dict_limiter_params, f, indent=4)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
main()
|
prepro/delimit_valid_prepro.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import soundfile as sf
|
| 6 |
+
import tqdm
|
| 7 |
+
|
| 8 |
+
from dataloader import DelimitValidDataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
# Parameters
|
| 13 |
+
data_path = "/path/to/musdb18hq"
|
| 14 |
+
save_path = "/path/to/musdb18hq_limited"
|
| 15 |
+
batch_size = 1
|
| 16 |
+
num_workers = 1
|
| 17 |
+
sr = 44100
|
| 18 |
+
|
| 19 |
+
# Dataset
|
| 20 |
+
dataset = DelimitValidDataset(root=data_path)
|
| 21 |
+
data_loader = DataLoader(
|
| 22 |
+
dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
|
| 23 |
+
)
|
| 24 |
+
dict_valid_loudness = {}
|
| 25 |
+
# Preprocessing
|
| 26 |
+
for limited_audio, orig_audio, audio_name, loudness in tqdm.tqdm(data_loader):
|
| 27 |
+
audio_name = audio_name[0]
|
| 28 |
+
limited_audio = limited_audio[0].numpy()
|
| 29 |
+
loudness = float(loudness[0].numpy())
|
| 30 |
+
dict_valid_loudness[audio_name] = loudness
|
| 31 |
+
# Save audio
|
| 32 |
+
os.makedirs(os.path.join(save_path, "valid"), exist_ok=True)
|
| 33 |
+
audio_path = os.path.join(save_path, "valid", audio_name)
|
| 34 |
+
sf.write(f"{audio_path}.wav", limited_audio.T, sr)
|
| 35 |
+
# write json write code
|
| 36 |
+
with open(os.path.join(save_path, "valid_loudness.json"), "w") as f:
|
| 37 |
+
json.dump(dict_valid_loudness, f, indent=4)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/asteroid-team/asteroid.git@master
|
| 2 |
+
numpy
|
| 3 |
+
librosa
|
| 4 |
+
soundfile
|
| 5 |
+
torch
|
| 6 |
+
torchaudio
|
| 7 |
+
matplotlib
|
| 8 |
+
wandb
|
| 9 |
+
musdb
|
| 10 |
+
dotmap
|
| 11 |
+
ema-pytorch
|
| 12 |
+
pedalboard
|
| 13 |
+
einops
|
separate_func/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .conv_tasnet_separate import conv_tasnet_separate
|
separate_func/conv_tasnet_separate.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import soundfile as sf
|
| 4 |
+
import torch
|
| 5 |
+
import pyloudnorm as pyln
|
| 6 |
+
import librosa
|
| 7 |
+
import matplotlib
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
from dataloader import SingleTrackSet
|
| 11 |
+
from utils import db2linear
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def conv_tasnet_separate(
|
| 15 |
+
args, our_model, device, track_audio, track_name, meter=None, augmented_gain=None
|
| 16 |
+
):
|
| 17 |
+
|
| 18 |
+
if args.use_singletrackset:
|
| 19 |
+
db = SingleTrackSet(
|
| 20 |
+
track_audio.squeeze(dim=0),
|
| 21 |
+
hop_length=args.data_params.nhop,
|
| 22 |
+
num_frame=128,
|
| 23 |
+
target_name=args.target,
|
| 24 |
+
)
|
| 25 |
+
separated = []
|
| 26 |
+
|
| 27 |
+
for item in db:
|
| 28 |
+
item = item.unsqueeze(0).to(device)
|
| 29 |
+
estimates, *estimates_vars = our_model(item)
|
| 30 |
+
if args.task_params.dataset == "delimit":
|
| 31 |
+
estimates = estimates_vars[0]
|
| 32 |
+
|
| 33 |
+
estimates = estimates.cpu().detach()
|
| 34 |
+
separated.append(
|
| 35 |
+
estimates[..., db.trim_length : -db.trim_length].cpu().detach().clone()
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
estimates = torch.cat(separated, dim=-1)
|
| 39 |
+
estimates = estimates[0, :, : track_audio.shape[-1]].numpy()
|
| 40 |
+
else:
|
| 41 |
+
estimates, *estimates_vars = our_model(track_audio)
|
| 42 |
+
if args.save_histogram and args.task_params.dataset == "delimit":
|
| 43 |
+
plt.figure(figsize=(10, 10))
|
| 44 |
+
plt.hist(estimates.cpu().detach().numpy().flatten(), bins=100)
|
| 45 |
+
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
|
| 46 |
+
plt.savefig(
|
| 47 |
+
f"{args.test_output_dir}/{track_name}/{args.target}_histogram.png"
|
| 48 |
+
)
|
| 49 |
+
if args.task_params.dataset == "delimit":
|
| 50 |
+
estimates = estimates_vars[0]
|
| 51 |
+
|
| 52 |
+
estimates = estimates.cpu().detach().numpy()
|
| 53 |
+
estimates = estimates[0, :, : track_audio.shape[-1]]
|
| 54 |
+
|
| 55 |
+
if args.save_name_as_target:
|
| 56 |
+
os.makedirs(f"{args.test_output_dir}/{track_name}", exist_ok=True)
|
| 57 |
+
|
| 58 |
+
if args.save_output_loudnorm:
|
| 59 |
+
print("SAVE Loudness normalized OUTPUT ")
|
| 60 |
+
loudness = meter.integrated_loudness(estimates.T)
|
| 61 |
+
estimates = estimates * db2linear(args.save_output_loudnorm - loudness, eps=0.0)
|
| 62 |
+
elif augmented_gain != None and args.save_output_loudnorm == None:
|
| 63 |
+
estimates = estimates * db2linear(-augmented_gain, eps=0.0)
|
| 64 |
+
|
| 65 |
+
sf.write(
|
| 66 |
+
f"{args.test_output_dir}/{track_name}/{args.target}.wav"
|
| 67 |
+
if args.save_name_as_target
|
| 68 |
+
else f"{args.test_output_dir}/{track_name}.wav",
|
| 69 |
+
estimates.T,
|
| 70 |
+
samplerate=args.data_params.sample_rate,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
if args.save_16k_mono:
|
| 74 |
+
estimates_16k_mono = librosa.to_mono(estimates)
|
| 75 |
+
estimates_16k_mono = librosa.resample(
|
| 76 |
+
estimates_16k_mono,
|
| 77 |
+
orig_sr=args.data_params.sample_rate,
|
| 78 |
+
target_sr=16000,
|
| 79 |
+
)
|
| 80 |
+
os.makedirs(f"{args.test_output_dir}_16k_mono/{track_name}", exist_ok=True)
|
| 81 |
+
sf.write(
|
| 82 |
+
f"{args.test_output_dir}_16k_mono/{track_name}/{args.target}.wav"
|
| 83 |
+
if args.save_name_as_target
|
| 84 |
+
else f"{args.test_output_dir}_16k_mono/{track_name}.wav",
|
| 85 |
+
estimates_16k_mono,
|
| 86 |
+
samplerate=16000,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return estimates
|
solver_ddp.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import wandb
|
| 7 |
+
import matplotlib
|
| 8 |
+
|
| 9 |
+
matplotlib.use("Agg")
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 13 |
+
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
| 14 |
+
from asteroid.losses import (
|
| 15 |
+
pairwise_neg_sisdr,
|
| 16 |
+
PairwiseNegSDR,
|
| 17 |
+
)
|
| 18 |
+
from einops import rearrange, reduce
|
| 19 |
+
from ema_pytorch import EMA
|
| 20 |
+
|
| 21 |
+
from models import load_model_with_args
|
| 22 |
+
import utils
|
| 23 |
+
from dataloader import (
|
| 24 |
+
MusdbTrainDataset,
|
| 25 |
+
MusdbValidDataset,
|
| 26 |
+
DelimitTrainDataset,
|
| 27 |
+
DelimitValidDataset,
|
| 28 |
+
OzoneTrainDataset,
|
| 29 |
+
OzoneValidDataset,
|
| 30 |
+
aug_from_str,
|
| 31 |
+
SingleTrackSet,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Solver(object):
|
| 36 |
+
def __init__(self):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def set_gpu(self, args):
|
| 40 |
+
|
| 41 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
| 42 |
+
if args.wandb_params.sweep:
|
| 43 |
+
wandb.init(
|
| 44 |
+
entity=args.wandb_params.entity,
|
| 45 |
+
project=args.wandb_params.project,
|
| 46 |
+
config=args,
|
| 47 |
+
resume=True
|
| 48 |
+
if args.dir_params.resume != None and args.gpu == 0
|
| 49 |
+
else False,
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
wandb.init(
|
| 53 |
+
entity=args.wandb_params.entity,
|
| 54 |
+
project=args.wandb_params.project,
|
| 55 |
+
name=f"{args.dir_params.exp_name}",
|
| 56 |
+
config=args,
|
| 57 |
+
resume="must"
|
| 58 |
+
if args.dir_params.resume != None
|
| 59 |
+
and not args.dir_params.continual_train
|
| 60 |
+
else False,
|
| 61 |
+
id=args.wandb_params.rerun_id
|
| 62 |
+
if args.wandb_params.rerun_id
|
| 63 |
+
else None,
|
| 64 |
+
settings=wandb.Settings(start_method="fork"),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
###################### Define Models ######################
|
| 68 |
+
self.model = load_model_with_args(args)
|
| 69 |
+
|
| 70 |
+
trainable_params = []
|
| 71 |
+
trainable_params = trainable_params + list(self.model.parameters())
|
| 72 |
+
|
| 73 |
+
if args.hyperparams.optimizer == "sgd":
|
| 74 |
+
print("Use SGD optimizer.")
|
| 75 |
+
self.optimizer = torch.optim.SGD(
|
| 76 |
+
params=trainable_params,
|
| 77 |
+
lr=args.hyperparams.lr,
|
| 78 |
+
momentum=0.9,
|
| 79 |
+
weight_decay=args.hyperparams.weight_decay,
|
| 80 |
+
)
|
| 81 |
+
elif args.hyperparams.optimizer == "adamw":
|
| 82 |
+
print("Use AdamW optimizer.")
|
| 83 |
+
self.optimizer = torch.optim.AdamW(
|
| 84 |
+
params=trainable_params,
|
| 85 |
+
lr=args.hyperparams.lr,
|
| 86 |
+
betas=(0.9, 0.999),
|
| 87 |
+
amsgrad=False,
|
| 88 |
+
weight_decay=args.hyperparams.weight_decay,
|
| 89 |
+
)
|
| 90 |
+
elif args.hyperparams.optimizer == "radam":
|
| 91 |
+
print("Use RAdam optimizer.")
|
| 92 |
+
self.optimizer = torch.optim.RAdam(
|
| 93 |
+
params=trainable_params,
|
| 94 |
+
lr=args.hyperparams.lr,
|
| 95 |
+
betas=(0.9, 0.999),
|
| 96 |
+
eps=1e-08,
|
| 97 |
+
weight_decay=args.hyperparams.weight_decay,
|
| 98 |
+
)
|
| 99 |
+
elif args.hyperparams.optimizer == "adam":
|
| 100 |
+
print("Use Adam optimizer.")
|
| 101 |
+
self.optimizer = torch.optim.Adam(
|
| 102 |
+
params=trainable_params,
|
| 103 |
+
lr=args.hyperparams.lr,
|
| 104 |
+
betas=(0.9, 0.999),
|
| 105 |
+
weight_decay=args.hyperparams.weight_decay,
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
print("no optimizer loaded")
|
| 109 |
+
raise NotImplementedError
|
| 110 |
+
|
| 111 |
+
if args.hyperparams.lr_scheduler == "step_lr":
|
| 112 |
+
if args.model_loss_params.architecture == "umx":
|
| 113 |
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 114 |
+
self.optimizer,
|
| 115 |
+
mode="min",
|
| 116 |
+
factor=args.hyperparams.lr_decay_gamma,
|
| 117 |
+
patience=args.hyperparams.lr_decay_patience,
|
| 118 |
+
cooldown=10,
|
| 119 |
+
verbose=True,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 123 |
+
self.optimizer,
|
| 124 |
+
mode="min",
|
| 125 |
+
factor=args.hyperparams.lr_decay_gamma,
|
| 126 |
+
patience=args.hyperparams.lr_decay_patience,
|
| 127 |
+
cooldown=0,
|
| 128 |
+
min_lr=5e-5,
|
| 129 |
+
verbose=True,
|
| 130 |
+
)
|
| 131 |
+
elif args.hyperparams.lr_scheduler == "cos_warmup":
|
| 132 |
+
self.scheduler = utils.CosineAnnealingWarmUpRestarts(
|
| 133 |
+
self.optimizer,
|
| 134 |
+
T_0=40,
|
| 135 |
+
T_mult=1,
|
| 136 |
+
eta_max=args.hyperparams.lr,
|
| 137 |
+
T_up=10,
|
| 138 |
+
gamma=0.5,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
torch.cuda.set_device(args.gpu)
|
| 142 |
+
|
| 143 |
+
self.model = self.model.to(f"cuda:{args.gpu}")
|
| 144 |
+
|
| 145 |
+
############################################################
|
| 146 |
+
# Define Losses
|
| 147 |
+
self.criterion = {}
|
| 148 |
+
|
| 149 |
+
self.criterion["l1"] = nn.L1Loss().to(args.gpu)
|
| 150 |
+
self.criterion["mse"] = nn.MSELoss().to(args.gpu)
|
| 151 |
+
self.criterion["si_sdr"] = pairwise_neg_sisdr.to(args.gpu)
|
| 152 |
+
self.criterion["snr"] = PairwiseNegSDR("snr").to(args.gpu)
|
| 153 |
+
self.criterion["bcewithlogits"] = nn.BCEWithLogitsLoss().to(args.gpu)
|
| 154 |
+
self.criterion["bce"] = nn.BCELoss().to(args.gpu)
|
| 155 |
+
self.criterion["kl"] = nn.KLDivLoss(log_target=True).to(args.gpu)
|
| 156 |
+
|
| 157 |
+
print("Loss functions we use in this training:")
|
| 158 |
+
print(args.model_loss_params.train_loss_func)
|
| 159 |
+
|
| 160 |
+
# Early stopping utils
|
| 161 |
+
self.es = utils.EarlyStopping(patience=args.hyperparams.patience)
|
| 162 |
+
self.stop = False
|
| 163 |
+
|
| 164 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
| 165 |
+
wandb.watch(self.model, log="all")
|
| 166 |
+
|
| 167 |
+
self.start_epoch = 1
|
| 168 |
+
self.train_losses = []
|
| 169 |
+
self.valid_losses = []
|
| 170 |
+
self.train_times = []
|
| 171 |
+
self.best_epoch = 0
|
| 172 |
+
|
| 173 |
+
if args.dir_params.resume and not args.hyperparams.ema:
|
| 174 |
+
self.resume(args)
|
| 175 |
+
|
| 176 |
+
# Distribute models to machine
|
| 177 |
+
self.model = DDP(
|
| 178 |
+
self.model,
|
| 179 |
+
device_ids=[args.gpu],
|
| 180 |
+
output_device=args.gpu,
|
| 181 |
+
find_unused_parameters=True,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if args.hyperparams.ema:
|
| 185 |
+
self.model_ema = EMA(
|
| 186 |
+
self.model,
|
| 187 |
+
beta=0.999,
|
| 188 |
+
update_after_step=100,
|
| 189 |
+
update_every=10,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if args.resume and args.hyperparams.ema:
|
| 193 |
+
self.resume(args)
|
| 194 |
+
|
| 195 |
+
###################### Define data pipeline ######################
|
| 196 |
+
args.hyperparams.batch_size = int(
|
| 197 |
+
args.hyperparams.batch_size / args.ngpus_per_node
|
| 198 |
+
)
|
| 199 |
+
self.mp_context = torch.multiprocessing.get_context("fork")
|
| 200 |
+
|
| 201 |
+
if args.task_params.dataset == "musdb":
|
| 202 |
+
self.train_dataset = MusdbTrainDataset(
|
| 203 |
+
target=args.task_params.target,
|
| 204 |
+
root=args.dir_params.root,
|
| 205 |
+
seq_duration=args.data_params.seq_dur,
|
| 206 |
+
samples_per_track=args.data_params.samples_per_track,
|
| 207 |
+
source_augmentations=aug_from_str(
|
| 208 |
+
["gain", "channelswap"],
|
| 209 |
+
),
|
| 210 |
+
sample_rate=args.data_params.sample_rate,
|
| 211 |
+
seed=args.sys_params.seed,
|
| 212 |
+
limitaug_method=args.data_params.limitaug_method,
|
| 213 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
| 214 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
| 215 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
| 216 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
| 217 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
| 218 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
| 219 |
+
)
|
| 220 |
+
self.valid_dataset = MusdbValidDataset(
|
| 221 |
+
target=args.task_params.target, root=args.dir_params.root
|
| 222 |
+
)
|
| 223 |
+
elif args.task_params.dataset == "delimit":
|
| 224 |
+
if args.data_params.limitaug_method == "ozone":
|
| 225 |
+
self.train_dataset = OzoneTrainDataset(
|
| 226 |
+
target=args.task_params.target,
|
| 227 |
+
root=args.dir_params.root,
|
| 228 |
+
ozone_root=args.dir_params.ozone_root,
|
| 229 |
+
use_fixed=args.data_params.use_fixed,
|
| 230 |
+
seq_duration=args.data_params.seq_dur,
|
| 231 |
+
samples_per_track=args.data_params.samples_per_track,
|
| 232 |
+
source_augmentations=aug_from_str(
|
| 233 |
+
["gain", "channelswap"],
|
| 234 |
+
),
|
| 235 |
+
sample_rate=args.data_params.sample_rate,
|
| 236 |
+
seed=args.sys_params.seed,
|
| 237 |
+
limitaug_method=args.data_params.limitaug_method,
|
| 238 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
| 239 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
| 240 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
| 241 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
| 242 |
+
target_limitaug_mode=args.data_params.target_limitaug_mode,
|
| 243 |
+
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
|
| 244 |
+
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
|
| 245 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
| 246 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
| 247 |
+
)
|
| 248 |
+
self.valid_dataset = OzoneValidDataset(
|
| 249 |
+
target=args.task_params.target,
|
| 250 |
+
root=args.dir_params.root,
|
| 251 |
+
ozone_root=args.dir_params.ozone_root,
|
| 252 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
self.train_dataset = DelimitTrainDataset(
|
| 256 |
+
target=args.task_params.target,
|
| 257 |
+
root=args.dir_params.root,
|
| 258 |
+
seq_duration=args.data_params.seq_dur,
|
| 259 |
+
samples_per_track=args.data_params.samples_per_track,
|
| 260 |
+
source_augmentations=aug_from_str(
|
| 261 |
+
["gain", "channelswap"],
|
| 262 |
+
),
|
| 263 |
+
sample_rate=args.data_params.sample_rate,
|
| 264 |
+
seed=args.sys_params.seed,
|
| 265 |
+
limitaug_method=args.data_params.limitaug_method,
|
| 266 |
+
limitaug_mode=args.data_params.limitaug_mode,
|
| 267 |
+
limitaug_custom_target_lufs=args.data_params.limitaug_custom_target_lufs,
|
| 268 |
+
limitaug_custom_target_lufs_std=args.data_params.limitaug_custom_target_lufs_std,
|
| 269 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
| 270 |
+
target_limitaug_mode=args.data_params.target_limitaug_mode,
|
| 271 |
+
target_limitaug_custom_target_lufs=args.data_params.target_limitaug_custom_target_lufs,
|
| 272 |
+
target_limitaug_custom_target_lufs_std=args.data_params.target_limitaug_custom_target_lufs_std,
|
| 273 |
+
custom_limiter_attack_range=args.data_params.custom_limiter_attack_range,
|
| 274 |
+
custom_limiter_release_range=args.data_params.custom_limiter_release_range,
|
| 275 |
+
)
|
| 276 |
+
self.valid_dataset = DelimitValidDataset(
|
| 277 |
+
target=args.task_params.target,
|
| 278 |
+
root=args.dir_params.root,
|
| 279 |
+
delimit_valid_root=args.dir_params.delimit_valid_root,
|
| 280 |
+
valid_target_lufs=args.data_params.valid_target_lufs,
|
| 281 |
+
target_loudnorm_lufs=args.data_params.target_loudnorm_lufs,
|
| 282 |
+
delimit_valid_L_root=args.dir_params.delimit_valid_L_root,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.train_sampler = DistributedSampler(
|
| 286 |
+
self.train_dataset, shuffle=True, rank=args.gpu
|
| 287 |
+
)
|
| 288 |
+
self.train_loader = torch.utils.data.DataLoader(
|
| 289 |
+
self.train_dataset,
|
| 290 |
+
batch_size=args.hyperparams.batch_size,
|
| 291 |
+
shuffle=False,
|
| 292 |
+
num_workers=args.sys_params.nb_workers,
|
| 293 |
+
multiprocessing_context=self.mp_context,
|
| 294 |
+
pin_memory=True,
|
| 295 |
+
sampler=self.train_sampler,
|
| 296 |
+
drop_last=False,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
self.valid_sampler = DistributedSampler(
|
| 300 |
+
self.valid_dataset, shuffle=False, rank=args.gpu
|
| 301 |
+
)
|
| 302 |
+
self.valid_loader = torch.utils.data.DataLoader(
|
| 303 |
+
self.valid_dataset,
|
| 304 |
+
batch_size=1,
|
| 305 |
+
shuffle=False,
|
| 306 |
+
num_workers=args.sys_params.nb_workers,
|
| 307 |
+
multiprocessing_context=self.mp_context,
|
| 308 |
+
pin_memory=False,
|
| 309 |
+
sampler=self.valid_sampler,
|
| 310 |
+
drop_last=False,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def train(self, args, epoch):
|
| 314 |
+
self.end = time.time()
|
| 315 |
+
self.model.train()
|
| 316 |
+
|
| 317 |
+
# get current learning rate
|
| 318 |
+
for param_group in self.optimizer.param_groups:
|
| 319 |
+
current_lr = param_group["lr"]
|
| 320 |
+
|
| 321 |
+
if (
|
| 322 |
+
args.sys_params.rank % args.ngpus_per_node == 0
|
| 323 |
+
): # when the last rank process is finished
|
| 324 |
+
print(f"Epoch {epoch}, Learning rate: {current_lr}")
|
| 325 |
+
|
| 326 |
+
losses = utils.AverageMeter()
|
| 327 |
+
loss_logger = {}
|
| 328 |
+
|
| 329 |
+
loss_logger["train/train loss"] = 0
|
| 330 |
+
# with torch.autograd.detect_anomaly(): # use this if you want to detect anomaly behavior while training.
|
| 331 |
+
for i, values in enumerate(self.train_loader):
|
| 332 |
+
mixture, clean, *train_vars = values
|
| 333 |
+
|
| 334 |
+
mixture = mixture.cuda(args.gpu, non_blocking=True)
|
| 335 |
+
clean = clean.cuda(args.gpu, non_blocking=True)
|
| 336 |
+
target = clean # target_shape = [batch_size, n_srcs, nb_channels (if stereo: 2), wave_length]
|
| 337 |
+
loss_input = {}
|
| 338 |
+
|
| 339 |
+
estimates, *estimates_vars = self.model(mixture)
|
| 340 |
+
# estimates = self.model(mixture)
|
| 341 |
+
|
| 342 |
+
# loss = []
|
| 343 |
+
dict_loss = {}
|
| 344 |
+
|
| 345 |
+
if args.task_params.dataset == "delimit":
|
| 346 |
+
estimates = estimates_vars[0]
|
| 347 |
+
|
| 348 |
+
for train_loss_idx, single_train_loss_func in enumerate(
|
| 349 |
+
args.model_loss_params.train_loss_func
|
| 350 |
+
):
|
| 351 |
+
if self.model.module.use_encoder_to_target:
|
| 352 |
+
target_spec = self.model.module.encoder(
|
| 353 |
+
rearrange(target, "b s c t -> (b s) c t")
|
| 354 |
+
)
|
| 355 |
+
target_spec = rearrange(
|
| 356 |
+
target_spec,
|
| 357 |
+
"(b s) c f t -> b s c f t",
|
| 358 |
+
s=args.task_params.bleeding_nsrcs,
|
| 359 |
+
)
|
| 360 |
+
loss_else = self.criterion[single_train_loss_func](
|
| 361 |
+
estimates,
|
| 362 |
+
target_spec
|
| 363 |
+
if self.model.module.use_encoder_to_target
|
| 364 |
+
else target,
|
| 365 |
+
)
|
| 366 |
+
dict_loss[single_train_loss_func] = (
|
| 367 |
+
loss_else.mean()
|
| 368 |
+
* args.model_loss_params.train_loss_scales[train_loss_idx]
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
loss = sum([value for key, value in dict_loss.items()])
|
| 372 |
+
|
| 373 |
+
############################################################
|
| 374 |
+
|
| 375 |
+
#################### 5. Back propagation ####################
|
| 376 |
+
loss.backward()
|
| 377 |
+
if args.hyperparams.gradient_clip:
|
| 378 |
+
nn.utils.clip_grad_norm_(
|
| 379 |
+
self.model.parameters(), max_norm=args.hyperparams.gradient_clip
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
losses.update(loss.item(), clean.size(0))
|
| 383 |
+
|
| 384 |
+
loss_logger["train/train loss"] = losses.avg
|
| 385 |
+
for key, value in dict_loss.items():
|
| 386 |
+
loss_logger[f"train/{key}"] = value.item()
|
| 387 |
+
|
| 388 |
+
self.optimizer.step()
|
| 389 |
+
|
| 390 |
+
self.model.zero_grad(
|
| 391 |
+
set_to_none=True
|
| 392 |
+
) # set_to_none=True is for memory saving
|
| 393 |
+
|
| 394 |
+
if args.hyperparams.ema:
|
| 395 |
+
self.model_ema.update()
|
| 396 |
+
############################################################
|
| 397 |
+
|
| 398 |
+
# ###################### 6. Plot ######################
|
| 399 |
+
|
| 400 |
+
if i % 30 == 0:
|
| 401 |
+
# loss print for multiple loss function
|
| 402 |
+
multiple_score = torch.Tensor(
|
| 403 |
+
[value for key, value in loss_logger.items()]
|
| 404 |
+
).to(args.gpu)
|
| 405 |
+
gathered_score_list = [
|
| 406 |
+
torch.ones_like(multiple_score)
|
| 407 |
+
for _ in range(dist.get_world_size())
|
| 408 |
+
]
|
| 409 |
+
dist.all_gather(gathered_score_list, multiple_score)
|
| 410 |
+
gathered_score = torch.mean(
|
| 411 |
+
torch.stack(gathered_score_list, dim=0), dim=0
|
| 412 |
+
)
|
| 413 |
+
if args.gpu == 0:
|
| 414 |
+
print(f"Epoch {epoch}, step {i} / {len(self.train_loader)}")
|
| 415 |
+
temp_loss_logger = {}
|
| 416 |
+
for index, (key, value) in enumerate(loss_logger.items()):
|
| 417 |
+
temp_key = key.replace("train/", "iter-wise/")
|
| 418 |
+
temp_loss_logger[temp_key] = round(
|
| 419 |
+
gathered_score[index].item(), 6
|
| 420 |
+
)
|
| 421 |
+
print(f"{key} : {round(gathered_score[index].item(), 6)}")
|
| 422 |
+
|
| 423 |
+
single_score = torch.Tensor([losses.avg]).to(args.gpu)
|
| 424 |
+
|
| 425 |
+
gathered_score_list = [
|
| 426 |
+
torch.ones_like(single_score) for _ in range(dist.get_world_size())
|
| 427 |
+
]
|
| 428 |
+
dist.all_gather(gathered_score_list, single_score)
|
| 429 |
+
gathered_score = torch.mean(torch.cat(gathered_score_list)).item()
|
| 430 |
+
if args.gpu == 0:
|
| 431 |
+
self.train_losses.append(gathered_score)
|
| 432 |
+
if args.wandb_params.use_wandb:
|
| 433 |
+
loss_logger["train/train loss"] = single_score
|
| 434 |
+
loss_logger["train/epoch"] = epoch
|
| 435 |
+
wandb.log(loss_logger)
|
| 436 |
+
############################################################
|
| 437 |
+
|
| 438 |
+
def multi_validate(self, args, epoch):
|
| 439 |
+
if args.gpu == 0:
|
| 440 |
+
print(f"Epoch {epoch} Validation session!")
|
| 441 |
+
|
| 442 |
+
losses = utils.AverageMeter()
|
| 443 |
+
|
| 444 |
+
loss_logger = {}
|
| 445 |
+
|
| 446 |
+
self.model.eval()
|
| 447 |
+
|
| 448 |
+
with torch.no_grad():
|
| 449 |
+
for i, values in enumerate(self.valid_loader, start=1):
|
| 450 |
+
mixture, clean, song_name, *valid_vars = values
|
| 451 |
+
|
| 452 |
+
mixture = mixture.cuda(args.gpu, non_blocking=True)
|
| 453 |
+
clean = clean.cuda(args.gpu, non_blocking=True)
|
| 454 |
+
target = clean
|
| 455 |
+
|
| 456 |
+
dict_loss = {}
|
| 457 |
+
if not args.data_params.singleset_num_frames:
|
| 458 |
+
if args.hyperparams.ema:
|
| 459 |
+
estimates, *estimates_vars = self.model_ema(mixture)
|
| 460 |
+
else:
|
| 461 |
+
estimates, *estimates_vars = self.model(mixture)
|
| 462 |
+
if args.task_params.dataset == "delimit":
|
| 463 |
+
estimates = estimates_vars[0]
|
| 464 |
+
|
| 465 |
+
estimates = estimates[..., : clean.size(-1)]
|
| 466 |
+
|
| 467 |
+
else: # use SingleTrackSet
|
| 468 |
+
db = SingleTrackSet(
|
| 469 |
+
mixture[0],
|
| 470 |
+
hop_length=args.data_params.nhop,
|
| 471 |
+
num_frame=args.data_params.singleset_num_frames,
|
| 472 |
+
target_name=args.task_params.target,
|
| 473 |
+
)
|
| 474 |
+
separated = []
|
| 475 |
+
|
| 476 |
+
for item in db:
|
| 477 |
+
|
| 478 |
+
if args.hyperparams.ema:
|
| 479 |
+
estimates, *estimates_vars = self.model_ema(
|
| 480 |
+
item.unsqueeze(0).to(args.gpu)
|
| 481 |
+
)
|
| 482 |
+
else:
|
| 483 |
+
estimates, *estimates_vars = self.model(
|
| 484 |
+
item.unsqueeze(0).to(args.gpu)
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
if args.task_params.dataset == "delimit":
|
| 488 |
+
estimates = estimates_vars[0]
|
| 489 |
+
|
| 490 |
+
separated.append(
|
| 491 |
+
estimates_vars[0][
|
| 492 |
+
..., db.trim_length : -db.trim_length
|
| 493 |
+
].clone()
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
estimates = torch.cat(separated, dim=-1)
|
| 497 |
+
estimates = estimates[..., : target.shape[-1]]
|
| 498 |
+
|
| 499 |
+
for valid_loss_idx, single_valid_loss_func in enumerate(
|
| 500 |
+
args.model_loss_params.valid_loss_func
|
| 501 |
+
):
|
| 502 |
+
loss_else = self.criterion[single_valid_loss_func](
|
| 503 |
+
estimates,
|
| 504 |
+
target,
|
| 505 |
+
)
|
| 506 |
+
dict_loss[single_valid_loss_func] = (
|
| 507 |
+
loss_else.mean()
|
| 508 |
+
* args.model_loss_params.valid_loss_scales[valid_loss_idx]
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
loss = sum([value for key, value in dict_loss.items()])
|
| 512 |
+
|
| 513 |
+
losses.update(loss.item(), clean.size(0))
|
| 514 |
+
|
| 515 |
+
list_sum_count = torch.Tensor([losses.sum, losses.count]).to(args.gpu)
|
| 516 |
+
list_gathered_sum_count = [
|
| 517 |
+
torch.ones_like(list_sum_count) for _ in range(dist.get_world_size())
|
| 518 |
+
]
|
| 519 |
+
dist.all_gather(list_gathered_sum_count, list_sum_count)
|
| 520 |
+
gathered_score = reduce(
|
| 521 |
+
torch.stack(list_gathered_sum_count), "s c -> c", "sum"
|
| 522 |
+
) # s: sum of losses.sum, c: sum of losses.count
|
| 523 |
+
gathered_score = (gathered_score[0] / gathered_score[1]).item()
|
| 524 |
+
|
| 525 |
+
loss_logger["valid/valid loss"] = gathered_score
|
| 526 |
+
for key, value in dict_loss.items():
|
| 527 |
+
loss_logger[f"valid/{key}"] = value.item()
|
| 528 |
+
|
| 529 |
+
if args.hyperparams.lr_scheduler == "step_lr":
|
| 530 |
+
self.scheduler.step(gathered_score)
|
| 531 |
+
elif args.hyperparams.lr_scheduler == "cos_warmup":
|
| 532 |
+
self.scheduler.step(epoch)
|
| 533 |
+
else:
|
| 534 |
+
self.scheduler.step(gathered_score)
|
| 535 |
+
|
| 536 |
+
if args.wandb_params.use_wandb and args.gpu == 0:
|
| 537 |
+
loss_logger["valid/epoch"] = epoch
|
| 538 |
+
wandb.log(loss_logger)
|
| 539 |
+
|
| 540 |
+
if args.gpu == 0:
|
| 541 |
+
self.valid_losses.append(gathered_score)
|
| 542 |
+
|
| 543 |
+
self.stop = self.es.step(gathered_score)
|
| 544 |
+
|
| 545 |
+
print(f"Epoch {epoch}, validation loss : {round(gathered_score, 6)}")
|
| 546 |
+
|
| 547 |
+
plt.plot(self.train_losses, label="train loss")
|
| 548 |
+
plt.plot(self.valid_losses, label="valid loss")
|
| 549 |
+
plt.legend(loc="upper right")
|
| 550 |
+
plt.savefig(f"{args.output}/loss_graph_{args.task_params.target}.png")
|
| 551 |
+
plt.close()
|
| 552 |
+
|
| 553 |
+
save_states = {
|
| 554 |
+
"epoch": epoch,
|
| 555 |
+
"state_dict": self.model.module.state_dict()
|
| 556 |
+
if not args.hyperparams.ema
|
| 557 |
+
else self.model_ema.state_dict(),
|
| 558 |
+
"best_loss": self.es.best,
|
| 559 |
+
"optimizer": self.optimizer.state_dict(),
|
| 560 |
+
"scheduler": self.scheduler.state_dict(),
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
utils.save_checkpoint(
|
| 564 |
+
save_states,
|
| 565 |
+
state_dict_only=gathered_score == self.es.best,
|
| 566 |
+
path=args.output,
|
| 567 |
+
target=args.task_params.target,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
self.train_times.append(time.time() - self.end)
|
| 571 |
+
|
| 572 |
+
if gathered_score == self.es.best:
|
| 573 |
+
self.best_epoch = epoch
|
| 574 |
+
|
| 575 |
+
# save params
|
| 576 |
+
params = {
|
| 577 |
+
"epochs_trained": epoch,
|
| 578 |
+
"args": args.toDict(),
|
| 579 |
+
"best_loss": self.es.best,
|
| 580 |
+
"best_epoch": self.best_epoch,
|
| 581 |
+
"train_loss_history": self.train_losses,
|
| 582 |
+
"valid_loss_history": self.valid_losses,
|
| 583 |
+
"train_time_history": self.train_times,
|
| 584 |
+
"num_bad_epochs": self.es.num_bad_epochs,
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
with open(
|
| 588 |
+
f"{args.output}/{args.task_params.target}.json", "w"
|
| 589 |
+
) as outfile:
|
| 590 |
+
outfile.write(json.dumps(params, indent=4, sort_keys=True))
|
| 591 |
+
|
| 592 |
+
self.train_times.append(time.time() - self.end)
|
| 593 |
+
print(
|
| 594 |
+
f"Epoch {epoch} train completed. Took {round(self.train_times[-1], 3)} seconds"
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
def resume(self, args):
|
| 598 |
+
print(f"Resume checkpoint from: {args.dir_params.resume}:")
|
| 599 |
+
loc = f"cuda:{args.gpu}"
|
| 600 |
+
checkpoint_path = f"{args.dir_params.resume}/{args.task_params.target}"
|
| 601 |
+
with open(f"{checkpoint_path}.json", "r") as stream:
|
| 602 |
+
results = json.load(stream)
|
| 603 |
+
checkpoint = torch.load(f"{checkpoint_path}.chkpnt", map_location=loc)
|
| 604 |
+
|
| 605 |
+
if args.hyperparams.ema:
|
| 606 |
+
self.model_ema.load_state_dict(checkpoint["state_dict"])
|
| 607 |
+
else:
|
| 608 |
+
self.model.load_state_dict(checkpoint["state_dict"])
|
| 609 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 610 |
+
|
| 611 |
+
if (
|
| 612 |
+
args.dir_params.continual_train
|
| 613 |
+
): # we want to use a pre-trained model but not want to use lr_scheduler history.
|
| 614 |
+
for param_group in self.optimizer.param_groups:
|
| 615 |
+
param_group["lr"] = args.hyperparams.lr
|
| 616 |
+
else:
|
| 617 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
| 618 |
+
self.es.best = results["best_loss"]
|
| 619 |
+
self.es.num_bad_epochs = results["num_bad_epochs"]
|
| 620 |
+
|
| 621 |
+
self.start_epoch = results["epochs_trained"]
|
| 622 |
+
self.train_losses = results["train_loss_history"]
|
| 623 |
+
self.valid_losses = results["valid_loss_history"]
|
| 624 |
+
self.train_times = results["train_time_history"]
|
| 625 |
+
self.best_epoch = results["best_epoch"]
|
| 626 |
+
if args.sys_params.rank % args.ngpus_per_node == 0:
|
| 627 |
+
print(
|
| 628 |
+
f"=> loaded checkpoint {checkpoint_path} (epoch {results['epochs_trained']})"
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
def cal_loss(self, args, loss_input):
|
| 632 |
+
loss_dict = {}
|
| 633 |
+
for key, value in loss_input.items():
|
| 634 |
+
loss_dict[key] = self.criterion[key](*value)
|
| 635 |
+
|
| 636 |
+
return loss_dict
|
| 637 |
+
|
| 638 |
+
def cal_multiple_losses(self, args, dict_loss_name_input):
|
| 639 |
+
loss_dict = {}
|
| 640 |
+
for loss_name, loss_input in dict_loss_name_input.items():
|
| 641 |
+
loss_dict[loss_name] = self.cal_loss(args, loss_input)
|
| 642 |
+
|
| 643 |
+
return loss_dict
|
test_ddp.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# To be honest... this is not ddp.
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import glob
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
import musdb
|
| 10 |
+
import librosa
|
| 11 |
+
import soundfile as sf
|
| 12 |
+
import pyloudnorm as pyln
|
| 13 |
+
from dotmap import DotMap
|
| 14 |
+
|
| 15 |
+
from models import load_model_with_args
|
| 16 |
+
from separate_func import (
|
| 17 |
+
conv_tasnet_separate,
|
| 18 |
+
)
|
| 19 |
+
from utils import str2bool, db2linear
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
tqdm.monitor_interval = 0
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def separate_track_with_model(
|
| 26 |
+
args, model, device, track_audio, track_name, meter, augmented_gain
|
| 27 |
+
):
|
| 28 |
+
with torch.no_grad():
|
| 29 |
+
if (
|
| 30 |
+
args.model_loss_params.architecture == "conv_tasnet_mask_on_output"
|
| 31 |
+
or args.model_loss_params.architecture == "conv_tasnet"
|
| 32 |
+
):
|
| 33 |
+
estimates = conv_tasnet_separate(
|
| 34 |
+
args,
|
| 35 |
+
model,
|
| 36 |
+
device,
|
| 37 |
+
track_audio,
|
| 38 |
+
track_name,
|
| 39 |
+
meter=meter,
|
| 40 |
+
augmented_gain=augmented_gain,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return estimates
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def main():
|
| 47 |
+
parser = argparse.ArgumentParser(description="model test.py")
|
| 48 |
+
|
| 49 |
+
parser.add_argument("--target", type=str, default="all")
|
| 50 |
+
parser.add_argument("--data_root", type=str, default="/path/to/musdb_XL")
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--use_musdb",
|
| 53 |
+
type=str2bool,
|
| 54 |
+
default=True,
|
| 55 |
+
help="Use musdb test data or just want to inference other samples?",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument("--exp_name", type=str, default="delimit_6_s')
|
| 58 |
+
parser.add_argument("--manual_output_name", type=str, default=None)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--output_directory", type=str, default="/path/to/results"
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
| 63 |
+
parser.add_arugment("--save_name_as_target", type=str2bool, default=True)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--loudnorm_input_lufs",
|
| 66 |
+
type=float,
|
| 67 |
+
default=None,
|
| 68 |
+
help="If you want to use loudnorm, input target lufs",
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--use_singletrackset",
|
| 72 |
+
type=str2bool,
|
| 73 |
+
default=False,
|
| 74 |
+
help="Use SingleTrackSet for X-UMX",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--best_model",
|
| 78 |
+
type=str2bool,
|
| 79 |
+
default=True,
|
| 80 |
+
help="Use best model or lastly saved model",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--save_output_loudnorm",
|
| 84 |
+
type=float,
|
| 85 |
+
default=None,
|
| 86 |
+
help="Save loudness normalized outputs or not. If you want to save, input target loudness",
|
| 87 |
+
)
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--save_mixed_output",
|
| 90 |
+
type=float,
|
| 91 |
+
default=None,
|
| 92 |
+
help="Save original+delimited-estimation mixed output with a ratio of default 0.5 (orginal) and 1 - 0.5 (estimation)",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--save_16k_mono",
|
| 96 |
+
type=str2bool,
|
| 97 |
+
default=False,
|
| 98 |
+
help="Save 16k mono wav files for FAD evaluation.",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--save_histogram",
|
| 102 |
+
type=str2bool,
|
| 103 |
+
default=False,
|
| 104 |
+
help="Save histogram of the output. Only valid when the task is 'delimit'",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
args, _ = parser.parse_known_args()
|
| 108 |
+
|
| 109 |
+
args.output_dir = f"{args.output_directory}/checkpoint/{args.exp_name}"
|
| 110 |
+
with open(f"{args.output_dir}/{args.target}.json", "r") as f:
|
| 111 |
+
args_dict = json.load(f)
|
| 112 |
+
args_dict = DotMap(args_dict)
|
| 113 |
+
|
| 114 |
+
for key, value in args_dict["args"].items():
|
| 115 |
+
if key in list(vars(args).keys()):
|
| 116 |
+
pass
|
| 117 |
+
else:
|
| 118 |
+
setattr(args, key, value)
|
| 119 |
+
|
| 120 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.exp_name}"
|
| 121 |
+
|
| 122 |
+
if args.manual_output_name != None:
|
| 123 |
+
args.test_output_dir = f"{args.output_directory}/test/{args.manual_output_name}"
|
| 124 |
+
os.makedirs(args.test_output_dir, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
device = torch.device(
|
| 127 |
+
"cuda" if torch.cuda.is_available() and args.use_gpu else "cpu"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
###################### Define Models ######################
|
| 131 |
+
our_model = load_model_with_args(args)
|
| 132 |
+
our_model = our_model.to(device)
|
| 133 |
+
print(our_model)
|
| 134 |
+
pytorch_total_params = sum(
|
| 135 |
+
p.numel() for p in our_model.parameters() if p.requires_grad
|
| 136 |
+
)
|
| 137 |
+
print("Total number of parameters", pytorch_total_params)
|
| 138 |
+
# Future work => Torchinfo would be better for this purpose.
|
| 139 |
+
|
| 140 |
+
if args.best_model:
|
| 141 |
+
target_model_path = f"{args.output_dir}/{args.target}.pth"
|
| 142 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
| 143 |
+
our_model.load_state_dict(checkpoint)
|
| 144 |
+
else: # when using lastly saved model
|
| 145 |
+
target_model_path = f"{args.output_dir}/{args.target}.chkpnt"
|
| 146 |
+
checkpoint = torch.load(target_model_path, map_location=device)
|
| 147 |
+
our_model.load_state_dict(checkpoint["state_dict"])
|
| 148 |
+
|
| 149 |
+
our_model.eval()
|
| 150 |
+
|
| 151 |
+
meter = pyln.Meter(44100)
|
| 152 |
+
|
| 153 |
+
if args.use_musdb:
|
| 154 |
+
test_tracks = musdb.DB(root=args.data_root, subsets="test", is_wav=True)
|
| 155 |
+
|
| 156 |
+
for track in tqdm.tqdm(test_tracks):
|
| 157 |
+
track_name = track.name
|
| 158 |
+
track_audio = track.audio
|
| 159 |
+
|
| 160 |
+
orig_audio = track_audio.copy()
|
| 161 |
+
|
| 162 |
+
augmented_gain = None
|
| 163 |
+
print("Now De-limiting : ", track_name)
|
| 164 |
+
|
| 165 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
| 166 |
+
track_lufs = meter.integrated_loudness(track_audio)
|
| 167 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
| 168 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 169 |
+
|
| 170 |
+
track_audio = (
|
| 171 |
+
torch.as_tensor(track_audio.T, dtype=torch.float32)
|
| 172 |
+
.unsqueeze(0)
|
| 173 |
+
.to(device)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
estimates = separate_track_with_model(
|
| 177 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if args.save_mixed_output:
|
| 181 |
+
orig_audio = orig_audio.T
|
| 182 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
| 183 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
| 184 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
| 185 |
+
|
| 186 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
| 187 |
+
1 - args.save_mixed_output
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
sf.write(
|
| 191 |
+
f"{args.test_output_dir}/{track_name}/{str(args.save_mixed_output)}_mixed.wav",
|
| 192 |
+
mixed_output.T,
|
| 193 |
+
args.data_params.sample_rate,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
test_tracks = glob.glob(f"{args.data_root}/*.wav") + glob.glob(
|
| 197 |
+
f"{args.data_root}/*.mp3"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
for track in tqdm.tqdm(test_tracks):
|
| 201 |
+
track_name = os.path.basename(track).replace(".wav", "").replace(".mp3", "")
|
| 202 |
+
track_audio, sr = librosa.load(
|
| 203 |
+
track, sr=None, mono=False
|
| 204 |
+
) # sr should be 44100
|
| 205 |
+
|
| 206 |
+
orig_audio = track_audio.copy()
|
| 207 |
+
|
| 208 |
+
if sr != 44100:
|
| 209 |
+
raise ValueError("Sample rate should be 44100")
|
| 210 |
+
augmented_gain = None
|
| 211 |
+
print("Now De-limiting : ", track_name)
|
| 212 |
+
|
| 213 |
+
if args.loudnorm_input_lufs: # If you want to use loud-normalized input
|
| 214 |
+
track_lufs = meter.integrated_loudness(track_audio.T)
|
| 215 |
+
augmented_gain = args.loudnorm_input_lufs - track_lufs
|
| 216 |
+
track_audio = track_audio * db2linear(augmented_gain, eps=0.0)
|
| 217 |
+
|
| 218 |
+
track_audio = (
|
| 219 |
+
torch.as_tensor(track_audio, dtype=torch.float32)
|
| 220 |
+
.unsqueeze(0)
|
| 221 |
+
.to(device)
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
estimates = separate_track_with_model(
|
| 225 |
+
args, our_model, device, track_audio, track_name, meter, augmented_gain
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if args.save_mixed_output:
|
| 229 |
+
track_lufs = meter.integrated_loudness(orig_audio.T)
|
| 230 |
+
augmented_gain = args.save_output_loudnorm - track_lufs
|
| 231 |
+
orig_audio = orig_audio * db2linear(augmented_gain, eps=0.0)
|
| 232 |
+
|
| 233 |
+
mixed_output = orig_audio * args.save_mixed_output + estimates * (
|
| 234 |
+
1 - args.save_mixed_output
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
sf.write(
|
| 238 |
+
f"{args.test_output_dir}/{track_name}/{track_name}_mixed.wav",
|
| 239 |
+
mixed_output.T,
|
| 240 |
+
args.data_params.sample_rate,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
if __name__ == "__main__":
|
| 245 |
+
main()
|
train_ddp.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.multiprocessing as mp
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
import wandb
|
| 8 |
+
|
| 9 |
+
from solver_ddp import Solver
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def train(args):
|
| 13 |
+
print("hello")
|
| 14 |
+
solver = Solver()
|
| 15 |
+
|
| 16 |
+
ngpus_per_node = int(torch.cuda.device_count() / args.sys_params.n_nodes)
|
| 17 |
+
print(f"use {ngpus_per_node} gpu machine")
|
| 18 |
+
args.sys_params.world_size = ngpus_per_node * args.sys_params.n_nodes
|
| 19 |
+
mp.spawn(worker, nprocs=ngpus_per_node, args=(solver, ngpus_per_node, args))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def worker(gpu, solver, ngpus_per_node, args):
|
| 23 |
+
args.sys_params.rank = args.sys_params.rank * ngpus_per_node + gpu
|
| 24 |
+
dist.init_process_group(
|
| 25 |
+
backend="nccl",
|
| 26 |
+
world_size=args.sys_params.world_size,
|
| 27 |
+
init_method="env://",
|
| 28 |
+
rank=args.sys_params.rank,
|
| 29 |
+
)
|
| 30 |
+
args.gpu = gpu
|
| 31 |
+
args.ngpus_per_node = ngpus_per_node
|
| 32 |
+
|
| 33 |
+
solver.set_gpu(args)
|
| 34 |
+
|
| 35 |
+
start_epoch = solver.start_epoch
|
| 36 |
+
|
| 37 |
+
if args.dir_params.resume:
|
| 38 |
+
start_epoch = start_epoch + 1
|
| 39 |
+
|
| 40 |
+
for epoch in range(start_epoch, args.hyperparams.epochs + 1):
|
| 41 |
+
|
| 42 |
+
solver.train_sampler.set_epoch(epoch)
|
| 43 |
+
solver.train(args, epoch)
|
| 44 |
+
|
| 45 |
+
time.sleep(1)
|
| 46 |
+
|
| 47 |
+
solver.multi_validate(args, epoch)
|
| 48 |
+
|
| 49 |
+
if solver.stop == True:
|
| 50 |
+
print("Apply Early Stopping")
|
| 51 |
+
if args.wandb_params.use_wandb:
|
| 52 |
+
wandb.finish()
|
| 53 |
+
sys.exit()
|
| 54 |
+
|
| 55 |
+
if args.wandb_params.use_wandb:
|
| 56 |
+
wandb.finish()
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .read_wave_utils import (
|
| 2 |
+
load_wav_arbitrary_position_mono,
|
| 3 |
+
load_wav_specific_position_mono,
|
| 4 |
+
load_wav_arbitrary_position_stereo,
|
| 5 |
+
load_wav_specific_position_stereo,
|
| 6 |
+
)
|
| 7 |
+
from .loudness_utils import (
|
| 8 |
+
linear2db,
|
| 9 |
+
db2linear,
|
| 10 |
+
normalize_mag_spec,
|
| 11 |
+
denormalize_mag_spec,
|
| 12 |
+
loudness_match_and_norm,
|
| 13 |
+
loudness_normal_match_and_norm,
|
| 14 |
+
loudness_normal_match_and_norm_output_louder_first,
|
| 15 |
+
loudnorm,
|
| 16 |
+
)
|
| 17 |
+
from .logging import save_img_and_npy, save_checkpoint, AverageMeter, EarlyStopping
|
| 18 |
+
from .lr_scheduler import CosineAnnealingWarmUpRestarts
|
| 19 |
+
from .train_utils import worker_init_fn, str2bool, get_config
|
utils/logging.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib
|
| 6 |
+
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def save_img_and_npy(path, matrix):
|
| 12 |
+
plt.imsave(path + ".png", matrix, origin="lower")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def save_checkpoint(state, state_dict_only, path, target):
|
| 16 |
+
torch.save(state, os.path.join(path, target + ".chkpnt"))
|
| 17 |
+
if state_dict_only:
|
| 18 |
+
# save just the weights
|
| 19 |
+
torch.save(state["state_dict"], os.path.join(path, target + ".pth"))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AverageMeter(object):
|
| 23 |
+
"""Computes and stores the average and current value"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.reset()
|
| 27 |
+
|
| 28 |
+
def reset(self):
|
| 29 |
+
self.val = 0
|
| 30 |
+
self.avg = 0
|
| 31 |
+
self.sum = 0
|
| 32 |
+
self.count = 0
|
| 33 |
+
|
| 34 |
+
def update(self, val, n=1):
|
| 35 |
+
self.val = val
|
| 36 |
+
self.sum += val * n
|
| 37 |
+
self.count += n
|
| 38 |
+
self.avg = self.sum / self.count
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EarlyStopping(object):
|
| 42 |
+
def __init__(self, mode="min", min_delta=0, patience=10):
|
| 43 |
+
self.mode = mode
|
| 44 |
+
self.min_delta = min_delta
|
| 45 |
+
self.patience = patience
|
| 46 |
+
self.best = None
|
| 47 |
+
self.num_bad_epochs = 0
|
| 48 |
+
self.is_better = None
|
| 49 |
+
self._init_is_better(mode, min_delta)
|
| 50 |
+
|
| 51 |
+
if patience == 0:
|
| 52 |
+
self.is_better = lambda a, b: True
|
| 53 |
+
|
| 54 |
+
def step(self, metrics):
|
| 55 |
+
if self.best is None:
|
| 56 |
+
self.best = metrics
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
if np.isnan(metrics):
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
if self.is_better(metrics, self.best):
|
| 63 |
+
self.num_bad_epochs = 0
|
| 64 |
+
self.best = metrics
|
| 65 |
+
else:
|
| 66 |
+
self.num_bad_epochs += 1
|
| 67 |
+
|
| 68 |
+
if self.num_bad_epochs >= self.patience:
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
def _init_is_better(self, mode, min_delta):
|
| 74 |
+
if mode not in {"min", "max"}:
|
| 75 |
+
raise ValueError("mode " + mode + " is unknown!")
|
| 76 |
+
if mode == "min":
|
| 77 |
+
self.is_better = lambda a, best: a < best - min_delta
|
| 78 |
+
if mode == "max":
|
| 79 |
+
self.is_better = lambda a, best: a > best + min_delta
|
utils/loudness_utils.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def linear2db(x, eps=1e-5, scale=20):
|
| 8 |
+
return scale * np.log10(x + eps)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def db2linear(x, eps=1e-5, scale=20):
|
| 12 |
+
return 10 ** (x / scale) - eps
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def normalize_mag_spec(S, min_level_db=-100.0):
|
| 16 |
+
return torch.clamp((S - min_level_db) / -min_level_db, min=0.0, max=1.0)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def denormalize_mag_spec(S, min_level_db=-100.0):
|
| 20 |
+
return torch.clamp(S, min=0.0, max=1.0) * -min_level_db + min_level_db
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def loudness_match_and_norm(audio1, audio2, meter):
|
| 24 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
| 25 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
| 26 |
+
|
| 27 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
| 28 |
+
return audio1, audio2
|
| 29 |
+
else:
|
| 30 |
+
audio2 = audio2 * db2linear(lufs_1 - lufs_2)
|
| 31 |
+
|
| 32 |
+
return audio1, audio2
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def loudness_normal_match_and_norm(audio1, audio2, meter):
|
| 36 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
| 37 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
| 38 |
+
|
| 39 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
| 40 |
+
return audio1, audio2
|
| 41 |
+
else:
|
| 42 |
+
target_lufs = random.normalvariate(lufs_1, 6.0)
|
| 43 |
+
audio2 = audio2 * db2linear(target_lufs - lufs_2)
|
| 44 |
+
|
| 45 |
+
return audio1, audio2
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def loudness_normal_match_and_norm_output_louder_first(audio1, audio2, meter):
|
| 49 |
+
lufs_1 = meter.integrated_loudness(audio1)
|
| 50 |
+
lufs_2 = meter.integrated_loudness(audio2)
|
| 51 |
+
|
| 52 |
+
if np.isinf(lufs_1) or np.isinf(lufs_2):
|
| 53 |
+
return audio1, audio2
|
| 54 |
+
else:
|
| 55 |
+
target_lufs = random.normalvariate(
|
| 56 |
+
lufs_1 - 2.0, 2.0
|
| 57 |
+
) # we want audio1 to be louder than audio2 about target_lufs_diff
|
| 58 |
+
audio2 = audio2 * db2linear(target_lufs - lufs_2)
|
| 59 |
+
|
| 60 |
+
return audio1, audio2
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def loudnorm(audio, target_lufs, meter, eps=1e-5):
|
| 64 |
+
lufs = meter.integrated_loudness(audio)
|
| 65 |
+
if np.isinf(lufs):
|
| 66 |
+
return audio, 0.0
|
| 67 |
+
else:
|
| 68 |
+
adjusted_gain = target_lufs - lufs
|
| 69 |
+
audio = audio * db2linear(adjusted_gain, eps)
|
| 70 |
+
|
| 71 |
+
return audio, adjusted_gain
|
utils/lr_scheduler.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CosineAnnealingWarmUpRestarts(_LRScheduler):
|
| 7 |
+
def __init__(
|
| 8 |
+
self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1.0, last_epoch=-1
|
| 9 |
+
):
|
| 10 |
+
if T_0 <= 0 or not isinstance(T_0, int):
|
| 11 |
+
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
| 12 |
+
if T_mult < 1 or not isinstance(T_mult, int):
|
| 13 |
+
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
| 14 |
+
if T_up < 0 or not isinstance(T_up, int):
|
| 15 |
+
raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
|
| 16 |
+
self.T_0 = T_0
|
| 17 |
+
self.T_mult = T_mult
|
| 18 |
+
self.base_eta_max = eta_max
|
| 19 |
+
self.eta_max = eta_max
|
| 20 |
+
self.T_up = T_up
|
| 21 |
+
self.T_i = T_0
|
| 22 |
+
self.gamma = gamma
|
| 23 |
+
self.cycle = 0
|
| 24 |
+
self.T_cur = last_epoch
|
| 25 |
+
super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
|
| 26 |
+
|
| 27 |
+
def get_lr(self):
|
| 28 |
+
if self.T_cur == -1:
|
| 29 |
+
return self.base_lrs
|
| 30 |
+
elif self.T_cur < self.T_up:
|
| 31 |
+
return [
|
| 32 |
+
(self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr
|
| 33 |
+
for base_lr in self.base_lrs
|
| 34 |
+
]
|
| 35 |
+
else:
|
| 36 |
+
return [
|
| 37 |
+
base_lr
|
| 38 |
+
+ (self.eta_max - base_lr)
|
| 39 |
+
* (
|
| 40 |
+
1
|
| 41 |
+
+ math.cos(
|
| 42 |
+
math.pi * (self.T_cur - self.T_up) / (self.T_i - self.T_up)
|
| 43 |
+
)
|
| 44 |
+
)
|
| 45 |
+
/ 2
|
| 46 |
+
for base_lr in self.base_lrs
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
def step(self, epoch=None):
|
| 50 |
+
if epoch is None:
|
| 51 |
+
epoch = self.last_epoch + 1
|
| 52 |
+
self.T_cur = self.T_cur + 1
|
| 53 |
+
if self.T_cur >= self.T_i:
|
| 54 |
+
self.cycle += 1
|
| 55 |
+
self.T_cur = self.T_cur - self.T_i
|
| 56 |
+
self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
|
| 57 |
+
else:
|
| 58 |
+
if epoch >= self.T_0:
|
| 59 |
+
if self.T_mult == 1:
|
| 60 |
+
self.T_cur = epoch % self.T_0
|
| 61 |
+
self.cycle = epoch // self.T_0
|
| 62 |
+
else:
|
| 63 |
+
n = int(
|
| 64 |
+
math.log(
|
| 65 |
+
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
self.cycle = n
|
| 69 |
+
self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
|
| 70 |
+
self.T_mult - 1
|
| 71 |
+
)
|
| 72 |
+
self.T_i = self.T_0 * self.T_mult ** (n)
|
| 73 |
+
else:
|
| 74 |
+
self.T_i = self.T_0
|
| 75 |
+
self.T_cur = epoch
|
| 76 |
+
|
| 77 |
+
self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
|
| 78 |
+
self.last_epoch = math.floor(epoch)
|
| 79 |
+
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
| 80 |
+
param_group["lr"] = lr
|
utils/read_wave_utils.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import librosa
|
| 6 |
+
import torchaudio
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_wav_arbitrary_position_mono(filename, sample_rate, seq_duration):
|
| 10 |
+
# mono
|
| 11 |
+
# seq_duration[second]
|
| 12 |
+
length = torchaudio.info(filename).num_frames
|
| 13 |
+
|
| 14 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
| 15 |
+
if length > read_length:
|
| 16 |
+
random_start = random.randint(0, int(length - read_length - 1)) / sample_rate
|
| 17 |
+
X, sr = librosa.load(
|
| 18 |
+
filename, sr=None, offset=random_start, duration=seq_duration
|
| 19 |
+
)
|
| 20 |
+
else:
|
| 21 |
+
random_start = 0
|
| 22 |
+
total_pad_length = read_length - length
|
| 23 |
+
X, sr = librosa.load(filename, sr=None, offset=0, duration=seq_duration)
|
| 24 |
+
pad_left = random.randint(0, total_pad_length)
|
| 25 |
+
X = np.pad(X, (pad_left, total_pad_length - pad_left))
|
| 26 |
+
|
| 27 |
+
return X
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_wav_specific_position_mono(
|
| 31 |
+
filename, sample_rate, seq_duration, start_position
|
| 32 |
+
):
|
| 33 |
+
# mono
|
| 34 |
+
# seq_duration[second]
|
| 35 |
+
# start_position[second]
|
| 36 |
+
length = torchaudio.info(filename).num_frames
|
| 37 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
| 38 |
+
|
| 39 |
+
start_pos_sec = max(
|
| 40 |
+
start_position, 0
|
| 41 |
+
) # if start_position is minus, then start from 0.
|
| 42 |
+
start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
|
| 43 |
+
|
| 44 |
+
if (
|
| 45 |
+
length <= start_pos_sample
|
| 46 |
+
): # if start position exceeds audio length, then start from 0.
|
| 47 |
+
start_pos_sec = 0
|
| 48 |
+
start_pos_sample = 0
|
| 49 |
+
X, sr = librosa.load(filename, sr=None, offset=start_pos_sec, duration=seq_duration)
|
| 50 |
+
|
| 51 |
+
if length < start_pos_sample + read_length:
|
| 52 |
+
X = np.pad(X, (0, (start_pos_sample + read_length) - length))
|
| 53 |
+
|
| 54 |
+
return X
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# load wav file from arbitrary positions of 16bit stereo wav file
|
| 58 |
+
def load_wav_arbitrary_position_stereo(
|
| 59 |
+
filename, sample_rate, seq_duration, return_pos=False
|
| 60 |
+
):
|
| 61 |
+
# stereo
|
| 62 |
+
# seq_duration[second]
|
| 63 |
+
length = torchaudio.info(filename).num_frames
|
| 64 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
| 65 |
+
|
| 66 |
+
random_start_sample = random.randint(
|
| 67 |
+
0, int(length - math.ceil(seq_duration * sample_rate) - 1)
|
| 68 |
+
)
|
| 69 |
+
random_start_sec = librosa.samples_to_time(random_start_sample, sr=sample_rate)
|
| 70 |
+
X, sr = librosa.load(
|
| 71 |
+
filename, sr=None, mono=False, offset=random_start_sec, duration=seq_duration
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if length < random_start_sample + read_length:
|
| 75 |
+
X = np.pad(X, ((0, 0), (0, (random_start_sample + read_length) - length)))
|
| 76 |
+
|
| 77 |
+
if return_pos:
|
| 78 |
+
return X, random_start_sec
|
| 79 |
+
else:
|
| 80 |
+
return X
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def load_wav_specific_position_stereo(
|
| 84 |
+
filename, sample_rate, seq_duration, start_position
|
| 85 |
+
):
|
| 86 |
+
# stereo
|
| 87 |
+
# seq_duration[second]
|
| 88 |
+
# start_position[second]
|
| 89 |
+
length = torchaudio.info(filename).num_frames
|
| 90 |
+
read_length = librosa.time_to_samples(seq_duration, sr=sample_rate)
|
| 91 |
+
|
| 92 |
+
start_pos_sec = max(
|
| 93 |
+
start_position, 0
|
| 94 |
+
) # if start_position is minus, then start from 0.
|
| 95 |
+
start_pos_sample = librosa.time_to_samples(start_pos_sec, sr=sample_rate)
|
| 96 |
+
|
| 97 |
+
if (
|
| 98 |
+
length <= start_pos_sample
|
| 99 |
+
): # if start position exceeds audio length, then start from 0.
|
| 100 |
+
start_pos_sec = 0
|
| 101 |
+
start_pos_sample = 0
|
| 102 |
+
X, sr = librosa.load(
|
| 103 |
+
filename, sr=None, mono=False, offset=start_pos_sec, duration=seq_duration
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if length < start_pos_sample + read_length:
|
| 107 |
+
X = np.pad(X, ((0, 0), (0, (start_pos_sample + read_length) - length)))
|
| 108 |
+
|
| 109 |
+
return X
|
utils/train_utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
from dotmap import DotMap
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def worker_init_fn(worker_id):
|
| 9 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def str2bool(v):
|
| 13 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
| 14 |
+
return True
|
| 15 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
| 16 |
+
return False
|
| 17 |
+
else:
|
| 18 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_config(config_name="default"):
|
| 22 |
+
|
| 23 |
+
with open(f"./configs/{config_name}.yaml", "r") as f:
|
| 24 |
+
|
| 25 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
| 26 |
+
config = DotMap(config)
|
| 27 |
+
return config
|
weight/all.json
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"args": {
|
| 3 |
+
"classifier_params": {
|
| 4 |
+
"chosen_source_mean": 0.7,
|
| 5 |
+
"chosen_source_std": 0.15,
|
| 6 |
+
"classifier_activation": "softmax",
|
| 7 |
+
"classifier_n_classes": 4,
|
| 8 |
+
"classifier_n_srcs": 4,
|
| 9 |
+
"freeze_when_mixit": true,
|
| 10 |
+
"melspec_power": 2.0,
|
| 11 |
+
"model_name": "hrnet_w18_small",
|
| 12 |
+
"n_mels": 128,
|
| 13 |
+
"other_source_mean": 0.3,
|
| 14 |
+
"other_source_std": 0.15,
|
| 15 |
+
"pretrained_model": false,
|
| 16 |
+
"use_one_source_prob": 0.2,
|
| 17 |
+
"use_stereo": true
|
| 18 |
+
},
|
| 19 |
+
"conv_tasnet_params": {
|
| 20 |
+
"bn_chan": 128,
|
| 21 |
+
"decoder_activation": "sigmoid",
|
| 22 |
+
"encoder_activation": "relu",
|
| 23 |
+
"hid_chan": 512,
|
| 24 |
+
"kernel_size": 128,
|
| 25 |
+
"mask_act": "relu",
|
| 26 |
+
"n_blocks": 5,
|
| 27 |
+
"n_filters": 512,
|
| 28 |
+
"n_repeats": 2,
|
| 29 |
+
"skip_chan": 128,
|
| 30 |
+
"stride": 64
|
| 31 |
+
},
|
| 32 |
+
"data_params": {
|
| 33 |
+
"custom_limiter_attack_range": null,
|
| 34 |
+
"custom_limiter_release_range": null,
|
| 35 |
+
"limitaug_custom_target_lufs": null,
|
| 36 |
+
"limitaug_custom_target_lufs_std": null,
|
| 37 |
+
"limitaug_method": "ozone",
|
| 38 |
+
"limitaug_mode": null,
|
| 39 |
+
"nb_channels": 2,
|
| 40 |
+
"nfft": 4096,
|
| 41 |
+
"nhop": 1024,
|
| 42 |
+
"random_mix": true,
|
| 43 |
+
"sample_rate": 44100,
|
| 44 |
+
"samples_per_track": 128,
|
| 45 |
+
"seq_dur": 4.0,
|
| 46 |
+
"singleset_num_frames": null,
|
| 47 |
+
"target_limitaug_custom_target_lufs": null,
|
| 48 |
+
"target_limitaug_custom_target_lufs_std": null,
|
| 49 |
+
"target_limitaug_mode": null,
|
| 50 |
+
"target_loudnorm_lufs": -14.0,
|
| 51 |
+
"use_fixed": 0.019
|
| 52 |
+
},
|
| 53 |
+
"dir_params": {
|
| 54 |
+
"continual_train": false,
|
| 55 |
+
"delimit_valid_L_root": null,
|
| 56 |
+
"delimit_valid_root": null,
|
| 57 |
+
"exp_name": "convtasnet_35",
|
| 58 |
+
"output_directory": "/data2/personal/jeon/delimit/results",
|
| 59 |
+
"ozone_root": "/data5/personal/jeon/delimit/data",
|
| 60 |
+
"pretrained_classifier": null,
|
| 61 |
+
"resume": null,
|
| 62 |
+
"root": "/data1/Music/musdb18hq"
|
| 63 |
+
},
|
| 64 |
+
"gpu": 0,
|
| 65 |
+
"hyperparams": {
|
| 66 |
+
"batch_size": 8,
|
| 67 |
+
"ema": false,
|
| 68 |
+
"epochs": 200,
|
| 69 |
+
"gradient_clip": 5.0,
|
| 70 |
+
"lr": 3e-05,
|
| 71 |
+
"lr_decay_gamma": 0.5,
|
| 72 |
+
"lr_decay_patience": 15,
|
| 73 |
+
"lr_scheduler": "step_lr",
|
| 74 |
+
"optimizer": "adamw",
|
| 75 |
+
"patience": 50,
|
| 76 |
+
"weight_decay": 0.01
|
| 77 |
+
},
|
| 78 |
+
"img_check": "/data2/personal/jeon/delimit/results/img_check/convtasnet_35",
|
| 79 |
+
"invest_unet_params": {
|
| 80 |
+
"bn_factor": 16,
|
| 81 |
+
"f_down_layers": null,
|
| 82 |
+
"first_conv_activation": "relu",
|
| 83 |
+
"input_channels": 4,
|
| 84 |
+
"internal_channels": 24,
|
| 85 |
+
"kernel_size_f": 3,
|
| 86 |
+
"kernel_size_t": 3,
|
| 87 |
+
"last_activation": "identity",
|
| 88 |
+
"min_bn_units": 16,
|
| 89 |
+
"n_blocks": 7,
|
| 90 |
+
"n_internal_layers": 5,
|
| 91 |
+
"t_down_layers": null,
|
| 92 |
+
"tfc_tdf_activation": "relu",
|
| 93 |
+
"tfc_tdf_bias": true,
|
| 94 |
+
"tif_init_mode": null
|
| 95 |
+
},
|
| 96 |
+
"model_loss_params": {
|
| 97 |
+
"architecture": "conv_tasnet_mask_on_output",
|
| 98 |
+
"efficient_mixit_threshold": null,
|
| 99 |
+
"train_loss_func": [
|
| 100 |
+
"si_sdr"
|
| 101 |
+
],
|
| 102 |
+
"train_loss_scales": [
|
| 103 |
+
1.0
|
| 104 |
+
],
|
| 105 |
+
"valid_loss_func": [
|
| 106 |
+
"si_sdr"
|
| 107 |
+
],
|
| 108 |
+
"valid_loss_scales": [
|
| 109 |
+
1.0
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"ngpus_per_node": 1,
|
| 113 |
+
"output": "/data2/personal/jeon/delimit/results/checkpoint/convtasnet_35",
|
| 114 |
+
"resume": {},
|
| 115 |
+
"sample_rate": {},
|
| 116 |
+
"sys_params": {
|
| 117 |
+
"n_nodes": 1,
|
| 118 |
+
"nb_workers": 4,
|
| 119 |
+
"port": null,
|
| 120 |
+
"rank": 0,
|
| 121 |
+
"seed": 777,
|
| 122 |
+
"world_size": 1
|
| 123 |
+
},
|
| 124 |
+
"task_params": {
|
| 125 |
+
"bleeding_nsrcs": null,
|
| 126 |
+
"dataset": "delimit",
|
| 127 |
+
"target": "all",
|
| 128 |
+
"train": true
|
| 129 |
+
},
|
| 130 |
+
"umx_params": {
|
| 131 |
+
"activation": "relu",
|
| 132 |
+
"dropout_rate": 0.05,
|
| 133 |
+
"hidden_size": 512,
|
| 134 |
+
"instead_tanh_activation": "tanh",
|
| 135 |
+
"lstm_dropout_rate": 0.4,
|
| 136 |
+
"nb_layers": 3,
|
| 137 |
+
"normalization": "bn",
|
| 138 |
+
"umx_get_statistics": false
|
| 139 |
+
},
|
| 140 |
+
"wandb_params": {
|
| 141 |
+
"entity": "vinyne",
|
| 142 |
+
"project": "delimit",
|
| 143 |
+
"rerun_id": null,
|
| 144 |
+
"sweep": false,
|
| 145 |
+
"use_wandb": true
|
| 146 |
+
}
|
| 147 |
+
},
|
| 148 |
+
"best_epoch": 183,
|
| 149 |
+
"best_loss": -14.165373802185059,
|
| 150 |
+
"epochs_trained": 200,
|
| 151 |
+
"num_bad_epochs": 17,
|
| 152 |
+
"train_loss_history": [
|
| 153 |
+
-11.723381042480469,
|
| 154 |
+
-11.759103775024414,
|
| 155 |
+
-11.818404197692871,
|
| 156 |
+
-11.88597583770752,
|
| 157 |
+
-11.882278442382812,
|
| 158 |
+
-11.943178176879883,
|
| 159 |
+
-11.909675598144531,
|
| 160 |
+
-11.93053913116455,
|
| 161 |
+
-11.922198295593262,
|
| 162 |
+
-12.013456344604492,
|
| 163 |
+
-12.106053352355957,
|
| 164 |
+
-11.999975204467773,
|
| 165 |
+
-12.067265510559082,
|
| 166 |
+
-12.079473495483398,
|
| 167 |
+
-12.13272762298584,
|
| 168 |
+
-12.15418529510498,
|
| 169 |
+
-12.08314037322998,
|
| 170 |
+
-12.152527809143066,
|
| 171 |
+
-12.096565246582031,
|
| 172 |
+
-12.219636917114258,
|
| 173 |
+
-12.246475219726562,
|
| 174 |
+
-12.170637130737305,
|
| 175 |
+
-12.188806533813477,
|
| 176 |
+
-12.230484962463379,
|
| 177 |
+
-12.207123756408691,
|
| 178 |
+
-12.307502746582031,
|
| 179 |
+
-12.200200080871582,
|
| 180 |
+
-12.284586906433105,
|
| 181 |
+
-12.244038581848145,
|
| 182 |
+
-12.302275657653809,
|
| 183 |
+
-12.200104713439941,
|
| 184 |
+
-12.31570816040039,
|
| 185 |
+
-12.42324447631836,
|
| 186 |
+
-12.352653503417969,
|
| 187 |
+
-12.367401123046875,
|
| 188 |
+
-12.295838356018066,
|
| 189 |
+
-12.404874801635742,
|
| 190 |
+
-12.338440895080566,
|
| 191 |
+
-12.365501403808594,
|
| 192 |
+
-12.365768432617188,
|
| 193 |
+
-12.225799560546875,
|
| 194 |
+
-12.26883602142334,
|
| 195 |
+
-12.390016555786133,
|
| 196 |
+
-12.410661697387695,
|
| 197 |
+
-12.311858177185059,
|
| 198 |
+
-12.408061027526855,
|
| 199 |
+
-12.396013259887695,
|
| 200 |
+
-12.353321075439453,
|
| 201 |
+
-12.470121383666992,
|
| 202 |
+
-12.469389915466309,
|
| 203 |
+
-12.452675819396973,
|
| 204 |
+
-12.381932258605957,
|
| 205 |
+
-12.31003475189209,
|
| 206 |
+
-12.412126541137695,
|
| 207 |
+
-12.267746925354004,
|
| 208 |
+
-12.440984725952148,
|
| 209 |
+
-12.413816452026367,
|
| 210 |
+
-12.417757034301758,
|
| 211 |
+
-12.4945650100708,
|
| 212 |
+
-12.445524215698242,
|
| 213 |
+
-12.38110065460205,
|
| 214 |
+
-12.454893112182617,
|
| 215 |
+
-12.390727996826172,
|
| 216 |
+
-12.339771270751953,
|
| 217 |
+
-12.528243064880371,
|
| 218 |
+
-12.434144973754883,
|
| 219 |
+
-12.43438720703125,
|
| 220 |
+
-12.458473205566406,
|
| 221 |
+
-12.424423217773438,
|
| 222 |
+
-12.387894630432129,
|
| 223 |
+
-12.438997268676758,
|
| 224 |
+
-12.528799057006836,
|
| 225 |
+
-12.423232078552246,
|
| 226 |
+
-12.534538269042969,
|
| 227 |
+
-12.495400428771973,
|
| 228 |
+
-12.53675651550293,
|
| 229 |
+
-12.551910400390625,
|
| 230 |
+
-12.478575706481934,
|
| 231 |
+
-12.461804389953613,
|
| 232 |
+
-12.483702659606934,
|
| 233 |
+
-12.474960327148438,
|
| 234 |
+
-12.441666603088379,
|
| 235 |
+
-12.42241096496582,
|
| 236 |
+
-12.48852252960205,
|
| 237 |
+
-12.513558387756348,
|
| 238 |
+
-12.40845012664795,
|
| 239 |
+
-12.555559158325195,
|
| 240 |
+
-12.589385032653809,
|
| 241 |
+
-12.395785331726074,
|
| 242 |
+
-12.496671676635742,
|
| 243 |
+
-12.554829597473145,
|
| 244 |
+
-12.530548095703125,
|
| 245 |
+
-12.564457893371582,
|
| 246 |
+
-12.52737808227539,
|
| 247 |
+
-12.608246803283691,
|
| 248 |
+
-12.3996000289917,
|
| 249 |
+
-12.433905601501465,
|
| 250 |
+
-12.490935325622559,
|
| 251 |
+
-12.477506637573242,
|
| 252 |
+
-12.470728874206543,
|
| 253 |
+
-12.564470291137695,
|
| 254 |
+
-12.525967597961426,
|
| 255 |
+
-12.502660751342773,
|
| 256 |
+
-12.440997123718262,
|
| 257 |
+
-12.576118469238281,
|
| 258 |
+
-12.538352966308594,
|
| 259 |
+
-12.512738227844238,
|
| 260 |
+
-12.525115966796875,
|
| 261 |
+
-12.511483192443848,
|
| 262 |
+
-12.571795463562012,
|
| 263 |
+
-12.59391975402832,
|
| 264 |
+
-12.442131996154785,
|
| 265 |
+
-12.617898941040039,
|
| 266 |
+
-12.495210647583008,
|
| 267 |
+
-12.551814079284668,
|
| 268 |
+
-12.4913330078125,
|
| 269 |
+
-12.626816749572754,
|
| 270 |
+
-12.556028366088867,
|
| 271 |
+
-12.477901458740234,
|
| 272 |
+
-12.596776008605957,
|
| 273 |
+
-12.597326278686523,
|
| 274 |
+
-12.484386444091797,
|
| 275 |
+
-12.660898208618164,
|
| 276 |
+
-12.440162658691406,
|
| 277 |
+
-12.530372619628906,
|
| 278 |
+
-12.51207447052002,
|
| 279 |
+
-12.503606796264648,
|
| 280 |
+
-12.670214653015137,
|
| 281 |
+
-12.51667308807373,
|
| 282 |
+
-12.546160697937012,
|
| 283 |
+
-12.504158020019531,
|
| 284 |
+
-12.6427001953125,
|
| 285 |
+
-12.56100082397461,
|
| 286 |
+
-12.506058692932129,
|
| 287 |
+
-12.637288093566895,
|
| 288 |
+
-12.572591781616211,
|
| 289 |
+
-12.544734001159668,
|
| 290 |
+
-12.604019165039062,
|
| 291 |
+
-12.549866676330566,
|
| 292 |
+
-12.521714210510254,
|
| 293 |
+
-12.601127624511719,
|
| 294 |
+
-12.629931449890137,
|
| 295 |
+
-12.587185859680176,
|
| 296 |
+
-12.605366706848145,
|
| 297 |
+
-12.606413841247559,
|
| 298 |
+
-12.536269187927246,
|
| 299 |
+
-12.577346801757812,
|
| 300 |
+
-12.703147888183594,
|
| 301 |
+
-12.60477066040039,
|
| 302 |
+
-12.603355407714844,
|
| 303 |
+
-12.536528587341309,
|
| 304 |
+
-12.601842880249023,
|
| 305 |
+
-12.698568344116211,
|
| 306 |
+
-12.72192668914795,
|
| 307 |
+
-12.663148880004883,
|
| 308 |
+
-12.644909858703613,
|
| 309 |
+
-12.631479263305664,
|
| 310 |
+
-12.596253395080566,
|
| 311 |
+
-12.61674690246582,
|
| 312 |
+
-12.701379776000977,
|
| 313 |
+
-12.664311408996582,
|
| 314 |
+
-12.646204948425293,
|
| 315 |
+
-12.597058296203613,
|
| 316 |
+
-12.652384757995605,
|
| 317 |
+
-12.579480171203613,
|
| 318 |
+
-12.757433891296387,
|
| 319 |
+
-12.686827659606934,
|
| 320 |
+
-12.65634536743164,
|
| 321 |
+
-12.552176475524902,
|
| 322 |
+
-12.625761032104492,
|
| 323 |
+
-12.652499198913574,
|
| 324 |
+
-12.668974876403809,
|
| 325 |
+
-12.700301170349121,
|
| 326 |
+
-12.591926574707031,
|
| 327 |
+
-12.54333782196045,
|
| 328 |
+
-12.541864395141602,
|
| 329 |
+
-12.720565795898438,
|
| 330 |
+
-12.625009536743164,
|
| 331 |
+
-12.577120780944824,
|
| 332 |
+
-12.67569637298584,
|
| 333 |
+
-12.634958267211914,
|
| 334 |
+
-12.660367012023926,
|
| 335 |
+
-12.646204948425293,
|
| 336 |
+
-12.713308334350586,
|
| 337 |
+
-12.734916687011719,
|
| 338 |
+
-12.602835655212402,
|
| 339 |
+
-12.596168518066406,
|
| 340 |
+
-12.66109848022461,
|
| 341 |
+
-12.568808555603027,
|
| 342 |
+
-12.719843864440918,
|
| 343 |
+
-12.746356010437012,
|
| 344 |
+
-12.602999687194824,
|
| 345 |
+
-12.632689476013184,
|
| 346 |
+
-12.715725898742676,
|
| 347 |
+
-12.671126365661621,
|
| 348 |
+
-12.659911155700684,
|
| 349 |
+
-12.755860328674316,
|
| 350 |
+
-12.591080665588379,
|
| 351 |
+
-12.623464584350586,
|
| 352 |
+
-12.643362045288086
|
| 353 |
+
],
|
| 354 |
+
"train_time_history": [
|
| 355 |
+
308.12283968925476,
|
| 356 |
+
308.12408661842346,
|
| 357 |
+
305.56318974494934,
|
| 358 |
+
305.6093053817749,
|
| 359 |
+
304.1926734447479,
|
| 360 |
+
304.2103099822998,
|
| 361 |
+
301.78035831451416,
|
| 362 |
+
301.7819468975067,
|
| 363 |
+
317.8168547153473,
|
| 364 |
+
317.818119764328,
|
| 365 |
+
314.8585801124573,
|
| 366 |
+
314.8601076602936,
|
| 367 |
+
311.61795926094055,
|
| 368 |
+
311.61953926086426,
|
| 369 |
+
316.2616910934448,
|
| 370 |
+
316.2639091014862,
|
| 371 |
+
312.59282636642456,
|
| 372 |
+
312.59408020973206,
|
| 373 |
+
314.6765525341034,
|
| 374 |
+
314.6778757572174,
|
| 375 |
+
314.4039900302887,
|
| 376 |
+
314.40531301498413,
|
| 377 |
+
313.9343922138214,
|
| 378 |
+
313.9356322288513,
|
| 379 |
+
315.1470823287964,
|
| 380 |
+
315.14854192733765,
|
| 381 |
+
317.65793561935425,
|
| 382 |
+
317.65903544425964,
|
| 383 |
+
316.41589403152466,
|
| 384 |
+
316.4171371459961,
|
| 385 |
+
316.253050327301,
|
| 386 |
+
316.2544617652893,
|
| 387 |
+
316.2039670944214,
|
| 388 |
+
316.20542550086975,
|
| 389 |
+
316.30707120895386,
|
| 390 |
+
316.30964159965515,
|
| 391 |
+
315.7812213897705,
|
| 392 |
+
315.7832131385803,
|
| 393 |
+
315.77191638946533,
|
| 394 |
+
315.7732570171356,
|
| 395 |
+
315.7776229381561,
|
| 396 |
+
315.77907848358154,
|
| 397 |
+
315.80343294143677,
|
| 398 |
+
315.8051166534424,
|
| 399 |
+
314.40133929252625,
|
| 400 |
+
314.403112411499,
|
| 401 |
+
314.32283997535706,
|
| 402 |
+
314.32424092292786,
|
| 403 |
+
314.90000677108765,
|
| 404 |
+
314.90242648124695,
|
| 405 |
+
313.8207128047943,
|
| 406 |
+
313.8227391242981,
|
| 407 |
+
313.86938881874084,
|
| 408 |
+
313.87079215049744,
|
| 409 |
+
316.9037547111511,
|
| 410 |
+
316.9056947231293,
|
| 411 |
+
317.4321286678314,
|
| 412 |
+
317.43361139297485,
|
| 413 |
+
316.41515493392944,
|
| 414 |
+
316.4182825088501,
|
| 415 |
+
315.69741559028625,
|
| 416 |
+
315.699245929718,
|
| 417 |
+
315.9285054206848,
|
| 418 |
+
315.930716753006,
|
| 419 |
+
314.25376319885254,
|
| 420 |
+
314.25567531585693,
|
| 421 |
+
312.997665643692,
|
| 422 |
+
313.0005877017975,
|
| 423 |
+
315.5962414741516,
|
| 424 |
+
315.5977747440338,
|
| 425 |
+
315.49425506591797,
|
| 426 |
+
315.4961242675781,
|
| 427 |
+
315.980491399765,
|
| 428 |
+
315.98283791542053,
|
| 429 |
+
315.5533638000488,
|
| 430 |
+
315.55492901802063,
|
| 431 |
+
313.9896593093872,
|
| 432 |
+
313.99131321907043,
|
| 433 |
+
314.3214478492737,
|
| 434 |
+
314.3232262134552,
|
| 435 |
+
314.6442220211029,
|
| 436 |
+
314.64620661735535,
|
| 437 |
+
315.69726514816284,
|
| 438 |
+
315.7001700401306,
|
| 439 |
+
314.78302001953125,
|
| 440 |
+
314.7847316265106,
|
| 441 |
+
313.14448523521423,
|
| 442 |
+
313.1465194225311,
|
| 443 |
+
311.8232834339142,
|
| 444 |
+
311.8251144886017,
|
| 445 |
+
318.88225960731506,
|
| 446 |
+
318.8843643665314,
|
| 447 |
+
319.20725083351135,
|
| 448 |
+
319.20886182785034,
|
| 449 |
+
317.81429648399353,
|
| 450 |
+
317.8159878253937,
|
| 451 |
+
320.23738193511963,
|
| 452 |
+
320.23904752731323,
|
| 453 |
+
315.8315763473511,
|
| 454 |
+
315.83344054222107,
|
| 455 |
+
317.32581615448,
|
| 456 |
+
317.3274848461151,
|
| 457 |
+
316.7596924304962,
|
| 458 |
+
316.7628848552704,
|
| 459 |
+
316.3167974948883,
|
| 460 |
+
316.3188827037811,
|
| 461 |
+
316.44567823410034,
|
| 462 |
+
316.44802141189575,
|
| 463 |
+
313.8653395175934,
|
| 464 |
+
313.8687484264374,
|
| 465 |
+
308.43933939933777,
|
| 466 |
+
308.44151163101196,
|
| 467 |
+
312.1857454776764,
|
| 468 |
+
312.18967509269714,
|
| 469 |
+
307.8407344818115,
|
| 470 |
+
307.84401679039,
|
| 471 |
+
307.48447585105896,
|
| 472 |
+
307.48623728752136,
|
| 473 |
+
310.300940990448,
|
| 474 |
+
310.3029022216797,
|
| 475 |
+
310.32225275039673,
|
| 476 |
+
310.3257050514221,
|
| 477 |
+
309.351779460907,
|
| 478 |
+
309.3539865016937,
|
| 479 |
+
309.4356527328491,
|
| 480 |
+
309.4380919933319,
|
| 481 |
+
312.63360381126404,
|
| 482 |
+
312.63535809516907,
|
| 483 |
+
311.7453818321228,
|
| 484 |
+
311.7476508617401,
|
| 485 |
+
311.3258364200592,
|
| 486 |
+
311.327698469162,
|
| 487 |
+
312.28111600875854,
|
| 488 |
+
312.2828998565674,
|
| 489 |
+
311.3383209705353,
|
| 490 |
+
311.34200048446655,
|
| 491 |
+
306.9764757156372,
|
| 492 |
+
306.9787657260895,
|
| 493 |
+
309.35506653785706,
|
| 494 |
+
309.3569576740265,
|
| 495 |
+
310.2506465911865,
|
| 496 |
+
310.2529339790344,
|
| 497 |
+
310.65880727767944,
|
| 498 |
+
310.66108298301697,
|
| 499 |
+
311.18562865257263,
|
| 500 |
+
311.1874952316284,
|
| 501 |
+
309.07765316963196,
|
| 502 |
+
309.07997822761536,
|
| 503 |
+
313.3008818626404,
|
| 504 |
+
313.3029179573059,
|
| 505 |
+
311.267498254776,
|
| 506 |
+
311.26989102363586,
|
| 507 |
+
310.62635374069214,
|
| 508 |
+
310.6306185722351,
|
| 509 |
+
308.1883268356323,
|
| 510 |
+
308.19112515449524,
|
| 511 |
+
310.65689158439636,
|
| 512 |
+
310.65896558761597,
|
| 513 |
+
308.98754620552063,
|
| 514 |
+
309.03386878967285,
|
| 515 |
+
309.21512937545776,
|
| 516 |
+
309.2185757160187,
|
| 517 |
+
309.93750405311584,
|
| 518 |
+
309.93965554237366,
|
| 519 |
+
310.2938587665558,
|
| 520 |
+
310.29592084884644,
|
| 521 |
+
308.24257493019104,
|
| 522 |
+
308.2463102340698,
|
| 523 |
+
310.6870594024658,
|
| 524 |
+
310.6905345916748,
|
| 525 |
+
310.7875945568085,
|
| 526 |
+
310.78995156288147,
|
| 527 |
+
310.9882712364197,
|
| 528 |
+
310.9906806945801,
|
| 529 |
+
310.95856285095215,
|
| 530 |
+
310.96066546440125,
|
| 531 |
+
312.4489221572876,
|
| 532 |
+
312.45125246047974,
|
| 533 |
+
312.24022579193115,
|
| 534 |
+
312.2863116264343,
|
| 535 |
+
309.68400406837463,
|
| 536 |
+
309.6862533092499,
|
| 537 |
+
309.64014887809753,
|
| 538 |
+
309.64232993125916,
|
| 539 |
+
309.9094281196594,
|
| 540 |
+
309.9119017124176,
|
| 541 |
+
309.40677762031555,
|
| 542 |
+
309.40893173217773,
|
| 543 |
+
309.1595506668091,
|
| 544 |
+
309.1617259979248,
|
| 545 |
+
308.4178020954132,
|
| 546 |
+
308.4198989868164,
|
| 547 |
+
308.5063133239746,
|
| 548 |
+
308.5085346698761,
|
| 549 |
+
307.5796904563904,
|
| 550 |
+
307.5972898006439,
|
| 551 |
+
309.66309905052185,
|
| 552 |
+
309.66530561447144,
|
| 553 |
+
312.70798993110657,
|
| 554 |
+
312.7102212905884,
|
| 555 |
+
310.2431013584137,
|
| 556 |
+
310.2453660964966,
|
| 557 |
+
312.2640459537506,
|
| 558 |
+
312.26635122299194,
|
| 559 |
+
311.27055287361145,
|
| 560 |
+
311.27321219444275,
|
| 561 |
+
312.58145689964294,
|
| 562 |
+
312.58376598358154,
|
| 563 |
+
313.1553518772125,
|
| 564 |
+
313.1574249267578,
|
| 565 |
+
308.4067575931549,
|
| 566 |
+
308.4089684486389,
|
| 567 |
+
311.0251498222351,
|
| 568 |
+
311.0274658203125,
|
| 569 |
+
308.0227520465851,
|
| 570 |
+
308.02498388290405,
|
| 571 |
+
308.0182030200958,
|
| 572 |
+
308.0204634666443,
|
| 573 |
+
308.63523149490356,
|
| 574 |
+
308.63751220703125,
|
| 575 |
+
308.53969383239746,
|
| 576 |
+
308.5420751571655,
|
| 577 |
+
306.51329946517944,
|
| 578 |
+
306.51555824279785,
|
| 579 |
+
309.59846591949463,
|
| 580 |
+
309.60128831863403,
|
| 581 |
+
305.3712034225464,
|
| 582 |
+
305.37409830093384,
|
| 583 |
+
305.43984270095825,
|
| 584 |
+
305.4421238899231,
|
| 585 |
+
309.3166663646698,
|
| 586 |
+
309.3195414543152,
|
| 587 |
+
308.8618497848511,
|
| 588 |
+
308.86409974098206,
|
| 589 |
+
304.8731882572174,
|
| 590 |
+
304.8755958080292,
|
| 591 |
+
306.6576888561249,
|
| 592 |
+
306.663143157959,
|
| 593 |
+
306.6716537475586,
|
| 594 |
+
306.6740062236786,
|
| 595 |
+
309.47339940071106,
|
| 596 |
+
309.47578954696655,
|
| 597 |
+
307.73386335372925,
|
| 598 |
+
307.7363700866699,
|
| 599 |
+
308.0688214302063,
|
| 600 |
+
308.07209277153015,
|
| 601 |
+
311.58968901634216,
|
| 602 |
+
311.6099576950073,
|
| 603 |
+
308.70460844039917,
|
| 604 |
+
308.70710158348083,
|
| 605 |
+
312.0563473701477,
|
| 606 |
+
312.05881452560425,
|
| 607 |
+
310.89456367492676,
|
| 608 |
+
310.9119510650635,
|
| 609 |
+
308.73097705841064,
|
| 610 |
+
308.73414373397827,
|
| 611 |
+
309.4255359172821,
|
| 612 |
+
309.42857813835144,
|
| 613 |
+
311.0751721858978,
|
| 614 |
+
311.07801842689514,
|
| 615 |
+
309.5860447883606,
|
| 616 |
+
309.5896680355072,
|
| 617 |
+
309.87396597862244,
|
| 618 |
+
309.8803391456604,
|
| 619 |
+
310.9183626174927,
|
| 620 |
+
310.92147397994995,
|
| 621 |
+
308.4321529865265,
|
| 622 |
+
308.4359757900238,
|
| 623 |
+
312.4424922466278,
|
| 624 |
+
312.44731879234314,
|
| 625 |
+
312.3443009853363,
|
| 626 |
+
312.3491401672363,
|
| 627 |
+
310.3139410018921,
|
| 628 |
+
310.3165555000305,
|
| 629 |
+
312.09410762786865,
|
| 630 |
+
312.09656262397766,
|
| 631 |
+
311.11144399642944,
|
| 632 |
+
311.1577796936035,
|
| 633 |
+
309.1589603424072,
|
| 634 |
+
309.16152119636536,
|
| 635 |
+
312.51157093048096,
|
| 636 |
+
312.51463317871094,
|
| 637 |
+
314.15198159217834,
|
| 638 |
+
314.15485286712646,
|
| 639 |
+
310.00070810317993,
|
| 640 |
+
310.0033264160156,
|
| 641 |
+
311.2290298938751,
|
| 642 |
+
311.23188829421997,
|
| 643 |
+
313.0510983467102,
|
| 644 |
+
313.05362153053284,
|
| 645 |
+
313.48791670799255,
|
| 646 |
+
313.4910161495209,
|
| 647 |
+
307.60272216796875,
|
| 648 |
+
307.6053590774536,
|
| 649 |
+
303.84622287750244,
|
| 650 |
+
303.8494029045105,
|
| 651 |
+
304.8547012805939,
|
| 652 |
+
304.85784125328064,
|
| 653 |
+
310.63141536712646,
|
| 654 |
+
310.63450264930725,
|
| 655 |
+
304.8634753227234,
|
| 656 |
+
304.8664004802704,
|
| 657 |
+
308.1505949497223,
|
| 658 |
+
308.15428018569946,
|
| 659 |
+
310.18936228752136,
|
| 660 |
+
310.1920323371887,
|
| 661 |
+
309.2550263404846,
|
| 662 |
+
309.2577428817749,
|
| 663 |
+
310.08596634864807,
|
| 664 |
+
310.08910751342773,
|
| 665 |
+
307.4643654823303,
|
| 666 |
+
307.4670605659485,
|
| 667 |
+
308.558221578598,
|
| 668 |
+
308.5638659000397,
|
| 669 |
+
309.7440264225006,
|
| 670 |
+
309.7467608451843,
|
| 671 |
+
308.2091956138611,
|
| 672 |
+
308.2125828266144,
|
| 673 |
+
307.0199763774872,
|
| 674 |
+
307.02332496643066,
|
| 675 |
+
306.3482081890106,
|
| 676 |
+
306.35128688812256,
|
| 677 |
+
307.3764581680298,
|
| 678 |
+
307.37923669815063,
|
| 679 |
+
311.61060428619385,
|
| 680 |
+
311.6135311126709,
|
| 681 |
+
306.8187861442566,
|
| 682 |
+
306.8240280151367,
|
| 683 |
+
305.19880175590515,
|
| 684 |
+
305.20313119888306,
|
| 685 |
+
309.252712726593,
|
| 686 |
+
309.256165266037,
|
| 687 |
+
310.80801463127136,
|
| 688 |
+
310.81236577033997,
|
| 689 |
+
309.1079206466675,
|
| 690 |
+
309.11073756217957,
|
| 691 |
+
310.6556165218353,
|
| 692 |
+
310.65838623046875,
|
| 693 |
+
310.94868993759155,
|
| 694 |
+
310.95155143737793,
|
| 695 |
+
308.4552607536316,
|
| 696 |
+
308.4580717086792,
|
| 697 |
+
308.2857587337494,
|
| 698 |
+
308.2886221408844,
|
| 699 |
+
306.4856150150299,
|
| 700 |
+
306.4887855052948,
|
| 701 |
+
306.8667871952057,
|
| 702 |
+
306.86966013908386,
|
| 703 |
+
306.1964519023895,
|
| 704 |
+
306.2005341053009,
|
| 705 |
+
308.2178611755371,
|
| 706 |
+
308.22126364707947,
|
| 707 |
+
305.94888377189636,
|
| 708 |
+
305.9523375034332,
|
| 709 |
+
307.48926973342896,
|
| 710 |
+
307.4920620918274,
|
| 711 |
+
307.60354018211365,
|
| 712 |
+
307.63674998283386,
|
| 713 |
+
307.2473645210266,
|
| 714 |
+
307.2501358985901,
|
| 715 |
+
308.16573452949524,
|
| 716 |
+
308.2115182876587,
|
| 717 |
+
307.30736780166626,
|
| 718 |
+
307.3109815120697,
|
| 719 |
+
307.2137475013733,
|
| 720 |
+
307.2178246974945,
|
| 721 |
+
308.5944905281067,
|
| 722 |
+
308.59843826293945,
|
| 723 |
+
307.2346291542053,
|
| 724 |
+
307.2382435798645,
|
| 725 |
+
308.417338848114,
|
| 726 |
+
308.4208617210388,
|
| 727 |
+
305.5816307067871,
|
| 728 |
+
305.5852439403534,
|
| 729 |
+
307.69459652900696,
|
| 730 |
+
307.6975119113922,
|
| 731 |
+
307.20833134651184,
|
| 732 |
+
307.212299823761,
|
| 733 |
+
305.9614431858063,
|
| 734 |
+
305.965185880661,
|
| 735 |
+
305.31594157218933,
|
| 736 |
+
305.3195445537567,
|
| 737 |
+
307.46696519851685,
|
| 738 |
+
307.47079825401306,
|
| 739 |
+
306.23966455459595,
|
| 740 |
+
306.2433180809021,
|
| 741 |
+
306.1235647201538,
|
| 742 |
+
306.1273248195648,
|
| 743 |
+
307.02436780929565,
|
| 744 |
+
307.02733421325684,
|
| 745 |
+
306.9687819480896,
|
| 746 |
+
306.97225856781006,
|
| 747 |
+
306.23205065727234,
|
| 748 |
+
306.2356073856354,
|
| 749 |
+
305.3567383289337,
|
| 750 |
+
305.36028504371643,
|
| 751 |
+
305.94446635246277,
|
| 752 |
+
305.9480822086334,
|
| 753 |
+
307.2553553581238
|
| 754 |
+
],
|
| 755 |
+
"valid_loss_history": [
|
| 756 |
+
-12.743322372436523,
|
| 757 |
+
-12.724347114562988,
|
| 758 |
+
-12.86701488494873,
|
| 759 |
+
-12.694435119628906,
|
| 760 |
+
-12.706733703613281,
|
| 761 |
+
-13.048251152038574,
|
| 762 |
+
-12.943618774414062,
|
| 763 |
+
-13.120084762573242,
|
| 764 |
+
-13.121935844421387,
|
| 765 |
+
-13.146740913391113,
|
| 766 |
+
-13.197364807128906,
|
| 767 |
+
-13.224929809570312,
|
| 768 |
+
-13.255891799926758,
|
| 769 |
+
-13.311783790588379,
|
| 770 |
+
-13.386489868164062,
|
| 771 |
+
-13.390006065368652,
|
| 772 |
+
-13.45509147644043,
|
| 773 |
+
-13.444679260253906,
|
| 774 |
+
-13.456311225891113,
|
| 775 |
+
-13.36051082611084,
|
| 776 |
+
-13.478644371032715,
|
| 777 |
+
-13.503388404846191,
|
| 778 |
+
-13.540580749511719,
|
| 779 |
+
-13.579903602600098,
|
| 780 |
+
-13.551591873168945,
|
| 781 |
+
-13.638075828552246,
|
| 782 |
+
-13.617512702941895,
|
| 783 |
+
-13.64240550994873,
|
| 784 |
+
-13.618767738342285,
|
| 785 |
+
-13.65319538116455,
|
| 786 |
+
-13.601574897766113,
|
| 787 |
+
-13.693778038024902,
|
| 788 |
+
-13.658882141113281,
|
| 789 |
+
-13.649510383605957,
|
| 790 |
+
-13.477263450622559,
|
| 791 |
+
-13.643564224243164,
|
| 792 |
+
-13.732584953308105,
|
| 793 |
+
-13.643271446228027,
|
| 794 |
+
-13.655325889587402,
|
| 795 |
+
-13.71172046661377,
|
| 796 |
+
-13.564180374145508,
|
| 797 |
+
-13.708178520202637,
|
| 798 |
+
-13.688010215759277,
|
| 799 |
+
-13.711198806762695,
|
| 800 |
+
-13.612863540649414,
|
| 801 |
+
-13.702019691467285,
|
| 802 |
+
-13.704530715942383,
|
| 803 |
+
-13.716957092285156,
|
| 804 |
+
-13.76714038848877,
|
| 805 |
+
-13.719636917114258,
|
| 806 |
+
-13.738469123840332,
|
| 807 |
+
-13.759002685546875,
|
| 808 |
+
-13.721348762512207,
|
| 809 |
+
-13.727803230285645,
|
| 810 |
+
-13.768327713012695,
|
| 811 |
+
-13.73253345489502,
|
| 812 |
+
-13.75208568572998,
|
| 813 |
+
-13.754429817199707,
|
| 814 |
+
-13.76417064666748,
|
| 815 |
+
-13.805985450744629,
|
| 816 |
+
-13.762914657592773,
|
| 817 |
+
-13.75927448272705,
|
| 818 |
+
-13.781553268432617,
|
| 819 |
+
-13.744827270507812,
|
| 820 |
+
-13.805213928222656,
|
| 821 |
+
-13.792055130004883,
|
| 822 |
+
-13.736992835998535,
|
| 823 |
+
-13.804685592651367,
|
| 824 |
+
-13.802186012268066,
|
| 825 |
+
-13.812178611755371,
|
| 826 |
+
-13.781081199645996,
|
| 827 |
+
-13.836441993713379,
|
| 828 |
+
-13.787053108215332,
|
| 829 |
+
-13.824462890625,
|
| 830 |
+
-13.827963829040527,
|
| 831 |
+
-13.768393516540527,
|
| 832 |
+
-13.824796676635742,
|
| 833 |
+
-13.809252738952637,
|
| 834 |
+
-13.820283889770508,
|
| 835 |
+
-13.811989784240723,
|
| 836 |
+
-13.845786094665527,
|
| 837 |
+
-13.801295280456543,
|
| 838 |
+
-13.795866966247559,
|
| 839 |
+
-13.847658157348633,
|
| 840 |
+
-13.841630935668945,
|
| 841 |
+
-13.887687683105469,
|
| 842 |
+
-13.838217735290527,
|
| 843 |
+
-13.833791732788086,
|
| 844 |
+
-13.8090181350708,
|
| 845 |
+
-13.810338973999023,
|
| 846 |
+
-13.812939643859863,
|
| 847 |
+
-13.813563346862793,
|
| 848 |
+
-13.72245979309082,
|
| 849 |
+
-13.829062461853027,
|
| 850 |
+
-13.820122718811035,
|
| 851 |
+
-13.764768600463867,
|
| 852 |
+
-13.882962226867676,
|
| 853 |
+
-13.887824058532715,
|
| 854 |
+
-13.874728202819824,
|
| 855 |
+
-13.83934211730957,
|
| 856 |
+
-13.854304313659668,
|
| 857 |
+
-13.853861808776855,
|
| 858 |
+
-13.878510475158691,
|
| 859 |
+
-13.855673789978027,
|
| 860 |
+
-13.935111999511719,
|
| 861 |
+
-13.873315811157227,
|
| 862 |
+
-13.88434886932373,
|
| 863 |
+
-13.913508415222168,
|
| 864 |
+
-13.804875373840332,
|
| 865 |
+
-13.874313354492188,
|
| 866 |
+
-13.925950050354004,
|
| 867 |
+
-13.898317337036133,
|
| 868 |
+
-13.861913681030273,
|
| 869 |
+
-13.83596134185791,
|
| 870 |
+
-13.907777786254883,
|
| 871 |
+
-13.832358360290527,
|
| 872 |
+
-13.936162948608398,
|
| 873 |
+
-13.925071716308594,
|
| 874 |
+
-13.906752586364746,
|
| 875 |
+
-13.87073040008545,
|
| 876 |
+
-13.964620590209961,
|
| 877 |
+
-13.925311088562012,
|
| 878 |
+
-13.974698066711426,
|
| 879 |
+
-13.957905769348145,
|
| 880 |
+
-13.918564796447754,
|
| 881 |
+
-13.975790023803711,
|
| 882 |
+
-13.988444328308105,
|
| 883 |
+
-13.959516525268555,
|
| 884 |
+
-14.01569652557373,
|
| 885 |
+
-13.992425918579102,
|
| 886 |
+
-14.039790153503418,
|
| 887 |
+
-13.940314292907715,
|
| 888 |
+
-14.011497497558594,
|
| 889 |
+
-13.953152656555176,
|
| 890 |
+
-13.920698165893555,
|
| 891 |
+
-13.960227966308594,
|
| 892 |
+
-13.907439231872559,
|
| 893 |
+
-14.014067649841309,
|
| 894 |
+
-13.972914695739746,
|
| 895 |
+
-13.942621231079102,
|
| 896 |
+
-14.019667625427246,
|
| 897 |
+
-14.037107467651367,
|
| 898 |
+
-13.85366153717041,
|
| 899 |
+
-13.980110168457031,
|
| 900 |
+
-13.97785472869873,
|
| 901 |
+
-13.983843803405762,
|
| 902 |
+
-13.843756675720215,
|
| 903 |
+
-14.002585411071777,
|
| 904 |
+
-14.026784896850586,
|
| 905 |
+
-14.028115272521973,
|
| 906 |
+
-14.02059268951416,
|
| 907 |
+
-13.985837936401367,
|
| 908 |
+
-14.076154708862305,
|
| 909 |
+
-14.060620307922363,
|
| 910 |
+
-13.936518669128418,
|
| 911 |
+
-13.957221031188965,
|
| 912 |
+
-14.017061233520508,
|
| 913 |
+
-13.995661735534668,
|
| 914 |
+
-14.056286811828613,
|
| 915 |
+
-14.037705421447754,
|
| 916 |
+
-13.940332412719727,
|
| 917 |
+
-14.092416763305664,
|
| 918 |
+
-14.024917602539062,
|
| 919 |
+
-14.002346992492676,
|
| 920 |
+
-14.026989936828613,
|
| 921 |
+
-13.944084167480469,
|
| 922 |
+
-14.002883911132812,
|
| 923 |
+
-14.120462417602539,
|
| 924 |
+
-14.043062210083008,
|
| 925 |
+
-14.008293151855469,
|
| 926 |
+
-14.040563583374023,
|
| 927 |
+
-13.994155883789062,
|
| 928 |
+
-14.08944034576416,
|
| 929 |
+
-14.078422546386719,
|
| 930 |
+
-14.014589309692383,
|
| 931 |
+
-14.083242416381836,
|
| 932 |
+
-14.104707717895508,
|
| 933 |
+
-14.103189468383789,
|
| 934 |
+
-14.063937187194824,
|
| 935 |
+
-14.0596284866333,
|
| 936 |
+
-14.059121131896973,
|
| 937 |
+
-14.102814674377441,
|
| 938 |
+
-14.165373802185059,
|
| 939 |
+
-14.106118202209473,
|
| 940 |
+
-14.107162475585938,
|
| 941 |
+
-14.085371017456055,
|
| 942 |
+
-14.123793601989746,
|
| 943 |
+
-14.053537368774414,
|
| 944 |
+
-14.077792167663574,
|
| 945 |
+
-14.056371688842773,
|
| 946 |
+
-14.033655166625977,
|
| 947 |
+
-14.096640586853027,
|
| 948 |
+
-14.057114601135254,
|
| 949 |
+
-14.115262985229492,
|
| 950 |
+
-14.074142456054688,
|
| 951 |
+
-14.067980766296387,
|
| 952 |
+
-14.118453025817871,
|
| 953 |
+
-14.117535591125488,
|
| 954 |
+
-14.126029968261719,
|
| 955 |
+
-14.117874145507812
|
| 956 |
+
]
|
| 957 |
+
}
|
weight/all.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34f2ef4e5c32542060621f7ea9f7a06a2acf91be22825a38f9270077a7346679
|
| 3 |
+
size 9424379
|