diff --git a/dito/.gitignore b/dito/.gitignore
deleted file mode 100644
index ed8ebf583f771da9150c35db3955987b7d757904..0000000000000000000000000000000000000000
--- a/dito/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-__pycache__
\ No newline at end of file
diff --git a/dito/configs/datasets/dae.yaml b/dito/configs/datasets/dae.yaml
deleted file mode 100644
index 8fd9018ae2636223cb3c70f192dd58ae397c1c56..0000000000000000000000000000000000000000
--- a/dito/configs/datasets/dae.yaml
+++ /dev/null
@@ -1,79 +0,0 @@
-# Datasets
-datasets:
- train:
- name: wrapper_audio_cae
- args:
- dataset:
- name: audio_dataset_from_folders
- args:
- folders:
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/Emilia/EN"]
- sample_rate: 24000
- duration: 0.38
- n_examples: 10000000
- shuffle: true
- mono: true
- sample_rate: 24000
- duration: 0.38
- mono: true
- normalize: true
- return_coords: true
- loader:
- batch_size: 64
- num_workers: 8
- drop_last: true
-
- val:
- name: wrapper_audio_cae
- args:
- dataset:
- name: audio_dataset_from_folders
- args:
- folders:
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
- sample_rate: 24000
- duration: 5.0
- n_examples: 100
- shuffle: false
- mono: true
- sample_rate: 24000
- duration: 5.0
- mono: true
- normalize: true
- return_coords: true
- loader:
- batch_size: 4
- num_workers: 8
- drop_last: false
-
- eval_ae:
- name: wrapper_audio_cae
- args:
- dataset:
- name: audio_dataset_from_folders
- args:
- folders:
- Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"]
- sample_rate: 24000
- duration: 10.0
- n_examples: 1000
- shuffle: false
- mono: true
- sample_rate: 24000
- duration: 10.0
- mono: true
- normalize: true
- return_coords: true
- loader:
- batch_size: 1
- num_workers: 8
- drop_last: false
-
-# Visualization
-visualize_ae_dir: /mnt/nvme/dito_audio
-visualize_ae_random_n_samples: 32
-eval_ae_max_samples: 100
-val_idx: [0, 1, 2, 3, 4, 5, 6, 7]
-
-# Enable autoencoder evaluation
-evaluate_ae: true
\ No newline at end of file
diff --git a/dito/configs/datasets/imagenet_ae.yaml b/dito/configs/datasets/imagenet_ae.yaml
deleted file mode 100644
index 5109901f168de8d1f9b30a301fb087e38fcd440e..0000000000000000000000000000000000000000
--- a/dito/configs/datasets/imagenet_ae.yaml
+++ /dev/null
@@ -1,47 +0,0 @@
-datasets:
- train:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, rand_crop: 256, rand_flip: true, image_only: true}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 14
- num_workers: 24
-
- val:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 14
- num_workers: 24
-
- eval_ae:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 14
- num_workers: 24
- drop_last: false
-
-visualize_ae_dir: /mnt/nvme/dito
-visualize_ae_random_n_samples: 32
-eval_ae_max_samples: 5000
\ No newline at end of file
diff --git a/dito/configs/datasets/imagenet_zdm.yaml b/dito/configs/datasets/imagenet_zdm.yaml
deleted file mode 100644
index 2a54041e020d8ede4c9e3e8647c6c05c5339f0af..0000000000000000000000000000000000000000
--- a/dito/configs/datasets/imagenet_zdm.yaml
+++ /dev/null
@@ -1,53 +0,0 @@
-datasets:
- train:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, square_crop: true, rand_flip: true, drop_label_p: 0.1}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 64
- num_workers: 24
-
- val:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 64
- num_workers: 24
-
- eval_zdm:
- name: wrapper_cae
- args:
- dataset:
- name: class_folder
- args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true}
- resize_inp: 256
- gt_glores_lb: 256
- gt_glores_ub: 256
- gt_patch_size: 256
- loader:
- batch_size: 64
- num_workers: 24
- drop_last: false
-
-visualize_zdm_file: null
-visualize_zdm_setting:
- name: class
- n_classes: 1000
-visualize_zdm_random_n_samples: 12
-visualize_zdm_batch_size: 6
-visualize_zdm_guidance_list: [4]
-visualize_zdm_denoising_file: null
-eval_zdm_max_samples: 5000
\ No newline at end of file
diff --git a/dito/configs/experiments/dito-B-audio.yaml b/dito/configs/experiments/dito-B-audio.yaml
deleted file mode 100644
index 179a84a947164ba7a687fdbddde627970989fb4b..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-B-audio.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
-__base__:
- - configs/datasets/dae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito_audio
- args:
- # Encoder
- encoder:
- name: dac_encoder
- args: {config_name: snakebeta}
-
- # Latent configuration - now fully convolutional
- z_channels: 64 # Number of latent channels
- z_downsample_factor: 320 # Product of encoder_rates: 2*4*5*8
- z_layernorm: true
-
- # Decoder (identity for DiTo)
- decoder:
- name: identity
-
- # Renderer - Fully convolutional for dynamic duration
- renderer:
- name: audio_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet # Fully Convolutional Network
- args:
- in_channels: 1
- z_dec_channels: 64
- c0: 128
- c1: 256
- c2: 512
- pe_dim: 320
- t_dim: 1280
-
- # Diffusion configuration
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
-
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
diff --git a/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml b/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml
deleted file mode 100644
index 5eb1df5b78efb87d3aea6d9850924a8b85cc8da1..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-__base__:
- - configs/datasets/imagenet_ae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito
- args:
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [64, 1, 1]
- z_layernorm: true
-
- zaug_p: 0.1
- zaug_decoding_loss_type: suffix
- zaug_zdm_diffusion:
- name: fm
- args: {timescale: 1000.0}
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 64
- c0: 128
- c1: 256
- c2: 512
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/dito-B-f8c4.yaml b/dito/configs/experiments/dito-B-f8c4.yaml
deleted file mode 100644
index feeab68b5dace98a9701bbc16a27f56d81e0dfc7..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-B-f8c4.yaml
+++ /dev/null
@@ -1,37 +0,0 @@
-__base__:
- - configs/datasets/imagenet_ae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito
- args:
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 128
- c1: 256
- c2: 512
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/dito-L-f8c4.yaml b/dito/configs/experiments/dito-L-f8c4.yaml
deleted file mode 100644
index 5f242bf6f27755c559e439eba78ae03809f07898..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-L-f8c4.yaml
+++ /dev/null
@@ -1,37 +0,0 @@
-__base__:
- - configs/datasets/imagenet_ae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito
- args:
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 192
- c1: 384
- c2: 768
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml
deleted file mode 100644
index ec7761d6570e5456b8099dd6a1eb28754de8c953..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-__base__:
- - configs/datasets/imagenet_ae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito
- args:
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- zaug_p: 0.1
- zaug_decoding_loss_type: suffix
- zaug_zdm_diffusion:
- name: fm
- args: {timescale: 1000.0}
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/dito-XL-f8c4.yaml b/dito/configs/experiments/dito-XL-f8c4.yaml
deleted file mode 100644
index 8610d23c1d1479538b1dd5d568d427f08ab7e9ec..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/dito-XL-f8c4.yaml
+++ /dev/null
@@ -1,37 +0,0 @@
-__base__:
- - configs/datasets/imagenet_ae.yaml
- - configs/trainers/dito.yaml
-
-model:
- name: dito
- args:
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml
deleted file mode 100644
index 679cc9fa97413b19ca1d1581c0a008c4107b7bb9..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
-__base__:
- - configs/datasets/imagenet_zdm.yaml
- - configs/models/zdm-XL_imagenet.yaml
- - configs/trainers/zdm.yaml
-
-eval_zdm_max_samples: 50000
-
-model:
- load_ckpt: save/zdm-XL_dito-XL-f8c4-noise-sync/ckpt-last.pth
- name: dito
- args:
- zdm_force_guidance: 2.0
- renderer_ema_rate: 1
-
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml b/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml
deleted file mode 100644
index f5d360f8536c7652bf77cdeeaeff2dc33885136e..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
-__base__:
- - configs/datasets/imagenet_zdm.yaml
- - configs/models/zdm-XL_imagenet.yaml
- - configs/trainers/zdm.yaml
-
-eval_zdm_max_samples: 50000
-
-model:
- load_ckpt: save/zdm-XL_dito-XL-f8c4/ckpt-last.pth
- name: dito
- args:
- zdm_force_guidance: 2.0
- renderer_ema_rate: 1
-
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml
deleted file mode 100644
index 283a17ad9a677a47a152a5e2b5a44a3905183dae..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml
+++ /dev/null
@@ -1,41 +0,0 @@
-__base__:
- - configs/datasets/imagenet_zdm.yaml
- - configs/models/zdm-XL_imagenet.yaml
- - configs/trainers/zdm.yaml
-
-model:
- load_ckpt: save/dito-XL-f8c4-noise-sync/ckpt-last.pth
- name: dito
- args:
- renderer_ema_rate: 1
-
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml b/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml
deleted file mode 100644
index eeee58d880334bb1806ec51b8772f72bd1f34a75..0000000000000000000000000000000000000000
--- a/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml
+++ /dev/null
@@ -1,41 +0,0 @@
-__base__:
- - configs/datasets/imagenet_zdm.yaml
- - configs/models/zdm-XL_imagenet.yaml
- - configs/trainers/zdm.yaml
-
-model:
- load_ckpt:
- name: dito
- args:
- renderer_ema_rate: 1
-
- encoder:
- name: vqgan_encoder
- args: {config_name: f8c4}
-
- z_shape: [4, 32, 32]
- z_layernorm: true
-
- decoder: {name: identity}
-
- renderer:
- name: fixres_renderer_wrapper
- args:
- net:
- name: consistency_decoder_unet
- args:
- in_channels: 3
- z_dec_channels: 4
- c0: 320
- c1: 640
- c2: 1024
- pe_dim: 320
- t_dim: 1280
-
- render_diffusion:
- name: fm
- args: {timescale: 1000.0}
- render_sampler: {name: fm_euler_sampler}
- render_n_steps: 50
-
- loss_config: {}
diff --git a/dito/configs/models/zdm-XL_imagenet.yaml b/dito/configs/models/zdm-XL_imagenet.yaml
deleted file mode 100644
index 27cb8116834b7f90c09c612f39507f9acf541d4e..0000000000000000000000000000000000000000
--- a/dito/configs/models/zdm-XL_imagenet.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-model:
- args:
- zdm_net:
- name: dit_xl_2
- args: {n_classes: 1001}
- zdm_diffusion:
- name: fm
- args: {timescale: 1000.0}
- zdm_sampler: {name: fm_euler_sampler}
- zdm_n_steps: 200
- zdm_train_normalize: false
- zdm_class_cond: 1000
\ No newline at end of file
diff --git a/dito/datasets/__init__.py b/dito/datasets/__init__.py
deleted file mode 100644
index 52749f4dfe1a9a36afab67115bdce2f243851bd6..0000000000000000000000000000000000000000
--- a/dito/datasets/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .datasets import register, make
-from . import image_folder, class_folder, webdataset
-from . import wrapper_cae
diff --git a/dito/datasets/class_folder.py b/dito/datasets/class_folder.py
deleted file mode 100644
index 91070b4545e1752c64c011d2f2c07b7e3bbb8b3f..0000000000000000000000000000000000000000
--- a/dito/datasets/class_folder.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import os
-import random
-from PIL import Image, ImageFile
-
-from datasets import register
-from torch.utils.data import Dataset
-from torchvision import transforms
-
-
-Image.MAX_IMAGE_PIXELS = 933120000
-ImageFile.LOAD_TRUNCATED_IMAGES = True
-IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp')
-
-
-@register('class_folder')
-class ClassFolder(Dataset):
-
- def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False):
- folders = []
- print('root_path', root_path)
- for folder in sorted(os.listdir(root_path)):
- print('folder', folder)
- if os.path.isdir(os.path.join(root_path, folder)):
- folders.append(os.path.join(root_path, folder))
- print('folders', folders)
- self.files = []
- self.labels = []
- for i, folder in enumerate(folders):
- for file in sorted(os.listdir(os.path.join(root_path, folder))):
- if file.endswith(IMAGE_EXTS):
- self.files.append(os.path.join(root_path, folder, file))
- self.labels.append(i)
-
- self.resize = resize
- self.square_crop = square_crop
- self.rand_crop = rand_crop
- self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None
-
- self.n_classes = len(folders)
- self.drop_label_p = drop_label_p
-
- self.image_only = image_only
-
- def __len__(self):
- return len(self.files)
-
- def __getitem__(self, idx):
- try:
- image = Image.open(self.files[idx]).convert('RGB')
- label = self.labels[idx]
- except:
- print('Error loading image:', self.files[idx])
- return self.__getitem__((idx + 1) % self.__len__())
-
- if self.resize is not None:
- r = self.resize
- if isinstance(r, int):
- w, h = image.size
- if w < h:
- r = (r, int(h / w * r))
- else:
- r = (int(w / h * r), r)
- image = image.resize(r, Image.LANCZOS)
-
- if self.square_crop:
- w, h = image.size
- l = min(w, h)
- left, upper = (w - l) // 2, (h - l) // 2
- image = image.crop((left, upper, left + l, upper + l))
-
- if self.rand_crop is not None:
- w, h = image.size
- left = random.randint(0, w - self.rand_crop)
- upper = random.randint(0, h - self.rand_crop)
- image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop))
-
- if self.rand_flip is not None:
- image = self.rand_flip(image)
-
- if self.drop_label_p > 0.0 and random.random() < self.drop_label_p:
- label = self.n_classes
-
- if self.image_only:
- return image
- else:
- return {
- 'image': image,
- 'class_labels': label,
- }
diff --git a/dito/datasets/datasets.py b/dito/datasets/datasets.py
deleted file mode 100644
index 3f102233b144c4c04c35204fd97b052236502661..0000000000000000000000000000000000000000
--- a/dito/datasets/datasets.py
+++ /dev/null
@@ -1,17 +0,0 @@
-datasets = dict()
-
-
-def register(name):
- def decorator(cls):
- datasets[name] = cls
- return cls
- return decorator
-
-
-def make(spec):
- args = spec.get('args')
- if args is None:
- args = dict()
- print('args:', args)
- dataset = datasets[spec['name']](**args)
- return dataset
diff --git a/dito/datasets/image_folder.py b/dito/datasets/image_folder.py
deleted file mode 100644
index 7564ef3fa31ea6f6aadd26a2d37b77c213e00095..0000000000000000000000000000000000000000
--- a/dito/datasets/image_folder.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import os
-import random
-from PIL import Image, ImageFile
-
-from datasets import register
-from torch.utils.data import Dataset
-from torchvision import transforms
-
-
-Image.MAX_IMAGE_PIXELS = 933120000
-ImageFile.LOAD_TRUNCATED_IMAGES = True
-IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp')
-
-
-@register('image_folder')
-class ImageFolder(Dataset):
-
- def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False):
- files = sorted(os.listdir(root_path))
- self.files = [os.path.join(root_path, _) for _ in files if _.endswith(IMAGE_EXTS)]
-
- self.resize = resize
- self.square_crop = square_crop
- self.rand_crop = rand_crop
- self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None
-
- def __len__(self):
- return len(self.files)
-
- def __getitem__(self, idx):
- try:
- image = Image.open(self.files[idx]).convert('RGB')
- except:
- print('Error loading image:', self.files[idx])
- return self.__getitem__((idx + 1) % self.__len__())
-
- if self.resize is not None:
- r = self.resize
- if isinstance(r, int):
- w, h = image.size
- if w < h:
- r = (r, int(h / w * r))
- else:
- r = (int(w / h * r), r)
- image = image.resize(r, Image.LANCZOS)
-
- if self.square_crop:
- w, h = image.size
- l = min(w, h)
- left, upper = (w - l) // 2, (h - l) // 2
- image = image.crop((left, upper, left + l, upper + l))
-
- if self.rand_crop is not None:
- w, h = image.size
- left = random.randint(0, w - self.rand_crop)
- upper = random.randint(0, h - self.rand_crop)
- image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop))
-
- if self.rand_flip is not None:
- image = self.rand_flip(image)
-
- return image
diff --git a/dito/datasets/webdataset.py b/dito/datasets/webdataset.py
deleted file mode 100644
index 7772713a688566eb8ac6fc7c38efcaf1c2993d66..0000000000000000000000000000000000000000
--- a/dito/datasets/webdataset.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import json
-
-import webdataset as wds
-from webdataset.handlers import warn_and_continue
-
-from datasets import register
-
-
-def webdataset_preprocessors(square_crop=True):
- def identity(x):
- if isinstance(x, bytes):
- x = x.decode('utf-8')
- return x
-
- def transform(image):
- w, h = image.size
- l = min(w, h)
- left, upper = (w - l) // 2, (h - l) // 2
- return image.crop((left, upper, left + l, upper + l))
-
- ret = [
- ('jpg;png', transform if square_crop else lambda x: x, 'image'),
- ('txt', identity, 'caption'),
- ]
-
- return ret
-
-
-@register('webdataset')
-def make_webdataset(json_file, **kwargs):
- with open(json_file, 'r') as file:
- tar_list = json.load(file)
- preprocessors = webdataset_preprocessors(**kwargs)
- handler = warn_and_continue
- dataset = wds.WebDataset(
- tar_list, resampled=True, handler=handler
- ).shuffle(690, handler=handler).decode(
- "pilrgb", handler=handler
- ).to_tuple(
- *[p[0] for p in preprocessors], handler=handler
- ).map_tuple(
- *[p[1] for p in preprocessors], handler=handler
- ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)})
-
- return dataset
diff --git a/dito/datasets/wrapper_cae.py b/dito/datasets/wrapper_cae.py
deleted file mode 100644
index 1a684e1b959287c6647ad49399d7f5458f0818d0..0000000000000000000000000000000000000000
--- a/dito/datasets/wrapper_cae.py
+++ /dev/null
@@ -1,308 +0,0 @@
-import random
-from PIL import Image
-
-import torch
-from torch.utils.data import Dataset, IterableDataset
-from torchvision import transforms
-
-import datasets
-from datasets import register
-from utils.geometry import make_coord_scale_grid
-
-
-from models.ldm.dac.audiotools import AudioSignal
-import numpy as np
-
-from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader
-from models.ldm.dac.audiotools import transforms as tfm
-
-
-class BaseWrapperCAE:
-
- def __init__(
- self,
- dataset,
- resize_inp,
- return_gt=True,
- gt_glores_lb=None,
- gt_glores_ub=None,
- gt_patch_size=None,
- p_whole=0.0,
- p_max=0.0
- ):
- self.dataset = datasets.make(dataset)
- self.resize_inp = resize_inp
- self.return_gt = return_gt
- self.gt_glores_lb = gt_glores_lb
- self.gt_glores_ub = gt_glores_ub
- self.gt_patch_size = gt_patch_size
- self.p_whole = p_whole
- self.p_max = p_max
- self.transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(0.5, 0.5),
- ])
-
- def process(self, image):
- assert image.size[0] == image.size[1]
- ret = {}
-
- inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS)
- inp = self.transform(inp)
- ret.update({'inp': inp})
- if not self.return_gt:
- return ret
-
- if self.gt_glores_lb is None:
- glo = self.transform(image)
- else:
- if random.random() < self.p_whole:
- r = self.gt_patch_size
- elif random.random() < self.p_max:
- r = min(image.size[0], self.gt_glores_ub)
- else:
- r = random.randint(
- self.gt_glores_lb,
- max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub))
- )
- glo = image.resize((r, r), Image.LANCZOS)
- glo = self.transform(glo)
-
- p = self.gt_patch_size
- ii = random.randint(0, glo.shape[1] - p)
- jj = random.randint(0, glo.shape[2] - p)
- gt_patch = glo[:, ii: ii + p, jj: jj + p]
-
- x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1]
- x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1]
- coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]])
- ret['gt'] = torch.cat([
- gt_patch, # 3 p p
- coord.permute(2, 0, 1), # 2 p p
- scale.permute(2, 0, 1), # 2 p p
- ], dim=0)
-
- return ret
-
-
-@register('wrapper_cae')
-class WrapperCAE(BaseWrapperCAE, Dataset):
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, idx):
- data = self.dataset[idx]
- if isinstance(data, dict):
- ret = dict()
- ret.update(self.process(data.pop('image')))
- ret.update(data)
- return ret
- else:
- return self.process(data)
-
-
-@register('wrapper_cae_iterable')
-class WrapperCAE(BaseWrapperCAE, IterableDataset):
-
- def __iter__(self):
- for data in self.dataset:
- if isinstance(data, dict):
- ret = dict()
- ret.update(self.process(data.pop('image')))
- ret.update(data)
- yield ret
- else:
- yield self.process(data)
-
-
-
-
-
-
-class BaseWrapperAudioCAE:
- """Base wrapper for audio Convolutional Autoencoder (CAE) training.
-
- Similar to the image wrapper, but for audio data.
- """
-
- def __init__(
- self,
- dataset,
- sample_rate=24000,
- duration=0.38, # Duration in seconds
- n_samples=None, # Alternative: specify exact number of samples
- return_gt=True,
- gt_sample_rate=None, # Ground truth sample rate (if different)
- mono=True,
- normalize=True,
- return_coords=True, # Whether to return coordinate grids
- ):
- self.dataset = dataset
- self.sample_rate = sample_rate
- self.duration = duration
- self.n_samples = n_samples or int(duration * sample_rate)
- self.return_gt = return_gt
- self.gt_sample_rate = gt_sample_rate or sample_rate
- self.mono = mono
- self.normalize = normalize
- self.return_coords = return_coords
-
- def process(self, audio_data):
- """Process audio data for DiTo training.
-
- Args:
- audio_data: Dictionary with 'signal' key containing AudioSignal
- or AudioSignal directly
- """
- ret = {}
-
- # Extract AudioSignal
- if isinstance(audio_data, dict):
- signal = audio_data['signal']
- else:
- signal = audio_data
-
- # Convert to mono if needed
- if self.mono and signal.num_channels > 1:
- signal = signal.to_mono()
-
- # Resample to target sample rate
- if signal.sample_rate != self.sample_rate:
- signal = signal.resample(self.sample_rate)
-
- # Extract fixed duration
- if signal.duration < self.duration:
- # Pad if too short
- signal = signal.zero_pad_to(self.n_samples)
- else:
- # Take random excerpt if too long
- max_start = signal.num_samples - self.n_samples
- if max_start > 0:
- start_idx = random.randint(0, max_start)
- signal = signal[..., start_idx:start_idx + self.n_samples]
- else:
- signal = signal[..., :self.n_samples]
-
- # Normalize audio
- audio_tensor = signal.audio_data # Shape: [channels, samples]
- if self.normalize:
- # Normalize to [-1, 1]
- max_val = audio_tensor.abs().max()
- if max_val > 0:
- audio_tensor = audio_tensor / max_val
-
- # Create input tensor
- ret['inp'] = audio_tensor
-
- if not self.return_gt:
- return ret
-
-
- ret['gt'] = audio_tensor
-
- return ret
-
-
-@register('wrapper_audio_cae')
-class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset):
- """Dataset wrapper for audio CAE training."""
-
- def __len__(self):
- return len(self.dataset)
-
- def __getitem__(self, idx):
- data = self.dataset[idx]
- return self.process(data)
-
-
-@register('wrapper_audio_cae_iterable')
-class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset):
- """Iterable dataset wrapper for audio CAE training."""
-
- def __iter__(self):
- for data in self.dataset:
- yield self.process(data)
-
-
-# Example usage with your existing AudioDataset
-def create_dito_audio_dataset(config):
- """Create DiTo audio dataset from config."""
-
- # Create base audio dataset using audiotools
-
- # Setup audio loaders
- train_folders = config.get("train_folders", {})
-
- loader = AudioLoader(
- sources=list(train_folders.values()),
- transform=tfm.Compose(
- tfm.VolumeNorm(("uniform", -20, -10)),
- tfm.RescaleAudio(),
- ),
- ext=['.wav', '.flac', '.mp3'],
- )
-
- # Create base dataset
- base_dataset = AudioDataset(
- loaders=loader,
- sample_rate=config['sample_rate'],
- duration=config['duration'],
- n_examples=config['n_examples'],
- num_channels=1 if config.get('mono', True) else 2,
- )
-
- # Wrap with DiTo wrapper
- dito_dataset = WrapperAudioCAE(
- dataset=base_dataset,
- sample_rate=config['sample_rate'],
- duration=config['duration'],
- mono=config.get('mono', True),
- normalize=True,
- return_coords=True,
- )
-
- return dito_dataset
-
-
-# For your training config, you would use it like:
-"""
-datasets:
- train:
- name: wrapper_audio_cae
- args:
- dataset:
- name: audio_dataset # Your base audio dataset
- args:
- sources: ["/path/to/audio/files"]
- sample_rate: 44100
- duration: 2.0
- n_examples: 10000
- sample_rate: 44100
- duration: 2.0
- mono: true
- normalize: true
- return_coords: true
- loader:
- batch_size: 16
- num_workers: 8
-
- val:
- name: wrapper_audio_cae
- args:
- dataset:
- name: audio_dataset
- args:
- sources: ["/path/to/val/audio/files"]
- sample_rate: 44100
- duration: 2.0
- n_examples: 1000
- sample_rate: 44100
- duration: 2.0
- mono: true
- normalize: true
- return_coords: true
- loader:
- batch_size: 16
- num_workers: 8
-"""
\ No newline at end of file
diff --git a/dito/load/dito.png b/dito/load/dito.png
deleted file mode 100644
index bbc7f7f11845d568c7c7e4d3439299c87073f4cb..0000000000000000000000000000000000000000
--- a/dito/load/dito.png
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:25999470ed7eaba155f1e8ab639fd43fe35eb53a53e9770829451b3d0434c467
-size 245154
diff --git a/dito/load/vgg_lpips.pth b/dito/load/vgg_lpips.pth
deleted file mode 100644
index 47e943cfacabf7040b4af8cf4084ab91177f1b88..0000000000000000000000000000000000000000
Binary files a/dito/load/vgg_lpips.pth and /dev/null differ
diff --git a/dito/load/wandb.yaml b/dito/load/wandb.yaml
deleted file mode 100644
index 792009be6e6f698d184cd8176c463da38b9f9075..0000000000000000000000000000000000000000
--- a/dito/load/wandb.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-entity:
-api_key:
-project:
\ No newline at end of file
diff --git a/dito/models/__init__.py b/dito/models/__init__.py
deleted file mode 100644
index b5a61e9db76e11bad1e378d95caa94e06341e0ee..0000000000000000000000000000000000000000
--- a/dito/models/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .models import register, make
-from . import ldm
-from . import diffusion
-from . import networks
\ No newline at end of file
diff --git a/dito/models/diffusion/__init__.py b/dito/models/diffusion/__init__.py
deleted file mode 100644
index 6578ad49c27b9797ee3e62c0b5d952089e253ed9..0000000000000000000000000000000000000000
--- a/dito/models/diffusion/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from . import fm
-from . import samplers
diff --git a/dito/models/diffusion/fm.py b/dito/models/diffusion/fm.py
deleted file mode 100644
index 0c008882977b3e240ee4ecfaf07802e565997fd0..0000000000000000000000000000000000000000
--- a/dito/models/diffusion/fm.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import torch
-
-from models import register
-
-
-@register('fm')
-class FM:
-
- def __init__(self, sigma_min=1e-5, timescale=1.0):
- self.sigma_min = sigma_min
- self.prediction_type = None
- self.timescale = timescale
-
- def alpha(self, t):
- return 1.0 - t
-
- def sigma(self, t):
- return self.sigma_min + t * (1.0 - self.sigma_min)
-
- def A(self, t):
- return 1.0
-
- def B(self, t):
- return -(1.0 - self.sigma_min)
-
- def get_betas(self, n_timesteps):
- return torch.zeros(n_timesteps) # Not VP and not supported
-
- def add_noise(self, x, t, noise=None):
- noise = torch.randn_like(x) if noise is None else noise
- s = [x.shape[0]] + [1] * (x.dim() - 1)
- x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise
- return x_t, noise
-
- def loss(self, net, x, t=None, net_kwargs=None, return_loss_unreduced=False, return_all=False):
- if net_kwargs is None:
- net_kwargs = {}
-
- if t is None:
- t = torch.rand(x.shape[0], device=x.device)
- print('x shape: ', x.shape)
- x_t, noise = self.add_noise(x, t)
- print('x_t shape: ', x_t.shape)
- pred = net(x_t, t=t * self.timescale, **net_kwargs)
- print('pred shape: ', pred.shape)
-
- target = self.A(t) * x + self.B(t) * noise # -dxt/dt
- print('target shape: ', target.shape)
- print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all)
- if return_loss_unreduced:
- loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3])
- if return_all:
- return loss, t, x_t, pred
- else:
- return loss, t
- else:
- # here we go
- loss = ((pred.float() - target.float()) ** 2).mean()
- if return_all:
- return loss, x_t, pred
- else:
- return loss
-
- def get_prediction(
- self,
- net,
- x_t,
- t,
- net_kwargs=None,
- uncond_net_kwargs=None,
- guidance=1.0,
- ):
- if net_kwargs is None:
- net_kwargs = {}
- pred = net(x_t, t=t * self.timescale, **net_kwargs)
- if guidance != 1.0:
- assert uncond_net_kwargs is not None
- uncond_pred = net(x_t, t=t * self.timescale, **uncond_net_kwargs)
- pred = uncond_pred + guidance * (pred - uncond_pred)
- return pred
-
- def convert_sample_prediction(self, x_t, t, pred):
- M = torch.tensor([
- [self.alpha(t), self.sigma(t)],
- [self.A(t), self.B(t)],
- ], dtype=torch.float64)
- M_inv = torch.linalg.inv(M)
- sample_pred = M_inv[0, 0].item() * x_t + M_inv[0, 1].item() * pred
- return sample_pred
diff --git a/dito/models/diffusion/samplers.py b/dito/models/diffusion/samplers.py
deleted file mode 100644
index 75fc52249e08e8f5bee516349f9545f0de7c5513..0000000000000000000000000000000000000000
--- a/dito/models/diffusion/samplers.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import numpy as np
-import torch
-
-from models import register
-
-
-@register('fm_euler_sampler')
-class FMEulerSampler:
-
- def __init__(self, diffusion):
- self.diffusion = diffusion
-
- def sample(
- self,
- net,
- shape,
- n_steps,
- net_kwargs=None,
- uncond_net_kwargs=None,
- guidance=1.0,
- noise=None,
- ):
- device = next(net.parameters()).device
- x_t = torch.randn(shape, device=device) if noise is None else noise
- t_steps = torch.linspace(1, 0, n_steps + 1, device=device)
-
- with torch.no_grad():
- for i in range(n_steps):
- t = t_steps[i].repeat(x_t.shape[0])
- neg_v = self.diffusion.get_prediction(
- net,
- x_t,
- t,
- net_kwargs=net_kwargs,
- uncond_net_kwargs=uncond_net_kwargs,
- guidance=guidance,
- )
- x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1])
- return x_t
diff --git a/dito/models/ldm/__init__.py b/dito/models/ldm/__init__.py
deleted file mode 100644
index e8103bdfc078dae107945147b04cb44ed5758f5e..0000000000000000000000000000000000000000
--- a/dito/models/ldm/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from . import glpto, dito
-from . import renderers
-from . import vqgan
-from . import dac
\ No newline at end of file
diff --git a/dito/models/ldm/dac/__init__.py b/dito/models/ldm/dac/__init__.py
deleted file mode 100644
index 90f60fdd89ad8575faafe45188bd1d968852fc67..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .utils import *
\ No newline at end of file
diff --git a/dito/models/ldm/dac/audiotools/__init__.py b/dito/models/ldm/dac/audiotools/__init__.py
deleted file mode 100644
index b251ff37628c56a19bb38976fca99d9536e64bbf..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-__version__ = "0.7.4"
-from .core import AudioSignal
-from .core import STFTParams
-from .core import Meter
-from .core import util
-from . import metrics
-from . import data
-from . import ml
-from .data import datasets
-from .data import transforms
diff --git a/dito/models/ldm/dac/audiotools/core/__init__.py b/dito/models/ldm/dac/audiotools/core/__init__.py
deleted file mode 100644
index 8660c4e67f43d0ded584a38939425e2c28d95cd3..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from . import util
-from .audio_signal import AudioSignal
-from .audio_signal import STFTParams
-from .loudness import Meter
diff --git a/dito/models/ldm/dac/audiotools/core/audio_signal.py b/dito/models/ldm/dac/audiotools/core/audio_signal.py
deleted file mode 100644
index fb6d751cb968a003656e3e7874c487b83d94c82e..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/audio_signal.py
+++ /dev/null
@@ -1,1682 +0,0 @@
-import copy
-import functools
-import hashlib
-import math
-import pathlib
-import tempfile
-import typing
-import warnings
-from collections import namedtuple
-from pathlib import Path
-
-import julius
-import numpy as np
-import soundfile
-import torch
-
-from . import util
-from .display import DisplayMixin
-from .dsp import DSPMixin
-from .effects import EffectMixin
-from .effects import ImpulseResponseMixin
-from .ffmpeg import FFMPEGMixin
-from .loudness import LoudnessMixin
-from .playback import PlayMixin
-from .whisper import WhisperMixin
-
-
-STFTParams = namedtuple(
- "STFTParams",
- ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
-)
-"""
-STFTParams object is a container that holds STFT parameters - window_length,
-hop_length, and window_type. Not all parameters need to be specified. Ones that
-are not specified will be inferred by the AudioSignal parameters.
-
-Parameters
-----------
-window_length : int, optional
- Window length of STFT, by default ``0.032 * self.sample_rate``.
-hop_length : int, optional
- Hop length of STFT, by default ``window_length // 4``.
-window_type : str, optional
- Type of window to use, by default ``sqrt\_hann``.
-match_stride : bool, optional
- Whether to match the stride of convolutional layers, by default False
-padding_type : str, optional
- Type of padding to use, by default 'reflect'
-"""
-STFTParams.__new__.__defaults__ = (None, None, None, None, None)
-
-
-class AudioSignal(
- EffectMixin,
- LoudnessMixin,
- PlayMixin,
- ImpulseResponseMixin,
- DSPMixin,
- DisplayMixin,
- FFMPEGMixin,
- WhisperMixin,
-):
- """This is the core object of this library. Audio is always
- loaded into an AudioSignal, which then enables all the features
- of this library, including audio augmentations, I/O, playback,
- and more.
-
- The structure of this object is that the base functionality
- is defined in ``core/audio_signal.py``, while extensions to
- that functionality are defined in the other ``core/*.py``
- files. For example, all the display-based functionality
- (e.g. plot spectrograms, waveforms, write to tensorboard)
- are in ``core/display.py``.
-
- Parameters
- ----------
- audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
- Object to create AudioSignal from. Can be a tensor, numpy array,
- or a path to a file. The file is always reshaped to
- sample_rate : int, optional
- Sample rate of the audio. If different from underlying file, resampling is
- performed. If passing in an array or tensor, this must be defined,
- by default None
- stft_params : STFTParams, optional
- Parameters of STFT to use. , by default None
- offset : float, optional
- Offset in seconds to read from file, by default 0
- duration : float, optional
- Duration in seconds to read from file, by default None
- device : str, optional
- Device to load audio onto, by default None
-
- Examples
- --------
- Loading an AudioSignal from an array, at a sample rate of
- 44100.
-
- >>> signal = AudioSignal(torch.randn(5*44100), 44100)
-
- Note, the signal is reshaped to have a batch size, and one
- audio channel:
-
- >>> print(signal.shape)
- (1, 1, 44100)
-
- You can treat AudioSignals like tensors, and many of the same
- functions you might use on tensors are defined for AudioSignals
- as well:
-
- >>> signal.to("cuda")
- >>> signal.cuda()
- >>> signal.clone()
- >>> signal.detach()
-
- Indexing AudioSignals returns an AudioSignal:
-
- >>> signal[..., 3*44100:4*44100]
-
- The above signal is 1 second long, and is also an AudioSignal.
- """
-
- def __init__(
- self,
- audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
- sample_rate: int = None,
- stft_params: STFTParams = None,
- offset: float = 0,
- duration: float = None,
- device: str = None,
- ):
- audio_path = None
- audio_array = None
-
- if isinstance(audio_path_or_array, str):
- audio_path = audio_path_or_array
- elif isinstance(audio_path_or_array, pathlib.Path):
- audio_path = audio_path_or_array
- elif isinstance(audio_path_or_array, np.ndarray):
- audio_array = audio_path_or_array
- elif torch.is_tensor(audio_path_or_array):
- audio_array = audio_path_or_array
- else:
- raise ValueError(
- "audio_path_or_array must be either a Path, "
- "string, numpy array, or torch Tensor!"
- )
-
- self.path_to_file = None
-
- self.audio_data = None
- self.sources = None # List of AudioSignal objects.
- self.stft_data = None
- if audio_path is not None:
- self.load_from_file(
- audio_path, offset=offset, duration=duration, device=device
- )
- elif audio_array is not None:
- assert sample_rate is not None, "Must set sample rate!"
- self.load_from_array(audio_array, sample_rate, device=device)
-
- self.window = None
- self.stft_params = stft_params
-
- self.metadata = {
- "offset": offset,
- "duration": duration,
- }
-
- @property
- def path_to_input_file(
- self,
- ):
- """
- Path to input file, if it exists.
- Alias to ``path_to_file`` for backwards compatibility
- """
- return self.path_to_file
-
- @classmethod
- def excerpt(
- cls,
- audio_path: typing.Union[str, Path],
- offset: float = None,
- duration: float = None,
- state: typing.Union[np.random.RandomState, int] = None,
- **kwargs,
- ):
- """Randomly draw an excerpt of ``duration`` seconds from an
- audio file specified at ``audio_path``, between ``offset`` seconds
- and end of file. ``state`` can be used to seed the random draw.
-
- Parameters
- ----------
- audio_path : typing.Union[str, Path]
- Path to audio file to grab excerpt from.
- offset : float, optional
- Lower bound for the start time, in seconds drawn from
- the file, by default None.
- duration : float, optional
- Duration of excerpt, in seconds, by default None
- state : typing.Union[np.random.RandomState, int], optional
- RandomState or seed of random state, by default None
-
- Returns
- -------
- AudioSignal
- AudioSignal containing excerpt.
-
- Examples
- --------
- >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
- """
- info = util.info(audio_path)
- total_duration = info.duration
-
- state = util.random_state(state)
- lower_bound = 0 if offset is None else offset
- upper_bound = max(total_duration - duration, 0)
- offset = state.uniform(lower_bound, upper_bound)
-
- signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
- signal.metadata["offset"] = offset
- signal.metadata["duration"] = duration
-
- return signal
-
- @classmethod
- def salient_excerpt(
- cls,
- audio_path: typing.Union[str, Path],
- loudness_cutoff: float = None,
- num_tries: int = 8,
- state: typing.Union[np.random.RandomState, int] = None,
- **kwargs,
- ):
- """Similar to AudioSignal.excerpt, except it extracts excerpts only
- if they are above a specified loudness threshold, which is computed via
- a fast LUFS routine.
-
- Parameters
- ----------
- audio_path : typing.Union[str, Path]
- Path to audio file to grab excerpt from.
- loudness_cutoff : float, optional
- Loudness threshold in dB. Typical values are ``-40, -60``,
- etc, by default None
- num_tries : int, optional
- Number of tries to grab an excerpt above the threshold
- before giving up, by default 8.
- state : typing.Union[np.random.RandomState, int], optional
- RandomState or seed of random state, by default None
- kwargs : dict
- Keyword arguments to AudioSignal.excerpt
-
- Returns
- -------
- AudioSignal
- AudioSignal containing excerpt.
-
-
- .. warning::
- if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
- result in an infinite loop if ``audio_path`` does not have
- any loud enough excerpts.
-
- Examples
- --------
- >>> signal = AudioSignal.salient_excerpt(
- "path/to/audio",
- loudness_cutoff=-40,
- duration=5
- )
- """
- state = util.random_state(state)
- if loudness_cutoff is None:
- excerpt = cls.excerpt(audio_path, state=state, **kwargs)
- else:
- loudness = -np.inf
- num_try = 0
- while loudness <= loudness_cutoff:
- excerpt = cls.excerpt(audio_path, state=state, **kwargs)
- loudness = excerpt.loudness()
- num_try += 1
- if num_tries is not None and num_try >= num_tries:
- break
- return excerpt
-
- @classmethod
- def zeros(
- cls,
- duration: float,
- sample_rate: int,
- num_channels: int = 1,
- batch_size: int = 1,
- **kwargs,
- ):
- """Helper function create an AudioSignal of all zeros.
-
- Parameters
- ----------
- duration : float
- Duration of AudioSignal
- sample_rate : int
- Sample rate of AudioSignal
- num_channels : int, optional
- Number of channels, by default 1
- batch_size : int, optional
- Batch size, by default 1
-
- Returns
- -------
- AudioSignal
- AudioSignal containing all zeros.
-
- Examples
- --------
- Generate 5 seconds of all zeros at a sample rate of 44100.
-
- >>> signal = AudioSignal.zeros(5.0, 44100)
- """
- n_samples = int(duration * sample_rate)
- return cls(
- torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
- )
-
- @classmethod
- def wave(
- cls,
- frequency: float,
- duration: float,
- sample_rate: int,
- num_channels: int = 1,
- shape: str = "sine",
- **kwargs,
- ):
- """
- Generate a waveform of a given frequency and shape.
-
- Parameters
- ----------
- frequency : float
- Frequency of the waveform
- duration : float
- Duration of the waveform
- sample_rate : int
- Sample rate of the waveform
- num_channels : int, optional
- Number of channels, by default 1
- shape : str, optional
- Shape of the waveform, by default "saw"
- One of "sawtooth", "square", "sine", "triangle"
- kwargs : dict
- Keyword arguments to AudioSignal
- """
- n_samples = int(duration * sample_rate)
- t = torch.linspace(0, duration, n_samples)
- if shape == "sawtooth":
- from scipy.signal import sawtooth
-
- wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
- elif shape == "square":
- from scipy.signal import square
-
- wave_data = square(2 * np.pi * frequency * t)
- elif shape == "sine":
- wave_data = np.sin(2 * np.pi * frequency * t)
- elif shape == "triangle":
- from scipy.signal import sawtooth
-
- # frequency is doubled by the abs call, so omit the 2 in 2pi
- wave_data = sawtooth(np.pi * frequency * t, 0.5)
- wave_data = -np.abs(wave_data) * 2 + 1
- else:
- raise ValueError(f"Invalid shape {shape}")
-
- wave_data = torch.tensor(wave_data, dtype=torch.float32)
- wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
- return cls(wave_data, sample_rate, **kwargs)
-
- @classmethod
- def batch(
- cls,
- audio_signals: list,
- pad_signals: bool = False,
- truncate_signals: bool = False,
- resample: bool = False,
- dim: int = 0,
- ):
- """Creates a batched AudioSignal from a list of AudioSignals.
-
- Parameters
- ----------
- audio_signals : list[AudioSignal]
- List of AudioSignal objects
- pad_signals : bool, optional
- Whether to pad signals to length of the maximum length
- AudioSignal in the list, by default False
- truncate_signals : bool, optional
- Whether to truncate signals to length of shortest length
- AudioSignal in the list, by default False
- resample : bool, optional
- Whether to resample AudioSignal to the sample rate of
- the first AudioSignal in the list, by default False
- dim : int, optional
- Dimension along which to batch the signals.
-
- Returns
- -------
- AudioSignal
- Batched AudioSignal.
-
- Raises
- ------
- RuntimeError
- If not all AudioSignals are the same sample rate, and
- ``resample=False``, an error is raised.
- RuntimeError
- If not all AudioSignals are the same the length, and
- both ``pad_signals=False`` and ``truncate_signals=False``,
- an error is raised.
-
- Examples
- --------
- Batching a bunch of random signals:
-
- >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
- >>> signal = AudioSignal.batch(signal_list)
- >>> print(signal.shape)
- (10, 1, 44100)
-
- """
- signal_lengths = [x.signal_length for x in audio_signals]
- sample_rates = [x.sample_rate for x in audio_signals]
-
- if len(set(sample_rates)) != 1:
- if resample:
- for x in audio_signals:
- x.resample(sample_rates[0])
- else:
- raise RuntimeError(
- f"Not all signals had the same sample rate! Got {sample_rates}. "
- f"All signals must have the same sample rate, or resample must be True. "
- )
-
- if len(set(signal_lengths)) != 1:
- if pad_signals:
- max_length = max(signal_lengths)
- for x in audio_signals:
- pad_len = max_length - x.signal_length
- x.zero_pad(0, pad_len)
- elif truncate_signals:
- min_length = min(signal_lengths)
- for x in audio_signals:
- x.truncate_samples(min_length)
- else:
- raise RuntimeError(
- f"Not all signals had the same length! Got {signal_lengths}. "
- f"All signals must be the same length, or pad_signals/truncate_signals "
- f"must be True. "
- )
- # Concatenate along the specified dimension (default 0)
- audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
- audio_paths = [x.path_to_file for x in audio_signals]
-
- batched_signal = cls(
- audio_data,
- sample_rate=audio_signals[0].sample_rate,
- )
- batched_signal.path_to_file = audio_paths
- return batched_signal
-
- # I/O
- def load_from_file(
- self,
- audio_path: typing.Union[str, Path],
- offset: float,
- duration: float,
- device: str = "cpu",
- ):
- """Loads data from file. Used internally when AudioSignal
- is instantiated with a path to a file.
-
- Parameters
- ----------
- audio_path : typing.Union[str, Path]
- Path to file
- offset : float
- Offset in seconds
- duration : float
- Duration in seconds
- device : str, optional
- Device to put AudioSignal on, by default "cpu"
-
- Returns
- -------
- AudioSignal
- AudioSignal loaded from file
- """
- import librosa
-
- data, sample_rate = librosa.load(
- audio_path,
- offset=offset,
- duration=duration,
- sr=None,
- mono=False,
- )
- data = util.ensure_tensor(data)
- if data.shape[-1] == 0:
- raise RuntimeError(
- f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
- )
-
- if data.ndim < 2:
- data = data.unsqueeze(0)
- if data.ndim < 3:
- data = data.unsqueeze(0)
- self.audio_data = data
-
- self.original_signal_length = self.signal_length
-
- self.sample_rate = sample_rate
- self.path_to_file = audio_path
- return self.to(device)
-
- def load_from_array(
- self,
- audio_array: typing.Union[torch.Tensor, np.ndarray],
- sample_rate: int,
- device: str = "cpu",
- ):
- """Loads data from array, reshaping it to be exactly 3
- dimensions. Used internally when AudioSignal is called
- with a tensor or an array.
-
- Parameters
- ----------
- audio_array : typing.Union[torch.Tensor, np.ndarray]
- Array/tensor of audio of samples.
- sample_rate : int
- Sample rate of audio
- device : str, optional
- Device to move audio onto, by default "cpu"
-
- Returns
- -------
- AudioSignal
- AudioSignal loaded from array
- """
- audio_data = util.ensure_tensor(audio_array)
-
- if audio_data.dtype == torch.double:
- audio_data = audio_data.float()
-
- if audio_data.ndim < 2:
- audio_data = audio_data.unsqueeze(0)
- if audio_data.ndim < 3:
- audio_data = audio_data.unsqueeze(0)
- self.audio_data = audio_data
-
- self.original_signal_length = self.signal_length
-
- self.sample_rate = sample_rate
- return self.to(device)
-
- def write(self, audio_path: typing.Union[str, Path]):
- """Writes audio to a file. Only writes the audio
- that is in the very first item of the batch. To write other items
- in the batch, index the signal along the batch dimension
- before writing. After writing, the signal's ``path_to_file``
- attribute is updated to the new path.
-
- Parameters
- ----------
- audio_path : typing.Union[str, Path]
- Path to write audio to.
-
- Returns
- -------
- AudioSignal
- Returns original AudioSignal, so you can use this in a fluent
- interface.
-
- Examples
- --------
- Creating and writing a signal to disk:
-
- >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
- >>> signal.write("/tmp/out.wav")
-
- Writing a different element of the batch:
-
- >>> signal[5].write("/tmp/out.wav")
-
- Using this in a fluent interface:
-
- >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
-
- """
- if self.audio_data[0].abs().max() > 1:
- warnings.warn("Audio amplitude > 1 clipped when saving")
- soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
-
- self.path_to_file = audio_path
- return self
-
- def deepcopy(self):
- """Copies the signal and all of its attributes.
-
- Returns
- -------
- AudioSignal
- Deep copy of the audio signal.
- """
- return copy.deepcopy(self)
-
- def copy(self):
- """Shallow copy of signal.
-
- Returns
- -------
- AudioSignal
- Shallow copy of the audio signal.
- """
- return copy.copy(self)
-
- def clone(self):
- """Clones all tensors contained in the AudioSignal,
- and returns a copy of the signal with everything
- cloned. Useful when using AudioSignal within autograd
- computation graphs.
-
- Relevant attributes are the stft data, the audio data,
- and the loudness of the file.
-
- Returns
- -------
- AudioSignal
- Clone of AudioSignal.
- """
- clone = type(self)(
- self.audio_data.clone(),
- self.sample_rate,
- stft_params=self.stft_params,
- )
- if self.stft_data is not None:
- clone.stft_data = self.stft_data.clone()
- if self._loudness is not None:
- clone._loudness = self._loudness.clone()
- clone.path_to_file = copy.deepcopy(self.path_to_file)
- clone.metadata = copy.deepcopy(self.metadata)
- return clone
-
- def detach(self):
- """Detaches tensors contained in AudioSignal.
-
- Relevant attributes are the stft data, the audio data,
- and the loudness of the file.
-
- Returns
- -------
- AudioSignal
- Same signal, but with all tensors detached.
- """
- if self._loudness is not None:
- self._loudness = self._loudness.detach()
- if self.stft_data is not None:
- self.stft_data = self.stft_data.detach()
-
- self.audio_data = self.audio_data.detach()
- return self
-
- def hash(self):
- """Writes the audio data to a temporary file, and then
- hashes it using hashlib. Useful for creating a file
- name based on the audio content.
-
- Returns
- -------
- str
- Hash of audio data.
-
- Examples
- --------
- Creating a signal, and writing it to a unique file name:
-
- >>> signal = AudioSignal(torch.randn(44100), 44100)
- >>> hash = signal.hash()
- >>> signal.write(f"{hash}.wav")
-
- """
- with tempfile.NamedTemporaryFile(suffix=".wav") as f:
- self.write(f.name)
- h = hashlib.sha256()
- b = bytearray(128 * 1024)
- mv = memoryview(b)
- with open(f.name, "rb", buffering=0) as f:
- for n in iter(lambda: f.readinto(mv), 0):
- h.update(mv[:n])
- file_hash = h.hexdigest()
- return file_hash
-
- # Signal operations
- def to_mono(self):
- """Converts audio data to mono audio, by taking the mean
- along the channels dimension.
-
- Returns
- -------
- AudioSignal
- AudioSignal with mean of channels.
- """
- self.audio_data = self.audio_data.mean(1, keepdim=True)
- return self
-
- def resample(self, sample_rate: int):
- """Resamples the audio, using sinc interpolation. This works on both
- cpu and gpu, and is much faster on gpu.
-
- Parameters
- ----------
- sample_rate : int
- Sample rate to resample to.
-
- Returns
- -------
- AudioSignal
- Resampled AudioSignal
- """
- if sample_rate == self.sample_rate:
- return self
- self.audio_data = julius.resample_frac(
- self.audio_data, self.sample_rate, sample_rate
- )
- self.sample_rate = sample_rate
- return self
-
- # Tensor operations
- def to(self, device: str):
- """Moves all tensors contained in signal to the specified device.
-
- Parameters
- ----------
- device : str
- Device to move AudioSignal onto. Typical values are
- "cuda", "cpu", or "cuda:n" to specify the nth gpu.
-
- Returns
- -------
- AudioSignal
- AudioSignal with all tensors moved to specified device.
- """
- if self._loudness is not None:
- self._loudness = self._loudness.to(device)
- if self.stft_data is not None:
- self.stft_data = self.stft_data.to(device)
- if self.audio_data is not None:
- self.audio_data = self.audio_data.to(device)
- return self
-
- def float(self):
- """Calls ``.float()`` on ``self.audio_data``.
-
- Returns
- -------
- AudioSignal
- """
- self.audio_data = self.audio_data.float()
- return self
-
- def cpu(self):
- """Moves AudioSignal to cpu.
-
- Returns
- -------
- AudioSignal
- """
- return self.to("cpu")
-
- def cuda(self): # pragma: no cover
- """Moves AudioSignal to cuda.
-
- Returns
- -------
- AudioSignal
- """
- return self.to("cuda")
-
- def numpy(self):
- """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
-
- Returns
- -------
- np.ndarray
- Audio data as a numpy array.
- """
- return self.audio_data.detach().cpu().numpy()
-
- def zero_pad(self, before: int, after: int):
- """Zero pads the audio_data tensor before and after.
-
- Parameters
- ----------
- before : int
- How many zeros to prepend to audio.
- after : int
- How many zeros to append to audio.
-
- Returns
- -------
- AudioSignal
- AudioSignal with padding applied.
- """
- self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
- return self
-
- def zero_pad_to(self, length: int, mode: str = "after"):
- """Pad with zeros to a specified length, either before or after
- the audio data.
-
- Parameters
- ----------
- length : int
- Length to pad to
- mode : str, optional
- Whether to prepend or append zeros to signal, by default "after"
-
- Returns
- -------
- AudioSignal
- AudioSignal with padding applied.
- """
- if mode == "before":
- self.zero_pad(max(length - self.signal_length, 0), 0)
- elif mode == "after":
- self.zero_pad(0, max(length - self.signal_length, 0))
- return self
-
- def trim(self, before: int, after: int):
- """Trims the audio_data tensor before and after.
-
- Parameters
- ----------
- before : int
- How many samples to trim from beginning.
- after : int
- How many samples to trim from end.
-
- Returns
- -------
- AudioSignal
- AudioSignal with trimming applied.
- """
- if after == 0:
- self.audio_data = self.audio_data[..., before:]
- else:
- self.audio_data = self.audio_data[..., before:-after]
- return self
-
- def truncate_samples(self, length_in_samples: int):
- """Truncate signal to specified length.
-
- Parameters
- ----------
- length_in_samples : int
- Truncate to this many samples.
-
- Returns
- -------
- AudioSignal
- AudioSignal with truncation applied.
- """
- self.audio_data = self.audio_data[..., :length_in_samples]
- return self
-
- @property
- def device(self):
- """Get device that AudioSignal is on.
-
- Returns
- -------
- torch.device
- Device that AudioSignal is on.
- """
- if self.audio_data is not None:
- device = self.audio_data.device
- elif self.stft_data is not None:
- device = self.stft_data.device
- return device
-
- # Properties
- @property
- def audio_data(self):
- """Returns the audio data tensor in the object.
-
- Audio data is always of the shape
- (batch_size, num_channels, num_samples). If value has less
- than 3 dims (e.g. is (num_channels, num_samples)), then it will
- be reshaped to (1, num_channels, num_samples) - a batch size of 1.
-
- Parameters
- ----------
- data : typing.Union[torch.Tensor, np.ndarray]
- Audio data to set.
-
- Returns
- -------
- torch.Tensor
- Audio samples.
- """
- return self._audio_data
-
- @audio_data.setter
- def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
- if data is not None:
- assert torch.is_tensor(data), "audio_data should be torch.Tensor"
- assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
- self._audio_data = data
- # Old loudness value not guaranteed to be right, reset it.
- self._loudness = None
- return
-
- # alias for audio_data
- samples = audio_data
-
- @property
- def stft_data(self):
- """Returns the STFT data inside the signal. Shape is
- (batch, channels, frequencies, time).
-
- Returns
- -------
- torch.Tensor
- Complex spectrogram data.
- """
- return self._stft_data
-
- @stft_data.setter
- def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
- if data is not None:
- assert torch.is_tensor(data) and torch.is_complex(data)
- if self.stft_data is not None and self.stft_data.shape != data.shape:
- warnings.warn("stft_data changed shape")
- self._stft_data = data
- return
-
- @property
- def batch_size(self):
- """Batch size of audio signal.
-
- Returns
- -------
- int
- Batch size of signal.
- """
- return self.audio_data.shape[0]
-
- @property
- def signal_length(self):
- """Length of audio signal.
-
- Returns
- -------
- int
- Length of signal in samples.
- """
- return self.audio_data.shape[-1]
-
- # alias for signal_length
- length = signal_length
-
- @property
- def shape(self):
- """Shape of audio data.
-
- Returns
- -------
- tuple
- Shape of audio data.
- """
- return self.audio_data.shape
-
- @property
- def signal_duration(self):
- """Length of audio signal in seconds.
-
- Returns
- -------
- float
- Length of signal in seconds.
- """
- return self.signal_length / self.sample_rate
-
- # alias for signal_duration
- duration = signal_duration
-
- @property
- def num_channels(self):
- """Number of audio channels.
-
- Returns
- -------
- int
- Number of audio channels.
- """
- return self.audio_data.shape[1]
-
- # STFT
- @staticmethod
- @functools.lru_cache(None)
- def get_window(window_type: str, window_length: int, device: str):
- """Wrapper around scipy.signal.get_window so one can also get the
- popular sqrt-hann window. This function caches for efficiency
- using functools.lru\_cache.
-
- Parameters
- ----------
- window_type : str
- Type of window to get
- window_length : int
- Length of the window
- device : str
- Device to put window onto.
-
- Returns
- -------
- torch.Tensor
- Window returned by scipy.signal.get_window, as a tensor.
- """
- from scipy import signal
-
- if window_type == "average":
- window = np.ones(window_length) / window_length
- elif window_type == "sqrt_hann":
- window = np.sqrt(signal.get_window("hann", window_length))
- else:
- window = signal.get_window(window_type, window_length)
- window = torch.from_numpy(window).to(device).float()
- return window
-
- @property
- def stft_params(self):
- """Returns STFTParams object, which can be re-used to other
- AudioSignals.
-
- This property can be set as well. If values are not defined in STFTParams,
- they are inferred automatically from the signal properties. The default is to use
- 32ms windows, with 8ms hop length, and the square root of the hann window.
-
- Returns
- -------
- STFTParams
- STFT parameters for the AudioSignal.
-
- Examples
- --------
- >>> stft_params = STFTParams(128, 32)
- >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
- >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
- >>> signal1.stft_params = STFTParams() # Defaults
- """
- return self._stft_params
-
- @stft_params.setter
- def stft_params(self, value: STFTParams):
- default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
- default_hop_len = default_win_len // 4
- default_win_type = "hann"
- default_match_stride = False
- default_padding_type = "reflect"
-
- default_stft_params = STFTParams(
- window_length=default_win_len,
- hop_length=default_hop_len,
- window_type=default_win_type,
- match_stride=default_match_stride,
- padding_type=default_padding_type,
- )._asdict()
-
- value = value._asdict() if value else default_stft_params
-
- for key in default_stft_params:
- if value[key] is None:
- value[key] = default_stft_params[key]
-
- self._stft_params = STFTParams(**value)
- self.stft_data = None
-
- def compute_stft_padding(
- self, window_length: int, hop_length: int, match_stride: bool
- ):
- """Compute how the STFT should be padded, based on match\_stride.
-
- Parameters
- ----------
- window_length : int
- Window length of STFT.
- hop_length : int
- Hop length of STFT.
- match_stride : bool
- Whether or not to match stride, making the STFT have the same alignment as
- convolutional layers.
-
- Returns
- -------
- tuple
- Amount to pad on either side of audio.
- """
- length = self.signal_length
-
- if match_stride:
- assert (
- hop_length == window_length // 4
- ), "For match_stride, hop must equal n_fft // 4"
- right_pad = math.ceil(length / hop_length) * hop_length - length
- pad = (window_length - hop_length) // 2
- else:
- right_pad = 0
- pad = 0
-
- return right_pad, pad
-
- def stft(
- self,
- window_length: int = None,
- hop_length: int = None,
- window_type: str = None,
- match_stride: bool = None,
- padding_type: str = None,
- ):
- """Computes the short-time Fourier transform of the audio data,
- with specified STFT parameters.
-
- Parameters
- ----------
- window_length : int, optional
- Window length of STFT, by default ``0.032 * self.sample_rate``.
- hop_length : int, optional
- Hop length of STFT, by default ``window_length // 4``.
- window_type : str, optional
- Type of window to use, by default ``sqrt\_hann``.
- match_stride : bool, optional
- Whether to match the stride of convolutional layers, by default False
- padding_type : str, optional
- Type of padding to use, by default 'reflect'
-
- Returns
- -------
- torch.Tensor
- STFT of audio data.
-
- Examples
- --------
- Compute the STFT of an AudioSignal:
-
- >>> signal = AudioSignal(torch.randn(44100), 44100)
- >>> signal.stft()
-
- Vary the window and hop length:
-
- >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
- >>> for stft_param in stft_params:
- >>> signal.stft_params = stft_params
- >>> signal.stft()
-
- """
- window_length = (
- self.stft_params.window_length
- if window_length is None
- else int(window_length)
- )
- hop_length = (
- self.stft_params.hop_length if hop_length is None else int(hop_length)
- )
- window_type = (
- self.stft_params.window_type if window_type is None else window_type
- )
- match_stride = (
- self.stft_params.match_stride if match_stride is None else match_stride
- )
- padding_type = (
- self.stft_params.padding_type if padding_type is None else padding_type
- )
-
- window = self.get_window(window_type, window_length, self.audio_data.device)
- window = window.to(self.audio_data.device)
-
- audio_data = self.audio_data
- right_pad, pad = self.compute_stft_padding(
- window_length, hop_length, match_stride
- )
- audio_data = torch.nn.functional.pad(
- audio_data, (pad, pad + right_pad), padding_type
- )
- stft_data = torch.stft(
- audio_data.reshape(-1, audio_data.shape[-1]),
- n_fft=window_length,
- hop_length=hop_length,
- window=window,
- return_complex=True,
- center=True,
- )
- _, nf, nt = stft_data.shape
- stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
-
- if match_stride:
- # Drop first two and last two frames, which are added
- # because of padding. Now num_frames * hop_length = num_samples.
- stft_data = stft_data[..., 2:-2]
- self.stft_data = stft_data
-
- return stft_data
-
- def istft(
- self,
- window_length: int = None,
- hop_length: int = None,
- window_type: str = None,
- match_stride: bool = None,
- length: int = None,
- ):
- """Computes inverse STFT and sets it to audio\_data.
-
- Parameters
- ----------
- window_length : int, optional
- Window length of STFT, by default ``0.032 * self.sample_rate``.
- hop_length : int, optional
- Hop length of STFT, by default ``window_length // 4``.
- window_type : str, optional
- Type of window to use, by default ``sqrt\_hann``.
- match_stride : bool, optional
- Whether to match the stride of convolutional layers, by default False
- length : int, optional
- Original length of signal, by default None
-
- Returns
- -------
- AudioSignal
- AudioSignal with istft applied.
-
- Raises
- ------
- RuntimeError
- Raises an error if stft was not called prior to istft on the signal,
- or if stft_data is not set.
- """
- if self.stft_data is None:
- raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
-
- window_length = (
- self.stft_params.window_length
- if window_length is None
- else int(window_length)
- )
- hop_length = (
- self.stft_params.hop_length if hop_length is None else int(hop_length)
- )
- window_type = (
- self.stft_params.window_type if window_type is None else window_type
- )
- match_stride = (
- self.stft_params.match_stride if match_stride is None else match_stride
- )
-
- window = self.get_window(window_type, window_length, self.stft_data.device)
-
- nb, nch, nf, nt = self.stft_data.shape
- stft_data = self.stft_data.reshape(nb * nch, nf, nt)
- right_pad, pad = self.compute_stft_padding(
- window_length, hop_length, match_stride
- )
-
- if length is None:
- length = self.original_signal_length
- length = length + 2 * pad + right_pad
-
- if match_stride:
- # Zero-pad the STFT on either side, putting back the frames that were
- # dropped in stft().
- stft_data = torch.nn.functional.pad(stft_data, (2, 2))
-
- audio_data = torch.istft(
- stft_data,
- n_fft=window_length,
- hop_length=hop_length,
- window=window,
- length=length,
- center=True,
- )
- audio_data = audio_data.reshape(nb, nch, -1)
- if match_stride:
- audio_data = audio_data[..., pad : -(pad + right_pad)]
- self.audio_data = audio_data
-
- return self
-
- @staticmethod
- @functools.lru_cache(None)
- def get_mel_filters(
- sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
- ):
- """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
-
- Parameters
- ----------
- sr : int
- Sample rate of audio
- n_fft : int
- Number of FFT bins
- n_mels : int
- Number of mels
- fmin : float, optional
- Lowest frequency, in Hz, by default 0.0
- fmax : float, optional
- Highest frequency, by default None
-
- Returns
- -------
- np.ndarray [shape=(n_mels, 1 + n_fft/2)]
- Mel transform matrix
- """
- from librosa.filters import mel as librosa_mel_fn
-
- return librosa_mel_fn(
- sr=sr,
- n_fft=n_fft,
- n_mels=n_mels,
- fmin=fmin,
- fmax=fmax,
- )
-
- def mel_spectrogram(
- self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
- ):
- """Computes a Mel spectrogram.
-
- Parameters
- ----------
- n_mels : int, optional
- Number of mels, by default 80
- mel_fmin : float, optional
- Lowest frequency, in Hz, by default 0.0
- mel_fmax : float, optional
- Highest frequency, by default None
- kwargs : dict, optional
- Keyword arguments to self.stft().
-
- Returns
- -------
- torch.Tensor [shape=(batch, channels, mels, time)]
- Mel spectrogram.
- """
- stft = self.stft(**kwargs)
- magnitude = torch.abs(stft)
-
- nf = magnitude.shape[2]
- mel_basis = self.get_mel_filters(
- sr=self.sample_rate,
- n_fft=2 * (nf - 1),
- n_mels=n_mels,
- fmin=mel_fmin,
- fmax=mel_fmax,
- )
- mel_basis = torch.from_numpy(mel_basis).to(self.device)
-
- mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
- mel_spectrogram = mel_spectrogram.transpose(-1, 2)
- return mel_spectrogram
-
- @staticmethod
- @functools.lru_cache(None)
- def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
- """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
- it can be normalized depending on norm. For more information about dct:
- http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
-
- Parameters
- ----------
- n_mfcc : int
- Number of mfccs
- n_mels : int
- Number of mels
- norm : str
- Use "ortho" to get a orthogonal matrix or None, by default "ortho"
- device : str, optional
- Device to load the transformation matrix on, by default None
-
- Returns
- -------
- torch.Tensor [shape=(n_mels, n_mfcc)] T
- The dct transformation matrix.
- """
- from torchaudio.functional import create_dct
-
- return create_dct(n_mfcc, n_mels, norm).to(device)
-
- def mfcc(
- self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
- ):
- """Computes mel-frequency cepstral coefficients (MFCCs).
-
- Parameters
- ----------
- n_mfcc : int, optional
- Number of mels, by default 40
- n_mels : int, optional
- Number of mels, by default 80
- log_offset: float, optional
- Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
- kwargs : dict, optional
- Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
-
- Returns
- -------
- torch.Tensor [shape=(batch, channels, mfccs, time)]
- MFCCs.
- """
-
- mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
- mel_spectrogram = torch.log(mel_spectrogram + log_offset)
- dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
-
- mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
- mfcc = mfcc.transpose(-1, -2)
- return mfcc
-
- @property
- def magnitude(self):
- """Computes and returns the absolute value of the STFT, which
- is the magnitude. This value can also be set to some tensor.
- When set, ``self.stft_data`` is manipulated so that its magnitude
- matches what this is set to, and modulated by the phase.
-
- Returns
- -------
- torch.Tensor
- Magnitude of STFT.
-
- Examples
- --------
- >>> signal = AudioSignal(torch.randn(44100), 44100)
- >>> magnitude = signal.magnitude # Computes stft if not computed
- >>> magnitude[magnitude < magnitude.mean()] = 0
- >>> signal.magnitude = magnitude
- >>> signal.istft()
- """
- if self.stft_data is None:
- self.stft()
- return torch.abs(self.stft_data)
-
- @magnitude.setter
- def magnitude(self, value):
- self.stft_data = value * torch.exp(1j * self.phase)
- return
-
- def log_magnitude(
- self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
- ):
- """Computes the log-magnitude of the spectrogram.
-
- Parameters
- ----------
- ref_value : float, optional
- The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
- Zeros in the output correspond to positions where ``S == ref``,
- by default 1.0
- amin : float, optional
- Minimum threshold for ``S`` and ``ref``, by default 1e-5
- top_db : float, optional
- Threshold the output at ``top_db`` below the peak:
- ``max(10 * log10(S/ref)) - top_db``, by default -80.0
-
- Returns
- -------
- torch.Tensor
- Log-magnitude spectrogram
- """
- magnitude = self.magnitude
-
- amin = amin**2
- log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
- log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
-
- if top_db is not None:
- log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
- return log_spec
-
- @property
- def phase(self):
- """Computes and returns the phase of the STFT.
- This value can also be set to some tensor.
- When set, ``self.stft_data`` is manipulated so that its phase
- matches what this is set to, we original magnitudeith th.
-
- Returns
- -------
- torch.Tensor
- Phase of STFT.
-
- Examples
- --------
- >>> signal = AudioSignal(torch.randn(44100), 44100)
- >>> phase = signal.phase # Computes stft if not computed
- >>> phase[phase < phase.mean()] = 0
- >>> signal.phase = phase
- >>> signal.istft()
- """
- if self.stft_data is None:
- self.stft()
- return torch.angle(self.stft_data)
-
- @phase.setter
- def phase(self, value):
- self.stft_data = self.magnitude * torch.exp(1j * value)
- return
-
- # Operator overloading
- def __add__(self, other):
- new_signal = self.clone()
- new_signal.audio_data += util._get_value(other)
- return new_signal
-
- def __iadd__(self, other):
- self.audio_data += util._get_value(other)
- return self
-
- def __radd__(self, other):
- return self + other
-
- def __sub__(self, other):
- new_signal = self.clone()
- new_signal.audio_data -= util._get_value(other)
- return new_signal
-
- def __isub__(self, other):
- self.audio_data -= util._get_value(other)
- return self
-
- def __mul__(self, other):
- new_signal = self.clone()
- new_signal.audio_data *= util._get_value(other)
- return new_signal
-
- def __imul__(self, other):
- self.audio_data *= util._get_value(other)
- return self
-
- def __rmul__(self, other):
- return self * other
-
- # Representation
- def _info(self):
- dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
- info = {
- "duration": f"{dur} seconds",
- "batch_size": self.batch_size,
- "path": self.path_to_file if self.path_to_file else "path unknown",
- "sample_rate": self.sample_rate,
- "num_channels": self.num_channels if self.num_channels else "[unknown]",
- "audio_data.shape": self.audio_data.shape,
- "stft_params": self.stft_params,
- "device": self.device,
- }
-
- return info
-
- def markdown(self):
- """Produces a markdown representation of AudioSignal, in a markdown table.
-
- Returns
- -------
- str
- Markdown representation of AudioSignal.
-
- Examples
- --------
- >>> signal = AudioSignal(torch.randn(44100), 44100)
- >>> print(signal.markdown())
- | Key | Value
- |---|---
- | duration | 1.000 seconds |
- | batch_size | 1 |
- | path | path unknown |
- | sample_rate | 44100 |
- | num_channels | 1 |
- | audio_data.shape | torch.Size([1, 1, 44100]) |
- | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
- | device | cpu |
- """
- info = self._info()
-
- FORMAT = "| Key | Value \n" "|---|--- \n"
- for k, v in info.items():
- row = f"| {k} | {v} |\n"
- FORMAT += row
- return FORMAT
-
- def __str__(self):
- info = self._info()
-
- desc = ""
- for k, v in info.items():
- desc += f"{k}: {v}\n"
- return desc
-
- def __rich__(self):
- from rich.table import Table
-
- info = self._info()
-
- table = Table(title=f"{self.__class__.__name__}")
- table.add_column("Key", style="green")
- table.add_column("Value", style="cyan")
-
- for k, v in info.items():
- table.add_row(k, str(v))
- return table
-
- # Comparison
- def __eq__(self, other):
- for k, v in list(self.__dict__.items()):
- if torch.is_tensor(v):
- if not torch.allclose(v, other.__dict__[k], atol=1e-6):
- max_error = (v - other.__dict__[k]).abs().max()
- print(f"Max abs error for {k}: {max_error}")
- return False
- return True
-
- # Indexing
- def __getitem__(self, key):
- if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
- assert self.batch_size == 1
- audio_data = self.audio_data
- _loudness = self._loudness
- stft_data = self.stft_data
-
- elif isinstance(key, (bool, int, list, slice, tuple)) or (
- torch.is_tensor(key) and key.ndim <= 1
- ):
- # Indexing only on the batch dimension.
- # Then let's copy over relevant stuff.
- # Future work: make this work for time-indexing
- # as well, using the hop length.
- audio_data = self.audio_data[key]
- _loudness = self._loudness[key] if self._loudness is not None else None
- stft_data = self.stft_data[key] if self.stft_data is not None else None
-
- sources = None
-
- copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
- copy._loudness = _loudness
- copy._stft_data = stft_data
- copy.sources = sources
-
- return copy
-
- def __setitem__(self, key, value):
- if not isinstance(value, type(self)):
- self.audio_data[key] = value
- return
-
- if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
- assert self.batch_size == 1
- self.audio_data = value.audio_data
- self._loudness = value._loudness
- self.stft_data = value.stft_data
- return
-
- elif isinstance(key, (bool, int, list, slice, tuple)) or (
- torch.is_tensor(key) and key.ndim <= 1
- ):
- if self.audio_data is not None and value.audio_data is not None:
- self.audio_data[key] = value.audio_data
- if self._loudness is not None and value._loudness is not None:
- self._loudness[key] = value._loudness
- if self.stft_data is not None and value.stft_data is not None:
- self.stft_data[key] = value.stft_data
- return
-
- def __ne__(self, other):
- return not self == other
diff --git a/dito/models/ldm/dac/audiotools/core/display.py b/dito/models/ldm/dac/audiotools/core/display.py
deleted file mode 100644
index 66cbcf34cb2cf9fdf8d67ec4418a887eba73f184..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/display.py
+++ /dev/null
@@ -1,194 +0,0 @@
-import inspect
-import typing
-from functools import wraps
-
-from . import util
-
-
-def format_figure(func):
- """Decorator for formatting figures produced by the code below.
- See :py:func:`audiotools.core.util.format_figure` for more.
-
- Parameters
- ----------
- func : Callable
- Plotting function that is decorated by this function.
-
- """
-
- @wraps(func)
- def wrapper(*args, **kwargs):
- f_keys = inspect.signature(util.format_figure).parameters.keys()
- f_kwargs = {}
- for k, v in list(kwargs.items()):
- if k in f_keys:
- kwargs.pop(k)
- f_kwargs[k] = v
- func(*args, **kwargs)
- util.format_figure(**f_kwargs)
-
- return wrapper
-
-
-class DisplayMixin:
- @format_figure
- def specshow(
- self,
- preemphasis: bool = False,
- x_axis: str = "time",
- y_axis: str = "linear",
- n_mels: int = 128,
- **kwargs,
- ):
- """Displays a spectrogram, using ``librosa.display.specshow``.
-
- Parameters
- ----------
- preemphasis : bool, optional
- Whether or not to apply preemphasis, which makes high
- frequency detail easier to see, by default False
- x_axis : str, optional
- How to label the x axis, by default "time"
- y_axis : str, optional
- How to label the y axis, by default "linear"
- n_mels : int, optional
- If displaying a mel spectrogram with ``y_axis = "mel"``,
- this controls the number of mels, by default 128.
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
- """
- import librosa
- import librosa.display
-
- # Always re-compute the STFT data before showing it, in case
- # it changed.
- signal = self.clone()
- signal.stft_data = None
-
- if preemphasis:
- signal.preemphasis()
-
- ref = signal.magnitude.max()
- log_mag = signal.log_magnitude(ref_value=ref)
-
- if y_axis == "mel":
- log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
- log_mag -= log_mag.max()
-
- librosa.display.specshow(
- log_mag.numpy()[0].mean(axis=0),
- x_axis=x_axis,
- y_axis=y_axis,
- sr=signal.sample_rate,
- **kwargs,
- )
-
- @format_figure
- def waveplot(self, x_axis: str = "time", **kwargs):
- """Displays a waveform plot, using ``librosa.display.waveshow``.
-
- Parameters
- ----------
- x_axis : str, optional
- How to label the x axis, by default "time"
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
- """
- import librosa
- import librosa.display
-
- audio_data = self.audio_data[0].mean(dim=0)
- audio_data = audio_data.cpu().numpy()
-
- plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
- wave_plot_fn = getattr(librosa.display, plot_fn)
- wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
-
- @format_figure
- def wavespec(self, x_axis: str = "time", **kwargs):
- """Displays a waveform plot, using ``librosa.display.waveshow``.
-
- Parameters
- ----------
- x_axis : str, optional
- How to label the x axis, by default "time"
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
- """
- import matplotlib.pyplot as plt
- from matplotlib.gridspec import GridSpec
-
- gs = GridSpec(6, 1)
- plt.subplot(gs[0, :])
- self.waveplot(x_axis=x_axis)
- plt.subplot(gs[1:, :])
- self.specshow(x_axis=x_axis, **kwargs)
-
- def write_audio_to_tb(
- self,
- tag: str,
- writer,
- step: int = None,
- plot_fn: typing.Union[typing.Callable, str] = "specshow",
- **kwargs,
- ):
- """Writes a signal and its spectrogram to Tensorboard. Will show up
- under the Audio and Images tab in Tensorboard.
-
- Parameters
- ----------
- tag : str
- Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
- written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
- writer : SummaryWriter
- A SummaryWriter object from PyTorch library.
- step : int, optional
- The step to write the signal to, by default None
- plot_fn : typing.Union[typing.Callable, str], optional
- How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
- whatever ``plot_fn`` is set to.
- """
- import matplotlib.pyplot as plt
-
- audio_data = self.audio_data[0, 0].detach().cpu()
- sample_rate = self.sample_rate
- writer.add_audio(tag, audio_data, step, sample_rate)
-
- if plot_fn is not None:
- if isinstance(plot_fn, str):
- plot_fn = getattr(self, plot_fn)
- fig = plt.figure()
- plt.clf()
- plot_fn(**kwargs)
- writer.add_figure(tag.replace("wav", "png"), fig, step)
-
- def save_image(
- self,
- image_path: str,
- plot_fn: typing.Union[typing.Callable, str] = "specshow",
- **kwargs,
- ):
- """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
- a specified file.
-
- Parameters
- ----------
- image_path : str
- Where to save the file to.
- plot_fn : typing.Union[typing.Callable, str], optional
- How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
- whatever ``plot_fn`` is set to.
- """
- import matplotlib.pyplot as plt
-
- if isinstance(plot_fn, str):
- plot_fn = getattr(self, plot_fn)
-
- plt.clf()
- plot_fn(**kwargs)
- plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
- plt.close()
diff --git a/dito/models/ldm/dac/audiotools/core/dsp.py b/dito/models/ldm/dac/audiotools/core/dsp.py
deleted file mode 100644
index f9be51a119537b77e497ddc2dac126d569533d7c..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/dsp.py
+++ /dev/null
@@ -1,390 +0,0 @@
-import typing
-
-import julius
-import numpy as np
-import torch
-
-from . import util
-
-
-class DSPMixin:
- _original_batch_size = None
- _original_num_channels = None
- _padded_signal_length = None
-
- def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
- self._original_batch_size = self.batch_size
- self._original_num_channels = self.num_channels
-
- window_length = int(window_duration * self.sample_rate)
- hop_length = int(hop_duration * self.sample_rate)
-
- if window_length % hop_length != 0:
- factor = window_length // hop_length
- window_length = factor * hop_length
-
- self.zero_pad(hop_length, hop_length)
- self._padded_signal_length = self.signal_length
-
- return window_length, hop_length
-
- def windows(
- self, window_duration: float, hop_duration: float, preprocess: bool = True
- ):
- """Generator which yields windows of specified duration from signal with a specified
- hop length.
-
- Parameters
- ----------
- window_duration : float
- Duration of every window in seconds.
- hop_duration : float
- Hop between windows in seconds.
- preprocess : bool, optional
- Whether to preprocess the signal, so that the first sample is in
- the middle of the first window, by default True
-
- Yields
- ------
- AudioSignal
- Each window is returned as an AudioSignal.
- """
- if preprocess:
- window_length, hop_length = self._preprocess_signal_for_windowing(
- window_duration, hop_duration
- )
-
- self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
-
- for b in range(self.batch_size):
- i = 0
- start_idx = i * hop_length
- while True:
- start_idx = i * hop_length
- i += 1
- end_idx = start_idx + window_length
- if end_idx > self.signal_length:
- break
- yield self[b, ..., start_idx:end_idx]
-
- def collect_windows(
- self, window_duration: float, hop_duration: float, preprocess: bool = True
- ):
- """Reshapes signal into windows of specified duration from signal with a specified
- hop length. Window are placed along the batch dimension. Use with
- :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
- original signal.
-
- Parameters
- ----------
- window_duration : float
- Duration of every window in seconds.
- hop_duration : float
- Hop between windows in seconds.
- preprocess : bool, optional
- Whether to preprocess the signal, so that the first sample is in
- the middle of the first window, by default True
-
- Returns
- -------
- AudioSignal
- AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
- """
- if preprocess:
- window_length, hop_length = self._preprocess_signal_for_windowing(
- window_duration, hop_duration
- )
-
- # self.audio_data: (nb, nch, nt).
- unfolded = torch.nn.functional.unfold(
- self.audio_data.reshape(-1, 1, 1, self.signal_length),
- kernel_size=(1, window_length),
- stride=(1, hop_length),
- )
- # unfolded: (nb * nch, window_length, num_windows).
- # -> (nb * nch * num_windows, 1, window_length)
- unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
- self.audio_data = unfolded
- return self
-
- def overlap_and_add(self, hop_duration: float):
- """Function which takes a list of windows and overlap adds them into a
- signal the same length as ``audio_signal``.
-
- Parameters
- ----------
- hop_duration : float
- How much to shift for each window
- (overlap is window_duration - hop_duration) in seconds.
-
- Returns
- -------
- AudioSignal
- overlap-and-added signal.
- """
- hop_length = int(hop_duration * self.sample_rate)
- window_length = self.signal_length
-
- nb, nch = self._original_batch_size, self._original_num_channels
-
- unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
- folded = torch.nn.functional.fold(
- unfolded,
- output_size=(1, self._padded_signal_length),
- kernel_size=(1, window_length),
- stride=(1, hop_length),
- )
-
- norm = torch.ones_like(unfolded, device=unfolded.device)
- norm = torch.nn.functional.fold(
- norm,
- output_size=(1, self._padded_signal_length),
- kernel_size=(1, window_length),
- stride=(1, hop_length),
- )
-
- folded = folded / norm
-
- folded = folded.reshape(nb, nch, -1)
- self.audio_data = folded
- self.trim(hop_length, hop_length)
- return self
-
- def low_pass(
- self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
- ):
- """Low-passes the signal in-place. Each item in the batch
- can have a different low-pass cutoff, if the input
- to this signal is an array or tensor. If a float, all
- items are given the same low-pass filter.
-
- Parameters
- ----------
- cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
- Cutoff in Hz of low-pass filter.
- zeros : int, optional
- Number of taps to use in low-pass filter, by default 51
-
- Returns
- -------
- AudioSignal
- Low-passed AudioSignal.
- """
- cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
- cutoffs = cutoffs / self.sample_rate
- filtered = torch.empty_like(self.audio_data)
-
- for i, cutoff in enumerate(cutoffs):
- lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
- filtered[i] = lp_filter(self.audio_data[i])
-
- self.audio_data = filtered
- self.stft_data = None
- return self
-
- def high_pass(
- self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
- ):
- """High-passes the signal in-place. Each item in the batch
- can have a different high-pass cutoff, if the input
- to this signal is an array or tensor. If a float, all
- items are given the same high-pass filter.
-
- Parameters
- ----------
- cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
- Cutoff in Hz of high-pass filter.
- zeros : int, optional
- Number of taps to use in high-pass filter, by default 51
-
- Returns
- -------
- AudioSignal
- High-passed AudioSignal.
- """
- cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
- cutoffs = cutoffs / self.sample_rate
- filtered = torch.empty_like(self.audio_data)
-
- for i, cutoff in enumerate(cutoffs):
- hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
- filtered[i] = hp_filter(self.audio_data[i])
-
- self.audio_data = filtered
- self.stft_data = None
- return self
-
- def mask_frequencies(
- self,
- fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
- fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
- val: float = 0.0,
- ):
- """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
- with the value specified by ``val``. Useful for implementing SpecAug.
- The min and max can be different for every item in the batch.
-
- Parameters
- ----------
- fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
- Lower end of band to mask out.
- fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
- Upper end of band to mask out.
- val : float, optional
- Value to fill in, by default 0.0
-
- Returns
- -------
- AudioSignal
- Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
- masked audio data.
- """
- # SpecAug
- mag, phase = self.magnitude, self.phase
- fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
- fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
- assert torch.all(fmin_hz < fmax_hz)
-
- # build mask
- nbins = mag.shape[-2]
- bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
- bins_hz = bins_hz[None, None, :, None].repeat(
- self.batch_size, 1, 1, mag.shape[-1]
- )
- mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
- mask = mask.to(self.device)
-
- mag = mag.masked_fill(mask, val)
- phase = phase.masked_fill(mask, val)
- self.stft_data = mag * torch.exp(1j * phase)
- return self
-
- def mask_timesteps(
- self,
- tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
- tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
- val: float = 0.0,
- ):
- """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
- with the value specified by ``val``. Useful for implementing SpecAug.
- The min and max can be different for every item in the batch.
-
- Parameters
- ----------
- tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
- Lower end of timesteps to mask out.
- tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
- Upper end of timesteps to mask out.
- val : float, optional
- Value to fill in, by default 0.0
-
- Returns
- -------
- AudioSignal
- Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
- masked audio data.
- """
- # SpecAug
- mag, phase = self.magnitude, self.phase
- tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
- tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
-
- assert torch.all(tmin_s < tmax_s)
-
- # build mask
- nt = mag.shape[-1]
- bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
- bins_t = bins_t[None, None, None, :].repeat(
- self.batch_size, 1, mag.shape[-2], 1
- )
- mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
-
- mag = mag.masked_fill(mask, val)
- phase = phase.masked_fill(mask, val)
- self.stft_data = mag * torch.exp(1j * phase)
- return self
-
- def mask_low_magnitudes(
- self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
- ):
- """Mask away magnitudes below a specified threshold, which
- can be different for every item in the batch.
-
- Parameters
- ----------
- db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
- Decibel value for which things below it will be masked away.
- val : float, optional
- Value to fill in for masked portions, by default 0.0
-
- Returns
- -------
- AudioSignal
- Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
- masked audio data.
- """
- mag = self.magnitude
- log_mag = self.log_magnitude()
-
- db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
- mask = log_mag < db_cutoff
- mag = mag.masked_fill(mask, val)
-
- self.magnitude = mag
- return self
-
- def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
- """Shifts the phase by a constant value.
-
- Parameters
- ----------
- shift : typing.Union[torch.Tensor, np.ndarray, float]
- What to shift the phase by.
-
- Returns
- -------
- AudioSignal
- Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
- masked audio data.
- """
- shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
- self.phase = self.phase + shift
- return self
-
- def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
- """Corrupts the phase randomly by some scaled value.
-
- Parameters
- ----------
- scale : typing.Union[torch.Tensor, np.ndarray, float]
- Standard deviation of noise to add to the phase.
-
- Returns
- -------
- AudioSignal
- Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
- masked audio data.
- """
- scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
- self.phase = self.phase + scale * torch.randn_like(self.phase)
- return self
-
- def preemphasis(self, coef: float = 0.85):
- """Applies pre-emphasis to audio signal.
-
- Parameters
- ----------
- coef : float, optional
- How much pre-emphasis to apply, lower values do less. 0 does nothing.
- by default 0.85
-
- Returns
- -------
- AudioSignal
- Pre-emphasized signal.
- """
- kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
- x = self.audio_data.reshape(-1, 1, self.signal_length)
- x = torch.nn.functional.conv1d(x, kernel, padding=1)
- self.audio_data = x.reshape(*self.audio_data.shape)
- return self
diff --git a/dito/models/ldm/dac/audiotools/core/effects.py b/dito/models/ldm/dac/audiotools/core/effects.py
deleted file mode 100644
index fb534cbcb2d457575de685fc9248d1716879145b..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/effects.py
+++ /dev/null
@@ -1,647 +0,0 @@
-import typing
-
-import julius
-import numpy as np
-import torch
-import torchaudio
-
-from . import util
-
-
-class EffectMixin:
- GAIN_FACTOR = np.log(10) / 20
- """Gain factor for converting between amplitude and decibels."""
- CODEC_PRESETS = {
- "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
- "GSM-FR": {"format": "gsm"},
- "MP3": {"format": "mp3", "compression": -9},
- "Vorbis": {"format": "vorbis", "compression": -1},
- "Ogg": {
- "format": "ogg",
- "compression": -1,
- },
- "Amr-nb": {"format": "amr-nb"},
- }
- """Presets for applying codecs via torchaudio."""
-
- def mix(
- self,
- other,
- snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
- other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
- ):
- """Mixes noise with signal at specified
- signal-to-noise ratio. Optionally, the
- other signal can be equalized in-place.
-
-
- Parameters
- ----------
- other : AudioSignal
- AudioSignal object to mix with.
- snr : typing.Union[torch.Tensor, np.ndarray, float], optional
- Signal to noise ratio, by default 10
- other_eq : typing.Union[torch.Tensor, np.ndarray], optional
- EQ curve to apply to other signal, if any, by default None
-
- Returns
- -------
- AudioSignal
- In-place modification of AudioSignal.
- """
- snr = util.ensure_tensor(snr).to(self.device)
-
- pad_len = max(0, self.signal_length - other.signal_length)
- other.zero_pad(0, pad_len)
- other.truncate_samples(self.signal_length)
- if other_eq is not None:
- other = other.equalizer(other_eq)
-
- tgt_loudness = self.loudness() - snr
- other = other.normalize(tgt_loudness)
-
- self.audio_data = self.audio_data + other.audio_data
- return self
-
- def convolve(self, other, start_at_max: bool = True):
- """Convolves self with other.
- This function uses FFTs to do the convolution.
-
- Parameters
- ----------
- other : AudioSignal
- Signal to convolve with.
- start_at_max : bool, optional
- Whether to start at the max value of other signal, to
- avoid inducing delays, by default True
-
- Returns
- -------
- AudioSignal
- Convolved signal, in-place.
- """
- from . import AudioSignal
-
- pad_len = self.signal_length - other.signal_length
-
- if pad_len > 0:
- other.zero_pad(0, pad_len)
- else:
- other.truncate_samples(self.signal_length)
-
- if start_at_max:
- # Use roll to rotate over the max for every item
- # so that the impulse responses don't induce any
- # delay.
- idx = other.audio_data.abs().argmax(axis=-1)
- irs = torch.zeros_like(other.audio_data)
- for i in range(other.batch_size):
- irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
- other = AudioSignal(irs, other.sample_rate)
-
- delta = torch.zeros_like(other.audio_data)
- delta[..., 0] = 1
-
- length = self.signal_length
- delta_fft = torch.fft.rfft(delta, length)
- other_fft = torch.fft.rfft(other.audio_data, length)
- self_fft = torch.fft.rfft(self.audio_data, length)
-
- convolved_fft = other_fft * self_fft
- convolved_audio = torch.fft.irfft(convolved_fft, length)
-
- delta_convolved_fft = other_fft * delta_fft
- delta_audio = torch.fft.irfft(delta_convolved_fft, length)
-
- # Use the delta to rescale the audio exactly as needed.
- delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
- scale = 1 / delta_max.clamp(1e-5)
- convolved_audio = convolved_audio * scale
-
- self.audio_data = convolved_audio
-
- return self
-
- def apply_ir(
- self,
- ir,
- drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
- ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
- use_original_phase: bool = False,
- ):
- """Applies an impulse response to the signal. If ` is`ir_eq``
- is specified, the impulse response is equalized before
- it is applied, using the given curve.
-
- Parameters
- ----------
- ir : AudioSignal
- Impulse response to convolve with.
- drr : typing.Union[torch.Tensor, np.ndarray, float], optional
- Direct-to-reverberant ratio that impulse response will be
- altered to, if specified, by default None
- ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
- Equalization that will be applied to impulse response
- if specified, by default None
- use_original_phase : bool, optional
- Whether to use the original phase, instead of the convolved
- phase, by default False
-
- Returns
- -------
- AudioSignal
- Signal with impulse response applied to it
- """
- if ir_eq is not None:
- ir = ir.equalizer(ir_eq)
- if drr is not None:
- ir = ir.alter_drr(drr)
-
- # Save the peak before
- max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
-
- # Augment the impulse response to simulate microphone effects
- # and with varying direct-to-reverberant ratio.
- phase = self.phase
- self.convolve(ir)
-
- # Use the input phase
- if use_original_phase:
- self.stft()
- self.stft_data = self.magnitude * torch.exp(1j * phase)
- self.istft()
-
- # Rescale to the input's amplitude
- max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
- scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
- self = self * scale_factor
-
- return self
-
- def ensure_max_of_audio(self, max: float = 1.0):
- """Ensures that ``abs(audio_data) <= max``.
-
- Parameters
- ----------
- max : float, optional
- Max absolute value of signal, by default 1.0
-
- Returns
- -------
- AudioSignal
- Signal with values scaled between -max and max.
- """
- peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
- peak_gain = torch.ones_like(peak)
- peak_gain[peak > max] = max / peak[peak > max]
- self.audio_data = self.audio_data * peak_gain
- return self
-
- def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
- """Normalizes the signal's volume to the specified db, in LUFS.
- This is GPU-compatible, making for very fast loudness normalization.
-
- Parameters
- ----------
- db : typing.Union[torch.Tensor, np.ndarray, float], optional
- Loudness to normalize to, by default -24.0
-
- Returns
- -------
- AudioSignal
- Normalized audio signal.
- """
- db = util.ensure_tensor(db).to(self.device)
- ref_db = self.loudness()
- gain = db - ref_db
- gain = torch.exp(gain * self.GAIN_FACTOR)
-
- self.audio_data = self.audio_data * gain[:, None, None]
- return self
-
- def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
- """Change volume of signal by some amount, in dB.
-
- Parameters
- ----------
- db : typing.Union[torch.Tensor, np.ndarray, float]
- Amount to change volume by.
-
- Returns
- -------
- AudioSignal
- Signal at new volume.
- """
- db = util.ensure_tensor(db, ndim=1).to(self.device)
- gain = torch.exp(db * self.GAIN_FACTOR)
- self.audio_data = self.audio_data * gain[:, None, None]
- return self
-
- def _to_2d(self):
- waveform = self.audio_data.reshape(-1, self.signal_length)
- return waveform
-
- def _to_3d(self, waveform):
- return waveform.reshape(self.batch_size, self.num_channels, -1)
-
- def pitch_shift(self, n_semitones: int, quick: bool = True):
- """Pitch shift the signal. All items in the batch
- get the same pitch shift.
-
- Parameters
- ----------
- n_semitones : int
- How many semitones to shift the signal by.
- quick : bool, optional
- Using quick pitch shifting, by default True
-
- Returns
- -------
- AudioSignal
- Pitch shifted audio signal.
- """
- device = self.device
- effects = [
- ["pitch", str(n_semitones * 100)],
- ["rate", str(self.sample_rate)],
- ]
- if quick:
- effects[0].insert(1, "-q")
-
- waveform = self._to_2d().cpu()
- waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
- waveform, self.sample_rate, effects, channels_first=True
- )
- self.sample_rate = sample_rate
- self.audio_data = self._to_3d(waveform)
- return self.to(device)
-
- def time_stretch(self, factor: float, quick: bool = True):
- """Time stretch the audio signal.
-
- Parameters
- ----------
- factor : float
- Factor by which to stretch the AudioSignal. Typically
- between 0.8 and 1.2.
- quick : bool, optional
- Whether to use quick time stretching, by default True
-
- Returns
- -------
- AudioSignal
- Time-stretched AudioSignal.
- """
- device = self.device
- effects = [
- ["tempo", str(factor)],
- ["rate", str(self.sample_rate)],
- ]
- if quick:
- effects[0].insert(1, "-q")
-
- waveform = self._to_2d().cpu()
- waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
- waveform, self.sample_rate, effects, channels_first=True
- )
- self.sample_rate = sample_rate
- self.audio_data = self._to_3d(waveform)
- return self.to(device)
-
- def apply_codec(
- self,
- preset: str = None,
- format: str = "wav",
- encoding: str = None,
- bits_per_sample: int = None,
- compression: int = None,
- ): # pragma: no cover
- """Applies an audio codec to the signal.
-
- Parameters
- ----------
- preset : str, optional
- One of the keys in ``self.CODEC_PRESETS``, by default None
- format : str, optional
- Format for audio codec, by default "wav"
- encoding : str, optional
- Encoding to use, by default None
- bits_per_sample : int, optional
- How many bits per sample, by default None
- compression : int, optional
- Compression amount of codec, by default None
-
- Returns
- -------
- AudioSignal
- AudioSignal with codec applied.
-
- Raises
- ------
- ValueError
- If preset is not in ``self.CODEC_PRESETS``, an error
- is thrown.
- """
- torchaudio_version_070 = "0.7" in torchaudio.__version__
- if torchaudio_version_070:
- return self
-
- kwargs = {
- "format": format,
- "encoding": encoding,
- "bits_per_sample": bits_per_sample,
- "compression": compression,
- }
-
- if preset is not None:
- if preset in self.CODEC_PRESETS:
- kwargs = self.CODEC_PRESETS[preset]
- else:
- raise ValueError(
- f"Unknown preset: {preset}. "
- f"Known presets: {list(self.CODEC_PRESETS.keys())}"
- )
-
- waveform = self._to_2d()
- if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
- # Apply it in a for loop
- augmented = torch.cat(
- [
- torchaudio.functional.apply_codec(
- waveform[i][None, :], self.sample_rate, **kwargs
- )
- for i in range(waveform.shape[0])
- ],
- dim=0,
- )
- else:
- augmented = torchaudio.functional.apply_codec(
- waveform, self.sample_rate, **kwargs
- )
- augmented = self._to_3d(augmented)
-
- self.audio_data = augmented
- return self
-
- def mel_filterbank(self, n_bands: int):
- """Breaks signal into mel bands.
-
- Parameters
- ----------
- n_bands : int
- Number of mel bands to use.
-
- Returns
- -------
- torch.Tensor
- Mel-filtered bands, with last axis being the band index.
- """
- filterbank = (
- julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
- )
- filtered = filterbank(self.audio_data)
- return filtered.permute(1, 2, 3, 0)
-
- def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
- """Applies a mel-spaced equalizer to the audio signal.
-
- Parameters
- ----------
- db : typing.Union[torch.Tensor, np.ndarray]
- EQ curve to apply.
-
- Returns
- -------
- AudioSignal
- AudioSignal with equalization applied.
- """
- db = util.ensure_tensor(db)
- n_bands = db.shape[-1]
- fbank = self.mel_filterbank(n_bands)
-
- # If there's a batch dimension, make sure it's the same.
- if db.ndim == 2:
- if db.shape[0] != 1:
- assert db.shape[0] == fbank.shape[0]
- else:
- db = db.unsqueeze(0)
-
- weights = (10**db).to(self.device).float()
- fbank = fbank * weights[:, None, None, :]
- eq_audio_data = fbank.sum(-1)
- self.audio_data = eq_audio_data
- return self
-
- def clip_distortion(
- self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
- ):
- """Clips the signal at a given percentile. The higher it is,
- the lower the threshold for clipping.
-
- Parameters
- ----------
- clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
- Values are between 0.0 to 1.0. Typical values are 0.1 or below.
-
- Returns
- -------
- AudioSignal
- Audio signal with clipped audio data.
- """
- clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
- min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
- max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
-
- nc = self.audio_data.shape[1]
- min_thresh = min_thresh[:, :nc, :]
- max_thresh = max_thresh[:, :nc, :]
-
- self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
-
- return self
-
- def quantization(
- self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
- ):
- """Applies quantization to the input waveform.
-
- Parameters
- ----------
- quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
- Number of evenly spaced quantization channels to quantize
- to.
-
- Returns
- -------
- AudioSignal
- Quantized AudioSignal.
- """
- quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
-
- x = self.audio_data
- x = (x + 1) / 2
- x = x * quantization_channels
- x = x.floor()
- x = x / quantization_channels
- x = 2 * x - 1
-
- residual = (self.audio_data - x).detach()
- self.audio_data = self.audio_data - residual
- return self
-
- def mulaw_quantization(
- self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
- ):
- """Applies mu-law quantization to the input waveform.
-
- Parameters
- ----------
- quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
- Number of mu-law spaced quantization channels to quantize
- to.
-
- Returns
- -------
- AudioSignal
- Quantized AudioSignal.
- """
- mu = quantization_channels - 1.0
- mu = util.ensure_tensor(mu, ndim=3)
-
- x = self.audio_data
-
- # quantize
- x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
- x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
-
- # unquantize
- x = (x / mu) * 2 - 1.0
- x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
-
- residual = (self.audio_data - x).detach()
- self.audio_data = self.audio_data - residual
- return self
-
- def __matmul__(self, other):
- return self.convolve(other)
-
-
-class ImpulseResponseMixin:
- """These functions are generally only used with AudioSignals that are derived
- from impulse responses, not other sources like music or speech. These methods
- are used to replicate the data augmentation described in [1].
-
- 1. Bryan, Nicholas J. "Impulse response data augmentation and deep
- neural networks for blind room acoustic parameter estimation."
- ICASSP 2020-2020 IEEE International Conference on Acoustics,
- Speech and Signal Processing (ICASSP). IEEE, 2020.
- """
-
- def decompose_ir(self):
- """Decomposes an impulse response into early and late
- field responses.
- """
- # Equations 1 and 2
- # -----------------
- # Breaking up into early
- # response + late field response.
-
- td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
- t0 = int(self.sample_rate * 0.0025)
-
- idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
- idx = idx.expand(self.batch_size, -1, -1)
- early_idx = (idx >= td - t0) * (idx <= td + t0)
-
- early_response = torch.zeros_like(self.audio_data, device=self.device)
- early_response[early_idx] = self.audio_data[early_idx]
-
- late_idx = ~early_idx
- late_field = torch.zeros_like(self.audio_data, device=self.device)
- late_field[late_idx] = self.audio_data[late_idx]
-
- # Equation 4
- # ----------
- # Decompose early response into windowed
- # direct path and windowed residual.
-
- window = torch.zeros_like(self.audio_data, device=self.device)
- for idx in range(self.batch_size):
- window_idx = early_idx[idx, 0].nonzero()
- window[idx, ..., window_idx] = self.get_window(
- "hann", window_idx.shape[-1], self.device
- )
- return early_response, late_field, window
-
- def measure_drr(self):
- """Measures the direct-to-reverberant ratio of the impulse
- response.
-
- Returns
- -------
- float
- Direct-to-reverberant ratio
- """
- early_response, late_field, _ = self.decompose_ir()
- num = (early_response**2).sum(dim=-1)
- den = (late_field**2).sum(dim=-1)
- drr = 10 * torch.log10(num / den)
- return drr
-
- @staticmethod
- def solve_alpha(early_response, late_field, wd, target_drr):
- """Used to solve for the alpha value, which is used
- to alter the drr.
- """
- # Equation 5
- # ----------
- # Apply the good ol' quadratic formula.
-
- wd_sq = wd**2
- wd_sq_1 = (1 - wd) ** 2
- e_sq = early_response**2
- l_sq = late_field**2
- a = (wd_sq * e_sq).sum(dim=-1)
- b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
- c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
- dim=-1
- )
-
- expr = ((b**2) - 4 * a * c).sqrt()
- alpha = torch.maximum(
- (-b - expr) / (2 * a),
- (-b + expr) / (2 * a),
- )
- return alpha
-
- def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
- """Alters the direct-to-reverberant ratio of the impulse response.
-
- Parameters
- ----------
- drr : typing.Union[torch.Tensor, np.ndarray, float]
- Direct-to-reverberant ratio that impulse response will be
- altered to, if specified, by default None
-
- Returns
- -------
- AudioSignal
- Altered impulse response.
- """
- drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
-
- early_response, late_field, window = self.decompose_ir()
- alpha = self.solve_alpha(early_response, late_field, window, drr)
- min_alpha = (
- late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
- )
- alpha = torch.maximum(alpha, min_alpha)[..., None]
-
- aug_ir_data = (
- alpha * window * early_response
- + ((1 - window) * early_response)
- + late_field
- )
- self.audio_data = aug_ir_data
- self.ensure_max_of_audio()
- return self
diff --git a/dito/models/ldm/dac/audiotools/core/ffmpeg.py b/dito/models/ldm/dac/audiotools/core/ffmpeg.py
deleted file mode 100644
index 83f9cd197d7dc8748a16be77614cc593a6a33297..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/ffmpeg.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import json
-import shlex
-import subprocess
-import tempfile
-from pathlib import Path
-from typing import Tuple
-
-import ffmpy
-import numpy as np
-import torch
-
-
-def r128stats(filepath: str, quiet: bool):
- """Takes a path to an audio file, returns a dict with the loudness
- stats computed by the ffmpeg ebur128 filter.
-
- Parameters
- ----------
- filepath : str
- Path to compute loudness stats on.
- quiet : bool
- Whether to show FFMPEG output during computation.
-
- Returns
- -------
- dict
- Dictionary containing loudness stats.
- """
- ffargs = [
- "ffmpeg",
- "-nostats",
- "-i",
- filepath,
- "-filter_complex",
- "ebur128",
- "-f",
- "null",
- "-",
- ]
- if quiet:
- ffargs += ["-hide_banner"]
- proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
- stats = proc.communicate()[1]
- summary_index = stats.rfind("Summary:")
-
- summary_list = stats[summary_index:].split()
- i_lufs = float(summary_list[summary_list.index("I:") + 1])
- i_thresh = float(summary_list[summary_list.index("I:") + 4])
- lra = float(summary_list[summary_list.index("LRA:") + 1])
- lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
- lra_low = float(summary_list[summary_list.index("low:") + 1])
- lra_high = float(summary_list[summary_list.index("high:") + 1])
- stats_dict = {
- "I": i_lufs,
- "I Threshold": i_thresh,
- "LRA": lra,
- "LRA Threshold": lra_thresh,
- "LRA Low": lra_low,
- "LRA High": lra_high,
- }
-
- return stats_dict
-
-
-def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
- """Given a path to a file, returns the start time offset and codec of
- the first audio stream.
- """
- ff = ffmpy.FFprobe(
- inputs={path: None},
- global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
- )
- streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
- seconds_offset = 0.0
- codec = None
-
- # Get the offset and codec of the first audio stream we find
- # and return its start time, if it has one.
- for stream in streams:
- if stream["codec_type"] == "audio":
- seconds_offset = stream.get("start_time", 0.0)
- codec = stream.get("codec_name")
- break
- return float(seconds_offset), codec
-
-
-class FFMPEGMixin:
- _loudness = None
-
- def ffmpeg_loudness(self, quiet: bool = True):
- """Computes loudness of audio file using FFMPEG.
-
- Parameters
- ----------
- quiet : bool, optional
- Whether to show FFMPEG output during computation,
- by default True
-
- Returns
- -------
- torch.Tensor
- Loudness of every item in the batch, computed via
- FFMPEG.
- """
- loudness = []
-
- with tempfile.NamedTemporaryFile(suffix=".wav") as f:
- for i in range(self.batch_size):
- self[i].write(f.name)
- loudness_stats = r128stats(f.name, quiet=quiet)
- loudness.append(loudness_stats["I"])
-
- self._loudness = torch.from_numpy(np.array(loudness)).float()
- return self.loudness()
-
- def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
- """Resamples AudioSignal using FFMPEG. More memory-efficient
- than using julius.resample for long audio files.
-
- Parameters
- ----------
- sample_rate : int
- Sample rate to resample to.
- quiet : bool, optional
- Whether to show FFMPEG output during computation,
- by default True
-
- Returns
- -------
- AudioSignal
- Resampled AudioSignal.
- """
- from audiotools import AudioSignal
-
- if sample_rate == self.sample_rate:
- return self
-
- with tempfile.NamedTemporaryFile(suffix=".wav") as f:
- self.write(f.name)
- f_out = f.name.replace("wav", "rs.wav")
- command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
- if quiet:
- command += " -hide_banner -loglevel error"
- subprocess.check_call(shlex.split(command))
- resampled = AudioSignal(f_out)
- Path.unlink(Path(f_out))
- return resampled
-
- @classmethod
- def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
- """Loads AudioSignal object after decoding it to a wav file using FFMPEG.
- Useful for loading audio that isn't covered by librosa's loading mechanism. Also
- useful for loading mp3 files, without any offset.
-
- Parameters
- ----------
- audio_path : str
- Path to load AudioSignal from.
- quiet : bool, optional
- Whether to show FFMPEG output during computation,
- by default True
-
- Returns
- -------
- AudioSignal
- AudioSignal loaded from file with FFMPEG.
- """
- audio_path = str(audio_path)
- with tempfile.TemporaryDirectory() as d:
- wav_file = str(Path(d) / "extracted.wav")
- padded_wav = str(Path(d) / "padded.wav")
-
- global_options = "-y"
- if quiet:
- global_options += " -loglevel error"
-
- ff = ffmpy.FFmpeg(
- inputs={audio_path: None},
- # For inputs that are m4a (and others?), the input audio can
- # have samples that don't match the sample rate. This aresample
- # option forces ffmpeg to read timing information in the source
- # file instead of assuming constant sample rate.
- #
- # This fixes an issue where an input m4a file might be a
- # different length than the output wav file
- outputs={wav_file: "-af aresample=async=1000"},
- global_options=global_options,
- )
- ff.run()
-
- # We pad the file using the start time offset in case it's an audio
- # stream starting at some offset in a video container.
- pad, codec = ffprobe_offset_and_codec(audio_path)
-
- # For mp3s, don't pad files with discrepancies less than 0.027s -
- # it's likely due to codec latency. The amount of latency introduced
- # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
- # here slightly above that.
- # Source: https://lame.sourceforge.io/tech-FAQ.txt.
- if codec == "mp3" and pad < 0.027:
- pad = 0.0
- ff = ffmpy.FFmpeg(
- inputs={wav_file: None},
- outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
- global_options=global_options,
- )
- ff.run()
-
- signal = cls(padded_wav, **kwargs)
-
- return signal
diff --git a/dito/models/ldm/dac/audiotools/core/loudness.py b/dito/models/ldm/dac/audiotools/core/loudness.py
deleted file mode 100644
index cb3ee2675d7cb71f4c00106b0c1e901b8e51b842..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/loudness.py
+++ /dev/null
@@ -1,320 +0,0 @@
-import copy
-
-import julius
-import numpy as np
-import scipy
-import torch
-import torch.nn.functional as F
-import torchaudio
-
-
-class Meter(torch.nn.Module):
- """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
-
- Parameters
- ----------
- rate : int
- Sample rate of audio.
- filter_class : str, optional
- Class of weighting filter used.
- K-weighting' (default), 'Fenton/Lee 1'
- 'Fenton/Lee 2', 'Dash et al.'
- by default "K-weighting"
- block_size : float, optional
- Gating block size in seconds, by default 0.400
- zeros : int, optional
- Number of zeros to use in FIR approximation of
- IIR filters, by default 512
- use_fir : bool, optional
- Whether to use FIR approximation or exact IIR formulation.
- If computing on GPU, ``use_fir=True`` will be used, as its
- much faster, by default False
- """
-
- def __init__(
- self,
- rate: int,
- filter_class: str = "K-weighting",
- block_size: float = 0.400,
- zeros: int = 512,
- use_fir: bool = False,
- ):
- super().__init__()
-
- self.rate = rate
- self.filter_class = filter_class
- self.block_size = block_size
- self.use_fir = use_fir
-
- G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
- self.register_buffer("G", G)
-
- # Compute impulse responses so that filtering is fast via
- # a convolution at runtime, on GPU, unlike lfilter.
- impulse = np.zeros((zeros,))
- impulse[..., 0] = 1.0
-
- firs = np.zeros((len(self._filters), 1, zeros))
- passband_gain = torch.zeros(len(self._filters))
-
- for i, (_, filter_stage) in enumerate(self._filters.items()):
- firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
- passband_gain[i] = filter_stage.passband_gain
-
- firs = torch.from_numpy(firs[..., ::-1].copy()).float()
-
- self.register_buffer("firs", firs)
- self.register_buffer("passband_gain", passband_gain)
-
- def apply_filter_gpu(self, data: torch.Tensor):
- """Performs FIR approximation of loudness computation.
-
- Parameters
- ----------
- data : torch.Tensor
- Audio data of shape (nb, nch, nt).
-
- Returns
- -------
- torch.Tensor
- Filtered audio data.
- """
- # Data is of shape (nb, nch, nt)
- # Reshape to (nb*nch, 1, nt)
- nb, nt, nch = data.shape
- data = data.permute(0, 2, 1)
- data = data.reshape(nb * nch, 1, nt)
-
- # Apply padding
- pad_length = self.firs.shape[-1]
-
- # Apply filtering in sequence
- for i in range(self.firs.shape[0]):
- data = F.pad(data, (pad_length, pad_length))
- data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
- data = self.passband_gain[i] * data
- data = data[..., 1 : nt + 1]
-
- data = data.permute(0, 2, 1)
- data = data[:, :nt, :]
- return data
-
- def apply_filter_cpu(self, data: torch.Tensor):
- """Performs IIR formulation of loudness computation.
-
- Parameters
- ----------
- data : torch.Tensor
- Audio data of shape (nb, nch, nt).
-
- Returns
- -------
- torch.Tensor
- Filtered audio data.
- """
- for _, filter_stage in self._filters.items():
- passband_gain = filter_stage.passband_gain
-
- a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
- b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
-
- _data = data.permute(0, 2, 1)
- filtered = torchaudio.functional.lfilter(
- _data, a_coeffs, b_coeffs, clamp=False
- )
- data = passband_gain * filtered.permute(0, 2, 1)
- return data
-
- def apply_filter(self, data: torch.Tensor):
- """Applies filter on either CPU or GPU, depending
- on if the audio is on GPU or is on CPU, or if
- ``self.use_fir`` is True.
-
- Parameters
- ----------
- data : torch.Tensor
- Audio data of shape (nb, nch, nt).
-
- Returns
- -------
- torch.Tensor
- Filtered audio data.
- """
- if data.is_cuda or self.use_fir:
- data = self.apply_filter_gpu(data)
- else:
- data = self.apply_filter_cpu(data)
- return data
-
- def forward(self, data: torch.Tensor):
- """Computes integrated loudness of data.
-
- Parameters
- ----------
- data : torch.Tensor
- Audio data of shape (nb, nch, nt).
-
- Returns
- -------
- torch.Tensor
- Filtered audio data.
- """
- return self.integrated_loudness(data)
-
- def _unfold(self, input_data):
- T_g = self.block_size
- overlap = 0.75 # overlap of 75% of the block duration
- step = 1.0 - overlap # step size by percentage
-
- kernel_size = int(T_g * self.rate)
- stride = int(T_g * self.rate * step)
- unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
- unfolded = unfolded.transpose(-1, -2)
-
- return unfolded
-
- def integrated_loudness(self, data: torch.Tensor):
- """Computes integrated loudness of data.
-
- Parameters
- ----------
- data : torch.Tensor
- Audio data of shape (nb, nch, nt).
-
- Returns
- -------
- torch.Tensor
- Filtered audio data.
- """
- if not torch.is_tensor(data):
- data = torch.from_numpy(data).float()
- else:
- data = data.float()
-
- input_data = copy.copy(data)
- # Data always has a batch and channel dimension.
- # Is of shape (nb, nt, nch)
- if input_data.ndim < 2:
- input_data = input_data.unsqueeze(-1)
- if input_data.ndim < 3:
- input_data = input_data.unsqueeze(0)
-
- nb, nt, nch = input_data.shape
-
- # Apply frequency weighting filters - account
- # for the acoustic respose of the head and auditory system
- input_data = self.apply_filter(input_data)
-
- G = self.G # channel gains
- T_g = self.block_size # 400 ms gating block standard
- Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
-
- unfolded = self._unfold(input_data)
-
- z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
- l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
- l = l.expand_as(z)
-
- # find gating block indices above absolute threshold
- z_avg_gated = z
- z_avg_gated[l <= Gamma_a] = 0
- masked = l > Gamma_a
- z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
-
- # calculate the relative threshold value (see eq. 6)
- Gamma_r = (
- -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
- )
- Gamma_r = Gamma_r[:, None, None]
- Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
-
- # find gating block indices above relative and absolute thresholds (end of eq. 7)
- z_avg_gated = z
- z_avg_gated[l <= Gamma_a] = 0
- z_avg_gated[l <= Gamma_r] = 0
- masked = (l > Gamma_a) * (l > Gamma_r)
- z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
-
- # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
- # z_avg_gated = torch.nan_to_num(z_avg_gated)
- z_avg_gated = torch.where(
- z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
- )
- z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
- z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
-
- LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
- return LUFS.float()
-
- @property
- def filter_class(self):
- return self._filter_class
-
- @filter_class.setter
- def filter_class(self, value):
- from pyloudnorm import Meter
-
- meter = Meter(self.rate)
- meter.filter_class = value
- self._filter_class = value
- self._filters = meter._filters
-
-
-class LoudnessMixin:
- _loudness = None
- MIN_LOUDNESS = -70
- """Minimum loudness possible."""
-
- def loudness(
- self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
- ):
- """Calculates loudness using an implementation of ITU-R BS.1770-4.
- Allows control over gating block size and frequency weighting filters for
- additional control. Measure the integrated gated loudness of a signal.
-
- API is derived from PyLoudnorm, but this implementation is ported to PyTorch
- and is tensorized across batches. When on GPU, an FIR approximation of the IIR
- filters is used to compute loudness for speed.
-
- Uses the weighting filters and block size defined by the meter
- the integrated loudness is measured based upon the gating algorithm
- defined in the ITU-R BS.1770-4 specification.
-
- Parameters
- ----------
- filter_class : str, optional
- Class of weighting filter used.
- K-weighting' (default), 'Fenton/Lee 1'
- 'Fenton/Lee 2', 'Dash et al.'
- by default "K-weighting"
- block_size : float, optional
- Gating block size in seconds, by default 0.400
- kwargs : dict, optional
- Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
-
- Returns
- -------
- torch.Tensor
- Loudness of audio data.
- """
- if self._loudness is not None:
- return self._loudness.to(self.device)
- original_length = self.signal_length
- if self.signal_duration < 0.5:
- pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
- self.zero_pad(0, pad_len)
-
- # create BS.1770 meter
- meter = Meter(
- self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
- )
- meter = meter.to(self.device)
- # measure loudness
- loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
- self.truncate_samples(original_length)
- min_loudness = (
- torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
- )
- self._loudness = torch.maximum(loudness, min_loudness)
-
- return self._loudness.to(self.device)
diff --git a/dito/models/ldm/dac/audiotools/core/playback.py b/dito/models/ldm/dac/audiotools/core/playback.py
deleted file mode 100644
index 5d0f21aaa392494f35305c0084c05b87667ea14d..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/playback.py
+++ /dev/null
@@ -1,252 +0,0 @@
-"""
-These are utilities that allow one to embed an AudioSignal
-as a playable object in a Jupyter notebook, or to play audio from
-the terminal, etc.
-""" # fmt: skip
-import base64
-import io
-import random
-import string
-import subprocess
-from tempfile import NamedTemporaryFile
-
-import importlib_resources as pkg_resources
-
-from . import templates
-from .util import _close_temp_files
-from .util import format_figure
-
-headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
-widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
-
-DEFAULT_EXTENSION = ".wav"
-
-
-def _check_imports(): # pragma: no cover
- try:
- import ffmpy
- except:
- ffmpy = False
-
- try:
- import IPython
- except:
- raise ImportError("IPython must be installed in order to use this function!")
- return ffmpy, IPython
-
-
-class PlayMixin:
- def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
- """Embeds audio as a playable audio embed in a notebook, or HTML
- document, etc.
-
- Parameters
- ----------
- ext : str, optional
- Extension to use when saving the audio, by default ".wav"
- display : bool, optional
- This controls whether or not to display the audio when called. This
- is used when the embed is the last line in a Jupyter cell, to prevent
- the audio from being embedded twice, by default True
- return_html : bool, optional
- Whether to return the data wrapped in an HTML audio element, by default False
-
- Returns
- -------
- str
- Either the element for display, or the HTML string of it.
- """
- if ext is None:
- ext = DEFAULT_EXTENSION
- ext = f".{ext}" if not ext.startswith(".") else ext
- ffmpy, IPython = _check_imports()
- sr = self.sample_rate
- tmpfiles = []
-
- with _close_temp_files(tmpfiles):
- tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
- tmpfiles.append(tmp_wav)
- self.write(tmp_wav.name)
- if ext != ".wav" and ffmpy:
- tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
- tmpfiles.append(tmp_wav)
- ff = ffmpy.FFmpeg(
- inputs={tmp_wav.name: None},
- outputs={
- tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
- },
- )
- ff.run()
- else:
- tmp_converted = tmp_wav
-
- audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
- if display:
- IPython.display.display(audio_element)
-
- if return_html:
- audio_element = (
- f" "
- )
- return audio_element
-
- def widget(
- self,
- title: str = None,
- ext: str = ".wav",
- add_headers: bool = True,
- player_width: str = "100%",
- margin: str = "10px",
- plot_fn: str = "specshow",
- return_html: bool = False,
- **kwargs,
- ):
- """Creates a playable widget with spectrogram. Inspired (heavily) by
- https://sjvasquez.github.io/blog/melnet/.
-
- Parameters
- ----------
- title : str, optional
- Title of plot, placed in upper right of top-most axis.
- ext : str, optional
- Extension for embedding, by default ".mp3"
- add_headers : bool, optional
- Whether or not to add headers (use for first embed, False for later embeds), by default True
- player_width : str, optional
- Width of the player, as a string in a CSS rule, by default "100%"
- margin : str, optional
- Margin on all sides of player, by default "10px"
- plot_fn : function, optional
- Plotting function to use (by default self.specshow).
- return_html : bool, optional
- Whether to return the data wrapped in an HTML audio element, by default False
- kwargs : dict, optional
- Keyword arguments to plot_fn (by default self.specshow).
-
- Returns
- -------
- HTML
- HTML object.
- """
- import matplotlib.pyplot as plt
-
- def _save_fig_to_tag():
- buffer = io.BytesIO()
-
- plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
- plt.close()
-
- buffer.seek(0)
- data_uri = base64.b64encode(buffer.read()).decode("ascii")
- tag = "data:image/png;base64,{0}".format(data_uri)
-
- return tag
-
- _, IPython = _check_imports()
-
- header_html = ""
-
- if add_headers:
- header_html = headers.replace("PLAYER_WIDTH", str(player_width))
- header_html = header_html.replace("MARGIN", str(margin))
- IPython.display.display(IPython.display.HTML(header_html))
-
- widget_html = widget
- if isinstance(plot_fn, str):
- plot_fn = getattr(self, plot_fn)
- kwargs["title"] = title
- plot_fn(**kwargs)
-
- fig = plt.gcf()
- pixels = fig.get_size_inches() * fig.dpi
-
- tag = _save_fig_to_tag()
-
- # Make the source image for the levels
- self.specshow()
- format_figure((12, 1.5))
- levels_tag = _save_fig_to_tag()
-
- player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
-
- audio_elem = self.embed(ext=ext, display=False)
- widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
- widget_html = widget_html.replace("IMAGE_SRC", tag)
- widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
- widget_html = widget_html.replace("PLAYER_ID", player_id)
-
- # Calculate width/height of figure based on figure size.
- widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
- widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
-
- IPython.display.display(IPython.display.HTML(widget_html))
-
- if return_html:
- html = header_html if add_headers else ""
- html += widget_html
- return html
-
- def play(self):
- """
- Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
- Otherwise, will fail. The audio signal is written to a temporary file
- and then played with ffplay.
- """
- tmpfiles = []
- with _close_temp_files(tmpfiles):
- tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
- tmpfiles.append(tmp_wav)
- self.write(tmp_wav.name)
- print(self)
- subprocess.call(
- [
- "ffplay",
- "-nodisp",
- "-autoexit",
- "-hide_banner",
- "-loglevel",
- "error",
- tmp_wav.name,
- ]
- )
- return self
-
-
-if __name__ == "__main__": # pragma: no cover
- from audiotools import AudioSignal
-
- signal = AudioSignal(
- "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
- )
-
- wave_html = signal.widget(
- "Waveform",
- plot_fn="waveplot",
- return_html=True,
- )
-
- spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
-
- combined_html = signal.widget(
- "Waveform + spectrogram",
- plot_fn="wavespec",
- return_html=True,
- add_headers=False,
- )
-
- signal.low_pass(8000)
- lowpass_html = signal.widget(
- "Lowpassed audio",
- plot_fn="wavespec",
- return_html=True,
- add_headers=False,
- )
-
- with open("/tmp/index.html", "w") as f:
- f.write(wave_html)
- f.write(spec_html)
- f.write(combined_html)
- f.write(lowpass_html)
diff --git a/dito/models/ldm/dac/audiotools/core/templates/__init__.py b/dito/models/ldm/dac/audiotools/core/templates/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/dito/models/ldm/dac/audiotools/core/templates/headers.html b/dito/models/ldm/dac/audiotools/core/templates/headers.html
deleted file mode 100644
index 9eaef4a94d575f7826608ad63dcc77fab13b7b19..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/templates/headers.html
+++ /dev/null
@@ -1,322 +0,0 @@
-
-
-
-
-
-
diff --git a/dito/models/ldm/dac/audiotools/core/templates/pandoc.css b/dito/models/ldm/dac/audiotools/core/templates/pandoc.css
deleted file mode 100644
index 842be7be6d65580dab44c6a8013259644f38e6ee..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/templates/pandoc.css
+++ /dev/null
@@ -1,407 +0,0 @@
-/*
-Copyright (c) 2017 Chris Patuzzo
-https://twitter.com/chrispatuzzo
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
-*/
-
-body {
- font-family: Helvetica, arial, sans-serif;
- font-size: 14px;
- line-height: 1.6;
- padding-top: 10px;
- padding-bottom: 10px;
- background-color: white;
- padding: 30px;
- color: #333;
-}
-
-body > *:first-child {
- margin-top: 0 !important;
-}
-
-body > *:last-child {
- margin-bottom: 0 !important;
-}
-
-a {
- color: #4183C4;
- text-decoration: none;
-}
-
-a.absent {
- color: #cc0000;
-}
-
-a.anchor {
- display: block;
- padding-left: 30px;
- margin-left: -30px;
- cursor: pointer;
- position: absolute;
- top: 0;
- left: 0;
- bottom: 0;
-}
-
-h1, h2, h3, h4, h5, h6 {
- margin: 20px 0 10px;
- padding: 0;
- font-weight: bold;
- -webkit-font-smoothing: antialiased;
- cursor: text;
- position: relative;
-}
-
-h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
- margin-top: 0;
- padding-top: 0;
-}
-
-h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
- text-decoration: none;
-}
-
-h1 tt, h1 code {
- font-size: inherit;
-}
-
-h2 tt, h2 code {
- font-size: inherit;
-}
-
-h3 tt, h3 code {
- font-size: inherit;
-}
-
-h4 tt, h4 code {
- font-size: inherit;
-}
-
-h5 tt, h5 code {
- font-size: inherit;
-}
-
-h6 tt, h6 code {
- font-size: inherit;
-}
-
-h1 {
- font-size: 28px;
- color: black;
-}
-
-h2 {
- font-size: 24px;
- border-bottom: 1px solid #cccccc;
- color: black;
-}
-
-h3 {
- font-size: 18px;
-}
-
-h4 {
- font-size: 16px;
-}
-
-h5 {
- font-size: 14px;
-}
-
-h6 {
- color: #777777;
- font-size: 14px;
-}
-
-p, blockquote, ul, ol, dl, li, table, pre {
- margin: 15px 0;
-}
-
-hr {
- border: 0 none;
- color: #cccccc;
- height: 4px;
- padding: 0;
-}
-
-body > h2:first-child {
- margin-top: 0;
- padding-top: 0;
-}
-
-body > h1:first-child {
- margin-top: 0;
- padding-top: 0;
-}
-
-body > h1:first-child + h2 {
- margin-top: 0;
- padding-top: 0;
-}
-
-body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
- margin-top: 0;
- padding-top: 0;
-}
-
-a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
- margin-top: 0;
- padding-top: 0;
-}
-
-h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
- margin-top: 0;
-}
-
-li p.first {
- display: inline-block;
-}
-
-ul, ol {
- padding-left: 30px;
-}
-
-ul :first-child, ol :first-child {
- margin-top: 0;
-}
-
-ul :last-child, ol :last-child {
- margin-bottom: 0;
-}
-
-dl {
- padding: 0;
-}
-
-dl dt {
- font-size: 14px;
- font-weight: bold;
- font-style: italic;
- padding: 0;
- margin: 15px 0 5px;
-}
-
-dl dt:first-child {
- padding: 0;
-}
-
-dl dt > :first-child {
- margin-top: 0;
-}
-
-dl dt > :last-child {
- margin-bottom: 0;
-}
-
-dl dd {
- margin: 0 0 15px;
- padding: 0 15px;
-}
-
-dl dd > :first-child {
- margin-top: 0;
-}
-
-dl dd > :last-child {
- margin-bottom: 0;
-}
-
-blockquote {
- border-left: 4px solid #dddddd;
- padding: 0 15px;
- color: #777777;
-}
-
-blockquote > :first-child {
- margin-top: 0;
-}
-
-blockquote > :last-child {
- margin-bottom: 0;
-}
-
-table {
- padding: 0;
-}
-table tr {
- border-top: 1px solid #cccccc;
- background-color: white;
- margin: 0;
- padding: 0;
-}
-
-table tr:nth-child(2n) {
- background-color: #f8f8f8;
-}
-
-table tr th {
- font-weight: bold;
- border: 1px solid #cccccc;
- text-align: left;
- margin: 0;
- padding: 6px 13px;
-}
-
-table tr td {
- border: 1px solid #cccccc;
- text-align: left;
- margin: 0;
- padding: 6px 13px;
-}
-
-table tr th :first-child, table tr td :first-child {
- margin-top: 0;
-}
-
-table tr th :last-child, table tr td :last-child {
- margin-bottom: 0;
-}
-
-img {
- max-width: 100%;
-}
-
-span.frame {
- display: block;
- overflow: hidden;
-}
-
-span.frame > span {
- border: 1px solid #dddddd;
- display: block;
- float: left;
- overflow: hidden;
- margin: 13px 0 0;
- padding: 7px;
- width: auto;
-}
-
-span.frame span img {
- display: block;
- float: left;
-}
-
-span.frame span span {
- clear: both;
- color: #333333;
- display: block;
- padding: 5px 0 0;
-}
-
-span.align-center {
- display: block;
- overflow: hidden;
- clear: both;
-}
-
-span.align-center > span {
- display: block;
- overflow: hidden;
- margin: 13px auto 0;
- text-align: center;
-}
-
-span.align-center span img {
- margin: 0 auto;
- text-align: center;
-}
-
-span.align-right {
- display: block;
- overflow: hidden;
- clear: both;
-}
-
-span.align-right > span {
- display: block;
- overflow: hidden;
- margin: 13px 0 0;
- text-align: right;
-}
-
-span.align-right span img {
- margin: 0;
- text-align: right;
-}
-
-span.float-left {
- display: block;
- margin-right: 13px;
- overflow: hidden;
- float: left;
-}
-
-span.float-left span {
- margin: 13px 0 0;
-}
-
-span.float-right {
- display: block;
- margin-left: 13px;
- overflow: hidden;
- float: right;
-}
-
-span.float-right > span {
- display: block;
- overflow: hidden;
- margin: 13px auto 0;
- text-align: right;
-}
-
-code, tt {
- margin: 0 2px;
- padding: 0 5px;
- white-space: nowrap;
- border-radius: 3px;
-}
-
-pre code {
- margin: 0;
- padding: 0;
- white-space: pre;
- border: none;
- background: transparent;
-}
-
-.highlight pre {
- font-size: 13px;
- line-height: 19px;
- overflow: auto;
- padding: 6px 10px;
- border-radius: 3px;
-}
-
-pre {
- font-size: 13px;
- line-height: 19px;
- overflow: auto;
- padding: 6px 10px;
- border-radius: 3px;
-}
-
-pre code, pre tt {
- background-color: transparent;
- border: none;
-}
-
-body {
- max-width: 600px;
-}
diff --git a/dito/models/ldm/dac/audiotools/core/templates/widget.html b/dito/models/ldm/dac/audiotools/core/templates/widget.html
deleted file mode 100644
index 0b44e8aec64fd1db929da5fa6208dee00247c967..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/templates/widget.html
+++ /dev/null
@@ -1,52 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/dito/models/ldm/dac/audiotools/core/util.py b/dito/models/ldm/dac/audiotools/core/util.py
deleted file mode 100644
index ece1344658d10836aa2eb693f275294ad8cdbb52..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/util.py
+++ /dev/null
@@ -1,671 +0,0 @@
-import csv
-import glob
-import math
-import numbers
-import os
-import random
-import typing
-from contextlib import contextmanager
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Dict
-from typing import List
-
-import numpy as np
-import torch
-import torchaudio
-from flatten_dict import flatten
-from flatten_dict import unflatten
-
-
-@dataclass
-class Info:
- """Shim for torchaudio.info API changes."""
-
- sample_rate: float
- num_frames: int
-
- @property
- def duration(self) -> float:
- return self.num_frames / self.sample_rate
-
-
-def info(audio_path: str):
- """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
-
- Parameters
- ----------
- audio_path : str
- Path to audio file.
- """
- # try default backend first, then fallback to soundfile
- try:
- info = torchaudio.info(str(audio_path))
- except: # pragma: no cover
- info = torchaudio.backend.soundfile_backend.info(str(audio_path))
-
- if isinstance(info, tuple): # pragma: no cover
- signal_info = info[0]
- info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
- else:
- info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
-
- return info
-
-
-def ensure_tensor(
- x: typing.Union[np.ndarray, torch.Tensor, float, int],
- ndim: int = None,
- batch_size: int = None,
-):
- """Ensures that the input ``x`` is a tensor of specified
- dimensions and batch size.
-
- Parameters
- ----------
- x : typing.Union[np.ndarray, torch.Tensor, float, int]
- Data that will become a tensor on its way out.
- ndim : int, optional
- How many dimensions should be in the output, by default None
- batch_size : int, optional
- The batch size of the output, by default None
-
- Returns
- -------
- torch.Tensor
- Modified version of ``x`` as a tensor.
- """
- if not torch.is_tensor(x):
- x = torch.as_tensor(x)
- if ndim is not None:
- assert x.ndim <= ndim
- while x.ndim < ndim:
- x = x.unsqueeze(-1)
- if batch_size is not None:
- if x.shape[0] != batch_size:
- shape = list(x.shape)
- shape[0] = batch_size
- x = x.expand(*shape)
- return x
-
-
-def _get_value(other):
- from . import AudioSignal
-
- if isinstance(other, AudioSignal):
- return other.audio_data
- return other
-
-
-def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
- """Closest frequency bin given a frequency, number
- of bins, and a sampling rate.
-
- Parameters
- ----------
- hz : torch.Tensor
- Tensor of frequencies in Hz.
- n_fft : int
- Number of FFT bins.
- sample_rate : int
- Sample rate of audio.
-
- Returns
- -------
- torch.Tensor
- Closest bins to the data.
- """
- shape = hz.shape
- hz = hz.flatten()
- freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
- hz[hz > sample_rate / 2] = sample_rate / 2
-
- closest = (hz[None, :] - freqs[:, None]).abs()
- closest_bins = closest.min(dim=0).indices
-
- return closest_bins.reshape(*shape)
-
-
-def random_state(seed: typing.Union[int, np.random.RandomState]):
- """
- Turn seed into a np.random.RandomState instance.
-
- Parameters
- ----------
- seed : typing.Union[int, np.random.RandomState] or None
- If seed is None, return the RandomState singleton used by np.random.
- If seed is an int, return a new RandomState instance seeded with seed.
- If seed is already a RandomState instance, return it.
- Otherwise raise ValueError.
-
- Returns
- -------
- np.random.RandomState
- Random state object.
-
- Raises
- ------
- ValueError
- If seed is not valid, an error is thrown.
- """
- if seed is None or seed is np.random:
- return np.random.mtrand._rand
- elif isinstance(seed, (numbers.Integral, np.integer, int)):
- return np.random.RandomState(seed)
- elif isinstance(seed, np.random.RandomState):
- return seed
- else:
- raise ValueError(
- "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
- )
-
-
-def seed(random_seed, set_cudnn=False):
- """
- Seeds all random states with the same random seed
- for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
- random generators.
- For full reproducibility, two further options must be set
- according to the torch documentation:
- https://pytorch.org/docs/stable/notes/randomness.html
- To do this, ``set_cudnn`` must be True. It defaults to
- False, since setting it to True results in a performance
- hit.
-
- Args:
- random_seed (int): integer corresponding to random seed to
- use.
- set_cudnn (bool): Whether or not to set cudnn into determinstic
- mode and off of benchmark mode. Defaults to False.
- """
-
- torch.manual_seed(random_seed)
- np.random.seed(random_seed)
- random.seed(random_seed)
-
- if set_cudnn:
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
-
-@contextmanager
-def _close_temp_files(tmpfiles: list):
- """Utility function for creating a context and closing all temporary files
- once the context is exited. For correct functionality, all temporary file
- handles created inside the context must be appended to the ```tmpfiles```
- list.
-
- This function is taken wholesale from Scaper.
-
- Parameters
- ----------
- tmpfiles : list
- List of temporary file handles
- """
-
- def _close():
- for t in tmpfiles:
- try:
- t.close()
- os.unlink(t.name)
- except:
- pass
-
- try:
- yield
- except: # pragma: no cover
- _close()
- raise
- _close()
-
-
-AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
-
-
-def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
- """Finds all audio files in a directory recursively.
- Returns a list.
-
- Parameters
- ----------
- folder : str
- Folder to look for audio files in, recursively.
- ext : List[str], optional
- Extensions to look for without the ., by default
- ``['.wav', '.flac', '.mp3', '.mp4']``.
- """
- folder = Path(folder)
- # Take care of case where user has passed in an audio file directly
- # into one of the calling functions.
- if str(folder).endswith(tuple(ext)):
- # if, however, there's a glob in the path, we need to
- # return the glob, not the file.
- if "*" in str(folder):
- return glob.glob(str(folder), recursive=("**" in str(folder)))
- else:
- return [folder]
-
- files = []
- for x in ext:
- files += folder.glob(f"**/*{x}")
- return files
-
-
-def read_sources(
- sources: List[str],
- remove_empty: bool = True,
- relative_path: str = "",
- ext: List[str] = AUDIO_EXTENSIONS,
-):
- """Reads audio sources that can either be folders
- full of audio files, or CSV files that contain paths
- to audio files. CSV files that adhere to the expected
- format can be generated by
- :py:func:`audiotools.data.preprocess.create_csv`.
-
- Parameters
- ----------
- sources : List[str]
- List of audio sources to be converted into a
- list of lists of audio files.
- remove_empty : bool, optional
- Whether or not to remove rows with an empty "path"
- from each CSV file, by default True.
-
- Returns
- -------
- list
- List of lists of rows of CSV files.
- """
- files = []
- relative_path = Path(relative_path)
- for source in sources:
- source = str(source)
- _files = []
- if source.endswith(".csv"):
- with open(source, "r") as f:
- reader = csv.DictReader(f)
- for x in reader:
- if remove_empty and x["path"] == "":
- continue
- if x["path"] != "":
- x["path"] = str(relative_path / x["path"])
- _files.append(x)
- else:
- for x in find_audio(source, ext=ext):
- x = str(relative_path / x)
- _files.append({"path": x})
- files.append(sorted(_files, key=lambda x: x["path"]))
- return files
-
-
-def choose_from_list_of_lists(
- state: np.random.RandomState, list_of_lists: list, p: float = None
-):
- """Choose a single item from a list of lists.
-
- Parameters
- ----------
- state : np.random.RandomState
- Random state to use when choosing an item.
- list_of_lists : list
- A list of lists from which items will be drawn.
- p : float, optional
- Probabilities of each list, by default None
-
- Returns
- -------
- typing.Any
- An item from the list of lists.
- """
- source_idx = state.choice(list(range(len(list_of_lists))), p=p)
- item_idx = state.randint(len(list_of_lists[source_idx]))
- return list_of_lists[source_idx][item_idx], source_idx, item_idx
-
-
-@contextmanager
-def chdir(newdir: typing.Union[Path, str]):
- """
- Context manager for switching directories to run a
- function. Useful for when you want to use relative
- paths to different runs.
-
- Parameters
- ----------
- newdir : typing.Union[Path, str]
- Directory to switch to.
- """
- curdir = os.getcwd()
- try:
- os.chdir(newdir)
- yield
- finally:
- os.chdir(curdir)
-
-
-def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
- """Moves items in a batch (typically generated by a DataLoader as a list
- or a dict) to the specified device. This works even if dictionaries
- are nested.
-
- Parameters
- ----------
- batch : typing.Union[dict, list, torch.Tensor]
- Batch, typically generated by a dataloader, that will be moved to
- the device.
- device : str, optional
- Device to move batch to, by default "cpu"
-
- Returns
- -------
- typing.Union[dict, list, torch.Tensor]
- Batch with all values moved to the specified device.
- """
- if isinstance(batch, dict):
- batch = flatten(batch)
- for key, val in batch.items():
- try:
- batch[key] = val.to(device)
- except:
- pass
- batch = unflatten(batch)
- elif torch.is_tensor(batch):
- batch = batch.to(device)
- elif isinstance(batch, list):
- for i in range(len(batch)):
- try:
- batch[i] = batch[i].to(device)
- except:
- pass
- return batch
-
-
-def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
- """Samples from a distribution defined by a tuple. The first
- item in the tuple is the distribution type, and the rest of the
- items are arguments to that distribution. The distribution function
- is gotten from the ``np.random.RandomState`` object.
-
- Parameters
- ----------
- dist_tuple : tuple
- Distribution tuple
- state : np.random.RandomState, optional
- Random state, or seed to use, by default None
-
- Returns
- -------
- typing.Union[float, int, str]
- Draw from the distribution.
-
- Examples
- --------
- Sample from a uniform distribution:
-
- >>> dist_tuple = ("uniform", 0, 1)
- >>> sample_from_dist(dist_tuple)
-
- Sample from a constant distribution:
-
- >>> dist_tuple = ("const", 0)
- >>> sample_from_dist(dist_tuple)
-
- Sample from a normal distribution:
-
- >>> dist_tuple = ("normal", 0, 0.5)
- >>> sample_from_dist(dist_tuple)
-
- """
- if dist_tuple[0] == "const":
- return dist_tuple[1]
- state = random_state(state)
- dist_fn = getattr(state, dist_tuple[0])
- return dist_fn(*dist_tuple[1:])
-
-
-def collate(list_of_dicts: list, n_splits: int = None):
- """Collates a list of dictionaries (e.g. as returned by a
- dataloader) into a dictionary with batched values. This routine
- uses the default torch collate function for everything
- except AudioSignal objects, which are handled by the
- :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
- function.
-
- This function takes n_splits to enable splitting a batch
- into multiple sub-batches for the purposes of gradient accumulation,
- etc.
-
- Parameters
- ----------
- list_of_dicts : list
- List of dictionaries to be collated.
- n_splits : int
- Number of splits to make when creating the batches (split into
- sub-batches). Useful for things like gradient accumulation.
-
- Returns
- -------
- dict
- Dictionary containing batched data.
- """
-
- from . import AudioSignal
-
- batches = []
- list_len = len(list_of_dicts)
-
- return_list = False if n_splits is None else True
- n_splits = 1 if n_splits is None else n_splits
- n_items = int(math.ceil(list_len / n_splits))
-
- for i in range(0, list_len, n_items):
- # Flatten the dictionaries to avoid recursion.
- list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
- dict_of_lists = {
- k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
- }
-
- batch = {}
- for k, v in dict_of_lists.items():
- if isinstance(v, list):
- if all(isinstance(s, AudioSignal) for s in v):
- batch[k] = AudioSignal.batch(v, pad_signals=True)
- else:
- # Borrow the default collate fn from torch.
- batch[k] = torch.utils.data._utils.collate.default_collate(v)
- batches.append(unflatten(batch))
-
- batches = batches[0] if not return_list else batches
- return batches
-
-
-BASE_SIZE = 864
-DEFAULT_FIG_SIZE = (9, 3)
-
-
-def format_figure(
- fig_size: tuple = None,
- title: str = None,
- fig=None,
- format_axes: bool = True,
- format: bool = True,
- font_color: str = "white",
-):
- """Prettifies the spectrogram and waveform plots. A title
- can be inset into the top right corner, and the axes can be
- inset into the figure, allowing the data to take up the entire
- image. Used in
-
- - :py:func:`audiotools.core.display.DisplayMixin.specshow`
- - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
- - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
-
- Parameters
- ----------
- fig_size : tuple, optional
- Size of figure, by default (9, 3)
- title : str, optional
- Title to inset in top right, by default None
- fig : matplotlib.figure.Figure, optional
- Figure object, if None ``plt.gcf()`` will be used, by default None
- format_axes : bool, optional
- Format the axes to be inside the figure, by default True
- format : bool, optional
- This formatting can be skipped entirely by passing ``format=False``
- to any of the plotting functions that use this formater, by default True
- font_color : str, optional
- Color of font of axes, by default "white"
- """
- import matplotlib
- import matplotlib.pyplot as plt
-
- if fig_size is None:
- fig_size = DEFAULT_FIG_SIZE
- if not format:
- return
- if fig is None:
- fig = plt.gcf()
- fig.set_size_inches(*fig_size)
- axs = fig.axes
-
- pixels = (fig.get_size_inches() * fig.dpi)[0]
- font_scale = pixels / BASE_SIZE
-
- if format_axes:
- axs = fig.axes
-
- for ax in axs:
- ymin, _ = ax.get_ylim()
- xmin, _ = ax.get_xlim()
-
- ticks = ax.get_yticks()
- for t in ticks[2:-1]:
- t = axs[0].annotate(
- f"{(t / 1000):2.1f}k",
- xy=(xmin, t),
- xycoords="data",
- xytext=(5, -5),
- textcoords="offset points",
- ha="left",
- va="top",
- color=font_color,
- fontsize=12 * font_scale,
- alpha=0.75,
- )
-
- ticks = ax.get_xticks()[2:]
- for t in ticks[:-1]:
- t = axs[0].annotate(
- f"{t:2.1f}s",
- xy=(t, ymin),
- xycoords="data",
- xytext=(5, 5),
- textcoords="offset points",
- ha="center",
- va="bottom",
- color=font_color,
- fontsize=12 * font_scale,
- alpha=0.75,
- )
-
- ax.margins(0, 0)
- ax.set_axis_off()
- ax.xaxis.set_major_locator(plt.NullLocator())
- ax.yaxis.set_major_locator(plt.NullLocator())
-
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
-
- if title is not None:
- t = axs[0].annotate(
- title,
- xy=(1, 1),
- xycoords="axes fraction",
- fontsize=20 * font_scale,
- xytext=(-5, -5),
- textcoords="offset points",
- ha="right",
- va="top",
- color="white",
- )
- t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
-
-
-def generate_chord_dataset(
- max_voices: int = 8,
- sample_rate: int = 44100,
- num_items: int = 5,
- duration: float = 1.0,
- min_note: str = "C2",
- max_note: str = "C6",
- output_dir: Path = "chords",
-):
- """
- Generates a toy multitrack dataset of chords, synthesized from sine waves.
-
-
- Parameters
- ----------
- max_voices : int, optional
- Maximum number of voices in a chord, by default 8
- sample_rate : int, optional
- Sample rate of audio, by default 44100
- num_items : int, optional
- Number of items to generate, by default 5
- duration : float, optional
- Duration of each item, by default 1.0
- min_note : str, optional
- Minimum note in the dataset, by default "C2"
- max_note : str, optional
- Maximum note in the dataset, by default "C6"
- output_dir : Path, optional
- Directory to save the dataset, by default "chords"
-
- """
- import librosa
- from . import AudioSignal
- from ..data.preprocess import create_csv
-
- min_midi = librosa.note_to_midi(min_note)
- max_midi = librosa.note_to_midi(max_note)
-
- tracks = []
- for idx in range(num_items):
- track = {}
- # figure out how many voices to put in this track
- num_voices = random.randint(1, max_voices)
- for voice_idx in range(num_voices):
- # choose some random params
- midinote = random.randint(min_midi, max_midi)
- dur = random.uniform(0.85 * duration, duration)
-
- sig = AudioSignal.wave(
- frequency=librosa.midi_to_hz(midinote),
- duration=dur,
- sample_rate=sample_rate,
- shape="sine",
- )
- track[f"voice_{voice_idx}"] = sig
- tracks.append(track)
-
- # save the tracks to disk
- output_dir = Path(output_dir)
- output_dir.mkdir(exist_ok=True)
- for idx, track in enumerate(tracks):
- track_dir = output_dir / f"track_{idx}"
- track_dir.mkdir(exist_ok=True)
- for voice_name, sig in track.items():
- sig.write(track_dir / f"{voice_name}.wav")
-
- all_voices = list(set([k for track in tracks for k in track.keys()]))
- voice_lists = {voice: [] for voice in all_voices}
- for track in tracks:
- for voice_name in all_voices:
- if voice_name in track:
- voice_lists[voice_name].append(track[voice_name].path_to_file)
- else:
- voice_lists[voice_name].append("")
-
- for voice_name, paths in voice_lists.items():
- create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
-
- return output_dir
diff --git a/dito/models/ldm/dac/audiotools/core/whisper.py b/dito/models/ldm/dac/audiotools/core/whisper.py
deleted file mode 100644
index 46c071f934fc3e2be3138e7596b1c6d2ef79eade..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/core/whisper.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import torch
-
-
-class WhisperMixin:
- is_initialized = False
-
- def setup_whisper(
- self,
- pretrained_model_name_or_path: str = "openai/whisper-base.en",
- device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
- ):
- from transformers import WhisperForConditionalGeneration
- from transformers import WhisperProcessor
-
- self.whisper_device = device
- self.whisper_processor = WhisperProcessor.from_pretrained(
- pretrained_model_name_or_path
- )
- self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
- pretrained_model_name_or_path
- ).to(self.whisper_device)
- self.is_initialized = True
-
- def get_whisper_features(self) -> torch.Tensor:
- """Preprocess audio signal as per the whisper model's training config.
-
- Returns
- -------
- torch.Tensor
- The prepinput features of the audio signal. Shape: (1, channels, seq_len)
- """
- import torch
-
- if not self.is_initialized:
- self.setup_whisper()
-
- signal = self.to(self.device)
- raw_speech = list(
- (
- signal.clone()
- .resample(self.whisper_processor.feature_extractor.sampling_rate)
- .audio_data[:, 0, :]
- .numpy()
- )
- )
-
- with torch.inference_mode():
- input_features = self.whisper_processor(
- raw_speech,
- sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
- return_tensors="pt",
- ).input_features
-
- return input_features
-
- def get_whisper_transcript(self) -> str:
- """Get the transcript of the audio signal using the whisper model.
-
- Returns
- -------
- str
- The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
- """
-
- if not self.is_initialized:
- self.setup_whisper()
-
- input_features = self.get_whisper_features()
-
- with torch.inference_mode():
- input_features = input_features.to(self.whisper_device)
- generated_ids = self.whisper_model.generate(inputs=input_features)
-
- transcription = self.whisper_processor.batch_decode(generated_ids)
- return transcription[0]
-
- def get_whisper_embeddings(self) -> torch.Tensor:
- """Get the last hidden state embeddings of the audio signal using the whisper model.
-
- Returns
- -------
- torch.Tensor
- The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
- """
- import torch
-
- if not self.is_initialized:
- self.setup_whisper()
-
- input_features = self.get_whisper_features()
- encoder = self.whisper_model.get_encoder()
-
- with torch.inference_mode():
- input_features = input_features.to(self.whisper_device)
- embeddings = encoder(input_features)
-
- return embeddings.last_hidden_state
diff --git a/dito/models/ldm/dac/audiotools/data/__init__.py b/dito/models/ldm/dac/audiotools/data/__init__.py
deleted file mode 100644
index aead269f26f3782043e68418b4c87ee323cbd015..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/data/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from . import datasets
-from . import preprocess
-from . import transforms
diff --git a/dito/models/ldm/dac/audiotools/data/datasets.py b/dito/models/ldm/dac/audiotools/data/datasets.py
deleted file mode 100644
index 12e7a60963399aa15ff865de2d06537818ce18ee..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/data/datasets.py
+++ /dev/null
@@ -1,517 +0,0 @@
-from pathlib import Path
-from typing import Callable
-from typing import Dict
-from typing import List
-from typing import Union
-
-import numpy as np
-from torch.utils.data import SequentialSampler
-from torch.utils.data.distributed import DistributedSampler
-
-from ..core import AudioSignal
-from ..core import util
-
-
-class AudioLoader:
- """Loads audio endlessly from a list of audio sources
- containing paths to audio files. Audio sources can be
- folders full of audio files (which are found via file
- extension) or by providing a CSV file which contains paths
- to audio files.
-
- Parameters
- ----------
- sources : List[str], optional
- Sources containing folders, or CSVs with
- paths to audio files, by default None
- weights : List[float], optional
- Weights to sample audio files from each source, by default None
- relative_path : str, optional
- Path audio should be loaded relative to, by default ""
- transform : Callable, optional
- Transform to instantiate alongside audio sample,
- by default None
- ext : List[str]
- List of extensions to find audio within each source by. Can
- also be a file name (e.g. "vocals.wav"). by default
- ``['.wav', '.flac', '.mp3', '.mp4']``.
- shuffle: bool
- Whether to shuffle the files within the dataloader. Defaults to True.
- shuffle_state: int
- State to use to seed the shuffle of the files.
- """
-
- def __init__(
- self,
- sources: List[str] = None,
- weights: List[float] = None,
- transform: Callable = None,
- relative_path: str = "",
- ext: List[str] = util.AUDIO_EXTENSIONS,
- shuffle: bool = True,
- shuffle_state: int = 0,
- ):
- self.audio_lists = util.read_sources(
- sources, relative_path=relative_path, ext=ext
- )
-
- self.audio_indices = [
- (src_idx, item_idx)
- for src_idx, src in enumerate(self.audio_lists)
- for item_idx in range(len(src))
- ]
- if shuffle:
- state = util.random_state(shuffle_state)
- state.shuffle(self.audio_indices)
-
- self.sources = sources
- self.weights = weights
- self.transform = transform
-
- def __call__(
- self,
- state,
- sample_rate: int,
- duration: float,
- loudness_cutoff: float = -40,
- num_channels: int = 1,
- offset: float = None,
- source_idx: int = None,
- item_idx: int = None,
- global_idx: int = None,
- ):
- if source_idx is not None and item_idx is not None:
- try:
- audio_info = self.audio_lists[source_idx][item_idx]
- except:
- audio_info = {"path": "none"}
- elif global_idx is not None:
- source_idx, item_idx = self.audio_indices[
- global_idx % len(self.audio_indices)
- ]
- audio_info = self.audio_lists[source_idx][item_idx]
- else:
- audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
- state, self.audio_lists, p=self.weights
- )
-
- path = audio_info["path"]
- signal = AudioSignal.zeros(duration, sample_rate, num_channels)
-
- if path != "none":
- if offset is None:
- signal = AudioSignal.salient_excerpt(
- path,
- duration=duration,
- state=state,
- loudness_cutoff=loudness_cutoff,
- )
- else:
- signal = AudioSignal(
- path,
- offset=offset,
- duration=duration,
- )
-
- if num_channels == 1:
- signal = signal.to_mono()
- signal = signal.resample(sample_rate)
-
- if signal.duration < duration:
- signal = signal.zero_pad_to(int(duration * sample_rate))
-
- for k, v in audio_info.items():
- signal.metadata[k] = v
-
- item = {
- "signal": signal,
- "source_idx": source_idx,
- "item_idx": item_idx,
- "source": str(self.sources[source_idx]),
- "path": str(path),
- }
- if self.transform is not None:
- item["transform_args"] = self.transform.instantiate(state, signal=signal)
- return item
-
-
-def default_matcher(x, y):
- return Path(x).parent == Path(y).parent
-
-
-def align_lists(lists, matcher: Callable = default_matcher):
- longest_list = lists[np.argmax([len(l) for l in lists])]
- for i, x in enumerate(longest_list):
- for l in lists:
- if i >= len(l):
- l.append({"path": "none"})
- elif not matcher(l[i]["path"], x["path"]):
- l.insert(i, {"path": "none"})
- return lists
-
-
-class AudioDataset:
- """Loads audio from multiple loaders (with associated transforms)
- for a specified number of samples. Excerpts are drawn randomly
- of the specified duration, above a specified loudness threshold
- and are resampled on the fly to the desired sample rate
- (if it is different from the audio source sample rate).
-
- This takes either a single AudioLoader object,
- a dictionary of AudioLoader objects, or a dictionary of AudioLoader
- objects. Each AudioLoader is called by the dataset, and the
- result is placed in the output dictionary. A transform can also be
- specified for the entire dataset, rather than for each specific
- loader. This transform can be applied to the output of all the
- loaders if desired.
-
- AudioLoader objects can be specified as aligned, which means the
- loaders correspond to multitrack audio (e.g. a vocals, bass,
- drums, and other loader for multitrack music mixtures).
-
-
- Parameters
- ----------
- loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
- AudioLoaders to sample audio from.
- sample_rate : int
- Desired sample rate.
- n_examples : int, optional
- Number of examples (length of dataset), by default 1000
- duration : float, optional
- Duration of audio samples, by default 0.5
- loudness_cutoff : float, optional
- Loudness cutoff threshold for audio samples, by default -40
- num_channels : int, optional
- Number of channels in output audio, by default 1
- transform : Callable, optional
- Transform to instantiate alongside each dataset item, by default None
- aligned : bool, optional
- Whether the loaders should be sampled in an aligned manner (e.g. same
- offset, duration, and matched file name), by default False
- shuffle_loaders : bool, optional
- Whether to shuffle the loaders before sampling from them, by default False
- matcher : Callable
- How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
- by default uses the parent directory of each file.
- without_replacement : bool
- Whether to choose files with or without replacement, by default True.
-
-
- Examples
- --------
- >>> from audiotools.data.datasets import AudioLoader
- >>> from audiotools.data.datasets import AudioDataset
- >>> from audiotools import transforms as tfm
- >>> import numpy as np
- >>>
- >>> loaders = [
- >>> AudioLoader(
- >>> sources=[f"tests/audio/spk"],
- >>> transform=tfm.Equalizer(),
- >>> ext=["wav"],
- >>> )
- >>> for i in range(5)
- >>> ]
- >>>
- >>> dataset = AudioDataset(
- >>> loaders = loaders,
- >>> sample_rate = 44100,
- >>> duration = 1.0,
- >>> transform = tfm.RescaleAudio(),
- >>> )
- >>>
- >>> item = dataset[np.random.randint(len(dataset))]
- >>>
- >>> for i in range(len(loaders)):
- >>> item[i]["signal"] = loaders[i].transform(
- >>> item[i]["signal"], **item[i]["transform_args"]
- >>> )
- >>> item[i]["signal"].widget(i)
- >>>
- >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
- >>> mix = dataset.transform(mix, **item["transform_args"])
- >>> mix.widget("mix")
-
- Below is an example of how one could load MUSDB multitrack data:
-
- >>> import audiotools as at
- >>> from pathlib import Path
- >>> from audiotools import transforms as tfm
- >>> import numpy as np
- >>> import torch
- >>>
- >>> def build_dataset(
- >>> sample_rate: int = 44100,
- >>> duration: float = 5.0,
- >>> musdb_path: str = "~/.data/musdb/",
- >>> ):
- >>> musdb_path = Path(musdb_path).expanduser()
- >>> loaders = {
- >>> src: at.datasets.AudioLoader(
- >>> sources=[musdb_path],
- >>> transform=tfm.Compose(
- >>> tfm.VolumeNorm(("uniform", -20, -10)),
- >>> tfm.Silence(prob=0.1),
- >>> ),
- >>> ext=[f"{src}.wav"],
- >>> )
- >>> for src in ["vocals", "bass", "drums", "other"]
- >>> }
- >>>
- >>> dataset = at.datasets.AudioDataset(
- >>> loaders=loaders,
- >>> sample_rate=sample_rate,
- >>> duration=duration,
- >>> num_channels=1,
- >>> aligned=True,
- >>> transform=tfm.RescaleAudio(),
- >>> shuffle_loaders=True,
- >>> )
- >>> return dataset, list(loaders.keys())
- >>>
- >>> train_data, sources = build_dataset()
- >>> dataloader = torch.utils.data.DataLoader(
- >>> train_data,
- >>> batch_size=16,
- >>> num_workers=0,
- >>> collate_fn=train_data.collate,
- >>> )
- >>> batch = next(iter(dataloader))
- >>>
- >>> for k in sources:
- >>> src = batch[k]
- >>> src["transformed"] = train_data.loaders[k].transform(
- >>> src["signal"].clone(), **src["transform_args"]
- >>> )
- >>>
- >>> mixture = sum(batch[k]["transformed"] for k in sources)
- >>> mixture = train_data.transform(mixture, **batch["transform_args"])
- >>>
- >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
- >>> # Construct the targets:
- >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
-
- Similarly, here's example code for loading Slakh data:
-
- >>> import audiotools as at
- >>> from pathlib import Path
- >>> from audiotools import transforms as tfm
- >>> import numpy as np
- >>> import torch
- >>> import glob
- >>>
- >>> def build_dataset(
- >>> sample_rate: int = 16000,
- >>> duration: float = 10.0,
- >>> slakh_path: str = "~/.data/slakh/",
- >>> ):
- >>> slakh_path = Path(slakh_path).expanduser()
- >>>
- >>> # Find the max number of sources in Slakh
- >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
- >>> n_sources = len(list(set(src_names)))
- >>>
- >>> loaders = {
- >>> f"S{i:02d}": at.datasets.AudioLoader(
- >>> sources=[slakh_path],
- >>> transform=tfm.Compose(
- >>> tfm.VolumeNorm(("uniform", -20, -10)),
- >>> tfm.Silence(prob=0.1),
- >>> ),
- >>> ext=[f"S{i:02d}.wav"],
- >>> )
- >>> for i in range(n_sources)
- >>> }
- >>> dataset = at.datasets.AudioDataset(
- >>> loaders=loaders,
- >>> sample_rate=sample_rate,
- >>> duration=duration,
- >>> num_channels=1,
- >>> aligned=True,
- >>> transform=tfm.RescaleAudio(),
- >>> shuffle_loaders=False,
- >>> )
- >>>
- >>> return dataset, list(loaders.keys())
- >>>
- >>> train_data, sources = build_dataset()
- >>> dataloader = torch.utils.data.DataLoader(
- >>> train_data,
- >>> batch_size=16,
- >>> num_workers=0,
- >>> collate_fn=train_data.collate,
- >>> )
- >>> batch = next(iter(dataloader))
- >>>
- >>> for k in sources:
- >>> src = batch[k]
- >>> src["transformed"] = train_data.loaders[k].transform(
- >>> src["signal"].clone(), **src["transform_args"]
- >>> )
- >>>
- >>> mixture = sum(batch[k]["transformed"] for k in sources)
- >>> mixture = train_data.transform(mixture, **batch["transform_args"])
-
- """
-
- def __init__(
- self,
- loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
- sample_rate: int,
- n_examples: int = 1000,
- duration: float = 0.5,
- offset: float = None,
- loudness_cutoff: float = -40,
- num_channels: int = 1,
- transform: Callable = None,
- aligned: bool = False,
- shuffle_loaders: bool = False,
- matcher: Callable = default_matcher,
- without_replacement: bool = True,
- ):
- # Internally we convert loaders to a dictionary
- if isinstance(loaders, list):
- loaders = {i: l for i, l in enumerate(loaders)}
- elif isinstance(loaders, AudioLoader):
- loaders = {0: loaders}
-
- self.loaders = loaders
- self.loudness_cutoff = loudness_cutoff
- self.num_channels = num_channels
-
- self.length = n_examples
- self.transform = transform
- self.sample_rate = sample_rate
- self.duration = duration
- self.offset = offset
- self.aligned = aligned
- self.shuffle_loaders = shuffle_loaders
- self.without_replacement = without_replacement
-
- if aligned:
- loaders_list = list(loaders.values())
- for i in range(len(loaders_list[0].audio_lists)):
- input_lists = [l.audio_lists[i] for l in loaders_list]
- # Alignment happens in-place
- align_lists(input_lists, matcher)
-
- def __getitem__(self, idx):
- state = util.random_state(idx)
- offset = None if self.offset is None else self.offset
- item = {}
-
- keys = list(self.loaders.keys())
- if self.shuffle_loaders:
- state.shuffle(keys)
-
- loader_kwargs = {
- "state": state,
- "sample_rate": self.sample_rate,
- "duration": self.duration,
- "loudness_cutoff": self.loudness_cutoff,
- "num_channels": self.num_channels,
- "global_idx": idx if self.without_replacement else None,
- }
-
- # Draw item from first loader
- loader = self.loaders[keys[0]]
- item[keys[0]] = loader(**loader_kwargs)
-
- for key in keys[1:]:
- loader = self.loaders[key]
- if self.aligned:
- # Path mapper takes the current loader + everything
- # returned by the first loader.
- offset = item[keys[0]]["signal"].metadata["offset"]
- loader_kwargs.update(
- {
- "offset": offset,
- "source_idx": item[keys[0]]["source_idx"],
- "item_idx": item[keys[0]]["item_idx"],
- }
- )
- item[key] = loader(**loader_kwargs)
-
- # Sort dictionary back into original order
- keys = list(self.loaders.keys())
- item = {k: item[k] for k in keys}
-
- item["idx"] = idx
- if self.transform is not None:
- item["transform_args"] = self.transform.instantiate(
- state=state, signal=item[keys[0]]["signal"]
- )
-
- # If there's only one loader, pop it up
- # to the main dictionary, instead of keeping it
- # nested.
- if len(keys) == 1:
- item.update(item.pop(keys[0]))
-
- return item
-
- def __len__(self):
- return self.length
-
- @staticmethod
- def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
- """Collates items drawn from this dataset. Uses
- :py:func:`audiotools.core.util.collate`.
-
- Parameters
- ----------
- list_of_dicts : typing.Union[list, dict]
- Data drawn from each item.
- n_splits : int
- Number of splits to make when creating the batches (split into
- sub-batches). Useful for things like gradient accumulation.
-
- Returns
- -------
- dict
- Dictionary of batched data.
- """
- return util.collate(list_of_dicts, n_splits=n_splits)
-
-
-class ConcatDataset(AudioDataset):
- def __init__(self, datasets: list):
- self.datasets = datasets
-
- def __len__(self):
- return sum([len(d) for d in self.datasets])
-
- def __getitem__(self, idx):
- dataset = self.datasets[idx % len(self.datasets)]
- return dataset[idx // len(self.datasets)]
-
-
-class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
- """Distributed sampler that can be resumed from a given start index."""
-
- def __init__(self, dataset, start_idx: int = None, **kwargs):
- super().__init__(dataset, **kwargs)
- # Start index, allows to resume an experiment at the index it was
- self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
-
- def __iter__(self):
- for i, idx in enumerate(super().__iter__()):
- if i >= self.start_idx:
- yield idx
- self.start_idx = 0 # set the index back to 0 so for the next epoch
-
-
-class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
- """Sequential sampler that can be resumed from a given start index."""
-
- def __init__(self, dataset, start_idx: int = None, **kwargs):
- super().__init__(dataset, **kwargs)
- # Start index, allows to resume an experiment at the index it was
- self.start_idx = start_idx if start_idx is not None else 0
-
- def __iter__(self):
- for i, idx in enumerate(super().__iter__()):
- if i >= self.start_idx:
- yield idx
- self.start_idx = 0 # set the index back to 0 so for the next epoch
diff --git a/dito/models/ldm/dac/audiotools/data/preprocess.py b/dito/models/ldm/dac/audiotools/data/preprocess.py
deleted file mode 100644
index d90de210115e45838bc8d69b350f7516ba730406..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/data/preprocess.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import csv
-import os
-from pathlib import Path
-
-from tqdm import tqdm
-
-from ..core import AudioSignal
-
-
-def create_csv(
- audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
-):
- """Converts a folder of audio files to a CSV file. If ``loudness = True``,
- the output of this function will create a CSV file that looks something
- like:
-
- .. csv-table::
- :header: path,loudness
-
- daps/produced/f1_script1_produced.wav,-16.299999237060547
- daps/produced/f1_script2_produced.wav,-16.600000381469727
- daps/produced/f1_script3_produced.wav,-17.299999237060547
- daps/produced/f1_script4_produced.wav,-16.100000381469727
- daps/produced/f1_script5_produced.wav,-16.700000762939453
- daps/produced/f3_script1_produced.wav,-16.5
-
- .. note::
- The paths above are written relative to the ``data_path`` argument
- which defaults to the environment variable ``PATH_TO_DATA`` if
- it isn't passed to this function, and defaults to the empty string
- if that environment variable is not set.
-
- You can produce a CSV file from a directory of audio files via:
-
- >>> import audiotools
- >>> directory = ...
- >>> audio_files = audiotools.util.find_audio(directory)
- >>> output_path = "train.csv"
- >>> audiotools.data.preprocess.create_csv(
- >>> audio_files, output_csv, loudness=True
- >>> )
-
- Note that you can create empty rows in the CSV file by passing an empty
- string or None in the ``audio_files`` list. This is useful if you want to
- sync multiple CSV files in a multitrack setting. The loudness of these
- empty rows will be set to -inf.
-
- Parameters
- ----------
- audio_files : list
- List of audio files.
- output_csv : Path
- Output CSV, with each row containing the relative path of every file
- to ``data_path``, if specified (defaults to None).
- loudness : bool
- Compute loudness of entire file and store alongside path.
- """
-
- info = []
- pbar = tqdm(audio_files)
- for af in pbar:
- af = Path(af)
- pbar.set_description(f"Processing {af.name}")
- _info = {}
- if af.name == "":
- _info["path"] = ""
- if loudness:
- _info["loudness"] = -float("inf")
- else:
- _info["path"] = af.relative_to(data_path) if data_path is not None else af
- if loudness:
- _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
-
- info.append(_info)
-
- with open(output_csv, "w") as f:
- writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
- writer.writeheader()
-
- for item in info:
- writer.writerow(item)
diff --git a/dito/models/ldm/dac/audiotools/data/transforms.py b/dito/models/ldm/dac/audiotools/data/transforms.py
deleted file mode 100644
index 504e87dc61777e36ba95eb794f497bed4cdc7d2c..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/data/transforms.py
+++ /dev/null
@@ -1,1592 +0,0 @@
-import copy
-from contextlib import contextmanager
-from inspect import signature
-from typing import List
-
-import numpy as np
-import torch
-from flatten_dict import flatten
-from flatten_dict import unflatten
-from numpy.random import RandomState
-
-from .. import ml
-from ..core import AudioSignal
-from ..core import util
-from .datasets import AudioLoader
-
-tt = torch.tensor
-"""Shorthand for converting things to torch.tensor."""
-
-
-class BaseTransform:
- """This is the base class for all transforms that are implemented
- in this library. Transforms have two main operations: ``transform``
- and ``instantiate``.
-
- ``instantiate`` sets the parameters randomly
- from distribution tuples for each parameter. For example, for the
- ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
- is chosen randomly by instantiate. By default, it chosen uniformly
- between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
-
- ``transform`` applies the transform using the instantiated parameters.
- A simple example is as follows:
-
- >>> seed = 0
- >>> signal = ...
- >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
- >>> kwargs = transform.instantiate()
- >>> output = transform(signal.clone(), **kwargs)
-
- By breaking apart the instantiation of parameters from the actual audio
- processing of the transform, we can make things more reproducible, while
- also applying the transform on batches of data efficiently on GPU,
- rather than on individual audio samples.
-
- .. note::
- We call ``signal.clone()`` for the input to the ``transform`` function
- because signals are modified in-place! If you don't clone the signal,
- you will lose the original data.
-
- Parameters
- ----------
- keys : list, optional
- Keys that the transform looks for when
- calling ``self.transform``, by default []. In general this is
- set automatically, and you won't need to manipulate this argument.
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
-
- Examples
- --------
-
- >>> seed = 0
- >>>
- >>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
- >>> signal = AudioSignal(audio_path, offset=10, duration=2)
- >>> transform = tfm.Compose(
- >>> [
- >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
- >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
- >>> ],
- >>> )
- >>>
- >>> kwargs = transform.instantiate(seed, signal)
- >>> output = transform(signal, **kwargs)
-
- """
-
- def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
- # Get keys from the _transform signature.
- tfm_keys = list(signature(self._transform).parameters.keys())
-
- # Filter out signal and kwargs keys.
- ignore_keys = ["signal", "kwargs"]
- tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
-
- # Combine keys specified by the child class, the keys found in
- # _transform signature, and the mask key.
- self.keys = keys + tfm_keys + ["mask"]
-
- self.prob = prob
-
- if name is None:
- name = self.__class__.__name__
- self.name = name
-
- def _prepare(self, batch: dict):
- sub_batch = batch[self.name]
-
- for k in self.keys:
- assert k in sub_batch.keys(), f"{k} not in batch"
-
- return sub_batch
-
- def _transform(self, signal):
- return signal
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- return {}
-
- @staticmethod
- def apply_mask(batch: dict, mask: torch.Tensor):
- """Applies a mask to the batch.
-
- Parameters
- ----------
- batch : dict
- Batch whose values will be masked in the ``transform`` pass.
- mask : torch.Tensor
- Mask to apply to batch.
-
- Returns
- -------
- dict
- A dictionary that contains values only where ``mask = True``.
- """
- masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
- return unflatten(masked_batch)
-
- def transform(self, signal: AudioSignal, **kwargs):
- """Apply the transform to the audio signal,
- with given keyword arguments.
-
- Parameters
- ----------
- signal : AudioSignal
- Signal that will be modified by the transforms in-place.
- kwargs: dict
- Keyword arguments to the specific transforms ``self._transform``
- function.
-
- Returns
- -------
- AudioSignal
- Transformed AudioSignal.
-
- Examples
- --------
-
- >>> for seed in range(10):
- >>> kwargs = transform.instantiate(seed, signal)
- >>> output = transform(signal.clone(), **kwargs)
-
- """
- tfm_kwargs = self._prepare(kwargs)
- mask = tfm_kwargs["mask"]
-
- if torch.any(mask):
- tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
- tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
- signal[mask] = self._transform(signal[mask], **tfm_kwargs)
-
- return signal
-
- def __call__(self, *args, **kwargs):
- return self.transform(*args, **kwargs)
-
- def instantiate(
- self,
- state: RandomState = None,
- signal: AudioSignal = None,
- ):
- """Instantiates parameters for the transform.
-
- Parameters
- ----------
- state : RandomState, optional
- _description_, by default None
- signal : AudioSignal, optional
- _description_, by default None
-
- Returns
- -------
- dict
- Dictionary containing instantiated arguments for every keyword
- argument to ``self._transform``.
-
- Examples
- --------
-
- >>> for seed in range(10):
- >>> kwargs = transform.instantiate(seed, signal)
- >>> output = transform(signal.clone(), **kwargs)
-
- """
- state = util.random_state(state)
-
- # Not all instantiates need the signal. Check if signal
- # is needed before passing it in, so that the end-user
- # doesn't need to have variables they're not using flowing
- # into their function.
- needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
- kwargs = {}
- if needs_signal:
- kwargs = {"signal": signal}
-
- # Instantiate the parameters for the transform.
- params = self._instantiate(state, **kwargs)
- for k in list(params.keys()):
- v = params[k]
- if isinstance(v, (AudioSignal, torch.Tensor, dict)):
- params[k] = v
- else:
- params[k] = tt(v)
- mask = state.rand() <= self.prob
- params[f"mask"] = tt(mask)
-
- # Put the params into a nested dictionary that will be
- # used later when calling the transform. This is to avoid
- # collisions in the dictionary.
- params = {self.name: params}
-
- return params
-
- def batch_instantiate(
- self,
- states: list = None,
- signal: AudioSignal = None,
- ):
- """Instantiates arguments for every item in a batch,
- given a list of states. Each state in the list
- corresponds to one item in the batch.
-
- Parameters
- ----------
- states : list, optional
- List of states, by default None
- signal : AudioSignal, optional
- AudioSignal to pass to the ``self.instantiate`` section
- if it is needed for this transform, by default None
-
- Returns
- -------
- dict
- Collated dictionary of arguments.
-
- Examples
- --------
-
- >>> batch_size = 4
- >>> signal = AudioSignal(audio_path, offset=10, duration=2)
- >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
- >>>
- >>> states = [seed + idx for idx in list(range(batch_size))]
- >>> kwargs = transform.batch_instantiate(states, signal_batch)
- >>> batch_output = transform(signal_batch, **kwargs)
- """
- kwargs = []
- for state in states:
- kwargs.append(self.instantiate(state, signal))
- kwargs = util.collate(kwargs)
- return kwargs
-
-
-class Identity(BaseTransform):
- """This transform just returns the original signal."""
-
- pass
-
-
-class SpectralTransform(BaseTransform):
- """Spectral transforms require STFT data to exist, since manipulations
- of the STFT require the spectrogram. This just calls ``stft`` before
- the transform is called, and calls ``istft`` after the transform is
- called so that the audio data is written to after the spectral
- manipulation.
- """
-
- def transform(self, signal, **kwargs):
- signal.stft()
- super().transform(signal, **kwargs)
- signal.istft()
- return signal
-
-
-class Compose(BaseTransform):
- """Compose applies transforms in sequence, one after the other. The
- transforms are passed in as positional arguments or as a list like so:
-
- >>> transform = tfm.Compose(
- >>> [
- >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
- >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
- >>> ],
- >>> )
-
- This will convolve the signal with a room impulse response, and then
- add background noise to the signal. Instantiate instantiates
- all the parameters for every transform in the transform list so the
- interface for using the Compose transform is the same as everything
- else:
-
- >>> kwargs = transform.instantiate()
- >>> output = transform(signal.clone(), **kwargs)
-
- Under the hood, the transform maps each transform to a unique name
- under the hood of the form ``{position}.{name}``, where ``position``
- is the index of the transform in the list. ``Compose`` can nest
- within other ``Compose`` transforms, like so:
-
- >>> preprocess = transforms.Compose(
- >>> tfm.GlobalVolumeNorm(),
- >>> tfm.CrossTalk(),
- >>> name="preprocess",
- >>> )
- >>> augment = transforms.Compose(
- >>> tfm.RoomImpulseResponse(),
- >>> tfm.BackgroundNoise(),
- >>> name="augment",
- >>> )
- >>> postprocess = transforms.Compose(
- >>> tfm.VolumeChange(),
- >>> tfm.RescaleAudio(),
- >>> tfm.ShiftPhase(),
- >>> name="postprocess",
- >>> )
- >>> transform = transforms.Compose(preprocess, augment, postprocess),
-
- This defines 3 composed transforms, and then composes them in sequence
- with one another.
-
- Parameters
- ----------
- *transforms : list
- List of transforms to apply
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
- if isinstance(transforms[0], list):
- transforms = transforms[0]
-
- for i, tfm in enumerate(transforms):
- tfm.name = f"{i}.{tfm.name}"
-
- keys = [tfm.name for tfm in transforms]
- super().__init__(keys=keys, name=name, prob=prob)
-
- self.transforms = transforms
- self.transforms_to_apply = keys
-
- @contextmanager
- def filter(self, *names: list):
- """This can be used to skip transforms entirely when applying
- the sequence of transforms to a signal. For example, take
- the following transforms with the names ``preprocess, augment, postprocess``.
-
- >>> preprocess = transforms.Compose(
- >>> tfm.GlobalVolumeNorm(),
- >>> tfm.CrossTalk(),
- >>> name="preprocess",
- >>> )
- >>> augment = transforms.Compose(
- >>> tfm.RoomImpulseResponse(),
- >>> tfm.BackgroundNoise(),
- >>> name="augment",
- >>> )
- >>> postprocess = transforms.Compose(
- >>> tfm.VolumeChange(),
- >>> tfm.RescaleAudio(),
- >>> tfm.ShiftPhase(),
- >>> name="postprocess",
- >>> )
- >>> transform = transforms.Compose(preprocess, augment, postprocess)
-
- If we wanted to apply all 3 to a signal, we do:
-
- >>> kwargs = transform.instantiate()
- >>> output = transform(signal.clone(), **kwargs)
-
- But if we only wanted to apply the ``preprocess`` and ``postprocess``
- transforms to the signal, we do:
-
- >>> with transform_fn.filter("preprocess", "postprocess"):
- >>> output = transform(signal.clone(), **kwargs)
-
- Parameters
- ----------
- *names : list
- List of transforms, identified by name, to apply to signal.
- """
- old_transforms = self.transforms_to_apply
- self.transforms_to_apply = names
- yield
- self.transforms_to_apply = old_transforms
-
- def _transform(self, signal, **kwargs):
- for transform in self.transforms:
- if any([x in transform.name for x in self.transforms_to_apply]):
- signal = transform(signal, **kwargs)
- return signal
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- parameters = {}
- for transform in self.transforms:
- parameters.update(transform.instantiate(state, signal=signal))
- return parameters
-
- def __getitem__(self, idx):
- return self.transforms[idx]
-
- def __len__(self):
- return len(self.transforms)
-
- def __iter__(self):
- for transform in self.transforms:
- yield transform
-
-
-class Choose(Compose):
- """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
- but instead of applying all the transforms in sequence, it applies just a single transform,
- which is chosen for each item in the batch.
-
- Parameters
- ----------
- *transforms : list
- List of transforms to apply
- weights : list
- Probability of choosing any specific transform.
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
-
- Examples
- --------
-
- >>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
- """
-
- def __init__(
- self,
- *transforms: list,
- weights: list = None,
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(*transforms, name=name, prob=prob)
-
- if weights is None:
- _len = len(self.transforms)
- weights = [1 / _len for _ in range(_len)]
- self.weights = np.array(weights)
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- kwargs = super()._instantiate(state, signal)
- tfm_idx = list(range(len(self.transforms)))
- tfm_idx = state.choice(tfm_idx, p=self.weights)
- one_hot = []
- for i, t in enumerate(self.transforms):
- mask = kwargs[t.name]["mask"]
- if mask.item():
- kwargs[t.name]["mask"] = tt(i == tfm_idx)
- one_hot.append(kwargs[t.name]["mask"])
- kwargs["one_hot"] = one_hot
- return kwargs
-
-
-class Repeat(Compose):
- """Repeatedly applies a given transform ``n_repeat`` times."
-
- Parameters
- ----------
- transform : BaseTransform
- Transform to repeat.
- n_repeat : int, optional
- Number of times to repeat transform, by default 1
- """
-
- def __init__(
- self,
- transform,
- n_repeat: int = 1,
- name: str = None,
- prob: float = 1.0,
- ):
- transforms = [copy.copy(transform) for _ in range(n_repeat)]
- super().__init__(transforms, name=name, prob=prob)
-
- self.n_repeat = n_repeat
-
-
-class RepeatUpTo(Choose):
- """Repeatedly applies a given transform up to ``max_repeat`` times."
-
- Parameters
- ----------
- transform : BaseTransform
- Transform to repeat.
- max_repeat : int, optional
- Max number of times to repeat transform, by default 1
- weights : list
- Probability of choosing any specific number up to ``max_repeat``.
- """
-
- def __init__(
- self,
- transform,
- max_repeat: int = 5,
- weights: list = None,
- name: str = None,
- prob: float = 1.0,
- ):
- transforms = []
- for n in range(1, max_repeat):
- transforms.append(Repeat(transform, n_repeat=n))
- super().__init__(transforms, name=name, prob=prob, weights=weights)
-
- self.max_repeat = max_repeat
-
-
-class ClippingDistortion(BaseTransform):
- """Adds clipping distortion to signal. Corresponds
- to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
-
- Parameters
- ----------
- perc : tuple, optional
- Clipping percentile. Values are between 0.0 to 1.0.
- Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- perc: tuple = ("uniform", 0.0, 0.1),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.perc = perc
-
- def _instantiate(self, state: RandomState):
- return {"perc": util.sample_from_dist(self.perc, state)}
-
- def _transform(self, signal, perc):
- return signal.clip_distortion(perc)
-
-
-class Equalizer(BaseTransform):
- """Applies an equalization curve to the audio signal. Corresponds
- to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
-
- Parameters
- ----------
- eq_amount : tuple, optional
- The maximum dB cut to apply to the audio in any band,
- by default ("const", 1.0 dB)
- n_bands : int, optional
- Number of bands in EQ, by default 6
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- eq_amount: tuple = ("const", 1.0),
- n_bands: int = 6,
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.eq_amount = eq_amount
- self.n_bands = n_bands
-
- def _instantiate(self, state: RandomState):
- eq_amount = util.sample_from_dist(self.eq_amount, state)
- eq = -eq_amount * state.rand(self.n_bands)
- return {"eq": eq}
-
- def _transform(self, signal, eq):
- return signal.equalizer(eq)
-
-
-class Quantization(BaseTransform):
- """Applies quantization to the input waveform. Corresponds
- to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
-
- Parameters
- ----------
- channels : tuple, optional
- Number of evenly spaced quantization channels to quantize
- to, by default ("choice", [8, 32, 128, 256, 1024])
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.channels = channels
-
- def _instantiate(self, state: RandomState):
- return {"channels": util.sample_from_dist(self.channels, state)}
-
- def _transform(self, signal, channels):
- return signal.quantization(channels)
-
-
-class MuLawQuantization(BaseTransform):
- """Applies mu-law quantization to the input waveform. Corresponds
- to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
-
- Parameters
- ----------
- channels : tuple, optional
- Number of mu-law spaced quantization channels to quantize
- to, by default ("choice", [8, 32, 128, 256, 1024])
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.channels = channels
-
- def _instantiate(self, state: RandomState):
- return {"channels": util.sample_from_dist(self.channels, state)}
-
- def _transform(self, signal, channels):
- return signal.mulaw_quantization(channels)
-
-
-class NoiseFloor(BaseTransform):
- """Adds a noise floor of Gaussian noise to the signal at a specified
- dB.
-
- Parameters
- ----------
- db : tuple, optional
- Level of noise to add to signal, by default ("const", -50.0)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- db: tuple = ("const", -50.0),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.db = db
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- db = util.sample_from_dist(self.db, state)
- audio_data = state.randn(signal.num_channels, signal.signal_length)
- nz_signal = AudioSignal(audio_data, signal.sample_rate)
- nz_signal.normalize(db)
- return {"nz_signal": nz_signal}
-
- def _transform(self, signal, nz_signal):
- # Clone bg_signal so that transform can be repeatedly applied
- # to different signals with the same effect.
- return signal + nz_signal
-
-
-class BackgroundNoise(BaseTransform):
- """Adds background noise from audio specified by a set of CSV files.
- A valid CSV file looks like, and is typically generated by
- :py:func:`audiotools.data.preprocess.create_csv`:
-
- .. csv-table::
- :header: path
-
- room_tone/m6_script2_clean.wav
- room_tone/m6_script2_cleanraw.wav
- room_tone/m6_script2_ipad_balcony1.wav
- room_tone/m6_script2_ipad_bedroom1.wav
- room_tone/m6_script2_ipad_confroom1.wav
- room_tone/m6_script2_ipad_confroom2.wav
- room_tone/m6_script2_ipad_livingroom1.wav
- room_tone/m6_script2_ipad_office1.wav
-
- .. note::
- All paths are relative to an environment variable called ``PATH_TO_DATA``,
- so that CSV files are portable across machines where data may be
- located in different places.
-
- This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
- and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
- hood.
-
- Parameters
- ----------
- snr : tuple, optional
- Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
- sources : List[str], optional
- Sources containing folders, or CSVs with paths to audio files,
- by default None
- weights : List[float], optional
- Weights to sample audio files from each source, by default None
- eq_amount : tuple, optional
- Amount of equalization to apply, by default ("const", 1.0)
- n_bands : int, optional
- Number of bands in equalizer, by default 3
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- loudness_cutoff : float, optional
- Loudness cutoff when loading from audio files, by default None
- """
-
- def __init__(
- self,
- snr: tuple = ("uniform", 10.0, 30.0),
- sources: List[str] = None,
- weights: List[float] = None,
- eq_amount: tuple = ("const", 1.0),
- n_bands: int = 3,
- name: str = None,
- prob: float = 1.0,
- loudness_cutoff: float = None,
- ):
- super().__init__(name=name, prob=prob)
-
- self.snr = snr
- self.eq_amount = eq_amount
- self.n_bands = n_bands
- self.loader = AudioLoader(sources, weights)
- self.loudness_cutoff = loudness_cutoff
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- eq_amount = util.sample_from_dist(self.eq_amount, state)
- eq = -eq_amount * state.rand(self.n_bands)
- snr = util.sample_from_dist(self.snr, state)
-
- bg_signal = self.loader(
- state,
- signal.sample_rate,
- duration=signal.signal_duration,
- loudness_cutoff=self.loudness_cutoff,
- num_channels=signal.num_channels,
- )["signal"]
-
- return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
-
- def _transform(self, signal, bg_signal, snr, eq):
- # Clone bg_signal so that transform can be repeatedly applied
- # to different signals with the same effect.
- return signal.mix(bg_signal.clone(), snr, eq)
-
-
-class CrossTalk(BaseTransform):
- """Adds crosstalk between speakers, whose audio is drawn from a CSV file
- that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
-
- This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
- under the hood.
-
- Parameters
- ----------
- snr : tuple, optional
- How loud cross-talk speaker is relative to original signal in dB,
- by default ("uniform", 0.0, 10.0)
- sources : List[str], optional
- Sources containing folders, or CSVs with paths to audio files,
- by default None
- weights : List[float], optional
- Weights to sample audio files from each source, by default None
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- loudness_cutoff : float, optional
- Loudness cutoff when loading from audio files, by default -40
- """
-
- def __init__(
- self,
- snr: tuple = ("uniform", 0.0, 10.0),
- sources: List[str] = None,
- weights: List[float] = None,
- name: str = None,
- prob: float = 1.0,
- loudness_cutoff: float = -40,
- ):
- super().__init__(name=name, prob=prob)
-
- self.snr = snr
- self.loader = AudioLoader(sources, weights)
- self.loudness_cutoff = loudness_cutoff
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- snr = util.sample_from_dist(self.snr, state)
- crosstalk_signal = self.loader(
- state,
- signal.sample_rate,
- duration=signal.signal_duration,
- loudness_cutoff=self.loudness_cutoff,
- num_channels=signal.num_channels,
- )["signal"]
-
- return {"crosstalk_signal": crosstalk_signal, "snr": snr}
-
- def _transform(self, signal, crosstalk_signal, snr):
- # Clone bg_signal so that transform can be repeatedly applied
- # to different signals with the same effect.
- loudness = signal.loudness()
- mix = signal.mix(crosstalk_signal.clone(), snr)
- mix.normalize(loudness)
- return mix
-
-
-class RoomImpulseResponse(BaseTransform):
- """Convolves signal with a room impulse response, at a specified
- direct-to-reverberant ratio, with equalization applied. Room impulse
- response data is drawn from a CSV file that was produced via
- :py:func:`audiotools.data.preprocess.create_csv`.
-
- This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
- under the hood.
-
- Parameters
- ----------
- drr : tuple, optional
- _description_, by default ("uniform", 0.0, 30.0)
- sources : List[str], optional
- Sources containing folders, or CSVs with paths to audio files,
- by default None
- weights : List[float], optional
- Weights to sample audio files from each source, by default None
- eq_amount : tuple, optional
- Amount of equalization to apply, by default ("const", 1.0)
- n_bands : int, optional
- Number of bands in equalizer, by default 6
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- use_original_phase : bool, optional
- Whether or not to use the original phase, by default False
- offset : float, optional
- Offset from each impulse response file to use, by default 0.0
- duration : float, optional
- Duration of each impulse response, by default 1.0
- """
-
- def __init__(
- self,
- drr: tuple = ("uniform", 0.0, 30.0),
- sources: List[str] = None,
- weights: List[float] = None,
- eq_amount: tuple = ("const", 1.0),
- n_bands: int = 6,
- name: str = None,
- prob: float = 1.0,
- use_original_phase: bool = False,
- offset: float = 0.0,
- duration: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.drr = drr
- self.eq_amount = eq_amount
- self.n_bands = n_bands
- self.use_original_phase = use_original_phase
-
- self.loader = AudioLoader(sources, weights)
- self.offset = offset
- self.duration = duration
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- eq_amount = util.sample_from_dist(self.eq_amount, state)
- eq = -eq_amount * state.rand(self.n_bands)
- drr = util.sample_from_dist(self.drr, state)
-
- ir_signal = self.loader(
- state,
- signal.sample_rate,
- offset=self.offset,
- duration=self.duration,
- loudness_cutoff=None,
- num_channels=signal.num_channels,
- )["signal"]
- ir_signal.zero_pad_to(signal.sample_rate)
-
- return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
-
- def _transform(self, signal, ir_signal, drr, eq):
- # Clone ir_signal so that transform can be repeatedly applied
- # to different signals with the same effect.
- return signal.apply_ir(
- ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
- )
-
-
-class VolumeChange(BaseTransform):
- """Changes the volume of the input signal.
-
- Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
-
- Parameters
- ----------
- db : tuple, optional
- Change in volume in decibels, by default ("uniform", -12.0, 0.0)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- db: tuple = ("uniform", -12.0, 0.0),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
- self.db = db
-
- def _instantiate(self, state: RandomState):
- return {"db": util.sample_from_dist(self.db, state)}
-
- def _transform(self, signal, db):
- return signal.volume_change(db)
-
-
-class VolumeNorm(BaseTransform):
- """Normalizes the volume of the excerpt to a specified decibel.
-
- Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
-
- Parameters
- ----------
- db : tuple, optional
- dB to normalize signal to, by default ("const", -24)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- db: tuple = ("const", -24),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.db = db
-
- def _instantiate(self, state: RandomState):
- return {"db": util.sample_from_dist(self.db, state)}
-
- def _transform(self, signal, db):
- return signal.normalize(db)
-
-
-class GlobalVolumeNorm(BaseTransform):
- """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
- transform also normalizes the volume of a signal, but it uses
- the volume of the entire audio file the loaded excerpt comes from,
- rather than the volume of just the excerpt. The volume of the
- entire audio file is expected in ``signal.metadata["loudness"]``.
- If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
- with ``loudness = True``, like the following:
-
- .. csv-table::
- :header: path,loudness
-
- daps/produced/f1_script1_produced.wav,-16.299999237060547
- daps/produced/f1_script2_produced.wav,-16.600000381469727
- daps/produced/f1_script3_produced.wav,-17.299999237060547
- daps/produced/f1_script4_produced.wav,-16.100000381469727
- daps/produced/f1_script5_produced.wav,-16.700000762939453
- daps/produced/f3_script1_produced.wav,-16.5
-
- The ``AudioLoader`` will automatically load the loudness column into
- the metadata of the signal.
-
- Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
-
- Parameters
- ----------
- db : tuple, optional
- dB to normalize signal to, by default ("const", -24)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- db: tuple = ("const", -24),
- name: str = None,
- prob: float = 1.0,
- ):
- super().__init__(name=name, prob=prob)
-
- self.db = db
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- if "loudness" not in signal.metadata:
- db_change = 0.0
- elif float(signal.metadata["loudness"]) == float("-inf"):
- db_change = 0.0
- else:
- db = util.sample_from_dist(self.db, state)
- db_change = db - float(signal.metadata["loudness"])
-
- return {"db": db_change}
-
- def _transform(self, signal, db):
- return signal.volume_change(db)
-
-
-class Silence(BaseTransform):
- """Zeros out the signal with some probability.
-
- Parameters
- ----------
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 0.1
- """
-
- def __init__(self, name: str = None, prob: float = 0.1):
- super().__init__(name=name, prob=prob)
-
- def _transform(self, signal):
- _loudness = signal._loudness
- signal = AudioSignal(
- torch.zeros_like(signal.audio_data),
- sample_rate=signal.sample_rate,
- stft_params=signal.stft_params,
- )
- # So that the amound of noise added is as if it wasn't silenced.
- # TODO: improve this hack
- signal._loudness = _loudness
-
- return signal
-
-
-class LowPass(BaseTransform):
- """Applies a LowPass filter.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
-
- Parameters
- ----------
- cutoff : tuple, optional
- Cutoff frequency distribution,
- by default ``("choice", [4000, 8000, 16000])``
- zeros : int, optional
- Number of zero-crossings in filter, argument to
- ``julius.LowPassFilters``, by default 51
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- cutoff: tuple = ("choice", [4000, 8000, 16000]),
- zeros: int = 51,
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
-
- self.cutoff = cutoff
- self.zeros = zeros
-
- def _instantiate(self, state: RandomState):
- return {"cutoff": util.sample_from_dist(self.cutoff, state)}
-
- def _transform(self, signal, cutoff):
- return signal.low_pass(cutoff, zeros=self.zeros)
-
-
-class HighPass(BaseTransform):
- """Applies a HighPass filter.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
-
- Parameters
- ----------
- cutoff : tuple, optional
- Cutoff frequency distribution,
- by default ``("choice", [50, 100, 250, 500, 1000])``
- zeros : int, optional
- Number of zero-crossings in filter, argument to
- ``julius.LowPassFilters``, by default 51
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
- zeros: int = 51,
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
-
- self.cutoff = cutoff
- self.zeros = zeros
-
- def _instantiate(self, state: RandomState):
- return {"cutoff": util.sample_from_dist(self.cutoff, state)}
-
- def _transform(self, signal, cutoff):
- return signal.high_pass(cutoff, zeros=self.zeros)
-
-
-class RescaleAudio(BaseTransform):
- """Rescales the audio so it is in between ``-val`` and ``val``
- only if the original audio exceeds those bounds. Useful if
- transforms have caused the audio to clip.
-
- Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
-
- Parameters
- ----------
- val : float, optional
- Max absolute value of signal, by default 1.0
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
- super().__init__(name=name, prob=prob)
-
- self.val = val
-
- def _transform(self, signal):
- return signal.ensure_max_of_audio(self.val)
-
-
-class ShiftPhase(SpectralTransform):
- """Shifts the phase of the audio.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
-
- Parameters
- ----------
- shift : tuple, optional
- How much to shift phase by, by default ("uniform", -np.pi, np.pi)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- shift: tuple = ("uniform", -np.pi, np.pi),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
- self.shift = shift
-
- def _instantiate(self, state: RandomState):
- return {"shift": util.sample_from_dist(self.shift, state)}
-
- def _transform(self, signal, shift):
- return signal.shift_phase(shift)
-
-
-class InvertPhase(ShiftPhase):
- """Inverts the phase of the audio.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
-
- Parameters
- ----------
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(self, name: str = None, prob: float = 1):
- super().__init__(shift=("const", np.pi), name=name, prob=prob)
-
-
-class CorruptPhase(SpectralTransform):
- """Corrupts the phase of the audio.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
-
- Parameters
- ----------
- scale : tuple, optional
- How much to corrupt phase by, by default ("uniform", 0, np.pi)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
- ):
- super().__init__(name=name, prob=prob)
- self.scale = scale
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- scale = util.sample_from_dist(self.scale, state)
- corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
- return {"corruption": corruption.astype("float32")}
-
- def _transform(self, signal, corruption):
- return signal.shift_phase(shift=corruption)
-
-
-class FrequencyMask(SpectralTransform):
- """Masks a band of frequencies at a center frequency
- from the audio.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
-
- Parameters
- ----------
- f_center : tuple, optional
- Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
- f_width : tuple, optional
- Width of zero'd out band, by default ("const", 0.1)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- f_center: tuple = ("uniform", 0.0, 1.0),
- f_width: tuple = ("const", 0.1),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
- self.f_center = f_center
- self.f_width = f_width
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- f_center = util.sample_from_dist(self.f_center, state)
- f_width = util.sample_from_dist(self.f_width, state)
-
- fmin = max(f_center - (f_width / 2), 0.0)
- fmax = min(f_center + (f_width / 2), 1.0)
-
- fmin_hz = (signal.sample_rate / 2) * fmin
- fmax_hz = (signal.sample_rate / 2) * fmax
-
- return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
-
- def _transform(self, signal, fmin_hz: float, fmax_hz: float):
- return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
-
-
-class TimeMask(SpectralTransform):
- """Masks out contiguous time-steps from signal.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
-
- Parameters
- ----------
- t_center : tuple, optional
- Center time in terms of 0.0 and 1.0 (duration of signal),
- by default ("uniform", 0.0, 1.0)
- t_width : tuple, optional
- Width of dropped out portion, by default ("const", 0.025)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- t_center: tuple = ("uniform", 0.0, 1.0),
- t_width: tuple = ("const", 0.025),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
- self.t_center = t_center
- self.t_width = t_width
-
- def _instantiate(self, state: RandomState, signal: AudioSignal):
- t_center = util.sample_from_dist(self.t_center, state)
- t_width = util.sample_from_dist(self.t_width, state)
-
- tmin = max(t_center - (t_width / 2), 0.0)
- tmax = min(t_center + (t_width / 2), 1.0)
-
- tmin_s = signal.signal_duration * tmin
- tmax_s = signal.signal_duration * tmax
- return {"tmin_s": tmin_s, "tmax_s": tmax_s}
-
- def _transform(self, signal, tmin_s: float, tmax_s: float):
- return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
-
-
-class MaskLowMagnitudes(SpectralTransform):
- """Masks low magnitude regions out of signal.
-
- Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
-
- Parameters
- ----------
- db_cutoff : tuple, optional
- Decibel value for which things below it will be masked away,
- by default ("uniform", -10, 10)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- db_cutoff: tuple = ("uniform", -10, 10),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
- self.db_cutoff = db_cutoff
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
-
- def _transform(self, signal, db_cutoff: float):
- return signal.mask_low_magnitudes(db_cutoff)
-
-
-class Smoothing(BaseTransform):
- """Convolves the signal with a smoothing window.
-
- Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
-
- Parameters
- ----------
- window_type : tuple, optional
- Type of window to use, by default ("const", "average")
- window_length : tuple, optional
- Length of smoothing window, by
- default ("choice", [8, 16, 32, 64, 128, 256, 512])
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- window_type: tuple = ("const", "average"),
- window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(name=name, prob=prob)
- self.window_type = window_type
- self.window_length = window_length
-
- def _instantiate(self, state: RandomState, signal: AudioSignal = None):
- window_type = util.sample_from_dist(self.window_type, state)
- window_length = util.sample_from_dist(self.window_length, state)
- window = signal.get_window(
- window_type=window_type, window_length=window_length, device="cpu"
- )
- return {"window": AudioSignal(window, signal.sample_rate)}
-
- def _transform(self, signal, window):
- sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
- sscale[sscale == 0.0] = 1.0
-
- out = signal.convolve(window)
-
- oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
- oscale[oscale == 0.0] = 1.0
-
- out = out * (sscale / oscale)
- return out
-
-
-class TimeNoise(TimeMask):
- """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
- replaces with noise instead of zeros.
-
- Parameters
- ----------
- t_center : tuple, optional
- Center time in terms of 0.0 and 1.0 (duration of signal),
- by default ("uniform", 0.0, 1.0)
- t_width : tuple, optional
- Width of dropped out portion, by default ("const", 0.025)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- t_center: tuple = ("uniform", 0.0, 1.0),
- t_width: tuple = ("const", 0.025),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
-
- def _transform(self, signal, tmin_s: float, tmax_s: float):
- signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
- mag, phase = signal.magnitude, signal.phase
-
- mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
- mask = (mag == 0.0) * (phase == 0.0)
-
- mag[mask] = mag_r[mask]
- phase[mask] = phase_r[mask]
-
- signal.magnitude = mag
- signal.phase = phase
- return signal
-
-
-class FrequencyNoise(FrequencyMask):
- """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
- replaces with noise instead of zeros.
-
- Parameters
- ----------
- f_center : tuple, optional
- Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
- f_width : tuple, optional
- Width of zero'd out band, by default ("const", 0.1)
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- f_center: tuple = ("uniform", 0.0, 1.0),
- f_width: tuple = ("const", 0.1),
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
-
- def _transform(self, signal, fmin_hz: float, fmax_hz: float):
- signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
- mag, phase = signal.magnitude, signal.phase
-
- mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
- mask = (mag == 0.0) * (phase == 0.0)
-
- mag[mask] = mag_r[mask]
- phase[mask] = phase_r[mask]
-
- signal.magnitude = mag
- signal.phase = phase
- return signal
-
-
-class SpectralDenoising(Equalizer):
- """Applies denoising algorithm detailed in
- :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
- using a randomly generated noise signal for denoising.
-
- Parameters
- ----------
- eq_amount : tuple, optional
- Amount of eq to apply to noise signal, by default ("const", 1.0)
- denoise_amount : tuple, optional
- Amount to denoise by, by default ("uniform", 0.8, 1.0)
- nz_volume : float, optional
- Volume of noise to denoise with, by default -40
- n_bands : int, optional
- Number of bands in equalizer, by default 6
- n_freq : int, optional
- Number of frequency bins to smooth by, by default 3
- n_time : int, optional
- Number of time bins to smooth by, by default 5
- name : str, optional
- Name of this transform, used to identify it in the dictionary
- produced by ``self.instantiate``, by default None
- prob : float, optional
- Probability of applying this transform, by default 1.0
- """
-
- def __init__(
- self,
- eq_amount: tuple = ("const", 1.0),
- denoise_amount: tuple = ("uniform", 0.8, 1.0),
- nz_volume: float = -40,
- n_bands: int = 6,
- n_freq: int = 3,
- n_time: int = 5,
- name: str = None,
- prob: float = 1,
- ):
- super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
-
- self.nz_volume = nz_volume
- self.denoise_amount = denoise_amount
- self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
-
- def _transform(self, signal, nz, eq, denoise_amount):
- nz = nz.normalize(self.nz_volume).equalizer(eq)
- self.spectral_gate = self.spectral_gate.to(signal.device)
- signal = self.spectral_gate(signal, nz, denoise_amount)
- return signal
-
- def _instantiate(self, state: RandomState):
- kwargs = super()._instantiate(state)
- kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
- kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
- return kwargs
diff --git a/dito/models/ldm/dac/audiotools/metrics/__init__.py b/dito/models/ldm/dac/audiotools/metrics/__init__.py
deleted file mode 100644
index c9c8d2df61f94afae8e39e57abf156e8e4059a9e..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/metrics/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""
-Functions for comparing AudioSignal objects to one another.
-""" # fmt: skip
-from . import distance
-from . import quality
-from . import spectral
diff --git a/dito/models/ldm/dac/audiotools/metrics/distance.py b/dito/models/ldm/dac/audiotools/metrics/distance.py
deleted file mode 100644
index ce78739bfc29f9ddc39b23063b4243ddac10adaf..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/metrics/distance.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import torch
-from torch import nn
-
-from .. import AudioSignal
-
-
-class L1Loss(nn.L1Loss):
- """L1 Loss between AudioSignals. Defaults
- to comparing ``audio_data``, but any
- attribute of an AudioSignal can be used.
-
- Parameters
- ----------
- attribute : str, optional
- Attribute of signal to compare, defaults to ``audio_data``.
- weight : float, optional
- Weight of this loss, defaults to 1.0.
- """
-
- def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
- self.attribute = attribute
- self.weight = weight
- super().__init__(**kwargs)
-
- def forward(self, x: AudioSignal, y: AudioSignal):
- """
- Parameters
- ----------
- x : AudioSignal
- Estimate AudioSignal
- y : AudioSignal
- Reference AudioSignal
-
- Returns
- -------
- torch.Tensor
- L1 loss between AudioSignal attributes.
- """
- if isinstance(x, AudioSignal):
- x = getattr(x, self.attribute)
- y = getattr(y, self.attribute)
- return super().forward(x, y)
-
-
-class SISDRLoss(nn.Module):
- """
- Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
- of estimated and reference audio signals or aligned features.
-
- Parameters
- ----------
- scaling : int, optional
- Whether to use scale-invariant (True) or
- signal-to-noise ratio (False), by default True
- reduction : str, optional
- How to reduce across the batch (either 'mean',
- 'sum', or none).], by default ' mean'
- zero_mean : int, optional
- Zero mean the references and estimates before
- computing the loss, by default True
- clip_min : int, optional
- The minimum possible loss value. Helps network
- to not focus on making already good examples better, by default None
- weight : float, optional
- Weight of this loss, defaults to 1.0.
- """
-
- def __init__(
- self,
- scaling: int = True,
- reduction: str = "mean",
- zero_mean: int = True,
- clip_min: int = None,
- weight: float = 1.0,
- ):
- self.scaling = scaling
- self.reduction = reduction
- self.zero_mean = zero_mean
- self.clip_min = clip_min
- self.weight = weight
- super().__init__()
-
- def forward(self, x: AudioSignal, y: AudioSignal):
- eps = 1e-8
- # nb, nc, nt
- if isinstance(x, AudioSignal):
- references = x.audio_data
- estimates = y.audio_data
- else:
- references = x
- estimates = y
-
- nb = references.shape[0]
- references = references.reshape(nb, 1, -1).permute(0, 2, 1)
- estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
-
- # samples now on axis 1
- if self.zero_mean:
- mean_reference = references.mean(dim=1, keepdim=True)
- mean_estimate = estimates.mean(dim=1, keepdim=True)
- else:
- mean_reference = 0
- mean_estimate = 0
-
- _references = references - mean_reference
- _estimates = estimates - mean_estimate
-
- references_projection = (_references**2).sum(dim=-2) + eps
- references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
-
- scale = (
- (references_on_estimates / references_projection).unsqueeze(1)
- if self.scaling
- else 1
- )
-
- e_true = scale * _references
- e_res = _estimates - e_true
-
- signal = (e_true**2).sum(dim=1)
- noise = (e_res**2).sum(dim=1)
- sdr = -10 * torch.log10(signal / noise + eps)
-
- if self.clip_min is not None:
- sdr = torch.clamp(sdr, min=self.clip_min)
-
- if self.reduction == "mean":
- sdr = sdr.mean()
- elif self.reduction == "sum":
- sdr = sdr.sum()
- return sdr
diff --git a/dito/models/ldm/dac/audiotools/metrics/quality.py b/dito/models/ldm/dac/audiotools/metrics/quality.py
deleted file mode 100644
index 1608f25507082b49ccbf49289025a5a94a422808..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/metrics/quality.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import os
-
-import numpy as np
-import torch
-
-from .. import AudioSignal
-
-
-def stoi(
- estimates: AudioSignal,
- references: AudioSignal,
- extended: int = False,
-):
- """Short term objective intelligibility
- Computes the STOI (See [1][2]) of a denoised signal compared to a clean
- signal, The output is expected to have a monotonic relation with the
- subjective speech-intelligibility, where a higher score denotes better
- speech intelligibility. Uses pystoi under the hood.
-
- Parameters
- ----------
- estimates : AudioSignal
- Denoised speech
- references : AudioSignal
- Clean original speech
- extended : int, optional
- Boolean, whether to use the extended STOI described in [3], by default False
-
- Returns
- -------
- Tensor[float]
- Short time objective intelligibility measure between clean and
- denoised speech
-
- References
- ----------
- 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
- Objective Intelligibility Measure for Time-Frequency Weighted Noisy
- Speech', ICASSP 2010, Texas, Dallas.
- 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
- Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
- IEEE Transactions on Audio, Speech, and Language Processing, 2011.
- 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
- Intelligibility of Speech Masked by Modulated Noise Maskers',
- IEEE Transactions on Audio, Speech and Language Processing, 2016.
- """
- import pystoi
-
- estimates = estimates.clone().to_mono()
- references = references.clone().to_mono()
-
- stois = []
- for i in range(estimates.batch_size):
- _stoi = pystoi.stoi(
- references.audio_data[i, 0].detach().cpu().numpy(),
- estimates.audio_data[i, 0].detach().cpu().numpy(),
- references.sample_rate,
- extended=extended,
- )
- stois.append(_stoi)
- return torch.from_numpy(np.array(stois))
-
-
-def pesq(
- estimates: AudioSignal,
- references: AudioSignal,
- mode: str = "wb",
- target_sr: float = 16000,
-):
- """_summary_
-
- Parameters
- ----------
- estimates : AudioSignal
- Degraded AudioSignal
- references : AudioSignal
- Reference AudioSignal
- mode : str, optional
- 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
- target_sr : float, optional
- Target sample rate, by default 16000
-
- Returns
- -------
- Tensor[float]
- PESQ score: P.862.2 Prediction (MOS-LQO)
- """
- from pesq import pesq as pesq_fn
-
- estimates = estimates.clone().to_mono().resample(target_sr)
- references = references.clone().to_mono().resample(target_sr)
-
- pesqs = []
- for i in range(estimates.batch_size):
- _pesq = pesq_fn(
- estimates.sample_rate,
- references.audio_data[i, 0].detach().cpu().numpy(),
- estimates.audio_data[i, 0].detach().cpu().numpy(),
- mode,
- )
- pesqs.append(_pesq)
- return torch.from_numpy(np.array(pesqs))
-
-
-def visqol(
- estimates: AudioSignal,
- references: AudioSignal,
- mode: str = "audio",
-): # pragma: no cover
- """ViSQOL score.
-
- Parameters
- ----------
- estimates : AudioSignal
- Degraded AudioSignal
- references : AudioSignal
- Reference AudioSignal
- mode : str, optional
- 'audio' or 'speech', by default 'audio'
-
- Returns
- -------
- Tensor[float]
- ViSQOL score (MOS-LQO)
- """
- from visqol import visqol_lib_py
- from visqol.pb2 import visqol_config_pb2
- from visqol.pb2 import similarity_result_pb2
-
- config = visqol_config_pb2.VisqolConfig()
- if mode == "audio":
- target_sr = 48000
- config.options.use_speech_scoring = False
- svr_model_path = "libsvm_nu_svr_model.txt"
- elif mode == "speech":
- target_sr = 16000
- config.options.use_speech_scoring = True
- svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
- else:
- raise ValueError(f"Unrecognized mode: {mode}")
- config.audio.sample_rate = target_sr
- config.options.svr_model_path = os.path.join(
- os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
- )
-
- api = visqol_lib_py.VisqolApi()
- api.Create(config)
-
- estimates = estimates.clone().to_mono().resample(target_sr)
- references = references.clone().to_mono().resample(target_sr)
-
- visqols = []
- for i in range(estimates.batch_size):
- _visqol = api.Measure(
- references.audio_data[i, 0].detach().cpu().numpy().astype(float),
- estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
- )
- visqols.append(_visqol.moslqo)
- return torch.from_numpy(np.array(visqols))
diff --git a/dito/models/ldm/dac/audiotools/metrics/spectral.py b/dito/models/ldm/dac/audiotools/metrics/spectral.py
deleted file mode 100644
index 7ce953882efa4e5b777a0348bee6c1be39279a6c..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/metrics/spectral.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import typing
-from typing import List
-
-import numpy as np
-from torch import nn
-
-from .. import AudioSignal
-from .. import STFTParams
-
-
-class MultiScaleSTFTLoss(nn.Module):
- """Computes the multi-scale STFT loss from [1].
-
- Parameters
- ----------
- window_lengths : List[int], optional
- Length of each window of each STFT, by default [2048, 512]
- loss_fn : typing.Callable, optional
- How to compare each loss, by default nn.L1Loss()
- clamp_eps : float, optional
- Clamp on the log magnitude, below, by default 1e-5
- mag_weight : float, optional
- Weight of raw magnitude portion of loss, by default 1.0
- log_weight : float, optional
- Weight of log magnitude portion of loss, by default 1.0
- pow : float, optional
- Power to raise magnitude to before taking log, by default 2.0
- weight : float, optional
- Weight of this loss, by default 1.0
- match_stride : bool, optional
- Whether to match the stride of convolutional layers, by default False
-
- References
- ----------
-
- 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
- "DDSP: Differentiable Digital Signal Processing."
- International Conference on Learning Representations. 2019.
- """
-
- def __init__(
- self,
- window_lengths: List[int] = [2048, 512],
- loss_fn: typing.Callable = nn.L1Loss(),
- clamp_eps: float = 1e-5,
- mag_weight: float = 1.0,
- log_weight: float = 1.0,
- pow: float = 2.0,
- weight: float = 1.0,
- match_stride: bool = False,
- window_type: str = None,
- ):
- super().__init__()
- self.stft_params = [
- STFTParams(
- window_length=w,
- hop_length=w // 4,
- match_stride=match_stride,
- window_type=window_type,
- )
- for w in window_lengths
- ]
- self.loss_fn = loss_fn
- self.log_weight = log_weight
- self.mag_weight = mag_weight
- self.clamp_eps = clamp_eps
- self.weight = weight
- self.pow = pow
-
- def forward(self, x: AudioSignal, y: AudioSignal):
- """Computes multi-scale STFT between an estimate and a reference
- signal.
-
- Parameters
- ----------
- x : AudioSignal
- Estimate signal
- y : AudioSignal
- Reference signal
-
- Returns
- -------
- torch.Tensor
- Multi-scale STFT loss.
- """
- loss = 0.0
- for s in self.stft_params:
- x.stft(s.window_length, s.hop_length, s.window_type)
- y.stft(s.window_length, s.hop_length, s.window_type)
- loss += self.log_weight * self.loss_fn(
- x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
- y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
- )
- loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
- return loss
-
-
-class MelSpectrogramLoss(nn.Module):
- """Compute distance between mel spectrograms. Can be used
- in a multi-scale way.
-
- Parameters
- ----------
- n_mels : List[int]
- Number of mels per STFT, by default [150, 80],
- window_lengths : List[int], optional
- Length of each window of each STFT, by default [2048, 512]
- loss_fn : typing.Callable, optional
- How to compare each loss, by default nn.L1Loss()
- clamp_eps : float, optional
- Clamp on the log magnitude, below, by default 1e-5
- mag_weight : float, optional
- Weight of raw magnitude portion of loss, by default 1.0
- log_weight : float, optional
- Weight of log magnitude portion of loss, by default 1.0
- pow : float, optional
- Power to raise magnitude to before taking log, by default 2.0
- weight : float, optional
- Weight of this loss, by default 1.0
- match_stride : bool, optional
- Whether to match the stride of convolutional layers, by default False
- """
-
- def __init__(
- self,
- n_mels: List[int] = [150, 80],
- window_lengths: List[int] = [2048, 512],
- loss_fn: typing.Callable = nn.L1Loss(),
- clamp_eps: float = 1e-5,
- mag_weight: float = 1.0,
- log_weight: float = 1.0,
- pow: float = 2.0,
- weight: float = 1.0,
- match_stride: bool = False,
- mel_fmin: List[float] = [0.0, 0.0],
- mel_fmax: List[float] = [None, None],
- window_type: str = None,
- ):
- super().__init__()
- self.stft_params = [
- STFTParams(
- window_length=w,
- hop_length=w // 4,
- match_stride=match_stride,
- window_type=window_type,
- )
- for w in window_lengths
- ]
- self.n_mels = n_mels
- self.loss_fn = loss_fn
- self.clamp_eps = clamp_eps
- self.log_weight = log_weight
- self.mag_weight = mag_weight
- self.weight = weight
- self.mel_fmin = mel_fmin
- self.mel_fmax = mel_fmax
- self.pow = pow
-
- def forward(self, x: AudioSignal, y: AudioSignal):
- """Computes mel loss between an estimate and a reference
- signal.
-
- Parameters
- ----------
- x : AudioSignal
- Estimate signal
- y : AudioSignal
- Reference signal
-
- Returns
- -------
- torch.Tensor
- Mel loss.
- """
- loss = 0.0
- for n_mels, fmin, fmax, s in zip(
- self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
- ):
- kwargs = {
- "window_length": s.window_length,
- "hop_length": s.hop_length,
- "window_type": s.window_type,
- }
- x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
- y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
-
- loss += self.log_weight * self.loss_fn(
- x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
- y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
- )
- loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
- return loss
-
-
-class PhaseLoss(nn.Module):
- """Difference between phase spectrograms.
-
- Parameters
- ----------
- window_length : int, optional
- Length of STFT window, by default 2048
- hop_length : int, optional
- Hop length of STFT window, by default 512
- weight : float, optional
- Weight of loss, by default 1.0
- """
-
- def __init__(
- self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
- ):
- super().__init__()
-
- self.weight = weight
- self.stft_params = STFTParams(window_length, hop_length)
-
- def forward(self, x: AudioSignal, y: AudioSignal):
- """Computes phase loss between an estimate and a reference
- signal.
-
- Parameters
- ----------
- x : AudioSignal
- Estimate signal
- y : AudioSignal
- Reference signal
-
- Returns
- -------
- torch.Tensor
- Phase loss.
- """
- s = self.stft_params
- x.stft(s.window_length, s.hop_length, s.window_type)
- y.stft(s.window_length, s.hop_length, s.window_type)
-
- # Take circular difference
- diff = x.phase - y.phase
- diff[diff < -np.pi] += 2 * np.pi
- diff[diff > np.pi] -= -2 * np.pi
-
- # Scale true magnitude to weights in [0, 1]
- x_min, x_max = x.magnitude.min(), x.magnitude.max()
- weights = (x.magnitude - x_min) / (x_max - x_min)
-
- # Take weighted mean of all phase errors
- loss = ((weights * diff) ** 2).mean()
- return loss
diff --git a/dito/models/ldm/dac/audiotools/ml/__init__.py b/dito/models/ldm/dac/audiotools/ml/__init__.py
deleted file mode 100644
index a9ca69977bad57e1a92b7551d601d9224ee854ab..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from . import decorators
-from . import layers
-from .accelerator import Accelerator
-from .experiment import Experiment
-from .layers import BaseModel
diff --git a/dito/models/ldm/dac/audiotools/ml/accelerator.py b/dito/models/ldm/dac/audiotools/ml/accelerator.py
deleted file mode 100644
index 37c6e8d954f112b8b0aff257894e62add8874e30..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/accelerator.py
+++ /dev/null
@@ -1,184 +0,0 @@
-import os
-import typing
-
-import torch
-import torch.distributed as dist
-from torch.nn.parallel import DataParallel
-from torch.nn.parallel import DistributedDataParallel
-
-from ..data.datasets import ResumableDistributedSampler as DistributedSampler
-from ..data.datasets import ResumableSequentialSampler as SequentialSampler
-
-
-class Accelerator: # pragma: no cover
- """This class is used to prepare models and dataloaders for
- usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
- prepare the respective objects. In the case of models, they are moved to
- the appropriate GPU and SyncBatchNorm is applied to them. In the case of
- dataloaders, a sampler is created and the dataloader is initialized with
- that sampler.
-
- If the world size is 1, prepare_model and prepare_dataloader are
- no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
- script was launched without ``torchrun``, and ``DataParallel``
- will be used instead of ``DistributedDataParallel`` (not recommended), if
- the world size (number of GPUs) is greater than 1.
-
- Parameters
- ----------
- amp : bool, optional
- Whether or not to enable automatic mixed precision, by default False
- """
-
- def __init__(self, amp: bool = False):
- local_rank = os.getenv("LOCAL_RANK", None)
- self.world_size = torch.cuda.device_count()
-
- self.use_ddp = self.world_size > 1 and local_rank is not None
- self.use_dp = self.world_size > 1 and local_rank is None
- self.device = "cpu" if self.world_size == 0 else "cuda"
-
- if self.use_ddp:
- local_rank = int(local_rank)
- dist.init_process_group(
- "nccl",
- init_method="env://",
- world_size=self.world_size,
- rank=local_rank,
- )
-
- self.local_rank = 0 if local_rank is None else local_rank
- self.amp = amp
-
- class DummyScaler:
- def __init__(self):
- pass
-
- def step(self, optimizer):
- optimizer.step()
-
- def scale(self, loss):
- return loss
-
- def unscale_(self, optimizer):
- return optimizer
-
- def update(self):
- pass
-
- self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
- self.device_ctx = (
- torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
- )
-
- def __enter__(self):
- if self.device_ctx is not None:
- self.device_ctx.__enter__()
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- if self.device_ctx is not None:
- self.device_ctx.__exit__(exc_type, exc_value, traceback)
-
- def prepare_model(self, model: torch.nn.Module, **kwargs):
- """Prepares model for DDP or DP. The model is moved to
- the device of the correct rank.
-
- Parameters
- ----------
- model : torch.nn.Module
- Model that is converted for DDP or DP.
-
- Returns
- -------
- torch.nn.Module
- Wrapped model, or original model if DDP and DP are turned off.
- """
- model = model.to(self.device)
- if self.use_ddp:
- model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
- model = DistributedDataParallel(
- model, device_ids=[self.local_rank], **kwargs
- )
- elif self.use_dp:
- model = DataParallel(model, **kwargs)
- return model
-
- # Automatic mixed-precision utilities
- def autocast(self, *args, **kwargs):
- """Context manager for autocasting. Arguments
- go to ``torch.cuda.amp.autocast``.
- """
- return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
-
- def backward(self, loss: torch.Tensor):
- """Backwards pass, after scaling the loss if ``amp`` is
- enabled.
-
- Parameters
- ----------
- loss : torch.Tensor
- Loss value.
- """
- self.scaler.scale(loss).backward()
-
- def step(self, optimizer: torch.optim.Optimizer):
- """Steps the optimizer, using a ``scaler`` if ``amp`` is
- enabled.
-
- Parameters
- ----------
- optimizer : torch.optim.Optimizer
- Optimizer to step forward.
- """
- self.scaler.step(optimizer)
-
- def update(self):
- """Updates the scale factor."""
- self.scaler.update()
-
- def prepare_dataloader(
- self, dataset: typing.Iterable, start_idx: int = None, **kwargs
- ):
- """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
- enabled.
-
- Parameters
- ----------
- dataset : typing.Iterable
- Dataset to build Dataloader around.
- start_idx : int, optional
- Start index of sampler, useful if resuming from some epoch,
- by default None
-
- Returns
- -------
- _type_
- _description_
- """
-
- if self.use_ddp:
- sampler = DistributedSampler(
- dataset,
- start_idx,
- num_replicas=self.world_size,
- rank=self.local_rank,
- )
- if "num_workers" in kwargs:
- kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
- kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
- else:
- sampler = SequentialSampler(dataset, start_idx)
-
- dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
- return dataloader
-
- @staticmethod
- def unwrap(model):
- """Unwraps the model if it was wrapped in DDP or DP, otherwise
- just returns the model. Use this to unwrap the model returned by
- :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
- """
- if hasattr(model, "module"):
- return model.module
- return model
diff --git a/dito/models/ldm/dac/audiotools/ml/decorators.py b/dito/models/ldm/dac/audiotools/ml/decorators.py
deleted file mode 100644
index 3a435b06c47a48dc3600fa54ac092006f5c5bb27..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/decorators.py
+++ /dev/null
@@ -1,441 +0,0 @@
-import math
-import os
-import time
-from collections import defaultdict
-from functools import wraps
-
-import torch
-import torch.distributed as dist
-from rich import box
-from rich.console import Console
-from rich.console import Group
-from rich.live import Live
-from rich.markdown import Markdown
-from rich.padding import Padding
-from rich.panel import Panel
-from rich.progress import BarColumn
-from rich.progress import Progress
-from rich.progress import SpinnerColumn
-from rich.progress import TimeElapsedColumn
-from rich.progress import TimeRemainingColumn
-from rich.rule import Rule
-from rich.table import Table
-from torch.utils.tensorboard import SummaryWriter
-
-
-# This is here so that the history can be pickled.
-def default_list():
- return []
-
-
-class Mean:
- """Keeps track of the running mean, along with the latest
- value.
- """
-
- def __init__(self):
- self.reset()
-
- def __call__(self):
- mean = self.total / max(self.count, 1)
- return mean
-
- def reset(self):
- self.count = 0
- self.total = 0
-
- def update(self, val):
- if math.isfinite(val):
- self.count += 1
- self.total += val
-
-
-def when(condition):
- """Runs a function only when the condition is met. The condition is
- a function that is run.
-
- Parameters
- ----------
- condition : Callable
- Function to run to check whether or not to run the decorated
- function.
-
- Example
- -------
- Checkpoint only runs every 100 iterations, and only if the
- local rank is 0.
-
- >>> i = 0
- >>> rank = 0
- >>>
- >>> @when(lambda: i % 100 == 0 and rank == 0)
- >>> def checkpoint():
- >>> print("Saving to /runs/exp1")
- >>>
- >>> for i in range(1000):
- >>> checkpoint()
-
- """
-
- def decorator(fn):
- @wraps(fn)
- def decorated(*args, **kwargs):
- if condition():
- return fn(*args, **kwargs)
-
- return decorated
-
- return decorator
-
-
-def timer(prefix: str = "time"):
- """Adds execution time to the output dictionary of the decorated
- function. The function decorated by this must output a dictionary.
- The key added will follow the form "[prefix]/[name_of_function]"
-
- Parameters
- ----------
- prefix : str, optional
- The key added will follow the form "[prefix]/[name_of_function]",
- by default "time".
- """
-
- def decorator(fn):
- @wraps(fn)
- def decorated(*args, **kwargs):
- s = time.perf_counter()
- output = fn(*args, **kwargs)
- assert isinstance(output, dict)
- e = time.perf_counter()
- output[f"{prefix}/{fn.__name__}"] = e - s
- return output
-
- return decorated
-
- return decorator
-
-
-class Tracker:
- """
- A tracker class that helps to monitor the progress of training and logging the metrics.
-
- Attributes
- ----------
- metrics : dict
- A dictionary containing the metrics for each label.
- history : dict
- A dictionary containing the history of metrics for each label.
- writer : SummaryWriter
- A SummaryWriter object for logging the metrics.
- rank : int
- The rank of the current process.
- step : int
- The current step of the training.
- tasks : dict
- A dictionary containing the progress bars and tables for each label.
- pbar : Progress
- A progress bar object for displaying the progress.
- consoles : list
- A list of console objects for logging.
- live : Live
- A Live object for updating the display live.
-
- Methods
- -------
- print(msg: str)
- Prints the given message to all consoles.
- update(label: str, fn_name: str)
- Updates the progress bar and table for the given label.
- done(label: str, title: str)
- Resets the progress bar and table for the given label and prints the final result.
- track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
- A decorator for tracking the progress and metrics of a function.
- log(label: str, value_type: str = "value", history: bool = True)
- A decorator for logging the metrics of a function.
- is_best(label: str, key: str) -> bool
- Checks if the latest value of the given key in the label is the best so far.
- state_dict() -> dict
- Returns a dictionary containing the state of the tracker.
- load_state_dict(state_dict: dict) -> Tracker
- Loads the state of the tracker from the given state dictionary.
- """
-
- def __init__(
- self,
- writer: SummaryWriter = None,
- log_file: str = None,
- rank: int = 0,
- console_width: int = 100,
- step: int = 0,
- ):
- """
- Initializes the Tracker object.
-
- Parameters
- ----------
- writer : SummaryWriter, optional
- A SummaryWriter object for logging the metrics, by default None.
- log_file : str, optional
- The path to the log file, by default None.
- rank : int, optional
- The rank of the current process, by default 0.
- console_width : int, optional
- The width of the console, by default 100.
- step : int, optional
- The current step of the training, by default 0.
- """
- self.metrics = {}
- self.history = {}
- self.writer = writer
- self.rank = rank
- self.step = step
-
- # Create progress bars etc.
- self.tasks = {}
- self.pbar = Progress(
- SpinnerColumn(),
- "[progress.description]{task.description}",
- "{task.completed}/{task.total}",
- BarColumn(),
- TimeElapsedColumn(),
- "/",
- TimeRemainingColumn(),
- )
- self.consoles = [Console(width=console_width)]
- self.live = Live(console=self.consoles[0], refresh_per_second=10)
- if log_file is not None:
- self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
-
- def print(self, msg):
- """
- Prints the given message to all consoles.
-
- Parameters
- ----------
- msg : str
- The message to be printed.
- """
- if self.rank == 0:
- for c in self.consoles:
- c.log(msg)
-
- def update(self, label, fn_name):
- """
- Updates the progress bar and table for the given label.
-
- Parameters
- ----------
- label : str
- The label of the progress bar and table to be updated.
- fn_name : str
- The name of the function associated with the label.
- """
- if self.rank == 0:
- self.pbar.advance(self.tasks[label]["pbar"])
-
- # Create table
- table = Table(title=label, expand=True, box=box.MINIMAL)
- table.add_column("key", style="cyan")
- table.add_column("value", style="bright_blue")
- table.add_column("mean", style="bright_green")
-
- keys = self.metrics[label]["value"].keys()
- for k in keys:
- value = self.metrics[label]["value"][k]
- mean = self.metrics[label]["mean"][k]()
- table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
-
- self.tasks[label]["table"] = table
- tables = [t["table"] for t in self.tasks.values()]
- group = Group(*tables, self.pbar)
- self.live.update(
- Group(
- Padding("", (0, 0)),
- Rule(f"[italic]{fn_name}()", style="white"),
- Padding("", (0, 0)),
- Panel.fit(
- group, padding=(0, 5), title="[b]Progress", border_style="blue"
- ),
- )
- )
-
- def done(self, label: str, title: str):
- """
- Resets the progress bar and table for the given label and prints the final result.
-
- Parameters
- ----------
- label : str
- The label of the progress bar and table to be reset.
- title : str
- The title to be displayed when printing the final result.
- """
- for label in self.metrics:
- for v in self.metrics[label]["mean"].values():
- v.reset()
-
- if self.rank == 0:
- self.pbar.reset(self.tasks[label]["pbar"])
- tables = [t["table"] for t in self.tasks.values()]
- group = Group(Markdown(f"# {title}"), *tables, self.pbar)
- self.print(group)
-
- def track(
- self,
- label: str,
- length: int,
- completed: int = 0,
- op: dist.ReduceOp = dist.ReduceOp.AVG,
- ddp_active: bool = "LOCAL_RANK" in os.environ,
- ):
- """
- A decorator for tracking the progress and metrics of a function.
-
- Parameters
- ----------
- label : str
- The label to be associated with the progress and metrics.
- length : int
- The total number of iterations to be completed.
- completed : int, optional
- The number of iterations already completed, by default 0.
- op : dist.ReduceOp, optional
- The reduce operation to be used, by default dist.ReduceOp.AVG.
- ddp_active : bool, optional
- Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
- """
- self.tasks[label] = {
- "pbar": self.pbar.add_task(
- f"[white]Iteration ({label})", total=length, completed=completed
- ),
- "table": Table(),
- }
- self.metrics[label] = {
- "value": defaultdict(),
- "mean": defaultdict(lambda: Mean()),
- }
-
- def decorator(fn):
- @wraps(fn)
- def decorated(*args, **kwargs):
- output = fn(*args, **kwargs)
- if not isinstance(output, dict):
- self.update(label, fn.__name__)
- return output
- # Collect across all DDP processes
- scalar_keys = []
- for k, v in output.items():
- if isinstance(v, (int, float)):
- v = torch.tensor([v])
- if not torch.is_tensor(v):
- continue
- if ddp_active and v.is_cuda: # pragma: no cover
- dist.all_reduce(v, op=op)
- output[k] = v.detach()
- if torch.numel(v) == 1:
- scalar_keys.append(k)
- output[k] = v.item()
-
- # Save the outputs to tracker
- for k, v in output.items():
- if k not in scalar_keys:
- continue
- self.metrics[label]["value"][k] = v
- # Update the running mean
- self.metrics[label]["mean"][k].update(v)
-
- self.update(label, fn.__name__)
- return output
-
- return decorated
-
- return decorator
-
- def log(self, label: str, value_type: str = "value", history: bool = True):
- """
- A decorator for logging the metrics of a function.
-
- Parameters
- ----------
- label : str
- The label to be associated with the logging.
- value_type : str, optional
- The type of value to be logged, by default "value".
- history : bool, optional
- Whether to save the history of the metrics, by default True.
- """
- assert value_type in ["mean", "value"]
- if history:
- if label not in self.history:
- self.history[label] = defaultdict(default_list)
-
- def decorator(fn):
- @wraps(fn)
- def decorated(*args, **kwargs):
- output = fn(*args, **kwargs)
- if self.rank == 0:
- nonlocal value_type, label
- metrics = self.metrics[label][value_type]
- for k, v in metrics.items():
- v = v() if isinstance(v, Mean) else v
- if self.writer is not None:
- # self.writer.add_scalar(f"{k}/{label}", v, self.step)
- self.writer.log_metric(f"{k}_{label}", v, step=self.step)
- if label in self.history:
- self.history[label][k].append(v)
-
- if label in self.history:
- self.history[label]["step"].append(self.step)
-
- return output
-
- return decorated
-
- return decorator
-
- def is_best(self, label, key):
- """
- Checks if the latest value of the given key in the label is the best so far.
-
- Parameters
- ----------
- label : str
- The label of the metrics to be checked.
- key : str
- The key of the metric to be checked.
-
- Returns
- -------
- bool
- True if the latest value is the best so far, otherwise False.
- """
- return self.history[label][key][-1] == min(self.history[label][key])
-
- def state_dict(self):
- """
- Returns a dictionary containing the state of the tracker.
-
- Returns
- -------
- dict
- A dictionary containing the history and step of the tracker.
- """
- return {"history": self.history, "step": self.step}
-
- def load_state_dict(self, state_dict):
- """
- Loads the state of the tracker from the given state dictionary.
-
- Parameters
- ----------
- state_dict : dict
- A dictionary containing the history and step of the tracker.
-
- Returns
- -------
- Tracker
- The tracker object with the loaded state.
- """
- self.history = state_dict["history"]
- self.step = state_dict["step"]
- return self
diff --git a/dito/models/ldm/dac/audiotools/ml/experiment.py b/dito/models/ldm/dac/audiotools/ml/experiment.py
deleted file mode 100644
index 62833d0f8f80dcdf496a1a5d2785ef666e0a15b6..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/experiment.py
+++ /dev/null
@@ -1,90 +0,0 @@
-"""
-Useful class for Experiment tracking, and ensuring code is
-saved alongside files.
-""" # fmt: skip
-import datetime
-import os
-import shlex
-import shutil
-import subprocess
-import typing
-from pathlib import Path
-
-import randomname
-
-
-class Experiment:
- """This class contains utilities for managing experiments.
- It is a context manager, that when you enter it, changes
- your directory to a specified experiment folder (which
- optionally can have an automatically generated experiment
- name, or a specified one), and changes the CUDA device used
- to the specified device (or devices).
-
- Parameters
- ----------
- exp_directory : str
- Folder where all experiments are saved, by default "runs/".
- exp_name : str, optional
- Name of the experiment, by default uses the current time, date, and
- hostname to save.
- """
-
- def __init__(
- self,
- exp_directory: str = "runs/",
- exp_name: str = None,
- ):
- if exp_name is None:
- exp_name = self.generate_exp_name()
- exp_dir = Path(exp_directory) / exp_name
- exp_dir.mkdir(parents=True, exist_ok=True)
-
- self.exp_dir = exp_dir
- self.exp_name = exp_name
- self.git_tracked_files = (
- subprocess.check_output(
- shlex.split("git ls-tree --full-tree --name-only -r HEAD")
- )
- .decode("utf-8")
- .splitlines()
- )
- self.parent_directory = Path(".").absolute()
-
- def __enter__(self):
- self.prev_dir = os.getcwd()
- os.chdir(self.exp_dir)
- return self
-
- def __exit__(self, exc_type, exc_value, traceback):
- os.chdir(self.prev_dir)
-
- @staticmethod
- def generate_exp_name():
- """Generates a random experiment name based on the date
- and a randomly generated adjective-noun tuple.
-
- Returns
- -------
- str
- Randomly generated experiment name.
- """
- date = datetime.datetime.now().strftime("%y%m%d")
- name = f"{date}-{randomname.get_name()}"
- return name
-
- def snapshot(self, filter_fn: typing.Callable = lambda f: True):
- """Captures a full snapshot of all the files tracked by git at the time
- the experiment is run. It also captures the diff against the committed
- code as a separate file.
-
- Parameters
- ----------
- filter_fn : typing.Callable, optional
- Function that can be used to exclude some files
- from the snapshot, by default accepts all files
- """
- for f in self.git_tracked_files:
- if filter_fn(f):
- Path(f).parent.mkdir(parents=True, exist_ok=True)
- shutil.copyfile(self.parent_directory / f, f)
diff --git a/dito/models/ldm/dac/audiotools/ml/layers/__init__.py b/dito/models/ldm/dac/audiotools/ml/layers/__init__.py
deleted file mode 100644
index 92a016cab2ddf06bf5dadfae241b7e5d9def4878..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/layers/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .base import BaseModel
-from .spectral_gate import SpectralGate
diff --git a/dito/models/ldm/dac/audiotools/ml/layers/base.py b/dito/models/ldm/dac/audiotools/ml/layers/base.py
deleted file mode 100644
index b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/layers/base.py
+++ /dev/null
@@ -1,328 +0,0 @@
-import inspect
-import shutil
-import tempfile
-import typing
-from pathlib import Path
-
-import torch
-from torch import nn
-
-
-class BaseModel(nn.Module):
- """This is a class that adds useful save/load functionality to a
- ``torch.nn.Module`` object. ``BaseModel`` objects can be saved
- as ``torch.package`` easily, making them super easy to port between
- machines without requiring a ton of dependencies. Files can also be
- saved as just weights, in the standard way.
-
- >>> class Model(ml.BaseModel):
- >>> def __init__(self, arg1: float = 1.0):
- >>> super().__init__()
- >>> self.arg1 = arg1
- >>> self.linear = nn.Linear(1, 1)
- >>>
- >>> def forward(self, x):
- >>> return self.linear(x)
- >>>
- >>> model1 = Model()
- >>>
- >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
- >>> model1.save(
- >>> f.name,
- >>> )
- >>> model2 = Model.load(f.name)
- >>> out2 = seed_and_run(model2, x)
- >>> assert torch.allclose(out1, out2)
- >>>
- >>> model1.save(f.name, package=True)
- >>> model2 = Model.load(f.name)
- >>> model2.save(f.name, package=False)
- >>> model3 = Model.load(f.name)
- >>> out3 = seed_and_run(model3, x)
- >>>
- >>> with tempfile.TemporaryDirectory() as d:
- >>> model1.save_to_folder(d, {"data": 1.0})
- >>> Model.load_from_folder(d)
-
- """
-
- EXTERN = [
- "audiotools.**",
- "tqdm",
- "__main__",
- "numpy.**",
- "julius.**",
- "torchaudio.**",
- "scipy.**",
- "einops",
- ]
- """Names of libraries that are external to the torch.package saving mechanism.
- Source code from these libraries will not be packaged into the model. This can
- be edited by the user of this class by editing ``model.EXTERN``."""
- INTERN = []
- """Names of libraries that are internal to the torch.package saving mechanism.
- Source code from these libraries will be saved alongside the model."""
-
- def save(
- self,
- path: str,
- metadata: dict = None,
- package: bool = True,
- intern: list = [],
- extern: list = [],
- mock: list = [],
- ):
- """Saves the model, either as a torch package, or just as
- weights, alongside some specified metadata.
-
- Parameters
- ----------
- path : str
- Path to save model to.
- metadata : dict, optional
- Any metadata to save alongside the model,
- by default None
- package : bool, optional
- Whether to use ``torch.package`` to save the model in
- a format that is portable, by default True
- intern : list, optional
- List of additional libraries that are internal
- to the model, used with torch.package, by default []
- extern : list, optional
- List of additional libraries that are external to
- the model, used with torch.package, by default []
- mock : list, optional
- List of libraries to mock, used with torch.package,
- by default []
-
- Returns
- -------
- str
- Path to saved model.
- """
- sig = inspect.signature(self.__class__)
- args = {}
-
- for key, val in sig.parameters.items():
- arg_val = val.default
- if arg_val is not inspect.Parameter.empty:
- args[key] = arg_val
-
- # Look up attibutes in self, and if any of them are in args,
- # overwrite them in args.
- for attribute in dir(self):
- if attribute in args:
- args[attribute] = getattr(self, attribute)
-
- metadata = {} if metadata is None else metadata
- metadata["kwargs"] = args
- if not hasattr(self, "metadata"):
- self.metadata = {}
- self.metadata.update(metadata)
-
- if not package:
- state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
- torch.save(state_dict, path)
- else:
- self._save_package(path, intern=intern, extern=extern, mock=mock)
-
- return path
-
- @property
- def device(self):
- """Gets the device the model is on by looking at the device of
- the first parameter. May not be valid if model is split across
- multiple devices.
- """
- return list(self.parameters())[0].device
-
- @classmethod
- def load(
- cls,
- location: str,
- *args,
- package_name: str = None,
- strict: bool = False,
- **kwargs,
- ):
- """Load model from a path. Tries first to load as a package, and if
- that fails, tries to load as weights. The arguments to the class are
- specified inside the model weights file.
-
- Parameters
- ----------
- location : str
- Path to file.
- package_name : str, optional
- Name of package, by default ``cls.__name__``.
- strict : bool, optional
- Ignore unmatched keys, by default False
- kwargs : dict
- Additional keyword arguments to the model instantiation, if
- not loading from package.
-
- Returns
- -------
- BaseModel
- A model that inherits from BaseModel.
- """
- try:
- model = cls._load_package(location, package_name=package_name)
- except:
- model_dict = torch.load(location, "cpu")
- metadata = model_dict["metadata"]
- metadata["kwargs"].update(kwargs)
-
- sig = inspect.signature(cls)
- class_keys = list(sig.parameters.keys())
- for k in list(metadata["kwargs"].keys()):
- if k not in class_keys:
- metadata["kwargs"].pop(k)
-
- model = cls(*args, **metadata["kwargs"])
- model.load_state_dict(model_dict["state_dict"], strict=strict)
- model.metadata = metadata
-
- return model
-
- def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
- package_name = type(self).__name__
- resource_name = f"{type(self).__name__}.pth"
-
- # Below is for loading and re-saving a package.
- if hasattr(self, "importer"):
- kwargs["importer"] = (self.importer, torch.package.sys_importer)
- del self.importer
-
- # Why do we use a tempfile, you ask?
- # It's so we can load a packaged model and then re-save
- # it to the same location. torch.package throws an
- # error if it's loading and writing to the same
- # file (this is undocumented).
- with tempfile.NamedTemporaryFile(suffix=".pth") as f:
- with torch.package.PackageExporter(f.name, **kwargs) as exp:
- exp.intern(self.INTERN + intern)
- exp.mock(mock)
- exp.extern(self.EXTERN + extern)
- exp.save_pickle(package_name, resource_name, self)
-
- if hasattr(self, "metadata"):
- exp.save_pickle(
- package_name, f"{package_name}.metadata", self.metadata
- )
-
- shutil.copyfile(f.name, path)
-
- # Must reset the importer back to `self` if it existed
- # so that you can save the model again!
- if "importer" in kwargs:
- self.importer = kwargs["importer"][0]
- return path
-
- @classmethod
- def _load_package(cls, path, package_name=None):
- package_name = cls.__name__ if package_name is None else package_name
- resource_name = f"{package_name}.pth"
-
- imp = torch.package.PackageImporter(path)
- model = imp.load_pickle(package_name, resource_name, "cpu")
- try:
- model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
- except: # pragma: no cover
- pass
- model.importer = imp
-
- return model
-
- def save_to_folder(
- self,
- folder: typing.Union[str, Path],
- extra_data: dict = None,
- package: bool = True,
- ):
- """Dumps a model into a folder, as both a package
- and as weights, as well as anything specified in
- ``extra_data``. ``extra_data`` is a dictionary of other
- pickleable files, with the keys being the paths
- to save them in. The model is saved under a subfolder
- specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
- if the model name was ``Generator``).
-
- >>> with tempfile.TemporaryDirectory() as d:
- >>> extra_data = {
- >>> "optimizer.pth": optimizer.state_dict()
- >>> }
- >>> model.save_to_folder(d, extra_data)
- >>> Model.load_from_folder(d)
-
- Parameters
- ----------
- folder : typing.Union[str, Path]
- _description_
- extra_data : dict, optional
- _description_, by default None
-
- Returns
- -------
- str
- Path to folder
- """
- extra_data = {} if extra_data is None else extra_data
- model_name = type(self).__name__.lower()
- target_base = Path(f"{folder}/{model_name}/")
- target_base.mkdir(exist_ok=True, parents=True)
-
- if package:
- package_path = target_base / f"package.pth"
- self.save(package_path)
-
- weights_path = target_base / f"weights.pth"
- self.save(weights_path, package=False)
-
- for path, obj in extra_data.items():
- torch.save(obj, target_base / path)
-
- return target_base
-
- @classmethod
- def load_from_folder(
- cls,
- folder: typing.Union[str, Path],
- package: bool = True,
- strict: bool = False,
- **kwargs,
- ):
- """Loads the model from a folder generated by
- :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
- Like that function, this one looks for a subfolder that has
- the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
- model name was ``Generator``).
-
- Parameters
- ----------
- folder : typing.Union[str, Path]
- _description_
- package : bool, optional
- Whether to use ``torch.package`` to load the model,
- loading the model from ``package.pth``.
- strict : bool, optional
- Ignore unmatched keys, by default False
-
- Returns
- -------
- tuple
- tuple of model and extra data as saved by
- :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
- """
- folder = Path(folder) / cls.__name__.lower()
- model_pth = "package.pth" if package else "weights.pth"
- model_pth = folder / model_pth
-
- model = cls.load(model_pth, strict=strict)
- extra_data = {}
- excluded = ["package.pth", "weights.pth"]
- files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
- for f in files:
- extra_data[f.name] = torch.load(f, **kwargs)
-
- return model, extra_data
diff --git a/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py b/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py
deleted file mode 100644
index c4ae8b5eab2e56ce13541695f52a11a454759dae..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-from ...core import AudioSignal
-from ...core import STFTParams
-from ...core import util
-
-
-class SpectralGate(nn.Module):
- """Spectral gating algorithm for noise reduction,
- as in Audacity/Ocenaudio. The steps are as follows:
-
- 1. An FFT is calculated over the noise audio clip
- 2. Statistics are calculated over FFT of the the noise
- (in frequency)
- 3. A threshold is calculated based upon the statistics
- of the noise (and the desired sensitivity of the algorithm)
- 4. An FFT is calculated over the signal
- 5. A mask is determined by comparing the signal FFT to the
- threshold
- 6. The mask is smoothed with a filter over frequency and time
- 7. The mask is appled to the FFT of the signal, and is inverted
-
- Implementation inspired by Tim Sainburg's noisereduce:
-
- https://timsainburg.com/noise-reduction-python.html
-
- Parameters
- ----------
- n_freq : int, optional
- Number of frequency bins to smooth by, by default 3
- n_time : int, optional
- Number of time bins to smooth by, by default 5
- """
-
- def __init__(self, n_freq: int = 3, n_time: int = 5):
- super().__init__()
-
- smoothing_filter = torch.outer(
- torch.cat(
- [
- torch.linspace(0, 1, n_freq + 2)[:-1],
- torch.linspace(1, 0, n_freq + 2),
- ]
- )[..., 1:-1],
- torch.cat(
- [
- torch.linspace(0, 1, n_time + 2)[:-1],
- torch.linspace(1, 0, n_time + 2),
- ]
- )[..., 1:-1],
- )
- smoothing_filter = smoothing_filter / smoothing_filter.sum()
- smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
- self.register_buffer("smoothing_filter", smoothing_filter)
-
- def forward(
- self,
- audio_signal: AudioSignal,
- nz_signal: AudioSignal,
- denoise_amount: float = 1.0,
- n_std: float = 3.0,
- win_length: int = 2048,
- hop_length: int = 512,
- ):
- """Perform noise reduction.
-
- Parameters
- ----------
- audio_signal : AudioSignal
- Audio signal that noise will be removed from.
- nz_signal : AudioSignal, optional
- Noise signal to compute noise statistics from.
- denoise_amount : float, optional
- Amount to denoise by, by default 1.0
- n_std : float, optional
- Number of standard deviations above which to consider
- noise, by default 3.0
- win_length : int, optional
- Length of window for STFT, by default 2048
- hop_length : int, optional
- Hop length for STFT, by default 512
-
- Returns
- -------
- AudioSignal
- Denoised audio signal.
- """
- stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
-
- audio_signal = audio_signal.clone()
- audio_signal.stft_data = None
- audio_signal.stft_params = stft_params
-
- nz_signal = nz_signal.clone()
- nz_signal.stft_params = stft_params
-
- nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
- nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
- nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
-
- nz_thresh = nz_freq_mean + nz_freq_std * n_std
-
- stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
- nb, nac, nf, nt = stft_db.shape
- db_thresh = nz_thresh.expand(nb, nac, -1, nt)
-
- stft_mask = (stft_db < db_thresh).float()
- shape = stft_mask.shape
-
- stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
- pad_tuple = (
- self.smoothing_filter.shape[-2] // 2,
- self.smoothing_filter.shape[-1] // 2,
- )
- stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
- stft_mask = stft_mask.reshape(*shape)
- stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
- audio_signal.device
- )
- stft_mask = 1 - stft_mask
-
- audio_signal.stft_data *= stft_mask
- audio_signal.istft()
-
- return audio_signal
diff --git a/dito/models/ldm/dac/audiotools/post.py b/dito/models/ldm/dac/audiotools/post.py
deleted file mode 100644
index 6ced2d1e66a4ffda3269685bd45593b01038739f..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/post.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import tempfile
-import typing
-import zipfile
-from pathlib import Path
-
-import markdown2 as md
-import matplotlib.pyplot as plt
-import torch
-from IPython.display import HTML
-
-
-def audio_table(
- audio_dict: dict,
- first_column: str = None,
- format_fn: typing.Callable = None,
- **kwargs,
-): # pragma: no cover
- """Embeds an audio table into HTML, or as the output cell
- in a notebook.
-
- Parameters
- ----------
- audio_dict : dict
- Dictionary of data to embed.
- first_column : str, optional
- The label for the first column of the table, by default None
- format_fn : typing.Callable, optional
- How to format the data, by default None
-
- Returns
- -------
- str
- Table as a string
-
- Examples
- --------
-
- >>> audio_dict = {}
- >>> for i in range(signal_batch.batch_size):
- >>> audio_dict[i] = {
- >>> "input": signal_batch[i],
- >>> "output": output_batch[i]
- >>> }
- >>> audiotools.post.audio_zip(audio_dict)
-
- """
- from audiotools import AudioSignal
-
- output = []
- columns = None
-
- def _default_format_fn(label, x, **kwargs):
- if torch.is_tensor(x):
- x = x.tolist()
-
- if x is None:
- return "."
- elif isinstance(x, AudioSignal):
- return x.embed(display=False, return_html=True, **kwargs)
- else:
- return str(x)
-
- if format_fn is None:
- format_fn = _default_format_fn
-
- if first_column is None:
- first_column = "."
-
- for k, v in audio_dict.items():
- if not isinstance(v, dict):
- v = {"Audio": v}
-
- v_keys = list(v.keys())
- if columns is None:
- columns = [first_column] + v_keys
- output.append(" | ".join(columns))
-
- layout = "|---" + len(v_keys) * "|:-:"
- output.append(layout)
-
- formatted_audio = []
- for col in columns[1:]:
- formatted_audio.append(format_fn(col, v[col], **kwargs))
-
- row = f"| {k} | "
- row += " | ".join(formatted_audio)
- output.append(row)
-
- output = "\n" + "\n".join(output)
- return output
-
-
-def in_notebook(): # pragma: no cover
- """Determines if code is running in a notebook.
-
- Returns
- -------
- bool
- Whether or not this is running in a notebook.
- """
- try:
- from IPython import get_ipython
-
- if "IPKernelApp" not in get_ipython().config: # pragma: no cover
- return False
- except ImportError:
- return False
- except AttributeError:
- return False
- return True
-
-
-def disp(obj, **kwargs): # pragma: no cover
- """Displays an object, depending on if its in a notebook
- or not.
-
- Parameters
- ----------
- obj : typing.Any
- Any object to display.
-
- """
- from audiotools import AudioSignal
-
- IN_NOTEBOOK = in_notebook()
-
- if isinstance(obj, AudioSignal):
- audio_elem = obj.embed(display=False, return_html=True)
- if IN_NOTEBOOK:
- return HTML(audio_elem)
- else:
- print(audio_elem)
- if isinstance(obj, dict):
- table = audio_table(obj, **kwargs)
- if IN_NOTEBOOK:
- return HTML(md.markdown(table, extras=["tables"]))
- else:
- print(table)
- if isinstance(obj, plt.Figure):
- plt.show()
diff --git a/dito/models/ldm/dac/audiotools/preference.py b/dito/models/ldm/dac/audiotools/preference.py
deleted file mode 100644
index 800a852e8119dd18ea65784cf95182de2470fbc4..0000000000000000000000000000000000000000
--- a/dito/models/ldm/dac/audiotools/preference.py
+++ /dev/null
@@ -1,600 +0,0 @@
-##############################################################
-### Tools for creating preference tests (MUSHRA, ABX, etc) ###
-##############################################################
-import copy
-import csv
-import random
-import sys
-import traceback
-from collections import defaultdict
-from pathlib import Path
-from typing import List
-
-import gradio as gr
-
-from audiotools.core.util import find_audio
-
-################################################################
-### Logic for audio player, and adding audio / play buttons. ###
-################################################################
-
-WAVESURFER = """"""
-
-CUSTOM_CSS = """
-.gradio-container {
- max-width: 840px !important;
-}
-region.wavesurfer-region:before {
- content: attr(data-region-label);
-}
-
-block {
- min-width: 0 !important;
-}
-
-#wave-timeline {
- background-color: rgba(0, 0, 0, 0.8);
-}
-
-.head.svelte-1cl284s {
- display: none;
-}
-"""
-
-load_wavesurfer_js = """
-function load_wavesurfer() {
- function load_script(url) {
- const script = document.createElement('script');
- script.src = url;
- document.body.appendChild(script);
-
- return new Promise((res, rej) => {
- script.onload = function() {
- res();
- }
- script.onerror = function () {
- rej();
- }
- });
- }
-
- function create_wavesurfer() {
- var options = {
- container: '#waveform',
- waveColor: '#F2F2F2', // Set a darker wave color
- progressColor: 'white', // Set a slightly lighter progress color
- loaderColor: 'white', // Set a slightly lighter loader color
- cursorColor: 'black', // Set a slightly lighter cursor color
- backgroundColor: '#00AAFF', // Set a black background color
- barWidth: 4,
- barRadius: 3,
- barHeight: 1, // the height of the wave
- plugins: [
- WaveSurfer.regions.create({
- regionsMinLength: 0.0,
- dragSelection: {
- slop: 5
- },
- color: 'hsla(200, 50%, 70%, 0.4)',
- }),
- WaveSurfer.timeline.create({
- container: "#wave-timeline",
- primaryLabelInterval: 5.0,
- secondaryLabelInterval: 1.0,
- primaryFontColor: '#F2F2F2',
- secondaryFontColor: '#F2F2F2',
- }),
- ]
- };
- wavesurfer = WaveSurfer.create(options);
- wavesurfer.on('region-created', region => {
- wavesurfer.regions.clear();
- });
- wavesurfer.on('finish', function () {
- var loop = document.getElementById("loop-button").textContent.includes("ON");
- if (loop) {
- wavesurfer.play();
- }
- else {
- var button_elements = document.getElementsByClassName('playpause')
- var buttons = Array.from(button_elements);
-
- for (let j = 0; j < buttons.length; j++) {
- buttons[j].classList.remove("primary");
- buttons[j].classList.add("secondary");
- buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
- }
- }
- });
-
- wavesurfer.on('region-out', function () {
- var loop = document.getElementById("loop-button").textContent.includes("ON");
- if (!loop) {
- var button_elements = document.getElementsByClassName('playpause')
- var buttons = Array.from(button_elements);
-
- for (let j = 0; j < buttons.length; j++) {
- buttons[j].classList.remove("primary");
- buttons[j].classList.add("secondary");
- buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
- }
- wavesurfer.pause();
- }
- });
-
- console.log("Created WaveSurfer object.")
- }
-
- load_script('https://unpkg.com/wavesurfer.js@6.6.4')
- .then(() => {
- load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js")
- .then(() => {
- load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js')
- .then(() => {
- console.log("Loaded regions");
- create_wavesurfer();
- document.getElementById("start-survey").click();
- })
- })
- });
-}
-"""
-
-play = lambda i: """
-function play() {
- var audio_elements = document.getElementsByTagName('audio');
- var button_elements = document.getElementsByClassName('playpause')
-
- var audio_array = Array.from(audio_elements);
- var buttons = Array.from(button_elements);
-
- var src_link = audio_array[{i}].getAttribute("src");
- console.log(src_link);
-
- var loop = document.getElementById("loop-button").textContent.includes("ON");
- var playing = buttons[{i}].textContent.includes("Stop");
-
- for (let j = 0; j < buttons.length; j++) {
- if (j != {i} || playing) {
- buttons[j].classList.remove("primary");
- buttons[j].classList.add("secondary");
- buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
- }
- else {
- buttons[j].classList.remove("secondary");
- buttons[j].classList.add("primary");
- buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop")
- }
- }
-
- if (playing) {
- wavesurfer.pause();
- wavesurfer.seekTo(0.0);
- }
- else {
- wavesurfer.load(src_link);
- wavesurfer.on('ready', function () {
- var region = Object.values(wavesurfer.regions.list)[0];
-
- if (region != null) {
- region.loop = loop;
- region.play();
- } else {
- wavesurfer.play();
- }
- });
- }
-}
-""".replace(
- "{i}", str(i)
-)
-
-clear_regions = """
-function clear_regions() {
- wavesurfer.clearRegions();
-}
-"""
-
-reset_player = """
-function reset_player() {
- wavesurfer.clearRegions();
- wavesurfer.pause();
- wavesurfer.seekTo(0.0);
-
- var button_elements = document.getElementsByClassName('playpause')
- var buttons = Array.from(button_elements);
-
- for (let j = 0; j < buttons.length; j++) {
- buttons[j].classList.remove("primary");
- buttons[j].classList.add("secondary");
- buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
- }
-}
-"""
-
-loop_region = """
-function loop_region() {
- var element = document.getElementById("loop-button");
- var loop = element.textContent.includes("OFF");
- console.log(loop);
-
- try {
- var region = Object.values(wavesurfer.regions.list)[0];
- region.loop = loop;
- } catch {}
-
- if (loop) {
- element.classList.remove("secondary");
- element.classList.add("primary");
- element.textContent = "Looping ON";
- } else {
- element.classList.remove("primary");
- element.classList.add("secondary");
- element.textContent = "Looping OFF";
- }
-}
-"""
-
-
-class Player:
- def __init__(self, app):
- self.app = app
-
- self.app.load(_js=load_wavesurfer_js)
- self.app.css = CUSTOM_CSS
-
- self.wavs = []
- self.position = 0
-
- def create(self):
- gr.HTML(WAVESURFER)
- gr.Markdown(
- "Click and drag on the waveform above to select a region for playback. "
- "Once created, the region can be moved around and resized. "
- "Clear the regions using the button below. Hit play on one of the buttons below to start!"
- )
-
- with gr.Row():
- clear = gr.Button("Clear region")
- loop = gr.Button("Looping OFF", elem_id="loop-button")
-
- loop.click(None, _js=loop_region)
- clear.click(None, _js=clear_regions)
-
- gr.HTML("")
-
- def add(self, name: str = "Play"):
- i = self.position
- self.wavs.append(
- {
- "audio": gr.Audio(visible=False),
- "button": gr.Button(name, elem_classes=["playpause"]),
- "position": i,
- }
- )
- self.wavs[-1]["button"].click(None, _js=play(i))
- self.position += 1
- return self.wavs[-1]
-
- def to_list(self):
- return [x["audio"] for x in self.wavs]
-
-
-############################################################
-### Keeping track of users, and CSS for the progress bar ###
-############################################################
-
-load_tracker = lambda name: """
-function load_name() {
- function setCookie(name, value, exp_days) {
- var d = new Date();
- d.setTime(d.getTime() + (exp_days*24*60*60*1000));
- var expires = "expires=" + d.toGMTString();
- document.cookie = name + "=" + value + ";" + expires + ";path=/";
- }
-
- function getCookie(name) {
- var cname = name + "=";
- var decodedCookie = decodeURIComponent(document.cookie);
- var ca = decodedCookie.split(';');
- for(var i = 0; i < ca.length; i++){
- var c = ca[i];
- while(c.charAt(0) == ' '){
- c = c.substring(1);
- }
- if(c.indexOf(cname) == 0){
- return c.substring(cname.length, c.length);
- }
- }
- return "";
- }
-
- name = getCookie("{name}");
- if (name == "") {
- name = Math.random().toString(36).slice(2);
- console.log(name);
- setCookie("name", name, 30);
- }
- name = getCookie("{name}");
- return name;
-}
-""".replace(
- "{name}", name
-)
-
-# Progress bar
-
-progress_template = """
-
-
-
- Progress Bar
-
-
-
-
-
-
{TEXT}
-
-
-
-"""
-
-
-def create_tracker(app, cookie_name="name"):
- user = gr.Text(label="user", interactive=True, visible=False, elem_id="user")
- app.load(_js=load_tracker(cookie_name), outputs=user)
- return user
-
-
-#################################################################
-### CSS and HTML for labeling sliders for both ABX and MUSHRA ###
-#################################################################
-
-slider_abx = """
-
-
-
-
- Labels Example
-
-
-
-