diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..4d39f7949baafdf858981168c335acbe6bc9186f 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
+*.jpg filter=lfs diff=lfs merge=lfs -text
+*.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index b471107db32daa7020643e0a22a5ec45312510f2..4bdd60e166cfc0d1f286e0fd9147851178235253 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,156 @@
+
PrismAudio
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ If you find this project useful,
+ a star ⭐ on GitHub would be greatly appreciated!
+
+
---
-title: PrismAudio
-emoji: 📉
-colorFrom: purple
-colorTo: green
-sdk: gradio
-sdk_version: 6.9.0
-app_file: app.py
-pinned: false
-license: apache-2.0
+
+**PrismAudio** is the first framework to integrate Reinforcement Learning into Video-to-Audio (V2A) generation with specialized Chain-of-Thought (CoT) planning. Building upon [ThinkSound](https://arxiv.org/pdf/2506.21448)'s pioneering CoT-based V2A framework, PrismAudio further decomposes monolithic reasoning into four specialized CoT modules (Semantic, Temporal, Aesthetic, and Spatial), each paired with targeted reward functions, enabling multi-dimensional RL optimization that jointly improves reasoning across all perceptual dimensions.
+
+
+
---
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+## 📰 News
+
+- **2026.03.22** 🔥 We have released **PrismAudio**, our next-generation video-to-audio generation model! For more details, please refer to the [`prismaudio`](https://github.com/liuhuadai/ThinkSound/tree/prismaudio) branch!
+- **2026.01.26** 🎉 PrismAudio has been accepted to the **ICLR 2026 Main Conference**! We plan to release the project in February 2026.
+- **2025.11.25** 🔥 [Online PrismAudio Demo](http://prismaudio-project.github.io/) is live - try it now!
+- **2025.11.25** 🔥 [PrismAudio paper](https://arxiv.org/pdf/2511.18833) released on arXiv, the first multi-dimensional CoT-RL framework for Video-to-Audio Generation!
+- **2025.09.19** 🎉 ThinkSound has been accepted to the **NeurIPS 2025 Main Conference**!
+- **2025.09.01** Our AudioCoT dataset is now open-sourced and available on [Hugging Face](https://huggingface.co/datasets/liuhuadai/AudioCoT)!
+- **2025.07.17** 🧠 Finetuning enabled: training and finetuning code is now publicly available, along with clear usage instructions to help you customize and extend ThinkSound with your own data.
+- **2025.07.15** 📦 Simplified installation and usability: dependencies on PyPI for easy cross-platform setup; Windows `.bat` scripts automate environment creation and script running.
+- **2025.07.08** 🔧 Major update: model lightweighted and optimized memory and GPU usage, now supports high-throughput audio generation at scale!
+- **2025.07.01** Online demo on [Hugging Face Spaces](https://huggingface.co/spaces/FunAudioLLM/ThinkSound) and [ModelScope](https://modelscope.cn/studios/iic/ThinkSound) for interactive experience!
+- **2025.07.01** Released inference scripts and web interface.
+- **2025.06** [ThinkSound paper](https://arxiv.org/pdf/2506.21448) released on arXiv!
+- **2025.06** [Online Demo](http://thinksound-project.github.io/) is live - try it now!
+
+---
+
+## 🚀 Features
+
+- **V2A SOTA**: Achieves state-of-the-art results across all four perceptual dimensions on both VGGSound and AudioCanvas benchmarks.
+- **Decomposed CoT Reasoning**: Four specialized CoT modules (Semantic, Temporal, Aesthetic, Spatial) each providing focused, interpretable reasoning for its corresponding perceptual dimension.
+- **Multi-dimensional RL**: Fast-GRPO enables efficient multi-dimensional reward optimization without compromising generation quality.
+- **New Benchmark**: AudioCanvas — a rigorous V2A benchmark with 300 single-event classes and 501 multi-event samples covering diverse and challenging scenarios.
+- **Efficient**: 518M parameters with faster inference than prior SOTAs.
+
+---
+
+## ✨ Method Overview
+
+PrismAudio consists of three main components:
+
+1. **CoT-Aware Audio Foundation Model**: Built on a Multimodal Diffusion Transformer with flow matching, enhanced with VideoPrism for video understanding and T5-Gemma for structured CoT text encoding.
+2. **Decomposed Multi-Dimensional CoT Reasoning**: Four specialized CoT modules — Semantic, Temporal, Aesthetic, and Spatial — each providing targeted reasoning for its corresponding perceptual dimension.
+3. **Fast-GRPO Multi-Dimensional RL Framework**: A hybrid ODE-SDE sampling strategy that dramatically reduces training overhead while enabling multi-dimensional reward optimization across all perceptual dimensions.
+
+
+---
+
+## ⚡ Quick Start
+
+```bash
+git clone -b prismaudio https://github.com/liuhuadai/ThinkSound.git
+cd ThinkSound
+
+conda create -n prismaudio python=3.10
+conda activate prismaudio
+chmod +x scripts/PrismAudio/setup/build_env.sh
+./scripts/PrismAudio/setup/build_env.sh
+
+# Download pretrained weights to Directory ckpts/
+# From Hugging Face: https://huggingface.co/liuhuadai/ThinkSound
+# From ModelScope: https://www.modelscope.cn/models/iic/ThinkSound
+git lfs install
+git clone https://huggingface.co/liuhuadai/ThinkSound ckpts
+```
+
+---
+
+## ▶️ Run Demo
+
+```bash
+chmod +x scripts/PrismAudio/demo.sh
+./scripts/PrismAudio/demo.sh ""
+```
+
+**Note:**
+- ``: Path to a single input video file.
+- `""`: A structured CoT description of the audio to generate.
+
+---
+
+## 🏋️ Train the Model
+
+See [`Training.md`](docs/PrismAudio/Training.md)
+
+---
+
+## 📄 License
+
+This project is released under the Apache 2.0 License.
+
+> **Note:**
+> The code, models, and dataset are **for research and educational purposes only**.
+> **Commercial use is NOT permitted.**
+> For commercial licensing, please contact the authors.
+
+**📦 Third-Party Components**
+
+- **Stable Audio Open VAE** (by Stability AI): Licensed under the [Stability AI Community License](./third_party/LICENSE_StabilityAI.md). **Commercial use and redistribution require prior permission from Stability AI.**
+- 📘 **All other code and models** are released under the Apache License 2.0.
+
+---
+
+## Acknowledgements
+
+Many thanks to:
+
+- **stable-audio-tools** (by Stability AI): For providing an easy-to-use framework for audio generation, as well as the VAE module and weights.
+- **MMAudio**: For the implementation of the MM-DiT backbone in the audio domain.
+- **ThinkSound**: For the foundational CoT-based V2A generation framework that PrismAudio builds upon.
+
+---
+
+## 📖 Citation
+
+If you find PrismAudio useful in your research or work, please cite our paper:
+
+```bibtex
+@misc{liu2025prismaudiodecomposedchainofthoughtsmultidimensional,
+ title={PrismAudio: Decomposed Chain-of-Thoughts and Multi-dimensional Rewards for Video-to-Audio Generation},
+ author={Huadai Liu and Kaicheng Luo and Wen Wang and Qian Chen and Peiwen Sun and Rongjie Huang and Xiangang Li and Jieping Ye and Wei Xue},
+ year={2025},
+ eprint={2511.18833},
+ archivePrefix={arXiv},
+ primaryClass={cs.SD},
+ url={https://arxiv.org/abs/2511.18833},
+ }
+```
+
+---
+
+## 📬 Contact
+
+✨ Feel free to [open an issue](https://github.com/liuhuadai/ThinkSound/issues) or contact us via email ([huadai.liu@connect.ust.hk](mailto:huadai.liu@connect.ust.hk)) if you have any questions or suggestions!
diff --git a/ThinkSound/__init__.py b/ThinkSound/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bdaa12ea2399cacdbf5859206911173c8ce0c0c
--- /dev/null
+++ b/ThinkSound/__init__.py
@@ -0,0 +1 @@
+from .models.factory import create_model_from_config, create_model_from_config_path
\ No newline at end of file
diff --git a/ThinkSound/configs/model_configs/prismaudio.json b/ThinkSound/configs/model_configs/prismaudio.json
new file mode 100644
index 0000000000000000000000000000000000000000..19d24a09a76e9656cf221c7f1743fef3a7dd69dd
--- /dev/null
+++ b/ThinkSound/configs/model_configs/prismaudio.json
@@ -0,0 +1,141 @@
+{
+ "model_type": "diffusion_cond",
+ "sample_size": 397312,
+ "sample_rate": 44100,
+ "audio_channels": 2,
+ "model": {
+ "pretransform": {
+ "type": "autoencoder",
+ "iterate_batch": true,
+ "config": {
+ "encoder": {
+ "type": "oobleck",
+ "config": {
+ "in_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 128,
+ "use_snake": true
+ }
+ },
+ "decoder": {
+ "type": "oobleck",
+ "config": {
+ "out_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 64,
+ "use_snake": true,
+ "final_tanh": false
+ }
+ },
+ "bottleneck": {
+ "type": "vae"
+ },
+ "latent_dim": 64,
+ "downsampling_ratio": 2048,
+ "io_channels": 2
+ }
+ },
+ "conditioning": {
+ "configs": [
+ {
+ "id": "video_features",
+ "type": "cond_mlp",
+ "config": {
+ "dim": 1024,
+ "output_dim": 1024
+ }
+ },
+ {
+ "id": "text_features",
+ "type": "cond_mlp",
+ "config": {
+ "dim": 1024,
+ "output_dim": 1024
+ }
+ },
+ {
+ "id": "sync_features",
+ "type": "sync_mlp",
+ "config": {
+ "dim": 768,
+ "output_dim": 1024
+ }
+ }
+ ],
+ "cond_dim": 768
+ },
+ "diffusion": {
+ "cross_attention_cond_ids": ["video_features","text_features"],
+ "add_cond_ids": ["video_features"],
+ "sync_cond_ids": ["sync_features"],
+ "type": "dit",
+ "diffusion_objective": "rectified_flow",
+ "config": {
+ "io_channels": 64,
+ "embed_dim": 1024,
+ "depth": 24,
+ "num_heads": 16,
+ "cond_token_dim": 1024,
+ "add_token_dim": 1024,
+ "sync_token_dim": 1024,
+ "project_cond_tokens": false,
+ "transformer_type": "continuous_transformer",
+ "attn_kwargs":{
+ "qk_norm": "rns"
+ },
+ "use_gated": true,
+ "use_sync_gated": true
+ }
+ },
+ "io_channels": 64
+ },
+ "training": {
+ "use_ema": true,
+ "log_loss_info": false,
+ "cfg_dropout_prob": 0.1,
+ "pre_encoded": true,
+ "timestep_sampler": "trunc_logit_normal",
+ "optimizer_configs": {
+ "diffusion": {
+ "optimizer": {
+ "type": "AdamW",
+ "config": {
+ "lr": 1e-4,
+ "betas": [0.9, 0.999],
+ "weight_decay": 1e-3
+ }
+ },
+ "scheduler": {
+ "type": "InverseLR",
+ "config": {
+ "inv_gamma": 100000,
+ "power": 0.5,
+ "warmup": 0.99
+ }
+ }
+ }
+ },
+ "demo": {
+ "demo_every": 5000,
+ "demo_steps": 24,
+ "num_demos": 10,
+ "demo_cond": [
+ "dataset/videoprism/test/0Cu33yBwAPg_000060.npz",
+ "dataset/videoprism/test/bmKtI808DsU_000009.npz",
+ "dataset/videoprism/test/VC0c22cJTbM_000424.npz",
+ "dataset/videoprism/test/F3gsbUTdc2U_000090.npz",
+ "dataset/videoprism/test/WatvT8A8iug_000100.npz",
+ "dataset/videoprism/test/0nvBTp-q7tU_000112.npz",
+ "dataset/videoprism/test/3-PFuDkTM48_000080.npz",
+ "dataset/videoprism/test/luSAuu-BoPs_000232.npz",
+ "dataset/videoprism/test/__8UJxW0aOQ_000002.npz",
+ "dataset/videoprism/test/_0m_YMpQayA_000168.npz"
+ ],
+ "demo_cfg_scales": [5]
+ }
+ }
+}
\ No newline at end of file
diff --git a/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json b/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json
new file mode 100644
index 0000000000000000000000000000000000000000..95f72495502f5b0a725378a0b9f51a56bb9da910
--- /dev/null
+++ b/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json
@@ -0,0 +1,122 @@
+{
+ "model_type": "autoencoder",
+ "sample_size": 65536,
+ "sample_rate": 44100,
+ "audio_channels": 2,
+ "model": {
+ "encoder": {
+ "type": "oobleck",
+ "config": {
+ "in_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 128,
+ "use_snake": true
+ }
+ },
+ "decoder": {
+ "type": "oobleck",
+ "config": {
+ "out_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 64,
+ "use_snake": true,
+ "final_tanh": false
+ }
+ },
+ "bottleneck": {
+ "type": "vae"
+ },
+ "latent_dim": 64,
+ "downsampling_ratio": 2048,
+ "io_channels": 2
+ },
+ "training": {
+ "learning_rate": 1.5e-4,
+ "warmup_steps": 0,
+ "use_ema": true,
+ "optimizer_configs": {
+ "autoencoder": {
+ "optimizer": {
+ "type": "AdamW",
+ "config": {
+ "betas": [0.8, 0.99],
+ "lr": 1.5e-4,
+ "weight_decay": 1e-3
+ }
+ },
+ "scheduler": {
+ "type": "InverseLR",
+ "config": {
+ "inv_gamma": 200000,
+ "power": 0.5,
+ "warmup": 0.999
+ }
+ }
+ },
+ "discriminator": {
+ "optimizer": {
+ "type": "AdamW",
+ "config": {
+ "betas": [0.8, 0.99],
+ "lr": 3e-4,
+ "weight_decay": 1e-3
+ }
+ },
+ "scheduler": {
+ "type": "InverseLR",
+ "config": {
+ "inv_gamma": 200000,
+ "power": 0.5,
+ "warmup": 0.999
+ }
+ }
+ }
+ },
+ "loss_configs": {
+ "discriminator": {
+ "type": "encodec",
+ "config": {
+ "filters": 64,
+ "n_ffts": [2048, 1024, 512, 256, 128],
+ "hop_lengths": [512, 256, 128, 64, 32],
+ "win_lengths": [2048, 1024, 512, 256, 128]
+ },
+ "weights": {
+ "adversarial": 0.1,
+ "feature_matching": 5.0
+ }
+ },
+ "spectral": {
+ "type": "mrstft",
+ "config": {
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
+ "perceptual_weighting": true
+ },
+ "weights": {
+ "mrstft": 1.0
+ }
+ },
+ "time": {
+ "type": "l1",
+ "weights": {
+ "l1": 0.0
+ }
+ },
+ "bottleneck": {
+ "type": "kl",
+ "weights": {
+ "kl": 1e-4
+ }
+ }
+ },
+ "demo": {
+ "demo_every": 10000
+ }
+ }
+}
\ No newline at end of file
diff --git a/ThinkSound/configs/model_configs/thinksound.json b/ThinkSound/configs/model_configs/thinksound.json
new file mode 100644
index 0000000000000000000000000000000000000000..1458b0d43350e928b9c68cc8619242b8ff8f87c1
--- /dev/null
+++ b/ThinkSound/configs/model_configs/thinksound.json
@@ -0,0 +1,147 @@
+{
+ "model_type": "mm_diffusion_cond",
+ "sample_size": 397312,
+ "sample_rate": 44100,
+ "audio_channels": 2,
+ "model": {
+ "pretransform": {
+ "type": "autoencoder",
+ "iterate_batch": true,
+ "config": {
+ "encoder": {
+ "type": "oobleck",
+ "config": {
+ "in_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 128,
+ "use_snake": true
+ }
+ },
+ "decoder": {
+ "type": "oobleck",
+ "config": {
+ "out_channels": 2,
+ "channels": 128,
+ "c_mults": [1, 2, 4, 8, 16],
+ "strides": [2, 4, 4, 8, 8],
+ "latent_dim": 64,
+ "use_snake": true,
+ "final_tanh": false
+ }
+ },
+ "bottleneck": {
+ "type": "vae"
+ },
+ "latent_dim": 64,
+ "downsampling_ratio": 2048,
+ "io_channels": 2
+ }
+ },
+ "conditioning": {
+ "configs": [
+ {
+ "id": "metaclip_features",
+ "type": "mm_unchang",
+ "config": {
+ "dim": 1024,
+ "output_dim": 1024
+ }
+ },
+ {
+ "id": "metaclip_text_features",
+ "type": "mm_unchang",
+ "config": {
+ "dim": 1024,
+ "output_dim": 1024
+ }
+ },
+ {
+ "id": "sync_features",
+ "type": "mm_unchang",
+ "config": {
+ "dim": 768,
+ "output_dim": 768
+ }
+ },
+ {
+ "id": "t5_features",
+ "type": "mm_unchang",
+ "config": {
+ "dim": 2048,
+ "output_dim": 2048
+ }
+ }
+ ],
+ "cond_dim": 768
+ },
+ "diffusion": {
+ "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"],
+ "type": "mmdit",
+ "diffusion_objective": "rectified_flow",
+ "config": {
+ "latent_dim":64,
+ "clip_dim":1024,
+ "sync_dim":768,
+ "text_dim":2048,
+ "hidden_dim":1024,
+ "depth":21,
+ "fused_depth":14,
+ "num_heads":16,
+ "latent_seq_len":194,
+ "clip_seq_len":72,
+ "sync_seq_len":216,
+ "v2": true,
+ "kernel_size": 3
+ }
+ },
+ "io_channels": 64
+ },
+ "training": {
+ "use_ema": true,
+ "log_loss_info": false,
+ "cfg_dropout_prob": 0.2,
+ "pre_encoded": true,
+ "timestep_sampler": "logit_normal",
+ "optimizer_configs": {
+ "diffusion": {
+ "optimizer": {
+ "type": "AdamW",
+ "config": {
+ "lr": 5e-5,
+ "betas": [0.9, 0.95],
+ "weight_decay": 1e-4,
+ "eps": 1e-6
+ }
+ },
+ "scheduler": {
+ "type": "InverseLR",
+ "config": {
+ "inv_gamma": 1000000,
+ "power": 0.5,
+ "warmup": 0.99
+ }
+ }
+ }
+ },
+ "demo": {
+ "demo_every": 5000,
+ "demo_steps": 24,
+ "num_demos": 10,
+ "demo_cond": [
+ "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz",
+ "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz"
+ ],
+ "demo_cfg_scales": [5]
+ }
+ }
+}
\ No newline at end of file
diff --git a/ThinkSound/configs/multimodal_dataset_demo.json b/ThinkSound/configs/multimodal_dataset_demo.json
new file mode 100644
index 0000000000000000000000000000000000000000..b00baa073eba5bd531eb723731ad6ba39d5019e9
--- /dev/null
+++ b/ThinkSound/configs/multimodal_dataset_demo.json
@@ -0,0 +1,52 @@
+{
+ "dataset_type": "multimodal_dir",
+ "video_datasets": [
+ {
+ "id": "vggsound",
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/train",
+ "split_path": "dataset/vggsound/split_txt/train_cot.txt"
+ }
+ ],
+ "audio_datasets": [
+ {
+ "id": "audiostock",
+ "path": "dataset/Laion-Audio-630k/audiostock_latents_npz",
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt"
+ },
+ {
+ "id": "freesound_no_overlap",
+ "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz",
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt"
+ },
+ {
+ "id": "audioset_sl",
+ "path": "dataset/wavcaps/audioset_sl_latents_npz",
+ "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt"
+ },
+ {
+ "id": "audiocaps",
+ "path": "dataset/1_audiocaps/audiocaps_latents_npz",
+ "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt"
+ },
+ {
+ "id": "bbc",
+ "path": "dataset/Laion-Audio-630k/bbc_latents_npz",
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt"
+ }
+ ],
+ "val_datasets": [
+ {
+ "id": "vggsound",
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/test",
+ "split_path": "dataset/vggsound/split_txt/test_cot.txt"
+ }
+ ],
+ "test_datasets": [
+ {
+ "id": "vggsound",
+ "path": "cot_coarse"
+ }
+ ],
+ "random_crop": true,
+ "input_type": "prompt"
+}
\ No newline at end of file
diff --git a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json b/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json
new file mode 100644
index 0000000000000000000000000000000000000000..8a3e2320175dd91bd176d4bded1847a46e865e62
--- /dev/null
+++ b/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json
@@ -0,0 +1,27 @@
+{
+ "dataset_type": "video_dataset",
+ "datasets": [
+ {
+ "id": "vggsound",
+ "path": "test",
+ "split_path": "test/test.txt"
+ }
+ ],
+ "val_datasets": [
+ {
+ "id": "vggsound",
+ "path": "test",
+ "split_path": "test/test.txt"
+ }
+ ],
+ "test_datasets": [
+ {
+ "id": "vggsound",
+ "path": "test",
+ "split_path": "test/test.txt"
+ }
+ ],
+ "random_crop": false,
+ "input_type": "video",
+ "fps": 8
+}
\ No newline at end of file
diff --git a/ThinkSound/data/__init__.py b/ThinkSound/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThinkSound/data/datamodule.py b/ThinkSound/data/datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7503c673511941d7e322fd0b9168c187a121a14
--- /dev/null
+++ b/ThinkSound/data/datamodule.py
@@ -0,0 +1,331 @@
+import lightning as L
+from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn
+import importlib
+import torch.distributed as dist
+from torch.utils.data import Dataset
+from torch.utils.data import DataLoader,IterableDataset
+import torch
+from itertools import cycle
+
+class AlternatingLoader(IterableDataset):
+ """
+ 一个可迭代的数据集,它包装了两个数据加载器,并按顺序轮流从它们中产出批次。
+ 它会持续进行直到两个加载器都耗尽。
+
+ Args:
+ loader1 (DataLoader): 第一个数据加载器。
+ loader2 (DataLoader): 第二个数据加载器。
+ loader1_name (str): 第一个加载器的名称 (例如 'video')。
+ loader2_name (str): 第二个加载器的名称 (例如 'audio')。
+ """
+ def __init__(self, loader1, loader2, loader1_name='video', loader2_name='audio'):
+ super().__init__()
+ self.loader1 = loader1
+ self.loader2 = loader2
+ self.loader1_name = loader1_name
+ self.loader2_name = loader2_name
+ self.max_len = max(len(loader1), len(loader2))
+
+ def __iter__(self):
+ # 获取 DDP 信息
+ try:
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ except (RuntimeError, ValueError):
+ # 如果不在分布式环境中,则默认为单进程
+ world_size = 1
+ rank = 0
+
+ # 创建两个无限循环迭代器
+ iter1 = cycle(self.loader1)
+ iter2 = cycle(self.loader2)
+
+ # 核心修改:只 yield 属于当前 rank 的数据
+ # 我们将总的交替流想象成一个大列表,然后对其进行切分
+ # 交替流: [v1, a1, v2, a2, v3, a3, ...]
+
+ # 每个 for 循环迭代产生 2 个 batch (1 个 video, 1 个 audio)
+ # 总共会产生 2 * self.max_len 个 batch
+
+ # for 循环负责驱动迭代
+ for i in range(self.max_len):
+ # 获取下一个 video batch
+ v_batch = next(iter1)
+ # 获取下一个 audio batch
+ a_batch = next(iter2)
+
+ # 这是一个交替对,我们根据索引 i 来决定哪个进程处理它
+ if i % world_size == rank:
+ # 只有当轮次索引 i 属于当前 rank 时,才 yield 数据
+ yield v_batch
+ yield a_batch
+
+ def __len__(self):
+ # 在 DDP 环境下,__len__ 应该返回单个进程处理的 batch 数量
+ # 以便 Lightning 正确显示进度条
+
+ try:
+ world_size = dist.get_world_size()
+ except (RuntimeError, ValueError):
+ world_size = 1
+
+ # 每个进程大致处理 1/world_size 的数据对
+ # 每个数据对包含 2 个 batch
+ num_pairs_per_process = self.max_len // world_size
+
+ # 如果总数不能整除,最后一个 rank 会多处理一些
+ # 为简化起见,我们通常可以用 ceil 来计算
+ # (self.max_len + world_size - 1) // world_size 是一种高效的 ceil 写法
+ num_pairs_per_process = (self.max_len + world_size - 1) // world_size
+
+ return 2 * num_pairs_per_process
+def get_configs(audio_configs):
+ configs = []
+ for config in audio_configs:
+ data_dir_path = config.get("path", None)
+ audio_dir_path = config.get("audio_dir", None)
+ split_path = config.get("split_path", None)
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
+
+ custom_metadata_fn = None
+ custom_metadata_module_path = config.get("custom_metadata_module", None)
+
+ if custom_metadata_module_path:
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
+ metadata_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(metadata_module)
+ custom_metadata_fn = metadata_module.get_custom_metadata
+
+ configs.append(
+ LocalDatasetConfig(
+ id=config["id"],
+ path=data_dir_path,
+ split_path=split_path,
+ custom_metadata_fn=custom_metadata_fn,
+ audio_dir=audio_dir_path
+ )
+ )
+ return configs
+
+class DataModule(L.LightningDataModule):
+ def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
+ super().__init__()
+ dataset_type = dataset_config.get("dataset_type", None)
+ repeat_num = dataset_config.get("repeat_num", 1)
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.test_batch_size = test_batch_size
+ self.repeat_num = repeat_num
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
+
+ if audio_channels == 1:
+ force_channels = "mono"
+ elif audio_channels == 2:
+ force_channels = "stereo"
+ else:
+ force_channels = "foa"
+ val_dir_configs = dataset_config.get("val_datasets", None)
+ test_dir_configs = dataset_config.get("test_datasets", None)
+ configs = []
+ val_configs = []
+ test_configs = []
+ if dataset_type == "audio_dir":
+ audio_dir_configs = dataset_config.get("datasets", None)
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
+ configs = get_configs(audio_dir_configs)
+ val_configs = get_configs(val_dir_configs)
+ test_configs = get_configs(test_dir_configs)
+ elif dataset_type == "latent_dir" or dataset_type == "video_dataset" or dataset_type == "audio_dataset":
+ audio_dir_configs = dataset_config.get("datasets", None)
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
+ for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)):
+ for config in dataset:
+ data_dir_path = config.get("path", None)
+ audio_dir_path = config.get("audio_dir", None)
+ split_path = config.get("split_path", None)
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
+
+ content = LocalDatasetConfig(
+ id=config["id"],
+ path=data_dir_path,
+ split_path=split_path,
+ audio_dir=audio_dir_path
+ )
+ if i == 0:
+ configs.append(content)
+ elif i == 1:
+ val_configs.append(content)
+ else:
+ test_configs.append(content)
+ elif dataset_type in ["multimodal_dir", "alternating_multimodal"]:
+ print('##########################')
+ print(f'repeat num is: {self.repeat_num}')
+ self.audio_configs = []
+ self.video_configs = []
+ audio_dir_configs = dataset_config.get("audio_datasets", None)
+ video_dir_configs = dataset_config.get("video_datasets", None)
+ assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets"
+ for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)):
+ for config in dataset:
+ data_dir_path = config.get("path", None)
+ audio_dir_path = config.get("audio_dir", None)
+ split_path = config.get("split_path", None)
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
+
+ content = LocalDatasetConfig(
+ id=config["id"],
+ path=data_dir_path,
+ split_path=split_path,
+ audio_dir=audio_dir_path
+ )
+ if i == 0:
+ self.audio_configs.append(content)
+ elif i == 1:
+ self.video_configs.append(content)
+ elif i == 2:
+ val_configs.append(content)
+ else:
+ test_configs.append(content)
+ self.dataset_type = dataset_type
+ self.configs = configs
+ self.val_configs = val_configs
+ self.test_configs = test_configs
+ self.sample_rate = sample_rate
+ self.sample_size = sample_size
+ self.random_crop = dataset_config.get("random_crop", True)
+ self.input_type = dataset_config.get("input_type", "video")
+ self.fps = dataset_config.get("fps", 4)
+ self.force_channels = force_channels
+
+
+ def setup(self, stage: str):
+ if self.dataset_type == 'audio_dir':
+ dataset_class = SampleDataset
+ elif self.dataset_type == 'latent_dir':
+ dataset_class = LatentDataset
+ elif self.dataset_type == 'video_dataset':
+ dataset_class = VideoDataset
+ elif self.dataset_type == 'audio_dataset':
+ dataset_class = AudioDataset
+ elif self.dataset_type in ["multimodal_dir", "alternating_multimodal"]:
+ dataset_class = VideoDataset
+
+ def create_dataset(configs, random_crop):
+ return dataset_class(
+ configs,
+ sample_rate=self.sample_rate,
+ sample_size=self.sample_size,
+ random_crop=random_crop,
+ input_type=self.input_type,
+ fps=self.input_type,
+ force_channels=self.force_channels
+ )
+
+ if stage == 'fit':
+ if self.dataset_type not in ["multimodal_dir", "alternating_multimodal"]:
+ self.train_set = create_dataset(self.configs, random_crop=self.random_crop)
+ elif self.dataset_type == "multimodal_dir":
+ self.video_set = VideoDataset(
+ self.video_configs,
+ sample_rate=self.sample_rate,
+ sample_size=self.sample_size,
+ random_crop=self.random_crop,
+ input_type=self.input_type,
+ fps=self.input_type,
+ force_channels=self.force_channels
+ )
+ self.audio_set = AudioDataset(
+ self.audio_configs,
+ sample_rate=self.sample_rate,
+ sample_size=self.sample_size,
+ random_crop=self.random_crop,
+ input_type=self.input_type,
+ fps=self.input_type,
+ force_channels=self.force_channels
+ )
+ self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set])
+ elif self.dataset_type == "alternating_multimodal":
+ self.video_set = VideoDataset(
+ self.video_configs,
+ sample_rate=self.sample_rate,
+ sample_size=self.sample_size,
+ random_crop=self.random_crop,
+ input_type=self.input_type,
+ fps=self.input_type,
+ force_channels=self.force_channels
+ )
+ self.audio_set = AudioDataset(
+ self.audio_configs,
+ sample_rate=self.sample_rate,
+ sample_size=self.sample_size,
+ random_crop=self.random_crop,
+ input_type=self.input_type,
+ fps=self.input_type,
+ force_channels=self.force_channels
+ )
+ self.val_set = create_dataset(self.val_configs, random_crop=False)
+ elif stage == 'validate':
+ self.val_set = create_dataset(self.val_configs, random_crop=False)
+ elif stage == 'predict':
+ self.test_set = create_dataset(self.test_configs, random_crop=False)
+
+
+
+ def train_dataloader(self):
+ if self.dataset_type == "alternating_multimodal":
+ # 视频 DataLoader
+ video_loader = DataLoader(
+ self.video_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ collate_fn=collation_fn
+ )
+
+ # 音频 DataLoader
+ audio_loader = DataLoader(
+ self.audio_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ drop_last=True,
+ collate_fn=collation_fn
+ )
+ alternating_loader = AlternatingLoader(
+ video_loader,
+ audio_loader,
+ loader1_name='video',
+ loader2_name='audio'
+ )
+ return DataLoader(alternating_loader, batch_size=None, num_workers=0)
+ else:
+ # 如果不是 alternating_multimodal,保持现有逻辑(仅用于兼容性)
+ return DataLoader(
+ self.train_set,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ persistent_workers=True,
+ pin_memory=True,
+ drop_last=True,
+ collate_fn=collation_fn
+ )
+
+
+ def val_dataloader(self):
+ return DataLoader(self.val_set, self.batch_size, shuffle=False,
+ num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
+
+ def predict_dataloader(self):
+ return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False,
+ num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
+
+ # def predict_dataloader(self):
+ # return DataLoader(self.mnist_predict, batch_size=self.batch_size)
+
+ # def teardown(self, stage: str):
+ # # Used to clean-up when the run is finished
+ # ...
\ No newline at end of file
diff --git a/ThinkSound/data/dataset.py b/ThinkSound/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac3546e73da58cfe7d05dd46cbdf997d168f00a
--- /dev/null
+++ b/ThinkSound/data/dataset.py
@@ -0,0 +1,1319 @@
+import importlib
+import numpy as np
+import io
+import os
+import posixpath
+import random
+import re
+import subprocess
+import time
+import torch
+import torchaudio
+import webdataset as wds
+import pandas as pd
+from aeiou.core import is_silence
+from os import path
+from pathlib import Path
+from pedalboard.io import AudioFile
+from torchaudio import transforms as T
+from typing import Optional, Callable, List
+import bisect
+
+from .utils import FOA, Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, PadCrop_Video_Normalized_T, PadCrop_Video_Hiera_Normalized_T, PadCrop_Video_Image_Normalized_T, PadCrop_DualVideo_Normalized_T
+
+AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
+
+# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
+
+def fast_scandir(
+ dir:str, # top-level directory at which to begin scanning
+ ext:list, # list of allowed file extensions,
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
+ ):
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
+ subfolders, files = [], []
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
+ try: # hope to avoid 'permission denied' by this try
+ for f in os.scandir(dir):
+ try: # 'hope to avoid too many levels of symbolic links' error
+ if f.is_dir():
+ subfolders.append(f.path)
+ elif f.is_file():
+ file_ext = os.path.splitext(f.name)[1].lower()
+ is_hidden = os.path.basename(f.path).startswith(".")
+
+ if file_ext in ext and not is_hidden:
+ files.append(f.path)
+ except:
+ pass
+ except:
+ pass
+
+ for dir in list(subfolders):
+ sf, f = fast_scandir(dir, ext)
+ subfolders.extend(sf)
+ files.extend(f)
+ return subfolders, files
+
+def keyword_scandir(
+ dir: str, # top-level directory at which to begin scanning
+ ext: list, # list of allowed file extensions
+ keywords: list, # list of keywords to search for in the file name
+):
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
+ subfolders, files = [], []
+ # make keywords case insensitive
+ keywords = [keyword.lower() for keyword in keywords]
+ # add starting period to extensions if needed
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
+ banned_words = ["paxheader", "__macosx"]
+ try: # hope to avoid 'permission denied' by this try
+ for f in os.scandir(dir):
+ try: # 'hope to avoid too many levels of symbolic links' error
+ if f.is_dir():
+ subfolders.append(f.path)
+ elif f.is_file():
+ is_hidden = f.name.split("/")[-1][0] == '.'
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
+ name_lower = f.name.lower()
+ has_keyword = any(
+ [keyword in name_lower for keyword in keywords])
+ has_banned = any(
+ [banned_word in name_lower for banned_word in banned_words])
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
+ files.append(f.path)
+ except:
+ pass
+ except:
+ pass
+
+ for dir in list(subfolders):
+ sf, f = keyword_scandir(dir, ext, keywords)
+ subfolders.extend(sf)
+ files.extend(f)
+ return subfolders, files
+
+def get_audio_filenames(
+ paths: list, # directories in which to search
+ keywords=None,
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
+):
+ "recursively get a list of audio filenames"
+ filenames = []
+ if type(paths) is str:
+ paths = [paths]
+ for path in paths: # get a list of relevant filenames
+ if keywords is not None:
+ subfolders, files = keyword_scandir(path, exts, keywords)
+ else:
+ subfolders, files = fast_scandir(path, exts)
+ filenames.extend(files)
+ return filenames
+
+
+
+
+
+class LocalDatasetConfig:
+ def __init__(
+ self,
+ id: str,
+ path: str,
+ split_path: str,
+ audio_dir: str = None,
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
+ ):
+ self.id = id
+ self.path = path
+ self.split_path = split_path
+ self.audio_dir = audio_dir
+ self.custom_metadata_fn = custom_metadata_fn
+
+class SampleDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ configs,
+ sample_size=65536,
+ sample_rate=48000,
+ keywords=None,
+ random_crop=True,
+ input_type="prompt",
+ fps=4,
+ force_channels="stereo"
+ ):
+ super().__init__()
+ self.filenames = []
+
+ self.augs = torch.nn.Sequential(
+ PhaseFlipper(),
+ )
+
+ self.root_paths = []
+ if input_type == 'video':
+ self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+ elif input_type == 'video_hiera':
+ self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+ elif input_type == 'video_image':
+ self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+ elif input_type == 'dual_video':
+ self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+ else:
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
+
+ self.force_channels = force_channels
+ print('######################')
+ print(f'input channels is: {force_channels}')
+ print('######################')
+ self.encoding = torch.nn.Sequential(
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+ )
+ self.input_type = input_type
+ self.sr = sample_rate
+ self.custom_metadata_fns = {}
+
+ for config in configs:
+ self.root_paths.append(config.path)
+ def add_prefix(s):
+ return str(os.path.join(config.path,f'{s.strip()}'))
+ with open(config.split_path,'r') as f:
+ item_names = f.readlines()
+ filenames = list(map(add_prefix, item_names))
+ self.filenames.extend(filenames)
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
+ if config.custom_metadata_fn is not None:
+ self.custom_metadata_fns[config.path] = config.custom_metadata_fn
+
+ print(f'Found {len(self.filenames)} files')
+
+ def load_file(self, filename):
+ ext = filename.split(".")[-1]
+ if ext == "mp3":
+ with AudioFile(filename) as f:
+ audio = f.read(f.frames)
+ audio = torch.from_numpy(audio)
+ in_sr = f.samplerate
+ else:
+ audio, in_sr = torchaudio.load(filename, format=ext)
+
+ if in_sr != self.sr:
+ try:
+ resample_tf = T.Resample(in_sr, self.sr)
+ audio = resample_tf(audio)
+ except:
+ print(f'{filename} resample errors')
+
+ assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
+ return audio
+
+ def __len__(self):
+ return len(self.filenames)
+
+ def __getitem__(self, idx):
+ audio_filename = self.filenames[idx]
+ assert os.path.exists(audio_filename), f'{audio_filename}: file not exists'
+ try:
+ start_time = time.time()
+ audio = self.load_file(audio_filename)
+ info = {}
+ info["path"] = audio_filename
+
+ for root_path in self.root_paths:
+ if root_path in audio_filename:
+ info["relpath"] = path.relpath(audio_filename, root_path)
+
+
+ for custom_md_path in self.custom_metadata_fns.keys():
+ if custom_md_path in audio_filename:
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
+ custom_metadata = custom_metadata_fn(info, audio)
+ info.update(custom_metadata)
+
+ if "__reject__" in info and info["__reject__"]:
+ return self[random.randrange(len(self))]
+ if self.input_type == 'video':
+ audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'])
+ info['video'] = video
+ elif self.input_type == 'dual_video':
+ audio, video_360, video_fov, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'], info['video_fov'])
+ info['video_360'] = video_360
+ info['video_fov'] = video_fov
+ else:
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
+ assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
+ # Run augmentations on this sample (including random crop)
+ if self.augs is not None:
+ audio = self.augs(audio)
+
+ audio = audio.clamp(-1, 1)
+
+ # Encode the file to assist in prediction
+ if self.encoding is not None:
+ audio = self.encoding(audio)
+
+
+
+ info["timestamps"] = (t_start, t_end)
+ info["seconds_start"] = seconds_start
+ info["seconds_total"] = seconds_total
+ info["padding_mask"] = padding_mask
+
+ end_time = time.time()
+ info["load_time"] = end_time - start_time
+
+
+ return (audio, info)
+ except Exception as e:
+ print(f'Couldn\'t load file {audio_filename}: {e}')
+ return self[random.randrange(len(self))]
+
+class LatentDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ configs,
+ sample_size=65536,
+ sample_rate=48000,
+ keywords=None,
+ random_crop=True,
+ input_type="prompt",
+ fps=4,
+ force_channels="stereo"
+ ):
+ super().__init__()
+ self.filenames = []
+
+ self.augs = torch.nn.Sequential(
+ PhaseFlipper(),
+ )
+
+ self.root_paths = []
+
+ self.force_channels = force_channels
+ print('######################')
+ print(f'input channels is: {force_channels}')
+ print('######################')
+ self.encoding = torch.nn.Sequential(
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+ )
+ self.input_type = input_type
+ self.sr = sample_rate
+ for config in configs:
+ self.root_paths.append(config.path)
+ def add_prefix(s):
+ return str(os.path.join(config.path,f'{s.strip()}'))
+ with open(config.split_path,'r') as f:
+ item_names = f.readlines()
+ filenames = list(map(add_prefix, item_names))
+ self.filenames.extend(filenames)
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
+
+
+ print(f'Found {len(self.filenames)} files')
+
+ def load_file(self, filename, info):
+ # try:
+ npz_file = filename.replace('.pth','.npz')
+ if os.path.exists(filename) and '.npz' not in filename:
+ data = torch.load(filename, weights_only=False)
+ elif os.path.exists(npz_file):
+ # print(filename)
+ npz_data = np.load(npz_file,allow_pickle=True)
+ data = {key: npz_data[key] for key in npz_data.files}
+ # print("data.keys()",data.keys())
+ for key in data.keys():
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
+ data[key] = torch.from_numpy(data[key])
+ else:
+ raise ValueError(f'error load file with file not exists: {filename}')
+ info.update(data)
+ audio = data['latent']
+ # except:
+ # print(f'error load file: {filename}')
+ return audio
+
+ def __len__(self):
+ return len(self.filenames)
+
+ def __getitem__(self, idx):
+ audio_filename = self.filenames[idx]
+ assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists'
+ # try:
+ start_time = time.time()
+ info = {}
+ audio = self.load_file(audio_filename, info)
+ info["path"] = audio_filename
+ info['id'] = Path(audio_filename).stem
+ for root_path in self.root_paths:
+ if root_path in audio_filename:
+ info["relpath"] = path.relpath(audio_filename, root_path)
+
+ return (audio, info)
+
+class AudioDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ configs,
+ sample_size=65536,
+ sample_rate=48000,
+ keywords=None,
+ random_crop=True,
+ input_type="prompt",
+ fps=4,
+ force_channels="stereo"
+ ):
+ super().__init__()
+ self.filenames = []
+
+ self.augs = torch.nn.Sequential(
+ PhaseFlipper(),
+ )
+
+ self.root_paths = []
+
+ self.force_channels = force_channels
+ print('######################')
+ print(f'input channels is: {force_channels}')
+ print('######################')
+ self.encoding = torch.nn.Sequential(
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+ )
+ self.fake_clip_features = torch.zeros(72, 1024)
+ self.fake_sync_features = torch.zeros(216, 768)
+ self.video_exist = torch.tensor(0, dtype=torch.bool)
+ self.input_type = input_type
+ self.sr = sample_rate
+ for config in configs:
+ self.root_paths.append(config.path)
+ def add_prefix(s):
+ return str(os.path.join(config.path,f'{s.strip()}'))
+ with open(config.split_path,'r') as f:
+ item_names = f.readlines()
+ filenames = list(map(add_prefix, item_names))
+ self.filenames.extend(filenames)
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
+
+
+ print(f'Found {len(self.filenames)} files')
+
+ def load_file(self, filename, info):
+ # try:
+ npz_file = filename.replace('.pth','.npz')
+ if os.path.exists(filename) and '.npz' not in filename:
+ data = torch.load(filename, weights_only=False)
+ elif os.path.exists(npz_file):
+ # print(filename)
+ npz_data = np.load(npz_file,allow_pickle=True)
+ data = dict(npz_data)
+ # print("data.keys()",data.keys())
+ for key in data.keys():
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
+ data[key] = torch.from_numpy(data[key])
+ else:
+ raise ValueError(f'error load file: {filename}')
+ info.update(data)
+ audio = data['latent']
+ if 'source_latent' not in data.keys():
+ info['source_latent']= audio
+ info['video_features'] = self.fake_clip_features
+ info['sync_features'] = self.fake_sync_features
+ info['video_exist'] = self.video_exist
+ # except:
+ # print(f'error load file: {filename}')
+ return audio
+
+ def __len__(self):
+ return len(self.filenames)
+
+ def __getitem__(self, idx):
+ audio_filename = self.filenames[idx]
+ assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists'
+ # try:
+ start_time = time.time()
+ info = {}
+ audio = self.load_file(audio_filename, info)
+ info["path"] = audio_filename
+ assert audio.shape == (64,194), f'{audio.shape} input error, id: {audio_filename}'
+ info['id'] = Path(audio_filename).stem
+ for root_path in self.root_paths:
+ if root_path in audio_filename:
+ info["relpath"] = path.relpath(audio_filename, root_path)
+
+ return (audio, info)
+
+
+
+
+class VideoDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ configs,
+ sample_size=65536,
+ sample_rate=48000,
+ keywords=None,
+ random_crop=True,
+ input_type="prompt",
+ fps=4,
+ force_channels="stereo"
+ ):
+ super().__init__()
+ self.filenames = []
+
+ self.augs = torch.nn.Sequential(
+ PhaseFlipper(),
+ )
+
+ self.root_paths = []
+ self.sample_size = sample_size
+ self.force_channels = force_channels
+ print('######################')
+ print(f'input channels is: {force_channels}')
+ print('######################')
+ self.encoding = torch.nn.Sequential(
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+ )
+ self.input_type = input_type
+ self.sr = sample_rate
+ self.video_exist = torch.tensor(1, dtype=torch.bool)
+ self.audio_files = []
+ for config in configs:
+ self.root_paths.append(config.path)
+ def add_prefix(s):
+ return str(os.path.join(config.path,f'{s.strip()}'))
+ with open(config.split_path,'r') as f:
+ item_names = f.readlines()
+ filenames = list(map(add_prefix, item_names))
+ self.filenames.extend(filenames)
+ if config.audio_dir is not None:
+ def add_prefix(s):
+ return str(os.path.join(config.audio_dir,f'{Path(s).stem}.wav'))
+ filenames = list(map(add_prefix, item_names))
+ self.audio_files.extend(filenames)
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
+
+ print(f'Found {len(self.filenames)} files')
+
+ def load_audio(self, filename):
+ ext = filename.split(".")[-1]
+ if ext == "mp3":
+ with AudioFile(filename) as f:
+ audio = f.read(f.frames)
+ audio = torch.from_numpy(audio)
+ in_sr = f.samplerate
+ else:
+ audio, in_sr = torchaudio.load(filename, format=ext)
+
+ if in_sr != self.sr:
+ try:
+ resample_tf = T.Resample(in_sr, self.sr)
+ audio = resample_tf(audio)
+ except:
+ print(f'{filename} resample errors')
+
+ assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
+ return audio
+
+
+
+
+ def check_audio_file(self, audio_path):
+ # 首先检查原始路径是否存在
+ if os.path.exists(audio_path):
+ return audio_path
+ # 如果不存在,尝试替换为.flac扩展名
+ name, ext = os.path.splitext(audio_path)
+ flac_path = f"{name}.flac"
+
+ if os.path.exists(flac_path):
+ return flac_path
+ raise FileNotFoundError(f"音频文件不存在: {audio_path} 和 {flac_path} 都不存在")
+
+ def load_file(self, filename, info):
+ try:
+ npz_file = filename.replace('.pth','.npz')
+ if os.path.exists(filename) and '.npz' not in filename:
+ data = torch.load(filename, weights_only=False)
+ elif os.path.exists(npz_file):
+ # print(filename)
+ npz_data = np.load(npz_file,allow_pickle=True)
+ data = {key: npz_data[key] for key in npz_data.files}
+ # print("data.keys()",data.keys())
+ for key in data.keys():
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
+ data[key] = torch.from_numpy(data[key])
+ else:
+ raise ValueError(f'error load file: {filename}')
+ info.update(data)
+ audio = data['latent']
+ info['video_exist'] = self.video_exist
+ except Exception as e:
+ print(f'error load file: {filename} with error: {e}')
+ return None
+ return audio
+
+ def __len__(self):
+ return len(self.filenames)
+
+ def __getitem__(self, idx):
+ loop = True
+ while loop:
+ filename = self.filenames[idx]
+ if len(self.audio_files) > 0:
+ audio_path = self.audio_files[idx]
+ audio_path = self.check_audio_file(audio_path)
+ waveform = self.load_audio(audio_path)
+ else:
+ waveform = None
+ assert os.path.exists(filename) or filename.replace('.pth','.npz'), f'{filename}: file not exists'
+ # try:
+ start_time = time.time()
+ info = {}
+ audio = self.load_file(filename, info)
+ if audio is not None:
+ loop = False
+ else:
+ idx = (idx+1) % len(self.filenames)
+
+ if waveform is not None:
+ padded_waveform = torch.zeros(waveform.shape[0], self.sample_size, dtype=waveform.dtype)
+ copy_length = min(waveform.shape[1], self.sample_size)
+ padded_waveform[:, :copy_length] = waveform[:, :copy_length]
+
+ waveform = padded_waveform
+ waveform = waveform.clamp(-1, 1)
+ # Encode the file to assist in prediction
+ if self.encoding is not None:
+ waveform = self.encoding(waveform)
+ info['waveform'] = waveform
+ info["path"] = filename
+ info['id'] = Path(filename).stem
+ for root_path in self.root_paths:
+ if root_path in filename:
+ info["relpath"] = path.relpath(filename, root_path)
+
+
+ return (audio, info)
+
+# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
+class MultiModalDataset(torch.utils.data.Dataset):
+ datasets: list[torch.utils.data.Dataset]
+ cumulative_sizes: list[int]
+
+ @staticmethod
+ def cumsum(sequence):
+ r, s = [], 0
+ for e in sequence:
+ l = len(e)
+ r.append(l + s)
+ s += l
+ return r
+
+ def __init__(self, video_datasets: list[torch.utils.data.Dataset], audio_datasets: list[torch.utils.data.Dataset]):
+ super().__init__()
+ self.video_datasets = list(video_datasets)
+ self.audio_datasets = list(audio_datasets)
+ self.datasets = self.video_datasets + self.audio_datasets
+
+ self.cumulative_sizes = self.cumsum(self.datasets)
+ print(f'Found {self.cumulative_sizes[-1]} files')
+
+ def __len__(self):
+ return self.cumulative_sizes[-1]
+
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx]
+
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
+ return self.video_datasets[0].compute_latent_stats()
+
+
+# class MultiModalDataset(torch.utils.data.Dataset):
+# def __init__(
+# self,
+# configs,
+# sample_size=65536,
+# sample_rate=48000,
+# keywords=None,
+# random_crop=True,
+# input_type="prompt",
+# fps=4,
+# force_channels="stereo"
+# ):
+# super().__init__()
+# self.filenames = []
+# self.captions = []
+# self.caption_t5s = []
+# self.ids = []
+# self.augs = torch.nn.Sequential(
+# PhaseFlipper(),
+# )
+
+# self.root_paths = []
+# if input_type == 'video':
+# self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+# elif input_type == 'video_hiera':
+# self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+# elif input_type == 'video_image':
+# self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+# elif input_type == 'dual_video':
+# self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
+# else:
+# self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
+
+# self.force_channels = force_channels
+# print('######################')
+# print(f'input channels is: {force_channels}')
+# print('######################')
+# self.encoding = torch.nn.Sequential(
+# FOA() if self.force_channels == "foa" else torch.nn.Identity(),
+# Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+# Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+# )
+# self.input_type = input_type
+# self.sr = sample_rate
+# self.custom_metadata_fns = {}
+
+# for config in configs:
+# print(config.split_path)
+# self.root_paths.append(config.path)
+# def add_prefix(s):
+# return str(os.path.join(config.path,f'{s.strip()}'))
+# with open(config.split_path,'r') as f:
+# item_names = f.readlines()
+# csv_path = config.split_path.replace('.txt','.csv')
+# df = pd.read_csv(csv_path)
+# # 检查是否存在 'caption_t5' 列,如果不存在则创建并复制 'caption' 的值
+# if 'caption_t5' not in df.columns:
+# df['caption_t5'] = df['caption']
+
+# captions = df['caption'].tolist()
+# caption_t5s = df['caption_t5'].tolist()
+# filenames = list(map(add_prefix, item_names))
+# assert len(captions) == len(caption_t5s) and len(captions) == len(filenames), f'{config.path} has wrong filename and caption'
+# if config.id == 'vggsound':
+# self.filenames.extend(filenames*5)
+# self.captions.extend(captions*5)
+# self.caption_t5s.extend(caption_t5s*5)
+# self.ids.extend(df['id'].tolist()*5)
+# else:
+# self.filenames.extend(filenames)
+# self.captions.extend(captions)
+# self.caption_t5s.extend(caption_t5s)
+# self.ids.extend(df['id'].tolist())
+# # self.filenames.extend(get_audio_filenames(config.path, keywords))
+# if config.custom_metadata_fn is not None:
+# self.custom_metadata_fns[config.path] = config.custom_metadata_fn
+
+# assert len(self.ids) == len(self.captions) and len(self.caption_t5s) == len(self.filenames), 'length need to be same'
+# print(f'Found {len(self.filenames)} files')
+
+
+# def load_file(self, filename):
+# ext = filename.split(".")[-1]
+# if ext == "mp3":
+# with AudioFile(filename) as f:
+# audio = f.read(f.frames)
+# audio = torch.from_numpy(audio)
+# in_sr = f.samplerate
+# else:
+# audio, in_sr = torchaudio.load(filename, format=ext)
+
+# if in_sr != self.sr:
+# try:
+# resample_tf = T.Resample(in_sr, self.sr)
+# audio = resample_tf(audio)
+# except:
+# print(f'{filename} resample errors')
+
+# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
+# return audio
+
+# def __len__(self):
+# return len(self.filenames)
+
+# def __getitem__(self, idx):
+# audio_filename = self.filenames[idx]
+# id = self.ids[idx]
+# assert str(id) == str(Path(audio_filename).stem), f'audio_file: {audio_filename} needs to be same as {id} '
+# assert os.path.exists(audio_filename), f'{audio_filename}: file not exists'
+# try:
+# start_time = time.time()
+# audio = self.load_file(audio_filename)
+# caption = self.captions[idx]
+# caption_t5 = self.caption_t5s[idx]
+# if pd.isna(caption_t5) or caption_t5 == '':
+# caption_t5 = caption
+# info = {}
+# info["path"] = audio_filename
+# info['caption'] = caption
+# info['caption_t5'] = caption_t5
+
+# for root_path in self.root_paths:
+# if root_path in audio_filename:
+# info["relpath"] = path.relpath(audio_filename, root_path)
+
+
+# for custom_md_path in self.custom_metadata_fns.keys():
+# if custom_md_path in audio_filename:
+# custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
+# custom_metadata = custom_metadata_fn(info, audio)
+# info.update(custom_metadata)
+
+# if "__reject__" in info and info["__reject__"]:
+# return self[random.randrange(len(self))]
+# # if self.input_type == 'video':
+# # audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['clip_features'])
+# # info['clip_features'] = video
+# # else:
+# if info['flag']:
+# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=False)
+# else:
+# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=True)
+# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
+# # Run augmentations on this sample (including random crop)
+# if self.augs is not None:
+# audio = self.augs(audio)
+
+# audio = audio.clamp(-1, 1)
+
+# # Encode the file to assist in prediction
+# if self.encoding is not None:
+# audio = self.encoding(audio)
+
+
+
+# info["timestamps"] = (t_start, t_end)
+# info["seconds_start"] = seconds_start
+# info["seconds_total"] = seconds_total
+# info["padding_mask"] = padding_mask
+
+# end_time = time.time()
+# info["load_time"] = end_time - start_time
+
+
+# return (audio, info)
+# except Exception as e:
+# print(f'Couldn\'t load file {audio_filename}: {e}')
+# return self[random.randrange(len(self))]
+
+def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+ :param keys: function that splits the key into key and extension (base_plus_ext)
+ :param lcase: convert suffixes to lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if wds.tariterators.trace:
+ print(
+ prefix,
+ suffix,
+ current_sample.keys() if isinstance(current_sample, dict) else None,
+ )
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ if current_sample is None or prefix != current_sample["__key__"]:
+ if wds.tariterators.valid_sample(current_sample):
+ yield current_sample
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
+ if suffix in current_sample:
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if wds.tariterators.valid_sample(current_sample):
+ yield current_sample
+
+wds.tariterators.group_by_keys = group_by_keys
+
+# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
+
+def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
+ """
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
+ """
+ # Ensure dataset_path ends with a trailing slash
+ if dataset_path != '' and not dataset_path.endswith('/'):
+ dataset_path += '/'
+ # Use posixpath to construct the S3 URL path
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
+ # Construct the `aws s3 ls` command
+ cmd = ['aws', 's3', 'ls', bucket_path]
+
+ if profile is not None:
+ cmd.extend(['--profile', profile])
+
+ if recursive:
+ # Add the --recursive flag if requested
+ cmd.append('--recursive')
+
+ # Run the `aws s3 ls` command and capture the output
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
+ # Split the output into lines and strip whitespace from each line
+ contents = run_ls.stdout.decode('utf-8').split('\n')
+ contents = [x.strip() for x in contents if x]
+ # Remove the timestamp from lines that begin with a timestamp
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
+ # Construct a full S3 path for each file in the contents list
+ contents = [posixpath.join(s3_url_prefix or '', x)
+ for x in contents if not x.endswith('/')]
+ # Apply the filter, if specified
+ if filter:
+ contents = [x for x in contents if filter in x]
+ # Remove redundant directory names in the S3 URL
+ if recursive:
+ # Get the main directory name from the S3 URL
+ main_dir = "/".join(bucket_path.split('/')[3:])
+ # Remove the redundant directory names from each file path
+ contents = [x.replace(f'{main_dir}', '').replace(
+ '//', '/') for x in contents]
+ # Print debugging information, if requested
+ if debug:
+ print("contents = \n", contents)
+ # Return the list of S3 paths to files
+ return contents
+
+
+def get_all_s3_urls(
+ names=[], # list of all valid [LAION AudioDataset] dataset names
+ # list of subsets you want from those datasets, e.g. ['train','valid']
+ subsets=[''],
+ s3_url_prefix=None, # prefix for those dataset names
+ recursive=True, # recursively list all tar files in all subdirs
+ filter_str='tar', # only grab files with this substring
+ # print debugging info -- note: info displayed likely to change at dev's whims
+ debug=False,
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
+):
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
+ urls = []
+ for name in names:
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
+ if s3_url_prefix is None:
+ contents_str = name
+ else:
+ # Construct the S3 path using the s3_url_prefix and the current name value
+ contents_str = posixpath.join(s3_url_prefix, name)
+ if debug:
+ print(f"get_all_s3_urls: {contents_str}:")
+ for subset in subsets:
+ subset_str = posixpath.join(contents_str, subset)
+ if debug:
+ print(f"subset_str = {subset_str}")
+ # Get the list of tar files in the current subset directory
+ profile = profiles.get(name, None)
+ tar_list = get_s3_contents(
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
+ for tar in tar_list:
+ # Escape spaces and parentheses in the tar filename for use in the shell command
+ tar = tar.replace(" ", "\ ").replace(
+ "(", "\(").replace(")", "\)")
+ # Construct the S3 path to the current tar file
+ s3_path = posixpath.join(name, subset, tar) + " -"
+ # Construct the AWS CLI command to download the current tar file
+ if s3_url_prefix is None:
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
+ else:
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
+ if profiles.get(name):
+ request_str += f" --profile {profiles.get(name)}"
+ if debug:
+ print("request_str = ", request_str)
+ # Add the constructed URL to the list of URLs
+ urls.append(request_str)
+ return urls
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+ return True
+
+
+def is_valid_sample(sample):
+ has_json = "json" in sample
+ has_audio = "audio" in sample
+ is_silent = is_silence(sample["audio"])
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
+
+ return has_json and has_audio and not is_silent and not is_rejected
+
+class S3DatasetConfig:
+ def __init__(
+ self,
+ id: str,
+ s3_path: str,
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
+ profile: Optional[str] = None,
+ ):
+ self.id = id
+ self.path = s3_path
+ self.custom_metadata_fn = custom_metadata_fn
+ self.profile = profile
+ self.urls = []
+
+ def load_data_urls(self):
+ self.urls = get_all_s3_urls(
+ names=[self.path],
+ s3_url_prefix=None,
+ recursive=True,
+ profiles={self.path: self.profile} if self.profile else {},
+ )
+
+ return self.urls
+
+class LocalWebDatasetConfig:
+ def __init__(
+ self,
+ id: str,
+ path: str,
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
+ profile: Optional[str] = None,
+ ):
+ self.id = id
+ self.path = path
+ self.custom_metadata_fn = custom_metadata_fn
+ self.urls = []
+
+ def load_data_urls(self):
+
+ self.urls = fast_scandir(self.path, ["tar"])[1]
+
+ return self.urls
+
+def audio_decoder(key, value):
+ # Get file extension from key
+ ext = key.split(".")[-1]
+
+ if ext in AUDIO_KEYS:
+ return torchaudio.load(io.BytesIO(value))
+ else:
+ return None
+
+def collation_fn(samples):
+ batched = list(zip(*samples))
+ result = []
+ for b in batched:
+ if isinstance(b[0], (int, float)):
+ b = np.array(b)
+ elif isinstance(b[0], torch.Tensor):
+ b = torch.stack(b)
+ elif isinstance(b[0], np.ndarray):
+ b = np.array(b)
+ else:
+ b = b
+ result.append(b)
+ return result
+
+class WebDatasetDataLoader():
+ def __init__(
+ self,
+ datasets: List[S3DatasetConfig],
+ batch_size,
+ sample_size,
+ sample_rate=48000,
+ num_workers=8,
+ epoch_steps=1000,
+ random_crop=True,
+ force_channels="stereo",
+ augment_phase=True,
+ **data_loader_kwargs
+ ):
+
+ self.datasets = datasets
+
+ self.sample_size = sample_size
+ self.sample_rate = sample_rate
+ self.random_crop = random_crop
+ self.force_channels = force_channels
+ self.augment_phase = augment_phase
+
+ urls = [dataset.load_data_urls() for dataset in datasets]
+
+ # Flatten the list of lists of URLs
+ urls = [url for dataset_urls in urls for url in dataset_urls]
+
+ # Shuffle the urls
+ random.shuffle(urls)
+
+ self.dataset = wds.DataPipeline(
+ wds.ResampledShards(urls),
+ wds.tarfile_to_samples(handler=log_and_continue),
+ wds.decode(audio_decoder, handler=log_and_continue),
+ wds.map(self.wds_preprocess, handler=log_and_continue),
+ wds.select(is_valid_sample),
+ wds.to_tuple("audio", "json", handler=log_and_continue),
+ #wds.shuffle(bufsize=1000, initial=5000),
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
+
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
+
+ def wds_preprocess(self, sample):
+
+ found_key, rewrite_key = '', ''
+ for k, v in sample.items(): # print the all entries in dict
+ for akey in AUDIO_KEYS:
+ if k.endswith(akey):
+ # to rename long/weird key with its simpler counterpart
+ found_key, rewrite_key = k, akey
+ break
+ if '' != found_key:
+ break
+ if '' == found_key: # got no audio!
+ return None # try returning None to tell WebDataset to skip this one
+
+ audio, in_sr = sample[found_key]
+ if in_sr != self.sample_rate:
+ resample_tf = T.Resample(in_sr, self.sample_rate)
+ audio = resample_tf(audio)
+
+ if self.sample_size is not None:
+ # Pad/crop and get the relative timestamp
+ pad_crop = PadCrop_Normalized_T(
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
+ audio)
+ sample["json"]["seconds_start"] = seconds_start
+ sample["json"]["seconds_total"] = seconds_total
+ sample["json"]["padding_mask"] = padding_mask
+ else:
+ t_start, t_end = 0, 1
+
+ # Check if audio is length zero, initialize to a single zero if so
+ if audio.shape[-1] == 0:
+ audio = torch.zeros(1, 1)
+
+ # Make the audio stereo and augment by randomly inverting phase
+ augs = torch.nn.Sequential(
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
+ )
+
+ audio = augs(audio)
+
+ sample["json"]["timestamps"] = (t_start, t_end)
+
+ if "text" in sample["json"]:
+ sample["json"]["prompt"] = sample["json"]["text"]
+
+ # Check for custom metadata functions
+ for dataset in self.datasets:
+ if dataset.custom_metadata_fn is None:
+ continue
+
+ if dataset.path in sample["__url__"]:
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
+ sample["json"].update(custom_metadata)
+
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
+ del sample[found_key]
+
+ sample["audio"] = audio
+
+ # Add audio to the metadata as well for conditioning
+ sample["json"]["audio"] = audio
+
+ return sample
+
+def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, shuffle=True):
+
+ dataset_type = dataset_config.get("dataset_type", None)
+
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
+
+ if audio_channels == 1:
+ force_channels = "mono"
+ elif audio_channels == 2:
+ force_channels = "stereo"
+ else:
+ force_channels = "foa"
+
+ if dataset_type == "audio_dir":
+
+ audio_dir_configs = dataset_config.get("datasets", None)
+
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
+
+ configs = []
+
+ for audio_dir_config in audio_dir_configs:
+ audio_dir_path = audio_dir_config.get("path", None)
+ split_path = audio_dir_config.get("split_path", None)
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
+ custom_metadata_fn = None
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
+
+ if custom_metadata_module_path is not None:
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
+ metadata_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(metadata_module)
+
+ custom_metadata_fn = metadata_module.get_custom_metadata
+
+ configs.append(
+ LocalDatasetConfig(
+ id=audio_dir_config["id"],
+ path=audio_dir_path,
+ split_path=split_path,
+ custom_metadata_fn=custom_metadata_fn
+ )
+ )
+
+ train_set = SampleDataset(
+ configs,
+ sample_rate=sample_rate,
+ sample_size=sample_size,
+ random_crop=dataset_config.get("random_crop", True),
+ input_type=dataset_config.get("input_type", "video"),
+ fps=dataset_config.get("fps", 4),
+ force_channels=force_channels
+ )
+
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
+
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
+
+ wds_configs = []
+
+ for wds_config in dataset_config["datasets"]:
+
+ custom_metadata_fn = None
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
+
+ if custom_metadata_module_path is not None:
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
+ metadata_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(metadata_module)
+
+ custom_metadata_fn = metadata_module.get_custom_metadata
+
+ if "s3_path" in wds_config:
+
+ wds_configs.append(
+ S3DatasetConfig(
+ id=wds_config["id"],
+ s3_path=wds_config["s3_path"],
+ custom_metadata_fn=custom_metadata_fn,
+ profile=wds_config.get("profile", None),
+ )
+ )
+
+ elif "path" in wds_config:
+
+ wds_configs.append(
+ LocalWebDatasetConfig(
+ id=wds_config["id"],
+ path=wds_config["path"],
+ custom_metadata_fn=custom_metadata_fn
+ )
+ )
+
+ return WebDatasetDataLoader(
+ wds_configs,
+ sample_rate=sample_rate,
+ sample_size=sample_size,
+ batch_size=batch_size,
+ random_crop=dataset_config.get("random_crop", True),
+ num_workers=num_workers,
+ persistent_workers=True,
+ force_channels=force_channels,
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
+ ).data_loader
+
+ elif dataset_type == "latent_dir":
+
+ audio_dir_configs = dataset_config.get("datasets", None)
+
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
+
+ configs = []
+
+ for audio_dir_config in audio_dir_configs:
+ audio_dir_path = audio_dir_config.get("path", None)
+ split_path = audio_dir_config.get("split_path", None)
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
+
+ configs.append(
+ LocalDatasetConfig(
+ id=audio_dir_config["id"],
+ path=audio_dir_path,
+ split_path=split_path,
+ )
+ )
+
+ train_set = LatentDataset(
+ configs,
+ sample_rate=sample_rate,
+ sample_size=sample_size,
+ random_crop=dataset_config.get("random_crop", True),
+ input_type=dataset_config.get("input_type", "video"),
+ fps=dataset_config.get("fps", 4),
+ force_channels=force_channels
+ )
+
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
+ elif dataset_type == 'multimodal_dir':
+ audio_dir_configs = dataset_config.get("datasets", None)
+
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
+
+ configs = []
+
+ for audio_dir_config in audio_dir_configs:
+ audio_dir_path = audio_dir_config.get("path", None)
+ split_path = audio_dir_config.get("split_path", None)
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
+ custom_metadata_fn = None
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
+
+ if custom_metadata_module_path is not None:
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
+ metadata_module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(metadata_module)
+
+ custom_metadata_fn = metadata_module.get_custom_metadata
+
+ configs.append(
+ LocalDatasetConfig(
+ id=audio_dir_config["id"],
+ path=audio_dir_path,
+ split_path=split_path,
+ custom_metadata_fn=custom_metadata_fn
+ )
+ )
+
+ train_set = MultiModalDataset(
+ configs,
+ sample_rate=sample_rate,
+ sample_size=sample_size,
+ random_crop=dataset_config.get("random_crop", True),
+ input_type=dataset_config.get("input_type", "video"),
+ fps=dataset_config.get("fps", 4),
+ force_channels=force_channels
+ )
+
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle,
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
\ No newline at end of file
diff --git a/ThinkSound/data/utils.py b/ThinkSound/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a0f0a19a4f9a32ee3424cfd24efe7643ddde0fc
--- /dev/null
+++ b/ThinkSound/data/utils.py
@@ -0,0 +1,378 @@
+import math
+import random
+import torch
+import torch.nn.functional as F
+from torch import nn
+from typing import Tuple
+import numpy as np
+
+class PadCrop(nn.Module):
+ def __init__(self, n_samples, randomize=True):
+ super().__init__()
+ self.n_samples = n_samples
+ self.randomize = randomize
+
+ def __call__(self, signal):
+ n, s = signal.shape
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
+ end = start + self.n_samples
+ output = signal.new_zeros([n, self.n_samples])
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
+ return output
+
+class PadCrop_Normalized_T(nn.Module):
+
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+
+ def __call__(self, source: torch.Tensor, randomize=True) -> Tuple[torch.Tensor, float, float, int, int]:
+
+ n_channels, n_samples = source.shape
+
+ # If the audio is shorter than the desired length, pad it
+ upper_bound = max(0, n_samples - self.n_samples)
+
+ # If randomize is False, always start at the beginning of the audio
+ offset = 0
+ if(randomize and n_samples > self.n_samples):
+ offset = random.randint(0, upper_bound)
+
+ # Calculate the start and end times of the chunk
+ t_start = offset / (upper_bound + self.n_samples)
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
+
+ # Create the chunk
+ chunk = source.new_zeros([n_channels, self.n_samples])
+
+ # Copy the audio into the chunk
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
+
+ # Calculate the start and end times of the chunk in seconds
+ seconds_start = math.floor(offset / self.sample_rate)
+ seconds_total = math.ceil(n_samples / self.sample_rate)
+
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
+ padding_mask = torch.zeros([self.n_samples])
+ padding_mask[:min(n_samples, self.n_samples)] = 1
+
+
+ return (
+ chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total,
+ padding_mask
+ )
+
+class PadCrop_Video_Normalized_T(nn.Module):
+
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+ self.fps = fps
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
+
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
+ n_channels, n_samples = audio.shape
+ # print(video.shape)
+ n_frames, dim = video.shape
+ if not torch.is_tensor(video):
+ video = torch.from_numpy(video)
+ # If the audio is shorter than the desired length, pad it
+ audio_upper_bound = max(0, n_samples - self.n_samples)
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
+ upper_bound = min(audio_upper_bound,video_upper_bound)
+
+ # If randomize is False, always start at the beginning of the audio
+ offset = 0
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
+ offset = random.randint(0, upper_bound)
+
+ # Calculate the start and end times of the chunk
+ t_start = offset / (upper_bound + self.n_samples)
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
+ frame_offset = int(self.fps * offset / self.sample_rate)
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
+ # Create the chunk
+ chunk = audio.new_zeros([n_channels, self.n_samples])
+ video_chunk = video.new_zeros([self.n_frames, video.shape[1]])
+ # Copy the audio into the chunk
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames,:]
+ # Calculate the start and end times of the chunk in seconds
+ seconds_start = math.floor(offset / self.sample_rate)
+ seconds_total = math.ceil(n_samples / self.sample_rate)
+
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
+ padding_mask = torch.zeros([self.n_samples])
+ padding_mask[:min(n_samples, self.n_samples)] = 1
+
+
+ return (
+ chunk,
+ video_chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total,
+ padding_mask
+ )
+
+class PadCrop_Video_Image_Normalized_T(nn.Module):
+
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+ self.fps = fps
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
+
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
+ n_channels, n_samples = audio.shape
+ # import ipdb
+ # ipdb.set_trace()
+ n_frames, channel, width, height= video.shape
+ video = torch.from_numpy(video)
+ # If the audio is shorter than the desired length, pad it
+ audio_upper_bound = max(0, n_samples - self.n_samples)
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
+ upper_bound = min(audio_upper_bound,video_upper_bound)
+
+ # If randomize is False, always start at the beginning of the audio
+ offset = 0
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
+ offset = random.randint(0, upper_bound)
+
+ # Calculate the start and end times of the chunk
+ t_start = offset / (upper_bound + self.n_samples)
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
+ frame_offset = int(self.fps * offset / self.sample_rate)
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
+ # Create the chunk
+ chunk = audio.new_zeros([n_channels, self.n_samples])
+ video_chunk = video.new_zeros([self.n_frames, channel, width, height])
+ # Copy the audio into the chunk
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames]
+ # Calculate the start and end times of the chunk in seconds
+ seconds_start = math.floor(offset / self.sample_rate)
+ seconds_total = math.ceil(n_samples / self.sample_rate)
+
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
+ padding_mask = torch.zeros([self.n_samples])
+ padding_mask[:min(n_samples, self.n_samples)] = 1
+
+
+ return (
+ chunk,
+ video_chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total,
+ padding_mask
+ )
+
+class PadCrop_Video_Hiera_Normalized_T(nn.Module):
+
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+ self.fps = fps
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
+
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
+
+ n_channels, n_samples = audio.shape
+ n_frames, heigh, width, channel = video.shape
+ video = torch.from_numpy(video)
+ # If the audio is shorter than the desired length, pad it
+ audio_upper_bound = max(0, n_samples - self.n_samples)
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
+ upper_bound = min(audio_upper_bound,video_upper_bound)
+
+ # If randomize is False, always start at the beginning of the audio
+ offset = 0
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
+ offset = random.randint(0, upper_bound)
+
+ # Calculate the start and end times of the chunk
+ t_start = offset / (upper_bound + self.n_samples)
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
+ frame_offset = int(self.fps * offset / self.sample_rate)
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
+ # Create the chunk
+ chunk = audio.new_zeros([n_channels, self.n_samples])
+ video_chunk = video.new_zeros([self.n_frames, heigh, width, channel])
+ # Copy the audio into the chunk
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames]
+ # video_chunk = video_chunk[None].permute(0, 4, 1, 2, 3).contiguous()
+ # print(video_chunk.shape)
+ # video_chunk = F.interpolate(
+ # video_chunk[0],
+ # size=(224, 224, 3), # 输出的空间尺寸
+ # scale_factor=(target_frames / video_tensor.shape[1], 1, 1), # 时间轴的缩放因子
+ # mode='trilinear', # 使用三线性插值
+ # align_corners=False
+ # )
+
+ # video_chunk = F.interpolate(video_chunk, size=(64, 224, 224), mode="trilinear")[0]
+ # video_chunk = video_chunk.view(3,4,16,224,224).transpose(0,1)
+ # Calculate the start and end times of the chunk in seconds
+ seconds_start = math.floor(offset / self.sample_rate)
+ seconds_total = math.ceil(n_samples / self.sample_rate)
+
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
+ padding_mask = torch.zeros([self.n_samples])
+ padding_mask[:min(n_samples, self.n_samples)] = 1
+
+
+ return (
+ chunk,
+ video_chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total,
+ padding_mask
+ )
+
+class PadCrop_DualVideo_Normalized_T(nn.Module):
+
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
+
+ super().__init__()
+
+ self.n_samples = n_samples
+ self.sample_rate = sample_rate
+ self.randomize = randomize
+ self.fps = fps
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
+
+ def __call__(self, audio: torch.Tensor, video_360: torch.Tensor, video_fov: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
+ n_channels, n_samples = audio.shape
+ # print(video.shape)
+ n_frames, dim = video_360.shape
+ video_360 = torch.from_numpy(video_360)
+ video_fov = torch.from_numpy(video_fov)
+ # If the audio is shorter than the desired length, pad it
+ audio_upper_bound = max(0, n_samples - self.n_samples)
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
+ upper_bound = min(audio_upper_bound,video_upper_bound)
+
+ # If randomize is False, always start at the beginning of the audio
+ offset = 0
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
+ offset = random.randint(0, upper_bound)
+
+ # Calculate the start and end times of the chunk
+ t_start = offset / (upper_bound + self.n_samples)
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
+ frame_offset = int(self.fps * offset / self.sample_rate)
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
+ # Create the chunk
+ chunk = audio.new_zeros([n_channels, self.n_samples])
+ video_360_chunk = video_360.new_zeros([self.n_frames, video_360.shape[1]])
+ video_fov_chunk = video_fov.new_zeros([self.n_frames, video_fov.shape[1]])
+ # Copy the audio into the chunk
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
+ video_360_chunk[:min(n_frames, self.n_frames)] = video_360[frame_offset:frame_offset + self.n_frames,:]
+ video_fov_chunk[:min(n_frames, self.n_frames)] = video_fov[frame_offset:frame_offset + self.n_frames,:]
+ # Calculate the start and end times of the chunk in seconds
+ seconds_start = math.floor(offset / self.sample_rate)
+ seconds_total = math.ceil(n_samples / self.sample_rate)
+
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
+ padding_mask = torch.zeros([self.n_samples])
+ padding_mask[:min(n_samples, self.n_samples)] = 1
+
+
+ return (
+ chunk,
+ video_360_chunk,
+ video_fov_chunk,
+ t_start,
+ t_end,
+ seconds_start,
+ seconds_total,
+ padding_mask
+ )
+
+class PhaseFlipper(nn.Module):
+ "Randomly invert the phase of a signal"
+ def __init__(self, p=0.5):
+ super().__init__()
+ self.p = p
+ def __call__(self, signal):
+ return -signal if (random.random() < self.p) else signal
+
+class Mono(nn.Module):
+ def __call__(self, signal):
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
+
+class Stereo(nn.Module):
+ def __call__(self, signal):
+ signal_shape = signal.shape
+ # Check if it's mono
+ if len(signal_shape) == 1: # s -> 2, s
+ signal = signal.unsqueeze(0).repeat(2, 1)
+ elif len(signal_shape) == 2:
+ if signal_shape[0] == 1: #1, s -> 2, s
+ signal = signal.repeat(2, 1)
+ elif signal_shape[0] > 2: #?, s -> 2,s
+ signal = signal[:2, :]
+
+ return signal
+
+class FOA(nn.Module):
+ def __call__(self, signal):
+ signal_shape = signal.shape
+ # Check if it's mono
+ if len(signal_shape) == 1: # s -> (4, s)
+ foa = torch.zeros(4, signal_shape[0], device=signal.device) # 与输入信号一致的设备类型
+ foa[0, :] = signal # W通道: 全方位声源
+ foa[1, :] = 0 # X通道
+ foa[2, :] = 0 # Y通道
+ foa[3, :] = 0 # Z通道
+ elif len(signal_shape) == 2:
+ foa = torch.zeros(4, signal_shape[1], device=signal.device) # 与输入信号一致的设备类型
+ if signal_shape[0] == 1: # (1, s) -> (4, s)
+ foa[0, :] = signal[0] # W通道: 全方位声源
+ foa[1, :] = 0 # X通道
+ foa[2, :] = 0 # Y通道
+ foa[3, :] = 0 # Z通道
+ elif signal_shape[0] == 2: # (2, s) -> (4, s)
+ left = signal[0]
+ right = signal[1]
+ # 将立体声信号映射到FOA信号通道
+ foa[0, :] = (left + right) / np.sqrt(2) # W通道: 全方位声源
+ foa[1, :] = (left - right) / np.sqrt(2) # X通道: 前后方向
+ foa[2, :] = 0 # Y通道: 左右方向,简单实现先置零
+ foa[3, :] = 0 # Z通道: 垂直方向,这里置零
+ else:
+ foa = signal
+
+ else:
+ raise ValueError(f"Unsupported signal shape: {signal_shape}")
+
+ assert foa.shape[0] == 4, f'inputs not FOA format'
+
+ return foa
\ No newline at end of file
diff --git a/ThinkSound/inference/__init__.py b/ThinkSound/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThinkSound/inference/generation.py b/ThinkSound/inference/generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d6873fe9d19deba0497e5fb4d939472552a261c
--- /dev/null
+++ b/ThinkSound/inference/generation.py
@@ -0,0 +1,274 @@
+import numpy as np
+import torch
+import typing as tp
+import math
+from torchaudio import transforms as T
+
+from .utils import prepare_audio
+from .sampling import sample, sample_k, sample_rf
+from ..data.utils import PadCrop
+
+def generate_diffusion_uncond(
+ model,
+ steps: int = 250,
+ batch_size: int = 1,
+ sample_size: int = 2097152,
+ seed: int = -1,
+ device: str = "cuda",
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
+ init_noise_level: float = 1.0,
+ return_latents = False,
+ **sampler_kwargs
+ ) -> torch.Tensor:
+
+ # The length of the output in audio samples
+ audio_sample_size = sample_size
+
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
+ if model.pretransform is not None:
+ sample_size = sample_size // model.pretransform.downsampling_ratio
+
+ # Seed
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
+ print(seed)
+ torch.manual_seed(seed)
+ # Define the initial noise immediately after setting the seed
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
+
+ if init_audio is not None:
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
+ in_sr, init_audio = init_audio
+
+ io_channels = model.io_channels
+
+ # For latent models, set the io_channels to the autoencoder's io_channels
+ if model.pretransform is not None:
+ io_channels = model.pretransform.io_channels
+
+ # Prepare the initial audio for use by the model
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
+
+ # For latent models, encode the initial audio into latents
+ if model.pretransform is not None:
+ init_audio = model.pretransform.encode(init_audio)
+
+ init_audio = init_audio.repeat(batch_size, 1, 1)
+ else:
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
+ init_audio = None
+ init_noise_level = None
+
+ # Inpainting mask
+
+ if init_audio is not None:
+ # variations
+ sampler_kwargs["sigma_max"] = init_noise_level
+ mask = None
+ else:
+ mask = None
+
+ # Now the generative AI part:
+
+ diff_objective = model.diffusion_objective
+
+ if diff_objective == "v":
+ # k-diffusion denoising process go!
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
+ elif diff_objective == "rectified_flow":
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
+
+ # Denoising process done.
+ # If this is latent diffusion, decode latents back into audio
+ if model.pretransform is not None and not return_latents:
+ sampled = model.pretransform.decode(sampled)
+
+ # Return audio
+ return sampled
+
+
+def generate_diffusion_cond(
+ model,
+ steps: int = 250,
+ cfg_scale=6,
+ conditioning: dict = None,
+ conditioning_tensors: tp.Optional[dict] = None,
+ negative_conditioning: dict = None,
+ negative_conditioning_tensors: tp.Optional[dict] = None,
+ batch_size: int = 1,
+ sample_size: int = 2097152,
+ sample_rate: int = 48000,
+ seed: int = -1,
+ device: str = "cuda",
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
+ init_noise_level: float = 1.0,
+ mask_args: dict = None,
+ return_latents = False,
+ **sampler_kwargs
+ ) -> torch.Tensor:
+ """
+ Generate audio from a prompt using a diffusion model.
+
+ Args:
+ model: The diffusion model to use for generation.
+ steps: The number of diffusion steps to use.
+ cfg_scale: Classifier-free guidance scale
+ conditioning: A dictionary of conditioning parameters to use for generation.
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
+ batch_size: The batch size to use for generation.
+ sample_size: The length of the audio to generate, in samples.
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
+ seed: The random seed to use for generation, or -1 to use a random seed.
+ device: The device to use for generation.
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
+ init_noise_level: The noise level to use when generating from an initial audio sample.
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
+ """
+
+ # The length of the output in audio samples
+ audio_sample_size = sample_size
+
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
+ if model.pretransform is not None:
+ sample_size = sample_size // model.pretransform.downsampling_ratio
+
+ # Seed
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
+ print(seed)
+ torch.manual_seed(seed)
+ # Define the initial noise immediately after setting the seed
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
+
+ torch.backends.cuda.matmul.allow_tf32 = False
+ torch.backends.cudnn.allow_tf32 = False
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
+ torch.backends.cudnn.benchmark = False
+ import ipdb
+ # ipdb.set_trace()
+ # Conditioning
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
+ if conditioning_tensors is None:
+ conditioning_tensors = model.conditioner(conditioning, device)
+ conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
+
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
+
+ if negative_conditioning_tensors is None:
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
+
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
+ else:
+ negative_conditioning_tensors = {}
+
+ if init_audio is not None:
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
+ in_sr, init_audio = init_audio
+
+ io_channels = model.io_channels
+
+ # For latent models, set the io_channels to the autoencoder's io_channels
+ if model.pretransform is not None:
+ io_channels = model.pretransform.io_channels
+
+ # Prepare the initial audio for use by the model
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
+
+ # For latent models, encode the initial audio into latents
+ if model.pretransform is not None:
+ init_audio = model.pretransform.encode(init_audio)
+
+ init_audio = init_audio.repeat(batch_size, 1, 1)
+ else:
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
+ init_audio = None
+ init_noise_level = None
+ mask_args = None
+
+ # Inpainting mask
+ if init_audio is not None and mask_args is not None:
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
+ # This is helpful for forward and reverse outpainting
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
+ croplen = pasteto - pastefrom
+ if cropfrom + croplen > sample_size:
+ croplen = sample_size - cropfrom
+ cropto = cropfrom + croplen
+ pasteto = pastefrom + croplen
+ cutpaste = init_audio.new_zeros(init_audio.shape)
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
+ #print(cropfrom, cropto, pastefrom, pasteto)
+ init_audio = cutpaste
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
+ mask = build_mask(sample_size, mask_args)
+ mask = mask.to(device)
+ elif init_audio is not None and mask_args is None:
+ # variations
+ sampler_kwargs["sigma_max"] = init_noise_level
+ mask = None
+ else:
+ mask = None
+
+ model_dtype = next(model.model.parameters()).dtype
+ noise = noise.type(model_dtype)
+ conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
+ # Now the generative AI part:
+ # k-diffusion denoising process go!
+ diff_objective = model.diffusion_objective
+ if diff_objective == "v":
+ # k-diffusion denoising process go!
+ # sampled = sample(model.model, noise, steps, 0, **conditioning_inputs)
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
+ elif diff_objective == "rectified_flow":
+
+ if "sigma_min" in sampler_kwargs:
+ del sampler_kwargs["sigma_min"]
+
+ if "sampler_type" in sampler_kwargs:
+ del sampler_kwargs["sampler_type"]
+
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
+
+ # v-diffusion:
+ #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
+ del noise
+ del conditioning_tensors
+ del conditioning_inputs
+ torch.cuda.empty_cache()
+ # Denoising process done.
+ # If this is latent diffusion, decode latents back into audio
+ if model.pretransform is not None and not return_latents:
+ #cast sampled latents to pretransform dtype
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
+ sampled = model.pretransform.decode(sampled)
+
+ # Return audio
+ return sampled
+
+# builds a softmask given the parameters
+# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
+# and anything between is a mixture of old/new
+# ideally 0.5 is half/half mixture but i haven't figured this out yet
+def build_mask(sample_size, mask_args):
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
+ marination = mask_args["marination"]
+ # use hann windows for softening the transition (i don't know if this is correct)
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
+ # build the mask.
+ mask = torch.zeros((sample_size))
+ mask[maskstart:maskend] = 1
+ mask[maskstart:maskstart+softnessL] = hannL
+ mask[maskend-softnessR:maskend] = hannR
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
+ if marination > 0:
+ mask = mask * (1-marination)
+ #print(mask)
+ return mask
diff --git a/ThinkSound/inference/sampling.py b/ThinkSound/inference/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3877d7252bcfd08f19bd59df80d48a1528b4e187
--- /dev/null
+++ b/ThinkSound/inference/sampling.py
@@ -0,0 +1,286 @@
+import torch
+import math
+from tqdm import trange, tqdm
+import torch.distributions as dist
+
+import k_diffusion as K
+
+# Define the noise schedule and sampling loop
+def get_alphas_sigmas(t):
+ """Returns the scaling factors for the clean image (alpha) and for the
+ noise (sigma), given a timestep."""
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
+
+def alpha_sigma_to_t(alpha, sigma):
+ """Returns a timestep, given the scaling factors for the clean image and for
+ the noise."""
+ return torch.atan2(sigma, alpha) / math.pi * 2
+
+def t_to_alpha_sigma(t):
+ """Returns the scaling factors for the clean image and for the noise, given
+ a timestep."""
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
+
+def sample_timesteps_logsnr(batch_size, mean_logsnr=-1.2, std_logsnr=2.0):
+ """
+ Sample timesteps for diffusion training by sampling logSNR values and converting to t.
+
+ Args:
+ batch_size (int): Number of timesteps to sample
+ mean_logsnr (float): Mean of the logSNR Gaussian distribution
+ std_logsnr (float): Standard deviation of the logSNR Gaussian distribution
+
+ Returns:
+ torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1]
+ """
+ # Sample logSNR from Gaussian distribution
+ logsnr = torch.randn(batch_size) * std_logsnr + mean_logsnr
+
+ # Convert logSNR to timesteps using the logistic function
+ # Since logSNR = ln((1-t)/t), we can solve for t:
+ # t = 1 / (1 + exp(logsnr))
+ t = torch.sigmoid(-logsnr)
+
+ # Clamp values to ensure numerical stability
+ t = t.clamp(1e-4, 1 - 1e-4)
+
+ return t
+def truncated_logistic_normal_rescaled(shape, left_trunc=0.075, right_trunc=1):
+ """
+
+ shape: shape of the output tensor
+ left_trunc: left truncation point, fraction of probability to be discarded
+ right_trunc: right truncation boundary, should be 1 (never seen at test time)
+ """
+
+ # Step 1: Sample from the logistic normal distribution (sigmoid of normal)
+ logits = torch.randn(shape)
+
+ # Step 2: Apply the CDF transformation of the normal distribution
+ normal_dist = dist.Normal(0, 1)
+ cdf_values = normal_dist.cdf(logits)
+
+ # Step 3: Define the truncation bounds on the CDF
+ lower_bound = normal_dist.cdf(torch.logit(torch.tensor(left_trunc)))
+ upper_bound = normal_dist.cdf(torch.logit(torch.tensor(right_trunc)))
+
+ # Step 4: Rescale linear CDF values into the truncated region (between lower_bound and upper_bound)
+ truncated_cdf_values = lower_bound + (upper_bound - lower_bound) * cdf_values
+
+ # Step 5: Map back to logistic-normal space using inverse CDF
+ truncated_samples = torch.sigmoid(normal_dist.icdf(truncated_cdf_values))
+
+ # Step 6: Rescale values so that min is 0 and max is just below 1
+ rescaled_samples = (truncated_samples - left_trunc) / (right_trunc - left_trunc)
+
+ return rescaled_samples
+
+@torch.no_grad()
+def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
+ """Draws samples from a model given starting noise. Euler method"""
+
+ # Make tensor of ones to broadcast the single t values
+ ts = x.new_ones([x.shape[0]])
+
+ # Create the noise schedule
+ t = torch.linspace(sigma_max, 0, steps + 1)
+
+ #alphas, sigmas = 1-t, t
+
+ for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
+ # Broadcast the current timestep to the correct shape
+ t_curr_tensor = t_curr * torch.ones(
+ (x.shape[0],), dtype=x.dtype, device=x.device
+ )
+ dt = t_prev - t_curr # we solve backwards in our formulation
+ x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
+
+ # If we are on the last timestep, output the denoised image
+ return x
+
+@torch.no_grad()
+def sample(model, x, steps, eta, **extra_args):
+ """Draws samples from a model given starting noise. v-diffusion"""
+ ts = x.new_ones([x.shape[0]])
+
+ # Create the noise schedule
+ t = torch.linspace(1, 0, steps + 1)[:-1]
+
+ alphas, sigmas = get_alphas_sigmas(t)
+
+ # The sampling loop
+ for i in trange(steps):
+
+ # Get the model output (v, the predicted velocity)
+ with torch.cuda.amp.autocast():
+ v = model(x, ts * t[i], **extra_args).float()
+
+ # Predict the noise and the denoised image
+ pred = x * alphas[i] - v * sigmas[i]
+ eps = x * sigmas[i] + v * alphas[i]
+
+ # If we are not on the last timestep, compute the noisy image for the
+ # next timestep.
+ if i < steps - 1:
+ # If eta > 0, adjust the scaling factor for the predicted noise
+ # downward according to the amount of additional noise to add
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
+
+ # Recombine the predicted noise and predicted denoised image in the
+ # correct proportions for the next step
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
+
+ # Add the correct amount of fresh noise
+ if eta:
+ x += torch.randn_like(x) * ddim_sigma
+
+ # If we are on the last timestep, output the denoised image
+ return pred
+
+# Soft mask inpainting is just shrinking hard (binary) mask inpainting
+# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
+def get_bmask(i, steps, mask):
+ strength = (i+1)/(steps)
+ # convert to binary mask
+ bmask = torch.where(mask<=strength,1,0)
+ return bmask
+
+def make_cond_model_fn(model, cond_fn):
+ def cond_model_fn(x, sigma, **kwargs):
+ with torch.enable_grad():
+ x = x.detach().requires_grad_()
+ denoised = model(x, sigma, **kwargs)
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
+ return cond_denoised
+ return cond_model_fn
+
+# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
+# init_data is init_audio as latents (if this is latent diffusion)
+# For sampling, set both init_data and mask to None
+# For variations, set init_data
+# For inpainting, set both init_data & mask
+def sample_k(
+ model_fn,
+ noise,
+ init_data=None,
+ mask=None,
+ steps=100,
+ sampler_type="dpmpp-2m-sde",
+ sigma_min=0.5,
+ sigma_max=50,
+ rho=1.0, device="cuda",
+ callback=None,
+ cond_fn=None,
+ **extra_args
+ ):
+
+ denoiser = K.external.VDenoiser(model_fn)
+
+ if cond_fn is not None:
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
+
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
+ # Scale the initial noise by sigma
+ noise = noise * sigmas[0]
+
+ wrapped_callback = callback
+
+ if mask is None and init_data is not None:
+ # VARIATION (no inpainting)
+ # set the initial latent to the init_data, and noise it with initial sigma
+ x = init_data + noise
+ elif mask is not None and init_data is not None:
+ # INPAINTING
+ bmask = get_bmask(0, steps, mask)
+ # initial noising
+ input_noised = init_data + noise
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
+ x = input_noised * bmask + noise * (1-bmask)
+ # define the inpainting callback function (Note: side effects, it mutates x)
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
+ def inpainting_callback(args):
+ i = args["i"]
+ x = args["x"]
+ sigma = args["sigma"]
+ #denoised = args["denoised"]
+ # noise the init_data input with this step's appropriate amount of noise
+ input_noised = init_data + torch.randn_like(init_data) * sigma
+ # shrinking hard mask
+ bmask = get_bmask(i, steps, mask)
+ # mix input_noise with x, using binary mask
+ new_x = input_noised * bmask + x * (1-bmask)
+ # mutate x
+ x[:,:,:] = new_x[:,:,:]
+ # wrap together the inpainting callback and the user-submitted callback.
+ if callback is None:
+ wrapped_callback = inpainting_callback
+ else:
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
+ else:
+ # SAMPLING
+ # set the initial latent to noise
+ x = noise
+
+
+ with torch.cuda.amp.autocast():
+ if sampler_type == "k-heun":
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "k-lms":
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "k-dpmpp-2s-ancestral":
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "k-dpm-2":
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "k-dpm-fast":
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "k-dpm-adaptive":
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "dpmpp-2m-sde":
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+ elif sampler_type == "dpmpp-3m-sde":
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
+
+# Uses discrete Euler sampling for rectified flow models
+# init_data is init_audio as latents (if this is latent diffusion)
+# For sampling, set both init_data and mask to None
+# For variations, set init_data
+# For inpainting, set both init_data & mask
+def sample_rf(
+ model_fn,
+ noise,
+ init_data=None,
+ steps=100,
+ sigma_max=1,
+ device="cuda",
+ callback=None,
+ cond_fn=None,
+ **extra_args
+ ):
+
+ if sigma_max > 1:
+ sigma_max = 1
+
+ if cond_fn is not None:
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
+
+ wrapped_callback = callback
+
+ if init_data is not None:
+ # VARIATION (no inpainting)
+ # Interpolate the init data and the noise for init audio
+ x = init_data * (1 - sigma_max) + noise * sigma_max
+ else:
+ # SAMPLING
+ # set the initial latent to noise
+ x = noise
+
+ with torch.cuda.amp.autocast():
+ # TODO: Add callback support
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
\ No newline at end of file
diff --git a/ThinkSound/inference/utils.py b/ThinkSound/inference/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a6c0a57609f68156ad244da9b5819666329772e
--- /dev/null
+++ b/ThinkSound/inference/utils.py
@@ -0,0 +1,35 @@
+from ..data.utils import PadCrop
+
+from torchaudio import transforms as T
+
+def set_audio_channels(audio, target_channels):
+ if target_channels == 1:
+ # Convert to mono
+ audio = audio.mean(1, keepdim=True)
+ elif target_channels == 2:
+ # Convert to stereo
+ if audio.shape[1] == 1:
+ audio = audio.repeat(1, 2, 1)
+ elif audio.shape[1] > 2:
+ audio = audio[:, :2, :]
+ return audio
+
+def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
+
+ audio = audio.to(device)
+
+ if in_sr != target_sr:
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
+ audio = resample_tf(audio)
+
+ audio = PadCrop(target_length, randomize=False)(audio)
+
+ # Add batch dimension
+ if audio.dim() == 1:
+ audio = audio.unsqueeze(0).unsqueeze(0)
+ elif audio.dim() == 2:
+ audio = audio.unsqueeze(0)
+
+ audio = set_audio_channels(audio, target_channels)
+
+ return audio
\ No newline at end of file
diff --git a/ThinkSound/interface/__init__.py b/ThinkSound/interface/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThinkSound/interface/aeiou.py b/ThinkSound/interface/aeiou.py
new file mode 100644
index 0000000000000000000000000000000000000000..8846c8968e48f9324f28642ade0dbc94a26ee99c
--- /dev/null
+++ b/ThinkSound/interface/aeiou.py
@@ -0,0 +1,278 @@
+# Modified from https://github.com/drscotthawley/aeiou/blob/main/aeiou/viz.py under Apache 2.0 License
+# License can be found in LICENSES/LICENSE_AEIOU.txt
+
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+import matplotlib.cm as cm
+from matplotlib.colors import Normalize
+from matplotlib.figure import Figure
+import numpy as np
+from PIL import Image
+
+import torch
+
+import torchaudio.transforms as T
+from einops import rearrange
+
+import numpy as np
+
+def embeddings_table(tokens):
+ from wandb import Table
+ from pandas import DataFrame
+
+ "make a table of embeddings for use with wandb"
+ features, labels = [], []
+ embeddings = rearrange(tokens, 'b d n -> b n d') # each demo sample is n vectors in d-dim space
+ for i in range(embeddings.size()[0]): # nested for's are slow but sure ;-)
+ for j in range(embeddings.size()[1]):
+ features.append(embeddings[i,j].detach().cpu().numpy())
+ labels.append([f'demo{i}']) # labels does the grouping / color for each point
+ features = np.array(features)
+ labels = np.concatenate(labels, axis=0)
+ cols = [f"dim_{i}" for i in range(features.shape[1])]
+ df = DataFrame(features, columns=cols)
+ df['LABEL'] = labels
+ return Table(columns=df.columns.to_list(), data=df.values)
+
+def project_down(tokens, # batched high-dimensional data with dims (b,d,n)
+ proj_dims=3, # dimensions to project to
+ method='pca', # projection method: 'pca'|'umap'
+ n_neighbors=10, # umap parameter for number of neighbors
+ min_dist=0.3, # umap param for minimum distance
+ debug=False, # print more info while running
+ **kwargs, # other params to pass to umap, cf. https://umap-learn.readthedocs.io/en/latest/parameters.html
+ ):
+ "this projects to lower dimenions, grabbing the first _`proj_dims`_ dimensions"
+ method = method.lower()
+ A = rearrange(tokens, 'b d n -> (b n) d') # put all the vectors into the same d-dim space
+ if A.shape[-1] > proj_dims:
+ if method=='umap':
+ from umap import UMAP
+ proj_data = UMAP(n_components=proj_dims, n_neighbors=n_neighbors, min_dist=min_dist,
+ metric='correlation', **kwargs).fit_transform(A.cpu().numpy())
+ proj_data = torch.from_numpy(proj_data).to(tokens.device)
+ else: # pca
+ (U, S, V) = torch.pca_lowrank(A)
+ proj_data = torch.matmul(A, V[:, :proj_dims]) # this is the actual PCA projection step
+ else:
+ proj_data = A
+ if debug: print("proj_data.shape =",proj_data.shape)
+ return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, proj_dims]
+
+
+def proj_pca(tokens, proj_dims=3):
+ return project_down(do_proj, method='pca', proj_dims=proj_dims)
+
+def point_cloud(
+ tokens, # embeddings / latent vectors. shape = (b, d, n)
+ method='pca', # projection method for 3d mapping: 'pca' | 'umap'
+ color_scheme='batch', # 'batch': group by sample; integer n: n groups, sequentially, otherwise color sequentially by time step
+ output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly!
+ mode='markers', # plotly scatter mode. 'lines+markers' or 'markers'
+ size=3, # size of the dots
+ line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line
+ ds_preproj=1, # EXPERIMENTAL: downsampling factor before projecting (1=no downsampling). Could screw up colors
+ ds_preplot=1, # EXPERIMENTAL: downsampling factor before plotting (1=no downsampling). Could screw up colors
+ debug=False, # print more info
+ colormap=None, # valid color map to use, None=defaults
+ darkmode=False, # dark background, white fonts
+ layout_dict=None, # extra plotly layout options such as camera orientation
+ rgb_float = False, # if True, color_scheme is RGB float values
+ **kwargs, # anything else to pass along
+ ):
+ "returns a 3D point cloud of the tokens"
+ if ds_preproj != 1:
+ tokens = tokens[torch.randperm(tokens.size()[0])] # EXPERIMENTAL: to 'correct' for possible weird effects of downsampling
+ tokens = tokens[::ds_preproj]
+ if debug: print("tokens.shape =",tokens.shape)
+
+ data = project_down(tokens, method=method, debug=debug, **kwargs).cpu().numpy()
+ if debug: print("data.shape =",data.shape)
+ if data.shape[-1] < 3: # for data less than 3D, embed it in 3D
+ data = np.pad(data, ((0,0),(0,0),(0, 3-data.shape[-1])), mode='constant', constant_values=0)
+
+ bytime = False
+ points = []
+ if color_scheme=='batch': # all dots in same batch index same color, each batch-index unique (almost)
+ ncolors = data.shape[0]
+ cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors)
+ elif isinstance(color_scheme, int) or color_scheme.isnumeric(): # n groups, by batch-indices, sequentially
+ ncolors = int(color_scheme)
+ cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors)
+ else: # time steps match up
+ bytime, ncolors = True, data.shape[1]
+ cmap, norm = cm.viridis, Normalize(vmin=0, vmax=ncolors)
+
+ cmap = cmap if colormap is None else colormap # overwrite default cmap with user choice if given
+
+ points = []
+ for bi in range(data.shape[0]): # batch index
+ if color_scheme=='batch':
+ [r, g, b, _] = [int(255*x) for x in cmap(norm(bi+1))]
+ elif isinstance(color_scheme, int) or color_scheme.isnumeric():
+ grouplen = data.shape[0]//(ncolors)
+ #if debug: print(f"bi, grouplen, bi//grouplen = ",bi, grouplen, bi//grouplen)
+ [r, g, b, _] = [int(255*x) for x in cmap(norm(bi//grouplen))]
+ #if debug: print("r,g,b = ",r,g,b)
+
+ if rgb_float: [r, g, b] = [x/255 for x in [r, g, b]]
+
+ for n in range(data.shape[1]): # across time
+ if bytime: [r, g, b, _] = [int(255*x) for x in cmap(norm(n))]
+ points.append([data[bi,n,0], data[bi,n,1], data[bi,n,2], r, g, b]) # include dot colors with point coordinates
+
+ point_cloud = np.array(points)
+
+ if output_type == 'points':
+ return point_cloud
+ elif output_type =='plotly':
+ import plotly.graph_objects as go
+
+ fig = go.Figure(data=[go.Scatter3d(
+ x=point_cloud[::ds_preplot,0], y=point_cloud[::ds_preplot,1], z=point_cloud[::ds_preplot,2],
+ marker=dict(size=size, color=point_cloud[:,3:6]),
+ mode=mode,
+ # show batch index and time index in tooltips:
+ text=[ f'bi: {i*ds_preplot}, ti: {j}' for i in range(data.shape[0]//ds_preplot) for j in range(data.shape[1]) ],
+ line=line,
+ )])
+ fig.update_layout(margin=dict(l=0, r=0, b=0, t=0)) # tight layout
+ if darkmode:
+ fig.layout.template = 'plotly_dark'
+ if isinstance(darkmode, str): # 'rgb(12,15,24)'gradio margins in dark mode
+ fig.update_layout( paper_bgcolor=darkmode)
+ if layout_dict:
+ fig.update_layout( **layout_dict )
+
+ if debug: print("point_cloud: fig made. returning")
+ return fig
+ else:
+ from wandb import Object3D
+ return Object3D(point_cloud)
+
+def pca_point_cloud(
+ tokens, # embeddings / latent vectors. shape = (b, d, n)
+ color_scheme='batch', # 'batch': group by sample, otherwise color sequentially
+ output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly!
+ mode='markers', # plotly scatter mode. 'lines+markers' or 'markers'
+ size=3, # size of the dots
+ line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line
+ **kwargs,
+ ):
+ return point_cloud(tokens, method='pca', color_scheme=color_scheme, output_type=output_type,
+ mode=mode, size=size, line=line, **kwargs)
+
+def power_to_db(spec, *, amin = 1e-10):
+ magnitude = np.asarray(spec)
+
+ log_spec = 10.0 * np.log10(np.maximum(amin, magnitude))
+ log_spec -= 10.0 * np.log10(np.maximum(amin, 1))
+
+ log_spec = np.maximum(log_spec, log_spec.max() - 80)
+
+ return log_spec
+
+def mel_spectrogram(waveform, power=2.0, sample_rate=48000, db=False, n_fft=1024, n_mels=128, debug=False):
+ "calculates data array for mel spectrogram (in however many channels)"
+ win_length = None
+ hop_length = n_fft//2 # 512
+
+ mel_spectrogram_op = T.MelSpectrogram(
+ sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
+ hop_length=hop_length, center=True, pad_mode="reflect", power=power,
+ norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk")
+
+ melspec = mel_spectrogram_op(waveform.float())
+ if db:
+ amp_to_db_op = T.AmplitudeToDB()
+ melspec = amp_to_db_op(melspec)
+ if debug:
+ print_stats(melspec, print=print)
+ print(f"torch.max(melspec) = {torch.max(melspec)}")
+ print(f"melspec.shape = {melspec.shape}")
+ return melspec
+
+def spectrogram_image(
+ spec,
+ title=None,
+ ylabel='freq_bin',
+ aspect='auto',
+ xmax=None,
+ db_range=[35,120],
+ justimage=False,
+ figsize=(5, 4), # size of plot (if justimage==False)
+ ):
+ "Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html"
+ fig = Figure(figsize=figsize, dpi=100) if not justimage else Figure(figsize=(4.145, 4.145), dpi=100, tight_layout=True)
+ canvas = FigureCanvasAgg(fig)
+ axs = fig.add_subplot()
+ spec = spec.squeeze()
+ im = axs.imshow(power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1])
+ if xmax:
+ axs.set_xlim((0, xmax))
+ if justimage:
+ import matplotlib.pyplot as plt
+ axs.axis('off')
+ plt.tight_layout()
+ else:
+ axs.set_ylabel(ylabel)
+ axs.set_xlabel('frame')
+ axs.set_title(title or 'Spectrogram (dB)')
+ fig.colorbar(im, ax=axs)
+ canvas.draw()
+ rgba = np.asarray(canvas.buffer_rgba())
+ im = Image.fromarray(rgba)
+ if justimage: # remove tiny white border
+ b = 15 # border size
+ im = im.crop((b,b, im.size[0]-b, im.size[1]-b))
+ #print(f"im.size = {im.size}")
+ return im
+
+def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000, print=print, db=False, db_range=[35,120], justimage=False, log=False, figsize=(5, 4)):
+ "Wrapper for calling above two routines at once, does Mel scale; Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html"
+ melspec = mel_spectrogram(waveform, power=power, db=db, sample_rate=sample_rate, debug=log)
+ melspec = melspec[0] # TODO: only left channel for now
+ return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)', db_range=db_range, justimage=justimage, figsize=figsize)
+
+from matplotlib.ticker import AutoLocator
+def tokens_spectrogram_image(
+ tokens, # the embeddings themselves (in some diffusion codes these are called 'tokens')
+ aspect='auto', # aspect ratio of plot
+ title='Embeddings', # title to put on top
+ ylabel='index', # label for y axis of plot
+ cmap='coolwarm', # colormap to use. (default used to be 'viridis')
+ symmetric=True, # make color scale symmetric about zero, i.e. +/- same extremes
+ figsize=(8, 4), # matplotlib size of the figure
+ dpi=100, # dpi of figure
+ mark_batches=False, # separate batches with dividing lines
+ debug=False, # print debugging info
+ ):
+ "for visualizing embeddings in a spectrogram-like way"
+ batch_size, dim, samples = tokens.shape
+ embeddings = rearrange(tokens, 'b d n -> (b n) d') # expand batches in time
+ vmin, vmax = None, None
+ if symmetric:
+ vmax = torch.abs(embeddings).max()
+ vmin = -vmax
+
+ fig = Figure(figsize=figsize, dpi=dpi)
+ canvas = FigureCanvasAgg(fig)
+ ax = fig.add_subplot()
+ if symmetric:
+ subtitle = f'min={embeddings.min():0.4g}, max={embeddings.max():0.4g}'
+ ax.set_title(title+'\n')
+ ax.text(x=0.435, y=0.9, s=subtitle, fontsize=11, ha="center", transform=fig.transFigure)
+ else:
+ ax.set_title(title)
+ ax.set_ylabel(ylabel)
+ ax.set_xlabel('time frame (samples, in batches)')
+ if mark_batches:
+ intervals = np.arange(batch_size)*samples
+ if debug: print("intervals = ",intervals)
+ ax.vlines(intervals, -10, dim+10, color='black', linestyle='dashed', linewidth=1)
+
+ im = ax.imshow(embeddings.cpu().numpy().T, origin='lower', aspect=aspect, interpolation='none', cmap=cmap, vmin=vmin,vmax=vmax) #.T because numpy is x/y 'backwards'
+ fig.colorbar(im, ax=ax)
+ fig.tight_layout()
+ canvas.draw()
+ rgba = np.asarray(canvas.buffer_rgba())
+ return Image.fromarray(rgba)
diff --git a/ThinkSound/interface/gradio.py b/ThinkSound/interface/gradio.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38468bc34b88ec6bbe5451a8b11b998430888f8
--- /dev/null
+++ b/ThinkSound/interface/gradio.py
@@ -0,0 +1,700 @@
+import gc
+import platform
+
+import numpy as np
+import gradio as gr
+import json
+import torch
+import torchaudio
+
+from aeiou.viz import audio_spectrogram_image
+from einops import rearrange
+from safetensors.torch import load_file
+from torch.nn import functional as F
+from torchaudio import transforms as T
+
+from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond
+from ..models.factory import create_model_from_config
+from ..models.pretrained import get_pretrained_model
+from ..models.utils import load_ckpt_state_dict
+from ..inference.utils import prepare_audio
+from ..training.utils import copy_state_dict
+
+model = None
+sample_rate = 32000
+sample_size = 1920000
+
+def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
+ global model, sample_rate, sample_size
+
+ if pretrained_name is not None:
+ print(f"Loading pretrained model {pretrained_name}")
+ model, model_config = get_pretrained_model(pretrained_name)
+
+ elif model_config is not None and model_ckpt_path is not None:
+ print(f"Creating model from config")
+ model = create_model_from_config(model_config)
+
+ print(f"Loading model checkpoint from {model_ckpt_path}")
+ # Load checkpoint
+ copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))
+ #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
+
+ sample_rate = model_config["sample_rate"]
+ sample_size = model_config["sample_size"]
+
+ if pretransform_ckpt_path is not None:
+ print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}")
+ model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False)
+ print(f"Done loading pretransform")
+
+ model.to(device).eval().requires_grad_(False)
+
+ if model_half:
+ model.to(torch.float16)
+
+ print(f"Done loading model")
+
+ return model, model_config
+
+def generate_cond(
+ prompt,
+ negative_prompt=None,
+ seconds_start=0,
+ seconds_total=30,
+ cfg_scale=6.0,
+ steps=250,
+ preview_every=None,
+ seed=-1,
+ sampler_type="dpmpp-3m-sde",
+ sigma_min=0.03,
+ sigma_max=1000,
+ cfg_rescale=0.0,
+ use_init=False,
+ init_audio=None,
+ init_noise_level=1.0,
+ mask_cropfrom=None,
+ mask_pastefrom=None,
+ mask_pasteto=None,
+ mask_maskstart=None,
+ mask_maskend=None,
+ mask_softnessL=None,
+ mask_softnessR=None,
+ mask_marination=None,
+ batch_size=1
+ ):
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ print(f"Prompt: {prompt}")
+
+ global preview_images
+ preview_images = []
+ if preview_every == 0:
+ preview_every = None
+
+ # Return fake stereo audio
+ conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
+
+ if negative_prompt:
+ negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size
+ else:
+ negative_conditioning = None
+
+ #Get the device from the model
+ device = next(model.parameters()).device
+
+ seed = int(seed)
+
+ if not use_init:
+ init_audio = None
+
+ input_sample_size = sample_size
+
+ if init_audio is not None:
+ in_sr, init_audio = init_audio
+ # Turn into torch tensor, converting from int16 to float32
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
+
+ if init_audio.dim() == 1:
+ init_audio = init_audio.unsqueeze(0) # [1, n]
+ elif init_audio.dim() == 2:
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
+
+ if in_sr != sample_rate:
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
+ init_audio = resample_tf(init_audio)
+
+ audio_length = init_audio.shape[-1]
+
+ if audio_length > sample_size:
+
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
+
+ init_audio = (sample_rate, init_audio)
+
+ def progress_callback(callback_info):
+ global preview_images
+ denoised = callback_info["denoised"]
+ current_step = callback_info["i"]
+ sigma = callback_info["sigma"]
+
+ if (current_step - 1) % preview_every == 0:
+ if model.pretransform is not None:
+ denoised = model.pretransform.decode(denoised)
+ denoised = rearrange(denoised, "b d n -> d (b n)")
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
+
+ # If inpainting, send mask args
+ # This will definitely change in the future
+ if mask_cropfrom is not None:
+ mask_args = {
+ "cropfrom": mask_cropfrom,
+ "pastefrom": mask_pastefrom,
+ "pasteto": mask_pasteto,
+ "maskstart": mask_maskstart,
+ "maskend": mask_maskend,
+ "softnessL": mask_softnessL,
+ "softnessR": mask_softnessR,
+ "marination": mask_marination,
+ }
+ else:
+ mask_args = None
+
+ # Do the audio generation
+ audio = generate_diffusion_cond(
+ model,
+ conditioning=conditioning,
+ negative_conditioning=negative_conditioning,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ batch_size=batch_size,
+ sample_size=input_sample_size,
+ sample_rate=sample_rate,
+ seed=seed,
+ device=device,
+ sampler_type=sampler_type,
+ sigma_min=sigma_min,
+ sigma_max=sigma_max,
+ init_audio=init_audio,
+ init_noise_level=init_noise_level,
+ mask_args = mask_args,
+ callback = progress_callback if preview_every is not None else None,
+ scale_phi = cfg_rescale
+ )
+
+ # Convert to WAV file
+ audio = rearrange(audio, "b d n -> d (b n)")
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+ torchaudio.save("output.wav", audio, sample_rate)
+
+ # Let's look at a nice spectrogram too
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
+
+ return ("output.wav", [audio_spectrogram, *preview_images])
+
+def generate_uncond(
+ steps=250,
+ seed=-1,
+ sampler_type="dpmpp-3m-sde",
+ sigma_min=0.03,
+ sigma_max=1000,
+ use_init=False,
+ init_audio=None,
+ init_noise_level=1.0,
+ batch_size=1,
+ preview_every=None
+ ):
+
+ global preview_images
+
+ preview_images = []
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ #Get the device from the model
+ device = next(model.parameters()).device
+
+ seed = int(seed)
+
+ if not use_init:
+ init_audio = None
+
+ input_sample_size = sample_size
+
+ if init_audio is not None:
+ in_sr, init_audio = init_audio
+ # Turn into torch tensor, converting from int16 to float32
+ init_audio = torch.from_numpy(init_audio).float().div(32767)
+
+ if init_audio.dim() == 1:
+ init_audio = init_audio.unsqueeze(0) # [1, n]
+ elif init_audio.dim() == 2:
+ init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n]
+
+ if in_sr != sample_rate:
+ resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device)
+ init_audio = resample_tf(init_audio)
+
+ audio_length = init_audio.shape[-1]
+
+ if audio_length > sample_size:
+
+ input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length
+
+ init_audio = (sample_rate, init_audio)
+
+ def progress_callback(callback_info):
+ global preview_images
+ denoised = callback_info["denoised"]
+ current_step = callback_info["i"]
+ sigma = callback_info["sigma"]
+
+ if (current_step - 1) % preview_every == 0:
+
+ if model.pretransform is not None:
+ denoised = model.pretransform.decode(denoised)
+
+ denoised = rearrange(denoised, "b d n -> d (b n)")
+
+ denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+
+ audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate)
+
+ preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})"))
+
+ audio = generate_diffusion_uncond(
+ model,
+ steps=steps,
+ batch_size=batch_size,
+ sample_size=input_sample_size,
+ seed=seed,
+ device=device,
+ sampler_type=sampler_type,
+ sigma_min=sigma_min,
+ sigma_max=sigma_max,
+ init_audio=init_audio,
+ init_noise_level=init_noise_level,
+ callback = progress_callback if preview_every is not None else None
+ )
+
+ audio = rearrange(audio, "b d n -> d (b n)")
+
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+
+ torchaudio.save("output.wav", audio, sample_rate)
+
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
+
+ return ("output.wav", [audio_spectrogram, *preview_images])
+
+def generate_lm(
+ temperature=1.0,
+ top_p=0.95,
+ top_k=0,
+ batch_size=1,
+ ):
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ #Get the device from the model
+ device = next(model.parameters()).device
+
+ audio = model.generate_audio(
+ batch_size=batch_size,
+ max_gen_len = sample_size//model.pretransform.downsampling_ratio,
+ conditioning=None,
+ temp=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ use_cache=True
+ )
+
+ audio = rearrange(audio, "b d n -> d (b n)")
+
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+
+ torchaudio.save("output.wav", audio, sample_rate)
+
+ audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate)
+
+ return ("output.wav", [audio_spectrogram])
+
+
+def create_uncond_sampling_ui(model_config):
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
+
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ with gr.Row():
+ # Steps slider
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
+
+ with gr.Accordion("Sampler params", open=False):
+
+ # Seed
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
+
+ # Sampler params
+ with gr.Row():
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
+
+ with gr.Accordion("Init audio", open=False):
+ init_audio_checkbox = gr.Checkbox(label="Use init audio")
+ init_audio_input = gr.Audio(label="Init audio")
+ init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
+
+ with gr.Column():
+ audio_output = gr.Audio(label="Output audio", interactive=False)
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
+ send_to_init_button = gr.Button("Send to init audio", scale=1)
+ send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
+
+ generate_button.click(fn=generate_uncond,
+ inputs=[
+ steps_slider,
+ seed_textbox,
+ sampler_type_dropdown,
+ sigma_min_slider,
+ sigma_max_slider,
+ init_audio_checkbox,
+ init_audio_input,
+ init_noise_level_slider,
+ ],
+ outputs=[
+ audio_output,
+ audio_spectrogram_output
+ ],
+ api_name="generate")
+
+def create_sampling_ui(model_config, inpainting=False):
+ with gr.Row():
+ with gr.Column(scale=6):
+ prompt = gr.Textbox(show_label=False, placeholder="Prompt")
+ negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt")
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
+
+ model_conditioning_config = model_config["model"].get("conditioning", None)
+
+ has_seconds_start = False
+ has_seconds_total = False
+
+ if model_conditioning_config is not None:
+ for conditioning_config in model_conditioning_config["configs"]:
+ if conditioning_config["id"] == "seconds_start":
+ has_seconds_start = True
+ if conditioning_config["id"] == "seconds_total":
+ has_seconds_total = True
+
+ with gr.Row(equal_height=False):
+ with gr.Column():
+ with gr.Row(visible = has_seconds_start or has_seconds_total):
+ # Timing controls
+ seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start)
+ seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total)
+
+ with gr.Row():
+ # Steps slider
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
+
+ # Preview Every slider
+ preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every")
+
+ # CFG scale
+ cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale")
+
+ with gr.Accordion("Sampler params", open=False):
+
+ # Seed
+ seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1")
+
+ # Sampler params
+ with gr.Row():
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
+ cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount")
+
+ if inpainting:
+ # Inpainting Tab
+ with gr.Accordion("Inpainting", open=False):
+ sigma_max_slider.maximum=1000
+
+ init_audio_checkbox = gr.Checkbox(label="Do inpainting")
+ init_audio_input = gr.Audio(label="Init audio")
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this
+
+ mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %")
+ mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %")
+ mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %")
+
+ mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %")
+ mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %")
+ mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %")
+ mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %")
+ mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this
+
+ inputs = [prompt,
+ negative_prompt,
+ seconds_start_slider,
+ seconds_total_slider,
+ cfg_scale_slider,
+ steps_slider,
+ preview_every_slider,
+ seed_textbox,
+ sampler_type_dropdown,
+ sigma_min_slider,
+ sigma_max_slider,
+ cfg_rescale_slider,
+ init_audio_checkbox,
+ init_audio_input,
+ init_noise_level_slider,
+ mask_cropfrom_slider,
+ mask_pastefrom_slider,
+ mask_pasteto_slider,
+ mask_maskstart_slider,
+ mask_maskend_slider,
+ mask_softnessL_slider,
+ mask_softnessR_slider,
+ mask_marination_slider
+ ]
+ else:
+ # Default generation tab
+ with gr.Accordion("Init audio", open=False):
+ init_audio_checkbox = gr.Checkbox(label="Use init audio")
+ init_audio_input = gr.Audio(label="Init audio")
+ init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level")
+
+ inputs = [prompt,
+ negative_prompt,
+ seconds_start_slider,
+ seconds_total_slider,
+ cfg_scale_slider,
+ steps_slider,
+ preview_every_slider,
+ seed_textbox,
+ sampler_type_dropdown,
+ sigma_min_slider,
+ sigma_max_slider,
+ cfg_rescale_slider,
+ init_audio_checkbox,
+ init_audio_input,
+ init_noise_level_slider
+ ]
+
+ with gr.Column():
+ audio_output = gr.Audio(label="Output audio", interactive=False)
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
+ send_to_init_button = gr.Button("Send to init audio", scale=1)
+ send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input])
+
+ generate_button.click(fn=generate_cond,
+ inputs=inputs,
+ outputs=[
+ audio_output,
+ audio_spectrogram_output
+ ],
+ api_name="generate")
+
+
+def create_txt2audio_ui(model_config):
+ with gr.Blocks() as ui:
+ with gr.Tab("Generation"):
+ create_sampling_ui(model_config)
+ with gr.Tab("Inpainting"):
+ create_sampling_ui(model_config, inpainting=True)
+ return ui
+
+def create_diffusion_uncond_ui(model_config):
+ with gr.Blocks() as ui:
+ create_uncond_sampling_ui(model_config)
+
+ return ui
+
+def autoencoder_process(audio, latent_noise, n_quantizers):
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ #Get the device from the model
+ device = next(model.parameters()).device
+
+ in_sr, audio = audio
+
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
+
+ if audio.dim() == 1:
+ audio = audio.unsqueeze(0)
+ else:
+ audio = audio.transpose(0, 1)
+
+ audio = model.preprocess_audio_for_encoder(audio, in_sr)
+ # Note: If you need to do chunked encoding, to reduce VRAM,
+ # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128
+ # To turn it off, do chunked=False
+ # Optimal overlap and chunk_size values will depend on the model.
+ # See encode_audio & decode_audio in autoencoders.py for more info
+ # Get dtype of model
+ dtype = next(model.parameters()).dtype
+
+ audio = audio.to(dtype)
+
+ if n_quantizers > 0:
+ latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers)
+ else:
+ latents = model.encode_audio(audio, chunked=False)
+
+ if latent_noise > 0:
+ latents = latents + torch.randn_like(latents) * latent_noise
+
+ audio = model.decode_audio(latents, chunked=False)
+
+ audio = rearrange(audio, "b d n -> d (b n)")
+
+ audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+
+ torchaudio.save("output.wav", audio, sample_rate)
+
+ return "output.wav"
+
+def create_autoencoder_ui(model_config):
+
+ is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"]
+
+ if is_dac_rvq:
+ n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"]
+ else:
+ n_quantizers = 0
+
+ with gr.Blocks() as ui:
+ input_audio = gr.Audio(label="Input audio")
+ output_audio = gr.Audio(label="Output audio", interactive=False)
+ n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq)
+ latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise")
+ process_button = gr.Button("Process", variant='primary', scale=1)
+ process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process")
+
+ return ui
+
+def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max):
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ #Get the device from the model
+ device = next(model.parameters()).device
+
+ in_sr, audio = audio
+
+ audio = torch.from_numpy(audio).float().div(32767).to(device)
+
+ if audio.dim() == 1:
+ audio = audio.unsqueeze(0) # [1, n]
+ elif audio.dim() == 2:
+ audio = audio.transpose(0, 1) # [n, 2] -> [2, n]
+
+ audio = audio.unsqueeze(0)
+
+ audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max})
+
+ audio = rearrange(audio, "b d n -> d (b n)")
+
+ audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
+
+ torchaudio.save("output.wav", audio, sample_rate)
+
+ return "output.wav"
+
+def create_diffusion_prior_ui(model_config):
+ with gr.Blocks() as ui:
+ input_audio = gr.Audio(label="Input audio")
+ output_audio = gr.Audio(label="Output audio", interactive=False)
+ # Sampler params
+ with gr.Row():
+ steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps")
+ sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde")
+ sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min")
+ sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max")
+ process_button = gr.Button("Process", variant='primary', scale=1)
+ process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process")
+
+ return ui
+
+def create_lm_ui(model_config):
+ with gr.Blocks() as ui:
+ output_audio = gr.Audio(label="Output audio", interactive=False)
+ audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False)
+
+ # Sampling params
+ with gr.Row():
+ temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature")
+ top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p")
+ top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k")
+
+ generate_button = gr.Button("Generate", variant='primary', scale=1)
+ generate_button.click(
+ fn=generate_lm,
+ inputs=[
+ temperature_slider,
+ top_p_slider,
+ top_k_slider
+ ],
+ outputs=[output_audio, audio_spectrogram_output],
+ api_name="generate"
+ )
+
+ return ui
+
+def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False):
+
+ assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both"
+
+ if model_config_path is not None:
+ # Load config from json file
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+ else:
+ model_config = None
+
+ try:
+ has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available()
+ except Exception:
+ # In case this version of Torch doesn't even have `torch.backends.mps`...
+ has_mps = False
+
+ if has_mps:
+ device = torch.device("mps")
+ elif torch.cuda.is_available():
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+
+ print("Using device:", device)
+
+ _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device)
+
+ model_type = model_config["model_type"]
+
+ if model_type == "diffusion_cond":
+ ui = create_txt2audio_ui(model_config)
+ elif model_type == "diffusion_uncond":
+ ui = create_diffusion_uncond_ui(model_config)
+ elif model_type == "autoencoder" or model_type == "diffusion_autoencoder":
+ ui = create_autoencoder_ui(model_config)
+ elif model_type == "diffusion_prior":
+ ui = create_diffusion_prior_ui(model_config)
+ elif model_type == "lm":
+ ui = create_lm_ui(model_config)
+
+ return ui
diff --git a/ThinkSound/models/__init__.py b/ThinkSound/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e27bbcb19a00a93e05ed6cf2a3a38895f26975d
--- /dev/null
+++ b/ThinkSound/models/__init__.py
@@ -0,0 +1 @@
+from .factory import create_model_from_config, create_model_from_config_path
\ No newline at end of file
diff --git a/ThinkSound/models/adp.py b/ThinkSound/models/adp.py
new file mode 100644
index 0000000000000000000000000000000000000000..49eb526ab02d16eb4952d346401b1ad2b7e5cb7c
--- /dev/null
+++ b/ThinkSound/models/adp.py
@@ -0,0 +1,1588 @@
+# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
+# License can be found in LICENSES/LICENSE_ADP.txt
+
+import math
+from inspect import isfunction
+from math import ceil, floor, log, pi, log2
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
+from packaging import version
+
+import torch
+import torch.nn as nn
+from einops import rearrange, reduce, repeat
+from einops.layers.torch import Rearrange
+from einops_exts import rearrange_many
+from torch import Tensor, einsum
+from torch.backends.cuda import sdp_kernel
+from torch.nn import functional as F
+from dac.nn.layers import Snake1d
+
+"""
+Utils
+"""
+
+
+class ConditionedSequential(nn.Module):
+ def __init__(self, *modules):
+ super().__init__()
+ self.module_list = nn.ModuleList(*modules)
+
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
+ for module in self.module_list:
+ x = module(x, mapping)
+ return x
+
+T = TypeVar("T")
+
+def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def exists(val: Optional[T]) -> T:
+ return val is not None
+
+def closest_power_2(x: float) -> int:
+ exponent = log2(x)
+ distance_fn = lambda z: abs(x - 2 ** z) # noqa
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
+ return 2 ** int(exponent_closest)
+
+def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
+ for key in d.keys():
+ no_prefix = int(not key.startswith(prefix))
+ return_dicts[no_prefix][key] = d[key]
+ return return_dicts
+
+def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
+ if keep_prefix:
+ return kwargs_with_prefix, kwargs
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
+ return kwargs_no_prefix, kwargs
+
+"""
+Convolutional Blocks
+"""
+import typing as tp
+
+# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
+# License available in LICENSES/LICENSE_META.txt
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+ padding_total: int = 0) -> int:
+ """See `pad_for_conv1d`."""
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left: end]
+
+
+class Conv1d(nn.Conv1d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: Tensor, causal=False) -> Tensor:
+ kernel_size = self.kernel_size[0]
+ stride = self.stride[0]
+ dilation = self.dilation[0]
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
+ padding_total = kernel_size - stride
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ if causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(x, (padding_left, padding_right + extra_padding))
+ return super().forward(x)
+
+class ConvTranspose1d(nn.ConvTranspose1d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x: Tensor, causal=False) -> Tensor:
+ kernel_size = self.kernel_size[0]
+ stride = self.stride[0]
+ padding_total = kernel_size - stride
+
+ y = super().forward(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if causal:
+ padding_right = ceil(padding_total)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
+
+
+def Downsample1d(
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
+) -> nn.Module:
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
+
+ return Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * kernel_multiplier + 1,
+ stride=factor
+ )
+
+
+def Upsample1d(
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
+) -> nn.Module:
+
+ if factor == 1:
+ return Conv1d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3
+ )
+
+ if use_nearest:
+ return nn.Sequential(
+ nn.Upsample(scale_factor=factor, mode="nearest"),
+ Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3
+ ),
+ )
+ else:
+ return ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * 2,
+ stride=factor
+ )
+
+
+class ConvBlock1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ *,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: int = 1,
+ num_groups: int = 8,
+ use_norm: bool = True,
+ use_snake: bool = False
+ ) -> None:
+ super().__init__()
+
+ self.groupnorm = (
+ nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
+ if use_norm
+ else nn.Identity()
+ )
+
+ if use_snake:
+ self.activation = Snake1d(in_channels)
+ else:
+ self.activation = nn.SiLU()
+
+ self.project = Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ )
+
+ def forward(
+ self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
+ ) -> Tensor:
+ x = self.groupnorm(x)
+ if exists(scale_shift):
+ scale, shift = scale_shift
+ x = x * (scale + 1) + shift
+ x = self.activation(x)
+ return self.project(x, causal=causal)
+
+
+class MappingToScaleShift(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ channels: int,
+ ):
+ super().__init__()
+
+ self.to_scale_shift = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(in_features=features, out_features=channels * 2),
+ )
+
+ def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
+ scale_shift = self.to_scale_shift(mapping)
+ scale_shift = rearrange(scale_shift, "b c -> b c 1")
+ scale, shift = scale_shift.chunk(2, dim=1)
+ return scale, shift
+
+
+class ResnetBlock1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ *,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: int = 1,
+ use_norm: bool = True,
+ use_snake: bool = False,
+ num_groups: int = 8,
+ context_mapping_features: Optional[int] = None,
+ ) -> None:
+ super().__init__()
+
+ self.use_mapping = exists(context_mapping_features)
+
+ self.block1 = ConvBlock1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ use_norm=use_norm,
+ num_groups=num_groups,
+ use_snake=use_snake
+ )
+
+ if self.use_mapping:
+ assert exists(context_mapping_features)
+ self.to_scale_shift = MappingToScaleShift(
+ features=context_mapping_features, channels=out_channels
+ )
+
+ self.block2 = ConvBlock1d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ use_norm=use_norm,
+ num_groups=num_groups,
+ use_snake=use_snake
+ )
+
+ self.to_out = (
+ Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
+ assert_message = "context mapping required if context_mapping_features > 0"
+ assert not (self.use_mapping ^ exists(mapping)), assert_message
+
+ h = self.block1(x, causal=causal)
+
+ scale_shift = None
+ if self.use_mapping:
+ scale_shift = self.to_scale_shift(mapping)
+
+ h = self.block2(h, scale_shift=scale_shift, causal=causal)
+
+ return h + self.to_out(x)
+
+
+class Patcher(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ patch_size: int,
+ context_mapping_features: Optional[int] = None,
+ use_snake: bool = False,
+ ):
+ super().__init__()
+ assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
+ assert out_channels % patch_size == 0, assert_message
+ self.patch_size = patch_size
+
+ self.block = ResnetBlock1d(
+ in_channels=in_channels,
+ out_channels=out_channels // patch_size,
+ num_groups=1,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
+ x = self.block(x, mapping, causal=causal)
+ x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
+ return x
+
+
+class Unpatcher(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ patch_size: int,
+ context_mapping_features: Optional[int] = None,
+ use_snake: bool = False
+ ):
+ super().__init__()
+ assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
+ assert in_channels % patch_size == 0, assert_message
+ self.patch_size = patch_size
+
+ self.block = ResnetBlock1d(
+ in_channels=in_channels // patch_size,
+ out_channels=out_channels,
+ num_groups=1,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
+ x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
+ x = self.block(x, mapping, causal=causal)
+ return x
+
+
+"""
+Attention Components
+"""
+def FeedForward(features: int, multiplier: int) -> nn.Module:
+ mid_features = features * multiplier
+ return nn.Sequential(
+ nn.Linear(in_features=features, out_features=mid_features),
+ nn.GELU(),
+ nn.Linear(in_features=mid_features, out_features=features),
+ )
+
+def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
+ b, ndim = sim.shape[0], mask.ndim
+ if ndim == 3:
+ mask = rearrange(mask, "b n m -> b 1 n m")
+ if ndim == 2:
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
+ max_neg_value = -torch.finfo(sim.dtype).max
+ sim = sim.masked_fill(~mask, max_neg_value)
+ return sim
+
+def causal_mask(q: Tensor, k: Tensor) -> Tensor:
+ b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
+ mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
+ mask = repeat(mask, "n m -> b n m", b=b)
+ return mask
+
+class AttentionBase(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ *,
+ head_features: int,
+ num_heads: int,
+ out_features: Optional[int] = None,
+ ):
+ super().__init__()
+ self.scale = head_features**-0.5
+ self.num_heads = num_heads
+ mid_features = head_features * num_heads
+ out_features = default(out_features, features)
+
+ self.to_out = nn.Linear(
+ in_features=mid_features, out_features=out_features
+ )
+
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
+
+ if not self.use_flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+
+ if device_properties.major == 8 and device_properties.minor == 0:
+ # Use flash attention for A100 GPUs
+ self.sdp_kernel_config = (True, False, False)
+ else:
+ # Don't use flash attention for other GPUs
+ self.sdp_kernel_config = (False, True, True)
+
+ def forward(
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
+ ) -> Tensor:
+ # Split heads
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
+
+ if not self.use_flash:
+ if is_causal and not mask:
+ # Mask out future tokens for causal attention
+ mask = causal_mask(q, k)
+
+ # Compute similarity matrix and add eventual mask
+ sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
+ sim = add_mask(sim, mask) if exists(mask) else sim
+
+ # Get attention matrix with softmax
+ attn = sim.softmax(dim=-1, dtype=torch.float32)
+
+ # Compute values
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
+ else:
+ with sdp_kernel(*self.sdp_kernel_config):
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ return self.to_out(out)
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ *,
+ head_features: int,
+ num_heads: int,
+ out_features: Optional[int] = None,
+ context_features: Optional[int] = None,
+ causal: bool = False,
+ ):
+ super().__init__()
+ self.context_features = context_features
+ self.causal = causal
+ mid_features = head_features * num_heads
+ context_features = default(context_features, features)
+
+ self.norm = nn.LayerNorm(features)
+ self.norm_context = nn.LayerNorm(context_features)
+ self.to_q = nn.Linear(
+ in_features=features, out_features=mid_features, bias=False
+ )
+ self.to_kv = nn.Linear(
+ in_features=context_features, out_features=mid_features * 2, bias=False
+ )
+ self.attention = AttentionBase(
+ features,
+ num_heads=num_heads,
+ head_features=head_features,
+ out_features=out_features,
+ )
+
+ def forward(
+ self,
+ x: Tensor, # [b, n, c]
+ context: Optional[Tensor] = None, # [b, m, d]
+ context_mask: Optional[Tensor] = None, # [b, m], false is masked,
+ causal: Optional[bool] = False,
+ ) -> Tensor:
+ assert_message = "You must provide a context when using context_features"
+ assert not self.context_features or exists(context), assert_message
+ # Use context if provided
+ context = default(context, x)
+ # Normalize then compute q from input and k,v from context
+ x, context = self.norm(x), self.norm_context(context)
+
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
+
+ if exists(context_mask):
+ # Mask out cross-attention for padding tokens
+ mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
+ k, v = k * mask, v * mask
+
+ # Compute and return attention
+ return self.attention(q, k, v, is_causal=self.causal or causal)
+
+
+def FeedForward(features: int, multiplier: int) -> nn.Module:
+ mid_features = features * multiplier
+ return nn.Sequential(
+ nn.Linear(in_features=features, out_features=mid_features),
+ nn.GELU(),
+ nn.Linear(in_features=mid_features, out_features=features),
+ )
+
+"""
+Transformer Blocks
+"""
+
+
+class TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ num_heads: int,
+ head_features: int,
+ multiplier: int,
+ context_features: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.use_cross_attention = exists(context_features) and context_features > 0
+
+ self.attention = Attention(
+ features=features,
+ num_heads=num_heads,
+ head_features=head_features
+ )
+
+ if self.use_cross_attention:
+ self.cross_attention = Attention(
+ features=features,
+ num_heads=num_heads,
+ head_features=head_features,
+ context_features=context_features
+ )
+
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
+
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
+ x = self.attention(x, causal=causal) + x
+ if self.use_cross_attention:
+ x = self.cross_attention(x, context=context, context_mask=context_mask) + x
+ x = self.feed_forward(x) + x
+ return x
+
+
+"""
+Transformers
+"""
+
+
+class Transformer1d(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ channels: int,
+ num_heads: int,
+ head_features: int,
+ multiplier: int,
+ context_features: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.to_in = nn.Sequential(
+ nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
+ Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=1,
+ ),
+ Rearrange("b c t -> b t c"),
+ )
+
+ self.blocks = nn.ModuleList(
+ [
+ TransformerBlock(
+ features=channels,
+ head_features=head_features,
+ num_heads=num_heads,
+ multiplier=multiplier,
+ context_features=context_features,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.to_out = nn.Sequential(
+ Rearrange("b t c -> b c t"),
+ Conv1d(
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=1,
+ ),
+ )
+
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
+ x = self.to_in(x)
+ for block in self.blocks:
+ x = block(x, context=context, context_mask=context_mask, causal=causal)
+ x = self.to_out(x)
+ return x
+
+
+"""
+Time Embeddings
+"""
+
+
+class SinusoidalEmbedding(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x: Tensor) -> Tensor:
+ device, half_dim = x.device, self.dim // 2
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
+
+
+class LearnedPositionalEmbedding(nn.Module):
+ """Used for continuous time"""
+
+ def __init__(self, dim: int):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = rearrange(x, "b -> b 1")
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+ fouriered = torch.cat((x, fouriered), dim=-1)
+ return fouriered
+
+
+def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
+ return nn.Sequential(
+ LearnedPositionalEmbedding(dim),
+ nn.Linear(in_features=dim + 1, out_features=out_features),
+ )
+
+
+"""
+Encoder/Decoder Components
+"""
+
+
+class DownsampleBlock1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ *,
+ factor: int,
+ num_groups: int,
+ num_layers: int,
+ kernel_multiplier: int = 2,
+ use_pre_downsample: bool = True,
+ use_skip: bool = False,
+ use_snake: bool = False,
+ extract_channels: int = 0,
+ context_channels: int = 0,
+ num_transformer_blocks: int = 0,
+ attention_heads: Optional[int] = None,
+ attention_features: Optional[int] = None,
+ attention_multiplier: Optional[int] = None,
+ context_mapping_features: Optional[int] = None,
+ context_embedding_features: Optional[int] = None,
+ ):
+ super().__init__()
+ self.use_pre_downsample = use_pre_downsample
+ self.use_skip = use_skip
+ self.use_transformer = num_transformer_blocks > 0
+ self.use_extract = extract_channels > 0
+ self.use_context = context_channels > 0
+
+ channels = out_channels if use_pre_downsample else in_channels
+
+ self.downsample = Downsample1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ factor=factor,
+ kernel_multiplier=kernel_multiplier,
+ )
+
+ self.blocks = nn.ModuleList(
+ [
+ ResnetBlock1d(
+ in_channels=channels + context_channels if i == 0 else channels,
+ out_channels=channels,
+ num_groups=num_groups,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ if self.use_transformer:
+ assert (
+ (exists(attention_heads) or exists(attention_features))
+ and exists(attention_multiplier)
+ )
+
+ if attention_features is None and attention_heads is not None:
+ attention_features = channels // attention_heads
+
+ if attention_heads is None and attention_features is not None:
+ attention_heads = channels // attention_features
+
+ self.transformer = Transformer1d(
+ num_layers=num_transformer_blocks,
+ channels=channels,
+ num_heads=attention_heads,
+ head_features=attention_features,
+ multiplier=attention_multiplier,
+ context_features=context_embedding_features
+ )
+
+ if self.use_extract:
+ num_extract_groups = min(num_groups, extract_channels)
+ self.to_extracted = ResnetBlock1d(
+ in_channels=out_channels,
+ out_channels=extract_channels,
+ num_groups=num_extract_groups,
+ use_snake=use_snake
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ *,
+ mapping: Optional[Tensor] = None,
+ channels: Optional[Tensor] = None,
+ embedding: Optional[Tensor] = None,
+ embedding_mask: Optional[Tensor] = None,
+ causal: Optional[bool] = False
+ ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
+
+ if self.use_pre_downsample:
+ x = self.downsample(x)
+
+ if self.use_context and exists(channels):
+ x = torch.cat([x, channels], dim=1)
+
+ skips = []
+ for block in self.blocks:
+ x = block(x, mapping=mapping, causal=causal)
+ skips += [x] if self.use_skip else []
+
+ if self.use_transformer:
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
+ skips += [x] if self.use_skip else []
+
+ if not self.use_pre_downsample:
+ x = self.downsample(x)
+
+ if self.use_extract:
+ extracted = self.to_extracted(x)
+ return x, extracted
+
+ return (x, skips) if self.use_skip else x
+
+
+class UpsampleBlock1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ *,
+ factor: int,
+ num_layers: int,
+ num_groups: int,
+ use_nearest: bool = False,
+ use_pre_upsample: bool = False,
+ use_skip: bool = False,
+ use_snake: bool = False,
+ skip_channels: int = 0,
+ use_skip_scale: bool = False,
+ extract_channels: int = 0,
+ num_transformer_blocks: int = 0,
+ attention_heads: Optional[int] = None,
+ attention_features: Optional[int] = None,
+ attention_multiplier: Optional[int] = None,
+ context_mapping_features: Optional[int] = None,
+ context_embedding_features: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.use_extract = extract_channels > 0
+ self.use_pre_upsample = use_pre_upsample
+ self.use_transformer = num_transformer_blocks > 0
+ self.use_skip = use_skip
+ self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
+
+ channels = out_channels if use_pre_upsample else in_channels
+
+ self.blocks = nn.ModuleList(
+ [
+ ResnetBlock1d(
+ in_channels=channels + skip_channels,
+ out_channels=channels,
+ num_groups=num_groups,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ if self.use_transformer:
+ assert (
+ (exists(attention_heads) or exists(attention_features))
+ and exists(attention_multiplier)
+ )
+
+ if attention_features is None and attention_heads is not None:
+ attention_features = channels // attention_heads
+
+ if attention_heads is None and attention_features is not None:
+ attention_heads = channels // attention_features
+
+ self.transformer = Transformer1d(
+ num_layers=num_transformer_blocks,
+ channels=channels,
+ num_heads=attention_heads,
+ head_features=attention_features,
+ multiplier=attention_multiplier,
+ context_features=context_embedding_features,
+ )
+
+ self.upsample = Upsample1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ factor=factor,
+ use_nearest=use_nearest,
+ )
+
+ if self.use_extract:
+ num_extract_groups = min(num_groups, extract_channels)
+ self.to_extracted = ResnetBlock1d(
+ in_channels=out_channels,
+ out_channels=extract_channels,
+ num_groups=num_extract_groups,
+ use_snake=use_snake
+ )
+
+ def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
+ return torch.cat([x, skip * self.skip_scale], dim=1)
+
+ def forward(
+ self,
+ x: Tensor,
+ *,
+ skips: Optional[List[Tensor]] = None,
+ mapping: Optional[Tensor] = None,
+ embedding: Optional[Tensor] = None,
+ embedding_mask: Optional[Tensor] = None,
+ causal: Optional[bool] = False
+ ) -> Union[Tuple[Tensor, Tensor], Tensor]:
+
+ if self.use_pre_upsample:
+ x = self.upsample(x)
+
+ for block in self.blocks:
+ x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
+ x = block(x, mapping=mapping, causal=causal)
+
+ if self.use_transformer:
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
+
+ if not self.use_pre_upsample:
+ x = self.upsample(x)
+
+ if self.use_extract:
+ extracted = self.to_extracted(x)
+ return x, extracted
+
+ return x
+
+
+class BottleneckBlock1d(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ *,
+ num_groups: int,
+ num_transformer_blocks: int = 0,
+ attention_heads: Optional[int] = None,
+ attention_features: Optional[int] = None,
+ attention_multiplier: Optional[int] = None,
+ context_mapping_features: Optional[int] = None,
+ context_embedding_features: Optional[int] = None,
+ use_snake: bool = False,
+ ):
+ super().__init__()
+ self.use_transformer = num_transformer_blocks > 0
+
+ self.pre_block = ResnetBlock1d(
+ in_channels=channels,
+ out_channels=channels,
+ num_groups=num_groups,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ if self.use_transformer:
+ assert (
+ (exists(attention_heads) or exists(attention_features))
+ and exists(attention_multiplier)
+ )
+
+ if attention_features is None and attention_heads is not None:
+ attention_features = channels // attention_heads
+
+ if attention_heads is None and attention_features is not None:
+ attention_heads = channels // attention_features
+
+ self.transformer = Transformer1d(
+ num_layers=num_transformer_blocks,
+ channels=channels,
+ num_heads=attention_heads,
+ head_features=attention_features,
+ multiplier=attention_multiplier,
+ context_features=context_embedding_features,
+ )
+
+ self.post_block = ResnetBlock1d(
+ in_channels=channels,
+ out_channels=channels,
+ num_groups=num_groups,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ def forward(
+ self,
+ x: Tensor,
+ *,
+ mapping: Optional[Tensor] = None,
+ embedding: Optional[Tensor] = None,
+ embedding_mask: Optional[Tensor] = None,
+ causal: Optional[bool] = False
+ ) -> Tensor:
+ x = self.pre_block(x, mapping=mapping, causal=causal)
+ if self.use_transformer:
+ x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
+ x = self.post_block(x, mapping=mapping, causal=causal)
+ return x
+
+
+"""
+UNet
+"""
+
+
+class UNet1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ channels: int,
+ multipliers: Sequence[int],
+ factors: Sequence[int],
+ num_blocks: Sequence[int],
+ attentions: Sequence[int],
+ patch_size: int = 1,
+ resnet_groups: int = 8,
+ use_context_time: bool = True,
+ kernel_multiplier_downsample: int = 2,
+ use_nearest_upsample: bool = False,
+ use_skip_scale: bool = True,
+ use_snake: bool = False,
+ use_stft: bool = False,
+ use_stft_context: bool = False,
+ out_channels: Optional[int] = None,
+ context_features: Optional[int] = None,
+ context_features_multiplier: int = 4,
+ context_channels: Optional[Sequence[int]] = None,
+ context_embedding_features: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__()
+ out_channels = default(out_channels, in_channels)
+ context_channels = list(default(context_channels, []))
+ num_layers = len(multipliers) - 1
+ use_context_features = exists(context_features)
+ use_context_channels = len(context_channels) > 0
+ context_mapping_features = None
+
+ attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
+
+ self.num_layers = num_layers
+ self.use_context_time = use_context_time
+ self.use_context_features = use_context_features
+ self.use_context_channels = use_context_channels
+ self.use_stft = use_stft
+ self.use_stft_context = use_stft_context
+
+ self.context_features = context_features
+ context_channels_pad_length = num_layers + 1 - len(context_channels)
+ context_channels = context_channels + [0] * context_channels_pad_length
+ self.context_channels = context_channels
+ self.context_embedding_features = context_embedding_features
+
+ if use_context_channels:
+ has_context = [c > 0 for c in context_channels]
+ self.has_context = has_context
+ self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
+
+ assert (
+ len(factors) == num_layers
+ and len(attentions) >= num_layers
+ and len(num_blocks) == num_layers
+ )
+
+ if use_context_time or use_context_features:
+ context_mapping_features = channels * context_features_multiplier
+
+ self.to_mapping = nn.Sequential(
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ nn.Linear(context_mapping_features, context_mapping_features),
+ nn.GELU(),
+ )
+
+ if use_context_time:
+ assert exists(context_mapping_features)
+ self.to_time = nn.Sequential(
+ TimePositionalEmbedding(
+ dim=channels, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ if use_context_features:
+ assert exists(context_features) and exists(context_mapping_features)
+ self.to_features = nn.Sequential(
+ nn.Linear(
+ in_features=context_features, out_features=context_mapping_features
+ ),
+ nn.GELU(),
+ )
+
+ if use_stft:
+ stft_kwargs, kwargs = groupby("stft_", kwargs)
+ assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
+ stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
+ in_channels *= stft_channels
+ out_channels *= stft_channels
+ context_channels[0] *= stft_channels if use_stft_context else 1
+ assert exists(in_channels) and exists(out_channels)
+ self.stft = STFT(**stft_kwargs)
+
+ assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
+
+ self.to_in = Patcher(
+ in_channels=in_channels + context_channels[0],
+ out_channels=channels * multipliers[0],
+ patch_size=patch_size,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ self.downsamples = nn.ModuleList(
+ [
+ DownsampleBlock1d(
+ in_channels=channels * multipliers[i],
+ out_channels=channels * multipliers[i + 1],
+ context_mapping_features=context_mapping_features,
+ context_channels=context_channels[i + 1],
+ context_embedding_features=context_embedding_features,
+ num_layers=num_blocks[i],
+ factor=factors[i],
+ kernel_multiplier=kernel_multiplier_downsample,
+ num_groups=resnet_groups,
+ use_pre_downsample=True,
+ use_skip=True,
+ use_snake=use_snake,
+ num_transformer_blocks=attentions[i],
+ **attention_kwargs,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ self.bottleneck = BottleneckBlock1d(
+ channels=channels * multipliers[-1],
+ context_mapping_features=context_mapping_features,
+ context_embedding_features=context_embedding_features,
+ num_groups=resnet_groups,
+ num_transformer_blocks=attentions[-1],
+ use_snake=use_snake,
+ **attention_kwargs,
+ )
+
+ self.upsamples = nn.ModuleList(
+ [
+ UpsampleBlock1d(
+ in_channels=channels * multipliers[i + 1],
+ out_channels=channels * multipliers[i],
+ context_mapping_features=context_mapping_features,
+ context_embedding_features=context_embedding_features,
+ num_layers=num_blocks[i] + (1 if attentions[i] else 0),
+ factor=factors[i],
+ use_nearest=use_nearest_upsample,
+ num_groups=resnet_groups,
+ use_skip_scale=use_skip_scale,
+ use_pre_upsample=False,
+ use_skip=True,
+ use_snake=use_snake,
+ skip_channels=channels * multipliers[i + 1],
+ num_transformer_blocks=attentions[i],
+ **attention_kwargs,
+ )
+ for i in reversed(range(num_layers))
+ ]
+ )
+
+ self.to_out = Unpatcher(
+ in_channels=channels * multipliers[0],
+ out_channels=out_channels,
+ patch_size=patch_size,
+ context_mapping_features=context_mapping_features,
+ use_snake=use_snake
+ )
+
+ def get_channels(
+ self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
+ ) -> Optional[Tensor]:
+ """Gets context channels at `layer` and checks that shape is correct"""
+ use_context_channels = self.use_context_channels and self.has_context[layer]
+ if not use_context_channels:
+ return None
+ assert exists(channels_list), "Missing context"
+ # Get channels index (skipping zero channel contexts)
+ channels_id = self.channels_ids[layer]
+ # Get channels
+ channels = channels_list[channels_id]
+ message = f"Missing context for layer {layer} at index {channels_id}"
+ assert exists(channels), message
+ # Check channels
+ num_channels = self.context_channels[layer]
+ message = f"Expected context with {num_channels} channels at idx {channels_id}"
+ assert channels.shape[1] == num_channels, message
+ # STFT channels if requested
+ channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
+ return channels
+
+ def get_mapping(
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
+ ) -> Optional[Tensor]:
+ """Combines context time features and features into mapping"""
+ items, mapping = [], None
+ # Compute time features
+ if self.use_context_time:
+ assert_message = "use_context_time=True but no time features provided"
+ assert exists(time), assert_message
+ items += [self.to_time(time)]
+ # Compute features
+ if self.use_context_features:
+ assert_message = "context_features exists but no features provided"
+ assert exists(features), assert_message
+ items += [self.to_features(features)]
+ # Compute joint mapping
+ if self.use_context_time or self.use_context_features:
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
+ mapping = self.to_mapping(mapping)
+ return mapping
+
+ def forward(
+ self,
+ x: Tensor,
+ time: Optional[Tensor] = None,
+ *,
+ features: Optional[Tensor] = None,
+ channels_list: Optional[Sequence[Tensor]] = None,
+ embedding: Optional[Tensor] = None,
+ embedding_mask: Optional[Tensor] = None,
+ causal: Optional[bool] = False,
+ ) -> Tensor:
+ channels = self.get_channels(channels_list, layer=0)
+ # Apply stft if required
+ x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
+ # Concat context channels at layer 0 if provided
+ x = torch.cat([x, channels], dim=1) if exists(channels) else x
+ # Compute mapping from time and features
+ mapping = self.get_mapping(time, features)
+ x = self.to_in(x, mapping, causal=causal)
+ skips_list = [x]
+
+ for i, downsample in enumerate(self.downsamples):
+ channels = self.get_channels(channels_list, layer=i + 1)
+ x, skips = downsample(
+ x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
+ )
+ skips_list += [skips]
+
+ x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
+
+ for i, upsample in enumerate(self.upsamples):
+ skips = skips_list.pop()
+ x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
+
+ x += skips_list.pop()
+ x = self.to_out(x, mapping, causal=causal)
+ x = self.stft.decode1d(x) if self.use_stft else x
+
+ return x
+
+
+""" Conditioning Modules """
+
+
+class FixedEmbedding(nn.Module):
+ def __init__(self, max_length: int, features: int):
+ super().__init__()
+ self.max_length = max_length
+ self.embedding = nn.Embedding(max_length, features)
+
+ def forward(self, x: Tensor) -> Tensor:
+ batch_size, length, device = *x.shape[0:2], x.device
+ assert_message = "Input sequence length must be <= max_length"
+ assert length <= self.max_length, assert_message
+ position = torch.arange(length, device=device)
+ fixed_embedding = self.embedding(position)
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
+ return fixed_embedding
+
+
+def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
+ if proba == 1:
+ return torch.ones(shape, device=device, dtype=torch.bool)
+ elif proba == 0:
+ return torch.zeros(shape, device=device, dtype=torch.bool)
+ else:
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
+
+
+class UNetCFG1d(UNet1d):
+
+ """UNet1d with Classifier-Free Guidance"""
+
+ def __init__(
+ self,
+ context_embedding_max_length: int,
+ context_embedding_features: int,
+ use_xattn_time: bool = False,
+ **kwargs,
+ ):
+ super().__init__(
+ context_embedding_features=context_embedding_features, **kwargs
+ )
+
+ self.use_xattn_time = use_xattn_time
+
+ if use_xattn_time:
+ assert exists(context_embedding_features)
+ self.to_time_embedding = nn.Sequential(
+ TimePositionalEmbedding(
+ dim=kwargs["channels"], out_features=context_embedding_features
+ ),
+ nn.GELU(),
+ )
+
+ context_embedding_max_length += 1 # Add one for time embedding
+
+ self.fixed_embedding = FixedEmbedding(
+ max_length=context_embedding_max_length, features=context_embedding_features
+ )
+
+ def forward( # type: ignore
+ self,
+ x: Tensor,
+ time: Tensor,
+ *,
+ embedding: Tensor,
+ embedding_mask: Optional[Tensor] = None,
+ embedding_scale: float = 1.0,
+ embedding_mask_proba: float = 0.0,
+ batch_cfg: bool = False,
+ rescale_cfg: bool = False,
+ scale_phi: float = 0.4,
+ negative_embedding: Optional[Tensor] = None,
+ negative_embedding_mask: Optional[Tensor] = None,
+ **kwargs,
+ ) -> Tensor:
+ b, device = embedding.shape[0], embedding.device
+
+ if self.use_xattn_time:
+ embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
+
+ if embedding_mask is not None:
+ embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
+
+ fixed_embedding = self.fixed_embedding(embedding)
+
+ if embedding_mask_proba > 0.0:
+ # Randomly mask embedding
+ batch_mask = rand_bool(
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
+ )
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
+
+ if embedding_scale != 1.0:
+ if batch_cfg:
+ batch_x = torch.cat([x, x], dim=0)
+ batch_time = torch.cat([time, time], dim=0)
+
+ if negative_embedding is not None:
+ if negative_embedding_mask is not None:
+ negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
+
+ negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
+
+ batch_embed = torch.cat([embedding, negative_embedding], dim=0)
+
+ else:
+ batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
+
+ batch_mask = None
+ if embedding_mask is not None:
+ batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
+
+ batch_features = None
+ features = kwargs.pop("features", None)
+ if self.use_context_features:
+ batch_features = torch.cat([features, features], dim=0)
+
+ batch_channels = None
+ channels_list = kwargs.pop("channels_list", None)
+ if self.use_context_channels:
+ batch_channels = []
+ for channels in channels_list:
+ batch_channels += [torch.cat([channels, channels], dim=0)]
+
+ # Compute both normal and fixed embedding outputs
+ batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
+ out, out_masked = batch_out.chunk(2, dim=0)
+
+ else:
+ # Compute both normal and fixed embedding outputs
+ out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
+ out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
+
+ out_cfg = out_masked + (out - out_masked) * embedding_scale
+
+ if rescale_cfg:
+
+ out_std = out.std(dim=1, keepdim=True)
+ out_cfg_std = out_cfg.std(dim=1, keepdim=True)
+
+ return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
+
+ else:
+
+ return out_cfg
+
+ else:
+ return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
+
+
+class UNetNCCA1d(UNet1d):
+
+ """UNet1d with Noise Channel Conditioning Augmentation"""
+
+ def __init__(self, context_features: int, **kwargs):
+ super().__init__(context_features=context_features, **kwargs)
+ self.embedder = NumberEmbedder(features=context_features)
+
+ def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
+ x = x if torch.is_tensor(x) else torch.tensor(x)
+ return x.expand(shape)
+
+ def forward( # type: ignore
+ self,
+ x: Tensor,
+ time: Tensor,
+ *,
+ channels_list: Sequence[Tensor],
+ channels_augmentation: Union[
+ bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
+ ] = False,
+ channels_scale: Union[
+ float, Sequence[float], Sequence[Sequence[float]], Tensor
+ ] = 0,
+ **kwargs,
+ ) -> Tensor:
+ b, n = x.shape[0], len(channels_list)
+ channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
+ channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
+
+ # Augmentation (for each channel list item)
+ for i in range(n):
+ scale = channels_scale[:, i] * channels_augmentation[:, i]
+ scale = rearrange(scale, "b -> b 1 1")
+ item = channels_list[i]
+ channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
+
+ # Scale embedding (sum reduction if more than one channel list item)
+ channels_scale_emb = self.embedder(channels_scale)
+ channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
+
+ return super().forward(
+ x=x,
+ time=time,
+ channels_list=channels_list,
+ features=channels_scale_emb,
+ **kwargs,
+ )
+
+
+class UNetAll1d(UNetCFG1d, UNetNCCA1d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, *args, **kwargs): # type: ignore
+ return UNetCFG1d.forward(self, *args, **kwargs)
+
+
+def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
+ if type == "base":
+ return UNet1d(**kwargs)
+ elif type == "all":
+ return UNetAll1d(**kwargs)
+ elif type == "cfg":
+ return UNetCFG1d(**kwargs)
+ elif type == "ncca":
+ return UNetNCCA1d(**kwargs)
+ else:
+ raise ValueError(f"Unknown XUNet1d type: {type}")
+
+class NumberEmbedder(nn.Module):
+ def __init__(
+ self,
+ features: int,
+ dim: int = 256,
+ ):
+ super().__init__()
+ self.features = features
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
+
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
+ if not torch.is_tensor(x):
+ device = next(self.embedding.parameters()).device
+ x = torch.tensor(x, device=device)
+ assert isinstance(x, Tensor)
+ shape = x.shape
+ x = rearrange(x, "... -> (...)")
+ embedding = self.embedding(x)
+ x = embedding.view(*shape, self.features)
+ return x # type: ignore
+
+
+"""
+Audio Transforms
+"""
+
+
+class STFT(nn.Module):
+ """Helper for torch stft and istft"""
+
+ def __init__(
+ self,
+ num_fft: int = 1023,
+ hop_length: int = 256,
+ window_length: Optional[int] = None,
+ length: Optional[int] = None,
+ use_complex: bool = False,
+ ):
+ super().__init__()
+ self.num_fft = num_fft
+ self.hop_length = default(hop_length, floor(num_fft // 4))
+ self.window_length = default(window_length, num_fft)
+ self.length = length
+ self.register_buffer("window", torch.hann_window(self.window_length))
+ self.use_complex = use_complex
+
+ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
+ b = wave.shape[0]
+ wave = rearrange(wave, "b c t -> (b c) t")
+
+ stft = torch.stft(
+ wave,
+ n_fft=self.num_fft,
+ hop_length=self.hop_length,
+ win_length=self.window_length,
+ window=self.window, # type: ignore
+ return_complex=True,
+ normalized=True,
+ )
+
+ if self.use_complex:
+ # Returns real and imaginary
+ stft_a, stft_b = stft.real, stft.imag
+ else:
+ # Returns magnitude and phase matrices
+ magnitude, phase = torch.abs(stft), torch.angle(stft)
+ stft_a, stft_b = magnitude, phase
+
+ return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
+
+ def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
+ b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
+ length = closest_power_2(l * self.hop_length)
+
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
+
+ if self.use_complex:
+ real, imag = stft_a, stft_b
+ else:
+ magnitude, phase = stft_a, stft_b
+ real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
+
+ stft = torch.stack([real, imag], dim=-1)
+
+ wave = torch.istft(
+ stft,
+ n_fft=self.num_fft,
+ hop_length=self.hop_length,
+ win_length=self.window_length,
+ window=self.window, # type: ignore
+ length=default(self.length, length),
+ normalized=True,
+ )
+
+ return rearrange(wave, "(b c) t -> b c t", b=b)
+
+ def encode1d(
+ self, wave: Tensor, stacked: bool = True
+ ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
+ stft_a, stft_b = self.encode(wave)
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
+ return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
+
+ def decode1d(self, stft_pair: Tensor) -> Tensor:
+ f = self.num_fft // 2 + 1
+ stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
+ stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
+ return self.decode(stft_a, stft_b)
diff --git a/ThinkSound/models/autoencoders.py b/ThinkSound/models/autoencoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..853fdc9712c12284b2073c1cce8c384baa258337
--- /dev/null
+++ b/ThinkSound/models/autoencoders.py
@@ -0,0 +1,800 @@
+import torch
+import math
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+from torchaudio import transforms as T
+from alias_free_torch import Activation1d
+from dac.nn.layers import WNConv1d, WNConvTranspose1d
+from typing import Literal, Dict, Any
+
+from ..inference.sampling import sample
+from ..inference.utils import prepare_audio
+from .blocks import SnakeBeta
+from .bottleneck import Bottleneck, DiscreteBottleneck
+from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
+from .factory import create_pretransform_from_config, create_bottleneck_from_config
+from .pretransforms import Pretransform
+
+def checkpoint(function, *args, **kwargs):
+ kwargs.setdefault("use_reentrant", False)
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
+
+def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
+ if activation == "elu":
+ act = nn.ELU()
+ elif activation == "snake":
+ act = SnakeBeta(channels)
+ elif activation == "none":
+ act = nn.Identity()
+ else:
+ raise ValueError(f"Unknown activation {activation}")
+
+ if antialias:
+ act = Activation1d(act)
+
+ return act
+
+class ResidualUnit(nn.Module):
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
+ super().__init__()
+
+ self.dilation = dilation
+
+ padding = (dilation * (7-1)) // 2
+
+ self.layers = nn.Sequential(
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=7, dilation=dilation, padding=padding),
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
+ kernel_size=1)
+ )
+
+ def forward(self, x):
+ res = x
+
+ #x = checkpoint(self.layers, x)
+ x = self.layers(x)
+
+ return x + res
+
+class EncoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
+ super().__init__()
+
+ self.layers = nn.Sequential(
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
+ ResidualUnit(in_channels=in_channels,
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+class DecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
+ super().__init__()
+
+ if use_nearest_upsample:
+ upsample_layer = nn.Sequential(
+ nn.Upsample(scale_factor=stride, mode="nearest"),
+ WNConv1d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2*stride,
+ stride=1,
+ bias=False,
+ padding='same')
+ )
+ else:
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
+
+ self.layers = nn.Sequential(
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
+ upsample_layer,
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=1, use_snake=use_snake),
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=3, use_snake=use_snake),
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
+ dilation=9, use_snake=use_snake),
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+class OobleckEncoder(nn.Module):
+ def __init__(self,
+ in_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults = [1, 2, 4, 8],
+ strides = [2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False
+ ):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
+ ]
+
+ for i in range(self.depth-1):
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
+
+ layers += [
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class OobleckDecoder(nn.Module):
+ def __init__(self,
+ out_channels=2,
+ channels=128,
+ latent_dim=32,
+ c_mults = [1, 2, 4, 8],
+ strides = [2, 4, 8, 8],
+ use_snake=False,
+ antialias_activation=False,
+ use_nearest_upsample=False,
+ final_tanh=True):
+ super().__init__()
+
+ c_mults = [1] + c_mults
+
+ self.depth = len(c_mults)
+
+ layers = [
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
+ ]
+
+ for i in range(self.depth-1, 0, -1):
+ layers += [DecoderBlock(
+ in_channels=c_mults[i]*channels,
+ out_channels=c_mults[i-1]*channels,
+ stride=strides[i-1],
+ use_snake=use_snake,
+ antialias_activation=antialias_activation,
+ use_nearest_upsample=use_nearest_upsample
+ )
+ ]
+
+ layers += [
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
+ nn.Tanh() if final_tanh else nn.Identity()
+ ]
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DACEncoderWrapper(nn.Module):
+ def __init__(self, in_channels=1, **kwargs):
+ super().__init__()
+
+ from dac.model.dac import Encoder as DACEncoder
+
+ latent_dim = kwargs.pop("latent_dim", None)
+
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
+ self.latent_dim = latent_dim
+
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
+
+ if in_channels != 1:
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.proj_out(x)
+ return x
+
+class DACDecoderWrapper(nn.Module):
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
+ super().__init__()
+
+ from dac.model.dac import Decoder as DACDecoder
+
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
+
+ self.latent_dim = latent_dim
+
+ def forward(self, x):
+ return self.decoder(x)
+
+class AudioAutoencoder(nn.Module):
+ def __init__(
+ self,
+ encoder,
+ decoder,
+ latent_dim,
+ downsampling_ratio,
+ sample_rate,
+ io_channels=2,
+ bottleneck: Bottleneck = None,
+ pretransform: Pretransform = None,
+ in_channels = None,
+ out_channels = None,
+ soft_clip = False
+ ):
+ super().__init__()
+
+ self.downsampling_ratio = downsampling_ratio
+ self.sample_rate = sample_rate
+
+ self.latent_dim = latent_dim
+ self.io_channels = io_channels
+ self.in_channels = io_channels
+ self.out_channels = io_channels
+
+ self.min_length = self.downsampling_ratio
+
+ if in_channels is not None:
+ self.in_channels = in_channels
+
+ if out_channels is not None:
+ self.out_channels = out_channels
+
+ self.bottleneck = bottleneck
+
+ self.encoder = encoder
+
+ self.decoder = decoder
+
+ self.pretransform = pretransform
+
+ self.soft_clip = soft_clip
+
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
+
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
+
+ info = {}
+ # import ipdb
+ # ipdb.set_trace()
+ if self.pretransform is not None and not skip_pretransform:
+ if self.pretransform.enable_grad:
+ if iterate_batch:
+ audios = []
+ for i in range(audio.shape[0]):
+ audios.append(self.pretransform.encode(audio[i:i+1]))
+ audio = torch.cat(audios, dim=0)
+ else:
+ audio = self.pretransform.encode(audio)
+ else:
+ with torch.no_grad():
+ if iterate_batch:
+ audios = []
+ for i in range(audio.shape[0]):
+ audios.append(self.pretransform.encode(audio[i:i+1]))
+ audio = torch.cat(audios, dim=0)
+ else:
+ audio = self.pretransform.encode(audio)
+
+ if self.encoder is not None:
+ if iterate_batch:
+ latents = []
+ for i in range(audio.shape[0]):
+ latents.append(self.encoder(audio[i:i+1]))
+ latents = torch.cat(latents, dim=0)
+ else:
+ latents = self.encoder(audio)
+ else:
+ latents = audio
+
+ if self.bottleneck is not None:
+ # TODO: Add iterate batch logic, needs to merge the info dicts
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
+
+ info.update(bottleneck_info)
+
+ if return_info:
+ return latents, info
+
+ return latents
+
+ def decode(self, latents, iterate_batch=False, **kwargs):
+
+ if self.bottleneck is not None:
+ if iterate_batch:
+ decoded = []
+ for i in range(latents.shape[0]):
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
+ latents = torch.cat(decoded, dim=0)
+ else:
+ latents = self.bottleneck.decode(latents)
+
+ if iterate_batch:
+ decoded = []
+ for i in range(latents.shape[0]):
+ decoded.append(self.decoder(latents[i:i+1]))
+ decoded = torch.cat(decoded, dim=0)
+ else:
+ decoded = self.decoder(latents, **kwargs)
+
+ if self.pretransform is not None:
+ if self.pretransform.enable_grad:
+ if iterate_batch:
+ decodeds = []
+ for i in range(decoded.shape[0]):
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+ decoded = torch.cat(decodeds, dim=0)
+ else:
+ decoded = self.pretransform.decode(decoded)
+ else:
+ with torch.no_grad():
+ if iterate_batch:
+ decodeds = []
+ for i in range(latents.shape[0]):
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
+ decoded = torch.cat(decodeds, dim=0)
+ else:
+ decoded = self.pretransform.decode(decoded)
+
+ if self.soft_clip:
+ decoded = torch.tanh(decoded)
+
+ return decoded
+
+ def decode_tokens(self, tokens, **kwargs):
+ '''
+ Decode discrete tokens to audio
+ Only works with discrete autoencoders
+ '''
+
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
+
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
+
+ return self.decode(latents, **kwargs)
+
+
+ def preprocess_audio_for_encoder(self, audio, in_sr):
+ '''
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
+ If the model is mono, stereo audio will be converted to mono.
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
+ Audio will be resampled to the model's sample rate.
+ The output will have batch size 1 and be shape (1 x Channels x Length)
+ '''
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
+
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
+ '''
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
+ The audio in that list can be of different lengths and channels.
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
+ All audio will be resampled to the model's sample rate.
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
+ If the model is mono, all audio will be converted to mono.
+ The output will be a tensor of shape (Batch x Channels x Length)
+ '''
+ batch_size = len(audio_list)
+ if isinstance(in_sr_list, int):
+ in_sr_list = [in_sr_list]*batch_size
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
+ new_audio = []
+ max_length = 0
+ # resample & find the max length
+ for i in range(batch_size):
+ audio = audio_list[i]
+ in_sr = in_sr_list[i]
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
+ # batchsize 1 was given by accident. Just squeeze it.
+ audio = audio.squeeze(0)
+ elif len(audio.shape) == 1:
+ # Mono signal, channel dimension is missing, unsqueeze it in
+ audio = audio.unsqueeze(0)
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
+ # Resample audio
+ if in_sr != self.sample_rate:
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
+ audio = resample_tf(audio)
+ new_audio.append(audio)
+ if audio.shape[-1] > max_length:
+ max_length = audio.shape[-1]
+ # Pad every audio to the same length, multiple of model's downsampling ratio
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
+ for i in range(batch_size):
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
+ # convert to tensor
+ return torch.stack(new_audio)
+
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
+ '''
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
+ # and therefore you likely could use the same values with decode_audio.
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+ Smaller chunk_size uses less memory, but more compute.
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+ '''
+ if not chunked:
+ # default behavior. Encode the entire audio in parallel
+ return self.encode(audio, **kwargs)
+ else:
+ # CHUNKED ENCODING
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
+ # import ipdb
+ # ipdb.set_trace()
+ samples_per_latent = self.downsampling_ratio
+ total_size = audio.shape[2] # in samples
+ print(f'audio shape: {audio.shape}')
+ batch_size = audio.shape[0]
+ chunk_size *= samples_per_latent # converting metric in latents to samples
+ overlap *= samples_per_latent # converting metric in latents to samples
+ hop_size = chunk_size - overlap
+ chunks = []
+ for i in range(0, total_size - chunk_size + 1, hop_size):
+ chunk = audio[:,:,i:i+chunk_size]
+ chunks.append(chunk)
+ if i+chunk_size != total_size:
+ # Final chunk
+ chunk = audio[:,:,-chunk_size:]
+ chunks.append(chunk)
+ chunks = torch.stack(chunks)
+ num_chunks = chunks.shape[0]
+ # Note: y_size might be a different value from the latent length used in diffusion training
+ # because we can encode audio of varying lengths
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
+ y_size = total_size // samples_per_latent
+ # Create an empty latent, we will populate it with chunks as we encode them
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
+ print(f'y_final shape: {y_final.shape}')
+ for i in range(num_chunks):
+ x_chunk = chunks[i,:]
+ # encode the chunk
+ y_chunk = self.encode(x_chunk)
+ print(f'y_chunk shape: {y_chunk.shape}')
+ # figure out where to put the audio along the time domain
+ if i == num_chunks-1:
+ # final chunk always goes at the end
+ t_end = y_size
+ t_start = t_end - y_chunk.shape[2]
+ else:
+ t_start = i * hop_size // samples_per_latent
+ t_end = t_start + chunk_size // samples_per_latent
+ # remove the edges of the overlaps
+ ol = overlap//samples_per_latent//2
+ chunk_start = 0
+ chunk_end = y_chunk.shape[2]
+ if i > 0:
+ # no overlap for the start of the first chunk
+ t_start += ol
+ chunk_start += ol
+ if i < num_chunks-1:
+ # no overlap for the end of the last chunk
+ t_end -= ol
+ chunk_end -= ol
+ # paste the chunked audio into our y_final output audio
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+ return y_final
+
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
+ '''
+ Decode latents to audio.
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
+ Smaller chunk_size uses less memory, but more compute.
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
+ '''
+ if not chunked:
+ # default behavior. Decode the entire latent in parallel
+ return self.decode(latents, **kwargs)
+ else:
+ # chunked decoding
+ hop_size = chunk_size - overlap
+ total_size = latents.shape[2]
+ batch_size = latents.shape[0]
+ chunks = []
+ for i in range(0, total_size - chunk_size + 1, hop_size):
+ chunk = latents[:,:,i:i+chunk_size]
+ chunks.append(chunk)
+ if i+chunk_size != total_size:
+ # Final chunk
+ chunk = latents[:,:,-chunk_size:]
+ chunks.append(chunk)
+ chunks = torch.stack(chunks)
+ num_chunks = chunks.shape[0]
+ # samples_per_latent is just the downsampling ratio
+ samples_per_latent = self.downsampling_ratio
+ # Create an empty waveform, we will populate it with chunks as decode them
+ y_size = total_size * samples_per_latent
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
+ for i in range(num_chunks):
+ x_chunk = chunks[i,:]
+ # decode the chunk
+ y_chunk = self.decode(x_chunk)
+ # figure out where to put the audio along the time domain
+ if i == num_chunks-1:
+ # final chunk always goes at the end
+ t_end = y_size
+ t_start = t_end - y_chunk.shape[2]
+ else:
+ t_start = i * hop_size * samples_per_latent
+ t_end = t_start + chunk_size * samples_per_latent
+ # remove the edges of the overlaps
+ ol = (overlap//2) * samples_per_latent
+ chunk_start = 0
+ chunk_end = y_chunk.shape[2]
+ if i > 0:
+ # no overlap for the start of the first chunk
+ t_start += ol
+ chunk_start += ol
+ if i < num_chunks-1:
+ # no overlap for the end of the last chunk
+ t_end -= ol
+ chunk_end -= ol
+ # paste the chunked audio into our y_final output audio
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
+ return y_final
+
+
+class DiffusionAutoencoder(AudioAutoencoder):
+ def __init__(
+ self,
+ diffusion: ConditionedDiffusionModel,
+ diffusion_downsampling_ratio,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.diffusion = diffusion
+
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
+
+ if self.encoder is not None:
+ # Shrink the initial encoder parameters to avoid saturated latents
+ with torch.no_grad():
+ for param in self.encoder.parameters():
+ param *= 0.5
+
+ def decode(self, latents, steps=100):
+
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
+
+ if self.bottleneck is not None:
+ latents = self.bottleneck.decode(latents)
+
+ if self.decoder is not None:
+ latents = self.decode(latents)
+
+ # Upsample latents to match diffusion length
+ if latents.shape[2] != upsampled_length:
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
+
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
+
+ if self.pretransform is not None:
+ if self.pretransform.enable_grad:
+ decoded = self.pretransform.decode(decoded)
+ else:
+ with torch.no_grad():
+ decoded = self.pretransform.decode(decoded)
+
+ return decoded
+
+# AE factories
+
+def create_encoder_from_config(encoder_config: Dict[str, Any]):
+ encoder_type = encoder_config.get("type", None)
+ assert encoder_type is not None, "Encoder type must be specified"
+
+ if encoder_type == "oobleck":
+ encoder = OobleckEncoder(
+ **encoder_config["config"]
+ )
+
+ elif encoder_type == "seanet":
+ from encodec.modules import SEANetEncoder
+ seanet_encoder_config = encoder_config["config"]
+
+ #SEANet encoder expects strides in reverse order
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
+ encoder = SEANetEncoder(
+ **seanet_encoder_config
+ )
+ elif encoder_type == "dac":
+ dac_config = encoder_config["config"]
+
+ encoder = DACEncoderWrapper(**dac_config)
+ elif encoder_type == "local_attn":
+ from .local_attention import TransformerEncoder1D
+
+ local_attn_config = encoder_config["config"]
+
+ encoder = TransformerEncoder1D(
+ **local_attn_config
+ )
+ else:
+ raise ValueError(f"Unknown encoder type {encoder_type}")
+
+ requires_grad = encoder_config.get("requires_grad", True)
+ if not requires_grad:
+ for param in encoder.parameters():
+ param.requires_grad = False
+
+ return encoder
+
+def create_decoder_from_config(decoder_config: Dict[str, Any]):
+ decoder_type = decoder_config.get("type", None)
+ assert decoder_type is not None, "Decoder type must be specified"
+
+ if decoder_type == "oobleck":
+ decoder = OobleckDecoder(
+ **decoder_config["config"]
+ )
+ elif decoder_type == "seanet":
+ from encodec.modules import SEANetDecoder
+
+ decoder = SEANetDecoder(
+ **decoder_config["config"]
+ )
+ elif decoder_type == "dac":
+ dac_config = decoder_config["config"]
+
+ decoder = DACDecoderWrapper(**dac_config)
+ elif decoder_type == "local_attn":
+ from .local_attention import TransformerDecoder1D
+
+ local_attn_config = decoder_config["config"]
+
+ decoder = TransformerDecoder1D(
+ **local_attn_config
+ )
+ else:
+ raise ValueError(f"Unknown decoder type {decoder_type}")
+
+ requires_grad = decoder_config.get("requires_grad", True)
+ if not requires_grad:
+ for param in decoder.parameters():
+ param.requires_grad = False
+
+ return decoder
+
+def create_autoencoder_from_config(config: Dict[str, Any]):
+
+ ae_config = config["model"]
+
+ encoder = create_encoder_from_config(ae_config["encoder"])
+ decoder = create_decoder_from_config(ae_config["decoder"])
+
+ bottleneck = ae_config.get("bottleneck", None)
+
+ latent_dim = ae_config.get("latent_dim", None)
+ assert latent_dim is not None, "latent_dim must be specified in model config"
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
+ io_channels = ae_config.get("io_channels", None)
+ assert io_channels is not None, "io_channels must be specified in model config"
+ sample_rate = config.get("sample_rate", None)
+ assert sample_rate is not None, "sample_rate must be specified in model config"
+
+ in_channels = ae_config.get("in_channels", None)
+ out_channels = ae_config.get("out_channels", None)
+
+ pretransform = ae_config.get("pretransform", None)
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+
+ if bottleneck is not None:
+ bottleneck = create_bottleneck_from_config(bottleneck)
+
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
+
+ return AudioAutoencoder(
+ encoder,
+ decoder,
+ io_channels=io_channels,
+ latent_dim=latent_dim,
+ downsampling_ratio=downsampling_ratio,
+ sample_rate=sample_rate,
+ bottleneck=bottleneck,
+ pretransform=pretransform,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ soft_clip=soft_clip
+ )
+
+def create_diffAE_from_config(config: Dict[str, Any]):
+
+ diffae_config = config["model"]
+
+ if "encoder" in diffae_config:
+ encoder = create_encoder_from_config(diffae_config["encoder"])
+ else:
+ encoder = None
+
+ if "decoder" in diffae_config:
+ decoder = create_decoder_from_config(diffae_config["decoder"])
+ else:
+ decoder = None
+
+ diffusion_model_type = diffae_config["diffusion"]["type"]
+
+ if diffusion_model_type == "DAU1d":
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
+ elif diffusion_model_type == "adp_1d":
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
+ elif diffusion_model_type == "dit":
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
+
+ latent_dim = diffae_config.get("latent_dim", None)
+ assert latent_dim is not None, "latent_dim must be specified in model config"
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
+ io_channels = diffae_config.get("io_channels", None)
+ assert io_channels is not None, "io_channels must be specified in model config"
+ sample_rate = config.get("sample_rate", None)
+ assert sample_rate is not None, "sample_rate must be specified in model config"
+
+ bottleneck = diffae_config.get("bottleneck", None)
+
+ pretransform = diffae_config.get("pretransform", None)
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+
+ if bottleneck is not None:
+ bottleneck = create_bottleneck_from_config(bottleneck)
+
+ diffusion_downsampling_ratio = None,
+
+ if diffusion_model_type == "DAU1d":
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
+ elif diffusion_model_type == "adp_1d":
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
+ elif diffusion_model_type == "dit":
+ diffusion_downsampling_ratio = 1
+
+ return DiffusionAutoencoder(
+ encoder=encoder,
+ decoder=decoder,
+ diffusion=diffusion,
+ io_channels=io_channels,
+ sample_rate=sample_rate,
+ latent_dim=latent_dim,
+ downsampling_ratio=downsampling_ratio,
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
+ bottleneck=bottleneck,
+ pretransform=pretransform
+ )
diff --git a/ThinkSound/models/blocks.py b/ThinkSound/models/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c827fd2441e643717d123847236d3d6c003ef4f
--- /dev/null
+++ b/ThinkSound/models/blocks.py
@@ -0,0 +1,339 @@
+from functools import reduce
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from torch.backends.cuda import sdp_kernel
+from packaging import version
+
+from dac.nn.layers import Snake1d
+
+class ResidualBlock(nn.Module):
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return self.main(input) + self.skip(input)
+
+class ResConvBlock(ResidualBlock):
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
+ super().__init__([
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
+ nn.GroupNorm(1, c_mid),
+ Snake1d(c_mid) if use_snake else nn.GELU(),
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
+ ], skip)
+
+class SelfAttention1d(nn.Module):
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
+ super().__init__()
+ assert c_in % n_head == 0
+ self.norm = nn.GroupNorm(1, c_in)
+ self.n_head = n_head
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
+
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
+
+ if not self.use_flash:
+ return
+
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
+
+ if device_properties.major == 8 and device_properties.minor == 0:
+ # Use flash attention for A100 GPUs
+ self.sdp_kernel_config = (True, False, False)
+ else:
+ # Don't use flash attention for other GPUs
+ self.sdp_kernel_config = (False, True, True)
+
+ def forward(self, input):
+ n, c, s = input.shape
+ qkv = self.qkv_proj(self.norm(input))
+ qkv = qkv.view(
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = k.shape[3]**-0.25
+
+ if self.use_flash:
+ with sdp_kernel(*self.sdp_kernel_config):
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
+ else:
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
+
+
+ return input + self.dropout(self.out_proj(y))
+
+class SkipBlock(nn.Module):
+ def __init__(self, *main):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+
+ def forward(self, input):
+ return torch.cat([self.main(input), input], dim=1)
+
+class FourierFeatures(nn.Module):
+ def __init__(self, in_features, out_features, std=1.):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = nn.Parameter(torch.randn(
+ [out_features // 2, in_features]) * std)
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
+
+def expand_to_planes(input, shape):
+ return input[..., None].repeat([1, 1, shape[2]])
+
+_kernels = {
+ 'linear':
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
+ 'cubic':
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
+ 'lanczos3':
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
+}
+
+class Downsample1d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel])
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer('kernel', kernel_1d)
+ self.channels_last = channels_last
+
+ def forward(self, x):
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ x = F.conv1d(x, weight, stride=2)
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ return x
+
+
+class Upsample1d(nn.Module):
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
+ self.pad = kernel_1d.shape[0] // 2 - 1
+ self.register_buffer('kernel', kernel_1d)
+ self.channels_last = channels_last
+
+ def forward(self, x):
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
+ indices = torch.arange(x.shape[1], device=x.device)
+ weight[indices, indices] = self.kernel.to(weight)
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
+ if self.channels_last:
+ x = x.permute(0, 2, 1)
+ return x
+
+def Downsample1d_2(
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
+) -> nn.Module:
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
+
+ return nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * kernel_multiplier + 1,
+ stride=factor,
+ padding=factor * (kernel_multiplier // 2),
+ )
+
+
+def Upsample1d_2(
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
+) -> nn.Module:
+
+ if factor == 1:
+ return nn.Conv1d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
+ )
+
+ if use_nearest:
+ return nn.Sequential(
+ nn.Upsample(scale_factor=factor, mode="nearest"),
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ )
+ else:
+ return nn.ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=factor * 2,
+ stride=factor,
+ padding=factor // 2 + factor % 2,
+ output_padding=factor % 2,
+ )
+
+def zero_init(layer):
+ nn.init.zeros_(layer.weight)
+ if layer.bias is not None:
+ nn.init.zeros_(layer.bias)
+ return layer
+
+def rms_norm(x, scale, eps):
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+ return x * scale.to(x.dtype)
+
+#rms_norm = torch.compile(rms_norm)
+
+class AdaRMSNorm(nn.Module):
+ def __init__(self, features, cond_features, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
+
+ def extra_repr(self):
+ return f"eps={self.eps},"
+
+ def forward(self, x, cond):
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
+
+def normalize(x, eps=1e-4):
+ dim = list(range(1, x.ndim))
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
+ alpha = np.sqrt(n.numel() / x.numel())
+ return x / torch.add(eps, n, alpha=alpha)
+
+class ForcedWNConv1d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
+
+ def forward(self, x):
+ if self.training:
+ with torch.no_grad():
+ self.weight.copy_(normalize(self.weight))
+
+ fan_in = self.weight[0].numel()
+
+ w = normalize(self.weight) / math.sqrt(fan_in)
+
+ return F.conv1d(x, w, padding='same')
+
+# Kernels
+
+use_compile = True
+
+def compile(function, *args, **kwargs):
+ if not use_compile:
+ return function
+ try:
+ return torch.compile(function, *args, **kwargs)
+ except RuntimeError:
+ return function
+
+
+@compile
+def linear_geglu(x, weight, bias=None):
+ x = x @ weight.mT
+ if bias is not None:
+ x = x + bias
+ x, gate = x.chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+@compile
+def rms_norm(x, scale, eps):
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
+ return x * scale.to(x.dtype)
+
+# Layers
+
+class LinearGEGLU(nn.Linear):
+ def __init__(self, in_features, out_features, bias=True):
+ super().__init__(in_features, out_features * 2, bias=bias)
+ self.out_features = out_features
+
+ def forward(self, x):
+ return linear_geglu(x, self.weight, self.bias)
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
+ super().__init__()
+ self.eps = eps
+
+ if fix_scale:
+ self.register_buffer("scale", torch.ones(shape))
+ else:
+ self.scale = nn.Parameter(torch.ones(shape))
+
+ def extra_repr(self):
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
+
+ def forward(self, x):
+ return rms_norm(x, self.scale, self.eps)
+
+def snake_beta(x, alpha, beta):
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
+
+# try:
+# snake_beta = torch.compile(snake_beta)
+# except RuntimeError:
+# pass
+
+# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
+# License available in LICENSES/LICENSE_NVIDIA.txt
+class SnakeBeta(nn.Module):
+
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
+ super(SnakeBeta, self).__init__()
+ self.in_features = in_features
+
+ # initialize alpha
+ self.alpha_logscale = alpha_logscale
+ if self.alpha_logscale: # log scale alphas initialized to zeros
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
+ else: # linear scale alphas initialized to ones
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
+
+ self.alpha.requires_grad = alpha_trainable
+ self.beta.requires_grad = alpha_trainable
+
+ self.no_div_by_zero = 0.000000001
+
+ def forward(self, x):
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
+ if self.alpha_logscale:
+ alpha = torch.exp(alpha)
+ beta = torch.exp(beta)
+ x = snake_beta(x, alpha, beta)
+
+ return x
\ No newline at end of file
diff --git a/ThinkSound/models/bottleneck.py b/ThinkSound/models/bottleneck.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e81cab4bfb16b615ee21d5e9248e3b455f7eb5b
--- /dev/null
+++ b/ThinkSound/models/bottleneck.py
@@ -0,0 +1,355 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from einops import rearrange
+from vector_quantize_pytorch import ResidualVQ, FSQ
+from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
+
+class Bottleneck(nn.Module):
+ def __init__(self, is_discrete: bool = False):
+ super().__init__()
+
+ self.is_discrete = is_discrete
+
+ def encode(self, x, return_info=False, **kwargs):
+ raise NotImplementedError
+
+ def decode(self, x):
+ raise NotImplementedError
+
+class DiscreteBottleneck(Bottleneck):
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
+ super().__init__(is_discrete=True)
+
+ self.num_quantizers = num_quantizers
+ self.codebook_size = codebook_size
+ self.tokens_id = tokens_id
+
+ def decode_tokens(self, codes, **kwargs):
+ raise NotImplementedError
+
+class TanhBottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+ self.tanh = nn.Tanh()
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x = torch.tanh(x)
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+def vae_sample(mean, scale):
+ stdev = nn.functional.softplus(scale) + 1e-4
+ var = stdev * stdev
+ logvar = torch.log(var)
+ latents = torch.randn_like(mean) * stdev + mean
+
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
+
+ return latents, kl
+
+class VAEBottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+
+ def encode(self, x, return_info=False, **kwargs):
+ info = {}
+
+ mean, scale = x.chunk(2, dim=1)
+
+ x, kl = vae_sample(mean, scale)
+
+ info["kl"] = kl
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+def compute_mean_kernel(x, y):
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
+ return torch.exp(-kernel_input).mean()
+
+def compute_mmd(latents):
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
+ noise = torch.randn_like(latents_reshaped)
+
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
+ noise_kernel = compute_mean_kernel(noise, noise)
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
+
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
+ return mmd.mean()
+
+class WassersteinBottleneck(Bottleneck):
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
+ super().__init__(is_discrete=False)
+
+ self.noise_augment_dim = noise_augment_dim
+ self.bypass_mmd = bypass_mmd
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ if self.training and return_info:
+ if self.bypass_mmd:
+ mmd = torch.tensor(0.0)
+ else:
+ mmd = compute_mmd(x)
+
+ info["mmd"] = mmd
+
+ if return_info:
+ return x, info
+
+ return x
+
+ def decode(self, x):
+
+ if self.noise_augment_dim > 0:
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
+ x.shape[-1]).type_as(x)
+ x = torch.cat([x, noise], dim=1)
+
+ return x
+
+class L2Bottleneck(Bottleneck):
+ def __init__(self):
+ super().__init__(is_discrete=False)
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x = F.normalize(x, dim=1)
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return F.normalize(x, dim=1)
+
+class RVQBottleneck(DiscreteBottleneck):
+ def __init__(self, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+ def encode(self, x, return_info=False, **kwargs):
+ info = {}
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices, loss = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ info["quantizer_indices"] = indices
+ info["quantizer_loss"] = loss.mean()
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents = self.quantizer.get_outputs_from_indices(codes)
+
+ return self.decode(latents, **kwargs)
+
+class RVQVAEBottleneck(DiscreteBottleneck):
+ def __init__(self, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ x, kl = vae_sample(*x.chunk(2, dim=1))
+
+ info["kl"] = kl
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices, loss = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ info["quantizer_indices"] = indices
+ info["quantizer_loss"] = loss.mean()
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents = self.quantizer.get_outputs_from_indices(codes)
+
+ return self.decode(latents, **kwargs)
+
+class DACRVQBottleneck(DiscreteBottleneck):
+ def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
+ self.quantize_on_decode = quantize_on_decode
+ self.noise_augment_dim = noise_augment_dim
+
+ def encode(self, x, return_info=False, **kwargs):
+ info = {}
+
+ info["pre_quantizer"] = x
+
+ if self.quantize_on_decode:
+ return x, info if return_info else x
+
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
+
+ output = {
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+ output["vq/commitment_loss"] /= self.num_quantizers
+ output["vq/codebook_loss"] /= self.num_quantizers
+
+ info.update(output)
+
+ if return_info:
+ return output["z"], info
+
+ return output["z"]
+
+ def decode(self, x):
+
+ if self.quantize_on_decode:
+ x = self.quantizer(x)[0]
+
+ if self.noise_augment_dim > 0:
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
+ x.shape[-1]).type_as(x)
+ x = torch.cat([x, noise], dim=1)
+
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents, _, _ = self.quantizer.from_codes(codes)
+
+ return self.decode(latents, **kwargs)
+
+class DACRVQVAEBottleneck(DiscreteBottleneck):
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
+ self.quantize_on_decode = quantize_on_decode
+
+ def encode(self, x, return_info=False, n_quantizers: int = None):
+ info = {}
+
+ mean, scale = x.chunk(2, dim=1)
+
+ x, kl = vae_sample(mean, scale)
+
+ info["pre_quantizer"] = x
+ info["kl"] = kl
+
+ if self.quantize_on_decode:
+ return x, info if return_info else x
+
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
+
+ output = {
+ "z": z,
+ "codes": codes,
+ "latents": latents,
+ "vq/commitment_loss": commitment_loss,
+ "vq/codebook_loss": codebook_loss,
+ }
+
+ output["vq/commitment_loss"] /= self.num_quantizers
+ output["vq/codebook_loss"] /= self.num_quantizers
+
+ info.update(output)
+
+ if return_info:
+ return output["z"], info
+
+ return output["z"]
+
+ def decode(self, x):
+
+ if self.quantize_on_decode:
+ x = self.quantizer(x)[0]
+
+ return x
+
+ def decode_tokens(self, codes, **kwargs):
+ latents, _, _ = self.quantizer.from_codes(codes)
+
+ return self.decode(latents, **kwargs)
+
+class FSQBottleneck(DiscreteBottleneck):
+ def __init__(self, noise_augment_dim=0, **kwargs):
+ super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
+
+ self.noise_augment_dim = noise_augment_dim
+
+ self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
+
+ def encode(self, x, return_info=False):
+ info = {}
+
+ orig_dtype = x.dtype
+ x = x.float()
+
+ x = rearrange(x, "b c n -> b n c")
+ x, indices = self.quantizer(x)
+ x = rearrange(x, "b n c -> b c n")
+
+ x = x.to(orig_dtype)
+
+ # Reorder indices to match the expected format
+ indices = rearrange(indices, "b n q -> b q n")
+
+ info["quantizer_indices"] = indices
+
+ if return_info:
+ return x, info
+ else:
+ return x
+
+ def decode(self, x):
+
+ if self.noise_augment_dim > 0:
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
+ x.shape[-1]).type_as(x)
+ x = torch.cat([x, noise], dim=1)
+
+ return x
+
+ def decode_tokens(self, tokens, **kwargs):
+ latents = self.quantizer.indices_to_codes(tokens)
+
+ return self.decode(latents, **kwargs)
\ No newline at end of file
diff --git a/ThinkSound/models/codebook_patterns.py b/ThinkSound/models/codebook_patterns.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9bd2a9b837bd77cb40f3b500b02ea491dfb9da0
--- /dev/null
+++ b/ThinkSound/models/codebook_patterns.py
@@ -0,0 +1,545 @@
+# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
+# License available in LICENSES/LICENSE_META.txt
+
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import lru_cache
+import logging
+import typing as tp
+
+from abc import ABC, abstractmethod
+import torch
+
+LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
+PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Pattern:
+ """Base implementation of a pattern over a sequence with multiple codebooks.
+
+ The codebook pattern consists in a layout, defining for each sequence step
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+ The first item of the pattern is always an empty list in order to properly insert a special token
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+ is returned along with a mask indicating valid tokens.
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+ to fill and specify invalid positions if needed.
+ See the dedicated methods for more details.
+ """
+ # Pattern layout, for each sequence step, we have a list of coordinates
+ # corresponding to the original codebook timestep and position.
+ # The first list is always an empty list in order to properly insert
+ # a special token to start with.
+ layout: PatternLayout
+ timesteps: int
+ n_q: int
+
+ def __post_init__(self):
+ assert len(self.layout) > 0
+ self._validate_layout()
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+ def _validate_layout(self):
+ """Runs checks on the layout to ensure a valid pattern is defined.
+ A pattern is considered invalid if:
+ - Multiple timesteps for a same codebook are defined in the same sequence step
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+ (this would mean that we have future timesteps before past timesteps).
+ """
+ q_timesteps = {q: 0 for q in range(self.n_q)}
+ for s, seq_coords in enumerate(self.layout):
+ if len(seq_coords) > 0:
+ qs = set()
+ for coord in seq_coords:
+ qs.add(coord.q)
+ last_q_timestep = q_timesteps[coord.q]
+ assert coord.t >= last_q_timestep, \
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+ q_timesteps[coord.q] = coord.t
+ # each sequence step contains at max 1 coordinate per codebook
+ assert len(qs) == len(seq_coords), \
+ f"Multiple entries for a same codebook are found at step {s}"
+
+ @property
+ def num_sequence_steps(self):
+ return len(self.layout) - 1
+
+ @property
+ def max_delay(self):
+ max_t_in_seq_coords = 0
+ for seq_coords in self.layout[1:]:
+ for coords in seq_coords:
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+ return max_t_in_seq_coords - self.timesteps
+
+ @property
+ def valid_layout(self):
+ valid_step = len(self.layout) - self.max_delay
+ return self.layout[:valid_step]
+
+ def starts_with_special_token(self):
+ return self.layout[0] == []
+
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+ and the actual codebook coordinates.
+ """
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+ if q is not None:
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+ coords = []
+ for s, seq_codes in enumerate(self.layout):
+ for code in seq_codes:
+ if code.t == t and (q is None or code.q == q):
+ coords.append((s, code))
+ return coords
+
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+ Args:
+ timesteps (int): Maximum number of timesteps steps to consider.
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+ device (torch.device or str): Device for created tensors.
+ Returns:
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+ """
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+ # which will correspond to the index: n_q * timesteps
+ indexes[:] = n_q * timesteps
+ # iterate over the pattern and fill scattered indexes and mask
+ for s, sequence_coords in enumerate(ref_layout):
+ for coords in sequence_coords:
+ if coords.t < timesteps:
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
+ mask[coords.q, s] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Build sequence corresponding to the pattern from the input tensor z.
+ The sequence is built using up to sequence_steps if specified, and non-pattern
+ coordinates are filled with the special token.
+
+ Args:
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+ """
+ B, K, T = z.shape
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+ )
+ z = z.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+ values = z[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+ keep_only_valid_steps: bool = False,
+ is_model_output: bool = False,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
+ from interleaving pattern.
+
+ Args:
+ sequence_steps (int): Sequence steps.
+ n_q (int): Number of codebooks.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+ device (torch.device or str): Device for created tensors.
+ Returns:
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+ timesteps = self.timesteps
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert sequence_steps <= len(ref_layout), \
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
+ if is_model_output and self.starts_with_special_token():
+ ref_layout = ref_layout[1:]
+
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ indexes[:] = n_q * sequence_steps
+ for s, sequence_codes in enumerate(ref_layout):
+ if s < sequence_steps:
+ for code in sequence_codes:
+ if code.t < timesteps:
+ indexes[code.q, code.t] = s + code.q * sequence_steps
+ mask[code.q, code.t] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+ are filled with the special token.
+
+ Args:
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ B, K, S = s.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+ )
+ s = s.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+ values = s[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+ """Revert model logits obtained on a sequence built from the pattern
+ back to a tensor matching the original sequence.
+
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
+ 1. It is designed to work with the extra cardinality dimension
+ 2. We return the logits for the first sequence item that matches the special_token and
+ which matching target in the original sequence is the first item of the sequence,
+ while we skip the last logits as there is no matching target
+ """
+ B, card, K, S = logits.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+ )
+ logits = logits.reshape(B, card, -1)
+ # we append the special token as the last index of our flattened z tensor
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
+ values = logits[:, :, indexes.view(-1)]
+ values = values.view(B, card, K, indexes.shape[-1])
+ return values, indexes, mask
+
+
+class CodebooksPatternProvider(ABC):
+ """Abstraction around providing pattern for interleaving codebooks.
+
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+ can be used to construct a new sequence from the original codes respecting the specified
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+ being a tuple with the original timestep and codebook to build the new sequence.
+ Note that all patterns must start with an empty list that is then used to insert a first
+ sequence step of special tokens in the newly generated sequence.
+
+ Args:
+ n_q (int): number of codebooks.
+ cached (bool): if True, patterns for a given length are cached. In general
+ that should be true for efficiency reason to avoid synchronization points.
+ """
+ def __init__(self, n_q: int, cached: bool = True):
+ assert n_q > 0
+ self.n_q = n_q
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
+
+ @abstractmethod
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern with specific interleaving between codebooks.
+
+ Args:
+ timesteps (int): Total number of timesteps.
+ """
+ raise NotImplementedError()
+
+
+class DelayedPatternProvider(CodebooksPatternProvider):
+ """Provider for delayed pattern across delayed codebooks.
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
+ from different timesteps.
+
+ Example:
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ The resulting sequence obtained from the returned pattern is:
+ [[S, 1, 2, 3, 4],
+ [S, S, 1, 2, 3],
+ [S, S, S, 1, 2]]
+ (with S being a special token)
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (list of int, optional): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ flatten_first (int): Flatten the first N timesteps.
+ empty_initial (int): Prepend with N empty list of coordinates.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+ flatten_first: int = 0, empty_initial: int = 0):
+ super().__init__(n_q)
+ if delays is None:
+ delays = list(range(n_q))
+ self.delays = delays
+ self.flatten_first = flatten_first
+ self.empty_initial = empty_initial
+ assert len(self.delays) == self.n_q
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ omit_special_token = self.empty_initial < 0
+ out: PatternLayout = [] if omit_special_token else [[]]
+ max_delay = max(self.delays)
+ if self.empty_initial:
+ out += [[] for _ in range(self.empty_initial)]
+ if self.flatten_first:
+ for t in range(min(timesteps, self.flatten_first)):
+ for q in range(self.n_q):
+ out.append([LayoutCoord(t, q)])
+ for t in range(self.flatten_first, timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= self.flatten_first:
+ v.append(LayoutCoord(t_for_q, q))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class ParallelPatternProvider(DelayedPatternProvider):
+ """Provider for parallel pattern across codebooks.
+ This pattern provider is a special case of the delayed pattern with actually no delay,
+ hence delays=repeat(0, n_q).
+
+ Args:
+ n_q (int): Number of codebooks.
+ empty_initial (int): Prepend with N empty list of coordinates.
+ """
+ def __init__(self, n_q: int, empty_initial: int = 0):
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
+
+
+class UnrolledPatternProvider(CodebooksPatternProvider):
+ """Provider for unrolling codebooks pattern.
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
+ while also specifying a given delay between the flattened codebooks representation, allowing to
+ unroll the codebooks in the sequence.
+
+ Example:
+ 1. Flattening of the codebooks.
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+ taking n_q = 3 and timesteps = 4:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+ and delays = [0, 3, 3]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
+ [S, S, S, 1, S, 2, S, 3, S, 4],
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+ Args:
+ n_q (int): Number of codebooks.
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+ have n_q extra steps for each timestep.
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
+ no delay is added and therefore will default to [0] * ``n_q``.
+ Note that two codebooks that will be flattened to the same inner step
+ should have the same delay, otherwise the pattern is considered as invalid.
+ """
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+ delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if flattening is None:
+ flattening = list(range(n_q))
+ if delays is None:
+ delays = [0] * n_q
+ assert len(flattening) == n_q
+ assert len(delays) == n_q
+ assert sorted(flattening) == flattening
+ assert sorted(delays) == delays
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+ self.max_delay = max(delays)
+
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+ """Build a flattened codebooks representation as a dictionary of inner step
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+ """
+ flattened_codebooks: dict = {}
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+ if inner_step not in flattened_codebooks:
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+ else:
+ flat_codebook = flattened_codebooks[inner_step]
+ assert flat_codebook.delay == delay, (
+ "Delay and flattening between codebooks is inconsistent: ",
+ "two codebooks flattened to the same position should have the same delay."
+ )
+ flat_codebook.codebooks.append(q)
+ flattened_codebooks[inner_step] = flat_codebook
+ return flattened_codebooks
+
+ @property
+ def _num_inner_steps(self):
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+ """
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+ def num_virtual_steps(self, timesteps: int) -> int:
+ return timesteps * self._num_inner_steps + 1
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern for delay across codebooks.
+
+ Args:
+ timesteps (int): Total number of timesteps.
+ """
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
+ indexed_out: list = [(-1, [])]
+ max_timesteps = timesteps + self.max_delay
+ for t in range(max_timesteps):
+ # for each timestep, we unroll the flattened codebooks,
+ # emitting the sequence step with the corresponding delay
+ for step in range(self._num_inner_steps):
+ if step in self._flattened_codebooks:
+ # we have codebooks at this virtual step to emit
+ step_codebooks = self._flattened_codebooks[step]
+ t_for_q = t + step_codebooks.delay
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+ if t_for_q < max_timesteps and t < max_timesteps:
+ indexed_out.append((t_for_q, coords))
+ else:
+ # there is no codebook in this virtual step so we emit an empty list
+ indexed_out.append((t, []))
+ out = [coords for _, coords in sorted(indexed_out)]
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class CoarseFirstPattern(CodebooksPatternProvider):
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
+ potentially with delays.
+
+ ..Warning:: You must always generate the full training duration at test time, for instance,
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
+ location. This is due to the non causality of the remaining codebooks with respect to
+ the first ones.
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (list of int, optional): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if delays is None:
+ delays = [0] * (n_q - 1)
+ self.delays = delays
+ assert len(self.delays) == self.n_q - 1
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for t in range(timesteps):
+ out.append([LayoutCoord(t, 0)])
+ max_delay = max(self.delays)
+ for t in range(timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= 0:
+ v.append(LayoutCoord(t_for_q, q + 1))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class MusicLMPattern(CodebooksPatternProvider):
+ """Almost MusicLM style pattern. This is equivalent to full flattening
+ but in a different order.
+
+ Args:
+ n_q (int): Number of codebooks.
+ group_by (int): Number of codebooks to group together.
+ """
+ def __init__(self, n_q: int, group_by: int = 2):
+ super().__init__(n_q)
+ self.group_by = group_by
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for offset in range(0, self.n_q, self.group_by):
+ for t in range(timesteps):
+ for q in range(offset, offset + self.group_by):
+ out.append([LayoutCoord(t, q)])
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
\ No newline at end of file
diff --git a/ThinkSound/models/conditioners.py b/ThinkSound/models/conditioners.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd4efc657b815092710427ae119f449049b1cf18
--- /dev/null
+++ b/ThinkSound/models/conditioners.py
@@ -0,0 +1,1082 @@
+#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
+
+import torch
+import logging, warnings
+import string
+import typing as tp
+import gc
+from typing import Literal, Optional
+import os
+from .adp import NumberEmbedder
+from ..inference.utils import set_audio_channels
+from .factory import create_pretransform_from_config
+from .pretransforms import Pretransform
+from ..training.utils import copy_state_dict
+from .utils import load_ckpt_state_dict
+import numpy as np
+from einops import rearrange
+from transformers import AutoProcessor, AutoModel
+from torch import nn
+import torch.nn.functional as F
+from .mmmodules.model.low_level import ConvMLP, MLP
+from torch.nn.utils.rnn import pad_sequence
+
+class Conditioner(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ output_dim: int,
+ project_out: bool = False
+ ):
+
+ super().__init__()
+
+ self.dim = dim
+ self.output_dim = output_dim
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
+
+ def forward(self, x: tp.Any) -> tp.Any:
+ raise NotImplementedError()
+
+class Cond_MLP(Conditioner):
+ def __init__(self, dim, output_dim, dropout = 0.0):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ self.dropout = dropout
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = pad_sequence(x, batch_first=True).to(device)
+ # x = torch.stack(x, dim=0).to(device)
+
+ if self.dropout > 0.0:
+ if self.training:
+ null_embed = torch.zeros_like(x, device=device)
+ dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool)
+ x = torch.where(dropout_mask, null_embed, x)
+ elif x.shape[0] < 16: # default test batch size=1
+ null_embed = torch.zeros_like(x, device=device)
+ x = torch.cat([x, null_embed], dim=0)
+
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Global_MLP(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+ x = x.mean(dim=1)
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Cond_MLP_1(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim),
+ nn.SiLU(),
+ MLP(output_dim, output_dim * 4),
+ )
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Cond_MLP_Global(Conditioner):
+ def __init__(self, dim, output_dim, dropout = 0.0):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ self.global_embedder = nn.Sequential(
+ nn.Linear(output_dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ self.dropout = dropout
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+ if self.dropout > 0 and self.training:
+ null_embed = torch.zeros_like(x, device=device)
+ dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool)
+ x = torch.where(dropout_mask, null_embed, x)
+ x = self.embedder(x) # B x 117 x C
+ global_x = self.global_embedder(x[:,0,:])
+ return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)]
+
+class Cond_MLP_Global_1(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim),
+ nn.SiLU(),
+ MLP(output_dim, output_dim * 4),
+ )
+ self.global_embedder = nn.Sequential(
+ nn.Linear(dim, output_dim),
+ MLP(output_dim, output_dim * 4),
+ )
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ global_x = self.global_embedder(x.mean(dim=1))
+ return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)]
+
+class Cond_MLP_Global_2(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ self.global_embedder = nn.Sequential(
+ nn.Linear(output_dim, output_dim, bias=False),
+ )
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ global_x = self.global_embedder(x.mean(dim=1))
+ return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)]
+
+class Sync_MLP(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=False)
+ )
+ self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, dim)))
+ nn.init.constant_(self.sync_pos_emb, 0)
+ def forward(self, x, device: tp.Any = "cuda"):
+ sync_f = torch.stack(x, dim=0).to(device)
+ bs, length, dim = sync_f.shape
+ #print(sync_f.shape,flush=True)
+ # B * num_segments (24) * 8 * 768
+ num_sync_segments = length // 8
+ sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb
+ sync_f = sync_f.flatten(1, 2) # (B, VN, D)
+ x = self.embedder(sync_f) # B x 117 x C
+ x = x.transpose(1,2)
+ x = F.interpolate(x, ((int)(194*sync_f.shape[1]/216), ), mode='linear', align_corners=False)
+ x = x.transpose(1,2)
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Cond_ConvMLP(Conditioner):
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(
+ nn.Linear(dim, output_dim),
+ nn.SiLU(),
+ ConvMLP(output_dim, output_dim * 4, kernel_size=1, padding=0),
+ )
+ def forward(self, x, device: tp.Any = "cuda"):
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Video_Global(Conditioner):
+ """ Transform the video feat encoder"""
+
+ def __init__(self, dim, output_dim, global_dim=1536):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
+ self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
+
+ def forward(self, x, device: tp.Any = "cuda"):
+ # import ipdb
+ # ipdb.set_trace()
+ if not isinstance(x[0], torch.Tensor):
+ video_feats = []
+ for path in x:
+ if '.npy' in path:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ elif '.pth' in path:
+ data = torch.load(path)
+ video_feats.append(data['metaclip_features'].to(device))
+ else:
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
+ x = torch.stack(video_feats, dim=0).to(device)
+ else:
+ # Revise the shape here:
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ global_x = self.global_proj(x.mean(dim=1))
+ return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)]
+
+class Video_Sync(Conditioner):
+ """ Transform the video feat encoder"""
+
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
+
+ def forward(self, x, device: tp.Any = "cuda"):
+ # import ipdb
+ # ipdb.set_trace()
+ if not isinstance(x[0], torch.Tensor):
+ video_feats = []
+ for path in x:
+ if '.npy' in path:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ elif '.pth' in path:
+ video_feats.append(torch.load(path)['sync_features'].to(device))
+ else:
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
+ x = torch.stack(video_feats, dim=0).to(device)
+ else:
+ # Revise the shape here:
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class Text_Linear(Conditioner):
+ """ Transform the video feat encoder"""
+
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
+
+ def forward(self, x, device: tp.Any = "cuda"):
+ # import ipdb
+ # ipdb.set_trace()
+ if not isinstance(x[0], torch.Tensor):
+ video_feats = []
+ for path in x:
+ if '.npy' in path:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ elif '.pth' in path:
+ video_feats.append(torch.load(path)['metaclip_text_features'].to(device))
+ else:
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
+ x = torch.stack(video_feats, dim=0).to(device)
+ else:
+ # Revise the shape here:
+ x = torch.stack(x, dim=0).to(device)
+
+ x = self.embedder(x) # B x 117 x C
+ return [x, torch.ones(x.shape[0], 1).to(device)]
+
+class mm_unchang(Conditioner):
+ """ Transform the video feat encoder"""
+
+ def __init__(self, dim, output_dim):
+ super().__init__(dim, output_dim)
+
+ def forward(self, x, device: tp.Any = "cuda"):
+ # import ipdb
+ # ipdb.set_trace()
+ if not isinstance(x[0], torch.Tensor):
+ video_feats = []
+ for path in x:
+ if '.npy' in path:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ elif '.pth' in path:
+ video_feats.append(torch.load(path)['metaclip_features'].to(device))
+ else:
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
+ x = torch.stack(video_feats, dim=0).to(device)
+ else:
+ # Revise the shape here:
+ x = torch.stack(x, dim=0).to(device)
+ return [x]
+
+class CLIPConditioner(Conditioner):
+
+ CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"]
+
+ CLIP_MODEL_DIMS = {
+ "metaclip-base": 512,
+ "metaclip-b16": 512,
+ "metaclip-large": 768,
+ "metaclip-huge": 1024,
+ }
+
+ def __init__(
+ self,
+ dim: int,
+ output_dim: int,
+ clip_model_name: str = "metaclip-huge",
+ enable_grad: bool = False,
+ project_out: bool = False
+ ):
+ assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}"
+ super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out)
+
+ self.enable_grad = enable_grad
+ model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
+
+
+
+ if self.enable_grad:
+ self.model = model
+ else:
+ self.__dict__["model"] = model
+
+
+ def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.model.to(device)
+ self.proj_out.to(device)
+ # import ipdb
+ # ipdb.set_trace()
+
+ self.model.eval()
+ if not isinstance(images[0], torch.Tensor):
+ video_feats = []
+ for path in images:
+ if '.npy' in path:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ else:
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
+ images = torch.stack(video_feats, dim=0).to(device)
+ else:
+ images = torch.stack(images, dim=0).to(device)
+ bsz, t, c, h, w = images.shape
+ # 使用 rearrange 进行维度合并
+ images = rearrange(images, 'b t c h w -> (b t) c h w')
+ with torch.set_grad_enabled(self.enable_grad):
+ image_features = self.model.get_image_features(images)
+ image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t)
+ image_features = self.proj_out(image_features)
+
+
+ return [image_features, torch.ones(image_features.shape[0], 1).to(device)]
+
+class IntConditioner(Conditioner):
+ def __init__(self,
+ output_dim: int,
+ min_val: int=0,
+ max_val: int=512
+ ):
+ super().__init__(output_dim, output_dim)
+
+ self.min_val = min_val
+ self.max_val = max_val
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
+
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
+
+ #self.int_embedder.to(device)
+
+ ints = torch.tensor(ints).to(device)
+ ints = ints.clamp(self.min_val, self.max_val)
+
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
+
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
+
+class NumberConditioner(Conditioner):
+ '''
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
+ '''
+ def __init__(self,
+ output_dim: int,
+ min_val: float=0,
+ max_val: float=1
+ ):
+ super().__init__(output_dim, output_dim)
+
+ self.min_val = min_val
+ self.max_val = max_val
+
+ self.embedder = NumberEmbedder(features=output_dim)
+
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
+
+ # Cast the inputs to floats
+ floats = [float(x) for x in floats]
+
+ floats = torch.tensor(floats).to(device)
+
+ floats = floats.clamp(self.min_val, self.max_val)
+
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
+
+ # Cast floats to same type as embedder
+ embedder_dtype = next(self.embedder.parameters()).dtype
+ normalized_floats = normalized_floats.to(embedder_dtype)
+
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
+
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
+
+class CLAPTextConditioner(Conditioner):
+ def __init__(self,
+ output_dim: int,
+ clap_ckpt_path,
+ use_text_features = False,
+ feature_layer_ix: int = -1,
+ audio_model_type="HTSAT-base",
+ enable_fusion=True,
+ project_out: bool = False,
+ finetune: bool = False):
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
+
+ self.use_text_features = use_text_features
+ self.feature_layer_ix = feature_layer_ix
+ self.finetune = finetune
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ import laion_clap
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
+
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
+
+ if self.finetune:
+ self.model = model
+ else:
+ self.__dict__["model"] = model
+
+ state_dict = clap_load_state_dict(clap_ckpt_path)
+ self.model.model.load_state_dict(state_dict, strict=False)
+
+ if self.finetune:
+ self.model.model.text_branch.requires_grad_(True)
+ self.model.model.text_branch.train()
+ else:
+ self.model.model.text_branch.requires_grad_(False)
+ self.model.model.text_branch.eval()
+
+ finally:
+ logging.disable(previous_level)
+
+ del self.model.model.audio_branch
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
+ prompt_tokens = self.model.tokenizer(prompts)
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
+ prompt_features = self.model.model.text_branch(
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=attention_mask,
+ output_hidden_states=True
+ )["hidden_states"][layer_ix]
+
+ return prompt_features, attention_mask
+
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
+ self.model.to(device)
+
+ if self.use_text_features:
+ if len(texts) == 1:
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
+ text_features = text_features[:1, ...]
+ text_attention_mask = text_attention_mask[:1, ...]
+ else:
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
+ return [self.proj_out(text_features), text_attention_mask]
+
+ # Fix for CLAP bug when only one text is passed
+ if len(texts) == 1:
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
+ else:
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
+
+ text_embedding = text_embedding.unsqueeze(1).to(device)
+
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
+
+class CLAPAudioConditioner(Conditioner):
+ def __init__(self,
+ output_dim: int,
+ clap_ckpt_path,
+ audio_model_type="HTSAT-base",
+ enable_fusion=True,
+ project_out: bool = False):
+ super().__init__(512, output_dim, project_out=project_out)
+
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ import laion_clap
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
+
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
+
+ self.model = model
+
+ state_dict = clap_load_state_dict(clap_ckpt_path)
+ self.model.model.load_state_dict(state_dict, strict=False)
+
+ self.model.model.audio_branch.requires_grad_(False)
+ self.model.model.audio_branch.eval()
+
+ finally:
+ logging.disable(previous_level)
+
+ del self.model.model.text_branch
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
+
+ self.model.to(device)
+
+ if isinstance(audios, list) or isinstance(audios, tuple):
+ audios = torch.cat(audios, dim=0)
+
+ # Convert to mono
+ mono_audios = audios.mean(dim=1)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
+
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
+
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
+
+class T5Conditioner(Conditioner):
+
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+ "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"]
+
+ T5_MODEL_DIMS = {
+ "t5-small": 512,
+ "t5-base": 768,
+ "t5-large": 1024,
+ "t5-3b": 1024,
+ "t5-11b": 1024,
+ "t5-v1_1-xl": 2048,
+ "google/t5-v1_1-xxl": 4096,
+ "google/flan-t5-small": 512,
+ "google/flan-t5-base": 768,
+ "google/flan-t5-large": 1024,
+ "google/flan-t5-3b": 1024,
+ "google/flan-t5-11b": 1024,
+ "google/flan-t5-xl": 2048,
+ "google/flan-t5-xxl": 4096,
+ }
+
+ def __init__(
+ self,
+ output_dim: int,
+ t5_model_name: str = "t5-base",
+ max_length: str = 77,
+ enable_grad: bool = False,
+ project_out: bool = False
+ ):
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
+
+ from transformers import T5EncoderModel, AutoTokenizer
+
+ self.max_length = max_length
+ self.enable_grad = enable_grad
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
+ # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
+ self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name))
+ model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
+ finally:
+ logging.disable(previous_level)
+
+ if self.enable_grad:
+ self.model = model
+ else:
+ self.__dict__["model"] = model
+
+
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.model.to(device)
+ self.proj_out.to(device)
+ encoded = self.tokenizer(
+ texts,
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ input_ids = encoded["input_ids"].to(device)
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
+
+ self.model.eval()
+
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
+ embeddings = self.model(
+ input_ids=input_ids, attention_mask=attention_mask
+ )["last_hidden_state"]
+
+ embeddings = self.proj_out(embeddings.float())
+
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
+
+ return embeddings, attention_mask
+
+def patch_clip(clip_model):
+ # a hack to make it output last hidden states
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
+ def new_encode_text(self, text, normalize: bool = False):
+ cast_dtype = self.transformer.get_cast_dtype()
+
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.to(cast_dtype)
+ x = self.transformer(x, attn_mask=self.attn_mask)
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
+ return F.normalize(x, dim=-1) if normalize else x
+
+ clip_model.encode_text = new_encode_text.__get__(clip_model)
+ return clip_model
+
+class CLIPTextConditioner(Conditioner):
+ def __init__(
+ self,
+ output_dim: int,
+ max_length: str = 77,
+ enable_grad: bool = False,
+ project_out: bool = False
+ ):
+ super().__init__(1024, output_dim, project_out=project_out)
+
+ from transformers import T5EncoderModel, AutoTokenizer
+ import open_clip
+ from open_clip import create_model_from_pretrained
+
+ self.max_length = max_length
+ self.enable_grad = enable_grad
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384',
+ return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
+ model = patch_clip(model)
+ self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
+ finally:
+ logging.disable(previous_level)
+
+ if self.enable_grad:
+ self.model = model
+ else:
+ self.__dict__["model"] = model
+
+
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.model.to(device)
+ self.proj_out.to(device)
+
+ encoded = self.tokenizer(
+ texts
+ ).to(device)
+
+ # input_ids = encoded["input_ids"].to(device)
+ # attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
+
+ self.model.eval()
+
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
+ embeddings = self.model.encode_text(
+ encoded
+ )
+
+ embeddings = self.proj_out(embeddings.float())
+
+ # embeddings = embeddings * attention_mask.unsqueeze(-1).float()
+
+ return embeddings, torch.ones(embeddings.shape[0], 1).to(device)
+
+def patch_clip(clip_model):
+ # a hack to make it output last hidden states
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
+ def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ text_outputs = self.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = text_outputs[0]
+ # pooled_output = text_outputs[1]
+ # text_features = self.text_projection(pooled_output)
+
+ return last_hidden_state
+
+ clip_model.get_text_features = new_get_text_features.__get__(clip_model)
+ return clip_model
+
+class MetaCLIPTextConditioner(Conditioner):
+ def __init__(
+ self,
+ output_dim: int,
+ max_length: str = 77,
+ enable_grad: bool = False,
+ project_out: bool = False
+ ):
+ super().__init__(1024, output_dim, project_out=project_out)
+
+ from transformers import AutoModel
+ from transformers import AutoProcessor
+
+ self.max_length = max_length
+ self.enable_grad = enable_grad
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge")
+ self.model = patch_clip(self.model)
+ self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
+ finally:
+ logging.disable(previous_level)
+
+
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.model.to(device)
+ self.proj_out.to(device)
+ encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
+
+ # input_ids = encoded["input_ids"].to(device)
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
+
+ self.model.eval()
+
+ with torch.set_grad_enabled(self.enable_grad):
+ embeddings = self.model.get_text_features(
+ **encoded
+ )
+
+ embeddings = self.proj_out(embeddings.float())
+
+ # embeddings = embeddings * attention_mask.unsqueeze(-1).float()
+
+ return embeddings, torch.ones(embeddings.shape[0],1).to(device)
+
+class PhonemeConditioner(Conditioner):
+ """
+ A conditioner that turns text into phonemes and embeds them using a lookup table
+ Only works for English text
+
+ Args:
+ output_dim: the dimension of the output embeddings
+ max_length: the maximum number of phonemes to embed
+ project_out: whether to add another linear projection to the output embeddings
+ """
+
+ def __init__(
+ self,
+ output_dim: int,
+ max_length: int = 1024,
+ project_out: bool = False,
+ ):
+ super().__init__(output_dim, output_dim, project_out=project_out)
+
+ from g2p_en import G2p
+
+ self.max_length = max_length
+
+ self.g2p = G2p()
+
+ # Reserving 0 for padding, 1 for ignored
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
+
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.phoneme_embedder.to(device)
+ self.proj_out.to(device)
+
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
+
+ phoneme_ignore = [" ", *string.punctuation]
+
+ # Remove ignored phonemes and cut to max length
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
+
+ # Convert to ids
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
+
+ #Pad to match longest and make a mask tensor for the padding
+ longest = max([len(ids) for ids in phoneme_ids])
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
+
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
+
+ # Convert to embeddings
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
+
+ phoneme_embeds = self.proj_out(phoneme_embeds)
+
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
+
+class TokenizerLUTConditioner(Conditioner):
+ """
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
+
+ Args:
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
+ output_dim: the dimension of the output embeddings
+ max_length: the maximum length of the text to embed
+ project_out: whether to add another linear projection to the output embeddings
+ """
+
+ def __init__(
+ self,
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
+ output_dim: int,
+ max_length: int = 1024,
+ project_out: bool = False,
+ ):
+ super().__init__(output_dim, output_dim, project_out=project_out)
+
+ from transformers import AutoTokenizer
+
+ # Suppress logging from transformers
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+ finally:
+ logging.disable(previous_level)
+
+ self.max_length = max_length
+
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
+
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ self.proj_out.to(device)
+
+ encoded = self.tokenizer(
+ texts,
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length",
+ return_tensors="pt",
+ )
+
+ input_ids = encoded["input_ids"].to(device)
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
+
+ embeddings = self.token_embedder(input_ids)
+
+ embeddings = self.proj_out(embeddings)
+
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
+
+ return embeddings, attention_mask
+
+class PretransformConditioner(Conditioner):
+ """
+ A conditioner that uses a pretransform's encoder for conditioning
+
+ Args:
+ pretransform: an instantiated pretransform to use for conditioning
+ output_dim: the dimension of the output embeddings
+ """
+ def __init__(self, pretransform: Pretransform, output_dim: int):
+ super().__init__(pretransform.encoded_channels, output_dim)
+
+ self.pretransform = pretransform
+
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+
+ self.pretransform.to(device)
+ self.proj_out.to(device)
+
+ if isinstance(audio, list) or isinstance(audio, tuple):
+ audio = torch.cat(audio, dim=0)
+
+ # Convert audio to pretransform input channels
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
+
+ latents = self.pretransform.encode(audio)
+
+ latents = self.proj_out(latents)
+
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
+
+class MultiConditioner(nn.Module):
+ """
+ A module that applies multiple conditioners to an input dictionary based on the keys
+
+ Args:
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
+ """
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
+ super().__init__()
+
+ self.conditioners = nn.ModuleDict(conditioners)
+ self.default_keys = default_keys
+
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
+ output = {}
+
+ for key, conditioner in self.conditioners.items():
+ condition_key = key
+
+ conditioner_inputs = []
+
+ for x in batch_metadata:
+
+ if condition_key not in x:
+ if condition_key in self.default_keys:
+ condition_key = self.default_keys[condition_key]
+ else:
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
+
+ #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
+ conditioner_input = x[condition_key][0]
+
+ else:
+ conditioner_input = x[condition_key]
+
+ conditioner_inputs.append(conditioner_input)
+
+ cond_output = conditioner(conditioner_inputs, device)
+ if len(cond_output) == 1:
+ output[key] = cond_output[0]
+ elif len(cond_output) == 2:
+ output[key] = cond_output
+ elif len(cond_output) == 4:
+ output[key] = cond_output[:2]
+ output[f'{key}_g'] = cond_output[2:]
+
+ return output
+
+def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
+ """
+ Create a MultiConditioner from a conditioning config dictionary
+
+ Args:
+ config: the conditioning config dictionary
+ device: the device to put the conditioners on
+ """
+ conditioners = {}
+ cond_dim = config["cond_dim"]
+
+ default_keys = config.get("default_keys", {})
+
+ for conditioner_info in config["configs"]:
+ id = conditioner_info["id"]
+
+ conditioner_type = conditioner_info["type"]
+
+ conditioner_config = {"output_dim": cond_dim}
+
+ conditioner_config.update(conditioner_info["config"])
+ if conditioner_type == "t5":
+ conditioners[id] = T5Conditioner(**conditioner_config)
+ elif conditioner_type == "clap_text":
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
+ elif conditioner_type == "clip_text":
+ conditioners[id] = CLIPTextConditioner(**conditioner_config)
+ elif conditioner_type == "metaclip_text":
+ conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
+ elif conditioner_type == "clap_audio":
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
+ elif conditioner_type == "cond_mlp":
+ conditioners[id] = Cond_MLP(**conditioner_config)
+ elif conditioner_type == "global_mlp":
+ conditioners[id] = Global_MLP(**conditioner_config)
+ elif conditioner_type == "sync_mlp":
+ conditioners[id] = Sync_MLP(**conditioner_config)
+ elif conditioner_type == "cond_mlp_1":
+ conditioners[id] = Cond_MLP_1(**conditioner_config)
+ elif conditioner_type == "cond_convmlp":
+ conditioners[id] = Cond_ConvMLP(**conditioner_config)
+ elif conditioner_type == "cond_mlp_global":
+ conditioners[id] = Cond_MLP_Global(**conditioner_config)
+ elif conditioner_type == "cond_mlp_global_1":
+ conditioners[id] = Cond_MLP_Global_1(**conditioner_config)
+ elif conditioner_type == "cond_mlp_global_2":
+ conditioners[id] = Cond_MLP_Global_2(**conditioner_config)
+ elif conditioner_type == "video_linear":
+ conditioners[id] = Video_Linear(**conditioner_config)
+ elif conditioner_type == "video_global":
+ conditioners[id] = Video_Global(**conditioner_config)
+ elif conditioner_type == "video_sync":
+ conditioners[id] = Video_Sync(**conditioner_config)
+ elif conditioner_type == "text_linear":
+ conditioners[id] = Text_Linear(**conditioner_config)
+ elif conditioner_type == "video_clip":
+ conditioners[id] = CLIPConditioner(**conditioner_config)
+ elif conditioner_type == "video_hiera":
+ conditioners[id] = VideoHieraConditioner(**conditioner_config)
+ elif conditioner_type == "meta_query":
+ from .meta_queries.model import MLLMInContext
+ conditioners[id] = MLLMInContext(**conditioner_config)
+ elif conditioner_type == "int":
+ conditioners[id] = IntConditioner(**conditioner_config)
+ elif conditioner_type == "number":
+ conditioners[id] = NumberConditioner(**conditioner_config)
+ elif conditioner_type == "phoneme":
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
+ elif conditioner_type == "lut":
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
+ elif conditioner_type == "pretransform":
+ sample_rate = conditioner_config.pop("sample_rate", None)
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
+
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
+
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
+
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
+ elif conditioner_type == "mm_unchang":
+ conditioners[id] = mm_unchang(**conditioner_config)
+ else:
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
+
+ return MultiConditioner(conditioners, default_keys=default_keys)
\ No newline at end of file
diff --git a/ThinkSound/models/diffusion.py b/ThinkSound/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..62638843abe7f1661c26840a31d6529b03ae727d
--- /dev/null
+++ b/ThinkSound/models/diffusion.py
@@ -0,0 +1,957 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+from functools import partial
+import numpy as np
+import typing as tp
+
+from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
+from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
+from .dit import DiffusionTransformer
+#from .mmdit import MMAudio
+from .factory import create_pretransform_from_config
+from .pretransforms import Pretransform
+from ..inference.generation import generate_diffusion_cond
+
+from .adp import UNetCFG1d, UNet1d
+
+from time import time
+
+class Profiler:
+
+ def __init__(self):
+ self.ticks = [[time(), None]]
+
+ def tick(self, msg):
+ self.ticks.append([time(), msg])
+
+ def __repr__(self):
+ rep = 80 * "=" + "\n"
+ for i in range(1, len(self.ticks)):
+ msg = self.ticks[i][1]
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
+ rep += 80 * "=" + "\n\n\n"
+ return rep
+
+class DiffusionModel(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, t, **kwargs):
+ raise NotImplementedError()
+
+class DiffusionModelWrapper(nn.Module):
+ def __init__(
+ self,
+ model: DiffusionModel,
+ io_channels,
+ sample_size,
+ sample_rate,
+ min_input_length,
+ pretransform: tp.Optional[Pretransform] = None,
+ ):
+ super().__init__()
+ self.io_channels = io_channels
+ self.sample_size = sample_size
+ self.sample_rate = sample_rate
+ self.min_input_length = min_input_length
+
+ self.model = model
+
+ if pretransform is not None:
+ self.pretransform = pretransform
+ else:
+ self.pretransform = None
+
+ def forward(self, x, t, **kwargs):
+ return self.model(x, t, **kwargs)
+
+class ConditionedDiffusionModel(nn.Module):
+ def __init__(self,
+ *args,
+ supports_cross_attention: bool = False,
+ supports_input_concat: bool = False,
+ supports_global_cond: bool = False,
+ supports_prepend_cond: bool = False,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.supports_cross_attention = supports_cross_attention
+ self.supports_input_concat = supports_input_concat
+ self.supports_global_cond = supports_global_cond
+ self.supports_prepend_cond = supports_prepend_cond
+
+ def forward(self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ cross_attn_cond: torch.Tensor = None,
+ cross_attn_mask: torch.Tensor = None,
+ input_concat_cond: torch.Tensor = None,
+ global_embed: torch.Tensor = None,
+ prepend_cond: torch.Tensor = None,
+ prepend_cond_mask: torch.Tensor = None,
+ cfg_scale: float = 1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = False,
+ rescale_cfg: bool = False,
+ **kwargs):
+ raise NotImplementedError()
+
+class ConditionedDiffusionModelWrapper(nn.Module):
+ """
+ A diffusion model that takes in conditioning
+ """
+ def __init__(
+ self,
+ model: ConditionedDiffusionModel,
+ conditioner: MultiConditioner,
+ io_channels,
+ sample_rate,
+ min_input_length: int,
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
+ zero_init: bool = False,
+ pretransform: tp.Optional[Pretransform] = None,
+ cross_attn_cond_ids: tp.List[str] = [],
+ global_cond_ids: tp.List[str] = [],
+ input_concat_ids: tp.List[str] = [],
+ prepend_cond_ids: tp.List[str] = [],
+ add_cond_ids: tp.List[str] = [],
+ sync_cond_ids: tp.List[str] = [],
+ ):
+ super().__init__()
+
+ self.model = model
+ self.conditioner = conditioner
+ self.io_channels = io_channels
+ self.sample_rate = sample_rate
+ self.diffusion_objective = diffusion_objective
+ self.pretransform = pretransform
+ self.cross_attn_cond_ids = cross_attn_cond_ids
+ self.global_cond_ids = global_cond_ids
+ self.input_concat_ids = input_concat_ids
+ self.prepend_cond_ids = prepend_cond_ids
+ self.add_cond_ids = add_cond_ids
+ self.sync_cond_ids = sync_cond_ids
+ self.min_input_length = min_input_length
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ if zero_init is True:
+ self.conditioner.apply(_basic_init)
+ self.model.model.initialize_weights()
+
+
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
+ cross_attention_input = None
+ cross_attention_masks = None
+ global_cond = None
+ input_concat_cond = None
+ prepend_cond = None
+ prepend_cond_mask = None
+ add_input = None
+ sync_input = None
+
+ if len(self.cross_attn_cond_ids) > 0:
+ # Concatenate all cross-attention inputs over the sequence dimension
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
+ cross_attention_input = []
+ cross_attention_masks = []
+
+ for key in self.cross_attn_cond_ids:
+ cross_attn_in, cross_attn_mask = conditioning_tensors[key]
+
+ # Add sequence dimension if it's not there
+ if len(cross_attn_in.shape) == 2:
+ cross_attn_in = cross_attn_in.unsqueeze(1)
+ # cross_attn_mask = cross_attn_mask.unsqueeze(1)
+
+ cross_attention_input.append(cross_attn_in)
+ cross_attention_masks.append(cross_attn_mask)
+ # import ipdb
+ # ipdb.set_trace()
+ cross_attention_input = torch.cat(cross_attention_input, dim=1)
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
+
+ if len(self.add_cond_ids) > 0:
+ # Concatenate all cross-attention inputs over the sequence dimension
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
+ add_input = []
+
+ for key in self.add_cond_ids:
+ add_in = conditioning_tensors[key][0]
+
+ # Add sequence dimension if it's not there
+ if len(add_in.shape) == 2:
+ add_in = add_in.unsqueeze(1)
+ # add_in = add_in.transpose(1,2)
+ # add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False)
+ # add_in = add_in.transpose(1,2)
+ add_input.append(add_in)
+
+ add_input = torch.cat(add_input, dim=2)
+
+ if len(self.sync_cond_ids) > 0:
+ # Concatenate all cross-attention inputs over the sequence dimension
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
+ sync_input = []
+
+ for key in self.sync_cond_ids:
+ sync_in = conditioning_tensors[key][0]
+
+ # Add sequence dimension if it's not there
+ if len(sync_in.shape) == 2:
+ sync_in = sync_in.unsqueeze(1)
+ sync_input.append(sync_in)
+
+ sync_input = torch.cat(sync_input, dim=2)
+
+ if len(self.global_cond_ids) > 0:
+ # Concatenate all global conditioning inputs over the channel dimension
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
+ global_conds = []
+ for key in self.global_cond_ids:
+ global_cond_input = conditioning_tensors[key][0]
+ if len(global_cond_input.shape) == 2:
+ global_cond_input = global_cond_input.unsqueeze(1)
+ global_conds.append(global_cond_input)
+
+ # # Concatenate over the channel dimension
+ # if global_conds[0].shape[-1] == 768:
+ # global_cond = torch.cat(global_conds, dim=-1)
+ # else:
+ # global_cond = sum(global_conds)
+ global_cond = sum(global_conds)
+ # global_cond = torch.cat(global_conds, dim=-1)
+
+ if len(global_cond.shape) == 3:
+ global_cond = global_cond.squeeze(1)
+
+ if len(self.input_concat_ids) > 0:
+ # Concatenate all input concat conditioning inputs over the channel dimension
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
+ input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
+
+ if len(self.prepend_cond_ids) > 0:
+ # Concatenate all prepend conditioning inputs over the sequence dimension
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
+ prepend_conds = []
+ prepend_cond_masks = []
+
+ for key in self.prepend_cond_ids:
+ prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
+ if len(prepend_cond_input.shape) == 2:
+ prepend_cond_input = prepend_cond_input.unsqueeze(1)
+ prepend_conds.append(prepend_cond_input)
+ prepend_cond_masks.append(prepend_cond_mask)
+
+ prepend_cond = torch.cat(prepend_conds, dim=1)
+ prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
+
+ if negative:
+ return {
+ "negative_cross_attn_cond": cross_attention_input,
+ "negative_cross_attn_mask": cross_attention_masks,
+ "negative_global_cond": global_cond,
+ "negative_input_concat_cond": input_concat_cond
+ }
+ else:
+ return {
+ "cross_attn_cond": cross_attention_input,
+ "cross_attn_mask": cross_attention_masks,
+ "global_cond": global_cond,
+ "input_concat_cond": input_concat_cond,
+ "prepend_cond": prepend_cond,
+ "prepend_cond_mask": prepend_cond_mask,
+ "add_cond": add_input,
+ "sync_cond": sync_input
+ }
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
+
+ def generate(self, *args, **kwargs):
+ return generate_diffusion_cond(self, *args, **kwargs)
+
+class UNetCFG1DWrapper(ConditionedDiffusionModel):
+ def __init__(
+ self,
+ *args,
+ **kwargs
+ ):
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
+
+ self.model = UNetCFG1d(*args, **kwargs)
+
+ with torch.no_grad():
+ for param in self.model.parameters():
+ param *= 0.5
+
+ def forward(self,
+ x,
+ t,
+ cross_attn_cond=None,
+ cross_attn_mask=None,
+ input_concat_cond=None,
+ global_cond=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = False,
+ rescale_cfg: bool = False,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ negative_global_cond=None,
+ negative_input_concat_cond=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ **kwargs):
+ p = Profiler()
+
+ p.tick("start")
+
+ channels_list = None
+ if input_concat_cond is not None:
+ channels_list = [input_concat_cond]
+
+ outputs = self.model(
+ x,
+ t,
+ embedding=cross_attn_cond,
+ embedding_mask=cross_attn_mask,
+ features=global_cond,
+ channels_list=channels_list,
+ embedding_scale=cfg_scale,
+ embedding_mask_proba=cfg_dropout_prob,
+ batch_cfg=batch_cfg,
+ rescale_cfg=rescale_cfg,
+ negative_embedding=negative_cross_attn_cond,
+ negative_embedding_mask=negative_cross_attn_mask,
+ **kwargs)
+
+ p.tick("UNetCFG1D forward")
+
+ #print(f"Profiler: {p}")
+ return outputs
+
+class UNet1DCondWrapper(ConditionedDiffusionModel):
+ def __init__(
+ self,
+ *args,
+ **kwargs
+ ):
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
+
+ self.model = UNet1d(*args, **kwargs)
+
+ with torch.no_grad():
+ for param in self.model.parameters():
+ param *= 0.5
+
+ def forward(self,
+ x,
+ t,
+ input_concat_cond=None,
+ global_cond=None,
+ cross_attn_cond=None,
+ cross_attn_mask=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = False,
+ rescale_cfg: bool = False,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ negative_global_cond=None,
+ negative_input_concat_cond=None,
+ **kwargs):
+
+ channels_list = None
+ if input_concat_cond is not None:
+
+ # Interpolate input_concat_cond to the same length as x
+ if input_concat_cond.shape[2] != x.shape[2]:
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
+
+ channels_list = [input_concat_cond]
+
+ outputs = self.model(
+ x,
+ t,
+ features=global_cond,
+ channels_list=channels_list,
+ **kwargs)
+
+ return outputs
+
+class UNet1DUncondWrapper(DiffusionModel):
+ def __init__(
+ self,
+ in_channels,
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
+
+ self.io_channels = in_channels
+
+ with torch.no_grad():
+ for param in self.model.parameters():
+ param *= 0.5
+
+ def forward(self, x, t, **kwargs):
+ return self.model(x, t, **kwargs)
+
+class DAU1DCondWrapper(ConditionedDiffusionModel):
+ def __init__(
+ self,
+ *args,
+ **kwargs
+ ):
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
+
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
+
+ with torch.no_grad():
+ for param in self.model.parameters():
+ param *= 0.5
+
+ def forward(self,
+ x,
+ t,
+ input_concat_cond=None,
+ cross_attn_cond=None,
+ cross_attn_mask=None,
+ global_cond=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = False,
+ rescale_cfg: bool = False,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ negative_global_cond=None,
+ negative_input_concat_cond=None,
+ prepend_cond=None,
+ **kwargs):
+
+ return self.model(x, t, cond = input_concat_cond)
+
+class DiffusionAttnUnet1D(nn.Module):
+ def __init__(
+ self,
+ io_channels = 2,
+ depth=14,
+ n_attn_layers = 6,
+ channels = [128, 128, 256, 256] + [512] * 10,
+ cond_dim = 0,
+ cond_noise_aug = False,
+ kernel_size = 5,
+ learned_resample = False,
+ strides = [2] * 13,
+ conv_bias = True,
+ use_snake = False
+ ):
+ super().__init__()
+
+ self.cond_noise_aug = cond_noise_aug
+
+ self.io_channels = io_channels
+
+ if self.cond_noise_aug:
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
+
+ self.timestep_embed = FourierFeatures(1, 16)
+
+ attn_layer = depth - n_attn_layers
+
+ strides = [1] + strides
+
+ block = nn.Identity()
+
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
+
+ for i in range(depth, 0, -1):
+ c = channels[i - 1]
+ stride = strides[i-1]
+ if stride > 2 and not learned_resample:
+ raise ValueError("Must have stride 2 without learned resampling")
+
+ if i > 1:
+ c_prev = channels[i - 2]
+ add_attn = i >= attn_layer and n_attn_layers > 0
+ block = SkipBlock(
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
+ conv_block(c_prev, c, c),
+ SelfAttention1d(
+ c, c // 32) if add_attn else nn.Identity(),
+ conv_block(c, c, c),
+ SelfAttention1d(
+ c, c // 32) if add_attn else nn.Identity(),
+ conv_block(c, c, c),
+ SelfAttention1d(
+ c, c // 32) if add_attn else nn.Identity(),
+ block,
+ conv_block(c * 2 if i != depth else c, c, c),
+ SelfAttention1d(
+ c, c // 32) if add_attn else nn.Identity(),
+ conv_block(c, c, c),
+ SelfAttention1d(
+ c, c // 32) if add_attn else nn.Identity(),
+ conv_block(c, c, c_prev),
+ SelfAttention1d(c_prev, c_prev //
+ 32) if add_attn else nn.Identity(),
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
+ )
+ else:
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
+ block = nn.Sequential(
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
+ conv_block(c, c, c),
+ conv_block(c, c, c),
+ block,
+ conv_block(c * 2, c, c),
+ conv_block(c, c, c),
+ conv_block(c, c, io_channels, is_last=True),
+ )
+ self.net = block
+
+ with torch.no_grad():
+ for param in self.net.parameters():
+ param *= 0.5
+
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
+
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
+
+ inputs = [x, timestep_embed]
+
+ if cond is not None:
+ if cond.shape[2] != x.shape[2]:
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
+
+ if self.cond_noise_aug:
+ # Get a random number between 0 and 1, uniformly sampled
+ if cond_aug_scale is None:
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
+ else:
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
+
+ # Add noise to the conditioning signal
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
+
+ # Get embedding for noise cond level, reusing timestamp_embed
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
+
+ inputs.append(aug_level_embed)
+
+ inputs.append(cond)
+
+ outputs = self.net(torch.cat(inputs, dim=1))
+
+ return outputs
+
+class DiTWrapper(ConditionedDiffusionModel):
+ def __init__(
+ self,
+ *args,
+ **kwargs
+ ):
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
+
+ self.model = DiffusionTransformer(*args, **kwargs)
+ # with torch.no_grad():
+ # for param in self.model.parameters():
+ # param *= 0.5
+
+ def forward(self,
+ x,
+ t,
+ cross_attn_cond=None,
+ cross_attn_mask=None,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ input_concat_cond=None,
+ negative_input_concat_cond=None,
+ global_cond=None,
+ negative_global_cond=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = True,
+ rescale_cfg: bool = False,
+ scale_phi: float = 0.0,
+ **kwargs):
+
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
+
+ return self.model(
+ x,
+ t,
+ cross_attn_cond=cross_attn_cond,
+ cross_attn_cond_mask=cross_attn_mask,
+ negative_cross_attn_cond=negative_cross_attn_cond,
+ negative_cross_attn_mask=negative_cross_attn_mask,
+ input_concat_cond=input_concat_cond,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ cfg_scale=cfg_scale,
+ cfg_dropout_prob=cfg_dropout_prob,
+ scale_phi=scale_phi,
+ global_embed=global_cond,
+ **kwargs)
+
+class MMDiTWrapper(ConditionedDiffusionModel):
+ def __init__(
+ self,
+ *args,
+ **kwargs
+ ):
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
+
+ self.model = MMAudio(*args, **kwargs)
+
+ # with torch.no_grad():
+ # for param in self.model.parameters():
+ # param *= 0.5
+
+ def forward(self,
+ x,
+ t,
+ clip_f,
+ sync_f,
+ text_f,
+ inpaint_masked_input=None,
+ t5_features=None,
+ metaclip_global_text_features=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob: float = 0.0,
+ batch_cfg: bool = True,
+ rescale_cfg: bool = False,
+ scale_phi: float = 0.0,
+ **kwargs):
+
+ # breakpoint()
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
+
+ return self.model(
+ latent=x,
+ t=t,
+ clip_f=clip_f,
+ sync_f=sync_f,
+ text_f=text_f,
+ inpaint_masked_input=inpaint_masked_input,
+ t5_features=t5_features,
+ metaclip_global_text_features=metaclip_global_text_features,
+ cfg_scale=cfg_scale,
+ cfg_dropout_prob=cfg_dropout_prob,
+ scale_phi=scale_phi,
+ **kwargs)
+
+class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
+ """
+ A diffusion model that takes in conditioning
+ """
+ def __init__(
+ self,
+ model,
+ conditioner: MultiConditioner,
+ io_channels,
+ sample_rate,
+ min_input_length: int,
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
+ pretransform: tp.Optional[Pretransform] = None,
+ cross_attn_cond_ids: tp.List[str] = [],
+ global_cond_ids: tp.List[str] = [],
+ input_concat_ids: tp.List[str] = [],
+ prepend_cond_ids: tp.List[str] = [],
+ add_cond_ids: tp.List[str] = [],
+ mm_cond_ids: tp.List[str] = [],
+ ):
+ super().__init__()
+
+ self.model = model
+ self.conditioner = conditioner
+ self.io_channels = io_channels
+ self.sample_rate = sample_rate
+ self.diffusion_objective = diffusion_objective
+ self.pretransform = pretransform
+ self.cross_attn_cond_ids = cross_attn_cond_ids
+ self.global_cond_ids = global_cond_ids
+ self.input_concat_ids = input_concat_ids
+ self.prepend_cond_ids = prepend_cond_ids
+ self.add_cond_ids = add_cond_ids
+ self.min_input_length = min_input_length
+ self.mm_cond_ids = mm_cond_ids
+
+ assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
+ assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
+ assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
+ assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
+ assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
+ assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
+ assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
+ assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
+ assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
+ # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
+
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
+ assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
+ cross_attention_input = None
+ cross_attention_masks = None
+ global_cond = None
+ input_concat_cond = None
+ prepend_cond = None
+ prepend_cond_mask = None
+ add_input = None
+ inpaint_masked_input = None
+ t5_features = None
+ metaclip_global_text_features = None
+ clip_f = conditioning_tensors["metaclip_features"]
+ sync_f = conditioning_tensors["sync_features"]
+ text_f = conditioning_tensors["metaclip_text_features"]
+ if 'inpaint_masked_input' in conditioning_tensors.keys():
+ inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
+ if 't5_features' in conditioning_tensors.keys():
+ t5_features = conditioning_tensors["t5_features"]
+ if 'metaclip_global_text_features' in conditioning_tensors.keys():
+ metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
+ return {
+ "clip_f": clip_f,
+ "sync_f": sync_f,
+ "text_f": text_f,
+ "inpaint_masked_input": inpaint_masked_input,
+ "t5_features": t5_features,
+ "metaclip_global_text_features": metaclip_global_text_features
+ }
+
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
+ # breakpoint()
+ # print(kwargs)
+ return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
+
+ def generate(self, *args, **kwargs):
+ return generate_diffusion_cond(self, *args, **kwargs)
+
+class DiTUncondWrapper(DiffusionModel):
+ def __init__(
+ self,
+ io_channels,
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
+
+ self.io_channels = io_channels
+
+ with torch.no_grad():
+ for param in self.model.parameters():
+ param *= 0.5
+
+ def forward(self, x, t, **kwargs):
+ return self.model(x, t, **kwargs)
+
+def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
+ diffusion_uncond_config = config["model"]
+
+ model_type = diffusion_uncond_config.get('type', None)
+
+ diffusion_config = diffusion_uncond_config.get('config', {})
+
+ assert model_type is not None, "Must specify model type in config"
+
+ pretransform = diffusion_uncond_config.get("pretransform", None)
+
+ sample_size = config.get("sample_size", None)
+ assert sample_size is not None, "Must specify sample size in config"
+
+ sample_rate = config.get("sample_rate", None)
+ assert sample_rate is not None, "Must specify sample rate in config"
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+ min_input_length = pretransform.downsampling_ratio
+ else:
+ min_input_length = 1
+
+ if model_type == 'DAU1d':
+
+ model = DiffusionAttnUnet1D(
+ **diffusion_config
+ )
+
+ elif model_type == "adp_uncond_1d":
+
+ model = UNet1DUncondWrapper(
+ **diffusion_config
+ )
+
+ elif model_type == "dit":
+ model = DiTUncondWrapper(
+ **diffusion_config
+ )
+
+ else:
+ raise NotImplementedError(f'Unknown model type: {model_type}')
+
+ return DiffusionModelWrapper(model,
+ io_channels=model.io_channels,
+ sample_size=sample_size,
+ sample_rate=sample_rate,
+ pretransform=pretransform,
+ min_input_length=min_input_length)
+
+def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
+ diffusion_uncond_config = config["model"]
+
+
+ diffusion_config = diffusion_uncond_config.get('diffusion', {})
+ model_type = diffusion_config.get('type', None)
+ model_config = diffusion_config.get("config",{})
+ assert model_type is not None, "Must specify model type in config"
+
+ pretransform = diffusion_uncond_config.get("pretransform", None)
+
+ sample_size = config.get("sample_size", None)
+ assert sample_size is not None, "Must specify sample size in config"
+
+ sample_rate = config.get("sample_rate", None)
+ assert sample_rate is not None, "Must specify sample rate in config"
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+ min_input_length = pretransform.downsampling_ratio
+ else:
+ min_input_length = 1
+
+ if model_type == 'DAU1d':
+
+ model = DiffusionAttnUnet1D(
+ **model_config
+ )
+
+ elif model_type == "adp_uncond_1d":
+
+ model = UNet1DUncondWrapper(
+ io_channels = io_channels,
+ **model_config
+ )
+ elif model_type == "dit":
+ model = DiTUncondWrapper(
+ **model_config
+ )
+
+ else:
+ raise NotImplementedError(f'Unknown model type: {model_type}')
+
+ return DiffusionModelWrapper(model,
+ io_channels=model.io_channels,
+ sample_size=sample_size,
+ sample_rate=sample_rate,
+ pretransform=pretransform,
+ min_input_length=min_input_length)
+
+def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
+
+ model_config = config["model"]
+
+ model_type = config["model_type"]
+
+ diffusion_config = model_config.get('diffusion', None)
+ assert diffusion_config is not None, "Must specify diffusion config"
+
+ diffusion_model_type = diffusion_config.get('type', None)
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
+
+ diffusion_model_config = diffusion_config.get('config', None)
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
+
+ if diffusion_model_type == 'adp_cfg_1d':
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
+ elif diffusion_model_type == 'adp_1d':
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
+ elif diffusion_model_type == 'dit':
+ diffusion_model = DiTWrapper(**diffusion_model_config)
+ elif diffusion_model_type == 'mmdit':
+ diffusion_model = MMDiTWrapper(**diffusion_model_config)
+
+ io_channels = model_config.get('io_channels', None)
+ assert io_channels is not None, "Must specify io_channels in model config"
+
+ sample_rate = config.get('sample_rate', None)
+ assert sample_rate is not None, "Must specify sample_rate in config"
+
+ diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
+
+ conditioning_config = model_config.get('conditioning', None)
+
+ conditioner = None
+ if conditioning_config is not None:
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
+
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
+ add_cond_ids = diffusion_config.get('add_cond_ids', [])
+ sync_cond_ids = diffusion_config.get('sync_cond_ids', [])
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
+ mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
+ zero_init = diffusion_config.get('zero_init', False)
+ pretransform = model_config.get("pretransform", None)
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+ min_input_length = pretransform.downsampling_ratio
+ else:
+ min_input_length = 1
+
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
+ min_input_length *= np.prod(diffusion_model_config["factors"])
+ elif diffusion_model_type == "dit":
+ min_input_length *= diffusion_model.model.patch_size
+
+ # Get the proper wrapper class
+
+ extra_kwargs = {}
+
+ if model_type == "mm_diffusion_cond":
+ wrapper_fn = MMConditionedDiffusionModelWrapper
+ extra_kwargs["diffusion_objective"] = diffusion_objective
+ extra_kwargs["mm_cond_ids"] = mm_cond_ids
+
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
+ wrapper_fn = ConditionedDiffusionModelWrapper
+ extra_kwargs["diffusion_objective"] = diffusion_objective
+
+ elif model_type == "diffusion_prior":
+ prior_type = model_config.get("prior_type", None)
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
+
+ if prior_type == "mono_stereo":
+ from .diffusion_prior import MonoToStereoDiffusionPrior
+ wrapper_fn = MonoToStereoDiffusionPrior
+
+ return wrapper_fn(
+ diffusion_model,
+ conditioner,
+ min_input_length=min_input_length,
+ sample_rate=sample_rate,
+ cross_attn_cond_ids=cross_attention_ids,
+ global_cond_ids=global_cond_ids,
+ input_concat_ids=input_concat_ids,
+ prepend_cond_ids=prepend_cond_ids,
+ add_cond_ids=add_cond_ids,
+ sync_cond_ids=sync_cond_ids,
+ pretransform=pretransform,
+ io_channels=io_channels,
+ zero_init=zero_init,
+ **extra_kwargs
+ )
\ No newline at end of file
diff --git a/ThinkSound/models/diffusion_prior.py b/ThinkSound/models/diffusion_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cb15258d7656fb85ee763910dc9b500331de603
--- /dev/null
+++ b/ThinkSound/models/diffusion_prior.py
@@ -0,0 +1,82 @@
+from enum import Enum
+import typing as tp
+
+from .diffusion import ConditionedDiffusionModelWrapper
+from ..inference.generation import generate_diffusion_cond
+from ..inference.utils import prepare_audio
+
+import torch
+from torch.nn import functional as F
+from torchaudio import transforms as T
+
+# Define prior types enum
+class PriorType(Enum):
+ MonoToStereo = 1
+
+class DiffusionPrior(ConditionedDiffusionModelWrapper):
+ def __init__(self, *args, prior_type: PriorType=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.prior_type = prior_type
+
+class MonoToStereoDiffusionPrior(DiffusionPrior):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs)
+
+ def stereoize(
+ self,
+ audio: torch.Tensor, # (batch, channels, time)
+ video: torch.Tensor,
+ in_sr: int,
+ steps: int,
+ sampler_kwargs: dict = {},
+ ):
+ """
+ Generate stereo audio from mono audio using a pre-trained diffusion prior
+
+ Args:
+ audio: The mono audio to convert to stereo
+ in_sr: The sample rate of the input audio
+ steps: The number of diffusion steps to run
+ sampler_kwargs: Keyword arguments to pass to the diffusion sampler
+ """
+
+ device = audio.device
+
+ sample_rate = self.sample_rate
+
+ # Resample input audio if necessary
+ if in_sr != sample_rate:
+ resample_tf = T.Resample(in_sr, sample_rate).to(audio.device)
+ audio = resample_tf(audio)
+
+ audio_length = audio.shape[-1]
+
+ # # Pad input audio to be compatible with the model
+ # min_length = self.min_input_length
+ # padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length
+
+ # # Pad input audio to be compatible with the model
+ # if padded_input_length > audio_length:
+ # audio = F.pad(audio, (0, padded_input_length - audio_length))
+
+ # Make audio mono, duplicate to stereo
+ dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1)
+
+ if self.pretransform is not None:
+ dual_mono = self.pretransform.encode(dual_mono)
+
+ conditioning = self.conditioner([{'video':video}], device)
+ # Return fake stereo audio
+ conditioning["source"] = [dual_mono]
+ stereo_audio = generate_diffusion_cond(
+ self,
+ conditioning_tensors=conditioning,
+ steps=steps,
+ sample_size=audio_length,
+ sample_rate=sample_rate,
+ device=device,
+ cfg_scale=1,
+ **sampler_kwargs,
+ )
+
+ return stereo_audio
\ No newline at end of file
diff --git a/ThinkSound/models/discriminators.py b/ThinkSound/models/discriminators.py
new file mode 100644
index 0000000000000000000000000000000000000000..b593168df965bb1f57881ea79edbc2f66478c6c2
--- /dev/null
+++ b/ThinkSound/models/discriminators.py
@@ -0,0 +1,546 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from functools import reduce
+import typing as tp
+from einops import rearrange
+from audiotools import AudioSignal, STFTParams
+from dac.model.discriminator import WNConv1d, WNConv2d
+
+def get_hinge_losses(score_real, score_fake):
+ gen_loss = -score_fake.mean()
+ dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean()
+ return dis_loss, gen_loss
+
+class EncodecDiscriminator(nn.Module):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ from encodec.msstftd import MultiScaleSTFTDiscriminator
+
+ self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs)
+
+ def forward(self, x):
+ logits, features = self.discriminators(x)
+ return logits, features
+
+ def loss(self, x, y):
+ feature_matching_distance = 0.
+ logits_true, feature_true = self.forward(x)
+ logits_fake, feature_fake = self.forward(y)
+
+ dis_loss = torch.tensor(0.)
+ adv_loss = torch.tensor(0.)
+
+ for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)):
+
+ feature_matching_distance = feature_matching_distance + sum(
+ map(
+ lambda x, y: abs(x - y).mean(),
+ scale_true,
+ scale_fake,
+ )) / len(scale_true)
+
+ _dis, _adv = get_hinge_losses(
+ logits_true[i],
+ logits_fake[i],
+ )
+
+ dis_loss = dis_loss + _dis
+ adv_loss = adv_loss + _adv
+
+ return dis_loss, adv_loss, feature_matching_distance
+
+# Discriminators from oobleck
+
+IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]]
+
+TensorDict = tp.Dict[str, torch.Tensor]
+
+class SharedDiscriminatorConvNet(nn.Module):
+
+ def __init__(
+ self,
+ in_size: int,
+ convolution: tp.Union[nn.Conv1d, nn.Conv2d],
+ out_size: int = 1,
+ capacity: int = 32,
+ n_layers: int = 4,
+ kernel_size: int = 15,
+ stride: int = 4,
+ activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(),
+ normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm,
+ ) -> None:
+ super().__init__()
+ channels = [in_size]
+ channels += list(capacity * 2**np.arange(n_layers))
+
+ if isinstance(stride, int):
+ stride = n_layers * [stride]
+
+ net = []
+ for i in range(n_layers):
+ if isinstance(kernel_size, int):
+ pad = kernel_size // 2
+ s = stride[i]
+ else:
+ pad = kernel_size[0] // 2
+ s = (stride[i], 1)
+
+ net.append(
+ normalization(
+ convolution(
+ channels[i],
+ channels[i + 1],
+ kernel_size,
+ stride=s,
+ padding=pad,
+ )))
+ net.append(activation())
+
+ net.append(convolution(channels[-1], out_size, 1))
+
+ self.net = nn.ModuleList(net)
+
+ def forward(self, x) -> IndividualDiscriminatorOut:
+ features = []
+ for layer in self.net:
+ x = layer(x)
+ if isinstance(layer, nn.modules.conv._ConvNd):
+ features.append(x)
+ score = x.reshape(x.shape[0], -1).mean(-1)
+ return score, features
+
+
+class MultiScaleDiscriminator(nn.Module):
+
+ def __init__(self,
+ in_channels: int,
+ n_scales: int,
+ **conv_kwargs) -> None:
+ super().__init__()
+ layers = []
+ for _ in range(n_scales):
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs))
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
+ score = 0
+ features = []
+ for layer in self.layers:
+ s, f = layer(x)
+ score = score + s
+ features.extend(f)
+ x = nn.functional.avg_pool1d(x, 2)
+ return score, features
+
+class MultiPeriodDiscriminator(nn.Module):
+
+ def __init__(self,
+ in_channels: int,
+ periods: tp.Sequence[int],
+ **conv_kwargs) -> None:
+ super().__init__()
+ layers = []
+ self.periods = periods
+
+ for _ in periods:
+ layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs))
+
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut:
+ score = 0
+ features = []
+ for layer, n in zip(self.layers, self.periods):
+ s, f = layer(self.fold(x, n))
+ score = score + s
+ features.extend(f)
+ return score, features
+
+ def fold(self, x: torch.Tensor, n: int) -> torch.Tensor:
+ pad = (n - (x.shape[-1] % n)) % n
+ x = nn.functional.pad(x, (0, pad))
+ return x.reshape(*x.shape[:2], -1, n)
+
+
+class MultiDiscriminator(nn.Module):
+ """
+ Individual discriminators should take a single tensor as input (NxB C T) and
+ return a tuple composed of a score tensor (NxB) and a Sequence of Features
+ Sequence[NxB C' T'].
+ """
+
+ def __init__(self, discriminator_list: tp.Sequence[nn.Module],
+ keys: tp.Sequence[str]) -> None:
+ super().__init__()
+ self.discriminators = nn.ModuleList(discriminator_list)
+ self.keys = keys
+
+ def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict:
+ features = features.chunk(len(self.keys), 0)
+ return {k: features[i] for i, k in enumerate(self.keys)}
+
+ @staticmethod
+ def concat_dicts(dict_a, dict_b):
+ out_dict = {}
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
+ for k in keys:
+ out_dict[k] = []
+ if k in dict_a:
+ if isinstance(dict_a[k], list):
+ out_dict[k].extend(dict_a[k])
+ else:
+ out_dict[k].append(dict_a[k])
+ if k in dict_b:
+ if isinstance(dict_b[k], list):
+ out_dict[k].extend(dict_b[k])
+ else:
+ out_dict[k].append(dict_b[k])
+ return out_dict
+
+ @staticmethod
+ def sum_dicts(dict_a, dict_b):
+ out_dict = {}
+ keys = set(list(dict_a.keys()) + list(dict_b.keys()))
+ for k in keys:
+ out_dict[k] = 0.
+ if k in dict_a:
+ out_dict[k] = out_dict[k] + dict_a[k]
+ if k in dict_b:
+ out_dict[k] = out_dict[k] + dict_b[k]
+ return out_dict
+
+ def forward(self, inputs: TensorDict) -> TensorDict:
+ discriminator_input = torch.cat([inputs[k] for k in self.keys], 0)
+ all_scores = []
+ all_features = []
+
+ for discriminator in self.discriminators:
+ score, features = discriminator(discriminator_input)
+ scores = self.unpack_tensor_to_dict(score)
+ scores = {f"score_{k}": scores[k] for k in scores.keys()}
+ all_scores.append(scores)
+
+ features = map(self.unpack_tensor_to_dict, features)
+ features = reduce(self.concat_dicts, features)
+ features = {f"features_{k}": features[k] for k in features.keys()}
+ all_features.append(features)
+
+ all_scores = reduce(self.sum_dicts, all_scores)
+ all_features = reduce(self.concat_dicts, all_features)
+
+ inputs.update(all_scores)
+ inputs.update(all_features)
+
+ return inputs
+
+class OobleckDiscriminator(nn.Module):
+
+ def __init__(
+ self,
+ in_channels=1,
+ ):
+ super().__init__()
+
+ multi_scale_discriminator = MultiScaleDiscriminator(
+ in_channels=in_channels,
+ n_scales=3,
+ )
+
+ multi_period_discriminator = MultiPeriodDiscriminator(
+ in_channels=in_channels,
+ periods=[2, 3, 5, 7, 11]
+ )
+
+ # multi_resolution_discriminator = MultiScaleSTFTDiscriminator(
+ # filters=32,
+ # in_channels = in_channels,
+ # out_channels = 1,
+ # n_ffts = [2048, 1024, 512, 256, 128],
+ # hop_lengths = [512, 256, 128, 64, 32],
+ # win_lengths = [2048, 1024, 512, 256, 128]
+ # )
+
+ self.multi_discriminator = MultiDiscriminator(
+ [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator],
+ ["reals", "fakes"]
+ )
+
+ def loss(self, reals, fakes):
+ inputs = {
+ "reals": reals,
+ "fakes": fakes,
+ }
+
+ inputs = self.multi_discriminator(inputs)
+
+ scores_real = inputs["score_reals"]
+ scores_fake = inputs["score_fakes"]
+
+ features_real = inputs["features_reals"]
+ features_fake = inputs["features_fakes"]
+
+ dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake)
+
+ feature_matching_distance = torch.tensor(0.)
+
+ for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)):
+
+ feature_matching_distance = feature_matching_distance + sum(
+ map(
+ lambda real, fake: abs(real - fake).mean(),
+ scale_real,
+ scale_fake,
+ )) / len(scale_real)
+
+ return dis_loss, gen_loss, feature_matching_distance
+
+
+## Discriminators from Descript Audio Codec repo
+## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt
+class MPD(nn.Module):
+ def __init__(self, period, channels=1):
+ super().__init__()
+
+ self.period = period
+ self.convs = nn.ModuleList(
+ [
+ WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
+ ]
+ )
+ self.conv_post = WNConv2d(
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
+ )
+
+ def pad_to_period(self, x):
+ t = x.shape[-1]
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
+ return x
+
+ def forward(self, x):
+ fmap = []
+
+ x = self.pad_to_period(x)
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
+
+ for layer in self.convs:
+ x = layer(x)
+ fmap.append(x)
+
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class MSD(nn.Module):
+ def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1):
+ super().__init__()
+
+ self.convs = nn.ModuleList(
+ [
+ WNConv1d(channels, 16, 15, 1, padding=7),
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
+ WNConv1d(1024, 1024, 5, 1, padding=2),
+ ]
+ )
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
+ self.sample_rate = sample_rate
+ self.rate = rate
+
+ def forward(self, x):
+ x = AudioSignal(x, self.sample_rate)
+ x.resample(self.sample_rate // self.rate)
+ x = x.audio_data
+
+ fmap = []
+
+ for l in self.convs:
+ x = l(x)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
+
+
+class MRD(nn.Module):
+ def __init__(
+ self,
+ window_length: int,
+ hop_factor: float = 0.25,
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ channels: int = 1
+ ):
+ """Complex multi-band spectrogram discriminator.
+ Parameters
+ ----------
+ window_length : int
+ Window length of STFT.
+ hop_factor : float, optional
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run discriminator over.
+ """
+ super().__init__()
+
+ self.window_length = window_length
+ self.hop_factor = hop_factor
+ self.sample_rate = sample_rate
+ self.stft_params = STFTParams(
+ window_length=window_length,
+ hop_length=int(window_length * hop_factor),
+ match_stride=True,
+ )
+
+ self.channels = channels
+
+ n_fft = window_length // 2 + 1
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
+ self.bands = bands
+
+ ch = 32
+ convs = lambda: nn.ModuleList(
+ [
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
+ ]
+ )
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
+
+ def spectrogram(self, x):
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
+ x = torch.view_as_real(x.stft())
+ x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels)
+ # Split into bands
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
+ return x_bands
+
+ def forward(self, x):
+ x_bands = self.spectrogram(x)
+ fmap = []
+
+ x = []
+ for band, stack in zip(x_bands, self.band_convs):
+ for layer in stack:
+ band = layer(band)
+ fmap.append(band)
+ x.append(band)
+
+ x = torch.cat(x, dim=-1)
+ x = self.conv_post(x)
+ fmap.append(x)
+
+ return fmap
+
+
+class DACDiscriminator(nn.Module):
+ def __init__(
+ self,
+ channels: int = 1,
+ rates: list = [],
+ periods: list = [2, 3, 5, 7, 11],
+ fft_sizes: list = [2048, 1024, 512],
+ sample_rate: int = 44100,
+ bands: list = BANDS,
+ ):
+ """Discriminator that combines multiple discriminators.
+
+ Parameters
+ ----------
+ rates : list, optional
+ sampling rates (in Hz) to run MSD at, by default []
+ If empty, MSD is not used.
+ periods : list, optional
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
+ fft_sizes : list, optional
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
+ sample_rate : int, optional
+ Sampling rate of audio in Hz, by default 44100
+ bands : list, optional
+ Bands to run MRD at, by default `BANDS`
+ """
+ super().__init__()
+ discs = []
+ discs += [MPD(p, channels=channels) for p in periods]
+ discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates]
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes]
+ self.discriminators = nn.ModuleList(discs)
+
+ def preprocess(self, y):
+ # Remove DC offset
+ y = y - y.mean(dim=-1, keepdims=True)
+ # Peak normalize the volume of input audio
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
+ return y
+
+ def forward(self, x):
+ x = self.preprocess(x)
+ fmaps = [d(x) for d in self.discriminators]
+ return fmaps
+
+class DACGANLoss(nn.Module):
+ """
+ Computes a discriminator loss, given a discriminator on
+ generated waveforms/spectrograms compared to ground truth
+ waveforms/spectrograms. Computes the loss for both the
+ discriminator and the generator in separate functions.
+ """
+
+ def __init__(self, **discriminator_kwargs):
+ super().__init__()
+ self.discriminator = DACDiscriminator(**discriminator_kwargs)
+
+ def forward(self, fake, real):
+ d_fake = self.discriminator(fake)
+ d_real = self.discriminator(real)
+ return d_fake, d_real
+
+ def discriminator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
+
+ loss_d = 0
+ for x_fake, x_real in zip(d_fake, d_real):
+ loss_d += torch.mean(x_fake[-1] ** 2)
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
+ return loss_d
+
+ def generator_loss(self, fake, real):
+ d_fake, d_real = self.forward(fake, real)
+
+ loss_g = 0
+ for x_fake in d_fake:
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
+
+ loss_feature = 0
+
+ for i in range(len(d_fake)):
+ for j in range(len(d_fake[i]) - 1):
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
+ return loss_g, loss_feature
+
+ def loss(self, fake, real):
+ gen_loss, feature_distance = self.generator_loss(fake, real)
+ dis_loss = self.discriminator_loss(fake, real)
+
+ return dis_loss, gen_loss, feature_distance
\ No newline at end of file
diff --git a/ThinkSound/models/dit (1).py b/ThinkSound/models/dit (1).py
new file mode 100644
index 0000000000000000000000000000000000000000..1b9a4095fcf7286f84fd64b017e5f889061d3d05
--- /dev/null
+++ b/ThinkSound/models/dit (1).py
@@ -0,0 +1,430 @@
+import typing as tp
+import math
+import torch
+
+from einops import rearrange
+from torch import nn
+from torch.nn import functional as F
+
+from .blocks import FourierFeatures
+from .transformer import ContinuousTransformer
+
+class DiffusionTransformer(nn.Module):
+ def __init__(self,
+ io_channels=32,
+ patch_size=1,
+ embed_dim=768,
+ cond_token_dim=0,
+ project_cond_tokens=True,
+ global_cond_dim=0,
+ project_global_cond=True,
+ input_concat_dim=0,
+ prepend_cond_dim=0,
+ depth=12,
+ num_heads=8,
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
+ timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
+ timestep_embed_dim=None,
+ diffusion_objective: tp.Literal["v", "rectified_flow", "rf_denoiser"] = "v",
+ **kwargs):
+
+ super().__init__()
+
+ self.cond_token_dim = cond_token_dim
+
+ # Timestep embeddings
+ self.timestep_cond_type = timestep_cond_type
+
+ timestep_features_dim = 256
+
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
+
+ if timestep_cond_type == "global":
+ timestep_embed_dim = embed_dim
+ elif timestep_cond_type == "input_concat":
+ assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
+ input_concat_dim += timestep_embed_dim
+
+ self.to_timestep_embed = nn.Sequential(
+ nn.Linear(timestep_features_dim, timestep_embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(timestep_embed_dim, timestep_embed_dim, bias=True),
+ )
+
+ self.diffusion_objective = diffusion_objective
+
+ if cond_token_dim > 0:
+ # Conditioning tokens
+
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
+ self.to_cond_embed = nn.Sequential(
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
+ )
+ else:
+ cond_embed_dim = 0
+
+ if global_cond_dim > 0:
+ # Global conditioning
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
+ self.to_global_embed = nn.Sequential(
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
+ )
+
+ if prepend_cond_dim > 0:
+ # Prepend conditioning
+ self.to_prepend_embed = nn.Sequential(
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+
+ self.input_concat_dim = input_concat_dim
+
+ dim_in = io_channels + self.input_concat_dim
+
+ self.patch_size = patch_size
+
+ # Transformer
+
+ self.transformer_type = transformer_type
+
+ self.global_cond_type = global_cond_type
+
+ if self.transformer_type == "continuous_transformer":
+
+ global_dim = None
+
+ if self.global_cond_type == "adaLN":
+ # The global conditioning is projected to the embed_dim already at this point
+ global_dim = embed_dim
+
+ self.transformer = ContinuousTransformer(
+ dim=embed_dim,
+ depth=depth,
+ dim_heads=embed_dim // num_heads,
+ dim_in=dim_in * patch_size,
+ dim_out=io_channels * patch_size,
+ cross_attend = cond_token_dim > 0,
+ cond_token_dim = cond_embed_dim,
+ global_cond_dim=global_dim,
+ **kwargs
+ )
+ else:
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
+
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
+ nn.init.zeros_(self.preprocess_conv.weight)
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
+ nn.init.zeros_(self.postprocess_conv.weight)
+
+ def _forward(
+ self,
+ x,
+ t,
+ mask=None,
+ cross_attn_cond=None,
+ cross_attn_cond_mask=None,
+ input_concat_cond=None,
+ global_embed=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ return_info=False,
+ exit_layer_ix=None,
+ **kwargs):
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
+
+ if global_embed is not None:
+ # Project the global conditioning to the embedding dimension
+ global_embed = self.to_global_embed(global_embed)
+
+ prepend_inputs = None
+ prepend_mask = None
+ prepend_length = 0
+ if prepend_cond is not None:
+ # Project the prepend conditioning to the embedding dimension
+ prepend_cond = self.to_prepend_embed(prepend_cond)
+
+ prepend_inputs = prepend_cond
+ if prepend_cond_mask is not None:
+ prepend_mask = prepend_cond_mask
+
+ prepend_length = prepend_cond.shape[1]
+
+ if input_concat_cond is not None:
+ # Interpolate input_concat_cond to the same length as x
+ if input_concat_cond.shape[2] != x.shape[2]:
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
+
+ x = torch.cat([x, input_concat_cond], dim=1)
+
+ # Get the batch of timestep embeddings
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
+
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
+
+ if self.timestep_cond_type == "global":
+ if global_embed is not None:
+ global_embed = global_embed + timestep_embed
+ else:
+ global_embed = timestep_embed
+ elif self.timestep_cond_type == "input_concat":
+ x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
+
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
+ if self.global_cond_type == "prepend" and global_embed is not None:
+ if prepend_inputs is None:
+ # Prepend inputs are just the global embed, and the mask is all ones
+ prepend_inputs = global_embed.unsqueeze(1)
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
+ else:
+ # Prepend inputs are the prepend conditioning + the global embed
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
+
+ prepend_length = prepend_inputs.shape[1]
+
+ x = self.preprocess_conv(x) + x
+
+ x = rearrange(x, "b c t -> b t c")
+
+ extra_args = {}
+
+ if self.global_cond_type == "adaLN":
+ extra_args["global_cond"] = global_embed
+
+ if self.patch_size > 1:
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
+
+ if self.transformer_type == "continuous_transformer":
+ # Masks not currently implemented for continuous transformer
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, return_info=return_info, exit_layer_ix=exit_layer_ix, **extra_args, **kwargs)
+
+ if return_info:
+ output, info = output
+
+ # Avoid postprocessing on early exit
+ if exit_layer_ix is not None:
+ if return_info:
+ return output, info
+ else:
+ return output
+
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
+
+ if self.patch_size > 1:
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
+
+ output = self.postprocess_conv(output) + output
+
+ if return_info:
+ return output, info
+
+ return output
+
+ def forward(
+ self,
+ x,
+ t,
+ cross_attn_cond=None,
+ cross_attn_cond_mask=None,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ input_concat_cond=None,
+ global_embed=None,
+ negative_global_embed=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob=0.0,
+ cfg_interval = (0, 1),
+ causal=False,
+ scale_phi=0.0,
+ mask=None,
+ return_info=False,
+ exit_layer_ix=None,
+ **kwargs):
+
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
+
+ model_dtype = next(self.parameters()).dtype
+
+ x = x.to(model_dtype)
+
+ t = t.to(model_dtype)
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = cross_attn_cond.to(model_dtype)
+
+ if negative_cross_attn_cond is not None:
+ negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
+
+ if input_concat_cond is not None:
+ input_concat_cond = input_concat_cond.to(model_dtype)
+
+ if global_embed is not None:
+ global_embed = global_embed.to(model_dtype)
+
+ if negative_global_embed is not None:
+ negative_global_embed = negative_global_embed.to(model_dtype)
+
+ if prepend_cond is not None:
+ prepend_cond = prepend_cond.to(model_dtype)
+
+ if cross_attn_cond_mask is not None:
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
+
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = prepend_cond_mask.bool()
+
+ # Early exit bypasses CFG processing
+ if exit_layer_ix is not None:
+ assert self.transformer_type == "continuous_transformer", "exit_layer_ix is only supported for continuous_transformer"
+ return self._forward(
+ x,
+ t,
+ cross_attn_cond=cross_attn_cond,
+ cross_attn_cond_mask=cross_attn_cond_mask,
+ input_concat_cond=input_concat_cond,
+ global_embed=global_embed,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ mask=mask,
+ return_info=return_info,
+ exit_layer_ix=exit_layer_ix,
+ **kwargs
+ )
+
+ # CFG dropout
+ if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
+ if cross_attn_cond is not None:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
+
+ if prepend_cond is not None:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
+
+ if self.diffusion_objective == "v":
+ sigma = torch.sin(t * math.pi / 2)
+ alpha = torch.cos(t * math.pi / 2)
+ elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]:
+ sigma = t
+
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None) and (cfg_interval[0] <= sigma[0] <= cfg_interval[1]):
+
+ # Classifier-free guidance
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
+ batch_inputs = torch.cat([x, x], dim=0)
+ batch_timestep = torch.cat([t, t], dim=0)
+
+ if global_embed is not None:
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
+ else:
+ batch_global_cond = None
+
+ if input_concat_cond is not None:
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
+ else:
+ batch_input_concat_cond = None
+
+ batch_cond = None
+ batch_cond_masks = None
+
+ # Handle CFG for cross-attention conditioning
+ if cross_attn_cond is not None:
+
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
+ if negative_cross_attn_cond is not None:
+
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
+ if negative_cross_attn_mask is not None:
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
+
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
+
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
+
+ else:
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
+
+ if cross_attn_cond_mask is not None:
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
+
+ batch_prepend_cond = None
+ batch_prepend_cond_mask = None
+
+ if prepend_cond is not None:
+
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
+
+ if prepend_cond_mask is not None:
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
+
+
+ if mask is not None:
+ batch_masks = torch.cat([mask, mask], dim=0)
+ else:
+ batch_masks = None
+
+ batch_output = self._forward(
+ batch_inputs,
+ batch_timestep,
+ cross_attn_cond=batch_cond,
+ cross_attn_cond_mask=batch_cond_masks,
+ mask = batch_masks,
+ input_concat_cond=batch_input_concat_cond,
+ global_embed = batch_global_cond,
+ prepend_cond = batch_prepend_cond,
+ prepend_cond_mask = batch_prepend_cond_mask,
+ return_info = return_info,
+ **kwargs)
+
+ if return_info:
+ batch_output, info = batch_output
+
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
+
+ # CFG Rescale
+ if scale_phi != 0.0:
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
+ output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
+ else:
+ output = cfg_output
+
+ if return_info:
+ info["uncond_output"] = uncond_output
+ return output, info
+
+ return output
+
+ else:
+ return self._forward(
+ x,
+ t,
+ cross_attn_cond=cross_attn_cond,
+ cross_attn_cond_mask=cross_attn_cond_mask,
+ input_concat_cond=input_concat_cond,
+ global_embed=global_embed,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ mask=mask,
+ return_info=return_info,
+ **kwargs
+ )
\ No newline at end of file
diff --git a/ThinkSound/models/dit.py b/ThinkSound/models/dit.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb1e96ba044c9520a6ba5fb5f32848d5aa342b7e
--- /dev/null
+++ b/ThinkSound/models/dit.py
@@ -0,0 +1,547 @@
+import typing as tp
+import math
+import torch
+# from beartype.typing import Tuple
+from einops import rearrange
+from torch import nn
+from torch.nn import functional as F
+from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
+from .blocks import FourierFeatures
+from .transformer import ContinuousTransformer
+from .utils import mask_from_frac_lengths, resample
+
+class DiffusionTransformer(nn.Module):
+ def __init__(self,
+ io_channels=32,
+ patch_size=1,
+ embed_dim=768,
+ cond_token_dim=0,
+ project_cond_tokens=True,
+ global_cond_dim=0,
+ project_global_cond=True,
+ input_concat_dim=0,
+ prepend_cond_dim=0,
+ cond_ctx_dim=0,
+ depth=12,
+ num_heads=8,
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
+ timestep_cond_type: tp.Literal["global", "input_concat"] = "global",
+ add_token_dim=0,
+ sync_token_dim=0,
+ use_mlp=False,
+ use_zero_init=False,
+ **kwargs):
+
+ super().__init__()
+
+ self.cond_token_dim = cond_token_dim
+
+ # Timestep embeddings
+ timestep_features_dim = 256
+ # Timestep embeddings
+ self.timestep_cond_type = timestep_cond_type
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
+
+ if timestep_cond_type == "global":
+ timestep_embed_dim = embed_dim
+ elif timestep_cond_type == "input_concat":
+ assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat"
+ input_concat_dim += timestep_embed_dim
+
+ self.to_timestep_embed = nn.Sequential(
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=True),
+ )
+ self.use_mlp = use_mlp
+ if cond_token_dim > 0:
+ # Conditioning tokens
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
+ self.to_cond_embed = nn.Sequential(
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
+ )
+ else:
+ cond_embed_dim = 0
+
+ if global_cond_dim > 0:
+ # Global conditioning
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
+ self.to_global_embed = nn.Sequential(
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
+ )
+ if add_token_dim > 0:
+ # Conditioning tokens
+ add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
+ self.to_add_embed = nn.Sequential(
+ nn.Linear(add_token_dim, add_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(add_embed_dim, add_embed_dim, bias=False)
+ )
+ else:
+ add_embed_dim = 0
+
+ if sync_token_dim > 0:
+ # Conditioning tokens
+ sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim
+ self.to_sync_embed = nn.Sequential(
+ nn.Linear(sync_token_dim, sync_embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(sync_embed_dim, sync_embed_dim, bias=False)
+ )
+ else:
+ sync_embed_dim = 0
+
+
+ if prepend_cond_dim > 0:
+ # Prepend conditioning
+ self.to_prepend_embed = nn.Sequential(
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+
+ self.input_concat_dim = input_concat_dim
+
+ dim_in = io_channels + self.input_concat_dim
+
+ self.patch_size = patch_size
+
+ # Transformer
+
+ self.transformer_type = transformer_type
+
+ self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
+ self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True)
+ self.global_cond_type = global_cond_type
+ print("######################")
+ print(f'global type: {global_cond_type}')
+ print("######################")
+ if self.transformer_type == "continuous_transformer":
+
+ global_dim = None
+
+ if self.global_cond_type == "adaLN":
+ # The global conditioning is projected to the embed_dim already at this point
+ global_dim = embed_dim
+
+ self.transformer = ContinuousTransformer(
+ dim=embed_dim,
+ depth=depth,
+ dim_heads=embed_dim // num_heads,
+ dim_in=dim_in * patch_size,
+ dim_out=io_channels * patch_size,
+ cross_attend = cond_token_dim > 0,
+ cond_token_dim = cond_embed_dim,
+ global_cond_dim=global_dim,
+ **kwargs
+ )
+ else:
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
+
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
+ nn.init.zeros_(self.preprocess_conv.weight)
+ nn.init.zeros_(self.postprocess_conv.weight)
+
+
+ def initialize_weights(self):
+ print("######################")
+ print(f'Fine! You are using zero initialization!')
+ print("######################")
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ # if isinstance(module, nn.Conv1d):
+ # if module.bias is not None:
+ # nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02)
+ nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ if self.global_cond_type == "adaLN":
+ for block in self.transformer.layers:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ nn.init.constant_(self.empty_clip_feat, 0)
+ nn.init.constant_(self.empty_sync_feat, 0)
+
+ def _forward(
+ self,
+ x,
+ t,
+ mask=None,
+ cross_attn_cond=None,
+ cross_attn_cond_mask=None,
+ input_concat_cond=None,
+ global_embed=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ add_cond=None,
+ add_masks=None,
+ sync_cond=None,
+ return_info=False,
+ **kwargs):
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
+
+ if global_embed is not None:
+ # Project the global conditioning to the embedding dimension
+ global_embed = self.to_global_embed(global_embed)
+
+ prepend_inputs = None
+ prepend_mask = None
+ prepend_length = 0
+ if prepend_cond is not None:
+ # Project the prepend conditioning to the embedding dimension
+ prepend_cond = self.to_prepend_embed(prepend_cond)
+
+ prepend_inputs = prepend_cond
+ if prepend_cond_mask is not None:
+ prepend_mask = prepend_cond_mask
+
+ if input_concat_cond is not None:
+ # reshape from (b, n, c) to (b, c, n)
+ if input_concat_cond.shape[1] != x.shape[1]:
+ input_concat_cond = input_concat_cond.transpose(1,2)
+ # Interpolate input_concat_cond to the same length as x
+ # if input_concat_cond.shape[1] != x.shape[2]:
+ # input_concat_cond = input_concat_cond.transpose(1,2)
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
+ # input_concat_cond = input_concat_cond.transpose(1,2)
+ # if len(global_embed.shape) == 2:
+ # global_embed = global_embed.unsqueeze(1)
+ # global_embed = global_embed + input_concat_cond
+ x = torch.cat([x, input_concat_cond], dim=1)
+
+ # Get the batch of timestep embeddings
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
+ # import ipdb
+ # ipdb.set_trace()
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
+ if self.timestep_cond_type == "global":
+ if global_embed is not None:
+ if len(global_embed.shape) == 3:
+ timestep_embed = timestep_embed.unsqueeze(1)
+ global_embed = global_embed + timestep_embed
+ else:
+ global_embed = timestep_embed
+ elif self.timestep_cond_type == "input_concat":
+ x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1)
+
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
+ if self.global_cond_type == "prepend" and global_embed is not None:
+ if prepend_inputs is None:
+ # Prepend inputs are just the global embed, and the mask is all ones
+ if len(global_embed.shape) == 2:
+ prepend_inputs = global_embed.unsqueeze(1)
+ else:
+ prepend_inputs = global_embed
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
+ else:
+ # Prepend inputs are the prepend conditioning + the global embed
+ if len(global_embed.shape) == 2:
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
+ else:
+ prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1)
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
+
+ prepend_length = prepend_inputs.shape[1]
+
+ x = self.preprocess_conv(x) + x
+ x = rearrange(x, "b c t -> b t c")
+
+ extra_args = {}
+
+ if self.global_cond_type == "adaLN":
+ extra_args["global_cond"] = global_embed
+
+ if self.patch_size > 1:
+ b, seq_len, c = x.shape
+
+ # 计算需要填充的数量
+ pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size
+
+ if pad_amount > 0:
+ # 在时间维度上进行填充
+ x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0)
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
+
+ if add_cond is not None:
+ # Interpolate add_cond to the same length as x
+ # if self.use_mlp:
+ add_cond = self.to_add_embed(add_cond)
+ if add_cond.shape[1] != x.shape[1]:
+ add_cond = add_cond.transpose(1,2)
+ add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False)
+ add_cond = add_cond.transpose(1,2)
+ # add_cond = resample(add_cond, x)
+
+ if sync_cond is not None:
+ sync_cond = self.to_sync_embed(sync_cond)
+
+ if self.transformer_type == "continuous_transformer":
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
+
+ if return_info:
+ output, info = output
+
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
+
+ if self.patch_size > 1:
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
+ # 移除之前添加的填充
+ if pad_amount > 0:
+ output = output[:, :, :seq_len]
+
+ output = self.postprocess_conv(output) + output
+
+ if return_info:
+ return output, info
+
+ return output
+
+ def forward(
+ self,
+ x,
+ t,
+ cross_attn_cond=None,
+ cross_attn_cond_mask=None,
+ negative_cross_attn_cond=None,
+ negative_cross_attn_mask=None,
+ input_concat_cond=None,
+ global_embed=None,
+ negative_global_embed=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ add_cond=None,
+ sync_cond=None,
+ cfg_scale=1.0,
+ cfg_dropout_prob=0.0,
+ causal=False,
+ scale_phi=0.0,
+ mask=None,
+ return_info=False,
+ **kwargs):
+
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
+ bsz, a, b = x.shape
+ model_dtype = next(self.parameters()).dtype
+ x = x.to(model_dtype)
+ t = t.to(model_dtype)
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = cross_attn_cond.to(model_dtype)
+
+ if negative_cross_attn_cond is not None:
+ negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype)
+
+ if input_concat_cond is not None:
+ input_concat_cond = input_concat_cond.to(model_dtype)
+
+ if global_embed is not None:
+ global_embed = global_embed.to(model_dtype)
+
+ if negative_global_embed is not None:
+ negative_global_embed = negative_global_embed.to(model_dtype)
+
+ if prepend_cond is not None:
+ prepend_cond = prepend_cond.to(model_dtype)
+
+ if add_cond is not None:
+ add_cond = add_cond.to(model_dtype)
+
+ if sync_cond is not None:
+ sync_cond = sync_cond.to(model_dtype)
+
+ if cross_attn_cond_mask is not None:
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
+
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = prepend_cond_mask.bool()
+
+
+ # CFG dropout
+ if cfg_dropout_prob > 0.0 and cfg_scale == 1.0:
+ if cross_attn_cond is not None:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
+
+ if prepend_cond is not None:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
+
+ if add_cond is not None:
+ null_embed = torch.zeros_like(add_cond, device=add_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
+ add_cond = torch.where(dropout_mask, null_embed, add_cond)
+
+ if sync_cond is not None:
+ null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool)
+ sync_cond = torch.where(dropout_mask, null_embed, sync_cond)
+
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
+ # Classifier-free guidance
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
+ batch_inputs = torch.cat([x, x], dim=0)
+ batch_timestep = torch.cat([t, t], dim=0)
+ if global_embed is not None and global_embed.shape[0] == bsz:
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
+ elif global_embed is not None:
+ batch_global_cond = global_embed
+ else:
+ batch_global_cond = None
+
+ if input_concat_cond is not None and input_concat_cond.shape[0] == bsz:
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
+ elif input_concat_cond is not None:
+ batch_input_concat_cond = input_concat_cond
+ else:
+ batch_input_concat_cond = None
+
+ batch_cond = None
+ batch_cond_masks = None
+
+ # Handle CFG for cross-attention conditioning
+ if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz:
+
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
+ if negative_cross_attn_cond is not None:
+
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
+ if negative_cross_attn_mask is not None:
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
+
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
+
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
+
+ else:
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
+
+ if cross_attn_cond_mask is not None:
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
+ elif cross_attn_cond is not None:
+ batch_cond = cross_attn_cond
+ else:
+ batch_cond = None
+
+ batch_prepend_cond = None
+ batch_prepend_cond_mask = None
+
+ if prepend_cond is not None and prepend_cond.shape[0] == bsz:
+
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
+
+ if prepend_cond_mask is not None:
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
+ elif prepend_cond is not None:
+ batch_prepend_cond = prepend_cond
+ else:
+ batch_prepend_cond = None
+
+ batch_add_cond = None
+
+ # Handle CFG for cross-attention conditioning
+ if add_cond is not None and add_cond.shape[0] == bsz:
+
+ null_embed = torch.zeros_like(add_cond, device=add_cond.device)
+
+
+ batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
+ elif add_cond is not None:
+ batch_add_cond = add_cond
+ else:
+ batch_add_cond = None
+
+ batch_sync_cond = None
+
+ # Handle CFG for cross-attention conditioning
+ if sync_cond is not None and sync_cond.shape[0] == bsz:
+
+ null_embed = torch.zeros_like(sync_cond, device=sync_cond.device)
+
+
+ batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0)
+ elif sync_cond is not None:
+ batch_sync_cond = sync_cond
+ else:
+ batch_sync_cond = None
+
+ if mask is not None:
+ batch_masks = torch.cat([mask, mask], dim=0)
+ else:
+ batch_masks = None
+
+ batch_output = self._forward(
+ batch_inputs,
+ batch_timestep,
+ cross_attn_cond=batch_cond,
+ cross_attn_cond_mask=batch_cond_masks,
+ mask = batch_masks,
+ input_concat_cond=batch_input_concat_cond,
+ global_embed = batch_global_cond,
+ prepend_cond = batch_prepend_cond,
+ prepend_cond_mask = batch_prepend_cond_mask,
+ add_cond = batch_add_cond,
+ sync_cond = batch_sync_cond,
+ return_info = return_info,
+ **kwargs)
+
+ if return_info:
+ batch_output, info = batch_output
+
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
+
+ # CFG Rescale
+ if scale_phi != 0.0:
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
+ output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
+ else:
+ output = cfg_output
+
+ if return_info:
+ return output, info
+
+ return output
+
+ else:
+ return self._forward(
+ x,
+ t,
+ cross_attn_cond=cross_attn_cond,
+ cross_attn_cond_mask=cross_attn_cond_mask,
+ input_concat_cond=input_concat_cond,
+ global_embed=global_embed,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ add_cond=add_cond,
+ sync_cond=sync_cond,
+ mask=mask,
+ return_info=return_info,
+ **kwargs
+ )
\ No newline at end of file
diff --git a/ThinkSound/models/factory.py b/ThinkSound/models/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..20a44e84aa6f58e7fc128a5d82f97e41ccc809b6
--- /dev/null
+++ b/ThinkSound/models/factory.py
@@ -0,0 +1,156 @@
+import json
+
+def create_model_from_config(model_config):
+ model_type = model_config.get('model_type', None)
+
+ assert model_type is not None, 'model_type must be specified in model config'
+
+ if model_type == 'autoencoder':
+ from .autoencoders import create_autoencoder_from_config
+ return create_autoencoder_from_config(model_config)
+ elif model_type == 'diffusion_uncond':
+ from .diffusion import create_diffusion_uncond_from_config
+ return create_diffusion_uncond_from_config(model_config)
+ # elif model_type == 'diffusion_infill':
+ # from .diffusion import create_diffusion_infill_from_config
+ # return create_diffusion_infill_from_config(model_config)
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
+ from .diffusion import create_diffusion_cond_from_config
+ return create_diffusion_cond_from_config(model_config)
+ elif model_type == 'diffusion_autoencoder':
+ from .autoencoders import create_diffAE_from_config
+ return create_diffAE_from_config(model_config)
+ elif model_type == 'lm':
+ from .lm import create_audio_lm_from_config
+ return create_audio_lm_from_config(model_config)
+ else:
+ raise NotImplementedError(f'Unknown model type: {model_type}')
+
+def create_model_from_config_path(model_config_path):
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+
+ return create_model_from_config(model_config)
+
+def create_pretransform_from_config(pretransform_config, sample_rate):
+ pretransform_type = pretransform_config.get('type', None)
+
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
+
+ if pretransform_type == 'autoencoder':
+ from .autoencoders import create_autoencoder_from_config
+ from .pretransforms import AutoencoderPretransform
+
+ # Create fake top-level config to pass sample rate to autoencoder constructor
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
+
+ scale = pretransform_config.get("scale", 1.0)
+ model_half = pretransform_config.get("model_half", False)
+ iterate_batch = pretransform_config.get("iterate_batch", False)
+ chunked = pretransform_config.get("chunked", False)
+
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
+ elif pretransform_type == 'wavelet':
+ from .pretransforms import WaveletPretransform
+
+ wavelet_config = pretransform_config["config"]
+ channels = wavelet_config["channels"]
+ levels = wavelet_config["levels"]
+ wavelet = wavelet_config["wavelet"]
+
+ pretransform = WaveletPretransform(channels, levels, wavelet)
+ elif pretransform_type == 'pqmf':
+ from .pretransforms import PQMFPretransform
+ pqmf_config = pretransform_config["config"]
+ pretransform = PQMFPretransform(**pqmf_config)
+ elif pretransform_type == 'dac_pretrained':
+ from .pretransforms import PretrainedDACPretransform
+ pretrained_dac_config = pretransform_config["config"]
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
+ elif pretransform_type == "audiocraft_pretrained":
+ from .pretransforms import AudiocraftCompressionPretransform
+
+ audiocraft_config = pretransform_config["config"]
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
+ else:
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
+
+ enable_grad = pretransform_config.get('enable_grad', False)
+ pretransform.enable_grad = enable_grad
+
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
+
+ return pretransform
+
+def create_bottleneck_from_config(bottleneck_config):
+ bottleneck_type = bottleneck_config.get('type', None)
+
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
+
+ if bottleneck_type == 'tanh':
+ from .bottleneck import TanhBottleneck
+ bottleneck = TanhBottleneck()
+ elif bottleneck_type == 'vae':
+ from .bottleneck import VAEBottleneck
+ bottleneck = VAEBottleneck()
+ elif bottleneck_type == 'rvq':
+ from .bottleneck import RVQBottleneck
+
+ quantizer_params = {
+ "dim": 128,
+ "codebook_size": 1024,
+ "num_quantizers": 8,
+ "decay": 0.99,
+ "kmeans_init": True,
+ "kmeans_iters": 50,
+ "threshold_ema_dead_code": 2,
+ }
+
+ quantizer_params.update(bottleneck_config["config"])
+
+ bottleneck = RVQBottleneck(**quantizer_params)
+ elif bottleneck_type == "dac_rvq":
+ from .bottleneck import DACRVQBottleneck
+
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
+
+ elif bottleneck_type == 'rvq_vae':
+ from .bottleneck import RVQVAEBottleneck
+
+ quantizer_params = {
+ "dim": 128,
+ "codebook_size": 1024,
+ "num_quantizers": 8,
+ "decay": 0.99,
+ "kmeans_init": True,
+ "kmeans_iters": 50,
+ "threshold_ema_dead_code": 2,
+ }
+
+ quantizer_params.update(bottleneck_config["config"])
+
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
+
+ elif bottleneck_type == 'dac_rvq_vae':
+ from .bottleneck import DACRVQVAEBottleneck
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
+ elif bottleneck_type == 'l2_norm':
+ from .bottleneck import L2Bottleneck
+ bottleneck = L2Bottleneck()
+ elif bottleneck_type == "wasserstein":
+ from .bottleneck import WassersteinBottleneck
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
+ elif bottleneck_type == "fsq":
+ from .bottleneck import FSQBottleneck
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
+ else:
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
+
+ requires_grad = bottleneck_config.get('requires_grad', True)
+ if not requires_grad:
+ for param in bottleneck.parameters():
+ param.requires_grad = False
+
+ return bottleneck
diff --git a/ThinkSound/models/lm.py b/ThinkSound/models/lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1897fa72ab716f69e0c6d71236e47cc50f78592e
--- /dev/null
+++ b/ThinkSound/models/lm.py
@@ -0,0 +1,541 @@
+from dataclasses import dataclass
+import torch
+from tqdm.auto import trange
+import typing as tp
+from einops import rearrange
+from torch import nn
+
+from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
+from .factory import create_pretransform_from_config
+from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
+from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
+from .utils import multinomial, sample_top_k, sample_top_p
+
+from .codebook_patterns import (
+ CodebooksPatternProvider,
+ DelayedPatternProvider,
+ MusicLMPattern,
+ ParallelPatternProvider,
+ UnrolledPatternProvider
+)
+
+# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
+# License can be found in LICENSES/LICENSE_META.txt
+
+@dataclass
+class LMOutput:
+ # The logits are already re-aligned with the input codes
+ # hence no extra shift is required, e.g. when computing CE
+ logits: torch.Tensor # [B, K, T, card]
+ mask: torch.Tensor # [B, K, T]
+
+# Wrapper for a multi-codebook language model
+# Handles patterns and quantizer heads
+class AudioLanguageModel(nn.Module):
+ def __init__(
+ self,
+ pattern_provider: CodebooksPatternProvider,
+ backbone: AudioLMBackbone,
+ num_quantizers: int,
+ codebook_size: int
+ ):
+ super().__init__()
+
+ self.pattern_provider = pattern_provider
+ self.backbone = backbone
+ self.num_quantizers = num_quantizers
+ self.codebook_size = codebook_size
+
+ self.masked_token_id = codebook_size
+
+ # Per-quantizer embedders
+ # Add one for the mask embed
+ self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)])
+
+ # Per-quantizer output heads
+ self.quantizer_heads = nn.ModuleList([
+ nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers)
+ ])
+
+ def forward(self,
+ sequence: torch.Tensor, #[batch, seq_len,
+ prepend_cond=None, #[batch, seq, channels]
+ prepend_cond_mask=None,
+ cross_attn_cond=None, #[batch, seq, channels],
+ **kwargs
+ ):
+
+ batch, num_quantizers, seq_len = sequence.shape
+
+ assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model"
+
+ backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim]
+
+ dtype = next(self.parameters()).dtype
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = cross_attn_cond.to(dtype)
+
+ if prepend_cond is not None:
+ prepend_cond = prepend_cond.to(dtype)
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = prepend_cond_mask.to(dtype)
+
+ backbone_input = backbone_input.to(dtype)
+
+ output = self.backbone(
+ backbone_input,
+ cross_attn_cond=cross_attn_cond,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ **kwargs
+ ) # [batch, seq_len, embed_dim]
+
+ # Run output through quantizer heads
+ logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size]
+
+ return logits
+
+ def compute_logits(
+ self,
+ codes, #[batch, num_quantizers, seq_len]
+ **kwargs):
+ """
+ Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning
+ Handles translation between input sequence and pattern-shifted sequence
+ Only used during training
+ """
+
+ batch, _, seq_len = codes.shape
+
+ pattern = self.pattern_provider.get_pattern(seq_len)
+
+ # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps
+ shifted_codes, _, _ = pattern.build_pattern_sequence(
+ codes,
+ self.masked_token_id,
+ keep_only_valid_steps=True
+ )
+
+ # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size]
+ logits = self(shifted_codes, **kwargs)
+
+ # Rearrange logits to prepare to revert pattern
+ logits = rearrange(logits, "b n s c -> b c n s")
+
+ # Revert sequence logits back to original sequence length, removing masked steps
+ logits, _, logits_mask = pattern.revert_pattern_logits(
+ logits, float('nan'), keep_only_valid_steps=True
+ )
+
+ logits = rearrange(logits, "b c n t -> b n t c")
+
+ logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len]
+
+ return LMOutput(logits=logits, mask=logits_mask)
+
+# Conditioning and generation wrapper for a multi-codebook language model
+# Handles conditioning, CFG, generation, and encoding/decoding
+class AudioLanguageModelWrapper(nn.Module):
+ def __init__(
+ self,
+ pretransform: Pretransform,
+ lm: AudioLanguageModel,
+ sample_rate: int,
+ min_input_length: int,
+ conditioner: MultiConditioner = None,
+ cross_attn_cond_ids: tp.List[str] = [],
+ prepend_cond_ids: tp.List[str] = [],
+ global_cond_ids: tp.List[str] = []
+ ):
+ super().__init__()
+
+ assert pretransform.is_discrete, "Pretransform must be discrete"
+ self.pretransform = pretransform
+
+ self.pretransform.requires_grad_(False)
+ self.pretransform.eval()
+
+ if isinstance(self.pretransform, AutoencoderPretransform):
+ self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
+ self.codebook_size = self.pretransform.model.bottleneck.codebook_size
+ elif isinstance(self.pretransform, PretrainedDACPretransform):
+ self.num_quantizers = self.pretransform.model.num_quantizers
+ self.codebook_size = self.pretransform.model.codebook_size
+ elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
+ self.num_quantizers = self.pretransform.num_quantizers
+ self.codebook_size = self.pretransform.codebook_size
+ else:
+ raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
+
+ self.conditioner = conditioner
+
+ self.lm = lm
+
+ self.sample_rate = sample_rate
+ self.min_input_length = min_input_length
+
+ self.cross_attn_cond_ids = cross_attn_cond_ids
+ self.prepend_cond_ids = prepend_cond_ids
+ self.global_cond_ids = global_cond_ids
+
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
+ cross_attention_input = None
+ prepend_cond = None
+ prepend_cond_mask = None
+ global_cond = None
+
+ if len(self.cross_attn_cond_ids) > 0:
+ # Concatenate all cross-attention inputs over the sequence dimension
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
+ cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
+
+ if len(self.prepend_cond_ids) > 0:
+ # Concatenate all prepend conditioning inputs over the sequence dimension
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
+
+ if len(self.global_cond_ids) > 0:
+ # Concatenate all global conditioning inputs over the channel dimension
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
+ if len(global_cond.shape) == 3:
+ global_cond = global_cond.squeeze(1)
+
+ if negative:
+ return {
+ "negative_cross_attn_cond": cross_attention_input,
+ "negative_prepend_cond": prepend_cond,
+ "negative_prepend_cond_mask": prepend_cond_mask,
+ "negative_global_cond": global_cond
+ }
+ else:
+ return {
+ "cross_attn_cond": cross_attention_input,
+ "prepend_cond": prepend_cond,
+ "prepend_cond_mask": prepend_cond_mask,
+ "global_cond": global_cond
+ }
+
+ def compute_logits(
+ self,
+ codes,
+ condition_tensors=None,
+ cfg_dropout_prob=0.0,
+ **kwargs
+ ):
+ """
+ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
+ Handles CFG dropout
+ """
+
+ if condition_tensors is None:
+ condition_tensors = {}
+
+ conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
+
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
+ prepend_cond = conditioning_inputs["prepend_cond"]
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
+ global_cond = conditioning_inputs["global_cond"]
+
+ if cfg_dropout_prob > 0.0:
+ if cross_attn_cond is not None:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
+
+ if prepend_cond is not None:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
+
+ if global_cond is not None:
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
+ global_cond = torch.where(dropout_mask, null_embed, global_cond)
+
+ return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
+
+ def _sample_next_token(
+ self,
+ sequence, #[batch, num_quantizers, seq_len]
+ conditioning_tensors=None,
+ cross_attn_use_cfg=True,
+ prepend_use_cfg=True,
+ global_use_cfg=True,
+ cfg_scale=1.0,
+ top_k=250,
+ top_p=0.0,
+ temp=1.0,
+ **kwargs
+ ):
+ """
+ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
+ Handles CFG inference
+ """
+
+ if conditioning_tensors is None:
+ conditioning_tensors = {}
+
+ conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
+
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
+ prepend_cond = conditioning_inputs["prepend_cond"]
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
+ global_cond = conditioning_inputs["global_cond"]
+
+ if cfg_scale != 1.0:
+
+ # Batch size is doubled to account for negative samples
+ sequence = torch.cat([sequence, sequence], dim=0)
+
+ if cross_attn_cond is not None and cross_attn_use_cfg:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+
+ cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
+
+ if prepend_cond is not None and prepend_use_cfg:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+
+ prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
+
+ if global_cond is not None and global_use_cfg:
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
+
+ global_cond = torch.cat([global_cond, null_embed], dim=0)
+
+ logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
+
+ if cfg_scale != 1.0:
+ cond_logits, uncond_logits = logits.chunk(2, dim=0)
+
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
+
+ logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
+
+ # Grab the logits for the last step
+ logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
+
+ # Apply top-k or top-p sampling
+
+ if temp > 0:
+ probs = torch.softmax(logits / temp, dim=-1)
+
+ if top_p > 0.0:
+ next_token = sample_top_p(probs, p=top_p)
+ elif top_k > 0:
+ next_token = sample_top_k(probs, k=top_k)
+ else:
+ next_token = multinomial(probs, num_samples=1)
+
+ else:
+ next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
+
+ return next_token
+
+ @torch.no_grad()
+ def generate(
+ self,
+ max_gen_len: int = 256,
+ batch_size: tp.Optional[int] = None,
+ init_data: tp.Optional[torch.Tensor] = None,
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
+ conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
+ use_cache: bool = True,
+ cfg_scale: float = 1.0,
+ **kwargs
+ ):
+ device = next(self.parameters()).device
+
+ if conditioning_tensors is None and conditioning is not None:
+ # Convert conditioning inputs to conditioning tensors
+ conditioning_tensors = self.conditioner(conditioning, device)
+
+ # Check that batch size is consistent across inputs
+ possible_batch_sizes = []
+
+ if batch_size is not None:
+ possible_batch_sizes.append(batch_size)
+ elif init_data is not None:
+ possible_batch_sizes.append(init_data.shape[0])
+ elif conditioning_tensors is not None:
+ # Assume that the first conditioning tensor has the batch dimension
+ possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
+ else:
+ possible_batch_sizes.append(1)
+
+ assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
+
+ batch_size = possible_batch_sizes[0]
+
+ if init_data is None:
+ # Initialize with zeros
+ assert batch_size > 0
+ init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
+
+ batch_size, num_quantizers, seq_len = init_data.shape
+
+ start_offset = seq_len
+ assert start_offset < max_gen_len, "init data longer than max gen length"
+
+ pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
+
+ unknown_token = -1
+
+ # Initialize the generated codes with the init data, padded with unknown tokens
+ gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
+ gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
+
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
+
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+ assert start_offset_sequence is not None
+
+ # Generation
+ prev_offset = 0
+ gen_sequence_len = gen_sequence.shape[-1]
+
+ # Reset generation cache
+ if use_cache and self.lm.backbone.use_generation_cache:
+ self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
+
+ for offset in trange(start_offset_sequence, gen_sequence_len):
+
+ # Get the full sequence up to the current offset
+ curr_sequence = gen_sequence[..., prev_offset:offset]
+
+ next_token = self._sample_next_token(
+ curr_sequence,
+ conditioning_tensors=conditioning_tensors,
+ use_cache=use_cache,
+ cfg_scale=cfg_scale,
+ **kwargs
+ )
+
+ valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
+ next_token[~valid_mask] = self.lm.masked_token_id
+
+ # Update the generated sequence with the next token
+ gen_sequence[..., offset:offset+1] = torch.where(
+ gen_sequence[..., offset:offset+1] == unknown_token,
+ next_token,
+ gen_sequence[..., offset:offset+1]
+ )
+
+ if use_cache and self.lm.backbone.use_generation_cache:
+ # Only update the offset if caching is being used
+ prev_offset = offset
+
+ self.lm.backbone.update_generation_cache(offset)
+
+ if callback is not None:
+ # Callback to report progress
+ # Pass in the offset relative to the start of the sequence, and the length of the current sequence
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+
+ assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
+
+ out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+ # sanity checks over the returned codes and corresponding masks
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
+ assert (out_mask[..., :max_gen_len] == 1).all()
+
+ #out_codes = out_codes[..., 0:max_gen_len]
+
+ return out_codes
+
+
+ def generate_audio(
+ self,
+ **kwargs
+ ):
+ """
+ Generate audio from a batch of codes
+ """
+
+ codes = self.generate(**kwargs)
+
+ audio = self.pretransform.decode_tokens(codes)
+
+ return audio
+
+
+def create_audio_lm_from_config(config):
+ model_config = config.get('model', None)
+ assert model_config is not None, 'model config must be specified in config'
+
+ sample_rate = config.get('sample_rate', None)
+ assert sample_rate is not None, "Must specify sample_rate in config"
+
+ lm_config = model_config.get('lm', None)
+ assert lm_config is not None, 'lm config must be specified in model config'
+
+ codebook_pattern = lm_config.get("codebook_pattern", "delay")
+
+ pattern_providers = {
+ 'parallel': ParallelPatternProvider,
+ 'delay': DelayedPatternProvider,
+ 'unroll': UnrolledPatternProvider,
+ 'musiclm': MusicLMPattern,
+ }
+
+ pretransform_config = model_config.get("pretransform", None)
+
+ pretransform = create_pretransform_from_config(pretransform_config, sample_rate)
+
+ assert pretransform.is_discrete, "Pretransform must be discrete"
+
+ min_input_length = pretransform.downsampling_ratio
+
+ pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers)
+
+ conditioning_config = model_config.get('conditioning', None)
+
+ conditioner = None
+ if conditioning_config is not None:
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
+
+ cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
+ prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
+ global_cond_ids = lm_config.get('global_cond_ids', [])
+
+ lm_type = lm_config.get("type", None)
+ lm_model_config = lm_config.get("config", None)
+
+ assert lm_type is not None, "Must specify lm type in lm config"
+ assert lm_model_config is not None, "Must specify lm model config in lm config"
+
+ if lm_type == "x-transformers":
+ backbone = XTransformersAudioLMBackbone(**lm_model_config)
+ elif lm_type == "continuous_transformer":
+ backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
+ else:
+ raise NotImplementedError(f"Unrecognized lm type {lm_type}")
+
+ lm = AudioLanguageModel(
+ pattern_provider=pattern_provider,
+ backbone=backbone,
+ num_quantizers=pretransform.num_quantizers,
+ codebook_size=pretransform.codebook_size
+ )
+
+ model = AudioLanguageModelWrapper(
+ pretransform=pretransform,
+ lm=lm,
+ conditioner=conditioner,
+ sample_rate=sample_rate,
+ min_input_length=min_input_length,
+ cross_attn_cond_ids=cross_attn_cond_ids,
+ prepend_cond_ids=prepend_cond_ids,
+ global_cond_ids=global_cond_ids
+ )
+
+ return model
\ No newline at end of file
diff --git a/ThinkSound/models/lm_backbone.py b/ThinkSound/models/lm_backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..c80cce60b06d9b367b114188444b0890a1990b61
--- /dev/null
+++ b/ThinkSound/models/lm_backbone.py
@@ -0,0 +1,159 @@
+from torch import nn
+from x_transformers import ContinuousTransformerWrapper, Decoder
+
+from .transformer import ContinuousTransformer
+
+# Interface for backbone of a language model
+# Handles conditioning and cross-attention
+# Does not have to deal with patterns or quantizer heads
+class AudioLMBackbone(nn.Module):
+ def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs):
+ super().__init__()
+
+ self.embed_dim = embed_dim
+ self.use_generation_cache = use_generation_cache
+
+ def forward(
+ self,
+ x,
+ cross_attn_cond=None,
+ prepend_cond=None,
+ prepend_cond_mask=None,
+ global_cond=None,
+ use_cache=False,
+ **kwargs
+ ):
+ raise NotImplementedError
+
+ def reset_generation_cache(
+ self,
+ max_seq_len,
+ batch_size,
+ dtype=None
+ ):
+ pass
+
+ def update_generation_cache(
+ self,
+ seqlen_offset
+ ):
+ pass
+
+class XTransformersAudioLMBackbone(AudioLMBackbone):
+ def __init__(self,
+ embed_dim: int,
+ cross_attn_cond_dim: int = 0,
+ prepend_cond_dim: int = 0,
+ **kwargs):
+ super().__init__(embed_dim=embed_dim)
+
+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
+ self.model = ContinuousTransformerWrapper(
+ dim_in=embed_dim,
+ dim_out=embed_dim,
+ max_seq_len=0, #Not relevant without absolute positional embeds,
+ attn_layers=Decoder(
+ dim=embed_dim,
+ attn_flash = True,
+ cross_attend = cross_attn_cond_dim > 0,
+ zero_init_branch_output=True,
+ use_abs_pos_emb = False,
+ rotary_pos_emb=True,
+ ff_swish = True,
+ ff_glu = True,
+ **kwargs
+ )
+ )
+
+ if prepend_cond_dim > 0:
+ # Prepend conditioning
+ self.to_prepend_embed = nn.Sequential(
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+
+ if cross_attn_cond_dim > 0:
+ # Cross-attention conditioning
+ self.to_cross_attn_embed = nn.Sequential(
+ nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+
+ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
+
+ prepend_length = 0
+ if prepend_cond is not None:
+ # Project the prepend conditioning to the embedding dimension
+ prepend_cond = self.to_prepend_embed(prepend_cond)
+ prepend_length = prepend_cond.shape[1]
+
+ if prepend_cond_mask is not None:
+ # Cast mask to bool
+ prepend_cond_mask = prepend_cond_mask.bool()
+
+ if cross_attn_cond is not None:
+ # Project the cross-attention conditioning to the embedding dimension
+ cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
+
+ return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
+
+class ContinuousTransformerAudioLMBackbone(AudioLMBackbone):
+ def __init__(self,
+ embed_dim: int,
+ cross_attn_cond_dim: int = 0,
+ prepend_cond_dim: int = 0,
+ project_cross_attn_cond: bool = False,
+ **kwargs):
+ super().__init__(embed_dim=embed_dim)
+
+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
+ self.model = ContinuousTransformer(
+ dim=embed_dim,
+ dim_in=embed_dim,
+ dim_out=embed_dim,
+ cross_attend = cross_attn_cond_dim > 0,
+ cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim,
+ causal=True,
+ **kwargs
+ )
+
+ if prepend_cond_dim > 0:
+ # Prepend conditioning
+ self.to_prepend_embed = nn.Sequential(
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+
+ if cross_attn_cond_dim > 0 and project_cross_attn_cond:
+ # Cross-attention conditioning
+ self.to_cross_attn_embed = nn.Sequential(
+ nn.Linear(cross_attn_cond_dim, embed_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(embed_dim, embed_dim, bias=False)
+ )
+ else:
+ self.to_cross_attn_embed = nn.Identity()
+
+ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False):
+
+ prepend_length = 0
+ if prepend_cond is not None:
+ # Project the prepend conditioning to the embedding dimension
+ prepend_cond = self.to_prepend_embed(prepend_cond)
+ prepend_length = prepend_cond.shape[1]
+
+ if prepend_cond_mask is not None:
+ # Cast mask to bool
+ prepend_cond_mask = prepend_cond_mask.bool()
+
+ if cross_attn_cond is not None:
+ # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed
+ cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype)
+
+ # Project the cross-attention conditioning to the embedding dimension
+ cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond)
+
+ return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :]
\ No newline at end of file
diff --git a/ThinkSound/models/lm_continuous.py b/ThinkSound/models/lm_continuous.py
new file mode 100644
index 0000000000000000000000000000000000000000..469bb49f32492794345cf76dafbb377778eca81e
--- /dev/null
+++ b/ThinkSound/models/lm_continuous.py
@@ -0,0 +1,525 @@
+from dataclasses import dataclass
+import torch
+from tqdm.auto import trange
+import typing as tp
+from einops import rearrange
+from torch import nn
+
+from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
+from .factory import create_pretransform_from_config
+from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone
+from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform
+from .utils import multinomial, sample_top_k, sample_top_p
+from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper, create_diffusion_cond_from_config
+
+from .codebook_patterns import (
+ CodebooksPatternProvider,
+ DelayedPatternProvider,
+ MusicLMPattern,
+ ParallelPatternProvider,
+ UnrolledPatternProvider
+)
+
+# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license
+# License can be found in LICENSES/LICENSE_META.txt
+
+@dataclass
+class LMContinuousOutput:
+ # The logits are already re-aligned with the input codes
+ # hence no extra shift is required, e.g. when computing CE
+ logits: torch.Tensor # [B, K, T, card]
+ mask: torch.Tensor # [B, K, T]
+
+# Wrapper for a multi-codebook language model
+# Handles patterns and quantizer heads
+class AudioLMContinuousModel(nn.Module):
+ def __init__(
+ self,
+ backbone: AudioLMBackbone,
+ ):
+ super().__init__()
+
+ self.backbone = backbone
+
+ def sample_orders(self, bsz):
+ # generate a batch of random generation orders
+ orders = []
+ for _ in range(bsz):
+ order = np.array(list(range(self.seq_len)))
+ np.random.shuffle(order)
+ orders.append(order)
+ orders = torch.Tensor(np.array(orders)).cuda().long()
+ return orders
+
+ def random_masking(self, x, orders):
+ # generate token mask
+ bsz, seq_len, embed_dim = x.shape
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
+ mask = torch.zeros(bsz, seq_len, device=x.device)
+ mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
+ src=torch.ones(bsz, seq_len, device=x.device))
+ return mask
+
+ def forward(self,
+ sequence: torch.Tensor, #[batch, seq_len,
+ prepend_cond=None, #[batch, seq, channels]
+ prepend_cond_mask=None,
+ cross_attn_cond=None, #[batch, seq, channels],
+ **kwargs
+ ):
+
+
+ batch, seq_len, dim = sequence.shape
+
+ dtype = next(self.parameters()).dtype
+
+ if cross_attn_cond is not None:
+ cross_attn_cond = cross_attn_cond.to(dtype)
+
+ if prepend_cond is not None:
+ prepend_cond = prepend_cond.to(dtype)
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = prepend_cond_mask.to(dtype)
+
+ x = sequence.to(dtype)
+ orders = self.sample_orders(bsz=batch)
+ mask = self.random_masking(x, orders)
+
+ output = self.backbone(
+ x,
+ mask = mask,
+ cross_attn_cond=cross_attn_cond,
+ prepend_cond=prepend_cond,
+ prepend_cond_mask=prepend_cond_mask,
+ **kwargs
+ ) # [batch, seq_len, embed_dim]
+
+
+ return output
+
+# Conditioning and generation wrapper for a multi-codebook language model
+# Handles conditioning, CFG, generation, and encoding/decoding
+class AudioLanguageModelWrapper(nn.Module):
+ def __init__(
+ self,
+ pretransform: Pretransform,
+ lm: AudioLanguageModel,
+ diff: ConditionedDiffusionModelWrapper,
+ sample_rate: int,
+ min_input_length: int,
+ conditioner: MultiConditioner = None,
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
+ cross_attn_cond_ids: tp.List[str] = [],
+ prepend_cond_ids: tp.List[str] = [],
+ global_cond_ids: tp.List[str] = []
+ ):
+ super().__init__()
+
+ assert pretransform.is_discrete, "Pretransform must be discrete"
+ self.pretransform = pretransform
+
+ self.pretransform.requires_grad_(False)
+ self.pretransform.eval()
+ self.diffusion_objective = diffusion_objective
+ print(f'Training in the {diffusion_objective} formulation')
+ if isinstance(self.pretransform, AutoencoderPretransform):
+ self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers
+ self.codebook_size = self.pretransform.model.bottleneck.codebook_size
+ elif isinstance(self.pretransform, PretrainedDACPretransform):
+ self.num_quantizers = self.pretransform.model.num_quantizers
+ self.codebook_size = self.pretransform.model.codebook_size
+ elif isinstance(self.pretransform, AudiocraftCompressionPretransform):
+ self.num_quantizers = self.pretransform.num_quantizers
+ self.codebook_size = self.pretransform.codebook_size
+ else:
+ raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}")
+
+ self.conditioner = conditioner
+
+ self.lm = lm
+
+ self.sample_rate = sample_rate
+ self.min_input_length = min_input_length
+
+ self.cross_attn_cond_ids = cross_attn_cond_ids
+ self.prepend_cond_ids = prepend_cond_ids
+ self.global_cond_ids = global_cond_ids
+
+ def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False):
+ cross_attention_input = None
+ prepend_cond = None
+ prepend_cond_mask = None
+ global_cond = None
+
+ if len(self.cross_attn_cond_ids) > 0:
+ # Concatenate all cross-attention inputs over the sequence dimension
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
+ cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1)
+
+ if len(self.prepend_cond_ids) > 0:
+ # Concatenate all prepend conditioning inputs over the sequence dimension
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
+ prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1)
+ prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1)
+
+ if len(self.global_cond_ids) > 0:
+ # Concatenate all global conditioning inputs over the channel dimension
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
+ global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1)
+ if len(global_cond.shape) == 3:
+ global_cond = global_cond.squeeze(1)
+
+ if negative:
+ return {
+ "negative_cross_attn_cond": cross_attention_input,
+ "negative_prepend_cond": prepend_cond,
+ "negative_prepend_cond_mask": prepend_cond_mask,
+ "negative_global_cond": global_cond
+ }
+ else:
+ return {
+ "cross_attn_cond": cross_attention_input,
+ "prepend_cond": prepend_cond,
+ "prepend_cond_mask": prepend_cond_mask,
+ "global_cond": global_cond
+ }
+
+ def compute_logits(
+ self,
+ audios,
+ condition_tensors=None,
+ cfg_dropout_prob=0.0,
+ **kwargs
+ ):
+ """
+ Compute logits for a batch of codes, and translates from conditioning inputs to model inputs
+ Handles CFG dropout
+ """
+
+ if condition_tensors is None:
+ condition_tensors = {}
+
+ conditioning_inputs = self.get_conditioning_inputs(condition_tensors)
+
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
+ prepend_cond = conditioning_inputs["prepend_cond"]
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
+ global_cond = conditioning_inputs["global_cond"]
+
+ if cfg_dropout_prob > 0.0:
+ if cross_attn_cond is not None:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
+
+ if prepend_cond is not None:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
+
+ if global_cond is not None:
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
+ dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool)
+ global_cond = torch.where(dropout_mask, null_embed, global_cond)
+
+ return self.lm.forward(audios, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
+
+ def _sample_next_token(
+ self,
+ sequence, #[batch, num_quantizers, seq_len]
+ conditioning_tensors=None,
+ cross_attn_use_cfg=True,
+ prepend_use_cfg=True,
+ global_use_cfg=True,
+ cfg_scale=1.0,
+ top_k=250,
+ top_p=0.0,
+ temp=1.0,
+ **kwargs
+ ):
+ """
+ Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs
+ Handles CFG inference
+ """
+
+ if conditioning_tensors is None:
+ conditioning_tensors = {}
+
+ conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors)
+
+ cross_attn_cond = conditioning_inputs["cross_attn_cond"]
+ prepend_cond = conditioning_inputs["prepend_cond"]
+ prepend_cond_mask = conditioning_inputs["prepend_cond_mask"]
+ global_cond = conditioning_inputs["global_cond"]
+
+ if cfg_scale != 1.0:
+
+ # Batch size is doubled to account for negative samples
+ sequence = torch.cat([sequence, sequence], dim=0)
+
+ if cross_attn_cond is not None and cross_attn_use_cfg:
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
+
+ cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
+
+ if prepend_cond is not None and prepend_use_cfg:
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
+
+ prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
+
+ if prepend_cond_mask is not None:
+ prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
+
+ if global_cond is not None and global_use_cfg:
+ null_embed = torch.zeros_like(global_cond, device=global_cond.device)
+
+ global_cond = torch.cat([global_cond, null_embed], dim=0)
+
+ logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs)
+
+ if cfg_scale != 1.0:
+ cond_logits, uncond_logits = logits.chunk(2, dim=0)
+
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
+
+ logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len]
+
+ # Grab the logits for the last step
+ logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size]
+
+ # Apply top-k or top-p sampling
+
+ if temp > 0:
+ probs = torch.softmax(logits / temp, dim=-1)
+
+ if top_p > 0.0:
+ next_token = sample_top_p(probs, p=top_p)
+ elif top_k > 0:
+ next_token = sample_top_k(probs, k=top_k)
+ else:
+ next_token = multinomial(probs, num_samples=1)
+
+ else:
+ next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1]
+
+ return next_token
+
+ @torch.no_grad()
+ def generate(
+ self,
+ max_gen_len: int = 256,
+ batch_size: tp.Optional[int] = None,
+ init_data: tp.Optional[torch.Tensor] = None,
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None,
+ conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None,
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
+ use_cache: bool = True,
+ cfg_scale: float = 1.0,
+ **kwargs
+ ):
+ device = next(self.parameters()).device
+
+ if conditioning_tensors is None and conditioning is not None:
+ # Convert conditioning inputs to conditioning tensors
+ conditioning_tensors = self.conditioner(conditioning, device)
+
+ # Check that batch size is consistent across inputs
+ possible_batch_sizes = []
+
+ if batch_size is not None:
+ possible_batch_sizes.append(batch_size)
+ elif init_data is not None:
+ possible_batch_sizes.append(init_data.shape[0])
+ elif conditioning_tensors is not None:
+ # Assume that the first conditioning tensor has the batch dimension
+ possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0])
+ else:
+ possible_batch_sizes.append(1)
+
+ assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs"
+
+ batch_size = possible_batch_sizes[0]
+
+ if init_data is None:
+ # Initialize with zeros
+ assert batch_size > 0
+ init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long)
+
+ batch_size, num_quantizers, seq_len = init_data.shape
+
+ start_offset = seq_len
+ assert start_offset < max_gen_len, "init data longer than max gen length"
+
+ pattern = self.lm.pattern_provider.get_pattern(max_gen_len)
+
+ unknown_token = -1
+
+ # Initialize the generated codes with the init data, padded with unknown tokens
+ gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long)
+ gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len]
+
+ gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len]
+
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+ assert start_offset_sequence is not None
+
+ # Generation
+ prev_offset = 0
+ gen_sequence_len = gen_sequence.shape[-1]
+
+ # Reset generation cache
+ if use_cache and self.lm.backbone.use_generation_cache:
+ self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2)
+
+ for offset in trange(start_offset_sequence, gen_sequence_len):
+
+ # Get the full sequence up to the current offset
+ curr_sequence = gen_sequence[..., prev_offset:offset]
+
+ next_token = self._sample_next_token(
+ curr_sequence,
+ conditioning_tensors=conditioning_tensors,
+ use_cache=use_cache,
+ cfg_scale=cfg_scale,
+ **kwargs
+ )
+
+ valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1)
+ next_token[~valid_mask] = self.lm.masked_token_id
+
+ # Update the generated sequence with the next token
+ gen_sequence[..., offset:offset+1] = torch.where(
+ gen_sequence[..., offset:offset+1] == unknown_token,
+ next_token,
+ gen_sequence[..., offset:offset+1]
+ )
+
+ if use_cache and self.lm.backbone.use_generation_cache:
+ # Only update the offset if caching is being used
+ prev_offset = offset
+
+ self.lm.backbone.update_generation_cache(offset)
+
+ if callback is not None:
+ # Callback to report progress
+ # Pass in the offset relative to the start of the sequence, and the length of the current sequence
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+
+ assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence"
+
+ out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+ # sanity checks over the returned codes and corresponding masks
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
+ assert (out_mask[..., :max_gen_len] == 1).all()
+
+ #out_codes = out_codes[..., 0:max_gen_len]
+
+ return out_codes
+
+
+ def generate_audio(
+ self,
+ **kwargs
+ ):
+ """
+ Generate audio from a batch of codes
+ """
+
+ codes = self.generate(**kwargs)
+
+ audio = self.pretransform.decode_tokens(codes)
+
+ return audio
+
+
+def create_audio_lm_continuous_from_config(config):
+ model_config = config.get('model', None)
+ assert model_config is not None, 'model config must be specified in config'
+
+ sample_rate = config.get('sample_rate', None)
+ assert sample_rate is not None, "Must specify sample_rate in config"
+
+ lm_config = model_config.get('lm', None)
+ assert lm_config is not None, 'lm config must be specified in model config'
+
+
+
+ pretransform_config = model_config.get("pretransform", None)
+
+ if pretransform is not None:
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
+ min_input_length = pretransform.downsampling_ratio
+ else:
+ min_input_length = 1
+
+
+ conditioning_config = model_config.get('conditioning', None)
+
+ conditioner = None
+ if conditioning_config is not None:
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
+
+ cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', [])
+ prepend_cond_ids = lm_config.get('prepend_cond_ids', [])
+ global_cond_ids = lm_config.get('global_cond_ids', [])
+
+ lm_type = lm_config.get("type", None)
+ lm_model_config = lm_config.get("config", None)
+
+ assert lm_type is not None, "Must specify lm type in lm config"
+ assert lm_model_config is not None, "Must specify lm model config in lm config"
+
+ if lm_type == "x-transformers":
+ backbone = XTransformersAudioLMBackbone(**lm_model_config)
+ elif lm_type == "continuous_transformer":
+ backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config)
+ else:
+ raise NotImplementedError(f"Unrecognized lm type {lm_type}")
+
+ lm = AudioLanguageModel(
+ pattern_provider=pattern_provider,
+ backbone=backbone,
+ num_quantizers=pretransform.num_quantizers,
+ codebook_size=pretransform.codebook_size
+ )
+
+ diff_config = model_config.get("diffusion", None)
+ diffusion_model = DiTWrapper(**diff_config)
+
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
+ add_cond_ids = diffusion_config.get('add_cond_ids', [])
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
+
+ diff = ConditionedDiffusionModelWrapper(
+ diffusion_model,
+ conditioner=None,
+ min_input_length=min_input_length,
+ sample_rate=sample_rate,
+ cross_attn_cond_ids=cross_attention_ids,
+ global_cond_ids=global_cond_ids,
+ input_concat_ids=input_concat_ids,
+ prepend_cond_ids=prepend_cond_ids,
+ add_cond_ids=add_cond_ids,
+ pretransform=pretransform,
+ io_channels=2,
+ )
+
+
+ model = AudioLanguageModelWrapper(
+ pretransform=pretransform,
+ lm=lm,
+ diff=diff,
+ conditioner=conditioner,
+ sample_rate=sample_rate,
+ min_input_length=min_input_length,
+ cross_attn_cond_ids=cross_attn_cond_ids,
+ prepend_cond_ids=prepend_cond_ids,
+ global_cond_ids=global_cond_ids
+ )
+
+ return model
\ No newline at end of file
diff --git a/ThinkSound/models/local_attention.py b/ThinkSound/models/local_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..893ce11fce1f263dd02ff2a2ebe8b5e67426f83f
--- /dev/null
+++ b/ThinkSound/models/local_attention.py
@@ -0,0 +1,278 @@
+import torch
+
+from einops import rearrange
+from torch import nn
+
+from .blocks import AdaRMSNorm
+from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
+
+def checkpoint(function, *args, **kwargs):
+ kwargs.setdefault("use_reentrant", False)
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
+
+# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
+class ContinuousLocalTransformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_in = None,
+ dim_out = None,
+ causal = False,
+ local_attn_window_size = 64,
+ heads = 8,
+ ff_mult = 2,
+ cond_dim = 0,
+ cross_attn_cond_dim = 0,
+ **kwargs
+ ):
+ super().__init__()
+
+ dim_head = dim//heads
+
+ self.layers = nn.ModuleList([])
+
+ self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
+
+ self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
+
+ self.local_attn_window_size = local_attn_window_size
+
+ self.cond_dim = cond_dim
+
+ self.cross_attn_cond_dim = cross_attn_cond_dim
+
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
+
+ for _ in range(depth):
+
+ self.layers.append(nn.ModuleList([
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
+ Attention(
+ dim=dim,
+ dim_heads=dim_head,
+ causal=causal,
+ zero_init_output=True,
+ natten_kernel_size=local_attn_window_size,
+ ),
+ Attention(
+ dim=dim,
+ dim_heads=dim_head,
+ dim_context = cross_attn_cond_dim,
+ zero_init_output=True
+ ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
+ FeedForward(dim = dim, mult = ff_mult, no_bias=True)
+ ]))
+
+ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
+
+ x = checkpoint(self.project_in, x)
+
+ if prepend_cond is not None:
+ x = torch.cat([prepend_cond, x], dim=1)
+
+ pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
+
+ for attn_norm, attn, xattn, ff_norm, ff in self.layers:
+
+ residual = x
+ if cond is not None:
+ x = checkpoint(attn_norm, x, cond)
+ else:
+ x = checkpoint(attn_norm, x)
+
+ x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
+
+ if cross_attn_cond is not None:
+ x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
+
+ residual = x
+
+ if cond is not None:
+ x = checkpoint(ff_norm, x, cond)
+ else:
+ x = checkpoint(ff_norm, x)
+
+ x = checkpoint(ff, x) + residual
+
+ return checkpoint(self.project_out, x)
+
+class TransformerDownsampleBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ embed_dim = 768,
+ depth = 3,
+ heads = 12,
+ downsample_ratio = 2,
+ local_attn_window_size = 64,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.downsample_ratio = downsample_ratio
+
+ self.transformer = ContinuousLocalTransformer(
+ dim=embed_dim,
+ depth=depth,
+ heads=heads,
+ local_attn_window_size=local_attn_window_size,
+ **kwargs
+ )
+
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
+
+ self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
+
+
+ def forward(self, x):
+
+ x = checkpoint(self.project_in, x)
+
+ # Compute
+ x = self.transformer(x)
+
+ # Trade sequence length for channels
+ x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
+
+ # Project back to embed dim
+ x = checkpoint(self.project_down, x)
+
+ return x
+
+class TransformerUpsampleBlock1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ embed_dim,
+ depth = 3,
+ heads = 12,
+ upsample_ratio = 2,
+ local_attn_window_size = 64,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.upsample_ratio = upsample_ratio
+
+ self.transformer = ContinuousLocalTransformer(
+ dim=embed_dim,
+ depth=depth,
+ heads=heads,
+ local_attn_window_size = local_attn_window_size,
+ **kwargs
+ )
+
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
+
+ self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
+
+ def forward(self, x):
+
+ # Project to embed dim
+ x = checkpoint(self.project_in, x)
+
+ # Project to increase channel dim
+ x = checkpoint(self.project_up, x)
+
+ # Trade channels for sequence length
+ x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
+
+ # Compute
+ x = self.transformer(x)
+
+ return x
+
+
+class TransformerEncoder1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ embed_dims = [96, 192, 384, 768],
+ heads = [12, 12, 12, 12],
+ depths = [3, 3, 3, 3],
+ ratios = [2, 2, 2, 2],
+ local_attn_window_size = 64,
+ **kwargs
+ ):
+ super().__init__()
+
+ layers = []
+
+ for layer in range(len(depths)):
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
+
+ layers.append(
+ TransformerDownsampleBlock1D(
+ in_channels = prev_dim,
+ embed_dim = embed_dims[layer],
+ heads = heads[layer],
+ depth = depths[layer],
+ downsample_ratio = ratios[layer],
+ local_attn_window_size = local_attn_window_size,
+ **kwargs
+ )
+ )
+
+ self.layers = nn.Sequential(*layers)
+
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
+
+ def forward(self, x):
+ x = rearrange(x, "b c n -> b n c")
+ x = checkpoint(self.project_in, x)
+ x = self.layers(x)
+ x = checkpoint(self.project_out, x)
+ x = rearrange(x, "b n c -> b c n")
+
+ return x
+
+
+class TransformerDecoder1D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ embed_dims = [768, 384, 192, 96],
+ heads = [12, 12, 12, 12],
+ depths = [3, 3, 3, 3],
+ ratios = [2, 2, 2, 2],
+ local_attn_window_size = 64,
+ **kwargs
+ ):
+
+ super().__init__()
+
+ layers = []
+
+ for layer in range(len(depths)):
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
+
+ layers.append(
+ TransformerUpsampleBlock1D(
+ in_channels = prev_dim,
+ embed_dim = embed_dims[layer],
+ heads = heads[layer],
+ depth = depths[layer],
+ upsample_ratio = ratios[layer],
+ local_attn_window_size = local_attn_window_size,
+ **kwargs
+ )
+ )
+
+ self.layers = nn.Sequential(*layers)
+
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
+
+ def forward(self, x):
+ x = rearrange(x, "b c n -> b n c")
+ x = checkpoint(self.project_in, x)
+ x = self.layers(x)
+ x = checkpoint(self.project_out, x)
+ x = rearrange(x, "b n c -> b c n")
+ return x
\ No newline at end of file
diff --git a/ThinkSound/models/meta_queries/__init__.py b/ThinkSound/models/meta_queries/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ThinkSound/models/meta_queries/metaquery.py b/ThinkSound/models/meta_queries/metaquery.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce46be39f25f4b6902ee73147d2793316b12448
--- /dev/null
+++ b/ThinkSound/models/meta_queries/metaquery.py
@@ -0,0 +1,435 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from diffusers.models import AutoencoderKL, AutoencoderDC
+from diffusers.pipelines.pipeline_utils import numpy_to_pil
+from diffusers.schedulers import (
+ DDPMScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ DPMSolverMultistepScheduler,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import PreTrainedModel
+import PIL
+from tqdm import tqdm
+
+from .model import MLLMInContextConfig, MLLMInContext
+from diffusers.training_utils import (
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+
+
+class MetaQueryConfig(MLLMInContextConfig):
+ model_type = "metaquery"
+
+ def __init__(
+ self,
+ vae_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers",
+ input_size: int = 16,
+ in_channels: int = 32,
+ vae_downsample_f: int = 32,
+ noise_scheduler_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers",
+ scheduler_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers",
+ _gradient_checkpointing: bool = True,
+ loss_type: str = "flow",
+ num_metaqueries: int = 64,
+ modules_to_freeze: tuple[str] = (),
+ modules_to_unfreeze: tuple[str] = (),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+ self.vae_id = vae_id
+ self.input_size = input_size
+ self.in_channels = in_channels
+ self.vae_downsample_f = vae_downsample_f
+ self.noise_scheduler_id = noise_scheduler_id
+ self.scheduler_id = scheduler_id
+ self._gradient_checkpointing = _gradient_checkpointing
+ self.loss_type = loss_type
+ self.num_metaqueries = num_metaqueries
+ self.modules_to_freeze = modules_to_freeze
+ self.modules_to_unfreeze = modules_to_unfreeze
+
+
+class MetaQuery(PreTrainedModel):
+ config_class = MetaQueryConfig
+
+ def __init__(self, config, *args, **kwargs):
+ super().__init__(config, *args, **kwargs)
+ self.config = config
+
+ self.model = MLLMInContext(MLLMInContextConfig(**config.to_dict()))
+ self.loss_type = config.loss_type
+
+ if "Sana" in config.vae_id:
+ self.vae = AutoencoderDC.from_pretrained(config.vae_id, subfolder="vae")
+ else:
+ try:
+ self.vae = AutoencoderKL.from_pretrained(config.vae_id)
+ except:
+ self.vae = AutoencoderKL.from_pretrained(config.vae_id, subfolder="vae")
+
+ if self.loss_type == "flow":
+ self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ config.noise_scheduler_id, subfolder="scheduler"
+ )
+ elif self.loss_type == "diff":
+ self.noise_scheduler = DDPMScheduler.from_pretrained(
+ config.noise_scheduler_id, subfolder="scheduler"
+ )
+ else:
+ raise ValueError(f"Unknown loss type {self.loss_type}")
+
+ self.scheduler = DPMSolverMultistepScheduler.from_pretrained(
+ config.scheduler_id, subfolder="scheduler"
+ )
+
+ for module_name in config.modules_to_freeze:
+ if "." in module_name:
+ module = self
+ for sub_module_name in module_name.split("."):
+ module = getattr(module, sub_module_name, None)
+ if module is None:
+ break
+ else:
+ module.requires_grad_(False)
+ else:
+ module = getattr(self, module_name, None)
+ if module is not None:
+ module.requires_grad_(False)
+
+ for module_name in config.modules_to_unfreeze:
+ if "." in module_name:
+ module = self
+ for sub_module_name in module_name.split("."):
+ module = getattr(module, sub_module_name, None)
+ if module is None:
+ break
+ else:
+ module.requires_grad_(True)
+ else:
+ module = getattr(self, module_name, None)
+ if module is not None:
+ module.requires_grad_(True)
+
+ def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
+ sigmas = self.noise_scheduler.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = self.noise_scheduler.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ def get_tokenizer(self):
+ return self.model.get_tokenizer()
+
+ def get_tokenize_fn(self):
+ return self.model.get_tokenize_fn()
+
+ def forward(
+ self, target, pixel_values=None, input_ids=None, attention_mask=None, **kwargs
+ ):
+ if self.vae is not None:
+ if isinstance(self.vae, AutoencoderKL):
+ latents = self.vae.encode(target).latent_dist.sample()
+ elif isinstance(self.vae, AutoencoderDC):
+ latents = self.vae.encode(target).latent
+ else:
+ raise ValueError(f"Unknown vae type {type(self.vae)}")
+ if (
+ "shift_factor" in self.vae.config
+ and self.vae.config.shift_factor is not None
+ ):
+ latents = latents - self.vae.config.shift_factor
+ latents = latents * self.vae.config.scaling_factor
+ else:
+ latents = target
+
+ bsz = latents.shape[0]
+
+ if (
+ pixel_values is not None
+ and hasattr(self.model, "mllm_type")
+ and self.model.mllm_type == "qwenvl"
+ ):
+ pixel_values = pixel_values.squeeze(0)
+
+ noise = torch.randn_like(latents, device=latents.device)
+
+ if self.loss_type == "flow":
+ weighting_scheme = "uniform"
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=weighting_scheme,
+ batch_size=bsz,
+ logit_mean=0.0,
+ logit_std=1.0,
+ mode_scale=1.29,
+ )
+ indices = (u * self.noise_scheduler.config.num_train_timesteps).long()
+ timesteps = self.noise_scheduler.timesteps[indices].to(
+ device=latents.device
+ )
+
+ sigmas = self.get_sigmas(
+ timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype
+ )
+ noisy_latents = (1.0 - sigmas) * latents + sigmas * noise
+ prompt_embeds, attention_mask = self.model.encode_condition(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ image_sizes=kwargs.get("image_sizes", None),
+ )
+
+ model_pred = self.model(
+ x=noisy_latents,
+ timestep=timesteps,
+ prompt_embeds=prompt_embeds,
+ attention_mask=attention_mask,
+ )
+
+ target = noise - latents
+ weighting = compute_loss_weighting_for_sd3(
+ weighting_scheme=weighting_scheme, sigmas=sigmas
+ )
+ loss = torch.mean(
+ (
+ weighting.float() * (model_pred.float() - target.float()) ** 2
+ ).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ elif self.loss_type == "diff":
+ # Sample a random timestep for each image
+ timesteps = torch.randint(
+ 0,
+ self.noise_scheduler.config.num_train_timesteps,
+ (bsz,),
+ device=latents.device,
+ )
+ timesteps = timesteps.long()
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+
+ if self.noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
+ )
+
+ prompt_embeds, attention_mask = self.model.encode_condition(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ image_sizes=kwargs.get("image_sizes", None),
+ )
+
+ noise_pred = self.model(
+ x=noisy_latents,
+ timestep=timesteps,
+ prompt_embeds=prompt_embeds,
+ attention_mask=attention_mask,
+ )
+ loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ return {"loss": loss}
+
+ @torch.no_grad()
+ def decode_latents(self, latents, normalize=True, return_tensor=False):
+ if self.vae is not None:
+ latents = latents / self.vae.config.scaling_factor
+ if (
+ "shift_factor" in self.vae.config
+ and self.vae.config.shift_factor is not None
+ ):
+ latents = latents + self.vae.config.shift_factor
+ samples = self.vae.decode(latents).sample
+ else:
+ samples = latents
+ if normalize:
+ samples = (samples / 2 + 0.5).clamp(0, 1)
+ else:
+ samples = samples.clamp(-1, 1)
+ if return_tensor:
+ return samples
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
+ samples = numpy_to_pil(samples)
+ return samples
+
+ def sample_images(
+ self,
+ caption="",
+ input_images=None,
+ guidance_scale: float = 3.0,
+ image_guidance_scale: float = 1.5,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ num_inference_steps: int = 30,
+ num_images_per_prompt: int = 1,
+ return_tensor=False,
+ negative_prompt="",
+ enable_progress_bar=False,
+ **kwargs,
+ ):
+ device = next(self.parameters()).device
+
+ if not isinstance(caption, list):
+ caption = [caption]
+ if input_images is not None:
+ if isinstance(input_images, list) and not isinstance(input_images[0], list):
+ input_images = [[img] for img in input_images]
+ elif isinstance(input_images, PIL.Image.Image):
+ input_images = [[input_images]]
+ assert isinstance(input_images, list) and all(
+ isinstance(sublist, list) for sublist in input_images
+ ), "input_images needs to be a nested list"
+
+ bsz = len(caption)
+ do_image_classifier_free_guidance = image_guidance_scale > 1.0
+
+ tokenize_func = self.get_tokenize_fn()
+ tokenizer = self.get_tokenizer()
+
+ if input_images is not None:
+ if do_image_classifier_free_guidance:
+ caption = [negative_prompt] * bsz * 2 + caption
+ input_images_null = [
+ (
+ [
+ PIL.Image.new("RGB", (img.size[0], img.size[1]))
+ for img in images
+ ]
+ if images
+ else None
+ )
+ for images in input_images
+ ]
+ input_images = input_images_null + input_images * 2
+ else:
+ caption = [negative_prompt] * bsz + caption
+ input_images = input_images * 2
+ input_ids, attention_mask, pixel_values, image_sizes = tokenize_func(
+ tokenizer, caption, input_images
+ )
+ else:
+ do_image_classifier_free_guidance = False
+ caption = [negative_prompt] * bsz + caption
+ input_ids, attention_mask = tokenize_func(tokenizer, caption)
+ pixel_values = None
+ image_sizes = None
+
+ latent_size = self.config.input_size
+ latent_channels = self.config.in_channels
+
+ latents = randn_tensor(
+ shape=(
+ bsz * num_images_per_prompt,
+ latent_channels,
+ latent_size,
+ latent_size,
+ ),
+ generator=generator,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ # set step values
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ self.scheduler.set_timesteps(num_inference_steps, sigmas=sigmas)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # Repeat pixel_values and conditions for each image per prompt
+ input_ids = input_ids.to(device=device).repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ attention_mask = attention_mask.to(device=device).repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ pixel_values = (
+ pixel_values.to(device=device)
+ .reshape(bsz, -1, *pixel_values.shape[1:])
+ .repeat_interleave(num_images_per_prompt, dim=0)
+ .flatten(0, 1)
+ if pixel_values is not None
+ else None
+ )
+ image_sizes = (
+ image_sizes.to(device=device).repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ if image_sizes is not None
+ else None
+ )
+
+ prompt_embeds, attention_mask = self.model.encode_condition(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ image_sizes=image_sizes,
+ )
+ # Convert to float32 before saving
+ for t in tqdm(
+ self.scheduler.timesteps,
+ desc="Sampling images",
+ disable=not enable_progress_bar,
+ ):
+ latent_model_input = torch.cat([latents] * (len(input_ids) // len(latents)))
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+
+ # predict noise model_output
+ noise_pred = self.model(
+ x=latent_model_input,
+ timestep=t.unsqueeze(0)
+ .expand(latent_model_input.shape[0])
+ .to(latents.device),
+ prompt_embeds=prompt_embeds,
+ attention_mask=attention_mask,
+ )
+
+ # perform guidance
+ if do_image_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_uncond_text, noise_pred = (
+ noise_pred.chunk(3)
+ )
+ noise_pred = (
+ noise_pred_uncond
+ + image_guidance_scale
+ * (noise_pred_uncond_text - noise_pred_uncond)
+ + guidance_scale * (noise_pred - noise_pred_uncond_text)
+ )
+ else:
+ noise_pred_uncond, noise_pred = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred - noise_pred_uncond
+ )
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ samples = self.decode_latents(
+ latents.to(self.vae.dtype) if self.vae is not None else latents,
+ return_tensor=return_tensor,
+ )
+ return samples
diff --git a/ThinkSound/models/meta_queries/model.py b/ThinkSound/models/meta_queries/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3fd49111724cba138cfa151124d3d91b4877235
--- /dev/null
+++ b/ThinkSound/models/meta_queries/model.py
@@ -0,0 +1,578 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import random
+import math
+from typing import List
+import typing as tp
+import torch
+import os
+os.environ["TOKENIZERS_PARALLELISM"] = "true"
+USE_AUDIO_IN_VIDEO_RATIO = 1.0
+
+from torch import nn
+from torchvision import transforms as v2
+
+from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, Qwen2Config
+
+import time
+from diffusers.models.normalization import RMSNorm
+
+from transformers.video_utils import load_video
+
+from .transformer_encoder import Qwen2Encoder
+
+from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
+def to_device(
+ data: Any,
+ device: Union[str, torch.device, int],
+ dtype: Optional[torch.dtype] = None, # 新增
+ non_blocking: bool = False
+) -> Any:
+ """Move inputs to a device and optionally convert dtype"""
+ if isinstance(data, Mapping):
+ return type(data)({
+ k: to_device(v, device, dtype, non_blocking)
+ for k, v in data.items()
+ })
+ elif isinstance(data, (tuple, list)):
+ return type(data)(
+ to_device(v, device, dtype, non_blocking)
+ for v in data
+ )
+ elif isinstance(data, torch.Tensor):
+ tensor = data.to(device=device, non_blocking=non_blocking)
+ if dtype is not None and tensor.is_floating_point():
+ tensor = tensor.to(dtype=dtype)
+ return tensor
+ else:
+ return data
+
+VIDEO_MIN_PIXELS=224*224
+VIDEO_MAX_PIXELS=224*224
+
+import torch.distributed as dist
+def print_memory_summary(prefix: str = ""):
+
+ if not torch.cuda.is_available():
+ return
+
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ device = torch.cuda.current_device()
+
+ allocated = torch.cuda.memory_allocated(device) / 1024**3
+ total = torch.cuda.get_device_properties(device).total_memory / 1024**3
+ usage = (allocated / total * 100) if total > 0 else 0
+
+ print(f"[Rank {rank}] {prefix} | GPU Memory: {allocated:.2f}/{total:.2f} GB ({usage:.1f}%)")
+
+
+class MLLMInContextConfig(PretrainedConfig):
+ model_type = "mllm-in-context"
+
+ def __init__(
+ self,
+ mllm_id: str = "Qwen/Qwen2.5-VL-3B-Instruct",
+ diffusion_model_id: str = None,
+ num_metaqueries: int = None,
+ _gradient_checkpointing: bool = True,
+ max_input_text_tokens: int = 2560,
+ connector_num_hidden_layers: int = None,
+ system_prompt: str = "You will be given an video and its caption. Please describe the content of the video in detail in your own words.",
+ **kwargs,
+ ):
+ super().__init__()
+ self.mllm_id = mllm_id
+ self.diffusion_model_id = diffusion_model_id
+ self.num_metaqueries = num_metaqueries
+ self._gradient_checkpointing = _gradient_checkpointing
+ self.max_input_text_tokens = max_input_text_tokens
+ self.connector_num_hidden_layers = connector_num_hidden_layers
+ self.system_prompt = system_prompt
+
+import numpy as np
+import torchvision.transforms as T
+import torch.nn.functional as F
+
+
+
+default_config = MLLMInContextConfig()
+
+class MLLMInContext(PreTrainedModel):
+
+ def __init__(
+ self,
+ output_dim: int,
+ query_len: int,
+ llm_id = "qwen_omni",
+ connection_layers=12,
+ config: MLLMInContextConfig = default_config,
+ ) -> None:
+ super().__init__(config)
+ self._gradient_checkpointing = config._gradient_checkpointing
+ self.config = config
+ config.num_metaqueries = query_len
+ config.connector_num_hidden_layers = connection_layers
+ print("use meta queries: ",query_len,flush=True)
+
+ if llm_id == "qwen_vl":
+ config.mllm_id = "Qwen/Qwen2.5-VL-3B-Instruct"
+ elif llm_id == "qwen_omni":
+ config.mllm_id = "Qwen/Qwen2.5-Omni-3B"
+ else:
+ raise ValueError(f"Unsupported model: {llm_id}")
+
+ if "Qwen2.5-VL" in config.mllm_id:
+ from .models.qwen25VL import (
+ Qwen2_5_VLForConditionalGeneration
+ )
+ self.mllm_type = "qwenvl"
+ elif "Qwen2.5-Omni" in config.mllm_id:
+ from .models.qwen25omni import (
+ Qwen2_5OmniForConditionalGeneration
+ )
+ self.mllm_type = "qwenomni"
+ elif "Qwen" in config.mllm_id:
+ self.mllm_type = "qwenlm"
+ elif "Llama" in config.mllm_id:
+ self.mllm_type = "llamaml"
+ else:
+ self.mllm_type = "llavaov"
+
+ if self.mllm_type == "qwenvl":
+ self.mllm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ config.mllm_id, attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16
+ )
+ self.mllm_backbone.model.config.use_sliding_window = False
+ self.mllm_backbone.model.config.sliding_window = None
+ #print(self.mllm_backbone.model)
+
+
+ self._freeze_mllm_backbone()
+
+ num_embeddings = self.mllm_backbone.get_input_embeddings().num_embeddings
+ self.num_embeddings = num_embeddings
+ if config.num_metaqueries > 0:
+ try:
+ self.mllm_backbone.resize_token_embeddings(
+ num_embeddings + config.num_metaqueries + 2
+ )
+ except:
+ self.mllm_backbone.resize_token_embeddings(
+ num_embeddings + config.num_metaqueries + 2, mean_resizing=False
+ )
+
+ def freeze_hook(grad):
+ grad[: self.num_embeddings].zero_()
+ return grad
+
+ self.mllm_backbone.model.embed_tokens.weight.register_hook(freeze_hook)
+ self.mllm_hidden_size = self.mllm_backbone.config.hidden_size
+ self.mllm_backbone.lm_head = nn.Identity()
+
+ self.tokenizer = AutoProcessor.from_pretrained(
+ config.mllm_id, video_min_pixels=VIDEO_MIN_PIXELS, video_max_pixels=VIDEO_MAX_PIXELS,use_fast=True,min_pixels=224*224,max_pixels=288*288
+ )
+ self.tokenizer.tokenizer.padding_side = "left"
+ self.tokenizer.resize_fn = None
+ #self.tokenizer.image_processor.size = {
+ # "height": 224,
+ # "width": 224
+ #}
+ # 3B 2048
+ # 7B 3584
+ self.tokenizer.system_prompt = config.system_prompt
+ elif self.mllm_type == "qwenomni":
+ self.mllm_backbone = Qwen2_5OmniForConditionalGeneration.from_pretrained(
+ config.mllm_id, attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16
+ )
+ #self.mllm_backbone.disable_talker()
+ self.mllm_backbone.thinker.model.config.use_sliding_window = False
+ self.mllm_backbone.thinker.model.config.sliding_window = None
+ self._freeze_mllm_backbone()
+
+ num_embeddings = self.mllm_backbone.thinker.get_input_embeddings().num_embeddings
+ self.num_embeddings = num_embeddings
+ if config.num_metaqueries > 0:
+ try:
+ self.mllm_backbone.thinker.resize_token_embeddings(
+ num_embeddings + config.num_metaqueries + 2
+ )
+ except:
+ self.mllm_backbone.thinker.resize_token_embeddings(
+ num_embeddings + config.num_metaqueries + 2, mean_resizing=False
+ )
+
+ def freeze_hook(grad):
+ grad[: self.num_embeddings].zero_()
+ return grad
+
+ self.mllm_backbone.thinker.model.embed_tokens.weight.register_hook(freeze_hook)
+ self.mllm_hidden_size = self.mllm_backbone.thinker.model.config.hidden_size
+ self.mllm_backbone.thinker.lm_head = nn.Identity()
+
+ self.tokenizer = AutoProcessor.from_pretrained(
+ config.mllm_id, video_min_pixels=VIDEO_MIN_PIXELS, video_max_pixels=VIDEO_MAX_PIXELS,use_fast=True,min_pixels=224*224,max_pixels=288*288
+ )
+ self.tokenizer.tokenizer.padding_side = "left"
+ self.tokenizer.resize_fn = None
+ self.tokenizer.system_prompt = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."
+
+ else:
+ raise ValueError(f"Unsupported model: {self.mllm_type}")
+
+
+
+ self.tokenizer.mllm_type = self.mllm_type
+ self.tokenizer.max_input_text_tokens = config.max_input_text_tokens
+ self.tokenizer.num_metaqueries = config.num_metaqueries
+
+ self.pad_token_id = getattr(
+ self.tokenizer, "tokenizer", self.tokenizer
+ ).pad_token_id
+ if config.num_metaqueries > 0:
+ tokenizer = getattr(self.tokenizer, "tokenizer", self.tokenizer)
+ tokenizer.add_special_tokens(
+ {
+ "additional_special_tokens": [
+ f""
+ for i in range(num_embeddings - len(tokenizer))
+ ]
+ }
+ )
+ tokenizer.add_special_tokens(
+ {
+ "additional_special_tokens": ["", ""]
+ + [f"