Spaces:
Sleeping
Sleeping
Upload 55 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- configs/callbacks/default.yaml +20 -0
- configs/callbacks/none.yaml +0 -0
- configs/callbacks/wandb.yaml +31 -0
- configs/config.yaml +44 -0
- configs/datamodule/musdb18_hq.yaml +50 -0
- configs/datamodule/musdb_dev14.yaml +28 -0
- configs/evaluation.yaml +43 -0
- configs/experiment/bass_dis.yaml +38 -0
- configs/experiment/drums_dis.yaml +38 -0
- configs/experiment/multigpu_default.yaml +26 -0
- configs/experiment/other_dis.yaml +38 -0
- configs/experiment/vocals_dis.yaml +38 -0
- configs/hydra/default.yaml +16 -0
- configs/infer.yaml +26 -0
- configs/logger/csv.yaml +8 -0
- configs/logger/many_loggers.yaml +10 -0
- configs/logger/neptune.yaml +11 -0
- configs/logger/none.yaml +0 -0
- configs/logger/tensorboard.yaml +10 -0
- configs/logger/wandb.yaml +15 -0
- configs/model/bass.yaml +28 -0
- configs/model/drums.yaml +28 -0
- configs/model/other.yaml +28 -0
- configs/model/vocals.yaml +28 -0
- configs/paths/default.yaml +18 -0
- configs/trainer/ddp.yaml +13 -0
- configs/trainer/default.yaml +19 -0
- configs/trainer/minimal.yaml +21 -0
- src/__init__.py +0 -0
- src/callbacks/__init__.py +0 -0
- src/callbacks/onnx_callback.py +49 -0
- src/callbacks/wandb_callbacks.py +280 -0
- src/datamodules/__init__.py +0 -0
- src/datamodules/datasets/__init__.py +0 -0
- src/datamodules/datasets/musdb.py +174 -0
- src/datamodules/musdb_datamodule.py +117 -0
- src/dp_tdf/__init__.py +0 -0
- src/dp_tdf/abstract.py +204 -0
- src/dp_tdf/bandsequence.py +136 -0
- src/dp_tdf/dp_tdf_net.py +118 -0
- src/dp_tdf/modules.py +158 -0
- src/evaluation/eval.py +120 -0
- src/evaluation/eval_demo.py +71 -0
- src/evaluation/separate.py +193 -0
- src/layers/__init__.py +2 -0
- src/layers/batch_norm.py +201 -0
- src/layers/chunk_size.py +53 -0
- src/train.py +152 -0
- src/utils/__init__.py +3 -0
- src/utils/data_augmentation.py +128 -0
configs/callbacks/default.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_checkpoint:
|
| 2 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
| 3 |
+
monitor: "val/usdr" # name of the logged metric which determines when model is improving
|
| 4 |
+
save_top_k: 5 # save k best models (determined by above metric)
|
| 5 |
+
save_last: True # additionaly always save model from last epoch
|
| 6 |
+
mode: "max" # can be "max" or "min"
|
| 7 |
+
verbose: False
|
| 8 |
+
dirpath: "checkpoints/"
|
| 9 |
+
filename: "{epoch:02d}-{step}"
|
| 10 |
+
#
|
| 11 |
+
#early_stopping:
|
| 12 |
+
# _target_: pytorch_lightning.callbacks.EarlyStopping
|
| 13 |
+
# monitor: "val/sdr" # name of the logged metric which determines when model is improving
|
| 14 |
+
# patience: 300 # how many epochs of not improving until training stops
|
| 15 |
+
# mode: "max" # can be "max" or "min"
|
| 16 |
+
# min_delta: 0.05 # minimum change in the monitored metric needed to qualify as an improvement
|
| 17 |
+
|
| 18 |
+
#make_onnx:
|
| 19 |
+
# _target_: src.callbacks.onnx_callback.MakeONNXCallback
|
| 20 |
+
# dirpath: "onnx/"
|
configs/callbacks/none.yaml
ADDED
|
File without changes
|
configs/callbacks/wandb.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default.yaml
|
| 3 |
+
|
| 4 |
+
watch_model:
|
| 5 |
+
_target_: src.callbacks.wandb_callbacks.WatchModel
|
| 6 |
+
log: "all"
|
| 7 |
+
log_freq: 100
|
| 8 |
+
|
| 9 |
+
#upload_valid_track:
|
| 10 |
+
# _target_: src.callbacks.wandb_callbacks.UploadValidTrack
|
| 11 |
+
# crop: 3
|
| 12 |
+
# upload_after_n_epoch: -1
|
| 13 |
+
|
| 14 |
+
#upload_code_as_artifact:
|
| 15 |
+
# _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
|
| 16 |
+
# code_dir: ${work_dir}/src
|
| 17 |
+
#
|
| 18 |
+
#upload_ckpts_as_artifact:
|
| 19 |
+
# _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
|
| 20 |
+
# ckpt_dir: "checkpoints/"
|
| 21 |
+
# upload_best_only: True
|
| 22 |
+
#
|
| 23 |
+
#log_f1_precision_recall_heatmap:
|
| 24 |
+
# _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
|
| 25 |
+
#
|
| 26 |
+
#log_confusion_matrix:
|
| 27 |
+
# _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
|
| 28 |
+
#
|
| 29 |
+
#log_image_predictions:
|
| 30 |
+
# _target_: src.callbacks.wandb_callbacks.LogImagePredictions
|
| 31 |
+
# num_samples: 8
|
configs/config.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# specify here default training configuration
|
| 4 |
+
defaults:
|
| 5 |
+
- datamodule: musdb18_hq
|
| 6 |
+
- model: null
|
| 7 |
+
- callbacks: default # set this to null if you don't want to use callbacks
|
| 8 |
+
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
|
| 9 |
+
- trainer: default
|
| 10 |
+
- hparams_search: null
|
| 11 |
+
- paths: default.yaml
|
| 12 |
+
|
| 13 |
+
- hydra: default
|
| 14 |
+
|
| 15 |
+
- experiment: null
|
| 16 |
+
|
| 17 |
+
# enable color logging
|
| 18 |
+
- override hydra/hydra_logging: colorlog
|
| 19 |
+
- override hydra/job_logging: colorlog
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# path to original working directory
|
| 23 |
+
# hydra hijacks working directory by changing it to the current log directory,
|
| 24 |
+
# so it's useful to have this path as a special variable
|
| 25 |
+
# learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
|
| 26 |
+
#work_dir: ${hydra:runtime.cwd}
|
| 27 |
+
#output_dir: ${hydra:runtime.output_dir}
|
| 28 |
+
|
| 29 |
+
# path to folder with data
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# use `python run.py debug=true` for easy debugging!
|
| 33 |
+
# this will run 1 train, val and test loop with only 1 batch
|
| 34 |
+
# equivalent to running `python run.py trainer.fast_dev_run=true`
|
| 35 |
+
# (this is placed here just for easier access from command line)
|
| 36 |
+
debug: False
|
| 37 |
+
|
| 38 |
+
# pretty print config at the start of the run using Rich library
|
| 39 |
+
print_config: True
|
| 40 |
+
|
| 41 |
+
# disable python warnings if they annoy you
|
| 42 |
+
ignore_warnings: True
|
| 43 |
+
|
| 44 |
+
wandb_api_key: ${oc.env:wandb_api_key}
|
configs/datamodule/musdb18_hq.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.datamodules.musdb_datamodule.MusdbDataModule
|
| 2 |
+
|
| 3 |
+
# data_dir is specified in config.yaml
|
| 4 |
+
data_dir: null
|
| 5 |
+
|
| 6 |
+
single_channel: False
|
| 7 |
+
|
| 8 |
+
# chunk_size = (hop_length * (dim_t - 1) / sample_rate) secs
|
| 9 |
+
sample_rate: 44100
|
| 10 |
+
hop_length: ${model.hop_length} # stft hop_length
|
| 11 |
+
dim_t: ${model.dim_t} # number of stft frames
|
| 12 |
+
|
| 13 |
+
# number of overlapping wave samples between chunks when separating a whole track
|
| 14 |
+
overlap: ${model.overlap}
|
| 15 |
+
|
| 16 |
+
source_names:
|
| 17 |
+
- bass
|
| 18 |
+
- drums
|
| 19 |
+
- other
|
| 20 |
+
- vocals
|
| 21 |
+
target_name: ${model.target_name}
|
| 22 |
+
|
| 23 |
+
external_datasets: null
|
| 24 |
+
#external_datasets:
|
| 25 |
+
# - test
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
batch_size: 8
|
| 29 |
+
num_workers: 0
|
| 30 |
+
pin_memory: False
|
| 31 |
+
|
| 32 |
+
aug_params:
|
| 33 |
+
- 2 # maximum pitch shift in semitones (-x < shift param < x)
|
| 34 |
+
- 20 # maximum time stretch percentage (-x < stretch param < x)
|
| 35 |
+
|
| 36 |
+
validation_set:
|
| 37 |
+
- Actions - One Minute Smile
|
| 38 |
+
- Clara Berry And Wooldog - Waltz For My Victims
|
| 39 |
+
- Johnny Lokke - Promises & Lies
|
| 40 |
+
- Patrick Talbot - A Reason To Leave
|
| 41 |
+
- Triviul - Angelsaint
|
| 42 |
+
# - Alexander Ross - Goodbye Bolero
|
| 43 |
+
# - Fergessen - Nos Palpitants
|
| 44 |
+
# - Leaf - Summerghost
|
| 45 |
+
# - Skelpolu - Human Mistakes
|
| 46 |
+
# - Young Griffo - Pennies
|
| 47 |
+
# - ANiMAL - Rockshow
|
| 48 |
+
# - James May - On The Line
|
| 49 |
+
# - Meaxic - Take A Step
|
| 50 |
+
# - Traffic Experiment - Sirens
|
configs/datamodule/musdb_dev14.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
defaults:
|
| 3 |
+
- musdb18_hq
|
| 4 |
+
|
| 5 |
+
data_dir: ${oc.env:data_dir}
|
| 6 |
+
|
| 7 |
+
has_split_structure: True
|
| 8 |
+
|
| 9 |
+
validation_set:
|
| 10 |
+
# - Meaxic - Take A Step
|
| 11 |
+
# - Skelpolu - Human Mistakes
|
| 12 |
+
- Actions - One Minute Smile
|
| 13 |
+
- Clara Berry And Wooldog - Waltz For My Victims
|
| 14 |
+
- Johnny Lokke - Promises & Lies
|
| 15 |
+
- Patrick Talbot - A Reason To Leave
|
| 16 |
+
- Triviul - Angelsaint
|
| 17 |
+
- Alexander Ross - Goodbye Bolero
|
| 18 |
+
- Fergessen - Nos Palpitants
|
| 19 |
+
- Leaf - Summerghost
|
| 20 |
+
- Skelpolu - Human Mistakes
|
| 21 |
+
- Young Griffo - Pennies
|
| 22 |
+
- ANiMAL - Rockshow
|
| 23 |
+
- James May - On The Line
|
| 24 |
+
- Meaxic - Take A Step
|
| 25 |
+
- Traffic Experiment - Sirens
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
mode: musdb18hq
|
configs/evaluation.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# specify here default training configuration
|
| 4 |
+
defaults:
|
| 5 |
+
- model: ConvTDFNet_vocals
|
| 6 |
+
- logger:
|
| 7 |
+
- wandb
|
| 8 |
+
- tensorboard
|
| 9 |
+
- paths: default.yaml
|
| 10 |
+
# enable color logging
|
| 11 |
+
- override hydra/hydra_logging: colorlog
|
| 12 |
+
- override hydra/job_logging: colorlog
|
| 13 |
+
|
| 14 |
+
hydra:
|
| 15 |
+
run:
|
| 16 |
+
dir: ${get_eval_log_dir:${ckpt_path}}
|
| 17 |
+
|
| 18 |
+
#ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx"
|
| 19 |
+
ckpt_path: ${oc.env:ckpt_path}
|
| 20 |
+
|
| 21 |
+
split: 'test'
|
| 22 |
+
batch_size: 4
|
| 23 |
+
device: 'cuda:0'
|
| 24 |
+
bss: fast # fast or official
|
| 25 |
+
single: False # for debug investigation, only run the model on 1 single song
|
| 26 |
+
|
| 27 |
+
#data_dir: ${oc.env:data_dir}
|
| 28 |
+
eval_dir: ${oc.env:data_dir}
|
| 29 |
+
wandb_api_key: ${oc.env:wandb_api_key}
|
| 30 |
+
|
| 31 |
+
logger:
|
| 32 |
+
wandb:
|
| 33 |
+
# project: mdx_eval_${split}
|
| 34 |
+
project: new_eval_order
|
| 35 |
+
name: ${get_eval_log_dir:${ckpt_path}}
|
| 36 |
+
|
| 37 |
+
pool_workers: 8
|
| 38 |
+
double_chunk: False
|
| 39 |
+
|
| 40 |
+
overlap_add:
|
| 41 |
+
overlap_rate: 0.5
|
| 42 |
+
tmp_root: ${paths.root_dir}/tmp # for saving temp chunks, since we use ffmpeg and will need io to disk
|
| 43 |
+
samplerate: 44100
|
configs/experiment/bass_dis.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python run.py experiment=example_simple.yaml
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- multigpu_default
|
| 8 |
+
- override /model: bass.yaml
|
| 9 |
+
|
| 10 |
+
seed: 2021
|
| 11 |
+
|
| 12 |
+
exp_name: bass_g32
|
| 13 |
+
|
| 14 |
+
# the name inside project
|
| 15 |
+
logger:
|
| 16 |
+
wandb:
|
| 17 |
+
name: ${exp_name}
|
| 18 |
+
|
| 19 |
+
model:
|
| 20 |
+
lr: 0.0002
|
| 21 |
+
optimizer: adamW
|
| 22 |
+
bn_norm: syncBN
|
| 23 |
+
audio_ch: 2 # datamodule.single_channel
|
| 24 |
+
g: 32
|
| 25 |
+
|
| 26 |
+
trainer:
|
| 27 |
+
devices: 2 # int or list
|
| 28 |
+
sync_batchnorm: True
|
| 29 |
+
track_grad_norm: 2
|
| 30 |
+
# gradient_clip_val: 5
|
| 31 |
+
|
| 32 |
+
datamodule:
|
| 33 |
+
batch_size: 8
|
| 34 |
+
num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
|
| 35 |
+
pin_memory: False
|
| 36 |
+
overlap: ${model.overlap}
|
| 37 |
+
audio_ch: ${model.audio_ch}
|
| 38 |
+
epoch_size:
|
configs/experiment/drums_dis.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python run.py experiment=example_simple.yaml
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- multigpu_default
|
| 8 |
+
- override /model: drums.yaml
|
| 9 |
+
|
| 10 |
+
seed: 2021
|
| 11 |
+
|
| 12 |
+
exp_name: drums_g32
|
| 13 |
+
|
| 14 |
+
# the name inside project
|
| 15 |
+
logger:
|
| 16 |
+
wandb:
|
| 17 |
+
name: ${exp_name}
|
| 18 |
+
|
| 19 |
+
model:
|
| 20 |
+
lr: 0.0002
|
| 21 |
+
optimizer: adamW
|
| 22 |
+
bn_norm: syncBN
|
| 23 |
+
audio_ch: 2 # datamodule.single_channel
|
| 24 |
+
g: 32
|
| 25 |
+
|
| 26 |
+
trainer:
|
| 27 |
+
devices: 2 # int or list
|
| 28 |
+
sync_batchnorm: True
|
| 29 |
+
track_grad_norm: 2
|
| 30 |
+
# gradient_clip_val: 5
|
| 31 |
+
|
| 32 |
+
datamodule:
|
| 33 |
+
batch_size: 8
|
| 34 |
+
num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
|
| 35 |
+
pin_memory: False
|
| 36 |
+
overlap: ${model.overlap}
|
| 37 |
+
audio_ch: ${model.audio_ch}
|
| 38 |
+
epoch_size:
|
configs/experiment/multigpu_default.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python run.py experiment=example_simple.yaml
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- override /callbacks: default
|
| 8 |
+
- override /logger:
|
| 9 |
+
- wandb
|
| 10 |
+
- tensorboard
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
#callbacks:
|
| 14 |
+
# early_stopping:
|
| 15 |
+
# patience: 1000000
|
| 16 |
+
|
| 17 |
+
#datamodule:
|
| 18 |
+
# external_datasets:
|
| 19 |
+
# - test
|
| 20 |
+
|
| 21 |
+
trainer:
|
| 22 |
+
max_epochs: 1000000
|
| 23 |
+
accelerator: cuda
|
| 24 |
+
amp_backend: native
|
| 25 |
+
precision: 16
|
| 26 |
+
track_grad_norm: -1
|
configs/experiment/other_dis.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python run.py experiment=example_simple.yaml
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- multigpu_default
|
| 8 |
+
- override /model: other.yaml
|
| 9 |
+
|
| 10 |
+
seed: 2021
|
| 11 |
+
|
| 12 |
+
exp_name: other_g32
|
| 13 |
+
|
| 14 |
+
# the name inside project
|
| 15 |
+
logger:
|
| 16 |
+
wandb:
|
| 17 |
+
name: ${exp_name}
|
| 18 |
+
|
| 19 |
+
model:
|
| 20 |
+
lr: 0.0002
|
| 21 |
+
optimizer: adamW
|
| 22 |
+
bn_norm: syncBN
|
| 23 |
+
audio_ch: 2 # datamodule.single_channel
|
| 24 |
+
g: 32
|
| 25 |
+
|
| 26 |
+
trainer:
|
| 27 |
+
devices: 2 # int or list
|
| 28 |
+
sync_batchnorm: True
|
| 29 |
+
track_grad_norm: 2
|
| 30 |
+
# gradient_clip_val: 5
|
| 31 |
+
|
| 32 |
+
datamodule:
|
| 33 |
+
batch_size: 8
|
| 34 |
+
num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
|
| 35 |
+
pin_memory: False
|
| 36 |
+
overlap: ${model.overlap}
|
| 37 |
+
audio_ch: ${model.audio_ch}
|
| 38 |
+
epoch_size:
|
configs/experiment/vocals_dis.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# to execute this experiment run:
|
| 4 |
+
# python run.py experiment=example_simple.yaml
|
| 5 |
+
|
| 6 |
+
defaults:
|
| 7 |
+
- multigpu_default
|
| 8 |
+
- override /model: vocals.yaml
|
| 9 |
+
|
| 10 |
+
seed: 2021
|
| 11 |
+
|
| 12 |
+
exp_name: vocals_g32
|
| 13 |
+
|
| 14 |
+
# the name inside project
|
| 15 |
+
logger:
|
| 16 |
+
wandb:
|
| 17 |
+
name: ${exp_name}
|
| 18 |
+
|
| 19 |
+
model:
|
| 20 |
+
lr: 0.0002
|
| 21 |
+
optimizer: adamW
|
| 22 |
+
bn_norm: syncBN
|
| 23 |
+
audio_ch: 2 # datamodule.single_channel
|
| 24 |
+
g: 32
|
| 25 |
+
|
| 26 |
+
trainer:
|
| 27 |
+
devices: 2 # int or list
|
| 28 |
+
sync_batchnorm: True
|
| 29 |
+
track_grad_norm: 2
|
| 30 |
+
# gradient_clip_val: 5
|
| 31 |
+
|
| 32 |
+
datamodule:
|
| 33 |
+
batch_size: 8
|
| 34 |
+
num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
|
| 35 |
+
pin_memory: False
|
| 36 |
+
overlap: ${model.overlap}
|
| 37 |
+
audio_ch: ${model.audio_ch}
|
| 38 |
+
epoch_size:
|
configs/hydra/default.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# output paths for hydra logs
|
| 2 |
+
run:
|
| 3 |
+
# dir: logs/runs/${datamodule.target_name}_${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 4 |
+
dir: ${get_train_log_dir:${datamodule.target_name},${exp_name}}
|
| 5 |
+
|
| 6 |
+
sweep:
|
| 7 |
+
# dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
|
| 8 |
+
dir: ${get_sweep_log_dir:${datamodule.target_name},${exp_name}}
|
| 9 |
+
subdir: ${hydra.job.num}
|
| 10 |
+
|
| 11 |
+
# you can set here environment variables that are universal for all users
|
| 12 |
+
# for system specific variables (like data paths) it's better to use .env file!
|
| 13 |
+
job:
|
| 14 |
+
env_set:
|
| 15 |
+
EXAMPLE_VAR: "example_value"
|
| 16 |
+
|
configs/infer.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# specify here default training configuration
|
| 4 |
+
defaults:
|
| 5 |
+
- model: vocals
|
| 6 |
+
- paths: default.yaml
|
| 7 |
+
# enable color logging
|
| 8 |
+
- override hydra/hydra_logging: colorlog
|
| 9 |
+
- override hydra/job_logging: colorlog
|
| 10 |
+
|
| 11 |
+
#hydra:
|
| 12 |
+
# run:
|
| 13 |
+
# dir: ${get_eval_log_dir:${ckpt_path}}
|
| 14 |
+
|
| 15 |
+
#ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx"
|
| 16 |
+
ckpt_path:
|
| 17 |
+
mixture_path:
|
| 18 |
+
batch_size: 4
|
| 19 |
+
device: 'cuda:0'
|
| 20 |
+
|
| 21 |
+
double_chunk: False
|
| 22 |
+
|
| 23 |
+
overlap_add:
|
| 24 |
+
overlap_rate: 0.5
|
| 25 |
+
tmp_root: ${paths.root_dir}/tmp # for saving temp chunks, since we use ffmpeg and will need io to disk
|
| 26 |
+
samplerate: 44100
|
configs/logger/csv.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# csv logger built in lightning
|
| 2 |
+
|
| 3 |
+
csv:
|
| 4 |
+
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
|
| 5 |
+
save_dir: "."
|
| 6 |
+
name: "csv/"
|
| 7 |
+
version: null
|
| 8 |
+
prefix: ""
|
configs/logger/many_loggers.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train with many loggers at once
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
# - aim.yaml
|
| 5 |
+
# - comet.yaml
|
| 6 |
+
- csv.yaml
|
| 7 |
+
# - mlflow.yaml
|
| 8 |
+
# - neptune.yaml
|
| 9 |
+
# - tensorboard.yaml
|
| 10 |
+
- wandb.yaml
|
configs/logger/neptune.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://neptune.ai
|
| 2 |
+
|
| 3 |
+
neptune:
|
| 4 |
+
_target_: pytorch_lightning.loggers.neptune.NeptuneLogger
|
| 5 |
+
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable
|
| 6 |
+
project_name: your_name/template-tests
|
| 7 |
+
close_after_fit: True
|
| 8 |
+
offline_mode: False
|
| 9 |
+
experiment_name: null
|
| 10 |
+
experiment_id: null
|
| 11 |
+
prefix: ""
|
configs/logger/none.yaml
ADDED
|
File without changes
|
configs/logger/tensorboard.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://www.tensorflow.org/tensorboard/
|
| 2 |
+
|
| 3 |
+
tensorboard:
|
| 4 |
+
_target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
|
| 5 |
+
save_dir: "tensorboard/"
|
| 6 |
+
name: "default"
|
| 7 |
+
version: null
|
| 8 |
+
log_graph: False
|
| 9 |
+
default_hp_metric: True
|
| 10 |
+
prefix: ""
|
configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://wandb.ai
|
| 2 |
+
|
| 3 |
+
wandb:
|
| 4 |
+
_target_: pytorch_lightning.loggers.wandb.WandbLogger
|
| 5 |
+
project: dtt_${model.target_name}
|
| 6 |
+
name: null
|
| 7 |
+
save_dir: ${hydra:run.dir}
|
| 8 |
+
offline: False # set True to store all logs only locally
|
| 9 |
+
id: null # pass correct id to resume experiment!
|
| 10 |
+
# entity: "" # set to name of your wandb team or just remove it
|
| 11 |
+
log_model: False
|
| 12 |
+
prefix: ""
|
| 13 |
+
job_type: "train"
|
| 14 |
+
group: ""
|
| 15 |
+
tags: []
|
configs/model/bass.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.dp_tdf.dp_tdf_net.DPTDFNet
|
| 2 |
+
|
| 3 |
+
# abstract parent class
|
| 4 |
+
target_name: 'bass'
|
| 5 |
+
lr: 0.0001
|
| 6 |
+
optimizer: adamW
|
| 7 |
+
|
| 8 |
+
dim_f: 864
|
| 9 |
+
dim_t: 256
|
| 10 |
+
n_fft: 6144
|
| 11 |
+
hop_length: 1024
|
| 12 |
+
overlap: 3072
|
| 13 |
+
|
| 14 |
+
audio_ch: 2
|
| 15 |
+
|
| 16 |
+
block_type: TFC_TDF_Res2
|
| 17 |
+
num_blocks: 5
|
| 18 |
+
l: 3
|
| 19 |
+
g: 32
|
| 20 |
+
k: 3
|
| 21 |
+
bn: 2
|
| 22 |
+
bias: False
|
| 23 |
+
bn_norm: BN
|
| 24 |
+
bandsequence:
|
| 25 |
+
rnn_type: LSTM
|
| 26 |
+
bidirectional: True
|
| 27 |
+
num_layers: 4
|
| 28 |
+
n_heads: 2
|
configs/model/drums.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.dp_tdf.dp_tdf_net.DPTDFNet
|
| 2 |
+
|
| 3 |
+
# abstract parent class
|
| 4 |
+
target_name: 'drums'
|
| 5 |
+
lr: 0.0001
|
| 6 |
+
optimizer: adamW
|
| 7 |
+
|
| 8 |
+
dim_f: 2048
|
| 9 |
+
dim_t: 256
|
| 10 |
+
n_fft: 6144
|
| 11 |
+
hop_length: 1024
|
| 12 |
+
overlap: 3072
|
| 13 |
+
|
| 14 |
+
audio_ch: 2
|
| 15 |
+
|
| 16 |
+
block_type: TFC_TDF_Res2
|
| 17 |
+
num_blocks: 5
|
| 18 |
+
l: 3
|
| 19 |
+
g: 32
|
| 20 |
+
k: 3
|
| 21 |
+
bn: 8
|
| 22 |
+
bias: False
|
| 23 |
+
bn_norm: BN
|
| 24 |
+
bandsequence:
|
| 25 |
+
rnn_type: LSTM
|
| 26 |
+
bidirectional: True
|
| 27 |
+
num_layers: 4
|
| 28 |
+
n_heads: 2
|
configs/model/other.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.dp_tdf.dp_tdf_net.DPTDFNet
|
| 2 |
+
|
| 3 |
+
# abstract parent class
|
| 4 |
+
target_name: 'other'
|
| 5 |
+
lr: 0.0001
|
| 6 |
+
optimizer: adamW
|
| 7 |
+
|
| 8 |
+
dim_f: 2048
|
| 9 |
+
dim_t: 256
|
| 10 |
+
n_fft: 6144
|
| 11 |
+
hop_length: 1024
|
| 12 |
+
overlap: 3072
|
| 13 |
+
|
| 14 |
+
audio_ch: 2
|
| 15 |
+
|
| 16 |
+
block_type: TFC_TDF_Res2
|
| 17 |
+
num_blocks: 5
|
| 18 |
+
l: 3
|
| 19 |
+
g: 32
|
| 20 |
+
k: 3
|
| 21 |
+
bn: 8
|
| 22 |
+
bias: False
|
| 23 |
+
bn_norm: BN
|
| 24 |
+
bandsequence:
|
| 25 |
+
rnn_type: LSTM
|
| 26 |
+
bidirectional: True
|
| 27 |
+
num_layers: 4
|
| 28 |
+
n_heads: 2
|
configs/model/vocals.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: src.dp_tdf.dp_tdf_net.DPTDFNet
|
| 2 |
+
|
| 3 |
+
# abstract parent class
|
| 4 |
+
target_name: 'vocals'
|
| 5 |
+
lr: 0.0001
|
| 6 |
+
optimizer: adamW
|
| 7 |
+
|
| 8 |
+
dim_f: 2048
|
| 9 |
+
dim_t: 256
|
| 10 |
+
n_fft: 6144
|
| 11 |
+
hop_length: 1024
|
| 12 |
+
overlap: 3072
|
| 13 |
+
|
| 14 |
+
audio_ch: 2
|
| 15 |
+
|
| 16 |
+
block_type: TFC_TDF_Res2
|
| 17 |
+
num_blocks: 5
|
| 18 |
+
l: 3
|
| 19 |
+
g: 32
|
| 20 |
+
k: 3
|
| 21 |
+
bn: 8
|
| 22 |
+
bias: False
|
| 23 |
+
bn_norm: BN
|
| 24 |
+
bandsequence:
|
| 25 |
+
rnn_type: LSTM
|
| 26 |
+
bidirectional: True
|
| 27 |
+
num_layers: 4
|
| 28 |
+
n_heads: 2
|
configs/paths/default.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# path to root directory
|
| 2 |
+
# this requires PROJECT_ROOT environment variable to exist
|
| 3 |
+
# you can replace it with "." if you want the root to be the current working directory
|
| 4 |
+
root_dir: ${oc.env:PROJECT_ROOT}
|
| 5 |
+
|
| 6 |
+
# path to data directory
|
| 7 |
+
data_dir: ${paths.root_dir}/data/
|
| 8 |
+
|
| 9 |
+
# path to logging directory
|
| 10 |
+
log_dir: ${oc.env:LOG_DIR}
|
| 11 |
+
|
| 12 |
+
# path to output directory, created dynamically by hydra
|
| 13 |
+
# path generation pattern is specified in `configs/hydra/default.yaml`
|
| 14 |
+
# use it to store all files generated during the run, like ckpts and metrics
|
| 15 |
+
output_dir: ${hydra:runtime.output_dir}
|
| 16 |
+
|
| 17 |
+
# path to working directory
|
| 18 |
+
work_dir: ${hydra:runtime.cwd}
|
configs/trainer/ddp.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- default.yaml
|
| 3 |
+
|
| 4 |
+
# use "ddp_spawn" instead of "ddp",
|
| 5 |
+
# it's slower but normal "ddp" currently doesn't work ideally with hydra
|
| 6 |
+
# https://github.com/facebookresearch/hydra/issues/2070
|
| 7 |
+
# https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
|
| 8 |
+
strategy: ddp_spawn
|
| 9 |
+
|
| 10 |
+
accelerator: gpu
|
| 11 |
+
devices: 2
|
| 12 |
+
num_nodes: 1
|
| 13 |
+
sync_batchnorm: True
|
configs/trainer/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: pytorch_lightning.Trainer
|
| 2 |
+
|
| 3 |
+
default_root_dir: ${paths.output_dir}
|
| 4 |
+
|
| 5 |
+
min_epochs: 1 # prevents early stopping
|
| 6 |
+
max_epochs: 10
|
| 7 |
+
|
| 8 |
+
accelerator: cpu
|
| 9 |
+
devices: 1
|
| 10 |
+
|
| 11 |
+
# mixed precision for extra speed-up
|
| 12 |
+
# precision: 16
|
| 13 |
+
|
| 14 |
+
# perform a validation loop every N training epochs
|
| 15 |
+
check_val_every_n_epoch: 1
|
| 16 |
+
|
| 17 |
+
# set True to to ensure deterministic results
|
| 18 |
+
# makes training slower but gives more reproducibility than just setting seeds
|
| 19 |
+
deterministic: False
|
configs/trainer/minimal.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: pytorch_lightning.Trainer
|
| 2 |
+
|
| 3 |
+
defaults:
|
| 4 |
+
- default
|
| 5 |
+
|
| 6 |
+
devices: 4
|
| 7 |
+
|
| 8 |
+
resume_from_checkpoint:
|
| 9 |
+
auto_lr_find: False
|
| 10 |
+
deterministic: True
|
| 11 |
+
accelerator: dp
|
| 12 |
+
sync_batchnorm: False
|
| 13 |
+
|
| 14 |
+
max_epochs: 3000
|
| 15 |
+
min_epochs: 1
|
| 16 |
+
check_val_every_n_epoch: 10
|
| 17 |
+
num_sanity_val_steps: 1
|
| 18 |
+
|
| 19 |
+
precision: 16
|
| 20 |
+
amp_backend: "native"
|
| 21 |
+
amp_level: "O2"
|
src/__init__.py
ADDED
|
File without changes
|
src/callbacks/__init__.py
ADDED
|
File without changes
|
src/callbacks/onnx_callback.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from pytorch_lightning import Callback
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
import inspect
|
| 8 |
+
from src.models.mdxnet import AbstractMDXNet
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MakeONNXCallback(Callback):
|
| 12 |
+
"""Upload all *.py files to wandb as an artifact, at the beginning of the run."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, dirpath: str):
|
| 15 |
+
self.dirpath = dirpath
|
| 16 |
+
if not os.path.exists(self.dirpath):
|
| 17 |
+
os.mkdir(self.dirpath)
|
| 18 |
+
|
| 19 |
+
def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule',
|
| 20 |
+
checkpoint: Dict[str, Any]) -> dict:
|
| 21 |
+
res = super().on_save_checkpoint(trainer, pl_module, checkpoint)
|
| 22 |
+
|
| 23 |
+
var = inspect.signature(pl_module.__init__).parameters
|
| 24 |
+
model = pl_module.__class__(**dict((name, pl_module.__dict__[name]) for name in var))
|
| 25 |
+
model.load_state_dict(pl_module.state_dict())
|
| 26 |
+
|
| 27 |
+
target_dir = '{}epoch_{}'.format(self.dirpath, pl_module.current_epoch)
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if not os.path.exists(target_dir):
|
| 31 |
+
os.mkdir(target_dir)
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
torch.onnx.export(model,
|
| 35 |
+
torch.zeros(model.inference_chunk_shape),
|
| 36 |
+
'{}/{}.onnx'.format(target_dir, model.target_name),
|
| 37 |
+
export_params=True, # store the trained parameter weights inside the model file
|
| 38 |
+
opset_version=13, # the ONNX version to export the model to
|
| 39 |
+
do_constant_folding=True, # whether to execute constant folding for optimization
|
| 40 |
+
input_names=['input'], # the model's input names
|
| 41 |
+
output_names=['output'], # the model's output names
|
| 42 |
+
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
| 43 |
+
'output': {0: 'batch_size'}})
|
| 44 |
+
except:
|
| 45 |
+
print('onnx error')
|
| 46 |
+
finally:
|
| 47 |
+
del model
|
| 48 |
+
|
| 49 |
+
return res
|
src/callbacks/wandb_callbacks.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Optional, Any
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import seaborn as sn
|
| 7 |
+
import torch
|
| 8 |
+
import wandb
|
| 9 |
+
from pytorch_lightning import Callback, Trainer
|
| 10 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 11 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 12 |
+
from sklearn import metrics
|
| 13 |
+
from sklearn.metrics import f1_score, precision_score, recall_score
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
|
| 17 |
+
"""Safely get Weights&Biases logger from Trainer."""
|
| 18 |
+
|
| 19 |
+
if isinstance(trainer.logger, WandbLogger):
|
| 20 |
+
return trainer.logger
|
| 21 |
+
|
| 22 |
+
if isinstance(trainer.loggers, list):
|
| 23 |
+
for logger in trainer.loggers:
|
| 24 |
+
if isinstance(logger, WandbLogger):
|
| 25 |
+
return logger
|
| 26 |
+
|
| 27 |
+
raise Exception(
|
| 28 |
+
"You are using wandb related callback, but WandbLogger was not found for some reason..."
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class UploadValidTrack(Callback):
|
| 33 |
+
def __init__(self, crop: int, upload_after_n_epoch: int):
|
| 34 |
+
self.sample_length = crop * 44100
|
| 35 |
+
self.upload_after_n_epoch = upload_after_n_epoch
|
| 36 |
+
self.len_left_window = self.len_right_window = self.sample_length // 2
|
| 37 |
+
|
| 38 |
+
def on_validation_batch_end(
|
| 39 |
+
self,
|
| 40 |
+
trainer: 'pl.Trainer',
|
| 41 |
+
pl_module: 'pl.LightningModule',
|
| 42 |
+
outputs: Optional[STEP_OUTPUT],
|
| 43 |
+
batch: Any,
|
| 44 |
+
batch_idx: int,
|
| 45 |
+
dataloader_idx: int,
|
| 46 |
+
) -> None:
|
| 47 |
+
if outputs is None:
|
| 48 |
+
return
|
| 49 |
+
track_id = outputs['track_id']
|
| 50 |
+
track = outputs['track']
|
| 51 |
+
|
| 52 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 53 |
+
experiment = logger.experiment
|
| 54 |
+
if pl_module.current_epoch < self.upload_after_n_epoch:
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
mid = track.shape[-1]//2
|
| 58 |
+
track = track[:, mid-self.len_left_window:mid+self.len_right_window]
|
| 59 |
+
|
| 60 |
+
experiment.log({'track={}_epoch={}'.format(track_id, pl_module.current_epoch):
|
| 61 |
+
[wandb.Audio(track.T, sample_rate=44100)]})
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class WatchModel(Callback):
|
| 65 |
+
"""Make wandb watch model at the beginning of the run."""
|
| 66 |
+
|
| 67 |
+
def __init__(self, log: str = "gradients", log_freq: int = 100):
|
| 68 |
+
self.log = log
|
| 69 |
+
self.log_freq = log_freq
|
| 70 |
+
|
| 71 |
+
def on_train_start(self, trainer, pl_module):
|
| 72 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 73 |
+
logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class UploadCodeAsArtifact(Callback):
|
| 77 |
+
"""Upload all *.py files to wandb as an artifact, at the beginning of the run."""
|
| 78 |
+
|
| 79 |
+
def __init__(self, code_dir: str):
|
| 80 |
+
self.code_dir = code_dir
|
| 81 |
+
|
| 82 |
+
def on_train_start(self, trainer, pl_module):
|
| 83 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 84 |
+
experiment = logger.experiment
|
| 85 |
+
|
| 86 |
+
code = wandb.Artifact("project-source", type="code")
|
| 87 |
+
for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True):
|
| 88 |
+
code.add_file(path)
|
| 89 |
+
|
| 90 |
+
experiment.use_artifact(code)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class UploadCheckpointsAsArtifact(Callback):
|
| 94 |
+
"""Upload checkpoints to wandb as an artifact, at the end of run."""
|
| 95 |
+
|
| 96 |
+
def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
|
| 97 |
+
self.ckpt_dir = ckpt_dir
|
| 98 |
+
self.upload_best_only = upload_best_only
|
| 99 |
+
|
| 100 |
+
def on_train_end(self, trainer, pl_module):
|
| 101 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 102 |
+
experiment = logger.experiment
|
| 103 |
+
|
| 104 |
+
ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
|
| 105 |
+
|
| 106 |
+
if self.upload_best_only:
|
| 107 |
+
ckpts.add_file(trainer.checkpoint_callback.best_model_path)
|
| 108 |
+
else:
|
| 109 |
+
for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True):
|
| 110 |
+
ckpts.add_file(path)
|
| 111 |
+
|
| 112 |
+
experiment.use_artifact(ckpts)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class LogConfusionMatrix(Callback):
|
| 116 |
+
"""Generate confusion matrix every epoch and send it to wandb.
|
| 117 |
+
Expects validation step to return predictions and targets.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self):
|
| 121 |
+
self.preds = []
|
| 122 |
+
self.targets = []
|
| 123 |
+
self.ready = True
|
| 124 |
+
|
| 125 |
+
def on_sanity_check_start(self, trainer, pl_module) -> None:
|
| 126 |
+
self.ready = False
|
| 127 |
+
|
| 128 |
+
def on_sanity_check_end(self, trainer, pl_module):
|
| 129 |
+
"""Start executing this callback only after all validation sanity checks end."""
|
| 130 |
+
self.ready = True
|
| 131 |
+
|
| 132 |
+
def on_validation_batch_end(
|
| 133 |
+
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
| 134 |
+
):
|
| 135 |
+
"""Gather data from single batch."""
|
| 136 |
+
if self.ready:
|
| 137 |
+
self.preds.append(outputs["preds"])
|
| 138 |
+
self.targets.append(outputs["targets"])
|
| 139 |
+
|
| 140 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 141 |
+
"""Generate confusion matrix."""
|
| 142 |
+
if self.ready:
|
| 143 |
+
logger = get_wandb_logger(trainer)
|
| 144 |
+
experiment = logger.experiment
|
| 145 |
+
|
| 146 |
+
preds = torch.cat(self.preds).cpu().numpy()
|
| 147 |
+
targets = torch.cat(self.targets).cpu().numpy()
|
| 148 |
+
|
| 149 |
+
confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)
|
| 150 |
+
|
| 151 |
+
# set figure size
|
| 152 |
+
plt.figure(figsize=(14, 8))
|
| 153 |
+
|
| 154 |
+
# set labels size
|
| 155 |
+
sn.set(font_scale=1.4)
|
| 156 |
+
|
| 157 |
+
# set font size
|
| 158 |
+
sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g")
|
| 159 |
+
|
| 160 |
+
# names should be uniqe or else charts from different experiments in wandb will overlap
|
| 161 |
+
experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False)
|
| 162 |
+
|
| 163 |
+
# according to wandb docs this should also work but it crashes
|
| 164 |
+
# experiment.log(f{"confusion_matrix/{experiment.name}": plt})
|
| 165 |
+
|
| 166 |
+
# reset plot
|
| 167 |
+
plt.clf()
|
| 168 |
+
|
| 169 |
+
self.preds.clear()
|
| 170 |
+
self.targets.clear()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class LogF1PrecRecHeatmap(Callback):
|
| 174 |
+
"""Generate f1, precision, recall heatmap every epoch and send it to wandb.
|
| 175 |
+
Expects validation step to return predictions and targets.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, class_names: List[str] = None):
|
| 179 |
+
self.preds = []
|
| 180 |
+
self.targets = []
|
| 181 |
+
self.ready = True
|
| 182 |
+
|
| 183 |
+
def on_sanity_check_start(self, trainer, pl_module):
|
| 184 |
+
self.ready = False
|
| 185 |
+
|
| 186 |
+
def on_sanity_check_end(self, trainer, pl_module):
|
| 187 |
+
"""Start executing this callback only after all validation sanity checks end."""
|
| 188 |
+
self.ready = True
|
| 189 |
+
|
| 190 |
+
def on_validation_batch_end(
|
| 191 |
+
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
| 192 |
+
):
|
| 193 |
+
"""Gather data from single batch."""
|
| 194 |
+
if self.ready:
|
| 195 |
+
self.preds.append(outputs["preds"])
|
| 196 |
+
self.targets.append(outputs["targets"])
|
| 197 |
+
|
| 198 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 199 |
+
"""Generate f1, precision and recall heatmap."""
|
| 200 |
+
if self.ready:
|
| 201 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 202 |
+
experiment = logger.experiment
|
| 203 |
+
|
| 204 |
+
preds = torch.cat(self.preds).cpu().numpy()
|
| 205 |
+
targets = torch.cat(self.targets).cpu().numpy()
|
| 206 |
+
f1 = f1_score(preds, targets, average=None)
|
| 207 |
+
r = recall_score(preds, targets, average=None)
|
| 208 |
+
p = precision_score(preds, targets, average=None)
|
| 209 |
+
data = [f1, p, r]
|
| 210 |
+
|
| 211 |
+
# set figure size
|
| 212 |
+
plt.figure(figsize=(14, 3))
|
| 213 |
+
|
| 214 |
+
# set labels size
|
| 215 |
+
sn.set(font_scale=1.2)
|
| 216 |
+
|
| 217 |
+
# set font size
|
| 218 |
+
sn.heatmap(
|
| 219 |
+
data,
|
| 220 |
+
annot=True,
|
| 221 |
+
annot_kws={"size": 10},
|
| 222 |
+
fmt=".3f",
|
| 223 |
+
yticklabels=["F1", "Precision", "Recall"],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# names should be uniqe or else charts from different experiments in wandb will overlap
|
| 227 |
+
experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False)
|
| 228 |
+
|
| 229 |
+
# reset plot
|
| 230 |
+
plt.clf()
|
| 231 |
+
|
| 232 |
+
self.preds.clear()
|
| 233 |
+
self.targets.clear()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class LogImagePredictions(Callback):
|
| 237 |
+
"""Logs a validation batch and their predictions to wandb.
|
| 238 |
+
Example adapted from:
|
| 239 |
+
https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(self, num_samples: int = 8):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.num_samples = num_samples
|
| 245 |
+
self.ready = True
|
| 246 |
+
|
| 247 |
+
def on_sanity_check_start(self, trainer, pl_module):
|
| 248 |
+
self.ready = False
|
| 249 |
+
|
| 250 |
+
def on_sanity_check_end(self, trainer, pl_module):
|
| 251 |
+
"""Start executing this callback only after all validation sanity checks end."""
|
| 252 |
+
self.ready = True
|
| 253 |
+
|
| 254 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 255 |
+
if self.ready:
|
| 256 |
+
logger = get_wandb_logger(trainer=trainer)
|
| 257 |
+
experiment = logger.experiment
|
| 258 |
+
|
| 259 |
+
# get a validation batch from the validation dat loader
|
| 260 |
+
val_samples = next(iter(trainer.datamodule.val_dataloader()))
|
| 261 |
+
val_imgs, val_labels = val_samples
|
| 262 |
+
|
| 263 |
+
# run the batch through the network
|
| 264 |
+
val_imgs = val_imgs.to(device=pl_module.device)
|
| 265 |
+
logits = pl_module(val_imgs)
|
| 266 |
+
preds = torch.argmax(logits, axis=-1)
|
| 267 |
+
|
| 268 |
+
# log the images as wandb Image
|
| 269 |
+
experiment.log(
|
| 270 |
+
{
|
| 271 |
+
f"Images/{experiment.name}": [
|
| 272 |
+
wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
|
| 273 |
+
for x, pred, y in zip(
|
| 274 |
+
val_imgs[: self.num_samples],
|
| 275 |
+
preds[: self.num_samples],
|
| 276 |
+
val_labels[: self.num_samples],
|
| 277 |
+
)
|
| 278 |
+
]
|
| 279 |
+
}
|
| 280 |
+
)
|
src/datamodules/__init__.py
ADDED
|
File without changes
|
src/datamodules/datasets/__init__.py
ADDED
|
File without changes
|
src/datamodules/datasets/musdb.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from abc import ABCMeta, ABC
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import soundfile
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.utils.utils import load_wav
|
| 13 |
+
from src import utils
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
log = utils.get_pylogger(__name__)
|
| 17 |
+
|
| 18 |
+
def check_target_name(target_name, source_names):
|
| 19 |
+
try:
|
| 20 |
+
assert target_name is not None
|
| 21 |
+
except AssertionError:
|
| 22 |
+
print('[ERROR] please identify target name. ex) +datamodule.target_name="vocals"')
|
| 23 |
+
exit(-1)
|
| 24 |
+
try:
|
| 25 |
+
assert target_name in source_names or target_name == 'all'
|
| 26 |
+
except AssertionError:
|
| 27 |
+
print('[ERROR] target name should one of "bass", "drums", "other", "vocals", "all"')
|
| 28 |
+
exit(-1)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def check_sample_rate(sr, sample_track):
|
| 32 |
+
try:
|
| 33 |
+
sample_rate = soundfile.read(sample_track)[1]
|
| 34 |
+
assert sample_rate == sr
|
| 35 |
+
except AssertionError:
|
| 36 |
+
sample_rate = soundfile.read(sample_track)[1]
|
| 37 |
+
print('[ERROR] sampling rate mismatched')
|
| 38 |
+
print('\t=> sr in Config file: {}, but sr of data: {}'.format(sr, sample_rate))
|
| 39 |
+
exit(-1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class MusdbDataset(Dataset):
|
| 43 |
+
__metaclass__ = ABCMeta
|
| 44 |
+
|
| 45 |
+
def __init__(self, data_dir, chunk_size):
|
| 46 |
+
self.source_names = ['bass', 'drums', 'other', 'vocals']
|
| 47 |
+
self.chunk_size = chunk_size
|
| 48 |
+
self.musdb_path = Path(data_dir)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MusdbTrainDataset(MusdbDataset):
|
| 52 |
+
def __init__(self, data_dir, chunk_size, target_name, aug_params, external_datasets, single_channel, epoch_size):
|
| 53 |
+
super(MusdbTrainDataset, self).__init__(data_dir, chunk_size)
|
| 54 |
+
|
| 55 |
+
self.single_channel = single_channel
|
| 56 |
+
self.neg_lst = [x for x in self.source_names if x != target_name]
|
| 57 |
+
|
| 58 |
+
self.target_name = target_name
|
| 59 |
+
check_target_name(self.target_name, self.source_names)
|
| 60 |
+
|
| 61 |
+
if not self.musdb_path.joinpath('metadata').exists():
|
| 62 |
+
os.mkdir(self.musdb_path.joinpath('metadata'))
|
| 63 |
+
|
| 64 |
+
splits = ['train']
|
| 65 |
+
if external_datasets is not None:
|
| 66 |
+
splits += external_datasets
|
| 67 |
+
|
| 68 |
+
# collect paths for datasets and metadata (track names and duration)
|
| 69 |
+
datasets, metadata_caches = [], []
|
| 70 |
+
raw_datasets = [] # un-augmented datasets
|
| 71 |
+
for split in splits:
|
| 72 |
+
raw_datasets.append(self.musdb_path.joinpath(split))
|
| 73 |
+
max_pitch, max_tempo = aug_params
|
| 74 |
+
for p in range(-max_pitch, max_pitch+1):
|
| 75 |
+
for t in range(-max_tempo, max_tempo+1, 10):
|
| 76 |
+
aug_split = split if p==t==0 else split + f'_p={p}_t={t}'
|
| 77 |
+
datasets.append(self.musdb_path.joinpath(aug_split))
|
| 78 |
+
metadata_caches.append(self.musdb_path.joinpath('metadata').joinpath(aug_split + '.pkl'))
|
| 79 |
+
|
| 80 |
+
# collect all track names and their duration
|
| 81 |
+
self.metadata = []
|
| 82 |
+
raw_track_lengths = [] # for calculating epoch size
|
| 83 |
+
for i, (dataset, metadata_cache) in enumerate(tqdm(zip(datasets, metadata_caches))):
|
| 84 |
+
try:
|
| 85 |
+
metadata = torch.load(metadata_cache)
|
| 86 |
+
except FileNotFoundError:
|
| 87 |
+
print('creating metadata for', dataset)
|
| 88 |
+
metadata = []
|
| 89 |
+
for track_name in sorted(os.listdir(dataset)):
|
| 90 |
+
track_path = dataset.joinpath(track_name)
|
| 91 |
+
track_length = load_wav(track_path.joinpath('vocals.wav')).shape[-1]
|
| 92 |
+
metadata.append((track_path, track_length))
|
| 93 |
+
torch.save(metadata, metadata_cache)
|
| 94 |
+
|
| 95 |
+
self.metadata += metadata
|
| 96 |
+
if dataset in raw_datasets:
|
| 97 |
+
raw_track_lengths += [length for path, length in metadata]
|
| 98 |
+
|
| 99 |
+
self.epoch_size = sum(raw_track_lengths) // self.chunk_size if epoch_size is None else epoch_size
|
| 100 |
+
log.info(f'epoch size: {self.epoch_size}')
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, _):
|
| 103 |
+
sources = []
|
| 104 |
+
for source_name in self.source_names:
|
| 105 |
+
track_path, track_length = random.choice(self.metadata) # random mixing between tracks
|
| 106 |
+
source = load_wav(track_path.joinpath(source_name + '.wav'),
|
| 107 |
+
track_length=track_length, chunk_size=self.chunk_size) # (2, times)
|
| 108 |
+
sources.append(source)
|
| 109 |
+
|
| 110 |
+
mix = sum(sources)
|
| 111 |
+
|
| 112 |
+
if self.target_name == 'all':
|
| 113 |
+
# Targets for models that separate all four sources (ex. Demucs).
|
| 114 |
+
# This adds additional 'source' dimension => batch_shape=[batch, source, channel, time]
|
| 115 |
+
target = sources
|
| 116 |
+
else:
|
| 117 |
+
target = sources[self.source_names.index(self.target_name)]
|
| 118 |
+
|
| 119 |
+
mix, target = torch.tensor(mix), torch.tensor(target)
|
| 120 |
+
if self.single_channel:
|
| 121 |
+
mix = torch.mean(mix, dim=0, keepdim=True)
|
| 122 |
+
target = torch.mean(target, dim=0, keepdim=True)
|
| 123 |
+
return mix, target
|
| 124 |
+
|
| 125 |
+
def __len__(self):
|
| 126 |
+
return self.epoch_size
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class MusdbValidDataset(MusdbDataset):
|
| 130 |
+
|
| 131 |
+
def __init__(self, data_dir, chunk_size, target_name, overlap, batch_size, single_channel):
|
| 132 |
+
super(MusdbValidDataset, self).__init__(data_dir, chunk_size)
|
| 133 |
+
|
| 134 |
+
self.target_name = target_name
|
| 135 |
+
check_target_name(self.target_name, self.source_names)
|
| 136 |
+
|
| 137 |
+
self.overlap = overlap
|
| 138 |
+
self.batch_size = batch_size
|
| 139 |
+
self.single_channel = single_channel
|
| 140 |
+
|
| 141 |
+
musdb_valid_path = self.musdb_path.joinpath('valid')
|
| 142 |
+
self.track_paths = [musdb_valid_path.joinpath(track_name)
|
| 143 |
+
for track_name in os.listdir(musdb_valid_path)]
|
| 144 |
+
|
| 145 |
+
def __getitem__(self, index):
|
| 146 |
+
mix = load_wav(self.track_paths[index].joinpath('mixture.wav')) # (2, time)
|
| 147 |
+
|
| 148 |
+
if self.target_name == 'all':
|
| 149 |
+
# Targets for models that separate all four sources (ex. Demucs).
|
| 150 |
+
# This adds additional 'source' dimension => batch_shape=[batch, source, channel, time]
|
| 151 |
+
target = [load_wav(self.track_paths[index].joinpath(source_name + '.wav'))
|
| 152 |
+
for source_name in self.source_names]
|
| 153 |
+
else:
|
| 154 |
+
target = load_wav(self.track_paths[index].joinpath(self.target_name + '.wav'))
|
| 155 |
+
|
| 156 |
+
chunk_output_size = self.chunk_size - 2 * self.overlap
|
| 157 |
+
left_pad = np.zeros([2, self.overlap])
|
| 158 |
+
right_pad = np.zeros([2, self.overlap + chunk_output_size - (mix.shape[-1] % chunk_output_size)])
|
| 159 |
+
mix_padded = np.concatenate([left_pad, mix, right_pad], 1)
|
| 160 |
+
|
| 161 |
+
num_chunks = mix_padded.shape[-1] // chunk_output_size
|
| 162 |
+
mix_chunks = np.array([mix_padded[:, i * chunk_output_size: i * chunk_output_size + self.chunk_size]
|
| 163 |
+
for i in range(num_chunks)])
|
| 164 |
+
mix_chunk_batches = torch.tensor(mix_chunks, dtype=torch.float32).split(self.batch_size)
|
| 165 |
+
target = torch.tensor(target)
|
| 166 |
+
|
| 167 |
+
if self.single_channel:
|
| 168 |
+
mix_chunk_batches = [torch.mean(t, dim=1, keepdim=True) for t in mix_chunk_batches]
|
| 169 |
+
target = torch.mean(target, dim=0, keepdim=True)
|
| 170 |
+
|
| 171 |
+
return mix_chunk_batches, target
|
| 172 |
+
|
| 173 |
+
def __len__(self):
|
| 174 |
+
return len(self.track_paths)
|
src/datamodules/musdb_datamodule.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from os.path import exists, join
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Optional, Tuple
|
| 5 |
+
|
| 6 |
+
from pytorch_lightning import LightningDataModule
|
| 7 |
+
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
| 8 |
+
|
| 9 |
+
from src.datamodules.datasets.musdb import MusdbTrainDataset, MusdbValidDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MusdbDataModule(LightningDataModule):
|
| 13 |
+
"""
|
| 14 |
+
LightningDataModule for Musdb18-HQ dataset.
|
| 15 |
+
A DataModule implements 5 key methods:
|
| 16 |
+
- prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
|
| 17 |
+
- setup (things to do on every accelerator in distributed mode)
|
| 18 |
+
- train_dataloader (the training dataloader)
|
| 19 |
+
- val_dataloader (the validation dataloader(s))
|
| 20 |
+
- test_dataloader (the test dataloader(s))
|
| 21 |
+
This allows you to share a full dataset without explaining how to download,
|
| 22 |
+
split, transform and process the data
|
| 23 |
+
Read the docs:
|
| 24 |
+
https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
data_dir: str,
|
| 30 |
+
aug_params,
|
| 31 |
+
target_name: str,
|
| 32 |
+
overlap: int,
|
| 33 |
+
hop_length: int,
|
| 34 |
+
dim_t: int,
|
| 35 |
+
sample_rate: int,
|
| 36 |
+
batch_size: int,
|
| 37 |
+
num_workers: int,
|
| 38 |
+
pin_memory: bool,
|
| 39 |
+
external_datasets,
|
| 40 |
+
audio_ch: int,
|
| 41 |
+
epoch_size,
|
| 42 |
+
**kwargs,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.data_dir = Path(data_dir)
|
| 47 |
+
self.target_name = target_name
|
| 48 |
+
self.aug_params = aug_params
|
| 49 |
+
self.external_datasets = external_datasets
|
| 50 |
+
|
| 51 |
+
self.batch_size = batch_size
|
| 52 |
+
self.num_workers = num_workers
|
| 53 |
+
self.pin_memory = pin_memory
|
| 54 |
+
|
| 55 |
+
# audio-related
|
| 56 |
+
self.hop_length = hop_length
|
| 57 |
+
self.sample_rate = sample_rate
|
| 58 |
+
self.single_channel = audio_ch == 1
|
| 59 |
+
|
| 60 |
+
# derived
|
| 61 |
+
self.chunk_size = hop_length * (dim_t - 1)
|
| 62 |
+
self.overlap = overlap
|
| 63 |
+
|
| 64 |
+
self.epoch_size = epoch_size
|
| 65 |
+
|
| 66 |
+
self.data_train: Optional[Dataset] = None
|
| 67 |
+
self.data_val: Optional[Dataset] = None
|
| 68 |
+
self.data_test: Optional[Dataset] = None
|
| 69 |
+
|
| 70 |
+
trainset_path = self.data_dir.joinpath('train')
|
| 71 |
+
validset_path = self.data_dir.joinpath('valid')
|
| 72 |
+
|
| 73 |
+
# create validation split
|
| 74 |
+
if not exists(validset_path):
|
| 75 |
+
from shutil import move
|
| 76 |
+
os.mkdir(validset_path)
|
| 77 |
+
for track in kwargs['validation_set']:
|
| 78 |
+
if trainset_path.joinpath(track).exists():
|
| 79 |
+
move(trainset_path.joinpath(track), validset_path.joinpath(track))
|
| 80 |
+
else:
|
| 81 |
+
valid_files = os.listdir(validset_path)
|
| 82 |
+
assert set(valid_files) == set(kwargs['validation_set'])
|
| 83 |
+
|
| 84 |
+
def setup(self, stage: Optional[str] = None):
|
| 85 |
+
"""Load data. Set variables: self.data_train, self.data_val, self.data_test."""
|
| 86 |
+
self.data_train = MusdbTrainDataset(self.data_dir,
|
| 87 |
+
self.chunk_size,
|
| 88 |
+
self.target_name,
|
| 89 |
+
self.aug_params,
|
| 90 |
+
self.external_datasets,
|
| 91 |
+
self.single_channel,
|
| 92 |
+
self.epoch_size)
|
| 93 |
+
|
| 94 |
+
self.data_val = MusdbValidDataset(self.data_dir,
|
| 95 |
+
self.chunk_size,
|
| 96 |
+
self.target_name,
|
| 97 |
+
self.overlap,
|
| 98 |
+
self.batch_size,
|
| 99 |
+
self.single_channel)
|
| 100 |
+
|
| 101 |
+
def train_dataloader(self):
|
| 102 |
+
return DataLoader(
|
| 103 |
+
dataset=self.data_train,
|
| 104 |
+
batch_size=self.batch_size,
|
| 105 |
+
num_workers=self.num_workers,
|
| 106 |
+
pin_memory=self.pin_memory,
|
| 107 |
+
shuffle=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def val_dataloader(self):
|
| 111 |
+
return DataLoader(
|
| 112 |
+
dataset=self.data_val,
|
| 113 |
+
batch_size=1,
|
| 114 |
+
num_workers=self.num_workers,
|
| 115 |
+
pin_memory=self.pin_memory,
|
| 116 |
+
shuffle=False,
|
| 117 |
+
)
|
src/dp_tdf/__init__.py
ADDED
|
File without changes
|
src/dp_tdf/abstract.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from pytorch_lightning import LightningModule
|
| 10 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 11 |
+
|
| 12 |
+
from src.utils.utils import sdr, simplified_msseval
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class AbstractModel(LightningModule):
|
| 16 |
+
__metaclass__ = ABCMeta
|
| 17 |
+
|
| 18 |
+
def __init__(self, target_name,
|
| 19 |
+
lr, optimizer,
|
| 20 |
+
dim_f, dim_t, n_fft, hop_length, overlap,
|
| 21 |
+
audio_ch,
|
| 22 |
+
**kwargs):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.target_name = target_name
|
| 25 |
+
self.lr = lr
|
| 26 |
+
self.optimizer = optimizer
|
| 27 |
+
self.dim_c_in = audio_ch * 2
|
| 28 |
+
self.dim_c_out = audio_ch * 2
|
| 29 |
+
self.dim_f = dim_f
|
| 30 |
+
self.dim_t = dim_t
|
| 31 |
+
self.n_fft = n_fft
|
| 32 |
+
self.n_bins = n_fft // 2 + 1
|
| 33 |
+
self.hop_length = hop_length
|
| 34 |
+
self.audio_ch = audio_ch
|
| 35 |
+
|
| 36 |
+
self.chunk_size = hop_length * (self.dim_t - 1)
|
| 37 |
+
self.inference_chunk_size = hop_length * (self.dim_t*2 - 1)
|
| 38 |
+
self.overlap = overlap
|
| 39 |
+
self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
|
| 40 |
+
self.freq_pad = nn.Parameter(torch.zeros([1, self.dim_c_out, self.n_bins - self.dim_f, 1]), requires_grad=False)
|
| 41 |
+
self.inference_chunk_shape = (self.stft(torch.zeros([1, audio_ch, self.inference_chunk_size]))).shape
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def configure_optimizers(self):
|
| 45 |
+
if self.optimizer == 'rmsprop':
|
| 46 |
+
print("Using RMSprop optimizer")
|
| 47 |
+
return torch.optim.RMSprop(self.parameters(), self.lr)
|
| 48 |
+
elif self.optimizer == 'adamW':
|
| 49 |
+
print("Using AdamW optimizer")
|
| 50 |
+
return torch.optim.AdamW(self.parameters(), self.lr)
|
| 51 |
+
|
| 52 |
+
def comp_loss(self, pred_detail, target_wave):
|
| 53 |
+
pred_detail = self.istft(pred_detail)
|
| 54 |
+
|
| 55 |
+
comp_loss = F.l1_loss(pred_detail, target_wave)
|
| 56 |
+
|
| 57 |
+
self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False)
|
| 58 |
+
|
| 59 |
+
return comp_loss
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
|
| 63 |
+
mix_wave, target_wave = args[0] # (batch, c, 261120)
|
| 64 |
+
|
| 65 |
+
# input 1
|
| 66 |
+
stft_44k = self.stft(mix_wave) # (batch, c*2, 1044, 256)
|
| 67 |
+
# forward
|
| 68 |
+
t_est_stft = self(stft_44k) # (batch, c, 1044, 256)
|
| 69 |
+
|
| 70 |
+
loss = self.comp_loss(t_est_stft, target_wave)
|
| 71 |
+
|
| 72 |
+
self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True)
|
| 73 |
+
|
| 74 |
+
return {"loss": loss}
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Validation SDR is calculated on whole tracks and not chunks since
|
| 78 |
+
# short inputs have high possibility of being silent (all-zero signal)
|
| 79 |
+
# which leads to very low sdr values regardless of the model.
|
| 80 |
+
# A natural procedure would be to split a track into chunk batches and
|
| 81 |
+
# load them on multiple gpus, but aggregation was too difficult.
|
| 82 |
+
# So instead we load one whole track on a single device (data_loader batch_size should always be 1)
|
| 83 |
+
# and do all the batch splitting and aggregation on a single device.
|
| 84 |
+
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
| 85 |
+
mix_chunk_batches, target = args[0]
|
| 86 |
+
|
| 87 |
+
# remove data_loader batch dimension
|
| 88 |
+
# [(b, c, time)], (c, all_times)
|
| 89 |
+
mix_chunk_batches, target = [batch[0] for batch in mix_chunk_batches], target[0]
|
| 90 |
+
|
| 91 |
+
# process whole track in batches of chunks
|
| 92 |
+
target_hat_chunks = []
|
| 93 |
+
for batch in mix_chunk_batches:
|
| 94 |
+
# input
|
| 95 |
+
stft_44k = self.stft(batch) # (batch, c*2, 1044, 256)
|
| 96 |
+
pred_detail = self(stft_44k) # (batch, c, 1044, 256), irm
|
| 97 |
+
pred_detail = self.istft(pred_detail)
|
| 98 |
+
|
| 99 |
+
target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap])
|
| 100 |
+
target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t)
|
| 101 |
+
|
| 102 |
+
# concat all output chunks (c, all_times)
|
| 103 |
+
target_hat = target_hat_chunks.transpose(0, 1).reshape(self.audio_ch, -1)[..., :target.shape[-1]]
|
| 104 |
+
|
| 105 |
+
ests = target_hat.detach().cpu().numpy() # (c, all_times)
|
| 106 |
+
references = target.cpu().numpy()
|
| 107 |
+
score = sdr(ests, references)
|
| 108 |
+
|
| 109 |
+
# (src, t, c)
|
| 110 |
+
SDR = simplified_msseval(np.expand_dims(references.T, axis=0), np.expand_dims(ests.T, axis=0), chunk_size=44100)
|
| 111 |
+
# self.log("val/sdr", score, sync_dist=True, on_step=False, on_epoch=True, logger=True)
|
| 112 |
+
|
| 113 |
+
return {'song': score, 'chunk': SDR}
|
| 114 |
+
|
| 115 |
+
def validation_epoch_end(self, outputs) -> None:
|
| 116 |
+
avg_uSDR = torch.Tensor([x['song'] for x in outputs]).mean()
|
| 117 |
+
self.log("val/usdr", avg_uSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True)
|
| 118 |
+
|
| 119 |
+
chunks = [x['chunk'][0, :] for x in outputs]
|
| 120 |
+
# concat np array
|
| 121 |
+
chunks = np.concatenate(chunks, axis=0)
|
| 122 |
+
median_cSDR = np.nanmedian(chunks.flatten(), axis=0)
|
| 123 |
+
median_cSDR = float(median_cSDR)
|
| 124 |
+
self.log("val/csdr", median_cSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True)
|
| 125 |
+
|
| 126 |
+
def stft(self, x):
|
| 127 |
+
'''
|
| 128 |
+
Args:
|
| 129 |
+
x: (batch, c, 261120)
|
| 130 |
+
'''
|
| 131 |
+
dim_b = x.shape[0]
|
| 132 |
+
x = x.reshape([dim_b * self.audio_ch, -1]) # (batch*c, 261120)
|
| 133 |
+
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) # (batch*c, 3073, 256, 2)
|
| 134 |
+
x = x.permute([0, 3, 1, 2]) # (batch*c, 2, 3073, 256)
|
| 135 |
+
x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b, self.audio_ch * 2, self.n_bins, -1]) # (batch, c*2, 3073, 256)
|
| 136 |
+
return x[:, :, :self.dim_f] # (batch, c*2, 2048, 256)
|
| 137 |
+
|
| 138 |
+
def istft(self, x):
|
| 139 |
+
'''
|
| 140 |
+
Args:
|
| 141 |
+
x: (batch, c*2, 2048, 256)
|
| 142 |
+
'''
|
| 143 |
+
dim_b = x.shape[0]
|
| 144 |
+
x = torch.cat([x, self.freq_pad.repeat([x.shape[0], 1, 1, x.shape[-1]])], -2) # (batch, c*2, 3073, 256)
|
| 145 |
+
x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b * self.audio_ch, 2, self.n_bins, -1]) # (batch*c, 2, 3073, 256)
|
| 146 |
+
x = x.permute([0, 2, 3, 1]) # (batch*c, 3073, 256, 2)
|
| 147 |
+
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) # (batch*c, 261120)
|
| 148 |
+
return x.reshape([dim_b, self.audio_ch, -1]) # (batch,c,261120)
|
| 149 |
+
|
| 150 |
+
def demix(self, mix, inf_chunk_size, batch_size=5, inf_overf=4):
|
| 151 |
+
'''
|
| 152 |
+
Args:
|
| 153 |
+
mix: (C, L)
|
| 154 |
+
Returns:
|
| 155 |
+
est: (src, C, L)
|
| 156 |
+
'''
|
| 157 |
+
|
| 158 |
+
# batch_size = self.config.inference.batch_size
|
| 159 |
+
# = self.chunk_size
|
| 160 |
+
# self.instruments = ['bass', 'drums', 'other', 'vocals']
|
| 161 |
+
num_instruments = 1
|
| 162 |
+
|
| 163 |
+
inf_hop = inf_chunk_size // inf_overf # hop size
|
| 164 |
+
L = mix.shape[1]
|
| 165 |
+
pad_size = inf_hop - (L - inf_chunk_size) % inf_hop
|
| 166 |
+
mix = torch.cat([torch.zeros(2, inf_chunk_size - inf_hop), torch.Tensor(mix), torch.zeros(2, pad_size + inf_chunk_size - inf_hop)], 1)
|
| 167 |
+
mix = mix.cuda()
|
| 168 |
+
|
| 169 |
+
chunks = []
|
| 170 |
+
i = 0
|
| 171 |
+
while i + inf_chunk_size <= mix.shape[1]:
|
| 172 |
+
chunks.append(mix[:, i:i + inf_chunk_size])
|
| 173 |
+
i += inf_hop
|
| 174 |
+
chunks = torch.stack(chunks)
|
| 175 |
+
|
| 176 |
+
batches = []
|
| 177 |
+
i = 0
|
| 178 |
+
while i < len(chunks):
|
| 179 |
+
batches.append(chunks[i:i + batch_size])
|
| 180 |
+
i = i + batch_size
|
| 181 |
+
|
| 182 |
+
X = torch.zeros(num_instruments, 2, inf_chunk_size - inf_hop) # (src, c, t)
|
| 183 |
+
X = X.cuda()
|
| 184 |
+
with torch.cuda.amp.autocast():
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
for batch in batches:
|
| 187 |
+
x = self.stft(batch)
|
| 188 |
+
x = self(x)
|
| 189 |
+
x = self.istft(x) # (batch, c, 261120)
|
| 190 |
+
# insert new axis, the model only predict 1 src so we need to add axis
|
| 191 |
+
x = x[:,None, ...] # (batch, 1, c, 261120)
|
| 192 |
+
x = x.repeat([ 1, num_instruments, 1, 1]) # (batch, src, c, 261120)
|
| 193 |
+
for w in x: # iterate over batch
|
| 194 |
+
a = X[..., :-(inf_chunk_size - inf_hop)]
|
| 195 |
+
b = X[..., -(inf_chunk_size - inf_hop):] + w[..., :(inf_chunk_size - inf_hop)]
|
| 196 |
+
c = w[..., (inf_chunk_size - inf_hop):]
|
| 197 |
+
X = torch.cat([a, b, c], -1)
|
| 198 |
+
|
| 199 |
+
estimated_sources = X[..., inf_chunk_size - inf_hop:-(pad_size + inf_chunk_size - inf_hop)] / inf_overf
|
| 200 |
+
|
| 201 |
+
assert L == estimated_sources.shape[-1]
|
| 202 |
+
|
| 203 |
+
return estimated_sources
|
| 204 |
+
|
src/dp_tdf/bandsequence.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
# Original code from https://github.com/amanteur/BandSplitRNN-Pytorch
|
| 5 |
+
class RNNModule(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
RNN submodule of BandSequence module
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
group_num: int,
|
| 13 |
+
input_dim_size: int,
|
| 14 |
+
hidden_dim_size: int,
|
| 15 |
+
rnn_type: str = 'lstm',
|
| 16 |
+
bidirectional: bool = True
|
| 17 |
+
):
|
| 18 |
+
super(RNNModule, self).__init__()
|
| 19 |
+
self.groupnorm = nn.GroupNorm(group_num, input_dim_size)
|
| 20 |
+
self.rnn = getattr(nn, rnn_type)(
|
| 21 |
+
input_dim_size, hidden_dim_size, batch_first=True, bidirectional=bidirectional # 输出是2*hidden_dim_size,因为是bi
|
| 22 |
+
)
|
| 23 |
+
self.fc = nn.Linear(
|
| 24 |
+
hidden_dim_size * 2 if bidirectional else hidden_dim_size,
|
| 25 |
+
input_dim_size
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(
|
| 29 |
+
self,
|
| 30 |
+
x: torch.Tensor
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Input shape:
|
| 34 |
+
across T - [batch_size, k_subbands, time, n_features]
|
| 35 |
+
OR
|
| 36 |
+
across K - [batch_size, time, k_subbands, n_features]
|
| 37 |
+
"""
|
| 38 |
+
B, K, T, N = x.shape # across T across K (keep in mind T->K, K->T)
|
| 39 |
+
# print(x.shape)
|
| 40 |
+
|
| 41 |
+
out = x.view(B * K, T, N) # [BK, T, N] [BT, K, N]
|
| 42 |
+
|
| 43 |
+
# print(out.shape)
|
| 44 |
+
# print(self.groupnorm)
|
| 45 |
+
out = self.groupnorm(
|
| 46 |
+
out.transpose(-1, -2)
|
| 47 |
+
).transpose(-1, -2) # [BK, T, N] [BT, K, N]
|
| 48 |
+
out = self.rnn(out)[0] # [BK, T, H] [BT, K, H], 最后一维是特征
|
| 49 |
+
out = self.fc(out) # [BK, T, N] [BT, K, N]
|
| 50 |
+
|
| 51 |
+
x = out.view(B, K, T, N) + x # [B, K, T, N] [B, T, K, N]
|
| 52 |
+
|
| 53 |
+
x = x.permute(0, 2, 1, 3).contiguous() # [B, T, K, N] [B, K, T, N]
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BandSequenceModelModule(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
BandSequence (2nd) Module of BandSplitRNN.
|
| 60 |
+
Runs input through n BiLSTMs in two dimensions - time and subbands.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
# group_num,
|
| 66 |
+
input_dim_size: int,
|
| 67 |
+
hidden_dim_size: int,
|
| 68 |
+
rnn_type: str = 'lstm',
|
| 69 |
+
bidirectional: bool = True,
|
| 70 |
+
num_layers: int = 12,
|
| 71 |
+
n_heads: int = 4,
|
| 72 |
+
):
|
| 73 |
+
super(BandSequenceModelModule, self).__init__()
|
| 74 |
+
|
| 75 |
+
self.bsrnn = nn.ModuleList([])
|
| 76 |
+
self.n_heads = n_heads
|
| 77 |
+
|
| 78 |
+
input_dim_size = input_dim_size // n_heads
|
| 79 |
+
hidden_dim_size = hidden_dim_size // n_heads
|
| 80 |
+
group_num = input_dim_size // 16
|
| 81 |
+
# print(f"input_dim_size: {input_dim_size}, hidden_dim_size: {hidden_dim_size}, group_num: {group_num}")
|
| 82 |
+
|
| 83 |
+
# print(group_num, input_dim_size)
|
| 84 |
+
|
| 85 |
+
for _ in range(num_layers):
|
| 86 |
+
rnn_across_t = RNNModule(
|
| 87 |
+
group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional
|
| 88 |
+
)
|
| 89 |
+
rnn_across_k = RNNModule(
|
| 90 |
+
group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional
|
| 91 |
+
)
|
| 92 |
+
self.bsrnn.append(
|
| 93 |
+
nn.Sequential(rnn_across_t, rnn_across_k)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: torch.Tensor):
|
| 97 |
+
"""
|
| 98 |
+
Input shape: [batch_size, k_subbands, time, n_features]
|
| 99 |
+
Output shape: [batch_size, k_subbands, time, n_features]
|
| 100 |
+
"""
|
| 101 |
+
# x (b,c,t,f)
|
| 102 |
+
b,c,t,f = x.shape
|
| 103 |
+
x = x.view(b * self.n_heads, c // self.n_heads, t, f) # [b*n_heads, c//n_heads, t, f]
|
| 104 |
+
|
| 105 |
+
x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, f, t, c//n_heads]
|
| 106 |
+
for i in range(len(self.bsrnn)):
|
| 107 |
+
x = self.bsrnn[i](x)
|
| 108 |
+
|
| 109 |
+
x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, c//n_heads, t, f]
|
| 110 |
+
x = x.view(b, c, t, f) # [b, c, t, f]
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
if __name__ == '__main__':
|
| 115 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 116 |
+
|
| 117 |
+
batch_size, k_subbands, t_timesteps, input_dim = 4, 41, 512, 128
|
| 118 |
+
in_features = torch.rand(batch_size, k_subbands, t_timesteps, input_dim).to(device)
|
| 119 |
+
|
| 120 |
+
cfg = {
|
| 121 |
+
# "t_timesteps": t_timesteps,
|
| 122 |
+
"group_num": 32,
|
| 123 |
+
"input_dim_size": 128,
|
| 124 |
+
"hidden_dim_size": 256,
|
| 125 |
+
"rnn_type": "LSTM",
|
| 126 |
+
"bidirectional": True,
|
| 127 |
+
"num_layers": 1
|
| 128 |
+
}
|
| 129 |
+
model = BandSequenceModelModule(**cfg).to(device)
|
| 130 |
+
_ = model.eval()
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
out_features = model(in_features)
|
| 134 |
+
|
| 135 |
+
print(f"In: {in_features.shape}\nOut: {out_features.shape}")
|
| 136 |
+
print(f"Total number of parameters: {sum([p.numel() for p in model.parameters()])}")
|
src/dp_tdf/dp_tdf_net.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from src.dp_tdf.modules import TFC_TDF, TFC_TDF_Res1, TFC_TDF_Res2
|
| 5 |
+
from src.dp_tdf.bandsequence import BandSequenceModelModule
|
| 6 |
+
|
| 7 |
+
from src.layers import (get_norm)
|
| 8 |
+
from src.dp_tdf.abstract import AbstractModel
|
| 9 |
+
|
| 10 |
+
class DPTDFNet(AbstractModel):
|
| 11 |
+
def __init__(self, num_blocks, l, g, k, bn, bias, bn_norm, bandsequence, block_type, **kwargs):
|
| 12 |
+
|
| 13 |
+
super(DPTDFNet, self).__init__(**kwargs)
|
| 14 |
+
# self.save_hyperparameters()
|
| 15 |
+
|
| 16 |
+
self.num_blocks = num_blocks
|
| 17 |
+
self.l = l
|
| 18 |
+
self.g = g
|
| 19 |
+
self.k = k
|
| 20 |
+
self.bn = bn
|
| 21 |
+
self.bias = bias
|
| 22 |
+
|
| 23 |
+
self.n = num_blocks // 2
|
| 24 |
+
scale = (2, 2)
|
| 25 |
+
|
| 26 |
+
if block_type == "TFC_TDF":
|
| 27 |
+
T_BLOCK = TFC_TDF
|
| 28 |
+
elif block_type == "TFC_TDF_Res1":
|
| 29 |
+
T_BLOCK = TFC_TDF_Res1
|
| 30 |
+
elif block_type == "TFC_TDF_Res2":
|
| 31 |
+
T_BLOCK = TFC_TDF_Res2
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unknown block type {block_type}")
|
| 34 |
+
|
| 35 |
+
self.first_conv = nn.Sequential(
|
| 36 |
+
nn.Conv2d(in_channels=self.dim_c_in, out_channels=g, kernel_size=(1, 1)),
|
| 37 |
+
get_norm(bn_norm, g),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
f = self.dim_f
|
| 42 |
+
c = g
|
| 43 |
+
self.encoding_blocks = nn.ModuleList()
|
| 44 |
+
self.ds = nn.ModuleList()
|
| 45 |
+
|
| 46 |
+
for i in range(self.n):
|
| 47 |
+
c_in = c
|
| 48 |
+
|
| 49 |
+
self.encoding_blocks.append(T_BLOCK(c_in, c, l, f, k, bn, bn_norm, bias=bias))
|
| 50 |
+
self.ds.append(
|
| 51 |
+
nn.Sequential(
|
| 52 |
+
nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
|
| 53 |
+
get_norm(bn_norm, c + g),
|
| 54 |
+
nn.ReLU()
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
f = f // 2
|
| 58 |
+
c += g
|
| 59 |
+
|
| 60 |
+
self.bottleneck_block1 = T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias)
|
| 61 |
+
self.bottleneck_block2 = BandSequenceModelModule(
|
| 62 |
+
**bandsequence,
|
| 63 |
+
input_dim_size=c,
|
| 64 |
+
hidden_dim_size=2*c
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.decoding_blocks = nn.ModuleList()
|
| 68 |
+
self.us = nn.ModuleList()
|
| 69 |
+
for i in range(self.n):
|
| 70 |
+
# print(f"i: {i}, in channels: {c}")
|
| 71 |
+
self.us.append(
|
| 72 |
+
nn.Sequential(
|
| 73 |
+
nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
|
| 74 |
+
get_norm(bn_norm, c - g),
|
| 75 |
+
nn.ReLU()
|
| 76 |
+
)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
f = f * 2
|
| 80 |
+
c -= g
|
| 81 |
+
|
| 82 |
+
self.decoding_blocks.append(T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias))
|
| 83 |
+
|
| 84 |
+
self.final_conv = nn.Sequential(
|
| 85 |
+
nn.Conv2d(in_channels=c, out_channels=self.dim_c_out, kernel_size=(1, 1)),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
'''
|
| 90 |
+
Args:
|
| 91 |
+
x: (batch, c*2, 2048, 256)
|
| 92 |
+
'''
|
| 93 |
+
x = self.first_conv(x)
|
| 94 |
+
|
| 95 |
+
x = x.transpose(-1, -2)
|
| 96 |
+
|
| 97 |
+
ds_outputs = []
|
| 98 |
+
for i in range(self.n):
|
| 99 |
+
x = self.encoding_blocks[i](x)
|
| 100 |
+
ds_outputs.append(x)
|
| 101 |
+
x = self.ds[i](x)
|
| 102 |
+
|
| 103 |
+
# print(f"bottleneck in: {x.shape}")
|
| 104 |
+
x = self.bottleneck_block1(x)
|
| 105 |
+
x = self.bottleneck_block2(x)
|
| 106 |
+
|
| 107 |
+
for i in range(self.n):
|
| 108 |
+
x = self.us[i](x)
|
| 109 |
+
# print(f"us{i} in: {x.shape}")
|
| 110 |
+
# print(f"ds{i} out: {ds_outputs[-i - 1].shape}")
|
| 111 |
+
x = x * ds_outputs[-i - 1]
|
| 112 |
+
x = self.decoding_blocks[i](x)
|
| 113 |
+
|
| 114 |
+
x = x.transpose(-1, -2)
|
| 115 |
+
|
| 116 |
+
x = self.final_conv(x)
|
| 117 |
+
|
| 118 |
+
return x
|
src/dp_tdf/modules.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from src.layers import (get_norm)
|
| 5 |
+
|
| 6 |
+
class TFC(nn.Module):
|
| 7 |
+
def __init__(self, c_in, c_out, l, k, bn_norm):
|
| 8 |
+
super(TFC, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.H = nn.ModuleList()
|
| 11 |
+
for i in range(l):
|
| 12 |
+
if i == 0:
|
| 13 |
+
c_in = c_in
|
| 14 |
+
else:
|
| 15 |
+
c_in = c_out
|
| 16 |
+
self.H.append(
|
| 17 |
+
nn.Sequential(
|
| 18 |
+
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2),
|
| 19 |
+
get_norm(bn_norm, c_out),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
for h in self.H:
|
| 26 |
+
x = h(x)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DenseTFC(nn.Module):
|
| 31 |
+
def __init__(self, c_in, c_out, l, k, bn_norm):
|
| 32 |
+
super(DenseTFC, self).__init__()
|
| 33 |
+
|
| 34 |
+
self.conv = nn.ModuleList()
|
| 35 |
+
for i in range(l):
|
| 36 |
+
self.conv.append(
|
| 37 |
+
nn.Sequential(
|
| 38 |
+
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2),
|
| 39 |
+
get_norm(bn_norm, c_out),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
for layer in self.conv[:-1]:
|
| 46 |
+
x = torch.cat([layer(x), x], 1)
|
| 47 |
+
return self.conv[-1](x)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class TFC_TDF(nn.Module):
|
| 51 |
+
def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
|
| 52 |
+
|
| 53 |
+
super(TFC_TDF, self).__init__()
|
| 54 |
+
|
| 55 |
+
self.use_tdf = bn is not None
|
| 56 |
+
|
| 57 |
+
self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm)
|
| 58 |
+
|
| 59 |
+
if self.use_tdf:
|
| 60 |
+
if bn == 0:
|
| 61 |
+
# print(f"TDF={f},{f}")
|
| 62 |
+
self.tdf = nn.Sequential(
|
| 63 |
+
nn.Linear(f, f, bias=bias),
|
| 64 |
+
get_norm(bn_norm, c_out),
|
| 65 |
+
nn.ReLU()
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
# print(f"TDF={f},{f // bn},{f}")
|
| 69 |
+
self.tdf = nn.Sequential(
|
| 70 |
+
nn.Linear(f, f // bn, bias=bias),
|
| 71 |
+
get_norm(bn_norm, c_out),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Linear(f // bn, f, bias=bias),
|
| 74 |
+
get_norm(bn_norm, c_out),
|
| 75 |
+
nn.ReLU()
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = self.tfc(x)
|
| 80 |
+
return x + self.tdf(x) if self.use_tdf else x
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class TFC_TDF_Res1(nn.Module):
|
| 84 |
+
def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
|
| 85 |
+
|
| 86 |
+
super(TFC_TDF_Res1, self).__init__()
|
| 87 |
+
|
| 88 |
+
self.use_tdf = bn is not None
|
| 89 |
+
|
| 90 |
+
self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm)
|
| 91 |
+
|
| 92 |
+
self.res = TFC(c_in, c_out, 1, k, bn_norm)
|
| 93 |
+
|
| 94 |
+
if self.use_tdf:
|
| 95 |
+
if bn == 0:
|
| 96 |
+
# print(f"TDF={f},{f}")
|
| 97 |
+
self.tdf = nn.Sequential(
|
| 98 |
+
nn.Linear(f, f, bias=bias),
|
| 99 |
+
get_norm(bn_norm, c_out),
|
| 100 |
+
nn.ReLU()
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
# print(f"TDF={f},{f // bn},{f}")
|
| 104 |
+
self.tdf = nn.Sequential(
|
| 105 |
+
nn.Linear(f, f // bn, bias=bias),
|
| 106 |
+
get_norm(bn_norm, c_out),
|
| 107 |
+
nn.ReLU(),
|
| 108 |
+
nn.Linear(f // bn, f, bias=bias),
|
| 109 |
+
get_norm(bn_norm, c_out),
|
| 110 |
+
nn.ReLU()
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
res = self.res(x)
|
| 115 |
+
x = self.tfc(x)
|
| 116 |
+
x = x + res
|
| 117 |
+
return x + self.tdf(x) if self.use_tdf else x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class TFC_TDF_Res2(nn.Module):
|
| 121 |
+
def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
|
| 122 |
+
|
| 123 |
+
super(TFC_TDF_Res2, self).__init__()
|
| 124 |
+
|
| 125 |
+
self.use_tdf = bn is not None
|
| 126 |
+
|
| 127 |
+
self.tfc1 = TFC(c_in, c_out, l, k, bn_norm)
|
| 128 |
+
self.tfc2 = TFC(c_in, c_out, l, k, bn_norm)
|
| 129 |
+
|
| 130 |
+
self.res = TFC(c_in, c_out, 1, k, bn_norm)
|
| 131 |
+
|
| 132 |
+
if self.use_tdf:
|
| 133 |
+
if bn == 0:
|
| 134 |
+
# print(f"TDF={f},{f}")
|
| 135 |
+
self.tdf = nn.Sequential(
|
| 136 |
+
nn.Linear(f, f, bias=bias),
|
| 137 |
+
get_norm(bn_norm, c_out),
|
| 138 |
+
nn.ReLU()
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
# print(f"TDF={f},{f // bn},{f}")
|
| 142 |
+
self.tdf = nn.Sequential(
|
| 143 |
+
nn.Linear(f, f // bn, bias=bias),
|
| 144 |
+
get_norm(bn_norm, c_out),
|
| 145 |
+
nn.ReLU(),
|
| 146 |
+
nn.Linear(f // bn, f, bias=bias),
|
| 147 |
+
get_norm(bn_norm, c_out),
|
| 148 |
+
nn.ReLU()
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
res = self.res(x)
|
| 153 |
+
x = self.tfc1(x)
|
| 154 |
+
if self.use_tdf:
|
| 155 |
+
x = x + self.tdf(x)
|
| 156 |
+
x = self.tfc2(x)
|
| 157 |
+
x = x + res
|
| 158 |
+
return x
|
src/evaluation/eval.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
|
| 5 |
+
from concurrent import futures
|
| 6 |
+
import hydra
|
| 7 |
+
import wandb
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
from pytorch_lightning import LightningDataModule, LightningModule
|
| 12 |
+
from pytorch_lightning.loggers import Logger, WandbLogger
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
import numpy as np
|
| 17 |
+
from src.callbacks.wandb_callbacks import get_wandb_logger
|
| 18 |
+
from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF
|
| 19 |
+
from src.utils import utils
|
| 20 |
+
from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics
|
| 21 |
+
|
| 22 |
+
from src.utils import pylogger
|
| 23 |
+
|
| 24 |
+
log = pylogger.get_pylogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def evaluation(config: DictConfig):
|
| 28 |
+
|
| 29 |
+
assert config.split in ['train', 'valid', 'test']
|
| 30 |
+
|
| 31 |
+
data_dir = Path(config.get('eval_dir')).joinpath(config['split'])
|
| 32 |
+
assert data_dir.exists()
|
| 33 |
+
|
| 34 |
+
# Init Lightning loggers
|
| 35 |
+
loggers: List[Logger] = []
|
| 36 |
+
if "logger" in config:
|
| 37 |
+
for _, lg_conf in config.logger.items():
|
| 38 |
+
if "_target_" in lg_conf:
|
| 39 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 40 |
+
loggers.append(hydra.utils.instantiate(lg_conf))
|
| 41 |
+
|
| 42 |
+
if any([isinstance(l, WandbLogger) for l in loggers]):
|
| 43 |
+
utils.wandb_login(key=config.wandb_api_key)
|
| 44 |
+
|
| 45 |
+
model = hydra.utils.instantiate(config.model)
|
| 46 |
+
target_name = model.target_name
|
| 47 |
+
ckpt_path = Path(config.ckpt_path)
|
| 48 |
+
is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx'
|
| 49 |
+
shutil.copy(ckpt_path,os.getcwd()) # copy model
|
| 50 |
+
|
| 51 |
+
ssdrs = []
|
| 52 |
+
bss_lst = []
|
| 53 |
+
bss_perms = []
|
| 54 |
+
num_tracks = len(listdir(data_dir))
|
| 55 |
+
target_list = [config.model.target_name,"complement"]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
pool = futures.ProcessPoolExecutor
|
| 59 |
+
with pool(config.pool_workers) as pool:
|
| 60 |
+
datas = sorted(listdir(data_dir))
|
| 61 |
+
if len(datas) > 27: # if not debugging
|
| 62 |
+
# move idx 27 to head
|
| 63 |
+
datas = [datas[27]] + datas[:27] + datas[28:]
|
| 64 |
+
# iterate datas with batchsize 8
|
| 65 |
+
for k in range(0, len(datas), config.pool_workers):
|
| 66 |
+
batch = datas[k:k + config.pool_workers]
|
| 67 |
+
pendings = []
|
| 68 |
+
for i, track in tqdm(enumerate(batch)):
|
| 69 |
+
folder_name = track
|
| 70 |
+
track = data_dir.joinpath(track)
|
| 71 |
+
mixture = load_wav(track.joinpath('mixture.wav')) # (c, t)
|
| 72 |
+
target = load_wav(track.joinpath(target_name + '.wav'))
|
| 73 |
+
|
| 74 |
+
if model.audio_ch == 1:
|
| 75 |
+
mixture = np.mean(mixture, axis=0, keepdims=True)
|
| 76 |
+
target = np.mean(target, axis=0, keepdims=True)
|
| 77 |
+
#target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources}
|
| 78 |
+
if is_onnx:
|
| 79 |
+
target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture)
|
| 80 |
+
else:
|
| 81 |
+
target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, config.overlap_add)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
pendings.append((folder_name, pool.submit(
|
| 85 |
+
get_metrics, target_hat, target, mixture, sr=44100,version=config.bss)))
|
| 86 |
+
|
| 87 |
+
for wandb_logger in [logger for logger in loggers if isinstance(logger, WandbLogger)]:
|
| 88 |
+
mid = mixture.shape[-1] // 2
|
| 89 |
+
track = target_hat[:, mid - 44100 * 3:mid + 44100 * 3]
|
| 90 |
+
wandb_logger.experiment.log(
|
| 91 |
+
{f'track={k+i}_target={target_name}': [wandb.Audio(track.T, sample_rate=44100)]})
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
for i, (track_name, pending) in tqdm(enumerate(pendings)):
|
| 95 |
+
pending = pending.result()
|
| 96 |
+
bssmetrics, perms, ssdr = pending
|
| 97 |
+
bss_lst.append(bssmetrics)
|
| 98 |
+
bss_perms.append(perms)
|
| 99 |
+
ssdrs.append(ssdr)
|
| 100 |
+
|
| 101 |
+
for logger in loggers:
|
| 102 |
+
logger.log_metrics({'song/ssdr': ssdr}, k+i)
|
| 103 |
+
logger.log_metrics({'song/csdr': get_median_csdr([bssmetrics])}, k+i)
|
| 104 |
+
|
| 105 |
+
log_dir = os.getcwd()
|
| 106 |
+
save_results(log_dir, bss_lst, target_list, bss_perms, ssdrs)
|
| 107 |
+
|
| 108 |
+
cSDR = get_median_csdr(bss_lst)
|
| 109 |
+
uSDR = sum(ssdrs)/num_tracks
|
| 110 |
+
for logger in loggers:
|
| 111 |
+
logger.log_metrics({'metrics/mean_sdr_' + target_name: sum(ssdrs)/num_tracks})
|
| 112 |
+
logger.log_metrics({'metrics/median_csdr_' + target_name: get_median_csdr(bss_lst)})
|
| 113 |
+
# get the path of the log dir
|
| 114 |
+
if not isinstance(logger, WandbLogger):
|
| 115 |
+
logger.experiment.close()
|
| 116 |
+
|
| 117 |
+
if any([isinstance(logger, WandbLogger) for logger in loggers]):
|
| 118 |
+
wandb.finish()
|
| 119 |
+
|
| 120 |
+
return cSDR, uSDR
|
src/evaluation/eval_demo.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
|
| 5 |
+
from concurrent import futures
|
| 6 |
+
import hydra
|
| 7 |
+
import wandb
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
from pytorch_lightning import LightningDataModule, LightningModule
|
| 12 |
+
from pytorch_lightning.loggers import Logger, WandbLogger
|
| 13 |
+
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
import numpy as np
|
| 16 |
+
from src.callbacks.wandb_callbacks import get_wandb_logger
|
| 17 |
+
from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF
|
| 18 |
+
from src.utils import utils
|
| 19 |
+
from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics
|
| 20 |
+
|
| 21 |
+
from src.utils import pylogger
|
| 22 |
+
import soundfile as sf
|
| 23 |
+
log = pylogger.get_pylogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def evaluation(config: DictConfig, idx):
|
| 27 |
+
|
| 28 |
+
assert config.split in ['train', 'valid', 'test']
|
| 29 |
+
|
| 30 |
+
data_dir = Path(config.get('eval_dir')).joinpath(config['split'])
|
| 31 |
+
assert data_dir.exists()
|
| 32 |
+
|
| 33 |
+
model = hydra.utils.instantiate(config.model)
|
| 34 |
+
target_name = model.target_name
|
| 35 |
+
ckpt_path = Path(config.ckpt_path)
|
| 36 |
+
is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx'
|
| 37 |
+
shutil.copy(ckpt_path,os.getcwd()) # copy model
|
| 38 |
+
|
| 39 |
+
datas = sorted(listdir(data_dir))
|
| 40 |
+
if len(datas) > 27: # if not debugging
|
| 41 |
+
# move idx 27 to head
|
| 42 |
+
datas = [datas[27]] + datas[:27] + datas[28:]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
track = datas[idx]
|
| 46 |
+
track = data_dir.joinpath(track)
|
| 47 |
+
print(track)
|
| 48 |
+
mixture = load_wav(track.joinpath('mixture.wav')) # (c, t)
|
| 49 |
+
target = load_wav(track.joinpath(target_name + '.wav'))
|
| 50 |
+
if model.audio_ch == 1:
|
| 51 |
+
mixture = np.mean(mixture, axis=0, keepdims=True)
|
| 52 |
+
target = np.mean(target, axis=0, keepdims=True)
|
| 53 |
+
#target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources}
|
| 54 |
+
if is_onnx:
|
| 55 |
+
target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture)
|
| 56 |
+
else:
|
| 57 |
+
target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, overlap_factor=config.overlap_factor)
|
| 58 |
+
|
| 59 |
+
bssmetrics, perms, ssdr = get_metrics(target_hat, target, mixture, sr=44100,version=config.bss)
|
| 60 |
+
# dump bssmetrics into pkl
|
| 61 |
+
import pickle
|
| 62 |
+
with open(os.path.join(os.getcwd(),'bssmetrics.pkl'),'wb') as f:
|
| 63 |
+
pickle.dump(bssmetrics,f)
|
| 64 |
+
|
| 65 |
+
return bssmetrics
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
src/evaluation/separate.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from os import listdir
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import onnxruntime as ort
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
from src.utils.utils import split_nparray_with_overlap, join_chunks
|
| 10 |
+
|
| 11 |
+
def separate_with_onnx(batch_size, model, onnx_path: Path, mix):
|
| 12 |
+
n_sample = mix.shape[1]
|
| 13 |
+
|
| 14 |
+
trim = model.n_fft // 2
|
| 15 |
+
gen_size = model.sampling_size - 2 * trim
|
| 16 |
+
pad = gen_size - n_sample % gen_size
|
| 17 |
+
mix_p = np.concatenate((np.zeros((2, trim)), mix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
|
| 18 |
+
|
| 19 |
+
mix_waves = []
|
| 20 |
+
i = 0
|
| 21 |
+
while i < n_sample + pad:
|
| 22 |
+
waves = np.array(mix_p[:, i:i + model.sampling_size], dtype=np.float32)
|
| 23 |
+
mix_waves.append(waves)
|
| 24 |
+
i += gen_size
|
| 25 |
+
mix_waves_batched = torch.tensor(mix_waves, dtype=torch.float32).split(batch_size)
|
| 26 |
+
|
| 27 |
+
tar_signals = []
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
_ort = ort.InferenceSession(str(onnx_path))
|
| 31 |
+
for mix_waves in mix_waves_batched:
|
| 32 |
+
tar_waves = model.istft(torch.tensor(
|
| 33 |
+
_ort.run(None, {'input': model.stft(mix_waves).numpy()})[0]
|
| 34 |
+
))
|
| 35 |
+
tar_signals.append(tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy())
|
| 36 |
+
tar_signal = np.concatenate(tar_signals, axis=-1)[:, :-pad]
|
| 37 |
+
|
| 38 |
+
return tar_signal
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def separate_with_ckpt(batch_size, model, ckpt_path: Path, mix, device, double_chunk):
|
| 42 |
+
model = model.load_from_checkpoint(ckpt_path).to(device)
|
| 43 |
+
if double_chunk:
|
| 44 |
+
inf_ck = model.inference_chunk_size
|
| 45 |
+
else:
|
| 46 |
+
inf_ck = model.sampling_size
|
| 47 |
+
true_samples = inf_ck - 2 * model.trim
|
| 48 |
+
|
| 49 |
+
right_pad = true_samples + model.trim - ((mix.shape[-1]) % true_samples)
|
| 50 |
+
mixture = np.concatenate((np.zeros((2, model.trim), dtype='float32'),
|
| 51 |
+
mix,
|
| 52 |
+
np.zeros((2, right_pad), dtype='float32')),
|
| 53 |
+
1)
|
| 54 |
+
num_chunks = mixture.shape[-1] // true_samples
|
| 55 |
+
mix_waves_batched = [mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
|
| 56 |
+
range(num_chunks)]
|
| 57 |
+
mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
|
| 58 |
+
|
| 59 |
+
target_wav_hats = []
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
model.eval()
|
| 63 |
+
for mixture_wav in mix_waves_batched:
|
| 64 |
+
mix_spec = model.stft(mixture_wav.to(device))
|
| 65 |
+
spec_hat = model(mix_spec)
|
| 66 |
+
target_wav_hat = model.istft(spec_hat)
|
| 67 |
+
target_wav_hat = target_wav_hat.cpu().detach().numpy()
|
| 68 |
+
target_wav_hats.append(target_wav_hat)
|
| 69 |
+
|
| 70 |
+
target_wav_hat = np.vstack(target_wav_hats)[:, :, model.trim:-model.trim]
|
| 71 |
+
target_wav_hat = np.concatenate(target_wav_hat, axis=-1)[:, :mix.shape[-1]]
|
| 72 |
+
return target_wav_hat
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def separate_with_onnx_TDF(batch_size, model, onnx_path: Path, mix):
|
| 78 |
+
n_sample = mix.shape[1]
|
| 79 |
+
|
| 80 |
+
overlap = model.n_fft // 2
|
| 81 |
+
gen_size = model.inference_chunk_size - 2 * overlap
|
| 82 |
+
pad = gen_size - n_sample % gen_size
|
| 83 |
+
mix_p = np.concatenate((np.zeros((2, overlap)), mix, np.zeros((2, pad)), np.zeros((2, overlap))), 1)
|
| 84 |
+
|
| 85 |
+
mix_waves = []
|
| 86 |
+
i = 0
|
| 87 |
+
while i < n_sample + pad:
|
| 88 |
+
waves = np.array(mix_p[:, i:i + model.inference_chunk_size], dtype=np.float32)
|
| 89 |
+
mix_waves.append(waves)
|
| 90 |
+
i += gen_size
|
| 91 |
+
mix_waves_batched = torch.tensor(mix_waves, dtype=torch.float32).split(batch_size)
|
| 92 |
+
|
| 93 |
+
tar_signals = []
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
_ort = ort.InferenceSession(str(onnx_path), providers=['CUDAExecutionProvider'])
|
| 97 |
+
for mix_waves in mix_waves_batched:
|
| 98 |
+
tar_waves = model.istft(torch.tensor(
|
| 99 |
+
_ort.run(None, {'input': model.stft(mix_waves).numpy()})[0]
|
| 100 |
+
))
|
| 101 |
+
tar_signals.append(tar_waves[:, :, overlap:-overlap].transpose(0, 1).reshape(2, -1).numpy())
|
| 102 |
+
tar_signal = np.concatenate(tar_signals, axis=-1)[:, :-pad]
|
| 103 |
+
|
| 104 |
+
return tar_signal
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def separate_with_ckpt_TDF(batch_size, model, ckpt_path: Path, mix, device, double_chunk, overlap_add):
|
| 109 |
+
'''
|
| 110 |
+
Args:
|
| 111 |
+
batch_size: the inference batch size
|
| 112 |
+
model: the model to be used
|
| 113 |
+
ckpt_path: the path to the checkpoint
|
| 114 |
+
mix: (c, t)
|
| 115 |
+
device: the device to be used
|
| 116 |
+
double_chunk: whether to use double chunk size
|
| 117 |
+
Returns:
|
| 118 |
+
target_wav_hat: (c, t)
|
| 119 |
+
'''
|
| 120 |
+
checkpoint = torch.load(ckpt_path)
|
| 121 |
+
model.load_state_dict(checkpoint["state_dict"])
|
| 122 |
+
model = model.to(device)
|
| 123 |
+
# model = model.load_from_checkpoint(ckpt_path).to(device)
|
| 124 |
+
if double_chunk:
|
| 125 |
+
inf_ck = model.inference_chunk_size
|
| 126 |
+
else:
|
| 127 |
+
inf_ck = model.chunk_size
|
| 128 |
+
|
| 129 |
+
if overlap_add is None:
|
| 130 |
+
target_wav_hat = no_overlap_inference(model, mix, device, batch_size, inf_ck)
|
| 131 |
+
else:
|
| 132 |
+
if not os.path.exists(overlap_add.tmp_root):
|
| 133 |
+
os.makedirs(overlap_add.tmp_root)
|
| 134 |
+
target_wav_hat = overlap_inference(model, mix, device, batch_size, inf_ck, overlap_add.overlap_rate, overlap_add.tmp_root, overlap_add.samplerate)
|
| 135 |
+
|
| 136 |
+
return target_wav_hat
|
| 137 |
+
|
| 138 |
+
def no_overlap_inference(model, mix, device, batch_size, inf_ck):
|
| 139 |
+
true_samples = inf_ck - 2 * model.overlap
|
| 140 |
+
|
| 141 |
+
right_pad = true_samples + model.overlap - ((mix.shape[-1]) % true_samples)
|
| 142 |
+
mixture = np.concatenate((np.zeros((model.audio_ch, model.overlap), dtype='float32'),
|
| 143 |
+
mix,
|
| 144 |
+
np.zeros((model.audio_ch, right_pad), dtype='float32')),
|
| 145 |
+
1)
|
| 146 |
+
num_chunks = mixture.shape[-1] // true_samples
|
| 147 |
+
mix_waves_batched = [mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
|
| 148 |
+
range(num_chunks)]
|
| 149 |
+
mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
|
| 150 |
+
|
| 151 |
+
target_wav_hats = []
|
| 152 |
+
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
model.eval()
|
| 155 |
+
for mixture_wav in mix_waves_batched:
|
| 156 |
+
mix_spec = model.stft(mixture_wav.to(device))
|
| 157 |
+
spec_hat = model(mix_spec)
|
| 158 |
+
target_wav_hat = model.istft(spec_hat)
|
| 159 |
+
target_wav_hat = target_wav_hat.cpu().detach().numpy()
|
| 160 |
+
target_wav_hats.append(target_wav_hat) # (b, c, t)
|
| 161 |
+
|
| 162 |
+
target_wav_hat = np.vstack(target_wav_hats)[:, :, model.overlap:-model.overlap] # (sum(b), c, t)
|
| 163 |
+
target_wav_hat = np.concatenate(target_wav_hat, axis=-1)[:, :mix.shape[-1]]
|
| 164 |
+
return target_wav_hat
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def overlap_inference(model, mix, device, batch_size, inf_ck, overlap_rate, tmp_root, samplerate):
|
| 168 |
+
'''
|
| 169 |
+
Args:
|
| 170 |
+
mix: (c, t)
|
| 171 |
+
'''
|
| 172 |
+
hop_length = math.ceil((1 - overlap_rate) * inf_ck)
|
| 173 |
+
overlap_size = inf_ck - hop_length
|
| 174 |
+
step_t = mix.shape[1]
|
| 175 |
+
mix_waves_batched = split_nparray_with_overlap(mix.T, hop_length, overlap_size)
|
| 176 |
+
|
| 177 |
+
mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size) # [(b, c, t)]
|
| 178 |
+
|
| 179 |
+
target_wav_hats = []
|
| 180 |
+
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
model.eval()
|
| 183 |
+
for mixture_wav in mix_waves_batched:
|
| 184 |
+
mix_spec = model.stft(mixture_wav.to(device))
|
| 185 |
+
spec_hat = model(mix_spec)
|
| 186 |
+
target_wav_hat = model.istft(spec_hat)
|
| 187 |
+
target_wav_hat = target_wav_hat.cpu().detach().numpy()
|
| 188 |
+
target_wav_hats.append(target_wav_hat) # (b, c, t)
|
| 189 |
+
|
| 190 |
+
target_wav_hat = np.vstack(target_wav_hats) # (sum(b), c, t)
|
| 191 |
+
target_wav_hat = np.transpose(target_wav_hat, (0, 2, 1)) # (sum(b), t, c)
|
| 192 |
+
target_wav_hat = join_chunks(tmp_root, target_wav_hat, samplerate, overlap_size) # (t, c)
|
| 193 |
+
return target_wav_hat[:step_t].T # (c, t)
|
src/layers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from .batch_norm import *
|
src/layers/batch_norm.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
__all__ = ["IBN", "get_norm"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchNorm(nn.BatchNorm2d):
|
| 11 |
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
| 12 |
+
bias_init=0.0, **kwargs):
|
| 13 |
+
super().__init__(num_features, eps=eps, momentum=momentum)
|
| 14 |
+
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
| 15 |
+
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
| 16 |
+
self.weight.requires_grad_(not weight_freeze)
|
| 17 |
+
self.bias.requires_grad_(not bias_freeze)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SyncBatchNorm(nn.SyncBatchNorm):
|
| 21 |
+
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
| 22 |
+
bias_init=0.0):
|
| 23 |
+
super().__init__(num_features, eps=eps, momentum=momentum)
|
| 24 |
+
if weight_init is not None: nn.init.constant_(self.weight, weight_init)
|
| 25 |
+
if bias_init is not None: nn.init.constant_(self.bias, bias_init)
|
| 26 |
+
self.weight.requires_grad_(not weight_freeze)
|
| 27 |
+
self.bias.requires_grad_(not bias_freeze)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IBN(nn.Module):
|
| 31 |
+
def __init__(self, planes, bn_norm, **kwargs):
|
| 32 |
+
super(IBN, self).__init__()
|
| 33 |
+
half1 = int(planes / 2)
|
| 34 |
+
self.half = half1
|
| 35 |
+
half2 = planes - half1
|
| 36 |
+
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
| 37 |
+
self.BN = get_norm(bn_norm, half2, **kwargs)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
split = torch.split(x, self.half, 1)
|
| 41 |
+
out1 = self.IN(split[0].contiguous())
|
| 42 |
+
out2 = self.BN(split[1].contiguous())
|
| 43 |
+
out = torch.cat((out1, out2), 1)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GhostBatchNorm(BatchNorm):
|
| 48 |
+
def __init__(self, num_features, num_splits=1, **kwargs):
|
| 49 |
+
super().__init__(num_features, **kwargs)
|
| 50 |
+
self.num_splits = num_splits
|
| 51 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
| 52 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
| 53 |
+
|
| 54 |
+
def forward(self, input):
|
| 55 |
+
N, C, H, W = input.shape
|
| 56 |
+
if self.training or not self.track_running_stats:
|
| 57 |
+
self.running_mean = self.running_mean.repeat(self.num_splits)
|
| 58 |
+
self.running_var = self.running_var.repeat(self.num_splits)
|
| 59 |
+
outputs = F.batch_norm(
|
| 60 |
+
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
|
| 61 |
+
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
|
| 62 |
+
True, self.momentum, self.eps).view(N, C, H, W)
|
| 63 |
+
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
|
| 64 |
+
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
|
| 65 |
+
return outputs
|
| 66 |
+
else:
|
| 67 |
+
return F.batch_norm(
|
| 68 |
+
input, self.running_mean, self.running_var,
|
| 69 |
+
self.weight, self.bias, False, self.momentum, self.eps)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FrozenBatchNorm(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
| 75 |
+
It contains non-trainable buffers called
|
| 76 |
+
"weight" and "bias", "running_mean", "running_var",
|
| 77 |
+
initialized to perform identity transformation.
|
| 78 |
+
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
| 79 |
+
which are computed from the original four parameters of BN.
|
| 80 |
+
The affine transform `x * weight + bias` will perform the equivalent
|
| 81 |
+
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
| 82 |
+
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
| 83 |
+
will be left unchanged as identity transformation.
|
| 84 |
+
Other pre-trained backbone models may contain all 4 parameters.
|
| 85 |
+
The forward is implemented by `F.batch_norm(..., training=False)`.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
_version = 3
|
| 89 |
+
|
| 90 |
+
def __init__(self, num_features, eps=1e-5, **kwargs):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.num_features = num_features
|
| 93 |
+
self.eps = eps
|
| 94 |
+
self.register_buffer("weight", torch.ones(num_features))
|
| 95 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
| 96 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
| 97 |
+
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
if x.requires_grad:
|
| 101 |
+
# When gradients are needed, F.batch_norm will use extra memory
|
| 102 |
+
# because its backward op computes gradients for weight/bias as well.
|
| 103 |
+
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
| 104 |
+
bias = self.bias - self.running_mean * scale
|
| 105 |
+
scale = scale.reshape(1, -1, 1, 1)
|
| 106 |
+
bias = bias.reshape(1, -1, 1, 1)
|
| 107 |
+
return x * scale + bias
|
| 108 |
+
else:
|
| 109 |
+
# When gradients are not needed, F.batch_norm is a single fused op
|
| 110 |
+
# and provide more optimization opportunities.
|
| 111 |
+
return F.batch_norm(
|
| 112 |
+
x,
|
| 113 |
+
self.running_mean,
|
| 114 |
+
self.running_var,
|
| 115 |
+
self.weight,
|
| 116 |
+
self.bias,
|
| 117 |
+
training=False,
|
| 118 |
+
eps=self.eps,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def _load_from_state_dict(
|
| 122 |
+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 123 |
+
):
|
| 124 |
+
version = local_metadata.get("version", None)
|
| 125 |
+
|
| 126 |
+
if version is None or version < 2:
|
| 127 |
+
# No running_mean/var in early versions
|
| 128 |
+
# This will silent the warnings
|
| 129 |
+
if prefix + "running_mean" not in state_dict:
|
| 130 |
+
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
| 131 |
+
if prefix + "running_var" not in state_dict:
|
| 132 |
+
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
| 133 |
+
|
| 134 |
+
if version is not None and version < 3:
|
| 135 |
+
logger = logging.getLogger(__name__)
|
| 136 |
+
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
|
| 137 |
+
# In version < 3, running_var are used without +eps.
|
| 138 |
+
state_dict[prefix + "running_var"] -= self.eps
|
| 139 |
+
|
| 140 |
+
super()._load_from_state_dict(
|
| 141 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def __repr__(self):
|
| 145 |
+
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def convert_frozen_batchnorm(cls, module):
|
| 149 |
+
"""
|
| 150 |
+
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
| 151 |
+
Args:
|
| 152 |
+
module (torch.nn.Module):
|
| 153 |
+
Returns:
|
| 154 |
+
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
| 155 |
+
Otherwise, in-place convert module and return it.
|
| 156 |
+
Similar to convert_sync_batchnorm in
|
| 157 |
+
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
| 158 |
+
"""
|
| 159 |
+
bn_module = nn.modules.batchnorm
|
| 160 |
+
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
| 161 |
+
res = module
|
| 162 |
+
if isinstance(module, bn_module):
|
| 163 |
+
res = cls(module.num_features)
|
| 164 |
+
if module.affine:
|
| 165 |
+
res.weight.data = module.weight.data.clone().detach()
|
| 166 |
+
res.bias.data = module.bias.data.clone().detach()
|
| 167 |
+
res.running_mean.data = module.running_mean.data
|
| 168 |
+
res.running_var.data = module.running_var.data
|
| 169 |
+
res.eps = module.eps
|
| 170 |
+
else:
|
| 171 |
+
for name, child in module.named_children():
|
| 172 |
+
new_child = cls.convert_frozen_batchnorm(child)
|
| 173 |
+
if new_child is not child:
|
| 174 |
+
res.add_module(name, new_child)
|
| 175 |
+
return res
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_norm(norm, out_channels, **kwargs):
|
| 179 |
+
"""
|
| 180 |
+
Args:
|
| 181 |
+
norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
|
| 182 |
+
or a callable that takes a channel number and returns
|
| 183 |
+
the normalization layer as a nn.Module
|
| 184 |
+
out_channels: number of channels for normalization layer
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
nn.Module or None: the normalization layer
|
| 188 |
+
"""
|
| 189 |
+
# return nn.BatchNorm2d(out_channels)
|
| 190 |
+
|
| 191 |
+
if isinstance(norm, str):
|
| 192 |
+
if len(norm) == 0:
|
| 193 |
+
return None
|
| 194 |
+
norm = {
|
| 195 |
+
"BN": BatchNorm,
|
| 196 |
+
"syncBN": SyncBatchNorm,
|
| 197 |
+
"GhostBN": GhostBatchNorm,
|
| 198 |
+
"FrozenBN": FrozenBatchNorm,
|
| 199 |
+
"GN": lambda channels, **args: nn.GroupNorm(32, channels),
|
| 200 |
+
}[norm]
|
| 201 |
+
return norm(out_channels, **kwargs)
|
src/layers/chunk_size.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#%%
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def wave_to_batches(mix, inf_ck, overlap, batch_size):
|
| 8 |
+
'''
|
| 9 |
+
Args:
|
| 10 |
+
mix: (2, N) numpy array
|
| 11 |
+
inf_ck: int, the chunk size as the model input (contains 2*overlap)
|
| 12 |
+
inf_ck = overlap + true_samples + overlap
|
| 13 |
+
overlap: int, the discarded samples at each side
|
| 14 |
+
Returns:
|
| 15 |
+
a tuples of batches, each batch is a (batch, 2, inf_ck) torch tensor
|
| 16 |
+
'''
|
| 17 |
+
true_samples = inf_ck - 2 * overlap
|
| 18 |
+
channels = mix.shape[0]
|
| 19 |
+
|
| 20 |
+
right_pad = true_samples + overlap - ((mix.shape[-1]) % true_samples)
|
| 21 |
+
mixture = np.concatenate((np.zeros((channels, overlap), dtype='float32'),
|
| 22 |
+
mix,
|
| 23 |
+
np.zeros((channels, right_pad), dtype='float32')),
|
| 24 |
+
1)
|
| 25 |
+
|
| 26 |
+
num_chunks = mixture.shape[-1] // true_samples
|
| 27 |
+
mix_waves_batched = np.array([mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
|
| 28 |
+
range(num_chunks)]) # (x,2,inf_ck)
|
| 29 |
+
return torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
|
| 30 |
+
|
| 31 |
+
def batches_to_wave(target_hat_chunks, overlap, org_len):
|
| 32 |
+
'''
|
| 33 |
+
Args:
|
| 34 |
+
target_hat_chunks: a list of (batch, 2, inf_ck) torch tensors
|
| 35 |
+
overlap: int, the discarded samples at each side
|
| 36 |
+
org_len: int, the original length of the mixture
|
| 37 |
+
Returns:
|
| 38 |
+
(2, N) numpy array
|
| 39 |
+
'''
|
| 40 |
+
target_hat_chunks = [c[..., overlap:-overlap] for c in target_hat_chunks]
|
| 41 |
+
target_hat_chunks = torch.cat(target_hat_chunks)
|
| 42 |
+
|
| 43 |
+
# concat all output chunks
|
| 44 |
+
return target_hat_chunks.transpose(0, 1).reshape(2, -1)[..., :org_len].detach().cpu().numpy()
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
mix = np.random.rand(2, 14318640)
|
| 48 |
+
inf_ck = 261120
|
| 49 |
+
overlap = 3072
|
| 50 |
+
batch_size = 8
|
| 51 |
+
out = wave_to_batches(mix, inf_ck, overlap, batch_size)
|
| 52 |
+
in_wav = batches_to_wave(out, overlap, mix.shape[-1])
|
| 53 |
+
print(in_wav.shape)
|
src/train.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
import pyrootutils
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
from pytorch_lightning import (
|
| 11 |
+
Callback,
|
| 12 |
+
LightningDataModule,
|
| 13 |
+
LightningModule,
|
| 14 |
+
Trainer,
|
| 15 |
+
seed_everything,
|
| 16 |
+
)
|
| 17 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 18 |
+
from hydra.core.hydra_config import HydraConfig
|
| 19 |
+
|
| 20 |
+
from src import utils
|
| 21 |
+
|
| 22 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 23 |
+
|
| 24 |
+
log = utils.get_pylogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@utils.task_wrapper
|
| 28 |
+
def train(cfg: DictConfig) -> Optional[float]:
|
| 29 |
+
"""Contains training pipeline.
|
| 30 |
+
Instantiates all PyTorch Lightning objects from config.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
cfg (DictConfig): Configuration composed by Hydra.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Optional[float]: Metric score for hyperparameter optimization.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# Set seed for random number generators in pytorch, numpy and python.random
|
| 40 |
+
try:
|
| 41 |
+
if "seed" in cfg:
|
| 42 |
+
# set seed for random number generators in pytorch, numpy and python.random
|
| 43 |
+
if cfg.get("seed"):
|
| 44 |
+
pl.seed_everything(cfg.seed, workers=True)
|
| 45 |
+
|
| 46 |
+
else:
|
| 47 |
+
raise ModuleNotFoundError
|
| 48 |
+
|
| 49 |
+
except ModuleNotFoundError:
|
| 50 |
+
print('[Error] seed should be fixed for reproducibility \n=> e.g. python run.py +seed=$SEED')
|
| 51 |
+
exit(-1)
|
| 52 |
+
|
| 53 |
+
# Init Lightning datamodule
|
| 54 |
+
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
|
| 55 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
|
| 56 |
+
|
| 57 |
+
# Init Lightning model
|
| 58 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
| 59 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
| 60 |
+
|
| 61 |
+
# Init Lightning callbacks
|
| 62 |
+
callbacks: List[Callback] = []
|
| 63 |
+
if "callbacks" in cfg:
|
| 64 |
+
for _, cb_conf in cfg["callbacks"].items():
|
| 65 |
+
if "_target_" in cb_conf:
|
| 66 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| 67 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
| 68 |
+
|
| 69 |
+
# Init Lightning loggers
|
| 70 |
+
if "resume_from_checkpoint" in cfg.trainer:
|
| 71 |
+
ckpt_path = cfg.trainer.resume_from_checkpoint
|
| 72 |
+
# get the parent directory of the checkpoint path
|
| 73 |
+
log_dir = os.path.dirname(os.path.dirname(ckpt_path))
|
| 74 |
+
tensorboard_dir = os.path.join(log_dir, "tensorboard")
|
| 75 |
+
if os.path.exists(tensorboard_dir):
|
| 76 |
+
# copy tensorboard dir to the parent directory of the checkpoint path
|
| 77 |
+
# HydraConfig.get().run.dir returns new dir so do not use it! (now fixed)
|
| 78 |
+
shutil.copytree(tensorboard_dir,os.path.join(os.getcwd(),"tensorboard"))
|
| 79 |
+
|
| 80 |
+
wandb_dir = os.path.join(log_dir, "wandb")
|
| 81 |
+
if os.path.exists(wandb_dir):
|
| 82 |
+
shutil.copytree(wandb_dir,os.path.join(os.getcwd(),"wandb"))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
logger: List = []
|
| 86 |
+
if "logger" in cfg:
|
| 87 |
+
for _, lg_conf in cfg["logger"].items():
|
| 88 |
+
if "_target_" in lg_conf:
|
| 89 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 90 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
| 91 |
+
|
| 92 |
+
for wandb_logger in [l for l in logger if isinstance(l, WandbLogger)]:
|
| 93 |
+
utils.wandb_login(key=cfg.wandb_api_key)
|
| 94 |
+
# utils.wandb_watch_all(wandb_logger, model) # TODO buggy
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
# Init Lightning trainer
|
| 98 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 99 |
+
# get env variable use_gloo
|
| 100 |
+
use_gloo = os.environ.get("USE_GLOO", False)
|
| 101 |
+
if use_gloo:
|
| 102 |
+
from pytorch_lightning.strategies import DDPStrategy
|
| 103 |
+
ddp = DDPStrategy(process_group_backend='gloo')
|
| 104 |
+
trainer: Trainer = hydra.utils.instantiate(
|
| 105 |
+
cfg.trainer, strategy=ddp, callbacks=callbacks, logger=logger, _convert_="partial"
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
trainer: Trainer = hydra.utils.instantiate(
|
| 109 |
+
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Send some parameters from config to all lightning loggers
|
| 113 |
+
log.info("Logging hyperparameters!")
|
| 114 |
+
utils.log_hyperparameters(
|
| 115 |
+
dict(
|
| 116 |
+
cfg=cfg,
|
| 117 |
+
model=model,
|
| 118 |
+
datamodule=datamodule,
|
| 119 |
+
trainer=trainer,
|
| 120 |
+
callbacks=callbacks,
|
| 121 |
+
logger=logger,
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Train the model
|
| 126 |
+
log.info("Starting training!")
|
| 127 |
+
trainer.fit(model=model, datamodule=datamodule)
|
| 128 |
+
|
| 129 |
+
# Evaluate model on test set after training
|
| 130 |
+
# if not cfg.trainer.get("fast_dev_run"):
|
| 131 |
+
# log.info("Starting testing!")
|
| 132 |
+
# trainer.test()
|
| 133 |
+
|
| 134 |
+
# Make sure everything closed properly
|
| 135 |
+
log.info("Finalizing!")
|
| 136 |
+
# utils.finish(
|
| 137 |
+
# config=cfg,
|
| 138 |
+
# model=model,
|
| 139 |
+
# datamodule=datamodule,
|
| 140 |
+
# trainer=trainer,
|
| 141 |
+
# callbacks=callbacks,
|
| 142 |
+
# logger=logger,
|
| 143 |
+
# )
|
| 144 |
+
|
| 145 |
+
# Print path to best checkpoint
|
| 146 |
+
# log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
|
| 147 |
+
|
| 148 |
+
# Return metric score for hyperparameter optimization
|
| 149 |
+
# optimized_metric = cfg.get("optimized_metric")
|
| 150 |
+
# if optimized_metric:
|
| 151 |
+
# return trainer.callback_metrics[optimized_metric]
|
| 152 |
+
return None, None
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.pylogger import get_pylogger
|
| 2 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
| 3 |
+
from src.utils.utils import *
|
src/utils/data_augmentation.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess as sp
|
| 3 |
+
import tempfile
|
| 4 |
+
import warnings
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from concurrent import futures
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import torch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
warnings.simplefilter(action='ignore', category=Warning)
|
| 14 |
+
source_names = ['vocals', 'drums', 'bass', 'other']
|
| 15 |
+
sample_rate = 44100
|
| 16 |
+
|
| 17 |
+
def main (args):
|
| 18 |
+
data_root = args.data_dir
|
| 19 |
+
train = args.train
|
| 20 |
+
test = args.test
|
| 21 |
+
valid = args.valid
|
| 22 |
+
|
| 23 |
+
musdb_train_path = data_root + 'train/'
|
| 24 |
+
musdb_test_path = data_root + 'test/'
|
| 25 |
+
musdb_valid_path = data_root + 'valid/'
|
| 26 |
+
print(f"train={train}, test={test}, valid={valid}")
|
| 27 |
+
|
| 28 |
+
mix_name = 'mixture'
|
| 29 |
+
|
| 30 |
+
P = [-3, -2, -1, 0, 1, 2, 3] # pitch shift amounts (in semitones)
|
| 31 |
+
T = [-30, -20, -10, 0, 10, 20, 30] # time stretch amounts (10 means 10% slower)
|
| 32 |
+
|
| 33 |
+
pool = futures.ProcessPoolExecutor
|
| 34 |
+
pool_workers = 13
|
| 35 |
+
pendings = []
|
| 36 |
+
with pool(pool_workers) as pool:
|
| 37 |
+
for p in P:
|
| 38 |
+
for t in T:
|
| 39 |
+
if not (p==0 and t==0):
|
| 40 |
+
if train:
|
| 41 |
+
pendings.append(pool.submit(save_shifted_dataset, p, t, musdb_train_path))
|
| 42 |
+
# save_shifted_dataset(p, t, musdb_train_path)
|
| 43 |
+
if valid:
|
| 44 |
+
save_shifted_dataset(p, t, musdb_valid_path)
|
| 45 |
+
if test:
|
| 46 |
+
save_shifted_dataset(p, t, musdb_test_path)
|
| 47 |
+
for pending in pendings:
|
| 48 |
+
pending.result()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def shift(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
|
| 52 |
+
def i16_pcm(wav):
|
| 53 |
+
if wav.dtype == np.int16:
|
| 54 |
+
return wav
|
| 55 |
+
return (wav * 2 ** 15).clamp_(-2 ** 15, 2 ** 15 - 1).short()
|
| 56 |
+
|
| 57 |
+
def f32_pcm(wav):
|
| 58 |
+
if wav.dtype == np.float:
|
| 59 |
+
return wav
|
| 60 |
+
return wav.float() / 2 ** 15
|
| 61 |
+
|
| 62 |
+
"""
|
| 63 |
+
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
|
| 64 |
+
pitch is in semi tones.
|
| 65 |
+
Requires `soundstretch` to be installed, see
|
| 66 |
+
https://www.surina.net/soundtouch/soundstretch.html
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
inputfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav")
|
| 70 |
+
outfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav")
|
| 71 |
+
|
| 72 |
+
sf.write(inputfile.name, data=i16_pcm(wav).t().numpy(), samplerate=samplerate, format='WAV')
|
| 73 |
+
command = [
|
| 74 |
+
"soundstretch",
|
| 75 |
+
inputfile.name,
|
| 76 |
+
outfile.name,
|
| 77 |
+
f"-pitch={pitch}",
|
| 78 |
+
f"-tempo={tempo:.6f}",
|
| 79 |
+
]
|
| 80 |
+
if quick:
|
| 81 |
+
command += ["-quick"]
|
| 82 |
+
if voice:
|
| 83 |
+
command += ["-speech"]
|
| 84 |
+
try:
|
| 85 |
+
sp.run(command, capture_output=True, check=True)
|
| 86 |
+
except sp.CalledProcessError as error:
|
| 87 |
+
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
|
| 88 |
+
wav, sr = sf.read(outfile.name, dtype='float32')
|
| 89 |
+
# wav = np.float32(wav)
|
| 90 |
+
# wav = f32_pcm(torch.from_numpy(wav).t())
|
| 91 |
+
assert sr == samplerate
|
| 92 |
+
return wav
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def save_shifted_dataset(delta_pitch, delta_tempo, data_path):
|
| 96 |
+
out_path = data_path[:-1] + f'_p={delta_pitch}_t={delta_tempo}/'
|
| 97 |
+
try:
|
| 98 |
+
os.mkdir(out_path)
|
| 99 |
+
except FileExistsError:
|
| 100 |
+
pass
|
| 101 |
+
track_names = list(filter(lambda x: os.path.isdir(f'{data_path}/{x}'), sorted(os.listdir(data_path))))
|
| 102 |
+
for track_name in tqdm(track_names):
|
| 103 |
+
try:
|
| 104 |
+
os.mkdir(f'{out_path}/{track_name}')
|
| 105 |
+
except FileExistsError:
|
| 106 |
+
pass
|
| 107 |
+
for s_name in source_names:
|
| 108 |
+
source = load_wav(f'{data_path}/{track_name}/{s_name}.wav')
|
| 109 |
+
shifted = shift(
|
| 110 |
+
torch.tensor(source),
|
| 111 |
+
delta_pitch,
|
| 112 |
+
delta_tempo,
|
| 113 |
+
voice=s_name == 'vocals')
|
| 114 |
+
sf.write(f'{out_path}/{track_name}/{s_name}.wav', shifted, samplerate=sample_rate, format='WAV')
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def load_wav(path, sr=None):
|
| 118 |
+
return sf.read(path, samplerate=sr, dtype='float32')[0].T
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == '__main__':
|
| 122 |
+
parser = ArgumentParser()
|
| 123 |
+
parser.add_argument('--data_dir', type=str)
|
| 124 |
+
parser.add_argument('--train', type=bool, default=True)
|
| 125 |
+
parser.add_argument('--valid', type=bool, default=False)
|
| 126 |
+
parser.add_argument('--test', type=bool, default=False)
|
| 127 |
+
|
| 128 |
+
main(parser.parse_args())
|