diff --git a/dito/LICENSE b/dito/LICENSE deleted file mode 100644 index 261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64..0000000000000000000000000000000000000000 --- a/dito/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/dito/README.md b/dito/README.md deleted file mode 100644 index 93a100c4e116f1627edbae77351708a98f451d68..0000000000000000000000000000000000000000 --- a/dito/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Diffusion Autoencoders are Scalable Image Tokenizers (DiTo) - -drawing - -## Environment - -```.bash -conda create -n dito python=3.11 -y -conda activate dito -pip install -r requirements.txt -``` - -## Experiments - -- By default, the experiment name is the config name. Experiments are saved in `save/` with corresponding names. Append `-n` to manually set a name. - -- After filling in the information in `load/wandb.yaml`, append `-w` to log to Wandb. - -- Dataset format is image folders. To set up data, fill in `root_path` for configs in `datasets/`. For example, `train` and `val` can be ImageNet training and validation set (as image folders for different classes), `eval_ae`/`eval_zdm` can be a smaller validation subset (image folders, in the paper it is 5K samples in total) that is used to evaluate FID for reconstruction or generation. - -- The commands below are for DiTo-XL, the configs can be changed accordingly (in `configs/experiments/`) for other scales or to enable noise synchronization. If the GPU memory is not sufficient, a multi-node training with `torchrun` is needed (or the batch size can be reduced in `configs/datasets/*.yaml`). In the paper, DiTo at B, L, XL were trained on 1, 2, 4 nodes with 8 A100 per node, and latent diffusion models were trained on 2 nodes. - -### Train diffusion tokenizers - -```.bash -torchrun --nnodes=1 --nproc-per-node=8 run.py --config configs/experiments/dito-XL-f8c4.yaml -``` - -### Train latent diffusion models - -```.bash -torchrun --nnodes=1 --nproc-per-node=8 run.py --config configs/experiments/zdm-XL_dito-XL-f8c4.yaml -``` - -### Evaluate latent diffusion with 50K samples - -```.bash -torchrun --nnodes=1 --nproc-per-node=8 run.py --config configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml --eval-only -``` \ No newline at end of file diff --git a/flowae/.gitignore b/flowae/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ed8ebf583f771da9150c35db3955987b7d757904 --- /dev/null +++ b/flowae/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/flowae/configs/datasets/dae.yaml b/flowae/configs/datasets/dae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8fd9018ae2636223cb3c70f192dd58ae397c1c56 --- /dev/null +++ b/flowae/configs/datasets/dae.yaml @@ -0,0 +1,79 @@ +# Datasets +datasets: + train: + name: wrapper_audio_cae + args: + dataset: + name: audio_dataset_from_folders + args: + folders: + Emilia_EN: ["/home/masuser/minimax-audio/dataset/Emilia/EN"] + sample_rate: 24000 + duration: 0.38 + n_examples: 10000000 + shuffle: true + mono: true + sample_rate: 24000 + duration: 0.38 + mono: true + normalize: true + return_coords: true + loader: + batch_size: 64 + num_workers: 8 + drop_last: true + + val: + name: wrapper_audio_cae + args: + dataset: + name: audio_dataset_from_folders + args: + folders: + Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"] + sample_rate: 24000 + duration: 5.0 + n_examples: 100 + shuffle: false + mono: true + sample_rate: 24000 + duration: 5.0 + mono: true + normalize: true + return_coords: true + loader: + batch_size: 4 + num_workers: 8 + drop_last: false + + eval_ae: + name: wrapper_audio_cae + args: + dataset: + name: audio_dataset_from_folders + args: + folders: + Emilia_EN: ["/home/masuser/minimax-audio/dataset/libritts"] + sample_rate: 24000 + duration: 10.0 + n_examples: 1000 + shuffle: false + mono: true + sample_rate: 24000 + duration: 10.0 + mono: true + normalize: true + return_coords: true + loader: + batch_size: 1 + num_workers: 8 + drop_last: false + +# Visualization +visualize_ae_dir: /mnt/nvme/dito_audio +visualize_ae_random_n_samples: 32 +eval_ae_max_samples: 100 +val_idx: [0, 1, 2, 3, 4, 5, 6, 7] + +# Enable autoencoder evaluation +evaluate_ae: true \ No newline at end of file diff --git a/flowae/configs/datasets/imagenet_ae.yaml b/flowae/configs/datasets/imagenet_ae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5109901f168de8d1f9b30a301fb087e38fcd440e --- /dev/null +++ b/flowae/configs/datasets/imagenet_ae.yaml @@ -0,0 +1,47 @@ +datasets: + train: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, rand_crop: 256, rand_flip: true, image_only: true} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 14 + num_workers: 24 + + val: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 14 + num_workers: 24 + + eval_ae: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true, image_only: true} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 14 + num_workers: 24 + drop_last: false + +visualize_ae_dir: /mnt/nvme/dito +visualize_ae_random_n_samples: 32 +eval_ae_max_samples: 5000 \ No newline at end of file diff --git a/flowae/configs/datasets/imagenet_zdm.yaml b/flowae/configs/datasets/imagenet_zdm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a54041e020d8ede4c9e3e8647c6c05c5339f0af --- /dev/null +++ b/flowae/configs/datasets/imagenet_zdm.yaml @@ -0,0 +1,53 @@ +datasets: + train: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/training, resize: 256, square_crop: true, rand_flip: true, drop_label_p: 0.1} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 64 + num_workers: 24 + + val: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 64 + num_workers: 24 + + eval_zdm: + name: wrapper_cae + args: + dataset: + name: class_folder + args: {root_path: /home/masuser/minimax-audio/mnist_png/testing, resize: 256, square_crop: true} + resize_inp: 256 + gt_glores_lb: 256 + gt_glores_ub: 256 + gt_patch_size: 256 + loader: + batch_size: 64 + num_workers: 24 + drop_last: false + +visualize_zdm_file: null +visualize_zdm_setting: + name: class + n_classes: 1000 +visualize_zdm_random_n_samples: 12 +visualize_zdm_batch_size: 6 +visualize_zdm_guidance_list: [4] +visualize_zdm_denoising_file: null +eval_zdm_max_samples: 5000 \ No newline at end of file diff --git a/flowae/configs/experiments/dito-B-audio.yaml b/flowae/configs/experiments/dito-B-audio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..179a84a947164ba7a687fdbddde627970989fb4b --- /dev/null +++ b/flowae/configs/experiments/dito-B-audio.yaml @@ -0,0 +1,44 @@ +__base__: + - configs/datasets/dae.yaml + - configs/trainers/dito.yaml + +model: + name: dito_audio + args: + # Encoder + encoder: + name: dac_encoder + args: {config_name: snakebeta} + + # Latent configuration - now fully convolutional + z_channels: 64 # Number of latent channels + z_downsample_factor: 320 # Product of encoder_rates: 2*4*5*8 + z_layernorm: true + + # Decoder (identity for DiTo) + decoder: + name: identity + + # Renderer - Fully convolutional for dynamic duration + renderer: + name: audio_renderer_wrapper + args: + net: + name: consistency_decoder_unet # Fully Convolutional Network + args: + in_channels: 1 + z_dec_channels: 64 + c0: 128 + c1: 256 + c2: 512 + pe_dim: 320 + t_dim: 1280 + + # Diffusion configuration + render_diffusion: + name: fm + args: {timescale: 1000.0} + + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + diff --git a/flowae/configs/experiments/dito-B-f8c4-noise-sync.yaml b/flowae/configs/experiments/dito-B-f8c4-noise-sync.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5eb1df5b78efb87d3aea6d9850924a8b85cc8da1 --- /dev/null +++ b/flowae/configs/experiments/dito-B-f8c4-noise-sync.yaml @@ -0,0 +1,43 @@ +__base__: + - configs/datasets/imagenet_ae.yaml + - configs/trainers/dito.yaml + +model: + name: dito + args: + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [64, 1, 1] + z_layernorm: true + + zaug_p: 0.1 + zaug_decoding_loss_type: suffix + zaug_zdm_diffusion: + name: fm + args: {timescale: 1000.0} + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 64 + c0: 128 + c1: 256 + c2: 512 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/dito-B-f8c4.yaml b/flowae/configs/experiments/dito-B-f8c4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..feeab68b5dace98a9701bbc16a27f56d81e0dfc7 --- /dev/null +++ b/flowae/configs/experiments/dito-B-f8c4.yaml @@ -0,0 +1,37 @@ +__base__: + - configs/datasets/imagenet_ae.yaml + - configs/trainers/dito.yaml + +model: + name: dito + args: + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 128 + c1: 256 + c2: 512 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/dito-L-f8c4.yaml b/flowae/configs/experiments/dito-L-f8c4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f242bf6f27755c559e439eba78ae03809f07898 --- /dev/null +++ b/flowae/configs/experiments/dito-L-f8c4.yaml @@ -0,0 +1,37 @@ +__base__: + - configs/datasets/imagenet_ae.yaml + - configs/trainers/dito.yaml + +model: + name: dito + args: + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 192 + c1: 384 + c2: 768 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/dito-XL-f8c4-noise-sync.yaml b/flowae/configs/experiments/dito-XL-f8c4-noise-sync.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ec7761d6570e5456b8099dd6a1eb28754de8c953 --- /dev/null +++ b/flowae/configs/experiments/dito-XL-f8c4-noise-sync.yaml @@ -0,0 +1,43 @@ +__base__: + - configs/datasets/imagenet_ae.yaml + - configs/trainers/dito.yaml + +model: + name: dito + args: + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + zaug_p: 0.1 + zaug_decoding_loss_type: suffix + zaug_zdm_diffusion: + name: fm + args: {timescale: 1000.0} + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/dito-XL-f8c4.yaml b/flowae/configs/experiments/dito-XL-f8c4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8610d23c1d1479538b1dd5d568d427f08ab7e9ec --- /dev/null +++ b/flowae/configs/experiments/dito-XL-f8c4.yaml @@ -0,0 +1,37 @@ +__base__: + - configs/datasets/imagenet_ae.yaml + - configs/trainers/dito.yaml + +model: + name: dito + args: + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml b/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml new file mode 100644 index 0000000000000000000000000000000000000000..679cc9fa97413b19ca1d1581c0a008c4107b7bb9 --- /dev/null +++ b/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4-noise-sync.yaml @@ -0,0 +1,44 @@ +__base__: + - configs/datasets/imagenet_zdm.yaml + - configs/models/zdm-XL_imagenet.yaml + - configs/trainers/zdm.yaml + +eval_zdm_max_samples: 50000 + +model: + load_ckpt: save/zdm-XL_dito-XL-f8c4-noise-sync/ckpt-last.pth + name: dito + args: + zdm_force_guidance: 2.0 + renderer_ema_rate: 1 + + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml b/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5d360f8536c7652bf77cdeeaeff2dc33885136e --- /dev/null +++ b/flowae/configs/experiments/eval50k_zdm-XL_dito-XL-f8c4.yaml @@ -0,0 +1,44 @@ +__base__: + - configs/datasets/imagenet_zdm.yaml + - configs/models/zdm-XL_imagenet.yaml + - configs/trainers/zdm.yaml + +eval_zdm_max_samples: 50000 + +model: + load_ckpt: save/zdm-XL_dito-XL-f8c4/ckpt-last.pth + name: dito + args: + zdm_force_guidance: 2.0 + renderer_ema_rate: 1 + + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml b/flowae/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml new file mode 100644 index 0000000000000000000000000000000000000000..283a17ad9a677a47a152a5e2b5a44a3905183dae --- /dev/null +++ b/flowae/configs/experiments/zdm-XL_dito-XL-f8c4-noise-sync.yaml @@ -0,0 +1,41 @@ +__base__: + - configs/datasets/imagenet_zdm.yaml + - configs/models/zdm-XL_imagenet.yaml + - configs/trainers/zdm.yaml + +model: + load_ckpt: save/dito-XL-f8c4-noise-sync/ckpt-last.pth + name: dito + args: + renderer_ema_rate: 1 + + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/experiments/zdm-XL_dito-XL-f8c4.yaml b/flowae/configs/experiments/zdm-XL_dito-XL-f8c4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eeee58d880334bb1806ec51b8772f72bd1f34a75 --- /dev/null +++ b/flowae/configs/experiments/zdm-XL_dito-XL-f8c4.yaml @@ -0,0 +1,41 @@ +__base__: + - configs/datasets/imagenet_zdm.yaml + - configs/models/zdm-XL_imagenet.yaml + - configs/trainers/zdm.yaml + +model: + load_ckpt: + name: dito + args: + renderer_ema_rate: 1 + + encoder: + name: vqgan_encoder + args: {config_name: f8c4} + + z_shape: [4, 32, 32] + z_layernorm: true + + decoder: {name: identity} + + renderer: + name: fixres_renderer_wrapper + args: + net: + name: consistency_decoder_unet + args: + in_channels: 3 + z_dec_channels: 4 + c0: 320 + c1: 640 + c2: 1024 + pe_dim: 320 + t_dim: 1280 + + render_diffusion: + name: fm + args: {timescale: 1000.0} + render_sampler: {name: fm_euler_sampler} + render_n_steps: 50 + + loss_config: {} diff --git a/flowae/configs/models/zdm-XL_imagenet.yaml b/flowae/configs/models/zdm-XL_imagenet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27cb8116834b7f90c09c612f39507f9acf541d4e --- /dev/null +++ b/flowae/configs/models/zdm-XL_imagenet.yaml @@ -0,0 +1,12 @@ +model: + args: + zdm_net: + name: dit_xl_2 + args: {n_classes: 1001} + zdm_diffusion: + name: fm + args: {timescale: 1000.0} + zdm_sampler: {name: fm_euler_sampler} + zdm_n_steps: 200 + zdm_train_normalize: false + zdm_class_cond: 1000 \ No newline at end of file diff --git a/flowae/datasets/__init__.py b/flowae/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52749f4dfe1a9a36afab67115bdce2f243851bd6 --- /dev/null +++ b/flowae/datasets/__init__.py @@ -0,0 +1,3 @@ +from .datasets import register, make +from . import image_folder, class_folder, webdataset +from . import wrapper_cae diff --git a/flowae/datasets/class_folder.py b/flowae/datasets/class_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..91070b4545e1752c64c011d2f2c07b7e3bbb8b3f --- /dev/null +++ b/flowae/datasets/class_folder.py @@ -0,0 +1,89 @@ +import os +import random +from PIL import Image, ImageFile + +from datasets import register +from torch.utils.data import Dataset +from torchvision import transforms + + +Image.MAX_IMAGE_PIXELS = 933120000 +ImageFile.LOAD_TRUNCATED_IMAGES = True +IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') + + +@register('class_folder') +class ClassFolder(Dataset): + + def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False, drop_label_p=0.0, image_only=False): + folders = [] + print('root_path', root_path) + for folder in sorted(os.listdir(root_path)): + print('folder', folder) + if os.path.isdir(os.path.join(root_path, folder)): + folders.append(os.path.join(root_path, folder)) + print('folders', folders) + self.files = [] + self.labels = [] + for i, folder in enumerate(folders): + for file in sorted(os.listdir(os.path.join(root_path, folder))): + if file.endswith(IMAGE_EXTS): + self.files.append(os.path.join(root_path, folder, file)) + self.labels.append(i) + + self.resize = resize + self.square_crop = square_crop + self.rand_crop = rand_crop + self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None + + self.n_classes = len(folders) + self.drop_label_p = drop_label_p + + self.image_only = image_only + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + try: + image = Image.open(self.files[idx]).convert('RGB') + label = self.labels[idx] + except: + print('Error loading image:', self.files[idx]) + return self.__getitem__((idx + 1) % self.__len__()) + + if self.resize is not None: + r = self.resize + if isinstance(r, int): + w, h = image.size + if w < h: + r = (r, int(h / w * r)) + else: + r = (int(w / h * r), r) + image = image.resize(r, Image.LANCZOS) + + if self.square_crop: + w, h = image.size + l = min(w, h) + left, upper = (w - l) // 2, (h - l) // 2 + image = image.crop((left, upper, left + l, upper + l)) + + if self.rand_crop is not None: + w, h = image.size + left = random.randint(0, w - self.rand_crop) + upper = random.randint(0, h - self.rand_crop) + image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) + + if self.rand_flip is not None: + image = self.rand_flip(image) + + if self.drop_label_p > 0.0 and random.random() < self.drop_label_p: + label = self.n_classes + + if self.image_only: + return image + else: + return { + 'image': image, + 'class_labels': label, + } diff --git a/flowae/datasets/datasets.py b/flowae/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3f102233b144c4c04c35204fd97b052236502661 --- /dev/null +++ b/flowae/datasets/datasets.py @@ -0,0 +1,17 @@ +datasets = dict() + + +def register(name): + def decorator(cls): + datasets[name] = cls + return cls + return decorator + + +def make(spec): + args = spec.get('args') + if args is None: + args = dict() + print('args:', args) + dataset = datasets[spec['name']](**args) + return dataset diff --git a/flowae/datasets/image_folder.py b/flowae/datasets/image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..7564ef3fa31ea6f6aadd26a2d37b77c213e00095 --- /dev/null +++ b/flowae/datasets/image_folder.py @@ -0,0 +1,62 @@ +import os +import random +from PIL import Image, ImageFile + +from datasets import register +from torch.utils.data import Dataset +from torchvision import transforms + + +Image.MAX_IMAGE_PIXELS = 933120000 +ImageFile.LOAD_TRUNCATED_IMAGES = True +IMAGE_EXTS = ('.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.webp') + + +@register('image_folder') +class ImageFolder(Dataset): + + def __init__(self, root_path, resize=None, square_crop=False, rand_crop=None, rand_flip=False): + files = sorted(os.listdir(root_path)) + self.files = [os.path.join(root_path, _) for _ in files if _.endswith(IMAGE_EXTS)] + + self.resize = resize + self.square_crop = square_crop + self.rand_crop = rand_crop + self.rand_flip = transforms.RandomHorizontalFlip() if rand_flip else None + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + try: + image = Image.open(self.files[idx]).convert('RGB') + except: + print('Error loading image:', self.files[idx]) + return self.__getitem__((idx + 1) % self.__len__()) + + if self.resize is not None: + r = self.resize + if isinstance(r, int): + w, h = image.size + if w < h: + r = (r, int(h / w * r)) + else: + r = (int(w / h * r), r) + image = image.resize(r, Image.LANCZOS) + + if self.square_crop: + w, h = image.size + l = min(w, h) + left, upper = (w - l) // 2, (h - l) // 2 + image = image.crop((left, upper, left + l, upper + l)) + + if self.rand_crop is not None: + w, h = image.size + left = random.randint(0, w - self.rand_crop) + upper = random.randint(0, h - self.rand_crop) + image = image.crop((left, upper, left + self.rand_crop, upper + self.rand_crop)) + + if self.rand_flip is not None: + image = self.rand_flip(image) + + return image diff --git a/flowae/datasets/webdataset.py b/flowae/datasets/webdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7772713a688566eb8ac6fc7c38efcaf1c2993d66 --- /dev/null +++ b/flowae/datasets/webdataset.py @@ -0,0 +1,45 @@ +import json + +import webdataset as wds +from webdataset.handlers import warn_and_continue + +from datasets import register + + +def webdataset_preprocessors(square_crop=True): + def identity(x): + if isinstance(x, bytes): + x = x.decode('utf-8') + return x + + def transform(image): + w, h = image.size + l = min(w, h) + left, upper = (w - l) // 2, (h - l) // 2 + return image.crop((left, upper, left + l, upper + l)) + + ret = [ + ('jpg;png', transform if square_crop else lambda x: x, 'image'), + ('txt', identity, 'caption'), + ] + + return ret + + +@register('webdataset') +def make_webdataset(json_file, **kwargs): + with open(json_file, 'r') as file: + tar_list = json.load(file) + preprocessors = webdataset_preprocessors(**kwargs) + handler = warn_and_continue + dataset = wds.WebDataset( + tar_list, resampled=True, handler=handler + ).shuffle(690, handler=handler).decode( + "pilrgb", handler=handler + ).to_tuple( + *[p[0] for p in preprocessors], handler=handler + ).map_tuple( + *[p[1] for p in preprocessors], handler=handler + ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) + + return dataset diff --git a/flowae/datasets/wrapper_cae.py b/flowae/datasets/wrapper_cae.py new file mode 100644 index 0000000000000000000000000000000000000000..1a684e1b959287c6647ad49399d7f5458f0818d0 --- /dev/null +++ b/flowae/datasets/wrapper_cae.py @@ -0,0 +1,308 @@ +import random +from PIL import Image + +import torch +from torch.utils.data import Dataset, IterableDataset +from torchvision import transforms + +import datasets +from datasets import register +from utils.geometry import make_coord_scale_grid + + +from models.ldm.dac.audiotools import AudioSignal +import numpy as np + +from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader +from models.ldm.dac.audiotools import transforms as tfm + + +class BaseWrapperCAE: + + def __init__( + self, + dataset, + resize_inp, + return_gt=True, + gt_glores_lb=None, + gt_glores_ub=None, + gt_patch_size=None, + p_whole=0.0, + p_max=0.0 + ): + self.dataset = datasets.make(dataset) + self.resize_inp = resize_inp + self.return_gt = return_gt + self.gt_glores_lb = gt_glores_lb + self.gt_glores_ub = gt_glores_ub + self.gt_patch_size = gt_patch_size + self.p_whole = p_whole + self.p_max = p_max + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(0.5, 0.5), + ]) + + def process(self, image): + assert image.size[0] == image.size[1] + ret = {} + + inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS) + inp = self.transform(inp) + ret.update({'inp': inp}) + if not self.return_gt: + return ret + + if self.gt_glores_lb is None: + glo = self.transform(image) + else: + if random.random() < self.p_whole: + r = self.gt_patch_size + elif random.random() < self.p_max: + r = min(image.size[0], self.gt_glores_ub) + else: + r = random.randint( + self.gt_glores_lb, + max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub)) + ) + glo = image.resize((r, r), Image.LANCZOS) + glo = self.transform(glo) + + p = self.gt_patch_size + ii = random.randint(0, glo.shape[1] - p) + jj = random.randint(0, glo.shape[2] - p) + gt_patch = glo[:, ii: ii + p, jj: jj + p] + + x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1] + x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1] + coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]]) + ret['gt'] = torch.cat([ + gt_patch, # 3 p p + coord.permute(2, 0, 1), # 2 p p + scale.permute(2, 0, 1), # 2 p p + ], dim=0) + + return ret + + +@register('wrapper_cae') +class WrapperCAE(BaseWrapperCAE, Dataset): + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + if isinstance(data, dict): + ret = dict() + ret.update(self.process(data.pop('image'))) + ret.update(data) + return ret + else: + return self.process(data) + + +@register('wrapper_cae_iterable') +class WrapperCAE(BaseWrapperCAE, IterableDataset): + + def __iter__(self): + for data in self.dataset: + if isinstance(data, dict): + ret = dict() + ret.update(self.process(data.pop('image'))) + ret.update(data) + yield ret + else: + yield self.process(data) + + + + + + +class BaseWrapperAudioCAE: + """Base wrapper for audio Convolutional Autoencoder (CAE) training. + + Similar to the image wrapper, but for audio data. + """ + + def __init__( + self, + dataset, + sample_rate=24000, + duration=0.38, # Duration in seconds + n_samples=None, # Alternative: specify exact number of samples + return_gt=True, + gt_sample_rate=None, # Ground truth sample rate (if different) + mono=True, + normalize=True, + return_coords=True, # Whether to return coordinate grids + ): + self.dataset = dataset + self.sample_rate = sample_rate + self.duration = duration + self.n_samples = n_samples or int(duration * sample_rate) + self.return_gt = return_gt + self.gt_sample_rate = gt_sample_rate or sample_rate + self.mono = mono + self.normalize = normalize + self.return_coords = return_coords + + def process(self, audio_data): + """Process audio data for DiTo training. + + Args: + audio_data: Dictionary with 'signal' key containing AudioSignal + or AudioSignal directly + """ + ret = {} + + # Extract AudioSignal + if isinstance(audio_data, dict): + signal = audio_data['signal'] + else: + signal = audio_data + + # Convert to mono if needed + if self.mono and signal.num_channels > 1: + signal = signal.to_mono() + + # Resample to target sample rate + if signal.sample_rate != self.sample_rate: + signal = signal.resample(self.sample_rate) + + # Extract fixed duration + if signal.duration < self.duration: + # Pad if too short + signal = signal.zero_pad_to(self.n_samples) + else: + # Take random excerpt if too long + max_start = signal.num_samples - self.n_samples + if max_start > 0: + start_idx = random.randint(0, max_start) + signal = signal[..., start_idx:start_idx + self.n_samples] + else: + signal = signal[..., :self.n_samples] + + # Normalize audio + audio_tensor = signal.audio_data # Shape: [channels, samples] + if self.normalize: + # Normalize to [-1, 1] + max_val = audio_tensor.abs().max() + if max_val > 0: + audio_tensor = audio_tensor / max_val + + # Create input tensor + ret['inp'] = audio_tensor + + if not self.return_gt: + return ret + + + ret['gt'] = audio_tensor + + return ret + + +@register('wrapper_audio_cae') +class WrapperAudioCAE(BaseWrapperAudioCAE, Dataset): + """Dataset wrapper for audio CAE training.""" + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + data = self.dataset[idx] + return self.process(data) + + +@register('wrapper_audio_cae_iterable') +class WrapperAudioCAEIterable(BaseWrapperAudioCAE, IterableDataset): + """Iterable dataset wrapper for audio CAE training.""" + + def __iter__(self): + for data in self.dataset: + yield self.process(data) + + +# Example usage with your existing AudioDataset +def create_dito_audio_dataset(config): + """Create DiTo audio dataset from config.""" + + # Create base audio dataset using audiotools + + # Setup audio loaders + train_folders = config.get("train_folders", {}) + + loader = AudioLoader( + sources=list(train_folders.values()), + transform=tfm.Compose( + tfm.VolumeNorm(("uniform", -20, -10)), + tfm.RescaleAudio(), + ), + ext=['.wav', '.flac', '.mp3'], + ) + + # Create base dataset + base_dataset = AudioDataset( + loaders=loader, + sample_rate=config['sample_rate'], + duration=config['duration'], + n_examples=config['n_examples'], + num_channels=1 if config.get('mono', True) else 2, + ) + + # Wrap with DiTo wrapper + dito_dataset = WrapperAudioCAE( + dataset=base_dataset, + sample_rate=config['sample_rate'], + duration=config['duration'], + mono=config.get('mono', True), + normalize=True, + return_coords=True, + ) + + return dito_dataset + + +# For your training config, you would use it like: +""" +datasets: + train: + name: wrapper_audio_cae + args: + dataset: + name: audio_dataset # Your base audio dataset + args: + sources: ["/path/to/audio/files"] + sample_rate: 44100 + duration: 2.0 + n_examples: 10000 + sample_rate: 44100 + duration: 2.0 + mono: true + normalize: true + return_coords: true + loader: + batch_size: 16 + num_workers: 8 + + val: + name: wrapper_audio_cae + args: + dataset: + name: audio_dataset + args: + sources: ["/path/to/val/audio/files"] + sample_rate: 44100 + duration: 2.0 + n_examples: 1000 + sample_rate: 44100 + duration: 2.0 + mono: true + normalize: true + return_coords: true + loader: + batch_size: 16 + num_workers: 8 +""" \ No newline at end of file diff --git a/flowae/load/dito.png b/flowae/load/dito.png new file mode 100644 index 0000000000000000000000000000000000000000..bbc7f7f11845d568c7c7e4d3439299c87073f4cb --- /dev/null +++ b/flowae/load/dito.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25999470ed7eaba155f1e8ab639fd43fe35eb53a53e9770829451b3d0434c467 +size 245154 diff --git a/flowae/load/vgg_lpips.pth b/flowae/load/vgg_lpips.pth new file mode 100644 index 0000000000000000000000000000000000000000..47e943cfacabf7040b4af8cf4084ab91177f1b88 Binary files /dev/null and b/flowae/load/vgg_lpips.pth differ diff --git a/flowae/load/wandb.yaml b/flowae/load/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..792009be6e6f698d184cd8176c463da38b9f9075 --- /dev/null +++ b/flowae/load/wandb.yaml @@ -0,0 +1,3 @@ +entity: +api_key: +project: \ No newline at end of file diff --git a/flowae/models/__init__.py b/flowae/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a61e9db76e11bad1e378d95caa94e06341e0ee --- /dev/null +++ b/flowae/models/__init__.py @@ -0,0 +1,4 @@ +from .models import register, make +from . import ldm +from . import diffusion +from . import networks \ No newline at end of file diff --git a/flowae/models/diffusion/__init__.py b/flowae/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6578ad49c27b9797ee3e62c0b5d952089e253ed9 --- /dev/null +++ b/flowae/models/diffusion/__init__.py @@ -0,0 +1,2 @@ +from . import fm +from . import samplers diff --git a/flowae/models/diffusion/fm.py b/flowae/models/diffusion/fm.py new file mode 100644 index 0000000000000000000000000000000000000000..0c008882977b3e240ee4ecfaf07802e565997fd0 --- /dev/null +++ b/flowae/models/diffusion/fm.py @@ -0,0 +1,89 @@ +import torch + +from models import register + + +@register('fm') +class FM: + + def __init__(self, sigma_min=1e-5, timescale=1.0): + self.sigma_min = sigma_min + self.prediction_type = None + self.timescale = timescale + + def alpha(self, t): + return 1.0 - t + + def sigma(self, t): + return self.sigma_min + t * (1.0 - self.sigma_min) + + def A(self, t): + return 1.0 + + def B(self, t): + return -(1.0 - self.sigma_min) + + def get_betas(self, n_timesteps): + return torch.zeros(n_timesteps) # Not VP and not supported + + def add_noise(self, x, t, noise=None): + noise = torch.randn_like(x) if noise is None else noise + s = [x.shape[0]] + [1] * (x.dim() - 1) + x_t = self.alpha(t).view(*s) * x + self.sigma(t).view(*s) * noise + return x_t, noise + + def loss(self, net, x, t=None, net_kwargs=None, return_loss_unreduced=False, return_all=False): + if net_kwargs is None: + net_kwargs = {} + + if t is None: + t = torch.rand(x.shape[0], device=x.device) + print('x shape: ', x.shape) + x_t, noise = self.add_noise(x, t) + print('x_t shape: ', x_t.shape) + pred = net(x_t, t=t * self.timescale, **net_kwargs) + print('pred shape: ', pred.shape) + + target = self.A(t) * x + self.B(t) * noise # -dxt/dt + print('target shape: ', target.shape) + print('return_loss_unreduced: ', return_loss_unreduced, 'return_all: ', return_all) + if return_loss_unreduced: + loss = ((pred.float() - target.float()) ** 2).mean(dim=[1, 2, 3]) + if return_all: + return loss, t, x_t, pred + else: + return loss, t + else: + # here we go + loss = ((pred.float() - target.float()) ** 2).mean() + if return_all: + return loss, x_t, pred + else: + return loss + + def get_prediction( + self, + net, + x_t, + t, + net_kwargs=None, + uncond_net_kwargs=None, + guidance=1.0, + ): + if net_kwargs is None: + net_kwargs = {} + pred = net(x_t, t=t * self.timescale, **net_kwargs) + if guidance != 1.0: + assert uncond_net_kwargs is not None + uncond_pred = net(x_t, t=t * self.timescale, **uncond_net_kwargs) + pred = uncond_pred + guidance * (pred - uncond_pred) + return pred + + def convert_sample_prediction(self, x_t, t, pred): + M = torch.tensor([ + [self.alpha(t), self.sigma(t)], + [self.A(t), self.B(t)], + ], dtype=torch.float64) + M_inv = torch.linalg.inv(M) + sample_pred = M_inv[0, 0].item() * x_t + M_inv[0, 1].item() * pred + return sample_pred diff --git a/flowae/models/diffusion/samplers.py b/flowae/models/diffusion/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..75fc52249e08e8f5bee516349f9545f0de7c5513 --- /dev/null +++ b/flowae/models/diffusion/samplers.py @@ -0,0 +1,39 @@ +import numpy as np +import torch + +from models import register + + +@register('fm_euler_sampler') +class FMEulerSampler: + + def __init__(self, diffusion): + self.diffusion = diffusion + + def sample( + self, + net, + shape, + n_steps, + net_kwargs=None, + uncond_net_kwargs=None, + guidance=1.0, + noise=None, + ): + device = next(net.parameters()).device + x_t = torch.randn(shape, device=device) if noise is None else noise + t_steps = torch.linspace(1, 0, n_steps + 1, device=device) + + with torch.no_grad(): + for i in range(n_steps): + t = t_steps[i].repeat(x_t.shape[0]) + neg_v = self.diffusion.get_prediction( + net, + x_t, + t, + net_kwargs=net_kwargs, + uncond_net_kwargs=uncond_net_kwargs, + guidance=guidance, + ) + x_t = x_t + neg_v * (t_steps[i] - t_steps[i + 1]) + return x_t diff --git a/flowae/models/ldm/__init__.py b/flowae/models/ldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8103bdfc078dae107945147b04cb44ed5758f5e --- /dev/null +++ b/flowae/models/ldm/__init__.py @@ -0,0 +1,4 @@ +from . import glpto, dito +from . import renderers +from . import vqgan +from . import dac \ No newline at end of file diff --git a/flowae/models/ldm/dac/__init__.py b/flowae/models/ldm/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67 --- /dev/null +++ b/flowae/models/ldm/dac/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/flowae/models/ldm/dac/audiotools/__init__.py b/flowae/models/ldm/dac/audiotools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b251ff37628c56a19bb38976fca99d9536e64bbf --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/__init__.py @@ -0,0 +1,10 @@ +__version__ = "0.7.4" +from .core import AudioSignal +from .core import STFTParams +from .core import Meter +from .core import util +from . import metrics +from . import data +from . import ml +from .data import datasets +from .data import transforms diff --git a/flowae/models/ldm/dac/audiotools/core/__init__.py b/flowae/models/ldm/dac/audiotools/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8660c4e67f43d0ded584a38939425e2c28d95cd3 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/__init__.py @@ -0,0 +1,4 @@ +from . import util +from .audio_signal import AudioSignal +from .audio_signal import STFTParams +from .loudness import Meter diff --git a/flowae/models/ldm/dac/audiotools/core/audio_signal.py b/flowae/models/ldm/dac/audiotools/core/audio_signal.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6d751cb968a003656e3e7874c487b83d94c82e --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/audio_signal.py @@ -0,0 +1,1682 @@ +import copy +import functools +import hashlib +import math +import pathlib +import tempfile +import typing +import warnings +from collections import namedtuple +from pathlib import Path + +import julius +import numpy as np +import soundfile +import torch + +from . import util +from .display import DisplayMixin +from .dsp import DSPMixin +from .effects import EffectMixin +from .effects import ImpulseResponseMixin +from .ffmpeg import FFMPEGMixin +from .loudness import LoudnessMixin +from .playback import PlayMixin +from .whisper import WhisperMixin + + +STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], +) +""" +STFTParams object is a container that holds STFT parameters - window_length, +hop_length, and window_type. Not all parameters need to be specified. Ones that +are not specified will be inferred by the AudioSignal parameters. + +Parameters +---------- +window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. +hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. +window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. +match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False +padding_type : str, optional + Type of padding to use, by default 'reflect' +""" +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +class AudioSignal( + EffectMixin, + LoudnessMixin, + PlayMixin, + ImpulseResponseMixin, + DSPMixin, + DisplayMixin, + FFMPEGMixin, + WhisperMixin, +): + """This is the core object of this library. Audio is always + loaded into an AudioSignal, which then enables all the features + of this library, including audio augmentations, I/O, playback, + and more. + + The structure of this object is that the base functionality + is defined in ``core/audio_signal.py``, while extensions to + that functionality are defined in the other ``core/*.py`` + files. For example, all the display-based functionality + (e.g. plot spectrograms, waveforms, write to tensorboard) + are in ``core/display.py``. + + Parameters + ---------- + audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray] + Object to create AudioSignal from. Can be a tensor, numpy array, + or a path to a file. The file is always reshaped to + sample_rate : int, optional + Sample rate of the audio. If different from underlying file, resampling is + performed. If passing in an array or tensor, this must be defined, + by default None + stft_params : STFTParams, optional + Parameters of STFT to use. , by default None + offset : float, optional + Offset in seconds to read from file, by default 0 + duration : float, optional + Duration in seconds to read from file, by default None + device : str, optional + Device to load audio onto, by default None + + Examples + -------- + Loading an AudioSignal from an array, at a sample rate of + 44100. + + >>> signal = AudioSignal(torch.randn(5*44100), 44100) + + Note, the signal is reshaped to have a batch size, and one + audio channel: + + >>> print(signal.shape) + (1, 1, 44100) + + You can treat AudioSignals like tensors, and many of the same + functions you might use on tensors are defined for AudioSignals + as well: + + >>> signal.to("cuda") + >>> signal.cuda() + >>> signal.clone() + >>> signal.detach() + + Indexing AudioSignals returns an AudioSignal: + + >>> signal[..., 3*44100:4*44100] + + The above signal is 1 second long, and is also an AudioSignal. + """ + + def __init__( + self, + audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray], + sample_rate: int = None, + stft_params: STFTParams = None, + offset: float = 0, + duration: float = None, + device: str = None, + ): + audio_path = None + audio_array = None + + if isinstance(audio_path_or_array, str): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, pathlib.Path): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, np.ndarray): + audio_array = audio_path_or_array + elif torch.is_tensor(audio_path_or_array): + audio_array = audio_path_or_array + else: + raise ValueError( + "audio_path_or_array must be either a Path, " + "string, numpy array, or torch Tensor!" + ) + + self.path_to_file = None + + self.audio_data = None + self.sources = None # List of AudioSignal objects. + self.stft_data = None + if audio_path is not None: + self.load_from_file( + audio_path, offset=offset, duration=duration, device=device + ) + elif audio_array is not None: + assert sample_rate is not None, "Must set sample rate!" + self.load_from_array(audio_array, sample_rate, device=device) + + self.window = None + self.stft_params = stft_params + + self.metadata = { + "offset": offset, + "duration": duration, + } + + @property + def path_to_input_file( + self, + ): + """ + Path to input file, if it exists. + Alias to ``path_to_file`` for backwards compatibility + """ + return self.path_to_file + + @classmethod + def excerpt( + cls, + audio_path: typing.Union[str, Path], + offset: float = None, + duration: float = None, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """Randomly draw an excerpt of ``duration`` seconds from an + audio file specified at ``audio_path``, between ``offset`` seconds + and end of file. ``state`` can be used to seed the random draw. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + offset : float, optional + Lower bound for the start time, in seconds drawn from + the file, by default None. + duration : float, optional + Duration of excerpt, in seconds, by default None + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + Examples + -------- + >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) + """ + info = util.info(audio_path) + total_duration = info.duration + + state = util.random_state(state) + lower_bound = 0 if offset is None else offset + upper_bound = max(total_duration - duration, 0) + offset = state.uniform(lower_bound, upper_bound) + + signal = cls(audio_path, offset=offset, duration=duration, **kwargs) + signal.metadata["offset"] = offset + signal.metadata["duration"] = duration + + return signal + + @classmethod + def salient_excerpt( + cls, + audio_path: typing.Union[str, Path], + loudness_cutoff: float = None, + num_tries: int = 8, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """Similar to AudioSignal.excerpt, except it extracts excerpts only + if they are above a specified loudness threshold, which is computed via + a fast LUFS routine. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + loudness_cutoff : float, optional + Loudness threshold in dB. Typical values are ``-40, -60``, + etc, by default None + num_tries : int, optional + Number of tries to grab an excerpt above the threshold + before giving up, by default 8. + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + kwargs : dict + Keyword arguments to AudioSignal.excerpt + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + + .. warning:: + if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can + result in an infinite loop if ``audio_path`` does not have + any loud enough excerpts. + + Examples + -------- + >>> signal = AudioSignal.salient_excerpt( + "path/to/audio", + loudness_cutoff=-40, + duration=5 + ) + """ + state = util.random_state(state) + if loudness_cutoff is None: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + else: + loudness = -np.inf + num_try = 0 + while loudness <= loudness_cutoff: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + loudness = excerpt.loudness() + num_try += 1 + if num_tries is not None and num_try >= num_tries: + break + return excerpt + + @classmethod + def zeros( + cls, + duration: float, + sample_rate: int, + num_channels: int = 1, + batch_size: int = 1, + **kwargs, + ): + """Helper function create an AudioSignal of all zeros. + + Parameters + ---------- + duration : float + Duration of AudioSignal + sample_rate : int + Sample rate of AudioSignal + num_channels : int, optional + Number of channels, by default 1 + batch_size : int, optional + Batch size, by default 1 + + Returns + ------- + AudioSignal + AudioSignal containing all zeros. + + Examples + -------- + Generate 5 seconds of all zeros at a sample rate of 44100. + + >>> signal = AudioSignal.zeros(5.0, 44100) + """ + n_samples = int(duration * sample_rate) + return cls( + torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs + ) + + @classmethod + def wave( + cls, + frequency: float, + duration: float, + sample_rate: int, + num_channels: int = 1, + shape: str = "sine", + **kwargs, + ): + """ + Generate a waveform of a given frequency and shape. + + Parameters + ---------- + frequency : float + Frequency of the waveform + duration : float + Duration of the waveform + sample_rate : int + Sample rate of the waveform + num_channels : int, optional + Number of channels, by default 1 + shape : str, optional + Shape of the waveform, by default "saw" + One of "sawtooth", "square", "sine", "triangle" + kwargs : dict + Keyword arguments to AudioSignal + """ + n_samples = int(duration * sample_rate) + t = torch.linspace(0, duration, n_samples) + if shape == "sawtooth": + from scipy.signal import sawtooth + + wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) + elif shape == "square": + from scipy.signal import square + + wave_data = square(2 * np.pi * frequency * t) + elif shape == "sine": + wave_data = np.sin(2 * np.pi * frequency * t) + elif shape == "triangle": + from scipy.signal import sawtooth + + # frequency is doubled by the abs call, so omit the 2 in 2pi + wave_data = sawtooth(np.pi * frequency * t, 0.5) + wave_data = -np.abs(wave_data) * 2 + 1 + else: + raise ValueError(f"Invalid shape {shape}") + + wave_data = torch.tensor(wave_data, dtype=torch.float32) + wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1) + return cls(wave_data, sample_rate, **kwargs) + + @classmethod + def batch( + cls, + audio_signals: list, + pad_signals: bool = False, + truncate_signals: bool = False, + resample: bool = False, + dim: int = 0, + ): + """Creates a batched AudioSignal from a list of AudioSignals. + + Parameters + ---------- + audio_signals : list[AudioSignal] + List of AudioSignal objects + pad_signals : bool, optional + Whether to pad signals to length of the maximum length + AudioSignal in the list, by default False + truncate_signals : bool, optional + Whether to truncate signals to length of shortest length + AudioSignal in the list, by default False + resample : bool, optional + Whether to resample AudioSignal to the sample rate of + the first AudioSignal in the list, by default False + dim : int, optional + Dimension along which to batch the signals. + + Returns + ------- + AudioSignal + Batched AudioSignal. + + Raises + ------ + RuntimeError + If not all AudioSignals are the same sample rate, and + ``resample=False``, an error is raised. + RuntimeError + If not all AudioSignals are the same the length, and + both ``pad_signals=False`` and ``truncate_signals=False``, + an error is raised. + + Examples + -------- + Batching a bunch of random signals: + + >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)] + >>> signal = AudioSignal.batch(signal_list) + >>> print(signal.shape) + (10, 1, 44100) + + """ + signal_lengths = [x.signal_length for x in audio_signals] + sample_rates = [x.sample_rate for x in audio_signals] + + if len(set(sample_rates)) != 1: + if resample: + for x in audio_signals: + x.resample(sample_rates[0]) + else: + raise RuntimeError( + f"Not all signals had the same sample rate! Got {sample_rates}. " + f"All signals must have the same sample rate, or resample must be True. " + ) + + if len(set(signal_lengths)) != 1: + if pad_signals: + max_length = max(signal_lengths) + for x in audio_signals: + pad_len = max_length - x.signal_length + x.zero_pad(0, pad_len) + elif truncate_signals: + min_length = min(signal_lengths) + for x in audio_signals: + x.truncate_samples(min_length) + else: + raise RuntimeError( + f"Not all signals had the same length! Got {signal_lengths}. " + f"All signals must be the same length, or pad_signals/truncate_signals " + f"must be True. " + ) + # Concatenate along the specified dimension (default 0) + audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim) + audio_paths = [x.path_to_file for x in audio_signals] + + batched_signal = cls( + audio_data, + sample_rate=audio_signals[0].sample_rate, + ) + batched_signal.path_to_file = audio_paths + return batched_signal + + # I/O + def load_from_file( + self, + audio_path: typing.Union[str, Path], + offset: float, + duration: float, + device: str = "cpu", + ): + """Loads data from file. Used internally when AudioSignal + is instantiated with a path to a file. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to file + offset : float + Offset in seconds + duration : float + Duration in seconds + device : str, optional + Device to put AudioSignal on, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from file + """ + import librosa + + data, sample_rate = librosa.load( + audio_path, + offset=offset, + duration=duration, + sr=None, + mono=False, + ) + data = util.ensure_tensor(data) + if data.shape[-1] == 0: + raise RuntimeError( + f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" + ) + + if data.ndim < 2: + data = data.unsqueeze(0) + if data.ndim < 3: + data = data.unsqueeze(0) + self.audio_data = data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + self.path_to_file = audio_path + return self.to(device) + + def load_from_array( + self, + audio_array: typing.Union[torch.Tensor, np.ndarray], + sample_rate: int, + device: str = "cpu", + ): + """Loads data from array, reshaping it to be exactly 3 + dimensions. Used internally when AudioSignal is called + with a tensor or an array. + + Parameters + ---------- + audio_array : typing.Union[torch.Tensor, np.ndarray] + Array/tensor of audio of samples. + sample_rate : int + Sample rate of audio + device : str, optional + Device to move audio onto, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from array + """ + audio_data = util.ensure_tensor(audio_array) + + if audio_data.dtype == torch.double: + audio_data = audio_data.float() + + if audio_data.ndim < 2: + audio_data = audio_data.unsqueeze(0) + if audio_data.ndim < 3: + audio_data = audio_data.unsqueeze(0) + self.audio_data = audio_data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + return self.to(device) + + def write(self, audio_path: typing.Union[str, Path]): + """Writes audio to a file. Only writes the audio + that is in the very first item of the batch. To write other items + in the batch, index the signal along the batch dimension + before writing. After writing, the signal's ``path_to_file`` + attribute is updated to the new path. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to write audio to. + + Returns + ------- + AudioSignal + Returns original AudioSignal, so you can use this in a fluent + interface. + + Examples + -------- + Creating and writing a signal to disk: + + >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100) + >>> signal.write("/tmp/out.wav") + + Writing a different element of the batch: + + >>> signal[5].write("/tmp/out.wav") + + Using this in a fluent interface: + + >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") + + """ + if self.audio_data[0].abs().max() > 1: + warnings.warn("Audio amplitude > 1 clipped when saving") + soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate) + + self.path_to_file = audio_path + return self + + def deepcopy(self): + """Copies the signal and all of its attributes. + + Returns + ------- + AudioSignal + Deep copy of the audio signal. + """ + return copy.deepcopy(self) + + def copy(self): + """Shallow copy of signal. + + Returns + ------- + AudioSignal + Shallow copy of the audio signal. + """ + return copy.copy(self) + + def clone(self): + """Clones all tensors contained in the AudioSignal, + and returns a copy of the signal with everything + cloned. Useful when using AudioSignal within autograd + computation graphs. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Clone of AudioSignal. + """ + clone = type(self)( + self.audio_data.clone(), + self.sample_rate, + stft_params=self.stft_params, + ) + if self.stft_data is not None: + clone.stft_data = self.stft_data.clone() + if self._loudness is not None: + clone._loudness = self._loudness.clone() + clone.path_to_file = copy.deepcopy(self.path_to_file) + clone.metadata = copy.deepcopy(self.metadata) + return clone + + def detach(self): + """Detaches tensors contained in AudioSignal. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Same signal, but with all tensors detached. + """ + if self._loudness is not None: + self._loudness = self._loudness.detach() + if self.stft_data is not None: + self.stft_data = self.stft_data.detach() + + self.audio_data = self.audio_data.detach() + return self + + def hash(self): + """Writes the audio data to a temporary file, and then + hashes it using hashlib. Useful for creating a file + name based on the audio content. + + Returns + ------- + str + Hash of audio data. + + Examples + -------- + Creating a signal, and writing it to a unique file name: + + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> hash = signal.hash() + >>> signal.write(f"{hash}.wav") + + """ + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(f.name, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + file_hash = h.hexdigest() + return file_hash + + # Signal operations + def to_mono(self): + """Converts audio data to mono audio, by taking the mean + along the channels dimension. + + Returns + ------- + AudioSignal + AudioSignal with mean of channels. + """ + self.audio_data = self.audio_data.mean(1, keepdim=True) + return self + + def resample(self, sample_rate: int): + """Resamples the audio, using sinc interpolation. This works on both + cpu and gpu, and is much faster on gpu. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + + Returns + ------- + AudioSignal + Resampled AudioSignal + """ + if sample_rate == self.sample_rate: + return self + self.audio_data = julius.resample_frac( + self.audio_data, self.sample_rate, sample_rate + ) + self.sample_rate = sample_rate + return self + + # Tensor operations + def to(self, device: str): + """Moves all tensors contained in signal to the specified device. + + Parameters + ---------- + device : str + Device to move AudioSignal onto. Typical values are + "cuda", "cpu", or "cuda:n" to specify the nth gpu. + + Returns + ------- + AudioSignal + AudioSignal with all tensors moved to specified device. + """ + if self._loudness is not None: + self._loudness = self._loudness.to(device) + if self.stft_data is not None: + self.stft_data = self.stft_data.to(device) + if self.audio_data is not None: + self.audio_data = self.audio_data.to(device) + return self + + def float(self): + """Calls ``.float()`` on ``self.audio_data``. + + Returns + ------- + AudioSignal + """ + self.audio_data = self.audio_data.float() + return self + + def cpu(self): + """Moves AudioSignal to cpu. + + Returns + ------- + AudioSignal + """ + return self.to("cpu") + + def cuda(self): # pragma: no cover + """Moves AudioSignal to cuda. + + Returns + ------- + AudioSignal + """ + return self.to("cuda") + + def numpy(self): + """Detaches ``self.audio_data``, moves to cpu, and converts to numpy. + + Returns + ------- + np.ndarray + Audio data as a numpy array. + """ + return self.audio_data.detach().cpu().numpy() + + def zero_pad(self, before: int, after: int): + """Zero pads the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many zeros to prepend to audio. + after : int + How many zeros to append to audio. + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after)) + return self + + def zero_pad_to(self, length: int, mode: str = "after"): + """Pad with zeros to a specified length, either before or after + the audio data. + + Parameters + ---------- + length : int + Length to pad to + mode : str, optional + Whether to prepend or append zeros to signal, by default "after" + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + if mode == "before": + self.zero_pad(max(length - self.signal_length, 0), 0) + elif mode == "after": + self.zero_pad(0, max(length - self.signal_length, 0)) + return self + + def trim(self, before: int, after: int): + """Trims the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many samples to trim from beginning. + after : int + How many samples to trim from end. + + Returns + ------- + AudioSignal + AudioSignal with trimming applied. + """ + if after == 0: + self.audio_data = self.audio_data[..., before:] + else: + self.audio_data = self.audio_data[..., before:-after] + return self + + def truncate_samples(self, length_in_samples: int): + """Truncate signal to specified length. + + Parameters + ---------- + length_in_samples : int + Truncate to this many samples. + + Returns + ------- + AudioSignal + AudioSignal with truncation applied. + """ + self.audio_data = self.audio_data[..., :length_in_samples] + return self + + @property + def device(self): + """Get device that AudioSignal is on. + + Returns + ------- + torch.device + Device that AudioSignal is on. + """ + if self.audio_data is not None: + device = self.audio_data.device + elif self.stft_data is not None: + device = self.stft_data.device + return device + + # Properties + @property + def audio_data(self): + """Returns the audio data tensor in the object. + + Audio data is always of the shape + (batch_size, num_channels, num_samples). If value has less + than 3 dims (e.g. is (num_channels, num_samples)), then it will + be reshaped to (1, num_channels, num_samples) - a batch size of 1. + + Parameters + ---------- + data : typing.Union[torch.Tensor, np.ndarray] + Audio data to set. + + Returns + ------- + torch.Tensor + Audio samples. + """ + return self._audio_data + + @audio_data.setter + def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]): + if data is not None: + assert torch.is_tensor(data), "audio_data should be torch.Tensor" + assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" + self._audio_data = data + # Old loudness value not guaranteed to be right, reset it. + self._loudness = None + return + + # alias for audio_data + samples = audio_data + + @property + def stft_data(self): + """Returns the STFT data inside the signal. Shape is + (batch, channels, frequencies, time). + + Returns + ------- + torch.Tensor + Complex spectrogram data. + """ + return self._stft_data + + @stft_data.setter + def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]): + if data is not None: + assert torch.is_tensor(data) and torch.is_complex(data) + if self.stft_data is not None and self.stft_data.shape != data.shape: + warnings.warn("stft_data changed shape") + self._stft_data = data + return + + @property + def batch_size(self): + """Batch size of audio signal. + + Returns + ------- + int + Batch size of signal. + """ + return self.audio_data.shape[0] + + @property + def signal_length(self): + """Length of audio signal. + + Returns + ------- + int + Length of signal in samples. + """ + return self.audio_data.shape[-1] + + # alias for signal_length + length = signal_length + + @property + def shape(self): + """Shape of audio data. + + Returns + ------- + tuple + Shape of audio data. + """ + return self.audio_data.shape + + @property + def signal_duration(self): + """Length of audio signal in seconds. + + Returns + ------- + float + Length of signal in seconds. + """ + return self.signal_length / self.sample_rate + + # alias for signal_duration + duration = signal_duration + + @property + def num_channels(self): + """Number of audio channels. + + Returns + ------- + int + Number of audio channels. + """ + return self.audio_data.shape[1] + + # STFT + @staticmethod + @functools.lru_cache(None) + def get_window(window_type: str, window_length: int, device: str): + """Wrapper around scipy.signal.get_window so one can also get the + popular sqrt-hann window. This function caches for efficiency + using functools.lru\_cache. + + Parameters + ---------- + window_type : str + Type of window to get + window_length : int + Length of the window + device : str + Device to put window onto. + + Returns + ------- + torch.Tensor + Window returned by scipy.signal.get_window, as a tensor. + """ + from scipy import signal + + if window_type == "average": + window = np.ones(window_length) / window_length + elif window_type == "sqrt_hann": + window = np.sqrt(signal.get_window("hann", window_length)) + else: + window = signal.get_window(window_type, window_length) + window = torch.from_numpy(window).to(device).float() + return window + + @property + def stft_params(self): + """Returns STFTParams object, which can be re-used to other + AudioSignals. + + This property can be set as well. If values are not defined in STFTParams, + they are inferred automatically from the signal properties. The default is to use + 32ms windows, with 8ms hop length, and the square root of the hann window. + + Returns + ------- + STFTParams + STFT parameters for the AudioSignal. + + Examples + -------- + >>> stft_params = STFTParams(128, 32) + >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params) + >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params) + >>> signal1.stft_params = STFTParams() # Defaults + """ + return self._stft_params + + @stft_params.setter + def stft_params(self, value: STFTParams): + default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) + default_hop_len = default_win_len // 4 + default_win_type = "hann" + default_match_stride = False + default_padding_type = "reflect" + + default_stft_params = STFTParams( + window_length=default_win_len, + hop_length=default_hop_len, + window_type=default_win_type, + match_stride=default_match_stride, + padding_type=default_padding_type, + )._asdict() + + value = value._asdict() if value else default_stft_params + + for key in default_stft_params: + if value[key] is None: + value[key] = default_stft_params[key] + + self._stft_params = STFTParams(**value) + self.stft_data = None + + def compute_stft_padding( + self, window_length: int, hop_length: int, match_stride: bool + ): + """Compute how the STFT should be padded, based on match\_stride. + + Parameters + ---------- + window_length : int + Window length of STFT. + hop_length : int + Hop length of STFT. + match_stride : bool + Whether or not to match stride, making the STFT have the same alignment as + convolutional layers. + + Returns + ------- + tuple + Amount to pad on either side of audio. + """ + length = self.signal_length + + if match_stride: + assert ( + hop_length == window_length // 4 + ), "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(length / hop_length) * hop_length - length + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + return right_pad, pad + + def stft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + padding_type: str = None, + ): + """Computes the short-time Fourier transform of the audio data, + with specified STFT parameters. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + padding_type : str, optional + Type of padding to use, by default 'reflect' + + Returns + ------- + torch.Tensor + STFT of audio data. + + Examples + -------- + Compute the STFT of an AudioSignal: + + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> signal.stft() + + Vary the window and hop length: + + >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)] + >>> for stft_param in stft_params: + >>> signal.stft_params = stft_params + >>> signal.stft() + + """ + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length if hop_length is None else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride if match_stride is None else match_stride + ) + padding_type = ( + self.stft_params.padding_type if padding_type is None else padding_type + ) + + window = self.get_window(window_type, window_length, self.audio_data.device) + window = window.to(self.audio_data.device) + + audio_data = self.audio_data + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + audio_data = torch.nn.functional.pad( + audio_data, (pad, pad + right_pad), padding_type + ) + stft_data = torch.stft( + audio_data.reshape(-1, audio_data.shape[-1]), + n_fft=window_length, + hop_length=hop_length, + window=window, + return_complex=True, + center=True, + ) + _, nf, nt = stft_data.shape + stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt) + + if match_stride: + # Drop first two and last two frames, which are added + # because of padding. Now num_frames * hop_length = num_samples. + stft_data = stft_data[..., 2:-2] + self.stft_data = stft_data + + return stft_data + + def istft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + length: int = None, + ): + """Computes inverse STFT and sets it to audio\_data. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + length : int, optional + Original length of signal, by default None + + Returns + ------- + AudioSignal + AudioSignal with istft applied. + + Raises + ------ + RuntimeError + Raises an error if stft was not called prior to istft on the signal, + or if stft_data is not set. + """ + if self.stft_data is None: + raise RuntimeError("Cannot do inverse STFT without self.stft_data!") + + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length if hop_length is None else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride if match_stride is None else match_stride + ) + + window = self.get_window(window_type, window_length, self.stft_data.device) + + nb, nch, nf, nt = self.stft_data.shape + stft_data = self.stft_data.reshape(nb * nch, nf, nt) + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + + if length is None: + length = self.original_signal_length + length = length + 2 * pad + right_pad + + if match_stride: + # Zero-pad the STFT on either side, putting back the frames that were + # dropped in stft(). + stft_data = torch.nn.functional.pad(stft_data, (2, 2)) + + audio_data = torch.istft( + stft_data, + n_fft=window_length, + hop_length=hop_length, + window=window, + length=length, + center=True, + ) + audio_data = audio_data.reshape(nb, nch, -1) + if match_stride: + audio_data = audio_data[..., pad : -(pad + right_pad)] + self.audio_data = audio_data + + return self + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters( + sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None + ): + """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. + + Parameters + ---------- + sr : int + Sample rate of audio + n_fft : int + Number of FFT bins + n_mels : int + Number of mels + fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + fmax : float, optional + Highest frequency, by default None + + Returns + ------- + np.ndarray [shape=(n_mels, 1 + n_fft/2)] + Mel transform matrix + """ + from librosa.filters import mel as librosa_mel_fn + + return librosa_mel_fn( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + ) + + def mel_spectrogram( + self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs + ): + """Computes a Mel spectrogram. + + Parameters + ---------- + n_mels : int, optional + Number of mels, by default 80 + mel_fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + mel_fmax : float, optional + Highest frequency, by default None + kwargs : dict, optional + Keyword arguments to self.stft(). + + Returns + ------- + torch.Tensor [shape=(batch, channels, mels, time)] + Mel spectrogram. + """ + stft = self.stft(**kwargs) + magnitude = torch.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters( + sr=self.sample_rate, + n_fft=2 * (nf - 1), + n_mels=n_mels, + fmin=mel_fmin, + fmax=mel_fmax, + ) + mel_basis = torch.from_numpy(mel_basis).to(self.device) + + mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose(-1, 2) + return mel_spectrogram + + @staticmethod + @functools.lru_cache(None) + def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None): + """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), + it can be normalized depending on norm. For more information about dct: + http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + + Parameters + ---------- + n_mfcc : int + Number of mfccs + n_mels : int + Number of mels + norm : str + Use "ortho" to get a orthogonal matrix or None, by default "ortho" + device : str, optional + Device to load the transformation matrix on, by default None + + Returns + ------- + torch.Tensor [shape=(n_mels, n_mfcc)] T + The dct transformation matrix. + """ + from torchaudio.functional import create_dct + + return create_dct(n_mfcc, n_mels, norm).to(device) + + def mfcc( + self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs + ): + """Computes mel-frequency cepstral coefficients (MFCCs). + + Parameters + ---------- + n_mfcc : int, optional + Number of mels, by default 40 + n_mels : int, optional + Number of mels, by default 80 + log_offset: float, optional + Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 + kwargs : dict, optional + Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() + + Returns + ------- + torch.Tensor [shape=(batch, channels, mfccs, time)] + MFCCs. + """ + + mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) + mel_spectrogram = torch.log(mel_spectrogram + log_offset) + dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) + + mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat + mfcc = mfcc.transpose(-1, -2) + return mfcc + + @property + def magnitude(self): + """Computes and returns the absolute value of the STFT, which + is the magnitude. This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its magnitude + matches what this is set to, and modulated by the phase. + + Returns + ------- + torch.Tensor + Magnitude of STFT. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> magnitude = signal.magnitude # Computes stft if not computed + >>> magnitude[magnitude < magnitude.mean()] = 0 + >>> signal.magnitude = magnitude + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return torch.abs(self.stft_data) + + @magnitude.setter + def magnitude(self, value): + self.stft_data = value * torch.exp(1j * self.phase) + return + + def log_magnitude( + self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 + ): + """Computes the log-magnitude of the spectrogram. + + Parameters + ---------- + ref_value : float, optional + The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. + Zeros in the output correspond to positions where ``S == ref``, + by default 1.0 + amin : float, optional + Minimum threshold for ``S`` and ``ref``, by default 1e-5 + top_db : float, optional + Threshold the output at ``top_db`` below the peak: + ``max(10 * log10(S/ref)) - top_db``, by default -80.0 + + Returns + ------- + torch.Tensor + Log-magnitude spectrogram + """ + magnitude = self.magnitude + + amin = amin**2 + log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin)) + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + log_spec = torch.maximum(log_spec, log_spec.max() - top_db) + return log_spec + + @property + def phase(self): + """Computes and returns the phase of the STFT. + This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its phase + matches what this is set to, we original magnitudeith th. + + Returns + ------- + torch.Tensor + Phase of STFT. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> phase = signal.phase # Computes stft if not computed + >>> phase[phase < phase.mean()] = 0 + >>> signal.phase = phase + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return torch.angle(self.stft_data) + + @phase.setter + def phase(self, value): + self.stft_data = self.magnitude * torch.exp(1j * value) + return + + # Operator overloading + def __add__(self, other): + new_signal = self.clone() + new_signal.audio_data += util._get_value(other) + return new_signal + + def __iadd__(self, other): + self.audio_data += util._get_value(other) + return self + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + new_signal = self.clone() + new_signal.audio_data -= util._get_value(other) + return new_signal + + def __isub__(self, other): + self.audio_data -= util._get_value(other) + return self + + def __mul__(self, other): + new_signal = self.clone() + new_signal.audio_data *= util._get_value(other) + return new_signal + + def __imul__(self, other): + self.audio_data *= util._get_value(other) + return self + + def __rmul__(self, other): + return self * other + + # Representation + def _info(self): + dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" + info = { + "duration": f"{dur} seconds", + "batch_size": self.batch_size, + "path": self.path_to_file if self.path_to_file else "path unknown", + "sample_rate": self.sample_rate, + "num_channels": self.num_channels if self.num_channels else "[unknown]", + "audio_data.shape": self.audio_data.shape, + "stft_params": self.stft_params, + "device": self.device, + } + + return info + + def markdown(self): + """Produces a markdown representation of AudioSignal, in a markdown table. + + Returns + ------- + str + Markdown representation of AudioSignal. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> print(signal.markdown()) + | Key | Value + |---|--- + | duration | 1.000 seconds | + | batch_size | 1 | + | path | path unknown | + | sample_rate | 44100 | + | num_channels | 1 | + | audio_data.shape | torch.Size([1, 1, 44100]) | + | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | + | device | cpu | + """ + info = self._info() + + FORMAT = "| Key | Value \n" "|---|--- \n" + for k, v in info.items(): + row = f"| {k} | {v} |\n" + FORMAT += row + return FORMAT + + def __str__(self): + info = self._info() + + desc = "" + for k, v in info.items(): + desc += f"{k}: {v}\n" + return desc + + def __rich__(self): + from rich.table import Table + + info = self._info() + + table = Table(title=f"{self.__class__.__name__}") + table.add_column("Key", style="green") + table.add_column("Value", style="cyan") + + for k, v in info.items(): + table.add_row(k, str(v)) + return table + + # Comparison + def __eq__(self, other): + for k, v in list(self.__dict__.items()): + if torch.is_tensor(v): + if not torch.allclose(v, other.__dict__[k], atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + return True + + # Indexing + def __getitem__(self, key): + if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + audio_data = self.audio_data + _loudness = self._loudness + stft_data = self.stft_data + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + torch.is_tensor(key) and key.ndim <= 1 + ): + # Indexing only on the batch dimension. + # Then let's copy over relevant stuff. + # Future work: make this work for time-indexing + # as well, using the hop length. + audio_data = self.audio_data[key] + _loudness = self._loudness[key] if self._loudness is not None else None + stft_data = self.stft_data[key] if self.stft_data is not None else None + + sources = None + + copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params) + copy._loudness = _loudness + copy._stft_data = stft_data + copy.sources = sources + + return copy + + def __setitem__(self, key, value): + if not isinstance(value, type(self)): + self.audio_data[key] = value + return + + if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + self.audio_data = value.audio_data + self._loudness = value._loudness + self.stft_data = value.stft_data + return + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + torch.is_tensor(key) and key.ndim <= 1 + ): + if self.audio_data is not None and value.audio_data is not None: + self.audio_data[key] = value.audio_data + if self._loudness is not None and value._loudness is not None: + self._loudness[key] = value._loudness + if self.stft_data is not None and value.stft_data is not None: + self.stft_data[key] = value.stft_data + return + + def __ne__(self, other): + return not self == other diff --git a/flowae/models/ldm/dac/audiotools/core/display.py b/flowae/models/ldm/dac/audiotools/core/display.py new file mode 100644 index 0000000000000000000000000000000000000000..66cbcf34cb2cf9fdf8d67ec4418a887eba73f184 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/display.py @@ -0,0 +1,194 @@ +import inspect +import typing +from functools import wraps + +from . import util + + +def format_figure(func): + """Decorator for formatting figures produced by the code below. + See :py:func:`audiotools.core.util.format_figure` for more. + + Parameters + ---------- + func : Callable + Plotting function that is decorated by this function. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + f_keys = inspect.signature(util.format_figure).parameters.keys() + f_kwargs = {} + for k, v in list(kwargs.items()): + if k in f_keys: + kwargs.pop(k) + f_kwargs[k] = v + func(*args, **kwargs) + util.format_figure(**f_kwargs) + + return wrapper + + +class DisplayMixin: + @format_figure + def specshow( + self, + preemphasis: bool = False, + x_axis: str = "time", + y_axis: str = "linear", + n_mels: int = 128, + **kwargs, + ): + """Displays a spectrogram, using ``librosa.display.specshow``. + + Parameters + ---------- + preemphasis : bool, optional + Whether or not to apply preemphasis, which makes high + frequency detail easier to see, by default False + x_axis : str, optional + How to label the x axis, by default "time" + y_axis : str, optional + How to label the y axis, by default "linear" + n_mels : int, optional + If displaying a mel spectrogram with ``y_axis = "mel"``, + this controls the number of mels, by default 128. + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + # Always re-compute the STFT data before showing it, in case + # it changed. + signal = self.clone() + signal.stft_data = None + + if preemphasis: + signal.preemphasis() + + ref = signal.magnitude.max() + log_mag = signal.log_magnitude(ref_value=ref) + + if y_axis == "mel": + log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() + log_mag -= log_mag.max() + + librosa.display.specshow( + log_mag.numpy()[0].mean(axis=0), + x_axis=x_axis, + y_axis=y_axis, + sr=signal.sample_rate, + **kwargs, + ) + + @format_figure + def waveplot(self, x_axis: str = "time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + audio_data = self.audio_data[0].mean(dim=0) + audio_data = audio_data.cpu().numpy() + + plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" + wave_plot_fn = getattr(librosa.display, plot_fn) + wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) + + @format_figure + def wavespec(self, x_axis: str = "time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. + """ + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + gs = GridSpec(6, 1) + plt.subplot(gs[0, :]) + self.waveplot(x_axis=x_axis) + plt.subplot(gs[1:, :]) + self.specshow(x_axis=x_axis, **kwargs) + + def write_audio_to_tb( + self, + tag: str, + writer, + step: int = None, + plot_fn: typing.Union[typing.Callable, str] = "specshow", + **kwargs, + ): + """Writes a signal and its spectrogram to Tensorboard. Will show up + under the Audio and Images tab in Tensorboard. + + Parameters + ---------- + tag : str + Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be + written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). + writer : SummaryWriter + A SummaryWriter object from PyTorch library. + step : int, optional + The step to write the signal to, by default None + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + audio_data = self.audio_data[0, 0].detach().cpu() + sample_rate = self.sample_rate + writer.add_audio(tag, audio_data, step, sample_rate) + + if plot_fn is not None: + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + fig = plt.figure() + plt.clf() + plot_fn(**kwargs) + writer.add_figure(tag.replace("wav", "png"), fig, step) + + def save_image( + self, + image_path: str, + plot_fn: typing.Union[typing.Callable, str] = "specshow", + **kwargs, + ): + """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to + a specified file. + + Parameters + ---------- + image_path : str + Where to save the file to. + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + + plt.clf() + plot_fn(**kwargs) + plt.savefig(image_path, bbox_inches="tight", pad_inches=0) + plt.close() diff --git a/flowae/models/ldm/dac/audiotools/core/dsp.py b/flowae/models/ldm/dac/audiotools/core/dsp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9be51a119537b77e497ddc2dac126d569533d7c --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/dsp.py @@ -0,0 +1,390 @@ +import typing + +import julius +import numpy as np +import torch + +from . import util + + +class DSPMixin: + _original_batch_size = None + _original_num_channels = None + _padded_signal_length = None + + def _preprocess_signal_for_windowing(self, window_duration, hop_duration): + self._original_batch_size = self.batch_size + self._original_num_channels = self.num_channels + + window_length = int(window_duration * self.sample_rate) + hop_length = int(hop_duration * self.sample_rate) + + if window_length % hop_length != 0: + factor = window_length // hop_length + window_length = factor * hop_length + + self.zero_pad(hop_length, hop_length) + self._padded_signal_length = self.signal_length + + return window_length, hop_length + + def windows( + self, window_duration: float, hop_duration: float, preprocess: bool = True + ): + """Generator which yields windows of specified duration from signal with a specified + hop length. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Yields + ------ + AudioSignal + Each window is returned as an AudioSignal. + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration + ) + + self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length) + + for b in range(self.batch_size): + i = 0 + start_idx = i * hop_length + while True: + start_idx = i * hop_length + i += 1 + end_idx = start_idx + window_length + if end_idx > self.signal_length: + break + yield self[b, ..., start_idx:end_idx] + + def collect_windows( + self, window_duration: float, hop_duration: float, preprocess: bool = True + ): + """Reshapes signal into windows of specified duration from signal with a specified + hop length. Window are placed along the batch dimension. Use with + :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the + original signal. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Returns + ------- + AudioSignal + AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration + ) + + # self.audio_data: (nb, nch, nt). + unfolded = torch.nn.functional.unfold( + self.audio_data.reshape(-1, 1, 1, self.signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + # unfolded: (nb * nch, window_length, num_windows). + # -> (nb * nch * num_windows, 1, window_length) + unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length) + self.audio_data = unfolded + return self + + def overlap_and_add(self, hop_duration: float): + """Function which takes a list of windows and overlap adds them into a + signal the same length as ``audio_signal``. + + Parameters + ---------- + hop_duration : float + How much to shift for each window + (overlap is window_duration - hop_duration) in seconds. + + Returns + ------- + AudioSignal + overlap-and-added signal. + """ + hop_length = int(hop_duration * self.sample_rate) + window_length = self.signal_length + + nb, nch = self._original_batch_size, self._original_num_channels + + unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1) + folded = torch.nn.functional.fold( + unfolded, + output_size=(1, self._padded_signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + + norm = torch.ones_like(unfolded, device=unfolded.device) + norm = torch.nn.functional.fold( + norm, + output_size=(1, self._padded_signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + + folded = folded / norm + + folded = folded.reshape(nb, nch, -1) + self.audio_data = folded + self.trim(hop_length, hop_length) + return self + + def low_pass( + self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 + ): + """Low-passes the signal in-place. Each item in the batch + can have a different low-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same low-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[torch.Tensor, np.ndarray, float] + Cutoff in Hz of low-pass filter. + zeros : int, optional + Number of taps to use in low-pass filter, by default 51 + + Returns + ------- + AudioSignal + Low-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = torch.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) + filtered[i] = lp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def high_pass( + self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 + ): + """High-passes the signal in-place. Each item in the batch + can have a different high-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same high-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[torch.Tensor, np.ndarray, float] + Cutoff in Hz of high-pass filter. + zeros : int, optional + Number of taps to use in high-pass filter, by default 51 + + Returns + ------- + AudioSignal + High-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = torch.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) + filtered[i] = hp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def mask_frequencies( + self, + fmin_hz: typing.Union[torch.Tensor, np.ndarray, float], + fmax_hz: typing.Union[torch.Tensor, np.ndarray, float], + val: float = 0.0, + ): + """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + fmin_hz : typing.Union[torch.Tensor, np.ndarray, float] + Lower end of band to mask out. + fmax_hz : typing.Union[torch.Tensor, np.ndarray, float] + Upper end of band to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim) + fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim) + assert torch.all(fmin_hz < fmax_hz) + + # build mask + nbins = mag.shape[-2] + bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device) + bins_hz = bins_hz[None, None, :, None].repeat( + self.batch_size, 1, 1, mag.shape[-1] + ) + mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) + mask = mask.to(self.device) + + mag = mag.masked_fill(mask, val) + phase = phase.masked_fill(mask, val) + self.stft_data = mag * torch.exp(1j * phase) + return self + + def mask_timesteps( + self, + tmin_s: typing.Union[torch.Tensor, np.ndarray, float], + tmax_s: typing.Union[torch.Tensor, np.ndarray, float], + val: float = 0.0, + ): + """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + tmin_s : typing.Union[torch.Tensor, np.ndarray, float] + Lower end of timesteps to mask out. + tmax_s : typing.Union[torch.Tensor, np.ndarray, float] + Upper end of timesteps to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim) + tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim) + + assert torch.all(tmin_s < tmax_s) + + # build mask + nt = mag.shape[-1] + bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device) + bins_t = bins_t[None, None, None, :].repeat( + self.batch_size, 1, mag.shape[-2], 1 + ) + mask = (tmin_s <= bins_t) & (bins_t < tmax_s) + + mag = mag.masked_fill(mask, val) + phase = phase.masked_fill(mask, val) + self.stft_data = mag * torch.exp(1j * phase) + return self + + def mask_low_magnitudes( + self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0 + ): + """Mask away magnitudes below a specified threshold, which + can be different for every item in the batch. + + Parameters + ---------- + db_cutoff : typing.Union[torch.Tensor, np.ndarray, float] + Decibel value for which things below it will be masked away. + val : float, optional + Value to fill in for masked portions, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + mag = self.magnitude + log_mag = self.log_magnitude() + + db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + mask = log_mag < db_cutoff + mag = mag.masked_fill(mask, val) + + self.magnitude = mag + return self + + def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]): + """Shifts the phase by a constant value. + + Parameters + ---------- + shift : typing.Union[torch.Tensor, np.ndarray, float] + What to shift the phase by. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + self.phase = self.phase + shift + return self + + def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]): + """Corrupts the phase randomly by some scaled value. + + Parameters + ---------- + scale : typing.Union[torch.Tensor, np.ndarray, float] + Standard deviation of noise to add to the phase. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + scale = util.ensure_tensor(scale, ndim=self.phase.ndim) + self.phase = self.phase + scale * torch.randn_like(self.phase) + return self + + def preemphasis(self, coef: float = 0.85): + """Applies pre-emphasis to audio signal. + + Parameters + ---------- + coef : float, optional + How much pre-emphasis to apply, lower values do less. 0 does nothing. + by default 0.85 + + Returns + ------- + AudioSignal + Pre-emphasized signal. + """ + kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device) + x = self.audio_data.reshape(-1, 1, self.signal_length) + x = torch.nn.functional.conv1d(x, kernel, padding=1) + self.audio_data = x.reshape(*self.audio_data.shape) + return self diff --git a/flowae/models/ldm/dac/audiotools/core/effects.py b/flowae/models/ldm/dac/audiotools/core/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..fb534cbcb2d457575de685fc9248d1716879145b --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/effects.py @@ -0,0 +1,647 @@ +import typing + +import julius +import numpy as np +import torch +import torchaudio + +from . import util + + +class EffectMixin: + GAIN_FACTOR = np.log(10) / 20 + """Gain factor for converting between amplitude and decibels.""" + CODEC_PRESETS = { + "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, + "GSM-FR": {"format": "gsm"}, + "MP3": {"format": "mp3", "compression": -9}, + "Vorbis": {"format": "vorbis", "compression": -1}, + "Ogg": { + "format": "ogg", + "compression": -1, + }, + "Amr-nb": {"format": "amr-nb"}, + } + """Presets for applying codecs via torchaudio.""" + + def mix( + self, + other, + snr: typing.Union[torch.Tensor, np.ndarray, float] = 10, + other_eq: typing.Union[torch.Tensor, np.ndarray] = None, + ): + """Mixes noise with signal at specified + signal-to-noise ratio. Optionally, the + other signal can be equalized in-place. + + + Parameters + ---------- + other : AudioSignal + AudioSignal object to mix with. + snr : typing.Union[torch.Tensor, np.ndarray, float], optional + Signal to noise ratio, by default 10 + other_eq : typing.Union[torch.Tensor, np.ndarray], optional + EQ curve to apply to other signal, if any, by default None + + Returns + ------- + AudioSignal + In-place modification of AudioSignal. + """ + snr = util.ensure_tensor(snr).to(self.device) + + pad_len = max(0, self.signal_length - other.signal_length) + other.zero_pad(0, pad_len) + other.truncate_samples(self.signal_length) + if other_eq is not None: + other = other.equalizer(other_eq) + + tgt_loudness = self.loudness() - snr + other = other.normalize(tgt_loudness) + + self.audio_data = self.audio_data + other.audio_data + return self + + def convolve(self, other, start_at_max: bool = True): + """Convolves self with other. + This function uses FFTs to do the convolution. + + Parameters + ---------- + other : AudioSignal + Signal to convolve with. + start_at_max : bool, optional + Whether to start at the max value of other signal, to + avoid inducing delays, by default True + + Returns + ------- + AudioSignal + Convolved signal, in-place. + """ + from . import AudioSignal + + pad_len = self.signal_length - other.signal_length + + if pad_len > 0: + other.zero_pad(0, pad_len) + else: + other.truncate_samples(self.signal_length) + + if start_at_max: + # Use roll to rotate over the max for every item + # so that the impulse responses don't induce any + # delay. + idx = other.audio_data.abs().argmax(axis=-1) + irs = torch.zeros_like(other.audio_data) + for i in range(other.batch_size): + irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1) + other = AudioSignal(irs, other.sample_rate) + + delta = torch.zeros_like(other.audio_data) + delta[..., 0] = 1 + + length = self.signal_length + delta_fft = torch.fft.rfft(delta, length) + other_fft = torch.fft.rfft(other.audio_data, length) + self_fft = torch.fft.rfft(self.audio_data, length) + + convolved_fft = other_fft * self_fft + convolved_audio = torch.fft.irfft(convolved_fft, length) + + delta_convolved_fft = other_fft * delta_fft + delta_audio = torch.fft.irfft(delta_convolved_fft, length) + + # Use the delta to rescale the audio exactly as needed. + delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0] + scale = 1 / delta_max.clamp(1e-5) + convolved_audio = convolved_audio * scale + + self.audio_data = convolved_audio + + return self + + def apply_ir( + self, + ir, + drr: typing.Union[torch.Tensor, np.ndarray, float] = None, + ir_eq: typing.Union[torch.Tensor, np.ndarray] = None, + use_original_phase: bool = False, + ): + """Applies an impulse response to the signal. If ` is`ir_eq`` + is specified, the impulse response is equalized before + it is applied, using the given curve. + + Parameters + ---------- + ir : AudioSignal + Impulse response to convolve with. + drr : typing.Union[torch.Tensor, np.ndarray, float], optional + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + ir_eq : typing.Union[torch.Tensor, np.ndarray], optional + Equalization that will be applied to impulse response + if specified, by default None + use_original_phase : bool, optional + Whether to use the original phase, instead of the convolved + phase, by default False + + Returns + ------- + AudioSignal + Signal with impulse response applied to it + """ + if ir_eq is not None: + ir = ir.equalizer(ir_eq) + if drr is not None: + ir = ir.alter_drr(drr) + + # Save the peak before + max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values + + # Augment the impulse response to simulate microphone effects + # and with varying direct-to-reverberant ratio. + phase = self.phase + self.convolve(ir) + + # Use the input phase + if use_original_phase: + self.stft() + self.stft_data = self.magnitude * torch.exp(1j * phase) + self.istft() + + # Rescale to the input's amplitude + max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values + scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8) + self = self * scale_factor + + return self + + def ensure_max_of_audio(self, max: float = 1.0): + """Ensures that ``abs(audio_data) <= max``. + + Parameters + ---------- + max : float, optional + Max absolute value of signal, by default 1.0 + + Returns + ------- + AudioSignal + Signal with values scaled between -max and max. + """ + peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0] + peak_gain = torch.ones_like(peak) + peak_gain[peak > max] = max / peak[peak > max] + self.audio_data = self.audio_data * peak_gain + return self + + def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0): + """Normalizes the signal's volume to the specified db, in LUFS. + This is GPU-compatible, making for very fast loudness normalization. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray, float], optional + Loudness to normalize to, by default -24.0 + + Returns + ------- + AudioSignal + Normalized audio signal. + """ + db = util.ensure_tensor(db).to(self.device) + ref_db = self.loudness() + gain = db - ref_db + gain = torch.exp(gain * self.GAIN_FACTOR) + + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]): + """Change volume of signal by some amount, in dB. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray, float] + Amount to change volume by. + + Returns + ------- + AudioSignal + Signal at new volume. + """ + db = util.ensure_tensor(db, ndim=1).to(self.device) + gain = torch.exp(db * self.GAIN_FACTOR) + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def _to_2d(self): + waveform = self.audio_data.reshape(-1, self.signal_length) + return waveform + + def _to_3d(self, waveform): + return waveform.reshape(self.batch_size, self.num_channels, -1) + + def pitch_shift(self, n_semitones: int, quick: bool = True): + """Pitch shift the signal. All items in the batch + get the same pitch shift. + + Parameters + ---------- + n_semitones : int + How many semitones to shift the signal by. + quick : bool, optional + Using quick pitch shifting, by default True + + Returns + ------- + AudioSignal + Pitch shifted audio signal. + """ + device = self.device + effects = [ + ["pitch", str(n_semitones * 100)], + ["rate", str(self.sample_rate)], + ] + if quick: + effects[0].insert(1, "-q") + + waveform = self._to_2d().cpu() + waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, self.sample_rate, effects, channels_first=True + ) + self.sample_rate = sample_rate + self.audio_data = self._to_3d(waveform) + return self.to(device) + + def time_stretch(self, factor: float, quick: bool = True): + """Time stretch the audio signal. + + Parameters + ---------- + factor : float + Factor by which to stretch the AudioSignal. Typically + between 0.8 and 1.2. + quick : bool, optional + Whether to use quick time stretching, by default True + + Returns + ------- + AudioSignal + Time-stretched AudioSignal. + """ + device = self.device + effects = [ + ["tempo", str(factor)], + ["rate", str(self.sample_rate)], + ] + if quick: + effects[0].insert(1, "-q") + + waveform = self._to_2d().cpu() + waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, self.sample_rate, effects, channels_first=True + ) + self.sample_rate = sample_rate + self.audio_data = self._to_3d(waveform) + return self.to(device) + + def apply_codec( + self, + preset: str = None, + format: str = "wav", + encoding: str = None, + bits_per_sample: int = None, + compression: int = None, + ): # pragma: no cover + """Applies an audio codec to the signal. + + Parameters + ---------- + preset : str, optional + One of the keys in ``self.CODEC_PRESETS``, by default None + format : str, optional + Format for audio codec, by default "wav" + encoding : str, optional + Encoding to use, by default None + bits_per_sample : int, optional + How many bits per sample, by default None + compression : int, optional + Compression amount of codec, by default None + + Returns + ------- + AudioSignal + AudioSignal with codec applied. + + Raises + ------ + ValueError + If preset is not in ``self.CODEC_PRESETS``, an error + is thrown. + """ + torchaudio_version_070 = "0.7" in torchaudio.__version__ + if torchaudio_version_070: + return self + + kwargs = { + "format": format, + "encoding": encoding, + "bits_per_sample": bits_per_sample, + "compression": compression, + } + + if preset is not None: + if preset in self.CODEC_PRESETS: + kwargs = self.CODEC_PRESETS[preset] + else: + raise ValueError( + f"Unknown preset: {preset}. " + f"Known presets: {list(self.CODEC_PRESETS.keys())}" + ) + + waveform = self._to_2d() + if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]: + # Apply it in a for loop + augmented = torch.cat( + [ + torchaudio.functional.apply_codec( + waveform[i][None, :], self.sample_rate, **kwargs + ) + for i in range(waveform.shape[0]) + ], + dim=0, + ) + else: + augmented = torchaudio.functional.apply_codec( + waveform, self.sample_rate, **kwargs + ) + augmented = self._to_3d(augmented) + + self.audio_data = augmented + return self + + def mel_filterbank(self, n_bands: int): + """Breaks signal into mel bands. + + Parameters + ---------- + n_bands : int + Number of mel bands to use. + + Returns + ------- + torch.Tensor + Mel-filtered bands, with last axis being the band index. + """ + filterbank = ( + julius.SplitBands(self.sample_rate, n_bands).float().to(self.device) + ) + filtered = filterbank(self.audio_data) + return filtered.permute(1, 2, 3, 0) + + def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]): + """Applies a mel-spaced equalizer to the audio signal. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray] + EQ curve to apply. + + Returns + ------- + AudioSignal + AudioSignal with equalization applied. + """ + db = util.ensure_tensor(db) + n_bands = db.shape[-1] + fbank = self.mel_filterbank(n_bands) + + # If there's a batch dimension, make sure it's the same. + if db.ndim == 2: + if db.shape[0] != 1: + assert db.shape[0] == fbank.shape[0] + else: + db = db.unsqueeze(0) + + weights = (10**db).to(self.device).float() + fbank = fbank * weights[:, None, None, :] + eq_audio_data = fbank.sum(-1) + self.audio_data = eq_audio_data + return self + + def clip_distortion( + self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float] + ): + """Clips the signal at a given percentile. The higher it is, + the lower the threshold for clipping. + + Parameters + ---------- + clip_percentile : typing.Union[torch.Tensor, np.ndarray, float] + Values are between 0.0 to 1.0. Typical values are 0.1 or below. + + Returns + ------- + AudioSignal + Audio signal with clipped audio data. + """ + clip_percentile = util.ensure_tensor(clip_percentile, ndim=1) + min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1) + max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1) + + nc = self.audio_data.shape[1] + min_thresh = min_thresh[:, :nc, :] + max_thresh = max_thresh[:, :nc, :] + + self.audio_data = self.audio_data.clamp(min_thresh, max_thresh) + + return self + + def quantization( + self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] + ): + """Applies quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] + Number of evenly spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + quantization_channels = util.ensure_tensor(quantization_channels, ndim=3) + + x = self.audio_data + x = (x + 1) / 2 + x = x * quantization_channels + x = x.floor() + x = x / quantization_channels + x = 2 * x - 1 + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def mulaw_quantization( + self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] + ): + """Applies mu-law quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] + Number of mu-law spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + mu = quantization_channels - 1.0 + mu = util.ensure_tensor(mu, ndim=3) + + x = self.audio_data + + # quantize + x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) + x = ((x + 1) / 2 * mu + 0.5).to(torch.int64) + + # unquantize + x = (x / mu) * 2 - 1.0 + x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def __matmul__(self, other): + return self.convolve(other) + + +class ImpulseResponseMixin: + """These functions are generally only used with AudioSignals that are derived + from impulse responses, not other sources like music or speech. These methods + are used to replicate the data augmentation described in [1]. + + 1. Bryan, Nicholas J. "Impulse response data augmentation and deep + neural networks for blind room acoustic parameter estimation." + ICASSP 2020-2020 IEEE International Conference on Acoustics, + Speech and Signal Processing (ICASSP). IEEE, 2020. + """ + + def decompose_ir(self): + """Decomposes an impulse response into early and late + field responses. + """ + # Equations 1 and 2 + # ----------------- + # Breaking up into early + # response + late field response. + + td = torch.argmax(self.audio_data, dim=-1, keepdim=True) + t0 = int(self.sample_rate * 0.0025) + + idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :] + idx = idx.expand(self.batch_size, -1, -1) + early_idx = (idx >= td - t0) * (idx <= td + t0) + + early_response = torch.zeros_like(self.audio_data, device=self.device) + early_response[early_idx] = self.audio_data[early_idx] + + late_idx = ~early_idx + late_field = torch.zeros_like(self.audio_data, device=self.device) + late_field[late_idx] = self.audio_data[late_idx] + + # Equation 4 + # ---------- + # Decompose early response into windowed + # direct path and windowed residual. + + window = torch.zeros_like(self.audio_data, device=self.device) + for idx in range(self.batch_size): + window_idx = early_idx[idx, 0].nonzero() + window[idx, ..., window_idx] = self.get_window( + "hann", window_idx.shape[-1], self.device + ) + return early_response, late_field, window + + def measure_drr(self): + """Measures the direct-to-reverberant ratio of the impulse + response. + + Returns + ------- + float + Direct-to-reverberant ratio + """ + early_response, late_field, _ = self.decompose_ir() + num = (early_response**2).sum(dim=-1) + den = (late_field**2).sum(dim=-1) + drr = 10 * torch.log10(num / den) + return drr + + @staticmethod + def solve_alpha(early_response, late_field, wd, target_drr): + """Used to solve for the alpha value, which is used + to alter the drr. + """ + # Equation 5 + # ---------- + # Apply the good ol' quadratic formula. + + wd_sq = wd**2 + wd_sq_1 = (1 - wd) ** 2 + e_sq = early_response**2 + l_sq = late_field**2 + a = (wd_sq * e_sq).sum(dim=-1) + b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1) + c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum( + dim=-1 + ) + + expr = ((b**2) - 4 * a * c).sqrt() + alpha = torch.maximum( + (-b - expr) / (2 * a), + (-b + expr) / (2 * a), + ) + return alpha + + def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]): + """Alters the direct-to-reverberant ratio of the impulse response. + + Parameters + ---------- + drr : typing.Union[torch.Tensor, np.ndarray, float] + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + + Returns + ------- + AudioSignal + Altered impulse response. + """ + drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device) + + early_response, late_field, window = self.decompose_ir() + alpha = self.solve_alpha(early_response, late_field, window, drr) + min_alpha = ( + late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0] + ) + alpha = torch.maximum(alpha, min_alpha)[..., None] + + aug_ir_data = ( + alpha * window * early_response + + ((1 - window) * early_response) + + late_field + ) + self.audio_data = aug_ir_data + self.ensure_max_of_audio() + return self diff --git a/flowae/models/ldm/dac/audiotools/core/ffmpeg.py b/flowae/models/ldm/dac/audiotools/core/ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..83f9cd197d7dc8748a16be77614cc593a6a33297 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/ffmpeg.py @@ -0,0 +1,211 @@ +import json +import shlex +import subprocess +import tempfile +from pathlib import Path +from typing import Tuple + +import ffmpy +import numpy as np +import torch + + +def r128stats(filepath: str, quiet: bool): + """Takes a path to an audio file, returns a dict with the loudness + stats computed by the ffmpeg ebur128 filter. + + Parameters + ---------- + filepath : str + Path to compute loudness stats on. + quiet : bool + Whether to show FFMPEG output during computation. + + Returns + ------- + dict + Dictionary containing loudness stats. + """ + ffargs = [ + "ffmpeg", + "-nostats", + "-i", + filepath, + "-filter_complex", + "ebur128", + "-f", + "null", + "-", + ] + if quiet: + ffargs += ["-hide_banner"] + proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True) + stats = proc.communicate()[1] + summary_index = stats.rfind("Summary:") + + summary_list = stats[summary_index:].split() + i_lufs = float(summary_list[summary_list.index("I:") + 1]) + i_thresh = float(summary_list[summary_list.index("I:") + 4]) + lra = float(summary_list[summary_list.index("LRA:") + 1]) + lra_thresh = float(summary_list[summary_list.index("LRA:") + 4]) + lra_low = float(summary_list[summary_list.index("low:") + 1]) + lra_high = float(summary_list[summary_list.index("high:") + 1]) + stats_dict = { + "I": i_lufs, + "I Threshold": i_thresh, + "LRA": lra, + "LRA Threshold": lra_thresh, + "LRA Low": lra_low, + "LRA High": lra_high, + } + + return stats_dict + + +def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]: + """Given a path to a file, returns the start time offset and codec of + the first audio stream. + """ + ff = ffmpy.FFprobe( + inputs={path: None}, + global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet", + ) + streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"] + seconds_offset = 0.0 + codec = None + + # Get the offset and codec of the first audio stream we find + # and return its start time, if it has one. + for stream in streams: + if stream["codec_type"] == "audio": + seconds_offset = stream.get("start_time", 0.0) + codec = stream.get("codec_name") + break + return float(seconds_offset), codec + + +class FFMPEGMixin: + _loudness = None + + def ffmpeg_loudness(self, quiet: bool = True): + """Computes loudness of audio file using FFMPEG. + + Parameters + ---------- + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + torch.Tensor + Loudness of every item in the batch, computed via + FFMPEG. + """ + loudness = [] + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + for i in range(self.batch_size): + self[i].write(f.name) + loudness_stats = r128stats(f.name, quiet=quiet) + loudness.append(loudness_stats["I"]) + + self._loudness = torch.from_numpy(np.array(loudness)).float() + return self.loudness() + + def ffmpeg_resample(self, sample_rate: int, quiet: bool = True): + """Resamples AudioSignal using FFMPEG. More memory-efficient + than using julius.resample for long audio files. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + AudioSignal + Resampled AudioSignal. + """ + from audiotools import AudioSignal + + if sample_rate == self.sample_rate: + return self + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + f_out = f.name.replace("wav", "rs.wav") + command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}" + if quiet: + command += " -hide_banner -loglevel error" + subprocess.check_call(shlex.split(command)) + resampled = AudioSignal(f_out) + Path.unlink(Path(f_out)) + return resampled + + @classmethod + def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs): + """Loads AudioSignal object after decoding it to a wav file using FFMPEG. + Useful for loading audio that isn't covered by librosa's loading mechanism. Also + useful for loading mp3 files, without any offset. + + Parameters + ---------- + audio_path : str + Path to load AudioSignal from. + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + AudioSignal + AudioSignal loaded from file with FFMPEG. + """ + audio_path = str(audio_path) + with tempfile.TemporaryDirectory() as d: + wav_file = str(Path(d) / "extracted.wav") + padded_wav = str(Path(d) / "padded.wav") + + global_options = "-y" + if quiet: + global_options += " -loglevel error" + + ff = ffmpy.FFmpeg( + inputs={audio_path: None}, + # For inputs that are m4a (and others?), the input audio can + # have samples that don't match the sample rate. This aresample + # option forces ffmpeg to read timing information in the source + # file instead of assuming constant sample rate. + # + # This fixes an issue where an input m4a file might be a + # different length than the output wav file + outputs={wav_file: "-af aresample=async=1000"}, + global_options=global_options, + ) + ff.run() + + # We pad the file using the start time offset in case it's an audio + # stream starting at some offset in a video container. + pad, codec = ffprobe_offset_and_codec(audio_path) + + # For mp3s, don't pad files with discrepancies less than 0.027s - + # it's likely due to codec latency. The amount of latency introduced + # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold + # here slightly above that. + # Source: https://lame.sourceforge.io/tech-FAQ.txt. + if codec == "mp3" and pad < 0.027: + pad = 0.0 + ff = ffmpy.FFmpeg( + inputs={wav_file: None}, + outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"}, + global_options=global_options, + ) + ff.run() + + signal = cls(padded_wav, **kwargs) + + return signal diff --git a/flowae/models/ldm/dac/audiotools/core/loudness.py b/flowae/models/ldm/dac/audiotools/core/loudness.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3ee2675d7cb71f4c00106b0c1e901b8e51b842 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/loudness.py @@ -0,0 +1,320 @@ +import copy + +import julius +import numpy as np +import scipy +import torch +import torch.nn.functional as F +import torchaudio + + +class Meter(torch.nn.Module): + """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors. + + Parameters + ---------- + rate : int + Sample rate of audio. + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + zeros : int, optional + Number of zeros to use in FIR approximation of + IIR filters, by default 512 + use_fir : bool, optional + Whether to use FIR approximation or exact IIR formulation. + If computing on GPU, ``use_fir=True`` will be used, as its + much faster, by default False + """ + + def __init__( + self, + rate: int, + filter_class: str = "K-weighting", + block_size: float = 0.400, + zeros: int = 512, + use_fir: bool = False, + ): + super().__init__() + + self.rate = rate + self.filter_class = filter_class + self.block_size = block_size + self.use_fir = use_fir + + G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41])) + self.register_buffer("G", G) + + # Compute impulse responses so that filtering is fast via + # a convolution at runtime, on GPU, unlike lfilter. + impulse = np.zeros((zeros,)) + impulse[..., 0] = 1.0 + + firs = np.zeros((len(self._filters), 1, zeros)) + passband_gain = torch.zeros(len(self._filters)) + + for i, (_, filter_stage) in enumerate(self._filters.items()): + firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse) + passband_gain[i] = filter_stage.passband_gain + + firs = torch.from_numpy(firs[..., ::-1].copy()).float() + + self.register_buffer("firs", firs) + self.register_buffer("passband_gain", passband_gain) + + def apply_filter_gpu(self, data: torch.Tensor): + """Performs FIR approximation of loudness computation. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + # Data is of shape (nb, nch, nt) + # Reshape to (nb*nch, 1, nt) + nb, nt, nch = data.shape + data = data.permute(0, 2, 1) + data = data.reshape(nb * nch, 1, nt) + + # Apply padding + pad_length = self.firs.shape[-1] + + # Apply filtering in sequence + for i in range(self.firs.shape[0]): + data = F.pad(data, (pad_length, pad_length)) + data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...]) + data = self.passband_gain[i] * data + data = data[..., 1 : nt + 1] + + data = data.permute(0, 2, 1) + data = data[:, :nt, :] + return data + + def apply_filter_cpu(self, data: torch.Tensor): + """Performs IIR formulation of loudness computation. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + for _, filter_stage in self._filters.items(): + passband_gain = filter_stage.passband_gain + + a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device) + b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device) + + _data = data.permute(0, 2, 1) + filtered = torchaudio.functional.lfilter( + _data, a_coeffs, b_coeffs, clamp=False + ) + data = passband_gain * filtered.permute(0, 2, 1) + return data + + def apply_filter(self, data: torch.Tensor): + """Applies filter on either CPU or GPU, depending + on if the audio is on GPU or is on CPU, or if + ``self.use_fir`` is True. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + if data.is_cuda or self.use_fir: + data = self.apply_filter_gpu(data) + else: + data = self.apply_filter_cpu(data) + return data + + def forward(self, data: torch.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + return self.integrated_loudness(data) + + def _unfold(self, input_data): + T_g = self.block_size + overlap = 0.75 # overlap of 75% of the block duration + step = 1.0 - overlap # step size by percentage + + kernel_size = int(T_g * self.rate) + stride = int(T_g * self.rate * step) + unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride) + unfolded = unfolded.transpose(-1, -2) + + return unfolded + + def integrated_loudness(self, data: torch.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + if not torch.is_tensor(data): + data = torch.from_numpy(data).float() + else: + data = data.float() + + input_data = copy.copy(data) + # Data always has a batch and channel dimension. + # Is of shape (nb, nt, nch) + if input_data.ndim < 2: + input_data = input_data.unsqueeze(-1) + if input_data.ndim < 3: + input_data = input_data.unsqueeze(0) + + nb, nt, nch = input_data.shape + + # Apply frequency weighting filters - account + # for the acoustic respose of the head and auditory system + input_data = self.apply_filter(input_data) + + G = self.G # channel gains + T_g = self.block_size # 400 ms gating block standard + Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold + + unfolded = self._unfold(input_data) + + z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) + l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True)) + l = l.expand_as(z) + + # find gating block indices above absolute threshold + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + masked = l > Gamma_a + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # calculate the relative threshold value (see eq. 6) + Gamma_r = ( + -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 + ) + Gamma_r = Gamma_r[:, None, None] + Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1]) + + # find gating block indices above relative and absolute thresholds (end of eq. 7) + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + z_avg_gated[l <= Gamma_r] = 0 + masked = (l > Gamma_a) * (l > Gamma_r) + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version) + # z_avg_gated = torch.nan_to_num(z_avg_gated) + z_avg_gated = torch.where( + z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated + ) + z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max) + z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min) + + LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1)) + return LUFS.float() + + @property + def filter_class(self): + return self._filter_class + + @filter_class.setter + def filter_class(self, value): + from pyloudnorm import Meter + + meter = Meter(self.rate) + meter.filter_class = value + self._filter_class = value + self._filters = meter._filters + + +class LoudnessMixin: + _loudness = None + MIN_LOUDNESS = -70 + """Minimum loudness possible.""" + + def loudness( + self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs + ): + """Calculates loudness using an implementation of ITU-R BS.1770-4. + Allows control over gating block size and frequency weighting filters for + additional control. Measure the integrated gated loudness of a signal. + + API is derived from PyLoudnorm, but this implementation is ported to PyTorch + and is tensorized across batches. When on GPU, an FIR approximation of the IIR + filters is used to compute loudness for speed. + + Uses the weighting filters and block size defined by the meter + the integrated loudness is measured based upon the gating algorithm + defined in the ITU-R BS.1770-4 specification. + + Parameters + ---------- + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.loudness.Meter`. + + Returns + ------- + torch.Tensor + Loudness of audio data. + """ + if self._loudness is not None: + return self._loudness.to(self.device) + original_length = self.signal_length + if self.signal_duration < 0.5: + pad_len = int((0.5 - self.signal_duration) * self.sample_rate) + self.zero_pad(0, pad_len) + + # create BS.1770 meter + meter = Meter( + self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs + ) + meter = meter.to(self.device) + # measure loudness + loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1)) + self.truncate_samples(original_length) + min_loudness = ( + torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS + ) + self._loudness = torch.maximum(loudness, min_loudness) + + return self._loudness.to(self.device) diff --git a/flowae/models/ldm/dac/audiotools/core/playback.py b/flowae/models/ldm/dac/audiotools/core/playback.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0f21aaa392494f35305c0084c05b87667ea14d --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/playback.py @@ -0,0 +1,252 @@ +""" +These are utilities that allow one to embed an AudioSignal +as a playable object in a Jupyter notebook, or to play audio from +the terminal, etc. +""" # fmt: skip +import base64 +import io +import random +import string +import subprocess +from tempfile import NamedTemporaryFile + +import importlib_resources as pkg_resources + +from . import templates +from .util import _close_temp_files +from .util import format_figure + +headers = pkg_resources.files(templates).joinpath("headers.html").read_text() +widget = pkg_resources.files(templates).joinpath("widget.html").read_text() + +DEFAULT_EXTENSION = ".wav" + + +def _check_imports(): # pragma: no cover + try: + import ffmpy + except: + ffmpy = False + + try: + import IPython + except: + raise ImportError("IPython must be installed in order to use this function!") + return ffmpy, IPython + + +class PlayMixin: + def embed(self, ext: str = None, display: bool = True, return_html: bool = False): + """Embeds audio as a playable audio embed in a notebook, or HTML + document, etc. + + Parameters + ---------- + ext : str, optional + Extension to use when saving the audio, by default ".wav" + display : bool, optional + This controls whether or not to display the audio when called. This + is used when the embed is the last line in a Jupyter cell, to prevent + the audio from being embedded twice, by default True + return_html : bool, optional + Whether to return the data wrapped in an HTML audio element, by default False + + Returns + ------- + str + Either the element for display, or the HTML string of it. + """ + if ext is None: + ext = DEFAULT_EXTENSION + ext = f".{ext}" if not ext.startswith(".") else ext + ffmpy, IPython = _check_imports() + sr = self.sample_rate + tmpfiles = [] + + with _close_temp_files(tmpfiles): + tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False) + tmpfiles.append(tmp_wav) + self.write(tmp_wav.name) + if ext != ".wav" and ffmpy: + tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False) + tmpfiles.append(tmp_wav) + ff = ffmpy.FFmpeg( + inputs={tmp_wav.name: None}, + outputs={ + tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error" + }, + ) + ff.run() + else: + tmp_converted = tmp_wav + + audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr) + if display: + IPython.display.display(audio_element) + + if return_html: + audio_element = ( + f" " + ) + return audio_element + + def widget( + self, + title: str = None, + ext: str = ".wav", + add_headers: bool = True, + player_width: str = "100%", + margin: str = "10px", + plot_fn: str = "specshow", + return_html: bool = False, + **kwargs, + ): + """Creates a playable widget with spectrogram. Inspired (heavily) by + https://sjvasquez.github.io/blog/melnet/. + + Parameters + ---------- + title : str, optional + Title of plot, placed in upper right of top-most axis. + ext : str, optional + Extension for embedding, by default ".mp3" + add_headers : bool, optional + Whether or not to add headers (use for first embed, False for later embeds), by default True + player_width : str, optional + Width of the player, as a string in a CSS rule, by default "100%" + margin : str, optional + Margin on all sides of player, by default "10px" + plot_fn : function, optional + Plotting function to use (by default self.specshow). + return_html : bool, optional + Whether to return the data wrapped in an HTML audio element, by default False + kwargs : dict, optional + Keyword arguments to plot_fn (by default self.specshow). + + Returns + ------- + HTML + HTML object. + """ + import matplotlib.pyplot as plt + + def _save_fig_to_tag(): + buffer = io.BytesIO() + + plt.savefig(buffer, bbox_inches="tight", pad_inches=0) + plt.close() + + buffer.seek(0) + data_uri = base64.b64encode(buffer.read()).decode("ascii") + tag = "data:image/png;base64,{0}".format(data_uri) + + return tag + + _, IPython = _check_imports() + + header_html = "" + + if add_headers: + header_html = headers.replace("PLAYER_WIDTH", str(player_width)) + header_html = header_html.replace("MARGIN", str(margin)) + IPython.display.display(IPython.display.HTML(header_html)) + + widget_html = widget + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + kwargs["title"] = title + plot_fn(**kwargs) + + fig = plt.gcf() + pixels = fig.get_size_inches() * fig.dpi + + tag = _save_fig_to_tag() + + # Make the source image for the levels + self.specshow() + format_figure((12, 1.5)) + levels_tag = _save_fig_to_tag() + + player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10)) + + audio_elem = self.embed(ext=ext, display=False) + widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr()) + widget_html = widget_html.replace("IMAGE_SRC", tag) + widget_html = widget_html.replace("LEVELS_SRC", levels_tag) + widget_html = widget_html.replace("PLAYER_ID", player_id) + + # Calculate width/height of figure based on figure size. + widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px") + widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px") + + IPython.display.display(IPython.display.HTML(widget_html)) + + if return_html: + html = header_html if add_headers else "" + html += widget_html + return html + + def play(self): + """ + Plays an audio signal if ffplay from the ffmpeg suite of tools is installed. + Otherwise, will fail. The audio signal is written to a temporary file + and then played with ffplay. + """ + tmpfiles = [] + with _close_temp_files(tmpfiles): + tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False) + tmpfiles.append(tmp_wav) + self.write(tmp_wav.name) + print(self) + subprocess.call( + [ + "ffplay", + "-nodisp", + "-autoexit", + "-hide_banner", + "-loglevel", + "error", + tmp_wav.name, + ] + ) + return self + + +if __name__ == "__main__": # pragma: no cover + from audiotools import AudioSignal + + signal = AudioSignal( + "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5 + ) + + wave_html = signal.widget( + "Waveform", + plot_fn="waveplot", + return_html=True, + ) + + spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False) + + combined_html = signal.widget( + "Waveform + spectrogram", + plot_fn="wavespec", + return_html=True, + add_headers=False, + ) + + signal.low_pass(8000) + lowpass_html = signal.widget( + "Lowpassed audio", + plot_fn="wavespec", + return_html=True, + add_headers=False, + ) + + with open("/tmp/index.html", "w") as f: + f.write(wave_html) + f.write(spec_html) + f.write(combined_html) + f.write(lowpass_html) diff --git a/flowae/models/ldm/dac/audiotools/core/templates/__init__.py b/flowae/models/ldm/dac/audiotools/core/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flowae/models/ldm/dac/audiotools/core/templates/headers.html b/flowae/models/ldm/dac/audiotools/core/templates/headers.html new file mode 100644 index 0000000000000000000000000000000000000000..9eaef4a94d575f7826608ad63dcc77fab13b7b19 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/templates/headers.html @@ -0,0 +1,322 @@ + + + + + + diff --git a/flowae/models/ldm/dac/audiotools/core/templates/pandoc.css b/flowae/models/ldm/dac/audiotools/core/templates/pandoc.css new file mode 100644 index 0000000000000000000000000000000000000000..842be7be6d65580dab44c6a8013259644f38e6ee --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/templates/pandoc.css @@ -0,0 +1,407 @@ +/* +Copyright (c) 2017 Chris Patuzzo +https://twitter.com/chrispatuzzo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +body { + font-family: Helvetica, arial, sans-serif; + font-size: 14px; + line-height: 1.6; + padding-top: 10px; + padding-bottom: 10px; + background-color: white; + padding: 30px; + color: #333; +} + +body > *:first-child { + margin-top: 0 !important; +} + +body > *:last-child { + margin-bottom: 0 !important; +} + +a { + color: #4183C4; + text-decoration: none; +} + +a.absent { + color: #cc0000; +} + +a.anchor { + display: block; + padding-left: 30px; + margin-left: -30px; + cursor: pointer; + position: absolute; + top: 0; + left: 0; + bottom: 0; +} + +h1, h2, h3, h4, h5, h6 { + margin: 20px 0 10px; + padding: 0; + font-weight: bold; + -webkit-font-smoothing: antialiased; + cursor: text; + position: relative; +} + +h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child { + margin-top: 0; + padding-top: 0; +} + +h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor { + text-decoration: none; +} + +h1 tt, h1 code { + font-size: inherit; +} + +h2 tt, h2 code { + font-size: inherit; +} + +h3 tt, h3 code { + font-size: inherit; +} + +h4 tt, h4 code { + font-size: inherit; +} + +h5 tt, h5 code { + font-size: inherit; +} + +h6 tt, h6 code { + font-size: inherit; +} + +h1 { + font-size: 28px; + color: black; +} + +h2 { + font-size: 24px; + border-bottom: 1px solid #cccccc; + color: black; +} + +h3 { + font-size: 18px; +} + +h4 { + font-size: 16px; +} + +h5 { + font-size: 14px; +} + +h6 { + color: #777777; + font-size: 14px; +} + +p, blockquote, ul, ol, dl, li, table, pre { + margin: 15px 0; +} + +hr { + border: 0 none; + color: #cccccc; + height: 4px; + padding: 0; +} + +body > h2:first-child { + margin-top: 0; + padding-top: 0; +} + +body > h1:first-child { + margin-top: 0; + padding-top: 0; +} + +body > h1:first-child + h2 { + margin-top: 0; + padding-top: 0; +} + +body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child { + margin-top: 0; + padding-top: 0; +} + +a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 { + margin-top: 0; + padding-top: 0; +} + +h1 p, h2 p, h3 p, h4 p, h5 p, h6 p { + margin-top: 0; +} + +li p.first { + display: inline-block; +} + +ul, ol { + padding-left: 30px; +} + +ul :first-child, ol :first-child { + margin-top: 0; +} + +ul :last-child, ol :last-child { + margin-bottom: 0; +} + +dl { + padding: 0; +} + +dl dt { + font-size: 14px; + font-weight: bold; + font-style: italic; + padding: 0; + margin: 15px 0 5px; +} + +dl dt:first-child { + padding: 0; +} + +dl dt > :first-child { + margin-top: 0; +} + +dl dt > :last-child { + margin-bottom: 0; +} + +dl dd { + margin: 0 0 15px; + padding: 0 15px; +} + +dl dd > :first-child { + margin-top: 0; +} + +dl dd > :last-child { + margin-bottom: 0; +} + +blockquote { + border-left: 4px solid #dddddd; + padding: 0 15px; + color: #777777; +} + +blockquote > :first-child { + margin-top: 0; +} + +blockquote > :last-child { + margin-bottom: 0; +} + +table { + padding: 0; +} +table tr { + border-top: 1px solid #cccccc; + background-color: white; + margin: 0; + padding: 0; +} + +table tr:nth-child(2n) { + background-color: #f8f8f8; +} + +table tr th { + font-weight: bold; + border: 1px solid #cccccc; + text-align: left; + margin: 0; + padding: 6px 13px; +} + +table tr td { + border: 1px solid #cccccc; + text-align: left; + margin: 0; + padding: 6px 13px; +} + +table tr th :first-child, table tr td :first-child { + margin-top: 0; +} + +table tr th :last-child, table tr td :last-child { + margin-bottom: 0; +} + +img { + max-width: 100%; +} + +span.frame { + display: block; + overflow: hidden; +} + +span.frame > span { + border: 1px solid #dddddd; + display: block; + float: left; + overflow: hidden; + margin: 13px 0 0; + padding: 7px; + width: auto; +} + +span.frame span img { + display: block; + float: left; +} + +span.frame span span { + clear: both; + color: #333333; + display: block; + padding: 5px 0 0; +} + +span.align-center { + display: block; + overflow: hidden; + clear: both; +} + +span.align-center > span { + display: block; + overflow: hidden; + margin: 13px auto 0; + text-align: center; +} + +span.align-center span img { + margin: 0 auto; + text-align: center; +} + +span.align-right { + display: block; + overflow: hidden; + clear: both; +} + +span.align-right > span { + display: block; + overflow: hidden; + margin: 13px 0 0; + text-align: right; +} + +span.align-right span img { + margin: 0; + text-align: right; +} + +span.float-left { + display: block; + margin-right: 13px; + overflow: hidden; + float: left; +} + +span.float-left span { + margin: 13px 0 0; +} + +span.float-right { + display: block; + margin-left: 13px; + overflow: hidden; + float: right; +} + +span.float-right > span { + display: block; + overflow: hidden; + margin: 13px auto 0; + text-align: right; +} + +code, tt { + margin: 0 2px; + padding: 0 5px; + white-space: nowrap; + border-radius: 3px; +} + +pre code { + margin: 0; + padding: 0; + white-space: pre; + border: none; + background: transparent; +} + +.highlight pre { + font-size: 13px; + line-height: 19px; + overflow: auto; + padding: 6px 10px; + border-radius: 3px; +} + +pre { + font-size: 13px; + line-height: 19px; + overflow: auto; + padding: 6px 10px; + border-radius: 3px; +} + +pre code, pre tt { + background-color: transparent; + border: none; +} + +body { + max-width: 600px; +} diff --git a/flowae/models/ldm/dac/audiotools/core/templates/widget.html b/flowae/models/ldm/dac/audiotools/core/templates/widget.html new file mode 100644 index 0000000000000000000000000000000000000000..0b44e8aec64fd1db929da5fa6208dee00247c967 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/templates/widget.html @@ -0,0 +1,52 @@ +
+
+
+
+ +
+
+ +
+ + + +
+ +
+ + +
+
+ + diff --git a/flowae/models/ldm/dac/audiotools/core/util.py b/flowae/models/ldm/dac/audiotools/core/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ece1344658d10836aa2eb693f275294ad8cdbb52 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/util.py @@ -0,0 +1,671 @@ +import csv +import glob +import math +import numbers +import os +import random +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict +from typing import List + +import numpy as np +import torch +import torchaudio +from flatten_dict import flatten +from flatten_dict import unflatten + + +@dataclass +class Info: + """Shim for torchaudio.info API changes.""" + + sample_rate: float + num_frames: int + + @property + def duration(self) -> float: + return self.num_frames / self.sample_rate + + +def info(audio_path: str): + """Shim for torchaudio.info to make 0.7.2 API match 0.8.0. + + Parameters + ---------- + audio_path : str + Path to audio file. + """ + # try default backend first, then fallback to soundfile + try: + info = torchaudio.info(str(audio_path)) + except: # pragma: no cover + info = torchaudio.backend.soundfile_backend.info(str(audio_path)) + + if isinstance(info, tuple): # pragma: no cover + signal_info = info[0] + info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length) + else: + info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames) + + return info + + +def ensure_tensor( + x: typing.Union[np.ndarray, torch.Tensor, float, int], + ndim: int = None, + batch_size: int = None, +): + """Ensures that the input ``x`` is a tensor of specified + dimensions and batch size. + + Parameters + ---------- + x : typing.Union[np.ndarray, torch.Tensor, float, int] + Data that will become a tensor on its way out. + ndim : int, optional + How many dimensions should be in the output, by default None + batch_size : int, optional + The batch size of the output, by default None + + Returns + ------- + torch.Tensor + Modified version of ``x`` as a tensor. + """ + if not torch.is_tensor(x): + x = torch.as_tensor(x) + if ndim is not None: + assert x.ndim <= ndim + while x.ndim < ndim: + x = x.unsqueeze(-1) + if batch_size is not None: + if x.shape[0] != batch_size: + shape = list(x.shape) + shape[0] = batch_size + x = x.expand(*shape) + return x + + +def _get_value(other): + from . import AudioSignal + + if isinstance(other, AudioSignal): + return other.audio_data + return other + + +def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int): + """Closest frequency bin given a frequency, number + of bins, and a sampling rate. + + Parameters + ---------- + hz : torch.Tensor + Tensor of frequencies in Hz. + n_fft : int + Number of FFT bins. + sample_rate : int + Sample rate of audio. + + Returns + ------- + torch.Tensor + Closest bins to the data. + """ + shape = hz.shape + hz = hz.flatten() + freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2) + hz[hz > sample_rate / 2] = sample_rate / 2 + + closest = (hz[None, :] - freqs[:, None]).abs() + closest_bins = closest.min(dim=0).indices + + return closest_bins.reshape(*shape) + + +def random_state(seed: typing.Union[int, np.random.RandomState]): + """ + Turn seed into a np.random.RandomState instance. + + Parameters + ---------- + seed : typing.Union[int, np.random.RandomState] or None + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + + Returns + ------- + np.random.RandomState + Random state object. + + Raises + ------ + ValueError + If seed is not valid, an error is thrown. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + elif isinstance(seed, (numbers.Integral, np.integer, int)): + return np.random.RandomState(seed) + elif isinstance(seed, np.random.RandomState): + return seed + else: + raise ValueError( + "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed + ) + + +def seed(random_seed, set_cudnn=False): + """ + Seeds all random states with the same random seed + for reproducibility. Seeds ``numpy``, ``random`` and ``torch`` + random generators. + For full reproducibility, two further options must be set + according to the torch documentation: + https://pytorch.org/docs/stable/notes/randomness.html + To do this, ``set_cudnn`` must be True. It defaults to + False, since setting it to True results in a performance + hit. + + Args: + random_seed (int): integer corresponding to random seed to + use. + set_cudnn (bool): Whether or not to set cudnn into determinstic + mode and off of benchmark mode. Defaults to False. + """ + + torch.manual_seed(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) + + if set_cudnn: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@contextmanager +def _close_temp_files(tmpfiles: list): + """Utility function for creating a context and closing all temporary files + once the context is exited. For correct functionality, all temporary file + handles created inside the context must be appended to the ```tmpfiles``` + list. + + This function is taken wholesale from Scaper. + + Parameters + ---------- + tmpfiles : list + List of temporary file handles + """ + + def _close(): + for t in tmpfiles: + try: + t.close() + os.unlink(t.name) + except: + pass + + try: + yield + except: # pragma: no cover + _close() + raise + _close() + + +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] + + +def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): + """Finds all audio files in a directory recursively. + Returns a list. + + Parameters + ---------- + folder : str + Folder to look for audio files in, recursively. + ext : List[str], optional + Extensions to look for without the ., by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + """ + folder = Path(folder) + # Take care of case where user has passed in an audio file directly + # into one of the calling functions. + if str(folder).endswith(tuple(ext)): + # if, however, there's a glob in the path, we need to + # return the glob, not the file. + if "*" in str(folder): + return glob.glob(str(folder), recursive=("**" in str(folder))) + else: + return [folder] + + files = [] + for x in ext: + files += folder.glob(f"**/*{x}") + return files + + +def read_sources( + sources: List[str], + remove_empty: bool = True, + relative_path: str = "", + ext: List[str] = AUDIO_EXTENSIONS, +): + """Reads audio sources that can either be folders + full of audio files, or CSV files that contain paths + to audio files. CSV files that adhere to the expected + format can be generated by + :py:func:`audiotools.data.preprocess.create_csv`. + + Parameters + ---------- + sources : List[str] + List of audio sources to be converted into a + list of lists of audio files. + remove_empty : bool, optional + Whether or not to remove rows with an empty "path" + from each CSV file, by default True. + + Returns + ------- + list + List of lists of rows of CSV files. + """ + files = [] + relative_path = Path(relative_path) + for source in sources: + source = str(source) + _files = [] + if source.endswith(".csv"): + with open(source, "r") as f: + reader = csv.DictReader(f) + for x in reader: + if remove_empty and x["path"] == "": + continue + if x["path"] != "": + x["path"] = str(relative_path / x["path"]) + _files.append(x) + else: + for x in find_audio(source, ext=ext): + x = str(relative_path / x) + _files.append({"path": x}) + files.append(sorted(_files, key=lambda x: x["path"])) + return files + + +def choose_from_list_of_lists( + state: np.random.RandomState, list_of_lists: list, p: float = None +): + """Choose a single item from a list of lists. + + Parameters + ---------- + state : np.random.RandomState + Random state to use when choosing an item. + list_of_lists : list + A list of lists from which items will be drawn. + p : float, optional + Probabilities of each list, by default None + + Returns + ------- + typing.Any + An item from the list of lists. + """ + source_idx = state.choice(list(range(len(list_of_lists))), p=p) + item_idx = state.randint(len(list_of_lists[source_idx])) + return list_of_lists[source_idx][item_idx], source_idx, item_idx + + +@contextmanager +def chdir(newdir: typing.Union[Path, str]): + """ + Context manager for switching directories to run a + function. Useful for when you want to use relative + paths to different runs. + + Parameters + ---------- + newdir : typing.Union[Path, str] + Directory to switch to. + """ + curdir = os.getcwd() + try: + os.chdir(newdir) + yield + finally: + os.chdir(curdir) + + +def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"): + """Moves items in a batch (typically generated by a DataLoader as a list + or a dict) to the specified device. This works even if dictionaries + are nested. + + Parameters + ---------- + batch : typing.Union[dict, list, torch.Tensor] + Batch, typically generated by a dataloader, that will be moved to + the device. + device : str, optional + Device to move batch to, by default "cpu" + + Returns + ------- + typing.Union[dict, list, torch.Tensor] + Batch with all values moved to the specified device. + """ + if isinstance(batch, dict): + batch = flatten(batch) + for key, val in batch.items(): + try: + batch[key] = val.to(device) + except: + pass + batch = unflatten(batch) + elif torch.is_tensor(batch): + batch = batch.to(device) + elif isinstance(batch, list): + for i in range(len(batch)): + try: + batch[i] = batch[i].to(device) + except: + pass + return batch + + +def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): + """Samples from a distribution defined by a tuple. The first + item in the tuple is the distribution type, and the rest of the + items are arguments to that distribution. The distribution function + is gotten from the ``np.random.RandomState`` object. + + Parameters + ---------- + dist_tuple : tuple + Distribution tuple + state : np.random.RandomState, optional + Random state, or seed to use, by default None + + Returns + ------- + typing.Union[float, int, str] + Draw from the distribution. + + Examples + -------- + Sample from a uniform distribution: + + >>> dist_tuple = ("uniform", 0, 1) + >>> sample_from_dist(dist_tuple) + + Sample from a constant distribution: + + >>> dist_tuple = ("const", 0) + >>> sample_from_dist(dist_tuple) + + Sample from a normal distribution: + + >>> dist_tuple = ("normal", 0, 0.5) + >>> sample_from_dist(dist_tuple) + + """ + if dist_tuple[0] == "const": + return dist_tuple[1] + state = random_state(state) + dist_fn = getattr(state, dist_tuple[0]) + return dist_fn(*dist_tuple[1:]) + + +def collate(list_of_dicts: list, n_splits: int = None): + """Collates a list of dictionaries (e.g. as returned by a + dataloader) into a dictionary with batched values. This routine + uses the default torch collate function for everything + except AudioSignal objects, which are handled by the + :py:func:`audiotools.core.audio_signal.AudioSignal.batch` + function. + + This function takes n_splits to enable splitting a batch + into multiple sub-batches for the purposes of gradient accumulation, + etc. + + Parameters + ---------- + list_of_dicts : list + List of dictionaries to be collated. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary containing batched data. + """ + + from . import AudioSignal + + batches = [] + list_len = len(list_of_dicts) + + return_list = False if n_splits is None else True + n_splits = 1 if n_splits is None else n_splits + n_items = int(math.ceil(list_len / n_splits)) + + for i in range(0, list_len, n_items): + # Flatten the dictionaries to avoid recursion. + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] + } + + batch = {} + for k, v in dict_of_lists.items(): + if isinstance(v, list): + if all(isinstance(s, AudioSignal) for s in v): + batch[k] = AudioSignal.batch(v, pad_signals=True) + else: + # Borrow the default collate fn from torch. + batch[k] = torch.utils.data._utils.collate.default_collate(v) + batches.append(unflatten(batch)) + + batches = batches[0] if not return_list else batches + return batches + + +BASE_SIZE = 864 +DEFAULT_FIG_SIZE = (9, 3) + + +def format_figure( + fig_size: tuple = None, + title: str = None, + fig=None, + format_axes: bool = True, + format: bool = True, + font_color: str = "white", +): + """Prettifies the spectrogram and waveform plots. A title + can be inset into the top right corner, and the axes can be + inset into the figure, allowing the data to take up the entire + image. Used in + + - :py:func:`audiotools.core.display.DisplayMixin.specshow` + - :py:func:`audiotools.core.display.DisplayMixin.waveplot` + - :py:func:`audiotools.core.display.DisplayMixin.wavespec` + + Parameters + ---------- + fig_size : tuple, optional + Size of figure, by default (9, 3) + title : str, optional + Title to inset in top right, by default None + fig : matplotlib.figure.Figure, optional + Figure object, if None ``plt.gcf()`` will be used, by default None + format_axes : bool, optional + Format the axes to be inside the figure, by default True + format : bool, optional + This formatting can be skipped entirely by passing ``format=False`` + to any of the plotting functions that use this formater, by default True + font_color : str, optional + Color of font of axes, by default "white" + """ + import matplotlib + import matplotlib.pyplot as plt + + if fig_size is None: + fig_size = DEFAULT_FIG_SIZE + if not format: + return + if fig is None: + fig = plt.gcf() + fig.set_size_inches(*fig_size) + axs = fig.axes + + pixels = (fig.get_size_inches() * fig.dpi)[0] + font_scale = pixels / BASE_SIZE + + if format_axes: + axs = fig.axes + + for ax in axs: + ymin, _ = ax.get_ylim() + xmin, _ = ax.get_xlim() + + ticks = ax.get_yticks() + for t in ticks[2:-1]: + t = axs[0].annotate( + f"{(t / 1000):2.1f}k", + xy=(xmin, t), + xycoords="data", + xytext=(5, -5), + textcoords="offset points", + ha="left", + va="top", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ticks = ax.get_xticks()[2:] + for t in ticks[:-1]: + t = axs[0].annotate( + f"{t:2.1f}s", + xy=(t, ymin), + xycoords="data", + xytext=(5, 5), + textcoords="offset points", + ha="center", + va="bottom", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ax.margins(0, 0) + ax.set_axis_off() + ax.xaxis.set_major_locator(plt.NullLocator()) + ax.yaxis.set_major_locator(plt.NullLocator()) + + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + + if title is not None: + t = axs[0].annotate( + title, + xy=(1, 1), + xycoords="axes fraction", + fontsize=20 * font_scale, + xytext=(-5, -5), + textcoords="offset points", + ha="right", + va="top", + color="white", + ) + t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) + + +def generate_chord_dataset( + max_voices: int = 8, + sample_rate: int = 44100, + num_items: int = 5, + duration: float = 1.0, + min_note: str = "C2", + max_note: str = "C6", + output_dir: Path = "chords", +): + """ + Generates a toy multitrack dataset of chords, synthesized from sine waves. + + + Parameters + ---------- + max_voices : int, optional + Maximum number of voices in a chord, by default 8 + sample_rate : int, optional + Sample rate of audio, by default 44100 + num_items : int, optional + Number of items to generate, by default 5 + duration : float, optional + Duration of each item, by default 1.0 + min_note : str, optional + Minimum note in the dataset, by default "C2" + max_note : str, optional + Maximum note in the dataset, by default "C6" + output_dir : Path, optional + Directory to save the dataset, by default "chords" + + """ + import librosa + from . import AudioSignal + from ..data.preprocess import create_csv + + min_midi = librosa.note_to_midi(min_note) + max_midi = librosa.note_to_midi(max_note) + + tracks = [] + for idx in range(num_items): + track = {} + # figure out how many voices to put in this track + num_voices = random.randint(1, max_voices) + for voice_idx in range(num_voices): + # choose some random params + midinote = random.randint(min_midi, max_midi) + dur = random.uniform(0.85 * duration, duration) + + sig = AudioSignal.wave( + frequency=librosa.midi_to_hz(midinote), + duration=dur, + sample_rate=sample_rate, + shape="sine", + ) + track[f"voice_{voice_idx}"] = sig + tracks.append(track) + + # save the tracks to disk + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + for idx, track in enumerate(tracks): + track_dir = output_dir / f"track_{idx}" + track_dir.mkdir(exist_ok=True) + for voice_name, sig in track.items(): + sig.write(track_dir / f"{voice_name}.wav") + + all_voices = list(set([k for track in tracks for k in track.keys()])) + voice_lists = {voice: [] for voice in all_voices} + for track in tracks: + for voice_name in all_voices: + if voice_name in track: + voice_lists[voice_name].append(track[voice_name].path_to_file) + else: + voice_lists[voice_name].append("") + + for voice_name, paths in voice_lists.items(): + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) + + return output_dir diff --git a/flowae/models/ldm/dac/audiotools/core/whisper.py b/flowae/models/ldm/dac/audiotools/core/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..46c071f934fc3e2be3138e7596b1c6d2ef79eade --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/core/whisper.py @@ -0,0 +1,97 @@ +import torch + + +class WhisperMixin: + is_initialized = False + + def setup_whisper( + self, + pretrained_model_name_or_path: str = "openai/whisper-base.en", + device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ): + from transformers import WhisperForConditionalGeneration + from transformers import WhisperProcessor + + self.whisper_device = device + self.whisper_processor = WhisperProcessor.from_pretrained( + pretrained_model_name_or_path + ) + self.whisper_model = WhisperForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path + ).to(self.whisper_device) + self.is_initialized = True + + def get_whisper_features(self) -> torch.Tensor: + """Preprocess audio signal as per the whisper model's training config. + + Returns + ------- + torch.Tensor + The prepinput features of the audio signal. Shape: (1, channels, seq_len) + """ + import torch + + if not self.is_initialized: + self.setup_whisper() + + signal = self.to(self.device) + raw_speech = list( + ( + signal.clone() + .resample(self.whisper_processor.feature_extractor.sampling_rate) + .audio_data[:, 0, :] + .numpy() + ) + ) + + with torch.inference_mode(): + input_features = self.whisper_processor( + raw_speech, + sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, + return_tensors="pt", + ).input_features + + return input_features + + def get_whisper_transcript(self) -> str: + """Get the transcript of the audio signal using the whisper model. + + Returns + ------- + str + The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. + """ + + if not self.is_initialized: + self.setup_whisper() + + input_features = self.get_whisper_features() + + with torch.inference_mode(): + input_features = input_features.to(self.whisper_device) + generated_ids = self.whisper_model.generate(inputs=input_features) + + transcription = self.whisper_processor.batch_decode(generated_ids) + return transcription[0] + + def get_whisper_embeddings(self) -> torch.Tensor: + """Get the last hidden state embeddings of the audio signal using the whisper model. + + Returns + ------- + torch.Tensor + The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) + """ + import torch + + if not self.is_initialized: + self.setup_whisper() + + input_features = self.get_whisper_features() + encoder = self.whisper_model.get_encoder() + + with torch.inference_mode(): + input_features = input_features.to(self.whisper_device) + embeddings = encoder(input_features) + + return embeddings.last_hidden_state diff --git a/flowae/models/ldm/dac/audiotools/data/__init__.py b/flowae/models/ldm/dac/audiotools/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aead269f26f3782043e68418b4c87ee323cbd015 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/data/__init__.py @@ -0,0 +1,3 @@ +from . import datasets +from . import preprocess +from . import transforms diff --git a/flowae/models/ldm/dac/audiotools/data/datasets.py b/flowae/models/ldm/dac/audiotools/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..12e7a60963399aa15ff865de2d06537818ce18ee --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/data/datasets.py @@ -0,0 +1,517 @@ +from pathlib import Path +from typing import Callable +from typing import Dict +from typing import List +from typing import Union + +import numpy as np +from torch.utils.data import SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from ..core import AudioSignal +from ..core import util + + +class AudioLoader: + """Loads audio endlessly from a list of audio sources + containing paths to audio files. Audio sources can be + folders full of audio files (which are found via file + extension) or by providing a CSV file which contains paths + to audio files. + + Parameters + ---------- + sources : List[str], optional + Sources containing folders, or CSVs with + paths to audio files, by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + relative_path : str, optional + Path audio should be loaded relative to, by default "" + transform : Callable, optional + Transform to instantiate alongside audio sample, + by default None + ext : List[str] + List of extensions to find audio within each source by. Can + also be a file name (e.g. "vocals.wav"). by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + shuffle: bool + Whether to shuffle the files within the dataloader. Defaults to True. + shuffle_state: int + State to use to seed the shuffle of the files. + """ + + def __init__( + self, + sources: List[str] = None, + weights: List[float] = None, + transform: Callable = None, + relative_path: str = "", + ext: List[str] = util.AUDIO_EXTENSIONS, + shuffle: bool = True, + shuffle_state: int = 0, + ): + self.audio_lists = util.read_sources( + sources, relative_path=relative_path, ext=ext + ) + + self.audio_indices = [ + (src_idx, item_idx) + for src_idx, src in enumerate(self.audio_lists) + for item_idx in range(len(src)) + ] + if shuffle: + state = util.random_state(shuffle_state) + state.shuffle(self.audio_indices) + + self.sources = sources + self.weights = weights + self.transform = transform + + def __call__( + self, + state, + sample_rate: int, + duration: float, + loudness_cutoff: float = -40, + num_channels: int = 1, + offset: float = None, + source_idx: int = None, + item_idx: int = None, + global_idx: int = None, + ): + if source_idx is not None and item_idx is not None: + try: + audio_info = self.audio_lists[source_idx][item_idx] + except: + audio_info = {"path": "none"} + elif global_idx is not None: + source_idx, item_idx = self.audio_indices[ + global_idx % len(self.audio_indices) + ] + audio_info = self.audio_lists[source_idx][item_idx] + else: + audio_info, source_idx, item_idx = util.choose_from_list_of_lists( + state, self.audio_lists, p=self.weights + ) + + path = audio_info["path"] + signal = AudioSignal.zeros(duration, sample_rate, num_channels) + + if path != "none": + if offset is None: + signal = AudioSignal.salient_excerpt( + path, + duration=duration, + state=state, + loudness_cutoff=loudness_cutoff, + ) + else: + signal = AudioSignal( + path, + offset=offset, + duration=duration, + ) + + if num_channels == 1: + signal = signal.to_mono() + signal = signal.resample(sample_rate) + + if signal.duration < duration: + signal = signal.zero_pad_to(int(duration * sample_rate)) + + for k, v in audio_info.items(): + signal.metadata[k] = v + + item = { + "signal": signal, + "source_idx": source_idx, + "item_idx": item_idx, + "source": str(self.sources[source_idx]), + "path": str(path), + } + if self.transform is not None: + item["transform_args"] = self.transform.instantiate(state, signal=signal) + return item + + +def default_matcher(x, y): + return Path(x).parent == Path(y).parent + + +def align_lists(lists, matcher: Callable = default_matcher): + longest_list = lists[np.argmax([len(l) for l in lists])] + for i, x in enumerate(longest_list): + for l in lists: + if i >= len(l): + l.append({"path": "none"}) + elif not matcher(l[i]["path"], x["path"]): + l.insert(i, {"path": "none"}) + return lists + + +class AudioDataset: + """Loads audio from multiple loaders (with associated transforms) + for a specified number of samples. Excerpts are drawn randomly + of the specified duration, above a specified loudness threshold + and are resampled on the fly to the desired sample rate + (if it is different from the audio source sample rate). + + This takes either a single AudioLoader object, + a dictionary of AudioLoader objects, or a dictionary of AudioLoader + objects. Each AudioLoader is called by the dataset, and the + result is placed in the output dictionary. A transform can also be + specified for the entire dataset, rather than for each specific + loader. This transform can be applied to the output of all the + loaders if desired. + + AudioLoader objects can be specified as aligned, which means the + loaders correspond to multitrack audio (e.g. a vocals, bass, + drums, and other loader for multitrack music mixtures). + + + Parameters + ---------- + loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] + AudioLoaders to sample audio from. + sample_rate : int + Desired sample rate. + n_examples : int, optional + Number of examples (length of dataset), by default 1000 + duration : float, optional + Duration of audio samples, by default 0.5 + loudness_cutoff : float, optional + Loudness cutoff threshold for audio samples, by default -40 + num_channels : int, optional + Number of channels in output audio, by default 1 + transform : Callable, optional + Transform to instantiate alongside each dataset item, by default None + aligned : bool, optional + Whether the loaders should be sampled in an aligned manner (e.g. same + offset, duration, and matched file name), by default False + shuffle_loaders : bool, optional + Whether to shuffle the loaders before sampling from them, by default False + matcher : Callable + How to match files from adjacent audio lists (e.g. for a multitrack audio loader), + by default uses the parent directory of each file. + without_replacement : bool + Whether to choose files with or without replacement, by default True. + + + Examples + -------- + >>> from audiotools.data.datasets import AudioLoader + >>> from audiotools.data.datasets import AudioDataset + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> + >>> loaders = [ + >>> AudioLoader( + >>> sources=[f"tests/audio/spk"], + >>> transform=tfm.Equalizer(), + >>> ext=["wav"], + >>> ) + >>> for i in range(5) + >>> ] + >>> + >>> dataset = AudioDataset( + >>> loaders = loaders, + >>> sample_rate = 44100, + >>> duration = 1.0, + >>> transform = tfm.RescaleAudio(), + >>> ) + >>> + >>> item = dataset[np.random.randint(len(dataset))] + >>> + >>> for i in range(len(loaders)): + >>> item[i]["signal"] = loaders[i].transform( + >>> item[i]["signal"], **item[i]["transform_args"] + >>> ) + >>> item[i]["signal"].widget(i) + >>> + >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) + >>> mix = dataset.transform(mix, **item["transform_args"]) + >>> mix.widget("mix") + + Below is an example of how one could load MUSDB multitrack data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> + >>> def build_dataset( + >>> sample_rate: int = 44100, + >>> duration: float = 5.0, + >>> musdb_path: str = "~/.data/musdb/", + >>> ): + >>> musdb_path = Path(musdb_path).expanduser() + >>> loaders = { + >>> src: at.datasets.AudioLoader( + >>> sources=[musdb_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"{src}.wav"], + >>> ) + >>> for src in ["vocals", "bass", "drums", "other"] + >>> } + >>> + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=True, + >>> ) + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + >>> + >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). + >>> # Construct the targets: + >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) + + Similarly, here's example code for loading Slakh data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> import glob + >>> + >>> def build_dataset( + >>> sample_rate: int = 16000, + >>> duration: float = 10.0, + >>> slakh_path: str = "~/.data/slakh/", + >>> ): + >>> slakh_path = Path(slakh_path).expanduser() + >>> + >>> # Find the max number of sources in Slakh + >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] + >>> n_sources = len(list(set(src_names))) + >>> + >>> loaders = { + >>> f"S{i:02d}": at.datasets.AudioLoader( + >>> sources=[slakh_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"S{i:02d}.wav"], + >>> ) + >>> for i in range(n_sources) + >>> } + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=False, + >>> ) + >>> + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + + """ + + def __init__( + self, + loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], + sample_rate: int, + n_examples: int = 1000, + duration: float = 0.5, + offset: float = None, + loudness_cutoff: float = -40, + num_channels: int = 1, + transform: Callable = None, + aligned: bool = False, + shuffle_loaders: bool = False, + matcher: Callable = default_matcher, + without_replacement: bool = True, + ): + # Internally we convert loaders to a dictionary + if isinstance(loaders, list): + loaders = {i: l for i, l in enumerate(loaders)} + elif isinstance(loaders, AudioLoader): + loaders = {0: loaders} + + self.loaders = loaders + self.loudness_cutoff = loudness_cutoff + self.num_channels = num_channels + + self.length = n_examples + self.transform = transform + self.sample_rate = sample_rate + self.duration = duration + self.offset = offset + self.aligned = aligned + self.shuffle_loaders = shuffle_loaders + self.without_replacement = without_replacement + + if aligned: + loaders_list = list(loaders.values()) + for i in range(len(loaders_list[0].audio_lists)): + input_lists = [l.audio_lists[i] for l in loaders_list] + # Alignment happens in-place + align_lists(input_lists, matcher) + + def __getitem__(self, idx): + state = util.random_state(idx) + offset = None if self.offset is None else self.offset + item = {} + + keys = list(self.loaders.keys()) + if self.shuffle_loaders: + state.shuffle(keys) + + loader_kwargs = { + "state": state, + "sample_rate": self.sample_rate, + "duration": self.duration, + "loudness_cutoff": self.loudness_cutoff, + "num_channels": self.num_channels, + "global_idx": idx if self.without_replacement else None, + } + + # Draw item from first loader + loader = self.loaders[keys[0]] + item[keys[0]] = loader(**loader_kwargs) + + for key in keys[1:]: + loader = self.loaders[key] + if self.aligned: + # Path mapper takes the current loader + everything + # returned by the first loader. + offset = item[keys[0]]["signal"].metadata["offset"] + loader_kwargs.update( + { + "offset": offset, + "source_idx": item[keys[0]]["source_idx"], + "item_idx": item[keys[0]]["item_idx"], + } + ) + item[key] = loader(**loader_kwargs) + + # Sort dictionary back into original order + keys = list(self.loaders.keys()) + item = {k: item[k] for k in keys} + + item["idx"] = idx + if self.transform is not None: + item["transform_args"] = self.transform.instantiate( + state=state, signal=item[keys[0]]["signal"] + ) + + # If there's only one loader, pop it up + # to the main dictionary, instead of keeping it + # nested. + if len(keys) == 1: + item.update(item.pop(keys[0])) + + return item + + def __len__(self): + return self.length + + @staticmethod + def collate(list_of_dicts: Union[list, dict], n_splits: int = None): + """Collates items drawn from this dataset. Uses + :py:func:`audiotools.core.util.collate`. + + Parameters + ---------- + list_of_dicts : typing.Union[list, dict] + Data drawn from each item. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary of batched data. + """ + return util.collate(list_of_dicts, n_splits=n_splits) + + +class ConcatDataset(AudioDataset): + def __init__(self, datasets: list): + self.datasets = datasets + + def __len__(self): + return sum([len(d) for d in self.datasets]) + + def __getitem__(self, idx): + dataset = self.datasets[idx % len(self.datasets)] + return dataset[idx // len(self.datasets)] + + +class ResumableDistributedSampler(DistributedSampler): # pragma: no cover + """Distributed sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int = None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch + + +class ResumableSequentialSampler(SequentialSampler): # pragma: no cover + """Sequential sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int = None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch diff --git a/flowae/models/ldm/dac/audiotools/data/preprocess.py b/flowae/models/ldm/dac/audiotools/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..d90de210115e45838bc8d69b350f7516ba730406 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/data/preprocess.py @@ -0,0 +1,81 @@ +import csv +import os +from pathlib import Path + +from tqdm import tqdm + +from ..core import AudioSignal + + +def create_csv( + audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None +): + """Converts a folder of audio files to a CSV file. If ``loudness = True``, + the output of this function will create a CSV file that looks something + like: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + .. note:: + The paths above are written relative to the ``data_path`` argument + which defaults to the environment variable ``PATH_TO_DATA`` if + it isn't passed to this function, and defaults to the empty string + if that environment variable is not set. + + You can produce a CSV file from a directory of audio files via: + + >>> import audiotools + >>> directory = ... + >>> audio_files = audiotools.util.find_audio(directory) + >>> output_path = "train.csv" + >>> audiotools.data.preprocess.create_csv( + >>> audio_files, output_csv, loudness=True + >>> ) + + Note that you can create empty rows in the CSV file by passing an empty + string or None in the ``audio_files`` list. This is useful if you want to + sync multiple CSV files in a multitrack setting. The loudness of these + empty rows will be set to -inf. + + Parameters + ---------- + audio_files : list + List of audio files. + output_csv : Path + Output CSV, with each row containing the relative path of every file + to ``data_path``, if specified (defaults to None). + loudness : bool + Compute loudness of entire file and store alongside path. + """ + + info = [] + pbar = tqdm(audio_files) + for af in pbar: + af = Path(af) + pbar.set_description(f"Processing {af.name}") + _info = {} + if af.name == "": + _info["path"] = "" + if loudness: + _info["loudness"] = -float("inf") + else: + _info["path"] = af.relative_to(data_path) if data_path is not None else af + if loudness: + _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() + + info.append(_info) + + with open(output_csv, "w") as f: + writer = csv.DictWriter(f, fieldnames=list(info[0].keys())) + writer.writeheader() + + for item in info: + writer.writerow(item) diff --git a/flowae/models/ldm/dac/audiotools/data/transforms.py b/flowae/models/ldm/dac/audiotools/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..504e87dc61777e36ba95eb794f497bed4cdc7d2c --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/data/transforms.py @@ -0,0 +1,1592 @@ +import copy +from contextlib import contextmanager +from inspect import signature +from typing import List + +import numpy as np +import torch +from flatten_dict import flatten +from flatten_dict import unflatten +from numpy.random import RandomState + +from .. import ml +from ..core import AudioSignal +from ..core import util +from .datasets import AudioLoader + +tt = torch.tensor +"""Shorthand for converting things to torch.tensor.""" + + +class BaseTransform: + """This is the base class for all transforms that are implemented + in this library. Transforms have two main operations: ``transform`` + and ``instantiate``. + + ``instantiate`` sets the parameters randomly + from distribution tuples for each parameter. For example, for the + ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``) + is chosen randomly by instantiate. By default, it chosen uniformly + between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``). + + ``transform`` applies the transform using the instantiated parameters. + A simple example is as follows: + + >>> seed = 0 + >>> signal = ... + >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0)) + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + By breaking apart the instantiation of parameters from the actual audio + processing of the transform, we can make things more reproducible, while + also applying the transform on batches of data efficiently on GPU, + rather than on individual audio samples. + + .. note:: + We call ``signal.clone()`` for the input to the ``transform`` function + because signals are modified in-place! If you don't clone the signal, + you will lose the original data. + + Parameters + ---------- + keys : list, optional + Keys that the transform looks for when + calling ``self.transform``, by default []. In general this is + set automatically, and you won't need to manipulate this argument. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> seed = 0 + >>> + >>> audio_path = "tests/audio/spk/f10_script4_produced.wav" + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + >>> + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal, **kwargs) + + """ + + def __init__(self, keys: list = [], name: str = None, prob: float = 1.0): + # Get keys from the _transform signature. + tfm_keys = list(signature(self._transform).parameters.keys()) + + # Filter out signal and kwargs keys. + ignore_keys = ["signal", "kwargs"] + tfm_keys = [k for k in tfm_keys if k not in ignore_keys] + + # Combine keys specified by the child class, the keys found in + # _transform signature, and the mask key. + self.keys = keys + tfm_keys + ["mask"] + + self.prob = prob + + if name is None: + name = self.__class__.__name__ + self.name = name + + def _prepare(self, batch: dict): + sub_batch = batch[self.name] + + for k in self.keys: + assert k in sub_batch.keys(), f"{k} not in batch" + + return sub_batch + + def _transform(self, signal): + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + return {} + + @staticmethod + def apply_mask(batch: dict, mask: torch.Tensor): + """Applies a mask to the batch. + + Parameters + ---------- + batch : dict + Batch whose values will be masked in the ``transform`` pass. + mask : torch.Tensor + Mask to apply to batch. + + Returns + ------- + dict + A dictionary that contains values only where ``mask = True``. + """ + masked_batch = {k: v[mask] for k, v in flatten(batch).items()} + return unflatten(masked_batch) + + def transform(self, signal: AudioSignal, **kwargs): + """Apply the transform to the audio signal, + with given keyword arguments. + + Parameters + ---------- + signal : AudioSignal + Signal that will be modified by the transforms in-place. + kwargs: dict + Keyword arguments to the specific transforms ``self._transform`` + function. + + Returns + ------- + AudioSignal + Transformed AudioSignal. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + tfm_kwargs = self._prepare(kwargs) + mask = tfm_kwargs["mask"] + + if torch.any(mask): + tfm_kwargs = self.apply_mask(tfm_kwargs, mask) + tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"} + signal[mask] = self._transform(signal[mask], **tfm_kwargs) + + return signal + + def __call__(self, *args, **kwargs): + return self.transform(*args, **kwargs) + + def instantiate( + self, + state: RandomState = None, + signal: AudioSignal = None, + ): + """Instantiates parameters for the transform. + + Parameters + ---------- + state : RandomState, optional + _description_, by default None + signal : AudioSignal, optional + _description_, by default None + + Returns + ------- + dict + Dictionary containing instantiated arguments for every keyword + argument to ``self._transform``. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + state = util.random_state(state) + + # Not all instantiates need the signal. Check if signal + # is needed before passing it in, so that the end-user + # doesn't need to have variables they're not using flowing + # into their function. + needs_signal = "signal" in set(signature(self._instantiate).parameters.keys()) + kwargs = {} + if needs_signal: + kwargs = {"signal": signal} + + # Instantiate the parameters for the transform. + params = self._instantiate(state, **kwargs) + for k in list(params.keys()): + v = params[k] + if isinstance(v, (AudioSignal, torch.Tensor, dict)): + params[k] = v + else: + params[k] = tt(v) + mask = state.rand() <= self.prob + params[f"mask"] = tt(mask) + + # Put the params into a nested dictionary that will be + # used later when calling the transform. This is to avoid + # collisions in the dictionary. + params = {self.name: params} + + return params + + def batch_instantiate( + self, + states: list = None, + signal: AudioSignal = None, + ): + """Instantiates arguments for every item in a batch, + given a list of states. Each state in the list + corresponds to one item in the batch. + + Parameters + ---------- + states : list, optional + List of states, by default None + signal : AudioSignal, optional + AudioSignal to pass to the ``self.instantiate`` section + if it is needed for this transform, by default None + + Returns + ------- + dict + Collated dictionary of arguments. + + Examples + -------- + + >>> batch_size = 4 + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) + >>> + >>> states = [seed + idx for idx in list(range(batch_size))] + >>> kwargs = transform.batch_instantiate(states, signal_batch) + >>> batch_output = transform(signal_batch, **kwargs) + """ + kwargs = [] + for state in states: + kwargs.append(self.instantiate(state, signal)) + kwargs = util.collate(kwargs) + return kwargs + + +class Identity(BaseTransform): + """This transform just returns the original signal.""" + + pass + + +class SpectralTransform(BaseTransform): + """Spectral transforms require STFT data to exist, since manipulations + of the STFT require the spectrogram. This just calls ``stft`` before + the transform is called, and calls ``istft`` after the transform is + called so that the audio data is written to after the spectral + manipulation. + """ + + def transform(self, signal, **kwargs): + signal.stft() + super().transform(signal, **kwargs) + signal.istft() + return signal + + +class Compose(BaseTransform): + """Compose applies transforms in sequence, one after the other. The + transforms are passed in as positional arguments or as a list like so: + + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + + This will convolve the signal with a room impulse response, and then + add background noise to the signal. Instantiate instantiates + all the parameters for every transform in the transform list so the + interface for using the Compose transform is the same as everything + else: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + Under the hood, the transform maps each transform to a unique name + under the hood of the form ``{position}.{name}``, where ``position`` + is the index of the transform in the list. ``Compose`` can nest + within other ``Compose`` transforms, like so: + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess), + + This defines 3 composed transforms, and then composes them in sequence + with one another. + + Parameters + ---------- + *transforms : list + List of transforms to apply + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, *transforms: list, name: str = None, prob: float = 1.0): + if isinstance(transforms[0], list): + transforms = transforms[0] + + for i, tfm in enumerate(transforms): + tfm.name = f"{i}.{tfm.name}" + + keys = [tfm.name for tfm in transforms] + super().__init__(keys=keys, name=name, prob=prob) + + self.transforms = transforms + self.transforms_to_apply = keys + + @contextmanager + def filter(self, *names: list): + """This can be used to skip transforms entirely when applying + the sequence of transforms to a signal. For example, take + the following transforms with the names ``preprocess, augment, postprocess``. + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess) + + If we wanted to apply all 3 to a signal, we do: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + But if we only wanted to apply the ``preprocess`` and ``postprocess`` + transforms to the signal, we do: + + >>> with transform_fn.filter("preprocess", "postprocess"): + >>> output = transform(signal.clone(), **kwargs) + + Parameters + ---------- + *names : list + List of transforms, identified by name, to apply to signal. + """ + old_transforms = self.transforms_to_apply + self.transforms_to_apply = names + yield + self.transforms_to_apply = old_transforms + + def _transform(self, signal, **kwargs): + for transform in self.transforms: + if any([x in transform.name for x in self.transforms_to_apply]): + signal = transform(signal, **kwargs) + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + parameters = {} + for transform in self.transforms: + parameters.update(transform.instantiate(state, signal=signal)) + return parameters + + def __getitem__(self, idx): + return self.transforms[idx] + + def __len__(self): + return len(self.transforms) + + def __iter__(self): + for transform in self.transforms: + yield transform + + +class Choose(Compose): + """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`, + but instead of applying all the transforms in sequence, it applies just a single transform, + which is chosen for each item in the batch. + + Parameters + ---------- + *transforms : list + List of transforms to apply + weights : list + Probability of choosing any specific transform. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> transforms.Choose(tfm.LowPass(), tfm.HighPass()) + """ + + def __init__( + self, + *transforms: list, + weights: list = None, + name: str = None, + prob: float = 1.0, + ): + super().__init__(*transforms, name=name, prob=prob) + + if weights is None: + _len = len(self.transforms) + weights = [1 / _len for _ in range(_len)] + self.weights = np.array(weights) + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + kwargs = super()._instantiate(state, signal) + tfm_idx = list(range(len(self.transforms))) + tfm_idx = state.choice(tfm_idx, p=self.weights) + one_hot = [] + for i, t in enumerate(self.transforms): + mask = kwargs[t.name]["mask"] + if mask.item(): + kwargs[t.name]["mask"] = tt(i == tfm_idx) + one_hot.append(kwargs[t.name]["mask"]) + kwargs["one_hot"] = one_hot + return kwargs + + +class Repeat(Compose): + """Repeatedly applies a given transform ``n_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + n_repeat : int, optional + Number of times to repeat transform, by default 1 + """ + + def __init__( + self, + transform, + n_repeat: int = 1, + name: str = None, + prob: float = 1.0, + ): + transforms = [copy.copy(transform) for _ in range(n_repeat)] + super().__init__(transforms, name=name, prob=prob) + + self.n_repeat = n_repeat + + +class RepeatUpTo(Choose): + """Repeatedly applies a given transform up to ``max_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + max_repeat : int, optional + Max number of times to repeat transform, by default 1 + weights : list + Probability of choosing any specific number up to ``max_repeat``. + """ + + def __init__( + self, + transform, + max_repeat: int = 5, + weights: list = None, + name: str = None, + prob: float = 1.0, + ): + transforms = [] + for n in range(1, max_repeat): + transforms.append(Repeat(transform, n_repeat=n)) + super().__init__(transforms, name=name, prob=prob, weights=weights) + + self.max_repeat = max_repeat + + +class ClippingDistortion(BaseTransform): + """Adds clipping distortion to signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`. + + Parameters + ---------- + perc : tuple, optional + Clipping percentile. Values are between 0.0 to 1.0. + Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + perc: tuple = ("uniform", 0.0, 0.1), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.perc = perc + + def _instantiate(self, state: RandomState): + return {"perc": util.sample_from_dist(self.perc, state)} + + def _transform(self, signal, perc): + return signal.clip_distortion(perc) + + +class Equalizer(BaseTransform): + """Applies an equalization curve to the audio signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. + + Parameters + ---------- + eq_amount : tuple, optional + The maximum dB cut to apply to the audio in any band, + by default ("const", 1.0 dB) + n_bands : int, optional + Number of bands in EQ, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 6, + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.eq_amount = eq_amount + self.n_bands = n_bands + + def _instantiate(self, state: RandomState): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + return {"eq": eq} + + def _transform(self, signal, eq): + return signal.equalizer(eq) + + +class Quantization(BaseTransform): + """Applies quantization to the input waveform. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.quantization`. + + Parameters + ---------- + channels : tuple, optional + Number of evenly spaced quantization channels to quantize + to, by default ("choice", [8, 32, 128, 256, 1024]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + channels: tuple = ("choice", [8, 32, 128, 256, 1024]), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.channels = channels + + def _instantiate(self, state: RandomState): + return {"channels": util.sample_from_dist(self.channels, state)} + + def _transform(self, signal, channels): + return signal.quantization(channels) + + +class MuLawQuantization(BaseTransform): + """Applies mu-law quantization to the input waveform. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`. + + Parameters + ---------- + channels : tuple, optional + Number of mu-law spaced quantization channels to quantize + to, by default ("choice", [8, 32, 128, 256, 1024]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + channels: tuple = ("choice", [8, 32, 128, 256, 1024]), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.channels = channels + + def _instantiate(self, state: RandomState): + return {"channels": util.sample_from_dist(self.channels, state)} + + def _transform(self, signal, channels): + return signal.mulaw_quantization(channels) + + +class NoiseFloor(BaseTransform): + """Adds a noise floor of Gaussian noise to the signal at a specified + dB. + + Parameters + ---------- + db : tuple, optional + Level of noise to add to signal, by default ("const", -50.0) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -50.0), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState, signal: AudioSignal): + db = util.sample_from_dist(self.db, state) + audio_data = state.randn(signal.num_channels, signal.signal_length) + nz_signal = AudioSignal(audio_data, signal.sample_rate) + nz_signal.normalize(db) + return {"nz_signal": nz_signal} + + def _transform(self, signal, nz_signal): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal + nz_signal + + +class BackgroundNoise(BaseTransform): + """Adds background noise from audio specified by a set of CSV files. + A valid CSV file looks like, and is typically generated by + :py:func:`audiotools.data.preprocess.create_csv`: + + .. csv-table:: + :header: path + + room_tone/m6_script2_clean.wav + room_tone/m6_script2_cleanraw.wav + room_tone/m6_script2_ipad_balcony1.wav + room_tone/m6_script2_ipad_bedroom1.wav + room_tone/m6_script2_ipad_confroom1.wav + room_tone/m6_script2_ipad_confroom2.wav + room_tone/m6_script2_ipad_livingroom1.wav + room_tone/m6_script2_ipad_office1.wav + + .. note:: + All paths are relative to an environment variable called ``PATH_TO_DATA``, + so that CSV files are portable across machines where data may be + located in different places. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the + hood. + + Parameters + ---------- + snr : tuple, optional + Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 3 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default None + """ + + def __init__( + self, + snr: tuple = ("uniform", 10.0, 30.0), + sources: List[str] = None, + weights: List[float] = None, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 3, + name: str = None, + prob: float = 1.0, + loudness_cutoff: float = None, + ): + super().__init__(name=name, prob=prob) + + self.snr = snr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + snr = util.sample_from_dist(self.snr, state) + + bg_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, + )["signal"] + + return {"eq": eq, "bg_signal": bg_signal, "snr": snr} + + def _transform(self, signal, bg_signal, snr, eq): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.mix(bg_signal.clone(), snr, eq) + + +class CrossTalk(BaseTransform): + """Adds crosstalk between speakers, whose audio is drawn from a CSV file + that was produced via :py:func:`audiotools.data.preprocess.create_csv`. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + under the hood. + + Parameters + ---------- + snr : tuple, optional + How loud cross-talk speaker is relative to original signal in dB, + by default ("uniform", 0.0, 10.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default -40 + """ + + def __init__( + self, + snr: tuple = ("uniform", 0.0, 10.0), + sources: List[str] = None, + weights: List[float] = None, + name: str = None, + prob: float = 1.0, + loudness_cutoff: float = -40, + ): + super().__init__(name=name, prob=prob) + + self.snr = snr + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal): + snr = util.sample_from_dist(self.snr, state) + crosstalk_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, + )["signal"] + + return {"crosstalk_signal": crosstalk_signal, "snr": snr} + + def _transform(self, signal, crosstalk_signal, snr): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + loudness = signal.loudness() + mix = signal.mix(crosstalk_signal.clone(), snr) + mix.normalize(loudness) + return mix + + +class RoomImpulseResponse(BaseTransform): + """Convolves signal with a room impulse response, at a specified + direct-to-reverberant ratio, with equalization applied. Room impulse + response data is drawn from a CSV file that was produced via + :py:func:`audiotools.data.preprocess.create_csv`. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` + under the hood. + + Parameters + ---------- + drr : tuple, optional + _description_, by default ("uniform", 0.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + use_original_phase : bool, optional + Whether or not to use the original phase, by default False + offset : float, optional + Offset from each impulse response file to use, by default 0.0 + duration : float, optional + Duration of each impulse response, by default 1.0 + """ + + def __init__( + self, + drr: tuple = ("uniform", 0.0, 30.0), + sources: List[str] = None, + weights: List[float] = None, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 6, + name: str = None, + prob: float = 1.0, + use_original_phase: bool = False, + offset: float = 0.0, + duration: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.drr = drr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.use_original_phase = use_original_phase + + self.loader = AudioLoader(sources, weights) + self.offset = offset + self.duration = duration + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + drr = util.sample_from_dist(self.drr, state) + + ir_signal = self.loader( + state, + signal.sample_rate, + offset=self.offset, + duration=self.duration, + loudness_cutoff=None, + num_channels=signal.num_channels, + )["signal"] + ir_signal.zero_pad_to(signal.sample_rate) + + return {"eq": eq, "ir_signal": ir_signal, "drr": drr} + + def _transform(self, signal, ir_signal, drr, eq): + # Clone ir_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.apply_ir( + ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase + ) + + +class VolumeChange(BaseTransform): + """Changes the volume of the input signal. + + Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. + + Parameters + ---------- + db : tuple, optional + Change in volume in decibels, by default ("uniform", -12.0, 0.0) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("uniform", -12.0, 0.0), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + self.db = db + + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} + + def _transform(self, signal, db): + return signal.volume_change(db) + + +class VolumeNorm(BaseTransform): + """Normalizes the volume of the excerpt to a specified decibel. + + Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -24), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} + + def _transform(self, signal, db): + return signal.normalize(db) + + +class GlobalVolumeNorm(BaseTransform): + """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this + transform also normalizes the volume of a signal, but it uses + the volume of the entire audio file the loaded excerpt comes from, + rather than the volume of just the excerpt. The volume of the + entire audio file is expected in ``signal.metadata["loudness"]``. + If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv` + with ``loudness = True``, like the following: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + The ``AudioLoader`` will automatically load the loudness column into + the metadata of the signal. + + Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -24), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState, signal: AudioSignal): + if "loudness" not in signal.metadata: + db_change = 0.0 + elif float(signal.metadata["loudness"]) == float("-inf"): + db_change = 0.0 + else: + db = util.sample_from_dist(self.db, state) + db_change = db - float(signal.metadata["loudness"]) + + return {"db": db_change} + + def _transform(self, signal, db): + return signal.volume_change(db) + + +class Silence(BaseTransform): + """Zeros out the signal with some probability. + + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 0.1 + """ + + def __init__(self, name: str = None, prob: float = 0.1): + super().__init__(name=name, prob=prob) + + def _transform(self, signal): + _loudness = signal._loudness + signal = AudioSignal( + torch.zeros_like(signal.audio_data), + sample_rate=signal.sample_rate, + stft_params=signal.stft_params, + ) + # So that the amound of noise added is as if it wasn't silenced. + # TODO: improve this hack + signal._loudness = _loudness + + return signal + + +class LowPass(BaseTransform): + """Applies a LowPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [4000, 8000, 16000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple = ("choice", [4000, 8000, 16000]), + zeros: int = 51, + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.low_pass(cutoff, zeros=self.zeros) + + +class HighPass(BaseTransform): + """Applies a HighPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [50, 100, 250, 500, 1000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]), + zeros: int = 51, + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.high_pass(cutoff, zeros=self.zeros) + + +class RescaleAudio(BaseTransform): + """Rescales the audio so it is in between ``-val`` and ``val`` + only if the original audio exceeds those bounds. Useful if + transforms have caused the audio to clip. + + Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`. + + Parameters + ---------- + val : float, optional + Max absolute value of signal, by default 1.0 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, val: float = 1.0, name: str = None, prob: float = 1): + super().__init__(name=name, prob=prob) + + self.val = val + + def _transform(self, signal): + return signal.ensure_max_of_audio(self.val) + + +class ShiftPhase(SpectralTransform): + """Shifts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`. + + Parameters + ---------- + shift : tuple, optional + How much to shift phase by, by default ("uniform", -np.pi, np.pi) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + shift: tuple = ("uniform", -np.pi, np.pi), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.shift = shift + + def _instantiate(self, state: RandomState): + return {"shift": util.sample_from_dist(self.shift, state)} + + def _transform(self, signal, shift): + return signal.shift_phase(shift) + + +class InvertPhase(ShiftPhase): + """Inverts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`. + + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, name: str = None, prob: float = 1): + super().__init__(shift=("const", np.pi), name=name, prob=prob) + + +class CorruptPhase(SpectralTransform): + """Corrupts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`. + + Parameters + ---------- + scale : tuple, optional + How much to corrupt phase by, by default ("uniform", 0, np.pi) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1 + ): + super().__init__(name=name, prob=prob) + self.scale = scale + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + scale = util.sample_from_dist(self.scale, state) + corruption = state.normal(scale=scale, size=signal.phase.shape[1:]) + return {"corruption": corruption.astype("float32")} + + def _transform(self, signal, corruption): + return signal.shift_phase(shift=corruption) + + +class FrequencyMask(SpectralTransform): + """Masks a band of frequencies at a center frequency + from the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple = ("uniform", 0.0, 1.0), + f_width: tuple = ("const", 0.1), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.f_center = f_center + self.f_width = f_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + f_center = util.sample_from_dist(self.f_center, state) + f_width = util.sample_from_dist(self.f_width, state) + + fmin = max(f_center - (f_width / 2), 0.0) + fmax = min(f_center + (f_width / 2), 1.0) + + fmin_hz = (signal.sample_rate / 2) * fmin + fmax_hz = (signal.sample_rate / 2) * fmax + + return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz} + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + + +class TimeMask(SpectralTransform): + """Masks out contiguous time-steps from signal. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`. + + Parameters + ---------- + t_center : tuple, optional + Center time in terms of 0.0 and 1.0 (duration of signal), + by default ("uniform", 0.0, 1.0) + t_width : tuple, optional + Width of dropped out portion, by default ("const", 0.025) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + t_center: tuple = ("uniform", 0.0, 1.0), + t_width: tuple = ("const", 0.025), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.t_center = t_center + self.t_width = t_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + t_center = util.sample_from_dist(self.t_center, state) + t_width = util.sample_from_dist(self.t_width, state) + + tmin = max(t_center - (t_width / 2), 0.0) + tmax = min(t_center + (t_width / 2), 1.0) + + tmin_s = signal.signal_duration * tmin + tmax_s = signal.signal_duration * tmax + return {"tmin_s": tmin_s, "tmax_s": tmax_s} + + def _transform(self, signal, tmin_s: float, tmax_s: float): + return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s) + + +class MaskLowMagnitudes(SpectralTransform): + """Masks low magnitude regions out of signal. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`. + + Parameters + ---------- + db_cutoff : tuple, optional + Decibel value for which things below it will be masked away, + by default ("uniform", -10, 10) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db_cutoff: tuple = ("uniform", -10, 10), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.db_cutoff = db_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)} + + def _transform(self, signal, db_cutoff: float): + return signal.mask_low_magnitudes(db_cutoff) + + +class Smoothing(BaseTransform): + """Convolves the signal with a smoothing window. + + Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`. + + Parameters + ---------- + window_type : tuple, optional + Type of window to use, by default ("const", "average") + window_length : tuple, optional + Length of smoothing window, by + default ("choice", [8, 16, 32, 64, 128, 256, 512]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + window_type: tuple = ("const", "average"), + window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.window_type = window_type + self.window_length = window_length + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + window_type = util.sample_from_dist(self.window_type, state) + window_length = util.sample_from_dist(self.window_length, state) + window = signal.get_window( + window_type=window_type, window_length=window_length, device="cpu" + ) + return {"window": AudioSignal(window, signal.sample_rate)} + + def _transform(self, signal, window): + sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values + sscale[sscale == 0.0] = 1.0 + + out = signal.convolve(window) + + oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values + oscale[oscale == 0.0] = 1.0 + + out = out * (sscale / oscale) + return out + + +class TimeNoise(TimeMask): + """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but + replaces with noise instead of zeros. + + Parameters + ---------- + t_center : tuple, optional + Center time in terms of 0.0 and 1.0 (duration of signal), + by default ("uniform", 0.0, 1.0) + t_width : tuple, optional + Width of dropped out portion, by default ("const", 0.025) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + t_center: tuple = ("uniform", 0.0, 1.0), + t_width: tuple = ("const", 0.025), + name: str = None, + prob: float = 1, + ): + super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob) + + def _transform(self, signal, tmin_s: float, tmax_s: float): + signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0) + mag, phase = signal.magnitude, signal.phase + + mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) + mask = (mag == 0.0) * (phase == 0.0) + + mag[mask] = mag_r[mask] + phase[mask] = phase_r[mask] + + signal.magnitude = mag + signal.phase = phase + return signal + + +class FrequencyNoise(FrequencyMask): + """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but + replaces with noise instead of zeros. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple = ("uniform", 0.0, 1.0), + f_width: tuple = ("const", 0.1), + name: str = None, + prob: float = 1, + ): + super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob) + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + mag, phase = signal.magnitude, signal.phase + + mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) + mask = (mag == 0.0) * (phase == 0.0) + + mag[mask] = mag_r[mask] + phase[mask] = phase_r[mask] + + signal.magnitude = mag + signal.phase = phase + return signal + + +class SpectralDenoising(Equalizer): + """Applies denoising algorithm detailed in + :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`, + using a randomly generated noise signal for denoising. + + Parameters + ---------- + eq_amount : tuple, optional + Amount of eq to apply to noise signal, by default ("const", 1.0) + denoise_amount : tuple, optional + Amount to denoise by, by default ("uniform", 0.8, 1.0) + nz_volume : float, optional + Volume of noise to denoise with, by default -40 + n_bands : int, optional + Number of bands in equalizer, by default 6 + n_freq : int, optional + Number of frequency bins to smooth by, by default 3 + n_time : int, optional + Number of time bins to smooth by, by default 5 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + eq_amount: tuple = ("const", 1.0), + denoise_amount: tuple = ("uniform", 0.8, 1.0), + nz_volume: float = -40, + n_bands: int = 6, + n_freq: int = 3, + n_time: int = 5, + name: str = None, + prob: float = 1, + ): + super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob) + + self.nz_volume = nz_volume + self.denoise_amount = denoise_amount + self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time) + + def _transform(self, signal, nz, eq, denoise_amount): + nz = nz.normalize(self.nz_volume).equalizer(eq) + self.spectral_gate = self.spectral_gate.to(signal.device) + signal = self.spectral_gate(signal, nz, denoise_amount) + return signal + + def _instantiate(self, state: RandomState): + kwargs = super()._instantiate(state) + kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state) + kwargs["nz"] = AudioSignal(state.randn(22050), 44100) + return kwargs diff --git a/flowae/models/ldm/dac/audiotools/metrics/__init__.py b/flowae/models/ldm/dac/audiotools/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c8d2df61f94afae8e39e57abf156e8e4059a9e --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/metrics/__init__.py @@ -0,0 +1,6 @@ +""" +Functions for comparing AudioSignal objects to one another. +""" # fmt: skip +from . import distance +from . import quality +from . import spectral diff --git a/flowae/models/ldm/dac/audiotools/metrics/distance.py b/flowae/models/ldm/dac/audiotools/metrics/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..ce78739bfc29f9ddc39b23063b4243ddac10adaf --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/metrics/distance.py @@ -0,0 +1,131 @@ +import torch +from torch import nn + +from .. import AudioSignal + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr diff --git a/flowae/models/ldm/dac/audiotools/metrics/quality.py b/flowae/models/ldm/dac/audiotools/metrics/quality.py new file mode 100644 index 0000000000000000000000000000000000000000..1608f25507082b49ccbf49289025a5a94a422808 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/metrics/quality.py @@ -0,0 +1,159 @@ +import os + +import numpy as np +import torch + +from .. import AudioSignal + + +def stoi( + estimates: AudioSignal, + references: AudioSignal, + extended: int = False, +): + """Short term objective intelligibility + Computes the STOI (See [1][2]) of a denoised signal compared to a clean + signal, The output is expected to have a monotonic relation with the + subjective speech-intelligibility, where a higher score denotes better + speech intelligibility. Uses pystoi under the hood. + + Parameters + ---------- + estimates : AudioSignal + Denoised speech + references : AudioSignal + Clean original speech + extended : int, optional + Boolean, whether to use the extended STOI described in [3], by default False + + Returns + ------- + Tensor[float] + Short time objective intelligibility measure between clean and + denoised speech + + References + ---------- + 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time + Objective Intelligibility Measure for Time-Frequency Weighted Noisy + Speech', ICASSP 2010, Texas, Dallas. + 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for + Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', + IEEE Transactions on Audio, Speech, and Language Processing, 2011. + 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the + Intelligibility of Speech Masked by Modulated Noise Maskers', + IEEE Transactions on Audio, Speech and Language Processing, 2016. + """ + import pystoi + + estimates = estimates.clone().to_mono() + references = references.clone().to_mono() + + stois = [] + for i in range(estimates.batch_size): + _stoi = pystoi.stoi( + references.audio_data[i, 0].detach().cpu().numpy(), + estimates.audio_data[i, 0].detach().cpu().numpy(), + references.sample_rate, + extended=extended, + ) + stois.append(_stoi) + return torch.from_numpy(np.array(stois)) + + +def pesq( + estimates: AudioSignal, + references: AudioSignal, + mode: str = "wb", + target_sr: float = 16000, +): + """_summary_ + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'wb' (wide-band) or 'nb' (narrow-band), by default "wb" + target_sr : float, optional + Target sample rate, by default 16000 + + Returns + ------- + Tensor[float] + PESQ score: P.862.2 Prediction (MOS-LQO) + """ + from pesq import pesq as pesq_fn + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + pesqs = [] + for i in range(estimates.batch_size): + _pesq = pesq_fn( + estimates.sample_rate, + references.audio_data[i, 0].detach().cpu().numpy(), + estimates.audio_data[i, 0].detach().cpu().numpy(), + mode, + ) + pesqs.append(_pesq) + return torch.from_numpy(np.array(pesqs)) + + +def visqol( + estimates: AudioSignal, + references: AudioSignal, + mode: str = "audio", +): # pragma: no cover + """ViSQOL score. + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'audio' or 'speech', by default 'audio' + + Returns + ------- + Tensor[float] + ViSQOL score (MOS-LQO) + """ + from visqol import visqol_lib_py + from visqol.pb2 import visqol_config_pb2 + from visqol.pb2 import similarity_result_pb2 + + config = visqol_config_pb2.VisqolConfig() + if mode == "audio": + target_sr = 48000 + config.options.use_speech_scoring = False + svr_model_path = "libsvm_nu_svr_model.txt" + elif mode == "speech": + target_sr = 16000 + config.options.use_speech_scoring = True + svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" + else: + raise ValueError(f"Unrecognized mode: {mode}") + config.audio.sample_rate = target_sr + config.options.svr_model_path = os.path.join( + os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path + ) + + api = visqol_lib_py.VisqolApi() + api.Create(config) + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + visqols = [] + for i in range(estimates.batch_size): + _visqol = api.Measure( + references.audio_data[i, 0].detach().cpu().numpy().astype(float), + estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), + ) + visqols.append(_visqol.moslqo) + return torch.from_numpy(np.array(visqols)) diff --git a/flowae/models/ldm/dac/audiotools/metrics/spectral.py b/flowae/models/ldm/dac/audiotools/metrics/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce953882efa4e5b777a0348bee6c1be39279a6c --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/metrics/spectral.py @@ -0,0 +1,247 @@ +import typing +from typing import List + +import numpy as np +from torch import nn + +from .. import AudioSignal +from .. import STFTParams + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class PhaseLoss(nn.Module): + """Difference between phase spectrograms. + + Parameters + ---------- + window_length : int, optional + Length of STFT window, by default 2048 + hop_length : int, optional + Hop length of STFT window, by default 512 + weight : float, optional + Weight of loss, by default 1.0 + """ + + def __init__( + self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 + ): + super().__init__() + + self.weight = weight + self.stft_params = STFTParams(window_length, hop_length) + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes phase loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Phase loss. + """ + s = self.stft_params + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + + # Take circular difference + diff = x.phase - y.phase + diff[diff < -np.pi] += 2 * np.pi + diff[diff > np.pi] -= -2 * np.pi + + # Scale true magnitude to weights in [0, 1] + x_min, x_max = x.magnitude.min(), x.magnitude.max() + weights = (x.magnitude - x_min) / (x_max - x_min) + + # Take weighted mean of all phase errors + loss = ((weights * diff) ** 2).mean() + return loss diff --git a/flowae/models/ldm/dac/audiotools/ml/__init__.py b/flowae/models/ldm/dac/audiotools/ml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ca69977bad57e1a92b7551d601d9224ee854ab --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/__init__.py @@ -0,0 +1,5 @@ +from . import decorators +from . import layers +from .accelerator import Accelerator +from .experiment import Experiment +from .layers import BaseModel diff --git a/flowae/models/ldm/dac/audiotools/ml/accelerator.py b/flowae/models/ldm/dac/audiotools/ml/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..37c6e8d954f112b8b0aff257894e62add8874e30 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/accelerator.py @@ -0,0 +1,184 @@ +import os +import typing + +import torch +import torch.distributed as dist +from torch.nn.parallel import DataParallel +from torch.nn.parallel import DistributedDataParallel + +from ..data.datasets import ResumableDistributedSampler as DistributedSampler +from ..data.datasets import ResumableSequentialSampler as SequentialSampler + + +class Accelerator: # pragma: no cover + """This class is used to prepare models and dataloaders for + usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to + prepare the respective objects. In the case of models, they are moved to + the appropriate GPU and SyncBatchNorm is applied to them. In the case of + dataloaders, a sampler is created and the dataloader is initialized with + that sampler. + + If the world size is 1, prepare_model and prepare_dataloader are + no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the + script was launched without ``torchrun``, and ``DataParallel`` + will be used instead of ``DistributedDataParallel`` (not recommended), if + the world size (number of GPUs) is greater than 1. + + Parameters + ---------- + amp : bool, optional + Whether or not to enable automatic mixed precision, by default False + """ + + def __init__(self, amp: bool = False): + local_rank = os.getenv("LOCAL_RANK", None) + self.world_size = torch.cuda.device_count() + + self.use_ddp = self.world_size > 1 and local_rank is not None + self.use_dp = self.world_size > 1 and local_rank is None + self.device = "cpu" if self.world_size == 0 else "cuda" + + if self.use_ddp: + local_rank = int(local_rank) + dist.init_process_group( + "nccl", + init_method="env://", + world_size=self.world_size, + rank=local_rank, + ) + + self.local_rank = 0 if local_rank is None else local_rank + self.amp = amp + + class DummyScaler: + def __init__(self): + pass + + def step(self, optimizer): + optimizer.step() + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + return optimizer + + def update(self): + pass + + self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() + self.device_ctx = ( + torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None + ) + + def __enter__(self): + if self.device_ctx is not None: + self.device_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.device_ctx is not None: + self.device_ctx.__exit__(exc_type, exc_value, traceback) + + def prepare_model(self, model: torch.nn.Module, **kwargs): + """Prepares model for DDP or DP. The model is moved to + the device of the correct rank. + + Parameters + ---------- + model : torch.nn.Module + Model that is converted for DDP or DP. + + Returns + ------- + torch.nn.Module + Wrapped model, or original model if DDP and DP are turned off. + """ + model = model.to(self.device) + if self.use_ddp: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel( + model, device_ids=[self.local_rank], **kwargs + ) + elif self.use_dp: + model = DataParallel(model, **kwargs) + return model + + # Automatic mixed-precision utilities + def autocast(self, *args, **kwargs): + """Context manager for autocasting. Arguments + go to ``torch.cuda.amp.autocast``. + """ + return torch.cuda.amp.autocast(self.amp, *args, **kwargs) + + def backward(self, loss: torch.Tensor): + """Backwards pass, after scaling the loss if ``amp`` is + enabled. + + Parameters + ---------- + loss : torch.Tensor + Loss value. + """ + self.scaler.scale(loss).backward() + + def step(self, optimizer: torch.optim.Optimizer): + """Steps the optimizer, using a ``scaler`` if ``amp`` is + enabled. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + Optimizer to step forward. + """ + self.scaler.step(optimizer) + + def update(self): + """Updates the scale factor.""" + self.scaler.update() + + def prepare_dataloader( + self, dataset: typing.Iterable, start_idx: int = None, **kwargs + ): + """Wraps a dataset with a DataLoader, using the correct sampler if DDP is + enabled. + + Parameters + ---------- + dataset : typing.Iterable + Dataset to build Dataloader around. + start_idx : int, optional + Start index of sampler, useful if resuming from some epoch, + by default None + + Returns + ------- + _type_ + _description_ + """ + + if self.use_ddp: + sampler = DistributedSampler( + dataset, + start_idx, + num_replicas=self.world_size, + rank=self.local_rank, + ) + if "num_workers" in kwargs: + kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) + kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) + else: + sampler = SequentialSampler(dataset, start_idx) + + dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) + return dataloader + + @staticmethod + def unwrap(model): + """Unwraps the model if it was wrapped in DDP or DP, otherwise + just returns the model. Use this to unwrap the model returned by + :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. + """ + if hasattr(model, "module"): + return model.module + return model diff --git a/flowae/models/ldm/dac/audiotools/ml/decorators.py b/flowae/models/ldm/dac/audiotools/ml/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..3a435b06c47a48dc3600fa54ac092006f5c5bb27 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/decorators.py @@ -0,0 +1,441 @@ +import math +import os +import time +from collections import defaultdict +from functools import wraps + +import torch +import torch.distributed as dist +from rich import box +from rich.console import Console +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.padding import Padding +from rich.panel import Panel +from rich.progress import BarColumn +from rich.progress import Progress +from rich.progress import SpinnerColumn +from rich.progress import TimeElapsedColumn +from rich.progress import TimeRemainingColumn +from rich.rule import Rule +from rich.table import Table +from torch.utils.tensorboard import SummaryWriter + + +# This is here so that the history can be pickled. +def default_list(): + return [] + + +class Mean: + """Keeps track of the running mean, along with the latest + value. + """ + + def __init__(self): + self.reset() + + def __call__(self): + mean = self.total / max(self.count, 1) + return mean + + def reset(self): + self.count = 0 + self.total = 0 + + def update(self, val): + if math.isfinite(val): + self.count += 1 + self.total += val + + +def when(condition): + """Runs a function only when the condition is met. The condition is + a function that is run. + + Parameters + ---------- + condition : Callable + Function to run to check whether or not to run the decorated + function. + + Example + ------- + Checkpoint only runs every 100 iterations, and only if the + local rank is 0. + + >>> i = 0 + >>> rank = 0 + >>> + >>> @when(lambda: i % 100 == 0 and rank == 0) + >>> def checkpoint(): + >>> print("Saving to /runs/exp1") + >>> + >>> for i in range(1000): + >>> checkpoint() + + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if condition(): + return fn(*args, **kwargs) + + return decorated + + return decorator + + +def timer(prefix: str = "time"): + """Adds execution time to the output dictionary of the decorated + function. The function decorated by this must output a dictionary. + The key added will follow the form "[prefix]/[name_of_function]" + + Parameters + ---------- + prefix : str, optional + The key added will follow the form "[prefix]/[name_of_function]", + by default "time". + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + s = time.perf_counter() + output = fn(*args, **kwargs) + assert isinstance(output, dict) + e = time.perf_counter() + output[f"{prefix}/{fn.__name__}"] = e - s + return output + + return decorated + + return decorator + + +class Tracker: + """ + A tracker class that helps to monitor the progress of training and logging the metrics. + + Attributes + ---------- + metrics : dict + A dictionary containing the metrics for each label. + history : dict + A dictionary containing the history of metrics for each label. + writer : SummaryWriter + A SummaryWriter object for logging the metrics. + rank : int + The rank of the current process. + step : int + The current step of the training. + tasks : dict + A dictionary containing the progress bars and tables for each label. + pbar : Progress + A progress bar object for displaying the progress. + consoles : list + A list of console objects for logging. + live : Live + A Live object for updating the display live. + + Methods + ------- + print(msg: str) + Prints the given message to all consoles. + update(label: str, fn_name: str) + Updates the progress bar and table for the given label. + done(label: str, title: str) + Resets the progress bar and table for the given label and prints the final result. + track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) + A decorator for tracking the progress and metrics of a function. + log(label: str, value_type: str = "value", history: bool = True) + A decorator for logging the metrics of a function. + is_best(label: str, key: str) -> bool + Checks if the latest value of the given key in the label is the best so far. + state_dict() -> dict + Returns a dictionary containing the state of the tracker. + load_state_dict(state_dict: dict) -> Tracker + Loads the state of the tracker from the given state dictionary. + """ + + def __init__( + self, + writer: SummaryWriter = None, + log_file: str = None, + rank: int = 0, + console_width: int = 100, + step: int = 0, + ): + """ + Initializes the Tracker object. + + Parameters + ---------- + writer : SummaryWriter, optional + A SummaryWriter object for logging the metrics, by default None. + log_file : str, optional + The path to the log file, by default None. + rank : int, optional + The rank of the current process, by default 0. + console_width : int, optional + The width of the console, by default 100. + step : int, optional + The current step of the training, by default 0. + """ + self.metrics = {} + self.history = {} + self.writer = writer + self.rank = rank + self.step = step + + # Create progress bars etc. + self.tasks = {} + self.pbar = Progress( + SpinnerColumn(), + "[progress.description]{task.description}", + "{task.completed}/{task.total}", + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), + ) + self.consoles = [Console(width=console_width)] + self.live = Live(console=self.consoles[0], refresh_per_second=10) + if log_file is not None: + self.consoles.append(Console(width=console_width, file=open(log_file, "a"))) + + def print(self, msg): + """ + Prints the given message to all consoles. + + Parameters + ---------- + msg : str + The message to be printed. + """ + if self.rank == 0: + for c in self.consoles: + c.log(msg) + + def update(self, label, fn_name): + """ + Updates the progress bar and table for the given label. + + Parameters + ---------- + label : str + The label of the progress bar and table to be updated. + fn_name : str + The name of the function associated with the label. + """ + if self.rank == 0: + self.pbar.advance(self.tasks[label]["pbar"]) + + # Create table + table = Table(title=label, expand=True, box=box.MINIMAL) + table.add_column("key", style="cyan") + table.add_column("value", style="bright_blue") + table.add_column("mean", style="bright_green") + + keys = self.metrics[label]["value"].keys() + for k in keys: + value = self.metrics[label]["value"][k] + mean = self.metrics[label]["mean"][k]() + table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") + + self.tasks[label]["table"] = table + tables = [t["table"] for t in self.tasks.values()] + group = Group(*tables, self.pbar) + self.live.update( + Group( + Padding("", (0, 0)), + Rule(f"[italic]{fn_name}()", style="white"), + Padding("", (0, 0)), + Panel.fit( + group, padding=(0, 5), title="[b]Progress", border_style="blue" + ), + ) + ) + + def done(self, label: str, title: str): + """ + Resets the progress bar and table for the given label and prints the final result. + + Parameters + ---------- + label : str + The label of the progress bar and table to be reset. + title : str + The title to be displayed when printing the final result. + """ + for label in self.metrics: + for v in self.metrics[label]["mean"].values(): + v.reset() + + if self.rank == 0: + self.pbar.reset(self.tasks[label]["pbar"]) + tables = [t["table"] for t in self.tasks.values()] + group = Group(Markdown(f"# {title}"), *tables, self.pbar) + self.print(group) + + def track( + self, + label: str, + length: int, + completed: int = 0, + op: dist.ReduceOp = dist.ReduceOp.AVG, + ddp_active: bool = "LOCAL_RANK" in os.environ, + ): + """ + A decorator for tracking the progress and metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the progress and metrics. + length : int + The total number of iterations to be completed. + completed : int, optional + The number of iterations already completed, by default 0. + op : dist.ReduceOp, optional + The reduce operation to be used, by default dist.ReduceOp.AVG. + ddp_active : bool, optional + Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. + """ + self.tasks[label] = { + "pbar": self.pbar.add_task( + f"[white]Iteration ({label})", total=length, completed=completed + ), + "table": Table(), + } + self.metrics[label] = { + "value": defaultdict(), + "mean": defaultdict(lambda: Mean()), + } + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if not isinstance(output, dict): + self.update(label, fn.__name__) + return output + # Collect across all DDP processes + scalar_keys = [] + for k, v in output.items(): + if isinstance(v, (int, float)): + v = torch.tensor([v]) + if not torch.is_tensor(v): + continue + if ddp_active and v.is_cuda: # pragma: no cover + dist.all_reduce(v, op=op) + output[k] = v.detach() + if torch.numel(v) == 1: + scalar_keys.append(k) + output[k] = v.item() + + # Save the outputs to tracker + for k, v in output.items(): + if k not in scalar_keys: + continue + self.metrics[label]["value"][k] = v + # Update the running mean + self.metrics[label]["mean"][k].update(v) + + self.update(label, fn.__name__) + return output + + return decorated + + return decorator + + def log(self, label: str, value_type: str = "value", history: bool = True): + """ + A decorator for logging the metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the logging. + value_type : str, optional + The type of value to be logged, by default "value". + history : bool, optional + Whether to save the history of the metrics, by default True. + """ + assert value_type in ["mean", "value"] + if history: + if label not in self.history: + self.history[label] = defaultdict(default_list) + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if self.rank == 0: + nonlocal value_type, label + metrics = self.metrics[label][value_type] + for k, v in metrics.items(): + v = v() if isinstance(v, Mean) else v + if self.writer is not None: + # self.writer.add_scalar(f"{k}/{label}", v, self.step) + self.writer.log_metric(f"{k}_{label}", v, step=self.step) + if label in self.history: + self.history[label][k].append(v) + + if label in self.history: + self.history[label]["step"].append(self.step) + + return output + + return decorated + + return decorator + + def is_best(self, label, key): + """ + Checks if the latest value of the given key in the label is the best so far. + + Parameters + ---------- + label : str + The label of the metrics to be checked. + key : str + The key of the metric to be checked. + + Returns + ------- + bool + True if the latest value is the best so far, otherwise False. + """ + return self.history[label][key][-1] == min(self.history[label][key]) + + def state_dict(self): + """ + Returns a dictionary containing the state of the tracker. + + Returns + ------- + dict + A dictionary containing the history and step of the tracker. + """ + return {"history": self.history, "step": self.step} + + def load_state_dict(self, state_dict): + """ + Loads the state of the tracker from the given state dictionary. + + Parameters + ---------- + state_dict : dict + A dictionary containing the history and step of the tracker. + + Returns + ------- + Tracker + The tracker object with the loaded state. + """ + self.history = state_dict["history"] + self.step = state_dict["step"] + return self diff --git a/flowae/models/ldm/dac/audiotools/ml/experiment.py b/flowae/models/ldm/dac/audiotools/ml/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..62833d0f8f80dcdf496a1a5d2785ef666e0a15b6 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/experiment.py @@ -0,0 +1,90 @@ +""" +Useful class for Experiment tracking, and ensuring code is +saved alongside files. +""" # fmt: skip +import datetime +import os +import shlex +import shutil +import subprocess +import typing +from pathlib import Path + +import randomname + + +class Experiment: + """This class contains utilities for managing experiments. + It is a context manager, that when you enter it, changes + your directory to a specified experiment folder (which + optionally can have an automatically generated experiment + name, or a specified one), and changes the CUDA device used + to the specified device (or devices). + + Parameters + ---------- + exp_directory : str + Folder where all experiments are saved, by default "runs/". + exp_name : str, optional + Name of the experiment, by default uses the current time, date, and + hostname to save. + """ + + def __init__( + self, + exp_directory: str = "runs/", + exp_name: str = None, + ): + if exp_name is None: + exp_name = self.generate_exp_name() + exp_dir = Path(exp_directory) / exp_name + exp_dir.mkdir(parents=True, exist_ok=True) + + self.exp_dir = exp_dir + self.exp_name = exp_name + self.git_tracked_files = ( + subprocess.check_output( + shlex.split("git ls-tree --full-tree --name-only -r HEAD") + ) + .decode("utf-8") + .splitlines() + ) + self.parent_directory = Path(".").absolute() + + def __enter__(self): + self.prev_dir = os.getcwd() + os.chdir(self.exp_dir) + return self + + def __exit__(self, exc_type, exc_value, traceback): + os.chdir(self.prev_dir) + + @staticmethod + def generate_exp_name(): + """Generates a random experiment name based on the date + and a randomly generated adjective-noun tuple. + + Returns + ------- + str + Randomly generated experiment name. + """ + date = datetime.datetime.now().strftime("%y%m%d") + name = f"{date}-{randomname.get_name()}" + return name + + def snapshot(self, filter_fn: typing.Callable = lambda f: True): + """Captures a full snapshot of all the files tracked by git at the time + the experiment is run. It also captures the diff against the committed + code as a separate file. + + Parameters + ---------- + filter_fn : typing.Callable, optional + Function that can be used to exclude some files + from the snapshot, by default accepts all files + """ + for f in self.git_tracked_files: + if filter_fn(f): + Path(f).parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(self.parent_directory / f, f) diff --git a/flowae/models/ldm/dac/audiotools/ml/layers/__init__.py b/flowae/models/ldm/dac/audiotools/ml/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92a016cab2ddf06bf5dadfae241b7e5d9def4878 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/layers/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseModel +from .spectral_gate import SpectralGate diff --git a/flowae/models/ldm/dac/audiotools/ml/layers/base.py b/flowae/models/ldm/dac/audiotools/ml/layers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/layers/base.py @@ -0,0 +1,328 @@ +import inspect +import shutil +import tempfile +import typing +from pathlib import Path + +import torch +from torch import nn + + +class BaseModel(nn.Module): + """This is a class that adds useful save/load functionality to a + ``torch.nn.Module`` object. ``BaseModel`` objects can be saved + as ``torch.package`` easily, making them super easy to port between + machines without requiring a ton of dependencies. Files can also be + saved as just weights, in the standard way. + + >>> class Model(ml.BaseModel): + >>> def __init__(self, arg1: float = 1.0): + >>> super().__init__() + >>> self.arg1 = arg1 + >>> self.linear = nn.Linear(1, 1) + >>> + >>> def forward(self, x): + >>> return self.linear(x) + >>> + >>> model1 = Model() + >>> + >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: + >>> model1.save( + >>> f.name, + >>> ) + >>> model2 = Model.load(f.name) + >>> out2 = seed_and_run(model2, x) + >>> assert torch.allclose(out1, out2) + >>> + >>> model1.save(f.name, package=True) + >>> model2 = Model.load(f.name) + >>> model2.save(f.name, package=False) + >>> model3 = Model.load(f.name) + >>> out3 = seed_and_run(model3, x) + >>> + >>> with tempfile.TemporaryDirectory() as d: + >>> model1.save_to_folder(d, {"data": 1.0}) + >>> Model.load_from_folder(d) + + """ + + EXTERN = [ + "audiotools.**", + "tqdm", + "__main__", + "numpy.**", + "julius.**", + "torchaudio.**", + "scipy.**", + "einops", + ] + """Names of libraries that are external to the torch.package saving mechanism. + Source code from these libraries will not be packaged into the model. This can + be edited by the user of this class by editing ``model.EXTERN``.""" + INTERN = [] + """Names of libraries that are internal to the torch.package saving mechanism. + Source code from these libraries will be saved alongside the model.""" + + def save( + self, + path: str, + metadata: dict = None, + package: bool = True, + intern: list = [], + extern: list = [], + mock: list = [], + ): + """Saves the model, either as a torch package, or just as + weights, alongside some specified metadata. + + Parameters + ---------- + path : str + Path to save model to. + metadata : dict, optional + Any metadata to save alongside the model, + by default None + package : bool, optional + Whether to use ``torch.package`` to save the model in + a format that is portable, by default True + intern : list, optional + List of additional libraries that are internal + to the model, used with torch.package, by default [] + extern : list, optional + List of additional libraries that are external to + the model, used with torch.package, by default [] + mock : list, optional + List of libraries to mock, used with torch.package, + by default [] + + Returns + ------- + str + Path to saved model. + """ + sig = inspect.signature(self.__class__) + args = {} + + for key, val in sig.parameters.items(): + arg_val = val.default + if arg_val is not inspect.Parameter.empty: + args[key] = arg_val + + # Look up attibutes in self, and if any of them are in args, + # overwrite them in args. + for attribute in dir(self): + if attribute in args: + args[attribute] = getattr(self, attribute) + + metadata = {} if metadata is None else metadata + metadata["kwargs"] = args + if not hasattr(self, "metadata"): + self.metadata = {} + self.metadata.update(metadata) + + if not package: + state_dict = {"state_dict": self.state_dict(), "metadata": metadata} + torch.save(state_dict, path) + else: + self._save_package(path, intern=intern, extern=extern, mock=mock) + + return path + + @property + def device(self): + """Gets the device the model is on by looking at the device of + the first parameter. May not be valid if model is split across + multiple devices. + """ + return list(self.parameters())[0].device + + @classmethod + def load( + cls, + location: str, + *args, + package_name: str = None, + strict: bool = False, + **kwargs, + ): + """Load model from a path. Tries first to load as a package, and if + that fails, tries to load as weights. The arguments to the class are + specified inside the model weights file. + + Parameters + ---------- + location : str + Path to file. + package_name : str, optional + Name of package, by default ``cls.__name__``. + strict : bool, optional + Ignore unmatched keys, by default False + kwargs : dict + Additional keyword arguments to the model instantiation, if + not loading from package. + + Returns + ------- + BaseModel + A model that inherits from BaseModel. + """ + try: + model = cls._load_package(location, package_name=package_name) + except: + model_dict = torch.load(location, "cpu") + metadata = model_dict["metadata"] + metadata["kwargs"].update(kwargs) + + sig = inspect.signature(cls) + class_keys = list(sig.parameters.keys()) + for k in list(metadata["kwargs"].keys()): + if k not in class_keys: + metadata["kwargs"].pop(k) + + model = cls(*args, **metadata["kwargs"]) + model.load_state_dict(model_dict["state_dict"], strict=strict) + model.metadata = metadata + + return model + + def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): + package_name = type(self).__name__ + resource_name = f"{type(self).__name__}.pth" + + # Below is for loading and re-saving a package. + if hasattr(self, "importer"): + kwargs["importer"] = (self.importer, torch.package.sys_importer) + del self.importer + + # Why do we use a tempfile, you ask? + # It's so we can load a packaged model and then re-save + # it to the same location. torch.package throws an + # error if it's loading and writing to the same + # file (this is undocumented). + with tempfile.NamedTemporaryFile(suffix=".pth") as f: + with torch.package.PackageExporter(f.name, **kwargs) as exp: + exp.intern(self.INTERN + intern) + exp.mock(mock) + exp.extern(self.EXTERN + extern) + exp.save_pickle(package_name, resource_name, self) + + if hasattr(self, "metadata"): + exp.save_pickle( + package_name, f"{package_name}.metadata", self.metadata + ) + + shutil.copyfile(f.name, path) + + # Must reset the importer back to `self` if it existed + # so that you can save the model again! + if "importer" in kwargs: + self.importer = kwargs["importer"][0] + return path + + @classmethod + def _load_package(cls, path, package_name=None): + package_name = cls.__name__ if package_name is None else package_name + resource_name = f"{package_name}.pth" + + imp = torch.package.PackageImporter(path) + model = imp.load_pickle(package_name, resource_name, "cpu") + try: + model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata") + except: # pragma: no cover + pass + model.importer = imp + + return model + + def save_to_folder( + self, + folder: typing.Union[str, Path], + extra_data: dict = None, + package: bool = True, + ): + """Dumps a model into a folder, as both a package + and as weights, as well as anything specified in + ``extra_data``. ``extra_data`` is a dictionary of other + pickleable files, with the keys being the paths + to save them in. The model is saved under a subfolder + specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` + if the model name was ``Generator``). + + >>> with tempfile.TemporaryDirectory() as d: + >>> extra_data = { + >>> "optimizer.pth": optimizer.state_dict() + >>> } + >>> model.save_to_folder(d, extra_data) + >>> Model.load_from_folder(d) + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + extra_data : dict, optional + _description_, by default None + + Returns + ------- + str + Path to folder + """ + extra_data = {} if extra_data is None else extra_data + model_name = type(self).__name__.lower() + target_base = Path(f"{folder}/{model_name}/") + target_base.mkdir(exist_ok=True, parents=True) + + if package: + package_path = target_base / f"package.pth" + self.save(package_path) + + weights_path = target_base / f"weights.pth" + self.save(weights_path, package=False) + + for path, obj in extra_data.items(): + torch.save(obj, target_base / path) + + return target_base + + @classmethod + def load_from_folder( + cls, + folder: typing.Union[str, Path], + package: bool = True, + strict: bool = False, + **kwargs, + ): + """Loads the model from a folder generated by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + Like that function, this one looks for a subfolder that has + the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the + model name was ``Generator``). + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + package : bool, optional + Whether to use ``torch.package`` to load the model, + loading the model from ``package.pth``. + strict : bool, optional + Ignore unmatched keys, by default False + + Returns + ------- + tuple + tuple of model and extra data as saved by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + """ + folder = Path(folder) / cls.__name__.lower() + model_pth = "package.pth" if package else "weights.pth" + model_pth = folder / model_pth + + model = cls.load(model_pth, strict=strict) + extra_data = {} + excluded = ["package.pth", "weights.pth"] + files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] + for f in files: + extra_data[f.name] = torch.load(f, **kwargs) + + return model, extra_data diff --git a/flowae/models/ldm/dac/audiotools/ml/layers/spectral_gate.py b/flowae/models/ldm/dac/audiotools/ml/layers/spectral_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ae8b5eab2e56ce13541695f52a11a454759dae --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/ml/layers/spectral_gate.py @@ -0,0 +1,127 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from ...core import AudioSignal +from ...core import STFTParams +from ...core import util + + +class SpectralGate(nn.Module): + """Spectral gating algorithm for noise reduction, + as in Audacity/Ocenaudio. The steps are as follows: + + 1. An FFT is calculated over the noise audio clip + 2. Statistics are calculated over FFT of the the noise + (in frequency) + 3. A threshold is calculated based upon the statistics + of the noise (and the desired sensitivity of the algorithm) + 4. An FFT is calculated over the signal + 5. A mask is determined by comparing the signal FFT to the + threshold + 6. The mask is smoothed with a filter over frequency and time + 7. The mask is appled to the FFT of the signal, and is inverted + + Implementation inspired by Tim Sainburg's noisereduce: + + https://timsainburg.com/noise-reduction-python.html + + Parameters + ---------- + n_freq : int, optional + Number of frequency bins to smooth by, by default 3 + n_time : int, optional + Number of time bins to smooth by, by default 5 + """ + + def __init__(self, n_freq: int = 3, n_time: int = 5): + super().__init__() + + smoothing_filter = torch.outer( + torch.cat( + [ + torch.linspace(0, 1, n_freq + 2)[:-1], + torch.linspace(1, 0, n_freq + 2), + ] + )[..., 1:-1], + torch.cat( + [ + torch.linspace(0, 1, n_time + 2)[:-1], + torch.linspace(1, 0, n_time + 2), + ] + )[..., 1:-1], + ) + smoothing_filter = smoothing_filter / smoothing_filter.sum() + smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) + self.register_buffer("smoothing_filter", smoothing_filter) + + def forward( + self, + audio_signal: AudioSignal, + nz_signal: AudioSignal, + denoise_amount: float = 1.0, + n_std: float = 3.0, + win_length: int = 2048, + hop_length: int = 512, + ): + """Perform noise reduction. + + Parameters + ---------- + audio_signal : AudioSignal + Audio signal that noise will be removed from. + nz_signal : AudioSignal, optional + Noise signal to compute noise statistics from. + denoise_amount : float, optional + Amount to denoise by, by default 1.0 + n_std : float, optional + Number of standard deviations above which to consider + noise, by default 3.0 + win_length : int, optional + Length of window for STFT, by default 2048 + hop_length : int, optional + Hop length for STFT, by default 512 + + Returns + ------- + AudioSignal + Denoised audio signal. + """ + stft_params = STFTParams(win_length, hop_length, "sqrt_hann") + + audio_signal = audio_signal.clone() + audio_signal.stft_data = None + audio_signal.stft_params = stft_params + + nz_signal = nz_signal.clone() + nz_signal.stft_params = stft_params + + nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() + nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) + nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) + + nz_thresh = nz_freq_mean + nz_freq_std * n_std + + stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() + nb, nac, nf, nt = stft_db.shape + db_thresh = nz_thresh.expand(nb, nac, -1, nt) + + stft_mask = (stft_db < db_thresh).float() + shape = stft_mask.shape + + stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) + pad_tuple = ( + self.smoothing_filter.shape[-2] // 2, + self.smoothing_filter.shape[-1] // 2, + ) + stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) + stft_mask = stft_mask.reshape(*shape) + stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to( + audio_signal.device + ) + stft_mask = 1 - stft_mask + + audio_signal.stft_data *= stft_mask + audio_signal.istft() + + return audio_signal diff --git a/flowae/models/ldm/dac/audiotools/post.py b/flowae/models/ldm/dac/audiotools/post.py new file mode 100644 index 0000000000000000000000000000000000000000..6ced2d1e66a4ffda3269685bd45593b01038739f --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/post.py @@ -0,0 +1,140 @@ +import tempfile +import typing +import zipfile +from pathlib import Path + +import markdown2 as md +import matplotlib.pyplot as plt +import torch +from IPython.display import HTML + + +def audio_table( + audio_dict: dict, + first_column: str = None, + format_fn: typing.Callable = None, + **kwargs, +): # pragma: no cover + """Embeds an audio table into HTML, or as the output cell + in a notebook. + + Parameters + ---------- + audio_dict : dict + Dictionary of data to embed. + first_column : str, optional + The label for the first column of the table, by default None + format_fn : typing.Callable, optional + How to format the data, by default None + + Returns + ------- + str + Table as a string + + Examples + -------- + + >>> audio_dict = {} + >>> for i in range(signal_batch.batch_size): + >>> audio_dict[i] = { + >>> "input": signal_batch[i], + >>> "output": output_batch[i] + >>> } + >>> audiotools.post.audio_zip(audio_dict) + + """ + from audiotools import AudioSignal + + output = [] + columns = None + + def _default_format_fn(label, x, **kwargs): + if torch.is_tensor(x): + x = x.tolist() + + if x is None: + return "." + elif isinstance(x, AudioSignal): + return x.embed(display=False, return_html=True, **kwargs) + else: + return str(x) + + if format_fn is None: + format_fn = _default_format_fn + + if first_column is None: + first_column = "." + + for k, v in audio_dict.items(): + if not isinstance(v, dict): + v = {"Audio": v} + + v_keys = list(v.keys()) + if columns is None: + columns = [first_column] + v_keys + output.append(" | ".join(columns)) + + layout = "|---" + len(v_keys) * "|:-:" + output.append(layout) + + formatted_audio = [] + for col in columns[1:]: + formatted_audio.append(format_fn(col, v[col], **kwargs)) + + row = f"| {k} | " + row += " | ".join(formatted_audio) + output.append(row) + + output = "\n" + "\n".join(output) + return output + + +def in_notebook(): # pragma: no cover + """Determines if code is running in a notebook. + + Returns + ------- + bool + Whether or not this is running in a notebook. + """ + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + return False + except ImportError: + return False + except AttributeError: + return False + return True + + +def disp(obj, **kwargs): # pragma: no cover + """Displays an object, depending on if its in a notebook + or not. + + Parameters + ---------- + obj : typing.Any + Any object to display. + + """ + from audiotools import AudioSignal + + IN_NOTEBOOK = in_notebook() + + if isinstance(obj, AudioSignal): + audio_elem = obj.embed(display=False, return_html=True) + if IN_NOTEBOOK: + return HTML(audio_elem) + else: + print(audio_elem) + if isinstance(obj, dict): + table = audio_table(obj, **kwargs) + if IN_NOTEBOOK: + return HTML(md.markdown(table, extras=["tables"])) + else: + print(table) + if isinstance(obj, plt.Figure): + plt.show() diff --git a/flowae/models/ldm/dac/audiotools/preference.py b/flowae/models/ldm/dac/audiotools/preference.py new file mode 100644 index 0000000000000000000000000000000000000000..800a852e8119dd18ea65784cf95182de2470fbc4 --- /dev/null +++ b/flowae/models/ldm/dac/audiotools/preference.py @@ -0,0 +1,600 @@ +############################################################## +### Tools for creating preference tests (MUSHRA, ABX, etc) ### +############################################################## +import copy +import csv +import random +import sys +import traceback +from collections import defaultdict +from pathlib import Path +from typing import List + +import gradio as gr + +from audiotools.core.util import find_audio + +################################################################ +### Logic for audio player, and adding audio / play buttons. ### +################################################################ + +WAVESURFER = """
""" + +CUSTOM_CSS = """ +.gradio-container { + max-width: 840px !important; +} +region.wavesurfer-region:before { + content: attr(data-region-label); +} + +block { + min-width: 0 !important; +} + +#wave-timeline { + background-color: rgba(0, 0, 0, 0.8); +} + +.head.svelte-1cl284s { + display: none; +} +""" + +load_wavesurfer_js = """ +function load_wavesurfer() { + function load_script(url) { + const script = document.createElement('script'); + script.src = url; + document.body.appendChild(script); + + return new Promise((res, rej) => { + script.onload = function() { + res(); + } + script.onerror = function () { + rej(); + } + }); + } + + function create_wavesurfer() { + var options = { + container: '#waveform', + waveColor: '#F2F2F2', // Set a darker wave color + progressColor: 'white', // Set a slightly lighter progress color + loaderColor: 'white', // Set a slightly lighter loader color + cursorColor: 'black', // Set a slightly lighter cursor color + backgroundColor: '#00AAFF', // Set a black background color + barWidth: 4, + barRadius: 3, + barHeight: 1, // the height of the wave + plugins: [ + WaveSurfer.regions.create({ + regionsMinLength: 0.0, + dragSelection: { + slop: 5 + }, + color: 'hsla(200, 50%, 70%, 0.4)', + }), + WaveSurfer.timeline.create({ + container: "#wave-timeline", + primaryLabelInterval: 5.0, + secondaryLabelInterval: 1.0, + primaryFontColor: '#F2F2F2', + secondaryFontColor: '#F2F2F2', + }), + ] + }; + wavesurfer = WaveSurfer.create(options); + wavesurfer.on('region-created', region => { + wavesurfer.regions.clear(); + }); + wavesurfer.on('finish', function () { + var loop = document.getElementById("loop-button").textContent.includes("ON"); + if (loop) { + wavesurfer.play(); + } + else { + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + } + }); + + wavesurfer.on('region-out', function () { + var loop = document.getElementById("loop-button").textContent.includes("ON"); + if (!loop) { + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + wavesurfer.pause(); + } + }); + + console.log("Created WaveSurfer object.") + } + + load_script('https://unpkg.com/wavesurfer.js@6.6.4') + .then(() => { + load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js") + .then(() => { + load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js') + .then(() => { + console.log("Loaded regions"); + create_wavesurfer(); + document.getElementById("start-survey").click(); + }) + }) + }); +} +""" + +play = lambda i: """ +function play() { + var audio_elements = document.getElementsByTagName('audio'); + var button_elements = document.getElementsByClassName('playpause') + + var audio_array = Array.from(audio_elements); + var buttons = Array.from(button_elements); + + var src_link = audio_array[{i}].getAttribute("src"); + console.log(src_link); + + var loop = document.getElementById("loop-button").textContent.includes("ON"); + var playing = buttons[{i}].textContent.includes("Stop"); + + for (let j = 0; j < buttons.length; j++) { + if (j != {i} || playing) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + else { + buttons[j].classList.remove("secondary"); + buttons[j].classList.add("primary"); + buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop") + } + } + + if (playing) { + wavesurfer.pause(); + wavesurfer.seekTo(0.0); + } + else { + wavesurfer.load(src_link); + wavesurfer.on('ready', function () { + var region = Object.values(wavesurfer.regions.list)[0]; + + if (region != null) { + region.loop = loop; + region.play(); + } else { + wavesurfer.play(); + } + }); + } +} +""".replace( + "{i}", str(i) +) + +clear_regions = """ +function clear_regions() { + wavesurfer.clearRegions(); +} +""" + +reset_player = """ +function reset_player() { + wavesurfer.clearRegions(); + wavesurfer.pause(); + wavesurfer.seekTo(0.0); + + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } +} +""" + +loop_region = """ +function loop_region() { + var element = document.getElementById("loop-button"); + var loop = element.textContent.includes("OFF"); + console.log(loop); + + try { + var region = Object.values(wavesurfer.regions.list)[0]; + region.loop = loop; + } catch {} + + if (loop) { + element.classList.remove("secondary"); + element.classList.add("primary"); + element.textContent = "Looping ON"; + } else { + element.classList.remove("primary"); + element.classList.add("secondary"); + element.textContent = "Looping OFF"; + } +} +""" + + +class Player: + def __init__(self, app): + self.app = app + + self.app.load(_js=load_wavesurfer_js) + self.app.css = CUSTOM_CSS + + self.wavs = [] + self.position = 0 + + def create(self): + gr.HTML(WAVESURFER) + gr.Markdown( + "Click and drag on the waveform above to select a region for playback. " + "Once created, the region can be moved around and resized. " + "Clear the regions using the button below. Hit play on one of the buttons below to start!" + ) + + with gr.Row(): + clear = gr.Button("Clear region") + loop = gr.Button("Looping OFF", elem_id="loop-button") + + loop.click(None, _js=loop_region) + clear.click(None, _js=clear_regions) + + gr.HTML("
") + + def add(self, name: str = "Play"): + i = self.position + self.wavs.append( + { + "audio": gr.Audio(visible=False), + "button": gr.Button(name, elem_classes=["playpause"]), + "position": i, + } + ) + self.wavs[-1]["button"].click(None, _js=play(i)) + self.position += 1 + return self.wavs[-1] + + def to_list(self): + return [x["audio"] for x in self.wavs] + + +############################################################ +### Keeping track of users, and CSS for the progress bar ### +############################################################ + +load_tracker = lambda name: """ +function load_name() { + function setCookie(name, value, exp_days) { + var d = new Date(); + d.setTime(d.getTime() + (exp_days*24*60*60*1000)); + var expires = "expires=" + d.toGMTString(); + document.cookie = name + "=" + value + ";" + expires + ";path=/"; + } + + function getCookie(name) { + var cname = name + "="; + var decodedCookie = decodeURIComponent(document.cookie); + var ca = decodedCookie.split(';'); + for(var i = 0; i < ca.length; i++){ + var c = ca[i]; + while(c.charAt(0) == ' '){ + c = c.substring(1); + } + if(c.indexOf(cname) == 0){ + return c.substring(cname.length, c.length); + } + } + return ""; + } + + name = getCookie("{name}"); + if (name == "") { + name = Math.random().toString(36).slice(2); + console.log(name); + setCookie("name", name, 30); + } + name = getCookie("{name}"); + return name; +} +""".replace( + "{name}", name +) + +# Progress bar + +progress_template = """ + + + + Progress Bar + + + +
+
+
{TEXT}
+
+ + +""" + + +def create_tracker(app, cookie_name="name"): + user = gr.Text(label="user", interactive=True, visible=False, elem_id="user") + app.load(_js=load_tracker(cookie_name), outputs=user) + return user + + +################################################################# +### CSS and HTML for labeling sliders for both ABX and MUSHRA ### +################################################################# + +slider_abx = """ + + + + + Labels Example + + + +
+
Prefer A
+
Toss-up
+
Prefer B
+
+ + +""" + +slider_mushra = """ + + + + + Labels Example + + + +
+
bad
+
poor
+
fair
+
good
+
excellent
+
+ + +""" + +######################################################### +### Handling loading audio and tracking session state ### +######################################################### + + +class Samples: + def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None): + files = find_audio(folder) + samples = defaultdict(lambda: defaultdict()) + + for f in files: + condition = f.parent.stem + samples[f.name][condition] = f + + self.samples = samples + self.names = list(samples.keys()) + self.filtered = False + self.current = 0 + + if shuffle: + random.shuffle(self.names) + + self.n_samples = len(self.names) if n_samples is None else n_samples + + def get_updates(self, idx, order): + key = self.names[idx] + return [gr.update(value=str(self.samples[key][o])) for o in order] + + def progress(self): + try: + pct = self.current / len(self) * 100 + except: # pragma: no cover + pct = 100 + text = f"On {self.current} / {len(self)} samples" + pbar = ( + copy.copy(progress_template) + .replace("{PROGRESS}", str(pct)) + .replace("{TEXT}", str(text)) + ) + return gr.update(value=pbar) + + def __len__(self): + return self.n_samples + + def filter_completed(self, user, save_path): + if not self.filtered: + done = [] + if Path(save_path).exists(): + with open(save_path, "r") as f: + reader = csv.DictReader(f) + done = [r["sample"] for r in reader if r["user"] == user] + self.names = [k for k in self.names if k not in done] + self.names = self.names[: self.n_samples] + self.filtered = True # Avoid filtering more than once per session. + + def get_next_sample(self, reference, conditions): + random.shuffle(conditions) + if reference is not None: + self.order = [reference] + conditions + else: + self.order = conditions + + try: + updates = self.get_updates(self.current, self.order) + self.current += 1 + done = gr.update(interactive=True) + pbar = self.progress() + except: + traceback.print_exc() + updates = [gr.update() for _ in range(len(self.order))] + done = gr.update(value="No more samples!", interactive=False) + self.current = len(self) + pbar = self.progress() + + return updates, done, pbar + + +def save_result(result, save_path): + with open(save_path, mode="a", newline="") as file: + writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys()))) + if file.tell() == 0: + writer.writeheader() + writer.writerow(result) diff --git a/flowae/models/ldm/dac/base.py b/flowae/models/ldm/dac/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ede7e8d87f4ec6ceedc94a4d2b9d75217adfe8fe --- /dev/null +++ b/flowae/models/ldm/dac/base.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + + self.padding = original_padding + return recons \ No newline at end of file diff --git a/flowae/models/ldm/dac/layers.py b/flowae/models/ldm/dac/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..a0cc6fb4021f2b34d2a1c9cee151a8576a8e5285 --- /dev/null +++ b/flowae/models/ldm/dac/layers.py @@ -0,0 +1,80 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * torch.pow(torch.sin(x * alpha), 2) +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x + +def get_activation(activation, channels, alpha): + if activation == "snake": + return Snake1d(channels) + elif activation == "relu": + return nn.ReLU() + elif activation == "leaky_relu": + return nn.LeakyReLU() + elif activation == "tanh": + return nn.Tanh() + elif activation == "snake_beta": + return SnakeBeta(channels, alpha) + else: + raise ValueError(f"Activation {activation} not supported") \ No newline at end of file diff --git a/flowae/models/ldm/dac/loss.py b/flowae/models/ldm/dac/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2a5fc6f38ea44ce666522ba96ec24751a3e4f1ee --- /dev/null +++ b/flowae/models/ldm/dac/loss.py @@ -0,0 +1,374 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], + window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 0.0, + log_weight: float = 1.0, + pow: float = 1.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0], + mel_fmax: List[float] = [None, None, None, None, None, None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + +def kl_loss(logs, m): + kl = 0.5 * (m**2 + torch.exp(logs) - logs - 1).sum(dim=1) + kl = torch.mean(kl) + return kl \ No newline at end of file diff --git a/flowae/models/ldm/dac/model.py b/flowae/models/ldm/dac/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa0e6c958442c3214bd49c6ce2b2643670db8b3 --- /dev/null +++ b/flowae/models/ldm/dac/model.py @@ -0,0 +1,729 @@ +import math +from typing import List, Union + +import numpy as np +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import weight_norm + +from .audiotools import AudioSignal, STFTParams, ml +from .audiotools.ml import BaseModel +from .base import CodecMixin +from .layers import WNConv1d, WNConvTranspose1d, get_activation + + +def init_weights(m, mean=0.0, std=0.02, init_type="xavier", gain=0.02): + """ + Initialize weights of the entire model using xavier_normal_ or kaiming_normal_. + Args: + m (nn.Module): The module to initialize. + mean (float): Mean for weight initialization. + std (float): Standard deviation for weight initialization. + init_type (str): Type of initialization ('xavier' or 'kaiming'). + gain (float): Gain for xavier initialization. + """ + classname = m.__class__.__name__ + + if init_type == "xavier": + # Handle convolutional layers + if "Depthwise_Separable" in classname: + nn.init.xavier_normal_(m.depth_conv.weight.data, gain=gain) + nn.init.xavier_normal_(m.point_conv.weight.data, gain=gain) + if hasattr(m.depth_conv, "bias") and m.depth_conv.bias is not None: + nn.init.zeros_(m.depth_conv.bias.data) + if hasattr(m.point_conv, "bias") and m.point_conv.bias is not None: + nn.init.zeros_(m.point_conv.bias.data) + elif classname.find("Conv") != -1: + nn.init.xavier_normal_(m.weight.data, gain=gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias.data) + + # Handle batch normalization layers + elif classname.find("BatchNorm") != -1: + if hasattr(m, "weight") and m.weight is not None: + nn.init.xavier_normal_(m.weight.data, gain=gain) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias.data) + + # Handle custom layers like Snake1d and SnakeBeta + elif classname == "Snake1d": + if hasattr(m, "alpha") and m.alpha is not None: + if m.alpha.data.dim() >= 2: + nn.init.xavier_normal_(m.alpha.data, gain=gain) + else: + nn.init.normal_(m.alpha.data, mean=1.0, std=std) + elif classname == "SnakeBeta": + # Respect the alpha_logscale setting in SnakeBeta + if hasattr(m, "alpha") and m.alpha is not None: + if m.alpha_logscale: + nn.init.constant_(m.alpha.data, 0.0) # Matches SnakeBeta's default + else: + nn.init.constant_(m.alpha.data, 1.0) + if hasattr(m, "beta") and m.beta is not None: + if m.alpha_logscale: + nn.init.constant_(m.beta.data, 0.0) # Matches SnakeBeta's default + else: + nn.init.constant_(m.beta.data, 1.0) + + # Handle residual scaling parameters + elif hasattr(m, "residual_scale") and m.residual_scale is not None: + nn.init.xavier_normal_(m.residual_scale.data, gain=gain) + + else: + # Kaiming initialization + if "Depthwise_Separable" in classname: + nn.init.kaiming_normal_( + m.depth_conv.weight.data, mode="fan_out", nonlinearity="relu" + ) + nn.init.kaiming_normal_( + m.point_conv.weight.data, mode="fan_out", nonlinearity="relu" + ) + elif classname.find("Conv") != -1: + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias.data) + elif classname.find("BatchNorm") != -1: + if hasattr(m, "weight") and m.weight is not None: + nn.init.normal_(m.weight.data, 1.0, std) + if hasattr(m, "bias") and m.bias is not None: + nn.init.zeros_(m.bias.data) + elif classname == "Snake1d": + if hasattr(m, "alpha") and m.alpha is not None: + nn.init.normal_(m.alpha.data, 1.0, std) + elif classname == "SnakeBeta": + if hasattr(m, "beta") and m.beta is not None: + nn.init.normal_(m.beta.data, 1.0, std) + elif ( + hasattr(m, "alpha") and m.alpha is not None + ): # Fallback if SnakeBeta uses alpha + nn.init.normal_(m.alpha.data, 1.0, std) + + elif hasattr(m, "residual_scale") and m.residual_scale is not None: + nn.init.normal_(m.residual_scale.data, 0.1, std) + + +class ResidualUnit(nn.Module): + def __init__( + self, + dim: int = 16, + dilation: int = 1, + activation: str = "snake", + alpha: float = 1.0, + scale_residual: bool = False, + ): + """ + Residual Unit with weight normalization and dilated convolutions. + Args: + dim (int): Number of input and output channels. + dilation (int): Dilation factor for the convolution. + activation (str): Activation function to use. + alpha (float): Scaling factor for the activation function. + """ + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + get_activation(activation=activation, channels=dim, alpha=alpha), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + get_activation(activation=activation, channels=dim, alpha=alpha), + WNConv1d(dim, dim, kernel_size=1), + ) + self.scale_residual = scale_residual + if self.scale_residual: + self.res_scale = nn.Parameter(torch.tensor(0.0)) # start at 0 + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + if self.scale_residual: + y = self.res_scale * y + return x + y + + +class EncoderBlock(nn.Module): + def __init__( + self, + dim: int = 16, + stride: int = 1, + activation: str = "snake", + alpha: float = 1.0, + scale_residual: bool = False, + ): + """ + Encoder block that downsamples the input and applies residual units. + """ + super().__init__() + self.block = nn.Sequential( + ResidualUnit( + dim // 2, + dilation=1, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ResidualUnit( + dim // 2, + dilation=3, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ResidualUnit( + dim // 2, + dilation=9, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + get_activation(activation=activation, channels=dim // 2, alpha=alpha), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + d_in: int = 1, + activation: str = "snake", + alpha: float = 1.0, + scale_residual: bool = False, + weight_init: str = "xavier", + gain: float = 1.0, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(d_in, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [ + EncoderBlock( + d_model, + stride=stride, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ) + ] + + # Create last convolution + self.block += [ + get_activation(activation=activation, channels=d_model, alpha=alpha), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.apply(lambda m: init_weights(m, init_type=weight_init, gain=gain)) + + def forward(self, x): + x = F.leaky_relu(x) + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + input_dim: int = 16, + output_dim: int = 8, + stride: int = 1, + norm: bool = False, + activation: str = "snake", + alpha: float = 1.0, + scale_residual: bool = False, + ): + """ + Decoder block that upsamples the input and applies residual units. + """ + super().__init__() + if not norm: + self.block = nn.Sequential( + get_activation(activation=activation, channels=input_dim, alpha=alpha), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=0 if stride % 2 == 0 else 1, + ), + ResidualUnit( + output_dim, + dilation=1, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ResidualUnit( + output_dim, + dilation=3, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ResidualUnit( + output_dim, + dilation=9, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ) + else: + self.block = nn.Sequential( + get_activation(activation=activation, channels=input_dim, alpha=alpha), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=0 if stride % 2 == 0 else 1, + ), + nn.BatchNorm1d(output_dim), + ResidualUnit( + output_dim, + dilation=1, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + nn.BatchNorm1d(output_dim), + ResidualUnit( + output_dim, + dilation=3, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + nn.BatchNorm1d(output_dim), + ResidualUnit( + output_dim, + dilation=9, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + norm: bool = False, + activation: str = "snake", + alpha: float = 1.0, + scale_residual: bool = False, + use_tanh_as_final: bool = True, + use_bias_at_final: bool = True, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [ + DecoderBlock( + input_dim, + output_dim, + stride, + norm=norm, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ) + ] + + # Add final conv layer + layers += [ + get_activation(activation=activation, channels=output_dim, alpha=alpha), + WNConv1d( + output_dim, d_out, kernel_size=7, padding=3, bias=use_bias_at_final + ), + nn.Tanh() if use_tanh_as_final else nn.Identity(), + ] + self.use_tanh_as_final = use_tanh_as_final + + self.model = nn.Sequential(*layers) + + def forward(self, x): + x = self.model(x) + if not self.use_tanh_as_final: + x = torch.clamp( + x, min=-1.0, max=1.0 + ) # Ensure output is within [-1, 1] range + return x + + +class DACVAE(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 5, 8], + latent_dim: int = 64, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 5, 4, 2], + sample_rate: int = 44100, + d_in: int = 2, + d_out: int = 2, + weight_init: str = "xavier", + norm: bool = False, + activation: str = "snake", + alpha: float = 1.0, + gain: float = 0.02, + scale_residual: bool = False, + use_tanh_as_final: bool = True, + use_bias_at_final: bool = True, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + self.d_in = d_in + self.d_out = d_out + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder( + encoder_dim, + encoder_rates, + latent_dim, + d_in=d_in, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + d_out=d_out, + norm=norm, + activation=activation, + alpha=alpha, + scale_residual=scale_residual, + use_tanh_as_final=use_tanh_as_final, + use_bias_at_final=use_bias_at_final, + ) + + self.en_conv_post = WNConv1d( + self.latent_dim, 2 * self.latent_dim, kernel_size=1 + ) + + self.de_conv_pre = WNConv1d(self.latent_dim, self.latent_dim, kernel_size=1) + + self.sample_rate = sample_rate + self.apply(lambda m: init_weights(m, init_type=weight_init, gain=gain)) + self.step = 0 # Initialize step counter for noise decay + + def freeze_encoder(self): + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.en_conv_post.parameters(): + param.requires_grad = False + print("Encoder and en_conv_post frozen") + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + length = audio_data.shape[-1] + # print(f"Audio length: {length}", "math.ceil(length / self.hop_length) * self.hop_length: ", math.ceil(length / self.hop_length) * self.hop_length) + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + training: bool = True, + ): + x = self.encoder(audio_data) + x = self.en_conv_post(x) + m, logs = torch.split(x, self.latent_dim, dim=1) + logs = torch.clamp(logs, min=-14.0, max=14.0) + + z = m + torch.randn_like(m) * torch.exp(logs) + + return z, m, logs + + def decode(self, z: torch.Tensor): + z = self.de_conv_pre(z) + z = self.decoder(z) + return z + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = 24000, + ): + # print(f"Audio data shape: {audio_data.shape}") + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, m, logs = self.encode(audio_data) + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "mu": m, + "logs": logs, + } + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + d_in: int = 1, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/flowae/models/ldm/dac/utils.py b/flowae/models/ldm/dac/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e7724a9dde8a937e3c4e06146707915625d63595 --- /dev/null +++ b/flowae/models/ldm/dac/utils.py @@ -0,0 +1,45 @@ +import torch.nn as nn + + +from models import register +from .model import Encoder, Decoder, WNConv1d + + +default_configs = { + 'snake': dict( + encoder_dim=64, + encoder_rates=[2, 4, 5, 8], + latent_dim=64, + d_in=1, + activation='snake', + ), + 'snake': dict( + encoder_dim=64, + encoder_rates=[2, 4, 5, 8], + latent_dim=64, + d_in=1, + activation='snakebeta', + ), +} + + +@register('dac_encoder') +def make_dac_encoder(config_name, **kwargs): + encoder_kwargs = default_configs[config_name] + encoder_kwargs.update(kwargs) + latent_dim = encoder_kwargs['latent_dim'] + return nn.Sequential( + Encoder(**encoder_kwargs), + WNConv1d(latent_dim, latent_dim, kernel_size=1), + ) + + +@register('vqgan_decoder') +def make_vqgan_decoder(config_name, **kwargs): + decoder_kwargs = default_configs[config_name] + decoder_kwargs.update(kwargs) + latent_dim = decoder_kwargs['latent_dim'] + return nn.Sequential( + WNConv1d(latent_dim, latent_dim, kernel_size=1), + Decoder(**decoder_kwargs), + ) diff --git a/flowae/models/ldm/dito.py b/flowae/models/ldm/dito.py new file mode 100644 index 0000000000000000000000000000000000000000..1559d99cb095c56de1991099234f613cb5a93984 --- /dev/null +++ b/flowae/models/ldm/dito.py @@ -0,0 +1,180 @@ +import copy +import math + +import torch + +import models +from omegaconf import OmegaConf +from models import register +from models.ldm.ldm_base import LDMBase +from models.ldm.vqgan.lpips import LPIPS + + +@register('dito') +class DiTo(LDMBase): + + def __init__(self, render_diffusion, render_sampler, render_n_steps, renderer_guidance=1, lpips=False, **kwargs): + super().__init__(**kwargs) + self.render_diffusion = models.make(render_diffusion) + + if OmegaConf.is_config(render_sampler): + render_sampler = OmegaConf.to_container(render_sampler, resolve=True) + render_sampler = copy.deepcopy(render_sampler) + if render_sampler.get('args') is None: + render_sampler['args'] = {} + render_sampler['args']['diffusion'] = self.render_diffusion + self.render_sampler = models.make(render_sampler) + self.render_n_steps = render_n_steps + self.renderer_guidance = renderer_guidance + + self.t_loss_monitor_v = [0 for _ in range(10)] + self.t_loss_monitor_n = [0 for _ in range(10)] + self.t_loss_monitor_decay = 0.99 + + self.use_lpips = lpips + if lpips: + self.lpips_loss = LPIPS().eval() + + def render(self, z_dec, coord, scale): + shape = (coord.size(0), 3, coord.size(2), coord.size(3)) + net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': z_dec} + + if self.use_ema_renderer: + self.swap_ema_renderer() + + if self.renderer_guidance > 1: + uncond_z_dec = self.drop_z_emb.unsqueeze(0).expand(z_dec.shape[0], -1, -1, -1) + uncond_net_kwargs = {'coord': coord, 'scale': scale, 'z_dec': uncond_z_dec} + else: + uncond_net_kwargs = None + + ret = self.render_sampler.sample( + net=self.renderer, + shape=shape, + n_steps=self.render_n_steps, + net_kwargs=net_kwargs, + uncond_net_kwargs=uncond_net_kwargs, + guidance=self.renderer_guidance, + ) + + if self.use_ema_renderer: + self.swap_ema_renderer() + + return ret + + def forward(self, data, mode, has_optimizer=None): + if mode in ['z', 'z_dec']: + ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) + return ret_z + + grad = self.get_grad_plan(has_optimizer) + loss_config = self.loss_config + print('mode', mode) + if mode == 'pred': + z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) + + gt_patch = data['gt'][:, :3, ...] + coord = data['gt'][:, 3:5, ...] + scale = data['gt'][:, 5:7, ...] + + if grad['renderer']: + return self.render(z_dec, coord, scale) + else: + with torch.no_grad(): + return self.render(z_dec, coord, scale) + + elif mode == 'loss': + if not grad['renderer']: # Only training zdm + print('not grad[renderer]') + _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) + return ret + + gt_patch = data['gt'][:, :3, ...] + coord = data['gt'][:, 3:5, ...] + scale = data['gt'][:, 5:7, ...] + + z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) + net_kwargs = {'z_dec': z_dec} + + print('latent z_dec shape: ', z_dec.shape) + + t = torch.rand(gt_patch.shape[0], device=gt_patch.device) + + print('self.gt_noise_lb:', self.gt_noise_lb) + if self.gt_noise_lb is not None: + tmin = torch.ones_like(t) * self.gt_noise_lb + tmax = torch.ones_like(t) * 1 + t = tmin + (tmax - tmin) * torch.rand_like(tmin) + + print('self.zaug_p:', self.zaug_p) + print('self.training:', self.training) + + if (self.zaug_p is not None) and self.training: + tz = self._tz + mask_aug = self._mask_aug + + typ = self.zaug_decoding_loss_type + if typ == 'all': + tmin = torch.ones_like(tz) * 0 + tmax = torch.ones_like(tz) * 1 + elif typ == 'suffix': + tmin = tz + tmax = torch.ones_like(tz) * 1 + elif typ == 'tz': + tmin = tz + tmax = tz + elif typ == 'tmax': + tmin = torch.ones_like(tz) * 1 + tmax = torch.ones_like(tz) * 1 + else: + raise NotImplementedError + t_aug = tmin + (tmax - tmin) * torch.rand_like(tmin) + + t = mask_aug * t_aug + (1 - mask_aug) * t + print('self.use_lpips:', self.use_lpips) + if not self.use_lpips: + loss, t = self.render_diffusion.loss( + net=self.renderer, + x=gt_patch, + t=t, + net_kwargs=net_kwargs, + return_loss_unreduced=True + ) + else: + loss, t, x_t, pred = self.render_diffusion.loss( + net=self.renderer, + x=gt_patch, + t=t, + net_kwargs=net_kwargs, + return_loss_unreduced=True, + return_all=True + ) + + sample_pred = x_t + t.view(-1, 1, 1, 1) * pred + lpips_loss = self.lpips_loss(sample_pred, gt_patch).mean() + ret['lpips_loss'] = lpips_loss.item() + lpips_loss_w = loss_config.get('lpips_loss', 1) + ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w + + # Visualize diffusion network loss for different timesteps # + if self.training: + m = len(self.t_loss_monitor_v) + for i in range(len(loss)): + q = min(math.floor(t[i].item() * m), m - 1) + self.t_loss_monitor_v[q] = self.t_loss_monitor_v[q] * self.t_loss_monitor_decay + loss[i].item() * (1 - self.t_loss_monitor_decay) + self.t_loss_monitor_n[q] += 1 + for q in range(m): + if self.t_loss_monitor_n[q] > 0: + if self.t_loss_monitor_n[q] < 500: + r = 1 - math.pow(self.t_loss_monitor_decay, self.t_loss_monitor_n[q]) + else: + r = 1 + ret[f'_loss_t{q}'] = self.t_loss_monitor_v[q] / r + # - # + + dae_loss = loss.mean() + + ret['dae_loss'] = dae_loss.item() + dae_loss_w = loss_config.get('dae_loss', 1) + ret['loss'] = ret['loss'] + dae_loss * dae_loss_w + return ret diff --git a/flowae/models/ldm/glpto.py b/flowae/models/ldm/glpto.py new file mode 100644 index 0000000000000000000000000000000000000000..03cddc0747222259b0a2da5c92362c6dff1dde78 --- /dev/null +++ b/flowae/models/ldm/glpto.py @@ -0,0 +1,137 @@ +import os + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +from models import register +from models.ldm.ldm_base import LDMBase +from models.ldm.vqgan.lpips import LPIPS +from models.ldm.vqgan.discriminator import make_discriminator + + +@register('glpto') +class GLPTo(LDMBase): + + def __init__(self, lpips=True, disc=True, adaptive_gan_weight=True, noise_render=False, **kwargs): + super().__init__(**kwargs) + if lpips: + self.lpips_loss = LPIPS().eval() + self.disc = make_discriminator(input_nc=3) if disc else None + self.adaptive_gan_weight = adaptive_gan_weight + self.noise_render = noise_render + + def get_parameters(self, name): + if name == 'disc': + return self.disc.parameters() + else: + return super().get_parameters(name) + + def render(self, z_dec, coord, scale): + if not self.noise_render: + return self.renderer(z_dec, coord=coord, scale=scale) + else: + shape = (coord.shape[0], 3, coord.shape[2], coord.shape[3]) + noise = torch.randn(shape, device=z_dec.device) + return self.renderer(noise, coord=coord, scale=scale, z_dec=z_dec) + + def forward(self, data, mode, has_optimizer=None, use_gan=False): + if mode in ['z', 'z_dec']: + ret_z, _ = super().forward(data, mode=mode, has_optimizer=has_optimizer) + return ret_z + + grad = self.get_grad_plan(has_optimizer) + loss_config = self.loss_config + + if mode == 'pred': + z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) + + gt_patch = data['gt'][:, :3, ...] + coord = data['gt'][:, 3:5, ...] + scale = data['gt'][:, 5:7, ...] + + if grad['renderer']: + return self.render(z_dec, coord, scale) + else: + with torch.no_grad(): + return self.render(z_dec, coord, scale) + + elif mode == 'loss': + if not grad['renderer']: # Only training zdm + _, ret = super().forward(data, mode='z', has_optimizer=has_optimizer) + return ret + + gt_patch = data['gt'][:, :3, ...] + coord = data['gt'][:, 3:5, ...] + scale = data['gt'][:, 5:7, ...] + + z_dec, ret = super().forward(data, mode='z_dec', has_optimizer=has_optimizer) + pred = self.render(z_dec, coord, scale) + + l1_loss = torch.abs(pred - gt_patch).mean() + ret['l1_loss'] = l1_loss.item() + l1_loss_w = loss_config.get('l1_loss', 1) + ret['loss'] = ret['loss'] + l1_loss * l1_loss_w + + lpips_loss = self.lpips_loss(pred, gt_patch).mean() + ret['lpips_loss'] = lpips_loss.item() + lpips_loss_w = loss_config.get('lpips_loss', 1) + ret['loss'] = ret['loss'] + lpips_loss * lpips_loss_w + + if use_gan: + logits_fake = self.disc(pred) + + gan_g_loss = -torch.mean(logits_fake) + ret['gan_g_loss'] = gan_g_loss.item() + weight = loss_config.get('gan_g_loss', 1) + + if self.training and self.adaptive_gan_weight: + nll_loss = l1_loss * l1_loss_w + lpips_loss * lpips_loss_w + adaptive_gan_w = self.calculate_adaptive_gan_w(nll_loss, gan_g_loss, self.renderer.get_last_layer_weight()) + ret['adaptive_gan_w'] = adaptive_gan_w.item() + weight = weight * adaptive_gan_w + + ret['loss'] = ret['loss'] + gan_g_loss * weight + + return ret + + elif mode == 'disc_loss': + gt_patch = data['gt'][:, :3, ...] + coord = data['gt'][:, 3:5, ...] + scale = data['gt'][:, 5:7, ...] + + with torch.no_grad(): + z_dec, _ = super().forward(data, mode='z_dec', has_optimizer=None) + pred = self.render(z_dec, coord, scale) + + logits_real = self.disc(gt_patch) + logits_fake = self.disc(pred) + + disc_loss_type = loss_config.get('disc_loss_type', 'hinge') + if disc_loss_type == 'hinge': + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + loss = (loss_real + loss_fake) / 2 + elif disc_loss_type == 'vanilla': + loss_real = torch.mean(F.softplus(-logits_real)) + loss_fake = torch.mean(F.softplus(logits_fake)) + loss = (loss_real + loss_fake) / 2 + + return { + 'loss': loss, + 'disc_logits_real': logits_real.mean().item(), + 'disc_logits_fake': logits_fake.mean().item(), + } + + def calculate_adaptive_gan_w(self, nll_loss, g_loss, last_layer): + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + world_size = int(os.environ.get('WORLD_SIZE', '1')) + if world_size > 1: + dist.all_reduce(nll_grads, op=dist.ReduceOp.SUM) + nll_grads.div_(world_size) + dist.all_reduce(g_grads, op=dist.ReduceOp.SUM) + g_grads.div_(world_size) + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + return d_weight diff --git a/flowae/models/ldm/ldm_base.py b/flowae/models/ldm/ldm_base.py new file mode 100644 index 0000000000000000000000000000000000000000..69e8650214e2126f4d49dd6a3be59b2728647d2f --- /dev/null +++ b/flowae/models/ldm/ldm_base.py @@ -0,0 +1,444 @@ +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +from omegaconf import OmegaConf + +import models +from models.ldm.vqgan.quantizer import VectorQuantizer + + +class LDMBase(nn.Module): + + def __init__( + self, + encoder, + z_shape, + decoder, + renderer, + encoder_ema_rate=None, + decoder_ema_rate=None, + renderer_ema_rate=None, + z_gaussian=False, + z_gaussian_sample=True, + z_quantizer=False, + z_quantizer_n_embed=8192, + z_quantizer_beta=0.25, + z_layernorm=False, + zaug_p=None, + zaug_tmax=1.0, + zaug_tmax_always=False, + zaug_decoding_loss_type='all', + zaug_zdm_diffusion=None, + gt_noise_lb=None, + drop_z_p=0.0, + zdm_net=None, + zdm_diffusion=None, + zdm_sampler=None, + zdm_n_steps=None, + zdm_ema_rate=0.9999, + zdm_train_normalize=False, + zdm_class_cond=None, + zdm_force_guidance=None, + loss_config=None, + use_ema_encoder=False, + use_ema_decoder=False, + use_ema_renderer=False, + ): + super().__init__() + self.loss_config = loss_config if loss_config is not None else dict() + + self.encoder = models.make(encoder) + self.decoder = models.make(decoder) + self.renderer = models.make(renderer) + + self.z_shape = tuple(z_shape) + + self.z_gaussian = z_gaussian + self.z_gaussian_sample = z_gaussian_sample + + self.z_quantizer = VectorQuantizer( + z_quantizer_n_embed, + z_shape[0], + beta=z_quantizer_beta, + remap=None, + sane_index_shape=False + ) if z_quantizer else None + + self.z_layernorm = nn.LayerNorm( + list(z_shape), + elementwise_affine=False + ) if z_layernorm else None + + self.zaug_p = zaug_p + self.zaug_tmax = zaug_tmax + self.zaug_tmax_always = zaug_tmax_always + self.zaug_decoding_loss_type = zaug_decoding_loss_type + if zaug_zdm_diffusion is not None: + self.zaug_zdm_diffusion = models.make(zaug_zdm_diffusion) + + self.drop_z_p = drop_z_p + if self.drop_z_p > 0: + self.drop_z_emb = nn.Parameter(torch.zeros(z_shape[0], z_shape[1], z_shape[2]), requires_grad=False) + + self.gt_noise_lb = gt_noise_lb + + # EMA models # + self.encoder_ema_rate = encoder_ema_rate + if self.encoder_ema_rate is not None: + self.encoder_ema = copy.deepcopy(self.encoder) + for p in self.encoder_ema.parameters(): + p.requires_grad = False + + self.decoder_ema_rate = decoder_ema_rate + if self.decoder_ema_rate is not None: + self.decoder_ema = copy.deepcopy(self.decoder) + for p in self.decoder_ema.parameters(): + p.requires_grad = False + + self.renderer_ema_rate = renderer_ema_rate + if self.renderer_ema_rate is not None: + self.renderer_ema = copy.deepcopy(self.renderer) + for p in self.renderer_ema.parameters(): + p.requires_grad = False + # - # + + # z DM # + if zdm_diffusion is not None: + self.zdm_diffusion = models.make(zdm_diffusion) + + if OmegaConf.is_config(zdm_sampler): + zdm_sampler = OmegaConf.to_container(zdm_sampler, resolve=True) + zdm_sampler = copy.deepcopy(zdm_sampler) + if zdm_sampler.get('args') is None: + zdm_sampler['args'] = {} + zdm_sampler['args']['diffusion'] = self.zdm_diffusion + self.zdm_sampler = models.make(zdm_sampler) + self.zdm_n_steps = zdm_n_steps + + self.zdm_net = models.make(zdm_net) + + self.zdm_net_ema = copy.deepcopy(self.zdm_net) + for p in self.zdm_net_ema.parameters(): + p.requires_grad = False + self.zdm_ema_rate = zdm_ema_rate + + self.zdm_class_cond = zdm_class_cond + + self.zdm_force_guidance = zdm_force_guidance + else: + self.zdm_diffusion = None + + self.zdm_train_normalize = zdm_train_normalize + if zdm_train_normalize: + self.register_buffer('zdm_Ez_v', torch.tensor(0.)) + self.register_buffer('zdm_Ez_n', torch.tensor(0.)) + self.register_buffer('zdm_Ez2_v', torch.tensor(0.)) + self.register_buffer('zdm_Ez2_n', torch.tensor(0.)) + # - # + + self.use_ema_encoder = use_ema_encoder + self.use_ema_decoder = use_ema_decoder + self.use_ema_renderer = use_ema_renderer + + def get_parameters(self, name): + if name == 'encoder': + return self.encoder.parameters() + elif name == 'decoder': + p = list(self.decoder.parameters()) + if self.z_quantizer is not None: + p += list(self.z_quantizer.parameters()) + return p + elif name == 'renderer': + return self.renderer.parameters() + elif name == 'zdm': + return self.zdm_net.parameters() + + def encode(self, x, return_loss=False, ret=None): + if self.use_ema_encoder: + self.swap_ema_encoder() + + z = self.encoder(x) + + if self.use_ema_encoder: + self.swap_ema_encoder() + + if self.z_gaussian: + print('doing zzzzz_gaussian') + posterior = DiagonalGaussianDistribution(z) + if self.z_gaussian_sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl().mean() + + if ret is not None: + ret['z_gau_mean_abs'] = posterior.mean.abs().mean().item() + ret['z_gau_std'] = posterior.std.mean().item() + else: + kl_loss = None + + if self.z_layernorm is not None: + z = self.z_layernorm(z) + + if (self.zaug_p is not None) and self.training: + assert self.z_layernorm is not None # ensure 0 mean 1 std + if self.zaug_tmax_always: + tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax + else: + tz = torch.rand(z.shape[0], device=z.device) * self.zaug_tmax + zt, _ = self.zaug_zdm_diffusion.add_noise(z, tz) + mask_aug = (torch.rand(z.shape[0], device=z.device) < self.zaug_p).float() + z = mask_aug.view(-1, 1, 1, 1) * zt + (1 - mask_aug).view(-1, 1, 1, 1) * z + self._tz = tz + self._mask_aug = mask_aug + + if return_loss: + print('kl_loss', kl_loss) + return z, kl_loss + else: + return z + + def decode(self, z, return_loss=False): + if self.z_quantizer is not None: + z, quant_loss, _ = self.z_quantizer(z) + else: + quant_loss = None + + if self.use_ema_decoder: + self.swap_ema_decoder() + + z_dec = self.decoder(z) + + if self.use_ema_decoder: + self.swap_ema_decoder() + + if return_loss: + return z_dec, quant_loss + else: + return z_dec + + def render(self, z_dec, coord, cell): + raise NotImplementedError + + def normalize_for_zdm(self, z): + if self.zdm_train_normalize: + mean = self.zdm_Ez_v + var = self.zdm_Ez2_v - mean ** 2 + return (z - mean) / torch.sqrt(var) + else: + return z + + def denormalize_for_zdm(self, z): + if self.zdm_train_normalize: + mean = self.zdm_Ez_v + var = self.zdm_Ez2_v - mean ** 2 + return z * torch.sqrt(var) + mean + else: + return z + + def forward(self, data, mode, has_optimizer=None): + grad = self.get_grad_plan(has_optimizer) + loss = torch.tensor(0., device=data['inp'].device) + loss_config = self.loss_config + ret = dict() + + # Encoder + if grad['encoder']: + print('doing kl loss') + z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) + + # if self.z_gaussian: + # print('doing z_gaussian') + # ret['kl_loss'] = kl_loss.item() + # loss = loss + kl_loss * loss_config.get('kl_loss', 0.0) + else: + print('not doing kl loss') + with torch.no_grad(): + z, kl_loss = self.encode(data['inp'], return_loss=True, ret=ret) + + if self.training and self.drop_z_p > 0: + drop_mask = (torch.rand(z.shape[0], device=z.device) < self.drop_z_p).to(z.dtype) + z = drop_mask.view(-1, 1, 1, 1) * self.drop_z_emb.unsqueeze(0) + (1 - drop_mask).view(-1, 1, 1, 1) * z + + # Z DM + if grad['zdm']: + print('doing zdm loss') + if self.zdm_train_normalize and self.training: + self.zdm_Ez_v = ( + self.zdm_Ez_v * (self.zdm_Ez_n / (self.zdm_Ez_n + 1)) + + z.mean().item() / (self.zdm_Ez_n + 1) + ) + self.zdm_Ez_n = self.zdm_Ez_n + 1 + + self.zdm_Ez2_v = ( + self.zdm_Ez2_v * (self.zdm_Ez2_n / (self.zdm_Ez2_n + 1)) + + (z ** 2).mean().item() / (self.zdm_Ez2_n + 1) + ) + self.zdm_Ez2_n = self.zdm_Ez2_n + 1 + + ret['normalize_z_mean'] = self.zdm_Ez_v.item() + ret['normalize_z_std'] = math.sqrt((self.zdm_Ez2_v - self.zdm_Ez_v ** 2).item()) + + z_for_dm = self.normalize_for_zdm(z) + + net_kwargs = dict() + if self.zdm_class_cond is not None: + net_kwargs['class_labels'] = data['class_labels'] + + zdm_loss = self.zdm_diffusion.loss(self.zdm_net, z_for_dm, net_kwargs=net_kwargs) + ret['zdm_loss'] = zdm_loss.item() + loss = loss + zdm_loss * loss_config.get('zdm_loss', 1.0) + + if not self.training: + ret['zdm_ema_loss'] = self.zdm_diffusion.loss(self.zdm_net_ema, z_for_dm, net_kwargs=net_kwargs).item() + + # Decoder + if mode == 'z': + print('doing z mode') + ret_z = z + elif mode == 'z_dec': + print('doing z_dec mode') + if grad['decoder']: + print('doing z_dec mode with grad') + z_dec, quant_loss = self.decode(z, return_loss=True) + else: + print('doing z_dec mode without grad') + with torch.no_grad(): + z_dec, quant_loss = self.decode(z, return_loss=True) + ret_z = z_dec + + # if self.z_quantizer is not None: + # print('doing quant_loss') + # ret['quant_loss'] = quant_loss.item() + # loss = loss + quant_loss * loss_config.get('quant_loss', 1.0) + + ret['loss'] = loss + return ret_z, ret + + def get_grad_plan(self, has_optimizer): + if has_optimizer is None: + has_optimizer = dict() + grad = dict() + grad['encoder'] = has_optimizer.get('encoder', False) + grad['decoder'] = grad['encoder'] or has_optimizer.get('decoder', False) + grad['renderer'] = grad['decoder'] or has_optimizer.get('renderer', False) + grad['zdm'] = has_optimizer.get('zdm', False) # not in chain definition + return grad + + def update_ema_fn(self, net_ema, net, rate): + if rate != 1: + for ema_p, cur_p in zip(net_ema.parameters(), net.parameters()): + ema_p.data.lerp_(cur_p.data, 1 - rate) + + def update_ema(self): + if self.encoder_ema_rate is not None: + self.update_ema_fn(self.encoder_ema, self.encoder, self.encoder_ema_rate) + if self.decoder_ema_rate is not None: + self.update_ema_fn(self.decoder_ema, self.decoder, self.decoder_ema_rate) + if self.renderer_ema_rate is not None: + self.update_ema_fn(self.renderer_ema, self.renderer, self.renderer_ema_rate) + if (self.zdm_diffusion is not None) and (self.zdm_ema_rate is not None): + self.update_ema_fn(self.zdm_net_ema, self.zdm_net, self.zdm_ema_rate) + + def generate_samples( + self, + batch_size, + n_steps, + net_kwargs=None, + uncond_net_kwargs=None, + ema=False, + guidance=1.0, + noise=None, + render_res=(256, 256), + return_z=False, + ): + if self.zdm_force_guidance is not None: + guidance = self.zdm_force_guidance + + shape = (batch_size,) + self.z_shape + net = self.zdm_net if not ema else self.zdm_net_ema + + z = self.zdm_sampler.sample( + net, + shape, + n_steps, + net_kwargs=net_kwargs, + uncond_net_kwargs=uncond_net_kwargs, + guidance=guidance, + noise=noise, + ) + + if return_z: + return z + + if (self.zaug_p is not None) and self.zaug_tmax_always: + tz = torch.ones(z.shape[0], device=z.device) * self.zaug_tmax + z, _ = self.zaug_zdm_diffusion.add_noise(z, tz) + + z = self.denormalize_for_zdm(z) + z_dec = self.decode(z) + + coord = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) + scale = torch.zeros(batch_size, 2, render_res[0], render_res[1], device=z_dec.device) + return self.render(z_dec, coord, scale) + + def swap_ema_encoder(self): + _ = self.encoder + self.encoder = self.encoder_ema + self.encoder_ema = _ + + def swap_ema_decoder(self): + _ = self.decoder + self.decoder = self.decoder_ema + self.decoder_ema = _ + + def swap_ema_renderer(self): + _ = self.renderer + self.renderer = self.renderer_ema + self.renderer_ema = _ + + +class DiagonalGaussianDistribution(object): + + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean diff --git a/flowae/models/ldm/renderers.py b/flowae/models/ldm/renderers.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fcd2189dd0ad26fd313f738860d7d5eda39b08 --- /dev/null +++ b/flowae/models/ldm/renderers.py @@ -0,0 +1,18 @@ +import torch.nn as nn + +import models +from models import register + + +@register('fixres_renderer_wrapper') +class FixresRendererWrapper(nn.Module): + + def __init__(self, net): + super().__init__() + self.net = models.make(net) + + def forward(self, x, coord=None, scale=None, **kwargs): + return self.net(x, **kwargs) + + def get_last_layer_weight(self): + return self.net.get_last_layer_weight() diff --git a/flowae/models/ldm/vqgan/__init__.py b/flowae/models/ldm/vqgan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16281fe0b66dbac563229823d656ef173736e306 --- /dev/null +++ b/flowae/models/ldm/vqgan/__init__.py @@ -0,0 +1 @@ +from .utils import * diff --git a/flowae/models/ldm/vqgan/discriminator.py b/flowae/models/ldm/vqgan/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..b420d617d08bc2625944f8f8dead7aa60c866808 --- /dev/null +++ b/flowae/models/ldm/vqgan/discriminator.py @@ -0,0 +1,154 @@ +import functools +import torch +import torch.nn as nn + + +def make_discriminator(**kwargs): + return NLayerDiscriminator(**kwargs).apply(weights_init) + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/flowae/models/ldm/vqgan/lpips.py b/flowae/models/ldm/vqgan/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..289e50e34c8945a944486d7f0ee04f76de68eb78 --- /dev/null +++ b/flowae/models/ldm/vqgan/lpips.py @@ -0,0 +1,113 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, ckpt='load/vgg_lpips.pth', use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained(ckpt) + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, ckpt): + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + # print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + if pretrained: + vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + else: + vgg_pretrained_features = models.vgg16().features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) diff --git a/flowae/models/ldm/vqgan/model.py b/flowae/models/ldm/vqgan/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6b82e6314bfd43a87ad48c45db6434c60bbc8662 --- /dev/null +++ b/flowae/models/ldm/vqgan/model.py @@ -0,0 +1,845 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from models import register + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, kernel_size=3, conv_shortcut=False, + dropout, temb_channels=512, normalize=True): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) if normalize else torch.nn.Identity() + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) if normalize else torch.nn.Identity() + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8,16), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + print('ch_mult: ', ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + print('num_resolutions: ', self.num_resolutions) + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + z_out = 2*z_channels if double_z else z_channels + print('z_out: ', z_out) + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + z_out, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + print('encoder h shape: ', h.shape) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + if not self.give_pre_end: + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +@register('simple_renderer_net') +class SimpleRendererNet(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels=3, kernel_size=3, normalize=True, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, hidden_channels, kernel_size, padding=(kernel_size - 1) // 2), + ResnetBlock(in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=kernel_size, + temb_channels=0, dropout=0.0, normalize=normalize), + ResnetBlock(in_channels=hidden_channels, + out_channels=hidden_channels, + kernel_size=kernel_size, + temb_channels=0, dropout=0.0, normalize=normalize)]) + self.norm_out = Normalize(hidden_channels) if normalize else torch.nn.Identity() + self.conv_out = torch.nn.Conv2d(hidden_channels, out_channels, kernel_size, padding=(kernel_size - 1) // 2) + + def get_last_layer_weight(self): + return self.conv_out.weight + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +@register('vqgan_last_conv') +class VQGANLastConv(nn.Module): + def __init__(self, in_channels, out_channels=3): + super().__init__() + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def get_last_layer_weight(self): + return self.conv_out.weight + + def forward(self, x): + x = self.norm_out(x) + x = nonlinearity(x) + x = self.conv_out(x) + return x + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/flowae/models/ldm/vqgan/quantizer.py b/flowae/models/ldm/vqgan/quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..863475a68248d28775708875189144d2704540cc --- /dev/null +++ b/flowae/models/ldm/vqgan/quantizer.py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q diff --git a/flowae/models/ldm/vqgan/utils.py b/flowae/models/ldm/vqgan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f895a1ab1d3d32cb6cdae7ef1bfca5765bd16b35 --- /dev/null +++ b/flowae/models/ldm/vqgan/utils.py @@ -0,0 +1,57 @@ +import torch.nn as nn + + +from models import register +from .model import Encoder, Decoder + + +default_configs = { + 'f8c4': dict( + double_z=False, + z_channels=64, + resolution=256, + in_channels=3, + out_ch=3, + ch=128, + ch_mult=[1, 2, 2, 4, 4, 4, 4, 8, 8], + num_res_blocks=2, + attn_resolutions=[], + dropout=0.0, + give_pre_end=True, + ), + 'f16c8': dict( + double_z=False, + z_channels=8, + resolution=256, + in_channels=3, + out_ch=3, + ch=128, + ch_mult=[1, 2, 4, 4, 4], + num_res_blocks=2, + attn_resolutions=[], + dropout=0.0, + give_pre_end=True, + ), +} + + +@register('vqgan_encoder') +def make_vqgan_encoder(config_name, **kwargs): + encoder_kwargs = default_configs[config_name] + encoder_kwargs.update(kwargs) + enc_out_channels = encoder_kwargs['z_channels'] * (2 if encoder_kwargs['double_z'] else 1) + return nn.Sequential( + Encoder(**encoder_kwargs), + nn.Conv2d(enc_out_channels, enc_out_channels, 1), + ) + + +@register('vqgan_decoder') +def make_vqgan_decoder(config_name, **kwargs): + decoder_kwargs = default_configs[config_name] + decoder_kwargs.update(kwargs) + dec_in_channels = decoder_kwargs['z_channels'] + return nn.Sequential( + nn.Conv2d(dec_in_channels, dec_in_channels, 1), + Decoder(**decoder_kwargs), + ) diff --git a/flowae/models/models.py b/flowae/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..56e8d189cb0418a75de81461dadb472691590cca --- /dev/null +++ b/flowae/models/models.py @@ -0,0 +1,44 @@ +import torch + + +models = dict() + + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + + +def load_sd_from_ckpt(ckpt, keys_only=None): + sd = torch.load(ckpt, map_location='cpu')['model']['sd'] + if keys_only is not None: + keys_only_dot = tuple([_ + '.' for _ in keys_only]) + keys_only = set(keys_only) + for k in list(sd.keys()): + if not (k in keys_only or k.startswith(keys_only_dot)): + sd.pop(k) + return sd + + +def make(spec, load_sd=False): + args = spec.get('args') + if args is None: + args = dict() + model = models[spec['name']](**args) + print('args', args) + + if spec.get('load_ckpt') is not None: + sd = load_sd_from_ckpt(spec['load_ckpt'], spec.get('load_ckpt_keys_only')) + model.load_state_dict(sd, strict=False) + + if load_sd: + model.load_state_dict(spec['sd']) + + return model + + +@register('identity') +def make_identity(): + return torch.nn.Identity() diff --git a/flowae/models/networks/__init__.py b/flowae/models/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8014d9824e1b32c71956711e6dcd307e9a4c5920 --- /dev/null +++ b/flowae/models/networks/__init__.py @@ -0,0 +1,2 @@ +from . import consistency_decoder_unet +from . import dit \ No newline at end of file diff --git a/flowae/models/networks/consistency_decoder_unet.py b/flowae/models/networks/consistency_decoder_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..2a729348e88d988901cb7cada77eb699523c603b --- /dev/null +++ b/flowae/models/networks/consistency_decoder_unet.py @@ -0,0 +1,268 @@ +# https://gist.github.com/mrsteyk/74ad3ec2f6f823111ae4c90e168505ac + +import torch +import torch.nn.functional as F +import torch.nn as nn + +from models import register + + +class TimestepEmbedding(nn.Module): + def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: + super().__init__() + self.emb = nn.Embedding(n_time, n_emb) + self.f_1 = nn.Linear(n_emb, n_out) + self.f_2 = nn.Linear(n_out, n_out) + + def forward(self, x) -> torch.Tensor: + x = self.emb(x) + x = self.f_1(x) + x = F.silu(x) + return self.f_2(x) + + +class PositionalEmbedding(nn.Module): + def __init__(self, pe_dim=320, out_dim=1280, max_positions=10000, endpoint=True): + super().__init__() + self.num_channels = pe_dim + self.max_positions = max_positions + self.endpoint = endpoint + self.f_1 = nn.Linear(pe_dim, out_dim) + self.f_2 = nn.Linear(out_dim, out_dim) + + def forward(self, x): + freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device) + freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) + freqs = (1 / self.max_positions) ** freqs + x = x.ger(freqs.to(x.dtype)) + x = torch.cat([x.cos(), x.sin()], dim=1) + + x = self.f_1(x) + x = F.silu(x) + return self.f_2(x) + + +class ImageEmbedding(nn.Module): + def __init__(self, in_channels, out_channels=320) -> None: + super().__init__() + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(x) + + +class ImageUnembedding(nn.Module): + def __init__(self, in_channels=320, out_channels=3) -> None: + super().__init__() + self.gn = nn.GroupNorm(32, in_channels) + self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x) -> torch.Tensor: + return self.f(F.silu(self.gn(x))) + + +class ConvResblock(nn.Module): + def __init__(self, in_features, out_features, t_dim) -> None: + super().__init__() + self.f_t = nn.Linear(t_dim, out_features * 2) + + self.gn_1 = nn.GroupNorm(32, in_features) + self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) + + self.gn_2 = nn.GroupNorm(32, out_features) + self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) + + skip_conv = in_features != out_features + self.f_s = ( + nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) + if skip_conv + else nn.Identity() + ) + + def forward(self, x, t): + x_skip = x + t = self.f_t(F.silu(t)) + t = t.chunk(2, dim=1) + t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 + t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) + + gn_1 = F.silu(self.gn_1(x)) + f_1 = self.f_1(gn_1) + + gn_2 = self.gn_2(f_1) + + return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) + + +# Also ConvResblock +class Downsample(nn.Module): + def __init__(self, in_channels, t_dim) -> None: + super().__init__() + self.f_t = nn.Linear(t_dim, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) + f_1 = self.f_1(avg_pool2d) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) + + +# Also ConvResblock +class Upsample(nn.Module): + def __init__(self, in_channels, t_dim) -> None: + super().__init__() + self.f_t = nn.Linear(t_dim, in_channels * 2) + + self.gn_1 = nn.GroupNorm(32, in_channels) + self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + self.gn_2 = nn.GroupNorm(32, in_channels) + + self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, x, t) -> torch.Tensor: + x_skip = x + + t = self.f_t(F.silu(t)) + t_1, t_2 = t.chunk(2, dim=1) + t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 + t_2 = t_2.unsqueeze(2).unsqueeze(3) + + gn_1 = F.silu(self.gn_1(x)) + upsample = F.upsample_nearest(gn_1, scale_factor=2) + f_1 = self.f_1(upsample) + gn_2 = self.gn_2(f_1) + + f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) + + return f_2 + F.upsample_nearest(x_skip, scale_factor=2) + + +@register('consistency_decoder_unet') +class ConsistencyDecoderUNet(nn.Module): + def __init__(self, in_channels=3, z_dec_channels=None, c0=320, c1=640, c2=1024, pe_dim=320, t_dim=1280) -> None: + super().__init__() + if z_dec_channels is not None: + in_channels += z_dec_channels + self.embed_image = ImageEmbedding(in_channels=in_channels, out_channels=c0) + self.embed_time = PositionalEmbedding(pe_dim=pe_dim, out_dim=t_dim) + + down_0 = nn.ModuleList([ + ConvResblock(c0, c0, t_dim), + ConvResblock(c0, c0, t_dim), + ConvResblock(c0, c0, t_dim), + Downsample(c0, t_dim), + ]) + down_1 = nn.ModuleList([ + ConvResblock(c0, c1, t_dim), + ConvResblock(c1, c1, t_dim), + ConvResblock(c1, c1, t_dim), + Downsample(c1, t_dim), + ]) + down_2 = nn.ModuleList([ + ConvResblock(c1, c2, t_dim), + ConvResblock(c2, c2, t_dim), + ConvResblock(c2, c2, t_dim), + Downsample(c2, t_dim), + ]) + down_3 = nn.ModuleList([ + ConvResblock(c2, c2, t_dim), + ConvResblock(c2, c2, t_dim), + ConvResblock(c2, c2, t_dim), + ]) + self.down = nn.ModuleList([ + down_0, + down_1, + down_2, + down_3, + ]) + + self.mid = nn.ModuleList([ + ConvResblock(c2, c2, t_dim), + ConvResblock(c2, c2, t_dim), + ]) + + up_3 = nn.ModuleList([ + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 * 2, c2, t_dim), + Upsample(c2, t_dim), + ]) + up_2 = nn.ModuleList([ + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 * 2, c2, t_dim), + ConvResblock(c2 + c1, c2, t_dim), + Upsample(c2, t_dim), + ]) + up_1 = nn.ModuleList([ + ConvResblock(c2 + c1, c1, t_dim), + ConvResblock(c1 * 2, c1, t_dim), + ConvResblock(c1 * 2, c1, t_dim), + ConvResblock(c0 + c1, c1, t_dim), + Upsample(c1, t_dim), + ]) + up_0 = nn.ModuleList([ + ConvResblock(c0 + c1, c0, t_dim), + ConvResblock(c0 * 2, c0, t_dim), + ConvResblock(c0 * 2, c0, t_dim), + ConvResblock(c0 * 2, c0, t_dim), + ]) + self.up = nn.ModuleList([ + up_0, + up_1, + up_2, + up_3, + ]) + + self.output = ImageUnembedding(in_channels=c0) + + def get_last_layer_weight(self): + return self.output.f.weight + + def forward(self, x, t=None, z_dec=None) -> torch.Tensor: + if z_dec is not None: + if z_dec.shape[-2] != x.shape[-2] or z_dec.shape[-1] != x.shape[-1]: + assert x.shape[-2] // z_dec.shape[-2] == x.shape[-1] // z_dec.shape[-1] + z_dec = F.upsample_nearest(z_dec, scale_factor=x.shape[-2] // z_dec.shape[-2]) + x = torch.cat([x, z_dec], dim=1) + + x = self.embed_image(x) + + if t is None: + t = torch.zeros(x.shape[0], device=x.device) + t = self.embed_time(t) + + skips = [x] + for down in self.down: + for block in down: + x = block(x, t) + skips.append(x) + + for mid in self.mid: + x = mid(x, t) + + for up in self.up[::-1]: + for block in up: + if isinstance(block, ConvResblock): + x = torch.concat([x, skips.pop()], dim=1) + x = block(x, t) + + return self.output(x) diff --git a/flowae/models/networks/dit.py b/flowae/models/networks/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..439a6331f10c13a9eeb2c81793ab933475b61c2b --- /dev/null +++ b/flowae/models/networks/dit.py @@ -0,0 +1,384 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# GLIDE: https://github.com/openai/glide-text2im +# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py +# -------------------------------------------------------- +import math + +import torch +import torch.nn as nn +import numpy as np +from timm.models.vision_transformer import PatchEmbed, Attention, Mlp + +from models import register + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.0, + n_classes=1000, + learn_sigma=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(n_classes, hidden_size, class_dropout_prob) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + + self.blocks = nn.ModuleList([ + DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) + ]) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + + def forward(self, x, t, class_labels): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(class_labels, self.training) # (N, D) + c = t + y # (N, D) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward_with_cfg(self, x, t, y, cfg_scale): + """ + Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, y) + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# DiT Configs # +################################################################################# + +@register('dit_xl_2') +def DiT_XL_2(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +@register('dit_xl_4') +def DiT_XL_4(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs) + +@register('dit_xl_8') +def DiT_XL_8(**kwargs): + return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs) + +@register('dit_l_2') +def DiT_L_2(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) + +@register('dit_l_4') +def DiT_L_4(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs) + +@register('dit_l_8') +def DiT_L_8(**kwargs): + return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs) + +@register('dit_b_2') +def DiT_B_2(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) + +@register('dit_b_4') +def DiT_B_4(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs) + +@register('dit_b_8') +def DiT_B_8(**kwargs): + return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs) + +@register('dit_s_2') +def DiT_S_2(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) + +@register('dit_s_4') +def DiT_S_4(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs) + +@register('dit_s_8') +def DiT_S_8(**kwargs): + return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs) + + +DiT_models = { + 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8, + 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, + 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, + 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8, +} \ No newline at end of file diff --git a/flowae/reconstruction.py b/flowae/reconstruction.py new file mode 100644 index 0000000000000000000000000000000000000000..a3fee6a4b8779c9cdbdadba258d2589024b6478d --- /dev/null +++ b/flowae/reconstruction.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from PIL import Image +from torchvision import transforms +import numpy as np +from pathlib import Path +import argparse + +# You'll need to have the DiTo codebase available +import models +from omegaconf import OmegaConf + +class DiToInference: + def __init__(self, checkpoint_path, device='cuda'): + """Initialize DiTo model from checkpoint""" + self.device = device + + # Load checkpoint + print(f"Loading checkpoint from {checkpoint_path}") + ckpt = torch.load(checkpoint_path, map_location='cpu') + + # Extract config + self.config = OmegaConf.create(ckpt['config']) + + # Create model + self.model = models.make(self.config['model']) + + # Load state dict + self.model.load_state_dict(ckpt['model']['sd']) + + # Move to device and set to eval + self.model = self.model.to(device) + self.model.eval() + + # Setup image transforms based on config + self.transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) + + print("Model loaded successfully!") + + def reconstruct_image(self, image_path, debug=True): + """Reconstruct a single image""" + # Load and preprocess image + image = Image.open(image_path).convert('RGB') + + if debug: + debug_transform = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(256), + ]) + debug_image = debug_transform(image) + debug_image.save('debug_1_resized_cropped.png') + print("Saved debug_1_resized_cropped.png") + + image_tensor = self.transform(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + # Step 1: Encode to latent + z = self.model.encode(image_tensor) + + # Step 2: Decode to features (in DiTo this is identity) + z_dec = self.model.decode(z) + print('z_dec.shape:', z_dec.shape) + + # Step 3: Prepare coordinate grids + # Based on the training code, coord and scale are dummy values + b, c, h, w = image_tensor.shape + coord = torch.zeros(b, 2, h, w, device=self.device) + scale = torch.zeros(b, 2, h, w, device=self.device) + + # Step 4: Render using diffusion + reconstructed = self.model.render(z_dec, coord, scale) + + # Denormalize from [-1, 1] to [0, 1] + reconstructed = (reconstructed * 0.5 + 0.5).clamp(0, 1) + + return reconstructed + + def save_reconstruction(self, image_path, output_path): + """Reconstruct and save image""" + reconstructed = self.reconstruct_image(image_path) + + # Convert to PIL + to_pil = transforms.ToPILImage() + reconstructed_pil = to_pil(reconstructed.squeeze(0).cpu()) + + # Save + reconstructed_pil.save(output_path) + print(f"Saved reconstruction to {output_path}") + + def compare_reconstruction(self, image_path, output_path): + """Save original and reconstruction side by side""" + # Get reconstruction + reconstructed = self.reconstruct_image(image_path) + + # Load original at same resolution + original = Image.open(image_path).convert('RGB') + original = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(256), + transforms.ToTensor() + ])(original).unsqueeze(0) + + # Concatenate side by side + comparison = torch.cat([original, reconstructed.cpu()], dim=3) + + # Save + to_pil = transforms.ToPILImage() + comparison_pil = to_pil(comparison.squeeze(0)) + comparison_pil.save(output_path) + print(f"Saved comparison to {output_path}") + + def batch_reconstruct(self, image_folder, output_folder, max_images=None): + """Reconstruct all images in a folder""" + image_folder = Path(image_folder) + output_folder = Path(output_folder) + output_folder.mkdir(exist_ok=True, parents=True) + + # Get all images + image_paths = list(image_folder.glob('*.png')) + \ + list(image_folder.glob('*.jpg')) + \ + list(image_folder.glob('*.jpeg')) + + if max_images: + image_paths = image_paths[:max_images] + + print(f"Processing {len(image_paths)} images...") + + for img_path in image_paths: + output_path = output_folder / f"recon_{img_path.name}" + self.save_reconstruction(str(img_path), str(output_path)) + + print("Batch reconstruction complete!") + +def main(): + parser = argparse.ArgumentParser(description='DiTo Image Reconstruction') + parser.add_argument('--checkpoint', type=str, required=True, + help='Path to DiTo checkpoint') + parser.add_argument('--input', type=str, required=True, + help='Input image path or folder') + parser.add_argument('--output', type=str, required=True, + help='Output path') + parser.add_argument('--compare', action='store_true', + help='Save comparison with original') + parser.add_argument('--batch', action='store_true', + help='Process entire folder') + parser.add_argument('--device', type=str, default='cuda', + help='Device to use (cuda/cpu)') + parser.add_argument('--max_images', type=int, default=None, + help='Maximum images to process in batch mode') + + args = parser.parse_args() + + # Initialize model + dito = DiToInference(args.checkpoint, device=args.device) + + # Process based on mode + if args.batch: + dito.batch_reconstruct(args.input, args.output, args.max_images) + elif args.compare: + dito.compare_reconstruction(args.input, args.output) + else: + dito.save_reconstruction(args.input, args.output) + +# Example usage function for direct Python use +def reconstruct_single_image(checkpoint_path, image_path, output_path): + """Simple function to reconstruct a single image""" + dito = DiToInference(checkpoint_path) + dito.save_reconstruction(image_path, output_path) + +if __name__ == "__main__": + main() + +# Usage examples: +# 1. Single image reconstruction: +# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output recon.jpg +# +# 2. Single image with comparison: +# python dito_inference.py --checkpoint ckpt-best.pth --input image.jpg --output compare.jpg --compare +# +# 3. Batch processing: +# python dito_inference.py --checkpoint ckpt-best.pth --input input_folder/ --output output_folder/ --batch +# +# 4. Direct Python usage: +# reconstruct_single_image('ckpt-best.pth', 'input.jpg', 'output.jpg') \ No newline at end of file diff --git a/flowae/requirements.txt b/flowae/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..531d37316f2a1a85065157d3a1d91df624a07c96 --- /dev/null +++ b/flowae/requirements.txt @@ -0,0 +1,9 @@ +torch==2.3.0 +torchvision==0.18.0 +torch_fidelity==0.3.0 +omegaconf +pyyaml +wandb +webdataset +timm +einops \ No newline at end of file diff --git a/flowae/run.py b/flowae/run.py new file mode 100644 index 0000000000000000000000000000000000000000..9387ef27d776ab0be9850aa23fd2fa912f5bd583 --- /dev/null +++ b/flowae/run.py @@ -0,0 +1,59 @@ +import argparse +import os + +from omegaconf import OmegaConf + +from trainers import trainers_dict + + +def make_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='configs/_.yaml') + parser.add_argument('--name', '-n', default=None) + parser.add_argument('--tag', '-t', default=None) + parser.add_argument('--resume', '-r', action='store_true') + parser.add_argument('--force-replace', '-f', action='store_true') + parser.add_argument('--wandb', '-w', action='store_true') + parser.add_argument('--save-root', default='save') + parser.add_argument('--eval-only', action='store_true') + args = parser.parse_args() + return args + + +def parse_config(config): + if config.get('__base__') is not None: + filenames = config.pop('__base__') + if isinstance(filenames, str): + filenames = [filenames] + base_config = OmegaConf.merge(*[ + parse_config(OmegaConf.load(_)) + for _ in filenames + ]) + config = OmegaConf.merge(base_config, config) + return config + + +def make_env(args): + env = dict() + + if args.name is None: + exp_name = os.path.splitext(os.path.basename(args.config))[0] + else: + exp_name = args.name + if args.tag is not None: + exp_name += '_' + args.tag + env['exp_name'] = exp_name + + env['save_dir'] = os.path.join(args.save_root, exp_name) + env['wandb'] = args.wandb + env['resume'] = args.resume + env['force_replace'] = args.force_replace + return env + + +if __name__ == '__main__': + args = make_args() + env = make_env(args) + config = parse_config(OmegaConf.load(args.config)) + trainer = trainers_dict[config.trainer](env, config) + trainer.run(eval_only=args.eval_only) diff --git a/flowae/utils/__init__.py b/flowae/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16281fe0b66dbac563229823d656ef173736e306 --- /dev/null +++ b/flowae/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * diff --git a/flowae/utils/geometry.py b/flowae/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..28e4c140a45d44f8fb9a3d14dbf58f2531a56695 --- /dev/null +++ b/flowae/utils/geometry.py @@ -0,0 +1,37 @@ +import torch + + +def make_coord_grid(shape, range=(0, 1), device='cpu', batch_size=None): + """ + Args: + shape: (s_1, ..., s_k), grid shape + range: range for each axis, list or tuple, [minv, maxv] or [[minv_1, maxv_1], ..., [minv_k, maxv_k]] + Returns: + (s_1, ..., s_k, k), coordinate grid + """ + p_lst = [] + for i, n in enumerate(shape): + p = (torch.arange(n, device=device) + 0.5) / n + if isinstance(range[0], list) or isinstance(range[0], tuple): + minv, maxv = range[i] + else: + minv, maxv = range + p = minv + (maxv - minv) * p + p_lst.append(p) + coord = torch.stack(torch.meshgrid(*p_lst, indexing='ij'), dim=-1) + + if batch_size is not None: + coord = coord.unsqueeze(0).expand(batch_size, *([-1] * coord.dim())) + return coord + + +def make_coord_scale_grid(shape, range=(0, 1), device='cpu', batch_size=None): + coord = make_coord_grid(shape, range=range, device=device, batch_size=batch_size) + scale = torch.ones_like(coord) + for i, n in enumerate(shape): + if isinstance(range[0], list) or isinstance(range[0], tuple): + minv, maxv = range[i] + else: + minv, maxv = range + scale[..., i] *= (maxv - minv) / n + return coord, scale diff --git a/flowae/utils/utils.py b/flowae/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f879c3cb8415249e362e766cff28a9af87daeb --- /dev/null +++ b/flowae/utils/utils.py @@ -0,0 +1,95 @@ +import os +import shutil +import time +import logging + +from torch.optim import SGD, Adam, AdamW + + +def ensure_path(path, replace=True, force_replace=False): + is_temp = os.path.basename(path.rstrip('/')).startswith('_') + if os.path.exists(path): + if replace and (is_temp or force_replace or input(f'{path} exists, replace? y/[n] ') == 'y'): + shutil.rmtree(path) + os.mkdir(path) + else: + os.makedirs(path) + + +def set_logger(file_path): + logger = logging.getLogger() + logger.setLevel('INFO') + stream_handler = logging.StreamHandler() + file_handler = logging.FileHandler(file_path, 'a') + formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S') + for handler in [stream_handler, file_handler]: + handler.setFormatter(formatter) + handler.setLevel('INFO') + logger.addHandler(handler) + return logger + + +def compute_num_params(model, text=True): + tot = sum(p.numel() for p in model.parameters()) + if text: + if tot >= 1e6: + s = '{:.1f}M'.format(tot / 1e6) + else: + s = '{:.1f}K'.format(tot / 1e3) + return f'{s} ({tot})' + else: + return tot + + +def make_optimizer(params, optimizer_spec): + optimizer = { + 'sgd': SGD, + 'adam': Adam, + 'adamw': AdamW, + }[optimizer_spec['name']](params, **optimizer_spec['args']) + return optimizer + + +class Averager(): + + def __init__(self, v=None): + if v is None: + self.n = 0. + self.v = 0. + else: + self.n = 1. + self.v = v + + def add(self, v, n=1.0): + self.v = self.v * (self.n / (self.n + n)) + v * (n / (self.n + n)) + self.n += n + + def item(self): + return self.v + + +class EpochTimer(): + + def __init__(self, max_epoch): + self.max_epoch = max_epoch + self.epoch = 0 + self.t_start = time.time() + self.t_last = self.t_start + + def epoch_done(self): + t_cur = time.time() + self.epoch += 1 + epoch_time = t_cur - self.t_last + tot_time = t_cur - self.t_start + est_time = tot_time / self.epoch * self.max_epoch + self.t_last = t_cur + return time_text(epoch_time), time_text(tot_time), time_text(est_time) + + +def time_text(sec): + if sec >= 3600: + return f'{sec / 3600:.1f}h' + elif sec >= 60: + return f'{sec / 60:.1f}m' + else: + return f'{sec:.1f}s'