diff --git a/dito/.gitignore b/dito/.gitignore deleted file mode 100644 index ed8ebf583f771da9150c35db3955987b7d757904..0000000000000000000000000000000000000000 --- a/dito/.gitignore +++ /dev/null @@ -1 +0,0 @@ -__pycache__ \ No newline at end of file diff --git a/dito/configs/datasets/dae.yaml b/dito/configs/datasets/dae.yaml deleted file mode 100644 index 8fd9018ae2636223cb3c70f192dd58ae397c1c56..0000000000000000000000000000000000000000 --- a/dito/configs/datasets/dae.yaml +++ /dev/null @@ -1,79 +0,0 @@ -# Datasets -datasets: - train: - name: wrapper_audio_cae - args: - dataset: - name: audio_dataset_from_folders - args: - folders: - Emilia_EN: ["/home/masuser/minimax-audio/dataset/Emilia/EN"] - sample_rate: 24000 - duration: 0.38 - n_examples: 10000000 - shuffle: true - mono: true - sample_rate: 24000 - duration: 0.38 - mono: true - normalize: true - return_coords: true - loader: - batch_size: 64 - num_workers: 8 - drop_last: true - - val: - name: wrapper_audio_cae - args: - dataset: - name: audio_dataset_from_folders - args: - folders: - Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"] - sample_rate: 24000 - duration: 5.0 - n_examples: 100 - shuffle: false - mono: true - sample_rate: 24000 - duration: 5.0 - mono: true - normalize: true - return_coords: true - loader: - batch_size: 4 - num_workers: 8 - drop_last: false - - eval_ae: - name: wrapper_audio_cae - args: - dataset: - name: audio_dataset_from_folders - args: - folders: - Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"] - sample_rate: 24000 - duration: 10.0 - n_examples: 1000 - shuffle: false - mono: true - sample_rate: 24000 - duration: 10.0 - mono: true - normalize: true - return_coords: true - loader: - batch_size: 1 - num_workers: 8 - drop_last: false - -# Visualization -visualize_ae_dir: /mnt/nvme/dito_audio -visualize_ae_random_n_samples: 32 -eval_ae_max_samples: 100 -val_idx: [0, 1, 2, 3, 4, 5, 6, 7] - -# Enable autoencoder evaluation -evaluate_ae: true \ No newline at end of file diff --git a/dito/configs/datasets/imagenet_ae.yaml b/dito/configs/datasets/imagenet_ae.yaml deleted file mode 100644 index 5109901f168de8d1f9b30a301fb087e38fcd440e..0000000000000000000000000000000000000000 --- a/dito/configs/datasets/imagenet_ae.yaml +++ /dev/null @@ -1,47 +0,0 @@ -datasets: - train: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, rand_crop: 256, rand_flip: true, image_only: true} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 14 - num_workers: 24 - - val: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 14 - num_workers: 24 - - eval_ae: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 14 - num_workers: 24 - drop_last: false - -visualize_ae_dir: /mnt/nvme/dito -visualize_ae_random_n_samples: 32 -eval_ae_max_samples: 5000 \ No newline at end of file diff --git a/dito/configs/datasets/imagenet_zdm.yaml b/dito/configs/datasets/imagenet_zdm.yaml deleted file mode 100644 index 2a54041e020d8ede4c9e3e8647c6c05c5339f0af..0000000000000000000000000000000000000000 --- a/dito/configs/datasets/imagenet_zdm.yaml +++ /dev/null @@ -1,53 +0,0 @@ -datasets: - train: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, square_crop: true, rand_flip: true, drop_label_p: 0.1} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 64 - num_workers: 24 - - val: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 64 - num_workers: 24 - - eval_zdm: - name: wrapper_cae - args: - dataset: - name: class_folder - args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true} - resize_inp: 256 - gt_glores_lb: 256 - gt_glores_ub: 256 - gt_patch_size: 256 - loader: - batch_size: 64 - num_workers: 24 - drop_last: false - -visualize_zdm_file: null -visualize_zdm_setting: - name: class - n_classes: 1000 -visualize_zdm_random_n_samples: 12 -visualize_zdm_batch_size: 6 -visualize_zdm_guidance_list: [4] -visualize_zdm_denoising_file: null -eval_zdm_max_samples: 5000 \ No newline at end of file diff --git a/dito/configs/experiments/dito-B-audio.yaml b/dito/configs/experiments/dito-B-audio.yaml deleted file mode 100644 index 179a84a947164ba7a687fdbddde627970989fb4b..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-B-audio.yaml +++ /dev/null @@ -1,44 +0,0 @@ -__base__: - - configs/datasets/dae.yaml - - configs/trainers/dito.yaml - -model: - name: dito_audio - args: - # Encoder - encoder: - name: dac_encoder - args: {config_name: snakebeta} - - # Latent configuration - now fully convolutional - z_channels: 64 # Number of latent channels - z_downsample_factor: 320 # Product of encoder_rates: 2*4*5*8 - z_layernorm: true - - # Decoder (identity for DiTo) - decoder: - name: identity - - # Renderer - Fully convolutional for dynamic duration - renderer: - name: audio_renderer_wrapper - args: - net: - name: consistency_decoder_unet # Fully Convolutional Network - args: - in_channels: 1 - z_dec_channels: 64 - c0: 128 - c1: 256 - c2: 512 - pe_dim: 320 - t_dim: 1280 - - # Diffusion configuration - render_diffusion: - name: fm - args: {timescale: 1000.0} - - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - diff --git a/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml b/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml deleted file mode 100644 index 5eb1df5b78efb87d3aea6d9850924a8b85cc8da1..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-B-f8c4-noise-sync.yaml +++ /dev/null @@ -1,43 +0,0 @@ -__base__: - - configs/datasets/imagenet_ae.yaml - - configs/trainers/dito.yaml - -model: - name: dito - args: - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [64, 1, 1] - z_layernorm: true - - zaug_p: 0.1 - zaug_decoding_loss_type: suffix - zaug_zdm_diffusion: - name: fm - args: {timescale: 1000.0} - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 64 - c0: 128 - c1: 256 - c2: 512 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/dito-B-f8c4.yaml b/dito/configs/experiments/dito-B-f8c4.yaml deleted file mode 100644 index feeab68b5dace98a9701bbc16a27f56d81e0dfc7..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-B-f8c4.yaml +++ /dev/null @@ -1,37 +0,0 @@ -__base__: - - configs/datasets/imagenet_ae.yaml - - configs/trainers/dito.yaml - -model: - name: dito - args: - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 128 - c1: 256 - c2: 512 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/dito-L-f8c4.yaml b/dito/configs/experiments/dito-L-f8c4.yaml deleted file mode 100644 index 5f242bf6f27755c559e439eba78ae03809f07898..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-L-f8c4.yaml +++ /dev/null @@ -1,37 +0,0 @@ -__base__: - - configs/datasets/imagenet_ae.yaml - - configs/trainers/dito.yaml - -model: - name: dito - args: - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 192 - c1: 384 - c2: 768 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml deleted file mode 100644 index ec7761d6570e5456b8099dd6a1eb28754de8c953..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-XL-f8c4-noise-sync.yaml +++ /dev/null @@ -1,43 +0,0 @@ -__base__: - - configs/datasets/imagenet_ae.yaml - - configs/trainers/dito.yaml - -model: - name: dito - args: - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - zaug_p: 0.1 - zaug_decoding_loss_type: suffix - zaug_zdm_diffusion: - name: fm - args: {timescale: 1000.0} - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/dito-XL-f8c4.yaml b/dito/configs/experiments/dito-XL-f8c4.yaml deleted file mode 100644 index 8610d23c1d1479538b1dd5d568d427f08ab7e9ec..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/dito-XL-f8c4.yaml +++ /dev/null @@ -1,37 +0,0 @@ -__base__: - - configs/datasets/imagenet_ae.yaml - - configs/trainers/dito.yaml - -model: - name: dito - args: - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml deleted file mode 100644 index 679cc9fa97413b19ca1d1581c0a008c4107b7bb9..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml +++ /dev/null @@ -1,44 +0,0 @@ -__base__: - - configs/datasets/imagenet_zdm.yaml - - configs/models/zdm-XL_imagenet.yaml - - configs/trainers/zdm.yaml - -eval_zdm_max_samples: 50000 - -model: - load_ckpt: save/zdm-XL_dito-XL-f8c4-noise-sync/ckpt-last.pth - name: dito - args: - zdm_force_guidance: 2.0 - renderer_ema_rate: 1 - - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml b/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml deleted file mode 100644 index f5d360f8536c7652bf77cdeeaeff2dc33885136e..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml +++ /dev/null @@ -1,44 +0,0 @@ -__base__: - - configs/datasets/imagenet_zdm.yaml - - configs/models/zdm-XL_imagenet.yaml - - configs/trainers/zdm.yaml - -eval_zdm_max_samples: 50000 - -model: - load_ckpt: save/zdm-XL_dito-XL-f8c4/ckpt-last.pth - name: dito - args: - zdm_force_guidance: 2.0 - renderer_ema_rate: 1 - - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml b/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml deleted file mode 100644 index 283a17ad9a677a47a152a5e2b5a44a3905183dae..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml +++ /dev/null @@ -1,41 +0,0 @@ -__base__: - - configs/datasets/imagenet_zdm.yaml - - configs/models/zdm-XL_imagenet.yaml - - configs/trainers/zdm.yaml - -model: - load_ckpt: save/dito-XL-f8c4-noise-sync/ckpt-last.pth - name: dito - args: - renderer_ema_rate: 1 - - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml b/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml deleted file mode 100644 index eeee58d880334bb1806ec51b8772f72bd1f34a75..0000000000000000000000000000000000000000 --- a/dito/configs/experiments/zdm-XL_dito-XL-f8c4.yaml +++ /dev/null @@ -1,41 +0,0 @@ -__base__: - - configs/datasets/imagenet_zdm.yaml - - configs/models/zdm-XL_imagenet.yaml - - configs/trainers/zdm.yaml - -model: - load_ckpt: - name: dito - args: - renderer_ema_rate: 1 - - encoder: - name: vqgan_encoder - args: {config_name: f8c4} - - z_shape: [4, 32, 32] - z_layernorm: true - - decoder: {name: identity} - - renderer: - name: fixres_renderer_wrapper - args: - net: - name: consistency_decoder_unet - args: - in_channels: 3 - z_dec_channels: 4 - c0: 320 - c1: 640 - c2: 1024 - pe_dim: 320 - t_dim: 1280 - - render_diffusion: - name: fm - args: {timescale: 1000.0} - render_sampler: {name: fm_euler_sampler} - render_n_steps: 50 - - loss_config: {} diff --git a/dito/configs/models/zdm-XL_imagenet.yaml b/dito/configs/models/zdm-XL_imagenet.yaml deleted file mode 100644 index 27cb8116834b7f90c09c612f39507f9acf541d4e..0000000000000000000000000000000000000000 --- a/dito/configs/models/zdm-XL_imagenet.yaml +++ /dev/null @@ -1,12 +0,0 @@ -model: - args: - zdm_net: - name: dit_xl_2 - args: {n_classes: 1001} - zdm_diffusion: - name: fm - args: {timescale: 1000.0} - zdm_sampler: {name: fm_euler_sampler} - zdm_n_steps: 200 - zdm_train_normalize: false - zdm_class_cond: 1000 \ No newline at end of file diff --git a/dito/datasets/__init__.py b/dito/datasets/__init__.py deleted file mode 100644 index 52749f4dfe1a9a36afab67115bdce2f243851bd6..0000000000000000000000000000000000000000 --- a/dito/datasets/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .datasets import register, make -from . import image_folder, class_folder, webdataset -from . import wrapper_cae diff --git a/dito/datasets/class_folder.py b/dito/datasets/class_folder.py deleted file mode 100644 index 91070b4545e1752c64c011d2f2c07b7e3bbb8b3f..0000000000000000000000000000000000000000 --- a/dito/datasets/class_folder.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -import random -from PIL import Image, ImageFile - -from datasets import register -from torch.utils.data import Dataset -from torchvision import transforms - - -Image.MAX_IMAGE_PIXELS = 933120000 -ImageFile.LOAD_TRUNCATED_IMAGES = True -IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') - - -@register('class_folder') -class ClassFolder(Dataset): - - def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False): - folders = [] - print('root_path', root_path) - for folder in sorted(os.listdir(root_path)): - print('folder', folder) - if os.path.isdir(os.path.join(root_path, folder)): - folders.append(os.path.join(root_path, folder)) - print('folders', folders) - self.files = [] - self.labels = [] - for i, folder in enumerate(folders): - for file in sorted(os.listdir(os.path.join(root_path, folder))): - if file.endswith(IMAGE_EXTS): - self.files.append(os.path.join(root_path, folder, file)) - self.labels.append(i) - - self.resize = resize - self.square_crop = square_crop - self.rand_crop = rand_crop - self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None - - self.n_classes = len(folders) - self.drop_label_p = drop_label_p - - self.image_only = image_only - - def __len__(self): - return len(self.files) - - def __getitem__(self, idx): - try: - image = Image.open(self.files[idx]).convert('RGB') - label = self.labels[idx] - except: - print('Error loading image:', self.files[idx]) - return self.__getitem__((idx + 1) % self.__len__()) - - if self.resize is not None: - r = self.resize - if isinstance(r, int): - w, h = image.size - if w < h: - r = (r, int(h / w * r)) - else: - r = (int(w / h * r), r) - image = image.resize(r, Image.LANCZOS) - - if self.square_crop: - w, h = image.size - l = min(w, h) - left, upper = (w - l) // 2, (h - l) // 2 - image = image.crop((left, upper, left + l, upper + l)) - - if self.rand_crop is not None: - w, h = image.size - left = random.randint(0, w - self.rand_crop) - upper = random.randint(0, h - self.rand_crop) - image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) - - if self.rand_flip is not None: - image = self.rand_flip(image) - - if self.drop_label_p > 0.0 and random.random() < self.drop_label_p: - label = self.n_classes - - if self.image_only: - return image - else: - return { - 'image': image, - 'class_labels': label, - } diff --git a/dito/datasets/datasets.py b/dito/datasets/datasets.py deleted file mode 100644 index 3f102233b144c4c04c35204fd97b052236502661..0000000000000000000000000000000000000000 --- a/dito/datasets/datasets.py +++ /dev/null @@ -1,17 +0,0 @@ -datasets = dict() - - -def register(name): - def decorator(cls): - datasets[name] = cls - return cls - return decorator - - -def make(spec): - args = spec.get('args') - if args is None: - args = dict() - print('args:', args) - dataset = datasets[spec['name']](**args) - return dataset diff --git a/dito/datasets/image_folder.py b/dito/datasets/image_folder.py deleted file mode 100644 index 7564ef3fa31ea6f6aadd26a2d37b77c213e00095..0000000000000000000000000000000000000000 --- a/dito/datasets/image_folder.py +++ /dev/null @@ -1,62 +0,0 @@ -import os -import random -from PIL import Image, ImageFile - -from datasets import register -from torch.utils.data import Dataset -from torchvision import transforms - - -Image.MAX_IMAGE_PIXELS = 933120000 -ImageFile.LOAD_TRUNCATED_IMAGES = True -IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') - - -@register('image_folder') -class ImageFolder(Dataset): - - def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False): - files = sorted(os.listdir(root_path)) - self.files = [os.path.join(root_path, _) for _ in files if _.endswith(IMAGE_EXTS)] - - self.resize = resize - self.square_crop = square_crop - self.rand_crop = rand_crop - self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None - - def __len__(self): - return len(self.files) - - def __getitem__(self, idx): - try: - image = Image.open(self.files[idx]).convert('RGB') - except: - print('Error loading image:', self.files[idx]) - return self.__getitem__((idx + 1) % self.__len__()) - - if self.resize is not None: - r = self.resize - if isinstance(r, int): - w, h = image.size - if w < h: - r = (r, int(h / w * r)) - else: - r = (int(w / h * r), r) - image = image.resize(r, Image.LANCZOS) - - if self.square_crop: - w, h = image.size - l = min(w, h) - left, upper = (w - l) // 2, (h - l) // 2 - image = image.crop((left, upper, left + l, upper + l)) - - if self.rand_crop is not None: - w, h = image.size - left = random.randint(0, w - self.rand_crop) - upper = random.randint(0, h - self.rand_crop) - image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) - - if self.rand_flip is not None: - image = self.rand_flip(image) - - return image diff --git a/dito/datasets/webdataset.py b/dito/datasets/webdataset.py deleted file mode 100644 index 7772713a688566eb8ac6fc7c38efcaf1c2993d66..0000000000000000000000000000000000000000 --- a/dito/datasets/webdataset.py +++ /dev/null @@ -1,45 +0,0 @@ -import json - -import webdataset as wds -from webdataset.handlers import warn_and_continue - -from datasets import register - - -def webdataset_preprocessors(square_crop=True): - def identity(x): - if isinstance(x, bytes): - x = x.decode('utf-8') - return x - - def transform(image): - w, h = image.size - l = min(w, h) - left, upper = (w - l) // 2, (h - l) // 2 - return image.crop((left, upper, left + l, upper + l)) - - ret = [ - ('jpg;png', transform if square_crop else lambda x: x, 'image'), - ('txt', identity, 'caption'), - ] - - return ret - - -@register('webdataset') -def make_webdataset(json_file, **kwargs): - with open(json_file, 'r') as file: - tar_list = json.load(file) - preprocessors = webdataset_preprocessors(**kwargs) - handler = warn_and_continue - dataset = wds.WebDataset( - tar_list, resampled=True, handler=handler - ).shuffle(690, handler=handler).decode( - "pilrgb", handler=handler - ).to_tuple( - *[p[0] for p in preprocessors], handler=handler - ).map_tuple( - *[p[1] for p in preprocessors], handler=handler - ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) - - return dataset diff --git a/dito/datasets/wrapper_cae.py b/dito/datasets/wrapper_cae.py deleted file mode 100644 index 1a684e1b959287c6647ad49399d7f5458f0818d0..0000000000000000000000000000000000000000 --- a/dito/datasets/wrapper_cae.py +++ /dev/null @@ -1,308 +0,0 @@ -import random -from PIL import Image - -import torch -from torch.utils.data import Dataset, IterableDataset -from torchvision import transforms - -import datasets -from datasets import register -from utils.geometry import make_coord_scale_grid - - -from models.ldm.dac.audiotools import AudioSignal -import numpy as np - -from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader -from models.ldm.dac.audiotools import transforms as tfm - - -class BaseWrapperCAE: - - def __init__( - self, - dataset, - resize_inp, - return_gt=True, - gt_glores_lb=None, - gt_glores_ub=None, - gt_patch_size=None, - p_whole=0.0, - p_max=0.0 - ): - self.dataset = datasets.make(dataset) - self.resize_inp = resize_inp - self.return_gt = return_gt - self.gt_glores_lb = gt_glores_lb - self.gt_glores_ub = gt_glores_ub - self.gt_patch_size = gt_patch_size - self.p_whole = p_whole - self.p_max = p_max - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize(0.5, 0.5), - ]) - - def process(self, image): - assert image.size[0] == image.size[1] - ret = {} - - inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS) - inp = self.transform(inp) - ret.update({'inp': inp}) - if not self.return_gt: - return ret - - if self.gt_glores_lb is None: - glo = self.transform(image) - else: - if random.random() < self.p_whole: - r = self.gt_patch_size - elif random.random() < self.p_max: - r = min(image.size[0], self.gt_glores_ub) - else: - r = random.randint( - self.gt_glores_lb, - max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub)) - ) - glo = image.resize((r, r), Image.LANCZOS) - glo = self.transform(glo) - - p = self.gt_patch_size - ii = random.randint(0, glo.shape[1] - p) - jj = random.randint(0, glo.shape[2] - p) - gt_patch = glo[:, ii: ii + p, jj: jj + p] - - x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1] - x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1] - coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]]) - ret['gt'] = torch.cat([ - gt_patch, # 3 p p - coord.permute(2, 0, 1), # 2 p p - scale.permute(2, 0, 1), # 2 p p - ], dim=0) - - return ret - - -@register('wrapper_cae') -class WrapperCAE(BaseWrapperCAE, Dataset): - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - if isinstance(data, dict): - ret = dict() - ret.update(self.process(data.pop('image'))) - ret.update(data) - return ret - else: - return self.process(data) - - -@register('wrapper_cae_iterable') -class WrapperCAE(BaseWrapperCAE, IterableDataset): - - def __iter__(self): - for data in self.dataset: - if isinstance(data, dict): - ret = dict() - ret.update(self.process(data.pop('image'))) - ret.update(data) - yield ret - else: - yield self.process(data) - - - - - - -class BaseWrapperAudioCAE: - """Base wrapper for audio Convolutional Autoencoder (CAE) training. - - Similar to the image wrapper, but for audio data. - """ - - def __init__( - self, - dataset, - sample_rate=24000, - duration=0.38, # Duration in seconds - n_samples=None, # Alternative: specify exact number of samples - return_gt=True, - gt_sample_rate=None, # Ground truth sample rate (if different) - mono=True, - normalize=True, - return_coords=True, # Whether to return coordinate grids - ): - self.dataset = dataset - self.sample_rate = sample_rate - self.duration = duration - self.n_samples = n_samples or int(duration * sample_rate) - self.return_gt = return_gt - self.gt_sample_rate = gt_sample_rate or sample_rate - self.mono = mono - self.normalize = normalize - self.return_coords = return_coords - - def process(self, audio_data): - """Process audio data for DiTo training. - - Args: - audio_data: Dictionary with 'signal' key containing AudioSignal - or AudioSignal directly - """ - ret = {} - - # Extract AudioSignal - if isinstance(audio_data, dict): - signal = audio_data['signal'] - else: - signal = audio_data - - # Convert to mono if needed - if self.mono and signal.num_channels > 1: - signal = signal.to_mono() - - # Resample to target sample rate - if signal.sample_rate != self.sample_rate: - signal = signal.resample(self.sample_rate) - - # Extract fixed duration - if signal.duration < self.duration: - # Pad if too short - signal = signal.zero_pad_to(self.n_samples) - else: - # Take random excerpt if too long - max_start = signal.num_samples - self.n_samples - if max_start > 0: - start_idx = random.randint(0, max_start) - signal = signal[..., start_idx:start_idx + self.n_samples] - else: - signal = signal[..., :self.n_samples] - - # Normalize audio - audio_tensor = signal.audio_data # Shape: [channels, samples] - if self.normalize: - # Normalize to [-1, 1] - max_val = audio_tensor.abs().max() - if max_val > 0: - audio_tensor = audio_tensor / max_val - - # Create input tensor - ret['inp'] = audio_tensor - - if not self.return_gt: - return ret - - - ret['gt'] = audio_tensor - - return ret - - -@register('wrapper_audio_cae') -class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset): - """Dataset wrapper for audio CAE training.""" - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - data = self.dataset[idx] - return self.process(data) - - -@register('wrapper_audio_cae_iterable') -class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset): - """Iterable dataset wrapper for audio CAE training.""" - - def __iter__(self): - for data in self.dataset: - yield self.process(data) - - -# Example usage with your existing AudioDataset -def create_dito_audio_dataset(config): - """Create DiTo audio dataset from config.""" - - # Create base audio dataset using audiotools - - # Setup audio loaders - train_folders = config.get("train_folders", {}) - - loader = AudioLoader( - sources=list(train_folders.values()), - transform=tfm.Compose( - tfm.VolumeNorm(("uniform", -20, -10)), - tfm.RescaleAudio(), - ), - ext=['.wav', '.flac', '.mp3'], - ) - - # Create base dataset - base_dataset = AudioDataset( - loaders=loader, - sample_rate=config['sample_rate'], - duration=config['duration'], - n_examples=config['n_examples'], - num_channels=1 if config.get('mono', True) else 2, - ) - - # Wrap with DiTo wrapper - dito_dataset = WrapperAudioCAE( - dataset=base_dataset, - sample_rate=config['sample_rate'], - duration=config['duration'], - mono=config.get('mono', True), - normalize=True, - return_coords=True, - ) - - return dito_dataset - - -# For your training config, you would use it like: -""" -datasets: - train: - name: wrapper_audio_cae - args: - dataset: - name: audio_dataset # Your base audio dataset - args: - sources: ["/path/to/audio/files"] - sample_rate: 44100 - duration: 2.0 - n_examples: 10000 - sample_rate: 44100 - duration: 2.0 - mono: true - normalize: true - return_coords: true - loader: - batch_size: 16 - num_workers: 8 - - val: - name: wrapper_audio_cae - args: - dataset: - name: audio_dataset - args: - sources: ["/path/to/val/audio/files"] - sample_rate: 44100 - duration: 2.0 - n_examples: 1000 - sample_rate: 44100 - duration: 2.0 - mono: true - normalize: true - return_coords: true - loader: - batch_size: 16 - num_workers: 8 -""" \ No newline at end of file diff --git a/dito/load/dito.png b/dito/load/dito.png deleted file mode 100644 index bbc7f7f11845d568c7c7e4d3439299c87073f4cb..0000000000000000000000000000000000000000 --- a/dito/load/dito.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:25999470ed7eaba155f1e8ab639fd43fe35eb53a53e9770829451b3d0434c467 -size 245154 diff --git a/dito/load/vgg_lpips.pth b/dito/load/vgg_lpips.pth deleted file mode 100644 index 47e943cfacabf7040b4af8cf4084ab91177f1b88..0000000000000000000000000000000000000000 Binary files a/dito/load/vgg_lpips.pth and /dev/null differ diff --git a/dito/load/wandb.yaml b/dito/load/wandb.yaml deleted file mode 100644 index 792009be6e6f698d184cd8176c463da38b9f9075..0000000000000000000000000000000000000000 --- a/dito/load/wandb.yaml +++ /dev/null @@ -1,3 +0,0 @@ -entity: -api_key: -project: \ No newline at end of file diff --git a/dito/models/__init__.py b/dito/models/__init__.py deleted file mode 100644 index b5a61e9db76e11bad1e378d95caa94e06341e0ee..0000000000000000000000000000000000000000 --- a/dito/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .models import register, make -from . import ldm -from . import diffusion -from . import networks \ No newline at end of file diff --git a/dito/models/diffusion/__init__.py b/dito/models/diffusion/__init__.py deleted file mode 100644 index 6578ad49c27b9797ee3e62c0b5d952089e253ed9..0000000000000000000000000000000000000000 --- a/dito/models/diffusion/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import fm -from . import samplers diff --git a/dito/models/diffusion/fm.py b/dito/models/diffusion/fm.py deleted file mode 100644 index 0c008882977b3e240ee4ecfaf07802e565997fd0..0000000000000000000000000000000000000000 --- a/dito/models/diffusion/fm.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch - -from models import register - - -@register('fm') -class FM: - - def __init__(self, sigma_min=1e-5, timescale=1.0): - self.sigma_min = sigma_min - self.prediction_type = None - self.timescale = timescale - - def alpha(self, t): - return 1.0 - t - - def sigma(self, t): - return self.sigma_min + t * (1.0 - self.sigma_min) - - def A(self, t): - return 1.0 - - def B(self, t): - return -(1.0 - self.sigma_min) - - def get_betas(self, n_timesteps): - return torch.zeros(n_timesteps) # Not VP and not supported - - def add_noise(self, x, t, noise=None): - noise = torch.randn_like(x) if noise is None else noise - s = [x.shape[0]] + [1] * (x.dim() - 1) - x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise - return x_t, noise - - def loss(self, net, x, t=None, net_kwargs=None, return_loss_unreduced=False, return_all=False): - if net_kwargs is None: - net_kwargs = {} - - if t is None: - t = torch.rand(x.shape[0], device=x.device) - print('x shape: ', x.shape) - x_t, noise = self.add_noise(x, t) - print('x_t shape: ', x_t.shape) - pred = net(x_t, t=t * self.timescale, **net_kwargs) - print('pred shape: ', pred.shape) - - target = self.A(t) * x + self.B(t) * noise # -dxt/dt - print('target shape: ', target.shape) - print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all) - if return_loss_unreduced: - loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3]) - if return_all: - return loss, t, x_t, pred - else: - return loss, t - else: - # here we go - loss = ((pred.float() - target.float()) ** 2).mean() - if return_all: - return loss, x_t, pred - else: - return loss - - def get_prediction( - self, - net, - x_t, - t, - net_kwargs=None, - uncond_net_kwargs=None, - guidance=1.0, - ): - if net_kwargs is None: - net_kwargs = {} - pred = net(x_t, t=t * self.timescale, **net_kwargs) - if guidance != 1.0: - assert uncond_net_kwargs is not None - uncond_pred = net(x_t, t=t * self.timescale, **uncond_net_kwargs) - pred = uncond_pred + guidance * (pred - uncond_pred) - return pred - - def convert_sample_prediction(self, x_t, t, pred): - M = torch.tensor([ - [self.alpha(t), self.sigma(t)], - [self.A(t), self.B(t)], - ], dtype=torch.float64) - M_inv = torch.linalg.inv(M) - sample_pred = M_inv[0, 0].item() * x_t + M_inv[0, 1].item() * pred - return sample_pred diff --git a/dito/models/diffusion/samplers.py b/dito/models/diffusion/samplers.py deleted file mode 100644 index 75fc52249e08e8f5bee516349f9545f0de7c5513..0000000000000000000000000000000000000000 --- a/dito/models/diffusion/samplers.py +++ /dev/null @@ -1,39 +0,0 @@ -import numpy as np -import torch - -from models import register - - -@register('fm_euler_sampler') -class FMEulerSampler: - - def __init__(self, diffusion): - self.diffusion = diffusion - - def sample( - self, - net, - shape, - n_steps, - net_kwargs=None, - uncond_net_kwargs=None, - guidance=1.0, - noise=None, - ): - device = next(net.parameters()).device - x_t = torch.randn(shape, device=device) if noise is None else noise - t_steps = torch.linspace(1, 0, n_steps + 1, device=device) - - with torch.no_grad(): - for i in range(n_steps): - t = t_steps[i].repeat(x_t.shape[0]) - neg_v = self.diffusion.get_prediction( - net, - x_t, - t, - net_kwargs=net_kwargs, - uncond_net_kwargs=uncond_net_kwargs, - guidance=guidance, - ) - x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1]) - return x_t diff --git a/dito/models/ldm/__init__.py b/dito/models/ldm/__init__.py deleted file mode 100644 index e8103bdfc078dae107945147b04cb44ed5758f5e..0000000000000000000000000000000000000000 --- a/dito/models/ldm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import glpto, dito -from . import renderers -from . import vqgan -from . import dac \ No newline at end of file diff --git a/dito/models/ldm/dac/__init__.py b/dito/models/ldm/dac/__init__.py deleted file mode 100644 index 90f60fdd89ad8575faafe45188bd1d968852fc67..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import * \ No newline at end of file diff --git a/dito/models/ldm/dac/audiotools/__init__.py b/dito/models/ldm/dac/audiotools/__init__.py deleted file mode 100644 index b251ff37628c56a19bb38976fca99d9536e64bbf..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -__version__ = "0.7.4" -from .core import AudioSignal -from .core import STFTParams -from .core import Meter -from .core import util -from . import metrics -from . import data -from . import ml -from .data import datasets -from .data import transforms diff --git a/dito/models/ldm/dac/audiotools/core/__init__.py b/dito/models/ldm/dac/audiotools/core/__init__.py deleted file mode 100644 index 8660c4e67f43d0ded584a38939425e2c28d95cd3..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from . import util -from .audio_signal import AudioSignal -from .audio_signal import STFTParams -from .loudness import Meter diff --git a/dito/models/ldm/dac/audiotools/core/audio_signal.py b/dito/models/ldm/dac/audiotools/core/audio_signal.py deleted file mode 100644 index fb6d751cb968a003656e3e7874c487b83d94c82e..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/audio_signal.py +++ /dev/null @@ -1,1682 +0,0 @@ -import copy -import functools -import hashlib -import math -import pathlib -import tempfile -import typing -import warnings -from collections import namedtuple -from pathlib import Path - -import julius -import numpy as np -import soundfile -import torch - -from . import util -from .display import DisplayMixin -from .dsp import DSPMixin -from .effects import EffectMixin -from .effects import ImpulseResponseMixin -from .ffmpeg import FFMPEGMixin -from .loudness import LoudnessMixin -from .playback import PlayMixin -from .whisper import WhisperMixin - - -STFTParams = namedtuple( - "STFTParams", - ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], -) -""" -STFTParams object is a container that holds STFT parameters - window_length, -hop_length, and window_type. Not all parameters need to be specified. Ones that -are not specified will be inferred by the AudioSignal parameters. - -Parameters ----------- -window_length : int, optional - Window length of STFT, by default ``0.032 * self.sample_rate``. -hop_length : int, optional - Hop length of STFT, by default ``window_length // 4``. -window_type : str, optional - Type of window to use, by default ``sqrt\_hann``. -match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False -padding_type : str, optional - Type of padding to use, by default 'reflect' -""" -STFTParams.__new__.__defaults__ = (None, None, None, None, None) - - -class AudioSignal( - EffectMixin, - LoudnessMixin, - PlayMixin, - ImpulseResponseMixin, - DSPMixin, - DisplayMixin, - FFMPEGMixin, - WhisperMixin, -): - """This is the core object of this library. Audio is always - loaded into an AudioSignal, which then enables all the features - of this library, including audio augmentations, I/O, playback, - and more. - - The structure of this object is that the base functionality - is defined in ``core/audio_signal.py``, while extensions to - that functionality are defined in the other ``core/*.py`` - files. For example, all the display-based functionality - (e.g. plot spectrograms, waveforms, write to tensorboard) - are in ``core/display.py``. - - Parameters - ---------- - audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray] - Object to create AudioSignal from. Can be a tensor, numpy array, - or a path to a file. The file is always reshaped to - sample_rate : int, optional - Sample rate of the audio. If different from underlying file, resampling is - performed. If passing in an array or tensor, this must be defined, - by default None - stft_params : STFTParams, optional - Parameters of STFT to use. , by default None - offset : float, optional - Offset in seconds to read from file, by default 0 - duration : float, optional - Duration in seconds to read from file, by default None - device : str, optional - Device to load audio onto, by default None - - Examples - -------- - Loading an AudioSignal from an array, at a sample rate of - 44100. - - >>> signal = AudioSignal(torch.randn(5*44100), 44100) - - Note, the signal is reshaped to have a batch size, and one - audio channel: - - >>> print(signal.shape) - (1, 1, 44100) - - You can treat AudioSignals like tensors, and many of the same - functions you might use on tensors are defined for AudioSignals - as well: - - >>> signal.to("cuda") - >>> signal.cuda() - >>> signal.clone() - >>> signal.detach() - - Indexing AudioSignals returns an AudioSignal: - - >>> signal[..., 3*44100:4*44100] - - The above signal is 1 second long, and is also an AudioSignal. - """ - - def __init__( - self, - audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray], - sample_rate: int = None, - stft_params: STFTParams = None, - offset: float = 0, - duration: float = None, - device: str = None, - ): - audio_path = None - audio_array = None - - if isinstance(audio_path_or_array, str): - audio_path = audio_path_or_array - elif isinstance(audio_path_or_array, pathlib.Path): - audio_path = audio_path_or_array - elif isinstance(audio_path_or_array, np.ndarray): - audio_array = audio_path_or_array - elif torch.is_tensor(audio_path_or_array): - audio_array = audio_path_or_array - else: - raise ValueError( - "audio_path_or_array must be either a Path, " - "string, numpy array, or torch Tensor!" - ) - - self.path_to_file = None - - self.audio_data = None - self.sources = None # List of AudioSignal objects. - self.stft_data = None - if audio_path is not None: - self.load_from_file( - audio_path, offset=offset, duration=duration, device=device - ) - elif audio_array is not None: - assert sample_rate is not None, "Must set sample rate!" - self.load_from_array(audio_array, sample_rate, device=device) - - self.window = None - self.stft_params = stft_params - - self.metadata = { - "offset": offset, - "duration": duration, - } - - @property - def path_to_input_file( - self, - ): - """ - Path to input file, if it exists. - Alias to ``path_to_file`` for backwards compatibility - """ - return self.path_to_file - - @classmethod - def excerpt( - cls, - audio_path: typing.Union[str, Path], - offset: float = None, - duration: float = None, - state: typing.Union[np.random.RandomState, int] = None, - **kwargs, - ): - """Randomly draw an excerpt of ``duration`` seconds from an - audio file specified at ``audio_path``, between ``offset`` seconds - and end of file. ``state`` can be used to seed the random draw. - - Parameters - ---------- - audio_path : typing.Union[str, Path] - Path to audio file to grab excerpt from. - offset : float, optional - Lower bound for the start time, in seconds drawn from - the file, by default None. - duration : float, optional - Duration of excerpt, in seconds, by default None - state : typing.Union[np.random.RandomState, int], optional - RandomState or seed of random state, by default None - - Returns - ------- - AudioSignal - AudioSignal containing excerpt. - - Examples - -------- - >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) - """ - info = util.info(audio_path) - total_duration = info.duration - - state = util.random_state(state) - lower_bound = 0 if offset is None else offset - upper_bound = max(total_duration - duration, 0) - offset = state.uniform(lower_bound, upper_bound) - - signal = cls(audio_path, offset=offset, duration=duration, **kwargs) - signal.metadata["offset"] = offset - signal.metadata["duration"] = duration - - return signal - - @classmethod - def salient_excerpt( - cls, - audio_path: typing.Union[str, Path], - loudness_cutoff: float = None, - num_tries: int = 8, - state: typing.Union[np.random.RandomState, int] = None, - **kwargs, - ): - """Similar to AudioSignal.excerpt, except it extracts excerpts only - if they are above a specified loudness threshold, which is computed via - a fast LUFS routine. - - Parameters - ---------- - audio_path : typing.Union[str, Path] - Path to audio file to grab excerpt from. - loudness_cutoff : float, optional - Loudness threshold in dB. Typical values are ``-40, -60``, - etc, by default None - num_tries : int, optional - Number of tries to grab an excerpt above the threshold - before giving up, by default 8. - state : typing.Union[np.random.RandomState, int], optional - RandomState or seed of random state, by default None - kwargs : dict - Keyword arguments to AudioSignal.excerpt - - Returns - ------- - AudioSignal - AudioSignal containing excerpt. - - - .. warning:: - if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can - result in an infinite loop if ``audio_path`` does not have - any loud enough excerpts. - - Examples - -------- - >>> signal = AudioSignal.salient_excerpt( - "path/to/audio", - loudness_cutoff=-40, - duration=5 - ) - """ - state = util.random_state(state) - if loudness_cutoff is None: - excerpt = cls.excerpt(audio_path, state=state, **kwargs) - else: - loudness = -np.inf - num_try = 0 - while loudness <= loudness_cutoff: - excerpt = cls.excerpt(audio_path, state=state, **kwargs) - loudness = excerpt.loudness() - num_try += 1 - if num_tries is not None and num_try >= num_tries: - break - return excerpt - - @classmethod - def zeros( - cls, - duration: float, - sample_rate: int, - num_channels: int = 1, - batch_size: int = 1, - **kwargs, - ): - """Helper function create an AudioSignal of all zeros. - - Parameters - ---------- - duration : float - Duration of AudioSignal - sample_rate : int - Sample rate of AudioSignal - num_channels : int, optional - Number of channels, by default 1 - batch_size : int, optional - Batch size, by default 1 - - Returns - ------- - AudioSignal - AudioSignal containing all zeros. - - Examples - -------- - Generate 5 seconds of all zeros at a sample rate of 44100. - - >>> signal = AudioSignal.zeros(5.0, 44100) - """ - n_samples = int(duration * sample_rate) - return cls( - torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs - ) - - @classmethod - def wave( - cls, - frequency: float, - duration: float, - sample_rate: int, - num_channels: int = 1, - shape: str = "sine", - **kwargs, - ): - """ - Generate a waveform of a given frequency and shape. - - Parameters - ---------- - frequency : float - Frequency of the waveform - duration : float - Duration of the waveform - sample_rate : int - Sample rate of the waveform - num_channels : int, optional - Number of channels, by default 1 - shape : str, optional - Shape of the waveform, by default "saw" - One of "sawtooth", "square", "sine", "triangle" - kwargs : dict - Keyword arguments to AudioSignal - """ - n_samples = int(duration * sample_rate) - t = torch.linspace(0, duration, n_samples) - if shape == "sawtooth": - from scipy.signal import sawtooth - - wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) - elif shape == "square": - from scipy.signal import square - - wave_data = square(2 * np.pi * frequency * t) - elif shape == "sine": - wave_data = np.sin(2 * np.pi * frequency * t) - elif shape == "triangle": - from scipy.signal import sawtooth - - # frequency is doubled by the abs call, so omit the 2 in 2pi - wave_data = sawtooth(np.pi * frequency * t, 0.5) - wave_data = -np.abs(wave_data) * 2 + 1 - else: - raise ValueError(f"Invalid shape {shape}") - - wave_data = torch.tensor(wave_data, dtype=torch.float32) - wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1) - return cls(wave_data, sample_rate, **kwargs) - - @classmethod - def batch( - cls, - audio_signals: list, - pad_signals: bool = False, - truncate_signals: bool = False, - resample: bool = False, - dim: int = 0, - ): - """Creates a batched AudioSignal from a list of AudioSignals. - - Parameters - ---------- - audio_signals : list[AudioSignal] - List of AudioSignal objects - pad_signals : bool, optional - Whether to pad signals to length of the maximum length - AudioSignal in the list, by default False - truncate_signals : bool, optional - Whether to truncate signals to length of shortest length - AudioSignal in the list, by default False - resample : bool, optional - Whether to resample AudioSignal to the sample rate of - the first AudioSignal in the list, by default False - dim : int, optional - Dimension along which to batch the signals. - - Returns - ------- - AudioSignal - Batched AudioSignal. - - Raises - ------ - RuntimeError - If not all AudioSignals are the same sample rate, and - ``resample=False``, an error is raised. - RuntimeError - If not all AudioSignals are the same the length, and - both ``pad_signals=False`` and ``truncate_signals=False``, - an error is raised. - - Examples - -------- - Batching a bunch of random signals: - - >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)] - >>> signal = AudioSignal.batch(signal_list) - >>> print(signal.shape) - (10, 1, 44100) - - """ - signal_lengths = [x.signal_length for x in audio_signals] - sample_rates = [x.sample_rate for x in audio_signals] - - if len(set(sample_rates)) != 1: - if resample: - for x in audio_signals: - x.resample(sample_rates[0]) - else: - raise RuntimeError( - f"Not all signals had the same sample rate! Got {sample_rates}. " - f"All signals must have the same sample rate, or resample must be True. " - ) - - if len(set(signal_lengths)) != 1: - if pad_signals: - max_length = max(signal_lengths) - for x in audio_signals: - pad_len = max_length - x.signal_length - x.zero_pad(0, pad_len) - elif truncate_signals: - min_length = min(signal_lengths) - for x in audio_signals: - x.truncate_samples(min_length) - else: - raise RuntimeError( - f"Not all signals had the same length! Got {signal_lengths}. " - f"All signals must be the same length, or pad_signals/truncate_signals " - f"must be True. " - ) - # Concatenate along the specified dimension (default 0) - audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim) - audio_paths = [x.path_to_file for x in audio_signals] - - batched_signal = cls( - audio_data, - sample_rate=audio_signals[0].sample_rate, - ) - batched_signal.path_to_file = audio_paths - return batched_signal - - # I/O - def load_from_file( - self, - audio_path: typing.Union[str, Path], - offset: float, - duration: float, - device: str = "cpu", - ): - """Loads data from file. Used internally when AudioSignal - is instantiated with a path to a file. - - Parameters - ---------- - audio_path : typing.Union[str, Path] - Path to file - offset : float - Offset in seconds - duration : float - Duration in seconds - device : str, optional - Device to put AudioSignal on, by default "cpu" - - Returns - ------- - AudioSignal - AudioSignal loaded from file - """ - import librosa - - data, sample_rate = librosa.load( - audio_path, - offset=offset, - duration=duration, - sr=None, - mono=False, - ) - data = util.ensure_tensor(data) - if data.shape[-1] == 0: - raise RuntimeError( - f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" - ) - - if data.ndim < 2: - data = data.unsqueeze(0) - if data.ndim < 3: - data = data.unsqueeze(0) - self.audio_data = data - - self.original_signal_length = self.signal_length - - self.sample_rate = sample_rate - self.path_to_file = audio_path - return self.to(device) - - def load_from_array( - self, - audio_array: typing.Union[torch.Tensor, np.ndarray], - sample_rate: int, - device: str = "cpu", - ): - """Loads data from array, reshaping it to be exactly 3 - dimensions. Used internally when AudioSignal is called - with a tensor or an array. - - Parameters - ---------- - audio_array : typing.Union[torch.Tensor, np.ndarray] - Array/tensor of audio of samples. - sample_rate : int - Sample rate of audio - device : str, optional - Device to move audio onto, by default "cpu" - - Returns - ------- - AudioSignal - AudioSignal loaded from array - """ - audio_data = util.ensure_tensor(audio_array) - - if audio_data.dtype == torch.double: - audio_data = audio_data.float() - - if audio_data.ndim < 2: - audio_data = audio_data.unsqueeze(0) - if audio_data.ndim < 3: - audio_data = audio_data.unsqueeze(0) - self.audio_data = audio_data - - self.original_signal_length = self.signal_length - - self.sample_rate = sample_rate - return self.to(device) - - def write(self, audio_path: typing.Union[str, Path]): - """Writes audio to a file. Only writes the audio - that is in the very first item of the batch. To write other items - in the batch, index the signal along the batch dimension - before writing. After writing, the signal's ``path_to_file`` - attribute is updated to the new path. - - Parameters - ---------- - audio_path : typing.Union[str, Path] - Path to write audio to. - - Returns - ------- - AudioSignal - Returns original AudioSignal, so you can use this in a fluent - interface. - - Examples - -------- - Creating and writing a signal to disk: - - >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100) - >>> signal.write("/tmp/out.wav") - - Writing a different element of the batch: - - >>> signal[5].write("/tmp/out.wav") - - Using this in a fluent interface: - - >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") - - """ - if self.audio_data[0].abs().max() > 1: - warnings.warn("Audio amplitude > 1 clipped when saving") - soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate) - - self.path_to_file = audio_path - return self - - def deepcopy(self): - """Copies the signal and all of its attributes. - - Returns - ------- - AudioSignal - Deep copy of the audio signal. - """ - return copy.deepcopy(self) - - def copy(self): - """Shallow copy of signal. - - Returns - ------- - AudioSignal - Shallow copy of the audio signal. - """ - return copy.copy(self) - - def clone(self): - """Clones all tensors contained in the AudioSignal, - and returns a copy of the signal with everything - cloned. Useful when using AudioSignal within autograd - computation graphs. - - Relevant attributes are the stft data, the audio data, - and the loudness of the file. - - Returns - ------- - AudioSignal - Clone of AudioSignal. - """ - clone = type(self)( - self.audio_data.clone(), - self.sample_rate, - stft_params=self.stft_params, - ) - if self.stft_data is not None: - clone.stft_data = self.stft_data.clone() - if self._loudness is not None: - clone._loudness = self._loudness.clone() - clone.path_to_file = copy.deepcopy(self.path_to_file) - clone.metadata = copy.deepcopy(self.metadata) - return clone - - def detach(self): - """Detaches tensors contained in AudioSignal. - - Relevant attributes are the stft data, the audio data, - and the loudness of the file. - - Returns - ------- - AudioSignal - Same signal, but with all tensors detached. - """ - if self._loudness is not None: - self._loudness = self._loudness.detach() - if self.stft_data is not None: - self.stft_data = self.stft_data.detach() - - self.audio_data = self.audio_data.detach() - return self - - def hash(self): - """Writes the audio data to a temporary file, and then - hashes it using hashlib. Useful for creating a file - name based on the audio content. - - Returns - ------- - str - Hash of audio data. - - Examples - -------- - Creating a signal, and writing it to a unique file name: - - >>> signal = AudioSignal(torch.randn(44100), 44100) - >>> hash = signal.hash() - >>> signal.write(f"{hash}.wav") - - """ - with tempfile.NamedTemporaryFile(suffix=".wav") as f: - self.write(f.name) - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(f.name, "rb", buffering=0) as f: - for n in iter(lambda: f.readinto(mv), 0): - h.update(mv[:n]) - file_hash = h.hexdigest() - return file_hash - - # Signal operations - def to_mono(self): - """Converts audio data to mono audio, by taking the mean - along the channels dimension. - - Returns - ------- - AudioSignal - AudioSignal with mean of channels. - """ - self.audio_data = self.audio_data.mean(1, keepdim=True) - return self - - def resample(self, sample_rate: int): - """Resamples the audio, using sinc interpolation. This works on both - cpu and gpu, and is much faster on gpu. - - Parameters - ---------- - sample_rate : int - Sample rate to resample to. - - Returns - ------- - AudioSignal - Resampled AudioSignal - """ - if sample_rate == self.sample_rate: - return self - self.audio_data = julius.resample_frac( - self.audio_data, self.sample_rate, sample_rate - ) - self.sample_rate = sample_rate - return self - - # Tensor operations - def to(self, device: str): - """Moves all tensors contained in signal to the specified device. - - Parameters - ---------- - device : str - Device to move AudioSignal onto. Typical values are - "cuda", "cpu", or "cuda:n" to specify the nth gpu. - - Returns - ------- - AudioSignal - AudioSignal with all tensors moved to specified device. - """ - if self._loudness is not None: - self._loudness = self._loudness.to(device) - if self.stft_data is not None: - self.stft_data = self.stft_data.to(device) - if self.audio_data is not None: - self.audio_data = self.audio_data.to(device) - return self - - def float(self): - """Calls ``.float()`` on ``self.audio_data``. - - Returns - ------- - AudioSignal - """ - self.audio_data = self.audio_data.float() - return self - - def cpu(self): - """Moves AudioSignal to cpu. - - Returns - ------- - AudioSignal - """ - return self.to("cpu") - - def cuda(self): # pragma: no cover - """Moves AudioSignal to cuda. - - Returns - ------- - AudioSignal - """ - return self.to("cuda") - - def numpy(self): - """Detaches ``self.audio_data``, moves to cpu, and converts to numpy. - - Returns - ------- - np.ndarray - Audio data as a numpy array. - """ - return self.audio_data.detach().cpu().numpy() - - def zero_pad(self, before: int, after: int): - """Zero pads the audio_data tensor before and after. - - Parameters - ---------- - before : int - How many zeros to prepend to audio. - after : int - How many zeros to append to audio. - - Returns - ------- - AudioSignal - AudioSignal with padding applied. - """ - self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after)) - return self - - def zero_pad_to(self, length: int, mode: str = "after"): - """Pad with zeros to a specified length, either before or after - the audio data. - - Parameters - ---------- - length : int - Length to pad to - mode : str, optional - Whether to prepend or append zeros to signal, by default "after" - - Returns - ------- - AudioSignal - AudioSignal with padding applied. - """ - if mode == "before": - self.zero_pad(max(length - self.signal_length, 0), 0) - elif mode == "after": - self.zero_pad(0, max(length - self.signal_length, 0)) - return self - - def trim(self, before: int, after: int): - """Trims the audio_data tensor before and after. - - Parameters - ---------- - before : int - How many samples to trim from beginning. - after : int - How many samples to trim from end. - - Returns - ------- - AudioSignal - AudioSignal with trimming applied. - """ - if after == 0: - self.audio_data = self.audio_data[..., before:] - else: - self.audio_data = self.audio_data[..., before:-after] - return self - - def truncate_samples(self, length_in_samples: int): - """Truncate signal to specified length. - - Parameters - ---------- - length_in_samples : int - Truncate to this many samples. - - Returns - ------- - AudioSignal - AudioSignal with truncation applied. - """ - self.audio_data = self.audio_data[..., :length_in_samples] - return self - - @property - def device(self): - """Get device that AudioSignal is on. - - Returns - ------- - torch.device - Device that AudioSignal is on. - """ - if self.audio_data is not None: - device = self.audio_data.device - elif self.stft_data is not None: - device = self.stft_data.device - return device - - # Properties - @property - def audio_data(self): - """Returns the audio data tensor in the object. - - Audio data is always of the shape - (batch_size, num_channels, num_samples). If value has less - than 3 dims (e.g. is (num_channels, num_samples)), then it will - be reshaped to (1, num_channels, num_samples) - a batch size of 1. - - Parameters - ---------- - data : typing.Union[torch.Tensor, np.ndarray] - Audio data to set. - - Returns - ------- - torch.Tensor - Audio samples. - """ - return self._audio_data - - @audio_data.setter - def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]): - if data is not None: - assert torch.is_tensor(data), "audio_data should be torch.Tensor" - assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" - self._audio_data = data - # Old loudness value not guaranteed to be right, reset it. - self._loudness = None - return - - # alias for audio_data - samples = audio_data - - @property - def stft_data(self): - """Returns the STFT data inside the signal. Shape is - (batch, channels, frequencies, time). - - Returns - ------- - torch.Tensor - Complex spectrogram data. - """ - return self._stft_data - - @stft_data.setter - def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]): - if data is not None: - assert torch.is_tensor(data) and torch.is_complex(data) - if self.stft_data is not None and self.stft_data.shape != data.shape: - warnings.warn("stft_data changed shape") - self._stft_data = data - return - - @property - def batch_size(self): - """Batch size of audio signal. - - Returns - ------- - int - Batch size of signal. - """ - return self.audio_data.shape[0] - - @property - def signal_length(self): - """Length of audio signal. - - Returns - ------- - int - Length of signal in samples. - """ - return self.audio_data.shape[-1] - - # alias for signal_length - length = signal_length - - @property - def shape(self): - """Shape of audio data. - - Returns - ------- - tuple - Shape of audio data. - """ - return self.audio_data.shape - - @property - def signal_duration(self): - """Length of audio signal in seconds. - - Returns - ------- - float - Length of signal in seconds. - """ - return self.signal_length / self.sample_rate - - # alias for signal_duration - duration = signal_duration - - @property - def num_channels(self): - """Number of audio channels. - - Returns - ------- - int - Number of audio channels. - """ - return self.audio_data.shape[1] - - # STFT - @staticmethod - @functools.lru_cache(None) - def get_window(window_type: str, window_length: int, device: str): - """Wrapper around scipy.signal.get_window so one can also get the - popular sqrt-hann window. This function caches for efficiency - using functools.lru\_cache. - - Parameters - ---------- - window_type : str - Type of window to get - window_length : int - Length of the window - device : str - Device to put window onto. - - Returns - ------- - torch.Tensor - Window returned by scipy.signal.get_window, as a tensor. - """ - from scipy import signal - - if window_type == "average": - window = np.ones(window_length) / window_length - elif window_type == "sqrt_hann": - window = np.sqrt(signal.get_window("hann", window_length)) - else: - window = signal.get_window(window_type, window_length) - window = torch.from_numpy(window).to(device).float() - return window - - @property - def stft_params(self): - """Returns STFTParams object, which can be re-used to other - AudioSignals. - - This property can be set as well. If values are not defined in STFTParams, - they are inferred automatically from the signal properties. The default is to use - 32ms windows, with 8ms hop length, and the square root of the hann window. - - Returns - ------- - STFTParams - STFT parameters for the AudioSignal. - - Examples - -------- - >>> stft_params = STFTParams(128, 32) - >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params) - >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params) - >>> signal1.stft_params = STFTParams() # Defaults - """ - return self._stft_params - - @stft_params.setter - def stft_params(self, value: STFTParams): - default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) - default_hop_len = default_win_len // 4 - default_win_type = "hann" - default_match_stride = False - default_padding_type = "reflect" - - default_stft_params = STFTParams( - window_length=default_win_len, - hop_length=default_hop_len, - window_type=default_win_type, - match_stride=default_match_stride, - padding_type=default_padding_type, - )._asdict() - - value = value._asdict() if value else default_stft_params - - for key in default_stft_params: - if value[key] is None: - value[key] = default_stft_params[key] - - self._stft_params = STFTParams(**value) - self.stft_data = None - - def compute_stft_padding( - self, window_length: int, hop_length: int, match_stride: bool - ): - """Compute how the STFT should be padded, based on match\_stride. - - Parameters - ---------- - window_length : int - Window length of STFT. - hop_length : int - Hop length of STFT. - match_stride : bool - Whether or not to match stride, making the STFT have the same alignment as - convolutional layers. - - Returns - ------- - tuple - Amount to pad on either side of audio. - """ - length = self.signal_length - - if match_stride: - assert ( - hop_length == window_length // 4 - ), "For match_stride, hop must equal n_fft // 4" - right_pad = math.ceil(length / hop_length) * hop_length - length - pad = (window_length - hop_length) // 2 - else: - right_pad = 0 - pad = 0 - - return right_pad, pad - - def stft( - self, - window_length: int = None, - hop_length: int = None, - window_type: str = None, - match_stride: bool = None, - padding_type: str = None, - ): - """Computes the short-time Fourier transform of the audio data, - with specified STFT parameters. - - Parameters - ---------- - window_length : int, optional - Window length of STFT, by default ``0.032 * self.sample_rate``. - hop_length : int, optional - Hop length of STFT, by default ``window_length // 4``. - window_type : str, optional - Type of window to use, by default ``sqrt\_hann``. - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - padding_type : str, optional - Type of padding to use, by default 'reflect' - - Returns - ------- - torch.Tensor - STFT of audio data. - - Examples - -------- - Compute the STFT of an AudioSignal: - - >>> signal = AudioSignal(torch.randn(44100), 44100) - >>> signal.stft() - - Vary the window and hop length: - - >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)] - >>> for stft_param in stft_params: - >>> signal.stft_params = stft_params - >>> signal.stft() - - """ - window_length = ( - self.stft_params.window_length - if window_length is None - else int(window_length) - ) - hop_length = ( - self.stft_params.hop_length if hop_length is None else int(hop_length) - ) - window_type = ( - self.stft_params.window_type if window_type is None else window_type - ) - match_stride = ( - self.stft_params.match_stride if match_stride is None else match_stride - ) - padding_type = ( - self.stft_params.padding_type if padding_type is None else padding_type - ) - - window = self.get_window(window_type, window_length, self.audio_data.device) - window = window.to(self.audio_data.device) - - audio_data = self.audio_data - right_pad, pad = self.compute_stft_padding( - window_length, hop_length, match_stride - ) - audio_data = torch.nn.functional.pad( - audio_data, (pad, pad + right_pad), padding_type - ) - stft_data = torch.stft( - audio_data.reshape(-1, audio_data.shape[-1]), - n_fft=window_length, - hop_length=hop_length, - window=window, - return_complex=True, - center=True, - ) - _, nf, nt = stft_data.shape - stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt) - - if match_stride: - # Drop first two and last two frames, which are added - # because of padding. Now num_frames * hop_length = num_samples. - stft_data = stft_data[..., 2:-2] - self.stft_data = stft_data - - return stft_data - - def istft( - self, - window_length: int = None, - hop_length: int = None, - window_type: str = None, - match_stride: bool = None, - length: int = None, - ): - """Computes inverse STFT and sets it to audio\_data. - - Parameters - ---------- - window_length : int, optional - Window length of STFT, by default ``0.032 * self.sample_rate``. - hop_length : int, optional - Hop length of STFT, by default ``window_length // 4``. - window_type : str, optional - Type of window to use, by default ``sqrt\_hann``. - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - length : int, optional - Original length of signal, by default None - - Returns - ------- - AudioSignal - AudioSignal with istft applied. - - Raises - ------ - RuntimeError - Raises an error if stft was not called prior to istft on the signal, - or if stft_data is not set. - """ - if self.stft_data is None: - raise RuntimeError("Cannot do inverse STFT without self.stft_data!") - - window_length = ( - self.stft_params.window_length - if window_length is None - else int(window_length) - ) - hop_length = ( - self.stft_params.hop_length if hop_length is None else int(hop_length) - ) - window_type = ( - self.stft_params.window_type if window_type is None else window_type - ) - match_stride = ( - self.stft_params.match_stride if match_stride is None else match_stride - ) - - window = self.get_window(window_type, window_length, self.stft_data.device) - - nb, nch, nf, nt = self.stft_data.shape - stft_data = self.stft_data.reshape(nb * nch, nf, nt) - right_pad, pad = self.compute_stft_padding( - window_length, hop_length, match_stride - ) - - if length is None: - length = self.original_signal_length - length = length + 2 * pad + right_pad - - if match_stride: - # Zero-pad the STFT on either side, putting back the frames that were - # dropped in stft(). - stft_data = torch.nn.functional.pad(stft_data, (2, 2)) - - audio_data = torch.istft( - stft_data, - n_fft=window_length, - hop_length=hop_length, - window=window, - length=length, - center=True, - ) - audio_data = audio_data.reshape(nb, nch, -1) - if match_stride: - audio_data = audio_data[..., pad : -(pad + right_pad)] - self.audio_data = audio_data - - return self - - @staticmethod - @functools.lru_cache(None) - def get_mel_filters( - sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None - ): - """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. - - Parameters - ---------- - sr : int - Sample rate of audio - n_fft : int - Number of FFT bins - n_mels : int - Number of mels - fmin : float, optional - Lowest frequency, in Hz, by default 0.0 - fmax : float, optional - Highest frequency, by default None - - Returns - ------- - np.ndarray [shape=(n_mels, 1 + n_fft/2)] - Mel transform matrix - """ - from librosa.filters import mel as librosa_mel_fn - - return librosa_mel_fn( - sr=sr, - n_fft=n_fft, - n_mels=n_mels, - fmin=fmin, - fmax=fmax, - ) - - def mel_spectrogram( - self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs - ): - """Computes a Mel spectrogram. - - Parameters - ---------- - n_mels : int, optional - Number of mels, by default 80 - mel_fmin : float, optional - Lowest frequency, in Hz, by default 0.0 - mel_fmax : float, optional - Highest frequency, by default None - kwargs : dict, optional - Keyword arguments to self.stft(). - - Returns - ------- - torch.Tensor [shape=(batch, channels, mels, time)] - Mel spectrogram. - """ - stft = self.stft(**kwargs) - magnitude = torch.abs(stft) - - nf = magnitude.shape[2] - mel_basis = self.get_mel_filters( - sr=self.sample_rate, - n_fft=2 * (nf - 1), - n_mels=n_mels, - fmin=mel_fmin, - fmax=mel_fmax, - ) - mel_basis = torch.from_numpy(mel_basis).to(self.device) - - mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T - mel_spectrogram = mel_spectrogram.transpose(-1, 2) - return mel_spectrogram - - @staticmethod - @functools.lru_cache(None) - def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None): - """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), - it can be normalized depending on norm. For more information about dct: - http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II - - Parameters - ---------- - n_mfcc : int - Number of mfccs - n_mels : int - Number of mels - norm : str - Use "ortho" to get a orthogonal matrix or None, by default "ortho" - device : str, optional - Device to load the transformation matrix on, by default None - - Returns - ------- - torch.Tensor [shape=(n_mels, n_mfcc)] T - The dct transformation matrix. - """ - from torchaudio.functional import create_dct - - return create_dct(n_mfcc, n_mels, norm).to(device) - - def mfcc( - self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs - ): - """Computes mel-frequency cepstral coefficients (MFCCs). - - Parameters - ---------- - n_mfcc : int, optional - Number of mels, by default 40 - n_mels : int, optional - Number of mels, by default 80 - log_offset: float, optional - Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 - kwargs : dict, optional - Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() - - Returns - ------- - torch.Tensor [shape=(batch, channels, mfccs, time)] - MFCCs. - """ - - mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) - mel_spectrogram = torch.log(mel_spectrogram + log_offset) - dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) - - mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat - mfcc = mfcc.transpose(-1, -2) - return mfcc - - @property - def magnitude(self): - """Computes and returns the absolute value of the STFT, which - is the magnitude. This value can also be set to some tensor. - When set, ``self.stft_data`` is manipulated so that its magnitude - matches what this is set to, and modulated by the phase. - - Returns - ------- - torch.Tensor - Magnitude of STFT. - - Examples - -------- - >>> signal = AudioSignal(torch.randn(44100), 44100) - >>> magnitude = signal.magnitude # Computes stft if not computed - >>> magnitude[magnitude < magnitude.mean()] = 0 - >>> signal.magnitude = magnitude - >>> signal.istft() - """ - if self.stft_data is None: - self.stft() - return torch.abs(self.stft_data) - - @magnitude.setter - def magnitude(self, value): - self.stft_data = value * torch.exp(1j * self.phase) - return - - def log_magnitude( - self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 - ): - """Computes the log-magnitude of the spectrogram. - - Parameters - ---------- - ref_value : float, optional - The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. - Zeros in the output correspond to positions where ``S == ref``, - by default 1.0 - amin : float, optional - Minimum threshold for ``S`` and ``ref``, by default 1e-5 - top_db : float, optional - Threshold the output at ``top_db`` below the peak: - ``max(10 * log10(S/ref)) - top_db``, by default -80.0 - - Returns - ------- - torch.Tensor - Log-magnitude spectrogram - """ - magnitude = self.magnitude - - amin = amin**2 - log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin)) - log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) - - if top_db is not None: - log_spec = torch.maximum(log_spec, log_spec.max() - top_db) - return log_spec - - @property - def phase(self): - """Computes and returns the phase of the STFT. - This value can also be set to some tensor. - When set, ``self.stft_data`` is manipulated so that its phase - matches what this is set to, we original magnitudeith th. - - Returns - ------- - torch.Tensor - Phase of STFT. - - Examples - -------- - >>> signal = AudioSignal(torch.randn(44100), 44100) - >>> phase = signal.phase # Computes stft if not computed - >>> phase[phase < phase.mean()] = 0 - >>> signal.phase = phase - >>> signal.istft() - """ - if self.stft_data is None: - self.stft() - return torch.angle(self.stft_data) - - @phase.setter - def phase(self, value): - self.stft_data = self.magnitude * torch.exp(1j * value) - return - - # Operator overloading - def __add__(self, other): - new_signal = self.clone() - new_signal.audio_data += util._get_value(other) - return new_signal - - def __iadd__(self, other): - self.audio_data += util._get_value(other) - return self - - def __radd__(self, other): - return self + other - - def __sub__(self, other): - new_signal = self.clone() - new_signal.audio_data -= util._get_value(other) - return new_signal - - def __isub__(self, other): - self.audio_data -= util._get_value(other) - return self - - def __mul__(self, other): - new_signal = self.clone() - new_signal.audio_data *= util._get_value(other) - return new_signal - - def __imul__(self, other): - self.audio_data *= util._get_value(other) - return self - - def __rmul__(self, other): - return self * other - - # Representation - def _info(self): - dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" - info = { - "duration": f"{dur} seconds", - "batch_size": self.batch_size, - "path": self.path_to_file if self.path_to_file else "path unknown", - "sample_rate": self.sample_rate, - "num_channels": self.num_channels if self.num_channels else "[unknown]", - "audio_data.shape": self.audio_data.shape, - "stft_params": self.stft_params, - "device": self.device, - } - - return info - - def markdown(self): - """Produces a markdown representation of AudioSignal, in a markdown table. - - Returns - ------- - str - Markdown representation of AudioSignal. - - Examples - -------- - >>> signal = AudioSignal(torch.randn(44100), 44100) - >>> print(signal.markdown()) - | Key | Value - |---|--- - | duration | 1.000 seconds | - | batch_size | 1 | - | path | path unknown | - | sample_rate | 44100 | - | num_channels | 1 | - | audio_data.shape | torch.Size([1, 1, 44100]) | - | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | - | device | cpu | - """ - info = self._info() - - FORMAT = "| Key | Value \n" "|---|--- \n" - for k, v in info.items(): - row = f"| {k} | {v} |\n" - FORMAT += row - return FORMAT - - def __str__(self): - info = self._info() - - desc = "" - for k, v in info.items(): - desc += f"{k}: {v}\n" - return desc - - def __rich__(self): - from rich.table import Table - - info = self._info() - - table = Table(title=f"{self.__class__.__name__}") - table.add_column("Key", style="green") - table.add_column("Value", style="cyan") - - for k, v in info.items(): - table.add_row(k, str(v)) - return table - - # Comparison - def __eq__(self, other): - for k, v in list(self.__dict__.items()): - if torch.is_tensor(v): - if not torch.allclose(v, other.__dict__[k], atol=1e-6): - max_error = (v - other.__dict__[k]).abs().max() - print(f"Max abs error for {k}: {max_error}") - return False - return True - - # Indexing - def __getitem__(self, key): - if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: - assert self.batch_size == 1 - audio_data = self.audio_data - _loudness = self._loudness - stft_data = self.stft_data - - elif isinstance(key, (bool, int, list, slice, tuple)) or ( - torch.is_tensor(key) and key.ndim <= 1 - ): - # Indexing only on the batch dimension. - # Then let's copy over relevant stuff. - # Future work: make this work for time-indexing - # as well, using the hop length. - audio_data = self.audio_data[key] - _loudness = self._loudness[key] if self._loudness is not None else None - stft_data = self.stft_data[key] if self.stft_data is not None else None - - sources = None - - copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params) - copy._loudness = _loudness - copy._stft_data = stft_data - copy.sources = sources - - return copy - - def __setitem__(self, key, value): - if not isinstance(value, type(self)): - self.audio_data[key] = value - return - - if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: - assert self.batch_size == 1 - self.audio_data = value.audio_data - self._loudness = value._loudness - self.stft_data = value.stft_data - return - - elif isinstance(key, (bool, int, list, slice, tuple)) or ( - torch.is_tensor(key) and key.ndim <= 1 - ): - if self.audio_data is not None and value.audio_data is not None: - self.audio_data[key] = value.audio_data - if self._loudness is not None and value._loudness is not None: - self._loudness[key] = value._loudness - if self.stft_data is not None and value.stft_data is not None: - self.stft_data[key] = value.stft_data - return - - def __ne__(self, other): - return not self == other diff --git a/dito/models/ldm/dac/audiotools/core/display.py b/dito/models/ldm/dac/audiotools/core/display.py deleted file mode 100644 index 66cbcf34cb2cf9fdf8d67ec4418a887eba73f184..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/display.py +++ /dev/null @@ -1,194 +0,0 @@ -import inspect -import typing -from functools import wraps - -from . import util - - -def format_figure(func): - """Decorator for formatting figures produced by the code below. - See :py:func:`audiotools.core.util.format_figure` for more. - - Parameters - ---------- - func : Callable - Plotting function that is decorated by this function. - - """ - - @wraps(func) - def wrapper(*args, **kwargs): - f_keys = inspect.signature(util.format_figure).parameters.keys() - f_kwargs = {} - for k, v in list(kwargs.items()): - if k in f_keys: - kwargs.pop(k) - f_kwargs[k] = v - func(*args, **kwargs) - util.format_figure(**f_kwargs) - - return wrapper - - -class DisplayMixin: - @format_figure - def specshow( - self, - preemphasis: bool = False, - x_axis: str = "time", - y_axis: str = "linear", - n_mels: int = 128, - **kwargs, - ): - """Displays a spectrogram, using ``librosa.display.specshow``. - - Parameters - ---------- - preemphasis : bool, optional - Whether or not to apply preemphasis, which makes high - frequency detail easier to see, by default False - x_axis : str, optional - How to label the x axis, by default "time" - y_axis : str, optional - How to label the y axis, by default "linear" - n_mels : int, optional - If displaying a mel spectrogram with ``y_axis = "mel"``, - this controls the number of mels, by default 128. - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.util.format_figure`. - """ - import librosa - import librosa.display - - # Always re-compute the STFT data before showing it, in case - # it changed. - signal = self.clone() - signal.stft_data = None - - if preemphasis: - signal.preemphasis() - - ref = signal.magnitude.max() - log_mag = signal.log_magnitude(ref_value=ref) - - if y_axis == "mel": - log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() - log_mag -= log_mag.max() - - librosa.display.specshow( - log_mag.numpy()[0].mean(axis=0), - x_axis=x_axis, - y_axis=y_axis, - sr=signal.sample_rate, - **kwargs, - ) - - @format_figure - def waveplot(self, x_axis: str = "time", **kwargs): - """Displays a waveform plot, using ``librosa.display.waveshow``. - - Parameters - ---------- - x_axis : str, optional - How to label the x axis, by default "time" - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.util.format_figure`. - """ - import librosa - import librosa.display - - audio_data = self.audio_data[0].mean(dim=0) - audio_data = audio_data.cpu().numpy() - - plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" - wave_plot_fn = getattr(librosa.display, plot_fn) - wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) - - @format_figure - def wavespec(self, x_axis: str = "time", **kwargs): - """Displays a waveform plot, using ``librosa.display.waveshow``. - - Parameters - ---------- - x_axis : str, optional - How to label the x axis, by default "time" - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. - """ - import matplotlib.pyplot as plt - from matplotlib.gridspec import GridSpec - - gs = GridSpec(6, 1) - plt.subplot(gs[0, :]) - self.waveplot(x_axis=x_axis) - plt.subplot(gs[1:, :]) - self.specshow(x_axis=x_axis, **kwargs) - - def write_audio_to_tb( - self, - tag: str, - writer, - step: int = None, - plot_fn: typing.Union[typing.Callable, str] = "specshow", - **kwargs, - ): - """Writes a signal and its spectrogram to Tensorboard. Will show up - under the Audio and Images tab in Tensorboard. - - Parameters - ---------- - tag : str - Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be - written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). - writer : SummaryWriter - A SummaryWriter object from PyTorch library. - step : int, optional - The step to write the signal to, by default None - plot_fn : typing.Union[typing.Callable, str], optional - How to create the image. Set to ``None`` to avoid plotting, by default "specshow" - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or - whatever ``plot_fn`` is set to. - """ - import matplotlib.pyplot as plt - - audio_data = self.audio_data[0, 0].detach().cpu() - sample_rate = self.sample_rate - writer.add_audio(tag, audio_data, step, sample_rate) - - if plot_fn is not None: - if isinstance(plot_fn, str): - plot_fn = getattr(self, plot_fn) - fig = plt.figure() - plt.clf() - plot_fn(**kwargs) - writer.add_figure(tag.replace("wav", "png"), fig, step) - - def save_image( - self, - image_path: str, - plot_fn: typing.Union[typing.Callable, str] = "specshow", - **kwargs, - ): - """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to - a specified file. - - Parameters - ---------- - image_path : str - Where to save the file to. - plot_fn : typing.Union[typing.Callable, str], optional - How to create the image. Set to ``None`` to avoid plotting, by default "specshow" - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or - whatever ``plot_fn`` is set to. - """ - import matplotlib.pyplot as plt - - if isinstance(plot_fn, str): - plot_fn = getattr(self, plot_fn) - - plt.clf() - plot_fn(**kwargs) - plt.savefig(image_path, bbox_inches="tight", pad_inches=0) - plt.close() diff --git a/dito/models/ldm/dac/audiotools/core/dsp.py b/dito/models/ldm/dac/audiotools/core/dsp.py deleted file mode 100644 index f9be51a119537b77e497ddc2dac126d569533d7c..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/dsp.py +++ /dev/null @@ -1,390 +0,0 @@ -import typing - -import julius -import numpy as np -import torch - -from . import util - - -class DSPMixin: - _original_batch_size = None - _original_num_channels = None - _padded_signal_length = None - - def _preprocess_signal_for_windowing(self, window_duration, hop_duration): - self._original_batch_size = self.batch_size - self._original_num_channels = self.num_channels - - window_length = int(window_duration * self.sample_rate) - hop_length = int(hop_duration * self.sample_rate) - - if window_length % hop_length != 0: - factor = window_length // hop_length - window_length = factor * hop_length - - self.zero_pad(hop_length, hop_length) - self._padded_signal_length = self.signal_length - - return window_length, hop_length - - def windows( - self, window_duration: float, hop_duration: float, preprocess: bool = True - ): - """Generator which yields windows of specified duration from signal with a specified - hop length. - - Parameters - ---------- - window_duration : float - Duration of every window in seconds. - hop_duration : float - Hop between windows in seconds. - preprocess : bool, optional - Whether to preprocess the signal, so that the first sample is in - the middle of the first window, by default True - - Yields - ------ - AudioSignal - Each window is returned as an AudioSignal. - """ - if preprocess: - window_length, hop_length = self._preprocess_signal_for_windowing( - window_duration, hop_duration - ) - - self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length) - - for b in range(self.batch_size): - i = 0 - start_idx = i * hop_length - while True: - start_idx = i * hop_length - i += 1 - end_idx = start_idx + window_length - if end_idx > self.signal_length: - break - yield self[b, ..., start_idx:end_idx] - - def collect_windows( - self, window_duration: float, hop_duration: float, preprocess: bool = True - ): - """Reshapes signal into windows of specified duration from signal with a specified - hop length. Window are placed along the batch dimension. Use with - :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the - original signal. - - Parameters - ---------- - window_duration : float - Duration of every window in seconds. - hop_duration : float - Hop between windows in seconds. - preprocess : bool, optional - Whether to preprocess the signal, so that the first sample is in - the middle of the first window, by default True - - Returns - ------- - AudioSignal - AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` - """ - if preprocess: - window_length, hop_length = self._preprocess_signal_for_windowing( - window_duration, hop_duration - ) - - # self.audio_data: (nb, nch, nt). - unfolded = torch.nn.functional.unfold( - self.audio_data.reshape(-1, 1, 1, self.signal_length), - kernel_size=(1, window_length), - stride=(1, hop_length), - ) - # unfolded: (nb * nch, window_length, num_windows). - # -> (nb * nch * num_windows, 1, window_length) - unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length) - self.audio_data = unfolded - return self - - def overlap_and_add(self, hop_duration: float): - """Function which takes a list of windows and overlap adds them into a - signal the same length as ``audio_signal``. - - Parameters - ---------- - hop_duration : float - How much to shift for each window - (overlap is window_duration - hop_duration) in seconds. - - Returns - ------- - AudioSignal - overlap-and-added signal. - """ - hop_length = int(hop_duration * self.sample_rate) - window_length = self.signal_length - - nb, nch = self._original_batch_size, self._original_num_channels - - unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1) - folded = torch.nn.functional.fold( - unfolded, - output_size=(1, self._padded_signal_length), - kernel_size=(1, window_length), - stride=(1, hop_length), - ) - - norm = torch.ones_like(unfolded, device=unfolded.device) - norm = torch.nn.functional.fold( - norm, - output_size=(1, self._padded_signal_length), - kernel_size=(1, window_length), - stride=(1, hop_length), - ) - - folded = folded / norm - - folded = folded.reshape(nb, nch, -1) - self.audio_data = folded - self.trim(hop_length, hop_length) - return self - - def low_pass( - self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 - ): - """Low-passes the signal in-place. Each item in the batch - can have a different low-pass cutoff, if the input - to this signal is an array or tensor. If a float, all - items are given the same low-pass filter. - - Parameters - ---------- - cutoffs : typing.Union[torch.Tensor, np.ndarray, float] - Cutoff in Hz of low-pass filter. - zeros : int, optional - Number of taps to use in low-pass filter, by default 51 - - Returns - ------- - AudioSignal - Low-passed AudioSignal. - """ - cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) - cutoffs = cutoffs / self.sample_rate - filtered = torch.empty_like(self.audio_data) - - for i, cutoff in enumerate(cutoffs): - lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) - filtered[i] = lp_filter(self.audio_data[i]) - - self.audio_data = filtered - self.stft_data = None - return self - - def high_pass( - self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 - ): - """High-passes the signal in-place. Each item in the batch - can have a different high-pass cutoff, if the input - to this signal is an array or tensor. If a float, all - items are given the same high-pass filter. - - Parameters - ---------- - cutoffs : typing.Union[torch.Tensor, np.ndarray, float] - Cutoff in Hz of high-pass filter. - zeros : int, optional - Number of taps to use in high-pass filter, by default 51 - - Returns - ------- - AudioSignal - High-passed AudioSignal. - """ - cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) - cutoffs = cutoffs / self.sample_rate - filtered = torch.empty_like(self.audio_data) - - for i, cutoff in enumerate(cutoffs): - hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) - filtered[i] = hp_filter(self.audio_data[i]) - - self.audio_data = filtered - self.stft_data = None - return self - - def mask_frequencies( - self, - fmin_hz: typing.Union[torch.Tensor, np.ndarray, float], - fmax_hz: typing.Union[torch.Tensor, np.ndarray, float], - val: float = 0.0, - ): - """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them - with the value specified by ``val``. Useful for implementing SpecAug. - The min and max can be different for every item in the batch. - - Parameters - ---------- - fmin_hz : typing.Union[torch.Tensor, np.ndarray, float] - Lower end of band to mask out. - fmax_hz : typing.Union[torch.Tensor, np.ndarray, float] - Upper end of band to mask out. - val : float, optional - Value to fill in, by default 0.0 - - Returns - ------- - AudioSignal - Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - masked audio data. - """ - # SpecAug - mag, phase = self.magnitude, self.phase - fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim) - fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim) - assert torch.all(fmin_hz < fmax_hz) - - # build mask - nbins = mag.shape[-2] - bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device) - bins_hz = bins_hz[None, None, :, None].repeat( - self.batch_size, 1, 1, mag.shape[-1] - ) - mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) - mask = mask.to(self.device) - - mag = mag.masked_fill(mask, val) - phase = phase.masked_fill(mask, val) - self.stft_data = mag * torch.exp(1j * phase) - return self - - def mask_timesteps( - self, - tmin_s: typing.Union[torch.Tensor, np.ndarray, float], - tmax_s: typing.Union[torch.Tensor, np.ndarray, float], - val: float = 0.0, - ): - """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them - with the value specified by ``val``. Useful for implementing SpecAug. - The min and max can be different for every item in the batch. - - Parameters - ---------- - tmin_s : typing.Union[torch.Tensor, np.ndarray, float] - Lower end of timesteps to mask out. - tmax_s : typing.Union[torch.Tensor, np.ndarray, float] - Upper end of timesteps to mask out. - val : float, optional - Value to fill in, by default 0.0 - - Returns - ------- - AudioSignal - Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - masked audio data. - """ - # SpecAug - mag, phase = self.magnitude, self.phase - tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim) - tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim) - - assert torch.all(tmin_s < tmax_s) - - # build mask - nt = mag.shape[-1] - bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device) - bins_t = bins_t[None, None, None, :].repeat( - self.batch_size, 1, mag.shape[-2], 1 - ) - mask = (tmin_s <= bins_t) & (bins_t < tmax_s) - - mag = mag.masked_fill(mask, val) - phase = phase.masked_fill(mask, val) - self.stft_data = mag * torch.exp(1j * phase) - return self - - def mask_low_magnitudes( - self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0 - ): - """Mask away magnitudes below a specified threshold, which - can be different for every item in the batch. - - Parameters - ---------- - db_cutoff : typing.Union[torch.Tensor, np.ndarray, float] - Decibel value for which things below it will be masked away. - val : float, optional - Value to fill in for masked portions, by default 0.0 - - Returns - ------- - AudioSignal - Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - masked audio data. - """ - mag = self.magnitude - log_mag = self.log_magnitude() - - db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) - mask = log_mag < db_cutoff - mag = mag.masked_fill(mask, val) - - self.magnitude = mag - return self - - def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]): - """Shifts the phase by a constant value. - - Parameters - ---------- - shift : typing.Union[torch.Tensor, np.ndarray, float] - What to shift the phase by. - - Returns - ------- - AudioSignal - Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - masked audio data. - """ - shift = util.ensure_tensor(shift, ndim=self.phase.ndim) - self.phase = self.phase + shift - return self - - def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]): - """Corrupts the phase randomly by some scaled value. - - Parameters - ---------- - scale : typing.Union[torch.Tensor, np.ndarray, float] - Standard deviation of noise to add to the phase. - - Returns - ------- - AudioSignal - Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the - masked audio data. - """ - scale = util.ensure_tensor(scale, ndim=self.phase.ndim) - self.phase = self.phase + scale * torch.randn_like(self.phase) - return self - - def preemphasis(self, coef: float = 0.85): - """Applies pre-emphasis to audio signal. - - Parameters - ---------- - coef : float, optional - How much pre-emphasis to apply, lower values do less. 0 does nothing. - by default 0.85 - - Returns - ------- - AudioSignal - Pre-emphasized signal. - """ - kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device) - x = self.audio_data.reshape(-1, 1, self.signal_length) - x = torch.nn.functional.conv1d(x, kernel, padding=1) - self.audio_data = x.reshape(*self.audio_data.shape) - return self diff --git a/dito/models/ldm/dac/audiotools/core/effects.py b/dito/models/ldm/dac/audiotools/core/effects.py deleted file mode 100644 index fb534cbcb2d457575de685fc9248d1716879145b..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/effects.py +++ /dev/null @@ -1,647 +0,0 @@ -import typing - -import julius -import numpy as np -import torch -import torchaudio - -from . import util - - -class EffectMixin: - GAIN_FACTOR = np.log(10) / 20 - """Gain factor for converting between amplitude and decibels.""" - CODEC_PRESETS = { - "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, - "GSM-FR": {"format": "gsm"}, - "MP3": {"format": "mp3", "compression": -9}, - "Vorbis": {"format": "vorbis", "compression": -1}, - "Ogg": { - "format": "ogg", - "compression": -1, - }, - "Amr-nb": {"format": "amr-nb"}, - } - """Presets for applying codecs via torchaudio.""" - - def mix( - self, - other, - snr: typing.Union[torch.Tensor, np.ndarray, float] = 10, - other_eq: typing.Union[torch.Tensor, np.ndarray] = None, - ): - """Mixes noise with signal at specified - signal-to-noise ratio. Optionally, the - other signal can be equalized in-place. - - - Parameters - ---------- - other : AudioSignal - AudioSignal object to mix with. - snr : typing.Union[torch.Tensor, np.ndarray, float], optional - Signal to noise ratio, by default 10 - other_eq : typing.Union[torch.Tensor, np.ndarray], optional - EQ curve to apply to other signal, if any, by default None - - Returns - ------- - AudioSignal - In-place modification of AudioSignal. - """ - snr = util.ensure_tensor(snr).to(self.device) - - pad_len = max(0, self.signal_length - other.signal_length) - other.zero_pad(0, pad_len) - other.truncate_samples(self.signal_length) - if other_eq is not None: - other = other.equalizer(other_eq) - - tgt_loudness = self.loudness() - snr - other = other.normalize(tgt_loudness) - - self.audio_data = self.audio_data + other.audio_data - return self - - def convolve(self, other, start_at_max: bool = True): - """Convolves self with other. - This function uses FFTs to do the convolution. - - Parameters - ---------- - other : AudioSignal - Signal to convolve with. - start_at_max : bool, optional - Whether to start at the max value of other signal, to - avoid inducing delays, by default True - - Returns - ------- - AudioSignal - Convolved signal, in-place. - """ - from . import AudioSignal - - pad_len = self.signal_length - other.signal_length - - if pad_len > 0: - other.zero_pad(0, pad_len) - else: - other.truncate_samples(self.signal_length) - - if start_at_max: - # Use roll to rotate over the max for every item - # so that the impulse responses don't induce any - # delay. - idx = other.audio_data.abs().argmax(axis=-1) - irs = torch.zeros_like(other.audio_data) - for i in range(other.batch_size): - irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1) - other = AudioSignal(irs, other.sample_rate) - - delta = torch.zeros_like(other.audio_data) - delta[..., 0] = 1 - - length = self.signal_length - delta_fft = torch.fft.rfft(delta, length) - other_fft = torch.fft.rfft(other.audio_data, length) - self_fft = torch.fft.rfft(self.audio_data, length) - - convolved_fft = other_fft * self_fft - convolved_audio = torch.fft.irfft(convolved_fft, length) - - delta_convolved_fft = other_fft * delta_fft - delta_audio = torch.fft.irfft(delta_convolved_fft, length) - - # Use the delta to rescale the audio exactly as needed. - delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0] - scale = 1 / delta_max.clamp(1e-5) - convolved_audio = convolved_audio * scale - - self.audio_data = convolved_audio - - return self - - def apply_ir( - self, - ir, - drr: typing.Union[torch.Tensor, np.ndarray, float] = None, - ir_eq: typing.Union[torch.Tensor, np.ndarray] = None, - use_original_phase: bool = False, - ): - """Applies an impulse response to the signal. If ` is`ir_eq`` - is specified, the impulse response is equalized before - it is applied, using the given curve. - - Parameters - ---------- - ir : AudioSignal - Impulse response to convolve with. - drr : typing.Union[torch.Tensor, np.ndarray, float], optional - Direct-to-reverberant ratio that impulse response will be - altered to, if specified, by default None - ir_eq : typing.Union[torch.Tensor, np.ndarray], optional - Equalization that will be applied to impulse response - if specified, by default None - use_original_phase : bool, optional - Whether to use the original phase, instead of the convolved - phase, by default False - - Returns - ------- - AudioSignal - Signal with impulse response applied to it - """ - if ir_eq is not None: - ir = ir.equalizer(ir_eq) - if drr is not None: - ir = ir.alter_drr(drr) - - # Save the peak before - max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values - - # Augment the impulse response to simulate microphone effects - # and with varying direct-to-reverberant ratio. - phase = self.phase - self.convolve(ir) - - # Use the input phase - if use_original_phase: - self.stft() - self.stft_data = self.magnitude * torch.exp(1j * phase) - self.istft() - - # Rescale to the input's amplitude - max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values - scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8) - self = self * scale_factor - - return self - - def ensure_max_of_audio(self, max: float = 1.0): - """Ensures that ``abs(audio_data) <= max``. - - Parameters - ---------- - max : float, optional - Max absolute value of signal, by default 1.0 - - Returns - ------- - AudioSignal - Signal with values scaled between -max and max. - """ - peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0] - peak_gain = torch.ones_like(peak) - peak_gain[peak > max] = max / peak[peak > max] - self.audio_data = self.audio_data * peak_gain - return self - - def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0): - """Normalizes the signal's volume to the specified db, in LUFS. - This is GPU-compatible, making for very fast loudness normalization. - - Parameters - ---------- - db : typing.Union[torch.Tensor, np.ndarray, float], optional - Loudness to normalize to, by default -24.0 - - Returns - ------- - AudioSignal - Normalized audio signal. - """ - db = util.ensure_tensor(db).to(self.device) - ref_db = self.loudness() - gain = db - ref_db - gain = torch.exp(gain * self.GAIN_FACTOR) - - self.audio_data = self.audio_data * gain[:, None, None] - return self - - def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]): - """Change volume of signal by some amount, in dB. - - Parameters - ---------- - db : typing.Union[torch.Tensor, np.ndarray, float] - Amount to change volume by. - - Returns - ------- - AudioSignal - Signal at new volume. - """ - db = util.ensure_tensor(db, ndim=1).to(self.device) - gain = torch.exp(db * self.GAIN_FACTOR) - self.audio_data = self.audio_data * gain[:, None, None] - return self - - def _to_2d(self): - waveform = self.audio_data.reshape(-1, self.signal_length) - return waveform - - def _to_3d(self, waveform): - return waveform.reshape(self.batch_size, self.num_channels, -1) - - def pitch_shift(self, n_semitones: int, quick: bool = True): - """Pitch shift the signal. All items in the batch - get the same pitch shift. - - Parameters - ---------- - n_semitones : int - How many semitones to shift the signal by. - quick : bool, optional - Using quick pitch shifting, by default True - - Returns - ------- - AudioSignal - Pitch shifted audio signal. - """ - device = self.device - effects = [ - ["pitch", str(n_semitones * 100)], - ["rate", str(self.sample_rate)], - ] - if quick: - effects[0].insert(1, "-q") - - waveform = self._to_2d().cpu() - waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( - waveform, self.sample_rate, effects, channels_first=True - ) - self.sample_rate = sample_rate - self.audio_data = self._to_3d(waveform) - return self.to(device) - - def time_stretch(self, factor: float, quick: bool = True): - """Time stretch the audio signal. - - Parameters - ---------- - factor : float - Factor by which to stretch the AudioSignal. Typically - between 0.8 and 1.2. - quick : bool, optional - Whether to use quick time stretching, by default True - - Returns - ------- - AudioSignal - Time-stretched AudioSignal. - """ - device = self.device - effects = [ - ["tempo", str(factor)], - ["rate", str(self.sample_rate)], - ] - if quick: - effects[0].insert(1, "-q") - - waveform = self._to_2d().cpu() - waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( - waveform, self.sample_rate, effects, channels_first=True - ) - self.sample_rate = sample_rate - self.audio_data = self._to_3d(waveform) - return self.to(device) - - def apply_codec( - self, - preset: str = None, - format: str = "wav", - encoding: str = None, - bits_per_sample: int = None, - compression: int = None, - ): # pragma: no cover - """Applies an audio codec to the signal. - - Parameters - ---------- - preset : str, optional - One of the keys in ``self.CODEC_PRESETS``, by default None - format : str, optional - Format for audio codec, by default "wav" - encoding : str, optional - Encoding to use, by default None - bits_per_sample : int, optional - How many bits per sample, by default None - compression : int, optional - Compression amount of codec, by default None - - Returns - ------- - AudioSignal - AudioSignal with codec applied. - - Raises - ------ - ValueError - If preset is not in ``self.CODEC_PRESETS``, an error - is thrown. - """ - torchaudio_version_070 = "0.7" in torchaudio.__version__ - if torchaudio_version_070: - return self - - kwargs = { - "format": format, - "encoding": encoding, - "bits_per_sample": bits_per_sample, - "compression": compression, - } - - if preset is not None: - if preset in self.CODEC_PRESETS: - kwargs = self.CODEC_PRESETS[preset] - else: - raise ValueError( - f"Unknown preset: {preset}. " - f"Known presets: {list(self.CODEC_PRESETS.keys())}" - ) - - waveform = self._to_2d() - if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]: - # Apply it in a for loop - augmented = torch.cat( - [ - torchaudio.functional.apply_codec( - waveform[i][None, :], self.sample_rate, **kwargs - ) - for i in range(waveform.shape[0]) - ], - dim=0, - ) - else: - augmented = torchaudio.functional.apply_codec( - waveform, self.sample_rate, **kwargs - ) - augmented = self._to_3d(augmented) - - self.audio_data = augmented - return self - - def mel_filterbank(self, n_bands: int): - """Breaks signal into mel bands. - - Parameters - ---------- - n_bands : int - Number of mel bands to use. - - Returns - ------- - torch.Tensor - Mel-filtered bands, with last axis being the band index. - """ - filterbank = ( - julius.SplitBands(self.sample_rate, n_bands).float().to(self.device) - ) - filtered = filterbank(self.audio_data) - return filtered.permute(1, 2, 3, 0) - - def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]): - """Applies a mel-spaced equalizer to the audio signal. - - Parameters - ---------- - db : typing.Union[torch.Tensor, np.ndarray] - EQ curve to apply. - - Returns - ------- - AudioSignal - AudioSignal with equalization applied. - """ - db = util.ensure_tensor(db) - n_bands = db.shape[-1] - fbank = self.mel_filterbank(n_bands) - - # If there's a batch dimension, make sure it's the same. - if db.ndim == 2: - if db.shape[0] != 1: - assert db.shape[0] == fbank.shape[0] - else: - db = db.unsqueeze(0) - - weights = (10**db).to(self.device).float() - fbank = fbank * weights[:, None, None, :] - eq_audio_data = fbank.sum(-1) - self.audio_data = eq_audio_data - return self - - def clip_distortion( - self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float] - ): - """Clips the signal at a given percentile. The higher it is, - the lower the threshold for clipping. - - Parameters - ---------- - clip_percentile : typing.Union[torch.Tensor, np.ndarray, float] - Values are between 0.0 to 1.0. Typical values are 0.1 or below. - - Returns - ------- - AudioSignal - Audio signal with clipped audio data. - """ - clip_percentile = util.ensure_tensor(clip_percentile, ndim=1) - min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1) - max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1) - - nc = self.audio_data.shape[1] - min_thresh = min_thresh[:, :nc, :] - max_thresh = max_thresh[:, :nc, :] - - self.audio_data = self.audio_data.clamp(min_thresh, max_thresh) - - return self - - def quantization( - self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] - ): - """Applies quantization to the input waveform. - - Parameters - ---------- - quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] - Number of evenly spaced quantization channels to quantize - to. - - Returns - ------- - AudioSignal - Quantized AudioSignal. - """ - quantization_channels = util.ensure_tensor(quantization_channels, ndim=3) - - x = self.audio_data - x = (x + 1) / 2 - x = x * quantization_channels - x = x.floor() - x = x / quantization_channels - x = 2 * x - 1 - - residual = (self.audio_data - x).detach() - self.audio_data = self.audio_data - residual - return self - - def mulaw_quantization( - self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] - ): - """Applies mu-law quantization to the input waveform. - - Parameters - ---------- - quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] - Number of mu-law spaced quantization channels to quantize - to. - - Returns - ------- - AudioSignal - Quantized AudioSignal. - """ - mu = quantization_channels - 1.0 - mu = util.ensure_tensor(mu, ndim=3) - - x = self.audio_data - - # quantize - x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) - x = ((x + 1) / 2 * mu + 0.5).to(torch.int64) - - # unquantize - x = (x / mu) * 2 - 1.0 - x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu - - residual = (self.audio_data - x).detach() - self.audio_data = self.audio_data - residual - return self - - def __matmul__(self, other): - return self.convolve(other) - - -class ImpulseResponseMixin: - """These functions are generally only used with AudioSignals that are derived - from impulse responses, not other sources like music or speech. These methods - are used to replicate the data augmentation described in [1]. - - 1. Bryan, Nicholas J. "Impulse response data augmentation and deep - neural networks for blind room acoustic parameter estimation." - ICASSP 2020-2020 IEEE International Conference on Acoustics, - Speech and Signal Processing (ICASSP). IEEE, 2020. - """ - - def decompose_ir(self): - """Decomposes an impulse response into early and late - field responses. - """ - # Equations 1 and 2 - # ----------------- - # Breaking up into early - # response + late field response. - - td = torch.argmax(self.audio_data, dim=-1, keepdim=True) - t0 = int(self.sample_rate * 0.0025) - - idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :] - idx = idx.expand(self.batch_size, -1, -1) - early_idx = (idx >= td - t0) * (idx <= td + t0) - - early_response = torch.zeros_like(self.audio_data, device=self.device) - early_response[early_idx] = self.audio_data[early_idx] - - late_idx = ~early_idx - late_field = torch.zeros_like(self.audio_data, device=self.device) - late_field[late_idx] = self.audio_data[late_idx] - - # Equation 4 - # ---------- - # Decompose early response into windowed - # direct path and windowed residual. - - window = torch.zeros_like(self.audio_data, device=self.device) - for idx in range(self.batch_size): - window_idx = early_idx[idx, 0].nonzero() - window[idx, ..., window_idx] = self.get_window( - "hann", window_idx.shape[-1], self.device - ) - return early_response, late_field, window - - def measure_drr(self): - """Measures the direct-to-reverberant ratio of the impulse - response. - - Returns - ------- - float - Direct-to-reverberant ratio - """ - early_response, late_field, _ = self.decompose_ir() - num = (early_response**2).sum(dim=-1) - den = (late_field**2).sum(dim=-1) - drr = 10 * torch.log10(num / den) - return drr - - @staticmethod - def solve_alpha(early_response, late_field, wd, target_drr): - """Used to solve for the alpha value, which is used - to alter the drr. - """ - # Equation 5 - # ---------- - # Apply the good ol' quadratic formula. - - wd_sq = wd**2 - wd_sq_1 = (1 - wd) ** 2 - e_sq = early_response**2 - l_sq = late_field**2 - a = (wd_sq * e_sq).sum(dim=-1) - b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1) - c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum( - dim=-1 - ) - - expr = ((b**2) - 4 * a * c).sqrt() - alpha = torch.maximum( - (-b - expr) / (2 * a), - (-b + expr) / (2 * a), - ) - return alpha - - def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]): - """Alters the direct-to-reverberant ratio of the impulse response. - - Parameters - ---------- - drr : typing.Union[torch.Tensor, np.ndarray, float] - Direct-to-reverberant ratio that impulse response will be - altered to, if specified, by default None - - Returns - ------- - AudioSignal - Altered impulse response. - """ - drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device) - - early_response, late_field, window = self.decompose_ir() - alpha = self.solve_alpha(early_response, late_field, window, drr) - min_alpha = ( - late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0] - ) - alpha = torch.maximum(alpha, min_alpha)[..., None] - - aug_ir_data = ( - alpha * window * early_response - + ((1 - window) * early_response) - + late_field - ) - self.audio_data = aug_ir_data - self.ensure_max_of_audio() - return self diff --git a/dito/models/ldm/dac/audiotools/core/ffmpeg.py b/dito/models/ldm/dac/audiotools/core/ffmpeg.py deleted file mode 100644 index 83f9cd197d7dc8748a16be77614cc593a6a33297..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/ffmpeg.py +++ /dev/null @@ -1,211 +0,0 @@ -import json -import shlex -import subprocess -import tempfile -from pathlib import Path -from typing import Tuple - -import ffmpy -import numpy as np -import torch - - -def r128stats(filepath: str, quiet: bool): - """Takes a path to an audio file, returns a dict with the loudness - stats computed by the ffmpeg ebur128 filter. - - Parameters - ---------- - filepath : str - Path to compute loudness stats on. - quiet : bool - Whether to show FFMPEG output during computation. - - Returns - ------- - dict - Dictionary containing loudness stats. - """ - ffargs = [ - "ffmpeg", - "-nostats", - "-i", - filepath, - "-filter_complex", - "ebur128", - "-f", - "null", - "-", - ] - if quiet: - ffargs += ["-hide_banner"] - proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True) - stats = proc.communicate()[1] - summary_index = stats.rfind("Summary:") - - summary_list = stats[summary_index:].split() - i_lufs = float(summary_list[summary_list.index("I:") + 1]) - i_thresh = float(summary_list[summary_list.index("I:") + 4]) - lra = float(summary_list[summary_list.index("LRA:") + 1]) - lra_thresh = float(summary_list[summary_list.index("LRA:") + 4]) - lra_low = float(summary_list[summary_list.index("low:") + 1]) - lra_high = float(summary_list[summary_list.index("high:") + 1]) - stats_dict = { - "I": i_lufs, - "I Threshold": i_thresh, - "LRA": lra, - "LRA Threshold": lra_thresh, - "LRA Low": lra_low, - "LRA High": lra_high, - } - - return stats_dict - - -def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]: - """Given a path to a file, returns the start time offset and codec of - the first audio stream. - """ - ff = ffmpy.FFprobe( - inputs={path: None}, - global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet", - ) - streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"] - seconds_offset = 0.0 - codec = None - - # Get the offset and codec of the first audio stream we find - # and return its start time, if it has one. - for stream in streams: - if stream["codec_type"] == "audio": - seconds_offset = stream.get("start_time", 0.0) - codec = stream.get("codec_name") - break - return float(seconds_offset), codec - - -class FFMPEGMixin: - _loudness = None - - def ffmpeg_loudness(self, quiet: bool = True): - """Computes loudness of audio file using FFMPEG. - - Parameters - ---------- - quiet : bool, optional - Whether to show FFMPEG output during computation, - by default True - - Returns - ------- - torch.Tensor - Loudness of every item in the batch, computed via - FFMPEG. - """ - loudness = [] - - with tempfile.NamedTemporaryFile(suffix=".wav") as f: - for i in range(self.batch_size): - self[i].write(f.name) - loudness_stats = r128stats(f.name, quiet=quiet) - loudness.append(loudness_stats["I"]) - - self._loudness = torch.from_numpy(np.array(loudness)).float() - return self.loudness() - - def ffmpeg_resample(self, sample_rate: int, quiet: bool = True): - """Resamples AudioSignal using FFMPEG. More memory-efficient - than using julius.resample for long audio files. - - Parameters - ---------- - sample_rate : int - Sample rate to resample to. - quiet : bool, optional - Whether to show FFMPEG output during computation, - by default True - - Returns - ------- - AudioSignal - Resampled AudioSignal. - """ - from audiotools import AudioSignal - - if sample_rate == self.sample_rate: - return self - - with tempfile.NamedTemporaryFile(suffix=".wav") as f: - self.write(f.name) - f_out = f.name.replace("wav", "rs.wav") - command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}" - if quiet: - command += " -hide_banner -loglevel error" - subprocess.check_call(shlex.split(command)) - resampled = AudioSignal(f_out) - Path.unlink(Path(f_out)) - return resampled - - @classmethod - def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs): - """Loads AudioSignal object after decoding it to a wav file using FFMPEG. - Useful for loading audio that isn't covered by librosa's loading mechanism. Also - useful for loading mp3 files, without any offset. - - Parameters - ---------- - audio_path : str - Path to load AudioSignal from. - quiet : bool, optional - Whether to show FFMPEG output during computation, - by default True - - Returns - ------- - AudioSignal - AudioSignal loaded from file with FFMPEG. - """ - audio_path = str(audio_path) - with tempfile.TemporaryDirectory() as d: - wav_file = str(Path(d) / "extracted.wav") - padded_wav = str(Path(d) / "padded.wav") - - global_options = "-y" - if quiet: - global_options += " -loglevel error" - - ff = ffmpy.FFmpeg( - inputs={audio_path: None}, - # For inputs that are m4a (and others?), the input audio can - # have samples that don't match the sample rate. This aresample - # option forces ffmpeg to read timing information in the source - # file instead of assuming constant sample rate. - # - # This fixes an issue where an input m4a file might be a - # different length than the output wav file - outputs={wav_file: "-af aresample=async=1000"}, - global_options=global_options, - ) - ff.run() - - # We pad the file using the start time offset in case it's an audio - # stream starting at some offset in a video container. - pad, codec = ffprobe_offset_and_codec(audio_path) - - # For mp3s, don't pad files with discrepancies less than 0.027s - - # it's likely due to codec latency. The amount of latency introduced - # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold - # here slightly above that. - # Source: https://lame.sourceforge.io/tech-FAQ.txt. - if codec == "mp3" and pad < 0.027: - pad = 0.0 - ff = ffmpy.FFmpeg( - inputs={wav_file: None}, - outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"}, - global_options=global_options, - ) - ff.run() - - signal = cls(padded_wav, **kwargs) - - return signal diff --git a/dito/models/ldm/dac/audiotools/core/loudness.py b/dito/models/ldm/dac/audiotools/core/loudness.py deleted file mode 100644 index cb3ee2675d7cb71f4c00106b0c1e901b8e51b842..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/loudness.py +++ /dev/null @@ -1,320 +0,0 @@ -import copy - -import julius -import numpy as np -import scipy -import torch -import torch.nn.functional as F -import torchaudio - - -class Meter(torch.nn.Module): - """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors. - - Parameters - ---------- - rate : int - Sample rate of audio. - filter_class : str, optional - Class of weighting filter used. - K-weighting' (default), 'Fenton/Lee 1' - 'Fenton/Lee 2', 'Dash et al.' - by default "K-weighting" - block_size : float, optional - Gating block size in seconds, by default 0.400 - zeros : int, optional - Number of zeros to use in FIR approximation of - IIR filters, by default 512 - use_fir : bool, optional - Whether to use FIR approximation or exact IIR formulation. - If computing on GPU, ``use_fir=True`` will be used, as its - much faster, by default False - """ - - def __init__( - self, - rate: int, - filter_class: str = "K-weighting", - block_size: float = 0.400, - zeros: int = 512, - use_fir: bool = False, - ): - super().__init__() - - self.rate = rate - self.filter_class = filter_class - self.block_size = block_size - self.use_fir = use_fir - - G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41])) - self.register_buffer("G", G) - - # Compute impulse responses so that filtering is fast via - # a convolution at runtime, on GPU, unlike lfilter. - impulse = np.zeros((zeros,)) - impulse[..., 0] = 1.0 - - firs = np.zeros((len(self._filters), 1, zeros)) - passband_gain = torch.zeros(len(self._filters)) - - for i, (_, filter_stage) in enumerate(self._filters.items()): - firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse) - passband_gain[i] = filter_stage.passband_gain - - firs = torch.from_numpy(firs[..., ::-1].copy()).float() - - self.register_buffer("firs", firs) - self.register_buffer("passband_gain", passband_gain) - - def apply_filter_gpu(self, data: torch.Tensor): - """Performs FIR approximation of loudness computation. - - Parameters - ---------- - data : torch.Tensor - Audio data of shape (nb, nch, nt). - - Returns - ------- - torch.Tensor - Filtered audio data. - """ - # Data is of shape (nb, nch, nt) - # Reshape to (nb*nch, 1, nt) - nb, nt, nch = data.shape - data = data.permute(0, 2, 1) - data = data.reshape(nb * nch, 1, nt) - - # Apply padding - pad_length = self.firs.shape[-1] - - # Apply filtering in sequence - for i in range(self.firs.shape[0]): - data = F.pad(data, (pad_length, pad_length)) - data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...]) - data = self.passband_gain[i] * data - data = data[..., 1 : nt + 1] - - data = data.permute(0, 2, 1) - data = data[:, :nt, :] - return data - - def apply_filter_cpu(self, data: torch.Tensor): - """Performs IIR formulation of loudness computation. - - Parameters - ---------- - data : torch.Tensor - Audio data of shape (nb, nch, nt). - - Returns - ------- - torch.Tensor - Filtered audio data. - """ - for _, filter_stage in self._filters.items(): - passband_gain = filter_stage.passband_gain - - a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device) - b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device) - - _data = data.permute(0, 2, 1) - filtered = torchaudio.functional.lfilter( - _data, a_coeffs, b_coeffs, clamp=False - ) - data = passband_gain * filtered.permute(0, 2, 1) - return data - - def apply_filter(self, data: torch.Tensor): - """Applies filter on either CPU or GPU, depending - on if the audio is on GPU or is on CPU, or if - ``self.use_fir`` is True. - - Parameters - ---------- - data : torch.Tensor - Audio data of shape (nb, nch, nt). - - Returns - ------- - torch.Tensor - Filtered audio data. - """ - if data.is_cuda or self.use_fir: - data = self.apply_filter_gpu(data) - else: - data = self.apply_filter_cpu(data) - return data - - def forward(self, data: torch.Tensor): - """Computes integrated loudness of data. - - Parameters - ---------- - data : torch.Tensor - Audio data of shape (nb, nch, nt). - - Returns - ------- - torch.Tensor - Filtered audio data. - """ - return self.integrated_loudness(data) - - def _unfold(self, input_data): - T_g = self.block_size - overlap = 0.75 # overlap of 75% of the block duration - step = 1.0 - overlap # step size by percentage - - kernel_size = int(T_g * self.rate) - stride = int(T_g * self.rate * step) - unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride) - unfolded = unfolded.transpose(-1, -2) - - return unfolded - - def integrated_loudness(self, data: torch.Tensor): - """Computes integrated loudness of data. - - Parameters - ---------- - data : torch.Tensor - Audio data of shape (nb, nch, nt). - - Returns - ------- - torch.Tensor - Filtered audio data. - """ - if not torch.is_tensor(data): - data = torch.from_numpy(data).float() - else: - data = data.float() - - input_data = copy.copy(data) - # Data always has a batch and channel dimension. - # Is of shape (nb, nt, nch) - if input_data.ndim < 2: - input_data = input_data.unsqueeze(-1) - if input_data.ndim < 3: - input_data = input_data.unsqueeze(0) - - nb, nt, nch = input_data.shape - - # Apply frequency weighting filters - account - # for the acoustic respose of the head and auditory system - input_data = self.apply_filter(input_data) - - G = self.G # channel gains - T_g = self.block_size # 400 ms gating block standard - Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold - - unfolded = self._unfold(input_data) - - z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) - l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True)) - l = l.expand_as(z) - - # find gating block indices above absolute threshold - z_avg_gated = z - z_avg_gated[l <= Gamma_a] = 0 - masked = l > Gamma_a - z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) - - # calculate the relative threshold value (see eq. 6) - Gamma_r = ( - -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 - ) - Gamma_r = Gamma_r[:, None, None] - Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1]) - - # find gating block indices above relative and absolute thresholds (end of eq. 7) - z_avg_gated = z - z_avg_gated[l <= Gamma_a] = 0 - z_avg_gated[l <= Gamma_r] = 0 - masked = (l > Gamma_a) * (l > Gamma_r) - z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) - - # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version) - # z_avg_gated = torch.nan_to_num(z_avg_gated) - z_avg_gated = torch.where( - z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated - ) - z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max) - z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min) - - LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1)) - return LUFS.float() - - @property - def filter_class(self): - return self._filter_class - - @filter_class.setter - def filter_class(self, value): - from pyloudnorm import Meter - - meter = Meter(self.rate) - meter.filter_class = value - self._filter_class = value - self._filters = meter._filters - - -class LoudnessMixin: - _loudness = None - MIN_LOUDNESS = -70 - """Minimum loudness possible.""" - - def loudness( - self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs - ): - """Calculates loudness using an implementation of ITU-R BS.1770-4. - Allows control over gating block size and frequency weighting filters for - additional control. Measure the integrated gated loudness of a signal. - - API is derived from PyLoudnorm, but this implementation is ported to PyTorch - and is tensorized across batches. When on GPU, an FIR approximation of the IIR - filters is used to compute loudness for speed. - - Uses the weighting filters and block size defined by the meter - the integrated loudness is measured based upon the gating algorithm - defined in the ITU-R BS.1770-4 specification. - - Parameters - ---------- - filter_class : str, optional - Class of weighting filter used. - K-weighting' (default), 'Fenton/Lee 1' - 'Fenton/Lee 2', 'Dash et al.' - by default "K-weighting" - block_size : float, optional - Gating block size in seconds, by default 0.400 - kwargs : dict, optional - Keyword arguments to :py:func:`audiotools.core.loudness.Meter`. - - Returns - ------- - torch.Tensor - Loudness of audio data. - """ - if self._loudness is not None: - return self._loudness.to(self.device) - original_length = self.signal_length - if self.signal_duration < 0.5: - pad_len = int((0.5 - self.signal_duration) * self.sample_rate) - self.zero_pad(0, pad_len) - - # create BS.1770 meter - meter = Meter( - self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs - ) - meter = meter.to(self.device) - # measure loudness - loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1)) - self.truncate_samples(original_length) - min_loudness = ( - torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS - ) - self._loudness = torch.maximum(loudness, min_loudness) - - return self._loudness.to(self.device) diff --git a/dito/models/ldm/dac/audiotools/core/playback.py b/dito/models/ldm/dac/audiotools/core/playback.py deleted file mode 100644 index 5d0f21aaa392494f35305c0084c05b87667ea14d..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/playback.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -These are utilities that allow one to embed an AudioSignal -as a playable object in a Jupyter notebook, or to play audio from -the terminal, etc. -""" # fmt: skip -import base64 -import io -import random -import string -import subprocess -from tempfile import NamedTemporaryFile - -import importlib_resources as pkg_resources - -from . import templates -from .util import _close_temp_files -from .util import format_figure - -headers = pkg_resources.files(templates).joinpath("headers.html").read_text() -widget = pkg_resources.files(templates).joinpath("widget.html").read_text() - -DEFAULT_EXTENSION = ".wav" - - -def _check_imports(): # pragma: no cover - try: - import ffmpy - except: - ffmpy = False - - try: - import IPython - except: - raise ImportError("IPython must be installed in order to use this function!") - return ffmpy, IPython - - -class PlayMixin: - def embed(self, ext: str = None, display: bool = True, return_html: bool = False): - """Embeds audio as a playable audio embed in a notebook, or HTML - document, etc. - - Parameters - ---------- - ext : str, optional - Extension to use when saving the audio, by default ".wav" - display : bool, optional - This controls whether or not to display the audio when called. This - is used when the embed is the last line in a Jupyter cell, to prevent - the audio from being embedded twice, by default True - return_html : bool, optional - Whether to return the data wrapped in an HTML audio element, by default False - - Returns - ------- - str - Either the element for display, or the HTML string of it. - """ - if ext is None: - ext = DEFAULT_EXTENSION - ext = f".{ext}" if not ext.startswith(".") else ext - ffmpy, IPython = _check_imports() - sr = self.sample_rate - tmpfiles = [] - - with _close_temp_files(tmpfiles): - tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False) - tmpfiles.append(tmp_wav) - self.write(tmp_wav.name) - if ext != ".wav" and ffmpy: - tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False) - tmpfiles.append(tmp_wav) - ff = ffmpy.FFmpeg( - inputs={tmp_wav.name: None}, - outputs={ - tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error" - }, - ) - ff.run() - else: - tmp_converted = tmp_wav - - audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr) - if display: - IPython.display.display(audio_element) - - if return_html: - audio_element = ( - f" " - ) - return audio_element - - def widget( - self, - title: str = None, - ext: str = ".wav", - add_headers: bool = True, - player_width: str = "100%", - margin: str = "10px", - plot_fn: str = "specshow", - return_html: bool = False, - **kwargs, - ): - """Creates a playable widget with spectrogram. Inspired (heavily) by - https://sjvasquez.github.io/blog/melnet/. - - Parameters - ---------- - title : str, optional - Title of plot, placed in upper right of top-most axis. - ext : str, optional - Extension for embedding, by default ".mp3" - add_headers : bool, optional - Whether or not to add headers (use for first embed, False for later embeds), by default True - player_width : str, optional - Width of the player, as a string in a CSS rule, by default "100%" - margin : str, optional - Margin on all sides of player, by default "10px" - plot_fn : function, optional - Plotting function to use (by default self.specshow). - return_html : bool, optional - Whether to return the data wrapped in an HTML audio element, by default False - kwargs : dict, optional - Keyword arguments to plot_fn (by default self.specshow). - - Returns - ------- - HTML - HTML object. - """ - import matplotlib.pyplot as plt - - def _save_fig_to_tag(): - buffer = io.BytesIO() - - plt.savefig(buffer, bbox_inches="tight", pad_inches=0) - plt.close() - - buffer.seek(0) - data_uri = base64.b64encode(buffer.read()).decode("ascii") - tag = "data:image/png;base64,{0}".format(data_uri) - - return tag - - _, IPython = _check_imports() - - header_html = "" - - if add_headers: - header_html = headers.replace("PLAYER_WIDTH", str(player_width)) - header_html = header_html.replace("MARGIN", str(margin)) - IPython.display.display(IPython.display.HTML(header_html)) - - widget_html = widget - if isinstance(plot_fn, str): - plot_fn = getattr(self, plot_fn) - kwargs["title"] = title - plot_fn(**kwargs) - - fig = plt.gcf() - pixels = fig.get_size_inches() * fig.dpi - - tag = _save_fig_to_tag() - - # Make the source image for the levels - self.specshow() - format_figure((12, 1.5)) - levels_tag = _save_fig_to_tag() - - player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10)) - - audio_elem = self.embed(ext=ext, display=False) - widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr()) - widget_html = widget_html.replace("IMAGE_SRC", tag) - widget_html = widget_html.replace("LEVELS_SRC", levels_tag) - widget_html = widget_html.replace("PLAYER_ID", player_id) - - # Calculate width/height of figure based on figure size. - widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px") - widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px") - - IPython.display.display(IPython.display.HTML(widget_html)) - - if return_html: - html = header_html if add_headers else "" - html += widget_html - return html - - def play(self): - """ - Plays an audio signal if ffplay from the ffmpeg suite of tools is installed. - Otherwise, will fail. The audio signal is written to a temporary file - and then played with ffplay. - """ - tmpfiles = [] - with _close_temp_files(tmpfiles): - tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False) - tmpfiles.append(tmp_wav) - self.write(tmp_wav.name) - print(self) - subprocess.call( - [ - "ffplay", - "-nodisp", - "-autoexit", - "-hide_banner", - "-loglevel", - "error", - tmp_wav.name, - ] - ) - return self - - -if __name__ == "__main__": # pragma: no cover - from audiotools import AudioSignal - - signal = AudioSignal( - "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5 - ) - - wave_html = signal.widget( - "Waveform", - plot_fn="waveplot", - return_html=True, - ) - - spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False) - - combined_html = signal.widget( - "Waveform + spectrogram", - plot_fn="wavespec", - return_html=True, - add_headers=False, - ) - - signal.low_pass(8000) - lowpass_html = signal.widget( - "Lowpassed audio", - plot_fn="wavespec", - return_html=True, - add_headers=False, - ) - - with open("/tmp/index.html", "w") as f: - f.write(wave_html) - f.write(spec_html) - f.write(combined_html) - f.write(lowpass_html) diff --git a/dito/models/ldm/dac/audiotools/core/templates/__init__.py b/dito/models/ldm/dac/audiotools/core/templates/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/dito/models/ldm/dac/audiotools/core/templates/headers.html b/dito/models/ldm/dac/audiotools/core/templates/headers.html deleted file mode 100644 index 9eaef4a94d575f7826608ad63dcc77fab13b7b19..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/templates/headers.html +++ /dev/null @@ -1,322 +0,0 @@ - - - - - - diff --git a/dito/models/ldm/dac/audiotools/core/templates/pandoc.css b/dito/models/ldm/dac/audiotools/core/templates/pandoc.css deleted file mode 100644 index 842be7be6d65580dab44c6a8013259644f38e6ee..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/templates/pandoc.css +++ /dev/null @@ -1,407 +0,0 @@ -/* -Copyright (c) 2017 Chris Patuzzo -https://twitter.com/chrispatuzzo - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -body { - font-family: Helvetica, arial, sans-serif; - font-size: 14px; - line-height: 1.6; - padding-top: 10px; - padding-bottom: 10px; - background-color: white; - padding: 30px; - color: #333; -} - -body > *:first-child { - margin-top: 0 !important; -} - -body > *:last-child { - margin-bottom: 0 !important; -} - -a { - color: #4183C4; - text-decoration: none; -} - -a.absent { - color: #cc0000; -} - -a.anchor { - display: block; - padding-left: 30px; - margin-left: -30px; - cursor: pointer; - position: absolute; - top: 0; - left: 0; - bottom: 0; -} - -h1, h2, h3, h4, h5, h6 { - margin: 20px 0 10px; - padding: 0; - font-weight: bold; - -webkit-font-smoothing: antialiased; - cursor: text; - position: relative; -} - -h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child { - margin-top: 0; - padding-top: 0; -} - -h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor { - text-decoration: none; -} - -h1 tt, h1 code { - font-size: inherit; -} - -h2 tt, h2 code { - font-size: inherit; -} - -h3 tt, h3 code { - font-size: inherit; -} - -h4 tt, h4 code { - font-size: inherit; -} - -h5 tt, h5 code { - font-size: inherit; -} - -h6 tt, h6 code { - font-size: inherit; -} - -h1 { - font-size: 28px; - color: black; -} - -h2 { - font-size: 24px; - border-bottom: 1px solid #cccccc; - color: black; -} - -h3 { - font-size: 18px; -} - -h4 { - font-size: 16px; -} - -h5 { - font-size: 14px; -} - -h6 { - color: #777777; - font-size: 14px; -} - -p, blockquote, ul, ol, dl, li, table, pre { - margin: 15px 0; -} - -hr { - border: 0 none; - color: #cccccc; - height: 4px; - padding: 0; -} - -body > h2:first-child { - margin-top: 0; - padding-top: 0; -} - -body > h1:first-child { - margin-top: 0; - padding-top: 0; -} - -body > h1:first-child + h2 { - margin-top: 0; - padding-top: 0; -} - -body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child { - margin-top: 0; - padding-top: 0; -} - -a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 { - margin-top: 0; - padding-top: 0; -} - -h1 p, h2 p, h3 p, h4 p, h5 p, h6 p { - margin-top: 0; -} - -li p.first { - display: inline-block; -} - -ul, ol { - padding-left: 30px; -} - -ul :first-child, ol :first-child { - margin-top: 0; -} - -ul :last-child, ol :last-child { - margin-bottom: 0; -} - -dl { - padding: 0; -} - -dl dt { - font-size: 14px; - font-weight: bold; - font-style: italic; - padding: 0; - margin: 15px 0 5px; -} - -dl dt:first-child { - padding: 0; -} - -dl dt > :first-child { - margin-top: 0; -} - -dl dt > :last-child { - margin-bottom: 0; -} - -dl dd { - margin: 0 0 15px; - padding: 0 15px; -} - -dl dd > :first-child { - margin-top: 0; -} - -dl dd > :last-child { - margin-bottom: 0; -} - -blockquote { - border-left: 4px solid #dddddd; - padding: 0 15px; - color: #777777; -} - -blockquote > :first-child { - margin-top: 0; -} - -blockquote > :last-child { - margin-bottom: 0; -} - -table { - padding: 0; -} -table tr { - border-top: 1px solid #cccccc; - background-color: white; - margin: 0; - padding: 0; -} - -table tr:nth-child(2n) { - background-color: #f8f8f8; -} - -table tr th { - font-weight: bold; - border: 1px solid #cccccc; - text-align: left; - margin: 0; - padding: 6px 13px; -} - -table tr td { - border: 1px solid #cccccc; - text-align: left; - margin: 0; - padding: 6px 13px; -} - -table tr th :first-child, table tr td :first-child { - margin-top: 0; -} - -table tr th :last-child, table tr td :last-child { - margin-bottom: 0; -} - -img { - max-width: 100%; -} - -span.frame { - display: block; - overflow: hidden; -} - -span.frame > span { - border: 1px solid #dddddd; - display: block; - float: left; - overflow: hidden; - margin: 13px 0 0; - padding: 7px; - width: auto; -} - -span.frame span img { - display: block; - float: left; -} - -span.frame span span { - clear: both; - color: #333333; - display: block; - padding: 5px 0 0; -} - -span.align-center { - display: block; - overflow: hidden; - clear: both; -} - -span.align-center > span { - display: block; - overflow: hidden; - margin: 13px auto 0; - text-align: center; -} - -span.align-center span img { - margin: 0 auto; - text-align: center; -} - -span.align-right { - display: block; - overflow: hidden; - clear: both; -} - -span.align-right > span { - display: block; - overflow: hidden; - margin: 13px 0 0; - text-align: right; -} - -span.align-right span img { - margin: 0; - text-align: right; -} - -span.float-left { - display: block; - margin-right: 13px; - overflow: hidden; - float: left; -} - -span.float-left span { - margin: 13px 0 0; -} - -span.float-right { - display: block; - margin-left: 13px; - overflow: hidden; - float: right; -} - -span.float-right > span { - display: block; - overflow: hidden; - margin: 13px auto 0; - text-align: right; -} - -code, tt { - margin: 0 2px; - padding: 0 5px; - white-space: nowrap; - border-radius: 3px; -} - -pre code { - margin: 0; - padding: 0; - white-space: pre; - border: none; - background: transparent; -} - -.highlight pre { - font-size: 13px; - line-height: 19px; - overflow: auto; - padding: 6px 10px; - border-radius: 3px; -} - -pre { - font-size: 13px; - line-height: 19px; - overflow: auto; - padding: 6px 10px; - border-radius: 3px; -} - -pre code, pre tt { - background-color: transparent; - border: none; -} - -body { - max-width: 600px; -} diff --git a/dito/models/ldm/dac/audiotools/core/templates/widget.html b/dito/models/ldm/dac/audiotools/core/templates/widget.html deleted file mode 100644 index 0b44e8aec64fd1db929da5fa6208dee00247c967..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/templates/widget.html +++ /dev/null @@ -1,52 +0,0 @@ -
-
-
-
- -
-
- -
- - - -
- -
- - -
-
- - diff --git a/dito/models/ldm/dac/audiotools/core/util.py b/dito/models/ldm/dac/audiotools/core/util.py deleted file mode 100644 index ece1344658d10836aa2eb693f275294ad8cdbb52..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/util.py +++ /dev/null @@ -1,671 +0,0 @@ -import csv -import glob -import math -import numbers -import os -import random -import typing -from contextlib import contextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Dict -from typing import List - -import numpy as np -import torch -import torchaudio -from flatten_dict import flatten -from flatten_dict import unflatten - - -@dataclass -class Info: - """Shim for torchaudio.info API changes.""" - - sample_rate: float - num_frames: int - - @property - def duration(self) -> float: - return self.num_frames / self.sample_rate - - -def info(audio_path: str): - """Shim for torchaudio.info to make 0.7.2 API match 0.8.0. - - Parameters - ---------- - audio_path : str - Path to audio file. - """ - # try default backend first, then fallback to soundfile - try: - info = torchaudio.info(str(audio_path)) - except: # pragma: no cover - info = torchaudio.backend.soundfile_backend.info(str(audio_path)) - - if isinstance(info, tuple): # pragma: no cover - signal_info = info[0] - info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length) - else: - info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames) - - return info - - -def ensure_tensor( - x: typing.Union[np.ndarray, torch.Tensor, float, int], - ndim: int = None, - batch_size: int = None, -): - """Ensures that the input ``x`` is a tensor of specified - dimensions and batch size. - - Parameters - ---------- - x : typing.Union[np.ndarray, torch.Tensor, float, int] - Data that will become a tensor on its way out. - ndim : int, optional - How many dimensions should be in the output, by default None - batch_size : int, optional - The batch size of the output, by default None - - Returns - ------- - torch.Tensor - Modified version of ``x`` as a tensor. - """ - if not torch.is_tensor(x): - x = torch.as_tensor(x) - if ndim is not None: - assert x.ndim <= ndim - while x.ndim < ndim: - x = x.unsqueeze(-1) - if batch_size is not None: - if x.shape[0] != batch_size: - shape = list(x.shape) - shape[0] = batch_size - x = x.expand(*shape) - return x - - -def _get_value(other): - from . import AudioSignal - - if isinstance(other, AudioSignal): - return other.audio_data - return other - - -def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int): - """Closest frequency bin given a frequency, number - of bins, and a sampling rate. - - Parameters - ---------- - hz : torch.Tensor - Tensor of frequencies in Hz. - n_fft : int - Number of FFT bins. - sample_rate : int - Sample rate of audio. - - Returns - ------- - torch.Tensor - Closest bins to the data. - """ - shape = hz.shape - hz = hz.flatten() - freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2) - hz[hz > sample_rate / 2] = sample_rate / 2 - - closest = (hz[None, :] - freqs[:, None]).abs() - closest_bins = closest.min(dim=0).indices - - return closest_bins.reshape(*shape) - - -def random_state(seed: typing.Union[int, np.random.RandomState]): - """ - Turn seed into a np.random.RandomState instance. - - Parameters - ---------- - seed : typing.Union[int, np.random.RandomState] or None - If seed is None, return the RandomState singleton used by np.random. - If seed is an int, return a new RandomState instance seeded with seed. - If seed is already a RandomState instance, return it. - Otherwise raise ValueError. - - Returns - ------- - np.random.RandomState - Random state object. - - Raises - ------ - ValueError - If seed is not valid, an error is thrown. - """ - if seed is None or seed is np.random: - return np.random.mtrand._rand - elif isinstance(seed, (numbers.Integral, np.integer, int)): - return np.random.RandomState(seed) - elif isinstance(seed, np.random.RandomState): - return seed - else: - raise ValueError( - "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed - ) - - -def seed(random_seed, set_cudnn=False): - """ - Seeds all random states with the same random seed - for reproducibility. Seeds ``numpy``, ``random`` and ``torch`` - random generators. - For full reproducibility, two further options must be set - according to the torch documentation: - https://pytorch.org/docs/stable/notes/randomness.html - To do this, ``set_cudnn`` must be True. It defaults to - False, since setting it to True results in a performance - hit. - - Args: - random_seed (int): integer corresponding to random seed to - use. - set_cudnn (bool): Whether or not to set cudnn into determinstic - mode and off of benchmark mode. Defaults to False. - """ - - torch.manual_seed(random_seed) - np.random.seed(random_seed) - random.seed(random_seed) - - if set_cudnn: - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -@contextmanager -def _close_temp_files(tmpfiles: list): - """Utility function for creating a context and closing all temporary files - once the context is exited. For correct functionality, all temporary file - handles created inside the context must be appended to the ```tmpfiles``` - list. - - This function is taken wholesale from Scaper. - - Parameters - ---------- - tmpfiles : list - List of temporary file handles - """ - - def _close(): - for t in tmpfiles: - try: - t.close() - os.unlink(t.name) - except: - pass - - try: - yield - except: # pragma: no cover - _close() - raise - _close() - - -AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] - - -def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): - """Finds all audio files in a directory recursively. - Returns a list. - - Parameters - ---------- - folder : str - Folder to look for audio files in, recursively. - ext : List[str], optional - Extensions to look for without the ., by default - ``['.wav', '.flac', '.mp3', '.mp4']``. - """ - folder = Path(folder) - # Take care of case where user has passed in an audio file directly - # into one of the calling functions. - if str(folder).endswith(tuple(ext)): - # if, however, there's a glob in the path, we need to - # return the glob, not the file. - if "*" in str(folder): - return glob.glob(str(folder), recursive=("**" in str(folder))) - else: - return [folder] - - files = [] - for x in ext: - files += folder.glob(f"**/*{x}") - return files - - -def read_sources( - sources: List[str], - remove_empty: bool = True, - relative_path: str = "", - ext: List[str] = AUDIO_EXTENSIONS, -): - """Reads audio sources that can either be folders - full of audio files, or CSV files that contain paths - to audio files. CSV files that adhere to the expected - format can be generated by - :py:func:`audiotools.data.preprocess.create_csv`. - - Parameters - ---------- - sources : List[str] - List of audio sources to be converted into a - list of lists of audio files. - remove_empty : bool, optional - Whether or not to remove rows with an empty "path" - from each CSV file, by default True. - - Returns - ------- - list - List of lists of rows of CSV files. - """ - files = [] - relative_path = Path(relative_path) - for source in sources: - source = str(source) - _files = [] - if source.endswith(".csv"): - with open(source, "r") as f: - reader = csv.DictReader(f) - for x in reader: - if remove_empty and x["path"] == "": - continue - if x["path"] != "": - x["path"] = str(relative_path / x["path"]) - _files.append(x) - else: - for x in find_audio(source, ext=ext): - x = str(relative_path / x) - _files.append({"path": x}) - files.append(sorted(_files, key=lambda x: x["path"])) - return files - - -def choose_from_list_of_lists( - state: np.random.RandomState, list_of_lists: list, p: float = None -): - """Choose a single item from a list of lists. - - Parameters - ---------- - state : np.random.RandomState - Random state to use when choosing an item. - list_of_lists : list - A list of lists from which items will be drawn. - p : float, optional - Probabilities of each list, by default None - - Returns - ------- - typing.Any - An item from the list of lists. - """ - source_idx = state.choice(list(range(len(list_of_lists))), p=p) - item_idx = state.randint(len(list_of_lists[source_idx])) - return list_of_lists[source_idx][item_idx], source_idx, item_idx - - -@contextmanager -def chdir(newdir: typing.Union[Path, str]): - """ - Context manager for switching directories to run a - function. Useful for when you want to use relative - paths to different runs. - - Parameters - ---------- - newdir : typing.Union[Path, str] - Directory to switch to. - """ - curdir = os.getcwd() - try: - os.chdir(newdir) - yield - finally: - os.chdir(curdir) - - -def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"): - """Moves items in a batch (typically generated by a DataLoader as a list - or a dict) to the specified device. This works even if dictionaries - are nested. - - Parameters - ---------- - batch : typing.Union[dict, list, torch.Tensor] - Batch, typically generated by a dataloader, that will be moved to - the device. - device : str, optional - Device to move batch to, by default "cpu" - - Returns - ------- - typing.Union[dict, list, torch.Tensor] - Batch with all values moved to the specified device. - """ - if isinstance(batch, dict): - batch = flatten(batch) - for key, val in batch.items(): - try: - batch[key] = val.to(device) - except: - pass - batch = unflatten(batch) - elif torch.is_tensor(batch): - batch = batch.to(device) - elif isinstance(batch, list): - for i in range(len(batch)): - try: - batch[i] = batch[i].to(device) - except: - pass - return batch - - -def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): - """Samples from a distribution defined by a tuple. The first - item in the tuple is the distribution type, and the rest of the - items are arguments to that distribution. The distribution function - is gotten from the ``np.random.RandomState`` object. - - Parameters - ---------- - dist_tuple : tuple - Distribution tuple - state : np.random.RandomState, optional - Random state, or seed to use, by default None - - Returns - ------- - typing.Union[float, int, str] - Draw from the distribution. - - Examples - -------- - Sample from a uniform distribution: - - >>> dist_tuple = ("uniform", 0, 1) - >>> sample_from_dist(dist_tuple) - - Sample from a constant distribution: - - >>> dist_tuple = ("const", 0) - >>> sample_from_dist(dist_tuple) - - Sample from a normal distribution: - - >>> dist_tuple = ("normal", 0, 0.5) - >>> sample_from_dist(dist_tuple) - - """ - if dist_tuple[0] == "const": - return dist_tuple[1] - state = random_state(state) - dist_fn = getattr(state, dist_tuple[0]) - return dist_fn(*dist_tuple[1:]) - - -def collate(list_of_dicts: list, n_splits: int = None): - """Collates a list of dictionaries (e.g. as returned by a - dataloader) into a dictionary with batched values. This routine - uses the default torch collate function for everything - except AudioSignal objects, which are handled by the - :py:func:`audiotools.core.audio_signal.AudioSignal.batch` - function. - - This function takes n_splits to enable splitting a batch - into multiple sub-batches for the purposes of gradient accumulation, - etc. - - Parameters - ---------- - list_of_dicts : list - List of dictionaries to be collated. - n_splits : int - Number of splits to make when creating the batches (split into - sub-batches). Useful for things like gradient accumulation. - - Returns - ------- - dict - Dictionary containing batched data. - """ - - from . import AudioSignal - - batches = [] - list_len = len(list_of_dicts) - - return_list = False if n_splits is None else True - n_splits = 1 if n_splits is None else n_splits - n_items = int(math.ceil(list_len / n_splits)) - - for i in range(0, list_len, n_items): - # Flatten the dictionaries to avoid recursion. - list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] - dict_of_lists = { - k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] - } - - batch = {} - for k, v in dict_of_lists.items(): - if isinstance(v, list): - if all(isinstance(s, AudioSignal) for s in v): - batch[k] = AudioSignal.batch(v, pad_signals=True) - else: - # Borrow the default collate fn from torch. - batch[k] = torch.utils.data._utils.collate.default_collate(v) - batches.append(unflatten(batch)) - - batches = batches[0] if not return_list else batches - return batches - - -BASE_SIZE = 864 -DEFAULT_FIG_SIZE = (9, 3) - - -def format_figure( - fig_size: tuple = None, - title: str = None, - fig=None, - format_axes: bool = True, - format: bool = True, - font_color: str = "white", -): - """Prettifies the spectrogram and waveform plots. A title - can be inset into the top right corner, and the axes can be - inset into the figure, allowing the data to take up the entire - image. Used in - - - :py:func:`audiotools.core.display.DisplayMixin.specshow` - - :py:func:`audiotools.core.display.DisplayMixin.waveplot` - - :py:func:`audiotools.core.display.DisplayMixin.wavespec` - - Parameters - ---------- - fig_size : tuple, optional - Size of figure, by default (9, 3) - title : str, optional - Title to inset in top right, by default None - fig : matplotlib.figure.Figure, optional - Figure object, if None ``plt.gcf()`` will be used, by default None - format_axes : bool, optional - Format the axes to be inside the figure, by default True - format : bool, optional - This formatting can be skipped entirely by passing ``format=False`` - to any of the plotting functions that use this formater, by default True - font_color : str, optional - Color of font of axes, by default "white" - """ - import matplotlib - import matplotlib.pyplot as plt - - if fig_size is None: - fig_size = DEFAULT_FIG_SIZE - if not format: - return - if fig is None: - fig = plt.gcf() - fig.set_size_inches(*fig_size) - axs = fig.axes - - pixels = (fig.get_size_inches() * fig.dpi)[0] - font_scale = pixels / BASE_SIZE - - if format_axes: - axs = fig.axes - - for ax in axs: - ymin, _ = ax.get_ylim() - xmin, _ = ax.get_xlim() - - ticks = ax.get_yticks() - for t in ticks[2:-1]: - t = axs[0].annotate( - f"{(t / 1000):2.1f}k", - xy=(xmin, t), - xycoords="data", - xytext=(5, -5), - textcoords="offset points", - ha="left", - va="top", - color=font_color, - fontsize=12 * font_scale, - alpha=0.75, - ) - - ticks = ax.get_xticks()[2:] - for t in ticks[:-1]: - t = axs[0].annotate( - f"{t:2.1f}s", - xy=(t, ymin), - xycoords="data", - xytext=(5, 5), - textcoords="offset points", - ha="center", - va="bottom", - color=font_color, - fontsize=12 * font_scale, - alpha=0.75, - ) - - ax.margins(0, 0) - ax.set_axis_off() - ax.xaxis.set_major_locator(plt.NullLocator()) - ax.yaxis.set_major_locator(plt.NullLocator()) - - plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) - - if title is not None: - t = axs[0].annotate( - title, - xy=(1, 1), - xycoords="axes fraction", - fontsize=20 * font_scale, - xytext=(-5, -5), - textcoords="offset points", - ha="right", - va="top", - color="white", - ) - t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) - - -def generate_chord_dataset( - max_voices: int = 8, - sample_rate: int = 44100, - num_items: int = 5, - duration: float = 1.0, - min_note: str = "C2", - max_note: str = "C6", - output_dir: Path = "chords", -): - """ - Generates a toy multitrack dataset of chords, synthesized from sine waves. - - - Parameters - ---------- - max_voices : int, optional - Maximum number of voices in a chord, by default 8 - sample_rate : int, optional - Sample rate of audio, by default 44100 - num_items : int, optional - Number of items to generate, by default 5 - duration : float, optional - Duration of each item, by default 1.0 - min_note : str, optional - Minimum note in the dataset, by default "C2" - max_note : str, optional - Maximum note in the dataset, by default "C6" - output_dir : Path, optional - Directory to save the dataset, by default "chords" - - """ - import librosa - from . import AudioSignal - from ..data.preprocess import create_csv - - min_midi = librosa.note_to_midi(min_note) - max_midi = librosa.note_to_midi(max_note) - - tracks = [] - for idx in range(num_items): - track = {} - # figure out how many voices to put in this track - num_voices = random.randint(1, max_voices) - for voice_idx in range(num_voices): - # choose some random params - midinote = random.randint(min_midi, max_midi) - dur = random.uniform(0.85 * duration, duration) - - sig = AudioSignal.wave( - frequency=librosa.midi_to_hz(midinote), - duration=dur, - sample_rate=sample_rate, - shape="sine", - ) - track[f"voice_{voice_idx}"] = sig - tracks.append(track) - - # save the tracks to disk - output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True) - for idx, track in enumerate(tracks): - track_dir = output_dir / f"track_{idx}" - track_dir.mkdir(exist_ok=True) - for voice_name, sig in track.items(): - sig.write(track_dir / f"{voice_name}.wav") - - all_voices = list(set([k for track in tracks for k in track.keys()])) - voice_lists = {voice: [] for voice in all_voices} - for track in tracks: - for voice_name in all_voices: - if voice_name in track: - voice_lists[voice_name].append(track[voice_name].path_to_file) - else: - voice_lists[voice_name].append("") - - for voice_name, paths in voice_lists.items(): - create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) - - return output_dir diff --git a/dito/models/ldm/dac/audiotools/core/whisper.py b/dito/models/ldm/dac/audiotools/core/whisper.py deleted file mode 100644 index 46c071f934fc3e2be3138e7596b1c6d2ef79eade..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/core/whisper.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch - - -class WhisperMixin: - is_initialized = False - - def setup_whisper( - self, - pretrained_model_name_or_path: str = "openai/whisper-base.en", - device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), - ): - from transformers import WhisperForConditionalGeneration - from transformers import WhisperProcessor - - self.whisper_device = device - self.whisper_processor = WhisperProcessor.from_pretrained( - pretrained_model_name_or_path - ) - self.whisper_model = WhisperForConditionalGeneration.from_pretrained( - pretrained_model_name_or_path - ).to(self.whisper_device) - self.is_initialized = True - - def get_whisper_features(self) -> torch.Tensor: - """Preprocess audio signal as per the whisper model's training config. - - Returns - ------- - torch.Tensor - The prepinput features of the audio signal. Shape: (1, channels, seq_len) - """ - import torch - - if not self.is_initialized: - self.setup_whisper() - - signal = self.to(self.device) - raw_speech = list( - ( - signal.clone() - .resample(self.whisper_processor.feature_extractor.sampling_rate) - .audio_data[:, 0, :] - .numpy() - ) - ) - - with torch.inference_mode(): - input_features = self.whisper_processor( - raw_speech, - sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, - return_tensors="pt", - ).input_features - - return input_features - - def get_whisper_transcript(self) -> str: - """Get the transcript of the audio signal using the whisper model. - - Returns - ------- - str - The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. - """ - - if not self.is_initialized: - self.setup_whisper() - - input_features = self.get_whisper_features() - - with torch.inference_mode(): - input_features = input_features.to(self.whisper_device) - generated_ids = self.whisper_model.generate(inputs=input_features) - - transcription = self.whisper_processor.batch_decode(generated_ids) - return transcription[0] - - def get_whisper_embeddings(self) -> torch.Tensor: - """Get the last hidden state embeddings of the audio signal using the whisper model. - - Returns - ------- - torch.Tensor - The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) - """ - import torch - - if not self.is_initialized: - self.setup_whisper() - - input_features = self.get_whisper_features() - encoder = self.whisper_model.get_encoder() - - with torch.inference_mode(): - input_features = input_features.to(self.whisper_device) - embeddings = encoder(input_features) - - return embeddings.last_hidden_state diff --git a/dito/models/ldm/dac/audiotools/data/__init__.py b/dito/models/ldm/dac/audiotools/data/__init__.py deleted file mode 100644 index aead269f26f3782043e68418b4c87ee323cbd015..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from . import datasets -from . import preprocess -from . import transforms diff --git a/dito/models/ldm/dac/audiotools/data/datasets.py b/dito/models/ldm/dac/audiotools/data/datasets.py deleted file mode 100644 index 12e7a60963399aa15ff865de2d06537818ce18ee..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/data/datasets.py +++ /dev/null @@ -1,517 +0,0 @@ -from pathlib import Path -from typing import Callable -from typing import Dict -from typing import List -from typing import Union - -import numpy as np -from torch.utils.data import SequentialSampler -from torch.utils.data.distributed import DistributedSampler - -from ..core import AudioSignal -from ..core import util - - -class AudioLoader: - """Loads audio endlessly from a list of audio sources - containing paths to audio files. Audio sources can be - folders full of audio files (which are found via file - extension) or by providing a CSV file which contains paths - to audio files. - - Parameters - ---------- - sources : List[str], optional - Sources containing folders, or CSVs with - paths to audio files, by default None - weights : List[float], optional - Weights to sample audio files from each source, by default None - relative_path : str, optional - Path audio should be loaded relative to, by default "" - transform : Callable, optional - Transform to instantiate alongside audio sample, - by default None - ext : List[str] - List of extensions to find audio within each source by. Can - also be a file name (e.g. "vocals.wav"). by default - ``['.wav', '.flac', '.mp3', '.mp4']``. - shuffle: bool - Whether to shuffle the files within the dataloader. Defaults to True. - shuffle_state: int - State to use to seed the shuffle of the files. - """ - - def __init__( - self, - sources: List[str] = None, - weights: List[float] = None, - transform: Callable = None, - relative_path: str = "", - ext: List[str] = util.AUDIO_EXTENSIONS, - shuffle: bool = True, - shuffle_state: int = 0, - ): - self.audio_lists = util.read_sources( - sources, relative_path=relative_path, ext=ext - ) - - self.audio_indices = [ - (src_idx, item_idx) - for src_idx, src in enumerate(self.audio_lists) - for item_idx in range(len(src)) - ] - if shuffle: - state = util.random_state(shuffle_state) - state.shuffle(self.audio_indices) - - self.sources = sources - self.weights = weights - self.transform = transform - - def __call__( - self, - state, - sample_rate: int, - duration: float, - loudness_cutoff: float = -40, - num_channels: int = 1, - offset: float = None, - source_idx: int = None, - item_idx: int = None, - global_idx: int = None, - ): - if source_idx is not None and item_idx is not None: - try: - audio_info = self.audio_lists[source_idx][item_idx] - except: - audio_info = {"path": "none"} - elif global_idx is not None: - source_idx, item_idx = self.audio_indices[ - global_idx % len(self.audio_indices) - ] - audio_info = self.audio_lists[source_idx][item_idx] - else: - audio_info, source_idx, item_idx = util.choose_from_list_of_lists( - state, self.audio_lists, p=self.weights - ) - - path = audio_info["path"] - signal = AudioSignal.zeros(duration, sample_rate, num_channels) - - if path != "none": - if offset is None: - signal = AudioSignal.salient_excerpt( - path, - duration=duration, - state=state, - loudness_cutoff=loudness_cutoff, - ) - else: - signal = AudioSignal( - path, - offset=offset, - duration=duration, - ) - - if num_channels == 1: - signal = signal.to_mono() - signal = signal.resample(sample_rate) - - if signal.duration < duration: - signal = signal.zero_pad_to(int(duration * sample_rate)) - - for k, v in audio_info.items(): - signal.metadata[k] = v - - item = { - "signal": signal, - "source_idx": source_idx, - "item_idx": item_idx, - "source": str(self.sources[source_idx]), - "path": str(path), - } - if self.transform is not None: - item["transform_args"] = self.transform.instantiate(state, signal=signal) - return item - - -def default_matcher(x, y): - return Path(x).parent == Path(y).parent - - -def align_lists(lists, matcher: Callable = default_matcher): - longest_list = lists[np.argmax([len(l) for l in lists])] - for i, x in enumerate(longest_list): - for l in lists: - if i >= len(l): - l.append({"path": "none"}) - elif not matcher(l[i]["path"], x["path"]): - l.insert(i, {"path": "none"}) - return lists - - -class AudioDataset: - """Loads audio from multiple loaders (with associated transforms) - for a specified number of samples. Excerpts are drawn randomly - of the specified duration, above a specified loudness threshold - and are resampled on the fly to the desired sample rate - (if it is different from the audio source sample rate). - - This takes either a single AudioLoader object, - a dictionary of AudioLoader objects, or a dictionary of AudioLoader - objects. Each AudioLoader is called by the dataset, and the - result is placed in the output dictionary. A transform can also be - specified for the entire dataset, rather than for each specific - loader. This transform can be applied to the output of all the - loaders if desired. - - AudioLoader objects can be specified as aligned, which means the - loaders correspond to multitrack audio (e.g. a vocals, bass, - drums, and other loader for multitrack music mixtures). - - - Parameters - ---------- - loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] - AudioLoaders to sample audio from. - sample_rate : int - Desired sample rate. - n_examples : int, optional - Number of examples (length of dataset), by default 1000 - duration : float, optional - Duration of audio samples, by default 0.5 - loudness_cutoff : float, optional - Loudness cutoff threshold for audio samples, by default -40 - num_channels : int, optional - Number of channels in output audio, by default 1 - transform : Callable, optional - Transform to instantiate alongside each dataset item, by default None - aligned : bool, optional - Whether the loaders should be sampled in an aligned manner (e.g. same - offset, duration, and matched file name), by default False - shuffle_loaders : bool, optional - Whether to shuffle the loaders before sampling from them, by default False - matcher : Callable - How to match files from adjacent audio lists (e.g. for a multitrack audio loader), - by default uses the parent directory of each file. - without_replacement : bool - Whether to choose files with or without replacement, by default True. - - - Examples - -------- - >>> from audiotools.data.datasets import AudioLoader - >>> from audiotools.data.datasets import AudioDataset - >>> from audiotools import transforms as tfm - >>> import numpy as np - >>> - >>> loaders = [ - >>> AudioLoader( - >>> sources=[f"tests/audio/spk"], - >>> transform=tfm.Equalizer(), - >>> ext=["wav"], - >>> ) - >>> for i in range(5) - >>> ] - >>> - >>> dataset = AudioDataset( - >>> loaders = loaders, - >>> sample_rate = 44100, - >>> duration = 1.0, - >>> transform = tfm.RescaleAudio(), - >>> ) - >>> - >>> item = dataset[np.random.randint(len(dataset))] - >>> - >>> for i in range(len(loaders)): - >>> item[i]["signal"] = loaders[i].transform( - >>> item[i]["signal"], **item[i]["transform_args"] - >>> ) - >>> item[i]["signal"].widget(i) - >>> - >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) - >>> mix = dataset.transform(mix, **item["transform_args"]) - >>> mix.widget("mix") - - Below is an example of how one could load MUSDB multitrack data: - - >>> import audiotools as at - >>> from pathlib import Path - >>> from audiotools import transforms as tfm - >>> import numpy as np - >>> import torch - >>> - >>> def build_dataset( - >>> sample_rate: int = 44100, - >>> duration: float = 5.0, - >>> musdb_path: str = "~/.data/musdb/", - >>> ): - >>> musdb_path = Path(musdb_path).expanduser() - >>> loaders = { - >>> src: at.datasets.AudioLoader( - >>> sources=[musdb_path], - >>> transform=tfm.Compose( - >>> tfm.VolumeNorm(("uniform", -20, -10)), - >>> tfm.Silence(prob=0.1), - >>> ), - >>> ext=[f"{src}.wav"], - >>> ) - >>> for src in ["vocals", "bass", "drums", "other"] - >>> } - >>> - >>> dataset = at.datasets.AudioDataset( - >>> loaders=loaders, - >>> sample_rate=sample_rate, - >>> duration=duration, - >>> num_channels=1, - >>> aligned=True, - >>> transform=tfm.RescaleAudio(), - >>> shuffle_loaders=True, - >>> ) - >>> return dataset, list(loaders.keys()) - >>> - >>> train_data, sources = build_dataset() - >>> dataloader = torch.utils.data.DataLoader( - >>> train_data, - >>> batch_size=16, - >>> num_workers=0, - >>> collate_fn=train_data.collate, - >>> ) - >>> batch = next(iter(dataloader)) - >>> - >>> for k in sources: - >>> src = batch[k] - >>> src["transformed"] = train_data.loaders[k].transform( - >>> src["signal"].clone(), **src["transform_args"] - >>> ) - >>> - >>> mixture = sum(batch[k]["transformed"] for k in sources) - >>> mixture = train_data.transform(mixture, **batch["transform_args"]) - >>> - >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). - >>> # Construct the targets: - >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) - - Similarly, here's example code for loading Slakh data: - - >>> import audiotools as at - >>> from pathlib import Path - >>> from audiotools import transforms as tfm - >>> import numpy as np - >>> import torch - >>> import glob - >>> - >>> def build_dataset( - >>> sample_rate: int = 16000, - >>> duration: float = 10.0, - >>> slakh_path: str = "~/.data/slakh/", - >>> ): - >>> slakh_path = Path(slakh_path).expanduser() - >>> - >>> # Find the max number of sources in Slakh - >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] - >>> n_sources = len(list(set(src_names))) - >>> - >>> loaders = { - >>> f"S{i:02d}": at.datasets.AudioLoader( - >>> sources=[slakh_path], - >>> transform=tfm.Compose( - >>> tfm.VolumeNorm(("uniform", -20, -10)), - >>> tfm.Silence(prob=0.1), - >>> ), - >>> ext=[f"S{i:02d}.wav"], - >>> ) - >>> for i in range(n_sources) - >>> } - >>> dataset = at.datasets.AudioDataset( - >>> loaders=loaders, - >>> sample_rate=sample_rate, - >>> duration=duration, - >>> num_channels=1, - >>> aligned=True, - >>> transform=tfm.RescaleAudio(), - >>> shuffle_loaders=False, - >>> ) - >>> - >>> return dataset, list(loaders.keys()) - >>> - >>> train_data, sources = build_dataset() - >>> dataloader = torch.utils.data.DataLoader( - >>> train_data, - >>> batch_size=16, - >>> num_workers=0, - >>> collate_fn=train_data.collate, - >>> ) - >>> batch = next(iter(dataloader)) - >>> - >>> for k in sources: - >>> src = batch[k] - >>> src["transformed"] = train_data.loaders[k].transform( - >>> src["signal"].clone(), **src["transform_args"] - >>> ) - >>> - >>> mixture = sum(batch[k]["transformed"] for k in sources) - >>> mixture = train_data.transform(mixture, **batch["transform_args"]) - - """ - - def __init__( - self, - loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], - sample_rate: int, - n_examples: int = 1000, - duration: float = 0.5, - offset: float = None, - loudness_cutoff: float = -40, - num_channels: int = 1, - transform: Callable = None, - aligned: bool = False, - shuffle_loaders: bool = False, - matcher: Callable = default_matcher, - without_replacement: bool = True, - ): - # Internally we convert loaders to a dictionary - if isinstance(loaders, list): - loaders = {i: l for i, l in enumerate(loaders)} - elif isinstance(loaders, AudioLoader): - loaders = {0: loaders} - - self.loaders = loaders - self.loudness_cutoff = loudness_cutoff - self.num_channels = num_channels - - self.length = n_examples - self.transform = transform - self.sample_rate = sample_rate - self.duration = duration - self.offset = offset - self.aligned = aligned - self.shuffle_loaders = shuffle_loaders - self.without_replacement = without_replacement - - if aligned: - loaders_list = list(loaders.values()) - for i in range(len(loaders_list[0].audio_lists)): - input_lists = [l.audio_lists[i] for l in loaders_list] - # Alignment happens in-place - align_lists(input_lists, matcher) - - def __getitem__(self, idx): - state = util.random_state(idx) - offset = None if self.offset is None else self.offset - item = {} - - keys = list(self.loaders.keys()) - if self.shuffle_loaders: - state.shuffle(keys) - - loader_kwargs = { - "state": state, - "sample_rate": self.sample_rate, - "duration": self.duration, - "loudness_cutoff": self.loudness_cutoff, - "num_channels": self.num_channels, - "global_idx": idx if self.without_replacement else None, - } - - # Draw item from first loader - loader = self.loaders[keys[0]] - item[keys[0]] = loader(**loader_kwargs) - - for key in keys[1:]: - loader = self.loaders[key] - if self.aligned: - # Path mapper takes the current loader + everything - # returned by the first loader. - offset = item[keys[0]]["signal"].metadata["offset"] - loader_kwargs.update( - { - "offset": offset, - "source_idx": item[keys[0]]["source_idx"], - "item_idx": item[keys[0]]["item_idx"], - } - ) - item[key] = loader(**loader_kwargs) - - # Sort dictionary back into original order - keys = list(self.loaders.keys()) - item = {k: item[k] for k in keys} - - item["idx"] = idx - if self.transform is not None: - item["transform_args"] = self.transform.instantiate( - state=state, signal=item[keys[0]]["signal"] - ) - - # If there's only one loader, pop it up - # to the main dictionary, instead of keeping it - # nested. - if len(keys) == 1: - item.update(item.pop(keys[0])) - - return item - - def __len__(self): - return self.length - - @staticmethod - def collate(list_of_dicts: Union[list, dict], n_splits: int = None): - """Collates items drawn from this dataset. Uses - :py:func:`audiotools.core.util.collate`. - - Parameters - ---------- - list_of_dicts : typing.Union[list, dict] - Data drawn from each item. - n_splits : int - Number of splits to make when creating the batches (split into - sub-batches). Useful for things like gradient accumulation. - - Returns - ------- - dict - Dictionary of batched data. - """ - return util.collate(list_of_dicts, n_splits=n_splits) - - -class ConcatDataset(AudioDataset): - def __init__(self, datasets: list): - self.datasets = datasets - - def __len__(self): - return sum([len(d) for d in self.datasets]) - - def __getitem__(self, idx): - dataset = self.datasets[idx % len(self.datasets)] - return dataset[idx // len(self.datasets)] - - -class ResumableDistributedSampler(DistributedSampler): # pragma: no cover - """Distributed sampler that can be resumed from a given start index.""" - - def __init__(self, dataset, start_idx: int = None, **kwargs): - super().__init__(dataset, **kwargs) - # Start index, allows to resume an experiment at the index it was - self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 - - def __iter__(self): - for i, idx in enumerate(super().__iter__()): - if i >= self.start_idx: - yield idx - self.start_idx = 0 # set the index back to 0 so for the next epoch - - -class ResumableSequentialSampler(SequentialSampler): # pragma: no cover - """Sequential sampler that can be resumed from a given start index.""" - - def __init__(self, dataset, start_idx: int = None, **kwargs): - super().__init__(dataset, **kwargs) - # Start index, allows to resume an experiment at the index it was - self.start_idx = start_idx if start_idx is not None else 0 - - def __iter__(self): - for i, idx in enumerate(super().__iter__()): - if i >= self.start_idx: - yield idx - self.start_idx = 0 # set the index back to 0 so for the next epoch diff --git a/dito/models/ldm/dac/audiotools/data/preprocess.py b/dito/models/ldm/dac/audiotools/data/preprocess.py deleted file mode 100644 index d90de210115e45838bc8d69b350f7516ba730406..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/data/preprocess.py +++ /dev/null @@ -1,81 +0,0 @@ -import csv -import os -from pathlib import Path - -from tqdm import tqdm - -from ..core import AudioSignal - - -def create_csv( - audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None -): - """Converts a folder of audio files to a CSV file. If ``loudness = True``, - the output of this function will create a CSV file that looks something - like: - - .. csv-table:: - :header: path,loudness - - daps/produced/f1_script1_produced.wav,-16.299999237060547 - daps/produced/f1_script2_produced.wav,-16.600000381469727 - daps/produced/f1_script3_produced.wav,-17.299999237060547 - daps/produced/f1_script4_produced.wav,-16.100000381469727 - daps/produced/f1_script5_produced.wav,-16.700000762939453 - daps/produced/f3_script1_produced.wav,-16.5 - - .. note:: - The paths above are written relative to the ``data_path`` argument - which defaults to the environment variable ``PATH_TO_DATA`` if - it isn't passed to this function, and defaults to the empty string - if that environment variable is not set. - - You can produce a CSV file from a directory of audio files via: - - >>> import audiotools - >>> directory = ... - >>> audio_files = audiotools.util.find_audio(directory) - >>> output_path = "train.csv" - >>> audiotools.data.preprocess.create_csv( - >>> audio_files, output_csv, loudness=True - >>> ) - - Note that you can create empty rows in the CSV file by passing an empty - string or None in the ``audio_files`` list. This is useful if you want to - sync multiple CSV files in a multitrack setting. The loudness of these - empty rows will be set to -inf. - - Parameters - ---------- - audio_files : list - List of audio files. - output_csv : Path - Output CSV, with each row containing the relative path of every file - to ``data_path``, if specified (defaults to None). - loudness : bool - Compute loudness of entire file and store alongside path. - """ - - info = [] - pbar = tqdm(audio_files) - for af in pbar: - af = Path(af) - pbar.set_description(f"Processing {af.name}") - _info = {} - if af.name == "": - _info["path"] = "" - if loudness: - _info["loudness"] = -float("inf") - else: - _info["path"] = af.relative_to(data_path) if data_path is not None else af - if loudness: - _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() - - info.append(_info) - - with open(output_csv, "w") as f: - writer = csv.DictWriter(f, fieldnames=list(info[0].keys())) - writer.writeheader() - - for item in info: - writer.writerow(item) diff --git a/dito/models/ldm/dac/audiotools/data/transforms.py b/dito/models/ldm/dac/audiotools/data/transforms.py deleted file mode 100644 index 504e87dc61777e36ba95eb794f497bed4cdc7d2c..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/data/transforms.py +++ /dev/null @@ -1,1592 +0,0 @@ -import copy -from contextlib import contextmanager -from inspect import signature -from typing import List - -import numpy as np -import torch -from flatten_dict import flatten -from flatten_dict import unflatten -from numpy.random import RandomState - -from .. import ml -from ..core import AudioSignal -from ..core import util -from .datasets import AudioLoader - -tt = torch.tensor -"""Shorthand for converting things to torch.tensor.""" - - -class BaseTransform: - """This is the base class for all transforms that are implemented - in this library. Transforms have two main operations: ``transform`` - and ``instantiate``. - - ``instantiate`` sets the parameters randomly - from distribution tuples for each parameter. For example, for the - ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``) - is chosen randomly by instantiate. By default, it chosen uniformly - between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``). - - ``transform`` applies the transform using the instantiated parameters. - A simple example is as follows: - - >>> seed = 0 - >>> signal = ... - >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0)) - >>> kwargs = transform.instantiate() - >>> output = transform(signal.clone(), **kwargs) - - By breaking apart the instantiation of parameters from the actual audio - processing of the transform, we can make things more reproducible, while - also applying the transform on batches of data efficiently on GPU, - rather than on individual audio samples. - - .. note:: - We call ``signal.clone()`` for the input to the ``transform`` function - because signals are modified in-place! If you don't clone the signal, - you will lose the original data. - - Parameters - ---------- - keys : list, optional - Keys that the transform looks for when - calling ``self.transform``, by default []. In general this is - set automatically, and you won't need to manipulate this argument. - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - - Examples - -------- - - >>> seed = 0 - >>> - >>> audio_path = "tests/audio/spk/f10_script4_produced.wav" - >>> signal = AudioSignal(audio_path, offset=10, duration=2) - >>> transform = tfm.Compose( - >>> [ - >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), - >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), - >>> ], - >>> ) - >>> - >>> kwargs = transform.instantiate(seed, signal) - >>> output = transform(signal, **kwargs) - - """ - - def __init__(self, keys: list = [], name: str = None, prob: float = 1.0): - # Get keys from the _transform signature. - tfm_keys = list(signature(self._transform).parameters.keys()) - - # Filter out signal and kwargs keys. - ignore_keys = ["signal", "kwargs"] - tfm_keys = [k for k in tfm_keys if k not in ignore_keys] - - # Combine keys specified by the child class, the keys found in - # _transform signature, and the mask key. - self.keys = keys + tfm_keys + ["mask"] - - self.prob = prob - - if name is None: - name = self.__class__.__name__ - self.name = name - - def _prepare(self, batch: dict): - sub_batch = batch[self.name] - - for k in self.keys: - assert k in sub_batch.keys(), f"{k} not in batch" - - return sub_batch - - def _transform(self, signal): - return signal - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - return {} - - @staticmethod - def apply_mask(batch: dict, mask: torch.Tensor): - """Applies a mask to the batch. - - Parameters - ---------- - batch : dict - Batch whose values will be masked in the ``transform`` pass. - mask : torch.Tensor - Mask to apply to batch. - - Returns - ------- - dict - A dictionary that contains values only where ``mask = True``. - """ - masked_batch = {k: v[mask] for k, v in flatten(batch).items()} - return unflatten(masked_batch) - - def transform(self, signal: AudioSignal, **kwargs): - """Apply the transform to the audio signal, - with given keyword arguments. - - Parameters - ---------- - signal : AudioSignal - Signal that will be modified by the transforms in-place. - kwargs: dict - Keyword arguments to the specific transforms ``self._transform`` - function. - - Returns - ------- - AudioSignal - Transformed AudioSignal. - - Examples - -------- - - >>> for seed in range(10): - >>> kwargs = transform.instantiate(seed, signal) - >>> output = transform(signal.clone(), **kwargs) - - """ - tfm_kwargs = self._prepare(kwargs) - mask = tfm_kwargs["mask"] - - if torch.any(mask): - tfm_kwargs = self.apply_mask(tfm_kwargs, mask) - tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"} - signal[mask] = self._transform(signal[mask], **tfm_kwargs) - - return signal - - def __call__(self, *args, **kwargs): - return self.transform(*args, **kwargs) - - def instantiate( - self, - state: RandomState = None, - signal: AudioSignal = None, - ): - """Instantiates parameters for the transform. - - Parameters - ---------- - state : RandomState, optional - _description_, by default None - signal : AudioSignal, optional - _description_, by default None - - Returns - ------- - dict - Dictionary containing instantiated arguments for every keyword - argument to ``self._transform``. - - Examples - -------- - - >>> for seed in range(10): - >>> kwargs = transform.instantiate(seed, signal) - >>> output = transform(signal.clone(), **kwargs) - - """ - state = util.random_state(state) - - # Not all instantiates need the signal. Check if signal - # is needed before passing it in, so that the end-user - # doesn't need to have variables they're not using flowing - # into their function. - needs_signal = "signal" in set(signature(self._instantiate).parameters.keys()) - kwargs = {} - if needs_signal: - kwargs = {"signal": signal} - - # Instantiate the parameters for the transform. - params = self._instantiate(state, **kwargs) - for k in list(params.keys()): - v = params[k] - if isinstance(v, (AudioSignal, torch.Tensor, dict)): - params[k] = v - else: - params[k] = tt(v) - mask = state.rand() <= self.prob - params[f"mask"] = tt(mask) - - # Put the params into a nested dictionary that will be - # used later when calling the transform. This is to avoid - # collisions in the dictionary. - params = {self.name: params} - - return params - - def batch_instantiate( - self, - states: list = None, - signal: AudioSignal = None, - ): - """Instantiates arguments for every item in a batch, - given a list of states. Each state in the list - corresponds to one item in the batch. - - Parameters - ---------- - states : list, optional - List of states, by default None - signal : AudioSignal, optional - AudioSignal to pass to the ``self.instantiate`` section - if it is needed for this transform, by default None - - Returns - ------- - dict - Collated dictionary of arguments. - - Examples - -------- - - >>> batch_size = 4 - >>> signal = AudioSignal(audio_path, offset=10, duration=2) - >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) - >>> - >>> states = [seed + idx for idx in list(range(batch_size))] - >>> kwargs = transform.batch_instantiate(states, signal_batch) - >>> batch_output = transform(signal_batch, **kwargs) - """ - kwargs = [] - for state in states: - kwargs.append(self.instantiate(state, signal)) - kwargs = util.collate(kwargs) - return kwargs - - -class Identity(BaseTransform): - """This transform just returns the original signal.""" - - pass - - -class SpectralTransform(BaseTransform): - """Spectral transforms require STFT data to exist, since manipulations - of the STFT require the spectrogram. This just calls ``stft`` before - the transform is called, and calls ``istft`` after the transform is - called so that the audio data is written to after the spectral - manipulation. - """ - - def transform(self, signal, **kwargs): - signal.stft() - super().transform(signal, **kwargs) - signal.istft() - return signal - - -class Compose(BaseTransform): - """Compose applies transforms in sequence, one after the other. The - transforms are passed in as positional arguments or as a list like so: - - >>> transform = tfm.Compose( - >>> [ - >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), - >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), - >>> ], - >>> ) - - This will convolve the signal with a room impulse response, and then - add background noise to the signal. Instantiate instantiates - all the parameters for every transform in the transform list so the - interface for using the Compose transform is the same as everything - else: - - >>> kwargs = transform.instantiate() - >>> output = transform(signal.clone(), **kwargs) - - Under the hood, the transform maps each transform to a unique name - under the hood of the form ``{position}.{name}``, where ``position`` - is the index of the transform in the list. ``Compose`` can nest - within other ``Compose`` transforms, like so: - - >>> preprocess = transforms.Compose( - >>> tfm.GlobalVolumeNorm(), - >>> tfm.CrossTalk(), - >>> name="preprocess", - >>> ) - >>> augment = transforms.Compose( - >>> tfm.RoomImpulseResponse(), - >>> tfm.BackgroundNoise(), - >>> name="augment", - >>> ) - >>> postprocess = transforms.Compose( - >>> tfm.VolumeChange(), - >>> tfm.RescaleAudio(), - >>> tfm.ShiftPhase(), - >>> name="postprocess", - >>> ) - >>> transform = transforms.Compose(preprocess, augment, postprocess), - - This defines 3 composed transforms, and then composes them in sequence - with one another. - - Parameters - ---------- - *transforms : list - List of transforms to apply - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__(self, *transforms: list, name: str = None, prob: float = 1.0): - if isinstance(transforms[0], list): - transforms = transforms[0] - - for i, tfm in enumerate(transforms): - tfm.name = f"{i}.{tfm.name}" - - keys = [tfm.name for tfm in transforms] - super().__init__(keys=keys, name=name, prob=prob) - - self.transforms = transforms - self.transforms_to_apply = keys - - @contextmanager - def filter(self, *names: list): - """This can be used to skip transforms entirely when applying - the sequence of transforms to a signal. For example, take - the following transforms with the names ``preprocess, augment, postprocess``. - - >>> preprocess = transforms.Compose( - >>> tfm.GlobalVolumeNorm(), - >>> tfm.CrossTalk(), - >>> name="preprocess", - >>> ) - >>> augment = transforms.Compose( - >>> tfm.RoomImpulseResponse(), - >>> tfm.BackgroundNoise(), - >>> name="augment", - >>> ) - >>> postprocess = transforms.Compose( - >>> tfm.VolumeChange(), - >>> tfm.RescaleAudio(), - >>> tfm.ShiftPhase(), - >>> name="postprocess", - >>> ) - >>> transform = transforms.Compose(preprocess, augment, postprocess) - - If we wanted to apply all 3 to a signal, we do: - - >>> kwargs = transform.instantiate() - >>> output = transform(signal.clone(), **kwargs) - - But if we only wanted to apply the ``preprocess`` and ``postprocess`` - transforms to the signal, we do: - - >>> with transform_fn.filter("preprocess", "postprocess"): - >>> output = transform(signal.clone(), **kwargs) - - Parameters - ---------- - *names : list - List of transforms, identified by name, to apply to signal. - """ - old_transforms = self.transforms_to_apply - self.transforms_to_apply = names - yield - self.transforms_to_apply = old_transforms - - def _transform(self, signal, **kwargs): - for transform in self.transforms: - if any([x in transform.name for x in self.transforms_to_apply]): - signal = transform(signal, **kwargs) - return signal - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - parameters = {} - for transform in self.transforms: - parameters.update(transform.instantiate(state, signal=signal)) - return parameters - - def __getitem__(self, idx): - return self.transforms[idx] - - def __len__(self): - return len(self.transforms) - - def __iter__(self): - for transform in self.transforms: - yield transform - - -class Choose(Compose): - """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`, - but instead of applying all the transforms in sequence, it applies just a single transform, - which is chosen for each item in the batch. - - Parameters - ---------- - *transforms : list - List of transforms to apply - weights : list - Probability of choosing any specific transform. - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - - Examples - -------- - - >>> transforms.Choose(tfm.LowPass(), tfm.HighPass()) - """ - - def __init__( - self, - *transforms: list, - weights: list = None, - name: str = None, - prob: float = 1.0, - ): - super().__init__(*transforms, name=name, prob=prob) - - if weights is None: - _len = len(self.transforms) - weights = [1 / _len for _ in range(_len)] - self.weights = np.array(weights) - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - kwargs = super()._instantiate(state, signal) - tfm_idx = list(range(len(self.transforms))) - tfm_idx = state.choice(tfm_idx, p=self.weights) - one_hot = [] - for i, t in enumerate(self.transforms): - mask = kwargs[t.name]["mask"] - if mask.item(): - kwargs[t.name]["mask"] = tt(i == tfm_idx) - one_hot.append(kwargs[t.name]["mask"]) - kwargs["one_hot"] = one_hot - return kwargs - - -class Repeat(Compose): - """Repeatedly applies a given transform ``n_repeat`` times." - - Parameters - ---------- - transform : BaseTransform - Transform to repeat. - n_repeat : int, optional - Number of times to repeat transform, by default 1 - """ - - def __init__( - self, - transform, - n_repeat: int = 1, - name: str = None, - prob: float = 1.0, - ): - transforms = [copy.copy(transform) for _ in range(n_repeat)] - super().__init__(transforms, name=name, prob=prob) - - self.n_repeat = n_repeat - - -class RepeatUpTo(Choose): - """Repeatedly applies a given transform up to ``max_repeat`` times." - - Parameters - ---------- - transform : BaseTransform - Transform to repeat. - max_repeat : int, optional - Max number of times to repeat transform, by default 1 - weights : list - Probability of choosing any specific number up to ``max_repeat``. - """ - - def __init__( - self, - transform, - max_repeat: int = 5, - weights: list = None, - name: str = None, - prob: float = 1.0, - ): - transforms = [] - for n in range(1, max_repeat): - transforms.append(Repeat(transform, n_repeat=n)) - super().__init__(transforms, name=name, prob=prob, weights=weights) - - self.max_repeat = max_repeat - - -class ClippingDistortion(BaseTransform): - """Adds clipping distortion to signal. Corresponds - to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`. - - Parameters - ---------- - perc : tuple, optional - Clipping percentile. Values are between 0.0 to 1.0. - Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - perc: tuple = ("uniform", 0.0, 0.1), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.perc = perc - - def _instantiate(self, state: RandomState): - return {"perc": util.sample_from_dist(self.perc, state)} - - def _transform(self, signal, perc): - return signal.clip_distortion(perc) - - -class Equalizer(BaseTransform): - """Applies an equalization curve to the audio signal. Corresponds - to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. - - Parameters - ---------- - eq_amount : tuple, optional - The maximum dB cut to apply to the audio in any band, - by default ("const", 1.0 dB) - n_bands : int, optional - Number of bands in EQ, by default 6 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - eq_amount: tuple = ("const", 1.0), - n_bands: int = 6, - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.eq_amount = eq_amount - self.n_bands = n_bands - - def _instantiate(self, state: RandomState): - eq_amount = util.sample_from_dist(self.eq_amount, state) - eq = -eq_amount * state.rand(self.n_bands) - return {"eq": eq} - - def _transform(self, signal, eq): - return signal.equalizer(eq) - - -class Quantization(BaseTransform): - """Applies quantization to the input waveform. Corresponds - to :py:func:`audiotools.core.effects.EffectMixin.quantization`. - - Parameters - ---------- - channels : tuple, optional - Number of evenly spaced quantization channels to quantize - to, by default ("choice", [8, 32, 128, 256, 1024]) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - channels: tuple = ("choice", [8, 32, 128, 256, 1024]), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.channels = channels - - def _instantiate(self, state: RandomState): - return {"channels": util.sample_from_dist(self.channels, state)} - - def _transform(self, signal, channels): - return signal.quantization(channels) - - -class MuLawQuantization(BaseTransform): - """Applies mu-law quantization to the input waveform. Corresponds - to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`. - - Parameters - ---------- - channels : tuple, optional - Number of mu-law spaced quantization channels to quantize - to, by default ("choice", [8, 32, 128, 256, 1024]) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - channels: tuple = ("choice", [8, 32, 128, 256, 1024]), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.channels = channels - - def _instantiate(self, state: RandomState): - return {"channels": util.sample_from_dist(self.channels, state)} - - def _transform(self, signal, channels): - return signal.mulaw_quantization(channels) - - -class NoiseFloor(BaseTransform): - """Adds a noise floor of Gaussian noise to the signal at a specified - dB. - - Parameters - ---------- - db : tuple, optional - Level of noise to add to signal, by default ("const", -50.0) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - db: tuple = ("const", -50.0), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.db = db - - def _instantiate(self, state: RandomState, signal: AudioSignal): - db = util.sample_from_dist(self.db, state) - audio_data = state.randn(signal.num_channels, signal.signal_length) - nz_signal = AudioSignal(audio_data, signal.sample_rate) - nz_signal.normalize(db) - return {"nz_signal": nz_signal} - - def _transform(self, signal, nz_signal): - # Clone bg_signal so that transform can be repeatedly applied - # to different signals with the same effect. - return signal + nz_signal - - -class BackgroundNoise(BaseTransform): - """Adds background noise from audio specified by a set of CSV files. - A valid CSV file looks like, and is typically generated by - :py:func:`audiotools.data.preprocess.create_csv`: - - .. csv-table:: - :header: path - - room_tone/m6_script2_clean.wav - room_tone/m6_script2_cleanraw.wav - room_tone/m6_script2_ipad_balcony1.wav - room_tone/m6_script2_ipad_bedroom1.wav - room_tone/m6_script2_ipad_confroom1.wav - room_tone/m6_script2_ipad_confroom2.wav - room_tone/m6_script2_ipad_livingroom1.wav - room_tone/m6_script2_ipad_office1.wav - - .. note:: - All paths are relative to an environment variable called ``PATH_TO_DATA``, - so that CSV files are portable across machines where data may be - located in different places. - - This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` - and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the - hood. - - Parameters - ---------- - snr : tuple, optional - Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) - sources : List[str], optional - Sources containing folders, or CSVs with paths to audio files, - by default None - weights : List[float], optional - Weights to sample audio files from each source, by default None - eq_amount : tuple, optional - Amount of equalization to apply, by default ("const", 1.0) - n_bands : int, optional - Number of bands in equalizer, by default 3 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - loudness_cutoff : float, optional - Loudness cutoff when loading from audio files, by default None - """ - - def __init__( - self, - snr: tuple = ("uniform", 10.0, 30.0), - sources: List[str] = None, - weights: List[float] = None, - eq_amount: tuple = ("const", 1.0), - n_bands: int = 3, - name: str = None, - prob: float = 1.0, - loudness_cutoff: float = None, - ): - super().__init__(name=name, prob=prob) - - self.snr = snr - self.eq_amount = eq_amount - self.n_bands = n_bands - self.loader = AudioLoader(sources, weights) - self.loudness_cutoff = loudness_cutoff - - def _instantiate(self, state: RandomState, signal: AudioSignal): - eq_amount = util.sample_from_dist(self.eq_amount, state) - eq = -eq_amount * state.rand(self.n_bands) - snr = util.sample_from_dist(self.snr, state) - - bg_signal = self.loader( - state, - signal.sample_rate, - duration=signal.signal_duration, - loudness_cutoff=self.loudness_cutoff, - num_channels=signal.num_channels, - )["signal"] - - return {"eq": eq, "bg_signal": bg_signal, "snr": snr} - - def _transform(self, signal, bg_signal, snr, eq): - # Clone bg_signal so that transform can be repeatedly applied - # to different signals with the same effect. - return signal.mix(bg_signal.clone(), snr, eq) - - -class CrossTalk(BaseTransform): - """Adds crosstalk between speakers, whose audio is drawn from a CSV file - that was produced via :py:func:`audiotools.data.preprocess.create_csv`. - - This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` - under the hood. - - Parameters - ---------- - snr : tuple, optional - How loud cross-talk speaker is relative to original signal in dB, - by default ("uniform", 0.0, 10.0) - sources : List[str], optional - Sources containing folders, or CSVs with paths to audio files, - by default None - weights : List[float], optional - Weights to sample audio files from each source, by default None - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - loudness_cutoff : float, optional - Loudness cutoff when loading from audio files, by default -40 - """ - - def __init__( - self, - snr: tuple = ("uniform", 0.0, 10.0), - sources: List[str] = None, - weights: List[float] = None, - name: str = None, - prob: float = 1.0, - loudness_cutoff: float = -40, - ): - super().__init__(name=name, prob=prob) - - self.snr = snr - self.loader = AudioLoader(sources, weights) - self.loudness_cutoff = loudness_cutoff - - def _instantiate(self, state: RandomState, signal: AudioSignal): - snr = util.sample_from_dist(self.snr, state) - crosstalk_signal = self.loader( - state, - signal.sample_rate, - duration=signal.signal_duration, - loudness_cutoff=self.loudness_cutoff, - num_channels=signal.num_channels, - )["signal"] - - return {"crosstalk_signal": crosstalk_signal, "snr": snr} - - def _transform(self, signal, crosstalk_signal, snr): - # Clone bg_signal so that transform can be repeatedly applied - # to different signals with the same effect. - loudness = signal.loudness() - mix = signal.mix(crosstalk_signal.clone(), snr) - mix.normalize(loudness) - return mix - - -class RoomImpulseResponse(BaseTransform): - """Convolves signal with a room impulse response, at a specified - direct-to-reverberant ratio, with equalization applied. Room impulse - response data is drawn from a CSV file that was produced via - :py:func:`audiotools.data.preprocess.create_csv`. - - This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` - under the hood. - - Parameters - ---------- - drr : tuple, optional - _description_, by default ("uniform", 0.0, 30.0) - sources : List[str], optional - Sources containing folders, or CSVs with paths to audio files, - by default None - weights : List[float], optional - Weights to sample audio files from each source, by default None - eq_amount : tuple, optional - Amount of equalization to apply, by default ("const", 1.0) - n_bands : int, optional - Number of bands in equalizer, by default 6 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - use_original_phase : bool, optional - Whether or not to use the original phase, by default False - offset : float, optional - Offset from each impulse response file to use, by default 0.0 - duration : float, optional - Duration of each impulse response, by default 1.0 - """ - - def __init__( - self, - drr: tuple = ("uniform", 0.0, 30.0), - sources: List[str] = None, - weights: List[float] = None, - eq_amount: tuple = ("const", 1.0), - n_bands: int = 6, - name: str = None, - prob: float = 1.0, - use_original_phase: bool = False, - offset: float = 0.0, - duration: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.drr = drr - self.eq_amount = eq_amount - self.n_bands = n_bands - self.use_original_phase = use_original_phase - - self.loader = AudioLoader(sources, weights) - self.offset = offset - self.duration = duration - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - eq_amount = util.sample_from_dist(self.eq_amount, state) - eq = -eq_amount * state.rand(self.n_bands) - drr = util.sample_from_dist(self.drr, state) - - ir_signal = self.loader( - state, - signal.sample_rate, - offset=self.offset, - duration=self.duration, - loudness_cutoff=None, - num_channels=signal.num_channels, - )["signal"] - ir_signal.zero_pad_to(signal.sample_rate) - - return {"eq": eq, "ir_signal": ir_signal, "drr": drr} - - def _transform(self, signal, ir_signal, drr, eq): - # Clone ir_signal so that transform can be repeatedly applied - # to different signals with the same effect. - return signal.apply_ir( - ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase - ) - - -class VolumeChange(BaseTransform): - """Changes the volume of the input signal. - - Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. - - Parameters - ---------- - db : tuple, optional - Change in volume in decibels, by default ("uniform", -12.0, 0.0) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - db: tuple = ("uniform", -12.0, 0.0), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - self.db = db - - def _instantiate(self, state: RandomState): - return {"db": util.sample_from_dist(self.db, state)} - - def _transform(self, signal, db): - return signal.volume_change(db) - - -class VolumeNorm(BaseTransform): - """Normalizes the volume of the excerpt to a specified decibel. - - Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. - - Parameters - ---------- - db : tuple, optional - dB to normalize signal to, by default ("const", -24) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - db: tuple = ("const", -24), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.db = db - - def _instantiate(self, state: RandomState): - return {"db": util.sample_from_dist(self.db, state)} - - def _transform(self, signal, db): - return signal.normalize(db) - - -class GlobalVolumeNorm(BaseTransform): - """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this - transform also normalizes the volume of a signal, but it uses - the volume of the entire audio file the loaded excerpt comes from, - rather than the volume of just the excerpt. The volume of the - entire audio file is expected in ``signal.metadata["loudness"]``. - If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv` - with ``loudness = True``, like the following: - - .. csv-table:: - :header: path,loudness - - daps/produced/f1_script1_produced.wav,-16.299999237060547 - daps/produced/f1_script2_produced.wav,-16.600000381469727 - daps/produced/f1_script3_produced.wav,-17.299999237060547 - daps/produced/f1_script4_produced.wav,-16.100000381469727 - daps/produced/f1_script5_produced.wav,-16.700000762939453 - daps/produced/f3_script1_produced.wav,-16.5 - - The ``AudioLoader`` will automatically load the loudness column into - the metadata of the signal. - - Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. - - Parameters - ---------- - db : tuple, optional - dB to normalize signal to, by default ("const", -24) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - db: tuple = ("const", -24), - name: str = None, - prob: float = 1.0, - ): - super().__init__(name=name, prob=prob) - - self.db = db - - def _instantiate(self, state: RandomState, signal: AudioSignal): - if "loudness" not in signal.metadata: - db_change = 0.0 - elif float(signal.metadata["loudness"]) == float("-inf"): - db_change = 0.0 - else: - db = util.sample_from_dist(self.db, state) - db_change = db - float(signal.metadata["loudness"]) - - return {"db": db_change} - - def _transform(self, signal, db): - return signal.volume_change(db) - - -class Silence(BaseTransform): - """Zeros out the signal with some probability. - - Parameters - ---------- - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 0.1 - """ - - def __init__(self, name: str = None, prob: float = 0.1): - super().__init__(name=name, prob=prob) - - def _transform(self, signal): - _loudness = signal._loudness - signal = AudioSignal( - torch.zeros_like(signal.audio_data), - sample_rate=signal.sample_rate, - stft_params=signal.stft_params, - ) - # So that the amound of noise added is as if it wasn't silenced. - # TODO: improve this hack - signal._loudness = _loudness - - return signal - - -class LowPass(BaseTransform): - """Applies a LowPass filter. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. - - Parameters - ---------- - cutoff : tuple, optional - Cutoff frequency distribution, - by default ``("choice", [4000, 8000, 16000])`` - zeros : int, optional - Number of zero-crossings in filter, argument to - ``julius.LowPassFilters``, by default 51 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - cutoff: tuple = ("choice", [4000, 8000, 16000]), - zeros: int = 51, - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - - self.cutoff = cutoff - self.zeros = zeros - - def _instantiate(self, state: RandomState): - return {"cutoff": util.sample_from_dist(self.cutoff, state)} - - def _transform(self, signal, cutoff): - return signal.low_pass(cutoff, zeros=self.zeros) - - -class HighPass(BaseTransform): - """Applies a HighPass filter. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. - - Parameters - ---------- - cutoff : tuple, optional - Cutoff frequency distribution, - by default ``("choice", [50, 100, 250, 500, 1000])`` - zeros : int, optional - Number of zero-crossings in filter, argument to - ``julius.LowPassFilters``, by default 51 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]), - zeros: int = 51, - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - - self.cutoff = cutoff - self.zeros = zeros - - def _instantiate(self, state: RandomState): - return {"cutoff": util.sample_from_dist(self.cutoff, state)} - - def _transform(self, signal, cutoff): - return signal.high_pass(cutoff, zeros=self.zeros) - - -class RescaleAudio(BaseTransform): - """Rescales the audio so it is in between ``-val`` and ``val`` - only if the original audio exceeds those bounds. Useful if - transforms have caused the audio to clip. - - Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`. - - Parameters - ---------- - val : float, optional - Max absolute value of signal, by default 1.0 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__(self, val: float = 1.0, name: str = None, prob: float = 1): - super().__init__(name=name, prob=prob) - - self.val = val - - def _transform(self, signal): - return signal.ensure_max_of_audio(self.val) - - -class ShiftPhase(SpectralTransform): - """Shifts the phase of the audio. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`. - - Parameters - ---------- - shift : tuple, optional - How much to shift phase by, by default ("uniform", -np.pi, np.pi) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - shift: tuple = ("uniform", -np.pi, np.pi), - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - self.shift = shift - - def _instantiate(self, state: RandomState): - return {"shift": util.sample_from_dist(self.shift, state)} - - def _transform(self, signal, shift): - return signal.shift_phase(shift) - - -class InvertPhase(ShiftPhase): - """Inverts the phase of the audio. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`. - - Parameters - ---------- - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__(self, name: str = None, prob: float = 1): - super().__init__(shift=("const", np.pi), name=name, prob=prob) - - -class CorruptPhase(SpectralTransform): - """Corrupts the phase of the audio. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`. - - Parameters - ---------- - scale : tuple, optional - How much to corrupt phase by, by default ("uniform", 0, np.pi) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1 - ): - super().__init__(name=name, prob=prob) - self.scale = scale - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - scale = util.sample_from_dist(self.scale, state) - corruption = state.normal(scale=scale, size=signal.phase.shape[1:]) - return {"corruption": corruption.astype("float32")} - - def _transform(self, signal, corruption): - return signal.shift_phase(shift=corruption) - - -class FrequencyMask(SpectralTransform): - """Masks a band of frequencies at a center frequency - from the audio. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`. - - Parameters - ---------- - f_center : tuple, optional - Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) - f_width : tuple, optional - Width of zero'd out band, by default ("const", 0.1) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - f_center: tuple = ("uniform", 0.0, 1.0), - f_width: tuple = ("const", 0.1), - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - self.f_center = f_center - self.f_width = f_width - - def _instantiate(self, state: RandomState, signal: AudioSignal): - f_center = util.sample_from_dist(self.f_center, state) - f_width = util.sample_from_dist(self.f_width, state) - - fmin = max(f_center - (f_width / 2), 0.0) - fmax = min(f_center + (f_width / 2), 1.0) - - fmin_hz = (signal.sample_rate / 2) * fmin - fmax_hz = (signal.sample_rate / 2) * fmax - - return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz} - - def _transform(self, signal, fmin_hz: float, fmax_hz: float): - return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) - - -class TimeMask(SpectralTransform): - """Masks out contiguous time-steps from signal. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`. - - Parameters - ---------- - t_center : tuple, optional - Center time in terms of 0.0 and 1.0 (duration of signal), - by default ("uniform", 0.0, 1.0) - t_width : tuple, optional - Width of dropped out portion, by default ("const", 0.025) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - t_center: tuple = ("uniform", 0.0, 1.0), - t_width: tuple = ("const", 0.025), - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - self.t_center = t_center - self.t_width = t_width - - def _instantiate(self, state: RandomState, signal: AudioSignal): - t_center = util.sample_from_dist(self.t_center, state) - t_width = util.sample_from_dist(self.t_width, state) - - tmin = max(t_center - (t_width / 2), 0.0) - tmax = min(t_center + (t_width / 2), 1.0) - - tmin_s = signal.signal_duration * tmin - tmax_s = signal.signal_duration * tmax - return {"tmin_s": tmin_s, "tmax_s": tmax_s} - - def _transform(self, signal, tmin_s: float, tmax_s: float): - return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s) - - -class MaskLowMagnitudes(SpectralTransform): - """Masks low magnitude regions out of signal. - - Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`. - - Parameters - ---------- - db_cutoff : tuple, optional - Decibel value for which things below it will be masked away, - by default ("uniform", -10, 10) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - db_cutoff: tuple = ("uniform", -10, 10), - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - self.db_cutoff = db_cutoff - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)} - - def _transform(self, signal, db_cutoff: float): - return signal.mask_low_magnitudes(db_cutoff) - - -class Smoothing(BaseTransform): - """Convolves the signal with a smoothing window. - - Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`. - - Parameters - ---------- - window_type : tuple, optional - Type of window to use, by default ("const", "average") - window_length : tuple, optional - Length of smoothing window, by - default ("choice", [8, 16, 32, 64, 128, 256, 512]) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - window_type: tuple = ("const", "average"), - window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]), - name: str = None, - prob: float = 1, - ): - super().__init__(name=name, prob=prob) - self.window_type = window_type - self.window_length = window_length - - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - window_type = util.sample_from_dist(self.window_type, state) - window_length = util.sample_from_dist(self.window_length, state) - window = signal.get_window( - window_type=window_type, window_length=window_length, device="cpu" - ) - return {"window": AudioSignal(window, signal.sample_rate)} - - def _transform(self, signal, window): - sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values - sscale[sscale == 0.0] = 1.0 - - out = signal.convolve(window) - - oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values - oscale[oscale == 0.0] = 1.0 - - out = out * (sscale / oscale) - return out - - -class TimeNoise(TimeMask): - """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but - replaces with noise instead of zeros. - - Parameters - ---------- - t_center : tuple, optional - Center time in terms of 0.0 and 1.0 (duration of signal), - by default ("uniform", 0.0, 1.0) - t_width : tuple, optional - Width of dropped out portion, by default ("const", 0.025) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - t_center: tuple = ("uniform", 0.0, 1.0), - t_width: tuple = ("const", 0.025), - name: str = None, - prob: float = 1, - ): - super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob) - - def _transform(self, signal, tmin_s: float, tmax_s: float): - signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0) - mag, phase = signal.magnitude, signal.phase - - mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) - mask = (mag == 0.0) * (phase == 0.0) - - mag[mask] = mag_r[mask] - phase[mask] = phase_r[mask] - - signal.magnitude = mag - signal.phase = phase - return signal - - -class FrequencyNoise(FrequencyMask): - """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but - replaces with noise instead of zeros. - - Parameters - ---------- - f_center : tuple, optional - Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) - f_width : tuple, optional - Width of zero'd out band, by default ("const", 0.1) - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - f_center: tuple = ("uniform", 0.0, 1.0), - f_width: tuple = ("const", 0.1), - name: str = None, - prob: float = 1, - ): - super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob) - - def _transform(self, signal, fmin_hz: float, fmax_hz: float): - signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) - mag, phase = signal.magnitude, signal.phase - - mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) - mask = (mag == 0.0) * (phase == 0.0) - - mag[mask] = mag_r[mask] - phase[mask] = phase_r[mask] - - signal.magnitude = mag - signal.phase = phase - return signal - - -class SpectralDenoising(Equalizer): - """Applies denoising algorithm detailed in - :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`, - using a randomly generated noise signal for denoising. - - Parameters - ---------- - eq_amount : tuple, optional - Amount of eq to apply to noise signal, by default ("const", 1.0) - denoise_amount : tuple, optional - Amount to denoise by, by default ("uniform", 0.8, 1.0) - nz_volume : float, optional - Volume of noise to denoise with, by default -40 - n_bands : int, optional - Number of bands in equalizer, by default 6 - n_freq : int, optional - Number of frequency bins to smooth by, by default 3 - n_time : int, optional - Number of time bins to smooth by, by default 5 - name : str, optional - Name of this transform, used to identify it in the dictionary - produced by ``self.instantiate``, by default None - prob : float, optional - Probability of applying this transform, by default 1.0 - """ - - def __init__( - self, - eq_amount: tuple = ("const", 1.0), - denoise_amount: tuple = ("uniform", 0.8, 1.0), - nz_volume: float = -40, - n_bands: int = 6, - n_freq: int = 3, - n_time: int = 5, - name: str = None, - prob: float = 1, - ): - super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob) - - self.nz_volume = nz_volume - self.denoise_amount = denoise_amount - self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time) - - def _transform(self, signal, nz, eq, denoise_amount): - nz = nz.normalize(self.nz_volume).equalizer(eq) - self.spectral_gate = self.spectral_gate.to(signal.device) - signal = self.spectral_gate(signal, nz, denoise_amount) - return signal - - def _instantiate(self, state: RandomState): - kwargs = super()._instantiate(state) - kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state) - kwargs["nz"] = AudioSignal(state.randn(22050), 44100) - return kwargs diff --git a/dito/models/ldm/dac/audiotools/metrics/__init__.py b/dito/models/ldm/dac/audiotools/metrics/__init__.py deleted file mode 100644 index c9c8d2df61f94afae8e39e57abf156e8e4059a9e..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/metrics/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Functions for comparing AudioSignal objects to one another. -""" # fmt: skip -from . import distance -from . import quality -from . import spectral diff --git a/dito/models/ldm/dac/audiotools/metrics/distance.py b/dito/models/ldm/dac/audiotools/metrics/distance.py deleted file mode 100644 index ce78739bfc29f9ddc39b23063b4243ddac10adaf..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/metrics/distance.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -from torch import nn - -from .. import AudioSignal - - -class L1Loss(nn.L1Loss): - """L1 Loss between AudioSignals. Defaults - to comparing ``audio_data``, but any - attribute of an AudioSignal can be used. - - Parameters - ---------- - attribute : str, optional - Attribute of signal to compare, defaults to ``audio_data``. - weight : float, optional - Weight of this loss, defaults to 1.0. - """ - - def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): - self.attribute = attribute - self.weight = weight - super().__init__(**kwargs) - - def forward(self, x: AudioSignal, y: AudioSignal): - """ - Parameters - ---------- - x : AudioSignal - Estimate AudioSignal - y : AudioSignal - Reference AudioSignal - - Returns - ------- - torch.Tensor - L1 loss between AudioSignal attributes. - """ - if isinstance(x, AudioSignal): - x = getattr(x, self.attribute) - y = getattr(y, self.attribute) - return super().forward(x, y) - - -class SISDRLoss(nn.Module): - """ - Computes the Scale-Invariant Source-to-Distortion Ratio between a batch - of estimated and reference audio signals or aligned features. - - Parameters - ---------- - scaling : int, optional - Whether to use scale-invariant (True) or - signal-to-noise ratio (False), by default True - reduction : str, optional - How to reduce across the batch (either 'mean', - 'sum', or none).], by default ' mean' - zero_mean : int, optional - Zero mean the references and estimates before - computing the loss, by default True - clip_min : int, optional - The minimum possible loss value. Helps network - to not focus on making already good examples better, by default None - weight : float, optional - Weight of this loss, defaults to 1.0. - """ - - def __init__( - self, - scaling: int = True, - reduction: str = "mean", - zero_mean: int = True, - clip_min: int = None, - weight: float = 1.0, - ): - self.scaling = scaling - self.reduction = reduction - self.zero_mean = zero_mean - self.clip_min = clip_min - self.weight = weight - super().__init__() - - def forward(self, x: AudioSignal, y: AudioSignal): - eps = 1e-8 - # nb, nc, nt - if isinstance(x, AudioSignal): - references = x.audio_data - estimates = y.audio_data - else: - references = x - estimates = y - - nb = references.shape[0] - references = references.reshape(nb, 1, -1).permute(0, 2, 1) - estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) - - # samples now on axis 1 - if self.zero_mean: - mean_reference = references.mean(dim=1, keepdim=True) - mean_estimate = estimates.mean(dim=1, keepdim=True) - else: - mean_reference = 0 - mean_estimate = 0 - - _references = references - mean_reference - _estimates = estimates - mean_estimate - - references_projection = (_references**2).sum(dim=-2) + eps - references_on_estimates = (_estimates * _references).sum(dim=-2) + eps - - scale = ( - (references_on_estimates / references_projection).unsqueeze(1) - if self.scaling - else 1 - ) - - e_true = scale * _references - e_res = _estimates - e_true - - signal = (e_true**2).sum(dim=1) - noise = (e_res**2).sum(dim=1) - sdr = -10 * torch.log10(signal / noise + eps) - - if self.clip_min is not None: - sdr = torch.clamp(sdr, min=self.clip_min) - - if self.reduction == "mean": - sdr = sdr.mean() - elif self.reduction == "sum": - sdr = sdr.sum() - return sdr diff --git a/dito/models/ldm/dac/audiotools/metrics/quality.py b/dito/models/ldm/dac/audiotools/metrics/quality.py deleted file mode 100644 index 1608f25507082b49ccbf49289025a5a94a422808..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/metrics/quality.py +++ /dev/null @@ -1,159 +0,0 @@ -import os - -import numpy as np -import torch - -from .. import AudioSignal - - -def stoi( - estimates: AudioSignal, - references: AudioSignal, - extended: int = False, -): - """Short term objective intelligibility - Computes the STOI (See [1][2]) of a denoised signal compared to a clean - signal, The output is expected to have a monotonic relation with the - subjective speech-intelligibility, where a higher score denotes better - speech intelligibility. Uses pystoi under the hood. - - Parameters - ---------- - estimates : AudioSignal - Denoised speech - references : AudioSignal - Clean original speech - extended : int, optional - Boolean, whether to use the extended STOI described in [3], by default False - - Returns - ------- - Tensor[float] - Short time objective intelligibility measure between clean and - denoised speech - - References - ---------- - 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time - Objective Intelligibility Measure for Time-Frequency Weighted Noisy - Speech', ICASSP 2010, Texas, Dallas. - 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for - Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', - IEEE Transactions on Audio, Speech, and Language Processing, 2011. - 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the - Intelligibility of Speech Masked by Modulated Noise Maskers', - IEEE Transactions on Audio, Speech and Language Processing, 2016. - """ - import pystoi - - estimates = estimates.clone().to_mono() - references = references.clone().to_mono() - - stois = [] - for i in range(estimates.batch_size): - _stoi = pystoi.stoi( - references.audio_data[i, 0].detach().cpu().numpy(), - estimates.audio_data[i, 0].detach().cpu().numpy(), - references.sample_rate, - extended=extended, - ) - stois.append(_stoi) - return torch.from_numpy(np.array(stois)) - - -def pesq( - estimates: AudioSignal, - references: AudioSignal, - mode: str = "wb", - target_sr: float = 16000, -): - """_summary_ - - Parameters - ---------- - estimates : AudioSignal - Degraded AudioSignal - references : AudioSignal - Reference AudioSignal - mode : str, optional - 'wb' (wide-band) or 'nb' (narrow-band), by default "wb" - target_sr : float, optional - Target sample rate, by default 16000 - - Returns - ------- - Tensor[float] - PESQ score: P.862.2 Prediction (MOS-LQO) - """ - from pesq import pesq as pesq_fn - - estimates = estimates.clone().to_mono().resample(target_sr) - references = references.clone().to_mono().resample(target_sr) - - pesqs = [] - for i in range(estimates.batch_size): - _pesq = pesq_fn( - estimates.sample_rate, - references.audio_data[i, 0].detach().cpu().numpy(), - estimates.audio_data[i, 0].detach().cpu().numpy(), - mode, - ) - pesqs.append(_pesq) - return torch.from_numpy(np.array(pesqs)) - - -def visqol( - estimates: AudioSignal, - references: AudioSignal, - mode: str = "audio", -): # pragma: no cover - """ViSQOL score. - - Parameters - ---------- - estimates : AudioSignal - Degraded AudioSignal - references : AudioSignal - Reference AudioSignal - mode : str, optional - 'audio' or 'speech', by default 'audio' - - Returns - ------- - Tensor[float] - ViSQOL score (MOS-LQO) - """ - from visqol import visqol_lib_py - from visqol.pb2 import visqol_config_pb2 - from visqol.pb2 import similarity_result_pb2 - - config = visqol_config_pb2.VisqolConfig() - if mode == "audio": - target_sr = 48000 - config.options.use_speech_scoring = False - svr_model_path = "libsvm_nu_svr_model.txt" - elif mode == "speech": - target_sr = 16000 - config.options.use_speech_scoring = True - svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" - else: - raise ValueError(f"Unrecognized mode: {mode}") - config.audio.sample_rate = target_sr - config.options.svr_model_path = os.path.join( - os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path - ) - - api = visqol_lib_py.VisqolApi() - api.Create(config) - - estimates = estimates.clone().to_mono().resample(target_sr) - references = references.clone().to_mono().resample(target_sr) - - visqols = [] - for i in range(estimates.batch_size): - _visqol = api.Measure( - references.audio_data[i, 0].detach().cpu().numpy().astype(float), - estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), - ) - visqols.append(_visqol.moslqo) - return torch.from_numpy(np.array(visqols)) diff --git a/dito/models/ldm/dac/audiotools/metrics/spectral.py b/dito/models/ldm/dac/audiotools/metrics/spectral.py deleted file mode 100644 index 7ce953882efa4e5b777a0348bee6c1be39279a6c..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/metrics/spectral.py +++ /dev/null @@ -1,247 +0,0 @@ -import typing -from typing import List - -import numpy as np -from torch import nn - -from .. import AudioSignal -from .. import STFTParams - - -class MultiScaleSTFTLoss(nn.Module): - """Computes the multi-scale STFT loss from [1]. - - Parameters - ---------- - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - - References - ---------- - - 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. - "DDSP: Differentiable Digital Signal Processing." - International Conference on Learning Representations. 2019. - """ - - def __init__( - self, - window_lengths: List[int] = [2048, 512], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.loss_fn = loss_fn - self.log_weight = log_weight - self.mag_weight = mag_weight - self.clamp_eps = clamp_eps - self.weight = weight - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes multi-scale STFT between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Multi-scale STFT loss. - """ - loss = 0.0 - for s in self.stft_params: - x.stft(s.window_length, s.hop_length, s.window_type) - y.stft(s.window_length, s.hop_length, s.window_type) - loss += self.log_weight * self.loss_fn( - x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) - return loss - - -class MelSpectrogramLoss(nn.Module): - """Compute distance between mel spectrograms. Can be used - in a multi-scale way. - - Parameters - ---------- - n_mels : List[int] - Number of mels per STFT, by default [150, 80], - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - """ - - def __init__( - self, - n_mels: List[int] = [150, 80], - window_lengths: List[int] = [2048, 512], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - mel_fmin: List[float] = [0.0, 0.0], - mel_fmax: List[float] = [None, None], - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.n_mels = n_mels - self.loss_fn = loss_fn - self.clamp_eps = clamp_eps - self.log_weight = log_weight - self.mag_weight = mag_weight - self.weight = weight - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes mel loss between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Mel loss. - """ - loss = 0.0 - for n_mels, fmin, fmax, s in zip( - self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params - ): - kwargs = { - "window_length": s.window_length, - "hop_length": s.hop_length, - "window_type": s.window_type, - } - x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - - loss += self.log_weight * self.loss_fn( - x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x_mels, y_mels) - return loss - - -class PhaseLoss(nn.Module): - """Difference between phase spectrograms. - - Parameters - ---------- - window_length : int, optional - Length of STFT window, by default 2048 - hop_length : int, optional - Hop length of STFT window, by default 512 - weight : float, optional - Weight of loss, by default 1.0 - """ - - def __init__( - self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 - ): - super().__init__() - - self.weight = weight - self.stft_params = STFTParams(window_length, hop_length) - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes phase loss between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Phase loss. - """ - s = self.stft_params - x.stft(s.window_length, s.hop_length, s.window_type) - y.stft(s.window_length, s.hop_length, s.window_type) - - # Take circular difference - diff = x.phase - y.phase - diff[diff < -np.pi] += 2 * np.pi - diff[diff > np.pi] -= -2 * np.pi - - # Scale true magnitude to weights in [0, 1] - x_min, x_max = x.magnitude.min(), x.magnitude.max() - weights = (x.magnitude - x_min) / (x_max - x_min) - - # Take weighted mean of all phase errors - loss = ((weights * diff) ** 2).mean() - return loss diff --git a/dito/models/ldm/dac/audiotools/ml/__init__.py b/dito/models/ldm/dac/audiotools/ml/__init__.py deleted file mode 100644 index a9ca69977bad57e1a92b7551d601d9224ee854ab..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from . import decorators -from . import layers -from .accelerator import Accelerator -from .experiment import Experiment -from .layers import BaseModel diff --git a/dito/models/ldm/dac/audiotools/ml/accelerator.py b/dito/models/ldm/dac/audiotools/ml/accelerator.py deleted file mode 100644 index 37c6e8d954f112b8b0aff257894e62add8874e30..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/accelerator.py +++ /dev/null @@ -1,184 +0,0 @@ -import os -import typing - -import torch -import torch.distributed as dist -from torch.nn.parallel import DataParallel -from torch.nn.parallel import DistributedDataParallel - -from ..data.datasets import ResumableDistributedSampler as DistributedSampler -from ..data.datasets import ResumableSequentialSampler as SequentialSampler - - -class Accelerator: # pragma: no cover - """This class is used to prepare models and dataloaders for - usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to - prepare the respective objects. In the case of models, they are moved to - the appropriate GPU and SyncBatchNorm is applied to them. In the case of - dataloaders, a sampler is created and the dataloader is initialized with - that sampler. - - If the world size is 1, prepare_model and prepare_dataloader are - no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the - script was launched without ``torchrun``, and ``DataParallel`` - will be used instead of ``DistributedDataParallel`` (not recommended), if - the world size (number of GPUs) is greater than 1. - - Parameters - ---------- - amp : bool, optional - Whether or not to enable automatic mixed precision, by default False - """ - - def __init__(self, amp: bool = False): - local_rank = os.getenv("LOCAL_RANK", None) - self.world_size = torch.cuda.device_count() - - self.use_ddp = self.world_size > 1 and local_rank is not None - self.use_dp = self.world_size > 1 and local_rank is None - self.device = "cpu" if self.world_size == 0 else "cuda" - - if self.use_ddp: - local_rank = int(local_rank) - dist.init_process_group( - "nccl", - init_method="env://", - world_size=self.world_size, - rank=local_rank, - ) - - self.local_rank = 0 if local_rank is None else local_rank - self.amp = amp - - class DummyScaler: - def __init__(self): - pass - - def step(self, optimizer): - optimizer.step() - - def scale(self, loss): - return loss - - def unscale_(self, optimizer): - return optimizer - - def update(self): - pass - - self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() - self.device_ctx = ( - torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None - ) - - def __enter__(self): - if self.device_ctx is not None: - self.device_ctx.__enter__() - return self - - def __exit__(self, exc_type, exc_value, traceback): - if self.device_ctx is not None: - self.device_ctx.__exit__(exc_type, exc_value, traceback) - - def prepare_model(self, model: torch.nn.Module, **kwargs): - """Prepares model for DDP or DP. The model is moved to - the device of the correct rank. - - Parameters - ---------- - model : torch.nn.Module - Model that is converted for DDP or DP. - - Returns - ------- - torch.nn.Module - Wrapped model, or original model if DDP and DP are turned off. - """ - model = model.to(self.device) - if self.use_ddp: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = DistributedDataParallel( - model, device_ids=[self.local_rank], **kwargs - ) - elif self.use_dp: - model = DataParallel(model, **kwargs) - return model - - # Automatic mixed-precision utilities - def autocast(self, *args, **kwargs): - """Context manager for autocasting. Arguments - go to ``torch.cuda.amp.autocast``. - """ - return torch.cuda.amp.autocast(self.amp, *args, **kwargs) - - def backward(self, loss: torch.Tensor): - """Backwards pass, after scaling the loss if ``amp`` is - enabled. - - Parameters - ---------- - loss : torch.Tensor - Loss value. - """ - self.scaler.scale(loss).backward() - - def step(self, optimizer: torch.optim.Optimizer): - """Steps the optimizer, using a ``scaler`` if ``amp`` is - enabled. - - Parameters - ---------- - optimizer : torch.optim.Optimizer - Optimizer to step forward. - """ - self.scaler.step(optimizer) - - def update(self): - """Updates the scale factor.""" - self.scaler.update() - - def prepare_dataloader( - self, dataset: typing.Iterable, start_idx: int = None, **kwargs - ): - """Wraps a dataset with a DataLoader, using the correct sampler if DDP is - enabled. - - Parameters - ---------- - dataset : typing.Iterable - Dataset to build Dataloader around. - start_idx : int, optional - Start index of sampler, useful if resuming from some epoch, - by default None - - Returns - ------- - _type_ - _description_ - """ - - if self.use_ddp: - sampler = DistributedSampler( - dataset, - start_idx, - num_replicas=self.world_size, - rank=self.local_rank, - ) - if "num_workers" in kwargs: - kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) - kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) - else: - sampler = SequentialSampler(dataset, start_idx) - - dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) - return dataloader - - @staticmethod - def unwrap(model): - """Unwraps the model if it was wrapped in DDP or DP, otherwise - just returns the model. Use this to unwrap the model returned by - :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. - """ - if hasattr(model, "module"): - return model.module - return model diff --git a/dito/models/ldm/dac/audiotools/ml/decorators.py b/dito/models/ldm/dac/audiotools/ml/decorators.py deleted file mode 100644 index 3a435b06c47a48dc3600fa54ac092006f5c5bb27..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/decorators.py +++ /dev/null @@ -1,441 +0,0 @@ -import math -import os -import time -from collections import defaultdict -from functools import wraps - -import torch -import torch.distributed as dist -from rich import box -from rich.console import Console -from rich.console import Group -from rich.live import Live -from rich.markdown import Markdown -from rich.padding import Padding -from rich.panel import Panel -from rich.progress import BarColumn -from rich.progress import Progress -from rich.progress import SpinnerColumn -from rich.progress import TimeElapsedColumn -from rich.progress import TimeRemainingColumn -from rich.rule import Rule -from rich.table import Table -from torch.utils.tensorboard import SummaryWriter - - -# This is here so that the history can be pickled. -def default_list(): - return [] - - -class Mean: - """Keeps track of the running mean, along with the latest - value. - """ - - def __init__(self): - self.reset() - - def __call__(self): - mean = self.total / max(self.count, 1) - return mean - - def reset(self): - self.count = 0 - self.total = 0 - - def update(self, val): - if math.isfinite(val): - self.count += 1 - self.total += val - - -def when(condition): - """Runs a function only when the condition is met. The condition is - a function that is run. - - Parameters - ---------- - condition : Callable - Function to run to check whether or not to run the decorated - function. - - Example - ------- - Checkpoint only runs every 100 iterations, and only if the - local rank is 0. - - >>> i = 0 - >>> rank = 0 - >>> - >>> @when(lambda: i % 100 == 0 and rank == 0) - >>> def checkpoint(): - >>> print("Saving to /runs/exp1") - >>> - >>> for i in range(1000): - >>> checkpoint() - - """ - - def decorator(fn): - @wraps(fn) - def decorated(*args, **kwargs): - if condition(): - return fn(*args, **kwargs) - - return decorated - - return decorator - - -def timer(prefix: str = "time"): - """Adds execution time to the output dictionary of the decorated - function. The function decorated by this must output a dictionary. - The key added will follow the form "[prefix]/[name_of_function]" - - Parameters - ---------- - prefix : str, optional - The key added will follow the form "[prefix]/[name_of_function]", - by default "time". - """ - - def decorator(fn): - @wraps(fn) - def decorated(*args, **kwargs): - s = time.perf_counter() - output = fn(*args, **kwargs) - assert isinstance(output, dict) - e = time.perf_counter() - output[f"{prefix}/{fn.__name__}"] = e - s - return output - - return decorated - - return decorator - - -class Tracker: - """ - A tracker class that helps to monitor the progress of training and logging the metrics. - - Attributes - ---------- - metrics : dict - A dictionary containing the metrics for each label. - history : dict - A dictionary containing the history of metrics for each label. - writer : SummaryWriter - A SummaryWriter object for logging the metrics. - rank : int - The rank of the current process. - step : int - The current step of the training. - tasks : dict - A dictionary containing the progress bars and tables for each label. - pbar : Progress - A progress bar object for displaying the progress. - consoles : list - A list of console objects for logging. - live : Live - A Live object for updating the display live. - - Methods - ------- - print(msg: str) - Prints the given message to all consoles. - update(label: str, fn_name: str) - Updates the progress bar and table for the given label. - done(label: str, title: str) - Resets the progress bar and table for the given label and prints the final result. - track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) - A decorator for tracking the progress and metrics of a function. - log(label: str, value_type: str = "value", history: bool = True) - A decorator for logging the metrics of a function. - is_best(label: str, key: str) -> bool - Checks if the latest value of the given key in the label is the best so far. - state_dict() -> dict - Returns a dictionary containing the state of the tracker. - load_state_dict(state_dict: dict) -> Tracker - Loads the state of the tracker from the given state dictionary. - """ - - def __init__( - self, - writer: SummaryWriter = None, - log_file: str = None, - rank: int = 0, - console_width: int = 100, - step: int = 0, - ): - """ - Initializes the Tracker object. - - Parameters - ---------- - writer : SummaryWriter, optional - A SummaryWriter object for logging the metrics, by default None. - log_file : str, optional - The path to the log file, by default None. - rank : int, optional - The rank of the current process, by default 0. - console_width : int, optional - The width of the console, by default 100. - step : int, optional - The current step of the training, by default 0. - """ - self.metrics = {} - self.history = {} - self.writer = writer - self.rank = rank - self.step = step - - # Create progress bars etc. - self.tasks = {} - self.pbar = Progress( - SpinnerColumn(), - "[progress.description]{task.description}", - "{task.completed}/{task.total}", - BarColumn(), - TimeElapsedColumn(), - "/", - TimeRemainingColumn(), - ) - self.consoles = [Console(width=console_width)] - self.live = Live(console=self.consoles[0], refresh_per_second=10) - if log_file is not None: - self.consoles.append(Console(width=console_width, file=open(log_file, "a"))) - - def print(self, msg): - """ - Prints the given message to all consoles. - - Parameters - ---------- - msg : str - The message to be printed. - """ - if self.rank == 0: - for c in self.consoles: - c.log(msg) - - def update(self, label, fn_name): - """ - Updates the progress bar and table for the given label. - - Parameters - ---------- - label : str - The label of the progress bar and table to be updated. - fn_name : str - The name of the function associated with the label. - """ - if self.rank == 0: - self.pbar.advance(self.tasks[label]["pbar"]) - - # Create table - table = Table(title=label, expand=True, box=box.MINIMAL) - table.add_column("key", style="cyan") - table.add_column("value", style="bright_blue") - table.add_column("mean", style="bright_green") - - keys = self.metrics[label]["value"].keys() - for k in keys: - value = self.metrics[label]["value"][k] - mean = self.metrics[label]["mean"][k]() - table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") - - self.tasks[label]["table"] = table - tables = [t["table"] for t in self.tasks.values()] - group = Group(*tables, self.pbar) - self.live.update( - Group( - Padding("", (0, 0)), - Rule(f"[italic]{fn_name}()", style="white"), - Padding("", (0, 0)), - Panel.fit( - group, padding=(0, 5), title="[b]Progress", border_style="blue" - ), - ) - ) - - def done(self, label: str, title: str): - """ - Resets the progress bar and table for the given label and prints the final result. - - Parameters - ---------- - label : str - The label of the progress bar and table to be reset. - title : str - The title to be displayed when printing the final result. - """ - for label in self.metrics: - for v in self.metrics[label]["mean"].values(): - v.reset() - - if self.rank == 0: - self.pbar.reset(self.tasks[label]["pbar"]) - tables = [t["table"] for t in self.tasks.values()] - group = Group(Markdown(f"# {title}"), *tables, self.pbar) - self.print(group) - - def track( - self, - label: str, - length: int, - completed: int = 0, - op: dist.ReduceOp = dist.ReduceOp.AVG, - ddp_active: bool = "LOCAL_RANK" in os.environ, - ): - """ - A decorator for tracking the progress and metrics of a function. - - Parameters - ---------- - label : str - The label to be associated with the progress and metrics. - length : int - The total number of iterations to be completed. - completed : int, optional - The number of iterations already completed, by default 0. - op : dist.ReduceOp, optional - The reduce operation to be used, by default dist.ReduceOp.AVG. - ddp_active : bool, optional - Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. - """ - self.tasks[label] = { - "pbar": self.pbar.add_task( - f"[white]Iteration ({label})", total=length, completed=completed - ), - "table": Table(), - } - self.metrics[label] = { - "value": defaultdict(), - "mean": defaultdict(lambda: Mean()), - } - - def decorator(fn): - @wraps(fn) - def decorated(*args, **kwargs): - output = fn(*args, **kwargs) - if not isinstance(output, dict): - self.update(label, fn.__name__) - return output - # Collect across all DDP processes - scalar_keys = [] - for k, v in output.items(): - if isinstance(v, (int, float)): - v = torch.tensor([v]) - if not torch.is_tensor(v): - continue - if ddp_active and v.is_cuda: # pragma: no cover - dist.all_reduce(v, op=op) - output[k] = v.detach() - if torch.numel(v) == 1: - scalar_keys.append(k) - output[k] = v.item() - - # Save the outputs to tracker - for k, v in output.items(): - if k not in scalar_keys: - continue - self.metrics[label]["value"][k] = v - # Update the running mean - self.metrics[label]["mean"][k].update(v) - - self.update(label, fn.__name__) - return output - - return decorated - - return decorator - - def log(self, label: str, value_type: str = "value", history: bool = True): - """ - A decorator for logging the metrics of a function. - - Parameters - ---------- - label : str - The label to be associated with the logging. - value_type : str, optional - The type of value to be logged, by default "value". - history : bool, optional - Whether to save the history of the metrics, by default True. - """ - assert value_type in ["mean", "value"] - if history: - if label not in self.history: - self.history[label] = defaultdict(default_list) - - def decorator(fn): - @wraps(fn) - def decorated(*args, **kwargs): - output = fn(*args, **kwargs) - if self.rank == 0: - nonlocal value_type, label - metrics = self.metrics[label][value_type] - for k, v in metrics.items(): - v = v() if isinstance(v, Mean) else v - if self.writer is not None: - # self.writer.add_scalar(f"{k}/{label}", v, self.step) - self.writer.log_metric(f"{k}_{label}", v, step=self.step) - if label in self.history: - self.history[label][k].append(v) - - if label in self.history: - self.history[label]["step"].append(self.step) - - return output - - return decorated - - return decorator - - def is_best(self, label, key): - """ - Checks if the latest value of the given key in the label is the best so far. - - Parameters - ---------- - label : str - The label of the metrics to be checked. - key : str - The key of the metric to be checked. - - Returns - ------- - bool - True if the latest value is the best so far, otherwise False. - """ - return self.history[label][key][-1] == min(self.history[label][key]) - - def state_dict(self): - """ - Returns a dictionary containing the state of the tracker. - - Returns - ------- - dict - A dictionary containing the history and step of the tracker. - """ - return {"history": self.history, "step": self.step} - - def load_state_dict(self, state_dict): - """ - Loads the state of the tracker from the given state dictionary. - - Parameters - ---------- - state_dict : dict - A dictionary containing the history and step of the tracker. - - Returns - ------- - Tracker - The tracker object with the loaded state. - """ - self.history = state_dict["history"] - self.step = state_dict["step"] - return self diff --git a/dito/models/ldm/dac/audiotools/ml/experiment.py b/dito/models/ldm/dac/audiotools/ml/experiment.py deleted file mode 100644 index 62833d0f8f80dcdf496a1a5d2785ef666e0a15b6..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/experiment.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Useful class for Experiment tracking, and ensuring code is -saved alongside files. -""" # fmt: skip -import datetime -import os -import shlex -import shutil -import subprocess -import typing -from pathlib import Path - -import randomname - - -class Experiment: - """This class contains utilities for managing experiments. - It is a context manager, that when you enter it, changes - your directory to a specified experiment folder (which - optionally can have an automatically generated experiment - name, or a specified one), and changes the CUDA device used - to the specified device (or devices). - - Parameters - ---------- - exp_directory : str - Folder where all experiments are saved, by default "runs/". - exp_name : str, optional - Name of the experiment, by default uses the current time, date, and - hostname to save. - """ - - def __init__( - self, - exp_directory: str = "runs/", - exp_name: str = None, - ): - if exp_name is None: - exp_name = self.generate_exp_name() - exp_dir = Path(exp_directory) / exp_name - exp_dir.mkdir(parents=True, exist_ok=True) - - self.exp_dir = exp_dir - self.exp_name = exp_name - self.git_tracked_files = ( - subprocess.check_output( - shlex.split("git ls-tree --full-tree --name-only -r HEAD") - ) - .decode("utf-8") - .splitlines() - ) - self.parent_directory = Path(".").absolute() - - def __enter__(self): - self.prev_dir = os.getcwd() - os.chdir(self.exp_dir) - return self - - def __exit__(self, exc_type, exc_value, traceback): - os.chdir(self.prev_dir) - - @staticmethod - def generate_exp_name(): - """Generates a random experiment name based on the date - and a randomly generated adjective-noun tuple. - - Returns - ------- - str - Randomly generated experiment name. - """ - date = datetime.datetime.now().strftime("%y%m%d") - name = f"{date}-{randomname.get_name()}" - return name - - def snapshot(self, filter_fn: typing.Callable = lambda f: True): - """Captures a full snapshot of all the files tracked by git at the time - the experiment is run. It also captures the diff against the committed - code as a separate file. - - Parameters - ---------- - filter_fn : typing.Callable, optional - Function that can be used to exclude some files - from the snapshot, by default accepts all files - """ - for f in self.git_tracked_files: - if filter_fn(f): - Path(f).parent.mkdir(parents=True, exist_ok=True) - shutil.copyfile(self.parent_directory / f, f) diff --git a/dito/models/ldm/dac/audiotools/ml/layers/__init__.py b/dito/models/ldm/dac/audiotools/ml/layers/__init__.py deleted file mode 100644 index 92a016cab2ddf06bf5dadfae241b7e5d9def4878..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/layers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base import BaseModel -from .spectral_gate import SpectralGate diff --git a/dito/models/ldm/dac/audiotools/ml/layers/base.py b/dito/models/ldm/dac/audiotools/ml/layers/base.py deleted file mode 100644 index b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/layers/base.py +++ /dev/null @@ -1,328 +0,0 @@ -import inspect -import shutil -import tempfile -import typing -from pathlib import Path - -import torch -from torch import nn - - -class BaseModel(nn.Module): - """This is a class that adds useful save/load functionality to a - ``torch.nn.Module`` object. ``BaseModel`` objects can be saved - as ``torch.package`` easily, making them super easy to port between - machines without requiring a ton of dependencies. Files can also be - saved as just weights, in the standard way. - - >>> class Model(ml.BaseModel): - >>> def __init__(self, arg1: float = 1.0): - >>> super().__init__() - >>> self.arg1 = arg1 - >>> self.linear = nn.Linear(1, 1) - >>> - >>> def forward(self, x): - >>> return self.linear(x) - >>> - >>> model1 = Model() - >>> - >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: - >>> model1.save( - >>> f.name, - >>> ) - >>> model2 = Model.load(f.name) - >>> out2 = seed_and_run(model2, x) - >>> assert torch.allclose(out1, out2) - >>> - >>> model1.save(f.name, package=True) - >>> model2 = Model.load(f.name) - >>> model2.save(f.name, package=False) - >>> model3 = Model.load(f.name) - >>> out3 = seed_and_run(model3, x) - >>> - >>> with tempfile.TemporaryDirectory() as d: - >>> model1.save_to_folder(d, {"data": 1.0}) - >>> Model.load_from_folder(d) - - """ - - EXTERN = [ - "audiotools.**", - "tqdm", - "__main__", - "numpy.**", - "julius.**", - "torchaudio.**", - "scipy.**", - "einops", - ] - """Names of libraries that are external to the torch.package saving mechanism. - Source code from these libraries will not be packaged into the model. This can - be edited by the user of this class by editing ``model.EXTERN``.""" - INTERN = [] - """Names of libraries that are internal to the torch.package saving mechanism. - Source code from these libraries will be saved alongside the model.""" - - def save( - self, - path: str, - metadata: dict = None, - package: bool = True, - intern: list = [], - extern: list = [], - mock: list = [], - ): - """Saves the model, either as a torch package, or just as - weights, alongside some specified metadata. - - Parameters - ---------- - path : str - Path to save model to. - metadata : dict, optional - Any metadata to save alongside the model, - by default None - package : bool, optional - Whether to use ``torch.package`` to save the model in - a format that is portable, by default True - intern : list, optional - List of additional libraries that are internal - to the model, used with torch.package, by default [] - extern : list, optional - List of additional libraries that are external to - the model, used with torch.package, by default [] - mock : list, optional - List of libraries to mock, used with torch.package, - by default [] - - Returns - ------- - str - Path to saved model. - """ - sig = inspect.signature(self.__class__) - args = {} - - for key, val in sig.parameters.items(): - arg_val = val.default - if arg_val is not inspect.Parameter.empty: - args[key] = arg_val - - # Look up attibutes in self, and if any of them are in args, - # overwrite them in args. - for attribute in dir(self): - if attribute in args: - args[attribute] = getattr(self, attribute) - - metadata = {} if metadata is None else metadata - metadata["kwargs"] = args - if not hasattr(self, "metadata"): - self.metadata = {} - self.metadata.update(metadata) - - if not package: - state_dict = {"state_dict": self.state_dict(), "metadata": metadata} - torch.save(state_dict, path) - else: - self._save_package(path, intern=intern, extern=extern, mock=mock) - - return path - - @property - def device(self): - """Gets the device the model is on by looking at the device of - the first parameter. May not be valid if model is split across - multiple devices. - """ - return list(self.parameters())[0].device - - @classmethod - def load( - cls, - location: str, - *args, - package_name: str = None, - strict: bool = False, - **kwargs, - ): - """Load model from a path. Tries first to load as a package, and if - that fails, tries to load as weights. The arguments to the class are - specified inside the model weights file. - - Parameters - ---------- - location : str - Path to file. - package_name : str, optional - Name of package, by default ``cls.__name__``. - strict : bool, optional - Ignore unmatched keys, by default False - kwargs : dict - Additional keyword arguments to the model instantiation, if - not loading from package. - - Returns - ------- - BaseModel - A model that inherits from BaseModel. - """ - try: - model = cls._load_package(location, package_name=package_name) - except: - model_dict = torch.load(location, "cpu") - metadata = model_dict["metadata"] - metadata["kwargs"].update(kwargs) - - sig = inspect.signature(cls) - class_keys = list(sig.parameters.keys()) - for k in list(metadata["kwargs"].keys()): - if k not in class_keys: - metadata["kwargs"].pop(k) - - model = cls(*args, **metadata["kwargs"]) - model.load_state_dict(model_dict["state_dict"], strict=strict) - model.metadata = metadata - - return model - - def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): - package_name = type(self).__name__ - resource_name = f"{type(self).__name__}.pth" - - # Below is for loading and re-saving a package. - if hasattr(self, "importer"): - kwargs["importer"] = (self.importer, torch.package.sys_importer) - del self.importer - - # Why do we use a tempfile, you ask? - # It's so we can load a packaged model and then re-save - # it to the same location. torch.package throws an - # error if it's loading and writing to the same - # file (this is undocumented). - with tempfile.NamedTemporaryFile(suffix=".pth") as f: - with torch.package.PackageExporter(f.name, **kwargs) as exp: - exp.intern(self.INTERN + intern) - exp.mock(mock) - exp.extern(self.EXTERN + extern) - exp.save_pickle(package_name, resource_name, self) - - if hasattr(self, "metadata"): - exp.save_pickle( - package_name, f"{package_name}.metadata", self.metadata - ) - - shutil.copyfile(f.name, path) - - # Must reset the importer back to `self` if it existed - # so that you can save the model again! - if "importer" in kwargs: - self.importer = kwargs["importer"][0] - return path - - @classmethod - def _load_package(cls, path, package_name=None): - package_name = cls.__name__ if package_name is None else package_name - resource_name = f"{package_name}.pth" - - imp = torch.package.PackageImporter(path) - model = imp.load_pickle(package_name, resource_name, "cpu") - try: - model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata") - except: # pragma: no cover - pass - model.importer = imp - - return model - - def save_to_folder( - self, - folder: typing.Union[str, Path], - extra_data: dict = None, - package: bool = True, - ): - """Dumps a model into a folder, as both a package - and as weights, as well as anything specified in - ``extra_data``. ``extra_data`` is a dictionary of other - pickleable files, with the keys being the paths - to save them in. The model is saved under a subfolder - specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` - if the model name was ``Generator``). - - >>> with tempfile.TemporaryDirectory() as d: - >>> extra_data = { - >>> "optimizer.pth": optimizer.state_dict() - >>> } - >>> model.save_to_folder(d, extra_data) - >>> Model.load_from_folder(d) - - Parameters - ---------- - folder : typing.Union[str, Path] - _description_ - extra_data : dict, optional - _description_, by default None - - Returns - ------- - str - Path to folder - """ - extra_data = {} if extra_data is None else extra_data - model_name = type(self).__name__.lower() - target_base = Path(f"{folder}/{model_name}/") - target_base.mkdir(exist_ok=True, parents=True) - - if package: - package_path = target_base / f"package.pth" - self.save(package_path) - - weights_path = target_base / f"weights.pth" - self.save(weights_path, package=False) - - for path, obj in extra_data.items(): - torch.save(obj, target_base / path) - - return target_base - - @classmethod - def load_from_folder( - cls, - folder: typing.Union[str, Path], - package: bool = True, - strict: bool = False, - **kwargs, - ): - """Loads the model from a folder generated by - :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. - Like that function, this one looks for a subfolder that has - the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the - model name was ``Generator``). - - Parameters - ---------- - folder : typing.Union[str, Path] - _description_ - package : bool, optional - Whether to use ``torch.package`` to load the model, - loading the model from ``package.pth``. - strict : bool, optional - Ignore unmatched keys, by default False - - Returns - ------- - tuple - tuple of model and extra data as saved by - :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. - """ - folder = Path(folder) / cls.__name__.lower() - model_pth = "package.pth" if package else "weights.pth" - model_pth = folder / model_pth - - model = cls.load(model_pth, strict=strict) - extra_data = {} - excluded = ["package.pth", "weights.pth"] - files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] - for f in files: - extra_data[f.name] = torch.load(f, **kwargs) - - return model, extra_data diff --git a/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py b/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py deleted file mode 100644 index c4ae8b5eab2e56ce13541695f52a11a454759dae..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/ml/layers/spectral_gate.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn - -from ...core import AudioSignal -from ...core import STFTParams -from ...core import util - - -class SpectralGate(nn.Module): - """Spectral gating algorithm for noise reduction, - as in Audacity/Ocenaudio. The steps are as follows: - - 1. An FFT is calculated over the noise audio clip - 2. Statistics are calculated over FFT of the the noise - (in frequency) - 3. A threshold is calculated based upon the statistics - of the noise (and the desired sensitivity of the algorithm) - 4. An FFT is calculated over the signal - 5. A mask is determined by comparing the signal FFT to the - threshold - 6. The mask is smoothed with a filter over frequency and time - 7. The mask is appled to the FFT of the signal, and is inverted - - Implementation inspired by Tim Sainburg's noisereduce: - - https://timsainburg.com/noise-reduction-python.html - - Parameters - ---------- - n_freq : int, optional - Number of frequency bins to smooth by, by default 3 - n_time : int, optional - Number of time bins to smooth by, by default 5 - """ - - def __init__(self, n_freq: int = 3, n_time: int = 5): - super().__init__() - - smoothing_filter = torch.outer( - torch.cat( - [ - torch.linspace(0, 1, n_freq + 2)[:-1], - torch.linspace(1, 0, n_freq + 2), - ] - )[..., 1:-1], - torch.cat( - [ - torch.linspace(0, 1, n_time + 2)[:-1], - torch.linspace(1, 0, n_time + 2), - ] - )[..., 1:-1], - ) - smoothing_filter = smoothing_filter / smoothing_filter.sum() - smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) - self.register_buffer("smoothing_filter", smoothing_filter) - - def forward( - self, - audio_signal: AudioSignal, - nz_signal: AudioSignal, - denoise_amount: float = 1.0, - n_std: float = 3.0, - win_length: int = 2048, - hop_length: int = 512, - ): - """Perform noise reduction. - - Parameters - ---------- - audio_signal : AudioSignal - Audio signal that noise will be removed from. - nz_signal : AudioSignal, optional - Noise signal to compute noise statistics from. - denoise_amount : float, optional - Amount to denoise by, by default 1.0 - n_std : float, optional - Number of standard deviations above which to consider - noise, by default 3.0 - win_length : int, optional - Length of window for STFT, by default 2048 - hop_length : int, optional - Hop length for STFT, by default 512 - - Returns - ------- - AudioSignal - Denoised audio signal. - """ - stft_params = STFTParams(win_length, hop_length, "sqrt_hann") - - audio_signal = audio_signal.clone() - audio_signal.stft_data = None - audio_signal.stft_params = stft_params - - nz_signal = nz_signal.clone() - nz_signal.stft_params = stft_params - - nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() - nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) - nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) - - nz_thresh = nz_freq_mean + nz_freq_std * n_std - - stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() - nb, nac, nf, nt = stft_db.shape - db_thresh = nz_thresh.expand(nb, nac, -1, nt) - - stft_mask = (stft_db < db_thresh).float() - shape = stft_mask.shape - - stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) - pad_tuple = ( - self.smoothing_filter.shape[-2] // 2, - self.smoothing_filter.shape[-1] // 2, - ) - stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) - stft_mask = stft_mask.reshape(*shape) - stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to( - audio_signal.device - ) - stft_mask = 1 - stft_mask - - audio_signal.stft_data *= stft_mask - audio_signal.istft() - - return audio_signal diff --git a/dito/models/ldm/dac/audiotools/post.py b/dito/models/ldm/dac/audiotools/post.py deleted file mode 100644 index 6ced2d1e66a4ffda3269685bd45593b01038739f..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/post.py +++ /dev/null @@ -1,140 +0,0 @@ -import tempfile -import typing -import zipfile -from pathlib import Path - -import markdown2 as md -import matplotlib.pyplot as plt -import torch -from IPython.display import HTML - - -def audio_table( - audio_dict: dict, - first_column: str = None, - format_fn: typing.Callable = None, - **kwargs, -): # pragma: no cover - """Embeds an audio table into HTML, or as the output cell - in a notebook. - - Parameters - ---------- - audio_dict : dict - Dictionary of data to embed. - first_column : str, optional - The label for the first column of the table, by default None - format_fn : typing.Callable, optional - How to format the data, by default None - - Returns - ------- - str - Table as a string - - Examples - -------- - - >>> audio_dict = {} - >>> for i in range(signal_batch.batch_size): - >>> audio_dict[i] = { - >>> "input": signal_batch[i], - >>> "output": output_batch[i] - >>> } - >>> audiotools.post.audio_zip(audio_dict) - - """ - from audiotools import AudioSignal - - output = [] - columns = None - - def _default_format_fn(label, x, **kwargs): - if torch.is_tensor(x): - x = x.tolist() - - if x is None: - return "." - elif isinstance(x, AudioSignal): - return x.embed(display=False, return_html=True, **kwargs) - else: - return str(x) - - if format_fn is None: - format_fn = _default_format_fn - - if first_column is None: - first_column = "." - - for k, v in audio_dict.items(): - if not isinstance(v, dict): - v = {"Audio": v} - - v_keys = list(v.keys()) - if columns is None: - columns = [first_column] + v_keys - output.append(" | ".join(columns)) - - layout = "|---" + len(v_keys) * "|:-:" - output.append(layout) - - formatted_audio = [] - for col in columns[1:]: - formatted_audio.append(format_fn(col, v[col], **kwargs)) - - row = f"| {k} | " - row += " | ".join(formatted_audio) - output.append(row) - - output = "\n" + "\n".join(output) - return output - - -def in_notebook(): # pragma: no cover - """Determines if code is running in a notebook. - - Returns - ------- - bool - Whether or not this is running in a notebook. - """ - try: - from IPython import get_ipython - - if "IPKernelApp" not in get_ipython().config: # pragma: no cover - return False - except ImportError: - return False - except AttributeError: - return False - return True - - -def disp(obj, **kwargs): # pragma: no cover - """Displays an object, depending on if its in a notebook - or not. - - Parameters - ---------- - obj : typing.Any - Any object to display. - - """ - from audiotools import AudioSignal - - IN_NOTEBOOK = in_notebook() - - if isinstance(obj, AudioSignal): - audio_elem = obj.embed(display=False, return_html=True) - if IN_NOTEBOOK: - return HTML(audio_elem) - else: - print(audio_elem) - if isinstance(obj, dict): - table = audio_table(obj, **kwargs) - if IN_NOTEBOOK: - return HTML(md.markdown(table, extras=["tables"])) - else: - print(table) - if isinstance(obj, plt.Figure): - plt.show() diff --git a/dito/models/ldm/dac/audiotools/preference.py b/dito/models/ldm/dac/audiotools/preference.py deleted file mode 100644 index 800a852e8119dd18ea65784cf95182de2470fbc4..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/audiotools/preference.py +++ /dev/null @@ -1,600 +0,0 @@ -############################################################## -### Tools for creating preference tests (MUSHRA, ABX, etc) ### -############################################################## -import copy -import csv -import random -import sys -import traceback -from collections import defaultdict -from pathlib import Path -from typing import List - -import gradio as gr - -from audiotools.core.util import find_audio - -################################################################ -### Logic for audio player, and adding audio / play buttons. ### -################################################################ - -WAVESURFER = """
""" - -CUSTOM_CSS = """ -.gradio-container { - max-width: 840px !important; -} -region.wavesurfer-region:before { - content: attr(data-region-label); -} - -block { - min-width: 0 !important; -} - -#wave-timeline { - background-color: rgba(0, 0, 0, 0.8); -} - -.head.svelte-1cl284s { - display: none; -} -""" - -load_wavesurfer_js = """ -function load_wavesurfer() { - function load_script(url) { - const script = document.createElement('script'); - script.src = url; - document.body.appendChild(script); - - return new Promise((res, rej) => { - script.onload = function() { - res(); - } - script.onerror = function () { - rej(); - } - }); - } - - function create_wavesurfer() { - var options = { - container: '#waveform', - waveColor: '#F2F2F2', // Set a darker wave color - progressColor: 'white', // Set a slightly lighter progress color - loaderColor: 'white', // Set a slightly lighter loader color - cursorColor: 'black', // Set a slightly lighter cursor color - backgroundColor: '#00AAFF', // Set a black background color - barWidth: 4, - barRadius: 3, - barHeight: 1, // the height of the wave - plugins: [ - WaveSurfer.regions.create({ - regionsMinLength: 0.0, - dragSelection: { - slop: 5 - }, - color: 'hsla(200, 50%, 70%, 0.4)', - }), - WaveSurfer.timeline.create({ - container: "#wave-timeline", - primaryLabelInterval: 5.0, - secondaryLabelInterval: 1.0, - primaryFontColor: '#F2F2F2', - secondaryFontColor: '#F2F2F2', - }), - ] - }; - wavesurfer = WaveSurfer.create(options); - wavesurfer.on('region-created', region => { - wavesurfer.regions.clear(); - }); - wavesurfer.on('finish', function () { - var loop = document.getElementById("loop-button").textContent.includes("ON"); - if (loop) { - wavesurfer.play(); - } - else { - var button_elements = document.getElementsByClassName('playpause') - var buttons = Array.from(button_elements); - - for (let j = 0; j < buttons.length; j++) { - buttons[j].classList.remove("primary"); - buttons[j].classList.add("secondary"); - buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") - } - } - }); - - wavesurfer.on('region-out', function () { - var loop = document.getElementById("loop-button").textContent.includes("ON"); - if (!loop) { - var button_elements = document.getElementsByClassName('playpause') - var buttons = Array.from(button_elements); - - for (let j = 0; j < buttons.length; j++) { - buttons[j].classList.remove("primary"); - buttons[j].classList.add("secondary"); - buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") - } - wavesurfer.pause(); - } - }); - - console.log("Created WaveSurfer object.") - } - - load_script('https://unpkg.com/wavesurfer.js@6.6.4') - .then(() => { - load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js") - .then(() => { - load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js') - .then(() => { - console.log("Loaded regions"); - create_wavesurfer(); - document.getElementById("start-survey").click(); - }) - }) - }); -} -""" - -play = lambda i: """ -function play() { - var audio_elements = document.getElementsByTagName('audio'); - var button_elements = document.getElementsByClassName('playpause') - - var audio_array = Array.from(audio_elements); - var buttons = Array.from(button_elements); - - var src_link = audio_array[{i}].getAttribute("src"); - console.log(src_link); - - var loop = document.getElementById("loop-button").textContent.includes("ON"); - var playing = buttons[{i}].textContent.includes("Stop"); - - for (let j = 0; j < buttons.length; j++) { - if (j != {i} || playing) { - buttons[j].classList.remove("primary"); - buttons[j].classList.add("secondary"); - buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") - } - else { - buttons[j].classList.remove("secondary"); - buttons[j].classList.add("primary"); - buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop") - } - } - - if (playing) { - wavesurfer.pause(); - wavesurfer.seekTo(0.0); - } - else { - wavesurfer.load(src_link); - wavesurfer.on('ready', function () { - var region = Object.values(wavesurfer.regions.list)[0]; - - if (region != null) { - region.loop = loop; - region.play(); - } else { - wavesurfer.play(); - } - }); - } -} -""".replace( - "{i}", str(i) -) - -clear_regions = """ -function clear_regions() { - wavesurfer.clearRegions(); -} -""" - -reset_player = """ -function reset_player() { - wavesurfer.clearRegions(); - wavesurfer.pause(); - wavesurfer.seekTo(0.0); - - var button_elements = document.getElementsByClassName('playpause') - var buttons = Array.from(button_elements); - - for (let j = 0; j < buttons.length; j++) { - buttons[j].classList.remove("primary"); - buttons[j].classList.add("secondary"); - buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") - } -} -""" - -loop_region = """ -function loop_region() { - var element = document.getElementById("loop-button"); - var loop = element.textContent.includes("OFF"); - console.log(loop); - - try { - var region = Object.values(wavesurfer.regions.list)[0]; - region.loop = loop; - } catch {} - - if (loop) { - element.classList.remove("secondary"); - element.classList.add("primary"); - element.textContent = "Looping ON"; - } else { - element.classList.remove("primary"); - element.classList.add("secondary"); - element.textContent = "Looping OFF"; - } -} -""" - - -class Player: - def __init__(self, app): - self.app = app - - self.app.load(_js=load_wavesurfer_js) - self.app.css = CUSTOM_CSS - - self.wavs = [] - self.position = 0 - - def create(self): - gr.HTML(WAVESURFER) - gr.Markdown( - "Click and drag on the waveform above to select a region for playback. " - "Once created, the region can be moved around and resized. " - "Clear the regions using the button below. Hit play on one of the buttons below to start!" - ) - - with gr.Row(): - clear = gr.Button("Clear region") - loop = gr.Button("Looping OFF", elem_id="loop-button") - - loop.click(None, _js=loop_region) - clear.click(None, _js=clear_regions) - - gr.HTML("
") - - def add(self, name: str = "Play"): - i = self.position - self.wavs.append( - { - "audio": gr.Audio(visible=False), - "button": gr.Button(name, elem_classes=["playpause"]), - "position": i, - } - ) - self.wavs[-1]["button"].click(None, _js=play(i)) - self.position += 1 - return self.wavs[-1] - - def to_list(self): - return [x["audio"] for x in self.wavs] - - -############################################################ -### Keeping track of users, and CSS for the progress bar ### -############################################################ - -load_tracker = lambda name: """ -function load_name() { - function setCookie(name, value, exp_days) { - var d = new Date(); - d.setTime(d.getTime() + (exp_days*24*60*60*1000)); - var expires = "expires=" + d.toGMTString(); - document.cookie = name + "=" + value + ";" + expires + ";path=/"; - } - - function getCookie(name) { - var cname = name + "="; - var decodedCookie = decodeURIComponent(document.cookie); - var ca = decodedCookie.split(';'); - for(var i = 0; i < ca.length; i++){ - var c = ca[i]; - while(c.charAt(0) == ' '){ - c = c.substring(1); - } - if(c.indexOf(cname) == 0){ - return c.substring(cname.length, c.length); - } - } - return ""; - } - - name = getCookie("{name}"); - if (name == "") { - name = Math.random().toString(36).slice(2); - console.log(name); - setCookie("name", name, 30); - } - name = getCookie("{name}"); - return name; -} -""".replace( - "{name}", name -) - -# Progress bar - -progress_template = """ - - - - Progress Bar - - - -
-
-
{TEXT}
-
- - -""" - - -def create_tracker(app, cookie_name="name"): - user = gr.Text(label="user", interactive=True, visible=False, elem_id="user") - app.load(_js=load_tracker(cookie_name), outputs=user) - return user - - -################################################################# -### CSS and HTML for labeling sliders for both ABX and MUSHRA ### -################################################################# - -slider_abx = """ - - - - - Labels Example - - - -
-
Prefer A
-
Toss-up
-
Prefer B
-
- - -""" - -slider_mushra = """ - - - - - Labels Example - - - -
-
bad
-
poor
-
fair
-
good
-
excellent
-
- - -""" - -######################################################### -### Handling loading audio and tracking session state ### -######################################################### - - -class Samples: - def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None): - files = find_audio(folder) - samples = defaultdict(lambda: defaultdict()) - - for f in files: - condition = f.parent.stem - samples[f.name][condition] = f - - self.samples = samples - self.names = list(samples.keys()) - self.filtered = False - self.current = 0 - - if shuffle: - random.shuffle(self.names) - - self.n_samples = len(self.names) if n_samples is None else n_samples - - def get_updates(self, idx, order): - key = self.names[idx] - return [gr.update(value=str(self.samples[key][o])) for o in order] - - def progress(self): - try: - pct = self.current / len(self) * 100 - except: # pragma: no cover - pct = 100 - text = f"On {self.current} / {len(self)} samples" - pbar = ( - copy.copy(progress_template) - .replace("{PROGRESS}", str(pct)) - .replace("{TEXT}", str(text)) - ) - return gr.update(value=pbar) - - def __len__(self): - return self.n_samples - - def filter_completed(self, user, save_path): - if not self.filtered: - done = [] - if Path(save_path).exists(): - with open(save_path, "r") as f: - reader = csv.DictReader(f) - done = [r["sample"] for r in reader if r["user"] == user] - self.names = [k for k in self.names if k not in done] - self.names = self.names[: self.n_samples] - self.filtered = True # Avoid filtering more than once per session. - - def get_next_sample(self, reference, conditions): - random.shuffle(conditions) - if reference is not None: - self.order = [reference] + conditions - else: - self.order = conditions - - try: - updates = self.get_updates(self.current, self.order) - self.current += 1 - done = gr.update(interactive=True) - pbar = self.progress() - except: - traceback.print_exc() - updates = [gr.update() for _ in range(len(self.order))] - done = gr.update(value="No more samples!", interactive=False) - self.current = len(self) - pbar = self.progress() - - return updates, done, pbar - - -def save_result(result, save_path): - with open(save_path, mode="a", newline="") as file: - writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys()))) - if file.tell() == 0: - writer.writeheader() - writer.writerow(result) diff --git a/dito/models/ldm/dac/base.py b/dito/models/ldm/dac/base.py deleted file mode 100644 index ede7e8d87f4ec6ceedc94a4d2b9d75217adfe8fe..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/base.py +++ /dev/null @@ -1,294 +0,0 @@ -import math -from dataclasses import dataclass -from pathlib import Path -from typing import Union - -import numpy as np -import torch -import tqdm -from audiotools import AudioSignal -from torch import nn - -SUPPORTED_VERSIONS = ["1.0.0"] - - -@dataclass -class DACFile: - codes: torch.Tensor - - # Metadata - chunk_length: int - original_length: int - input_db: float - channels: int - sample_rate: int - padding: bool - dac_version: str - - def save(self, path): - artifacts = { - "codes": self.codes.numpy().astype(np.uint16), - "metadata": { - "input_db": self.input_db.numpy().astype(np.float32), - "original_length": self.original_length, - "sample_rate": self.sample_rate, - "chunk_length": self.chunk_length, - "channels": self.channels, - "padding": self.padding, - "dac_version": SUPPORTED_VERSIONS[-1], - }, - } - path = Path(path).with_suffix(".dac") - with open(path, "wb") as f: - np.save(f, artifacts) - return path - - @classmethod - def load(cls, path): - artifacts = np.load(path, allow_pickle=True)[()] - codes = torch.from_numpy(artifacts["codes"].astype(int)) - if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: - raise RuntimeError( - f"Given file {path} can't be loaded with this version of descript-audio-codec." - ) - return cls(codes=codes, **artifacts["metadata"]) - - -class CodecMixin: - @property - def padding(self): - if not hasattr(self, "_padding"): - self._padding = True - return self._padding - - @padding.setter - def padding(self, value): - assert isinstance(value, bool) - - layers = [ - l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) - ] - - for layer in layers: - if value: - if hasattr(layer, "original_padding"): - layer.padding = layer.original_padding - else: - layer.original_padding = layer.padding - layer.padding = tuple(0 for _ in range(len(layer.padding))) - - self._padding = value - - def get_delay(self): - # Any number works here, delay is invariant to input length - l_out = self.get_output_length(0) - L = l_out - - layers = [] - for layer in self.modules(): - if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): - layers.append(layer) - - for layer in reversed(layers): - d = layer.dilation[0] - k = layer.kernel_size[0] - s = layer.stride[0] - - if isinstance(layer, nn.ConvTranspose1d): - L = ((L - d * (k - 1) - 1) / s) + 1 - elif isinstance(layer, nn.Conv1d): - L = (L - 1) * s + d * (k - 1) + 1 - - L = math.ceil(L) - - l_in = L - - return (l_in - l_out) // 2 - - def get_output_length(self, input_length): - L = input_length - # Calculate output length - for layer in self.modules(): - if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): - d = layer.dilation[0] - k = layer.kernel_size[0] - s = layer.stride[0] - - if isinstance(layer, nn.Conv1d): - L = ((L - d * (k - 1) - 1) / s) + 1 - elif isinstance(layer, nn.ConvTranspose1d): - L = (L - 1) * s + d * (k - 1) + 1 - - L = math.floor(L) - return L - - @torch.no_grad() - def compress( - self, - audio_path_or_signal: Union[str, Path, AudioSignal], - win_duration: float = 1.0, - verbose: bool = False, - normalize_db: float = -16, - n_quantizers: int = None, - ) -> DACFile: - """Processes an audio signal from a file or AudioSignal object into - discrete codes. This function processes the signal in short windows, - using constant GPU memory. - - Parameters - ---------- - audio_path_or_signal : Union[str, Path, AudioSignal] - audio signal to reconstruct - win_duration : float, optional - window duration in seconds, by default 5.0 - verbose : bool, optional - by default False - normalize_db : float, optional - normalize db, by default -16 - - Returns - ------- - DACFile - Object containing compressed codes and metadata - required for decompression - """ - audio_signal = audio_path_or_signal - if isinstance(audio_signal, (str, Path)): - audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) - - self.eval() - original_padding = self.padding - original_device = audio_signal.device - - audio_signal = audio_signal.clone() - original_sr = audio_signal.sample_rate - - resample_fn = audio_signal.resample - loudness_fn = audio_signal.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if audio_signal.signal_duration >= 10 * 60 * 60: - resample_fn = audio_signal.ffmpeg_resample - loudness_fn = audio_signal.ffmpeg_loudness - - original_length = audio_signal.signal_length - resample_fn(self.sample_rate) - input_db = loudness_fn() - - if normalize_db is not None: - audio_signal.normalize(normalize_db) - audio_signal.ensure_max_of_audio() - - nb, nac, nt = audio_signal.audio_data.shape - audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) - win_duration = ( - audio_signal.signal_duration if win_duration is None else win_duration - ) - - if audio_signal.signal_duration <= win_duration: - # Unchunked compression (used if signal length < win duration) - self.padding = True - n_samples = nt - hop = nt - else: - # Chunked inference - self.padding = False - # Zero-pad signal on either side by the delay - audio_signal.zero_pad(self.delay, self.delay) - n_samples = int(win_duration * self.sample_rate) - # Round n_samples to nearest hop length multiple - n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) - hop = self.get_output_length(n_samples) - - codes = [] - range_fn = range if not verbose else tqdm.trange - - for i in range_fn(0, nt, hop): - x = audio_signal[..., i : i + n_samples] - x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) - - audio_data = x.audio_data.to(self.device) - audio_data = self.preprocess(audio_data, self.sample_rate) - _, c, _, _, _ = self.encode(audio_data, n_quantizers) - codes.append(c.to(original_device)) - chunk_length = c.shape[-1] - - codes = torch.cat(codes, dim=-1) - - dac_file = DACFile( - codes=codes, - chunk_length=chunk_length, - original_length=original_length, - input_db=input_db, - channels=nac, - sample_rate=original_sr, - padding=self.padding, - dac_version=SUPPORTED_VERSIONS[-1], - ) - - if n_quantizers is not None: - codes = codes[:, :n_quantizers, :] - - self.padding = original_padding - return dac_file - - @torch.no_grad() - def decompress( - self, - obj: Union[str, Path, DACFile], - verbose: bool = False, - ) -> AudioSignal: - """Reconstruct audio from a given .dac file - - Parameters - ---------- - obj : Union[str, Path, DACFile] - .dac file location or corresponding DACFile object. - verbose : bool, optional - Prints progress if True, by default False - - Returns - ------- - AudioSignal - Object with the reconstructed audio - """ - self.eval() - if isinstance(obj, (str, Path)): - obj = DACFile.load(obj) - - original_padding = self.padding - self.padding = obj.padding - - range_fn = range if not verbose else tqdm.trange - codes = obj.codes - original_device = codes.device - chunk_length = obj.chunk_length - recons = [] - - for i in range_fn(0, codes.shape[-1], chunk_length): - c = codes[..., i : i + chunk_length].to(self.device) - z = self.quantizer.from_codes(c)[0] - r = self.decode(z) - recons.append(r.to(original_device)) - - recons = torch.cat(recons, dim=-1) - recons = AudioSignal(recons, self.sample_rate) - - resample_fn = recons.resample - loudness_fn = recons.loudness - - # If audio is > 10 minutes long, use the ffmpeg versions - if recons.signal_duration >= 10 * 60 * 60: - resample_fn = recons.ffmpeg_resample - loudness_fn = recons.ffmpeg_loudness - - recons.normalize(obj.input_db) - resample_fn(obj.sample_rate) - recons = recons[..., : obj.original_length] - loudness_fn() - recons.audio_data = recons.audio_data.reshape( - -1, obj.channels, obj.original_length - ) - - self.padding = original_padding - return recons \ No newline at end of file diff --git a/dito/models/ldm/dac/layers.py b/dito/models/ldm/dac/layers.py deleted file mode 100644 index a0cc6fb4021f2b34d2a1c9cee151a8576a8e5285..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/layers.py +++ /dev/null @@ -1,80 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch.nn.utils import weight_norm - - -def WNConv1d(*args, **kwargs): - return weight_norm(nn.Conv1d(*args, **kwargs)) - - -def WNConvTranspose1d(*args, **kwargs): - return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) - - -# Scripting this brings model speed up 1.4x -@torch.jit.script -def snake(x, alpha): - shape = x.shape - x = x.reshape(shape[0], shape[1], -1) - x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) - x = x.reshape(shape) - return x - - -class Snake1d(nn.Module): - def __init__(self, channels): - super().__init__() - self.alpha = nn.Parameter(torch.ones(1, channels, 1)) - - def forward(self, x): - return snake(x, self.alpha) - -def snake_beta(x, alpha, beta): - return x + (1.0 / (beta + 0.000000001)) * torch.pow(torch.sin(x * alpha), 2) -# License available in LICENSES/LICENSE_NVIDIA.txt -class SnakeBeta(nn.Module): - - def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): - super(SnakeBeta, self).__init__() - self.in_features = in_features - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) - self.beta = nn.Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = nn.Parameter(torch.ones(in_features) * alpha) - self.beta = nn.Parameter(torch.ones(in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - self.beta.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - beta = self.beta.unsqueeze(0).unsqueeze(-1) - if self.alpha_logscale: - alpha = torch.exp(alpha) - beta = torch.exp(beta) - x = snake_beta(x, alpha, beta) - - return x - -def get_activation(activation, channels, alpha): - if activation == "snake": - return Snake1d(channels) - elif activation == "relu": - return nn.ReLU() - elif activation == "leaky_relu": - return nn.LeakyReLU() - elif activation == "tanh": - return nn.Tanh() - elif activation == "snake_beta": - return SnakeBeta(channels, alpha) - else: - raise ValueError(f"Activation {activation} not supported") \ No newline at end of file diff --git a/dito/models/ldm/dac/loss.py b/dito/models/ldm/dac/loss.py deleted file mode 100644 index 2a5fc6f38ea44ce666522ba96ec24751a3e4f1ee..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/loss.py +++ /dev/null @@ -1,374 +0,0 @@ -import typing -from typing import List - -import torch -import torch.nn.functional as F -from audiotools import AudioSignal -from audiotools import STFTParams -from torch import nn - - -class L1Loss(nn.L1Loss): - """L1 Loss between AudioSignals. Defaults - to comparing ``audio_data``, but any - attribute of an AudioSignal can be used. - - Parameters - ---------- - attribute : str, optional - Attribute of signal to compare, defaults to ``audio_data``. - weight : float, optional - Weight of this loss, defaults to 1.0. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py - """ - - def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): - self.attribute = attribute - self.weight = weight - super().__init__(**kwargs) - - def forward(self, x: AudioSignal, y: AudioSignal): - """ - Parameters - ---------- - x : AudioSignal - Estimate AudioSignal - y : AudioSignal - Reference AudioSignal - - Returns - ------- - torch.Tensor - L1 loss between AudioSignal attributes. - """ - if isinstance(x, AudioSignal): - x = getattr(x, self.attribute) - y = getattr(y, self.attribute) - return super().forward(x, y) - - -class SISDRLoss(nn.Module): - """ - Computes the Scale-Invariant Source-to-Distortion Ratio between a batch - of estimated and reference audio signals or aligned features. - - Parameters - ---------- - scaling : int, optional - Whether to use scale-invariant (True) or - signal-to-noise ratio (False), by default True - reduction : str, optional - How to reduce across the batch (either 'mean', - 'sum', or none).], by default ' mean' - zero_mean : int, optional - Zero mean the references and estimates before - computing the loss, by default True - clip_min : int, optional - The minimum possible loss value. Helps network - to not focus on making already good examples better, by default None - weight : float, optional - Weight of this loss, defaults to 1.0. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py - """ - - def __init__( - self, - scaling: int = True, - reduction: str = "mean", - zero_mean: int = True, - clip_min: int = None, - weight: float = 1.0, - ): - self.scaling = scaling - self.reduction = reduction - self.zero_mean = zero_mean - self.clip_min = clip_min - self.weight = weight - super().__init__() - - def forward(self, x: AudioSignal, y: AudioSignal): - eps = 1e-8 - # nb, nc, nt - if isinstance(x, AudioSignal): - references = x.audio_data - estimates = y.audio_data - else: - references = x - estimates = y - - nb = references.shape[0] - references = references.reshape(nb, 1, -1).permute(0, 2, 1) - estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) - - # samples now on axis 1 - if self.zero_mean: - mean_reference = references.mean(dim=1, keepdim=True) - mean_estimate = estimates.mean(dim=1, keepdim=True) - else: - mean_reference = 0 - mean_estimate = 0 - - _references = references - mean_reference - _estimates = estimates - mean_estimate - - references_projection = (_references**2).sum(dim=-2) + eps - references_on_estimates = (_estimates * _references).sum(dim=-2) + eps - - scale = ( - (references_on_estimates / references_projection).unsqueeze(1) - if self.scaling - else 1 - ) - - e_true = scale * _references - e_res = _estimates - e_true - - signal = (e_true**2).sum(dim=1) - noise = (e_res**2).sum(dim=1) - sdr = -10 * torch.log10(signal / noise + eps) - - if self.clip_min is not None: - sdr = torch.clamp(sdr, min=self.clip_min) - - if self.reduction == "mean": - sdr = sdr.mean() - elif self.reduction == "sum": - sdr = sdr.sum() - return sdr - - -class MultiScaleSTFTLoss(nn.Module): - """Computes the multi-scale STFT loss from [1]. - - Parameters - ---------- - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - - References - ---------- - - 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. - "DDSP: Differentiable Digital Signal Processing." - International Conference on Learning Representations. 2019. - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py - """ - - def __init__( - self, - window_lengths: List[int] = [2048, 512], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.loss_fn = loss_fn - self.log_weight = log_weight - self.mag_weight = mag_weight - self.clamp_eps = clamp_eps - self.weight = weight - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes multi-scale STFT between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Multi-scale STFT loss. - """ - loss = 0.0 - for s in self.stft_params: - x.stft(s.window_length, s.hop_length, s.window_type) - y.stft(s.window_length, s.hop_length, s.window_type) - loss += self.log_weight * self.loss_fn( - x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) - return loss - - -class MelSpectrogramLoss(nn.Module): - """Compute distance between mel spectrograms. Can be used - in a multi-scale way. - - Parameters - ---------- - n_mels : List[int] - Number of mels per STFT, by default [150, 80], - window_lengths : List[int], optional - Length of each window of each STFT, by default [2048, 512] - loss_fn : typing.Callable, optional - How to compare each loss, by default nn.L1Loss() - clamp_eps : float, optional - Clamp on the log magnitude, below, by default 1e-5 - mag_weight : float, optional - Weight of raw magnitude portion of loss, by default 1.0 - log_weight : float, optional - Weight of log magnitude portion of loss, by default 1.0 - pow : float, optional - Power to raise magnitude to before taking log, by default 2.0 - weight : float, optional - Weight of this loss, by default 1.0 - match_stride : bool, optional - Whether to match the stride of convolutional layers, by default False - - Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py - """ - - def __init__( - self, - n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], - window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], - loss_fn: typing.Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 0.0, - log_weight: float = 1.0, - pow: float = 1.0, - weight: float = 1.0, - match_stride: bool = False, - mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], - mel_fmax: List[float] = [None, None, None, None, None, None, None], - window_type: str = None, - ): - super().__init__() - self.stft_params = [ - STFTParams( - window_length=w, - hop_length=w // 4, - match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths - ] - self.n_mels = n_mels - self.loss_fn = loss_fn - self.clamp_eps = clamp_eps - self.log_weight = log_weight - self.mag_weight = mag_weight - self.weight = weight - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.pow = pow - - def forward(self, x: AudioSignal, y: AudioSignal): - """Computes mel loss between an estimate and a reference - signal. - - Parameters - ---------- - x : AudioSignal - Estimate signal - y : AudioSignal - Reference signal - - Returns - ------- - torch.Tensor - Mel loss. - """ - loss = 0.0 - for n_mels, fmin, fmax, s in zip( - self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params - ): - kwargs = { - "window_length": s.window_length, - "hop_length": s.hop_length, - "window_type": s.window_type, - } - x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) - - loss += self.log_weight * self.loss_fn( - x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), - ) - loss += self.mag_weight * self.loss_fn(x_mels, y_mels) - return loss - - -class GANLoss(nn.Module): - """ - Computes a discriminator loss, given a discriminator on - generated waveforms/spectrograms compared to ground truth - waveforms/spectrograms. Computes the loss for both the - discriminator and the generator in separate functions. - """ - - def __init__(self, discriminator): - super().__init__() - self.discriminator = discriminator - - def forward(self, fake, real): - d_fake = self.discriminator(fake.audio_data) - d_real = self.discriminator(real.audio_data) - return d_fake, d_real - - def discriminator_loss(self, fake, real): - d_fake, d_real = self.forward(fake.clone().detach(), real) - - loss_d = 0 - for x_fake, x_real in zip(d_fake, d_real): - loss_d += torch.mean(x_fake[-1] ** 2) - loss_d += torch.mean((1 - x_real[-1]) ** 2) - return loss_d - - def generator_loss(self, fake, real): - d_fake, d_real = self.forward(fake, real) - - loss_g = 0 - for x_fake in d_fake: - loss_g += torch.mean((1 - x_fake[-1]) ** 2) - - loss_feature = 0 - - for i in range(len(d_fake)): - for j in range(len(d_fake[i]) - 1): - loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) - return loss_g, loss_feature - - -def kl_loss(logs, m): - kl = 0.5 * (m**2 + torch.exp(logs) - logs - 1).sum(dim=1) - kl = torch.mean(kl) - return kl \ No newline at end of file diff --git a/dito/models/ldm/dac/model.py b/dito/models/ldm/dac/model.py deleted file mode 100644 index 5fa0e6c958442c3214bd49c6ce2b2643670db8b3..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/model.py +++ /dev/null @@ -1,729 +0,0 @@ -import math -from typing import List, Union - -import numpy as np -import torch -from einops import rearrange -from torch import nn -from torch.nn import functional as F -from torch.nn.utils import weight_norm - -from .audiotools import AudioSignal, STFTParams, ml -from .audiotools.ml import BaseModel -from .base import CodecMixin -from .layers import WNConv1d, WNConvTranspose1d, get_activation - - -def init_weights(m, mean=0.0, std=0.02, init_type="xavier", gain=0.02): - """ - Initialize weights of the entire model using xavier_normal_ or kaiming_normal_. - Args: - m (nn.Module): The module to initialize. - mean (float): Mean for weight initialization. - std (float): Standard deviation for weight initialization. - init_type (str): Type of initialization ('xavier' or 'kaiming'). - gain (float): Gain for xavier initialization. - """ - classname = m.__class__.__name__ - - if init_type == "xavier": - # Handle convolutional layers - if "Depthwise_Separable" in classname: - nn.init.xavier_normal_(m.depth_conv.weight.data, gain=gain) - nn.init.xavier_normal_(m.point_conv.weight.data, gain=gain) - if hasattr(m.depth_conv, "bias") and m.depth_conv.bias is not None: - nn.init.zeros_(m.depth_conv.bias.data) - if hasattr(m.point_conv, "bias") and m.point_conv.bias is not None: - nn.init.zeros_(m.point_conv.bias.data) - elif classname.find("Conv") != -1: - nn.init.xavier_normal_(m.weight.data, gain=gain) - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias.data) - - # Handle batch normalization layers - elif classname.find("BatchNorm") != -1: - if hasattr(m, "weight") and m.weight is not None: - nn.init.xavier_normal_(m.weight.data, gain=gain) - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias.data) - - # Handle custom layers like Snake1d and SnakeBeta - elif classname == "Snake1d": - if hasattr(m, "alpha") and m.alpha is not None: - if m.alpha.data.dim() >= 2: - nn.init.xavier_normal_(m.alpha.data, gain=gain) - else: - nn.init.normal_(m.alpha.data, mean=1.0, std=std) - elif classname == "SnakeBeta": - # Respect the alpha_logscale setting in SnakeBeta - if hasattr(m, "alpha") and m.alpha is not None: - if m.alpha_logscale: - nn.init.constant_(m.alpha.data, 0.0) # Matches SnakeBeta's default - else: - nn.init.constant_(m.alpha.data, 1.0) - if hasattr(m, "beta") and m.beta is not None: - if m.alpha_logscale: - nn.init.constant_(m.beta.data, 0.0) # Matches SnakeBeta's default - else: - nn.init.constant_(m.beta.data, 1.0) - - # Handle residual scaling parameters - elif hasattr(m, "residual_scale") and m.residual_scale is not None: - nn.init.xavier_normal_(m.residual_scale.data, gain=gain) - - else: - # Kaiming initialization - if "Depthwise_Separable" in classname: - nn.init.kaiming_normal_( - m.depth_conv.weight.data, mode="fan_out", nonlinearity="relu" - ) - nn.init.kaiming_normal_( - m.point_conv.weight.data, mode="fan_out", nonlinearity="relu" - ) - elif classname.find("Conv") != -1: - nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias.data) - elif classname.find("BatchNorm") != -1: - if hasattr(m, "weight") and m.weight is not None: - nn.init.normal_(m.weight.data, 1.0, std) - if hasattr(m, "bias") and m.bias is not None: - nn.init.zeros_(m.bias.data) - elif classname == "Snake1d": - if hasattr(m, "alpha") and m.alpha is not None: - nn.init.normal_(m.alpha.data, 1.0, std) - elif classname == "SnakeBeta": - if hasattr(m, "beta") and m.beta is not None: - nn.init.normal_(m.beta.data, 1.0, std) - elif ( - hasattr(m, "alpha") and m.alpha is not None - ): # Fallback if SnakeBeta uses alpha - nn.init.normal_(m.alpha.data, 1.0, std) - - elif hasattr(m, "residual_scale") and m.residual_scale is not None: - nn.init.normal_(m.residual_scale.data, 0.1, std) - - -class ResidualUnit(nn.Module): - def __init__( - self, - dim: int = 16, - dilation: int = 1, - activation: str = "snake", - alpha: float = 1.0, - scale_residual: bool = False, - ): - """ - Residual Unit with weight normalization and dilated convolutions. - Args: - dim (int): Number of input and output channels. - dilation (int): Dilation factor for the convolution. - activation (str): Activation function to use. - alpha (float): Scaling factor for the activation function. - """ - super().__init__() - pad = ((7 - 1) * dilation) // 2 - self.block = nn.Sequential( - get_activation(activation=activation, channels=dim, alpha=alpha), - WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), - get_activation(activation=activation, channels=dim, alpha=alpha), - WNConv1d(dim, dim, kernel_size=1), - ) - self.scale_residual = scale_residual - if self.scale_residual: - self.res_scale = nn.Parameter(torch.tensor(0.0)) # start at 0 - - def forward(self, x): - y = self.block(x) - pad = (x.shape[-1] - y.shape[-1]) // 2 - if pad > 0: - x = x[..., pad:-pad] - if self.scale_residual: - y = self.res_scale * y - return x + y - - -class EncoderBlock(nn.Module): - def __init__( - self, - dim: int = 16, - stride: int = 1, - activation: str = "snake", - alpha: float = 1.0, - scale_residual: bool = False, - ): - """ - Encoder block that downsamples the input and applies residual units. - """ - super().__init__() - self.block = nn.Sequential( - ResidualUnit( - dim // 2, - dilation=1, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ResidualUnit( - dim // 2, - dilation=3, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ResidualUnit( - dim // 2, - dilation=9, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - get_activation(activation=activation, channels=dim // 2, alpha=alpha), - WNConv1d( - dim // 2, - dim, - kernel_size=2 * stride, - stride=stride, - padding=math.ceil(stride / 2), - ), - ) - - def forward(self, x): - return self.block(x) - - -class Encoder(nn.Module): - def __init__( - self, - d_model: int = 64, - strides: list = [2, 4, 8, 8], - d_latent: int = 64, - d_in: int = 1, - activation: str = "snake", - alpha: float = 1.0, - scale_residual: bool = False, - weight_init: str = "xavier", - gain: float = 1.0, - ): - super().__init__() - # Create first convolution - self.block = [WNConv1d(d_in, d_model, kernel_size=7, padding=3)] - - # Create EncoderBlocks that double channels as they downsample by `stride` - for stride in strides: - d_model *= 2 - self.block += [ - EncoderBlock( - d_model, - stride=stride, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ) - ] - - # Create last convolution - self.block += [ - get_activation(activation=activation, channels=d_model, alpha=alpha), - WNConv1d(d_model, d_latent, kernel_size=3, padding=1), - ] - - # Wrap black into nn.Sequential - self.block = nn.Sequential(*self.block) - self.enc_dim = d_model - - self.apply(lambda m: init_weights(m, init_type=weight_init, gain=gain)) - - def forward(self, x): - x = F.leaky_relu(x) - return self.block(x) - - -class DecoderBlock(nn.Module): - def __init__( - self, - input_dim: int = 16, - output_dim: int = 8, - stride: int = 1, - norm: bool = False, - activation: str = "snake", - alpha: float = 1.0, - scale_residual: bool = False, - ): - """ - Decoder block that upsamples the input and applies residual units. - """ - super().__init__() - if not norm: - self.block = nn.Sequential( - get_activation(activation=activation, channels=input_dim, alpha=alpha), - WNConvTranspose1d( - input_dim, - output_dim, - kernel_size=2 * stride, - stride=stride, - padding=math.ceil(stride / 2), - output_padding=0 if stride % 2 == 0 else 1, - ), - ResidualUnit( - output_dim, - dilation=1, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ResidualUnit( - output_dim, - dilation=3, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ResidualUnit( - output_dim, - dilation=9, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ) - else: - self.block = nn.Sequential( - get_activation(activation=activation, channels=input_dim, alpha=alpha), - WNConvTranspose1d( - input_dim, - output_dim, - kernel_size=2 * stride, - stride=stride, - padding=math.ceil(stride / 2), - output_padding=0 if stride % 2 == 0 else 1, - ), - nn.BatchNorm1d(output_dim), - ResidualUnit( - output_dim, - dilation=1, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - nn.BatchNorm1d(output_dim), - ResidualUnit( - output_dim, - dilation=3, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - nn.BatchNorm1d(output_dim), - ResidualUnit( - output_dim, - dilation=9, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ), - ) - - def forward(self, x): - return self.block(x) - - -class Decoder(nn.Module): - def __init__( - self, - input_channel, - channels, - rates, - d_out: int = 1, - norm: bool = False, - activation: str = "snake", - alpha: float = 1.0, - scale_residual: bool = False, - use_tanh_as_final: bool = True, - use_bias_at_final: bool = True, - ): - super().__init__() - - # Add first conv layer - layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] - - # Add upsampling + MRF blocks - for i, stride in enumerate(rates): - input_dim = channels // 2**i - output_dim = channels // 2 ** (i + 1) - layers += [ - DecoderBlock( - input_dim, - output_dim, - stride, - norm=norm, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ) - ] - - # Add final conv layer - layers += [ - get_activation(activation=activation, channels=output_dim, alpha=alpha), - WNConv1d( - output_dim, d_out, kernel_size=7, padding=3, bias=use_bias_at_final - ), - nn.Tanh() if use_tanh_as_final else nn.Identity(), - ] - self.use_tanh_as_final = use_tanh_as_final - - self.model = nn.Sequential(*layers) - - def forward(self, x): - x = self.model(x) - if not self.use_tanh_as_final: - x = torch.clamp( - x, min=-1.0, max=1.0 - ) # Ensure output is within [-1, 1] range - return x - - -class DACVAE(BaseModel, CodecMixin): - def __init__( - self, - encoder_dim: int = 64, - encoder_rates: List[int] = [2, 4, 5, 8], - latent_dim: int = 64, - decoder_dim: int = 1536, - decoder_rates: List[int] = [8, 5, 4, 2], - sample_rate: int = 44100, - d_in: int = 2, - d_out: int = 2, - weight_init: str = "xavier", - norm: bool = False, - activation: str = "snake", - alpha: float = 1.0, - gain: float = 0.02, - scale_residual: bool = False, - use_tanh_as_final: bool = True, - use_bias_at_final: bool = True, - ): - super().__init__() - - self.encoder_dim = encoder_dim - self.encoder_rates = encoder_rates - self.decoder_dim = decoder_dim - self.decoder_rates = decoder_rates - self.sample_rate = sample_rate - self.d_in = d_in - self.d_out = d_out - - if latent_dim is None: - latent_dim = encoder_dim * (2 ** len(encoder_rates)) - - self.latent_dim = latent_dim - - self.hop_length = np.prod(encoder_rates) - self.encoder = Encoder( - encoder_dim, - encoder_rates, - latent_dim, - d_in=d_in, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - ) - - self.decoder = Decoder( - latent_dim, - decoder_dim, - decoder_rates, - d_out=d_out, - norm=norm, - activation=activation, - alpha=alpha, - scale_residual=scale_residual, - use_tanh_as_final=use_tanh_as_final, - use_bias_at_final=use_bias_at_final, - ) - - self.en_conv_post = WNConv1d( - self.latent_dim, 2 * self.latent_dim, kernel_size=1 - ) - - self.de_conv_pre = WNConv1d(self.latent_dim, self.latent_dim, kernel_size=1) - - self.sample_rate = sample_rate - self.apply(lambda m: init_weights(m, init_type=weight_init, gain=gain)) - self.step = 0 # Initialize step counter for noise decay - - def freeze_encoder(self): - for param in self.encoder.parameters(): - param.requires_grad = False - for param in self.en_conv_post.parameters(): - param.requires_grad = False - print("Encoder and en_conv_post frozen") - - def preprocess(self, audio_data, sample_rate): - if sample_rate is None: - sample_rate = self.sample_rate - assert sample_rate == self.sample_rate - length = audio_data.shape[-1] - # print(f"Audio length: {length}", "math.ceil(length / self.hop_length) * self.hop_length: ", math.ceil(length / self.hop_length) * self.hop_length) - right_pad = math.ceil(length / self.hop_length) * self.hop_length - length - - audio_data = nn.functional.pad(audio_data, (0, right_pad)) - return audio_data - - def encode( - self, - audio_data: torch.Tensor, - training: bool = True, - ): - x = self.encoder(audio_data) - x = self.en_conv_post(x) - m, logs = torch.split(x, self.latent_dim, dim=1) - logs = torch.clamp(logs, min=-14.0, max=14.0) - - z = m + torch.randn_like(m) * torch.exp(logs) - - return z, m, logs - - def decode(self, z: torch.Tensor): - z = self.de_conv_pre(z) - z = self.decoder(z) - return z - - def forward( - self, - audio_data: torch.Tensor, - sample_rate: int = 24000, - ): - # print(f"Audio data shape: {audio_data.shape}") - length = audio_data.shape[-1] - audio_data = self.preprocess(audio_data, sample_rate) - z, m, logs = self.encode(audio_data) - x = self.decode(z) - return { - "audio": x[..., :length], - "z": z, - "mu": m, - "logs": logs, - } - - -def WNConv1d(*args, **kwargs): - act = kwargs.pop("act", True) - conv = weight_norm(nn.Conv1d(*args, **kwargs)) - if not act: - return conv - return nn.Sequential(conv, nn.LeakyReLU(0.1)) - - -def WNConv2d(*args, **kwargs): - act = kwargs.pop("act", True) - conv = weight_norm(nn.Conv2d(*args, **kwargs)) - if not act: - return conv - return nn.Sequential(conv, nn.LeakyReLU(0.1)) - - -class MPD(nn.Module): - def __init__(self, period): - super().__init__() - self.period = period - self.convs = nn.ModuleList( - [ - WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), - WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), - ] - ) - self.conv_post = WNConv2d( - 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False - ) - - def pad_to_period(self, x): - t = x.shape[-1] - x = F.pad(x, (0, self.period - t % self.period), mode="reflect") - return x - - def forward(self, x): - fmap = [] - - x = self.pad_to_period(x) - x = rearrange(x, "b c (l p) -> b c l p", p=self.period) - - for layer in self.convs: - x = layer(x) - fmap.append(x) - - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class MSD(nn.Module): - def __init__(self, rate: int = 1, sample_rate: int = 44100): - super().__init__() - self.convs = nn.ModuleList( - [ - WNConv1d(1, 16, 15, 1, padding=7), - WNConv1d(16, 64, 41, 4, groups=4, padding=20), - WNConv1d(64, 256, 41, 4, groups=16, padding=20), - WNConv1d(256, 1024, 41, 4, groups=64, padding=20), - WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), - WNConv1d(1024, 1024, 5, 1, padding=2), - ] - ) - self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) - self.sample_rate = sample_rate - self.rate = rate - - def forward(self, x): - x = AudioSignal(x, self.sample_rate) - x.resample(self.sample_rate // self.rate) - x = x.audio_data - - fmap = [] - - for l in self.convs: - x = l(x) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] - - -class MRD(nn.Module): - def __init__( - self, - window_length: int, - hop_factor: float = 0.25, - sample_rate: int = 44100, - bands: list = BANDS, - ): - """Complex multi-band spectrogram discriminator. - Parameters - ---------- - window_length : int - Window length of STFT. - hop_factor : float, optional - Hop factor of the STFT, defaults to ``0.25 * window_length``. - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run discriminator over. - """ - super().__init__() - - self.window_length = window_length - self.hop_factor = hop_factor - self.sample_rate = sample_rate - self.stft_params = STFTParams( - window_length=window_length, - hop_length=int(window_length * hop_factor), - match_stride=True, - ) - - n_fft = window_length // 2 + 1 - bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] - self.bands = bands - - ch = 32 - convs = lambda: nn.ModuleList( - [ - WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), - WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), - ] - ) - self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) - self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) - - def spectrogram(self, x): - x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) - x = torch.view_as_real(x.stft()) - x = rearrange(x, "b 1 f t c -> (b 1) c t f") - # Split into bands - x_bands = [x[..., b[0] : b[1]] for b in self.bands] - return x_bands - - def forward(self, x): - x_bands = self.spectrogram(x) - fmap = [] - - x = [] - for band, stack in zip(x_bands, self.band_convs): - for layer in stack: - band = layer(band) - fmap.append(band) - x.append(band) - - x = torch.cat(x, dim=-1) - x = self.conv_post(x) - fmap.append(x) - - return fmap - - -class Discriminator(ml.BaseModel): - def __init__( - self, - rates: list = [], - periods: list = [2, 3, 5, 7, 11], - fft_sizes: list = [2048, 1024, 512], - sample_rate: int = 44100, - bands: list = BANDS, - d_in: int = 1, - ): - """Discriminator that combines multiple discriminators. - - Parameters - ---------- - rates : list, optional - sampling rates (in Hz) to run MSD at, by default [] - If empty, MSD is not used. - periods : list, optional - periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] - fft_sizes : list, optional - Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] - sample_rate : int, optional - Sampling rate of audio in Hz, by default 44100 - bands : list, optional - Bands to run MRD at, by default `BANDS` - """ - super().__init__() - discs = [] - discs += [MPD(p) for p in periods] - discs += [MSD(r, sample_rate=sample_rate) for r in rates] - discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] - self.discriminators = nn.ModuleList(discs) - - def preprocess(self, y): - # Remove DC offset - y = y - y.mean(dim=-1, keepdims=True) - # Peak normalize the volume of input audio - y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) - return y - - def forward(self, x): - x = self.preprocess(x) - fmaps = [d(x) for d in self.discriminators] - return fmaps - - -if __name__ == "__main__": - disc = Discriminator() - x = torch.zeros(1, 1, 44100) - results = disc(x) - for i, result in enumerate(results): - print(f"disc{i}") - for i, r in enumerate(result): - print(r.shape, r.mean(), r.min(), r.max()) - print() diff --git a/dito/models/ldm/dac/utils.py b/dito/models/ldm/dac/utils.py deleted file mode 100644 index e7724a9dde8a937e3c4e06146707915625d63595..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dac/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch.nn as nn - - -from models import register -from .model import Encoder, Decoder, WNConv1d - - -default_configs = { - 'snake': dict( - encoder_dim=64, - encoder_rates=[2, 4, 5, 8], - latent_dim=64, - d_in=1, - activation='snake', - ), - 'snake': dict( - encoder_dim=64, - encoder_rates=[2, 4, 5, 8], - latent_dim=64, - d_in=1, - activation='snakebeta', - ), -} - - -@register('dac_encoder') -def make_dac_encoder(config_name, **kwargs): - encoder_kwargs = default_configs[config_name] - encoder_kwargs.update(kwargs) - latent_dim = encoder_kwargs['latent_dim'] - return nn.Sequential( - Encoder(**encoder_kwargs), - WNConv1d(latent_dim, latent_dim, kernel_size=1), - ) - - -@register('vqgan_decoder') -def make_vqgan_decoder(config_name, **kwargs): - decoder_kwargs = default_configs[config_name] - decoder_kwargs.update(kwargs) - latent_dim = decoder_kwargs['latent_dim'] - return nn.Sequential( - WNConv1d(latent_dim, latent_dim, kernel_size=1), - Decoder(**decoder_kwargs), - ) diff --git a/dito/models/ldm/dito.py b/dito/models/ldm/dito.py deleted file mode 100644 index 1559d99cb095c56de1991099234f613cb5a93984..0000000000000000000000000000000000000000 --- a/dito/models/ldm/dito.py +++ /dev/null @@ -1,180 +0,0 @@ -import copy -import math - -import torch - -import models -from omegaconf import OmegaConf -from models import register -from models.ldm.ldm_base import LDMBase -from models.ldm.vqgan.lpips import LPIPS - - -@register('dito') -class DiTo(LDMBase): - - def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1, lpips=False, **kwargs): - super().__init__(**kwargs) - self.render_diffusion = models.make(render_diffusion) - - if OmegaConf.is_config(render_sampler): - render_sampler = OmegaConf.to_container(render_sampler, resolve=True) - render_sampler = copy.deepcopy(render_sampler) - if render_sampler.get('args') is None: - render_sampler['args'] = {} - render_sampler['args']['diffusion'] = self.render_diffusion - self.render_sampler = models.make(render_sampler) - self.render_n_steps = render_n_steps - self.renderer_guidance = renderer_guidance - - self.t_loss_monitor_v = [0 for _ in range(10)] - self.t_loss_monitor_n = [0 for _ in range(10)] - self.t_loss_monitor_decay = 0.99 - - self.use_lpips = lpips - if lpips: - self.lpips_loss = LPIPS().eval() - - def render(self, z_dec, coord, scale): - shape = (coord.size(0), 3, coord.size(2), coord.size(3)) - net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': z_dec} - - if self.use_ema_renderer: - self.swap_ema_renderer() - - if self.renderer_guidance > 1: - uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1) - uncond_net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': uncond_z_dec} - else: - uncond_net_kwargs = None - - ret = self.render_sampler.sample( - net=self.renderer, - shape=shape, - n_steps=self.render_n_steps, - net_kwargs=net_kwargs, - uncond_net_kwargs=uncond_net_kwargs, - guidance=self.renderer_guidance, - ) - - if self.use_ema_renderer: - self.swap_ema_renderer() - - return ret - - def forward(self, data, mode, has_optimizer=None): - if mode in ['z', 'z_dec']: - ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) - return ret_z - - grad = self.get_grad_plan(has_optimizer) - loss_config = self.loss_config - print('mode', mode) - if mode == 'pred': - z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) - - gt_patch = data['gt'][:, :3, ...] - coord = data['gt'][:, 3:5, ...] - scale = data['gt'][:, 5:7, ...] - - if grad['renderer']: - return self.render(z_dec, coord, scale) - else: - with torch.no_grad(): - return self.render(z_dec, coord, scale) - - elif mode == 'loss': - if not grad['renderer']: # Only training zdm - print('not grad[renderer]') - _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) - return ret - - gt_patch = data['gt'][:, :3, ...] - coord = data['gt'][:, 3:5, ...] - scale = data['gt'][:, 5:7, ...] - - z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) - net_kwargs = {'z_dec': z_dec} - - print('latent z_dec shape: ', z_dec.shape) - - t = torch.rand(gt_patch.shape[0], device=gt_patch.device) - - print('self.gt_noise_lb:', self.gt_noise_lb) - if self.gt_noise_lb is not None: - tmin = torch.ones_like(t) * self.gt_noise_lb - tmax = torch.ones_like(t) * 1 - t = tmin + (tmax - tmin) * torch.rand_like(tmin) - - print('self.zaug_p:', self.zaug_p) - print('self.training:', self.training) - - if (self.zaug_p is not None) and self.training: - tz = self._tz - mask_aug = self._mask_aug - - typ = self.zaug_decoding_loss_type - if typ == 'all': - tmin = torch.ones_like(tz) * 0 - tmax = torch.ones_like(tz) * 1 - elif typ == 'suffix': - tmin = tz - tmax = torch.ones_like(tz) * 1 - elif typ == 'tz': - tmin = tz - tmax = tz - elif typ == 'tmax': - tmin = torch.ones_like(tz) * 1 - tmax = torch.ones_like(tz) * 1 - else: - raise NotImplementedError - t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin) - - t = mask_aug * t_aug + (1 - mask_aug) * t - print('self.use_lpips:', self.use_lpips) - if not self.use_lpips: - loss, t = self.render_diffusion.loss( - net=self.renderer, - x=gt_patch, - t=t, - net_kwargs=net_kwargs, - return_loss_unreduced=True - ) - else: - loss, t, x_t, pred = self.render_diffusion.loss( - net=self.renderer, - x=gt_patch, - t=t, - net_kwargs=net_kwargs, - return_loss_unreduced=True, - return_all=True - ) - - sample_pred = x_t + t.view(-1, 1, 1, 1) * pred - lpips_loss = self.lpips_loss(sample_pred, gt_patch).mean() - ret['lpips_loss'] = lpips_loss.item() - lpips_loss_w = loss_config.get('lpips_loss', 1) - ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w - - # Visualize diffusion network loss for different timesteps # - if self.training: - m = len(self.t_loss_monitor_v) - for i in range(len(loss)): - q = min(math.floor(t[i].item() * m), m - 1) - self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay) - self.t_loss_monitor_n[q] += 1 - for q in range(m): - if self.t_loss_monitor_n[q] > 0: - if self.t_loss_monitor_n[q] < 500: - r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q]) - else: - r = 1 - ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r - # - # - - dae_loss = loss.mean() - - ret['dae_loss'] = dae_loss.item() - dae_loss_w = loss_config.get('dae_loss', 1) - ret['loss'] = ret['loss'] + dae_loss * dae_loss_w - return ret diff --git a/dito/models/ldm/glpto.py b/dito/models/ldm/glpto.py deleted file mode 100644 index 03cddc0747222259b0a2da5c92362c6dff1dde78..0000000000000000000000000000000000000000 --- a/dito/models/ldm/glpto.py +++ /dev/null @@ -1,137 +0,0 @@ -import os - -import torch -import torch.nn.functional as F -import torch.distributed as dist - -from models import register -from models.ldm.ldm_base import LDMBase -from models.ldm.vqgan.lpips import LPIPS -from models.ldm.vqgan.discriminator import make_discriminator - - -@register('glpto') -class GLPTo(LDMBase): - - def __init__(self, lpips=True, disc=True, adaptive_gan_weight=True, noise_render=False, **kwargs): - super().__init__(**kwargs) - if lpips: - self.lpips_loss = LPIPS().eval() - self.disc = make_discriminator(input_nc=3) if disc else None - self.adaptive_gan_weight = adaptive_gan_weight - self.noise_render = noise_render - - def get_parameters(self, name): - if name == 'disc': - return self.disc.parameters() - else: - return super().get_parameters(name) - - def render(self, z_dec, coord, scale): - if not self.noise_render: - return self.renderer(z_dec, coord=coord, scale=scale) - else: - shape = (coord.shape[0], 3, coord.shape[2], coord.shape[3]) - noise = torch.randn(shape, device=z_dec.device) - return self.renderer(noise, coord=coord, scale=scale, z_dec=z_dec) - - def forward(self, data, mode, has_optimizer=None, use_gan=False): - if mode in ['z', 'z_dec']: - ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) - return ret_z - - grad = self.get_grad_plan(has_optimizer) - loss_config = self.loss_config - - if mode == 'pred': - z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) - - gt_patch = data['gt'][:, :3, ...] - coord = data['gt'][:, 3:5, ...] - scale = data['gt'][:, 5:7, ...] - - if grad['renderer']: - return self.render(z_dec, coord, scale) - else: - with torch.no_grad(): - return self.render(z_dec, coord, scale) - - elif mode == 'loss': - if not grad['renderer']: # Only training zdm - _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) - return ret - - gt_patch = data['gt'][:, :3, ...] - coord = data['gt'][:, 3:5, ...] - scale = data['gt'][:, 5:7, ...] - - z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) - pred = self.render(z_dec, coord, scale) - - l1_loss = torch.abs(pred - gt_patch).mean() - ret['l1_loss'] = l1_loss.item() - l1_loss_w = loss_config.get('l1_loss', 1) - ret['loss'] = ret['loss'] + l1_loss * l1_loss_w - - lpips_loss = self.lpips_loss(pred, gt_patch).mean() - ret['lpips_loss'] = lpips_loss.item() - lpips_loss_w = loss_config.get('lpips_loss', 1) - ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w - - if use_gan: - logits_fake = self.disc(pred) - - gan_g_loss = -torch.mean(logits_fake) - ret['gan_g_loss'] = gan_g_loss.item() - weight = loss_config.get('gan_g_loss', 1) - - if self.training and self.adaptive_gan_weight: - nll_loss = l1_loss * l1_loss_w + lpips_loss * lpips_loss_w - adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self.renderer.get_last_layer_weight()) - ret['adaptive_gan_w'] = adaptive_gan_w.item() - weight = weight * adaptive_gan_w - - ret['loss'] = ret['loss'] + gan_g_loss * weight - - return ret - - elif mode == 'disc_loss': - gt_patch = data['gt'][:, :3, ...] - coord = data['gt'][:, 3:5, ...] - scale = data['gt'][:, 5:7, ...] - - with torch.no_grad(): - z_dec, _ = super().forward(data, mode='z_dec', has_optimizer=None) - pred = self.render(z_dec, coord, scale) - - logits_real = self.disc(gt_patch) - logits_fake = self.disc(pred) - - disc_loss_type = loss_config.get('disc_loss_type', 'hinge') - if disc_loss_type == 'hinge': - loss_real = torch.mean(F.relu(1. - logits_real)) - loss_fake = torch.mean(F.relu(1. + logits_fake)) - loss = (loss_real + loss_fake) / 2 - elif disc_loss_type == 'vanilla': - loss_real = torch.mean(F.softplus(-logits_real)) - loss_fake = torch.mean(F.softplus(logits_fake)) - loss = (loss_real + loss_fake) / 2 - - return { - 'loss': loss, - 'disc_logits_real': logits_real.mean().item(), - 'disc_logits_fake': logits_fake.mean().item(), - } - - def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer): - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - world_size = int(os.environ.get('WORLD_SIZE', '1')) - if world_size > 1: - dist.all_reduce(nll_grads, op=dist.ReduceOp.SUM) - nll_grads.div_(world_size) - dist.all_reduce(g_grads, op=dist.ReduceOp.SUM) - g_grads.div_(world_size) - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - return d_weight diff --git a/dito/models/ldm/ldm_base.py b/dito/models/ldm/ldm_base.py deleted file mode 100644 index 69e8650214e2126f4d49dd6a3be59b2728647d2f..0000000000000000000000000000000000000000 --- a/dito/models/ldm/ldm_base.py +++ /dev/null @@ -1,444 +0,0 @@ -import copy -import math - -import numpy as np -import torch -import torch.nn as nn -from omegaconf import OmegaConf - -import models -from models.ldm.vqgan.quantizer import VectorQuantizer - - -class LDMBase(nn.Module): - - def __init__( - self, - encoder, - z_shape, - decoder, - renderer, - encoder_ema_rate=None, - decoder_ema_rate=None, - renderer_ema_rate=None, - z_gaussian=False, - z_gaussian_sample=True, - z_quantizer=False, - z_quantizer_n_embed=8192, - z_quantizer_beta=0.25, - z_layernorm=False, - zaug_p=None, - zaug_tmax=1.0, - zaug_tmax_always=False, - zaug_decoding_loss_type='all', - zaug_zdm_diffusion=None, - gt_noise_lb=None, - drop_z_p=0.0, - zdm_net=None, - zdm_diffusion=None, - zdm_sampler=None, - zdm_n_steps=None, - zdm_ema_rate=0.9999, - zdm_train_normalize=False, - zdm_class_cond=None, - zdm_force_guidance=None, - loss_config=None, - use_ema_encoder=False, - use_ema_decoder=False, - use_ema_renderer=False, - ): - super().__init__() - self.loss_config = loss_config if loss_config is not None else dict() - - self.encoder = models.make(encoder) - self.decoder = models.make(decoder) - self.renderer = models.make(renderer) - - self.z_shape = tuple(z_shape) - - self.z_gaussian = z_gaussian - self.z_gaussian_sample = z_gaussian_sample - - self.z_quantizer = VectorQuantizer( - z_quantizer_n_embed, - z_shape[0], - beta=z_quantizer_beta, - remap=None, - sane_index_shape=False - ) if z_quantizer else None - - self.z_layernorm = nn.LayerNorm( - list(z_shape), - elementwise_affine=False - ) if z_layernorm else None - - self.zaug_p = zaug_p - self.zaug_tmax = zaug_tmax - self.zaug_tmax_always = zaug_tmax_always - self.zaug_decoding_loss_type = zaug_decoding_loss_type - if zaug_zdm_diffusion is not None: - self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion) - - self.drop_z_p = drop_z_p - if self.drop_z_p > 0: - self.drop_z_emb = nn.Parameter(torch.zeros(z_shape[0], z_shape[1], z_shape[2]), requires_grad=False) - - self.gt_noise_lb = gt_noise_lb - - # EMA models # - self.encoder_ema_rate = encoder_ema_rate - if self.encoder_ema_rate is not None: - self.encoder_ema = copy.deepcopy(self.encoder) - for p in self.encoder_ema.parameters(): - p.requires_grad = False - - self.decoder_ema_rate = decoder_ema_rate - if self.decoder_ema_rate is not None: - self.decoder_ema = copy.deepcopy(self.decoder) - for p in self.decoder_ema.parameters(): - p.requires_grad = False - - self.renderer_ema_rate = renderer_ema_rate - if self.renderer_ema_rate is not None: - self.renderer_ema = copy.deepcopy(self.renderer) - for p in self.renderer_ema.parameters(): - p.requires_grad = False - # - # - - # z DM # - if zdm_diffusion is not None: - self.zdm_diffusion = models.make(zdm_diffusion) - - if OmegaConf.is_config(zdm_sampler): - zdm_sampler = OmegaConf.to_container(zdm_sampler, resolve=True) - zdm_sampler = copy.deepcopy(zdm_sampler) - if zdm_sampler.get('args') is None: - zdm_sampler['args'] = {} - zdm_sampler['args']['diffusion'] = self.zdm_diffusion - self.zdm_sampler = models.make(zdm_sampler) - self.zdm_n_steps = zdm_n_steps - - self.zdm_net = models.make(zdm_net) - - self.zdm_net_ema = copy.deepcopy(self.zdm_net) - for p in self.zdm_net_ema.parameters(): - p.requires_grad = False - self.zdm_ema_rate = zdm_ema_rate - - self.zdm_class_cond = zdm_class_cond - - self.zdm_force_guidance = zdm_force_guidance - else: - self.zdm_diffusion = None - - self.zdm_train_normalize = zdm_train_normalize - if zdm_train_normalize: - self.register_buffer('zdm_Ez_v', torch.tensor(0.)) - self.register_buffer('zdm_Ez_n', torch.tensor(0.)) - self.register_buffer('zdm_Ez2_v', torch.tensor(0.)) - self.register_buffer('zdm_Ez2_n', torch.tensor(0.)) - # - # - - self.use_ema_encoder = use_ema_encoder - self.use_ema_decoder = use_ema_decoder - self.use_ema_renderer = use_ema_renderer - - def get_parameters(self, name): - if name == 'encoder': - return self.encoder.parameters() - elif name == 'decoder': - p = list(self.decoder.parameters()) - if self.z_quantizer is not None: - p += list(self.z_quantizer.parameters()) - return p - elif name == 'renderer': - return self.renderer.parameters() - elif name == 'zdm': - return self.zdm_net.parameters() - - def encode(self, x, return_loss=False, ret=None): - if self.use_ema_encoder: - self.swap_ema_encoder() - - z = self.encoder(x) - - if self.use_ema_encoder: - self.swap_ema_encoder() - - if self.z_gaussian: - print('doing zzzzz_gaussian') - posterior = DiagonalGaussianDistribution(z) - if self.z_gaussian_sample: - z = posterior.sample() - else: - z = posterior.mode() - kl_loss = posterior.kl().mean() - - if ret is not None: - ret['z_gau_mean_abs'] = posterior.mean.abs().mean().item() - ret['z_gau_std'] = posterior.std.mean().item() - else: - kl_loss = None - - if self.z_layernorm is not None: - z = self.z_layernorm(z) - - if (self.zaug_p is not None) and self.training: - assert self.z_layernorm is not None # ensure 0 mean 1 std - if self.zaug_tmax_always: - tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax - else: - tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax - zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz) - mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float() - z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z - self._tz = tz - self._mask_aug = mask_aug - - if return_loss: - print('kl_loss', kl_loss) - return z, kl_loss - else: - return z - - def decode(self, z, return_loss=False): - if self.z_quantizer is not None: - z, quant_loss, _ = self.z_quantizer(z) - else: - quant_loss = None - - if self.use_ema_decoder: - self.swap_ema_decoder() - - z_dec = self.decoder(z) - - if self.use_ema_decoder: - self.swap_ema_decoder() - - if return_loss: - return z_dec, quant_loss - else: - return z_dec - - def render(self, z_dec, coord, cell): - raise NotImplementedError - - def normalize_for_zdm(self, z): - if self.zdm_train_normalize: - mean = self.zdm_Ez_v - var = self.zdm_Ez2_v - mean ** 2 - return (z - mean) / torch.sqrt(var) - else: - return z - - def denormalize_for_zdm(self, z): - if self.zdm_train_normalize: - mean = self.zdm_Ez_v - var = self.zdm_Ez2_v - mean ** 2 - return z * torch.sqrt(var) + mean - else: - return z - - def forward(self, data, mode, has_optimizer=None): - grad = self.get_grad_plan(has_optimizer) - loss = torch.tensor(0., device=data['inp'].device) - loss_config = self.loss_config - ret = dict() - - # Encoder - if grad['encoder']: - print('doing kl loss') - z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) - - # if self.z_gaussian: - # print('doing z_gaussian') - # ret['kl_loss'] = kl_loss.item() - # loss = loss + kl_loss * loss_config.get('kl_loss', 0.0) - else: - print('not doing kl loss') - with torch.no_grad(): - z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) - - if self.training and self.drop_z_p > 0: - drop_mask = (torch.rand(z.shape[0], device=z.device) < self.drop_z_p).to(z.dtype) - z = drop_mask.view(-1, 1, 1, 1) * self.drop_z_emb.unsqueeze(0) + (1 - drop_mask).view(-1, 1, 1, 1) * z - - # Z DM - if grad['zdm']: - print('doing zdm loss') - if self.zdm_train_normalize and self.training: - self.zdm_Ez_v = ( - self.zdm_Ez_v * (self.zdm_Ez_n / (self.zdm_Ez_n + 1)) - + z.mean().item() / (self.zdm_Ez_n + 1) - ) - self.zdm_Ez_n = self.zdm_Ez_n + 1 - - self.zdm_Ez2_v = ( - self.zdm_Ez2_v * (self.zdm_Ez2_n / (self.zdm_Ez2_n + 1)) - + (z ** 2).mean().item() / (self.zdm_Ez2_n + 1) - ) - self.zdm_Ez2_n = self.zdm_Ez2_n + 1 - - ret['normalize_z_mean'] = self.zdm_Ez_v.item() - ret['normalize_z_std'] = math.sqrt((self.zdm_Ez2_v - self.zdm_Ez_v ** 2).item()) - - z_for_dm = self.normalize_for_zdm(z) - - net_kwargs = dict() - if self.zdm_class_cond is not None: - net_kwargs['class_labels'] = data['class_labels'] - - zdm_loss = self.zdm_diffusion.loss(self.zdm_net, z_for_dm, net_kwargs=net_kwargs) - ret['zdm_loss'] = zdm_loss.item() - loss = loss + zdm_loss * loss_config.get('zdm_loss', 1.0) - - if not self.training: - ret['zdm_ema_loss'] = self.zdm_diffusion.loss(self.zdm_net_ema, z_for_dm, net_kwargs=net_kwargs).item() - - # Decoder - if mode == 'z': - print('doing z mode') - ret_z = z - elif mode == 'z_dec': - print('doing z_dec mode') - if grad['decoder']: - print('doing z_dec mode with grad') - z_dec, quant_loss = self.decode(z, return_loss=True) - else: - print('doing z_dec mode without grad') - with torch.no_grad(): - z_dec, quant_loss = self.decode(z, return_loss=True) - ret_z = z_dec - - # if self.z_quantizer is not None: - # print('doing quant_loss') - # ret['quant_loss'] = quant_loss.item() - # loss = loss + quant_loss * loss_config.get('quant_loss', 1.0) - - ret['loss'] = loss - return ret_z, ret - - def get_grad_plan(self, has_optimizer): - if has_optimizer is None: - has_optimizer = dict() - grad = dict() - grad['encoder'] = has_optimizer.get('encoder', False) - grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False) - grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False) - grad['zdm'] = has_optimizer.get('zdm', False) # not in chain definition - return grad - - def update_ema_fn(self, net_ema, net, rate): - if rate != 1: - for ema_p, cur_p in zip(net_ema.parameters(), net.parameters()): - ema_p.data.lerp_(cur_p.data, 1 - rate) - - def update_ema(self): - if self.encoder_ema_rate is not None: - self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate) - if self.decoder_ema_rate is not None: - self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate) - if self.renderer_ema_rate is not None: - self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate) - if (self.zdm_diffusion is not None) and (self.zdm_ema_rate is not None): - self.update_ema_fn(self.zdm_net_ema, self.zdm_net, self.zdm_ema_rate) - - def generate_samples( - self, - batch_size, - n_steps, - net_kwargs=None, - uncond_net_kwargs=None, - ema=False, - guidance=1.0, - noise=None, - render_res=(256, 256), - return_z=False, - ): - if self.zdm_force_guidance is not None: - guidance = self.zdm_force_guidance - - shape = (batch_size,) + self.z_shape - net = self.zdm_net if not ema else self.zdm_net_ema - - z = self.zdm_sampler.sample( - net, - shape, - n_steps, - net_kwargs=net_kwargs, - uncond_net_kwargs=uncond_net_kwargs, - guidance=guidance, - noise=noise, - ) - - if return_z: - return z - - if (self.zaug_p is not None) and self.zaug_tmax_always: - tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax - z, _ = self.zaug_zdm_diffusion.add_noise(z, tz) - - z = self.denormalize_for_zdm(z) - z_dec = self.decode(z) - - coord = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) - scale = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) - return self.render(z_dec, coord, scale) - - def swap_ema_encoder(self): - _ = self.encoder - self.encoder = self.encoder_ema - self.encoder_ema = _ - - def swap_ema_decoder(self): - _ = self.decoder - self.decoder = self.decoder_ema - self.decoder_ema = _ - - def swap_ema_renderer(self): - _ = self.renderer - self.renderer = self.renderer_ema - self.renderer_ema = _ - - -class DiagonalGaussianDistribution(object): - - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): - if self.deterministic: - return torch.Tensor([0.]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) - - def mode(self): - return self.mean diff --git a/dito/models/ldm/renderers.py b/dito/models/ldm/renderers.py deleted file mode 100644 index c7fcd2189dd0ad26fd313f738860d7d5eda39b08..0000000000000000000000000000000000000000 --- a/dito/models/ldm/renderers.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch.nn as nn - -import models -from models import register - - -@register('fixres_renderer_wrapper') -class FixresRendererWrapper(nn.Module): - - def __init__(self, net): - super().__init__() - self.net = models.make(net) - - def forward(self, x, coord=None, scale=None, **kwargs): - return self.net(x, **kwargs) - - def get_last_layer_weight(self): - return self.net.get_last_layer_weight() diff --git a/dito/models/ldm/vqgan/__init__.py b/dito/models/ldm/vqgan/__init__.py deleted file mode 100644 index 16281fe0b66dbac563229823d656ef173736e306..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import * diff --git a/dito/models/ldm/vqgan/discriminator.py b/dito/models/ldm/vqgan/discriminator.py deleted file mode 100644 index b420d617d08bc2625944f8f8dead7aa60c866808..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/discriminator.py +++ /dev/null @@ -1,154 +0,0 @@ -import functools -import torch -import torch.nn as nn - - -def make_discriminator(**kwargs): - return NLayerDiscriminator(**kwargs).apply(weights_init) - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find('BatchNorm') != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator as in Pix2Pix - --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py - """ - def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator, self).__init__() - if not use_actnorm: - norm_layer = nn.BatchNorm2d - else: - norm_layer = ActNorm - if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func != nn.BatchNorm2d - else: - use_bias = norm_layer != nn.BatchNorm2d - - kw = 4 - padw = 1 - sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2 ** n, 8) - sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - nf_mult_prev = nf_mult - nf_mult = min(2 ** n_layers, 8) - sequence += [ - nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True) - ] - - sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map - self.main = nn.Sequential(*sequence) - - def forward(self, input): - """Standard forward.""" - return self.main(input) - - -class ActNorm(nn.Module): - def __init__(self, num_features, logdet=False, affine=True, - allow_reverse_init=False): - assert affine - super().__init__() - self.logdet = logdet - self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) - self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) - self.allow_reverse_init = allow_reverse_init - - self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) - - def initialize(self, input): - with torch.no_grad(): - flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) - mean = ( - flatten.mean(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - std = ( - flatten.std(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - - self.loc.data.copy_(-mean) - self.scale.data.copy_(1 / (std + 1e-6)) - - def forward(self, input, reverse=False): - if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:,:,None,None] - squeeze = True - else: - squeeze = False - - _, _, height, width = input.shape - - if self.training and self.initialized.item() == 0: - self.initialize(input) - self.initialized.fill_(1) - - h = self.scale * (input + self.loc) - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - - if self.logdet: - log_abs = torch.log(torch.abs(self.scale)) - logdet = height*width*torch.sum(log_abs) - logdet = logdet * torch.ones(input.shape[0]).to(input) - return h, logdet - - return h - - def reverse(self, output): - if self.training and self.initialized.item() == 0: - if not self.allow_reverse_init: - raise RuntimeError( - "Initializing ActNorm in reverse direction is " - "disabled by default. Use allow_reverse_init=True to enable." - ) - else: - self.initialize(output) - self.initialized.fill_(1) - - if len(output.shape) == 2: - output = output[:,:,None,None] - squeeze = True - else: - squeeze = False - - h = output / self.scale - self.loc - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - return h diff --git a/dito/models/ldm/vqgan/lpips.py b/dito/models/ldm/vqgan/lpips.py deleted file mode 100644 index 289e50e34c8945a944486d7f0ee04f76de68eb78..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/lpips.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" - -import torch -import torch.nn as nn -from torchvision import models -from collections import namedtuple - - -class LPIPS(nn.Module): - # Learned perceptual metric - def __init__(self, ckpt='load/vgg_lpips.pth', use_dropout=True): - super().__init__() - self.scaling_layer = ScalingLayer() - self.chns = [64, 128, 256, 512, 512] # vg16 features - self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - self.load_from_pretrained(ckpt) - for param in self.parameters(): - param.requires_grad = False - - def load_from_pretrained(self, ckpt): - self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) - # print("loaded pretrained LPIPS loss from {}".format(ckpt)) - - def forward(self, input, target): - in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) - outs0, outs1 = self.net(in0_input), self.net(in1_input) - feats0, feats1, diffs = {}, {}, {} - lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - for kk in range(len(self.chns)): - feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) - diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - - res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] - val = res[0] - for l in range(1, len(self.chns)): - val += res[l] - return val - - -class ScalingLayer(nn.Module): - def __init__(self): - super(ScalingLayer, self).__init__() - self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) - self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) - - def forward(self, inp): - return (inp - self.shift) / self.scale - - -class NetLinLayer(nn.Module): - """ A single linear layer which does a 1x1 conv """ - def __init__(self, chn_in, chn_out=1, use_dropout=False): - super(NetLinLayer, self).__init__() - layers = [nn.Dropout(), ] if (use_dropout) else [] - layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] - self.model = nn.Sequential(*layers) - - -class vgg16(torch.nn.Module): - def __init__(self, requires_grad=False, pretrained=True): - super(vgg16, self).__init__() - if pretrained: - vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features - else: - vgg_pretrained_features = models.vgg16().features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - self.N_slices = 5 - for x in range(4): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(4, 9): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(9, 16): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(16, 23): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(23, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h = self.slice1(X) - h_relu1_2 = h - h = self.slice2(h) - h_relu2_2 = h - h = self.slice3(h) - h_relu3_3 = h - h = self.slice4(h) - h_relu4_3 = h - h = self.slice5(h) - h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - return out - - -def normalize_tensor(x,eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) - return x/(norm_factor+eps) - - -def spatial_average(x, keepdim=True): - return x.mean([2,3],keepdim=keepdim) diff --git a/dito/models/ldm/vqgan/model.py b/dito/models/ldm/vqgan/model.py deleted file mode 100644 index 6b82e6314bfd43a87ad48c45db6434c60bbc8662..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/model.py +++ /dev/null @@ -1,845 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange - -from models import register - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) - return self.to_out(out) - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) - return emb - - -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, kernel_size=3, conv_shortcut=False, - dropout, temb_channels=512, normalize=True): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) if normalize else torch.nn.Identity() - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) if normalize else torch.nn.Identity() - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) - - h_ = self.proj_out(h_) - - return x+h_ - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' - # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8,16), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - print('ch_mult: ', ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - print('num_resolutions: ', self.num_resolutions) - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # end - z_out = 2*z_channels if double_z else z_channels - print('z_out: ', z_out) - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - z_out, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - print('encoder h shape: ', h.shape) - return h - - -class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - # print("Working with z of shape {} = {} dimensions.".format( - # self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - if not self.give_pre_end: - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -@register('simple_renderer_net') -class SimpleRendererNet(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels=3, kernel_size=3, normalize=True, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, hidden_channels, kernel_size, padding=(kernel_size - 1) // 2), - ResnetBlock(in_channels=hidden_channels, - out_channels=hidden_channels, - kernel_size=kernel_size, - temb_channels=0, dropout=0.0, normalize=normalize), - ResnetBlock(in_channels=hidden_channels, - out_channels=hidden_channels, - kernel_size=kernel_size, - temb_channels=0, dropout=0.0, normalize=normalize)]) - self.norm_out = Normalize(hidden_channels) if normalize else torch.nn.Identity() - self.conv_out = torch.nn.Conv2d(hidden_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2) - - def get_last_layer_weight(self): - return self.conv_out.weight - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1, 2]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -@register('vqgan_last_conv') -class VQGANLastConv(nn.Module): - def __init__(self, in_channels, out_channels=3): - super().__init__() - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def get_last_layer_weight(self): - return self.conv_out.weight - - def forward(self, x): - x = self.norm_out(x) - x = nonlinearity(x) - x = self.conv_out(x) - return x - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) - - def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: - return x - else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) - return x diff --git a/dito/models/ldm/vqgan/quantizer.py b/dito/models/ldm/vqgan/quantizer.py deleted file mode 100644 index 863475a68248d28775708875189144d2704540cc..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/quantizer.py +++ /dev/null @@ -1,123 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange - - -class VectorQuantizer(nn.Module): - """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. - """ - # NOTE: due to a bug the beta term was applied to the wrong term. for - # backwards compatibility we use the buggy version by default, but you can - # specify legacy=False to fix it. - def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", - sane_index_shape=False, legacy=True): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.legacy = legacy - - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed+1 - print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices.") - else: - self.re_embed = n_e - - self.sane_index_shape = sane_index_shape - - def remap_to_used(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - match = (inds[:,:,None]==used[None,None,...]).long() - new = match.argmax(-1) - unknown = match.sum(2)<1 - if self.unknown_index == "random": - new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds): - ishape = inds.shape - assert len(ishape)>1 - inds = inds.reshape(ishape[0],-1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds>=self.used.shape[0]] = 0 # simply set to zero - back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) - return back.reshape(ishape) - - def forward(self, z, temp=None, rescale_logits=False, return_logits=False): - assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" - assert rescale_logits==False, "Only for interface compatible with Gumbel" - assert return_logits==False, "Only for interface compatible with Gumbel" - # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, 'b c h w -> b h w c').contiguous() - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ - torch.sum(self.embedding.weight**2, dim=1) - 2 * \ - torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - perplexity = None - min_encodings = None - - # compute loss for embedding - if not self.legacy: - loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ - torch.mean((z_q - z.detach()) ** 2) - else: - loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ - torch.mean((z_q - z.detach()) ** 2) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten - - if self.sane_index_shape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3]) - - return z_q, loss, (perplexity, min_encodings, min_encoding_indices) - - def get_codebook_entry(self, indices, shape): - # shape specifying (batch, height, width, channel) - if self.remap is not None: - indices = indices.reshape(shape[0],-1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q diff --git a/dito/models/ldm/vqgan/utils.py b/dito/models/ldm/vqgan/utils.py deleted file mode 100644 index f895a1ab1d3d32cb6cdae7ef1bfca5765bd16b35..0000000000000000000000000000000000000000 --- a/dito/models/ldm/vqgan/utils.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch.nn as nn - - -from models import register -from .model import Encoder, Decoder - - -default_configs = { - 'f8c4': dict( - double_z=False, - z_channels=64, - resolution=256, - in_channels=3, - out_ch=3, - ch=128, - ch_mult=[1, 2, 2, 4, 4, 4, 4, 8, 8], - num_res_blocks=2, - attn_resolutions=[], - dropout=0.0, - give_pre_end=True, - ), - 'f16c8': dict( - double_z=False, - z_channels=8, - resolution=256, - in_channels=3, - out_ch=3, - ch=128, - ch_mult=[1, 2, 4, 4, 4], - num_res_blocks=2, - attn_resolutions=[], - dropout=0.0, - give_pre_end=True, - ), -} - - -@register('vqgan_encoder') -def make_vqgan_encoder(config_name, **kwargs): - encoder_kwargs = default_configs[config_name] - encoder_kwargs.update(kwargs) - enc_out_channels = encoder_kwargs['z_channels'] * (2 if encoder_kwargs['double_z'] else 1) - return nn.Sequential( - Encoder(**encoder_kwargs), - nn.Conv2d(enc_out_channels, enc_out_channels, 1), - ) - - -@register('vqgan_decoder') -def make_vqgan_decoder(config_name, **kwargs): - decoder_kwargs = default_configs[config_name] - decoder_kwargs.update(kwargs) - dec_in_channels = decoder_kwargs['z_channels'] - return nn.Sequential( - nn.Conv2d(dec_in_channels, dec_in_channels, 1), - Decoder(**decoder_kwargs), - ) diff --git a/dito/models/models.py b/dito/models/models.py deleted file mode 100644 index 56e8d189cb0418a75de81461dadb472691590cca..0000000000000000000000000000000000000000 --- a/dito/models/models.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch - - -models = dict() - - -def register(name): - def decorator(cls): - models[name] = cls - return cls - return decorator - - -def load_sd_from_ckpt(ckpt, keys_only=None): - sd = torch.load(ckpt, map_location='cpu')['model']['sd'] - if keys_only is not None: - keys_only_dot = tuple([_ + '.' for _ in keys_only]) - keys_only = set(keys_only) - for k in list(sd.keys()): - if not (k in keys_only or k.startswith(keys_only_dot)): - sd.pop(k) - return sd - - -def make(spec, load_sd=False): - args = spec.get('args') - if args is None: - args = dict() - model = models[spec['name']](**args) - print('args', args) - - if spec.get('load_ckpt') is not None: - sd = load_sd_from_ckpt(spec['load_ckpt'], spec.get('load_ckpt_keys_only')) - model.load_state_dict(sd, strict=False) - - if load_sd: - model.load_state_dict(spec['sd']) - - return model - - -@register('identity') -def make_identity(): - return torch.nn.Identity() diff --git a/dito/models/networks/__init__.py b/dito/models/networks/__init__.py deleted file mode 100644 index 8014d9824e1b32c71956711e6dcd307e9a4c5920..0000000000000000000000000000000000000000 --- a/dito/models/networks/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import consistency_decoder_unet -from . import dit \ No newline at end of file diff --git a/dito/models/networks/consistency_decoder_unet.py b/dito/models/networks/consistency_decoder_unet.py deleted file mode 100644 index 2a729348e88d988901cb7cada77eb699523c603b..0000000000000000000000000000000000000000 --- a/dito/models/networks/consistency_decoder_unet.py +++ /dev/null @@ -1,268 +0,0 @@ -# https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac - -import torch -import torch.nn.functional as F -import torch.nn as nn - -from models import register - - -class TimestepEmbedding(nn.Module): - def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: - super().__init__() - self.emb = nn.Embedding(n_time, n_emb) - self.f_1 = nn.Linear(n_emb, n_out) - self.f_2 = nn.Linear(n_out, n_out) - - def forward(self, x) -> torch.Tensor: - x = self.emb(x) - x = self.f_1(x) - x = F.silu(x) - return self.f_2(x) - - -class PositionalEmbedding(nn.Module): - def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True): - super().__init__() - self.num_channels = pe_dim - self.max_positions = max_positions - self.endpoint = endpoint - self.f_1 = nn.Linear(pe_dim, out_dim) - self.f_2 = nn.Linear(out_dim, out_dim) - - def forward(self, x): - freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) - freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) - freqs = (1 / self.max_positions) ** freqs - x = x.ger(freqs.to(x.dtype)) - x = torch.cat([x.cos(), x.sin()], dim=1) - - x = self.f_1(x) - x = F.silu(x) - return self.f_2(x) - - -class ImageEmbedding(nn.Module): - def __init__(self, in_channels, out_channels=320) -> None: - super().__init__() - self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - - def forward(self, x) -> torch.Tensor: - return self.f(x) - - -class ImageUnembedding(nn.Module): - def __init__(self, in_channels=320, out_channels=3) -> None: - super().__init__() - self.gn = nn.GroupNorm(32, in_channels) - self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) - - def forward(self, x) -> torch.Tensor: - return self.f(F.silu(self.gn(x))) - - -class ConvResblock(nn.Module): - def __init__(self, in_features, out_features, t_dim) -> None: - super().__init__() - self.f_t = nn.Linear(t_dim, out_features * 2) - - self.gn_1 = nn.GroupNorm(32, in_features) - self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) - - self.gn_2 = nn.GroupNorm(32, out_features) - self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) - - skip_conv = in_features != out_features - self.f_s = ( - nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) - if skip_conv - else nn.Identity() - ) - - def forward(self, x, t): - x_skip = x - t = self.f_t(F.silu(t)) - t = t.chunk(2, dim=1) - t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 - t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) - - gn_1 = F.silu(self.gn_1(x)) - f_1 = self.f_1(gn_1) - - gn_2 = self.gn_2(f_1) - - return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) - - -# Also ConvResblock -class Downsample(nn.Module): - def __init__(self, in_channels, t_dim) -> None: - super().__init__() - self.f_t = nn.Linear(t_dim, in_channels * 2) - - self.gn_1 = nn.GroupNorm(32, in_channels) - self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - self.gn_2 = nn.GroupNorm(32, in_channels) - - self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, x, t) -> torch.Tensor: - x_skip = x - - t = self.f_t(F.silu(t)) - t_1, t_2 = t.chunk(2, dim=1) - t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 - t_2 = t_2.unsqueeze(2).unsqueeze(3) - - gn_1 = F.silu(self.gn_1(x)) - avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) - f_1 = self.f_1(avg_pool2d) - gn_2 = self.gn_2(f_1) - - f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) - - return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) - - -# Also ConvResblock -class Upsample(nn.Module): - def __init__(self, in_channels, t_dim) -> None: - super().__init__() - self.f_t = nn.Linear(t_dim, in_channels * 2) - - self.gn_1 = nn.GroupNorm(32, in_channels) - self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - self.gn_2 = nn.GroupNorm(32, in_channels) - - self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) - - def forward(self, x, t) -> torch.Tensor: - x_skip = x - - t = self.f_t(F.silu(t)) - t_1, t_2 = t.chunk(2, dim=1) - t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 - t_2 = t_2.unsqueeze(2).unsqueeze(3) - - gn_1 = F.silu(self.gn_1(x)) - upsample = F.upsample_nearest(gn_1, scale_factor=2) - f_1 = self.f_1(upsample) - gn_2 = self.gn_2(f_1) - - f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) - - return f_2 + F.upsample_nearest(x_skip, scale_factor=2) - - -@register('consistency_decoder_unet') -class ConsistencyDecoderUNet(nn.Module): - def __init__(self, in_channels=3, z_dec_channels=None, c0=320, c1=640, c2=1024, pe_dim=320, t_dim=1280) -> None: - super().__init__() - if z_dec_channels is not None: - in_channels += z_dec_channels - self.embed_image = ImageEmbedding(in_channels=in_channels, out_channels=c0) - self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim) - - down_0 = nn.ModuleList([ - ConvResblock(c0, c0, t_dim), - ConvResblock(c0, c0, t_dim), - ConvResblock(c0, c0, t_dim), - Downsample(c0, t_dim), - ]) - down_1 = nn.ModuleList([ - ConvResblock(c0, c1, t_dim), - ConvResblock(c1, c1, t_dim), - ConvResblock(c1, c1, t_dim), - Downsample(c1, t_dim), - ]) - down_2 = nn.ModuleList([ - ConvResblock(c1, c2, t_dim), - ConvResblock(c2, c2, t_dim), - ConvResblock(c2, c2, t_dim), - Downsample(c2, t_dim), - ]) - down_3 = nn.ModuleList([ - ConvResblock(c2, c2, t_dim), - ConvResblock(c2, c2, t_dim), - ConvResblock(c2, c2, t_dim), - ]) - self.down = nn.ModuleList([ - down_0, - down_1, - down_2, - down_3, - ]) - - self.mid = nn.ModuleList([ - ConvResblock(c2, c2, t_dim), - ConvResblock(c2, c2, t_dim), - ]) - - up_3 = nn.ModuleList([ - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 * 2, c2, t_dim), - Upsample(c2, t_dim), - ]) - up_2 = nn.ModuleList([ - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 * 2, c2, t_dim), - ConvResblock(c2 + c1, c2, t_dim), - Upsample(c2, t_dim), - ]) - up_1 = nn.ModuleList([ - ConvResblock(c2 + c1, c1, t_dim), - ConvResblock(c1 * 2, c1, t_dim), - ConvResblock(c1 * 2, c1, t_dim), - ConvResblock(c0 + c1, c1, t_dim), - Upsample(c1, t_dim), - ]) - up_0 = nn.ModuleList([ - ConvResblock(c0 + c1, c0, t_dim), - ConvResblock(c0 * 2, c0, t_dim), - ConvResblock(c0 * 2, c0, t_dim), - ConvResblock(c0 * 2, c0, t_dim), - ]) - self.up = nn.ModuleList([ - up_0, - up_1, - up_2, - up_3, - ]) - - self.output = ImageUnembedding(in_channels=c0) - - def get_last_layer_weight(self): - return self.output.f.weight - - def forward(self, x, t=None, z_dec=None) -> torch.Tensor: - if z_dec is not None: - if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]: - assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1] - z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2]) - x = torch.cat([x, z_dec], dim=1) - - x = self.embed_image(x) - - if t is None: - t = torch.zeros(x.shape[0], device=x.device) - t = self.embed_time(t) - - skips = [x] - for down in self.down: - for block in down: - x = block(x, t) - skips.append(x) - - for mid in self.mid: - x = mid(x, t) - - for up in self.up[::-1]: - for block in up: - if isinstance(block, ConvResblock): - x = torch.concat([x, skips.pop()], dim=1) - x = block(x, t) - - return self.output(x) diff --git a/dito/models/networks/dit.py b/dito/models/networks/dit.py deleted file mode 100644 index 439a6331f10c13a9eeb2c81793ab933475b61c2b..0000000000000000000000000000000000000000 --- a/dito/models/networks/dit.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# References: -# GLIDE: https://github.com/openai/glide-text2im -# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py -# -------------------------------------------------------- -import math - -import torch -import torch.nn as nn -import numpy as np -from timm.models.vision_transformer import PatchEmbed, Attention, Mlp - -from models import register - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -################################################################################# -# Embedding Layers for Timesteps and Class Labels # -################################################################################# - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - - -class LabelEmbedder(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - """ - def __init__(self, num_classes, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) - self.num_classes = num_classes - self.dropout_prob = dropout_prob - - def token_drop(self, labels, force_drop_ids=None): - """ - Drops labels to enable classifier-free guidance. - """ - if force_drop_ids is None: - drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob - else: - drop_ids = force_drop_ids == 1 - labels = torch.where(drop_ids, self.num_classes, labels) - return labels - - def forward(self, labels, train, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (train and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - embeddings = self.embedding_table(labels) - return embeddings - - -################################################################################# -# Core DiT Model # -################################################################################# - -class DiTBlock(nn.Module): - """ - A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. - """ - def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): - super().__init__() - self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - approx_gelu = lambda: nn.GELU(approximate="tanh") - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) - x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FinalLayer(nn.Module): - """ - The final layer of DiT. - """ - def __init__(self, hidden_size, patch_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - - -class DiT(nn.Module): - """ - Diffusion model with a Transformer backbone. - """ - def __init__( - self, - input_size=32, - patch_size=2, - in_channels=4, - hidden_size=1152, - depth=28, - num_heads=16, - mlp_ratio=4.0, - class_dropout_prob=0.0, - n_classes=1000, - learn_sigma=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels * 2 if learn_sigma else in_channels - self.patch_size = patch_size - self.num_heads = num_heads - - self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(n_classes, hidden_size, class_dropout_prob) - num_patches = self.x_embedder.num_patches - # Will use fixed sin-cos embedding: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) - - self.blocks = nn.ModuleList([ - DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) - ]) - self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) - self.initialize_weights() - - def initialize_weights(self): - # Initialize transformer layers: - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - self.apply(_basic_init) - - # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers in DiT blocks: - for block in self.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def unpatchify(self, x): - """ - x: (N, T, patch_size**2 * C) - imgs: (N, H, W, C) - """ - c = self.out_channels - p = self.x_embedder.patch_size[0] - h = w = int(x.shape[1] ** 0.5) - assert h * w == x.shape[1] - - x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) - x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) - return imgs - - def forward(self, x, t, class_labels): - """ - Forward pass of DiT. - x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) - t: (N,) tensor of diffusion timesteps - y: (N,) tensor of class labels - """ - x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 - t = self.t_embedder(t) # (N, D) - y = self.y_embedder(class_labels, self.training) # (N, D) - c = t + y # (N, D) - for block in self.blocks: - x = block(x, c) # (N, T, D) - x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) - x = self.unpatchify(x) # (N, out_channels, H, W) - return x - - def forward_with_cfg(self, x, t, y, cfg_scale): - """ - Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. - """ - # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb - half = x[: len(x) // 2] - combined = torch.cat([half, half], dim=0) - model_out = self.forward(combined, t, y) - # For exact reproducibility reasons, we apply classifier-free guidance on only - # three channels by default. The standard approach to cfg applies it to all channels. - # This can be done by uncommenting the following line and commenting-out the line following that. - # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] - eps, rest = model_out[:, :3], model_out[:, 3:] - cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) - half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) - eps = torch.cat([half_eps, half_eps], dim=0) - return torch.cat([eps, rest], dim=1) - - -################################################################################# -# Sine/Cosine Positional Embedding Functions # -################################################################################# -# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py - -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token and extra_tokens > 0: - pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float64) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -################################################################################# -# DiT Configs # -################################################################################# - -@register('dit_xl_2') -def DiT_XL_2(**kwargs): - return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) - -@register('dit_xl_4') -def DiT_XL_4(**kwargs): - return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) - -@register('dit_xl_8') -def DiT_XL_8(**kwargs): - return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) - -@register('dit_l_2') -def DiT_L_2(**kwargs): - return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) - -@register('dit_l_4') -def DiT_L_4(**kwargs): - return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) - -@register('dit_l_8') -def DiT_L_8(**kwargs): - return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) - -@register('dit_b_2') -def DiT_B_2(**kwargs): - return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) - -@register('dit_b_4') -def DiT_B_4(**kwargs): - return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) - -@register('dit_b_8') -def DiT_B_8(**kwargs): - return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) - -@register('dit_s_2') -def DiT_S_2(**kwargs): - return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) - -@register('dit_s_4') -def DiT_S_4(**kwargs): - return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) - -@register('dit_s_8') -def DiT_S_8(**kwargs): - return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) - - -DiT_models = { - 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, - 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, - 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, - 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, -} \ No newline at end of file diff --git a/dito/reconstruction.py b/dito/reconstruction.py deleted file mode 100644 index a3fee6a4b8779c9cdbdadba258d2589024b6478d..0000000000000000000000000000000000000000 --- a/dito/reconstruction.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import torch.nn as nn -from PIL import Image -from torchvision import transforms -import numpy as np -from pathlib import Path -import argparse - -# You'll need to have the DiTo codebase available -import models -from omegaconf import OmegaConf - -class DiToInference: - def __init__(self, checkpoint_path, device='cuda'): - """Initialize DiTo model from checkpoint""" - self.device = device - - # Load checkpoint - print(f"Loading checkpoint from {checkpoint_path}") - ckpt = torch.load(checkpoint_path, map_location='cpu') - - # Extract config - self.config = OmegaConf.create(ckpt['config']) - - # Create model - self.model = models.make(self.config['model']) - - # Load state dict - self.model.load_state_dict(ckpt['model']['sd']) - - # Move to device and set to eval - self.model = self.model.to(device) - self.model.eval() - - # Setup image transforms based on config - self.transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(256), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - ]) - - print("Model loaded successfully!") - - def reconstruct_image(self, image_path, debug=True): - """Reconstruct a single image""" - # Load and preprocess image - image = Image.open(image_path).convert('RGB') - - if debug: - debug_transform = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(256), - ]) - debug_image = debug_transform(image) - debug_image.save('debug_1_resized_cropped.png') - print("Saved debug_1_resized_cropped.png") - - image_tensor = self.transform(image).unsqueeze(0).to(self.device) - - with torch.no_grad(): - # Step 1: Encode to latent - z = self.model.encode(image_tensor) - - # Step 2: Decode to features (in DiTo this is identity) - z_dec = self.model.decode(z) - print('z_dec.shape:', z_dec.shape) - - # Step 3: Prepare coordinate grids - # Based on the training code, coord and scale are dummy values - b, c, h, w = image_tensor.shape - coord = torch.zeros(b, 2, h, w, device=self.device) - scale = torch.zeros(b, 2, h, w, device=self.device) - - # Step 4: Render using diffusion - reconstructed = self.model.render(z_dec, coord, scale) - - # Denormalize from [-1, 1] to [0, 1] - reconstructed = (reconstructed * 0.5 + 0.5).clamp(0, 1) - - return reconstructed - - def save_reconstruction(self, image_path, output_path): - """Reconstruct and save image""" - reconstructed = self.reconstruct_image(image_path) - - # Convert to PIL - to_pil = transforms.ToPILImage() - reconstructed_pil = to_pil(reconstructed.squeeze(0).cpu()) - - # Save - reconstructed_pil.save(output_path) - print(f"Saved reconstruction to {output_path}") - - def compare_reconstruction(self, image_path, output_path): - """Save original and reconstruction side by side""" - # Get reconstruction - reconstructed = self.reconstruct_image(image_path) - - # Load original at same resolution - original = Image.open(image_path).convert('RGB') - original = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(256), - transforms.ToTensor() - ])(original).unsqueeze(0) - - # Concatenate side by side - comparison = torch.cat([original, reconstructed.cpu()], dim=3) - - # Save - to_pil = transforms.ToPILImage() - comparison_pil = to_pil(comparison.squeeze(0)) - comparison_pil.save(output_path) - print(f"Saved comparison to {output_path}") - - def batch_reconstruct(self, image_folder, output_folder, max_images=None): - """Reconstruct all images in a folder""" - image_folder = Path(image_folder) - output_folder = Path(output_folder) - output_folder.mkdir(exist_ok=True, parents=True) - - # Get all images - image_paths = list(image_folder.glob('*.png')) + \ - list(image_folder.glob('*.jpg')) + \ - list(image_folder.glob('*.jpeg')) - - if max_images: - image_paths = image_paths[:max_images] - - print(f"Processing {len(image_paths)} images...") - - for img_path in image_paths: - output_path = output_folder / f"recon_{img_path.name}" - self.save_reconstruction(str(img_path), str(output_path)) - - print("Batch reconstruction complete!") - -def main(): - parser = argparse.ArgumentParser(description='DiTo Image Reconstruction') - parser.add_argument('--checkpoint', type=str, required=True, - help='Path to DiTo checkpoint') - parser.add_argument('--input', type=str, required=True, - help='Input image path or folder') - parser.add_argument('--output', type=str, required=True, - help='Output path') - parser.add_argument('--compare', action='store_true', - help='Save comparison with original') - parser.add_argument('--batch', action='store_true', - help='Process entire folder') - parser.add_argument('--device', type=str, default='cuda', - help='Device to use (cuda/cpu)') - parser.add_argument('--max_images', type=int, default=None, - help='Maximum images to process in batch mode') - - args = parser.parse_args() - - # Initialize model - dito = DiToInference(args.checkpoint, device=args.device) - - # Process based on mode - if args.batch: - dito.batch_reconstruct(args.input, args.output, args.max_images) - elif args.compare: - dito.compare_reconstruction(args.input, args.output) - else: - dito.save_reconstruction(args.input, args.output) - -# Example usage function for direct Python use -def reconstruct_single_image(checkpoint_path, image_path, output_path): - """Simple function to reconstruct a single image""" - dito = DiToInference(checkpoint_path) - dito.save_reconstruction(image_path, output_path) - -if __name__ == "__main__": - main() - -# Usage examples: -# 1. Single image reconstruction: -# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output recon.jpg -# -# 2. Single image with comparison: -# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output compare.jpg --compare -# -# 3. Batch processing: -# python dito_inference.py --checkpoint ckpt-best.pth --input input_folder/ --output output_folder/ --batch -# -# 4. Direct Python usage: -# reconstruct_single_image('ckpt-best.pth', 'input.jpg', 'output.jpg') \ No newline at end of file diff --git a/dito/requirements.txt b/dito/requirements.txt deleted file mode 100644 index 531d37316f2a1a85065157d3a1d91df624a07c96..0000000000000000000000000000000000000000 --- a/dito/requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -torch==2.3.0 -torchvision==0.18.0 -torch_fidelity==0.3.0 -omegaconf -pyyaml -wandb -webdataset -timm -einops \ No newline at end of file diff --git a/dito/run.py b/dito/run.py deleted file mode 100644 index 9387ef27d776ab0be9850aa23fd2fa912f5bd583..0000000000000000000000000000000000000000 --- a/dito/run.py +++ /dev/null @@ -1,59 +0,0 @@ -import argparse -import os - -from omegaconf import OmegaConf - -from trainers import trainers_dict - - -def make_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--config', default='configs/_.yaml') - parser.add_argument('--name', '-n', default=None) - parser.add_argument('--tag', '-t', default=None) - parser.add_argument('--resume', '-r', action='store_true') - parser.add_argument('--force-replace', '-f', action='store_true') - parser.add_argument('--wandb', '-w', action='store_true') - parser.add_argument('--save-root', default='save') - parser.add_argument('--eval-only', action='store_true') - args = parser.parse_args() - return args - - -def parse_config(config): - if config.get('__base__') is not None: - filenames = config.pop('__base__') - if isinstance(filenames, str): - filenames = [filenames] - base_config = OmegaConf.merge(*[ - parse_config(OmegaConf.load(_)) - for _ in filenames - ]) - config = OmegaConf.merge(base_config, config) - return config - - -def make_env(args): - env = dict() - - if args.name is None: - exp_name = os.path.splitext(os.path.basename(args.config))[0] - else: - exp_name = args.name - if args.tag is not None: - exp_name += '_' + args.tag - env['exp_name'] = exp_name - - env['save_dir'] = os.path.join(args.save_root, exp_name) - env['wandb'] = args.wandb - env['resume'] = args.resume - env['force_replace'] = args.force_replace - return env - - -if __name__ == '__main__': - args = make_args() - env = make_env(args) - config = parse_config(OmegaConf.load(args.config)) - trainer = trainers_dict[config.trainer](env, config) - trainer.run(eval_only=args.eval_only) diff --git a/dito/utils/__init__.py b/dito/utils/__init__.py deleted file mode 100644 index 16281fe0b66dbac563229823d656ef173736e306..0000000000000000000000000000000000000000 --- a/dito/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import * diff --git a/dito/utils/geometry.py b/dito/utils/geometry.py deleted file mode 100644 index 28e4c140a45d44f8fb9a3d14dbf58f2531a56695..0000000000000000000000000000000000000000 --- a/dito/utils/geometry.py +++ /dev/null @@ -1,37 +0,0 @@ -import torch - - -def make_coord_grid(shape, range=(0, 1), device='cpu', batch_size=None): - """ - Args: - shape: (s_1, ..., s_k), grid shape - range: range for each axis, list or tuple, [minv, maxv] or [[minv_1, maxv_1], ..., [minv_k, maxv_k]] - Returns: - (s_1, ..., s_k, k), coordinate grid - """ - p_lst = [] - for i, n in enumerate(shape): - p = (torch.arange(n, device=device) + 0.5) / n - if isinstance(range[0], list) or isinstance(range[0], tuple): - minv, maxv = range[i] - else: - minv, maxv = range - p = minv + (maxv - minv) * p - p_lst.append(p) - coord = torch.stack(torch.meshgrid(*p_lst, indexing='ij'), dim=-1) - - if batch_size is not None: - coord = coord.unsqueeze(0).expand(batch_size, *([-1] * coord.dim())) - return coord - - -def make_coord_scale_grid(shape, range=(0, 1), device='cpu', batch_size=None): - coord = make_coord_grid(shape, range=range, device=device, batch_size=batch_size) - scale = torch.ones_like(coord) - for i, n in enumerate(shape): - if isinstance(range[0], list) or isinstance(range[0], tuple): - minv, maxv = range[i] - else: - minv, maxv = range - scale[..., i] *= (maxv - minv) / n - return coord, scale diff --git a/dito/utils/utils.py b/dito/utils/utils.py deleted file mode 100644 index f5f879c3cb8415249e362e766cff28a9af87daeb..0000000000000000000000000000000000000000 --- a/dito/utils/utils.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import shutil -import time -import logging - -from torch.optim import SGD, Adam, AdamW - - -def ensure_path(path, replace=True, force_replace=False): - is_temp = os.path.basename(path.rstrip('/')).startswith('_') - if os.path.exists(path): - if replace and (is_temp or force_replace or input(f'{path} exists, replace? y/[n] ') == 'y'): - shutil.rmtree(path) - os.mkdir(path) - else: - os.makedirs(path) - - -def set_logger(file_path): - logger = logging.getLogger() - logger.setLevel('INFO') - stream_handler = logging.StreamHandler() - file_handler = logging.FileHandler(file_path, 'a') - formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S') - for handler in [stream_handler, file_handler]: - handler.setFormatter(formatter) - handler.setLevel('INFO') - logger.addHandler(handler) - return logger - - -def compute_num_params(model, text=True): - tot = sum(p.numel() for p in model.parameters()) - if text: - if tot >= 1e6: - s = '{:.1f}M'.format(tot / 1e6) - else: - s = '{:.1f}K'.format(tot / 1e3) - return f'{s} ({tot})' - else: - return tot - - -def make_optimizer(params, optimizer_spec): - optimizer = { - 'sgd': SGD, - 'adam': Adam, - 'adamw': AdamW, - }[optimizer_spec['name']](params, **optimizer_spec['args']) - return optimizer - - -class Averager(): - - def __init__(self, v=None): - if v is None: - self.n = 0. - self.v = 0. - else: - self.n = 1. - self.v = v - - def add(self, v, n=1.0): - self.v = self.v * (self.n / (self.n + n)) + v * (n / (self.n + n)) - self.n += n - - def item(self): - return self.v - - -class EpochTimer(): - - def __init__(self, max_epoch): - self.max_epoch = max_epoch - self.epoch = 0 - self.t_start = time.time() - self.t_last = self.t_start - - def epoch_done(self): - t_cur = time.time() - self.epoch += 1 - epoch_time = t_cur - self.t_last - tot_time = t_cur - self.t_start - est_time = tot_time / self.epoch * self.max_epoch - self.t_last = t_cur - return time_text(epoch_time), time_text(tot_time), time_text(est_time) - - -def time_text(sec): - if sec >= 3600: - return f'{sec / 3600:.1f}h' - elif sec >= 60: - return f'{sec / 60:.1f}m' - else: - return f'{sec:.1f}s'