Spaces:
Runtime error
Runtime error
Merge pull request #24 from mhrice/new-metrics
Browse files- README.md +1 -1
- cfg/config.yaml +12 -7
- cfg/effects/all.yaml +29 -68
- cfg/effects/chorus.yaml +5 -18
- cfg/effects/compression.yaml +0 -22
- cfg/effects/compressor.yaml +9 -0
- cfg/effects/distortion.yaml +5 -12
- cfg/effects/reverb.yaml +11 -24
- cfg/exp/{demucs_compression.yaml → demucs_compressor.yaml} +1 -1
- cfg/exp/{umx_compression.yaml → umx_compressor.yaml} +1 -1
- remfx/datasets.py +11 -14
- remfx/models.py +28 -2
- remfx/utils.py +3 -4
- scripts/test.py +0 -1
- scripts/train.py +1 -0
README.md
CHANGED
|
@@ -22,7 +22,7 @@ Models and effects detailed below.
|
|
| 22 |
|
| 23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
| 24 |
|
| 25 |
-
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices
|
| 26 |
|
| 27 |
### Current Models
|
| 28 |
- `umx`
|
|
|
|
| 22 |
|
| 23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
| 24 |
|
| 25 |
+
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
|
| 26 |
|
| 27 |
### Current Models
|
| 28 |
- `umx`
|
cfg/config.yaml
CHANGED
|
@@ -6,8 +6,8 @@ defaults:
|
|
| 6 |
seed: 12345
|
| 7 |
train: True
|
| 8 |
sample_rate: 48000
|
|
|
|
| 9 |
logs_dir: "./logs"
|
| 10 |
-
log_every_n_steps: 1000
|
| 11 |
render_files: True
|
| 12 |
render_root: "./data/processed"
|
| 13 |
|
|
@@ -21,6 +21,9 @@ callbacks:
|
|
| 21 |
verbose: False
|
| 22 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
| 23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
datamodule:
|
| 26 |
_target_: remfx.datasets.VocalSetDatamodule
|
|
@@ -28,27 +31,27 @@ datamodule:
|
|
| 28 |
_target_: remfx.datasets.VocalSet
|
| 29 |
sample_rate: ${sample_rate}
|
| 30 |
root: ${oc.env:DATASET_ROOT}
|
| 31 |
-
|
| 32 |
mode: "train"
|
| 33 |
-
effect_types: ${effects
|
| 34 |
render_files: ${render_files}
|
| 35 |
render_root: ${render_root}
|
| 36 |
val_dataset:
|
| 37 |
_target_: remfx.datasets.VocalSet
|
| 38 |
sample_rate: ${sample_rate}
|
| 39 |
root: ${oc.env:DATASET_ROOT}
|
| 40 |
-
|
| 41 |
mode: "val"
|
| 42 |
-
effect_types: ${effects
|
| 43 |
render_files: ${render_files}
|
| 44 |
render_root: ${render_root}
|
| 45 |
test_dataset:
|
| 46 |
_target_: remfx.datasets.VocalSet
|
| 47 |
sample_rate: ${sample_rate}
|
| 48 |
root: ${oc.env:DATASET_ROOT}
|
| 49 |
-
|
| 50 |
mode: "test"
|
| 51 |
-
effect_types: ${effects
|
| 52 |
render_files: ${render_files}
|
| 53 |
render_root: ${render_root}
|
| 54 |
|
|
@@ -76,3 +79,5 @@ trainer:
|
|
| 76 |
accumulate_grad_batches: 1
|
| 77 |
accelerator: null
|
| 78 |
devices: 1
|
|
|
|
|
|
|
|
|
| 6 |
seed: 12345
|
| 7 |
train: True
|
| 8 |
sample_rate: 48000
|
| 9 |
+
chunk_size: 262144 # 5.5s
|
| 10 |
logs_dir: "./logs"
|
|
|
|
| 11 |
render_files: True
|
| 12 |
render_root: "./data/processed"
|
| 13 |
|
|
|
|
| 21 |
verbose: False
|
| 22 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
| 23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
| 24 |
+
learning_rate_monitor:
|
| 25 |
+
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
| 26 |
+
logging_interval: "step"
|
| 27 |
|
| 28 |
datamodule:
|
| 29 |
_target_: remfx.datasets.VocalSetDatamodule
|
|
|
|
| 31 |
_target_: remfx.datasets.VocalSet
|
| 32 |
sample_rate: ${sample_rate}
|
| 33 |
root: ${oc.env:DATASET_ROOT}
|
| 34 |
+
chunk_size: ${chunk_size}
|
| 35 |
mode: "train"
|
| 36 |
+
effect_types: ${effects}
|
| 37 |
render_files: ${render_files}
|
| 38 |
render_root: ${render_root}
|
| 39 |
val_dataset:
|
| 40 |
_target_: remfx.datasets.VocalSet
|
| 41 |
sample_rate: ${sample_rate}
|
| 42 |
root: ${oc.env:DATASET_ROOT}
|
| 43 |
+
chunk_size: ${chunk_size}
|
| 44 |
mode: "val"
|
| 45 |
+
effect_types: ${effects}
|
| 46 |
render_files: ${render_files}
|
| 47 |
render_root: ${render_root}
|
| 48 |
test_dataset:
|
| 49 |
_target_: remfx.datasets.VocalSet
|
| 50 |
sample_rate: ${sample_rate}
|
| 51 |
root: ${oc.env:DATASET_ROOT}
|
| 52 |
+
chunk_size: ${chunk_size}
|
| 53 |
mode: "test"
|
| 54 |
+
effect_types: ${effects}
|
| 55 |
render_files: ${render_files}
|
| 56 |
render_root: ${render_root}
|
| 57 |
|
|
|
|
| 79 |
accumulate_grad_batches: 1
|
| 80 |
accelerator: null
|
| 81 |
devices: 1
|
| 82 |
+
gradient_clip_val: 10.0
|
| 83 |
+
max_steps: 50000
|
cfg/effects/all.yaml
CHANGED
|
@@ -1,70 +1,31 @@
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
_target_: remfx.effects.RandomPedalboardChorus
|
| 33 |
-
sample_rate: ${sample_rate}
|
| 34 |
-
min_rate_hz: 1.0
|
| 35 |
-
max_rate_hz: 1.0
|
| 36 |
-
min_depth: 0.3
|
| 37 |
-
max_depth: 0.3
|
| 38 |
-
min_centre_delay_ms: 7.5
|
| 39 |
-
max_centre_delay_ms: 7.5
|
| 40 |
-
min_feedback: 0.4
|
| 41 |
-
max_feedback: 0.4
|
| 42 |
-
min_mix: 0.4
|
| 43 |
-
max_mix: 0.4
|
| 44 |
-
Distortion:
|
| 45 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
| 46 |
-
sample_rate: ${sample_rate}
|
| 47 |
-
min_drive_db: 30
|
| 48 |
-
max_drive_db: 30
|
| 49 |
-
Compressor:
|
| 50 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
| 51 |
-
sample_rate: ${sample_rate}
|
| 52 |
-
min_threshold_db: -32
|
| 53 |
-
max_threshold_db: -32
|
| 54 |
-
min_ratio: 3.0
|
| 55 |
-
max_ratio: 3.0
|
| 56 |
-
min_attack_ms: 10.0
|
| 57 |
-
max_attack_ms: 10.0
|
| 58 |
-
min_release_ms: 40.0
|
| 59 |
-
max_release_ms: 40.0
|
| 60 |
-
Reverb:
|
| 61 |
-
_target_: remfx.effects.RandomPedalboardReverb
|
| 62 |
-
sample_rate: ${sample_rate}
|
| 63 |
-
min_room_size: 0.5
|
| 64 |
-
max_room_size: 0.5
|
| 65 |
-
min_damping: 0.5
|
| 66 |
-
max_damping: 0.5
|
| 67 |
-
min_wet_dry: 0.4
|
| 68 |
-
max_wet_dry: 0.4
|
| 69 |
-
min_width: 0.5
|
| 70 |
-
max_width: 0.5
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
+
Chorus:
|
| 4 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
| 5 |
+
sample_rate: ${sample_rate}
|
| 6 |
+
min_depth: 0.2
|
| 7 |
+
min_mix: 0.3
|
| 8 |
+
Distortion:
|
| 9 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 10 |
+
sample_rate: ${sample_rate}
|
| 11 |
+
min_drive_db: 10
|
| 12 |
+
max_drive_db: 50
|
| 13 |
+
Compressor:
|
| 14 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
| 15 |
+
sample_rate: ${sample_rate}
|
| 16 |
+
min_threshold_db: -42.0
|
| 17 |
+
max_threshold_db: -20.0
|
| 18 |
+
min_ratio: 1.5
|
| 19 |
+
max_ratio: 6.0
|
| 20 |
+
Reverb:
|
| 21 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
| 22 |
+
sample_rate: ${sample_rate}
|
| 23 |
+
min_room_size: 0.3
|
| 24 |
+
max_room_size: 1.0
|
| 25 |
+
min_damping: 0.2
|
| 26 |
+
max_damping: 1.0
|
| 27 |
+
min_wet_dry: 0.2
|
| 28 |
+
max_wet_dry: 0.8
|
| 29 |
+
min_width: 0.2
|
| 30 |
+
max_width: 1.0
|
| 31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/chorus.yaml
CHANGED
|
@@ -1,20 +1,7 @@
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
Chorus:
|
| 9 |
-
_target_: remfx.effects.RandomPedalboardChorus
|
| 10 |
-
sample_rate: ${sample_rate}
|
| 11 |
-
min_rate_hz: 1.0
|
| 12 |
-
max_rate_hz: 1.0
|
| 13 |
-
min_depth: 0.3
|
| 14 |
-
max_depth: 0.3
|
| 15 |
-
min_centre_delay_ms: 7.5
|
| 16 |
-
max_centre_delay_ms: 7.5
|
| 17 |
-
min_feedback: 0.4
|
| 18 |
-
max_feedback: 0.4
|
| 19 |
-
min_mix: 0.4
|
| 20 |
-
max_mix: 0.4
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
+
Chorus:
|
| 4 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
| 5 |
+
sample_rate: ${sample_rate}
|
| 6 |
+
min_depth: 0.2
|
| 7 |
+
min_mix: 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/compression.yaml
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
# @package _global_
|
| 2 |
-
effects:
|
| 3 |
-
train_effects:
|
| 4 |
-
Compressor:
|
| 5 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
| 6 |
-
sample_rate: ${sample_rate}
|
| 7 |
-
min_threshold_db: -42.0
|
| 8 |
-
max_threshold_db: -20.0
|
| 9 |
-
min_ratio: 1.5
|
| 10 |
-
max_ratio: 6.0
|
| 11 |
-
val_effects:
|
| 12 |
-
Compressor:
|
| 13 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
| 14 |
-
sample_rate: ${sample_rate}
|
| 15 |
-
min_threshold_db: -32
|
| 16 |
-
max_threshold_db: -32
|
| 17 |
-
min_ratio: 3.0
|
| 18 |
-
max_ratio: 3.0
|
| 19 |
-
min_attack_ms: 10.0
|
| 20 |
-
max_attack_ms: 10.0
|
| 21 |
-
min_release_ms: 40.0
|
| 22 |
-
max_release_ms: 40.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/compressor.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
effects:
|
| 3 |
+
Compressor:
|
| 4 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
| 5 |
+
sample_rate: ${sample_rate}
|
| 6 |
+
min_threshold_db: -42.0
|
| 7 |
+
max_threshold_db: -20.0
|
| 8 |
+
min_ratio: 1.5
|
| 9 |
+
max_ratio: 6.0
|
cfg/effects/distortion.yaml
CHANGED
|
@@ -1,14 +1,7 @@
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
max_drive_db: 50
|
| 9 |
-
val_effects:
|
| 10 |
-
Distortion:
|
| 11 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
| 12 |
-
sample_rate: ${sample_rate}
|
| 13 |
-
min_drive_db: 30
|
| 14 |
-
max_drive_db: 30
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
+
Distortion:
|
| 4 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
| 5 |
+
sample_rate: ${sample_rate}
|
| 6 |
+
min_drive_db: 10
|
| 7 |
+
max_drive_db: 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/reverb.yaml
CHANGED
|
@@ -1,26 +1,13 @@
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
max_width: 1.0
|
| 15 |
-
val_effects:
|
| 16 |
-
Reverb:
|
| 17 |
-
_target_: remfx.effects.RandomPedalboardReverb
|
| 18 |
-
sample_rate: ${sample_rate}
|
| 19 |
-
min_room_size: 0.5
|
| 20 |
-
max_room_size: 0.5
|
| 21 |
-
min_damping: 0.5
|
| 22 |
-
max_damping: 0.5
|
| 23 |
-
min_wet_dry: 0.4
|
| 24 |
-
max_wet_dry: 0.4
|
| 25 |
-
min_width: 0.5
|
| 26 |
-
max_width: 0.5
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
effects:
|
| 3 |
+
Reverb:
|
| 4 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
| 5 |
+
sample_rate: ${sample_rate}
|
| 6 |
+
min_room_size: 0.3
|
| 7 |
+
max_room_size: 1.0
|
| 8 |
+
min_damping: 0.2
|
| 9 |
+
max_damping: 1.0
|
| 10 |
+
min_wet_dry: 0.2
|
| 11 |
+
max_wet_dry: 0.8
|
| 12 |
+
min_width: 0.2
|
| 13 |
+
max_width: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/exp/{demucs_compression.yaml → demucs_compressor.yaml}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
- override /model: demucs
|
| 4 |
-
- override /effects:
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
- override /model: demucs
|
| 4 |
+
- override /effects: compressor
|
cfg/exp/{umx_compression.yaml → umx_compressor.yaml}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
- override /model: umx
|
| 4 |
-
- override /effects:
|
|
|
|
| 1 |
# @package _global_
|
| 2 |
defaults:
|
| 3 |
- override /model: umx
|
| 4 |
+
- override /effects: compressor
|
remfx/datasets.py
CHANGED
|
@@ -17,7 +17,7 @@ class VocalSet(Dataset):
|
|
| 17 |
self,
|
| 18 |
root: str,
|
| 19 |
sample_rate: int,
|
| 20 |
-
|
| 21 |
effect_types: List[torch.nn.Module] = None,
|
| 22 |
render_files: bool = True,
|
| 23 |
render_root: str = None,
|
|
@@ -28,7 +28,7 @@ class VocalSet(Dataset):
|
|
| 28 |
self.song_idx = []
|
| 29 |
self.root = Path(root)
|
| 30 |
self.render_root = Path(render_root)
|
| 31 |
-
self.
|
| 32 |
self.sample_rate = sample_rate
|
| 33 |
self.mode = mode
|
| 34 |
|
|
@@ -36,9 +36,11 @@ class VocalSet(Dataset):
|
|
| 36 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
| 37 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 38 |
self.effect_types = effect_types
|
| 39 |
-
|
| 40 |
-
self.processed_root = self.render_root / "processed" / self.mode
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
self.num_chunks = 0
|
| 43 |
print("Total files:", len(self.files))
|
| 44 |
print("Processing files...")
|
|
@@ -46,19 +48,14 @@ class VocalSet(Dataset):
|
|
| 46 |
# Split audio file into chunks, resample, then apply random effects
|
| 47 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
| 48 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
| 49 |
-
chunks, orig_sr = create_sequential_chunks(
|
| 50 |
-
audio_file, self.chunk_size_in_sec
|
| 51 |
-
)
|
| 52 |
for chunk in chunks:
|
| 53 |
resampled_chunk = torchaudio.functional.resample(
|
| 54 |
chunk, orig_sr, sample_rate
|
| 55 |
)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
resampled_chunk,
|
| 60 |
-
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
| 61 |
-
)
|
| 62 |
# Apply effect
|
| 63 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 64 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
|
|
|
| 17 |
self,
|
| 18 |
root: str,
|
| 19 |
sample_rate: int,
|
| 20 |
+
chunk_size: int = 3,
|
| 21 |
effect_types: List[torch.nn.Module] = None,
|
| 22 |
render_files: bool = True,
|
| 23 |
render_root: str = None,
|
|
|
|
| 28 |
self.song_idx = []
|
| 29 |
self.root = Path(root)
|
| 30 |
self.render_root = Path(render_root)
|
| 31 |
+
self.chunk_size = chunk_size
|
| 32 |
self.sample_rate = sample_rate
|
| 33 |
self.mode = mode
|
| 34 |
|
|
|
|
| 36 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
| 37 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
| 38 |
self.effect_types = effect_types
|
| 39 |
+
effect_str = "_".join([e for e in self.effect_types])
|
| 40 |
+
self.processed_root = self.render_root / "processed" / effect_str / self.mode
|
| 41 |
+
if self.processed_root.exists():
|
| 42 |
+
print("Found processed files.")
|
| 43 |
+
render_files = False
|
| 44 |
self.num_chunks = 0
|
| 45 |
print("Total files:", len(self.files))
|
| 46 |
print("Processing files...")
|
|
|
|
| 48 |
# Split audio file into chunks, resample, then apply random effects
|
| 49 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
| 50 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
| 51 |
+
chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
|
|
|
|
|
|
|
| 52 |
for chunk in chunks:
|
| 53 |
resampled_chunk = torchaudio.functional.resample(
|
| 54 |
chunk, orig_sr, sample_rate
|
| 55 |
)
|
| 56 |
+
if resampled_chunk.shape[-1] < chunk_size:
|
| 57 |
+
# Skip if chunk is too small
|
| 58 |
+
continue
|
|
|
|
|
|
|
|
|
|
| 59 |
# Apply effect
|
| 60 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
| 61 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
remfx/models.py
CHANGED
|
@@ -55,6 +55,29 @@ class RemFXModel(pl.LightningModule):
|
|
| 55 |
)
|
| 56 |
return optimizer
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def training_step(self, batch, batch_idx):
|
| 59 |
loss = self.common_step(batch, batch_idx, mode="train")
|
| 60 |
return loss
|
|
@@ -215,7 +238,7 @@ class OpenUnmixModel(torch.nn.Module):
|
|
| 215 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 216 |
Y = self.model(X)
|
| 217 |
sep_out = self.separator(x).squeeze(1)
|
| 218 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
| 219 |
|
| 220 |
return loss, sep_out
|
| 221 |
|
|
@@ -236,7 +259,7 @@ class DemucsModel(torch.nn.Module):
|
|
| 236 |
def forward(self, batch):
|
| 237 |
x, target, label = batch
|
| 238 |
output = self.model(x).squeeze(1)
|
| 239 |
-
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
| 240 |
return loss, output
|
| 241 |
|
| 242 |
def sample(self, x: Tensor) -> Tensor:
|
|
@@ -264,10 +287,13 @@ def log_wandb_audio_batch(
|
|
| 264 |
samples: Tensor,
|
| 265 |
sampling_rate: int,
|
| 266 |
caption: str = "",
|
|
|
|
| 267 |
):
|
| 268 |
num_items = samples.shape[0]
|
| 269 |
samples = rearrange(samples, "b c t -> b t c")
|
| 270 |
for idx in range(num_items):
|
|
|
|
|
|
|
| 271 |
logger.experiment.log(
|
| 272 |
{
|
| 273 |
f"{id}_{idx}": wandb.Audio(
|
|
|
|
| 55 |
)
|
| 56 |
return optimizer
|
| 57 |
|
| 58 |
+
# Add step-based learning rate scheduler
|
| 59 |
+
def optimizer_step(
|
| 60 |
+
self,
|
| 61 |
+
epoch,
|
| 62 |
+
batch_idx,
|
| 63 |
+
optimizer,
|
| 64 |
+
optimizer_idx,
|
| 65 |
+
optimizer_closure,
|
| 66 |
+
on_tpu,
|
| 67 |
+
using_native_amp,
|
| 68 |
+
using_lbfgs,
|
| 69 |
+
):
|
| 70 |
+
# update params
|
| 71 |
+
optimizer.step(closure=optimizer_closure)
|
| 72 |
+
|
| 73 |
+
# update learning rate. Reduce by factor of 10 at 80% and 95% of training
|
| 74 |
+
if self.trainer.global_step == 0.8 * self.trainer.max_steps:
|
| 75 |
+
for pg in optimizer.param_groups:
|
| 76 |
+
pg["lr"] = 0.1 * pg["lr"]
|
| 77 |
+
if self.trainer.global_step == 0.95 * self.trainer.max_steps:
|
| 78 |
+
for pg in optimizer.param_groups:
|
| 79 |
+
pg["lr"] = 0.1 * pg["lr"]
|
| 80 |
+
|
| 81 |
def training_step(self, batch, batch_idx):
|
| 82 |
loss = self.common_step(batch, batch_idx, mode="train")
|
| 83 |
return loss
|
|
|
|
| 238 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 239 |
Y = self.model(X)
|
| 240 |
sep_out = self.separator(x).squeeze(1)
|
| 241 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
| 242 |
|
| 243 |
return loss, sep_out
|
| 244 |
|
|
|
|
| 259 |
def forward(self, batch):
|
| 260 |
x, target, label = batch
|
| 261 |
output = self.model(x).squeeze(1)
|
| 262 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
| 263 |
return loss, output
|
| 264 |
|
| 265 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
|
| 287 |
samples: Tensor,
|
| 288 |
sampling_rate: int,
|
| 289 |
caption: str = "",
|
| 290 |
+
max_items: int = 10,
|
| 291 |
):
|
| 292 |
num_items = samples.shape[0]
|
| 293 |
samples = rearrange(samples, "b c t -> b t c")
|
| 294 |
for idx in range(num_items):
|
| 295 |
+
if idx >= max_items:
|
| 296 |
+
break
|
| 297 |
logger.experiment.log(
|
| 298 |
{
|
| 299 |
f"{id}_{idx}": wandb.Audio(
|
remfx/utils.py
CHANGED
|
@@ -132,10 +132,9 @@ def create_sequential_chunks(
|
|
| 132 |
"""
|
| 133 |
chunks = []
|
| 134 |
audio, sr = torchaudio.load(audio_file)
|
| 135 |
-
|
| 136 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
| 137 |
for start in chunk_starts:
|
| 138 |
-
if start +
|
| 139 |
break
|
| 140 |
-
chunks.append(audio[:, start : start +
|
| 141 |
return chunks, sr
|
|
|
|
| 132 |
"""
|
| 133 |
chunks = []
|
| 134 |
audio, sr = torchaudio.load(audio_file)
|
| 135 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
|
|
|
|
| 136 |
for start in chunk_starts:
|
| 137 |
+
if start + chunk_size > audio.shape[-1]:
|
| 138 |
break
|
| 139 |
+
chunks.append(audio[:, start : start + chunk_size])
|
| 140 |
return chunks, sr
|
scripts/test.py
CHANGED
|
@@ -14,7 +14,6 @@ def main(cfg: DictConfig):
|
|
| 14 |
# Apply seed for reproducibility
|
| 15 |
if cfg.seed:
|
| 16 |
pl.seed_everything(cfg.seed)
|
| 17 |
-
cfg.render_files = False
|
| 18 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
| 19 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
| 20 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
|
|
|
| 14 |
# Apply seed for reproducibility
|
| 15 |
if cfg.seed:
|
| 16 |
pl.seed_everything(cfg.seed)
|
|
|
|
| 17 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
| 18 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
| 19 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
scripts/train.py
CHANGED
|
@@ -42,6 +42,7 @@ def main(cfg: DictConfig):
|
|
| 42 |
summary = ModelSummary(model)
|
| 43 |
print(summary)
|
| 44 |
trainer.fit(model=model, datamodule=datamodule)
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
if __name__ == "__main__":
|
|
|
|
| 42 |
summary = ModelSummary(model)
|
| 43 |
print(summary)
|
| 44 |
trainer.fit(model=model, datamodule=datamodule)
|
| 45 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
|
| 46 |
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|