youngNdum commited on
Commit
f47f762
·
verified ·
1 Parent(s): 172f29c

Upload 55 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/callbacks/default.yaml +20 -0
  2. configs/callbacks/none.yaml +0 -0
  3. configs/callbacks/wandb.yaml +31 -0
  4. configs/config.yaml +44 -0
  5. configs/datamodule/musdb18_hq.yaml +50 -0
  6. configs/datamodule/musdb_dev14.yaml +28 -0
  7. configs/evaluation.yaml +43 -0
  8. configs/experiment/bass_dis.yaml +38 -0
  9. configs/experiment/drums_dis.yaml +38 -0
  10. configs/experiment/multigpu_default.yaml +26 -0
  11. configs/experiment/other_dis.yaml +38 -0
  12. configs/experiment/vocals_dis.yaml +38 -0
  13. configs/hydra/default.yaml +16 -0
  14. configs/infer.yaml +26 -0
  15. configs/logger/csv.yaml +8 -0
  16. configs/logger/many_loggers.yaml +10 -0
  17. configs/logger/neptune.yaml +11 -0
  18. configs/logger/none.yaml +0 -0
  19. configs/logger/tensorboard.yaml +10 -0
  20. configs/logger/wandb.yaml +15 -0
  21. configs/model/bass.yaml +28 -0
  22. configs/model/drums.yaml +28 -0
  23. configs/model/other.yaml +28 -0
  24. configs/model/vocals.yaml +28 -0
  25. configs/paths/default.yaml +18 -0
  26. configs/trainer/ddp.yaml +13 -0
  27. configs/trainer/default.yaml +19 -0
  28. configs/trainer/minimal.yaml +21 -0
  29. src/__init__.py +0 -0
  30. src/callbacks/__init__.py +0 -0
  31. src/callbacks/onnx_callback.py +49 -0
  32. src/callbacks/wandb_callbacks.py +280 -0
  33. src/datamodules/__init__.py +0 -0
  34. src/datamodules/datasets/__init__.py +0 -0
  35. src/datamodules/datasets/musdb.py +174 -0
  36. src/datamodules/musdb_datamodule.py +117 -0
  37. src/dp_tdf/__init__.py +0 -0
  38. src/dp_tdf/abstract.py +204 -0
  39. src/dp_tdf/bandsequence.py +136 -0
  40. src/dp_tdf/dp_tdf_net.py +118 -0
  41. src/dp_tdf/modules.py +158 -0
  42. src/evaluation/eval.py +120 -0
  43. src/evaluation/eval_demo.py +71 -0
  44. src/evaluation/separate.py +193 -0
  45. src/layers/__init__.py +2 -0
  46. src/layers/batch_norm.py +201 -0
  47. src/layers/chunk_size.py +53 -0
  48. src/train.py +152 -0
  49. src/utils/__init__.py +3 -0
  50. src/utils/data_augmentation.py +128 -0
configs/callbacks/default.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_checkpoint:
2
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
3
+ monitor: "val/usdr" # name of the logged metric which determines when model is improving
4
+ save_top_k: 5 # save k best models (determined by above metric)
5
+ save_last: True # additionaly always save model from last epoch
6
+ mode: "max" # can be "max" or "min"
7
+ verbose: False
8
+ dirpath: "checkpoints/"
9
+ filename: "{epoch:02d}-{step}"
10
+ #
11
+ #early_stopping:
12
+ # _target_: pytorch_lightning.callbacks.EarlyStopping
13
+ # monitor: "val/sdr" # name of the logged metric which determines when model is improving
14
+ # patience: 300 # how many epochs of not improving until training stops
15
+ # mode: "max" # can be "max" or "min"
16
+ # min_delta: 0.05 # minimum change in the monitored metric needed to qualify as an improvement
17
+
18
+ #make_onnx:
19
+ # _target_: src.callbacks.onnx_callback.MakeONNXCallback
20
+ # dirpath: "onnx/"
configs/callbacks/none.yaml ADDED
File without changes
configs/callbacks/wandb.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+
4
+ watch_model:
5
+ _target_: src.callbacks.wandb_callbacks.WatchModel
6
+ log: "all"
7
+ log_freq: 100
8
+
9
+ #upload_valid_track:
10
+ # _target_: src.callbacks.wandb_callbacks.UploadValidTrack
11
+ # crop: 3
12
+ # upload_after_n_epoch: -1
13
+
14
+ #upload_code_as_artifact:
15
+ # _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
16
+ # code_dir: ${work_dir}/src
17
+ #
18
+ #upload_ckpts_as_artifact:
19
+ # _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
20
+ # ckpt_dir: "checkpoints/"
21
+ # upload_best_only: True
22
+ #
23
+ #log_f1_precision_recall_heatmap:
24
+ # _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
25
+ #
26
+ #log_confusion_matrix:
27
+ # _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
28
+ #
29
+ #log_image_predictions:
30
+ # _target_: src.callbacks.wandb_callbacks.LogImagePredictions
31
+ # num_samples: 8
configs/config.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default training configuration
4
+ defaults:
5
+ - datamodule: musdb18_hq
6
+ - model: null
7
+ - callbacks: default # set this to null if you don't want to use callbacks
8
+ - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
9
+ - trainer: default
10
+ - hparams_search: null
11
+ - paths: default.yaml
12
+
13
+ - hydra: default
14
+
15
+ - experiment: null
16
+
17
+ # enable color logging
18
+ - override hydra/hydra_logging: colorlog
19
+ - override hydra/job_logging: colorlog
20
+
21
+
22
+ # path to original working directory
23
+ # hydra hijacks working directory by changing it to the current log directory,
24
+ # so it's useful to have this path as a special variable
25
+ # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
26
+ #work_dir: ${hydra:runtime.cwd}
27
+ #output_dir: ${hydra:runtime.output_dir}
28
+
29
+ # path to folder with data
30
+
31
+
32
+ # use `python run.py debug=true` for easy debugging!
33
+ # this will run 1 train, val and test loop with only 1 batch
34
+ # equivalent to running `python run.py trainer.fast_dev_run=true`
35
+ # (this is placed here just for easier access from command line)
36
+ debug: False
37
+
38
+ # pretty print config at the start of the run using Rich library
39
+ print_config: True
40
+
41
+ # disable python warnings if they annoy you
42
+ ignore_warnings: True
43
+
44
+ wandb_api_key: ${oc.env:wandb_api_key}
configs/datamodule/musdb18_hq.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.datamodules.musdb_datamodule.MusdbDataModule
2
+
3
+ # data_dir is specified in config.yaml
4
+ data_dir: null
5
+
6
+ single_channel: False
7
+
8
+ # chunk_size = (hop_length * (dim_t - 1) / sample_rate) secs
9
+ sample_rate: 44100
10
+ hop_length: ${model.hop_length} # stft hop_length
11
+ dim_t: ${model.dim_t} # number of stft frames
12
+
13
+ # number of overlapping wave samples between chunks when separating a whole track
14
+ overlap: ${model.overlap}
15
+
16
+ source_names:
17
+ - bass
18
+ - drums
19
+ - other
20
+ - vocals
21
+ target_name: ${model.target_name}
22
+
23
+ external_datasets: null
24
+ #external_datasets:
25
+ # - test
26
+
27
+
28
+ batch_size: 8
29
+ num_workers: 0
30
+ pin_memory: False
31
+
32
+ aug_params:
33
+ - 2 # maximum pitch shift in semitones (-x < shift param < x)
34
+ - 20 # maximum time stretch percentage (-x < stretch param < x)
35
+
36
+ validation_set:
37
+ - Actions - One Minute Smile
38
+ - Clara Berry And Wooldog - Waltz For My Victims
39
+ - Johnny Lokke - Promises & Lies
40
+ - Patrick Talbot - A Reason To Leave
41
+ - Triviul - Angelsaint
42
+ # - Alexander Ross - Goodbye Bolero
43
+ # - Fergessen - Nos Palpitants
44
+ # - Leaf - Summerghost
45
+ # - Skelpolu - Human Mistakes
46
+ # - Young Griffo - Pennies
47
+ # - ANiMAL - Rockshow
48
+ # - James May - On The Line
49
+ # - Meaxic - Take A Step
50
+ # - Traffic Experiment - Sirens
configs/datamodule/musdb_dev14.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ defaults:
3
+ - musdb18_hq
4
+
5
+ data_dir: ${oc.env:data_dir}
6
+
7
+ has_split_structure: True
8
+
9
+ validation_set:
10
+ # - Meaxic - Take A Step
11
+ # - Skelpolu - Human Mistakes
12
+ - Actions - One Minute Smile
13
+ - Clara Berry And Wooldog - Waltz For My Victims
14
+ - Johnny Lokke - Promises & Lies
15
+ - Patrick Talbot - A Reason To Leave
16
+ - Triviul - Angelsaint
17
+ - Alexander Ross - Goodbye Bolero
18
+ - Fergessen - Nos Palpitants
19
+ - Leaf - Summerghost
20
+ - Skelpolu - Human Mistakes
21
+ - Young Griffo - Pennies
22
+ - ANiMAL - Rockshow
23
+ - James May - On The Line
24
+ - Meaxic - Take A Step
25
+ - Traffic Experiment - Sirens
26
+
27
+
28
+ mode: musdb18hq
configs/evaluation.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default training configuration
4
+ defaults:
5
+ - model: ConvTDFNet_vocals
6
+ - logger:
7
+ - wandb
8
+ - tensorboard
9
+ - paths: default.yaml
10
+ # enable color logging
11
+ - override hydra/hydra_logging: colorlog
12
+ - override hydra/job_logging: colorlog
13
+
14
+ hydra:
15
+ run:
16
+ dir: ${get_eval_log_dir:${ckpt_path}}
17
+
18
+ #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx"
19
+ ckpt_path: ${oc.env:ckpt_path}
20
+
21
+ split: 'test'
22
+ batch_size: 4
23
+ device: 'cuda:0'
24
+ bss: fast # fast or official
25
+ single: False # for debug investigation, only run the model on 1 single song
26
+
27
+ #data_dir: ${oc.env:data_dir}
28
+ eval_dir: ${oc.env:data_dir}
29
+ wandb_api_key: ${oc.env:wandb_api_key}
30
+
31
+ logger:
32
+ wandb:
33
+ # project: mdx_eval_${split}
34
+ project: new_eval_order
35
+ name: ${get_eval_log_dir:${ckpt_path}}
36
+
37
+ pool_workers: 8
38
+ double_chunk: False
39
+
40
+ overlap_add:
41
+ overlap_rate: 0.5
42
+ tmp_root: ${paths.root_dir}/tmp # for saving temp chunks, since we use ffmpeg and will need io to disk
43
+ samplerate: 44100
configs/experiment/bass_dis.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python run.py experiment=example_simple.yaml
5
+
6
+ defaults:
7
+ - multigpu_default
8
+ - override /model: bass.yaml
9
+
10
+ seed: 2021
11
+
12
+ exp_name: bass_g32
13
+
14
+ # the name inside project
15
+ logger:
16
+ wandb:
17
+ name: ${exp_name}
18
+
19
+ model:
20
+ lr: 0.0002
21
+ optimizer: adamW
22
+ bn_norm: syncBN
23
+ audio_ch: 2 # datamodule.single_channel
24
+ g: 32
25
+
26
+ trainer:
27
+ devices: 2 # int or list
28
+ sync_batchnorm: True
29
+ track_grad_norm: 2
30
+ # gradient_clip_val: 5
31
+
32
+ datamodule:
33
+ batch_size: 8
34
+ num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
35
+ pin_memory: False
36
+ overlap: ${model.overlap}
37
+ audio_ch: ${model.audio_ch}
38
+ epoch_size:
configs/experiment/drums_dis.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python run.py experiment=example_simple.yaml
5
+
6
+ defaults:
7
+ - multigpu_default
8
+ - override /model: drums.yaml
9
+
10
+ seed: 2021
11
+
12
+ exp_name: drums_g32
13
+
14
+ # the name inside project
15
+ logger:
16
+ wandb:
17
+ name: ${exp_name}
18
+
19
+ model:
20
+ lr: 0.0002
21
+ optimizer: adamW
22
+ bn_norm: syncBN
23
+ audio_ch: 2 # datamodule.single_channel
24
+ g: 32
25
+
26
+ trainer:
27
+ devices: 2 # int or list
28
+ sync_batchnorm: True
29
+ track_grad_norm: 2
30
+ # gradient_clip_val: 5
31
+
32
+ datamodule:
33
+ batch_size: 8
34
+ num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
35
+ pin_memory: False
36
+ overlap: ${model.overlap}
37
+ audio_ch: ${model.audio_ch}
38
+ epoch_size:
configs/experiment/multigpu_default.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python run.py experiment=example_simple.yaml
5
+
6
+ defaults:
7
+ - override /callbacks: default
8
+ - override /logger:
9
+ - wandb
10
+ - tensorboard
11
+
12
+
13
+ #callbacks:
14
+ # early_stopping:
15
+ # patience: 1000000
16
+
17
+ #datamodule:
18
+ # external_datasets:
19
+ # - test
20
+
21
+ trainer:
22
+ max_epochs: 1000000
23
+ accelerator: cuda
24
+ amp_backend: native
25
+ precision: 16
26
+ track_grad_norm: -1
configs/experiment/other_dis.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python run.py experiment=example_simple.yaml
5
+
6
+ defaults:
7
+ - multigpu_default
8
+ - override /model: other.yaml
9
+
10
+ seed: 2021
11
+
12
+ exp_name: other_g32
13
+
14
+ # the name inside project
15
+ logger:
16
+ wandb:
17
+ name: ${exp_name}
18
+
19
+ model:
20
+ lr: 0.0002
21
+ optimizer: adamW
22
+ bn_norm: syncBN
23
+ audio_ch: 2 # datamodule.single_channel
24
+ g: 32
25
+
26
+ trainer:
27
+ devices: 2 # int or list
28
+ sync_batchnorm: True
29
+ track_grad_norm: 2
30
+ # gradient_clip_val: 5
31
+
32
+ datamodule:
33
+ batch_size: 8
34
+ num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
35
+ pin_memory: False
36
+ overlap: ${model.overlap}
37
+ audio_ch: ${model.audio_ch}
38
+ epoch_size:
configs/experiment/vocals_dis.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python run.py experiment=example_simple.yaml
5
+
6
+ defaults:
7
+ - multigpu_default
8
+ - override /model: vocals.yaml
9
+
10
+ seed: 2021
11
+
12
+ exp_name: vocals_g32
13
+
14
+ # the name inside project
15
+ logger:
16
+ wandb:
17
+ name: ${exp_name}
18
+
19
+ model:
20
+ lr: 0.0002
21
+ optimizer: adamW
22
+ bn_norm: syncBN
23
+ audio_ch: 2 # datamodule.single_channel
24
+ g: 32
25
+
26
+ trainer:
27
+ devices: 2 # int or list
28
+ sync_batchnorm: True
29
+ track_grad_norm: 2
30
+ # gradient_clip_val: 5
31
+
32
+ datamodule:
33
+ batch_size: 8
34
+ num_workers: ${oc.decode:${oc.env:NUM_WORKERS}}
35
+ pin_memory: False
36
+ overlap: ${model.overlap}
37
+ audio_ch: ${model.audio_ch}
38
+ epoch_size:
configs/hydra/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # output paths for hydra logs
2
+ run:
3
+ # dir: logs/runs/${datamodule.target_name}_${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
4
+ dir: ${get_train_log_dir:${datamodule.target_name},${exp_name}}
5
+
6
+ sweep:
7
+ # dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
8
+ dir: ${get_sweep_log_dir:${datamodule.target_name},${exp_name}}
9
+ subdir: ${hydra.job.num}
10
+
11
+ # you can set here environment variables that are universal for all users
12
+ # for system specific variables (like data paths) it's better to use .env file!
13
+ job:
14
+ env_set:
15
+ EXAMPLE_VAR: "example_value"
16
+
configs/infer.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # specify here default training configuration
4
+ defaults:
5
+ - model: vocals
6
+ - paths: default.yaml
7
+ # enable color logging
8
+ - override hydra/hydra_logging: colorlog
9
+ - override hydra/job_logging: colorlog
10
+
11
+ #hydra:
12
+ # run:
13
+ # dir: ${get_eval_log_dir:${ckpt_path}}
14
+
15
+ #ckpt_path: "G:\\Experiments\\KLRef\\vocals.onnx"
16
+ ckpt_path:
17
+ mixture_path:
18
+ batch_size: 4
19
+ device: 'cuda:0'
20
+
21
+ double_chunk: False
22
+
23
+ overlap_add:
24
+ overlap_rate: 0.5
25
+ tmp_root: ${paths.root_dir}/tmp # for saving temp chunks, since we use ffmpeg and will need io to disk
26
+ samplerate: 44100
configs/logger/csv.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # csv logger built in lightning
2
+
3
+ csv:
4
+ _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5
+ save_dir: "."
6
+ name: "csv/"
7
+ version: null
8
+ prefix: ""
configs/logger/many_loggers.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # train with many loggers at once
2
+
3
+ defaults:
4
+ # - aim.yaml
5
+ # - comet.yaml
6
+ - csv.yaml
7
+ # - mlflow.yaml
8
+ # - neptune.yaml
9
+ # - tensorboard.yaml
10
+ - wandb.yaml
configs/logger/neptune.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://neptune.ai
2
+
3
+ neptune:
4
+ _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5
+ api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable
6
+ project_name: your_name/template-tests
7
+ close_after_fit: True
8
+ offline_mode: False
9
+ experiment_name: null
10
+ experiment_id: null
11
+ prefix: ""
configs/logger/none.yaml ADDED
File without changes
configs/logger/tensorboard.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://www.tensorflow.org/tensorboard/
2
+
3
+ tensorboard:
4
+ _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5
+ save_dir: "tensorboard/"
6
+ name: "default"
7
+ version: null
8
+ log_graph: False
9
+ default_hp_metric: True
10
+ prefix: ""
configs/logger/wandb.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://wandb.ai
2
+
3
+ wandb:
4
+ _target_: pytorch_lightning.loggers.wandb.WandbLogger
5
+ project: dtt_${model.target_name}
6
+ name: null
7
+ save_dir: ${hydra:run.dir}
8
+ offline: False # set True to store all logs only locally
9
+ id: null # pass correct id to resume experiment!
10
+ # entity: "" # set to name of your wandb team or just remove it
11
+ log_model: False
12
+ prefix: ""
13
+ job_type: "train"
14
+ group: ""
15
+ tags: []
configs/model/bass.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.dp_tdf.dp_tdf_net.DPTDFNet
2
+
3
+ # abstract parent class
4
+ target_name: 'bass'
5
+ lr: 0.0001
6
+ optimizer: adamW
7
+
8
+ dim_f: 864
9
+ dim_t: 256
10
+ n_fft: 6144
11
+ hop_length: 1024
12
+ overlap: 3072
13
+
14
+ audio_ch: 2
15
+
16
+ block_type: TFC_TDF_Res2
17
+ num_blocks: 5
18
+ l: 3
19
+ g: 32
20
+ k: 3
21
+ bn: 2
22
+ bias: False
23
+ bn_norm: BN
24
+ bandsequence:
25
+ rnn_type: LSTM
26
+ bidirectional: True
27
+ num_layers: 4
28
+ n_heads: 2
configs/model/drums.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.dp_tdf.dp_tdf_net.DPTDFNet
2
+
3
+ # abstract parent class
4
+ target_name: 'drums'
5
+ lr: 0.0001
6
+ optimizer: adamW
7
+
8
+ dim_f: 2048
9
+ dim_t: 256
10
+ n_fft: 6144
11
+ hop_length: 1024
12
+ overlap: 3072
13
+
14
+ audio_ch: 2
15
+
16
+ block_type: TFC_TDF_Res2
17
+ num_blocks: 5
18
+ l: 3
19
+ g: 32
20
+ k: 3
21
+ bn: 8
22
+ bias: False
23
+ bn_norm: BN
24
+ bandsequence:
25
+ rnn_type: LSTM
26
+ bidirectional: True
27
+ num_layers: 4
28
+ n_heads: 2
configs/model/other.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.dp_tdf.dp_tdf_net.DPTDFNet
2
+
3
+ # abstract parent class
4
+ target_name: 'other'
5
+ lr: 0.0001
6
+ optimizer: adamW
7
+
8
+ dim_f: 2048
9
+ dim_t: 256
10
+ n_fft: 6144
11
+ hop_length: 1024
12
+ overlap: 3072
13
+
14
+ audio_ch: 2
15
+
16
+ block_type: TFC_TDF_Res2
17
+ num_blocks: 5
18
+ l: 3
19
+ g: 32
20
+ k: 3
21
+ bn: 8
22
+ bias: False
23
+ bn_norm: BN
24
+ bandsequence:
25
+ rnn_type: LSTM
26
+ bidirectional: True
27
+ num_layers: 4
28
+ n_heads: 2
configs/model/vocals.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.dp_tdf.dp_tdf_net.DPTDFNet
2
+
3
+ # abstract parent class
4
+ target_name: 'vocals'
5
+ lr: 0.0001
6
+ optimizer: adamW
7
+
8
+ dim_f: 2048
9
+ dim_t: 256
10
+ n_fft: 6144
11
+ hop_length: 1024
12
+ overlap: 3072
13
+
14
+ audio_ch: 2
15
+
16
+ block_type: TFC_TDF_Res2
17
+ num_blocks: 5
18
+ l: 3
19
+ g: 32
20
+ k: 3
21
+ bn: 8
22
+ bias: False
23
+ bn_norm: BN
24
+ bandsequence:
25
+ rnn_type: LSTM
26
+ bidirectional: True
27
+ num_layers: 4
28
+ n_heads: 2
configs/paths/default.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path to root directory
2
+ # this requires PROJECT_ROOT environment variable to exist
3
+ # you can replace it with "." if you want the root to be the current working directory
4
+ root_dir: ${oc.env:PROJECT_ROOT}
5
+
6
+ # path to data directory
7
+ data_dir: ${paths.root_dir}/data/
8
+
9
+ # path to logging directory
10
+ log_dir: ${oc.env:LOG_DIR}
11
+
12
+ # path to output directory, created dynamically by hydra
13
+ # path generation pattern is specified in `configs/hydra/default.yaml`
14
+ # use it to store all files generated during the run, like ckpts and metrics
15
+ output_dir: ${hydra:runtime.output_dir}
16
+
17
+ # path to working directory
18
+ work_dir: ${hydra:runtime.cwd}
configs/trainer/ddp.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - default.yaml
3
+
4
+ # use "ddp_spawn" instead of "ddp",
5
+ # it's slower but normal "ddp" currently doesn't work ideally with hydra
6
+ # https://github.com/facebookresearch/hydra/issues/2070
7
+ # https://pytorch-lightning.readthedocs.io/en/latest/accelerators/gpu_intermediate.html#distributed-data-parallel-spawn
8
+ strategy: ddp_spawn
9
+
10
+ accelerator: gpu
11
+ devices: 2
12
+ num_nodes: 1
13
+ sync_batchnorm: True
configs/trainer/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ default_root_dir: ${paths.output_dir}
4
+
5
+ min_epochs: 1 # prevents early stopping
6
+ max_epochs: 10
7
+
8
+ accelerator: cpu
9
+ devices: 1
10
+
11
+ # mixed precision for extra speed-up
12
+ # precision: 16
13
+
14
+ # perform a validation loop every N training epochs
15
+ check_val_every_n_epoch: 1
16
+
17
+ # set True to to ensure deterministic results
18
+ # makes training slower but gives more reproducibility than just setting seeds
19
+ deterministic: False
configs/trainer/minimal.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: pytorch_lightning.Trainer
2
+
3
+ defaults:
4
+ - default
5
+
6
+ devices: 4
7
+
8
+ resume_from_checkpoint:
9
+ auto_lr_find: False
10
+ deterministic: True
11
+ accelerator: dp
12
+ sync_batchnorm: False
13
+
14
+ max_epochs: 3000
15
+ min_epochs: 1
16
+ check_val_every_n_epoch: 10
17
+ num_sanity_val_steps: 1
18
+
19
+ precision: 16
20
+ amp_backend: "native"
21
+ amp_level: "O2"
src/__init__.py ADDED
File without changes
src/callbacks/__init__.py ADDED
File without changes
src/callbacks/onnx_callback.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Dict, Any
3
+
4
+ import torch
5
+ from pytorch_lightning import Callback
6
+ import pytorch_lightning as pl
7
+ import inspect
8
+ from src.models.mdxnet import AbstractMDXNet
9
+
10
+
11
+ class MakeONNXCallback(Callback):
12
+ """Upload all *.py files to wandb as an artifact, at the beginning of the run."""
13
+
14
+ def __init__(self, dirpath: str):
15
+ self.dirpath = dirpath
16
+ if not os.path.exists(self.dirpath):
17
+ os.mkdir(self.dirpath)
18
+
19
+ def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule',
20
+ checkpoint: Dict[str, Any]) -> dict:
21
+ res = super().on_save_checkpoint(trainer, pl_module, checkpoint)
22
+
23
+ var = inspect.signature(pl_module.__init__).parameters
24
+ model = pl_module.__class__(**dict((name, pl_module.__dict__[name]) for name in var))
25
+ model.load_state_dict(pl_module.state_dict())
26
+
27
+ target_dir = '{}epoch_{}'.format(self.dirpath, pl_module.current_epoch)
28
+
29
+ try:
30
+ if not os.path.exists(target_dir):
31
+ os.mkdir(target_dir)
32
+
33
+ with torch.no_grad():
34
+ torch.onnx.export(model,
35
+ torch.zeros(model.inference_chunk_shape),
36
+ '{}/{}.onnx'.format(target_dir, model.target_name),
37
+ export_params=True, # store the trained parameter weights inside the model file
38
+ opset_version=13, # the ONNX version to export the model to
39
+ do_constant_folding=True, # whether to execute constant folding for optimization
40
+ input_names=['input'], # the model's input names
41
+ output_names=['output'], # the model's output names
42
+ dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
43
+ 'output': {0: 'batch_size'}})
44
+ except:
45
+ print('onnx error')
46
+ finally:
47
+ del model
48
+
49
+ return res
src/callbacks/wandb_callbacks.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import List, Optional, Any
4
+
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sn
7
+ import torch
8
+ import wandb
9
+ from pytorch_lightning import Callback, Trainer
10
+ from pytorch_lightning.loggers import WandbLogger
11
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
12
+ from sklearn import metrics
13
+ from sklearn.metrics import f1_score, precision_score, recall_score
14
+
15
+
16
+ def get_wandb_logger(trainer: Trainer) -> WandbLogger:
17
+ """Safely get Weights&Biases logger from Trainer."""
18
+
19
+ if isinstance(trainer.logger, WandbLogger):
20
+ return trainer.logger
21
+
22
+ if isinstance(trainer.loggers, list):
23
+ for logger in trainer.loggers:
24
+ if isinstance(logger, WandbLogger):
25
+ return logger
26
+
27
+ raise Exception(
28
+ "You are using wandb related callback, but WandbLogger was not found for some reason..."
29
+ )
30
+
31
+
32
+ class UploadValidTrack(Callback):
33
+ def __init__(self, crop: int, upload_after_n_epoch: int):
34
+ self.sample_length = crop * 44100
35
+ self.upload_after_n_epoch = upload_after_n_epoch
36
+ self.len_left_window = self.len_right_window = self.sample_length // 2
37
+
38
+ def on_validation_batch_end(
39
+ self,
40
+ trainer: 'pl.Trainer',
41
+ pl_module: 'pl.LightningModule',
42
+ outputs: Optional[STEP_OUTPUT],
43
+ batch: Any,
44
+ batch_idx: int,
45
+ dataloader_idx: int,
46
+ ) -> None:
47
+ if outputs is None:
48
+ return
49
+ track_id = outputs['track_id']
50
+ track = outputs['track']
51
+
52
+ logger = get_wandb_logger(trainer=trainer)
53
+ experiment = logger.experiment
54
+ if pl_module.current_epoch < self.upload_after_n_epoch:
55
+ return None
56
+
57
+ mid = track.shape[-1]//2
58
+ track = track[:, mid-self.len_left_window:mid+self.len_right_window]
59
+
60
+ experiment.log({'track={}_epoch={}'.format(track_id, pl_module.current_epoch):
61
+ [wandb.Audio(track.T, sample_rate=44100)]})
62
+
63
+
64
+ class WatchModel(Callback):
65
+ """Make wandb watch model at the beginning of the run."""
66
+
67
+ def __init__(self, log: str = "gradients", log_freq: int = 100):
68
+ self.log = log
69
+ self.log_freq = log_freq
70
+
71
+ def on_train_start(self, trainer, pl_module):
72
+ logger = get_wandb_logger(trainer=trainer)
73
+ logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
74
+
75
+
76
+ class UploadCodeAsArtifact(Callback):
77
+ """Upload all *.py files to wandb as an artifact, at the beginning of the run."""
78
+
79
+ def __init__(self, code_dir: str):
80
+ self.code_dir = code_dir
81
+
82
+ def on_train_start(self, trainer, pl_module):
83
+ logger = get_wandb_logger(trainer=trainer)
84
+ experiment = logger.experiment
85
+
86
+ code = wandb.Artifact("project-source", type="code")
87
+ for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True):
88
+ code.add_file(path)
89
+
90
+ experiment.use_artifact(code)
91
+
92
+
93
+ class UploadCheckpointsAsArtifact(Callback):
94
+ """Upload checkpoints to wandb as an artifact, at the end of run."""
95
+
96
+ def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
97
+ self.ckpt_dir = ckpt_dir
98
+ self.upload_best_only = upload_best_only
99
+
100
+ def on_train_end(self, trainer, pl_module):
101
+ logger = get_wandb_logger(trainer=trainer)
102
+ experiment = logger.experiment
103
+
104
+ ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
105
+
106
+ if self.upload_best_only:
107
+ ckpts.add_file(trainer.checkpoint_callback.best_model_path)
108
+ else:
109
+ for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True):
110
+ ckpts.add_file(path)
111
+
112
+ experiment.use_artifact(ckpts)
113
+
114
+
115
+ class LogConfusionMatrix(Callback):
116
+ """Generate confusion matrix every epoch and send it to wandb.
117
+ Expects validation step to return predictions and targets.
118
+ """
119
+
120
+ def __init__(self):
121
+ self.preds = []
122
+ self.targets = []
123
+ self.ready = True
124
+
125
+ def on_sanity_check_start(self, trainer, pl_module) -> None:
126
+ self.ready = False
127
+
128
+ def on_sanity_check_end(self, trainer, pl_module):
129
+ """Start executing this callback only after all validation sanity checks end."""
130
+ self.ready = True
131
+
132
+ def on_validation_batch_end(
133
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
134
+ ):
135
+ """Gather data from single batch."""
136
+ if self.ready:
137
+ self.preds.append(outputs["preds"])
138
+ self.targets.append(outputs["targets"])
139
+
140
+ def on_validation_epoch_end(self, trainer, pl_module):
141
+ """Generate confusion matrix."""
142
+ if self.ready:
143
+ logger = get_wandb_logger(trainer)
144
+ experiment = logger.experiment
145
+
146
+ preds = torch.cat(self.preds).cpu().numpy()
147
+ targets = torch.cat(self.targets).cpu().numpy()
148
+
149
+ confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds)
150
+
151
+ # set figure size
152
+ plt.figure(figsize=(14, 8))
153
+
154
+ # set labels size
155
+ sn.set(font_scale=1.4)
156
+
157
+ # set font size
158
+ sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g")
159
+
160
+ # names should be uniqe or else charts from different experiments in wandb will overlap
161
+ experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False)
162
+
163
+ # according to wandb docs this should also work but it crashes
164
+ # experiment.log(f{"confusion_matrix/{experiment.name}": plt})
165
+
166
+ # reset plot
167
+ plt.clf()
168
+
169
+ self.preds.clear()
170
+ self.targets.clear()
171
+
172
+
173
+ class LogF1PrecRecHeatmap(Callback):
174
+ """Generate f1, precision, recall heatmap every epoch and send it to wandb.
175
+ Expects validation step to return predictions and targets.
176
+ """
177
+
178
+ def __init__(self, class_names: List[str] = None):
179
+ self.preds = []
180
+ self.targets = []
181
+ self.ready = True
182
+
183
+ def on_sanity_check_start(self, trainer, pl_module):
184
+ self.ready = False
185
+
186
+ def on_sanity_check_end(self, trainer, pl_module):
187
+ """Start executing this callback only after all validation sanity checks end."""
188
+ self.ready = True
189
+
190
+ def on_validation_batch_end(
191
+ self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
192
+ ):
193
+ """Gather data from single batch."""
194
+ if self.ready:
195
+ self.preds.append(outputs["preds"])
196
+ self.targets.append(outputs["targets"])
197
+
198
+ def on_validation_epoch_end(self, trainer, pl_module):
199
+ """Generate f1, precision and recall heatmap."""
200
+ if self.ready:
201
+ logger = get_wandb_logger(trainer=trainer)
202
+ experiment = logger.experiment
203
+
204
+ preds = torch.cat(self.preds).cpu().numpy()
205
+ targets = torch.cat(self.targets).cpu().numpy()
206
+ f1 = f1_score(preds, targets, average=None)
207
+ r = recall_score(preds, targets, average=None)
208
+ p = precision_score(preds, targets, average=None)
209
+ data = [f1, p, r]
210
+
211
+ # set figure size
212
+ plt.figure(figsize=(14, 3))
213
+
214
+ # set labels size
215
+ sn.set(font_scale=1.2)
216
+
217
+ # set font size
218
+ sn.heatmap(
219
+ data,
220
+ annot=True,
221
+ annot_kws={"size": 10},
222
+ fmt=".3f",
223
+ yticklabels=["F1", "Precision", "Recall"],
224
+ )
225
+
226
+ # names should be uniqe or else charts from different experiments in wandb will overlap
227
+ experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False)
228
+
229
+ # reset plot
230
+ plt.clf()
231
+
232
+ self.preds.clear()
233
+ self.targets.clear()
234
+
235
+
236
+ class LogImagePredictions(Callback):
237
+ """Logs a validation batch and their predictions to wandb.
238
+ Example adapted from:
239
+ https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY
240
+ """
241
+
242
+ def __init__(self, num_samples: int = 8):
243
+ super().__init__()
244
+ self.num_samples = num_samples
245
+ self.ready = True
246
+
247
+ def on_sanity_check_start(self, trainer, pl_module):
248
+ self.ready = False
249
+
250
+ def on_sanity_check_end(self, trainer, pl_module):
251
+ """Start executing this callback only after all validation sanity checks end."""
252
+ self.ready = True
253
+
254
+ def on_validation_epoch_end(self, trainer, pl_module):
255
+ if self.ready:
256
+ logger = get_wandb_logger(trainer=trainer)
257
+ experiment = logger.experiment
258
+
259
+ # get a validation batch from the validation dat loader
260
+ val_samples = next(iter(trainer.datamodule.val_dataloader()))
261
+ val_imgs, val_labels = val_samples
262
+
263
+ # run the batch through the network
264
+ val_imgs = val_imgs.to(device=pl_module.device)
265
+ logits = pl_module(val_imgs)
266
+ preds = torch.argmax(logits, axis=-1)
267
+
268
+ # log the images as wandb Image
269
+ experiment.log(
270
+ {
271
+ f"Images/{experiment.name}": [
272
+ wandb.Image(x, caption=f"Pred:{pred}, Label:{y}")
273
+ for x, pred, y in zip(
274
+ val_imgs[: self.num_samples],
275
+ preds[: self.num_samples],
276
+ val_labels[: self.num_samples],
277
+ )
278
+ ]
279
+ }
280
+ )
src/datamodules/__init__.py ADDED
File without changes
src/datamodules/datasets/__init__.py ADDED
File without changes
src/datamodules/datasets/musdb.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABCMeta, ABC
3
+ from pathlib import Path
4
+
5
+ import soundfile
6
+ from torch.utils.data import Dataset
7
+ import torch
8
+ import numpy as np
9
+ import random
10
+ from tqdm import tqdm
11
+
12
+ from src.utils.utils import load_wav
13
+ from src import utils
14
+ import numpy as np
15
+
16
+ log = utils.get_pylogger(__name__)
17
+
18
+ def check_target_name(target_name, source_names):
19
+ try:
20
+ assert target_name is not None
21
+ except AssertionError:
22
+ print('[ERROR] please identify target name. ex) +datamodule.target_name="vocals"')
23
+ exit(-1)
24
+ try:
25
+ assert target_name in source_names or target_name == 'all'
26
+ except AssertionError:
27
+ print('[ERROR] target name should one of "bass", "drums", "other", "vocals", "all"')
28
+ exit(-1)
29
+
30
+
31
+ def check_sample_rate(sr, sample_track):
32
+ try:
33
+ sample_rate = soundfile.read(sample_track)[1]
34
+ assert sample_rate == sr
35
+ except AssertionError:
36
+ sample_rate = soundfile.read(sample_track)[1]
37
+ print('[ERROR] sampling rate mismatched')
38
+ print('\t=> sr in Config file: {}, but sr of data: {}'.format(sr, sample_rate))
39
+ exit(-1)
40
+
41
+
42
+ class MusdbDataset(Dataset):
43
+ __metaclass__ = ABCMeta
44
+
45
+ def __init__(self, data_dir, chunk_size):
46
+ self.source_names = ['bass', 'drums', 'other', 'vocals']
47
+ self.chunk_size = chunk_size
48
+ self.musdb_path = Path(data_dir)
49
+
50
+
51
+ class MusdbTrainDataset(MusdbDataset):
52
+ def __init__(self, data_dir, chunk_size, target_name, aug_params, external_datasets, single_channel, epoch_size):
53
+ super(MusdbTrainDataset, self).__init__(data_dir, chunk_size)
54
+
55
+ self.single_channel = single_channel
56
+ self.neg_lst = [x for x in self.source_names if x != target_name]
57
+
58
+ self.target_name = target_name
59
+ check_target_name(self.target_name, self.source_names)
60
+
61
+ if not self.musdb_path.joinpath('metadata').exists():
62
+ os.mkdir(self.musdb_path.joinpath('metadata'))
63
+
64
+ splits = ['train']
65
+ if external_datasets is not None:
66
+ splits += external_datasets
67
+
68
+ # collect paths for datasets and metadata (track names and duration)
69
+ datasets, metadata_caches = [], []
70
+ raw_datasets = [] # un-augmented datasets
71
+ for split in splits:
72
+ raw_datasets.append(self.musdb_path.joinpath(split))
73
+ max_pitch, max_tempo = aug_params
74
+ for p in range(-max_pitch, max_pitch+1):
75
+ for t in range(-max_tempo, max_tempo+1, 10):
76
+ aug_split = split if p==t==0 else split + f'_p={p}_t={t}'
77
+ datasets.append(self.musdb_path.joinpath(aug_split))
78
+ metadata_caches.append(self.musdb_path.joinpath('metadata').joinpath(aug_split + '.pkl'))
79
+
80
+ # collect all track names and their duration
81
+ self.metadata = []
82
+ raw_track_lengths = [] # for calculating epoch size
83
+ for i, (dataset, metadata_cache) in enumerate(tqdm(zip(datasets, metadata_caches))):
84
+ try:
85
+ metadata = torch.load(metadata_cache)
86
+ except FileNotFoundError:
87
+ print('creating metadata for', dataset)
88
+ metadata = []
89
+ for track_name in sorted(os.listdir(dataset)):
90
+ track_path = dataset.joinpath(track_name)
91
+ track_length = load_wav(track_path.joinpath('vocals.wav')).shape[-1]
92
+ metadata.append((track_path, track_length))
93
+ torch.save(metadata, metadata_cache)
94
+
95
+ self.metadata += metadata
96
+ if dataset in raw_datasets:
97
+ raw_track_lengths += [length for path, length in metadata]
98
+
99
+ self.epoch_size = sum(raw_track_lengths) // self.chunk_size if epoch_size is None else epoch_size
100
+ log.info(f'epoch size: {self.epoch_size}')
101
+
102
+ def __getitem__(self, _):
103
+ sources = []
104
+ for source_name in self.source_names:
105
+ track_path, track_length = random.choice(self.metadata) # random mixing between tracks
106
+ source = load_wav(track_path.joinpath(source_name + '.wav'),
107
+ track_length=track_length, chunk_size=self.chunk_size) # (2, times)
108
+ sources.append(source)
109
+
110
+ mix = sum(sources)
111
+
112
+ if self.target_name == 'all':
113
+ # Targets for models that separate all four sources (ex. Demucs).
114
+ # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time]
115
+ target = sources
116
+ else:
117
+ target = sources[self.source_names.index(self.target_name)]
118
+
119
+ mix, target = torch.tensor(mix), torch.tensor(target)
120
+ if self.single_channel:
121
+ mix = torch.mean(mix, dim=0, keepdim=True)
122
+ target = torch.mean(target, dim=0, keepdim=True)
123
+ return mix, target
124
+
125
+ def __len__(self):
126
+ return self.epoch_size
127
+
128
+
129
+ class MusdbValidDataset(MusdbDataset):
130
+
131
+ def __init__(self, data_dir, chunk_size, target_name, overlap, batch_size, single_channel):
132
+ super(MusdbValidDataset, self).__init__(data_dir, chunk_size)
133
+
134
+ self.target_name = target_name
135
+ check_target_name(self.target_name, self.source_names)
136
+
137
+ self.overlap = overlap
138
+ self.batch_size = batch_size
139
+ self.single_channel = single_channel
140
+
141
+ musdb_valid_path = self.musdb_path.joinpath('valid')
142
+ self.track_paths = [musdb_valid_path.joinpath(track_name)
143
+ for track_name in os.listdir(musdb_valid_path)]
144
+
145
+ def __getitem__(self, index):
146
+ mix = load_wav(self.track_paths[index].joinpath('mixture.wav')) # (2, time)
147
+
148
+ if self.target_name == 'all':
149
+ # Targets for models that separate all four sources (ex. Demucs).
150
+ # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time]
151
+ target = [load_wav(self.track_paths[index].joinpath(source_name + '.wav'))
152
+ for source_name in self.source_names]
153
+ else:
154
+ target = load_wav(self.track_paths[index].joinpath(self.target_name + '.wav'))
155
+
156
+ chunk_output_size = self.chunk_size - 2 * self.overlap
157
+ left_pad = np.zeros([2, self.overlap])
158
+ right_pad = np.zeros([2, self.overlap + chunk_output_size - (mix.shape[-1] % chunk_output_size)])
159
+ mix_padded = np.concatenate([left_pad, mix, right_pad], 1)
160
+
161
+ num_chunks = mix_padded.shape[-1] // chunk_output_size
162
+ mix_chunks = np.array([mix_padded[:, i * chunk_output_size: i * chunk_output_size + self.chunk_size]
163
+ for i in range(num_chunks)])
164
+ mix_chunk_batches = torch.tensor(mix_chunks, dtype=torch.float32).split(self.batch_size)
165
+ target = torch.tensor(target)
166
+
167
+ if self.single_channel:
168
+ mix_chunk_batches = [torch.mean(t, dim=1, keepdim=True) for t in mix_chunk_batches]
169
+ target = torch.mean(target, dim=0, keepdim=True)
170
+
171
+ return mix_chunk_batches, target
172
+
173
+ def __len__(self):
174
+ return len(self.track_paths)
src/datamodules/musdb_datamodule.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import exists, join
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple
5
+
6
+ from pytorch_lightning import LightningDataModule
7
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
8
+
9
+ from src.datamodules.datasets.musdb import MusdbTrainDataset, MusdbValidDataset
10
+
11
+
12
+ class MusdbDataModule(LightningDataModule):
13
+ """
14
+ LightningDataModule for Musdb18-HQ dataset.
15
+ A DataModule implements 5 key methods:
16
+ - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
17
+ - setup (things to do on every accelerator in distributed mode)
18
+ - train_dataloader (the training dataloader)
19
+ - val_dataloader (the validation dataloader(s))
20
+ - test_dataloader (the test dataloader(s))
21
+ This allows you to share a full dataset without explaining how to download,
22
+ split, transform and process the data
23
+ Read the docs:
24
+ https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ data_dir: str,
30
+ aug_params,
31
+ target_name: str,
32
+ overlap: int,
33
+ hop_length: int,
34
+ dim_t: int,
35
+ sample_rate: int,
36
+ batch_size: int,
37
+ num_workers: int,
38
+ pin_memory: bool,
39
+ external_datasets,
40
+ audio_ch: int,
41
+ epoch_size,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.data_dir = Path(data_dir)
47
+ self.target_name = target_name
48
+ self.aug_params = aug_params
49
+ self.external_datasets = external_datasets
50
+
51
+ self.batch_size = batch_size
52
+ self.num_workers = num_workers
53
+ self.pin_memory = pin_memory
54
+
55
+ # audio-related
56
+ self.hop_length = hop_length
57
+ self.sample_rate = sample_rate
58
+ self.single_channel = audio_ch == 1
59
+
60
+ # derived
61
+ self.chunk_size = hop_length * (dim_t - 1)
62
+ self.overlap = overlap
63
+
64
+ self.epoch_size = epoch_size
65
+
66
+ self.data_train: Optional[Dataset] = None
67
+ self.data_val: Optional[Dataset] = None
68
+ self.data_test: Optional[Dataset] = None
69
+
70
+ trainset_path = self.data_dir.joinpath('train')
71
+ validset_path = self.data_dir.joinpath('valid')
72
+
73
+ # create validation split
74
+ if not exists(validset_path):
75
+ from shutil import move
76
+ os.mkdir(validset_path)
77
+ for track in kwargs['validation_set']:
78
+ if trainset_path.joinpath(track).exists():
79
+ move(trainset_path.joinpath(track), validset_path.joinpath(track))
80
+ else:
81
+ valid_files = os.listdir(validset_path)
82
+ assert set(valid_files) == set(kwargs['validation_set'])
83
+
84
+ def setup(self, stage: Optional[str] = None):
85
+ """Load data. Set variables: self.data_train, self.data_val, self.data_test."""
86
+ self.data_train = MusdbTrainDataset(self.data_dir,
87
+ self.chunk_size,
88
+ self.target_name,
89
+ self.aug_params,
90
+ self.external_datasets,
91
+ self.single_channel,
92
+ self.epoch_size)
93
+
94
+ self.data_val = MusdbValidDataset(self.data_dir,
95
+ self.chunk_size,
96
+ self.target_name,
97
+ self.overlap,
98
+ self.batch_size,
99
+ self.single_channel)
100
+
101
+ def train_dataloader(self):
102
+ return DataLoader(
103
+ dataset=self.data_train,
104
+ batch_size=self.batch_size,
105
+ num_workers=self.num_workers,
106
+ pin_memory=self.pin_memory,
107
+ shuffle=True,
108
+ )
109
+
110
+ def val_dataloader(self):
111
+ return DataLoader(
112
+ dataset=self.data_val,
113
+ batch_size=1,
114
+ num_workers=self.num_workers,
115
+ pin_memory=self.pin_memory,
116
+ shuffle=False,
117
+ )
src/dp_tdf/__init__.py ADDED
File without changes
src/dp_tdf/abstract.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from pytorch_lightning import LightningModule
10
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
11
+
12
+ from src.utils.utils import sdr, simplified_msseval
13
+
14
+
15
+ class AbstractModel(LightningModule):
16
+ __metaclass__ = ABCMeta
17
+
18
+ def __init__(self, target_name,
19
+ lr, optimizer,
20
+ dim_f, dim_t, n_fft, hop_length, overlap,
21
+ audio_ch,
22
+ **kwargs):
23
+ super().__init__()
24
+ self.target_name = target_name
25
+ self.lr = lr
26
+ self.optimizer = optimizer
27
+ self.dim_c_in = audio_ch * 2
28
+ self.dim_c_out = audio_ch * 2
29
+ self.dim_f = dim_f
30
+ self.dim_t = dim_t
31
+ self.n_fft = n_fft
32
+ self.n_bins = n_fft // 2 + 1
33
+ self.hop_length = hop_length
34
+ self.audio_ch = audio_ch
35
+
36
+ self.chunk_size = hop_length * (self.dim_t - 1)
37
+ self.inference_chunk_size = hop_length * (self.dim_t*2 - 1)
38
+ self.overlap = overlap
39
+ self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False)
40
+ self.freq_pad = nn.Parameter(torch.zeros([1, self.dim_c_out, self.n_bins - self.dim_f, 1]), requires_grad=False)
41
+ self.inference_chunk_shape = (self.stft(torch.zeros([1, audio_ch, self.inference_chunk_size]))).shape
42
+
43
+
44
+ def configure_optimizers(self):
45
+ if self.optimizer == 'rmsprop':
46
+ print("Using RMSprop optimizer")
47
+ return torch.optim.RMSprop(self.parameters(), self.lr)
48
+ elif self.optimizer == 'adamW':
49
+ print("Using AdamW optimizer")
50
+ return torch.optim.AdamW(self.parameters(), self.lr)
51
+
52
+ def comp_loss(self, pred_detail, target_wave):
53
+ pred_detail = self.istft(pred_detail)
54
+
55
+ comp_loss = F.l1_loss(pred_detail, target_wave)
56
+
57
+ self.log("train/comp_loss", comp_loss, sync_dist=True, on_step=False, on_epoch=True, prog_bar=False)
58
+
59
+ return comp_loss
60
+
61
+
62
+ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
63
+ mix_wave, target_wave = args[0] # (batch, c, 261120)
64
+
65
+ # input 1
66
+ stft_44k = self.stft(mix_wave) # (batch, c*2, 1044, 256)
67
+ # forward
68
+ t_est_stft = self(stft_44k) # (batch, c, 1044, 256)
69
+
70
+ loss = self.comp_loss(t_est_stft, target_wave)
71
+
72
+ self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True)
73
+
74
+ return {"loss": loss}
75
+
76
+
77
+ # Validation SDR is calculated on whole tracks and not chunks since
78
+ # short inputs have high possibility of being silent (all-zero signal)
79
+ # which leads to very low sdr values regardless of the model.
80
+ # A natural procedure would be to split a track into chunk batches and
81
+ # load them on multiple gpus, but aggregation was too difficult.
82
+ # So instead we load one whole track on a single device (data_loader batch_size should always be 1)
83
+ # and do all the batch splitting and aggregation on a single device.
84
+ def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
85
+ mix_chunk_batches, target = args[0]
86
+
87
+ # remove data_loader batch dimension
88
+ # [(b, c, time)], (c, all_times)
89
+ mix_chunk_batches, target = [batch[0] for batch in mix_chunk_batches], target[0]
90
+
91
+ # process whole track in batches of chunks
92
+ target_hat_chunks = []
93
+ for batch in mix_chunk_batches:
94
+ # input
95
+ stft_44k = self.stft(batch) # (batch, c*2, 1044, 256)
96
+ pred_detail = self(stft_44k) # (batch, c, 1044, 256), irm
97
+ pred_detail = self.istft(pred_detail)
98
+
99
+ target_hat_chunks.append(pred_detail[..., self.overlap:-self.overlap])
100
+ target_hat_chunks = torch.cat(target_hat_chunks) # (b*len(ls),c,t)
101
+
102
+ # concat all output chunks (c, all_times)
103
+ target_hat = target_hat_chunks.transpose(0, 1).reshape(self.audio_ch, -1)[..., :target.shape[-1]]
104
+
105
+ ests = target_hat.detach().cpu().numpy() # (c, all_times)
106
+ references = target.cpu().numpy()
107
+ score = sdr(ests, references)
108
+
109
+ # (src, t, c)
110
+ SDR = simplified_msseval(np.expand_dims(references.T, axis=0), np.expand_dims(ests.T, axis=0), chunk_size=44100)
111
+ # self.log("val/sdr", score, sync_dist=True, on_step=False, on_epoch=True, logger=True)
112
+
113
+ return {'song': score, 'chunk': SDR}
114
+
115
+ def validation_epoch_end(self, outputs) -> None:
116
+ avg_uSDR = torch.Tensor([x['song'] for x in outputs]).mean()
117
+ self.log("val/usdr", avg_uSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True)
118
+
119
+ chunks = [x['chunk'][0, :] for x in outputs]
120
+ # concat np array
121
+ chunks = np.concatenate(chunks, axis=0)
122
+ median_cSDR = np.nanmedian(chunks.flatten(), axis=0)
123
+ median_cSDR = float(median_cSDR)
124
+ self.log("val/csdr", median_cSDR, sync_dist=True, on_step=False, on_epoch=True, logger=True)
125
+
126
+ def stft(self, x):
127
+ '''
128
+ Args:
129
+ x: (batch, c, 261120)
130
+ '''
131
+ dim_b = x.shape[0]
132
+ x = x.reshape([dim_b * self.audio_ch, -1]) # (batch*c, 261120)
133
+ x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) # (batch*c, 3073, 256, 2)
134
+ x = x.permute([0, 3, 1, 2]) # (batch*c, 2, 3073, 256)
135
+ x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b, self.audio_ch * 2, self.n_bins, -1]) # (batch, c*2, 3073, 256)
136
+ return x[:, :, :self.dim_f] # (batch, c*2, 2048, 256)
137
+
138
+ def istft(self, x):
139
+ '''
140
+ Args:
141
+ x: (batch, c*2, 2048, 256)
142
+ '''
143
+ dim_b = x.shape[0]
144
+ x = torch.cat([x, self.freq_pad.repeat([x.shape[0], 1, 1, x.shape[-1]])], -2) # (batch, c*2, 3073, 256)
145
+ x = x.reshape([dim_b, self.audio_ch, 2, self.n_bins, -1]).reshape([dim_b * self.audio_ch, 2, self.n_bins, -1]) # (batch*c, 2, 3073, 256)
146
+ x = x.permute([0, 2, 3, 1]) # (batch*c, 3073, 256, 2)
147
+ x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) # (batch*c, 261120)
148
+ return x.reshape([dim_b, self.audio_ch, -1]) # (batch,c,261120)
149
+
150
+ def demix(self, mix, inf_chunk_size, batch_size=5, inf_overf=4):
151
+ '''
152
+ Args:
153
+ mix: (C, L)
154
+ Returns:
155
+ est: (src, C, L)
156
+ '''
157
+
158
+ # batch_size = self.config.inference.batch_size
159
+ # = self.chunk_size
160
+ # self.instruments = ['bass', 'drums', 'other', 'vocals']
161
+ num_instruments = 1
162
+
163
+ inf_hop = inf_chunk_size // inf_overf # hop size
164
+ L = mix.shape[1]
165
+ pad_size = inf_hop - (L - inf_chunk_size) % inf_hop
166
+ mix = torch.cat([torch.zeros(2, inf_chunk_size - inf_hop), torch.Tensor(mix), torch.zeros(2, pad_size + inf_chunk_size - inf_hop)], 1)
167
+ mix = mix.cuda()
168
+
169
+ chunks = []
170
+ i = 0
171
+ while i + inf_chunk_size <= mix.shape[1]:
172
+ chunks.append(mix[:, i:i + inf_chunk_size])
173
+ i += inf_hop
174
+ chunks = torch.stack(chunks)
175
+
176
+ batches = []
177
+ i = 0
178
+ while i < len(chunks):
179
+ batches.append(chunks[i:i + batch_size])
180
+ i = i + batch_size
181
+
182
+ X = torch.zeros(num_instruments, 2, inf_chunk_size - inf_hop) # (src, c, t)
183
+ X = X.cuda()
184
+ with torch.cuda.amp.autocast():
185
+ with torch.no_grad():
186
+ for batch in batches:
187
+ x = self.stft(batch)
188
+ x = self(x)
189
+ x = self.istft(x) # (batch, c, 261120)
190
+ # insert new axis, the model only predict 1 src so we need to add axis
191
+ x = x[:,None, ...] # (batch, 1, c, 261120)
192
+ x = x.repeat([ 1, num_instruments, 1, 1]) # (batch, src, c, 261120)
193
+ for w in x: # iterate over batch
194
+ a = X[..., :-(inf_chunk_size - inf_hop)]
195
+ b = X[..., -(inf_chunk_size - inf_hop):] + w[..., :(inf_chunk_size - inf_hop)]
196
+ c = w[..., (inf_chunk_size - inf_hop):]
197
+ X = torch.cat([a, b, c], -1)
198
+
199
+ estimated_sources = X[..., inf_chunk_size - inf_hop:-(pad_size + inf_chunk_size - inf_hop)] / inf_overf
200
+
201
+ assert L == estimated_sources.shape[-1]
202
+
203
+ return estimated_sources
204
+
src/dp_tdf/bandsequence.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # Original code from https://github.com/amanteur/BandSplitRNN-Pytorch
5
+ class RNNModule(nn.Module):
6
+ """
7
+ RNN submodule of BandSequence module
8
+ """
9
+
10
+ def __init__(
11
+ self,
12
+ group_num: int,
13
+ input_dim_size: int,
14
+ hidden_dim_size: int,
15
+ rnn_type: str = 'lstm',
16
+ bidirectional: bool = True
17
+ ):
18
+ super(RNNModule, self).__init__()
19
+ self.groupnorm = nn.GroupNorm(group_num, input_dim_size)
20
+ self.rnn = getattr(nn, rnn_type)(
21
+ input_dim_size, hidden_dim_size, batch_first=True, bidirectional=bidirectional # 输出是2*hidden_dim_size,因为是bi
22
+ )
23
+ self.fc = nn.Linear(
24
+ hidden_dim_size * 2 if bidirectional else hidden_dim_size,
25
+ input_dim_size
26
+ )
27
+
28
+ def forward(
29
+ self,
30
+ x: torch.Tensor
31
+ ):
32
+ """
33
+ Input shape:
34
+ across T - [batch_size, k_subbands, time, n_features]
35
+ OR
36
+ across K - [batch_size, time, k_subbands, n_features]
37
+ """
38
+ B, K, T, N = x.shape # across T across K (keep in mind T->K, K->T)
39
+ # print(x.shape)
40
+
41
+ out = x.view(B * K, T, N) # [BK, T, N] [BT, K, N]
42
+
43
+ # print(out.shape)
44
+ # print(self.groupnorm)
45
+ out = self.groupnorm(
46
+ out.transpose(-1, -2)
47
+ ).transpose(-1, -2) # [BK, T, N] [BT, K, N]
48
+ out = self.rnn(out)[0] # [BK, T, H] [BT, K, H], 最后一维是特征
49
+ out = self.fc(out) # [BK, T, N] [BT, K, N]
50
+
51
+ x = out.view(B, K, T, N) + x # [B, K, T, N] [B, T, K, N]
52
+
53
+ x = x.permute(0, 2, 1, 3).contiguous() # [B, T, K, N] [B, K, T, N]
54
+ return x
55
+
56
+
57
+ class BandSequenceModelModule(nn.Module):
58
+ """
59
+ BandSequence (2nd) Module of BandSplitRNN.
60
+ Runs input through n BiLSTMs in two dimensions - time and subbands.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ # group_num,
66
+ input_dim_size: int,
67
+ hidden_dim_size: int,
68
+ rnn_type: str = 'lstm',
69
+ bidirectional: bool = True,
70
+ num_layers: int = 12,
71
+ n_heads: int = 4,
72
+ ):
73
+ super(BandSequenceModelModule, self).__init__()
74
+
75
+ self.bsrnn = nn.ModuleList([])
76
+ self.n_heads = n_heads
77
+
78
+ input_dim_size = input_dim_size // n_heads
79
+ hidden_dim_size = hidden_dim_size // n_heads
80
+ group_num = input_dim_size // 16
81
+ # print(f"input_dim_size: {input_dim_size}, hidden_dim_size: {hidden_dim_size}, group_num: {group_num}")
82
+
83
+ # print(group_num, input_dim_size)
84
+
85
+ for _ in range(num_layers):
86
+ rnn_across_t = RNNModule(
87
+ group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional
88
+ )
89
+ rnn_across_k = RNNModule(
90
+ group_num, input_dim_size, hidden_dim_size, rnn_type, bidirectional
91
+ )
92
+ self.bsrnn.append(
93
+ nn.Sequential(rnn_across_t, rnn_across_k)
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ """
98
+ Input shape: [batch_size, k_subbands, time, n_features]
99
+ Output shape: [batch_size, k_subbands, time, n_features]
100
+ """
101
+ # x (b,c,t,f)
102
+ b,c,t,f = x.shape
103
+ x = x.view(b * self.n_heads, c // self.n_heads, t, f) # [b*n_heads, c//n_heads, t, f]
104
+
105
+ x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, f, t, c//n_heads]
106
+ for i in range(len(self.bsrnn)):
107
+ x = self.bsrnn[i](x)
108
+
109
+ x = x.permute(0, 3, 2, 1).contiguous() # [b*n_heads, c//n_heads, t, f]
110
+ x = x.view(b, c, t, f) # [b, c, t, f]
111
+ return x
112
+
113
+
114
+ if __name__ == '__main__':
115
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
116
+
117
+ batch_size, k_subbands, t_timesteps, input_dim = 4, 41, 512, 128
118
+ in_features = torch.rand(batch_size, k_subbands, t_timesteps, input_dim).to(device)
119
+
120
+ cfg = {
121
+ # "t_timesteps": t_timesteps,
122
+ "group_num": 32,
123
+ "input_dim_size": 128,
124
+ "hidden_dim_size": 256,
125
+ "rnn_type": "LSTM",
126
+ "bidirectional": True,
127
+ "num_layers": 1
128
+ }
129
+ model = BandSequenceModelModule(**cfg).to(device)
130
+ _ = model.eval()
131
+
132
+ with torch.no_grad():
133
+ out_features = model(in_features)
134
+
135
+ print(f"In: {in_features.shape}\nOut: {out_features.shape}")
136
+ print(f"Total number of parameters: {sum([p.numel() for p in model.parameters()])}")
src/dp_tdf/dp_tdf_net.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ from src.dp_tdf.modules import TFC_TDF, TFC_TDF_Res1, TFC_TDF_Res2
5
+ from src.dp_tdf.bandsequence import BandSequenceModelModule
6
+
7
+ from src.layers import (get_norm)
8
+ from src.dp_tdf.abstract import AbstractModel
9
+
10
+ class DPTDFNet(AbstractModel):
11
+ def __init__(self, num_blocks, l, g, k, bn, bias, bn_norm, bandsequence, block_type, **kwargs):
12
+
13
+ super(DPTDFNet, self).__init__(**kwargs)
14
+ # self.save_hyperparameters()
15
+
16
+ self.num_blocks = num_blocks
17
+ self.l = l
18
+ self.g = g
19
+ self.k = k
20
+ self.bn = bn
21
+ self.bias = bias
22
+
23
+ self.n = num_blocks // 2
24
+ scale = (2, 2)
25
+
26
+ if block_type == "TFC_TDF":
27
+ T_BLOCK = TFC_TDF
28
+ elif block_type == "TFC_TDF_Res1":
29
+ T_BLOCK = TFC_TDF_Res1
30
+ elif block_type == "TFC_TDF_Res2":
31
+ T_BLOCK = TFC_TDF_Res2
32
+ else:
33
+ raise ValueError(f"Unknown block type {block_type}")
34
+
35
+ self.first_conv = nn.Sequential(
36
+ nn.Conv2d(in_channels=self.dim_c_in, out_channels=g, kernel_size=(1, 1)),
37
+ get_norm(bn_norm, g),
38
+ nn.ReLU(),
39
+ )
40
+
41
+ f = self.dim_f
42
+ c = g
43
+ self.encoding_blocks = nn.ModuleList()
44
+ self.ds = nn.ModuleList()
45
+
46
+ for i in range(self.n):
47
+ c_in = c
48
+
49
+ self.encoding_blocks.append(T_BLOCK(c_in, c, l, f, k, bn, bn_norm, bias=bias))
50
+ self.ds.append(
51
+ nn.Sequential(
52
+ nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale),
53
+ get_norm(bn_norm, c + g),
54
+ nn.ReLU()
55
+ )
56
+ )
57
+ f = f // 2
58
+ c += g
59
+
60
+ self.bottleneck_block1 = T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias)
61
+ self.bottleneck_block2 = BandSequenceModelModule(
62
+ **bandsequence,
63
+ input_dim_size=c,
64
+ hidden_dim_size=2*c
65
+ )
66
+
67
+ self.decoding_blocks = nn.ModuleList()
68
+ self.us = nn.ModuleList()
69
+ for i in range(self.n):
70
+ # print(f"i: {i}, in channels: {c}")
71
+ self.us.append(
72
+ nn.Sequential(
73
+ nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale),
74
+ get_norm(bn_norm, c - g),
75
+ nn.ReLU()
76
+ )
77
+ )
78
+
79
+ f = f * 2
80
+ c -= g
81
+
82
+ self.decoding_blocks.append(T_BLOCK(c, c, l, f, k, bn, bn_norm, bias=bias))
83
+
84
+ self.final_conv = nn.Sequential(
85
+ nn.Conv2d(in_channels=c, out_channels=self.dim_c_out, kernel_size=(1, 1)),
86
+ )
87
+
88
+ def forward(self, x):
89
+ '''
90
+ Args:
91
+ x: (batch, c*2, 2048, 256)
92
+ '''
93
+ x = self.first_conv(x)
94
+
95
+ x = x.transpose(-1, -2)
96
+
97
+ ds_outputs = []
98
+ for i in range(self.n):
99
+ x = self.encoding_blocks[i](x)
100
+ ds_outputs.append(x)
101
+ x = self.ds[i](x)
102
+
103
+ # print(f"bottleneck in: {x.shape}")
104
+ x = self.bottleneck_block1(x)
105
+ x = self.bottleneck_block2(x)
106
+
107
+ for i in range(self.n):
108
+ x = self.us[i](x)
109
+ # print(f"us{i} in: {x.shape}")
110
+ # print(f"ds{i} out: {ds_outputs[-i - 1].shape}")
111
+ x = x * ds_outputs[-i - 1]
112
+ x = self.decoding_blocks[i](x)
113
+
114
+ x = x.transpose(-1, -2)
115
+
116
+ x = self.final_conv(x)
117
+
118
+ return x
src/dp_tdf/modules.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from src.layers import (get_norm)
5
+
6
+ class TFC(nn.Module):
7
+ def __init__(self, c_in, c_out, l, k, bn_norm):
8
+ super(TFC, self).__init__()
9
+
10
+ self.H = nn.ModuleList()
11
+ for i in range(l):
12
+ if i == 0:
13
+ c_in = c_in
14
+ else:
15
+ c_in = c_out
16
+ self.H.append(
17
+ nn.Sequential(
18
+ nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2),
19
+ get_norm(bn_norm, c_out),
20
+ nn.ReLU(),
21
+ )
22
+ )
23
+
24
+ def forward(self, x):
25
+ for h in self.H:
26
+ x = h(x)
27
+ return x
28
+
29
+
30
+ class DenseTFC(nn.Module):
31
+ def __init__(self, c_in, c_out, l, k, bn_norm):
32
+ super(DenseTFC, self).__init__()
33
+
34
+ self.conv = nn.ModuleList()
35
+ for i in range(l):
36
+ self.conv.append(
37
+ nn.Sequential(
38
+ nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=k, stride=1, padding=k // 2),
39
+ get_norm(bn_norm, c_out),
40
+ nn.ReLU(),
41
+ )
42
+ )
43
+
44
+ def forward(self, x):
45
+ for layer in self.conv[:-1]:
46
+ x = torch.cat([layer(x), x], 1)
47
+ return self.conv[-1](x)
48
+
49
+
50
+ class TFC_TDF(nn.Module):
51
+ def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
52
+
53
+ super(TFC_TDF, self).__init__()
54
+
55
+ self.use_tdf = bn is not None
56
+
57
+ self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm)
58
+
59
+ if self.use_tdf:
60
+ if bn == 0:
61
+ # print(f"TDF={f},{f}")
62
+ self.tdf = nn.Sequential(
63
+ nn.Linear(f, f, bias=bias),
64
+ get_norm(bn_norm, c_out),
65
+ nn.ReLU()
66
+ )
67
+ else:
68
+ # print(f"TDF={f},{f // bn},{f}")
69
+ self.tdf = nn.Sequential(
70
+ nn.Linear(f, f // bn, bias=bias),
71
+ get_norm(bn_norm, c_out),
72
+ nn.ReLU(),
73
+ nn.Linear(f // bn, f, bias=bias),
74
+ get_norm(bn_norm, c_out),
75
+ nn.ReLU()
76
+ )
77
+
78
+ def forward(self, x):
79
+ x = self.tfc(x)
80
+ return x + self.tdf(x) if self.use_tdf else x
81
+
82
+
83
+ class TFC_TDF_Res1(nn.Module):
84
+ def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
85
+
86
+ super(TFC_TDF_Res1, self).__init__()
87
+
88
+ self.use_tdf = bn is not None
89
+
90
+ self.tfc = DenseTFC(c_in, c_out, l, k, bn_norm) if dense else TFC(c_in, c_out, l, k, bn_norm)
91
+
92
+ self.res = TFC(c_in, c_out, 1, k, bn_norm)
93
+
94
+ if self.use_tdf:
95
+ if bn == 0:
96
+ # print(f"TDF={f},{f}")
97
+ self.tdf = nn.Sequential(
98
+ nn.Linear(f, f, bias=bias),
99
+ get_norm(bn_norm, c_out),
100
+ nn.ReLU()
101
+ )
102
+ else:
103
+ # print(f"TDF={f},{f // bn},{f}")
104
+ self.tdf = nn.Sequential(
105
+ nn.Linear(f, f // bn, bias=bias),
106
+ get_norm(bn_norm, c_out),
107
+ nn.ReLU(),
108
+ nn.Linear(f // bn, f, bias=bias),
109
+ get_norm(bn_norm, c_out),
110
+ nn.ReLU()
111
+ )
112
+
113
+ def forward(self, x):
114
+ res = self.res(x)
115
+ x = self.tfc(x)
116
+ x = x + res
117
+ return x + self.tdf(x) if self.use_tdf else x
118
+
119
+
120
+ class TFC_TDF_Res2(nn.Module):
121
+ def __init__(self, c_in, c_out, l, f, k, bn, bn_norm, dense=False, bias=True):
122
+
123
+ super(TFC_TDF_Res2, self).__init__()
124
+
125
+ self.use_tdf = bn is not None
126
+
127
+ self.tfc1 = TFC(c_in, c_out, l, k, bn_norm)
128
+ self.tfc2 = TFC(c_in, c_out, l, k, bn_norm)
129
+
130
+ self.res = TFC(c_in, c_out, 1, k, bn_norm)
131
+
132
+ if self.use_tdf:
133
+ if bn == 0:
134
+ # print(f"TDF={f},{f}")
135
+ self.tdf = nn.Sequential(
136
+ nn.Linear(f, f, bias=bias),
137
+ get_norm(bn_norm, c_out),
138
+ nn.ReLU()
139
+ )
140
+ else:
141
+ # print(f"TDF={f},{f // bn},{f}")
142
+ self.tdf = nn.Sequential(
143
+ nn.Linear(f, f // bn, bias=bias),
144
+ get_norm(bn_norm, c_out),
145
+ nn.ReLU(),
146
+ nn.Linear(f // bn, f, bias=bias),
147
+ get_norm(bn_norm, c_out),
148
+ nn.ReLU()
149
+ )
150
+
151
+ def forward(self, x):
152
+ res = self.res(x)
153
+ x = self.tfc1(x)
154
+ if self.use_tdf:
155
+ x = x + self.tdf(x)
156
+ x = self.tfc2(x)
157
+ x = x + res
158
+ return x
src/evaluation/eval.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir
2
+ from pathlib import Path
3
+ from typing import Optional, List
4
+
5
+ from concurrent import futures
6
+ import hydra
7
+ import wandb
8
+ import os
9
+ import shutil
10
+ from omegaconf import DictConfig
11
+ from pytorch_lightning import LightningDataModule, LightningModule
12
+ from pytorch_lightning.loggers import Logger, WandbLogger
13
+ import soundfile as sf
14
+
15
+ from tqdm import tqdm
16
+ import numpy as np
17
+ from src.callbacks.wandb_callbacks import get_wandb_logger
18
+ from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF
19
+ from src.utils import utils
20
+ from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics
21
+
22
+ from src.utils import pylogger
23
+
24
+ log = pylogger.get_pylogger(__name__)
25
+
26
+
27
+ def evaluation(config: DictConfig):
28
+
29
+ assert config.split in ['train', 'valid', 'test']
30
+
31
+ data_dir = Path(config.get('eval_dir')).joinpath(config['split'])
32
+ assert data_dir.exists()
33
+
34
+ # Init Lightning loggers
35
+ loggers: List[Logger] = []
36
+ if "logger" in config:
37
+ for _, lg_conf in config.logger.items():
38
+ if "_target_" in lg_conf:
39
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
40
+ loggers.append(hydra.utils.instantiate(lg_conf))
41
+
42
+ if any([isinstance(l, WandbLogger) for l in loggers]):
43
+ utils.wandb_login(key=config.wandb_api_key)
44
+
45
+ model = hydra.utils.instantiate(config.model)
46
+ target_name = model.target_name
47
+ ckpt_path = Path(config.ckpt_path)
48
+ is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx'
49
+ shutil.copy(ckpt_path,os.getcwd()) # copy model
50
+
51
+ ssdrs = []
52
+ bss_lst = []
53
+ bss_perms = []
54
+ num_tracks = len(listdir(data_dir))
55
+ target_list = [config.model.target_name,"complement"]
56
+
57
+
58
+ pool = futures.ProcessPoolExecutor
59
+ with pool(config.pool_workers) as pool:
60
+ datas = sorted(listdir(data_dir))
61
+ if len(datas) > 27: # if not debugging
62
+ # move idx 27 to head
63
+ datas = [datas[27]] + datas[:27] + datas[28:]
64
+ # iterate datas with batchsize 8
65
+ for k in range(0, len(datas), config.pool_workers):
66
+ batch = datas[k:k + config.pool_workers]
67
+ pendings = []
68
+ for i, track in tqdm(enumerate(batch)):
69
+ folder_name = track
70
+ track = data_dir.joinpath(track)
71
+ mixture = load_wav(track.joinpath('mixture.wav')) # (c, t)
72
+ target = load_wav(track.joinpath(target_name + '.wav'))
73
+
74
+ if model.audio_ch == 1:
75
+ mixture = np.mean(mixture, axis=0, keepdims=True)
76
+ target = np.mean(target, axis=0, keepdims=True)
77
+ #target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources}
78
+ if is_onnx:
79
+ target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture)
80
+ else:
81
+ target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, config.overlap_add)
82
+
83
+
84
+ pendings.append((folder_name, pool.submit(
85
+ get_metrics, target_hat, target, mixture, sr=44100,version=config.bss)))
86
+
87
+ for wandb_logger in [logger for logger in loggers if isinstance(logger, WandbLogger)]:
88
+ mid = mixture.shape[-1] // 2
89
+ track = target_hat[:, mid - 44100 * 3:mid + 44100 * 3]
90
+ wandb_logger.experiment.log(
91
+ {f'track={k+i}_target={target_name}': [wandb.Audio(track.T, sample_rate=44100)]})
92
+
93
+
94
+ for i, (track_name, pending) in tqdm(enumerate(pendings)):
95
+ pending = pending.result()
96
+ bssmetrics, perms, ssdr = pending
97
+ bss_lst.append(bssmetrics)
98
+ bss_perms.append(perms)
99
+ ssdrs.append(ssdr)
100
+
101
+ for logger in loggers:
102
+ logger.log_metrics({'song/ssdr': ssdr}, k+i)
103
+ logger.log_metrics({'song/csdr': get_median_csdr([bssmetrics])}, k+i)
104
+
105
+ log_dir = os.getcwd()
106
+ save_results(log_dir, bss_lst, target_list, bss_perms, ssdrs)
107
+
108
+ cSDR = get_median_csdr(bss_lst)
109
+ uSDR = sum(ssdrs)/num_tracks
110
+ for logger in loggers:
111
+ logger.log_metrics({'metrics/mean_sdr_' + target_name: sum(ssdrs)/num_tracks})
112
+ logger.log_metrics({'metrics/median_csdr_' + target_name: get_median_csdr(bss_lst)})
113
+ # get the path of the log dir
114
+ if not isinstance(logger, WandbLogger):
115
+ logger.experiment.close()
116
+
117
+ if any([isinstance(logger, WandbLogger) for logger in loggers]):
118
+ wandb.finish()
119
+
120
+ return cSDR, uSDR
src/evaluation/eval_demo.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir
2
+ from pathlib import Path
3
+ from typing import Optional, List
4
+
5
+ from concurrent import futures
6
+ import hydra
7
+ import wandb
8
+ import os
9
+ import shutil
10
+ from omegaconf import DictConfig
11
+ from pytorch_lightning import LightningDataModule, LightningModule
12
+ from pytorch_lightning.loggers import Logger, WandbLogger
13
+
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ from src.callbacks.wandb_callbacks import get_wandb_logger
17
+ from src.evaluation.separate import separate_with_onnx_TDF, separate_with_ckpt_TDF
18
+ from src.utils import utils
19
+ from src.utils.utils import load_wav, sdr, get_median_csdr, save_results, get_metrics
20
+
21
+ from src.utils import pylogger
22
+ import soundfile as sf
23
+ log = pylogger.get_pylogger(__name__)
24
+
25
+
26
+ def evaluation(config: DictConfig, idx):
27
+
28
+ assert config.split in ['train', 'valid', 'test']
29
+
30
+ data_dir = Path(config.get('eval_dir')).joinpath(config['split'])
31
+ assert data_dir.exists()
32
+
33
+ model = hydra.utils.instantiate(config.model)
34
+ target_name = model.target_name
35
+ ckpt_path = Path(config.ckpt_path)
36
+ is_onnx = os.path.split(ckpt_path)[-1].split('.')[-1] == 'onnx'
37
+ shutil.copy(ckpt_path,os.getcwd()) # copy model
38
+
39
+ datas = sorted(listdir(data_dir))
40
+ if len(datas) > 27: # if not debugging
41
+ # move idx 27 to head
42
+ datas = [datas[27]] + datas[:27] + datas[28:]
43
+
44
+
45
+ track = datas[idx]
46
+ track = data_dir.joinpath(track)
47
+ print(track)
48
+ mixture = load_wav(track.joinpath('mixture.wav')) # (c, t)
49
+ target = load_wav(track.joinpath(target_name + '.wav'))
50
+ if model.audio_ch == 1:
51
+ mixture = np.mean(mixture, axis=0, keepdims=True)
52
+ target = np.mean(target, axis=0, keepdims=True)
53
+ #target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources}
54
+ if is_onnx:
55
+ target_hat = separate_with_onnx_TDF(config.batch_size, model, ckpt_path, mixture)
56
+ else:
57
+ target_hat = separate_with_ckpt_TDF(config.batch_size, model, ckpt_path, mixture, config.device, config.double_chunk, overlap_factor=config.overlap_factor)
58
+
59
+ bssmetrics, perms, ssdr = get_metrics(target_hat, target, mixture, sr=44100,version=config.bss)
60
+ # dump bssmetrics into pkl
61
+ import pickle
62
+ with open(os.path.join(os.getcwd(),'bssmetrics.pkl'),'wb') as f:
63
+ pickle.dump(bssmetrics,f)
64
+
65
+ return bssmetrics
66
+
67
+
68
+
69
+
70
+
71
+
src/evaluation/separate.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import listdir
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import math
8
+ import os
9
+ from src.utils.utils import split_nparray_with_overlap, join_chunks
10
+
11
+ def separate_with_onnx(batch_size, model, onnx_path: Path, mix):
12
+ n_sample = mix.shape[1]
13
+
14
+ trim = model.n_fft // 2
15
+ gen_size = model.sampling_size - 2 * trim
16
+ pad = gen_size - n_sample % gen_size
17
+ mix_p = np.concatenate((np.zeros((2, trim)), mix, np.zeros((2, pad)), np.zeros((2, trim))), 1)
18
+
19
+ mix_waves = []
20
+ i = 0
21
+ while i < n_sample + pad:
22
+ waves = np.array(mix_p[:, i:i + model.sampling_size], dtype=np.float32)
23
+ mix_waves.append(waves)
24
+ i += gen_size
25
+ mix_waves_batched = torch.tensor(mix_waves, dtype=torch.float32).split(batch_size)
26
+
27
+ tar_signals = []
28
+
29
+ with torch.no_grad():
30
+ _ort = ort.InferenceSession(str(onnx_path))
31
+ for mix_waves in mix_waves_batched:
32
+ tar_waves = model.istft(torch.tensor(
33
+ _ort.run(None, {'input': model.stft(mix_waves).numpy()})[0]
34
+ ))
35
+ tar_signals.append(tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy())
36
+ tar_signal = np.concatenate(tar_signals, axis=-1)[:, :-pad]
37
+
38
+ return tar_signal
39
+
40
+
41
+ def separate_with_ckpt(batch_size, model, ckpt_path: Path, mix, device, double_chunk):
42
+ model = model.load_from_checkpoint(ckpt_path).to(device)
43
+ if double_chunk:
44
+ inf_ck = model.inference_chunk_size
45
+ else:
46
+ inf_ck = model.sampling_size
47
+ true_samples = inf_ck - 2 * model.trim
48
+
49
+ right_pad = true_samples + model.trim - ((mix.shape[-1]) % true_samples)
50
+ mixture = np.concatenate((np.zeros((2, model.trim), dtype='float32'),
51
+ mix,
52
+ np.zeros((2, right_pad), dtype='float32')),
53
+ 1)
54
+ num_chunks = mixture.shape[-1] // true_samples
55
+ mix_waves_batched = [mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
56
+ range(num_chunks)]
57
+ mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
58
+
59
+ target_wav_hats = []
60
+
61
+ with torch.no_grad():
62
+ model.eval()
63
+ for mixture_wav in mix_waves_batched:
64
+ mix_spec = model.stft(mixture_wav.to(device))
65
+ spec_hat = model(mix_spec)
66
+ target_wav_hat = model.istft(spec_hat)
67
+ target_wav_hat = target_wav_hat.cpu().detach().numpy()
68
+ target_wav_hats.append(target_wav_hat)
69
+
70
+ target_wav_hat = np.vstack(target_wav_hats)[:, :, model.trim:-model.trim]
71
+ target_wav_hat = np.concatenate(target_wav_hat, axis=-1)[:, :mix.shape[-1]]
72
+ return target_wav_hat
73
+
74
+
75
+
76
+
77
+ def separate_with_onnx_TDF(batch_size, model, onnx_path: Path, mix):
78
+ n_sample = mix.shape[1]
79
+
80
+ overlap = model.n_fft // 2
81
+ gen_size = model.inference_chunk_size - 2 * overlap
82
+ pad = gen_size - n_sample % gen_size
83
+ mix_p = np.concatenate((np.zeros((2, overlap)), mix, np.zeros((2, pad)), np.zeros((2, overlap))), 1)
84
+
85
+ mix_waves = []
86
+ i = 0
87
+ while i < n_sample + pad:
88
+ waves = np.array(mix_p[:, i:i + model.inference_chunk_size], dtype=np.float32)
89
+ mix_waves.append(waves)
90
+ i += gen_size
91
+ mix_waves_batched = torch.tensor(mix_waves, dtype=torch.float32).split(batch_size)
92
+
93
+ tar_signals = []
94
+
95
+ with torch.no_grad():
96
+ _ort = ort.InferenceSession(str(onnx_path), providers=['CUDAExecutionProvider'])
97
+ for mix_waves in mix_waves_batched:
98
+ tar_waves = model.istft(torch.tensor(
99
+ _ort.run(None, {'input': model.stft(mix_waves).numpy()})[0]
100
+ ))
101
+ tar_signals.append(tar_waves[:, :, overlap:-overlap].transpose(0, 1).reshape(2, -1).numpy())
102
+ tar_signal = np.concatenate(tar_signals, axis=-1)[:, :-pad]
103
+
104
+ return tar_signal
105
+
106
+
107
+
108
+ def separate_with_ckpt_TDF(batch_size, model, ckpt_path: Path, mix, device, double_chunk, overlap_add):
109
+ '''
110
+ Args:
111
+ batch_size: the inference batch size
112
+ model: the model to be used
113
+ ckpt_path: the path to the checkpoint
114
+ mix: (c, t)
115
+ device: the device to be used
116
+ double_chunk: whether to use double chunk size
117
+ Returns:
118
+ target_wav_hat: (c, t)
119
+ '''
120
+ checkpoint = torch.load(ckpt_path)
121
+ model.load_state_dict(checkpoint["state_dict"])
122
+ model = model.to(device)
123
+ # model = model.load_from_checkpoint(ckpt_path).to(device)
124
+ if double_chunk:
125
+ inf_ck = model.inference_chunk_size
126
+ else:
127
+ inf_ck = model.chunk_size
128
+
129
+ if overlap_add is None:
130
+ target_wav_hat = no_overlap_inference(model, mix, device, batch_size, inf_ck)
131
+ else:
132
+ if not os.path.exists(overlap_add.tmp_root):
133
+ os.makedirs(overlap_add.tmp_root)
134
+ target_wav_hat = overlap_inference(model, mix, device, batch_size, inf_ck, overlap_add.overlap_rate, overlap_add.tmp_root, overlap_add.samplerate)
135
+
136
+ return target_wav_hat
137
+
138
+ def no_overlap_inference(model, mix, device, batch_size, inf_ck):
139
+ true_samples = inf_ck - 2 * model.overlap
140
+
141
+ right_pad = true_samples + model.overlap - ((mix.shape[-1]) % true_samples)
142
+ mixture = np.concatenate((np.zeros((model.audio_ch, model.overlap), dtype='float32'),
143
+ mix,
144
+ np.zeros((model.audio_ch, right_pad), dtype='float32')),
145
+ 1)
146
+ num_chunks = mixture.shape[-1] // true_samples
147
+ mix_waves_batched = [mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
148
+ range(num_chunks)]
149
+ mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
150
+
151
+ target_wav_hats = []
152
+
153
+ with torch.no_grad():
154
+ model.eval()
155
+ for mixture_wav in mix_waves_batched:
156
+ mix_spec = model.stft(mixture_wav.to(device))
157
+ spec_hat = model(mix_spec)
158
+ target_wav_hat = model.istft(spec_hat)
159
+ target_wav_hat = target_wav_hat.cpu().detach().numpy()
160
+ target_wav_hats.append(target_wav_hat) # (b, c, t)
161
+
162
+ target_wav_hat = np.vstack(target_wav_hats)[:, :, model.overlap:-model.overlap] # (sum(b), c, t)
163
+ target_wav_hat = np.concatenate(target_wav_hat, axis=-1)[:, :mix.shape[-1]]
164
+ return target_wav_hat
165
+
166
+
167
+ def overlap_inference(model, mix, device, batch_size, inf_ck, overlap_rate, tmp_root, samplerate):
168
+ '''
169
+ Args:
170
+ mix: (c, t)
171
+ '''
172
+ hop_length = math.ceil((1 - overlap_rate) * inf_ck)
173
+ overlap_size = inf_ck - hop_length
174
+ step_t = mix.shape[1]
175
+ mix_waves_batched = split_nparray_with_overlap(mix.T, hop_length, overlap_size)
176
+
177
+ mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size) # [(b, c, t)]
178
+
179
+ target_wav_hats = []
180
+
181
+ with torch.no_grad():
182
+ model.eval()
183
+ for mixture_wav in mix_waves_batched:
184
+ mix_spec = model.stft(mixture_wav.to(device))
185
+ spec_hat = model(mix_spec)
186
+ target_wav_hat = model.istft(spec_hat)
187
+ target_wav_hat = target_wav_hat.cpu().detach().numpy()
188
+ target_wav_hats.append(target_wav_hat) # (b, c, t)
189
+
190
+ target_wav_hat = np.vstack(target_wav_hats) # (sum(b), c, t)
191
+ target_wav_hat = np.transpose(target_wav_hat, (0, 2, 1)) # (sum(b), t, c)
192
+ target_wav_hat = join_chunks(tmp_root, target_wav_hat, samplerate, overlap_size) # (t, c)
193
+ return target_wav_hat[:step_t].T # (c, t)
src/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ from .batch_norm import *
src/layers/batch_norm.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ __all__ = ["IBN", "get_norm"]
8
+
9
+
10
+ class BatchNorm(nn.BatchNorm2d):
11
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
12
+ bias_init=0.0, **kwargs):
13
+ super().__init__(num_features, eps=eps, momentum=momentum)
14
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
15
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
16
+ self.weight.requires_grad_(not weight_freeze)
17
+ self.bias.requires_grad_(not bias_freeze)
18
+
19
+
20
+ class SyncBatchNorm(nn.SyncBatchNorm):
21
+ def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
22
+ bias_init=0.0):
23
+ super().__init__(num_features, eps=eps, momentum=momentum)
24
+ if weight_init is not None: nn.init.constant_(self.weight, weight_init)
25
+ if bias_init is not None: nn.init.constant_(self.bias, bias_init)
26
+ self.weight.requires_grad_(not weight_freeze)
27
+ self.bias.requires_grad_(not bias_freeze)
28
+
29
+
30
+ class IBN(nn.Module):
31
+ def __init__(self, planes, bn_norm, **kwargs):
32
+ super(IBN, self).__init__()
33
+ half1 = int(planes / 2)
34
+ self.half = half1
35
+ half2 = planes - half1
36
+ self.IN = nn.InstanceNorm2d(half1, affine=True)
37
+ self.BN = get_norm(bn_norm, half2, **kwargs)
38
+
39
+ def forward(self, x):
40
+ split = torch.split(x, self.half, 1)
41
+ out1 = self.IN(split[0].contiguous())
42
+ out2 = self.BN(split[1].contiguous())
43
+ out = torch.cat((out1, out2), 1)
44
+ return out
45
+
46
+
47
+ class GhostBatchNorm(BatchNorm):
48
+ def __init__(self, num_features, num_splits=1, **kwargs):
49
+ super().__init__(num_features, **kwargs)
50
+ self.num_splits = num_splits
51
+ self.register_buffer('running_mean', torch.zeros(num_features))
52
+ self.register_buffer('running_var', torch.ones(num_features))
53
+
54
+ def forward(self, input):
55
+ N, C, H, W = input.shape
56
+ if self.training or not self.track_running_stats:
57
+ self.running_mean = self.running_mean.repeat(self.num_splits)
58
+ self.running_var = self.running_var.repeat(self.num_splits)
59
+ outputs = F.batch_norm(
60
+ input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
61
+ self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
62
+ True, self.momentum, self.eps).view(N, C, H, W)
63
+ self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
64
+ self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
65
+ return outputs
66
+ else:
67
+ return F.batch_norm(
68
+ input, self.running_mean, self.running_var,
69
+ self.weight, self.bias, False, self.momentum, self.eps)
70
+
71
+
72
+ class FrozenBatchNorm(nn.Module):
73
+ """
74
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
75
+ It contains non-trainable buffers called
76
+ "weight" and "bias", "running_mean", "running_var",
77
+ initialized to perform identity transformation.
78
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
79
+ which are computed from the original four parameters of BN.
80
+ The affine transform `x * weight + bias` will perform the equivalent
81
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
82
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
83
+ will be left unchanged as identity transformation.
84
+ Other pre-trained backbone models may contain all 4 parameters.
85
+ The forward is implemented by `F.batch_norm(..., training=False)`.
86
+ """
87
+
88
+ _version = 3
89
+
90
+ def __init__(self, num_features, eps=1e-5, **kwargs):
91
+ super().__init__()
92
+ self.num_features = num_features
93
+ self.eps = eps
94
+ self.register_buffer("weight", torch.ones(num_features))
95
+ self.register_buffer("bias", torch.zeros(num_features))
96
+ self.register_buffer("running_mean", torch.zeros(num_features))
97
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
98
+
99
+ def forward(self, x):
100
+ if x.requires_grad:
101
+ # When gradients are needed, F.batch_norm will use extra memory
102
+ # because its backward op computes gradients for weight/bias as well.
103
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
104
+ bias = self.bias - self.running_mean * scale
105
+ scale = scale.reshape(1, -1, 1, 1)
106
+ bias = bias.reshape(1, -1, 1, 1)
107
+ return x * scale + bias
108
+ else:
109
+ # When gradients are not needed, F.batch_norm is a single fused op
110
+ # and provide more optimization opportunities.
111
+ return F.batch_norm(
112
+ x,
113
+ self.running_mean,
114
+ self.running_var,
115
+ self.weight,
116
+ self.bias,
117
+ training=False,
118
+ eps=self.eps,
119
+ )
120
+
121
+ def _load_from_state_dict(
122
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
123
+ ):
124
+ version = local_metadata.get("version", None)
125
+
126
+ if version is None or version < 2:
127
+ # No running_mean/var in early versions
128
+ # This will silent the warnings
129
+ if prefix + "running_mean" not in state_dict:
130
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
131
+ if prefix + "running_var" not in state_dict:
132
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
133
+
134
+ if version is not None and version < 3:
135
+ logger = logging.getLogger(__name__)
136
+ logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
137
+ # In version < 3, running_var are used without +eps.
138
+ state_dict[prefix + "running_var"] -= self.eps
139
+
140
+ super()._load_from_state_dict(
141
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
142
+ )
143
+
144
+ def __repr__(self):
145
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
146
+
147
+ @classmethod
148
+ def convert_frozen_batchnorm(cls, module):
149
+ """
150
+ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
151
+ Args:
152
+ module (torch.nn.Module):
153
+ Returns:
154
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
155
+ Otherwise, in-place convert module and return it.
156
+ Similar to convert_sync_batchnorm in
157
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
158
+ """
159
+ bn_module = nn.modules.batchnorm
160
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
161
+ res = module
162
+ if isinstance(module, bn_module):
163
+ res = cls(module.num_features)
164
+ if module.affine:
165
+ res.weight.data = module.weight.data.clone().detach()
166
+ res.bias.data = module.bias.data.clone().detach()
167
+ res.running_mean.data = module.running_mean.data
168
+ res.running_var.data = module.running_var.data
169
+ res.eps = module.eps
170
+ else:
171
+ for name, child in module.named_children():
172
+ new_child = cls.convert_frozen_batchnorm(child)
173
+ if new_child is not child:
174
+ res.add_module(name, new_child)
175
+ return res
176
+
177
+
178
+ def get_norm(norm, out_channels, **kwargs):
179
+ """
180
+ Args:
181
+ norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN;
182
+ or a callable that takes a channel number and returns
183
+ the normalization layer as a nn.Module
184
+ out_channels: number of channels for normalization layer
185
+
186
+ Returns:
187
+ nn.Module or None: the normalization layer
188
+ """
189
+ # return nn.BatchNorm2d(out_channels)
190
+
191
+ if isinstance(norm, str):
192
+ if len(norm) == 0:
193
+ return None
194
+ norm = {
195
+ "BN": BatchNorm,
196
+ "syncBN": SyncBatchNorm,
197
+ "GhostBN": GhostBatchNorm,
198
+ "FrozenBN": FrozenBatchNorm,
199
+ "GN": lambda channels, **args: nn.GroupNorm(32, channels),
200
+ }[norm]
201
+ return norm(out_channels, **kwargs)
src/layers/chunk_size.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+
7
+ def wave_to_batches(mix, inf_ck, overlap, batch_size):
8
+ '''
9
+ Args:
10
+ mix: (2, N) numpy array
11
+ inf_ck: int, the chunk size as the model input (contains 2*overlap)
12
+ inf_ck = overlap + true_samples + overlap
13
+ overlap: int, the discarded samples at each side
14
+ Returns:
15
+ a tuples of batches, each batch is a (batch, 2, inf_ck) torch tensor
16
+ '''
17
+ true_samples = inf_ck - 2 * overlap
18
+ channels = mix.shape[0]
19
+
20
+ right_pad = true_samples + overlap - ((mix.shape[-1]) % true_samples)
21
+ mixture = np.concatenate((np.zeros((channels, overlap), dtype='float32'),
22
+ mix,
23
+ np.zeros((channels, right_pad), dtype='float32')),
24
+ 1)
25
+
26
+ num_chunks = mixture.shape[-1] // true_samples
27
+ mix_waves_batched = np.array([mixture[:, i * true_samples: i * true_samples + inf_ck] for i in
28
+ range(num_chunks)]) # (x,2,inf_ck)
29
+ return torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size)
30
+
31
+ def batches_to_wave(target_hat_chunks, overlap, org_len):
32
+ '''
33
+ Args:
34
+ target_hat_chunks: a list of (batch, 2, inf_ck) torch tensors
35
+ overlap: int, the discarded samples at each side
36
+ org_len: int, the original length of the mixture
37
+ Returns:
38
+ (2, N) numpy array
39
+ '''
40
+ target_hat_chunks = [c[..., overlap:-overlap] for c in target_hat_chunks]
41
+ target_hat_chunks = torch.cat(target_hat_chunks)
42
+
43
+ # concat all output chunks
44
+ return target_hat_chunks.transpose(0, 1).reshape(2, -1)[..., :org_len].detach().cpu().numpy()
45
+
46
+ if __name__ == '__main__':
47
+ mix = np.random.rand(2, 14318640)
48
+ inf_ck = 261120
49
+ overlap = 3072
50
+ batch_size = 8
51
+ out = wave_to_batches(mix, inf_ck, overlap, batch_size)
52
+ in_wav = batches_to_wave(out, overlap, mix.shape[-1])
53
+ print(in_wav.shape)
src/train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import hydra
4
+ import pytorch_lightning as pl
5
+ import pyrootutils
6
+ import torch
7
+ import os
8
+ import shutil
9
+ from omegaconf import DictConfig
10
+ from pytorch_lightning import (
11
+ Callback,
12
+ LightningDataModule,
13
+ LightningModule,
14
+ Trainer,
15
+ seed_everything,
16
+ )
17
+ from pytorch_lightning.loggers import WandbLogger
18
+ from hydra.core.hydra_config import HydraConfig
19
+
20
+ from src import utils
21
+
22
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
23
+
24
+ log = utils.get_pylogger(__name__)
25
+
26
+
27
+ @utils.task_wrapper
28
+ def train(cfg: DictConfig) -> Optional[float]:
29
+ """Contains training pipeline.
30
+ Instantiates all PyTorch Lightning objects from config.
31
+
32
+ Args:
33
+ cfg (DictConfig): Configuration composed by Hydra.
34
+
35
+ Returns:
36
+ Optional[float]: Metric score for hyperparameter optimization.
37
+ """
38
+
39
+ # Set seed for random number generators in pytorch, numpy and python.random
40
+ try:
41
+ if "seed" in cfg:
42
+ # set seed for random number generators in pytorch, numpy and python.random
43
+ if cfg.get("seed"):
44
+ pl.seed_everything(cfg.seed, workers=True)
45
+
46
+ else:
47
+ raise ModuleNotFoundError
48
+
49
+ except ModuleNotFoundError:
50
+ print('[Error] seed should be fixed for reproducibility \n=> e.g. python run.py +seed=$SEED')
51
+ exit(-1)
52
+
53
+ # Init Lightning datamodule
54
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
55
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datamodule)
56
+
57
+ # Init Lightning model
58
+ log.info(f"Instantiating model <{cfg.model._target_}>")
59
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
60
+
61
+ # Init Lightning callbacks
62
+ callbacks: List[Callback] = []
63
+ if "callbacks" in cfg:
64
+ for _, cb_conf in cfg["callbacks"].items():
65
+ if "_target_" in cb_conf:
66
+ log.info(f"Instantiating callback <{cb_conf._target_}>")
67
+ callbacks.append(hydra.utils.instantiate(cb_conf))
68
+
69
+ # Init Lightning loggers
70
+ if "resume_from_checkpoint" in cfg.trainer:
71
+ ckpt_path = cfg.trainer.resume_from_checkpoint
72
+ # get the parent directory of the checkpoint path
73
+ log_dir = os.path.dirname(os.path.dirname(ckpt_path))
74
+ tensorboard_dir = os.path.join(log_dir, "tensorboard")
75
+ if os.path.exists(tensorboard_dir):
76
+ # copy tensorboard dir to the parent directory of the checkpoint path
77
+ # HydraConfig.get().run.dir returns new dir so do not use it! (now fixed)
78
+ shutil.copytree(tensorboard_dir,os.path.join(os.getcwd(),"tensorboard"))
79
+
80
+ wandb_dir = os.path.join(log_dir, "wandb")
81
+ if os.path.exists(wandb_dir):
82
+ shutil.copytree(wandb_dir,os.path.join(os.getcwd(),"wandb"))
83
+
84
+
85
+ logger: List = []
86
+ if "logger" in cfg:
87
+ for _, lg_conf in cfg["logger"].items():
88
+ if "_target_" in lg_conf:
89
+ log.info(f"Instantiating logger <{lg_conf._target_}>")
90
+ logger.append(hydra.utils.instantiate(lg_conf))
91
+
92
+ for wandb_logger in [l for l in logger if isinstance(l, WandbLogger)]:
93
+ utils.wandb_login(key=cfg.wandb_api_key)
94
+ # utils.wandb_watch_all(wandb_logger, model) # TODO buggy
95
+ break
96
+
97
+ # Init Lightning trainer
98
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
99
+ # get env variable use_gloo
100
+ use_gloo = os.environ.get("USE_GLOO", False)
101
+ if use_gloo:
102
+ from pytorch_lightning.strategies import DDPStrategy
103
+ ddp = DDPStrategy(process_group_backend='gloo')
104
+ trainer: Trainer = hydra.utils.instantiate(
105
+ cfg.trainer, strategy=ddp, callbacks=callbacks, logger=logger, _convert_="partial"
106
+ )
107
+ else:
108
+ trainer: Trainer = hydra.utils.instantiate(
109
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
110
+ )
111
+
112
+ # Send some parameters from config to all lightning loggers
113
+ log.info("Logging hyperparameters!")
114
+ utils.log_hyperparameters(
115
+ dict(
116
+ cfg=cfg,
117
+ model=model,
118
+ datamodule=datamodule,
119
+ trainer=trainer,
120
+ callbacks=callbacks,
121
+ logger=logger,
122
+ )
123
+ )
124
+
125
+ # Train the model
126
+ log.info("Starting training!")
127
+ trainer.fit(model=model, datamodule=datamodule)
128
+
129
+ # Evaluate model on test set after training
130
+ # if not cfg.trainer.get("fast_dev_run"):
131
+ # log.info("Starting testing!")
132
+ # trainer.test()
133
+
134
+ # Make sure everything closed properly
135
+ log.info("Finalizing!")
136
+ # utils.finish(
137
+ # config=cfg,
138
+ # model=model,
139
+ # datamodule=datamodule,
140
+ # trainer=trainer,
141
+ # callbacks=callbacks,
142
+ # logger=logger,
143
+ # )
144
+
145
+ # Print path to best checkpoint
146
+ # log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
147
+
148
+ # Return metric score for hyperparameter optimization
149
+ # optimized_metric = cfg.get("optimized_metric")
150
+ # if optimized_metric:
151
+ # return trainer.callback_metrics[optimized_metric]
152
+ return None, None
src/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from src.utils.pylogger import get_pylogger
2
+ from src.utils.rich_utils import enforce_tags, print_config_tree
3
+ from src.utils.utils import *
src/utils/data_augmentation.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess as sp
3
+ import tempfile
4
+ import warnings
5
+ from argparse import ArgumentParser
6
+ from concurrent import futures
7
+
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import torch
11
+ from tqdm import tqdm
12
+
13
+ warnings.simplefilter(action='ignore', category=Warning)
14
+ source_names = ['vocals', 'drums', 'bass', 'other']
15
+ sample_rate = 44100
16
+
17
+ def main (args):
18
+ data_root = args.data_dir
19
+ train = args.train
20
+ test = args.test
21
+ valid = args.valid
22
+
23
+ musdb_train_path = data_root + 'train/'
24
+ musdb_test_path = data_root + 'test/'
25
+ musdb_valid_path = data_root + 'valid/'
26
+ print(f"train={train}, test={test}, valid={valid}")
27
+
28
+ mix_name = 'mixture'
29
+
30
+ P = [-3, -2, -1, 0, 1, 2, 3] # pitch shift amounts (in semitones)
31
+ T = [-30, -20, -10, 0, 10, 20, 30] # time stretch amounts (10 means 10% slower)
32
+
33
+ pool = futures.ProcessPoolExecutor
34
+ pool_workers = 13
35
+ pendings = []
36
+ with pool(pool_workers) as pool:
37
+ for p in P:
38
+ for t in T:
39
+ if not (p==0 and t==0):
40
+ if train:
41
+ pendings.append(pool.submit(save_shifted_dataset, p, t, musdb_train_path))
42
+ # save_shifted_dataset(p, t, musdb_train_path)
43
+ if valid:
44
+ save_shifted_dataset(p, t, musdb_valid_path)
45
+ if test:
46
+ save_shifted_dataset(p, t, musdb_test_path)
47
+ for pending in pendings:
48
+ pending.result()
49
+
50
+
51
+ def shift(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
52
+ def i16_pcm(wav):
53
+ if wav.dtype == np.int16:
54
+ return wav
55
+ return (wav * 2 ** 15).clamp_(-2 ** 15, 2 ** 15 - 1).short()
56
+
57
+ def f32_pcm(wav):
58
+ if wav.dtype == np.float:
59
+ return wav
60
+ return wav.float() / 2 ** 15
61
+
62
+ """
63
+ tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
64
+ pitch is in semi tones.
65
+ Requires `soundstretch` to be installed, see
66
+ https://www.surina.net/soundtouch/soundstretch.html
67
+ """
68
+
69
+ inputfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav")
70
+ outfile = tempfile.NamedTemporaryFile(dir="/root/autodl-tmp/tmp", suffix=".wav")
71
+
72
+ sf.write(inputfile.name, data=i16_pcm(wav).t().numpy(), samplerate=samplerate, format='WAV')
73
+ command = [
74
+ "soundstretch",
75
+ inputfile.name,
76
+ outfile.name,
77
+ f"-pitch={pitch}",
78
+ f"-tempo={tempo:.6f}",
79
+ ]
80
+ if quick:
81
+ command += ["-quick"]
82
+ if voice:
83
+ command += ["-speech"]
84
+ try:
85
+ sp.run(command, capture_output=True, check=True)
86
+ except sp.CalledProcessError as error:
87
+ raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
88
+ wav, sr = sf.read(outfile.name, dtype='float32')
89
+ # wav = np.float32(wav)
90
+ # wav = f32_pcm(torch.from_numpy(wav).t())
91
+ assert sr == samplerate
92
+ return wav
93
+
94
+
95
+ def save_shifted_dataset(delta_pitch, delta_tempo, data_path):
96
+ out_path = data_path[:-1] + f'_p={delta_pitch}_t={delta_tempo}/'
97
+ try:
98
+ os.mkdir(out_path)
99
+ except FileExistsError:
100
+ pass
101
+ track_names = list(filter(lambda x: os.path.isdir(f'{data_path}/{x}'), sorted(os.listdir(data_path))))
102
+ for track_name in tqdm(track_names):
103
+ try:
104
+ os.mkdir(f'{out_path}/{track_name}')
105
+ except FileExistsError:
106
+ pass
107
+ for s_name in source_names:
108
+ source = load_wav(f'{data_path}/{track_name}/{s_name}.wav')
109
+ shifted = shift(
110
+ torch.tensor(source),
111
+ delta_pitch,
112
+ delta_tempo,
113
+ voice=s_name == 'vocals')
114
+ sf.write(f'{out_path}/{track_name}/{s_name}.wav', shifted, samplerate=sample_rate, format='WAV')
115
+
116
+
117
+ def load_wav(path, sr=None):
118
+ return sf.read(path, samplerate=sr, dtype='float32')[0].T
119
+
120
+
121
+ if __name__ == '__main__':
122
+ parser = ArgumentParser()
123
+ parser.add_argument('--data_dir', type=str)
124
+ parser.add_argument('--train', type=bool, default=True)
125
+ parser.add_argument('--valid', type=bool, default=False)
126
+ parser.add_argument('--test', type=bool, default=False)
127
+
128
+ main(parser.parse_args())