Spaces:
Runtime error
Runtime error
Commit ·
bd1743b
1
Parent(s): e0aa67f
Add dsd100 dataset
Browse files- cfg/config.yaml +4 -4
- cfg/exp/default.yaml +1 -1
- remfx/datasets.py +28 -11
- remfx/models.py +0 -1
- scripts/download.py +39 -6
- shell_vars.sh +1 -1
cfg/config.yaml
CHANGED
|
@@ -53,9 +53,9 @@ callbacks:
|
|
| 53 |
_target_: remfx.callbacks.MetricCallback
|
| 54 |
|
| 55 |
datamodule:
|
| 56 |
-
_target_: remfx.datasets.
|
| 57 |
train_dataset:
|
| 58 |
-
_target_: remfx.datasets.
|
| 59 |
sample_rate: ${sample_rate}
|
| 60 |
root: ${oc.env:DATASET_ROOT}
|
| 61 |
chunk_size: ${chunk_size}
|
|
@@ -70,7 +70,7 @@ datamodule:
|
|
| 70 |
render_files: ${render_files}
|
| 71 |
render_root: ${render_root}
|
| 72 |
val_dataset:
|
| 73 |
-
_target_: remfx.datasets.
|
| 74 |
sample_rate: ${sample_rate}
|
| 75 |
root: ${oc.env:DATASET_ROOT}
|
| 76 |
chunk_size: ${chunk_size}
|
|
@@ -85,7 +85,7 @@ datamodule:
|
|
| 85 |
render_files: ${render_files}
|
| 86 |
render_root: ${render_root}
|
| 87 |
test_dataset:
|
| 88 |
-
_target_: remfx.datasets.
|
| 89 |
sample_rate: ${sample_rate}
|
| 90 |
root: ${oc.env:DATASET_ROOT}
|
| 91 |
chunk_size: ${chunk_size}
|
|
|
|
| 53 |
_target_: remfx.callbacks.MetricCallback
|
| 54 |
|
| 55 |
datamodule:
|
| 56 |
+
_target_: remfx.datasets.EffectDatamodule
|
| 57 |
train_dataset:
|
| 58 |
+
_target_: remfx.datasets.EffectDataset
|
| 59 |
sample_rate: ${sample_rate}
|
| 60 |
root: ${oc.env:DATASET_ROOT}
|
| 61 |
chunk_size: ${chunk_size}
|
|
|
|
| 70 |
render_files: ${render_files}
|
| 71 |
render_root: ${render_root}
|
| 72 |
val_dataset:
|
| 73 |
+
_target_: remfx.datasets.EffectDataset
|
| 74 |
sample_rate: ${sample_rate}
|
| 75 |
root: ${oc.env:DATASET_ROOT}
|
| 76 |
chunk_size: ${chunk_size}
|
|
|
|
| 85 |
render_files: ${render_files}
|
| 86 |
render_root: ${render_root}
|
| 87 |
test_dataset:
|
| 88 |
+
_target_: remfx.datasets.EffectDataset
|
| 89 |
sample_rate: ${sample_rate}
|
| 90 |
root: ${oc.env:DATASET_ROOT}
|
| 91 |
chunk_size: ${chunk_size}
|
cfg/exp/default.yaml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
-
- override /model:
|
| 4 |
- override /effects: all
|
| 5 |
seed: 12345
|
| 6 |
sample_rate: 48000
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
+
- override /model: umx
|
| 4 |
- override /effects: all
|
| 5 |
seed: 12345
|
| 6 |
sample_rate: 48000
|
remfx/datasets.py
CHANGED
|
@@ -55,6 +55,11 @@ idmt_bass_splits = {
|
|
| 55 |
"val": ["VIF"],
|
| 56 |
"test": ["VIS"],
|
| 57 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
idmt_drums_splits = {
|
| 59 |
"train": ["WaveDrum02", "TechnoDrum01"],
|
| 60 |
"val": ["RealDrum01"],
|
|
@@ -105,19 +110,28 @@ def locate_files(root: str, mode: str):
|
|
| 105 |
file_list += sorted(files)
|
| 106 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
| 107 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
| 108 |
-
idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
| 109 |
-
if os.path.isdir(idmt_smt_bass_dir):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
files = glob.glob(
|
| 111 |
-
os.path.join(
|
| 112 |
recursive=True,
|
| 113 |
)
|
| 114 |
-
files = [
|
| 115 |
-
f
|
| 116 |
-
for f in files
|
| 117 |
-
if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
| 118 |
-
]
|
| 119 |
file_list += sorted(files)
|
| 120 |
-
print(f"Found {len(files)} files in
|
| 121 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
| 122 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
| 123 |
if os.path.isdir(idmt_smt_drums_dir):
|
|
@@ -133,7 +147,7 @@ def locate_files(root: str, mode: str):
|
|
| 133 |
return file_list
|
| 134 |
|
| 135 |
|
| 136 |
-
class
|
| 137 |
def __init__(
|
| 138 |
self,
|
| 139 |
root: str,
|
|
@@ -199,6 +213,9 @@ class VocalSet(Dataset):
|
|
| 199 |
if resampled_chunk.shape[-1] < chunk_size:
|
| 200 |
# Skip if chunk is too small
|
| 201 |
continue
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
| 204 |
resampled_chunk
|
|
@@ -334,7 +351,7 @@ class VocalSet(Dataset):
|
|
| 334 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
| 335 |
|
| 336 |
|
| 337 |
-
class
|
| 338 |
def __init__(
|
| 339 |
self,
|
| 340 |
train_dataset,
|
|
|
|
| 55 |
"val": ["VIF"],
|
| 56 |
"test": ["VIS"],
|
| 57 |
}
|
| 58 |
+
dsd_100_splits = {
|
| 59 |
+
"train": ["train"],
|
| 60 |
+
"val": ["val"],
|
| 61 |
+
"test": ["test"],
|
| 62 |
+
}
|
| 63 |
idmt_drums_splits = {
|
| 64 |
"train": ["WaveDrum02", "TechnoDrum01"],
|
| 65 |
"val": ["RealDrum01"],
|
|
|
|
| 110 |
file_list += sorted(files)
|
| 111 |
print(f"Found {len(files)} files in IDMT-SMT-Guitar {mode}.")
|
| 112 |
# ------------------------- IDMT-SMT-BASS -------------------------
|
| 113 |
+
# idmt_smt_bass_dir = os.path.join(root, "IDMT-SMT-BASS")
|
| 114 |
+
# if os.path.isdir(idmt_smt_bass_dir):
|
| 115 |
+
# files = glob.glob(
|
| 116 |
+
# os.path.join(idmt_smt_bass_dir, "**", "*.wav"),
|
| 117 |
+
# recursive=True,
|
| 118 |
+
# )
|
| 119 |
+
# files = [
|
| 120 |
+
# f
|
| 121 |
+
# for f in files
|
| 122 |
+
# if os.path.basename(os.path.dirname(f)) in idmt_bass_splits[mode]
|
| 123 |
+
# ]
|
| 124 |
+
# file_list += sorted(files)
|
| 125 |
+
# print(f"Found {len(files)} files in IDMT-SMT-Bass {mode}.")
|
| 126 |
+
# ------------------------- DSD100 ---------------------------------
|
| 127 |
+
dsd_100_dir = os.path.join(root, "DSD100")
|
| 128 |
+
if os.path.isdir(dsd_100_dir):
|
| 129 |
files = glob.glob(
|
| 130 |
+
os.path.join(dsd_100_dir, mode, "**", "*.wav"),
|
| 131 |
recursive=True,
|
| 132 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
file_list += sorted(files)
|
| 134 |
+
print(f"Found {len(files)} files in DSD100 {mode}.")
|
| 135 |
# ------------------------- IDMT-SMT-DRUMS -------------------------
|
| 136 |
idmt_smt_drums_dir = os.path.join(root, "IDMT-SMT-DRUMS-V2")
|
| 137 |
if os.path.isdir(idmt_smt_drums_dir):
|
|
|
|
| 147 |
return file_list
|
| 148 |
|
| 149 |
|
| 150 |
+
class EffectDataset(Dataset):
|
| 151 |
def __init__(
|
| 152 |
self,
|
| 153 |
root: str,
|
|
|
|
| 213 |
if resampled_chunk.shape[-1] < chunk_size:
|
| 214 |
# Skip if chunk is too small
|
| 215 |
continue
|
| 216 |
+
# Sum to mono
|
| 217 |
+
if resampled_chunk.shape[0] > 1:
|
| 218 |
+
resampled_chunk = resampled_chunk.sum(0, keepdim=True)
|
| 219 |
|
| 220 |
dry, wet, dry_effects, wet_effects = self.process_effects(
|
| 221 |
resampled_chunk
|
|
|
|
| 351 |
return normalized_dry, normalized_wet, dry_labels_tensor, wet_labels_tensor
|
| 352 |
|
| 353 |
|
| 354 |
+
class EffectDatamodule(pl.LightningDataModule):
|
| 355 |
def __init__(
|
| 356 |
self,
|
| 357 |
train_dataset,
|
remfx/models.py
CHANGED
|
@@ -2,7 +2,6 @@ import torch
|
|
| 2 |
import torchmetrics
|
| 3 |
import pytorch_lightning as pl
|
| 4 |
from torch import Tensor, nn
|
| 5 |
-
from torch.nn import functional as F
|
| 6 |
from torchaudio.models import HDemucs
|
| 7 |
from audio_diffusion_pytorch import DiffusionModel
|
| 8 |
from auraloss.time import SISDRLoss
|
|
|
|
| 2 |
import torchmetrics
|
| 3 |
import pytorch_lightning as pl
|
| 4 |
from torch import Tensor, nn
|
|
|
|
| 5 |
from torchaudio.models import HDemucs
|
| 6 |
from audio_diffusion_pytorch import DiffusionModel
|
| 7 |
from auraloss.time import SISDRLoss
|
scripts/download.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
-
import sys
|
| 3 |
-
import glob
|
| 4 |
-
import torch
|
| 5 |
import argparse
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
|
@@ -26,8 +24,42 @@ def process_dataset(dataset_dir: str, output_dir: str):
|
|
| 26 |
pass
|
| 27 |
elif dataset_dir == "IDMT-SMT-DRUMS-V2":
|
| 28 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
else:
|
| 30 |
-
raise
|
| 31 |
|
| 32 |
|
| 33 |
if __name__ == "__main__":
|
|
@@ -38,7 +70,7 @@ if __name__ == "__main__":
|
|
| 38 |
"vocalset",
|
| 39 |
"guitarset",
|
| 40 |
"idmt-smt-guitar",
|
| 41 |
-
"
|
| 42 |
"idmt-smt-drums",
|
| 43 |
],
|
| 44 |
nargs="+",
|
|
@@ -49,10 +81,11 @@ if __name__ == "__main__":
|
|
| 49 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
| 50 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
| 51 |
"IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
|
| 52 |
-
"
|
| 53 |
"IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
|
| 54 |
}
|
| 55 |
|
| 56 |
for dataset_name, dataset_url in dataset_urls.items():
|
| 57 |
if dataset_name in args.dataset_names:
|
| 58 |
download_zip_dataset(dataset_url, "~/data/remfx-data")
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
import argparse
|
| 3 |
+
import shutil
|
| 4 |
|
| 5 |
|
| 6 |
def download_zip_dataset(dataset_url: str, output_dir: str):
|
|
|
|
| 24 |
pass
|
| 25 |
elif dataset_dir == "IDMT-SMT-DRUMS-V2":
|
| 26 |
pass
|
| 27 |
+
elif dataset_dir == "DSD100":
|
| 28 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Mixtures"))
|
| 29 |
+
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Dev")):
|
| 30 |
+
source = os.path.join(output_dir, dataset_dir, "Sources", "Dev", dir)
|
| 31 |
+
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
| 32 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Dev"))
|
| 33 |
+
for dir in os.listdir(os.path.join(output_dir, dataset_dir, "Sources", "Test")):
|
| 34 |
+
source = os.path.join(output_dir, dataset_dir, "Sources", "Test", dir)
|
| 35 |
+
shutil.move(source, os.path.join(output_dir, dataset_dir))
|
| 36 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources", "Test"))
|
| 37 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, "Sources"))
|
| 38 |
+
|
| 39 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "train"))
|
| 40 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "val"))
|
| 41 |
+
os.mkdir(os.path.join(output_dir, dataset_dir, "test"))
|
| 42 |
+
files = os.listdir(os.path.join(output_dir, dataset_dir))
|
| 43 |
+
|
| 44 |
+
num = 0
|
| 45 |
+
for dir in files:
|
| 46 |
+
if not os.path.isdir(os.path.join(output_dir, dataset_dir, dir)):
|
| 47 |
+
continue
|
| 48 |
+
if dir == "train" or dir == "val" or dir == "test":
|
| 49 |
+
continue
|
| 50 |
+
source = os.path.join(output_dir, dataset_dir, dir, "bass.wav")
|
| 51 |
+
if num < 80:
|
| 52 |
+
dest = os.path.join(output_dir, dataset_dir, "train", f"{num}.wav")
|
| 53 |
+
elif num < 90:
|
| 54 |
+
dest = os.path.join(output_dir, dataset_dir, "val", f"{num}.wav")
|
| 55 |
+
else:
|
| 56 |
+
dest = os.path.join(output_dir, dataset_dir, "test", f"{num}.wav")
|
| 57 |
+
shutil.move(source, dest)
|
| 58 |
+
shutil.rmtree(os.path.join(output_dir, dataset_dir, dir))
|
| 59 |
+
num += 1
|
| 60 |
+
|
| 61 |
else:
|
| 62 |
+
raise NotImplementedError(f"Invalid dataset_dir = {dataset_dir}.")
|
| 63 |
|
| 64 |
|
| 65 |
if __name__ == "__main__":
|
|
|
|
| 70 |
"vocalset",
|
| 71 |
"guitarset",
|
| 72 |
"idmt-smt-guitar",
|
| 73 |
+
"dsd100",
|
| 74 |
"idmt-smt-drums",
|
| 75 |
],
|
| 76 |
nargs="+",
|
|
|
|
| 81 |
"vocalset": "https://zenodo.org/record/1442513/files/VocalSet1-2.zip",
|
| 82 |
"guitarset": "https://zenodo.org/record/3371780/files/audio_mono-mic.zip",
|
| 83 |
"IDMT-SMT-GUITAR_V2": "https://zenodo.org/record/7544110/files/IDMT-SMT-GUITAR_V2.zip",
|
| 84 |
+
"DSD100": "http://liutkus.net/DSD100.zip",
|
| 85 |
"IDMT-SMT-DRUMS-V2": "https://zenodo.org/record/7544164/files/IDMT-SMT-DRUMS-V2.zip",
|
| 86 |
}
|
| 87 |
|
| 88 |
for dataset_name, dataset_url in dataset_urls.items():
|
| 89 |
if dataset_name in args.dataset_names:
|
| 90 |
download_zip_dataset(dataset_url, "~/data/remfx-data")
|
| 91 |
+
process_dataset(dataset_name, "~/data/remfx-data")
|
shell_vars.sh
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
-
export DATASET_ROOT="./data/
|
| 2 |
export WANDB_PROJECT="RemFX"
|
| 3 |
export WANDB_ENTITY="mattricesound"
|
|
|
|
| 1 |
+
export DATASET_ROOT="./data/"
|
| 2 |
export WANDB_PROJECT="RemFX"
|
| 3 |
export WANDB_ENTITY="mattricesound"
|