diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4d39f7949baafdf858981168c335acbe6bc9186f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index b471107db32daa7020643e0a22a5ec45312510f2..4bdd60e166cfc0d1f286e0fd9147851178235253 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,156 @@ +

PrismAudio

+ + +

+ ICLR 2026 +

+ +

+ + arXiv + +   + + Online Demo + + +

+ +

+ If you find this project useful,
+ a star ⭐ on GitHub would be greatly appreciated! +

+ --- -title: PrismAudio -emoji: 📉 -colorFrom: purple -colorTo: green -sdk: gradio -sdk_version: 6.9.0 -app_file: app.py -pinned: false -license: apache-2.0 + +**PrismAudio** is the first framework to integrate Reinforcement Learning into Video-to-Audio (V2A) generation with specialized Chain-of-Thought (CoT) planning. Building upon [ThinkSound](https://arxiv.org/pdf/2506.21448)'s pioneering CoT-based V2A framework, PrismAudio further decomposes monolithic reasoning into four specialized CoT modules (Semantic, Temporal, Aesthetic, and Spatial), each paired with targeted reward functions, enabling multi-dimensional RL optimization that jointly improves reasoning across all perceptual dimensions. + + + --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +## 📰 News + +- **2026.03.22**   🔥 We have released **PrismAudio**, our next-generation video-to-audio generation model! For more details, please refer to the [`prismaudio`](https://github.com/liuhuadai/ThinkSound/tree/prismaudio) branch! +- **2026.01.26**   🎉 PrismAudio has been accepted to the **ICLR 2026 Main Conference**! We plan to release the project in February 2026. +- **2025.11.25**   🔥 [Online PrismAudio Demo](http://prismaudio-project.github.io/) is live - try it now! +- **2025.11.25**   🔥 [PrismAudio paper](https://arxiv.org/pdf/2511.18833) released on arXiv, the first multi-dimensional CoT-RL framework for Video-to-Audio Generation! +- **2025.09.19**   🎉 ThinkSound has been accepted to the **NeurIPS 2025 Main Conference**! +- **2025.09.01**   Our AudioCoT dataset is now open-sourced and available on [Hugging Face](https://huggingface.co/datasets/liuhuadai/AudioCoT)! +- **2025.07.17**   🧠 Finetuning enabled: training and finetuning code is now publicly available, along with clear usage instructions to help you customize and extend ThinkSound with your own data. +- **2025.07.15**   📦 Simplified installation and usability: dependencies on PyPI for easy cross-platform setup; Windows `.bat` scripts automate environment creation and script running. +- **2025.07.08**   🔧 Major update: model lightweighted and optimized memory and GPU usage, now supports high-throughput audio generation at scale! +- **2025.07.01**   Online demo on [Hugging Face Spaces](https://huggingface.co/spaces/FunAudioLLM/ThinkSound) and [ModelScope](https://modelscope.cn/studios/iic/ThinkSound) for interactive experience! +- **2025.07.01**   Released inference scripts and web interface. +- **2025.06**   [ThinkSound paper](https://arxiv.org/pdf/2506.21448) released on arXiv! +- **2025.06**   [Online Demo](http://thinksound-project.github.io/) is live - try it now! + +--- + +## 🚀 Features + +- **V2A SOTA**: Achieves state-of-the-art results across all four perceptual dimensions on both VGGSound and AudioCanvas benchmarks. +- **Decomposed CoT Reasoning**: Four specialized CoT modules (Semantic, Temporal, Aesthetic, Spatial) each providing focused, interpretable reasoning for its corresponding perceptual dimension. +- **Multi-dimensional RL**: Fast-GRPO enables efficient multi-dimensional reward optimization without compromising generation quality. +- **New Benchmark**: AudioCanvas — a rigorous V2A benchmark with 300 single-event classes and 501 multi-event samples covering diverse and challenging scenarios. +- **Efficient**: 518M parameters with faster inference than prior SOTAs. + +--- + +## ✨ Method Overview + +PrismAudio consists of three main components: + +1. **CoT-Aware Audio Foundation Model**: Built on a Multimodal Diffusion Transformer with flow matching, enhanced with VideoPrism for video understanding and T5-Gemma for structured CoT text encoding. +2. **Decomposed Multi-Dimensional CoT Reasoning**: Four specialized CoT modules — Semantic, Temporal, Aesthetic, and Spatial — each providing targeted reasoning for its corresponding perceptual dimension. +3. **Fast-GRPO Multi-Dimensional RL Framework**: A hybrid ODE-SDE sampling strategy that dramatically reduces training overhead while enabling multi-dimensional reward optimization across all perceptual dimensions. + + +--- + +## ⚡ Quick Start + +```bash +git clone -b prismaudio https://github.com/liuhuadai/ThinkSound.git +cd ThinkSound + +conda create -n prismaudio python=3.10 +conda activate prismaudio +chmod +x scripts/PrismAudio/setup/build_env.sh +./scripts/PrismAudio/setup/build_env.sh + +# Download pretrained weights to Directory ckpts/ +# From Hugging Face: https://huggingface.co/liuhuadai/ThinkSound +# From ModelScope: https://www.modelscope.cn/models/iic/ThinkSound +git lfs install +git clone https://huggingface.co/liuhuadai/ThinkSound ckpts +``` + +--- + +## ▶️ Run Demo + +```bash +chmod +x scripts/PrismAudio/demo.sh +./scripts/PrismAudio/demo.sh "" +``` + +**Note:** +- ``: Path to a single input video file. +- `""`: A structured CoT description of the audio to generate. + +--- + +## 🏋️ Train the Model + +See [`Training.md`](docs/PrismAudio/Training.md) + +--- + +## 📄 License + +This project is released under the Apache 2.0 License. + +> **Note:** +> The code, models, and dataset are **for research and educational purposes only**. +> **Commercial use is NOT permitted.** +> For commercial licensing, please contact the authors. + +**📦 Third-Party Components** + +- **Stable Audio Open VAE** (by Stability AI): Licensed under the [Stability AI Community License](./third_party/LICENSE_StabilityAI.md). **Commercial use and redistribution require prior permission from Stability AI.** +- 📘 **All other code and models** are released under the Apache License 2.0. + +--- + +## Acknowledgements + +Many thanks to: + +- **stable-audio-tools** (by Stability AI): For providing an easy-to-use framework for audio generation, as well as the VAE module and weights. +- **MMAudio**: For the implementation of the MM-DiT backbone in the audio domain. +- **ThinkSound**: For the foundational CoT-based V2A generation framework that PrismAudio builds upon. + +--- + +## 📖 Citation + +If you find PrismAudio useful in your research or work, please cite our paper: + +```bibtex +@misc{liu2025prismaudiodecomposedchainofthoughtsmultidimensional, + title={PrismAudio: Decomposed Chain-of-Thoughts and Multi-dimensional Rewards for Video-to-Audio Generation}, + author={Huadai Liu and Kaicheng Luo and Wen Wang and Qian Chen and Peiwen Sun and Rongjie Huang and Xiangang Li and Jieping Ye and Wei Xue}, + year={2025}, + eprint={2511.18833}, + archivePrefix={arXiv}, + primaryClass={cs.SD}, + url={https://arxiv.org/abs/2511.18833}, + } +``` + +--- + +## 📬 Contact + +✨ Feel free to [open an issue](https://github.com/liuhuadai/ThinkSound/issues) or contact us via email ([huadai.liu@connect.ust.hk](mailto:huadai.liu@connect.ust.hk)) if you have any questions or suggestions! diff --git a/ThinkSound/__init__.py b/ThinkSound/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdaa12ea2399cacdbf5859206911173c8ce0c0c --- /dev/null +++ b/ThinkSound/__init__.py @@ -0,0 +1 @@ +from .models.factory import create_model_from_config, create_model_from_config_path \ No newline at end of file diff --git a/ThinkSound/configs/model_configs/prismaudio.json b/ThinkSound/configs/model_configs/prismaudio.json new file mode 100644 index 0000000000000000000000000000000000000000..19d24a09a76e9656cf221c7f1743fef3a7dd69dd --- /dev/null +++ b/ThinkSound/configs/model_configs/prismaudio.json @@ -0,0 +1,141 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video_features", + "type": "cond_mlp", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "text_features", + "type": "cond_mlp", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "sync_mlp", + "config": { + "dim": 768, + "output_dim": 1024 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["video_features","text_features"], + "add_cond_ids": ["video_features"], + "sync_cond_ids": ["sync_features"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1024, + "depth": 24, + "num_heads": 16, + "cond_token_dim": 1024, + "add_token_dim": 1024, + "sync_token_dim": 1024, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer", + "attn_kwargs":{ + "qk_norm": "rns" + }, + "use_gated": true, + "use_sync_gated": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.1, + "pre_encoded": true, + "timestep_sampler": "trunc_logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/videoprism/test/0Cu33yBwAPg_000060.npz", + "dataset/videoprism/test/bmKtI808DsU_000009.npz", + "dataset/videoprism/test/VC0c22cJTbM_000424.npz", + "dataset/videoprism/test/F3gsbUTdc2U_000090.npz", + "dataset/videoprism/test/WatvT8A8iug_000100.npz", + "dataset/videoprism/test/0nvBTp-q7tU_000112.npz", + "dataset/videoprism/test/3-PFuDkTM48_000080.npz", + "dataset/videoprism/test/luSAuu-BoPs_000232.npz", + "dataset/videoprism/test/__8UJxW0aOQ_000002.npz", + "dataset/videoprism/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json b/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..95f72495502f5b0a725378a0b9f51a56bb9da910 --- /dev/null +++ b/ThinkSound/configs/model_configs/stable_audio_2_0_vae.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/ThinkSound/configs/model_configs/thinksound.json b/ThinkSound/configs/model_configs/thinksound.json new file mode 100644 index 0000000000000000000000000000000000000000..1458b0d43350e928b9c68cc8619242b8ff8f87c1 --- /dev/null +++ b/ThinkSound/configs/model_configs/thinksound.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024, + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/ThinkSound/configs/multimodal_dataset_demo.json b/ThinkSound/configs/multimodal_dataset_demo.json new file mode 100644 index 0000000000000000000000000000000000000000..b00baa073eba5bd531eb723731ad6ba39d5019e9 --- /dev/null +++ b/ThinkSound/configs/multimodal_dataset_demo.json @@ -0,0 +1,52 @@ +{ + "dataset_type": "multimodal_dir", + "video_datasets": [ + { + "id": "vggsound", + "path": "dataset/vggsound/video_latents_t5_clip_npz/train", + "split_path": "dataset/vggsound/split_txt/train_cot.txt" + } + ], + "audio_datasets": [ + { + "id": "audiostock", + "path": "dataset/Laion-Audio-630k/audiostock_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt" + }, + { + "id": "freesound_no_overlap", + "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt" + }, + { + "id": "audioset_sl", + "path": "dataset/wavcaps/audioset_sl_latents_npz", + "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt" + }, + { + "id": "audiocaps", + "path": "dataset/1_audiocaps/audiocaps_latents_npz", + "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt" + }, + { + "id": "bbc", + "path": "dataset/Laion-Audio-630k/bbc_latents_npz", + "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt" + } + ], + "val_datasets": [ + { + "id": "vggsound", + "path": "dataset/vggsound/video_latents_t5_clip_npz/test", + "split_path": "dataset/vggsound/split_txt/test_cot.txt" + } + ], + "test_datasets": [ + { + "id": "vggsound", + "path": "cot_coarse" + } + ], + "random_crop": true, + "input_type": "prompt" +} \ No newline at end of file diff --git a/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json b/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json new file mode 100644 index 0000000000000000000000000000000000000000..8a3e2320175dd91bd176d4bded1847a46e865e62 --- /dev/null +++ b/ThinkSound/configs/multimodal_dataset_demo_prismaudio.json @@ -0,0 +1,27 @@ +{ + "dataset_type": "video_dataset", + "datasets": [ + { + "id": "vggsound", + "path": "test", + "split_path": "test/test.txt" + } + ], + "val_datasets": [ + { + "id": "vggsound", + "path": "test", + "split_path": "test/test.txt" + } + ], + "test_datasets": [ + { + "id": "vggsound", + "path": "test", + "split_path": "test/test.txt" + } + ], + "random_crop": false, + "input_type": "video", + "fps": 8 +} \ No newline at end of file diff --git a/ThinkSound/data/__init__.py b/ThinkSound/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/data/datamodule.py b/ThinkSound/data/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..e7503c673511941d7e322fd0b9168c187a121a14 --- /dev/null +++ b/ThinkSound/data/datamodule.py @@ -0,0 +1,331 @@ +import lightning as L +from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn +import importlib +import torch.distributed as dist +from torch.utils.data import Dataset +from torch.utils.data import DataLoader,IterableDataset +import torch +from itertools import cycle + +class AlternatingLoader(IterableDataset): + """ + 一个可迭代的数据集,它包装了两个数据加载器,并按顺序轮流从它们中产出批次。 + 它会持续进行直到两个加载器都耗尽。 + + Args: + loader1 (DataLoader): 第一个数据加载器。 + loader2 (DataLoader): 第二个数据加载器。 + loader1_name (str): 第一个加载器的名称 (例如 'video')。 + loader2_name (str): 第二个加载器的名称 (例如 'audio')。 + """ + def __init__(self, loader1, loader2, loader1_name='video', loader2_name='audio'): + super().__init__() + self.loader1 = loader1 + self.loader2 = loader2 + self.loader1_name = loader1_name + self.loader2_name = loader2_name + self.max_len = max(len(loader1), len(loader2)) + + def __iter__(self): + # 获取 DDP 信息 + try: + world_size = dist.get_world_size() + rank = dist.get_rank() + except (RuntimeError, ValueError): + # 如果不在分布式环境中,则默认为单进程 + world_size = 1 + rank = 0 + + # 创建两个无限循环迭代器 + iter1 = cycle(self.loader1) + iter2 = cycle(self.loader2) + + # 核心修改:只 yield 属于当前 rank 的数据 + # 我们将总的交替流想象成一个大列表,然后对其进行切分 + # 交替流: [v1, a1, v2, a2, v3, a3, ...] + + # 每个 for 循环迭代产生 2 个 batch (1 个 video, 1 个 audio) + # 总共会产生 2 * self.max_len 个 batch + + # for 循环负责驱动迭代 + for i in range(self.max_len): + # 获取下一个 video batch + v_batch = next(iter1) + # 获取下一个 audio batch + a_batch = next(iter2) + + # 这是一个交替对,我们根据索引 i 来决定哪个进程处理它 + if i % world_size == rank: + # 只有当轮次索引 i 属于当前 rank 时,才 yield 数据 + yield v_batch + yield a_batch + + def __len__(self): + # 在 DDP 环境下,__len__ 应该返回单个进程处理的 batch 数量 + # 以便 Lightning 正确显示进度条 + + try: + world_size = dist.get_world_size() + except (RuntimeError, ValueError): + world_size = 1 + + # 每个进程大致处理 1/world_size 的数据对 + # 每个数据对包含 2 个 batch + num_pairs_per_process = self.max_len // world_size + + # 如果总数不能整除,最后一个 rank 会多处理一些 + # 为简化起见,我们通常可以用 ceil 来计算 + # (self.max_len + world_size - 1) // world_size 是一种高效的 ceil 写法 + num_pairs_per_process = (self.max_len + world_size - 1) // world_size + + return 2 * num_pairs_per_process +def get_configs(audio_configs): + configs = [] + for config in audio_configs: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + + custom_metadata_fn = None + custom_metadata_module_path = config.get("custom_metadata_module", None) + + if custom_metadata_module_path: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn, + audio_dir=audio_dir_path + ) + ) + return configs + +class DataModule(L.LightningDataModule): + def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): + super().__init__() + dataset_type = dataset_config.get("dataset_type", None) + repeat_num = dataset_config.get("repeat_num", 1) + self.batch_size = batch_size + self.num_workers = num_workers + self.test_batch_size = test_batch_size + self.repeat_num = repeat_num + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + elif audio_channels == 2: + force_channels = "stereo" + else: + force_channels = "foa" + val_dir_configs = dataset_config.get("val_datasets", None) + test_dir_configs = dataset_config.get("test_datasets", None) + configs = [] + val_configs = [] + test_configs = [] + if dataset_type == "audio_dir": + audio_dir_configs = dataset_config.get("datasets", None) + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + configs = get_configs(audio_dir_configs) + val_configs = get_configs(val_dir_configs) + test_configs = get_configs(test_dir_configs) + elif dataset_type == "latent_dir" or dataset_type == "video_dataset" or dataset_type == "audio_dataset": + audio_dir_configs = dataset_config.get("datasets", None) + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)): + for config in dataset: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + + content = LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + audio_dir=audio_dir_path + ) + if i == 0: + configs.append(content) + elif i == 1: + val_configs.append(content) + else: + test_configs.append(content) + elif dataset_type in ["multimodal_dir", "alternating_multimodal"]: + print('##########################') + print(f'repeat num is: {self.repeat_num}') + self.audio_configs = [] + self.video_configs = [] + audio_dir_configs = dataset_config.get("audio_datasets", None) + video_dir_configs = dataset_config.get("video_datasets", None) + assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets" + for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)): + for config in dataset: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + + content = LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + audio_dir=audio_dir_path + ) + if i == 0: + self.audio_configs.append(content) + elif i == 1: + self.video_configs.append(content) + elif i == 2: + val_configs.append(content) + else: + test_configs.append(content) + self.dataset_type = dataset_type + self.configs = configs + self.val_configs = val_configs + self.test_configs = test_configs + self.sample_rate = sample_rate + self.sample_size = sample_size + self.random_crop = dataset_config.get("random_crop", True) + self.input_type = dataset_config.get("input_type", "video") + self.fps = dataset_config.get("fps", 4) + self.force_channels = force_channels + + + def setup(self, stage: str): + if self.dataset_type == 'audio_dir': + dataset_class = SampleDataset + elif self.dataset_type == 'latent_dir': + dataset_class = LatentDataset + elif self.dataset_type == 'video_dataset': + dataset_class = VideoDataset + elif self.dataset_type == 'audio_dataset': + dataset_class = AudioDataset + elif self.dataset_type in ["multimodal_dir", "alternating_multimodal"]: + dataset_class = VideoDataset + + def create_dataset(configs, random_crop): + return dataset_class( + configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + + if stage == 'fit': + if self.dataset_type not in ["multimodal_dir", "alternating_multimodal"]: + self.train_set = create_dataset(self.configs, random_crop=self.random_crop) + elif self.dataset_type == "multimodal_dir": + self.video_set = VideoDataset( + self.video_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.audio_set = AudioDataset( + self.audio_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set]) + elif self.dataset_type == "alternating_multimodal": + self.video_set = VideoDataset( + self.video_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.audio_set = AudioDataset( + self.audio_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.val_set = create_dataset(self.val_configs, random_crop=False) + elif stage == 'validate': + self.val_set = create_dataset(self.val_configs, random_crop=False) + elif stage == 'predict': + self.test_set = create_dataset(self.test_configs, random_crop=False) + + + + def train_dataloader(self): + if self.dataset_type == "alternating_multimodal": + # 视频 DataLoader + video_loader = DataLoader( + self.video_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True, + collate_fn=collation_fn + ) + + # 音频 DataLoader + audio_loader = DataLoader( + self.audio_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True, + collate_fn=collation_fn + ) + alternating_loader = AlternatingLoader( + video_loader, + audio_loader, + loader1_name='video', + loader2_name='audio' + ) + return DataLoader(alternating_loader, batch_size=None, num_workers=0) + else: + # 如果不是 alternating_multimodal,保持现有逻辑(仅用于兼容性) + return DataLoader( + self.train_set, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + persistent_workers=True, + pin_memory=True, + drop_last=True, + collate_fn=collation_fn + ) + + + def val_dataloader(self): + return DataLoader(self.val_set, self.batch_size, shuffle=False, + num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) + + def predict_dataloader(self): + return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False, + num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) + + # def predict_dataloader(self): + # return DataLoader(self.mnist_predict, batch_size=self.batch_size) + + # def teardown(self, stage: str): + # # Used to clean-up when the run is finished + # ... \ No newline at end of file diff --git a/ThinkSound/data/dataset.py b/ThinkSound/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac3546e73da58cfe7d05dd46cbdf997d168f00a --- /dev/null +++ b/ThinkSound/data/dataset.py @@ -0,0 +1,1319 @@ +import importlib +import numpy as np +import io +import os +import posixpath +import random +import re +import subprocess +import time +import torch +import torchaudio +import webdataset as wds +import pandas as pd +from aeiou.core import is_silence +from os import path +from pathlib import Path +from pedalboard.io import AudioFile +from torchaudio import transforms as T +from typing import Optional, Callable, List +import bisect + +from .utils import FOA, Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, PadCrop_Video_Normalized_T, PadCrop_Video_Hiera_Normalized_T, PadCrop_Video_Image_Normalized_T, PadCrop_DualVideo_Normalized_T + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") + +# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def fast_scandir( + dir:str, # top-level directory at which to begin scanning + ext:list, # list of allowed file extensions, + #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB + ): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + file_ext = os.path.splitext(f.name)[1].lower() + is_hidden = os.path.basename(f.path).startswith(".") + + if file_ext in ext and not is_hidden: + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = fast_scandir(dir, ext) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def keyword_scandir( + dir: str, # top-level directory at which to begin scanning + ext: list, # list of allowed file extensions + keywords: list, # list of keywords to search for in the file name +): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + # make keywords case insensitive + keywords = [keyword.lower() for keyword in keywords] + # add starting period to extensions if needed + ext = ['.'+x if x[0] != '.' else x for x in ext] + banned_words = ["paxheader", "__macosx"] + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + is_hidden = f.name.split("/")[-1][0] == '.' + has_ext = os.path.splitext(f.name)[1].lower() in ext + name_lower = f.name.lower() + has_keyword = any( + [keyword in name_lower for keyword in keywords]) + has_banned = any( + [banned_word in name_lower for banned_word in banned_words]) + if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = keyword_scandir(dir, ext, keywords) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def get_audio_filenames( + paths: list, # directories in which to search + keywords=None, + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] +): + "recursively get a list of audio filenames" + filenames = [] + if type(paths) is str: + paths = [paths] + for path in paths: # get a list of relevant filenames + if keywords is not None: + subfolders, files = keyword_scandir(path, exts, keywords) + else: + subfolders, files = fast_scandir(path, exts) + filenames.extend(files) + return filenames + + + + + +class LocalDatasetConfig: + def __init__( + self, + id: str, + path: str, + split_path: str, + audio_dir: str = None, + custom_metadata_fn: Optional[Callable[[str], str]] = None + ): + self.id = id + self.path = path + self.split_path = split_path + self.audio_dir = audio_dir + self.custom_metadata_fn = custom_metadata_fn + +class SampleDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + if input_type == 'video': + self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'video_hiera': + self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'video_image': + self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'dual_video': + self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + else: + self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + self.custom_metadata_fns = {} + + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + if config.custom_metadata_fn is not None: + self.custom_metadata_fns[config.path] = config.custom_metadata_fn + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename): + ext = filename.split(".")[-1] + if ext == "mp3": + with AudioFile(filename) as f: + audio = f.read(f.frames) + audio = torch.from_numpy(audio) + in_sr = f.samplerate + else: + audio, in_sr = torchaudio.load(filename, format=ext) + + if in_sr != self.sr: + try: + resample_tf = T.Resample(in_sr, self.sr) + audio = resample_tf(audio) + except: + print(f'{filename} resample errors') + + assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' + try: + start_time = time.time() + audio = self.load_file(audio_filename) + info = {} + info["path"] = audio_filename + + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + + for custom_md_path in self.custom_metadata_fns.keys(): + if custom_md_path in audio_filename: + custom_metadata_fn = self.custom_metadata_fns[custom_md_path] + custom_metadata = custom_metadata_fn(info, audio) + info.update(custom_metadata) + + if "__reject__" in info and info["__reject__"]: + return self[random.randrange(len(self))] + if self.input_type == 'video': + audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video']) + info['video'] = video + elif self.input_type == 'dual_video': + audio, video_360, video_fov, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'], info['video_fov']) + info['video_360'] = video_360 + info['video_fov'] = video_fov + else: + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) + assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' + # Run augmentations on this sample (including random crop) + if self.augs is not None: + audio = self.augs(audio) + + audio = audio.clamp(-1, 1) + + # Encode the file to assist in prediction + if self.encoding is not None: + audio = self.encoding(audio) + + + + info["timestamps"] = (t_start, t_end) + info["seconds_start"] = seconds_start + info["seconds_total"] = seconds_total + info["padding_mask"] = padding_mask + + end_time = time.time() + info["load_time"] = end_time - start_time + + + return (audio, info) + except Exception as e: + print(f'Couldn\'t load file {audio_filename}: {e}') + return self[random.randrange(len(self))] + +class LatentDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename, info): + # try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file with file not exists: {filename}') + info.update(data) + audio = data['latent'] + # except: + # print(f'error load file: {filename}') + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio = self.load_file(audio_filename, info) + info["path"] = audio_filename + info['id'] = Path(audio_filename).stem + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + return (audio, info) + +class AudioDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.fake_clip_features = torch.zeros(72, 1024) + self.fake_sync_features = torch.zeros(216, 768) + self.video_exist = torch.tensor(0, dtype=torch.bool) + self.input_type = input_type + self.sr = sample_rate + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename, info): + # try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = dict(npz_data) + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + audio = data['latent'] + if 'source_latent' not in data.keys(): + info['source_latent']= audio + info['video_features'] = self.fake_clip_features + info['sync_features'] = self.fake_sync_features + info['video_exist'] = self.video_exist + # except: + # print(f'error load file: {filename}') + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio = self.load_file(audio_filename, info) + info["path"] = audio_filename + assert audio.shape == (64,194), f'{audio.shape} input error, id: {audio_filename}' + info['id'] = Path(audio_filename).stem + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + return (audio, info) + + + + +class VideoDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + self.sample_size = sample_size + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + self.video_exist = torch.tensor(1, dtype=torch.bool) + self.audio_files = [] + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + if config.audio_dir is not None: + def add_prefix(s): + return str(os.path.join(config.audio_dir,f'{Path(s).stem}.wav')) + filenames = list(map(add_prefix, item_names)) + self.audio_files.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + print(f'Found {len(self.filenames)} files') + + def load_audio(self, filename): + ext = filename.split(".")[-1] + if ext == "mp3": + with AudioFile(filename) as f: + audio = f.read(f.frames) + audio = torch.from_numpy(audio) + in_sr = f.samplerate + else: + audio, in_sr = torchaudio.load(filename, format=ext) + + if in_sr != self.sr: + try: + resample_tf = T.Resample(in_sr, self.sr) + audio = resample_tf(audio) + except: + print(f'{filename} resample errors') + + assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' + return audio + + + + + def check_audio_file(self, audio_path): + # 首先检查原始路径是否存在 + if os.path.exists(audio_path): + return audio_path + # 如果不存在,尝试替换为.flac扩展名 + name, ext = os.path.splitext(audio_path) + flac_path = f"{name}.flac" + + if os.path.exists(flac_path): + return flac_path + raise FileNotFoundError(f"音频文件不存在: {audio_path} 和 {flac_path} 都不存在") + + def load_file(self, filename, info): + try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + audio = data['latent'] + info['video_exist'] = self.video_exist + except Exception as e: + print(f'error load file: {filename} with error: {e}') + return None + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + loop = True + while loop: + filename = self.filenames[idx] + if len(self.audio_files) > 0: + audio_path = self.audio_files[idx] + audio_path = self.check_audio_file(audio_path) + waveform = self.load_audio(audio_path) + else: + waveform = None + assert os.path.exists(filename) or filename.replace('.pth','.npz'), f'{filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio = self.load_file(filename, info) + if audio is not None: + loop = False + else: + idx = (idx+1) % len(self.filenames) + + if waveform is not None: + padded_waveform = torch.zeros(waveform.shape[0], self.sample_size, dtype=waveform.dtype) + copy_length = min(waveform.shape[1], self.sample_size) + padded_waveform[:, :copy_length] = waveform[:, :copy_length] + + waveform = padded_waveform + waveform = waveform.clamp(-1, 1) + # Encode the file to assist in prediction + if self.encoding is not None: + waveform = self.encoding(waveform) + info['waveform'] = waveform + info["path"] = filename + info['id'] = Path(filename).stem + for root_path in self.root_paths: + if root_path in filename: + info["relpath"] = path.relpath(filename, root_path) + + + return (audio, info) + +# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset +class MultiModalDataset(torch.utils.data.Dataset): + datasets: list[torch.utils.data.Dataset] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, video_datasets: list[torch.utils.data.Dataset], audio_datasets: list[torch.utils.data.Dataset]): + super().__init__() + self.video_datasets = list(video_datasets) + self.audio_datasets = list(audio_datasets) + self.datasets = self.video_datasets + self.audio_datasets + + self.cumulative_sizes = self.cumsum(self.datasets) + print(f'Found {self.cumulative_sizes[-1]} files') + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.video_datasets[0].compute_latent_stats() + + +# class MultiModalDataset(torch.utils.data.Dataset): +# def __init__( +# self, +# configs, +# sample_size=65536, +# sample_rate=48000, +# keywords=None, +# random_crop=True, +# input_type="prompt", +# fps=4, +# force_channels="stereo" +# ): +# super().__init__() +# self.filenames = [] +# self.captions = [] +# self.caption_t5s = [] +# self.ids = [] +# self.augs = torch.nn.Sequential( +# PhaseFlipper(), +# ) + +# self.root_paths = [] +# if input_type == 'video': +# self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'video_hiera': +# self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'video_image': +# self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'dual_video': +# self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# else: +# self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + +# self.force_channels = force_channels +# print('######################') +# print(f'input channels is: {force_channels}') +# print('######################') +# self.encoding = torch.nn.Sequential( +# FOA() if self.force_channels == "foa" else torch.nn.Identity(), +# Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), +# Mono() if self.force_channels == "mono" else torch.nn.Identity(), +# ) +# self.input_type = input_type +# self.sr = sample_rate +# self.custom_metadata_fns = {} + +# for config in configs: +# print(config.split_path) +# self.root_paths.append(config.path) +# def add_prefix(s): +# return str(os.path.join(config.path,f'{s.strip()}')) +# with open(config.split_path,'r') as f: +# item_names = f.readlines() +# csv_path = config.split_path.replace('.txt','.csv') +# df = pd.read_csv(csv_path) +# # 检查是否存在 'caption_t5' 列,如果不存在则创建并复制 'caption' 的值 +# if 'caption_t5' not in df.columns: +# df['caption_t5'] = df['caption'] + +# captions = df['caption'].tolist() +# caption_t5s = df['caption_t5'].tolist() +# filenames = list(map(add_prefix, item_names)) +# assert len(captions) == len(caption_t5s) and len(captions) == len(filenames), f'{config.path} has wrong filename and caption' +# if config.id == 'vggsound': +# self.filenames.extend(filenames*5) +# self.captions.extend(captions*5) +# self.caption_t5s.extend(caption_t5s*5) +# self.ids.extend(df['id'].tolist()*5) +# else: +# self.filenames.extend(filenames) +# self.captions.extend(captions) +# self.caption_t5s.extend(caption_t5s) +# self.ids.extend(df['id'].tolist()) +# # self.filenames.extend(get_audio_filenames(config.path, keywords)) +# if config.custom_metadata_fn is not None: +# self.custom_metadata_fns[config.path] = config.custom_metadata_fn + +# assert len(self.ids) == len(self.captions) and len(self.caption_t5s) == len(self.filenames), 'length need to be same' +# print(f'Found {len(self.filenames)} files') + + +# def load_file(self, filename): +# ext = filename.split(".")[-1] +# if ext == "mp3": +# with AudioFile(filename) as f: +# audio = f.read(f.frames) +# audio = torch.from_numpy(audio) +# in_sr = f.samplerate +# else: +# audio, in_sr = torchaudio.load(filename, format=ext) + +# if in_sr != self.sr: +# try: +# resample_tf = T.Resample(in_sr, self.sr) +# audio = resample_tf(audio) +# except: +# print(f'{filename} resample errors') + +# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' +# return audio + +# def __len__(self): +# return len(self.filenames) + +# def __getitem__(self, idx): +# audio_filename = self.filenames[idx] +# id = self.ids[idx] +# assert str(id) == str(Path(audio_filename).stem), f'audio_file: {audio_filename} needs to be same as {id} ' +# assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' +# try: +# start_time = time.time() +# audio = self.load_file(audio_filename) +# caption = self.captions[idx] +# caption_t5 = self.caption_t5s[idx] +# if pd.isna(caption_t5) or caption_t5 == '': +# caption_t5 = caption +# info = {} +# info["path"] = audio_filename +# info['caption'] = caption +# info['caption_t5'] = caption_t5 + +# for root_path in self.root_paths: +# if root_path in audio_filename: +# info["relpath"] = path.relpath(audio_filename, root_path) + + +# for custom_md_path in self.custom_metadata_fns.keys(): +# if custom_md_path in audio_filename: +# custom_metadata_fn = self.custom_metadata_fns[custom_md_path] +# custom_metadata = custom_metadata_fn(info, audio) +# info.update(custom_metadata) + +# if "__reject__" in info and info["__reject__"]: +# return self[random.randrange(len(self))] +# # if self.input_type == 'video': +# # audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['clip_features']) +# # info['clip_features'] = video +# # else: +# if info['flag']: +# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=False) +# else: +# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=True) +# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' +# # Run augmentations on this sample (including random crop) +# if self.augs is not None: +# audio = self.augs(audio) + +# audio = audio.clamp(-1, 1) + +# # Encode the file to assist in prediction +# if self.encoding is not None: +# audio = self.encoding(audio) + + + +# info["timestamps"] = (t_start, t_end) +# info["seconds_start"] = seconds_start +# info["seconds_total"] = seconds_total +# info["padding_mask"] = padding_mask + +# end_time = time.time() +# info["load_time"] = end_time - start_time + + +# return (audio, info) +# except Exception as e: +# print(f'Couldn\'t load file {audio_filename}: {e}') +# return self[random.randrange(len(self))] + +def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if wds.tariterators.trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if wds.tariterators.valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if wds.tariterators.valid_sample(current_sample): + yield current_sample + +wds.tariterators.group_by_keys = group_by_keys + +# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): + """ + Returns a list of full S3 paths to files in a given S3 bucket and directory path. + """ + # Ensure dataset_path ends with a trailing slash + if dataset_path != '' and not dataset_path.endswith('/'): + dataset_path += '/' + # Use posixpath to construct the S3 URL path + bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + # Construct the `aws s3 ls` command + cmd = ['aws', 's3', 'ls', bucket_path] + + if profile is not None: + cmd.extend(['--profile', profile]) + + if recursive: + # Add the --recursive flag if requested + cmd.append('--recursive') + + # Run the `aws s3 ls` command and capture the output + run_ls = subprocess.run(cmd, capture_output=True, check=True) + # Split the output into lines and strip whitespace from each line + contents = run_ls.stdout.decode('utf-8').split('\n') + contents = [x.strip() for x in contents if x] + # Remove the timestamp from lines that begin with a timestamp + contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) + if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] + # Construct a full S3 path for each file in the contents list + contents = [posixpath.join(s3_url_prefix or '', x) + for x in contents if not x.endswith('/')] + # Apply the filter, if specified + if filter: + contents = [x for x in contents if filter in x] + # Remove redundant directory names in the S3 URL + if recursive: + # Get the main directory name from the S3 URL + main_dir = "/".join(bucket_path.split('/')[3:]) + # Remove the redundant directory names from each file path + contents = [x.replace(f'{main_dir}', '').replace( + '//', '/') for x in contents] + # Print debugging information, if requested + if debug: + print("contents = \n", contents) + # Return the list of S3 paths to files + return contents + + +def get_all_s3_urls( + names=[], # list of all valid [LAION AudioDataset] dataset names + # list of subsets you want from those datasets, e.g. ['train','valid'] + subsets=[''], + s3_url_prefix=None, # prefix for those dataset names + recursive=True, # recursively list all tar files in all subdirs + filter_str='tar', # only grab files with this substring + # print debugging info -- note: info displayed likely to change at dev's whims + debug=False, + profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} +): + "get urls of shards (tar files) for multiple datasets in one s3 bucket" + urls = [] + for name in names: + # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list + if s3_url_prefix is None: + contents_str = name + else: + # Construct the S3 path using the s3_url_prefix and the current name value + contents_str = posixpath.join(s3_url_prefix, name) + if debug: + print(f"get_all_s3_urls: {contents_str}:") + for subset in subsets: + subset_str = posixpath.join(contents_str, subset) + if debug: + print(f"subset_str = {subset_str}") + # Get the list of tar files in the current subset directory + profile = profiles.get(name, None) + tar_list = get_s3_contents( + subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + for tar in tar_list: + # Escape spaces and parentheses in the tar filename for use in the shell command + tar = tar.replace(" ", "\ ").replace( + "(", "\(").replace(")", "\)") + # Construct the S3 path to the current tar file + s3_path = posixpath.join(name, subset, tar) + " -" + # Construct the AWS CLI command to download the current tar file + if s3_url_prefix is None: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" + else: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" + if profiles.get(name): + request_str += f" --profile {profiles.get(name)}" + if debug: + print("request_str = ", request_str) + # Add the constructed URL to the list of URLs + urls.append(request_str) + return urls + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + print(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +def is_valid_sample(sample): + has_json = "json" in sample + has_audio = "audio" in sample + is_silent = is_silence(sample["audio"]) + is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] + + return has_json and has_audio and not is_silent and not is_rejected + +class S3DatasetConfig: + def __init__( + self, + id: str, + s3_path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = s3_path + self.custom_metadata_fn = custom_metadata_fn + self.profile = profile + self.urls = [] + + def load_data_urls(self): + self.urls = get_all_s3_urls( + names=[self.path], + s3_url_prefix=None, + recursive=True, + profiles={self.path: self.profile} if self.profile else {}, + ) + + return self.urls + +class LocalWebDatasetConfig: + def __init__( + self, + id: str, + path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = path + self.custom_metadata_fn = custom_metadata_fn + self.urls = [] + + def load_data_urls(self): + + self.urls = fast_scandir(self.path, ["tar"])[1] + + return self.urls + +def audio_decoder(key, value): + # Get file extension from key + ext = key.split(".")[-1] + + if ext in AUDIO_KEYS: + return torchaudio.load(io.BytesIO(value)) + else: + return None + +def collation_fn(samples): + batched = list(zip(*samples)) + result = [] + for b in batched: + if isinstance(b[0], (int, float)): + b = np.array(b) + elif isinstance(b[0], torch.Tensor): + b = torch.stack(b) + elif isinstance(b[0], np.ndarray): + b = np.array(b) + else: + b = b + result.append(b) + return result + +class WebDatasetDataLoader(): + def __init__( + self, + datasets: List[S3DatasetConfig], + batch_size, + sample_size, + sample_rate=48000, + num_workers=8, + epoch_steps=1000, + random_crop=True, + force_channels="stereo", + augment_phase=True, + **data_loader_kwargs + ): + + self.datasets = datasets + + self.sample_size = sample_size + self.sample_rate = sample_rate + self.random_crop = random_crop + self.force_channels = force_channels + self.augment_phase = augment_phase + + urls = [dataset.load_data_urls() for dataset in datasets] + + # Flatten the list of lists of URLs + urls = [url for dataset_urls in urls for url in dataset_urls] + + # Shuffle the urls + random.shuffle(urls) + + self.dataset = wds.DataPipeline( + wds.ResampledShards(urls), + wds.tarfile_to_samples(handler=log_and_continue), + wds.decode(audio_decoder, handler=log_and_continue), + wds.map(self.wds_preprocess, handler=log_and_continue), + wds.select(is_valid_sample), + wds.to_tuple("audio", "json", handler=log_and_continue), + #wds.shuffle(bufsize=1000, initial=5000), + wds.batched(batch_size, partial=False, collation_fn=collation_fn), + ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) + + self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) + + def wds_preprocess(self, sample): + + found_key, rewrite_key = '', '' + for k, v in sample.items(): # print the all entries in dict + for akey in AUDIO_KEYS: + if k.endswith(akey): + # to rename long/weird key with its simpler counterpart + found_key, rewrite_key = k, akey + break + if '' != found_key: + break + if '' == found_key: # got no audio! + return None # try returning None to tell WebDataset to skip this one + + audio, in_sr = sample[found_key] + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate) + audio = resample_tf(audio) + + if self.sample_size is not None: + # Pad/crop and get the relative timestamp + pad_crop = PadCrop_Normalized_T( + self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( + audio) + sample["json"]["seconds_start"] = seconds_start + sample["json"]["seconds_total"] = seconds_total + sample["json"]["padding_mask"] = padding_mask + else: + t_start, t_end = 0, 1 + + # Check if audio is length zero, initialize to a single zero if so + if audio.shape[-1] == 0: + audio = torch.zeros(1, 1) + + # Make the audio stereo and augment by randomly inverting phase + augs = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + PhaseFlipper() if self.augment_phase else torch.nn.Identity() + ) + + audio = augs(audio) + + sample["json"]["timestamps"] = (t_start, t_end) + + if "text" in sample["json"]: + sample["json"]["prompt"] = sample["json"]["text"] + + # Check for custom metadata functions + for dataset in self.datasets: + if dataset.custom_metadata_fn is None: + continue + + if dataset.path in sample["__url__"]: + custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) + sample["json"].update(custom_metadata) + + if found_key != rewrite_key: # rename long/weird key with its simpler counterpart + del sample[found_key] + + sample["audio"] = audio + + # Add audio to the metadata as well for conditioning + sample["json"]["audio"] = audio + + return sample + +def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, shuffle=True): + + dataset_type = dataset_config.get("dataset_type", None) + + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + elif audio_channels == 2: + force_channels = "stereo" + else: + force_channels = "foa" + + if dataset_type == "audio_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn + ) + ) + + train_set = SampleDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + + elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility + + wds_configs = [] + + for wds_config in dataset_config["datasets"]: + + custom_metadata_fn = None + custom_metadata_module_path = wds_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + if "s3_path" in wds_config: + + wds_configs.append( + S3DatasetConfig( + id=wds_config["id"], + s3_path=wds_config["s3_path"], + custom_metadata_fn=custom_metadata_fn, + profile=wds_config.get("profile", None), + ) + ) + + elif "path" in wds_config: + + wds_configs.append( + LocalWebDatasetConfig( + id=wds_config["id"], + path=wds_config["path"], + custom_metadata_fn=custom_metadata_fn + ) + ) + + return WebDatasetDataLoader( + wds_configs, + sample_rate=sample_rate, + sample_size=sample_size, + batch_size=batch_size, + random_crop=dataset_config.get("random_crop", True), + num_workers=num_workers, + persistent_workers=True, + force_channels=force_channels, + epoch_steps=dataset_config.get("epoch_steps", 2000) + ).data_loader + + elif dataset_type == "latent_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + ) + ) + + train_set = LatentDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + elif dataset_type == 'multimodal_dir': + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn + ) + ) + + train_set = MultiModalDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=shuffle, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) \ No newline at end of file diff --git a/ThinkSound/data/utils.py b/ThinkSound/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0f0a19a4f9a32ee3424cfd24efe7643ddde0fc --- /dev/null +++ b/ThinkSound/data/utils.py @@ -0,0 +1,378 @@ +import math +import random +import torch +import torch.nn.functional as F +from torch import nn +from typing import Tuple +import numpy as np + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + +class PadCrop_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, source: torch.Tensor, randomize=True) -> Tuple[torch.Tensor, float, float, int, int]: + + n_channels, n_samples = source.shape + + # If the audio is shorter than the desired length, pad it + upper_bound = max(0, n_samples - self.n_samples) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(randomize and n_samples > self.n_samples): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + + # Create the chunk + chunk = source.new_zeros([n_channels, self.n_samples]) + + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] + + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # print(video.shape) + n_frames, dim = video.shape + if not torch.is_tensor(video): + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, video.shape[1]]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames,:] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Image_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # import ipdb + # ipdb.set_trace() + n_frames, channel, width, height= video.shape + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, channel, width, height]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Hiera_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + + n_channels, n_samples = audio.shape + n_frames, heigh, width, channel = video.shape + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, heigh, width, channel]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames] + # video_chunk = video_chunk[None].permute(0, 4, 1, 2, 3).contiguous() + # print(video_chunk.shape) + # video_chunk = F.interpolate( + # video_chunk[0], + # size=(224, 224, 3), # 输出的空间尺寸 + # scale_factor=(target_frames / video_tensor.shape[1], 1, 1), # 时间轴的缩放因子 + # mode='trilinear', # 使用三线性插值 + # align_corners=False + # ) + + # video_chunk = F.interpolate(video_chunk, size=(64, 224, 224), mode="trilinear")[0] + # video_chunk = video_chunk.view(3,4,16,224,224).transpose(0,1) + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_DualVideo_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video_360: torch.Tensor, video_fov: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # print(video.shape) + n_frames, dim = video_360.shape + video_360 = torch.from_numpy(video_360) + video_fov = torch.from_numpy(video_fov) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_360_chunk = video_360.new_zeros([self.n_frames, video_360.shape[1]]) + video_fov_chunk = video_fov.new_zeros([self.n_frames, video_fov.shape[1]]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_360_chunk[:min(n_frames, self.n_frames)] = video_360[frame_offset:frame_offset + self.n_frames,:] + video_fov_chunk[:min(n_frames, self.n_frames)] = video_fov[frame_offset:frame_offset + self.n_frames,:] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_360_chunk, + video_fov_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PhaseFlipper(nn.Module): + "Randomly invert the phase of a signal" + def __init__(self, p=0.5): + super().__init__() + self.p = p + def __call__(self, signal): + return -signal if (random.random() < self.p) else signal + +class Mono(nn.Module): + def __call__(self, signal): + return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + +class Stereo(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: #1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: #?, s -> 2,s + signal = signal[:2, :] + + return signal + +class FOA(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> (4, s) + foa = torch.zeros(4, signal_shape[0], device=signal.device) # 与输入信号一致的设备类型 + foa[0, :] = signal # W通道: 全方位声源 + foa[1, :] = 0 # X通道 + foa[2, :] = 0 # Y通道 + foa[3, :] = 0 # Z通道 + elif len(signal_shape) == 2: + foa = torch.zeros(4, signal_shape[1], device=signal.device) # 与输入信号一致的设备类型 + if signal_shape[0] == 1: # (1, s) -> (4, s) + foa[0, :] = signal[0] # W通道: 全方位声源 + foa[1, :] = 0 # X通道 + foa[2, :] = 0 # Y通道 + foa[3, :] = 0 # Z通道 + elif signal_shape[0] == 2: # (2, s) -> (4, s) + left = signal[0] + right = signal[1] + # 将立体声信号映射到FOA信号通道 + foa[0, :] = (left + right) / np.sqrt(2) # W通道: 全方位声源 + foa[1, :] = (left - right) / np.sqrt(2) # X通道: 前后方向 + foa[2, :] = 0 # Y通道: 左右方向,简单实现先置零 + foa[3, :] = 0 # Z通道: 垂直方向,这里置零 + else: + foa = signal + + else: + raise ValueError(f"Unsupported signal shape: {signal_shape}") + + assert foa.shape[0] == 4, f'inputs not FOA format' + + return foa \ No newline at end of file diff --git a/ThinkSound/inference/__init__.py b/ThinkSound/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/inference/generation.py b/ThinkSound/inference/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6873fe9d19deba0497e5fb4d939472552a261c --- /dev/null +++ b/ThinkSound/inference/generation.py @@ -0,0 +1,274 @@ +import numpy as np +import torch +import typing as tp +import math +from torchaudio import transforms as T + +from .utils import prepare_audio +from .sampling import sample, sample_k, sample_rf +from ..data.utils import PadCrop + +def generate_diffusion_uncond( + model, + steps: int = 250, + batch_size: int = 1, + sample_size: int = 2097152, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + + # Inpainting mask + + if init_audio is not None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + # Now the generative AI part: + + diff_objective = model.diffusion_objective + + if diff_objective == "v": + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) + elif diff_objective == "rectified_flow": + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) + + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + + +def generate_diffusion_cond( + model, + steps: int = 250, + cfg_scale=6, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + negative_conditioning: dict = None, + negative_conditioning_tensors: tp.Optional[dict] = None, + batch_size: int = 1, + sample_size: int = 2097152, + sample_rate: int = 48000, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + mask_args: dict = None, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + """ + Generate audio from a prompt using a diffusion model. + + Args: + model: The diffusion model to use for generation. + steps: The number of diffusion steps to use. + cfg_scale: Classifier-free guidance scale + conditioning: A dictionary of conditioning parameters to use for generation. + conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. + batch_size: The batch size to use for generation. + sample_size: The length of the audio to generate, in samples. + sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) + seed: The random seed to use for generation, or -1 to use a random seed. + device: The device to use for generation. + init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. + init_noise_level: The noise level to use when generating from an initial audio sample. + return_latents: Whether to return the latents used for generation instead of the decoded audio. + **sampler_kwargs: Additional keyword arguments to pass to the sampler. + """ + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + torch.backends.cudnn.benchmark = False + import ipdb + # ipdb.set_trace() + # Conditioning + assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" + if conditioning_tensors is None: + conditioning_tensors = model.conditioner(conditioning, device) + conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) + + if negative_conditioning is not None or negative_conditioning_tensors is not None: + + if negative_conditioning_tensors is None: + negative_conditioning_tensors = model.conditioner(negative_conditioning, device) + + negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) + else: + negative_conditioning_tensors = {} + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + mask_args = None + + # Inpainting mask + if init_audio is not None and mask_args is not None: + # Cut and paste init_audio according to cropfrom, pastefrom, pasteto + # This is helpful for forward and reverse outpainting + cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) + pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) + pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) + assert pastefrom < pasteto, "Paste From should be less than Paste To" + croplen = pasteto - pastefrom + if cropfrom + croplen > sample_size: + croplen = sample_size - cropfrom + cropto = cropfrom + croplen + pasteto = pastefrom + croplen + cutpaste = init_audio.new_zeros(init_audio.shape) + cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] + #print(cropfrom, cropto, pastefrom, pasteto) + init_audio = cutpaste + # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args + mask = build_mask(sample_size, mask_args) + mask = mask.to(device) + elif init_audio is not None and mask_args is None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + model_dtype = next(model.model.parameters()).dtype + noise = noise.type(model_dtype) + conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()} + # Now the generative AI part: + # k-diffusion denoising process go! + diff_objective = model.diffusion_objective + if diff_objective == "v": + # k-diffusion denoising process go! + # sampled = sample(model.model, noise, steps, 0, **conditioning_inputs) + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + elif diff_objective == "rectified_flow": + + if "sigma_min" in sampler_kwargs: + del sampler_kwargs["sigma_min"] + + if "sampler_type" in sampler_kwargs: + del sampler_kwargs["sampler_type"] + + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + + # v-diffusion: + #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale) + del noise + del conditioning_tensors + del conditioning_inputs + torch.cuda.empty_cache() + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + #cast sampled latents to pretransform dtype + sampled = sampled.to(next(model.pretransform.parameters()).dtype) + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + +# builds a softmask given the parameters +# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, +# and anything between is a mixture of old/new +# ideally 0.5 is half/half mixture but i haven't figured this out yet +def build_mask(sample_size, mask_args): + maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) + maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) + softnessL = round(mask_args["softnessL"]/100.0 * sample_size) + softnessR = round(mask_args["softnessR"]/100.0 * sample_size) + marination = mask_args["marination"] + # use hann windows for softening the transition (i don't know if this is correct) + hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] + hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] + # build the mask. + mask = torch.zeros((sample_size)) + mask[maskstart:maskend] = 1 + mask[maskstart:maskstart+softnessL] = hannL + mask[maskend-softnessR:maskend] = hannR + # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds + if marination > 0: + mask = mask * (1-marination) + #print(mask) + return mask diff --git a/ThinkSound/inference/sampling.py b/ThinkSound/inference/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..3877d7252bcfd08f19bd59df80d48a1528b4e187 --- /dev/null +++ b/ThinkSound/inference/sampling.py @@ -0,0 +1,286 @@ +import torch +import math +from tqdm import trange, tqdm +import torch.distributions as dist + +import k_diffusion as K + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + +def t_to_alpha_sigma(t): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def sample_timesteps_logsnr(batch_size, mean_logsnr=-1.2, std_logsnr=2.0): + """ + Sample timesteps for diffusion training by sampling logSNR values and converting to t. + + Args: + batch_size (int): Number of timesteps to sample + mean_logsnr (float): Mean of the logSNR Gaussian distribution + std_logsnr (float): Standard deviation of the logSNR Gaussian distribution + + Returns: + torch.Tensor: Tensor of shape (batch_size,) containing timestep values t in [0, 1] + """ + # Sample logSNR from Gaussian distribution + logsnr = torch.randn(batch_size) * std_logsnr + mean_logsnr + + # Convert logSNR to timesteps using the logistic function + # Since logSNR = ln((1-t)/t), we can solve for t: + # t = 1 / (1 + exp(logsnr)) + t = torch.sigmoid(-logsnr) + + # Clamp values to ensure numerical stability + t = t.clamp(1e-4, 1 - 1e-4) + + return t +def truncated_logistic_normal_rescaled(shape, left_trunc=0.075, right_trunc=1): + """ + + shape: shape of the output tensor + left_trunc: left truncation point, fraction of probability to be discarded + right_trunc: right truncation boundary, should be 1 (never seen at test time) + """ + + # Step 1: Sample from the logistic normal distribution (sigmoid of normal) + logits = torch.randn(shape) + + # Step 2: Apply the CDF transformation of the normal distribution + normal_dist = dist.Normal(0, 1) + cdf_values = normal_dist.cdf(logits) + + # Step 3: Define the truncation bounds on the CDF + lower_bound = normal_dist.cdf(torch.logit(torch.tensor(left_trunc))) + upper_bound = normal_dist.cdf(torch.logit(torch.tensor(right_trunc))) + + # Step 4: Rescale linear CDF values into the truncated region (between lower_bound and upper_bound) + truncated_cdf_values = lower_bound + (upper_bound - lower_bound) * cdf_values + + # Step 5: Map back to logistic-normal space using inverse CDF + truncated_samples = torch.sigmoid(normal_dist.icdf(truncated_cdf_values)) + + # Step 6: Rescale values so that min is 0 and max is just below 1 + rescaled_samples = (truncated_samples - left_trunc) / (right_trunc - left_trunc) + + return rescaled_samples + +@torch.no_grad() +def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): + """Draws samples from a model given starting noise. Euler method""" + + # Make tensor of ones to broadcast the single t values + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(sigma_max, 0, steps + 1) + + #alphas, sigmas = 1-t, t + + for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): + # Broadcast the current timestep to the correct shape + t_curr_tensor = t_curr * torch.ones( + (x.shape[0],), dtype=x.dtype, device=x.device + ) + dt = t_prev - t_curr # we solve backwards in our formulation + x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) + + # If we are on the last timestep, output the denoised image + return x + +@torch.no_grad() +def sample(model, x, steps, eta, **extra_args): + """Draws samples from a model given starting noise. v-diffusion""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + + alphas, sigmas = get_alphas_sigmas(t) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], **extra_args).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred + +# Soft mask inpainting is just shrinking hard (binary) mask inpainting +# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step +def get_bmask(i, steps, mask): + strength = (i+1)/(steps) + # convert to binary mask + bmask = torch.where(mask<=strength,1,0) + return bmask + +def make_cond_model_fn(model, cond_fn): + def cond_model_fn(x, sigma, **kwargs): + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + return cond_denoised + return cond_model_fn + +# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_k( + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + denoiser = K.external.VDenoiser(model_fn) + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has + sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + # Scale the initial noise by sigma + noise = noise * sigmas[0] + + wrapped_callback = callback + + if mask is None and init_data is not None: + # VARIATION (no inpainting) + # set the initial latent to the init_data, and noise it with initial sigma + x = init_data + noise + elif mask is not None and init_data is not None: + # INPAINTING + bmask = get_bmask(0, steps, mask) + # initial noising + input_noised = init_data + noise + # set the initial latent to a mix of init_data and noise, based on step 0's binary mask + x = input_noised * bmask + noise * (1-bmask) + # define the inpainting callback function (Note: side effects, it mutates x) + # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` + def inpainting_callback(args): + i = args["i"] + x = args["x"] + sigma = args["sigma"] + #denoised = args["denoised"] + # noise the init_data input with this step's appropriate amount of noise + input_noised = init_data + torch.randn_like(init_data) * sigma + # shrinking hard mask + bmask = get_bmask(i, steps, mask) + # mix input_noise with x, using binary mask + new_x = input_noised * bmask + x * (1-bmask) + # mutate x + x[:,:,:] = new_x[:,:,:] + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: + wrapped_callback = inpainting_callback + else: + wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) + else: + # SAMPLING + # set the initial latent to noise + x = noise + + + with torch.cuda.amp.autocast(): + if sampler_type == "k-heun": + return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-lms": + return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpmpp-2s-ancestral": + return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-2": + return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-fast": + return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-adaptive": + return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + +# Uses discrete Euler sampling for rectified flow models +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_rf( + model_fn, + noise, + init_data=None, + steps=100, + sigma_max=1, + device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + if sigma_max > 1: + sigma_max = 1 + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + wrapped_callback = callback + + if init_data is not None: + # VARIATION (no inpainting) + # Interpolate the init data and the noise for init audio + x = init_data * (1 - sigma_max) + noise * sigma_max + else: + # SAMPLING + # set the initial latent to noise + x = noise + + with torch.cuda.amp.autocast(): + # TODO: Add callback support + #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) + return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) \ No newline at end of file diff --git a/ThinkSound/inference/utils.py b/ThinkSound/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6c0a57609f68156ad244da9b5819666329772e --- /dev/null +++ b/ThinkSound/inference/utils.py @@ -0,0 +1,35 @@ +from ..data.utils import PadCrop + +from torchaudio import transforms as T + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/ThinkSound/interface/__init__.py b/ThinkSound/interface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/interface/aeiou.py b/ThinkSound/interface/aeiou.py new file mode 100644 index 0000000000000000000000000000000000000000..8846c8968e48f9324f28642ade0dbc94a26ee99c --- /dev/null +++ b/ThinkSound/interface/aeiou.py @@ -0,0 +1,278 @@ +# Modified from https://github.com/drscotthawley/aeiou/blob/main/aeiou/viz.py under Apache 2.0 License +# License can be found in LICENSES/LICENSE_AEIOU.txt + +from matplotlib.backends.backend_agg import FigureCanvasAgg +import matplotlib.cm as cm +from matplotlib.colors import Normalize +from matplotlib.figure import Figure +import numpy as np +from PIL import Image + +import torch + +import torchaudio.transforms as T +from einops import rearrange + +import numpy as np + +def embeddings_table(tokens): + from wandb import Table + from pandas import DataFrame + + "make a table of embeddings for use with wandb" + features, labels = [], [] + embeddings = rearrange(tokens, 'b d n -> b n d') # each demo sample is n vectors in d-dim space + for i in range(embeddings.size()[0]): # nested for's are slow but sure ;-) + for j in range(embeddings.size()[1]): + features.append(embeddings[i,j].detach().cpu().numpy()) + labels.append([f'demo{i}']) # labels does the grouping / color for each point + features = np.array(features) + labels = np.concatenate(labels, axis=0) + cols = [f"dim_{i}" for i in range(features.shape[1])] + df = DataFrame(features, columns=cols) + df['LABEL'] = labels + return Table(columns=df.columns.to_list(), data=df.values) + +def project_down(tokens, # batched high-dimensional data with dims (b,d,n) + proj_dims=3, # dimensions to project to + method='pca', # projection method: 'pca'|'umap' + n_neighbors=10, # umap parameter for number of neighbors + min_dist=0.3, # umap param for minimum distance + debug=False, # print more info while running + **kwargs, # other params to pass to umap, cf. https://umap-learn.readthedocs.io/en/latest/parameters.html + ): + "this projects to lower dimenions, grabbing the first _`proj_dims`_ dimensions" + method = method.lower() + A = rearrange(tokens, 'b d n -> (b n) d') # put all the vectors into the same d-dim space + if A.shape[-1] > proj_dims: + if method=='umap': + from umap import UMAP + proj_data = UMAP(n_components=proj_dims, n_neighbors=n_neighbors, min_dist=min_dist, + metric='correlation', **kwargs).fit_transform(A.cpu().numpy()) + proj_data = torch.from_numpy(proj_data).to(tokens.device) + else: # pca + (U, S, V) = torch.pca_lowrank(A) + proj_data = torch.matmul(A, V[:, :proj_dims]) # this is the actual PCA projection step + else: + proj_data = A + if debug: print("proj_data.shape =",proj_data.shape) + return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, proj_dims] + + +def proj_pca(tokens, proj_dims=3): + return project_down(do_proj, method='pca', proj_dims=proj_dims) + +def point_cloud( + tokens, # embeddings / latent vectors. shape = (b, d, n) + method='pca', # projection method for 3d mapping: 'pca' | 'umap' + color_scheme='batch', # 'batch': group by sample; integer n: n groups, sequentially, otherwise color sequentially by time step + output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly! + mode='markers', # plotly scatter mode. 'lines+markers' or 'markers' + size=3, # size of the dots + line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line + ds_preproj=1, # EXPERIMENTAL: downsampling factor before projecting (1=no downsampling). Could screw up colors + ds_preplot=1, # EXPERIMENTAL: downsampling factor before plotting (1=no downsampling). Could screw up colors + debug=False, # print more info + colormap=None, # valid color map to use, None=defaults + darkmode=False, # dark background, white fonts + layout_dict=None, # extra plotly layout options such as camera orientation + rgb_float = False, # if True, color_scheme is RGB float values + **kwargs, # anything else to pass along + ): + "returns a 3D point cloud of the tokens" + if ds_preproj != 1: + tokens = tokens[torch.randperm(tokens.size()[0])] # EXPERIMENTAL: to 'correct' for possible weird effects of downsampling + tokens = tokens[::ds_preproj] + if debug: print("tokens.shape =",tokens.shape) + + data = project_down(tokens, method=method, debug=debug, **kwargs).cpu().numpy() + if debug: print("data.shape =",data.shape) + if data.shape[-1] < 3: # for data less than 3D, embed it in 3D + data = np.pad(data, ((0,0),(0,0),(0, 3-data.shape[-1])), mode='constant', constant_values=0) + + bytime = False + points = [] + if color_scheme=='batch': # all dots in same batch index same color, each batch-index unique (almost) + ncolors = data.shape[0] + cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors) + elif isinstance(color_scheme, int) or color_scheme.isnumeric(): # n groups, by batch-indices, sequentially + ncolors = int(color_scheme) + cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors) + else: # time steps match up + bytime, ncolors = True, data.shape[1] + cmap, norm = cm.viridis, Normalize(vmin=0, vmax=ncolors) + + cmap = cmap if colormap is None else colormap # overwrite default cmap with user choice if given + + points = [] + for bi in range(data.shape[0]): # batch index + if color_scheme=='batch': + [r, g, b, _] = [int(255*x) for x in cmap(norm(bi+1))] + elif isinstance(color_scheme, int) or color_scheme.isnumeric(): + grouplen = data.shape[0]//(ncolors) + #if debug: print(f"bi, grouplen, bi//grouplen = ",bi, grouplen, bi//grouplen) + [r, g, b, _] = [int(255*x) for x in cmap(norm(bi//grouplen))] + #if debug: print("r,g,b = ",r,g,b) + + if rgb_float: [r, g, b] = [x/255 for x in [r, g, b]] + + for n in range(data.shape[1]): # across time + if bytime: [r, g, b, _] = [int(255*x) for x in cmap(norm(n))] + points.append([data[bi,n,0], data[bi,n,1], data[bi,n,2], r, g, b]) # include dot colors with point coordinates + + point_cloud = np.array(points) + + if output_type == 'points': + return point_cloud + elif output_type =='plotly': + import plotly.graph_objects as go + + fig = go.Figure(data=[go.Scatter3d( + x=point_cloud[::ds_preplot,0], y=point_cloud[::ds_preplot,1], z=point_cloud[::ds_preplot,2], + marker=dict(size=size, color=point_cloud[:,3:6]), + mode=mode, + # show batch index and time index in tooltips: + text=[ f'bi: {i*ds_preplot}, ti: {j}' for i in range(data.shape[0]//ds_preplot) for j in range(data.shape[1]) ], + line=line, + )]) + fig.update_layout(margin=dict(l=0, r=0, b=0, t=0)) # tight layout + if darkmode: + fig.layout.template = 'plotly_dark' + if isinstance(darkmode, str): # 'rgb(12,15,24)'gradio margins in dark mode + fig.update_layout( paper_bgcolor=darkmode) + if layout_dict: + fig.update_layout( **layout_dict ) + + if debug: print("point_cloud: fig made. returning") + return fig + else: + from wandb import Object3D + return Object3D(point_cloud) + +def pca_point_cloud( + tokens, # embeddings / latent vectors. shape = (b, d, n) + color_scheme='batch', # 'batch': group by sample, otherwise color sequentially + output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly! + mode='markers', # plotly scatter mode. 'lines+markers' or 'markers' + size=3, # size of the dots + line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line + **kwargs, + ): + return point_cloud(tokens, method='pca', color_scheme=color_scheme, output_type=output_type, + mode=mode, size=size, line=line, **kwargs) + +def power_to_db(spec, *, amin = 1e-10): + magnitude = np.asarray(spec) + + log_spec = 10.0 * np.log10(np.maximum(amin, magnitude)) + log_spec -= 10.0 * np.log10(np.maximum(amin, 1)) + + log_spec = np.maximum(log_spec, log_spec.max() - 80) + + return log_spec + +def mel_spectrogram(waveform, power=2.0, sample_rate=48000, db=False, n_fft=1024, n_mels=128, debug=False): + "calculates data array for mel spectrogram (in however many channels)" + win_length = None + hop_length = n_fft//2 # 512 + + mel_spectrogram_op = T.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, + hop_length=hop_length, center=True, pad_mode="reflect", power=power, + norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") + + melspec = mel_spectrogram_op(waveform.float()) + if db: + amp_to_db_op = T.AmplitudeToDB() + melspec = amp_to_db_op(melspec) + if debug: + print_stats(melspec, print=print) + print(f"torch.max(melspec) = {torch.max(melspec)}") + print(f"melspec.shape = {melspec.shape}") + return melspec + +def spectrogram_image( + spec, + title=None, + ylabel='freq_bin', + aspect='auto', + xmax=None, + db_range=[35,120], + justimage=False, + figsize=(5, 4), # size of plot (if justimage==False) + ): + "Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html" + fig = Figure(figsize=figsize, dpi=100) if not justimage else Figure(figsize=(4.145, 4.145), dpi=100, tight_layout=True) + canvas = FigureCanvasAgg(fig) + axs = fig.add_subplot() + spec = spec.squeeze() + im = axs.imshow(power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1]) + if xmax: + axs.set_xlim((0, xmax)) + if justimage: + import matplotlib.pyplot as plt + axs.axis('off') + plt.tight_layout() + else: + axs.set_ylabel(ylabel) + axs.set_xlabel('frame') + axs.set_title(title or 'Spectrogram (dB)') + fig.colorbar(im, ax=axs) + canvas.draw() + rgba = np.asarray(canvas.buffer_rgba()) + im = Image.fromarray(rgba) + if justimage: # remove tiny white border + b = 15 # border size + im = im.crop((b,b, im.size[0]-b, im.size[1]-b)) + #print(f"im.size = {im.size}") + return im + +def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000, print=print, db=False, db_range=[35,120], justimage=False, log=False, figsize=(5, 4)): + "Wrapper for calling above two routines at once, does Mel scale; Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html" + melspec = mel_spectrogram(waveform, power=power, db=db, sample_rate=sample_rate, debug=log) + melspec = melspec[0] # TODO: only left channel for now + return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)', db_range=db_range, justimage=justimage, figsize=figsize) + +from matplotlib.ticker import AutoLocator +def tokens_spectrogram_image( + tokens, # the embeddings themselves (in some diffusion codes these are called 'tokens') + aspect='auto', # aspect ratio of plot + title='Embeddings', # title to put on top + ylabel='index', # label for y axis of plot + cmap='coolwarm', # colormap to use. (default used to be 'viridis') + symmetric=True, # make color scale symmetric about zero, i.e. +/- same extremes + figsize=(8, 4), # matplotlib size of the figure + dpi=100, # dpi of figure + mark_batches=False, # separate batches with dividing lines + debug=False, # print debugging info + ): + "for visualizing embeddings in a spectrogram-like way" + batch_size, dim, samples = tokens.shape + embeddings = rearrange(tokens, 'b d n -> (b n) d') # expand batches in time + vmin, vmax = None, None + if symmetric: + vmax = torch.abs(embeddings).max() + vmin = -vmax + + fig = Figure(figsize=figsize, dpi=dpi) + canvas = FigureCanvasAgg(fig) + ax = fig.add_subplot() + if symmetric: + subtitle = f'min={embeddings.min():0.4g}, max={embeddings.max():0.4g}' + ax.set_title(title+'\n') + ax.text(x=0.435, y=0.9, s=subtitle, fontsize=11, ha="center", transform=fig.transFigure) + else: + ax.set_title(title) + ax.set_ylabel(ylabel) + ax.set_xlabel('time frame (samples, in batches)') + if mark_batches: + intervals = np.arange(batch_size)*samples + if debug: print("intervals = ",intervals) + ax.vlines(intervals, -10, dim+10, color='black', linestyle='dashed', linewidth=1) + + im = ax.imshow(embeddings.cpu().numpy().T, origin='lower', aspect=aspect, interpolation='none', cmap=cmap, vmin=vmin,vmax=vmax) #.T because numpy is x/y 'backwards' + fig.colorbar(im, ax=ax) + fig.tight_layout() + canvas.draw() + rgba = np.asarray(canvas.buffer_rgba()) + return Image.fromarray(rgba) diff --git a/ThinkSound/interface/gradio.py b/ThinkSound/interface/gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..f38468bc34b88ec6bbe5451a8b11b998430888f8 --- /dev/null +++ b/ThinkSound/interface/gradio.py @@ -0,0 +1,700 @@ +import gc +import platform + +import numpy as np +import gradio as gr +import json +import torch +import torchaudio + +from aeiou.viz import audio_spectrogram_image +from einops import rearrange +from safetensors.torch import load_file +from torch.nn import functional as F +from torchaudio import transforms as T + +from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..models.factory import create_model_from_config +from ..models.pretrained import get_pretrained_model +from ..models.utils import load_ckpt_state_dict +from ..inference.utils import prepare_audio +from ..training.utils import copy_state_dict + +model = None +sample_rate = 32000 +sample_size = 1920000 + +def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + global model, sample_rate, sample_size + + if pretrained_name is not None: + print(f"Loading pretrained model {pretrained_name}") + model, model_config = get_pretrained_model(pretrained_name) + + elif model_config is not None and model_ckpt_path is not None: + print(f"Creating model from config") + model = create_model_from_config(model_config) + + print(f"Loading model checkpoint from {model_ckpt_path}") + # Load checkpoint + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + + if pretransform_ckpt_path is not None: + print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") + model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) + print(f"Done loading pretransform") + + model.to(device).eval().requires_grad_(False) + + if model_half: + model.to(torch.float16) + + print(f"Done loading model") + + return model, model_config + +def generate_cond( + prompt, + negative_prompt=None, + seconds_start=0, + seconds_total=30, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + cfg_rescale=0.0, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1 + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + print(f"Prompt: {prompt}") + + global preview_images + preview_images = [] + if preview_every == 0: + preview_every = None + + # Return fake stereo audio + conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size + + if negative_prompt: + negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size + else: + negative_conditioning = None + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + denoised = rearrange(denoised, "b d n -> d (b n)") + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + # If inpainting, send mask args + # This will definitely change in the future + if mask_cropfrom is not None: + mask_args = { + "cropfrom": mask_cropfrom, + "pastefrom": mask_pastefrom, + "pasteto": mask_pasteto, + "maskstart": mask_maskstart, + "maskend": mask_maskend, + "softnessL": mask_softnessL, + "softnessR": mask_softnessR, + "marination": mask_marination, + } + else: + mask_args = None + + # Do the audio generation + audio = generate_diffusion_cond( + model, + conditioning=conditioning, + negative_conditioning=negative_conditioning, + steps=steps, + cfg_scale=cfg_scale, + batch_size=batch_size, + sample_size=input_sample_size, + sample_rate=sample_rate, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + mask_args = mask_args, + callback = progress_callback if preview_every is not None else None, + scale_phi = cfg_rescale + ) + + # Convert to WAV file + audio = rearrange(audio, "b d n -> d (b n)") + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save("output.wav", audio, sample_rate) + + # Let's look at a nice spectrogram too + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_uncond( + steps=250, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + use_init=False, + init_audio=None, + init_noise_level=1.0, + batch_size=1, + preview_every=None + ): + + global preview_images + + preview_images = [] + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + + denoised = rearrange(denoised, "b d n -> d (b n)") + + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + audio = generate_diffusion_uncond( + model, + steps=steps, + batch_size=batch_size, + sample_size=input_sample_size, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + callback = progress_callback if preview_every is not None else None + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_lm( + temperature=1.0, + top_p=0.95, + top_k=0, + batch_size=1, + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + audio = model.generate_audio( + batch_size=batch_size, + max_gen_len = sample_size//model.pretransform.downsampling_ratio, + conditioning=None, + temp=temperature, + top_p=top_p, + top_k=top_k, + use_cache=True + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram]) + + +def create_uncond_sampling_ui(model_config): + generate_button = gr.Button("Generate", variant='primary', scale=1) + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + + with gr.Accordion("Init audio", open=False): + init_audio_checkbox = gr.Checkbox(label="Use init audio") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + send_to_init_button = gr.Button("Send to init audio", scale=1) + send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) + + generate_button.click(fn=generate_uncond, + inputs=[ + steps_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider, + ], + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") + +def create_sampling_ui(model_config, inpainting=False): + with gr.Row(): + with gr.Column(scale=6): + prompt = gr.Textbox(show_label=False, placeholder="Prompt") + negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") + generate_button = gr.Button("Generate", variant='primary', scale=1) + + model_conditioning_config = model_config["model"].get("conditioning", None) + + has_seconds_start = False + has_seconds_total = False + + if model_conditioning_config is not None: + for conditioning_config in model_conditioning_config["configs"]: + if conditioning_config["id"] == "seconds_start": + has_seconds_start = True + if conditioning_config["id"] == "seconds_total": + has_seconds_total = True + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(visible = has_seconds_start or has_seconds_total): + # Timing controls + seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) + seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) + + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + + # Preview Every slider + preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") + + # CFG scale + cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount") + + if inpainting: + # Inpainting Tab + with gr.Accordion("Inpainting", open=False): + sigma_max_slider.maximum=1000 + + init_audio_checkbox = gr.Checkbox(label="Do inpainting") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this + + mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %") + mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %") + mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %") + + mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %") + mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %") + mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %") + mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %") + mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this + + inputs = [prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider, + mask_cropfrom_slider, + mask_pastefrom_slider, + mask_pasteto_slider, + mask_maskstart_slider, + mask_maskend_slider, + mask_softnessL_slider, + mask_softnessR_slider, + mask_marination_slider + ] + else: + # Default generation tab + with gr.Accordion("Init audio", open=False): + init_audio_checkbox = gr.Checkbox(label="Use init audio") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + + inputs = [prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider + ] + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + send_to_init_button = gr.Button("Send to init audio", scale=1) + send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) + + generate_button.click(fn=generate_cond, + inputs=inputs, + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") + + +def create_txt2audio_ui(model_config): + with gr.Blocks() as ui: + with gr.Tab("Generation"): + create_sampling_ui(model_config) + with gr.Tab("Inpainting"): + create_sampling_ui(model_config, inpainting=True) + return ui + +def create_diffusion_uncond_ui(model_config): + with gr.Blocks() as ui: + create_uncond_sampling_ui(model_config) + + return ui + +def autoencoder_process(audio, latent_noise, n_quantizers): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) + else: + audio = audio.transpose(0, 1) + + audio = model.preprocess_audio_for_encoder(audio, in_sr) + # Note: If you need to do chunked encoding, to reduce VRAM, + # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 + # To turn it off, do chunked=False + # Optimal overlap and chunk_size values will depend on the model. + # See encode_audio & decode_audio in autoencoders.py for more info + # Get dtype of model + dtype = next(model.parameters()).dtype + + audio = audio.to(dtype) + + if n_quantizers > 0: + latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) + else: + latents = model.encode_audio(audio, chunked=False) + + if latent_noise > 0: + latents = latents + torch.randn_like(latents) * latent_noise + + audio = model.decode_audio(latents, chunked=False) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_autoencoder_ui(model_config): + + is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"] + + if is_dac_rvq: + n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"] + else: + n_quantizers = 0 + + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq) + latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process") + + return ui + +def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) # [1, n] + elif audio.dim() == 2: + audio = audio.transpose(0, 1) # [n, 2] -> [2, n] + + audio = audio.unsqueeze(0) + + audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_diffusion_prior_ui(model_config): + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + # Sampler params + with gr.Row(): + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process") + + return ui + +def create_lm_ui(model_config): + with gr.Blocks() as ui: + output_audio = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + + # Sampling params + with gr.Row(): + temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature") + top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p") + top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k") + + generate_button = gr.Button("Generate", variant='primary', scale=1) + generate_button.click( + fn=generate_lm, + inputs=[ + temperature_slider, + top_p_slider, + top_k_slider + ], + outputs=[output_audio, audio_spectrogram_output], + api_name="generate" + ) + + return ui + +def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): + + assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" + + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + model_config = json.load(f) + else: + model_config = None + + try: + has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() + except Exception: + # In case this version of Torch doesn't even have `torch.backends.mps`... + has_mps = False + + if has_mps: + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + print("Using device:", device) + + _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) + + model_type = model_config["model_type"] + + if model_type == "diffusion_cond": + ui = create_txt2audio_ui(model_config) + elif model_type == "diffusion_uncond": + ui = create_diffusion_uncond_ui(model_config) + elif model_type == "autoencoder" or model_type == "diffusion_autoencoder": + ui = create_autoencoder_ui(model_config) + elif model_type == "diffusion_prior": + ui = create_diffusion_prior_ui(model_config) + elif model_type == "lm": + ui = create_lm_ui(model_config) + + return ui diff --git a/ThinkSound/models/__init__.py b/ThinkSound/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e27bbcb19a00a93e05ed6cf2a3a38895f26975d --- /dev/null +++ b/ThinkSound/models/__init__.py @@ -0,0 +1 @@ +from .factory import create_model_from_config, create_model_from_config_path \ No newline at end of file diff --git a/ThinkSound/models/adp.py b/ThinkSound/models/adp.py new file mode 100644 index 0000000000000000000000000000000000000000..49eb526ab02d16eb4952d346401b1ad2b7e5cb7c --- /dev/null +++ b/ThinkSound/models/adp.py @@ -0,0 +1,1588 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d + +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/ThinkSound/models/autoencoders.py b/ThinkSound/models/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..853fdc9712c12284b2073c1cce8c384baa258337 --- /dev/null +++ b/ThinkSound/models/autoencoders.py @@ -0,0 +1,800 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +from ..inference.sampling import sample +from ..inference.utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + # import ipdb + # ipdb.set_trace() + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + latents = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + # import ipdb + # ipdb.set_trace() + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + print(f'audio shape: {audio.shape}') + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + print(f'y_final shape: {y_final.shape}') + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + print(f'y_chunk shape: {y_chunk.shape}') + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +class DiffusionAutoencoder(AudioAutoencoder): + def __init__( + self, + diffusion: ConditionedDiffusionModel, + diffusion_downsampling_ratio, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.diffusion = diffusion + + self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio + + if self.encoder is not None: + # Shrink the initial encoder parameters to avoid saturated latents + with torch.no_grad(): + for param in self.encoder.parameters(): + param *= 0.5 + + def decode(self, latents, steps=100): + + upsampled_length = latents.shape[2] * self.downsampling_ratio + + if self.bottleneck is not None: + latents = self.bottleneck.decode(latents) + + if self.decoder is not None: + latents = self.decode(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != upsampled_length: + latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + + noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) + decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + + return decoded + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + +def create_diffAE_from_config(config: Dict[str, Any]): + + diffae_config = config["model"] + + if "encoder" in diffae_config: + encoder = create_encoder_from_config(diffae_config["encoder"]) + else: + encoder = None + + if "decoder" in diffae_config: + decoder = create_decoder_from_config(diffae_config["decoder"]) + else: + decoder = None + + diffusion_model_type = diffae_config["diffusion"]["type"] + + if diffusion_model_type == "DAU1d": + diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "adp_1d": + diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "dit": + diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) + + latent_dim = diffae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = diffae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = diffae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + bottleneck = diffae_config.get("bottleneck", None) + + pretransform = diffae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + diffusion_downsampling_ratio = None, + + if diffusion_model_type == "DAU1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) + elif diffusion_model_type == "adp_1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) + elif diffusion_model_type == "dit": + diffusion_downsampling_ratio = 1 + + return DiffusionAutoencoder( + encoder=encoder, + decoder=decoder, + diffusion=diffusion, + io_channels=io_channels, + sample_rate=sample_rate, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + diffusion_downsampling_ratio=diffusion_downsampling_ratio, + bottleneck=bottleneck, + pretransform=pretransform + ) diff --git a/ThinkSound/models/blocks.py b/ThinkSound/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..3c827fd2441e643717d123847236d3d6c003ef4f --- /dev/null +++ b/ThinkSound/models/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/ThinkSound/models/bottleneck.py b/ThinkSound/models/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..5e81cab4bfb16b615ee21d5e9248e3b455f7eb5b --- /dev/null +++ b/ThinkSound/models/bottleneck.py @@ -0,0 +1,355 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + self.bypass_mmd = bypass_mmd + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + if self.bypass_mmd: + mmd = torch.tensor(0.0) + else: + mmd = compute_mmd(x) + + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, noise_augment_dim=0, **kwargs): + super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") + + self.noise_augment_dim = noise_augment_dim + + self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) + + def encode(self, x, return_info=False): + info = {} + + orig_dtype = x.dtype + x = x.float() + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + x = x.to(orig_dtype) + + # Reorder indices to match the expected format + indices = rearrange(indices, "b n q -> b q n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/ThinkSound/models/codebook_patterns.py b/ThinkSound/models/codebook_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bd2a9b837bd77cb40f3b500b02ea491dfb9da0 --- /dev/null +++ b/ThinkSound/models/codebook_patterns.py @@ -0,0 +1,545 @@ +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def starts_with_special_token(self): + return self.layout[0] == [] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output and self.starts_with_special_token(): + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + omit_special_token = self.empty_initial < 0 + out: PatternLayout = [] if omit_special_token else [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, empty_initial: int = 0): + super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (list of int, optional): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) \ No newline at end of file diff --git a/ThinkSound/models/conditioners.py b/ThinkSound/models/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4efc657b815092710427ae119f449049b1cf18 --- /dev/null +++ b/ThinkSound/models/conditioners.py @@ -0,0 +1,1082 @@ +#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py + +import torch +import logging, warnings +import string +import typing as tp +import gc +from typing import Literal, Optional +import os +from .adp import NumberEmbedder +from ..inference.utils import set_audio_channels +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..training.utils import copy_state_dict +from .utils import load_ckpt_state_dict +import numpy as np +from einops import rearrange +from transformers import AutoProcessor, AutoModel +from torch import nn +import torch.nn.functional as F +from .mmmodules.model.low_level import ConvMLP, MLP +from torch.nn.utils.rnn import pad_sequence + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + +class Cond_MLP(Conditioner): + def __init__(self, dim, output_dim, dropout = 0.0): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.dropout = dropout + def forward(self, x, device: tp.Any = "cuda"): + x = pad_sequence(x, batch_first=True).to(device) + # x = torch.stack(x, dim=0).to(device) + + if self.dropout > 0.0: + if self.training: + null_embed = torch.zeros_like(x, device=device) + dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) + x = torch.where(dropout_mask, null_embed, x) + elif x.shape[0] < 16: # default test batch size=1 + null_embed = torch.zeros_like(x, device=device) + x = torch.cat([x, null_embed], dim=0) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Global_MLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + x = x.mean(dim=1) + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_MLP_1(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + MLP(output_dim, output_dim * 4), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_MLP_Global(Conditioner): + def __init__(self, dim, output_dim, dropout = 0.0): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.global_embedder = nn.Sequential( + nn.Linear(output_dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.dropout = dropout + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + if self.dropout > 0 and self.training: + null_embed = torch.zeros_like(x, device=device) + dropout_mask = torch.bernoulli(torch.full((x.shape[0], 1, 1), self.dropout, device=device)).to(torch.bool) + x = torch.where(dropout_mask, null_embed, x) + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x[:,0,:]) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Cond_MLP_Global_1(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + MLP(output_dim, output_dim * 4), + ) + self.global_embedder = nn.Sequential( + nn.Linear(dim, output_dim), + MLP(output_dim, output_dim * 4), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Cond_MLP_Global_2(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.global_embedder = nn.Sequential( + nn.Linear(output_dim, output_dim, bias=False), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_embedder(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Sync_MLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim, bias=False), + nn.SiLU(), + nn.Linear(output_dim, output_dim, bias=False) + ) + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, dim))) + nn.init.constant_(self.sync_pos_emb, 0) + def forward(self, x, device: tp.Any = "cuda"): + sync_f = torch.stack(x, dim=0).to(device) + bs, length, dim = sync_f.shape + #print(sync_f.shape,flush=True) + # B * num_segments (24) * 8 * 768 + num_sync_segments = length // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + x = self.embedder(sync_f) # B x 117 x C + x = x.transpose(1,2) + x = F.interpolate(x, ((int)(194*sync_f.shape[1]/216), ), mode='linear', align_corners=False) + x = x.transpose(1,2) + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Cond_ConvMLP(Conditioner): + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential( + nn.Linear(dim, output_dim), + nn.SiLU(), + ConvMLP(output_dim, output_dim * 4, kernel_size=1, padding=0), + ) + def forward(self, x, device: tp.Any = "cuda"): + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Video_Global(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim, global_dim=1536): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + data = torch.load(path) + video_feats.append(data['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_proj(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Video_Sync(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['sync_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Text_Linear(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_text_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class mm_unchang(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + return [x] + +class CLIPConditioner(Conditioner): + + CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"] + + CLIP_MODEL_DIMS = { + "metaclip-base": 512, + "metaclip-b16": 512, + "metaclip-large": 768, + "metaclip-huge": 1024, + } + + def __init__( + self, + dim: int, + output_dim: int, + clip_model_name: str = "metaclip-huge", + enable_grad: bool = False, + project_out: bool = False + ): + assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}" + super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out) + + self.enable_grad = enable_grad + model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + + + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + # import ipdb + # ipdb.set_trace() + + self.model.eval() + if not isinstance(images[0], torch.Tensor): + video_feats = [] + for path in images: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + images = torch.stack(video_feats, dim=0).to(device) + else: + images = torch.stack(images, dim=0).to(device) + bsz, t, c, h, w = images.shape + # 使用 rearrange 进行维度合并 + images = rearrange(images, 'b t c h w -> (b t) c h w') + with torch.set_grad_enabled(self.enable_grad): + image_features = self.model.get_image_features(images) + image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t) + image_features = self.proj_out(image_features) + + + return [image_features, torch.ones(image_features.shape[0], 1).to(device)] + +class IntConditioner(Conditioner): + def __init__(self, + output_dim: int, + min_val: int=0, + max_val: int=512 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) + + def forward(self, ints: tp.List[int], device=None) -> tp.Any: + + #self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats: tp.List[float], device=None) -> tp.Any: + + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] + +class CLAPTextConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + use_text_features = False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False): + super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) + + self.use_text_features = use_text_features + self.feature_layer_ix = feature_layer_ix + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.text_branch.requires_grad_(True) + self.model.model.text_branch.train() + else: + self.model.model.text_branch.requires_grad_(False) + self.model.model.text_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.audio_branch + + gc.collect() + torch.cuda.empty_cache() + + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + prompt_tokens = self.model.tokenizer(prompts) + attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) + prompt_features = self.model.model.text_branch( + input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), + attention_mask=attention_mask, + output_hidden_states=True + )["hidden_states"][layer_ix] + + return prompt_features, attention_mask + + def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + + if self.use_text_features: + if len(texts) == 1: + text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features = text_features[:1, ...] + text_attention_mask = text_attention_mask[:1, ...] + else: + text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + return [self.proj_out(text_features), text_attention_mask] + + # Fix for CLAP bug when only one text is passed + if len(texts) == 1: + text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] + else: + text_embedding = self.model.get_text_embedding(texts, use_tensor=True) + + text_embedding = text_embedding.unsqueeze(1).to(device) + + return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + +class CLAPAudioConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False): + super().__init__(512, output_dim, project_out=project_out) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + self.model = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + self.model.model.audio_branch.requires_grad_(False) + self.model.model.audio_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.text_branch + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + + self.model.to(device) + + if isinstance(audios, list) or isinstance(audios, tuple): + audios = torch.cat(audios, dim=0) + + # Convert to mono + mono_audios = audios.mean(dim=1) + + with torch.cuda.amp.autocast(enabled=False): + audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) + + audio_embedding = audio_embedding.unsqueeze(1).to(device) + + return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + +class T5Conditioner(Conditioner): + + T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"] + + T5_MODEL_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "t5-v1_1-xl": 2048, + "google/t5-v1_1-xxl": 4096, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" + super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) + # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) + self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name)) + model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + embeddings = self.proj_out(embeddings.float()) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + return F.normalize(x, dim=-1) if normalize else x + + clip_model.encode_text = new_encode_text.__get__(clip_model) + return clip_model + +class CLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + import open_clip + from open_clip import create_model_from_pretrained + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384', + return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + model = patch_clip(model) + self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + + encoded = self.tokenizer( + texts + ).to(device) + + # input_ids = encoded["input_ids"].to(device) + # attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.encode_text( + encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0], 1).to(device) + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + # pooled_output = text_outputs[1] + # text_features = self.text_projection(pooled_output) + + return last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + +class MetaCLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import AutoModel + from transformers import AutoProcessor + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge") + self.model = patch_clip(self.model) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + finally: + logging.disable(previous_level) + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device) + + # input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.get_text_features( + **encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0],1).to(device) + +class PhonemeConditioner(Conditioner): + """ + A conditioner that turns text into phonemes and embeds them using a lookup table + Only works for English text + + Args: + output_dim: the dimension of the output embeddings + max_length: the maximum number of phonemes to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from g2p_en import G2p + + self.max_length = max_length + + self.g2p = G2p() + + # Reserving 0 for padding, 1 for ignored + self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.phoneme_embedder.to(device) + self.proj_out.to(device) + + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + + phoneme_ignore = [" ", *string.punctuation] + + # Remove ignored phonemes and cut to max length + batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] + + # Convert to ids + phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + + #Pad to match longest and make a mask tensor for the padding + longest = max([len(ids) for ids in phoneme_ids]) + phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] + + phoneme_ids = torch.tensor(phoneme_ids).to(device) + + # Convert to embeddings + phoneme_embeds = self.phoneme_embedder(phoneme_ids) + + phoneme_embeds = self.proj_out(phoneme_embeds) + + return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) + +class TokenizerLUTConditioner(Conditioner): + """ + A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary + + Args: + tokenizer_name: the name of the tokenizer from the Hugging Face transformers library + output_dim: the dimension of the output embeddings + max_length: the maximum length of the text to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from transformers import AutoTokenizer + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + finally: + logging.disable(previous_level) + + self.max_length = max_length + + self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + embeddings = self.token_embedder(input_ids) + + embeddings = self.proj_out(embeddings) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PretransformConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + audio = torch.cat(audio, dim=0) + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + + latents = self.proj_out(latents) + + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + +class MultiConditioner(nn.Module): + """ + A module that applies multiple conditioners to an input dictionary based on the keys + + Args: + conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") + default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) + """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): + super().__init__() + + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys + + def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + output = {} + + for key, conditioner in self.conditioners.items(): + condition_key = key + + conditioner_inputs = [] + + for x in batch_metadata: + + if condition_key not in x: + if condition_key in self.default_keys: + condition_key = self.default_keys[condition_key] + else: + raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") + + #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list + if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + conditioner_input = x[condition_key][0] + + else: + conditioner_input = x[condition_key] + + conditioner_inputs.append(conditioner_input) + + cond_output = conditioner(conditioner_inputs, device) + if len(cond_output) == 1: + output[key] = cond_output[0] + elif len(cond_output) == 2: + output[key] = cond_output + elif len(cond_output) == 4: + output[key] = cond_output[:2] + output[f'{key}_g'] = cond_output[2:] + + return output + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: + """ + Create a MultiConditioner from a conditioning config dictionary + + Args: + config: the conditioning config dictionary + device: the device to put the conditioners on + """ + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clip_text": + conditioners[id] = CLIPTextConditioner(**conditioner_config) + elif conditioner_type == "metaclip_text": + conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "cond_mlp": + conditioners[id] = Cond_MLP(**conditioner_config) + elif conditioner_type == "global_mlp": + conditioners[id] = Global_MLP(**conditioner_config) + elif conditioner_type == "sync_mlp": + conditioners[id] = Sync_MLP(**conditioner_config) + elif conditioner_type == "cond_mlp_1": + conditioners[id] = Cond_MLP_1(**conditioner_config) + elif conditioner_type == "cond_convmlp": + conditioners[id] = Cond_ConvMLP(**conditioner_config) + elif conditioner_type == "cond_mlp_global": + conditioners[id] = Cond_MLP_Global(**conditioner_config) + elif conditioner_type == "cond_mlp_global_1": + conditioners[id] = Cond_MLP_Global_1(**conditioner_config) + elif conditioner_type == "cond_mlp_global_2": + conditioners[id] = Cond_MLP_Global_2(**conditioner_config) + elif conditioner_type == "video_linear": + conditioners[id] = Video_Linear(**conditioner_config) + elif conditioner_type == "video_global": + conditioners[id] = Video_Global(**conditioner_config) + elif conditioner_type == "video_sync": + conditioners[id] = Video_Sync(**conditioner_config) + elif conditioner_type == "text_linear": + conditioners[id] = Text_Linear(**conditioner_config) + elif conditioner_type == "video_clip": + conditioners[id] = CLIPConditioner(**conditioner_config) + elif conditioner_type == "video_hiera": + conditioners[id] = VideoHieraConditioner(**conditioner_config) + elif conditioner_type == "meta_query": + from .meta_queries.model import MLLMInContext + conditioners[id] = MLLMInContext(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + elif conditioner_type == "mm_unchang": + conditioners[id] = mm_unchang(**conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/ThinkSound/models/diffusion.py b/ThinkSound/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..62638843abe7f1661c26840a31d6529b03ae727d --- /dev/null +++ b/ThinkSound/models/diffusion.py @@ -0,0 +1,957 @@ +import torch +from torch import nn +from torch.nn import functional as F +from functools import partial +import numpy as np +import typing as tp + +from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .dit import DiffusionTransformer +#from .mmdit import MMAudio +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..inference.generation import generate_diffusion_cond + +from .adp import UNetCFG1d, UNet1d + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionModel(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, t, **kwargs): + raise NotImplementedError() + +class DiffusionModelWrapper(nn.Module): + def __init__( + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, + ): + super().__init__() + self.io_channels = io_channels + self.sample_size = sample_size + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.model = model + + if pretransform is not None: + self.pretransform = pretransform + else: + self.pretransform = None + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class ConditionedDiffusionModel(nn.Module): + def __init__(self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + self.supports_cross_attention = supports_cross_attention + self.supports_input_concat = supports_input_concat + self.supports_global_cond = supports_global_cond + self.supports_prepend_cond = supports_prepend_cond + + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs): + raise NotImplementedError() + +class ConditionedDiffusionModelWrapper(nn.Module): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + zero_init: bool = False, + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + sync_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.sync_cond_ids = sync_cond_ids + self.min_input_length = min_input_length + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + if zero_init is True: + self.conditioner.apply(_basic_init) + self.model.model.initialize_weights() + + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + sync_input = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = [] + cross_attention_masks = [] + + for key in self.cross_attn_cond_ids: + cross_attn_in, cross_attn_mask = conditioning_tensors[key] + + # Add sequence dimension if it's not there + if len(cross_attn_in.shape) == 2: + cross_attn_in = cross_attn_in.unsqueeze(1) + # cross_attn_mask = cross_attn_mask.unsqueeze(1) + + cross_attention_input.append(cross_attn_in) + cross_attention_masks.append(cross_attn_mask) + # import ipdb + # ipdb.set_trace() + cross_attention_input = torch.cat(cross_attention_input, dim=1) + cross_attention_masks = torch.cat(cross_attention_masks, dim=1) + + if len(self.add_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + add_input = [] + + for key in self.add_cond_ids: + add_in = conditioning_tensors[key][0] + + # Add sequence dimension if it's not there + if len(add_in.shape) == 2: + add_in = add_in.unsqueeze(1) + # add_in = add_in.transpose(1,2) + # add_in = F.interpolate(add_in, (194, ), mode='linear', align_corners=False) + # add_in = add_in.transpose(1,2) + add_input.append(add_in) + + add_input = torch.cat(add_input, dim=2) + + if len(self.sync_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + sync_input = [] + + for key in self.sync_cond_ids: + sync_in = conditioning_tensors[key][0] + + # Add sequence dimension if it's not there + if len(sync_in.shape) == 2: + sync_in = sync_in.unsqueeze(1) + sync_input.append(sync_in) + + sync_input = torch.cat(sync_input, dim=2) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_conds = [] + for key in self.global_cond_ids: + global_cond_input = conditioning_tensors[key][0] + if len(global_cond_input.shape) == 2: + global_cond_input = global_cond_input.unsqueeze(1) + global_conds.append(global_cond_input) + + # # Concatenate over the channel dimension + # if global_conds[0].shape[-1] == 768: + # global_cond = torch.cat(global_conds, dim=-1) + # else: + # global_cond = sum(global_conds) + global_cond = sum(global_conds) + # global_cond = torch.cat(global_conds, dim=-1) + + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if len(self.input_concat_ids) > 0: + # Concatenate all input concat conditioning inputs over the channel dimension + # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) + input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_conds = [] + prepend_cond_masks = [] + + for key in self.prepend_cond_ids: + prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] + if len(prepend_cond_input.shape) == 2: + prepend_cond_input = prepend_cond_input.unsqueeze(1) + prepend_conds.append(prepend_cond_input) + prepend_cond_masks.append(prepend_cond_mask) + + prepend_cond = torch.cat(prepend_conds, dim=1) + prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_cross_attn_mask": cross_attention_masks, + "negative_global_cond": global_cond, + "negative_input_concat_cond": input_concat_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "cross_attn_mask": cross_attention_masks, + "global_cond": global_cond, + "input_concat_cond": input_concat_cond, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "add_cond": add_input, + "sync_cond": sync_input + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class UNetCFG1DWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) + + self.model = UNetCFG1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + p = Profiler() + + p.tick("start") + + channels_list = None + if input_concat_cond is not None: + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + embedding=cross_attn_cond, + embedding_mask=cross_attn_mask, + features=global_cond, + channels_list=channels_list, + embedding_scale=cfg_scale, + embedding_mask_proba=cfg_dropout_prob, + batch_cfg=batch_cfg, + rescale_cfg=rescale_cfg, + negative_embedding=negative_cross_attn_cond, + negative_embedding_mask=negative_cross_attn_mask, + **kwargs) + + p.tick("UNetCFG1D forward") + + #print(f"Profiler: {p}") + return outputs + +class UNet1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) + + self.model = UNet1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs): + + channels_list = None + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + features=global_cond, + channels_list=channels_list, + **kwargs) + + return outputs + +class UNet1DUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = UNet1d(in_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class DAU1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) + + self.model = DiffusionAttnUnet1D(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs): + + return self.model(x, t, cond = input_concat_cond) + +class DiffusionAttnUnet1D(nn.Module): + def __init__( + self, + io_channels = 2, + depth=14, + n_attn_layers = 6, + channels = [128, 128, 256, 256] + [512] * 10, + cond_dim = 0, + cond_noise_aug = False, + kernel_size = 5, + learned_resample = False, + strides = [2] * 13, + conv_bias = True, + use_snake = False + ): + super().__init__() + + self.cond_noise_aug = cond_noise_aug + + self.io_channels = io_channels + + if self.cond_noise_aug: + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_embed = FourierFeatures(1, 16) + + attn_layer = depth - n_attn_layers + + strides = [1] + strides + + block = nn.Identity() + + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + + for i in range(depth, 0, -1): + c = channels[i - 1] + stride = strides[i-1] + if stride > 2 and not learned_resample: + raise ValueError("Must have stride 2 without learned resampling") + + if i > 1: + c_prev = channels[i - 2] + add_attn = i >= attn_layer and n_attn_layers > 0 + block = SkipBlock( + Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + conv_block(c_prev, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + block, + conv_block(c * 2 if i != depth else c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c_prev), + SelfAttention1d(c_prev, c_prev // + 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + ) + else: + cond_embed_dim = 16 if not self.cond_noise_aug else 32 + block = nn.Sequential( + conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), + conv_block(c, c, c), + conv_block(c, c, c), + block, + conv_block(c * 2, c, c), + conv_block(c, c, c), + conv_block(c, c, io_channels, is_last=True), + ) + self.net = block + + with torch.no_grad(): + for param in self.net.parameters(): + param *= 0.5 + + def forward(self, x, t, cond=None, cond_aug_scale=None): + + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) + + inputs = [x, timestep_embed] + + if cond is not None: + if cond.shape[2] != x.shape[2]: + cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + + if self.cond_noise_aug: + # Get a random number between 0 and 1, uniformly sampled + if cond_aug_scale is None: + aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) + else: + aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) + + # Add noise to the conditioning signal + cond = cond + torch.randn_like(cond) * aug_level[:, None, None] + + # Get embedding for noise cond level, reusing timestamp_embed + aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) + + inputs.append(aug_level_embed) + + inputs.append(cond) + + outputs = self.net(torch.cat(inputs, dim=1)) + + return outputs + +class DiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = DiffusionTransformer(*args, **kwargs) + # with torch.no_grad(): + # for param in self.model.parameters(): + # param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_mask, + negative_cross_attn_cond=negative_cross_attn_cond, + negative_cross_attn_mask=negative_cross_attn_mask, + input_concat_cond=input_concat_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + global_embed=global_cond, + **kwargs) + +class MMDiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = MMAudio(*args, **kwargs) + + # with torch.no_grad(): + # for param in self.model.parameters(): + # param *= 0.5 + + def forward(self, + x, + t, + clip_f, + sync_f, + text_f, + inpaint_masked_input=None, + t5_features=None, + metaclip_global_text_features=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + # breakpoint() + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + latent=x, + t=t, + clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + inpaint_masked_input=inpaint_masked_input, + t5_features=t5_features, + metaclip_global_text_features=metaclip_global_text_features, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + **kwargs) + +class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + mm_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.min_input_length = min_input_length + self.mm_cond_ids = mm_cond_ids + + assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper" + assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper" + assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper" + assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper" + assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper" + assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper" + assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper" + assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper" + assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper" + # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper" + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + assert negative == False, "negative conditioning is not supported for MMDiTWrapper" + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + inpaint_masked_input = None + t5_features = None + metaclip_global_text_features = None + clip_f = conditioning_tensors["metaclip_features"] + sync_f = conditioning_tensors["sync_features"] + text_f = conditioning_tensors["metaclip_text_features"] + if 'inpaint_masked_input' in conditioning_tensors.keys(): + inpaint_masked_input = conditioning_tensors["inpaint_masked_input"] + if 't5_features' in conditioning_tensors.keys(): + t5_features = conditioning_tensors["t5_features"] + if 'metaclip_global_text_features' in conditioning_tensors.keys(): + metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"] + return { + "clip_f": clip_f, + "sync_f": sync_f, + "text_f": text_f, + "inpaint_masked_input": inpaint_masked_input, + "t5_features": t5_features, + "metaclip_global_text_features": metaclip_global_text_features + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + # breakpoint() + # print(kwargs) + return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class DiTUncondWrapper(DiffusionModel): + def __init__( + self, + io_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs) + + self.io_channels = io_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + model_type = diffusion_uncond_config.get('type', None) + + diffusion_config = diffusion_uncond_config.get('config', {}) + + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **diffusion_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + **diffusion_config + ) + + elif model_type == "dit": + model = DiTUncondWrapper( + **diffusion_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + + diffusion_config = diffusion_uncond_config.get('diffusion', {}) + model_type = diffusion_config.get('type', None) + model_config = diffusion_config.get("config",{}) + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **model_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + io_channels = io_channels, + **model_config + ) + elif model_type == "dit": + model = DiTUncondWrapper( + **model_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + elif diffusion_model_type == 'mmdit': + diffusion_model = MMDiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + sync_cond_ids = diffusion_config.get('sync_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + mm_cond_ids = diffusion_config.get('mm_cond_ids', []) + zero_init = diffusion_config.get('zero_init', False) + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + extra_kwargs = {} + + if model_type == "mm_diffusion_cond": + wrapper_fn = MMConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + extra_kwargs["mm_cond_ids"] = mm_cond_ids + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': + wrapper_fn = ConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + sync_cond_ids=sync_cond_ids, + pretransform=pretransform, + io_channels=io_channels, + zero_init=zero_init, + **extra_kwargs + ) \ No newline at end of file diff --git a/ThinkSound/models/diffusion_prior.py b/ThinkSound/models/diffusion_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb15258d7656fb85ee763910dc9b500331de603 --- /dev/null +++ b/ThinkSound/models/diffusion_prior.py @@ -0,0 +1,82 @@ +from enum import Enum +import typing as tp + +from .diffusion import ConditionedDiffusionModelWrapper +from ..inference.generation import generate_diffusion_cond +from ..inference.utils import prepare_audio + +import torch +from torch.nn import functional as F +from torchaudio import transforms as T + +# Define prior types enum +class PriorType(Enum): + MonoToStereo = 1 + +class DiffusionPrior(ConditionedDiffusionModelWrapper): + def __init__(self, *args, prior_type: PriorType=None, **kwargs): + super().__init__(*args, **kwargs) + self.prior_type = prior_type + +class MonoToStereoDiffusionPrior(DiffusionPrior): + def __init__(self, *args, **kwargs): + super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) + + def stereoize( + self, + audio: torch.Tensor, # (batch, channels, time) + video: torch.Tensor, + in_sr: int, + steps: int, + sampler_kwargs: dict = {}, + ): + """ + Generate stereo audio from mono audio using a pre-trained diffusion prior + + Args: + audio: The mono audio to convert to stereo + in_sr: The sample rate of the input audio + steps: The number of diffusion steps to run + sampler_kwargs: Keyword arguments to pass to the diffusion sampler + """ + + device = audio.device + + sample_rate = self.sample_rate + + # Resample input audio if necessary + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) + audio = resample_tf(audio) + + audio_length = audio.shape[-1] + + # # Pad input audio to be compatible with the model + # min_length = self.min_input_length + # padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length + + # # Pad input audio to be compatible with the model + # if padded_input_length > audio_length: + # audio = F.pad(audio, (0, padded_input_length - audio_length)) + + # Make audio mono, duplicate to stereo + dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) + + if self.pretransform is not None: + dual_mono = self.pretransform.encode(dual_mono) + + conditioning = self.conditioner([{'video':video}], device) + # Return fake stereo audio + conditioning["source"] = [dual_mono] + stereo_audio = generate_diffusion_cond( + self, + conditioning_tensors=conditioning, + steps=steps, + sample_size=audio_length, + sample_rate=sample_rate, + device=device, + cfg_scale=1, + **sampler_kwargs, + ) + + return stereo_audio \ No newline at end of file diff --git a/ThinkSound/models/discriminators.py b/ThinkSound/models/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..b593168df965bb1f57881ea79edbc2f66478c6c2 --- /dev/null +++ b/ThinkSound/models/discriminators.py @@ -0,0 +1,546 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from functools import reduce +import typing as tp +from einops import rearrange +from audiotools import AudioSignal, STFTParams +from dac.model.discriminator import WNConv1d, WNConv2d + +def get_hinge_losses(score_real, score_fake): + gen_loss = -score_fake.mean() + dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() + return dis_loss, gen_loss + +class EncodecDiscriminator(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + from encodec.msstftd import MultiScaleSTFTDiscriminator + + self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) + + def forward(self, x): + logits, features = self.discriminators(x) + return logits, features + + def loss(self, x, y): + feature_matching_distance = 0. + logits_true, feature_true = self.forward(x) + logits_fake, feature_fake = self.forward(y) + + dis_loss = torch.tensor(0.) + adv_loss = torch.tensor(0.) + + for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda x, y: abs(x - y).mean(), + scale_true, + scale_fake, + )) / len(scale_true) + + _dis, _adv = get_hinge_losses( + logits_true[i], + logits_fake[i], + ) + + dis_loss = dis_loss + _dis + adv_loss = adv_loss + _adv + + return dis_loss, adv_loss, feature_matching_distance + +# Discriminators from oobleck + +IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] + +TensorDict = tp.Dict[str, torch.Tensor] + +class SharedDiscriminatorConvNet(nn.Module): + + def __init__( + self, + in_size: int, + convolution: tp.Union[nn.Conv1d, nn.Conv2d], + out_size: int = 1, + capacity: int = 32, + n_layers: int = 4, + kernel_size: int = 15, + stride: int = 4, + activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), + normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, + ) -> None: + super().__init__() + channels = [in_size] + channels += list(capacity * 2**np.arange(n_layers)) + + if isinstance(stride, int): + stride = n_layers * [stride] + + net = [] + for i in range(n_layers): + if isinstance(kernel_size, int): + pad = kernel_size // 2 + s = stride[i] + else: + pad = kernel_size[0] // 2 + s = (stride[i], 1) + + net.append( + normalization( + convolution( + channels[i], + channels[i + 1], + kernel_size, + stride=s, + padding=pad, + ))) + net.append(activation()) + + net.append(convolution(channels[-1], out_size, 1)) + + self.net = nn.ModuleList(net) + + def forward(self, x) -> IndividualDiscriminatorOut: + features = [] + for layer in self.net: + x = layer(x) + if isinstance(layer, nn.modules.conv._ConvNd): + features.append(x) + score = x.reshape(x.shape[0], -1).mean(-1) + return score, features + + +class MultiScaleDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + n_scales: int, + **conv_kwargs) -> None: + super().__init__() + layers = [] + for _ in range(n_scales): + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer in self.layers: + s, f = layer(x) + score = score + s + features.extend(f) + x = nn.functional.avg_pool1d(x, 2) + return score, features + +class MultiPeriodDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + periods: tp.Sequence[int], + **conv_kwargs) -> None: + super().__init__() + layers = [] + self.periods = periods + + for _ in periods: + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer, n in zip(self.layers, self.periods): + s, f = layer(self.fold(x, n)) + score = score + s + features.extend(f) + return score, features + + def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: + pad = (n - (x.shape[-1] % n)) % n + x = nn.functional.pad(x, (0, pad)) + return x.reshape(*x.shape[:2], -1, n) + + +class MultiDiscriminator(nn.Module): + """ + Individual discriminators should take a single tensor as input (NxB C T) and + return a tuple composed of a score tensor (NxB) and a Sequence of Features + Sequence[NxB C' T']. + """ + + def __init__(self, discriminator_list: tp.Sequence[nn.Module], + keys: tp.Sequence[str]) -> None: + super().__init__() + self.discriminators = nn.ModuleList(discriminator_list) + self.keys = keys + + def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: + features = features.chunk(len(self.keys), 0) + return {k: features[i] for i, k in enumerate(self.keys)} + + @staticmethod + def concat_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = [] + if k in dict_a: + if isinstance(dict_a[k], list): + out_dict[k].extend(dict_a[k]) + else: + out_dict[k].append(dict_a[k]) + if k in dict_b: + if isinstance(dict_b[k], list): + out_dict[k].extend(dict_b[k]) + else: + out_dict[k].append(dict_b[k]) + return out_dict + + @staticmethod + def sum_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = 0. + if k in dict_a: + out_dict[k] = out_dict[k] + dict_a[k] + if k in dict_b: + out_dict[k] = out_dict[k] + dict_b[k] + return out_dict + + def forward(self, inputs: TensorDict) -> TensorDict: + discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) + all_scores = [] + all_features = [] + + for discriminator in self.discriminators: + score, features = discriminator(discriminator_input) + scores = self.unpack_tensor_to_dict(score) + scores = {f"score_{k}": scores[k] for k in scores.keys()} + all_scores.append(scores) + + features = map(self.unpack_tensor_to_dict, features) + features = reduce(self.concat_dicts, features) + features = {f"features_{k}": features[k] for k in features.keys()} + all_features.append(features) + + all_scores = reduce(self.sum_dicts, all_scores) + all_features = reduce(self.concat_dicts, all_features) + + inputs.update(all_scores) + inputs.update(all_features) + + return inputs + +class OobleckDiscriminator(nn.Module): + + def __init__( + self, + in_channels=1, + ): + super().__init__() + + multi_scale_discriminator = MultiScaleDiscriminator( + in_channels=in_channels, + n_scales=3, + ) + + multi_period_discriminator = MultiPeriodDiscriminator( + in_channels=in_channels, + periods=[2, 3, 5, 7, 11] + ) + + # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( + # filters=32, + # in_channels = in_channels, + # out_channels = 1, + # n_ffts = [2048, 1024, 512, 256, 128], + # hop_lengths = [512, 256, 128, 64, 32], + # win_lengths = [2048, 1024, 512, 256, 128] + # ) + + self.multi_discriminator = MultiDiscriminator( + [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], + ["reals", "fakes"] + ) + + def loss(self, reals, fakes): + inputs = { + "reals": reals, + "fakes": fakes, + } + + inputs = self.multi_discriminator(inputs) + + scores_real = inputs["score_reals"] + scores_fake = inputs["score_fakes"] + + features_real = inputs["features_reals"] + features_fake = inputs["features_fakes"] + + dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) + + feature_matching_distance = torch.tensor(0.) + + for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda real, fake: abs(real - fake).mean(), + scale_real, + scale_fake, + )) / len(scale_real) + + return dis_loss, gen_loss, feature_matching_distance + + +## Discriminators from Descript Audio Codec repo +## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt +class MPD(nn.Module): + def __init__(self, period, channels=1): + super().__init__() + + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): + super().__init__() + + self.convs = nn.ModuleList( + [ + WNConv1d(channels, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + channels: int = 1 + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + self.channels = channels + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class DACDiscriminator(nn.Module): + def __init__( + self, + channels: int = 1, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p, channels=channels) for p in periods] + discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, **discriminator_kwargs): + super().__init__() + self.discriminator = DACDiscriminator(**discriminator_kwargs) + + def forward(self, fake, real): + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + def loss(self, fake, real): + gen_loss, feature_distance = self.generator_loss(fake, real) + dis_loss = self.discriminator_loss(fake, real) + + return dis_loss, gen_loss, feature_distance \ No newline at end of file diff --git a/ThinkSound/models/dit (1).py b/ThinkSound/models/dit (1).py new file mode 100644 index 0000000000000000000000000000000000000000..1b9a4095fcf7286f84fd64b017e5f889061d3d05 --- /dev/null +++ b/ThinkSound/models/dit (1).py @@ -0,0 +1,430 @@ +import typing as tp +import math +import torch + +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_cond_type: tp.Literal["global", "input_concat"] = "global", + timestep_embed_dim=None, + diffusion_objective: tp.Literal["v", "rectified_flow", "rf_denoiser"] = "v", + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + self.timestep_cond_type = timestep_cond_type + + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + if timestep_cond_type == "global": + timestep_embed_dim = embed_dim + elif timestep_cond_type == "input_concat": + assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat" + input_concat_dim += timestep_embed_dim + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, timestep_embed_dim, bias=True), + nn.SiLU(), + nn.Linear(timestep_embed_dim, timestep_embed_dim, bias=True), + ) + + self.diffusion_objective = diffusion_objective + + if cond_token_dim > 0: + # Conditioning tokens + + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + + if self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + return_info=False, + exit_layer_ix=None, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + prepend_length = prepend_cond.shape[1] + + if input_concat_cond is not None: + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + + if self.timestep_cond_type == "global": + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + elif self.timestep_cond_type == "input_concat": + x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1) + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend" and global_embed is not None: + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "continuous_transformer": + # Masks not currently implemented for continuous transformer + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, return_info=return_info, exit_layer_ix=exit_layer_ix, **extra_args, **kwargs) + + if return_info: + output, info = output + + # Avoid postprocessing on early exit + if exit_layer_ix is not None: + if return_info: + return output, info + else: + return output + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + cfg_interval = (0, 1), + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + exit_layer_ix=None, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + + model_dtype = next(self.parameters()).dtype + + x = x.to(model_dtype) + + t = t.to(model_dtype) + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(model_dtype) + + if negative_cross_attn_cond is not None: + negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype) + + if input_concat_cond is not None: + input_concat_cond = input_concat_cond.to(model_dtype) + + if global_embed is not None: + global_embed = global_embed.to(model_dtype) + + if negative_global_embed is not None: + negative_global_embed = negative_global_embed.to(model_dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(model_dtype) + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # Early exit bypasses CFG processing + if exit_layer_ix is not None: + assert self.transformer_type == "continuous_transformer", "exit_layer_ix is only supported for continuous_transformer" + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + exit_layer_ix=exit_layer_ix, + **kwargs + ) + + # CFG dropout + if cfg_dropout_prob > 0.0 and cfg_scale == 1.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if self.diffusion_objective == "v": + sigma = torch.sin(t * math.pi / 2) + alpha = torch.cos(t * math.pi / 2) + elif self.diffusion_objective in ["rectified_flow", "rf_denoiser"]: + sigma = t + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None) and (cfg_interval[0] <= sigma[0] <= cfg_interval[1]): + + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + info["uncond_output"] = uncond_output + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/ThinkSound/models/dit.py b/ThinkSound/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..bb1e96ba044c9520a6ba5fb5f32848d5aa342b7e --- /dev/null +++ b/ThinkSound/models/dit.py @@ -0,0 +1,547 @@ +import typing as tp +import math +import torch +# from beartype.typing import Tuple +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer +from .utils import mask_from_frac_lengths, resample + +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + cond_ctx_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_cond_type: tp.Literal["global", "input_concat"] = "global", + add_token_dim=0, + sync_token_dim=0, + use_mlp=False, + use_zero_init=False, + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + # Timestep embeddings + self.timestep_cond_type = timestep_cond_type + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + if timestep_cond_type == "global": + timestep_embed_dim = embed_dim + elif timestep_cond_type == "input_concat": + assert timestep_embed_dim is not None, "timestep_embed_dim must be specified if timestep_cond_type is input_concat" + input_concat_dim += timestep_embed_dim + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.use_mlp = use_mlp + if cond_token_dim > 0: + # Conditioning tokens + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + if add_token_dim > 0: + # Conditioning tokens + add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim + self.to_add_embed = nn.Sequential( + nn.Linear(add_token_dim, add_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(add_embed_dim, add_embed_dim, bias=False) + ) + else: + add_embed_dim = 0 + + if sync_token_dim > 0: + # Conditioning tokens + sync_embed_dim = sync_token_dim if not project_cond_tokens else embed_dim + self.to_sync_embed = nn.Sequential( + nn.Linear(sync_token_dim, sync_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(sync_embed_dim, sync_embed_dim, bias=False) + ) + else: + sync_embed_dim = 0 + + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.empty_clip_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, embed_dim), requires_grad=True) + self.global_cond_type = global_cond_type + print("######################") + print(f'global type: {global_cond_type}') + print("######################") + if self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + nn.init.zeros_(self.postprocess_conv.weight) + + + def initialize_weights(self): + print("######################") + print(f'Fine! You are using zero initialization!') + print("######################") + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + # if isinstance(module, nn.Conv1d): + # if module.bias is not None: + # nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.to_timestep_embed[0].weight, std=0.02) + nn.init.normal_(self.to_timestep_embed[2].weight, std=0.02) + + # Zero-out output layers: + if self.global_cond_type == "adaLN": + for block in self.transformer.layers: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + add_masks=None, + sync_cond=None, + return_info=False, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + # reshape from (b, n, c) to (b, c, n) + if input_concat_cond.shape[1] != x.shape[1]: + input_concat_cond = input_concat_cond.transpose(1,2) + # Interpolate input_concat_cond to the same length as x + # if input_concat_cond.shape[1] != x.shape[2]: + # input_concat_cond = input_concat_cond.transpose(1,2) + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + # input_concat_cond = input_concat_cond.transpose(1,2) + # if len(global_embed.shape) == 2: + # global_embed = global_embed.unsqueeze(1) + # global_embed = global_embed + input_concat_cond + x = torch.cat([x, input_concat_cond], dim=1) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + # import ipdb + # ipdb.set_trace() + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if self.timestep_cond_type == "global": + if global_embed is not None: + if len(global_embed.shape) == 3: + timestep_embed = timestep_embed.unsqueeze(1) + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + elif self.timestep_cond_type == "input_concat": + x = torch.cat([x, timestep_embed.unsqueeze(1).expand(-1, -1, x.shape[2])], dim=1) + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend" and global_embed is not None: + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + if len(global_embed.shape) == 2: + prepend_inputs = global_embed.unsqueeze(1) + else: + prepend_inputs = global_embed + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + if len(global_embed.shape) == 2: + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + else: + prepend_inputs = torch.cat([prepend_inputs, global_embed], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + x = rearrange(x, "b c t -> b t c") + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + b, seq_len, c = x.shape + + # 计算需要填充的数量 + pad_amount = (self.patch_size - seq_len % self.patch_size) % self.patch_size + + if pad_amount > 0: + # 在时间维度上进行填充 + x = F.pad(x, (0, 0, 0, pad_amount), mode='constant', value=0) + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if add_cond is not None: + # Interpolate add_cond to the same length as x + # if self.use_mlp: + add_cond = self.to_add_embed(add_cond) + if add_cond.shape[1] != x.shape[1]: + add_cond = add_cond.transpose(1,2) + add_cond = F.interpolate(add_cond, (x.shape[1], ), mode='linear', align_corners=False) + add_cond = add_cond.transpose(1,2) + # add_cond = resample(add_cond, x) + + if sync_cond is not None: + sync_cond = self.to_sync_embed(sync_cond) + + if self.transformer_type == "continuous_transformer": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, sync_cond=sync_cond, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + # 移除之前添加的填充 + if pad_amount > 0: + output = output[:, :, :seq_len] + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + sync_cond=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + bsz, a, b = x.shape + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + t = t.to(model_dtype) + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(model_dtype) + + if negative_cross_attn_cond is not None: + negative_cross_attn_cond = negative_cross_attn_cond.to(model_dtype) + + if input_concat_cond is not None: + input_concat_cond = input_concat_cond.to(model_dtype) + + if global_embed is not None: + global_embed = global_embed.to(model_dtype) + + if negative_global_embed is not None: + negative_global_embed = negative_global_embed.to(model_dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(model_dtype) + + if add_cond is not None: + add_cond = add_cond.to(model_dtype) + + if sync_cond is not None: + sync_cond = sync_cond.to(model_dtype) + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + + # CFG dropout + if cfg_dropout_prob > 0.0 and cfg_scale == 1.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if add_cond is not None: + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool) + add_cond = torch.where(dropout_mask, null_embed, add_cond) + + if sync_cond is not None: + null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) + dropout_mask = torch.bernoulli(torch.full((sync_cond.shape[0], 1, 1), cfg_dropout_prob, device=sync_cond.device)).to(torch.bool) + sync_cond = torch.where(dropout_mask, null_embed, sync_cond) + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + if global_embed is not None and global_embed.shape[0] == bsz: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + elif global_embed is not None: + batch_global_cond = global_embed + else: + batch_global_cond = None + + if input_concat_cond is not None and input_concat_cond.shape[0] == bsz: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + elif input_concat_cond is not None: + batch_input_concat_cond = input_concat_cond + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None and cross_attn_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + elif cross_attn_cond is not None: + batch_cond = cross_attn_cond + else: + batch_cond = None + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None and prepend_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + elif prepend_cond is not None: + batch_prepend_cond = prepend_cond + else: + batch_prepend_cond = None + + batch_add_cond = None + + # Handle CFG for cross-attention conditioning + if add_cond is not None and add_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + + + batch_add_cond = torch.cat([add_cond, null_embed], dim=0) + elif add_cond is not None: + batch_add_cond = add_cond + else: + batch_add_cond = None + + batch_sync_cond = None + + # Handle CFG for cross-attention conditioning + if sync_cond is not None and sync_cond.shape[0] == bsz: + + null_embed = torch.zeros_like(sync_cond, device=sync_cond.device) + + + batch_sync_cond = torch.cat([sync_cond, null_embed], dim=0) + elif sync_cond is not None: + batch_sync_cond = sync_cond + else: + batch_sync_cond = None + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + add_cond = batch_add_cond, + sync_cond = batch_sync_cond, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + add_cond=add_cond, + sync_cond=sync_cond, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/ThinkSound/models/factory.py b/ThinkSound/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..20a44e84aa6f58e7fc128a5d82f97e41ccc809b6 --- /dev/null +++ b/ThinkSound/models/factory.py @@ -0,0 +1,156 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + # elif model_type == 'diffusion_infill': + # from .diffusion import create_diffusion_infill_from_config + # return create_diffusion_infill_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/ThinkSound/models/lm.py b/ThinkSound/models/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1897fa72ab716f69e0c6d71236e47cc50f78592e --- /dev/null +++ b/ThinkSound/models/lm.py @@ -0,0 +1,541 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p + +from .codebook_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLanguageModel(nn.Module): + def __init__( + self, + pattern_provider: CodebooksPatternProvider, + backbone: AudioLMBackbone, + num_quantizers: int, + codebook_size: int + ): + super().__init__() + + self.pattern_provider = pattern_provider + self.backbone = backbone + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + + self.masked_token_id = codebook_size + + # Per-quantizer embedders + # Add one for the mask embed + self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) + + # Per-quantizer output heads + self.quantizer_heads = nn.ModuleList([ + nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) + ]) + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + batch, num_quantizers, seq_len = sequence.shape + + assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" + + backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] + + dtype = next(self.parameters()).dtype + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(dtype) + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.to(dtype) + + backbone_input = backbone_input.to(dtype) + + output = self.backbone( + backbone_input, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + # Run output through quantizer heads + logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] + + return logits + + def compute_logits( + self, + codes, #[batch, num_quantizers, seq_len] + **kwargs): + """ + Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning + Handles translation between input sequence and pattern-shifted sequence + Only used during training + """ + + batch, _, seq_len = codes.shape + + pattern = self.pattern_provider.get_pattern(seq_len) + + # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps + shifted_codes, _, _ = pattern.build_pattern_sequence( + codes, + self.masked_token_id, + keep_only_valid_steps=True + ) + + # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] + logits = self(shifted_codes, **kwargs) + + # Rearrange logits to prepare to revert pattern + logits = rearrange(logits, "b n s c -> b c n s") + + # Revert sequence logits back to original sequence length, removing masked steps + logits, _, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=True + ) + + logits = rearrange(logits, "b c n t -> b n t c") + + logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] + + return LMOutput(logits=logits, mask=logits_mask) + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + codes, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + codebook_pattern = lm_config.get("codebook_pattern", "delay") + + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'musiclm': MusicLMPattern, + } + + pretransform_config = model_config.get("pretransform", None) + + pretransform = create_pretransform_from_config(pretransform_config, sample_rate) + + assert pretransform.is_discrete, "Pretransform must be discrete" + + min_input_length = pretransform.downsampling_ratio + + pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/ThinkSound/models/lm_backbone.py b/ThinkSound/models/lm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..c80cce60b06d9b367b114188444b0890a1990b61 --- /dev/null +++ b/ThinkSound/models/lm_backbone.py @@ -0,0 +1,159 @@ +from torch import nn +from x_transformers import ContinuousTransformerWrapper, Decoder + +from .transformer import ContinuousTransformer + +# Interface for backbone of a language model +# Handles conditioning and cross-attention +# Does not have to deal with patterns or quantizer heads +class AudioLMBackbone(nn.Module): + def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): + super().__init__() + + self.embed_dim = embed_dim + self.use_generation_cache = use_generation_cache + + def forward( + self, + x, + cross_attn_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + global_cond=None, + use_cache=False, + **kwargs + ): + raise NotImplementedError + + def reset_generation_cache( + self, + max_seq_len, + batch_size, + dtype=None + ): + pass + + def update_generation_cache( + self, + seqlen_offset + ): + pass + +class XTransformersAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformerWrapper( + dim_in=embed_dim, + dim_out=embed_dim, + max_seq_len=0, #Not relevant without absolute positional embeds, + attn_layers=Decoder( + dim=embed_dim, + attn_flash = True, + cross_attend = cross_attn_cond_dim > 0, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] + +class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + project_cross_attn_cond: bool = False, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformer( + dim=embed_dim, + dim_in=embed_dim, + dim_out=embed_dim, + cross_attend = cross_attn_cond_dim > 0, + cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, + causal=True, + **kwargs + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0 and project_cross_attn_cond: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + else: + self.to_cross_attn_embed = nn.Identity() + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed + cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) + + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] \ No newline at end of file diff --git a/ThinkSound/models/lm_continuous.py b/ThinkSound/models/lm_continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..469bb49f32492794345cf76dafbb377778eca81e --- /dev/null +++ b/ThinkSound/models/lm_continuous.py @@ -0,0 +1,525 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper, create_diffusion_cond_from_config + +from .codebook_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMContinuousOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLMContinuousModel(nn.Module): + def __init__( + self, + backbone: AudioLMBackbone, + ): + super().__init__() + + self.backbone = backbone + + def sample_orders(self, bsz): + # generate a batch of random generation orders + orders = [] + for _ in range(bsz): + order = np.array(list(range(self.seq_len))) + np.random.shuffle(order) + orders.append(order) + orders = torch.Tensor(np.array(orders)).cuda().long() + return orders + + def random_masking(self, x, orders): + # generate token mask + bsz, seq_len, embed_dim = x.shape + mask_rate = self.mask_ratio_generator.rvs(1)[0] + num_masked_tokens = int(np.ceil(seq_len * mask_rate)) + mask = torch.zeros(bsz, seq_len, device=x.device) + mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], + src=torch.ones(bsz, seq_len, device=x.device)) + return mask + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + + batch, seq_len, dim = sequence.shape + + dtype = next(self.parameters()).dtype + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(dtype) + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.to(dtype) + + x = sequence.to(dtype) + orders = self.sample_orders(bsz=batch) + mask = self.random_masking(x, orders) + + output = self.backbone( + x, + mask = mask, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + + return output + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + diff: ConditionedDiffusionModelWrapper, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + self.diffusion_objective = diffusion_objective + print(f'Training in the {diffusion_objective} formulation') + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + audios, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.forward(audios, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_continuous_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + + + pretransform_config = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + diff_config = model_config.get("diffusion", None) + diffusion_model = DiTWrapper(**diff_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + + diff = ConditionedDiffusionModelWrapper( + diffusion_model, + conditioner=None, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + pretransform=pretransform, + io_channels=2, + ) + + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + diff=diff, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/ThinkSound/models/local_attention.py b/ThinkSound/models/local_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..893ce11fce1f263dd02ff2a2ebe8b5e67426f83f --- /dev/null +++ b/ThinkSound/models/local_attention.py @@ -0,0 +1,278 @@ +import torch + +from einops import rearrange +from torch import nn + +from .blocks import AdaRMSNorm +from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py +class ContinuousLocalTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_in = None, + dim_out = None, + causal = False, + local_attn_window_size = 64, + heads = 8, + ff_mult = 2, + cond_dim = 0, + cross_attn_cond_dim = 0, + **kwargs + ): + super().__init__() + + dim_head = dim//heads + + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() + + self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() + + self.local_attn_window_size = local_attn_window_size + + self.cond_dim = cond_dim + + self.cross_attn_cond_dim = cross_attn_cond_dim + + self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) + + for _ in range(depth): + + self.layers.append(nn.ModuleList([ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + Attention( + dim=dim, + dim_heads=dim_head, + dim_context = cross_attn_cond_dim, + zero_init_output=True + ) if self.cross_attn_cond_dim > 0 else nn.Identity(), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim = dim, mult = ff_mult, no_bias=True) + ])) + + def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): + + x = checkpoint(self.project_in, x) + + if prepend_cond is not None: + x = torch.cat([prepend_cond, x], dim=1) + + pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + + for attn_norm, attn, xattn, ff_norm, ff in self.layers: + + residual = x + if cond is not None: + x = checkpoint(attn_norm, x, cond) + else: + x = checkpoint(attn_norm, x) + + x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + + if cross_attn_cond is not None: + x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x + + residual = x + + if cond is not None: + x = checkpoint(ff_norm, x, cond) + else: + x = checkpoint(ff_norm, x) + + x = checkpoint(ff, x) + residual + + return checkpoint(self.project_out, x) + +class TransformerDownsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim = 768, + depth = 3, + heads = 12, + downsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.downsample_ratio = downsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size=local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) + + + def forward(self, x): + + x = checkpoint(self.project_in, x) + + # Compute + x = self.transformer(x) + + # Trade sequence length for channels + x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) + + # Project back to embed dim + x = checkpoint(self.project_down, x) + + return x + +class TransformerUpsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim, + depth = 3, + heads = 12, + upsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.upsample_ratio = upsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size = local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) + + def forward(self, x): + + # Project to embed dim + x = checkpoint(self.project_in, x) + + # Project to increase channel dim + x = checkpoint(self.project_up, x) + + # Trade channels for sequence length + x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) + + # Compute + x = self.transformer(x) + + return x + + +class TransformerEncoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [96, 192, 384, 768], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerDownsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + downsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + + return x + + +class TransformerDecoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [768, 384, 192, 96], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerUpsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + upsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + return x \ No newline at end of file diff --git a/ThinkSound/models/meta_queries/__init__.py b/ThinkSound/models/meta_queries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/models/meta_queries/metaquery.py b/ThinkSound/models/meta_queries/metaquery.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce46be39f25f4b6902ee73147d2793316b12448 --- /dev/null +++ b/ThinkSound/models/meta_queries/metaquery.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Union, List + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.pipelines.pipeline_utils import numpy_to_pil +from diffusers.schedulers import ( + DDPMScheduler, + FlowMatchEulerDiscreteScheduler, + DPMSolverMultistepScheduler, +) +from diffusers.utils.torch_utils import randn_tensor +from transformers import PreTrainedModel +import PIL +from tqdm import tqdm + +from .model import MLLMInContextConfig, MLLMInContext +from diffusers.training_utils import ( + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) + + +class MetaQueryConfig(MLLMInContextConfig): + model_type = "metaquery" + + def __init__( + self, + vae_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers", + input_size: int = 16, + in_channels: int = 32, + vae_downsample_f: int = 32, + noise_scheduler_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers", + scheduler_id: str = "Efficient-Large-Model/Sana_1600M_512px_diffusers", + _gradient_checkpointing: bool = True, + loss_type: str = "flow", + num_metaqueries: int = 64, + modules_to_freeze: tuple[str] = (), + modules_to_unfreeze: tuple[str] = (), + **kwargs, + ): + super().__init__(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) + self.vae_id = vae_id + self.input_size = input_size + self.in_channels = in_channels + self.vae_downsample_f = vae_downsample_f + self.noise_scheduler_id = noise_scheduler_id + self.scheduler_id = scheduler_id + self._gradient_checkpointing = _gradient_checkpointing + self.loss_type = loss_type + self.num_metaqueries = num_metaqueries + self.modules_to_freeze = modules_to_freeze + self.modules_to_unfreeze = modules_to_unfreeze + + +class MetaQuery(PreTrainedModel): + config_class = MetaQueryConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.config = config + + self.model = MLLMInContext(MLLMInContextConfig(**config.to_dict())) + self.loss_type = config.loss_type + + if "Sana" in config.vae_id: + self.vae = AutoencoderDC.from_pretrained(config.vae_id, subfolder="vae") + else: + try: + self.vae = AutoencoderKL.from_pretrained(config.vae_id) + except: + self.vae = AutoencoderKL.from_pretrained(config.vae_id, subfolder="vae") + + if self.loss_type == "flow": + self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + config.noise_scheduler_id, subfolder="scheduler" + ) + elif self.loss_type == "diff": + self.noise_scheduler = DDPMScheduler.from_pretrained( + config.noise_scheduler_id, subfolder="scheduler" + ) + else: + raise ValueError(f"Unknown loss type {self.loss_type}") + + self.scheduler = DPMSolverMultistepScheduler.from_pretrained( + config.scheduler_id, subfolder="scheduler" + ) + + for module_name in config.modules_to_freeze: + if "." in module_name: + module = self + for sub_module_name in module_name.split("."): + module = getattr(module, sub_module_name, None) + if module is None: + break + else: + module.requires_grad_(False) + else: + module = getattr(self, module_name, None) + if module is not None: + module.requires_grad_(False) + + for module_name in config.modules_to_unfreeze: + if "." in module_name: + module = self + for sub_module_name in module_name.split("."): + module = getattr(module, sub_module_name, None) + if module is None: + break + else: + module.requires_grad_(True) + else: + module = getattr(self, module_name, None) + if module is not None: + module.requires_grad_(True) + + def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32): + sigmas = self.noise_scheduler.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.noise_scheduler.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + def get_tokenizer(self): + return self.model.get_tokenizer() + + def get_tokenize_fn(self): + return self.model.get_tokenize_fn() + + def forward( + self, target, pixel_values=None, input_ids=None, attention_mask=None, **kwargs + ): + if self.vae is not None: + if isinstance(self.vae, AutoencoderKL): + latents = self.vae.encode(target).latent_dist.sample() + elif isinstance(self.vae, AutoencoderDC): + latents = self.vae.encode(target).latent + else: + raise ValueError(f"Unknown vae type {type(self.vae)}") + if ( + "shift_factor" in self.vae.config + and self.vae.config.shift_factor is not None + ): + latents = latents - self.vae.config.shift_factor + latents = latents * self.vae.config.scaling_factor + else: + latents = target + + bsz = latents.shape[0] + + if ( + pixel_values is not None + and hasattr(self.model, "mllm_type") + and self.model.mllm_type == "qwenvl" + ): + pixel_values = pixel_values.squeeze(0) + + noise = torch.randn_like(latents, device=latents.device) + + if self.loss_type == "flow": + weighting_scheme = "uniform" + u = compute_density_for_timestep_sampling( + weighting_scheme=weighting_scheme, + batch_size=bsz, + logit_mean=0.0, + logit_std=1.0, + mode_scale=1.29, + ) + indices = (u * self.noise_scheduler.config.num_train_timesteps).long() + timesteps = self.noise_scheduler.timesteps[indices].to( + device=latents.device + ) + + sigmas = self.get_sigmas( + timesteps, latents.device, n_dim=latents.ndim, dtype=latents.dtype + ) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + prompt_embeds, attention_mask = self.model.encode_condition( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_sizes=kwargs.get("image_sizes", None), + ) + + model_pred = self.model( + x=noisy_latents, + timestep=timesteps, + prompt_embeds=prompt_embeds, + attention_mask=attention_mask, + ) + + target = noise - latents + weighting = compute_loss_weighting_for_sd3( + weighting_scheme=weighting_scheme, sigmas=sigmas + ) + loss = torch.mean( + ( + weighting.float() * (model_pred.float() - target.float()) ** 2 + ).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + elif self.loss_type == "diff": + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" + ) + + prompt_embeds, attention_mask = self.model.encode_condition( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_sizes=kwargs.get("image_sizes", None), + ) + + noise_pred = self.model( + x=noisy_latents, + timestep=timesteps, + prompt_embeds=prompt_embeds, + attention_mask=attention_mask, + ) + loss = F.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + return {"loss": loss} + + @torch.no_grad() + def decode_latents(self, latents, normalize=True, return_tensor=False): + if self.vae is not None: + latents = latents / self.vae.config.scaling_factor + if ( + "shift_factor" in self.vae.config + and self.vae.config.shift_factor is not None + ): + latents = latents + self.vae.config.shift_factor + samples = self.vae.decode(latents).sample + else: + samples = latents + if normalize: + samples = (samples / 2 + 0.5).clamp(0, 1) + else: + samples = samples.clamp(-1, 1) + if return_tensor: + return samples + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + samples = numpy_to_pil(samples) + return samples + + def sample_images( + self, + caption="", + input_images=None, + guidance_scale: float = 3.0, + image_guidance_scale: float = 1.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 30, + num_images_per_prompt: int = 1, + return_tensor=False, + negative_prompt="", + enable_progress_bar=False, + **kwargs, + ): + device = next(self.parameters()).device + + if not isinstance(caption, list): + caption = [caption] + if input_images is not None: + if isinstance(input_images, list) and not isinstance(input_images[0], list): + input_images = [[img] for img in input_images] + elif isinstance(input_images, PIL.Image.Image): + input_images = [[input_images]] + assert isinstance(input_images, list) and all( + isinstance(sublist, list) for sublist in input_images + ), "input_images needs to be a nested list" + + bsz = len(caption) + do_image_classifier_free_guidance = image_guidance_scale > 1.0 + + tokenize_func = self.get_tokenize_fn() + tokenizer = self.get_tokenizer() + + if input_images is not None: + if do_image_classifier_free_guidance: + caption = [negative_prompt] * bsz * 2 + caption + input_images_null = [ + ( + [ + PIL.Image.new("RGB", (img.size[0], img.size[1])) + for img in images + ] + if images + else None + ) + for images in input_images + ] + input_images = input_images_null + input_images * 2 + else: + caption = [negative_prompt] * bsz + caption + input_images = input_images * 2 + input_ids, attention_mask, pixel_values, image_sizes = tokenize_func( + tokenizer, caption, input_images + ) + else: + do_image_classifier_free_guidance = False + caption = [negative_prompt] * bsz + caption + input_ids, attention_mask = tokenize_func(tokenizer, caption) + pixel_values = None + image_sizes = None + + latent_size = self.config.input_size + latent_channels = self.config.in_channels + + latents = randn_tensor( + shape=( + bsz * num_images_per_prompt, + latent_channels, + latent_size, + latent_size, + ), + generator=generator, + device=device, + dtype=torch.float32, + ) + + # set step values + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, sigmas=sigmas) + else: + self.scheduler.set_timesteps(num_inference_steps) + + # Repeat pixel_values and conditions for each image per prompt + input_ids = input_ids.to(device=device).repeat_interleave( + num_images_per_prompt, dim=0 + ) + attention_mask = attention_mask.to(device=device).repeat_interleave( + num_images_per_prompt, dim=0 + ) + pixel_values = ( + pixel_values.to(device=device) + .reshape(bsz, -1, *pixel_values.shape[1:]) + .repeat_interleave(num_images_per_prompt, dim=0) + .flatten(0, 1) + if pixel_values is not None + else None + ) + image_sizes = ( + image_sizes.to(device=device).repeat_interleave( + num_images_per_prompt, dim=0 + ) + if image_sizes is not None + else None + ) + + prompt_embeds, attention_mask = self.model.encode_condition( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_sizes=image_sizes, + ) + # Convert to float32 before saving + for t in tqdm( + self.scheduler.timesteps, + desc="Sampling images", + disable=not enable_progress_bar, + ): + latent_model_input = torch.cat([latents] * (len(input_ids) // len(latents))) + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # predict noise model_output + noise_pred = self.model( + x=latent_model_input, + timestep=t.unsqueeze(0) + .expand(latent_model_input.shape[0]) + .to(latents.device), + prompt_embeds=prompt_embeds, + attention_mask=attention_mask, + ) + + # perform guidance + if do_image_classifier_free_guidance: + noise_pred_uncond, noise_pred_uncond_text, noise_pred = ( + noise_pred.chunk(3) + ) + noise_pred = ( + noise_pred_uncond + + image_guidance_scale + * (noise_pred_uncond_text - noise_pred_uncond) + + guidance_scale * (noise_pred - noise_pred_uncond_text) + ) + else: + noise_pred_uncond, noise_pred = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred - noise_pred_uncond + ) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + samples = self.decode_latents( + latents.to(self.vae.dtype) if self.vae is not None else latents, + return_tensor=return_tensor, + ) + return samples diff --git a/ThinkSound/models/meta_queries/model.py b/ThinkSound/models/meta_queries/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c3fd49111724cba138cfa151124d3d91b4877235 --- /dev/null +++ b/ThinkSound/models/meta_queries/model.py @@ -0,0 +1,578 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import random +import math +from typing import List +import typing as tp +import torch +import os +os.environ["TOKENIZERS_PARALLELISM"] = "true" +USE_AUDIO_IN_VIDEO_RATIO = 1.0 + +from torch import nn +from torchvision import transforms as v2 + +from transformers import PretrainedConfig, PreTrainedModel, AutoProcessor, Qwen2Config + +import time +from diffusers.models.normalization import RMSNorm + +from transformers.video_utils import load_video + +from .transformer_encoder import Qwen2Encoder + +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +def to_device( + data: Any, + device: Union[str, torch.device, int], + dtype: Optional[torch.dtype] = None, # 新增 + non_blocking: bool = False +) -> Any: + """Move inputs to a device and optionally convert dtype""" + if isinstance(data, Mapping): + return type(data)({ + k: to_device(v, device, dtype, non_blocking) + for k, v in data.items() + }) + elif isinstance(data, (tuple, list)): + return type(data)( + to_device(v, device, dtype, non_blocking) + for v in data + ) + elif isinstance(data, torch.Tensor): + tensor = data.to(device=device, non_blocking=non_blocking) + if dtype is not None and tensor.is_floating_point(): + tensor = tensor.to(dtype=dtype) + return tensor + else: + return data + +VIDEO_MIN_PIXELS=224*224 +VIDEO_MAX_PIXELS=224*224 + +import torch.distributed as dist +def print_memory_summary(prefix: str = ""): + + if not torch.cuda.is_available(): + return + + rank = dist.get_rank() if dist.is_initialized() else 0 + device = torch.cuda.current_device() + + allocated = torch.cuda.memory_allocated(device) / 1024**3 + total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + usage = (allocated / total * 100) if total > 0 else 0 + + print(f"[Rank {rank}] {prefix} | GPU Memory: {allocated:.2f}/{total:.2f} GB ({usage:.1f}%)") + + +class MLLMInContextConfig(PretrainedConfig): + model_type = "mllm-in-context" + + def __init__( + self, + mllm_id: str = "Qwen/Qwen2.5-VL-3B-Instruct", + diffusion_model_id: str = None, + num_metaqueries: int = None, + _gradient_checkpointing: bool = True, + max_input_text_tokens: int = 2560, + connector_num_hidden_layers: int = None, + system_prompt: str = "You will be given an video and its caption. Please describe the content of the video in detail in your own words.", + **kwargs, + ): + super().__init__() + self.mllm_id = mllm_id + self.diffusion_model_id = diffusion_model_id + self.num_metaqueries = num_metaqueries + self._gradient_checkpointing = _gradient_checkpointing + self.max_input_text_tokens = max_input_text_tokens + self.connector_num_hidden_layers = connector_num_hidden_layers + self.system_prompt = system_prompt + +import numpy as np +import torchvision.transforms as T +import torch.nn.functional as F + + + +default_config = MLLMInContextConfig() + +class MLLMInContext(PreTrainedModel): + + def __init__( + self, + output_dim: int, + query_len: int, + llm_id = "qwen_omni", + connection_layers=12, + config: MLLMInContextConfig = default_config, + ) -> None: + super().__init__(config) + self._gradient_checkpointing = config._gradient_checkpointing + self.config = config + config.num_metaqueries = query_len + config.connector_num_hidden_layers = connection_layers + print("use meta queries: ",query_len,flush=True) + + if llm_id == "qwen_vl": + config.mllm_id = "Qwen/Qwen2.5-VL-3B-Instruct" + elif llm_id == "qwen_omni": + config.mllm_id = "Qwen/Qwen2.5-Omni-3B" + else: + raise ValueError(f"Unsupported model: {llm_id}") + + if "Qwen2.5-VL" in config.mllm_id: + from .models.qwen25VL import ( + Qwen2_5_VLForConditionalGeneration + ) + self.mllm_type = "qwenvl" + elif "Qwen2.5-Omni" in config.mllm_id: + from .models.qwen25omni import ( + Qwen2_5OmniForConditionalGeneration + ) + self.mllm_type = "qwenomni" + elif "Qwen" in config.mllm_id: + self.mllm_type = "qwenlm" + elif "Llama" in config.mllm_id: + self.mllm_type = "llamaml" + else: + self.mllm_type = "llavaov" + + if self.mllm_type == "qwenvl": + self.mllm_backbone = Qwen2_5_VLForConditionalGeneration.from_pretrained( + config.mllm_id, attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16 + ) + self.mllm_backbone.model.config.use_sliding_window = False + self.mllm_backbone.model.config.sliding_window = None + #print(self.mllm_backbone.model) + + + self._freeze_mllm_backbone() + + num_embeddings = self.mllm_backbone.get_input_embeddings().num_embeddings + self.num_embeddings = num_embeddings + if config.num_metaqueries > 0: + try: + self.mllm_backbone.resize_token_embeddings( + num_embeddings + config.num_metaqueries + 2 + ) + except: + self.mllm_backbone.resize_token_embeddings( + num_embeddings + config.num_metaqueries + 2, mean_resizing=False + ) + + def freeze_hook(grad): + grad[: self.num_embeddings].zero_() + return grad + + self.mllm_backbone.model.embed_tokens.weight.register_hook(freeze_hook) + self.mllm_hidden_size = self.mllm_backbone.config.hidden_size + self.mllm_backbone.lm_head = nn.Identity() + + self.tokenizer = AutoProcessor.from_pretrained( + config.mllm_id, video_min_pixels=VIDEO_MIN_PIXELS, video_max_pixels=VIDEO_MAX_PIXELS,use_fast=True,min_pixels=224*224,max_pixels=288*288 + ) + self.tokenizer.tokenizer.padding_side = "left" + self.tokenizer.resize_fn = None + #self.tokenizer.image_processor.size = { + # "height": 224, + # "width": 224 + #} + # 3B 2048 + # 7B 3584 + self.tokenizer.system_prompt = config.system_prompt + elif self.mllm_type == "qwenomni": + self.mllm_backbone = Qwen2_5OmniForConditionalGeneration.from_pretrained( + config.mllm_id, attn_implementation="flash_attention_2",torch_dtype=torch.bfloat16 + ) + #self.mllm_backbone.disable_talker() + self.mllm_backbone.thinker.model.config.use_sliding_window = False + self.mllm_backbone.thinker.model.config.sliding_window = None + self._freeze_mllm_backbone() + + num_embeddings = self.mllm_backbone.thinker.get_input_embeddings().num_embeddings + self.num_embeddings = num_embeddings + if config.num_metaqueries > 0: + try: + self.mllm_backbone.thinker.resize_token_embeddings( + num_embeddings + config.num_metaqueries + 2 + ) + except: + self.mllm_backbone.thinker.resize_token_embeddings( + num_embeddings + config.num_metaqueries + 2, mean_resizing=False + ) + + def freeze_hook(grad): + grad[: self.num_embeddings].zero_() + return grad + + self.mllm_backbone.thinker.model.embed_tokens.weight.register_hook(freeze_hook) + self.mllm_hidden_size = self.mllm_backbone.thinker.model.config.hidden_size + self.mllm_backbone.thinker.lm_head = nn.Identity() + + self.tokenizer = AutoProcessor.from_pretrained( + config.mllm_id, video_min_pixels=VIDEO_MIN_PIXELS, video_max_pixels=VIDEO_MAX_PIXELS,use_fast=True,min_pixels=224*224,max_pixels=288*288 + ) + self.tokenizer.tokenizer.padding_side = "left" + self.tokenizer.resize_fn = None + self.tokenizer.system_prompt = "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech." + + else: + raise ValueError(f"Unsupported model: {self.mllm_type}") + + + + self.tokenizer.mllm_type = self.mllm_type + self.tokenizer.max_input_text_tokens = config.max_input_text_tokens + self.tokenizer.num_metaqueries = config.num_metaqueries + + self.pad_token_id = getattr( + self.tokenizer, "tokenizer", self.tokenizer + ).pad_token_id + if config.num_metaqueries > 0: + tokenizer = getattr(self.tokenizer, "tokenizer", self.tokenizer) + tokenizer.add_special_tokens( + { + "additional_special_tokens": [ + f"" + for i in range(num_embeddings - len(tokenizer)) + ] + } + ) + tokenizer.add_special_tokens( + { + "additional_special_tokens": ["", ""] + + [f"" for i in range(self.tokenizer.num_metaqueries)] + } + ) + self.boi_token_id = tokenizer.convert_tokens_to_ids("") + self.eoi_token_id = tokenizer.convert_tokens_to_ids("") + + #self.mllm_backbone = torch.compile(self.mllm_backbone) + + self.connector_in_dim = self.mllm_hidden_size + self.connector_out_dim = output_dim + + norm = RMSNorm(self.connector_out_dim, eps=1e-5, elementwise_affine=True) + with torch.no_grad(): + norm.weight.fill_(1.0) + + encoder = Qwen2Encoder( + Qwen2Config( + hidden_size=self.connector_in_dim, + intermediate_size=self.connector_in_dim * 4, + num_hidden_layers=config.connector_num_hidden_layers, + num_attention_heads=self.connector_in_dim // 64, + num_key_value_heads=self.connector_in_dim // 64, + initializer_range=0.014, + use_cache=False, + rope=True, + qk_norm=True, + ), + ) + self.connector = nn.Sequential( + encoder, + nn.Linear(self.connector_in_dim, self.connector_out_dim), + nn.GELU(approximate="tanh"), + nn.Linear(self.connector_out_dim, self.connector_out_dim), + norm, + ) + + if config._gradient_checkpointing: + try: + self.mllm_backbone.gradient_checkpointing_enable( + {"use_reentrant": False} + ) + except: + pass + if not isinstance(self.connector, nn.Identity): + for module in self.connector: + if isinstance(module, Qwen2Encoder): + module.gradient_checkpointing_enable({"use_reentrant": False}) + + def _freeze_mllm_backbone(self): + + print("\nFreeze MLLM backbone...") + + for param in self.mllm_backbone.parameters(): + param.requires_grad = False + + if self.config.num_metaqueries > 0: + if hasattr(self.mllm_backbone,"model"): + embed_tokens = self.mllm_backbone.model.embed_tokens + embed_tokens.weight.requires_grad = True + elif hasattr(self.mllm_backbone,"thinker"): + embed_tokens = self.mllm_backbone.thinker.model.embed_tokens + embed_tokens.weight.requires_grad = True + + + + + def get_tokenizer(self): + return self.tokenizer + + def get_tokenize_fn(self): + return self.tokenize + + def get_resize_fn(self): + return self.resize_fn + + @staticmethod + @torch.no_grad() + def tokenize( + tokenizer, caption, video = None,audio = None, text_response=None, add_generation_prompt=True + ): + #print(video) + if not isinstance(caption, List): + caption = [caption] + + if video is not None and not isinstance(video, List): + video = [video] + if audio is not None and not isinstance(audio, List): + audio = [audio] + + prefix = ( + [ + { + "role": "system", + "content": ( + tokenizer.system_prompt + if tokenizer.mllm_type == "qwenlm" + else [{"type": "text", "text": tokenizer.system_prompt}] + ), + }, + ] + if tokenizer.system_prompt is not None + else [] + ) + + if not add_generation_prompt or tokenizer.num_metaqueries <= 0: + suffix = "" + elif tokenizer.mllm_type=="qwenvl": + suffix = ( + "\n" + + "".join([f"" for i in range(tokenizer.num_metaqueries)]) + + "<|im_end|>" + ) + elif tokenizer.mllm_type=="qwenomni": + suffix = ( + "\n" + + "".join([f"" for i in range(tokenizer.num_metaqueries)]) + + "<|im_end|>" + ) + + caption = [ + tokenizer.decode( + tokenizer( + text=cap, + return_tensors="pt", + padding="max_length", + max_length=tokenizer.max_input_text_tokens, + truncation=True + ).input_ids[0] + ) + for cap in caption + ] + + if audio is not None: + #print("audio",audio[0].shape,audio) + # If each batch item is not a list, wrap it in a single-element list (or empty list if None) + for i, aud in enumerate(audio): + if aud is not None and not isinstance(aud, list): + audio[i] = [aud] + + if video is not None: + # If each batch item is not a list, wrap it in a single-element list (or empty list if None) + for i, vid in enumerate(video): + if vid is not None and not isinstance(vid, list): + #print("vid shape",vid.shape,flush=True) + video[i] = [vid] + + # Resize each image in each batch if resize_fn is not None + if tokenizer.resize_fn is not None: + video = [ + [tokenizer.resize_fn(sub_img) for sub_img in imgs] if imgs else None + for imgs in video + ] + if tokenizer.mllm_type == "qwenvl": + conversations = [ + prefix + + [ + { + "role": "user", + "content": ( + [{"type": "video"} for _ in vids] + + [{"type": "text", "text": cap}] + if vids + else [{"type": "text", "text": cap}] + ), + }, + ] + for cap, vids in zip(caption, video) + ] + kwargs = {"videos": [imgs for imgs in video if imgs]} + if tokenizer.mllm_type == "qwenomni": + conversations = [ + prefix + + [ + { + "role": "user", + "content": ( + [{"type": "video"} for vid in vids] if vids else [] + + [{"type": "text", "text": cap}] + ), + }, + ] + for cap, vids, auds in zip(caption, video, audio) + ] + kwargs = {"videos": [vid for vids in video for vid in vids], + "audio": [aud for auds in audio for aud in auds]} + #print("conversations",conversations) + elif tokenizer.mllm_type in ["qwenlm", "llamaml"]: + conversations = [ + prefix + + [ + { + "role": "user", + "content": cap, + }, + ] + for cap in caption + ] + kwargs = dict() + + else: + conversations = [ + prefix + + [ + { + "role": "user", + "content": [{"type": "text", "text": cap}], + }, + ] + for cap in caption + ] + kwargs = dict() + + + prompts = [ + tokenizer.apply_chat_template(conv, add_generation_prompt=True) + for conv in conversations + ] + if tokenizer.mllm_type=="qwenomni": + prompts = [item for prompt in prompts for item in prompt] + #print(prompts,flush=True) + + if text_response is not None: + prompts = [p + t.strip() for p, t in zip(prompts, text_response)] + if tokenizer.num_metaqueries > 0: + prompts = [p + suffix for p in prompts] + + #print("prompts",prompts) + #print("kwargs",kwargs) + use_audio_in_video = random.random() < USE_AUDIO_IN_VIDEO_RATIO + #use_audio_in_video = True + text_inputs = tokenizer( + text=prompts, + return_tensors="pt", + padding=True, + videos_kwargs={"fps": 1, "use_audio_in_video": use_audio_in_video}, + **kwargs, + ) + #print("text_inputs",text_inputs,flush=True) + #print("input_ids",text_inputs["input_ids"].tolist(),flush=True) + + return text_inputs + + + def encode_condition( + self, input_ids, attention_mask, **kwargs + ): + if self.mllm_type == "llavaov": + prompt_embeds = self.mllm_backbone( + input_ids=input_ids, + **kwargs, + attention_mask=attention_mask, + ).logits + elif self.mllm_type in ["qwenvl"]: + + prompt_embeds = self.mllm_backbone( + input_ids=input_ids, + **kwargs, + attention_mask=attention_mask, + ).logits + + elif self.mllm_type in ["qwenomni"]: + prompt_embeds = self.mllm_backbone.thinker( + input_ids=input_ids, + **kwargs, + attention_mask=attention_mask, + ).logits + elif self.mllm_type in ["qwenlm", "llamaml"]: + prompt_embeds = self.mllm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + ).logits + else: + raise ValueError(f"Unsupported model: {self.mllm_type}") + + if self.tokenizer.num_metaqueries > 0: + # Get positions for all sequences in batch at once + boi_pos = torch.where(input_ids == self.boi_token_id)[1] + eoi_pos = torch.where(input_ids == self.eoi_token_id)[1] + + # Create mask for selecting tokens between BOI and EOI + batch_size, seq_len = input_ids.shape + indices = torch.arange(seq_len, device=input_ids.device)[None, :].expand( + batch_size, -1 + ) + + + + if boi_pos.shape[0] == batch_size and eoi_pos.shape[0] == batch_size: + mask = (indices > boi_pos[:, None]) & (indices < eoi_pos[:, None]) + prompt_embeds = prompt_embeds[mask].view( + batch_size, -1, prompt_embeds.size(-1) + ) + attention_mask = attention_mask[mask].view(batch_size, -1) + else: + print(f"[DEBUG] boi_pos.shape[0]={boi_pos.shape[0]}, eoi_pos.shape[0]={eoi_pos.shape[0]}") + print(f"[DEBUG] boi_pos={boi_pos}") + print(f"[DEBUG] eoi_pos={eoi_pos}",flush=True) + prompt_embeds = torch.zeros( + batch_size, + self.tokenizer.num_metaqueries, + prompt_embeds.size(-1), + device=prompt_embeds.device, + dtype=prompt_embeds.dtype, + requires_grad=True + ) + attention_mask = None + + return self.connector(prompt_embeds), attention_mask + + + def forward(self, conversations, device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.mllm_backbone = self.mllm_backbone.to(device) + + tokenize_func = self.get_tokenize_fn() + tokenizer = self.get_tokenizer() + conversations = [con.item() for con in conversations] + caption = [con["text"] for con in conversations] + video = [con["video"] for con in conversations] + audio = [con["audio"] for con in conversations if "audio" in con] + #start_time = time.time() + inputs = tokenize_func( + tokenizer, caption, video, audio + ) + + inputs = to_device(inputs,device,dtype = torch.bfloat16) + + + prompt_embeds, attention_mask = self.encode_condition(**inputs) + #print("prompt_embeds.shape:",prompt_embeds.shape,flush=True) + + return [prompt_embeds, torch.ones(prompt_embeds.shape[0], 1).to(device)] diff --git a/ThinkSound/models/meta_queries/models/__init__.py b/ThinkSound/models/meta_queries/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/models/meta_queries/models/process_audio_info.py b/ThinkSound/models/meta_queries/models/process_audio_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b246ed8ae322ad37907e124f4f4fec0725f7a4c1 --- /dev/null +++ b/ThinkSound/models/meta_queries/models/process_audio_info.py @@ -0,0 +1,94 @@ +import base64 +from io import BytesIO + +import audioread +import av +import librosa +import numpy as np + + +SAMPLE_RATE=16000 +def _check_if_video_has_audio(video_path): + container = av.open(video_path) + audio_streams = [stream for stream in container.streams if stream.type == "audio"] + if not audio_streams: + return False + return True + + +def process_audio_info(conversations: list[dict] | list[list[dict]], use_audio_in_video: bool): + """ + Read and process audio info + + Support dict keys: + + type = audio + - audio + - audio_start + - audio_end + + type = video + - video + - video_start + - video_end + """ + audios = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if not isinstance(message["content"], list): + continue + for ele in message["content"]: + if ele["type"] == "audio": + if "audio" in ele or "audio_url" in ele: + path = ele.get("audio", ele.get("audio_url")) + audio_start = ele.get("audio_start", 0.0) + audio_end = ele.get("audio_end", None) + if isinstance(path, np.ndarray): + if path.ndim > 1: + raise ValueError("Support only mono audio") + audios.append( + path[int(SAMPLE_RATE * audio_start) : None if audio_end is None else int(SAMPLE_RATE * audio_end)] + ) + continue + elif path.startswith("data:audio"): + _, base64_data = path.split("base64,", 1) + data = BytesIO(base64.b64decode(base64_data)) + elif path.startswith("http://") or path.startswith("https://"): + data = audioread.ffdec.FFmpegAudioFile(path) + elif path.startswith("file://"): + data = path[len("file://") :] + else: + data = path + else: + raise ValueError("Unknown audio {}".format(ele)) + elif use_audio_in_video and ele["type"] == "video": + if "video" in ele or "video_url" in ele: + path = ele.get("video", ele.get("video_url")) + audio_start = ele.get("video_start", 0.0) + audio_end = ele.get("video_end", None) + assert _check_if_video_has_audio( + path + ), "Video must has audio track when use_audio_in_video=True" + if path.startswith("http://") or path.startswith("https://"): + data = audioread.ffdec.FFmpegAudioFile(path) + elif path.startswith("file://"): + data = path[len("file://") :] + else: + data = path + else: + raise ValueError("Unknown video {}".format(ele)) + else: + continue + audios.append( + librosa.load( + data, + sr=SAMPLE_RATE, + offset=audio_start, + duration=(audio_end - audio_start) if audio_end is not None else None, + )[0] + ) + if len(audios) == 0: + audios = None + return audios diff --git a/ThinkSound/models/meta_queries/models/qwen25VL.py b/ThinkSound/models/meta_queries/models/qwen25VL.py new file mode 100644 index 0000000000000000000000000000000000000000..a1428efc1e00a1fae338c6eac8f169aad19e7a49 --- /dev/null +++ b/ThinkSound/models/meta_queries/models/qwen25VL.py @@ -0,0 +1,2112 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_vl.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, + logging, + replace_return_docstrings, +) +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb + +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward +else: + flash_attn_varlen_func = None + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +class Qwen2_5_VLMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos, sin).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos, sin).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VLVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) + q = q.squeeze(0) + k = k.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + orig_q_dtype = q.dtype + orig_k_dtype = k.dtype + q, k = q.float(), k.float() + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + q_embed = q_embed.to(orig_q_dtype) + k_embed = k_embed.to(orig_k_dtype) + return q_embed, k_embed + + +class Qwen2_5_VLVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5_VLVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " + "removed and `position_embeddings` will be mandatory." + ) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + cos = emb.cos().float() + sin = emb.sin().float() + else: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +QWEN2_5_VL_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLVisionAttention, + "flash_attention_2": Qwen2_5_VLVisionFlashAttention2, + "sdpa": Qwen2_5_VLVisionSdpaAttention, +} + + +class Qwen2_5_VLVisionBlock(nn.Module): + def __init__(self, config, attn_implementation: str = "sdpa") -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_VL_VISION_ATTENTION_CLASSES[attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5_VLMLP(config, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +Qwen2_5_VL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2_5_VLConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5_VLConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): + config_class = Qwen2_5_VLVisionConfig + _no_split_modules = ["Qwen2_5_VLVisionBlock"] + + def __init__(self, config, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [Qwen2_5_VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)] + ) + self.merger = Qwen2_5_VLPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings + ) + else: + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for thw grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLFlashAttention2(Qwen2_5_VLAttention): + """ + Qwen2_5_VL flash attention module, following Qwen2_5_VL attention module. This module inherits from `Qwen2_5_VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5_VLSdpaAttention(Qwen2_5_VLAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5_VLModel is using Qwen2_5_VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_VL_ATTENTION_CLASSES = { + "eager": Qwen2_5_VLAttention, + "flash_attention_2": Qwen2_5_VLFlashAttention2, + "sdpa": Qwen2_5_VLSdpaAttention, +} + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_VL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare Qwen2_5_VL Model outputting raw hidden-states without any specific head on top.", + Qwen2_5_VL_START_DOCSTRING, +) +class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): + def __init__(self, config: Qwen2_5_VLConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5_VL. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5_VLConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5_VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +QWEN2_5_VL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + pixel_values (`torch.FloatTensor` of shape `(seq_length, num_channels * image_size * image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing images. + pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses + [`Qwen2_5_VLImageProcessor`] for processing videos. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. +""" + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embeddin for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + #self.config.vision_config.tokens_per_second = self.config.vision_config.tokens_per_second.to(self.device) + time_tensor = expanded_range * second_per_grid_t.to("cpu") * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and + # generate the first token for each sequence. Later use the generated Input ids for continuation. + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif ( + inputs_embeds is not None # Exception 1 + or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 + ): + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + batch_size, sequence_length = input_ids.shape + device = input_ids.device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + "second_per_grid_ts": second_per_grid_ts, + } + ) + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] diff --git a/ThinkSound/models/meta_queries/models/qwen25omni.py b/ThinkSound/models/meta_queries/models/qwen25omni.py new file mode 100644 index 0000000000000000000000000000000000000000..8d1a31c3277dae84cd603bc65957903f4e2f7c62 --- /dev/null +++ b/ThinkSound/models/meta_queries/models/qwen25omni.py @@ -0,0 +1,4649 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2_5_omni.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Parameter + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.utils import ( + auto_docstring, + check_torch_load_is_safe, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_torch_flex_attn_available, + logging, +) +from transformers.utils.hub import cached_file +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniAudioEncoderConfig, + Qwen2_5OmniBigVGANConfig, + Qwen2_5OmniConfig, + Qwen2_5OmniDiTConfig, + Qwen2_5OmniTalkerConfig, + Qwen2_5OmniTextConfig, + Qwen2_5OmniThinkerConfig, + Qwen2_5OmniToken2WavConfig, + Qwen2_5OmniVisionEncoderConfig, +) + + +if is_flash_attn_2_available(): + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb +else: + flash_attn_varlen_func = None + apply_rotary_emb = None + + +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from transformers.integrations.flex_attention import make_flex_block_causal_mask + + +if is_flash_attn_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +@auto_docstring +class Qwen2_5OmniPreTrainedModel(PreTrainedModel): + config_class = Qwen2_5OmniConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2_5OmniDecoderLayer", "Qwen2_5OmniVisionBlock"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only + # inference and fine-tuning - so the proper init weights code has been removed + std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02 + + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + if module.weight is not None: + module.weight.data.fill_(1.0) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, Qwen2RMSNorm): + module.weight.data.fill_(1.0) + + +class Qwen2_5OmniPreTrainedModelForConditionalGeneration(Qwen2_5OmniPreTrainedModel): + def _prepare_4d_causal_attention_mask_with_cache_position( + self, + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + def get_llm_pos_ids_for_vision( + self, + start_idx: int, + vision_idx: int, + spatial_merge_size: int, + t_index: List[int], + grid_hs: List[int], + grid_ws: List[int], + ): + llm_pos_ids_list = [] + llm_grid_h = grid_hs[vision_idx] // spatial_merge_size + llm_grid_w = grid_ws[vision_idx] // spatial_merge_size + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + _llm_pos_ids = torch.stack([t_index, h_index, w_index]) + llm_pos_ids_list.append(_llm_pos_ids + start_idx) # + 1 ) # 12.09 by malinhan + llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1) + return llm_pos_ids + + def get_chunked_index( + self, token_indices: torch.Tensor, tokens_per_chunk: int, remove_index: int + ) -> list[tuple[int, int]]: + """ + Splits token index list into chunks based on token value ranges. + + Given a list of token indices, returns a list of (start, end) index tuples representing + slices of the list where the token values fall within successive ranges of `t_ntoken_per_chunk`. + + For example, if `t_ntoken_per_chunk` is 1000, the function will create chunks such that: + - the first chunk contains token values < 1000, + - the second chunk contains values >= 1000 and < 2000, and so on. + + Parameters: + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. + t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). + remove_index (`int`) An index id to subtract from `token_indices` before chunking + + Returns: + `List[Tuple[int, int]]`: A list of tuples, each representing the start (inclusive) + and end (exclusive) indices of a chunk in `token_indices`. + """ + + def _iter(): + i, start_idx = 0, 0 # skip bos token + current_chunk = 1 + while i < len(token_indices): # skip eos token + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: + yield (start_idx, i) + start_idx = i + current_chunk += 1 + i += 1 + yield (start_idx, len(token_indices)) + + return list(_iter()) + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_audio_in_video: bool = False, + audio_seqlens: Optional[torch.LongTensor] = None, + second_per_grids: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + use_audio_in_video (`bool`, *optional*): + If set to `True`, use the audio in video. + audio_seqlens (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + second_per_grids (`torch.LongTensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + audio_token_id = self.config.audio_token_id + vision_start_token_id = self.config.vision_start_token_id + audio_start_token_id = self.config.audio_start_token_id + position_id_per_seconds = self.config.position_id_per_seconds + seconds_per_chunk = self.config.seconds_per_chunk + + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_idx, video_idx, audio_idx = 0, 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums, audio_nums = 0, 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + audio_nums = torch.sum(input_ids == audio_start_token_id) + image_nums = (vision_tokens == image_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums + multimodal_nums = ( + image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + ) + for _ in range(multimodal_nums): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if audio_token_id in input_tokens and remain_audios > 0: + ed_audio = input_tokens.index(audio_token_id, st) + else: + ed_audio = len(input_tokens) + 1 + min_ed = min(ed_image, ed_video, ed_audio) + if min_ed == ed_audio: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + audio_len + eos_len + audio_idx += 1 + remain_audios -= 1 + + elif min_ed == ed_image: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = image_grid_thw[image_idx][0] + grid_hs = image_grid_thw[:, 1] + grid_ws = image_grid_thw[:, 2] + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + image_len + eos_len + image_idx += 1 + remain_images -= 1 + + elif min_ed == ed_video and not use_audio_in_video: + text_len = min_ed - st - 1 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + llm_pos_ids_list.append(llm_pos_ids) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len + video_len + eos_len + video_idx += 1 + remain_videos -= 1 + + elif min_ed == ed_video and use_audio_in_video: + text_len = min_ed - st - 2 + if text_len != 0: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + bos_len = 1 + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + grid_t = video_grid_thw[video_idx][0] + grid_hs = video_grid_thw[:, 1] + grid_ws = video_grid_thw[:, 2] + + t_index = ( + torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds + ).long() + video_llm_pos_ids = self.get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + sub_len = 0 + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + if video_chunk_index is not None: + sub_len += video_chunk_index[1] - video_chunk_index[0] + + llm_pos_ids_list.append( + video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]] + ) + if audio_chunk_index is not None: + sub_len += audio_chunk_index[1] - audio_chunk_index[0] + + llm_pos_ids_list.append( + audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]] + ) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + eos_len = 1 + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + + st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 + + audio_idx += 1 + video_idx += 1 + remain_videos -= 1 + remain_audios -= 1 + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + + return position_ids, mrope_position_deltas + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + + return position_ids, mrope_position_deltas + + +############################ +# Start Thinker # +############################ + + +@dataclass +class Qwen2_5OmniThinkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniThinker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5OmniAudioAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2_5OmniAudioEncoderConfig, + ): + super().__init__() + self.embed_dim = config.d_model + self.num_heads = config.encoder_attention_heads + self.dropout = config.attention_dropout + self.head_dim = self.embed_dim // self.num_heads + self.config = config + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = False + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim) + + attention_mask = torch.full( + [1, seq_length, key_states.shape[1]], + torch.finfo(query_states.dtype).min, + device=query_states.device, + dtype=query_states.dtype, + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype) + + attn_output = torch.matmul(attn_weights, value_states).transpose(0, 1).reshape(seq_length, self.embed_dim) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioFlashAttention2(Qwen2_5OmniAudioAttention): + """ + Qwen2.5OmniThinker flash attention module. This module inherits from `Qwen2_5OmniAudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + seq_length, all_dim = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.reshape(seq_length, self.num_heads, -1) + + key_states = self.k_proj(hidden_states) + key_states = key_states.reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states) + value_states = value_states.reshape(seq_length, self.num_heads, -1) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func( + query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0 + ) + attn_output = attn_output.reshape(seq_length, all_dim) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class Qwen2_5OmniAudioSdpaAttention(Qwen2_5OmniAudioAttention): + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + seq_length, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) + + attention_mask = torch.zeros( + [1, seq_length, key_states.shape[0]], device=query_states.device, dtype=torch.bool + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + query_states = query_states.transpose(0, 1) + key_states = key_states.transpose(0, 1) + value_states = value_states.transpose(0, 1) + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + ) + attn_output = attn_output.transpose(0, 1) + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(seq_length, self.embed_dim) + attn_output = self.out_proj(attn_output) + return attn_output + + +QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniAudioAttention, + "flash_attention_2": Qwen2_5OmniAudioFlashAttention2, + "sdpa": Qwen2_5OmniAudioSdpaAttention, +} + + +class Qwen2_5OmniAudioEncoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__() + self.embed_dim = config.d_model + self.self_attn = QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES[config._attn_implementation](config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + if hidden_states.dtype == torch.float16: + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + return outputs + + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +@auto_docstring( + custom_intro=""" + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`Qwen2_5OmniAudioEncoderLayer`]. + """ +) +class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["Qwen2_5OmniAudioEncoderLayer"] + _supports_sdpa = True + + def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.n_window = config.n_window + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) + self.audio_bos_eos_token = nn.Embedding(2, config.output_dim) + self.layers = nn.ModuleList([Qwen2_5OmniAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.ln_post = nn.LayerNorm(config.d_model) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(config.d_model, config.output_dim) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + @auto_docstring + def forward( + self, + input_features, + feature_lens=None, + aftercnn_lens=None, + ): + r""" + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + feature_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length + aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): + mel length after cnn + """ + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum(), + dtype=torch.long, + device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + + chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + + for idx, encoder_layer in enumerate(self.layers): + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + cu_seqlens, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens, + ) + + hidden_states = layer_outputs[0] + + hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) + token_audio_list = [] + for each_audio_states in hidden_states_list: + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.ln_post(each_audio_states) + each_audio_states = self.proj(each_audio_states) + token_audio_list.append(each_audio_states) + token_audio = torch.cat(token_audio_list, dim=0) + return BaseModelOutput(last_hidden_state=token_audio) + + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + """ + Pads a sequence of tensors to their maximum length on indicated `padding_side`. + Then prepares a mask so that pad tokens are not attended to. + """ + max_len = tensor_len.max() + dim = tensor_list[0].shape[0] + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(tensor_len): + batch_mask[i, :length] = 1 + padded_tensor[i, :, :length] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = feature_lens_after_cnn.max() + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, :length] = 1 + return ( + padded_tensor, + batch_mask.unsqueeze(1), + batch_mask_after_cnn.bool(), + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + cos = freqs.cos() + sin = freqs.sin() + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + output = (tensor * cos) + (rotate_half(tensor) * sin) + output = output.to(orig_dtype) + return output + + +class Qwen2_5OmniVisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionFlashAttention2(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def _apply_rotary_pos_emb_flashatt(self, tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + tensor_ = tensor.float() + cos = freqs.cos() # .type_as(tensor_) + sin = freqs.sin() # .type_as(tensor_) + output = apply_rotary_emb(tensor_, cos, sin).type_as(tensor) + return output + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = self._apply_rotary_pos_emb_flashatt(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = self._apply_rotary_pos_emb_flashatt(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniVisionSdpaAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) + k = self.k(hidden_states).reshape(seq_length, self.num_heads, -1) + v = self.v(hidden_states).reshape(seq_length, self.num_heads, -1) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + +class Qwen2_5OmniMLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniVisionAttention, + "flash_attention_2": Qwen2_5OmniVisionFlashAttention2, + "sdpa": Qwen2_5OmniVisionSdpaAttention, +} + + +class Qwen2_5OmniVisionBlock(nn.Module): + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: + super().__init__() + self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) + self.attn = QWEN2_5_OMNI_VISION_ATTENTION_CLASSES[config._attn_implementation]( + config.hidden_size, num_heads=config.num_heads + ) + self.mlp = Qwen2_5OmniMLP(config, bias=True) + + def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Qwen2_5OmniPatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + nn.Linear(self.hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + return x + + +class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniVisionEncoderConfig + _no_split_modules = ["Qwen2_5OmniVisionBlock"] + + def __init__(self, config: Qwen2_5OmniVisionEncoderConfig, *inputs, **kwargs) -> None: + super().__init__(config, *inputs, **kwargs) + self.spatial_merge_size = config.spatial_merge_size + self.patch_size = config.patch_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=config.patch_size, + temporal_patch_size=config.temporal_patch_size, + in_channels=config.in_channels, + embed_dim=config.hidden_size, + ) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + self.blocks = nn.ModuleList([Qwen2_5OmniVisionBlock(config) for _ in range(config.depth)]) + self.merger = Qwen2_5OmniPatchMerger( + dim=config.out_hidden_size, + context_dim=config.hidden_size, + spatial_merge_size=config.spatial_merge_size, + ) + self.gradient_checkpointing = False + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.spatial_merge_size, + grid_w // self.spatial_merge_size, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: + """ + Args: + hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`): + The final hidden states of the model. + grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`): + The temporal, height and width of feature shape of each image in LLM. + + Returns: + `torch.Tensor`: hidden_states. + """ + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Modification here + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb + ) + else: + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + +class Qwen2_5OmniRotaryEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniThinkerConfig, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen2_5Omni has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5OmniAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5OmniConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2MLP(nn.Module): + def __init__(self, config, bias: bool = False): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class Qwen2_5OmniFlashAttention2(Qwen2_5OmniAttention): + """ + Qwen2_5Omni flash attention module, following Qwen2_5Omni attention module. This module inherits from `Qwen2_5OmniAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + else: + sliding_window = None + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2_5OmniSdpaAttention(Qwen2_5OmniAttention): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2_5OmniModel is using Qwen2_5OmniSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +QWEN2_5_OMNI_ATTENTION_CLASSES = { + "eager": Qwen2_5OmniAttention, + "flash_attention_2": Qwen2_5OmniFlashAttention2, + "sdpa": Qwen2_5OmniSdpaAttention, +} + + +class Qwen2_5OmniDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = QWEN2_5_OMNI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@auto_docstring +class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniTextConfig + _no_split_modules = ["Qwen2_5OmniDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen25OmniThinker. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5OmniConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen25OmniThinkerConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +@auto_docstring( + custom_intro=""" + The Qwen2.5OmniThinker model which consists of a audio backbone and a language model. + """ +) +class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniThinkerConfig + base_model_prefix = "thinker" + _no_split_modules = ["Qwen2_5OmniAudioEncoder", "Qwen2_5OmniVisionEncoder"] + + def __init__(self, config: Qwen2_5OmniThinkerConfig): + super().__init__(config) + self.audio_tower = Qwen2_5OmniAudioEncoder._from_config( + config.audio_config, attn_implementation=config._attn_implementation + ) + + self.visual = Qwen2_5OmniVisionEncoder._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + + self.vocab_size = config.text_config.vocab_size + self.model = Qwen2_5OmniThinkerTextModel._from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.spatial_merge_size = config.vision_config.spatial_merge_size + self.rope_deltas = None + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + return video_embeds + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + return image_embeds + + def get_audio_features( + self, + input_features: torch.FloatTensor, + feature_attention_mask: Optional[torch.LongTensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + ): + """ + Encodes audios into continuous embeddings that can be forwarded to the language model. + + Args: + input_features (`torch.FloatTensor`): + The tensors corresponding to the input audios. + feature_attention_mask (`torch.LongTensor`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + """ + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + else: + audio_feature_lengths = None + + audio_feat_lengths, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + audio_outputs = self.audio_tower( + input_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + ) + audio_features = audio_outputs.last_hidden_state + + if audio_features.shape[0] != sum(audio_output_lengths.tolist()): + raise ValueError("length of audio_features should match audio_output_lengths") + + return audio_features + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + feature_attention_mask: Optional[torch.Tensor] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_audio_in_video: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: + r""" + input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size), *optional*): + The tensors corresponding to the input videos. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`NewTaskModelProcessor`] uses + [`SiglipImageProcessor`] for processing videos). + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from qwen_vl_utils import process_vision_info + >>> from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration + + >>> thinker = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-7B") + >>> processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B") + + >>> conversations = [ + >>> {'role': 'system', 'content': 'You are a helpful voice chat bot, and please respond to me in a casual conversation manner using random voice.'}, + >>> {"role": "user", "content": [ + >>> {"type": "image", "image_url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + >>> {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"}, + >>> ]}, + >>> ] + + >>> text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + >>> audios = [ librosa.load(BytesIO(urlopen( conversations[1]['content'][1]['audio_url'] ).read()), sr=self.processor.feature_extractor.sampling_rate) ] + >>> images, videos = process_vision_info(conversations) + >>> inputs = processor(text=text, audios=audios, images=images, videos=videos, return_tensors="pt", padding=True) + + >>> # Generate + >>> inputs['use_audio_in_video'] = `True` or `False` + >>> generation = thinker.generate(**inputs, max_new_tokens=2048) + >>> generate_ids = generation[:, inputs.input_ids.size(1):] + + >>> response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text , audios , image and video + if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage + if input_features is not None: + audio_features = self.get_audio_features( + input_features, + feature_attention_mask=feature_attention_mask, + audio_feature_lengths=audio_feature_lengths, + ) + audio_mask = ( + (input_ids == self.config.audio_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) + #print("audio_mask",audio_mask.shape ,torch.sum(audio_mask).item()) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + if feature_attention_mask is not None: + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) + else: + audio_feature_lengths = None + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + rope_deltas = rope_deltas - delta0 + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.get_text_config().vocab_size + ) + + if not return_dict: + output = (logits,) + outputs + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniThinkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_features=None, + feature_attention_mask=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + input_features=input_features, + feature_attention_mask=feature_attention_mask, + use_audio_in_video=use_audio_in_video, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + + return model_inputs + + +############################ +# Start Talker # +############################ + + +@dataclass +class Qwen2_5OmniTalkerCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2.5OmniTalker causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + thinker_reply_part: torch.FloatTensor = None + + +@auto_docstring +class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniTalkerConfig + _no_split_modules = ["Qwen2_5OmniTalkerDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.embedding_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2_5OmniDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2_5Omni. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + config: Qwen2_5OmniConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`Qwen2_5OmniConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape( + -1, 1 + ) + text_config = config.get_text_config() + if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= ( + cache_position.reshape(-1, 1) - text_config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class Qwen2_5OmniTalkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForConditionalGeneration, GenerationMixin): + config_class = Qwen2_5OmniTalkerConfig + base_model_prefix = "talker" + + def __init__(self, config: Qwen2_5OmniTalkerConfig): + super().__init__(config) + + self.thinker_to_talker_proj = nn.Linear(config.embedding_size, config.hidden_size) + + self.model = Qwen2_5OmniTalkerModel(config) + self.codebook_size = config.vocab_size + self.codec_head = nn.Linear(config.hidden_size, self.codebook_size, bias=False) + + self.codec_bos_token = config.tts_codec_start_token_id + self.codec_eos_token = config.tts_codec_end_token_id + self.codec_pad_token = config.tts_codec_pad_token_id + self.codec_mask_token = config.tts_codec_mask_token_id + + self.text_bos_token = config.tts_text_start_token_id + self.text_eos_token = config.tts_text_end_token_id + self.text_pad_token = config.tts_text_pad_token_id + + self.spatial_merge_size = self.config.spatial_merge_size + self.rope_deltas = None + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + thinker_reply_part: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + input_text_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + use_audio_in_video: Optional[bool] = None, + audio_feature_lengths: Optional[torch.LongTensor] = None, + video_second_per_grid: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Qwen2_5OmniTalkerCausalLMOutputWithPast]: + r""" + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*): + The length of feature shape of each audio in LLM. + thinker_reply_part (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Hidden states from the thinker model's output that represent the text reply part to be processed. + input_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Input token IDs for text-only content, used for position calculation in multimodal contexts. + use_audio_in_video (`bool`, *optional*): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*): + Number of seconds per grid for each video, used for temporal feature mapping. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2_5OmniTalkerForConditionalGeneration + + >>> model = Qwen2_5OmniTalkerForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audios=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attention_mask is not None and position_ids is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): + position_ids, rope_deltas = self.get_rope_index( + input_text_ids, + image_grid_thw, + video_grid_thw, + attention_mask, + use_audio_in_video, + audio_feature_lengths, + video_second_per_grid, + ) + + inputs_embeds[:, -1, :] += self.get_input_embeddings()( + torch.tensor([self.codec_bos_token], dtype=torch.long, device=inputs_embeds.device) + ) + inputs_embeds[:, -2, :] += self.get_input_embeddings()( + torch.tensor([self.codec_pad_token], dtype=torch.long, device=inputs_embeds.device) + ) + self.rope_deltas = rope_deltas + + else: + batch_size, seq_length = input_ids.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=input_ids.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + if inputs_embeds is None: + # 1. Inference tokens after second token + codec_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds = codec_embeds + thinker_reply_part[:, :1, :] + if thinker_reply_part.shape[1] > 1: + thinker_reply_part = thinker_reply_part[:, 1:, :] + + talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=talker_lm_input, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.codec_head(hidden_states) + logits = logits.float() + + loss = None + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5OmniTalkerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + thinker_reply_part=thinker_reply_part, + ) + + def _get_initial_cache_position(self, seq_length, device, model_kwargs): + # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily + inputs_embeds = model_kwargs.pop("inputs_embeds") + model_kwargs = super()._get_initial_cache_position(seq_length, device, model_kwargs) + model_kwargs["inputs_embeds"] = inputs_embeds + return model_kwargs + + # prepare inputs for talker lm generation + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values, + attention_mask, + inputs_embeds, + cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + + model_inputs["position_ids"] = None + + return model_inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) + + if getattr(outputs, "thinker_reply_part", None) is not None: + model_kwargs["thinker_reply_part"] = outputs.thinker_reply_part + + return model_kwargs + + +############################ +# Start Token2Wav # +############################ + + +# Using custom RoPE, will use LlamaRotaryEmbedding next version +class Qwen2_5OmniDiTRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, x): + batch_size, seq_len = x.shape[0], x.shape[1] + t = torch.arange(seq_len, device=x.device) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float() + freqs = torch.stack((freqs, freqs), dim=-1) + freqs = freqs.reshape(*freqs.shape[:-2], -1) + freqs = freqs.repeat(batch_size, *([1] * freqs.dim())) + cos = freqs.cos() + sin = freqs.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class TimeDelayNetBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + dilation, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + padding="same", + padding_mode="reflect", + ) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor): + return self.activation(self.conv(hidden_states)) + + +class Res2NetBlock(torch.nn.Module): + def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1): + super().__init__() + + in_channel = in_channels // scale + hidden_channel = out_channels // scale + + self.blocks = nn.ModuleList( + [ + TimeDelayNetBlock( + in_channel, + hidden_channel, + kernel_size=kernel_size, + dilation=dilation, + ) + for i in range(scale - 1) + ] + ) + self.scale = scale + + def forward(self, hidden_states): + outputs = [] + for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)): + if i == 0: + output_part = hidden_part + elif i == 1: + output_part = self.blocks[i - 1](hidden_part) + else: + output_part = self.blocks[i - 1](hidden_part + output_part) + outputs.append(output_part) + output = torch.cat(outputs, dim=1) + return output + + +class SqueezeExcitationBlock(nn.Module): + def __init__(self, in_channels, se_channels, out_channels): + super().__init__() + + self.conv1 = nn.Conv1d( + in_channels=in_channels, + out_channels=se_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv1d( + in_channels=se_channels, + out_channels=out_channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, hidden_states): + hidden_states_mean = hidden_states.mean(dim=2, keepdim=True) + + hidden_states_mean = self.relu(self.conv1(hidden_states_mean)) + hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean)) + + return hidden_states * hidden_states_mean + + +class AttentiveStatisticsPooling(nn.Module): + """This class implements an attentive statistic pooling layer for each channel. + It returns the concatenated mean and std of the input tensor. + """ + + def __init__(self, channels, attention_channels=128): + super().__init__() + + self.eps = 1e-12 + self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1) + self.tanh = nn.Tanh() + self.conv = nn.Conv1d( + in_channels=attention_channels, + out_channels=channels, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def _length_to_mask(self, length, max_len=None, dtype=None, device=None): + """Creates a binary mask for each sequence. + + Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 + + Arguments + --------- + length : torch.LongTensor + Containing the length of each sequence in the batch. Must be 1D. + max_len : int + Max length for the mask, also the size of the second dimension. + dtype : torch.dtype, default: None + The dtype of the generated mask. + device: torch.device, default: None + The device to put the mask variable. + + Returns + ------- + mask : tensor + The binary mask. + """ + + if max_len is None: + max_len = length.max().long().item() # using arange to generate mask + mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( + len(length), max_len + ) < length.unsqueeze(1) + + mask = torch.as_tensor(mask, dtype=dtype, device=device) + return mask + + def _compute_statistics(self, x, m, dim=2): + mean = (m * x).sum(dim) + std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps)) + return mean, std + + def forward(self, hidden_states): + seq_length = hidden_states.shape[-1] + lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device) + + # Make binary mask of shape [N, 1, L] + mask = self._length_to_mask( + lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device + ) + mask = mask.unsqueeze(1) + + # Expand the temporal context of the pooling layer by allowing the + # self-attention to look at global properties of the utterance. + total = mask.sum(dim=2, keepdim=True) + + mean, std = self._compute_statistics(hidden_states, mask / total) + mean = mean.unsqueeze(2).repeat(1, 1, seq_length) + std = std.unsqueeze(2).repeat(1, 1, seq_length) + attention = torch.cat([hidden_states, mean, std], dim=1) + + # Apply layers + attention = self.conv(self.tanh(self.tdnn(attention))) + + # Filter out zero-paddings + attention = attention.masked_fill(mask == 0, float("-inf")) + + attention = F.softmax(attention, dim=2) + mean, std = self._compute_statistics(hidden_states, attention) + # Append mean and std of the batch + pooled_stats = torch.cat((mean, std), dim=1) + pooled_stats = pooled_stats.unsqueeze(2) + + return pooled_stats + + +class SqueezeExcitationRes2NetBlock(nn.Module): + """An implementation of building block in ECAPA-TDNN, i.e., + TDNN-Res2Net-TDNN-SqueezeExcitationBlock. + """ + + def __init__( + self, + in_channels, + out_channels, + res2net_scale=8, + se_channels=128, + kernel_size=1, + dilation=1, + ): + super().__init__() + self.out_channels = out_channels + self.tdnn1 = TimeDelayNetBlock( + in_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation) + self.tdnn2 = TimeDelayNetBlock( + out_channels, + out_channels, + kernel_size=1, + dilation=1, + ) + self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels) + + def forward(self, hidden_state): + residual = hidden_state + + hidden_state = self.tdnn1(hidden_state) + hidden_state = self.res2net_block(hidden_state) + hidden_state = self.tdnn2(hidden_state) + hidden_state = self.se_block(hidden_state) + + return hidden_state + residual + + +class ECAPA_TimeDelayNet(torch.nn.Module): + """An implementation of the speaker embedding model in a paper. + "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in + TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). + """ + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len( + config.enc_dilations + ): + raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length") + self.channels = config.enc_channels + self.blocks = nn.ModuleList() + + # The initial TDNN layer + self.blocks.append( + TimeDelayNetBlock( + config.mel_dim, + config.enc_channels[0], + config.enc_kernel_sizes[0], + config.enc_dilations[0], + ) + ) + + # SE-Res2Net layers + for i in range(1, len(config.enc_channels) - 1): + self.blocks.append( + SqueezeExcitationRes2NetBlock( + config.enc_channels[i - 1], + config.enc_channels[i], + res2net_scale=config.enc_res2net_scale, + se_channels=config.enc_se_channels, + kernel_size=config.enc_kernel_sizes[i], + dilation=config.enc_dilations[i], + ) + ) + + # Multi-layer feature aggregation + self.mfa = TimeDelayNetBlock( + config.enc_channels[-1], + config.enc_channels[-1], + config.enc_kernel_sizes[-1], + config.enc_dilations[-1], + ) + + # Attentive Statistical Pooling + self.asp = AttentiveStatisticsPooling( + config.enc_channels[-1], + attention_channels=config.enc_attention_channels, + ) + + # Final linear transformation + self.fc = nn.Conv1d( + in_channels=config.enc_channels[-1] * 2, + out_channels=config.enc_dim, + kernel_size=1, + padding="same", + padding_mode="reflect", + ) + + def forward(self, hidden_states): + # Minimize transpose for efficiency + hidden_states = hidden_states.transpose(1, 2) + + hidden_states_list = [] + for layer in self.blocks: + hidden_states = layer(hidden_states) + hidden_states_list.append(hidden_states) + + # Multi-layer feature aggregation + hidden_states = torch.cat(hidden_states_list[1:], dim=1) + hidden_states = self.mfa(hidden_states) + + # Attentive Statistical Pooling + hidden_states = self.asp(hidden_states) + + # Final linear transformation + hidden_states = self.fc(hidden_states) + + hidden_states = hidden_states.squeeze(-1) + return hidden_states + + +class DiTInputEmbedding(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + self.proj = nn.Linear( + config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim, + config.hidden_size, + ) + self.spk_encoder = ECAPA_TimeDelayNet(config) + + def forward( + self, + hidden_states: torch.Tensor, + speaker_embedding: torch.Tensor, + condition_vector: torch.Tensor, + code_embed: torch.Tensor, + drop_audio_cond: Optional[bool] = False, + code_embed_uncond: Optional[bool] = None, + apply_cfg: Optional[bool] = True, + ): + if apply_cfg: + hidden_states = torch.cat([hidden_states, hidden_states], dim=0) + speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0) + condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0) + code_embed = torch.cat([code_embed, code_embed_uncond], dim=0) + elif drop_audio_cond: # cfg for cond audio + condition_vector = torch.zeros_like(condition_vector) + speaker_embedding = torch.zeros_like(speaker_embedding) + condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1) + hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1)) + + return hidden_states + + +# Transformer backbone using DiT blocks +class DiTCodecEmbedding(nn.Module): + def __init__(self, codec_num_embeds, codec_dim, repeats): + super().__init__() + self.repeats = repeats + self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim) + + def forward(self, code, drop_code=False): + if drop_code: + code = torch.zeros_like(code) + code_embed = self.codec_embed(code) + + code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1) + return code_embed + + +# AdaLayerNormZero +# return with modulated x for attn input, and params for later mlp modulation +class Qwen2_5_OmniAdaLayerNormZero(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 6) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb=None): + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +# AdaLayerNormZero for final layer +# return only with modulated x for attn input, cuz no more mlp modulation +class Qwen2_5_OmniAdaLayerNormZero_Final(nn.Module): + def __init__(self, dim): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2) + + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + + def forward(self, hidden_states, emb): + emb = self.linear(self.silu(emb)) + scale, shift = torch.chunk(emb, 2, dim=1) + + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + return hidden_states + + +# FeedForward +class DiTMLP(nn.Module): + def __init__(self, dim, mult=4, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + + self.ff = nn.ModuleList( + [ + nn.Linear(dim, inner_dim), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + ] + ) + + def forward(self, hidden_states): + for layer in self.ff: + hidden_states = layer(hidden_states) + return hidden_states + + +# Modified from Llama with a different rotate function, will fixed in next release +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + def rotate_half_codec(x): + # x = rearrange(x, "... (d r) -> ... d r", r=2) + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.reshape(*x.shape[:-2], -1) + + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half_codec(q) * sin) + k_embed = (k * cos) + (rotate_half_codec(k) * sin) + return q_embed, k_embed + + +class DiTAttention(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__() + + self.config = config + self.dim = config.hidden_size + self.heads = config.num_attention_heads + self.inner_dim = config.head_dim * config.num_attention_heads + self.dropout = config.dropout + self._attn_implementation = config._attn_implementation + self.is_causal = False + + self.to_q = nn.Linear(config.hidden_size, self.inner_dim) + self.to_k = nn.Linear(config.hidden_size, self.inner_dim) + self.to_v = nn.Linear(config.hidden_size, self.inner_dim) + + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)]) + + def forward( + self, + hidden_states, # noised input x + position_embeddings=None, # rotary position embedding for x + attention_mask=None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # attention + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # apply rotary position embedding + # Due to training process, only first head is applied with RoPE, will be fixed at next release + cos, sin = position_embeddings + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + + attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + attention_weights, _ = attention_interface( + self, + query, + key, + value, + attention_mask=attention_mask, + is_causal=False, + ) + + # mask. e.g. inference got a batch with different target durations, mask out the padding + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.to(query.dtype) + + # linear proj + attention_output = self.to_out[0](attention_weights) + attention_output = self.to_out[1](attention_output) + + return attention_output + + +# time step conditioning embedding +class SinusPositionEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, hidden_states, scale=1000): + device = hidden_states.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) + emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb.type_as(hidden_states) + + +class DiTTimestepEmbedding(nn.Module): + def __init__(self, dim, freq_embed_dim=256): + super().__init__() + self.time_embed = SinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)]) + + def forward(self, timestep): # noqa: F821 + time_hidden = self.time_embed(timestep) + time_hidden = time_hidden.to(timestep.dtype) + for layer in self.time_mlp: + time_hidden = layer(time_hidden) # b d + return time_hidden + + +class DiTDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5OmniDiTConfig, look_ahead_block=0, look_backward_block=0): + super().__init__() + self.attn_norm = Qwen2_5_OmniAdaLayerNormZero(config.hidden_size) + + self.attn = DiTAttention(config) + self.look_ahead_block = look_ahead_block + self.look_backward_block = look_backward_block + self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout) + + def forward( + self, hidden_states, timestep, position_embeddings=None, block_diff=None + ): # x: noised input, t: time embedding + # pre-norm & modulation for attention input + norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep) + + # attention + attn_output = self.attn( + hidden_states=norm, + position_embeddings=position_embeddings, + attention_mask=(block_diff >= -float(self.look_backward_block)) + & (block_diff <= float(self.look_ahead_block)), + ) + + # process attention output for input x + hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output + + norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ff_output = self.ff(norm) + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + + return hidden_states + + +class SnakeBeta(nn.Module): + """ + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + """ + + def __init__(self, in_features, alpha=1.0): + super().__init__() + self.in_features = in_features + + # initialize alpha + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + + self.no_div_by_zero = 0.000000001 + + def forward(self, hidden_states): + """ + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + """ + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + alpha = torch.exp(alpha) + beta = torch.exp(beta) + hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow( + torch.sin(hidden_states * alpha), 2 + ) + + return hidden_states + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + """Generates a 1D Kaiser-windowed sinc filter. + + Args: + cutoff (float): Normalized cutoff frequency (0 to 0.5). + half_width (float): Transition bandwidth. + kernel_size (int): Number of filter taps. + + Returns: + torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter. + """ + is_even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + + # Compute Kaiser window parameters + delta_f = 4 * half_width + attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + + if attenuation > 50.0: + beta = 0.1102 * (attenuation - 8.7) + elif attenuation >= 21.0: + beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0) + else: + beta = 0.0 + + kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32) + + # Compute time indices + if is_even: + time_indices = torch.arange(-half_size, half_size) + 0.5 + else: + time_indices = torch.arange(kernel_size) - half_size + + # Compute sinc filter + if cutoff == 0: + return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape + + sinc_filter = torch.sinc(2 * cutoff * time_indices) + normalized_filter = 2 * cutoff * kaiser_window * sinc_filter + + # Normalize to ensure sum = 1 (avoid leakage of constant component) + normalized_filter /= normalized_filter.sum() + + return normalized_filter.view(1, 1, kernel_size) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + + hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate") + hidden_states = self.ratio * F.conv_transpose1d( + hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels + ) + hidden_states = hidden_states[..., self.pad_left : -self.pad_right] + + return hidden_states + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + cutoff = 0.5 / ratio + half_width = 0.6 / ratio + + if cutoff < 0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = ratio + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter, persistent=False) + + def forward(self, hidden_states): + channels = hidden_states.shape[1] + hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate") + out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels) + return out + + +class TorchActivation1d(nn.Module): + def __init__( + self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + ): + super().__init__() + if not callable(activation): + raise ValueError("Activation function must be callable") + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, hidden_states): + hidden_states = self.upsample(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.downsample(hidden_states) + + return hidden_states + + +class AMPBlock(torch.nn.Module): + def __init__( + self, + channels, + kernel_size=3, + dilation=(1, 3, 5), + ): + super().__init__() + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=self._get_padding(kernel_size, dilation[0]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=self._get_padding(kernel_size, dilation[1]), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=self._get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=self._get_padding(kernel_size, 1), + ), + ] + ) + + self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers + + self.activations = nn.ModuleList( + [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)] + ) + + def _get_padding(self, kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + def forward(self, hidden_states): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2): + residual = hidden_states + hidden_states = act1(hidden_states) + hidden_states = conv1(hidden_states) + hidden_states = act2(hidden_states) + hidden_states = conv2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavBigVGAN model. Which take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavBigVGANModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniBigVGANConfig + + def __init__(self, config: Qwen2_5OmniBigVGANConfig): + super().__init__(config) + self.num_residual_blocks = len(config.resblock_kernel_sizes) + self.num_upsample_layers = len(config.upsample_rates) + + self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 7, 1, padding=3) + + # Removing extra ModuleList breaks official state dict + ups = [ + nn.ModuleList( + [ + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**layer_idx), + config.upsample_initial_channel // (2 ** (layer_idx + 1)), + kernel_size, + stride, + padding=(kernel_size - stride) // 2, + ) + ] + ) + for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)) + ] + self.ups = nn.ModuleList(ups) + + self.resblocks = nn.ModuleList( + [ + AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation) + for layer_idx in range(self.num_upsample_layers) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes) + ] + ) + + self.activation_post = TorchActivation1d( + activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers)) + ) + self.conv_post = nn.Conv1d( + config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False + ) + + def normalize_spectrogram(self, spectrogram, max_value, min_db): + return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value) + + def amplitude_to_db(self, amplitude, min_db_level): + min_level = torch.exp( + torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype) + ) + return 20 * torch.log10(torch.clamp(amplitude, min=min_level)) + + def process_mel_spectrogram(self, mel_spectrogram): + amplitude_spectrum = torch.exp(mel_spectrogram) + decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20 + return self.normalize_spectrogram(decibel_spectrum, 1, -115) + + def forward(self, mel_spectrogram): + processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram) + hidden_representation = self.conv_pre(processed_spectrogram) + + for layer_index in range(self.num_upsample_layers): + hidden_representation = self.ups[layer_index][0](hidden_representation) + residual_output = sum( + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + for block_index in range(self.num_residual_blocks) + ) + residual_output = residual_output / self.num_residual_blocks + hidden_representation = residual_output + + hidden_representation = self.activation_post(hidden_representation) + output_waveform = self.conv_post(hidden_representation) + return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze().cpu() + + +class RungeKutta4ODESolver: + def __init__(self, function, initial_value): + self.function = function + self.initial_value = initial_value + + self._one_third = 1 / 3 + self._two_thirds = 2 / 3 + + def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) + k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) + k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) + return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 + + def _compute_step(self, function, time_start, time_step, time_end, value_start): + function_value_start = function(time_start, value_start) + return self._rk4_step( + function, time_start, time_step, time_end, value_start, function_value_start=function_value_start + ), function_value_start + + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + if time_point == time_start: + return value_start + if time_point == time_end: + return value_end + weight = (time_point - time_start) / (time_end - time_start) + return value_start + weight * (value_end - value_start) + + def integrate(self, time_points): + solution = torch.empty( + len(time_points), + *self.initial_value.shape, + dtype=self.initial_value.dtype, + device=self.initial_value.device, + ) + solution[0] = self.initial_value + + current_index = 1 + current_value = self.initial_value + for time_start, time_end in zip(time_points[:-1], time_points[1:]): + time_step = time_end - time_start + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + next_value = current_value + delta_value + + while current_index < len(time_points) and time_end >= time_points[current_index]: + solution[current_index] = self._linear_interpolation( + time_start, time_end, current_value, next_value, time_points[current_index] + ) + current_index += 1 + + current_value = next_value + + return solution + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2WavDiT model. Which take speech tokens as input and predict mel spectrogram. + """ +) +class Qwen2_5OmniToken2WavDiTModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniDiTConfig + _no_split_modules = ["DiTDecoderLayer"] + + def __init__(self, config: Qwen2_5OmniDiTConfig): + super().__init__(config) + self.mel_dim = config.mel_dim + self.repeats = config.repeats + self.time_embed = DiTTimestepEmbedding(config.hidden_size) + + self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats) + self.input_embed = DiTInputEmbedding(config) + + self.rotary_embed = Qwen2_5OmniDiTRotaryEmbedding(config.head_dim) + + self.hidden_size = config.hidden_size + self.layers = config.num_hidden_layers + self.block_size = config.block_size + self.num_attention_heads = config.num_attention_heads + + self.transformer_blocks = nn.ModuleList() + for i in range(config.num_hidden_layers): + self.transformer_blocks.append( + DiTDecoderLayer( + config, + look_ahead_block=1 if i in config.look_ahead_layers else 0, + look_backward_block=1 if i in config.look_backward_layers else 0, + ) + ) + + self.norm_out = Qwen2_5_OmniAdaLayerNormZero_Final(config.hidden_size) # final modulation + self.proj_out = nn.Linear(config.hidden_size, config.mel_dim) + + def _create_block_diff(self, hidden_states): + batch, seq_len = hidden_states.shape[0], hidden_states.shape[1] + block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length] + + block_i = block_indices.unsqueeze(1) # [seq_length, 1] + block_j = block_indices.unsqueeze(0) # [1, seq_length] + block_diff = block_j - block_i # (n, n) + + return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len) + + def forward( + self, + hidden_states, + condition_vector, + speaker_embedding, + quantized_code, + time_step, + drop_audio_conditioning=False, + drop_code=False, + apply_cfg=True, + ): + batch_size = hidden_states.shape[0] + if time_step.ndim == 0: + time_step = time_step.repeat(batch_size) + + # Compute embeddings + time_embedding = self.time_embed(time_step) + text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code) + text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None + + hidden_states = self.input_embed( + hidden_states, + speaker_embedding, + condition_vector, + text_embedding, + drop_audio_cond=drop_audio_conditioning, + code_embed_uncond=text_embedding_unconditioned, + apply_cfg=apply_cfg, + ) + + # Compute positional encodings + position_embeddings = self.rotary_embed(hidden_states) + blockwise_difference = self._create_block_diff(hidden_states) + + # Transformer blocks + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block( + hidden_states, + time_embedding, + position_embeddings=position_embeddings, + block_diff=blockwise_difference, + ) + + hidden_states = self.norm_out(hidden_states, time_embedding) + output = self.proj_out(hidden_states) + + return output + + @torch.no_grad() + def sample( + self, + conditioning_vector, + reference_mel_spectrogram, + quantized_code, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + ): + noise_initialization = torch.randn([1, 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype) + maximum_duration = quantized_code.shape[1] * self.repeats + initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device) + batch_size = reference_mel_spectrogram.shape[0] + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + + if batch_size != 1: + raise ValueError("Only batch size = 1 is currently supported") + + def ode_function(time_step, hidden_states): + if guidance_scale < 1e-5: + prediction = self( + hidden_states=hidden_states, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + quantized_code=quantized_code, + time_step=time_step, + drop_audio_conditioning=False, + drop_code=False, + ) + return prediction + + model_output = self( + hidden_states=hidden_states, + quantized_code=quantized_code, + speaker_embedding=conditioning_vector, + condition_vector=reference_mel_spectrogram, + time_step=time_step, + apply_cfg=True, + ) + guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + + initial_time = 0 + time_embedding = torch.linspace( + initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype + ) + + if sway_coefficient is not None: + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + solution_trajectory = ode_solver.integrate(time_embedding) + + generated_waveform = solution_trajectory[-1] + generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) + return generated_mel_spectrogram + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni Token2Wav model. Consists a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniToken2WavModel(Qwen2_5OmniPreTrainedModel): + config_class = Qwen2_5OmniToken2WavConfig + base_model_prefix = "model" + _no_split_modules = ["Qwen2_5OmniToken2WavDiTModel", "Qwen2_5OmniToken2WavBigVGANModel"] + + def __init__(self, config: Qwen2_5OmniToken2WavConfig): + super().__init__(config) + attn_impl = config._attn_implementation + if config._attn_implementation == "flash_attention_2": + logger.warning_once( + "Qwen2_5OmniToken2WavModel must inference with fp32, but flash_attention_2 only supports fp16 and bf16, " + "attention implementation of Qwen2_5OmniToken2WavModel will fallback to sdpa." + ) + attn_impl = "sdpa" + elif config._attn_implementation == "eager": + logger.warning_once( + "Qwen2_5OmniToken2WavModel does not support eager attention implementation, fall back to sdpa" + ) + attn_impl = "sdpa" + self.code2wav_dit_model = Qwen2_5OmniToken2WavDiTModel._from_config( + config.dit_config, attn_implementation=attn_impl + ) + self.code2wav_bigvgan_model = Qwen2_5OmniToken2WavBigVGANModel._from_config( + config.bigvgan_config, attn_implementation=attn_impl + ) + + def forward( + self, + code, + conditioning, + reference_mel, + num_steps=10, + guidance_scale=0.5, + sway_coefficient=-1.0, + **kwargs, + ): + """Generates a waveform from input code and conditioning parameters.""" + + mel_spectrogram = self.code2wav_dit_model.sample( + conditioning, + reference_mel, + code, + num_steps=num_steps, + guidance_scale=guidance_scale, + sway_coefficient=sway_coefficient, + ) + + waveform = self.code2wav_bigvgan_model(mel_spectrogram) + + return waveform + + +############################ +# Start Qwen2.5Omni # +############################ + + +@auto_docstring( + custom_intro=""" + The full Qwen2.5Omni model, a multimodal model composed of 3 sub-models: + - [`Qwen2_5OmniThinkerForConditionalGeneration`]: + a causal auto-regressive transformer takes text, audio, image, video as input and predict text tokens. + - [`Qwen2_5OmniTalkerForConditionalGeneration`]: + a causal auto-regressive transformer takes thinker hidden states and response as input and predict speech tokens. + - [`Qwen2_5OmniToken2WavModel`]: + a DiT model take speech tokens as input and predict mel spectrogram and a BigVGAN vocoder take mel spectrogram as input and predict waveform. + """ +) +class Qwen2_5OmniForConditionalGeneration(Qwen2_5OmniPreTrainedModel, GenerationMixin): + config_class = Qwen2_5OmniConfig + _no_split_modules = [ + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavModel", + ] + + def __init__(self, config): + super().__init__(config) + + self.thinker = Qwen2_5OmniThinkerForConditionalGeneration(config.thinker_config) + + self.has_talker = config.enable_audio_output + self.speaker_map = {} + #if config.enable_audio_output: + # self.enable_talker() + self.post_init() + + def enable_talker(self): + self.talker = Qwen2_5OmniTalkerForConditionalGeneration(self.config.talker_config) + self.token2wav = Qwen2_5OmniToken2WavModel(self.config.token2wav_config) + self.token2wav.float() + self.has_talker = True + + def load_speakers(self, path): + check_torch_load_is_safe() + for key, value in torch.load(path, weights_only=True).items(): + self.speaker_map[key] = value + logger.info("Speaker {} loaded".format(list(self.speaker_map.keys()))) + + def disable_talker(self): + if hasattr(self, "talker"): + del self.talker + if hasattr(self, "token2wav"): + del self.token2wav + self.has_talker = False + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + *model_args, + config=None, + cache_dir=None, + ignore_mismatched_sizes=False, + force_download=False, + local_files_only=False, + token=None, + revision="main", + use_safetensors=None, + weights_only=True, + **kwargs, + ): + model = super().from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + weights_only=weights_only, + **kwargs, + ) + spk_path = cached_file( + pretrained_model_name_or_path, + "spk_dict.pt", + subfolder=kwargs.pop("subfolder", None), + cache_dir=kwargs.pop("cache_dir", None), + force_download=kwargs.pop("force_download", False), + proxies=kwargs.pop("proxies", None), + resume_download=kwargs.pop("resume_download", None), + local_files_only=kwargs.pop("local_files_only", False), + token=kwargs.pop("use_auth_token", None), + revision=kwargs.pop("revision", None), + ) + if spk_path is None: + raise ValueError(f"""{pretrained_model_name_or_path}/{spk_path} not exists""") + model.load_speakers(spk_path) + + return model + + @torch.no_grad() + # TODO: raushan, defaults should be saved in generation config + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + speaker: str = "Chelsie", + use_audio_in_video: bool = False, + return_audio: Optional[bool] = None, + thinker_max_new_tokens: int = 1024, + talker_max_new_tokens: int = 4096, + talker_do_sample: bool = True, + talker_top_k: int = 40, + talker_top_p: float = 0.8, + talker_temperature: float = 0.9, + talker_eos_token_id: list[int] = [8292, 8294], + talker_repetition_penalty: float = 1.05, + **kwargs, + ): + r""" + Generate text response and audio from input. + + Args: + input_ids (`Optional[torch.Tensor]`, *optional*): + Input ids, should obtain from processor. + speaker (`str` , defaults to "Chelsie"): + Which speaker should be used in audio response. + use_audio_in_video (`bool`, defaults to False): + Whether or not use audio track in video, should same as the parameter in `process_audio_info`. + return_audio (`Optional[bool]`, *optional*): + Whether or not return response in audio format. When `return_audio=None`, this parameter is same as `config.enable_audio_output`. + kwargs (*optional*): + - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model. + - With a *thinker_*, *talker_*, *token2wav_* prefix, they will be input for the `generate` method of the + thinker, talker and token2wav respectively. It has the priority over the keywords without a prefix. + Returns: + When `return_audio=False`: + - **Text** (`torch.Tensor`): Generated text token sequence. + When `return_audio=True`: + - **Text** (`torch.Tensor`): Generated text token sequence. + - **Audio waveform** (`torch.Tensor`): Generated audio waveform. + """ + if speaker not in self.speaker_map: + raise ValueError(f"{speaker} is not available, available speakers: {self.speaker_map.keys()}") + if return_audio and not self.has_talker: + raise ValueError( + "Cannot use talker when talker module not initialized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) + if return_audio is None: + return_audio = self.has_talker + if input_ids.shape[0] != 1 and return_audio: + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + + shared_kwargs = {"use_audio_in_video": use_audio_in_video} + thinker_kwargs = { + "max_new_tokens": thinker_max_new_tokens, + } + talker_kwargs = { + "max_new_tokens": talker_max_new_tokens, + "do_sample": talker_do_sample, + "top_k": talker_top_k, + "top_p": talker_top_p, + "temperature": talker_temperature, + "eos_token_id": talker_eos_token_id, + "repetition_penalty": talker_repetition_penalty, + } + token2wav_kwargs = {} + + for key, value in kwargs.items(): + if key.startswith("thinker_"): + thinker_kwargs[key[len("thinker_") :]] = value + elif key.startswith("talker_"): + talker_kwargs[key[len("talker_") :]] = value + elif key.startswith("token2wav_"): + token2wav_kwargs[key[len("token2wav_") :]] = value + # Process special input values + elif key == "feature_attention_mask": + thinker_kwargs[key] = value + talker_kwargs["audio_feature_lengths"] = torch.sum(value, dim=1) + elif key == "input_features" or key == "attention_mask": + thinker_kwargs[key] = value + # Put other key to shared kwargs + else: + shared_kwargs[key] = value + + # Merge kwargs + for key, value in shared_kwargs.items(): + if key not in thinker_kwargs: + thinker_kwargs[key] = value + if key not in talker_kwargs: + talker_kwargs[key] = value + if key not in token2wav_kwargs: + token2wav_kwargs[key] = value + speaker_params = self.speaker_map[speaker] + + # 1. Generate from thinker module + generate_audio = return_audio and self.has_talker + if generate_audio: + thinker_kwargs["output_hidden_states"] = True + thinker_kwargs["return_dict_in_generate"] = True + + thinker_result = self.thinker.generate(input_ids=input_ids, **thinker_kwargs) + + if not generate_audio: + return thinker_result + + # 2. Generate speech tokens from talker module + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) + if thinker_kwargs.get("input_features", None) is not None: + audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index + audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + audio_mask_tensor = torch.zeros( + [audio_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(audio_mask, audio_mask_tensor) + if thinker_kwargs.get("pixel_values", None) is not None: + image_ids_mask = input_ids == self.config.thinker_config.image_token_index + image_mask = image_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + image_mask_tensor = torch.zeros( + [image_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(image_mask, image_mask_tensor) + if thinker_kwargs.get("pixel_values_videos", None) is not None: + video_ids_mask = input_ids == self.config.thinker_config.video_token_index + video_mask = video_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) + video_mask_tensor = torch.zeros( + [video_ids_mask.sum(), embeds_to_talker.shape[-1]], + dtype=embeds_to_talker.dtype, + device=self.talker.device, + ) + embeds_to_talker.masked_scatter_(video_mask, video_mask_tensor) + + processed_thinker_hidden = ( + (embeds_to_talker,) + thinker_result.hidden_states[0][1:], + ) + thinker_result.hidden_states[1:] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device) + thinker_token_embeds = [ + token_hidden_states[0].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + thinker_hidden_states = [ + token_hidden_states[-1].to(self.talker.device) for token_hidden_states in processed_thinker_hidden + ] + + talker_text_bos_token = speaker_params["bos_token"] + talker_input_text_ids = torch.cat( + [ + input_ids.to(self.talker.device), + torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device), + thinker_generate_ids[:, :1], + ], + dim=-1, + ) + + talker_input_ids = torch.cat( + [ + torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device), + torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device), + torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device), + ], + dim=1, + ) + + thinker_embed_tokens = self.thinker.get_input_embeddings() + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device) + talker_text_bos_embed = thinker_embed_tokens(talker_text_bos_token).to(self.talker.device) + talker_inputs_embeds = torch.cat( + [ + talker_inputs_embeds, + talker_text_bos_embed, + thinker_reply_part[:, :1, :], + ], + dim=1, + ) + + eos_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + pad_embedding = thinker_embed_tokens( + torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device) + ).to(self.talker.device) + + thinker_reply_part = torch.cat( + [ + thinker_reply_part[:, 1:, :], + eos_embedding, + pad_embedding, + ], + dim=1, + ) + + talker_attention_mask = None + if "attention_mask" in kwargs: + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1 + ).to(self.talker.device) + + talker_result = self.talker.generate( + input_ids=talker_input_ids, + input_text_ids=talker_input_text_ids, + thinker_reply_part=thinker_reply_part, + inputs_embeds=talker_inputs_embeds, + attention_mask=talker_attention_mask, + suppress_tokens=[self.talker.codec_bos_token], + **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + ) + talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] + + # 3. Generate wavs from code + if self.token2wav.dtype != torch.float: + self.token2wav.float() + + wav = self.token2wav( + talker_generate_codes.to(self.token2wav.device), + conditioning=speaker_params["cond"].to(self.token2wav.device).float(), + reference_mel=speaker_params["ref_mel"].to(self.token2wav.device).float(), + **token2wav_kwargs, + ) + + return thinker_result.sequences, wav.float() + + +__all__ = [ + "Qwen2_5OmniForConditionalGeneration", + "Qwen2_5OmniThinkerTextModel", + "Qwen2_5OmniThinkerForConditionalGeneration", + "Qwen2_5OmniTalkerModel", + "Qwen2_5OmniTalkerForConditionalGeneration", + "Qwen2_5OmniToken2WavDiTModel", + "Qwen2_5OmniToken2WavBigVGANModel", + "Qwen2_5OmniToken2WavModel", + "Qwen2_5OmniPreTrainedModel", + "Qwen2_5OmniPreTrainedModelForConditionalGeneration", +] diff --git a/ThinkSound/models/meta_queries/transformer_encoder.py b/ThinkSound/models/meta_queries/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4968824ca0580a1f4712d693ed8481b775ddd2 --- /dev/null +++ b/ThinkSound/models/meta_queries/transformer_encoder.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from typing import Optional, Tuple + +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2PreTrainedModel, + Qwen2Attention, + Qwen2MLP, + Qwen2RMSNorm, + Qwen2RotaryEmbedding, + repeat_kv, + apply_rotary_pos_emb, +) +from transformers.integrations.sdpa_attention import sdpa_attention_forward +from torch.nn import functional as F + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim, heads=1): + super().__init__() + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(heads, 1, dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.gamma * self.scale + + +class Qwen2BidirectionalSdpaAttention(Qwen2Attention): + """ + An SDPA-based attention that does NOT apply causal masking. + Inherits from Qwen2Attention, but sets self.is_causal = False. + """ + + def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.is_causal = False + self.qk_norm = config.qk_norm + if self.qk_norm: + self.q_norm = MultiHeadRMSNorm( + config.hidden_size // config.num_attention_heads, + config.num_attention_heads, + ) + self.k_norm = MultiHeadRMSNorm( + config.hidden_size // config.num_attention_heads, + config.num_key_value_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if self.qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + attn_output, attn_weights = sdpa_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask=None, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + is_causal=False, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen2EncoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2BidirectionalSdpaAttention(config, layer_idx) + self.mlp = Qwen2MLP(config) + + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ): + # Norm + Self-Attn + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Norm + MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2Encoder(Qwen2PreTrainedModel): + supports_gradient_checkpointing = True + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Qwen2EncoderLayer(config, i) for i in range(self.config.num_hidden_layers)] + ) + if config.rope: + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + else: + self.rotary_emb = None + if hasattr(config, "norm") and config.norm: + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = None + self.gradient_checkpointing = True + self.post_init() + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + def forward(self, hidden_states): + bsz, seq_len, _ = hidden_states.size() + position_ids = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) + + # Compute RoPE embeddings once, shared across layers + if self.rotary_emb is not None: + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + + for layer in self.layers: + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + position_embeddings, + ) + else: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + ) + if self.norm: + hidden_states = self.norm(hidden_states) + return hidden_states diff --git a/ThinkSound/models/mmdit.py b/ThinkSound/models/mmdit.py new file mode 100644 index 0000000000000000000000000000000000000000..599ca9985778062540c6da2064d30dcec468cd22 --- /dev/null +++ b/ThinkSound/models/mmdit.py @@ -0,0 +1,555 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import sys +from .mmmodules.ext.rotary_embeddings import compute_rope_rotations +from .mmmodules.model.embeddings import TimestepEmbedder +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .mmmodules.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) +from .utils import resample + +log = logging.getLogger() + + +@dataclass +class PreprocessedConditions: + clip_f: torch.Tensor + sync_f: torch.Tensor + text_f: torch.Tensor + clip_f_c: torch.Tensor + text_f_c: torch.Tensor + + +# Partially from https://github.com/facebookresearch/DiT +class MMAudio(nn.Module): + + def __init__(self, + *, + latent_dim: int, + clip_dim: int, + sync_dim: int, + text_dim: int, + hidden_dim: int, + depth: int, + fused_depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + latent_seq_len: int, + clip_seq_len: int, + sync_seq_len: int, + text_seq_len: int = 77, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None, + empty_string_feat: Optional[torch.Tensor] = None, + v2: bool = False, + kernel_size: int = 7, + sync_kernel: int = 7, + use_inpaint: bool = False, + use_mlp: bool = False, + cross_attend: bool = False, + add_video: bool = False, + triple_fusion: bool = False, + gated_video: bool = False) -> None: + super().__init__() + + self.v2 = v2 + self.latent_dim = latent_dim + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self._text_seq_len = text_seq_len + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.cross_attend = cross_attend + self.add_video = add_video + self.gated_video = gated_video + self.triple_fusion = triple_fusion + self.use_inpaint = use_inpaint + if self.gated_video: + self.gated_mlp = nn.Sequential( + nn.LayerNorm(hidden_dim * 2), + nn.Linear(hidden_dim*2, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + # 初始化最后一层权重为零,促进初始均匀融合 + nn.init.zeros_(self.gated_mlp[3].weight) + if self.triple_fusion: + self.gated_mlp_v = nn.Sequential( + nn.LayerNorm(hidden_dim * 3), + nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + self.gated_mlp_t = nn.Sequential( + nn.LayerNorm(hidden_dim * 3), + nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + # 初始化最后一层权重为零,促进初始均匀融合 + nn.init.zeros_(self.gated_mlp_v[3].weight) + nn.init.zeros_(self.gated_mlp_t[3].weight) + if v2: + padding_size = (kernel_size - 1) // 2 + if use_inpaint: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=kernel_size, padding=padding_size), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=kernel_size, padding=padding_size), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + sync_pad = (sync_kernel - 1) // 2 + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=sync_kernel, padding=sync_pad), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + nn.SiLU(), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + + self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim) + if use_mlp: + self.text_cond_proj = nn.Sequential( + nn.Linear(1024, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.text_cond_proj = nn.Linear(1024, hidden_dim) + self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) + # each synchformer output segment has 8 feature frames + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim))) + + self.final_layer = FinalBlock(hidden_dim, latent_dim) + + if v2: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=hidden_dim, + max_period=1) + else: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=256, + max_period=10000) + self.joint_blocks = nn.ModuleList([ + JointBlock(hidden_dim, + num_heads, + mlp_ratio=mlp_ratio, + pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) + ]) + + self.fused_blocks = nn.ModuleList([ + MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1, cross_attend=cross_attend) + for i in range(fused_depth) + ]) + + if empty_string_feat is None: + empty_string_feat = torch.zeros((77, 1024)) + + empty_t5_feat = torch.zeros((77, 2048)) + + self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) + self.empty_t5_feat = nn.Parameter(empty_t5_feat, requires_grad=False) + self.empty_audio_feat = nn.Parameter(torch.zeros(1, latent_dim), requires_grad=True) + self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True) + + self.initialize_weights() + self.initialize_rotations() + + def initialize_rotations(self): + base_freq = 1.0 + latent_rot = compute_rope_rotations(self._latent_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq, + device=self.device) + clip_rot = compute_rope_rotations(self._clip_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq * self._latent_seq_len / + self._clip_seq_len, + device=self.device) + + self.latent_rot = nn.Buffer(latent_rot, persistent=False) + self.clip_rot = nn.Buffer(clip_rot, persistent=False) + + def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self.initialize_rotations() + + def initialize_weights(self): + + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) + for block in self.fused_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.conv.weight, 0) + nn.init.constant_(self.final_layer.conv.bias, 0) + + # empty string feat shall be initialized by a CLIP encoder + nn.init.constant_(self.sync_pos_emb, 0) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, t5_features: torch.Tensor, metaclip_global_text_features: torch.Tensor) -> PreprocessedConditions: + """ + cache computations that do not depend on the latent/time step + i.e., the features are reused over steps during inference + """ + # breakpoint() + assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}' + assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}' + assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' + + bs = clip_f.shape[0] + + # B * num_segments (24) * 8 * 768 + num_sync_segments = self._sync_seq_len // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + + # extend vf to match x + clip_f = self.clip_input_proj(clip_f) # (B, VN, D) + sync_f = self.sync_input_proj(sync_f) # (B, VN, D) + + if t5_features is not None: + + if metaclip_global_text_features is not None: + text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D) + else: + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + # 计算填充长度 + padding_size = t5_features.size(2) - text_f.size(2) # 渴望填充的数量 + # 当确实需要填充的时候,确保填充是正数 + if padding_size > 0: + # 填充 text_f 的特征维度两侧 + text_f = F.pad(text_f, pad=(0, padding_size), mode='constant', value=0) # 在最后一个维度上进行填充 + else: + text_f = text_f # 如果填充长度不是正数,则不需要填充 + text_concat = torch.cat((text_f, t5_features), dim=1) + text_f = self.text_input_proj(text_concat) # (B, VN, D) + else: + text_f = self.text_input_proj(text_f) # (B, VN, D) + if metaclip_global_text_features is not None: + text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D) + else: + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + + # upsample the sync features to match the audio + sync_f = sync_f.transpose(1, 2) # (B, D, VN) + # sync_f = resample(sync_f, self._latent_seq_len) + sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact') + sync_f = sync_f.transpose(1, 2) # (B, N, D) + + # get conditional features from the clip side + clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D) + + return PreprocessedConditions(clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + clip_f_c=clip_f_c, + text_f_c=text_f_c) + + def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, + conditions: PreprocessedConditions, inpaint_masked_input=None, cfg_scale:float=1.0,cfg_dropout_prob:float=0.0,scale_phi:float=0.0 + ) -> torch.Tensor: + """ + for non-cacheable computations + """ + # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}') + assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' + empty_conditions = None + + clip_f = conditions.clip_f + sync_f = conditions.sync_f + text_f = conditions.text_f + clip_f_c = conditions.clip_f_c + text_f_c = conditions.text_f_c + + # breakpoint() + if inpaint_masked_input is not None: + if inpaint_masked_input.shape[1] != latent.shape[1]: + inpaint_masked_input = inpaint_masked_input.transpose(1,2) + latent = torch.cat([latent,inpaint_masked_input],dim=2) + latent = self.audio_input_proj(latent) # (B, N, D) + global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D) + # global_c = text_f_c + global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D) + extended_c = global_c + sync_f + + for block in self.joint_blocks: + latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c, + self.latent_rot, self.clip_rot) # (B, N, D) + if self.add_video: + if clip_f.shape[1] != latent.shape[1]: + clip_f = resample(clip_f, latent) + + if self.triple_fusion: + text_f = torch.mean(text_f, dim=1, keepdim=True) # (bsz, 1, D) + text_f = text_f.expand(-1,latent.shape[1], -1) # (T_audio, D) + fusion = torch.concat((latent, clip_f, text_f),dim=-1) + gate_v = self.gated_mlp_v(fusion) + gate_t = self.gated_mlp_t(fusion) + # modulated_latent = gate * latent # 非对称设计 + latent = latent + gate_v * clip_f + gate_t * text_f + elif self.gated_video: + fusion = torch.concat((latent, clip_f),dim=-1) + gate = self.gated_mlp(fusion) + modulated_latent = gate * latent # 非对称设计 + latent = latent + modulated_latent + else: + latent = latent + clip_f + + for block in self.fused_blocks: + if self.cross_attend: + latent = block(latent, extended_c, self.latent_rot, context=text_f) + else: + latent = block(latent, extended_c, self.latent_rot) + + # should be extended_c; this is a minor implementation error #55 + flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t + return flow + + def forward(self, latent: torch.Tensor, t: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, inpaint_masked_input, t5_features, metaclip_global_text_features, cfg_scale:float,cfg_dropout_prob:float,scale_phi:float,video_dropout_prob:float=0.2) -> torch.Tensor: + """ + latent: (B, N, C) + vf: (B, T, C_V) + t: (B,) + """ + # breakpoint() + # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}') + if self.use_inpaint and inpaint_masked_input is None: + inpaint_masked_input = torch.zeros_like(latent, device=latent.device) + latent = latent.permute(0, 2, 1) + + if cfg_dropout_prob > 0.0: + bsz = latent.shape[0] + if inpaint_masked_input is not None: + # samples = torch.rand(bsz, device=latent.device) + # null_audio = (samples < cfg_dropout_prob) + # inpaint_masked_input = inpaint_masked_input.transpose(1,2) + # inpaint_masked_input[null_audio] = self.empty_audio_feat + inpaint_masked_input = inpaint_masked_input.transpose(1,2) + null_embed = torch.zeros_like(inpaint_masked_input,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((inpaint_masked_input.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # inpaint_masked_input = torch.where(dropout_mask, self.empty_audio_feat, inpaint_masked_input) + inpaint_masked_input = torch.where(dropout_mask, null_embed, inpaint_masked_input) + + # samples = torch.rand(bsz, device=latent.device) + # null_video = (samples < cfg_dropout_prob) + # clip_f[null_video] = self.empty_clip_feat + # sync_f[null_video] = self.empty_sync_feat + null_embed = torch.zeros_like(clip_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((clip_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # clip_f = torch.where(dropout_mask, null_embed, clip_f) + clip_f = torch.where(dropout_mask, self.empty_clip_feat, clip_f) + null_embed = torch.zeros_like(sync_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((sync_f.shape[0], 1, 1), video_dropout_prob, device=latent.device)).to(torch.bool) + # sync_f = torch.where(dropout_mask, null_embed, sync_f) + sync_f = torch.where(dropout_mask, self.empty_sync_feat, sync_f) + # samples = torch.rand(bsz, device=latent.device) + # null_text = (samples < cfg_dropout_prob) + # text_f[null_text] = self.empty_string_feat + null_embed = torch.zeros_like(text_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((text_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # text_f = torch.where(dropout_mask, null_embed, text_f) + text_f = torch.where(dropout_mask, self.empty_string_feat, text_f) + if t5_features is not None: + # t5_features[null_text] = self.empty_t5_feat + null_embed = torch.zeros_like(t5_features,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((t5_features.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # t5_features = torch.where(dropout_mask, null_embed, t5_features) + t5_features = torch.where(dropout_mask, self.empty_t5_feat, t5_features) + if metaclip_global_text_features is not None: + null_embed = torch.zeros_like(metaclip_global_text_features,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((metaclip_global_text_features.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + metaclip_global_text_features = torch.where(dropout_mask, null_embed, metaclip_global_text_features) + # null_embed = torch.zeros_like(clip_f_c,device=latent.device) + # dropout_mask = torch.bernoulli(torch.full((clip_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # clip_f_c = torch.where(dropout_mask, null_embed, clip_f_c) + # null_embed = torch.zeros_like(text_f_c,device=latent.device) + # dropout_mask = torch.bernoulli(torch.full((text_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # text_f_c = torch.where(dropout_mask, null_embed, text_f_c) + + if cfg_scale != 1.0: + # empty_conditions = self.get_empty_conditions(latent.shape[0]) + # breakpoint() + bsz = latent.shape[0] + latent = torch.cat([latent,latent], dim=0) + if inpaint_masked_input is not None: + inpaint_masked_input = inpaint_masked_input.transpose(1,2) + empty_inpaint_masked_input = torch.zeros_like(inpaint_masked_input, device=latent.device) + # inpaint_masked_input = torch.cat([inpaint_masked_input,self.get_empty_audio_sequence(bsz)], dim=0) + inpaint_masked_input = torch.cat([inpaint_masked_input,empty_inpaint_masked_input], dim=0) + t = torch.cat([t, t], dim=0) + # empty_clip_f = torch.zeros_like(clip_f, device=latent.device) + # empty_sync_f = torch.zeros_like(sync_f, device=latent.device) + # empty_text_f = torch.zeros_like(text_f, device=latent.device) + + # clip_f = torch.cat([clip_f,empty_clip_f], dim=0) + # sync_f = torch.cat([sync_f,empty_sync_f], dim=0) + # text_f = torch.cat([text_f,empty_text_f], dim=0) + clip_f = torch.cat([clip_f,self.get_empty_clip_sequence(bsz)], dim=0) + # sync_f = torch.cat([sync_f,sync_f], dim=0) + sync_f = torch.cat([sync_f,self.get_empty_sync_sequence(bsz)], dim=0) + text_f = torch.cat([text_f,self.get_empty_string_sequence(bsz)], dim=0) + if t5_features is not None: + empty_t5_features = torch.zeros_like(t5_features, device=latent.device) + # t5_features = torch.cat([t5_features,empty_t5_features], dim=0) + t5_features = torch.cat([t5_features,self.get_empty_t5_sequence(bsz)], dim=0) + if metaclip_global_text_features is not None: + empty_metaclip_global_text_features = torch.zeros_like(metaclip_global_text_features, device=latent.device) + metaclip_global_text_features = torch.cat([metaclip_global_text_features,empty_metaclip_global_text_features], dim=0) + # metaclip_global_text_features = torch.cat([metaclip_global_text_features,metaclip_global_text_features], dim=0) + # clip_f_c = torch.cat([clip_f_c,empty_clip_f_c], dim=0) + # text_f_c = torch.cat([text_f_c,empty_text_f_c], dim=0) + + + conditions = self.preprocess_conditions(clip_f, sync_f, text_f, t5_features, metaclip_global_text_features) + flow = self.predict_flow(latent, t, conditions, inpaint_masked_input, cfg_scale,cfg_dropout_prob,scale_phi) + if cfg_scale != 1.0: + cond_output, uncond_output = torch.chunk(flow, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + flow = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + flow = cfg_output + flow = flow.permute(0, 2, 1) + return flow + + def get_empty_string_sequence(self, bs: int) -> torch.Tensor: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_t5_sequence(self, bs: int) -> torch.Tensor: + return self.empty_t5_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_audio_sequence(self, bs: int) -> torch.Tensor: + return self.empty_audio_feat.unsqueeze(0).expand(bs, self._latent_seq_len, -1) + + def get_empty_clip_sequence(self, bs: int) -> torch.Tensor: + return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1) + + def get_empty_sync_sequence(self, bs: int) -> torch.Tensor: + return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1) + + def get_empty_conditions( + self, + bs: int, + *, + negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: + if negative_text_features is not None: + empty_text = negative_text_features + else: + empty_text = self.get_empty_string_sequence(1) + + empty_clip = self.get_empty_clip_sequence(1) + empty_sync = self.get_empty_sync_sequence(1) + conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text) + conditions.clip_f = conditions.clip_f.expand(bs, -1, -1) + conditions.sync_f = conditions.sync_f.expand(bs, -1, -1) + conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1) + if negative_text_features is None: + conditions.text_f = conditions.text_f.expand(bs, -1, -1) + conditions.text_f_c = conditions.text_f_c.expand(bs, -1) + + return conditions + + def load_weights(self, src_dict) -> None: + if 't_embed.freqs' in src_dict: + del src_dict['t_embed.freqs'] + if 'latent_rot' in src_dict: + del src_dict['latent_rot'] + if 'clip_rot' in src_dict: + del src_dict['clip_rot'] + + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return self.empty_clip_feat.device + + @property + def latent_seq_len(self) -> int: + return self._latent_seq_len + + @property + def clip_seq_len(self) -> int: + return self._clip_seq_len + + @property + def sync_seq_len(self) -> int: + return self._sync_seq_len + diff --git a/ThinkSound/models/mmmodules/__init__.py b/ThinkSound/models/mmmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/models/mmmodules/ext/__init__.py b/ThinkSound/models/mmmodules/ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ThinkSound/models/mmmodules/ext/__init__.py @@ -0,0 +1 @@ + diff --git a/ThinkSound/models/mmmodules/ext/rotary_embeddings.py b/ThinkSound/models/mmmodules/ext/rotary_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea9d56278cb68b7577ed13148227c30ed98fd02 --- /dev/null +++ b/ThinkSound/models/mmmodules/ext/rotary_embeddings.py @@ -0,0 +1,35 @@ +from typing import Union + +import torch +from einops import rearrange +from torch import Tensor + +# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +# Ref: https://github.com/lucidrains/rotary-embedding-torch + + +def compute_rope_rotations(length: int, + dim: int, + theta: int, + *, + freq_scaling: float = 1.0, + device: Union[torch.device, str] = 'cpu') -> Tensor: + assert dim % 2 == 0 + + with torch.amp.autocast(device_type='cuda', enabled=False): + pos = torch.arange(length, dtype=torch.float32, device=device) + freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freqs *= freq_scaling + + rot = torch.einsum('..., f -> ... f', pos, freqs) + rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1) + rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2) + return rot + + +def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: + with torch.amp.autocast(device_type='cuda', enabled=False): + _x = x.float() + _x = _x.view(*_x.shape[:-1], -1, 1, 2) + x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] + return x_out.reshape(*x.shape).to(dtype=x.dtype) diff --git a/ThinkSound/models/mmmodules/ext/stft_converter.py b/ThinkSound/models/mmmodules/ext/stft_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..62922067ef3b1d3b8727ec39e7d664ccb304d9fe --- /dev/null +++ b/ThinkSound/models/mmmodules/ext/stft_converter.py @@ -0,0 +1,183 @@ +# Reference: # https://github.com/bytedance/Make-An-Audio-2 + +import torch +import torch.nn as nn +import torchaudio +from einops import rearrange +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + + +class STFTConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float = 16_000, + n_fft: int = 1024, + num_mels: int = 128, + hop_size: int = 256, + win_size: int = 1024, + fmin: float = 0, + fmax: float = 8_000, + norm_fn=torch.log, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + mel = librosa_mel_fn(sr=self.sampling_rate, + n_fft=self.n_fft, + n_mels=self.num_mels, + fmin=self.fmin, + fmax=self.fmax) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.hann_window.device + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + # input: batch_size * length + bs = waveform.shape[0] + waveform = waveform.clamp(min=-1., max=1.) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) + + power = spec.pow(2).sum(-1) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power', power.shape, power.min(), power.max(), power.mean()) + print('angle', angle.shape, angle.min(), angle.max(), angle.mean()) + + # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), + # self.mel_basis.mean()) + + # spec = rearrange(spec, 'b f t c -> (b c) f t') + + # spec = self.mel_transform(spec) + + # spec = torch.matmul(self.mel_basis, spec) + + # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) + + # spec = spectral_normalize_torch(spec, self.norm_fn) + + # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) + + # compute magnitude + # magnitude = torch.sqrt((spec**2).sum(-1)) + # normalize by magnitude + # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 + # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) + + # power = torch.log10(power.clamp(min=1e-5)) * 10 + power = torch.log10(power.clamp(min=1e-5)) + + print('After scaling', power.shape, power.min(), power.max(), power.mean()) + + spec = torch.stack([power, angle], dim=-1) + + # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) + spec = rearrange(spec, 'b f t c -> b c f t', b=bs) + + # spec[:, :, 400:] = 0 + + return spec + + def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: + bs = spec.shape[0] + + # spec = rearrange(spec, 'b c f t -> (b c) f t') + # print(spec.shape, self.mel_basis.shape) + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + + # spec = self.invmel_transform(spec) + + spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() + + # spec[..., 0] = 10**(spec[..., 0] / 10) + + power = spec[..., 0] + power = 10**power + + # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), + # spec[..., 0].mean()) + + unit_vector = torch.stack([ + torch.cos(spec[..., 1]), + torch.sin(spec[..., 1]), + ], dim=-1) + + spec = torch.sqrt(power) * unit_vector + + # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + spec = torch.view_as_complex(spec) + + waveform = torch.istft( + spec, + self.n_fft, + length=length, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + normalized=False, + onesided=True, + return_complex=False, + ) + + return waveform + + +if __name__ == '__main__': + + converter = STFTConverter(sampling_rate=16000) + + signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] + # resample signal at 44100 Hz + # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) + + L = signal.shape[1] + print('Input signal', signal.shape) + spec = converter(signal) + + print('Final spec', spec.shape) + + signal_recon = converter.invert(spec, length=L) + print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), + signal_recon.mean()) + + print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) + torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/ThinkSound/models/mmmodules/ext/stft_converter_mel.py b/ThinkSound/models/mmmodules/ext/stft_converter_mel.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b32d4cb9a23cd74f723e7d8307fd82fa1abba0 --- /dev/null +++ b/ThinkSound/models/mmmodules/ext/stft_converter_mel.py @@ -0,0 +1,234 @@ +# Reference: # https://github.com/bytedance/Make-An-Audio-2 + +import torch +import torch.nn as nn +import torchaudio +from einops import rearrange +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + + +class STFTConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float = 16_000, + n_fft: int = 1024, + num_mels: int = 128, + hop_size: int = 256, + win_size: int = 1024, + fmin: float = 0, + fmax: float = 8_000, + norm_fn=torch.log, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + mel = librosa_mel_fn(sr=self.sampling_rate, + n_fft=self.n_fft, + n_mels=self.num_mels, + fmin=self.fmin, + fmax=self.fmax) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.hann_window.device + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + # input: batch_size * length + bs = waveform.shape[0] + waveform = waveform.clamp(min=-1., max=1.) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power 1', power.shape, power.min(), power.max(), power.mean()) + print('angle 1', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), + # self.mel_basis.mean()) + + # spec = self.mel_transform(spec) + + # power = torch.matmul(self.mel_basis, power) + + spec = rearrange(spec, 'b f t c -> (b c) f t') + spec = self.mel_basis.unsqueeze(0) @ spec + spec = rearrange(spec, '(b c) f t -> b f t c', b=bs) + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power', power.shape, power.min(), power.max(), power.mean()) + print('angle', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) + + # spec = spectral_normalize_torch(spec, self.norm_fn) + + # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) + + # compute magnitude + # magnitude = torch.sqrt((spec**2).sum(-1)) + # normalize by magnitude + # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 + # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) + + # power = torch.log10(power.clamp(min=1e-5)) * 10 + power = torch.log10(power.clamp(min=1e-8)) + + print('After scaling', power.shape, power.min(), power.max(), power.mean()) + + # spec = torch.stack([power, angle], dim=-1) + + # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) + # spec = rearrange(spec, 'b f t c -> b c f t', b=bs) + + # spec[:, :, 400:] = 0 + + return power, angle + # return spec[..., 0], spec[..., 1] + + def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: + + power, angle = spec + + bs = power.shape[0] + + # spec = rearrange(spec, 'b c f t -> (b c) f t') + # print(spec.shape, self.mel_basis.shape) + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + + # spec = self.invmel_transform(spec) + + # spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() + + # spec[..., 0] = 10**(spec[..., 0] / 10) + + # power = spec[..., 0] + power = 10**power + + # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), + # spec[..., 0].mean()) + + unit_vector = torch.stack([ + torch.cos(angle), + torch.sin(angle), + ], dim=-1) + + spec = power.unsqueeze(-1) * unit_vector + + # power = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), power).solution + spec = rearrange(spec, 'b f t c -> (b c) f t') + spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power 2', power.shape, power.min(), power.max(), power.mean()) + print('angle 2', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + spec = torch.view_as_complex(spec) + + waveform = torch.istft( + spec, + self.n_fft, + length=length, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + normalized=False, + onesided=True, + return_complex=False, + ) + + return waveform + + +if __name__ == '__main__': + + converter = STFTConverter(sampling_rate=16000) + + signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] + # resample signal at 44100 Hz + # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) + + L = signal.shape[1] + print('Input signal', signal.shape) + spec = converter(signal) + + power, angle = spec + + # print(power.shape, angle.shape) + # print(power, power.min(), power.max(), power.mean()) + # power = power.clamp(-1, 1) + # angle = angle.clamp(-1, 1) + + import matplotlib.pyplot as plt + + # Visualize power + plt.figure() + plt.imshow(power[0].detach().numpy(), aspect='auto', origin='lower') + plt.colorbar() + plt.title('Power') + plt.xlabel('Time') + plt.ylabel('Frequency') + plt.savefig('./output/power.png') + + # Visualize angle + plt.figure() + plt.imshow(angle[0].detach().numpy(), aspect='auto', origin='lower') + plt.colorbar() + plt.title('Angle') + plt.xlabel('Time') + plt.ylabel('Frequency') + plt.savefig('./output/angle.png') + + # print('Final spec', spec.shape) + + signal_recon = converter.invert(spec, length=L) + print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), + signal_recon.mean()) + + print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) + torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/ThinkSound/models/mmmodules/model/__init__.py b/ThinkSound/models/mmmodules/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ThinkSound/models/mmmodules/model/embeddings.py b/ThinkSound/models/mmmodules/model/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..297feb4d2c79d306771f5436dbd4ada1a976b3bc --- /dev/null +++ b/ThinkSound/models/mmmodules/model/embeddings.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + +# https://github.com/facebookresearch/DiT + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, dim, frequency_embedding_size, max_period): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + self.dim = dim + self.max_period = max_period + assert dim % 2 == 0, 'dim must be even.' + + with torch.autocast('cuda', enabled=False): + self.freqs = nn.Buffer( + 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / + frequency_embedding_size)), + persistent=False) + freq_scale = 10000 / max_period + self.freqs = freq_scale * self.freqs + + def timestep_embedding(self, t): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + + args = t[:, None].float() * self.freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t).to(t.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/ThinkSound/models/mmmodules/model/flow_matching.py b/ThinkSound/models/mmmodules/model/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c65dece6dec746db999092606f4384d084d119 --- /dev/null +++ b/ThinkSound/models/mmmodules/model/flow_matching.py @@ -0,0 +1,71 @@ +import logging +from typing import Callable, Optional + +import torch +from torchdiffeq import odeint + +log = logging.getLogger() + + +# Partially from https://github.com/gle-bellier/flow-matching +class FlowMatching: + + def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25): + # inference_mode: 'euler' or 'adaptive' + # num_steps: number of steps in the euler inference mode + super().__init__() + self.min_sigma = min_sigma + self.inference_mode = inference_mode + self.num_steps = num_steps + + # self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma) + + assert self.inference_mode in ['euler', 'adaptive'] + if self.inference_mode == 'adaptive' and num_steps > 0: + log.info('The number of steps is ignored in adaptive inference mode ') + + def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor, + t: torch.Tensor) -> torch.Tensor: + # which is psi_t(x), eq 22 in flow matching for generative models + t = t[:, None, None].expand_as(x0) + return (1 - (1 - self.min_sigma) * t) * x0 + t * x1 + + def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: + # return the mean error without reducing the batch dimension + reduce_dim = list(range(1, len(predicted_v.shape))) + target_v = x1 - (1 - self.min_sigma) * x0 + return (predicted_v - target_v).pow(2).mean(dim=reduce_dim) + + def get_x0_xt_c( + self, + x1: torch.Tensor, + t: torch.Tensor, + Cs: list[torch.Tensor], + generator: Optional[torch.Generator] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x0 = torch.empty_like(x1).normal_(generator=generator) + + xt = self.get_conditional_flow(x0, x1, t) + return x0, x1, xt, Cs + + def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor: + return self.run_t0_to_t1(fn, x1, 1, 0) + + def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor: + return self.run_t0_to_t1(fn, x0, 0, 1) + + def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor: + # fn: a function that takes (t, x) and returns the direction x0->x1 + + if self.inference_mode == 'adaptive': + return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype)) + elif self.inference_mode == 'euler': + x = x0 + steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1) + for ti, t in enumerate(steps[:-1]): + flow = fn(t, x) + next_t = steps[ti + 1] + dt = next_t - t + x = x + dt * flow + + return x diff --git a/ThinkSound/models/mmmodules/model/low_level.py b/ThinkSound/models/mmmodules/model/low_level.py new file mode 100644 index 0000000000000000000000000000000000000000..c8326a8bec99f1be08b92e76fda4b59e777b39d2 --- /dev/null +++ b/ThinkSound/models/mmmodules/model/low_level.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +# https://github.com/Stability-AI/sd3-ref +class MLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w2 = ChannelLastConv1d(hidden_dim, + dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w3 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/ThinkSound/models/mmmodules/model/networks.py b/ThinkSound/models/mmmodules/model/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8272a896b358f5db681d1462c4189d671b916d76 --- /dev/null +++ b/ThinkSound/models/mmmodules/model/networks.py @@ -0,0 +1,470 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmaudio.ext.rotary_embeddings import compute_rope_rotations +from mmaudio.model.embeddings import TimestepEmbedder +from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from mmaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) + +log = logging.getLogger() + + +@dataclass +class PreprocessedConditions: + clip_f: torch.Tensor + sync_f: torch.Tensor + text_f: torch.Tensor + clip_f_c: torch.Tensor + text_f_c: torch.Tensor + + +# Partially from https://github.com/facebookresearch/DiT +class MMAudio(nn.Module): + + def __init__(self, + *, + latent_dim: int, + clip_dim: int, + sync_dim: int, + text_dim: int, + hidden_dim: int, + depth: int, + fused_depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + latent_seq_len: int, + clip_seq_len: int, + sync_seq_len: int, + text_seq_len: int = 77, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None, + empty_string_feat: Optional[torch.Tensor] = None, + v2: bool = False) -> None: + super().__init__() + + self.v2 = v2 + self.latent_dim = latent_dim + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self._text_seq_len = text_seq_len + self.hidden_dim = hidden_dim + self.num_heads = num_heads + + if v2: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + nn.SiLU(), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + + self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim) + self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim) + self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) + # each synchformer output segment has 8 feature frames + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim))) + + self.final_layer = FinalBlock(hidden_dim, latent_dim) + + if v2: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=hidden_dim, + max_period=1) + else: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=256, + max_period=10000) + self.joint_blocks = nn.ModuleList([ + JointBlock(hidden_dim, + num_heads, + mlp_ratio=mlp_ratio, + pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) + ]) + + self.fused_blocks = nn.ModuleList([ + MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) + for i in range(fused_depth) + ]) + + if latent_mean is None: + # these values are not meant to be used + # if you don't provide mean/std here, we should load them later from a checkpoint + assert latent_std is None + latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) + latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) + else: + assert latent_std is not None + assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}' + if empty_string_feat is None: + empty_string_feat = torch.zeros((text_seq_len, text_dim)) + self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) + self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) + + self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) + self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True) + + self.initialize_weights() + self.initialize_rotations() + + def initialize_rotations(self): + base_freq = 1.0 + latent_rot = compute_rope_rotations(self._latent_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq, + device=self.device) + clip_rot = compute_rope_rotations(self._clip_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq * self._latent_seq_len / + self._clip_seq_len, + device=self.device) + + self.latent_rot = nn.Buffer(latent_rot, persistent=False) + self.clip_rot = nn.Buffer(clip_rot, persistent=False) + + def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self.initialize_rotations() + + def initialize_weights(self): + + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) + for block in self.fused_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.conv.weight, 0) + nn.init.constant_(self.final_layer.conv.bias, 0) + + # empty string feat shall be initialized by a CLIP encoder + nn.init.constant_(self.sync_pos_emb, 0) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + # return (x - self.latent_mean) / self.latent_std + return x.sub_(self.latent_mean).div_(self.latent_std) + + def unnormalize(self, x: torch.Tensor) -> torch.Tensor: + # return x * self.latent_std + self.latent_mean + return x.mul_(self.latent_std).add_(self.latent_mean) + + def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor) -> PreprocessedConditions: + """ + cache computations that do not depend on the latent/time step + i.e., the features are reused over steps during inference + """ + assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}' + assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}' + assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' + + bs = clip_f.shape[0] + + # B * num_segments (24) * 8 * 768 + num_sync_segments = self._sync_seq_len // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + + # extend vf to match x + clip_f = self.clip_input_proj(clip_f) # (B, VN, D) + sync_f = self.sync_input_proj(sync_f) # (B, VN, D) + text_f = self.text_input_proj(text_f) # (B, VN, D) + + # upsample the sync features to match the audio + sync_f = sync_f.transpose(1, 2) # (B, D, VN) + sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact') + sync_f = sync_f.transpose(1, 2) # (B, N, D) + + # get conditional features from the clip side + clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D) + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + + return PreprocessedConditions(clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + clip_f_c=clip_f_c, + text_f_c=text_f_c) + + def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, + conditions: PreprocessedConditions) -> torch.Tensor: + """ + for non-cacheable computations + """ + assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' + + clip_f = conditions.clip_f + sync_f = conditions.sync_f + text_f = conditions.text_f + clip_f_c = conditions.clip_f_c + text_f_c = conditions.text_f_c + + latent = self.audio_input_proj(latent) # (B, N, D) + global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D) + + global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D) + extended_c = global_c + sync_f + + for block in self.joint_blocks: + latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c, + self.latent_rot, self.clip_rot) # (B, N, D) + + for block in self.fused_blocks: + latent = block(latent, extended_c, self.latent_rot) + + # should be extended_c; this is a minor implementation error #55 + flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t + return flow + + def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + latent: (B, N, C) + vf: (B, T, C_V) + t: (B,) + """ + conditions = self.preprocess_conditions(clip_f, sync_f, text_f) + flow = self.predict_flow(latent, t, conditions) + return flow + + def get_empty_string_sequence(self, bs: int) -> torch.Tensor: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_clip_sequence(self, bs: int) -> torch.Tensor: + return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1) + + def get_empty_sync_sequence(self, bs: int) -> torch.Tensor: + return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1) + + def get_empty_conditions( + self, + bs: int, + *, + negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: + if negative_text_features is not None: + empty_text = negative_text_features + else: + empty_text = self.get_empty_string_sequence(1) + + empty_clip = self.get_empty_clip_sequence(1) + empty_sync = self.get_empty_sync_sequence(1) + conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text) + conditions.clip_f = conditions.clip_f.expand(bs, -1, -1) + conditions.sync_f = conditions.sync_f.expand(bs, -1, -1) + conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1) + if negative_text_features is None: + conditions.text_f = conditions.text_f.expand(bs, -1, -1) + conditions.text_f_c = conditions.text_f_c.expand(bs, -1) + + return conditions + + def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions, + empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor: + t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) + + if cfg_strength < 1.0: + return self.predict_flow(latent, t, conditions) + else: + return (cfg_strength * self.predict_flow(latent, t, conditions) + + (1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions)) + + def load_weights(self, src_dict) -> None: + if 't_embed.freqs' in src_dict: + del src_dict['t_embed.freqs'] + if 'latent_rot' in src_dict: + del src_dict['latent_rot'] + if 'clip_rot' in src_dict: + del src_dict['clip_rot'] + + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return self.latent_mean.device + + @property + def latent_seq_len(self) -> int: + return self._latent_seq_len + + @property + def clip_seq_len(self) -> int: + return self._clip_seq_len + + @property + def sync_seq_len(self) -> int: + return self._sync_seq_len + + +def small_16k(**kwargs) -> MMAudio: + num_heads = 7 + return MMAudio(latent_dim=20, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=250, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def small_44k(**kwargs) -> MMAudio: + num_heads = 7 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def medium_44k(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def large_44k(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=21, + fused_depth=14, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def large_44k_v2(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=21, + fused_depth=14, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + v2=True, + **kwargs) + + +def get_my_mmaudio(name: str, **kwargs) -> MMAudio: + if name == 'small_16k': + return small_16k(**kwargs) + if name == 'small_44k': + return small_44k(**kwargs) + if name == 'medium_44k': + return medium_44k(**kwargs) + if name == 'large_44k': + return large_44k(**kwargs) + if name == 'large_44k_v2': + return large_44k_v2(**kwargs) + + raise ValueError(f'Unknown model name: {name}') + + +if __name__ == '__main__': + network = get_my_mmaudio('small_16k') + + # print the number of parameters in terms of millions + num_params = sum(p.numel() for p in network.parameters()) / 1e6 + print(f'Number of parameters: {num_params:.2f}M') diff --git a/ThinkSound/models/mmmodules/model/sequence_config.py b/ThinkSound/models/mmmodules/model/sequence_config.py new file mode 100644 index 0000000000000000000000000000000000000000..14269014dc401b4751d172466813a935fddda6c1 --- /dev/null +++ b/ThinkSound/models/mmmodules/model/sequence_config.py @@ -0,0 +1,58 @@ +import dataclasses +import math + + +@dataclasses.dataclass +class SequenceConfig: + # general + duration: float + + # audio + sampling_rate: int + spectrogram_frame_rate: int + latent_downsample_rate: int = 2 + + # visual + clip_frame_rate: int = 8 + sync_frame_rate: int = 25 + sync_num_frames_per_segment: int = 16 + sync_step_size: int = 8 + sync_downsample_rate: int = 2 + + @property + def num_audio_frames(self) -> int: + # we need an integer number of latents + return self.latent_seq_len * self.spectrogram_frame_rate * self.latent_downsample_rate + + @property + def latent_seq_len(self) -> int: + return int( + math.ceil(self.duration * self.sampling_rate / self.spectrogram_frame_rate / + self.latent_downsample_rate)) + + @property + def clip_seq_len(self) -> int: + return int(self.duration * self.clip_frame_rate) + + @property + def sync_seq_len(self) -> int: + num_frames = self.duration * self.sync_frame_rate + num_segments = (num_frames - self.sync_num_frames_per_segment) // self.sync_step_size + 1 + return int(num_segments * self.sync_num_frames_per_segment / self.sync_downsample_rate) + + +CONFIG_16K = SequenceConfig(duration=8.0, sampling_rate=16000, spectrogram_frame_rate=256) +CONFIG_44K = SequenceConfig(duration=8.0, sampling_rate=44100, spectrogram_frame_rate=512) + +if __name__ == '__main__': + assert CONFIG_16K.latent_seq_len == 250 + assert CONFIG_16K.clip_seq_len == 64 + assert CONFIG_16K.sync_seq_len == 192 + assert CONFIG_16K.num_audio_frames == 128000 + + assert CONFIG_44K.latent_seq_len == 345 + assert CONFIG_44K.clip_seq_len == 64 + assert CONFIG_44K.sync_seq_len == 192 + assert CONFIG_44K.num_audio_frames == 353280 + + print('Passed') diff --git a/ThinkSound/models/mmmodules/model/transformer_layers.py b/ThinkSound/models/mmmodules/model/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfcaa57baed828b22071584c8d8fb1a29d0b4dd --- /dev/null +++ b/ThinkSound/models/mmmodules/model/transformer_layers.py @@ -0,0 +1,271 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange + +from mmaudio.ext.rotary_embeddings import apply_rope +from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func + print('flash_attn installed, using Flash Attention') +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return x * (1 + scale) + shift + + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + # training will crash without these contiguous calls and the CUDNN limitation + # I believe this is related to https://github.com/pytorch/pytorch/issues/133974 + # unresolved at the time of writing + fa_dtype_in = q.dtype + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + # out = F.scaled_dot_product_attention(q, k, v) + # out = rearrange(out, 'b h n d -> b n (h d)').contiguous() + # return out + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.bfloat16), (q, k, v)) + # print(f"q dtype: {q.dtype}") + # print(f"k dtype: {k.dtype}") + # print(f"v dtype: {v.dtype}") + # breakpoint() + out = flash_attn_func(q, k, v) + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b n (h d)') + # out = rearrange(out.to(fa_dtype_in), 'b h n d -> b n (h d)').contiguous() + return out + + +class SelfAttention(nn.Module): + + def __init__(self, dim: int, nheads: int): + super().__init__() + self.dim = dim + self.nheads = nheads + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = nn.RMSNorm(dim // nheads) + self.k_norm = nn.RMSNorm(dim // nheads) + + self.split_into_heads = Rearrange('b n (h d j) -> b h n d j', + h=nheads, + d=dim // nheads, + j=3) + + def pre_attention( + self, x: torch.Tensor, + rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # x: batch_size * n_tokens * n_channels + qkv = self.qkv(x) + q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + q = self.q_norm(q) + k = self.k_norm(k) + + if rot is not None: + q = apply_rope(q, rot) + k = apply_rope(k, rot) + + return q, k, v + + def forward( + self, + x: torch.Tensor, # batch_size * n_tokens * n_channels + ) -> torch.Tensor: + q, v, k = self.pre_attention(x) + out = attention(q, k, v) + return out + +class CrossAttention(nn.Module): + + def __init__(self, dim: int, nheads: int): + super().__init__() + self.dim = dim + self.nheads = nheads + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim * 2, bias=False) + self.q_norm = nn.RMSNorm(dim // nheads) + self.k_norm = nn.RMSNorm(dim // nheads) + + self.split_q_into_heads = Rearrange('b n (h d) -> b h n d', + h=nheads, + d=dim // nheads) + self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j', + h=nheads, + d=dim // nheads, + j=2) + + def pre_attention( + self, x: torch.Tensor, + context: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # x: batch_size * n_tokens * n_channels + q = self.to_q(x) + kv = self.to_kv(context) + q = self.split_q_into_heads(q) + k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + q = self.q_norm(q) + k = self.k_norm(k) + + + return q, k, v + + def forward( + self, + x: torch.Tensor, context=None + ) -> torch.Tensor: + q, v, k = self.pre_attention(x, context=context) + out = attention(q, k, v) + return out + + +class MMDitSingleBlock(nn.Module): + + def __init__(self, + dim: int, + nhead: int, + mlp_ratio: float = 4.0, + pre_only: bool = False, + kernel_size: int = 7, + padding: int = 3, + cross_attend: bool = False): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False) + self.attn = SelfAttention(dim, nhead) + if cross_attend: + self.cross_attn = CrossAttention(dim, nhead) + self.pre_only = pre_only + if pre_only: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) + else: + if kernel_size == 1: + self.linear1 = nn.Linear(dim, dim) + else: + self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False) + + if kernel_size == 1: + self.ffn = MLP(dim, int(dim * mlp_ratio)) + else: + self.ffn = ConvMLP(dim, + int(dim * mlp_ratio), + kernel_size=kernel_size, + padding=padding) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]): + # x: BS * N * D + # cond: BS * D + modulation = self.adaLN_modulation(c) + if self.pre_only: + (shift_msa, scale_msa) = modulation.chunk(2, dim=-1) + gate_msa = shift_mlp = scale_mlp = gate_mlp = None + else: + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, + gate_mlp) = modulation.chunk(6, dim=-1) + + x = modulate(self.norm1(x), shift_msa, scale_msa) + q, k, v = self.attn.pre_attention(x, rot) + return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp) + + def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor], context=None): + if self.pre_only: + return x + + (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c + x = x + self.linear1(attn_out) * gate_msa + + if context is not None: + x = x + self.cross_attn(x, context=context) + + r = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = x + self.ffn(r) * gate_mlp + + return x + + def forward(self, x: torch.Tensor, cond: torch.Tensor, + rot: Optional[torch.Tensor], context: torch.Tensor = None) -> torch.Tensor: + # x: BS * N * D + # cond: BS * D + x_qkv, x_conditions = self.pre_attention(x, cond, rot) + attn_out = attention(*x_qkv) + x = self.post_attention(x, attn_out, x_conditions, context = context) + + return x + + +class JointBlock(nn.Module): + + def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False): + super().__init__() + self.pre_only = pre_only + self.latent_block = MMDitSingleBlock(dim, + nhead, + mlp_ratio, + pre_only=False, + kernel_size=3, + padding=1) + self.clip_block = MMDitSingleBlock(dim, + nhead, + mlp_ratio, + pre_only=pre_only, + kernel_size=3, + padding=1) + self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1) + + def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor, + global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor, + clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # latent: BS * N1 * D + # clip_f: BS * N2 * D + # c: BS * (1/N) * D + x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot) + c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot) + t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None) + + latent_len = latent.shape[1] + clip_len = clip_f.shape[1] + text_len = text_f.shape[1] + + joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)] + + attn_out = attention(*joint_qkv) + x_attn_out = attn_out[:, :latent_len] + c_attn_out = attn_out[:, latent_len:latent_len + clip_len] + t_attn_out = attn_out[:, latent_len + clip_len:] + + latent = self.latent_block.post_attention(latent, x_attn_out, x_mod) + if not self.pre_only: + clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod) + text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod) + + return latent, clip_f, text_f + + +class FinalBlock(nn.Module): + + def __init__(self, dim, out_dim): + super().__init__() + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) + self.norm = nn.LayerNorm(dim, elementwise_affine=False) + self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3) + + def forward(self, latent, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + latent = modulate(self.norm(latent), shift, scale) + latent = self.conv(latent) + return latent diff --git a/ThinkSound/models/mmmodules/runner.py b/ThinkSound/models/mmmodules/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..755ee76bea7de3f31a14a5512710c39743dc9239 --- /dev/null +++ b/ThinkSound/models/mmmodules/runner.py @@ -0,0 +1,609 @@ +""" +trainer.py - wrapper and utility functions for network training +Compute loss, back-prop, update parameters, logging, etc. +""" +import os +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.distributed +import torch.optim as optim +from av_bench.evaluate import evaluate +from av_bench.extract import extract +from nitrous_ema import PostHocEMA +from omegaconf import DictConfig +from torch.nn.parallel import DistributedDataParallel as DDP + +from mmaudio.model.flow_matching import FlowMatching +from mmaudio.model.networks import get_my_mmaudio +from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K +from mmaudio.model.utils.features_utils import FeaturesUtils +from mmaudio.model.utils.parameter_groups import get_parameter_groups +from mmaudio.model.utils.sample_utils import log_normal_sample +from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero) +from mmaudio.utils.log_integrator import Integrator +from mmaudio.utils.logger import TensorboardLogger +from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator +from mmaudio.utils.video_joiner import VideoJoiner + + +class Runner: + + def __init__(self, + cfg: DictConfig, + log: TensorboardLogger, + run_path: Union[str, Path], + for_training: bool = True, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None): + self.exp_id = cfg.exp_id + self.use_amp = cfg.amp + self.enable_grad_scaler = cfg.enable_grad_scaler + self.for_training = for_training + self.cfg = cfg + + if cfg.model.endswith('16k'): + self.seq_cfg = CONFIG_16K + mode = '16k' + elif cfg.model.endswith('44k'): + self.seq_cfg = CONFIG_44K + mode = '44k' + else: + raise ValueError(f'Unknown model: {cfg.model}') + + self.sample_rate = self.seq_cfg.sampling_rate + self.duration_sec = self.seq_cfg.duration + + # setting up the model + empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0] + self.network = DDP(get_my_mmaudio(cfg.model, + latent_mean=latent_mean, + latent_std=latent_std, + empty_string_feat=empty_string_feat).cuda(), + device_ids=[local_rank], + broadcast_buffers=False) + if cfg.compile: + # NOTE: though train_fn and val_fn are very similar + # (early on they are implemented as a single function) + # keeping them separate and compiling them separately are CRUCIAL for high performance + self.train_fn = torch.compile(self.train_fn) + self.val_fn = torch.compile(self.val_fn) + + self.fm = FlowMatching(cfg.sampling.min_sigma, + inference_mode=cfg.sampling.method, + num_steps=cfg.sampling.num_steps) + + # ema profile + if for_training and cfg.ema.enable and local_rank == 0: + self.ema = PostHocEMA(self.network.module, + sigma_rels=cfg.ema.sigma_rels, + update_every=cfg.ema.update_every, + checkpoint_every_num_steps=cfg.ema.checkpoint_every, + checkpoint_folder=cfg.ema.checkpoint_folder, + step_size_correction=True).cuda() + self.ema_start = cfg.ema.start + else: + self.ema = None + + self.rng = torch.Generator(device='cuda') + self.rng.manual_seed(cfg['seed'] + local_rank) + + # setting up feature extractors and VAEs + if mode == '16k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_16k_ckpt'], + bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + elif mode == '44k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_44k_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + self.features = self.features.cuda().eval() + + if cfg.compile: + self.features.compile() + + # hyperparameters + self.log_normal_sampling_mean = cfg.sampling.mean + self.log_normal_sampling_scale = cfg.sampling.scale + self.null_condition_probability = cfg.null_condition_probability + self.cfg_strength = cfg.cfg_strength + + # setting up logging + self.log = log + self.run_path = Path(run_path) + vgg_cfg = cfg.data.VGGSound + if for_training: + self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos', + self.sample_rate, self.duration_sec) + else: + self.test_video_joiner = VideoJoiner(vgg_cfg.root, + self.run_path / 'test-sampled-videos', + self.sample_rate, self.duration_sec) + string_if_rank_zero(self.log, 'model_size', + f'{sum([param.nelement() for param in self.network.parameters()])}') + string_if_rank_zero( + self.log, 'number_of_parameters_that_require_gradient: ', + str( + sum([ + param.nelement() + for param in filter(lambda p: p.requires_grad, self.network.parameters()) + ]))) + info_if_rank_zero(self.log, 'torch version: ' + torch.__version__) + self.train_integrator = Integrator(self.log, distributed=True) + self.val_integrator = Integrator(self.log, distributed=True) + + # setting up optimizer and loss + if for_training: + self.enter_train() + parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0)) + self.optimizer = optim.AdamW(parameter_groups, + lr=cfg['learning_rate'], + weight_decay=cfg['weight_decay'], + betas=[0.9, 0.95], + eps=1e-6 if self.use_amp else 1e-8, + fused=True) + if self.enable_grad_scaler: + self.scaler = torch.amp.GradScaler(init_scale=2048) + self.clip_grad_norm = cfg['clip_grad_norm'] + + # linearly warmup learning rate + linear_warmup_steps = cfg['linear_warmup_steps'] + + def warmup(currrent_step: int): + return (currrent_step + 1) / (linear_warmup_steps + 1) + + warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup) + + # setting up learning rate scheduler + if cfg['lr_schedule'] == 'constant': + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1) + elif cfg['lr_schedule'] == 'poly': + total_num_iter = cfg['iterations'] + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, + lr_lambda=lambda x: + (1 - (x / total_num_iter))**0.9) + elif cfg['lr_schedule'] == 'step': + next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, + cfg['lr_schedule_steps'], + cfg['lr_schedule_gamma']) + else: + raise NotImplementedError + + self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer, + [warmup_scheduler, next_scheduler], + [linear_warmup_steps]) + + # Logging info + self.log_text_interval = cfg['log_text_interval'] + self.log_extra_interval = cfg['log_extra_interval'] + self.save_weights_interval = cfg['save_weights_interval'] + self.save_checkpoint_interval = cfg['save_checkpoint_interval'] + self.save_copy_iterations = cfg['save_copy_iterations'] + self.num_iterations = cfg['num_iterations'] + if cfg['debug']: + self.log_text_interval = self.log_extra_interval = 1 + + # update() is called when we log metrics, within the logger + self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval) + # update() is called every iteration, in this script + self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9) + else: + self.enter_val() + + def train_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + a_mean: torch.Tensor, + a_std: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # sample + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + bs = x1.shape[0] # batch_size * seq_len * num_channels + + # normalize the latents + x1 = self.network.module.normalize(x1) + + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_video = (samples < self.null_condition_probability) + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return x1, loss, mean_loss, t + + def val_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + x1: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + bs = x1.shape[0] # batch_size * seq_len * num_channels + # normalize the latents + x1 = self.network.module.normalize(x1) + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + # null mask is for when a video is provided but we decided to ignore it + null_video = (samples < self.null_condition_probability) + # complete mask is for when a video is not provided or we decided to ignore it + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return loss, mean_loss, t + + def train_pass(self, data, it: int = 0): + + if not self.for_training: + raise ValueError('train_pass() should not be called when not training.') + + self.enter_train() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + # these masks are for non-existent data; masking for CFG training is in train_fn + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + self.log.data_timer.end() + if it % self.log_extra_interval == 0: + unmasked_clip_f = clip_f.clone() + unmasked_sync_f = sync_f.clone() + unmasked_text_f = text_f.clone() + x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std) + + self.train_integrator.add_dict({'loss': mean_loss}) + + if it % self.log_text_interval == 0 and it != 0: + self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0]) + self.train_integrator.add_binned_tensor('binned_loss', loss, t) + self.train_integrator.finalize('train', it) + self.train_integrator.reset_except_hooks() + + # Backward pass + self.optimizer.zero_grad(set_to_none=True) + if self.enable_grad_scaler: + self.scaler.scale(mean_loss).backward() + self.scaler.unscale_(self.optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + mean_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.optimizer.step() + + if self.ema is not None and it >= self.ema_start: + self.ema.update() + self.scheduler.step() + self.integrator.add_scalar('grad_norm', grad_norm) + + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, + dtype=torch.bfloat16), torch.inference_mode(): + try: + if it % self.log_extra_interval == 0: + # save GT audio + # unnormalize the latents + x1 = self.network.module.unnormalize(x1[0:1]) + mel = self.features.decode(x1) + audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples + self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-gt-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + + # save audio from sampling + x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng) + clip_f = unmasked_clip_f[0:1] + sync_f = unmasked_sync_f[0:1] + text_f = unmasked_text_f[0:1] + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu()[0] + self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + except Exception as e: + self.log.warning(f'Error in extra logging: {e}') + if self.cfg.debug: + raise + + # Save network weights and checkpoint if needed + save_copy = it in self.save_copy_iterations + + if (it % self.save_weights_interval == 0 and it != 0) or save_copy: + self.save_weights(it) + + if it % self.save_checkpoint_interval == 0 and it != 0: + self.save_checkpoint(it, save_copy=save_copy) + + self.log.data_timer.start() + + @torch.inference_mode() + def validation_pass(self, data, it: int = 0): + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + + self.log.data_timer.end() + loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1) + + self.val_integrator.add_binned_tensor('binned_loss', loss, t) + self.val_integrator.add_dict({'loss': mean_loss}) + + self.log.data_timer.start() + + @torch.inference_mode() + def inference_pass(self, + data, + it: int, + data_cfg: DictConfig, + *, + save_eval: bool = True) -> Path: + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + # sample + x0 = torch.empty_like(a_mean).normal_(generator=self.rng) + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu() + for i in range(audio.shape[0]): + video_id = data['id'][i] + if (not self.for_training) and i == 0: + # save very few videos + self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1)) + + if data_cfg.output_subdir is not None: + # validation + if save_eval: + iter_naming = f'{it:09d}' + else: + iter_naming = 'val-cache' + audio_dir = self.log.log_audio(iter_naming, + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate, + subdir=Path(data_cfg.output_subdir)) + if save_eval and i == 0: + self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}', + audio[i].transpose(0, 1)) + else: + # full test set, usually + audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled', + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate) + + return Path(audio_dir) + + @torch.inference_mode() + def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]: + with torch.amp.autocast('cuda', enabled=False): + if local_rank == 0: + extract(audio_path=audio_dir, + output_path=audio_dir / 'cache', + device='cuda', + batch_size=32, + audio_length=8) + output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache), + pred_audio_cache=audio_dir / 'cache') + for k, v in output_metrics.items(): + # pad k to 10 characters + # pad v to 10 decimal places + self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it) + self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}') + else: + output_metrics = None + + return output_metrics + + def save_weights(self, it, save_copy=False): + if local_rank != 0: + return + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_{it}.pth' + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + # if last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) + self.log.info(f'Network weights shadowed to {shadow_path}.') + + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + def save_checkpoint(self, it, save_copy=False): + if local_rank != 0: + return + + checkpoint = { + 'it': it, + 'weights': self.network.module.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict(), + 'ema': self.ema.state_dict() if self.ema is not None else None, + } + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth' + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + # if ckpt_last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) # moves the file + self.log.info(f'Checkpoint shadowed to {shadow_path}.') + + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + def get_latest_checkpoint_path(self): + ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if not ckpt_path.exists(): + info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.') + return None + return ckpt_path + + def get_latest_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_last.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def get_final_ema_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_ema_final.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def load_checkpoint(self, path): + # This method loads everything and should be used to resume training + map_location = 'cuda:%d' % local_rank + checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + it = checkpoint['it'] + weights = checkpoint['weights'] + optimizer = checkpoint['optimizer'] + scheduler = checkpoint['scheduler'] + if self.ema is not None: + self.ema.load_state_dict(checkpoint['ema']) + self.log.info(f'EMA states loaded from step {self.ema.step}') + + map_location = 'cuda:%d' % local_rank + self.network.module.load_state_dict(weights) + self.optimizer.load_state_dict(optimizer) + self.scheduler.load_state_dict(scheduler) + + self.log.info(f'Global iteration {it} loaded.') + self.log.info('Network weights, optimizer states, and scheduler states loaded.') + + return it + + def load_weights_in_memory(self, src_dict): + self.network.module.load_weights(src_dict) + self.log.info('Network weights loaded from memory.') + + def load_weights(self, path): + # This method loads only the network weight and should be used to load a pretrained model + map_location = 'cuda:%d' % local_rank + src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + self.log.info(f'Importing network weights from {path}...') + self.load_weights_in_memory(src_dict) + + def weights(self): + return self.network.module.state_dict() + + def enter_train(self): + self.integrator = self.train_integrator + self.network.train() + return self + + def enter_val(self): + self.network.eval() + return self diff --git a/ThinkSound/models/mmmodules/sample.py b/ThinkSound/models/mmmodules/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..72b83389d7dbb55bed02991f51731b0d1e346a2b --- /dev/null +++ b/ThinkSound/models/mmmodules/sample.py @@ -0,0 +1,90 @@ +import json +import logging +import os +import random + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, open_dict +from tqdm import tqdm + +from mmaudio.data.data_setup import setup_test_datasets +from mmaudio.runner import Runner +from mmaudio.utils.dist_utils import info_if_rank_zero +from mmaudio.utils.logger import TensorboardLogger + +local_rank = int(os.environ['LOCAL_RANK']) +world_size = int(os.environ['WORLD_SIZE']) + + +def sample(cfg: DictConfig): + # initial setup + num_gpus = world_size + run_dir = HydraConfig.get().run.dir + + # wrap python logger with a tensorboard logger + log = TensorboardLogger(cfg.exp_id, + run_dir, + logging.getLogger(), + is_rank0=(local_rank == 0), + enable_email=cfg.enable_email and not cfg.debug) + + info_if_rank_zero(log, f'All configuration: {cfg}') + info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') + + # cuda setup + torch.cuda.set_device(local_rank) + torch.backends.cudnn.benchmark = cfg.cudnn_benchmark + + # number of dataloader workers + info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') + + # Set seeds to ensure the same initialization + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + random.seed(cfg.seed) + + # setting up configurations + info_if_rank_zero(log, f'Configuration: {cfg}') + info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') + + # construct the trainer + runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val() + + # load the last weights if needed + if cfg['weights'] is not None: + info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}') + runner.load_weights(cfg['weights']) + cfg['weights'] = None + else: + weights = runner.get_final_ema_weight_path() + if weights is not None: + info_if_rank_zero(log, f'Automatically finding weight: {weights}') + runner.load_weights(weights) + + # setup datasets + dataset, sampler, loader = setup_test_datasets(cfg) + data_cfg = cfg.data.ExtractedVGG_test + with open_dict(data_cfg): + if cfg.output_name is not None: + # append to the tag + data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}' + + # loop + audio_path = None + for curr_iter, data in enumerate(tqdm(loader)): + new_audio_path = runner.inference_pass(data, curr_iter, data_cfg) + if audio_path is None: + audio_path = new_audio_path + else: + assert audio_path == new_audio_path, 'Different audio path detected' + + info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}') + output_metrics = runner.eval(audio_path, curr_iter, data_cfg) + + if local_rank == 0: + # write the output metrics to run_dir + output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json') + with open(output_metrics_path, 'w') as f: + json.dump(output_metrics, f, indent=4) diff --git a/ThinkSound/models/pqmf.py b/ThinkSound/models/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..007fdb51ec797554c1cdd4d9363894d743d970bf --- /dev/null +++ b/ThinkSound/models/pqmf.py @@ -0,0 +1,393 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from scipy.optimize import fmin +from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + +class PQMF(nn.Module): + """ + Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. + Uses polyphase representation which is computationally more efficient for real-time. + + Parameters: + - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. + - num_bands (int): Number of desired frequency bands. It must be a power of 2. + """ + + def __init__(self, attenuation, num_bands): + super(PQMF, self).__init__() + + # Ensure num_bands is a power of 2 + is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + assert is_power_of_2, "'num_bands' must be a power of 2." + + # Create the prototype filter + prototype_filter = design_prototype_filter(attenuation, num_bands) + filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) + padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) + + # Register filters and settings + self.register_buffer("filter_bank", padded_filter_bank) + self.register_buffer("prototype", prototype_filter) + self.num_bands = num_bands + + def forward(self, signal): + """Decompose the signal into multiple frequency bands.""" + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + signal = prepare_signal_dimensions(signal) + # The signal length must be a multiple of num_bands. Pad it with zeros. + signal = pad_signal(signal, self.num_bands) + # run it + signal = polyphase_analysis(signal, self.filter_bank) + return apply_alias_cancellation(signal) + + def inverse(self, bands): + """Reconstruct the original signal from the frequency bands.""" + bands = apply_alias_cancellation(bands) + return polyphase_synthesis(bands, self.filter_bank) + + +def prepare_signal_dimensions(signal): + """ + Rearrange signal into Batch x Channels x Length. + + Parameters + ---------- + signal : torch.Tensor or numpy.ndarray + The input signal. + + Returns + ------- + torch.Tensor + Preprocessed signal tensor. + """ + # Convert numpy to torch tensor + if isinstance(signal, np.ndarray): + signal = torch.from_numpy(signal) + + # Ensure tensor + if not isinstance(signal, torch.Tensor): + raise ValueError("Input should be either a numpy array or a PyTorch tensor.") + + # Modify dimension of signal to Batch x Channels x Length + if signal.dim() == 1: + # This is just a mono signal. Unsqueeze to 1 x 1 x Length + signal = signal.unsqueeze(0).unsqueeze(0) + elif signal.dim() == 2: + # This is a multi-channel signal (e.g. stereo) + # Rearrange so that larger dimension (Length) is last + if signal.shape[0] > signal.shape[1]: + signal = signal.T + # Unsqueeze to 1 x Channels x Length + signal = signal.unsqueeze(0) + return signal + +def pad_signal(signal, num_bands): + """ + Pads the signal to make its length divisible by the given number of bands. + + Parameters + ---------- + signal : torch.Tensor + The input signal tensor, where the last dimension represents the signal length. + + num_bands : int + The number of bands by which the signal length should be divisible. + + Returns + ------- + torch.Tensor + The padded signal tensor. If the original signal length was already divisible + by num_bands, returns the original signal unchanged. + """ + remainder = signal.shape[-1] % num_bands + if remainder > 0: + padding_size = num_bands - remainder + signal = nn.functional.pad(signal, (0, padding_size)) + return signal + +def generate_modulated_filter_bank(prototype_filter, num_bands): + """ + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + + Parameters + ---------- + prototype_filter : torch.Tensor + The prototype filter used as the basis for modulation. + num_bands : int + The number of desired subbands or filters. + + Returns + ------- + torch.Tensor + A bank of cosine modulated filters. + """ + + # Initialize indices for modulation. + subband_indices = torch.arange(num_bands).reshape(-1, 1) + + # Calculate the length of the prototype filter. + filter_length = prototype_filter.shape[-1] + + # Generate symmetric time indices centered around zero. + time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) + + # Calculate phase offsets to ensure orthogonality between subbands. + phase_offsets = (-1)**subband_indices * np.pi / 4 + + # Compute the cosine modulation function. + modulation = torch.cos( + (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets + ) + + # Apply modulation to the prototype filter. + modulated_filters = 2 * prototype_filter * modulation + + return modulated_filters + + +def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): + """ + Design a lowpass filter using the Kaiser window. + + Parameters + ---------- + angular_cutoff : float + The angular frequency cutoff of the filter. + attenuation : float + The desired stopband attenuation in decibels (dB). + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The designed lowpass filter coefficients. + """ + + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) + + # Ensure the estimated length is odd. + estimated_length = 2 * (estimated_length // 2) + 1 + + if filter_length is None: + filter_length = estimated_length + + return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + +def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): + """ + Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 + + Parameters + ---------- + angular_cutoff : float + Angular frequency cutoff of the filter. + attenuation : float + Desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. + + Returns + ------- + float + The computed objective (loss) value for the given filter specs. + """ + + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) + convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + +def design_prototype_filter(attenuation, num_bands, filter_length=None): + """ + Design the optimal prototype filter for a multiband system given the desired specs. + + Parameters + ---------- + attenuation : float + The desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The optimal prototype filter coefficients. + """ + + optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, disp=0)[0] + + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) + return torch.tensor(prototype_filter, dtype=torch.float32) + +def pad_to_nearest_power_of_two(x): + """ + Pads the input tensor 'x' on both sides such that its last dimension + becomes the nearest larger power of two. + + Parameters: + ----------- + x : torch.Tensor + The input tensor to be padded. + + Returns: + -------- + torch.Tensor + The padded tensor. + """ + current_length = x.shape[-1] + target_length = 2**math.ceil(math.log2(current_length)) + + total_padding = target_length - current_length + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + + return nn.functional.pad(x, (left_padding, right_padding)) + +def apply_alias_cancellation(x): + """ + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second + row's first element in a tensor. + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during + the reconstruction. + + Parameters: + ----------- + x : torch.Tensor + The input tensor. + + Returns: + -------- + torch.Tensor + Tensor with specific elements' sign inverted for alias cancellation. + """ + + # Create a mask of the same shape as 'x', initialized with all ones + mask = torch.ones_like(x) + + # Update specific elements in the mask to -1 to perform inversion + mask[..., 1::2, ::2] = -1 + + # Apply the mask to the input tensor 'x' + return x * mask + +def ensure_odd_length(tensor): + """ + Pads the last dimension of a tensor to ensure its size is odd. + + Parameters: + ----------- + tensor : torch.Tensor + Input tensor whose last dimension might need padding. + + Returns: + -------- + torch.Tensor + The original tensor if its last dimension was already odd, + or the padded tensor with an odd-sized last dimension. + """ + + last_dim_size = tensor.shape[-1] + + if last_dim_size % 2 == 0: + tensor = nn.functional.pad(tensor, (0, 1)) + + return tensor + +def polyphase_analysis(signal, filter_bank): + """ + Applies the polyphase method to efficiently analyze the signal using a filter bank. + + Parameters: + ----------- + signal : torch.Tensor + Input signal tensor with shape (Batch x Channels x Length). + + filter_bank : torch.Tensor + Filter bank tensor with shape (Bands x Length). + + Returns: + -------- + torch.Tensor + Signal split into sub-bands. (Batch x Channels x Bands x Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) + + # Rearrange the filter bank for matching signal shape + filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) + + # Apply convolution with appropriate padding to maintain spatial dimensions + padding = filter_bank.shape[-1] // 2 + filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) + + # Truncate the last dimension post-convolution to adjust the output shape + filtered_signal = filtered_signal[..., :-1] + # Rearrange the first dimension back into Batch x Channels + filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) + + return filtered_signal + +def polyphase_synthesis(signal, filter_bank): + """ + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + + Parameters + ---------- + signal : torch.Tensor + Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). + + filter_bank : torch.Tensor + Analysis filter bank (shape: Bands x Length). + + should_rearrange : bool, optional + Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. + + Returns + ------- + torch.Tensor + Reconstructed signal (shape: Batch x Channels X Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange the filter bank + filter_bank = filter_bank.flip(-1) + filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) + + # Combine Batch x Channels into one dimension for now. + signal = rearrange(signal, "b c n t -> (b c) n t") + + # Apply convolution with appropriate padding + padding_amount = filter_bank.shape[-1] // 2 + 1 + reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) + + # Scale the result + reconstructed_signal = reconstructed_signal[..., :-1] * num_bands + + # Reorganize the output and truncate + reconstructed_signal = reconstructed_signal.flip(1) + reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] + + return reconstructed_signal \ No newline at end of file diff --git a/ThinkSound/models/pretrained.py b/ThinkSound/models/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e83af343587da91af92218f309c969c5a975b5ed --- /dev/null +++ b/ThinkSound/models/pretrained.py @@ -0,0 +1,25 @@ +import json + +from .factory import create_model_from_config +from .utils import load_ckpt_state_dict + +from huggingface_hub import hf_hub_download + +def get_pretrained_model(name: str): + + model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') + + with open(model_config_path) as f: + model_config = json.load(f) + + model = create_model_from_config(model_config) + + # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file + try: + model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') + except Exception as e: + model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') + + model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + return model, model_config \ No newline at end of file diff --git a/ThinkSound/models/pretransforms.py b/ThinkSound/models/pretransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a --- /dev/null +++ b/ThinkSound/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/ThinkSound/models/transformer (1).py b/ThinkSound/models/transformer (1).py new file mode 100644 index 0000000000000000000000000000000000000000..444420f34d77129b74748fcf5e3409c26e4ad2ee --- /dev/null +++ b/ThinkSound/models/transformer (1).py @@ -0,0 +1,862 @@ +from functools import reduce + +from einops import rearrange +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.amp import autocast +from typing import Callable, Literal +from torch.nn.attention.flex_attention import flex_attention + +try: + from flash_attn import flash_attn_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +from .utils import compile + +try: + torch._dynamo.config.cache_size_limit = 5000 + flex_attention_compiled = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs") +except: + flex_attention_compiled = flex_attention + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast("cuda", enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast("cuda", enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=10.0): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * init_alpha) + self.gamma = nn.Parameter(torch.ones(dim)) + self.beta = nn.Parameter(torch.zeros(dim)) + + def forward(self, x): + x = F.tanh(self.alpha * x) + return self.gamma * x + self.beta + +class RunningInstanceNorm(nn.Module): + def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): + super().__init__() + self.register_buffer("running_mean", torch.zeros(1,1,dim)) + self.register_buffer("running_std", torch.ones(1,1,dim)) + self.saturate = saturate + self.eps = eps + self.momentum = momentum + self.dim = dim + self.trainable_gain = trainable_gain + if self.trainable_gain: + self.gain = nn.Parameter(torch.ones(1)) + + def _update_stats(self, x): + self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) + self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) + + def forward(self, x): + if self.training: + self._update_stats(x) + x = (x - self.running_mean) / self.running_std + if self.saturate: + x = torch.asinh(x) + if self.trainable_gain: + x = x * self.gain + return x + +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False, force_fp32=False, eps=1e-5): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + self.eps = eps + + self.force_fp32 = force_fp32 + + def forward(self, x): + if not self.force_fp32: + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) + else: + output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) + return output.to(x.dtype) + +class LayerScale(nn.Module): + def __init__(self, dim, init_val = 1e-5): + super().__init__() + self.scale = nn.Parameter(torch.full([dim], init_val)) + def forward(self, x): + return x * self.scale + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + #@compile + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'dyt', 'none'] = 'none', + differential = False, + feat_scale = False + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.differential = differential + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + if differential: + self.to_q = nn.Linear(dim, dim * 2, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) + else: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + if differential: + self.to_qkv = nn.Linear(dim, dim * 5, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + if qk_norm not in ['l2', 'ln', 'dyt','none']: + raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + elif self.qk_norm == 'dyt': + self.q_norm = DynamicTanh(dim_heads) + self.k_norm = DynamicTanh(dim_heads) + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.zeros(dim)) + self.lambda_hf = nn.Parameter(torch.zeros(dim)) + + self.causal = causal + if causal: + print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.') + + @compile + def apply_qk_layernorm(self, q, k): + q_type = q.dtype + k_type = k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + return q, k + + + def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): + + if self.num_heads != self.kv_heads: + # Repeat interleave kv_heads to match q_heads for grouped query attention + heads_per_kv_head = self.num_heads // self.kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + flash_attn_available = flash_attn_func is not None + if flash_attn_sliding_window is not None and (not flash_attn_available): + print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available") + + if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None: + print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention") + + if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): + print(f"Disabling FlexAttention because causal is set") + flex_attention_block_mask = None + flex_attention_score_mod = None + + if flex_attention_block_mask is not None or flex_attention_score_mod is not None: + out = flex_attention_compiled(q,k,v, + block_mask = flex_attention_block_mask, + score_mod = flex_attention_score_mod) + elif flash_attn_available: + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) + + if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: + q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + else: + out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) + return out + + + #@compile + def forward( + self, + x, + context = None, + rotary_pos_emb = None, + causal = None, + flex_attention_block_mask = None, + flex_attention_score_mod = None, + flash_attn_sliding_window = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + if self.differential: + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) + q = torch.stack([q, q_diff], dim = 1) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) + k = torch.stack([k, k_diff], dim = 1) + else: + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + if self.differential: + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) + q = torch.stack([q, q_diff], dim = 1) + k = torch.stack([k, k_diff], dim = 1) + else: + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm != "none": + q, k = self.apply_qk_layernorm(q, k) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype = q.dtype + k_dtype = k.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + if q.shape[-2] >= k.shape[-2]: + ratio = q.shape[-2] / k.shape[-2] + q_freqs, k_freqs = freqs, ratio * freqs + else: + ratio = k.shape[-2] / q.shape[-2] + q_freqs, k_freqs = ratio * freqs, freqs + q = apply_rotary_pos_emb(q, q_freqs) + k = apply_rotary_pos_emb(k, k_freqs) + q = q.to(v.dtype) + k = k.to(v.dtype) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.differential: + q, q_diff = q.unbind(dim = 1) + k, k_diff = k.unbind(dim = 1) + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out = out - out_diff + else: + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + #@compile + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + add_rope = False, + layer_scale = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = min(dim_heads,dim) + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + if layer_scale and zero_init_branch_outputs: + print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False') + zero_init_branch_outputs = False + + self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) + + self.add_rope = add_rope + + self.self_attn = Attention( + dim, + dim_heads = self.dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.cross_attend = cross_attend + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.cross_attn = Attention( + dim, + dim_heads = self.dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.layer_ix = layer_ix + + self.conformer = None + if conformer: + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) + self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) + + self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None + + @compile + def forward( + self, + x, + context = None, + global_cond=None, + rotary_pos_emb = None, + self_attention_block_mask = None, + self_attention_score_mod = None, + cross_attention_block_mask = None, + cross_attention_score_mod = None, + self_attention_flash_sliding_window = None, + cross_attention_flash_sliding_window = None + ): + if rotary_pos_emb is None and self.add_rope: + rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) + + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) + x = x * torch.sigmoid(1 - gate_self) + x = self.self_attn_scale(x) + x = x + residual + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = self.ff_scale(x) + x = x + residual + + else: + x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + x = x + self.ff_scale(self.ff(self.ff_norm(x))) + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + final_cross_attn_ix=-1, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + num_memory_tokens=0, + sliding_window=None, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) + + self.global_cond_embedder = None + if global_cond_dim is not None: + self.global_cond_embedder = nn.Sequential( + nn.Linear(global_cond_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6) + ) + + self.final_cross_attn_ix = final_cross_attn_ix + + self.sliding_window = sliding_window + + for i in range(depth): + should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i <= (self.final_cross_attn_ix)) + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = should_cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + prepend_embeds = None, + global_cond = None, + return_info = False, + use_checkpointing = True, + exit_layer_ix = None, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if self.num_memory_tokens > 0: + memory_tokens = self.memory_tokens.expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + if global_cond is not None and self.global_cond_embedder is not None: + global_cond = self.global_cond_embedder(global_cond) + + # Iterate over the transformer layers + for layer_ix, layer in enumerate(self.layers): + + if use_checkpointing: + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, self_attention_flash_sliding_window = self.sliding_window, **kwargs) + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, self_attention_flash_sliding_window = self.sliding_window, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + if exit_layer_ix is not None and layer_ix == exit_layer_ix: + x = x[:, self.num_memory_tokens:, :] + + if return_info: + return x, info + + return x + + x = x[:, self.num_memory_tokens:, :] + + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/ThinkSound/models/transformer.py b/ThinkSound/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7eccd0ada04dc4587b02193ac3ab730f648cc0 --- /dev/null +++ b/ThinkSound/models/transformer.py @@ -0,0 +1,993 @@ +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +from .utils import compile +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return x * (1 + scale) + shift + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=10.0): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1) * init_alpha) + self.gamma = nn.Parameter(torch.ones(dim)) + self.beta = nn.Parameter(torch.zeros(dim)) + + def forward(self, x): + x = F.tanh(self.alpha * x) + return self.gamma * x + self.beta + +class RunningInstanceNorm(nn.Module): + def __init__(self, dim, momentum = 0.99, eps = 1e-4, saturate = True, trainable_gain = True): + super().__init__() + self.register_buffer("running_mean", torch.zeros(1,1,dim)) + self.register_buffer("running_std", torch.ones(1,1,dim)) + self.saturate = saturate + self.eps = eps + self.momentum = momentum + self.dim = dim + self.trainable_gain = trainable_gain + if self.trainable_gain: + self.gain = nn.Parameter(torch.ones(1)) + + def _update_stats(self, x): + self.running_mean = self.running_mean * self.momentum + x.detach().mean(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum) + self.running_std = (self.running_std * self.momentum + x.detach().std(dim = [0,1]).view(1, 1, self.dim) * (1 - self.momentum)).clip(min = self.eps) + + def forward(self, x): + if self.training: + self._update_stats(x) + x = (x - self.running_mean) / self.running_std + if self.saturate: + x = torch.asinh(x) + if self.trainable_gain: + x = x * self.gain + return x + +class LayerNorm(nn.Module): + def __init__(self, dim, bias = False, fix_scale=False, force_fp32=False, eps=1e-5): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + self.eps = eps + + self.force_fp32 = force_fp32 + + def forward(self, x): + if not self.force_fp32: + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta, eps=self.eps) + else: + output = F.layer_norm(x.float(), x.shape[-1:], weight=self.gamma.float(), bias=self.beta.float(), eps=self.eps) + return output.to(x.dtype) + +class LayerScale(nn.Module): + def __init__(self, dim, init_val = 1e-5): + super().__init__() + self.scale = nn.Parameter(torch.full([dim], init_val)) + def forward(self, x): + return x * self.scale + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'rns', 'dyt', 'none'] = 'none', + differential = False, + feat_scale = False + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.differential = differential + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + if differential: + self.to_q = nn.Linear(dim, dim * 2, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 3, bias=False) + else: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + if differential: + self.to_qkv = nn.Linear(dim, dim * 5, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + if qk_norm not in ['l2', 'ln', 'rns', 'dyt','none']: + raise ValueError(f'qk_norm must be one of ["l2", "ln", "none"], got {qk_norm}') + + self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + elif self.qk_norm == 'rns': + self.q_norm = nn.RMSNorm(dim_heads) + self.k_norm = nn.RMSNorm(dim_heads) + elif self.qk_norm == 'dyt': + self.q_norm = DynamicTanh(dim_heads) + self.k_norm = DynamicTanh(dim_heads) + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.zeros(dim)) + self.lambda_hf = nn.Parameter(torch.zeros(dim)) + + self.causal = causal + if causal: + print('Using `causal` argument disables FlexAttention. If you want to use them together, incorporate causal masking into `flex_attention_block_mask`.') + + @compile + def apply_qk_layernorm(self, q, k): + q_type = q.dtype + k_type = k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + return q, k + + + def apply_attn(self, q, k, v, causal = None, flex_attention_block_mask = None, flex_attention_score_mod = None, flash_attn_sliding_window = None): + + if self.num_heads != self.kv_heads: + # Repeat interleave kv_heads to match q_heads for grouped query attention + heads_per_kv_head = self.num_heads // self.kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + flash_attn_available = flash_attn_func is not None + if flash_attn_sliding_window is not None and (not flash_attn_available): + print(f"Cannot use FlashAttention sliding window as FlashAttention is disabled or not available") + + if (flex_attention_block_mask is not None or flex_attention_score_mod is not None) and flash_attn_sliding_window is not None: + print(f"cannot use both FlashAttention and FlexAttention, favouring FlexAttention") + + if causal and (flex_attention_block_mask is not None or flex_attention_score_mod is not None): + print(f"Disabling FlexAttention because causal is set") + flex_attention_block_mask = None + flex_attention_score_mod = None + + if flex_attention_block_mask is not None or flex_attention_score_mod is not None: + out = flex_attention_compiled(q,k,v, + block_mask = flex_attention_block_mask, + score_mod = flex_attention_score_mod) + elif flash_attn_available: + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d'), (q, k, v)) + + if fa_dtype_in != torch.float16 and fa_dtype_in != torch.bfloat16: + q, k, v = map(lambda t: t.to(torch.bfloat16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal, window_size=flash_attn_sliding_window if (flash_attn_sliding_window is not None) else [-1,-1]) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + else: + out = F.scaled_dot_product_attention(q, k, v, is_causal = causal) + return out + + + #@compile + def forward( + self, + x, + context = None, + rotary_pos_emb = None, + causal = None, + flex_attention_block_mask = None, + flex_attention_score_mod = None, + flash_attn_sliding_window = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + if self.differential: + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q, q_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, q_diff)) + q = torch.stack([q, q_diff], dim = 1) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k, k_diff, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, k_diff, v)) + k = torch.stack([k, k_diff], dim = 1) + else: + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + if self.differential: + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v, q_diff, k_diff)) + q = torch.stack([q, q_diff], dim = 1) + k = torch.stack([k, k_diff], dim = 1) + else: + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm != "none": + q, k = self.apply_qk_layernorm(q, k) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype = q.dtype + k_dtype = k.dtype + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + if q.shape[-2] >= k.shape[-2]: + ratio = q.shape[-2] / k.shape[-2] + q_freqs, k_freqs = freqs, ratio * freqs + else: + ratio = k.shape[-2] / q.shape[-2] + q_freqs, k_freqs = ratio * freqs, freqs + q = apply_rotary_pos_emb(q, q_freqs) + k = apply_rotary_pos_emb(k, k_freqs) + q = q.to(v.dtype) + k = k.to(v.dtype) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.differential: + q, q_diff = q.unbind(dim = 1) + k, k_diff = k.unbind(dim = 1) + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out_diff = self.apply_attn(q_diff, k_diff, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + out = out - out_diff + else: + out = self.apply_attn(q, k, v, causal = causal, flex_attention_block_mask = flex_attention_block_mask, flex_attention_score_mod = flex_attention_score_mod, flash_attn_sliding_window = flash_attn_sliding_window) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + self.lambda_dc * out_dc + self.lambda_hf * out_hf + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + #@compile + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + add_rope = False, + layer_scale = False, + use_sync_block_film = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = min(dim_heads,dim) + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + if layer_scale and zero_init_branch_outputs: + print('zero_init_branch_outputs is redundant with layer_scale, setting zero_init_branch_outputs to False') + zero_init_branch_outputs = False + + self.pre_norm = LayerNorm(dim,**norm_kwargs) if not remove_norms else DynamicTanh(dim) + + self.add_rope = add_rope + + self.self_attn = Attention( + dim, + dim_heads = self.dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.self_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.cross_attend = cross_attend + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.cross_attn = Attention( + dim, + dim_heads = self.dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + self.cross_attn_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else DynamicTanh(dim) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + self.ff_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.layer_ix = layer_ix + + self.conformer = None + if conformer: + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) + self.conformer_scale = LayerScale(dim) if layer_scale else nn.Identity() + + self.global_cond_dim = global_cond_dim + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Parameter(torch.randn(6*dim)/dim**0.5) + + self.rope = RotaryEmbedding(self.dim_heads // 2) if add_rope else None + + if use_sync_block_film: + self.sync_film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + + @compile + def forward( + self, + x, + context = None, + global_cond=None, + rotary_pos_emb = None, + self_attention_block_mask = None, + self_attention_score_mod = None, + cross_attention_block_mask = None, + cross_attention_score_mod = None, + self_attention_flash_sliding_window = None, + cross_attention_flash_sliding_window = None, + sync_cond = None, + prepend_length=0 + ): + if rotary_pos_emb is None and self.add_rope: + rotary_pos_emb = self.rope.forward_from_seq_len(x.shape[-2]) + + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + if len(global_cond.shape) == 2: + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).unsqueeze(1).chunk(6, dim=-1) + else: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = (self.to_scale_shift_gate + global_cond).chunk(6, dim=-1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window) + x = x * torch.sigmoid(1 - gate_self) + x = self.self_attn_scale(x) + x = x + residual + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + if sync_cond is not None and hasattr(self, 'sync_film_generator'): + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = self.ff_scale(x) + x = x + residual + + else: + x = x + self.self_attn_scale(self.self_attn(self.pre_norm(x), rotary_pos_emb = rotary_pos_emb, flex_attention_block_mask = self_attention_block_mask, flex_attention_score_mod = self_attention_score_mod, flash_attn_sliding_window = self_attention_flash_sliding_window)) + + if context is not None and self.cross_attend: + x = x + self.cross_attn_scale(self.cross_attn(self.cross_attend_norm(x), context = context, flex_attention_block_mask = cross_attention_block_mask, flex_attention_score_mod = cross_attention_score_mod, flash_attn_sliding_window = cross_attention_flash_sliding_window)) + + if self.conformer is not None: + x = x + self.conformer_scale(self.conformer(x)) + + if sync_cond is not None and hasattr(self, 'sync_film_generator'): + prepend_part = x[:, :prepend_length, :] + audio_part = x[:, prepend_length:, :] + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + modulated_audio_part = audio_part * (1 + scale) + shift + x = torch.cat([prepend_part, modulated_audio_part], dim=1) + + x = x + self.ff_scale(self.ff(self.ff_norm(x))) + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + pre_cross_attn_ix=-1, + final_cross_attn_ix=-1, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + num_memory_tokens=0, + sliding_window=None, + use_mlp=False, + use_add_norm=False, + use_gated=False, + use_final_layer=False, + use_zeros=False, + use_conv=False, + use_fusion_mlp=False, + use_film=False, + use_sync_film=False, + use_sync_gated=False, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + if use_mlp: + self.project_in = nn.Sequential( + nn.Linear(dim_in, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim, bias=False) + ) + else: + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + self.video_temporal_conv = None + self.audio_temporal_conv = None + self.fusion_mlp = None + if use_conv: + self.video_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) + self.audio_temporal_conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1) + if use_fusion_mlp: + self.fusion_mlp = nn.Sequential( + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + self.num_memory_tokens) + + self.adaLN_modulation = None + if global_cond_dim is not None: + if use_final_layer: + self.norm_final = LayerNorm(dim) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 2 * dim, bias=True + ), + ) + + if use_zeros: + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.project_out.weight, 0) + self.global_cond_embedder = nn.Sequential( + nn.Linear(global_cond_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6) + ) + if use_zeros: + nn.init.constant_(self.global_cond_embedder[-1].weight, 0) + nn.init.constant_(self.global_cond_embedder[-1].bias, 0) + nn.init.constant_(self.global_cond_embedder[0].weight, 0) + nn.init.constant_(self.global_cond_embedder[0].bias, 0) + + self.final_cross_attn_ix = final_cross_attn_ix + self.use_gated = use_gated + self.use_film = use_film + self.use_add_norm = use_add_norm + if self.use_add_norm: + self.add_norm = nn.LayerNorm(dim) + if use_gated: + self.gate = nn.Parameter(torch.ones(1, 1, dim)) + + if use_film: + self.film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + else: + self.film_generator = None + + if use_sync_film: + self.sync_film_generator = nn.Sequential( + nn.Linear(dim, dim, bias=False), + nn.SiLU(), + nn.Linear(dim, dim * 2, bias=False) # 从sync_cond生成scale和shift + ) + else: + self.sync_film_generator = None + if use_sync_gated: + self.sync_gate = nn.Parameter(torch.zeros(1, 1, dim)) + else: + self.sync_gate = None + + self.sliding_window = sliding_window + + for i in range(depth): + should_cross_attend = cross_attend and (self.final_cross_attn_ix == -1 or i < (self.final_cross_attn_ix)) and (pre_cross_attn_ix == -1 or i >= (pre_cross_attn_ix)) + # print(f"Layer {i} cross attends: {should_cross_attend}") + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = should_cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + add_cond = None, + sync_cond = None, + global_cond = None, + return_info = False, + use_checkpointing = True, + exit_layer_ix = None, + video_dropout_prob = 0.0, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + model_dtype = next(self.parameters()).dtype + x = x.to(model_dtype) + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + if add_cond is not None: + if self.use_gated: + gate = torch.sigmoid(self.gate) + x = x + gate * add_cond + elif self.use_film: + scale, shift = self.film_generator(add_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + else: + x = x + add_cond + + if self.use_add_norm: + x = self.add_norm(x) + if self.fusion_mlp is not None: + x = self.fusion_mlp(x) + + if sync_cond is not None: + if self.sync_film_generator is not None: + scale, shift = self.sync_film_generator(sync_cond).chunk(2, dim=-1) + x = x * (1 + scale) + shift + elif self.sync_gate is not None: + gate_value = torch.sigmoid(self.sync_gate) + x = x + gate_value * sync_cond + # else: + # x = x + sync_cond + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if self.num_memory_tokens > 0: + memory_tokens = self.memory_tokens.expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + if global_cond is not None and self.global_cond_embedder is not None: + global_cond_embed = self.global_cond_embedder(global_cond) + else: + global_cond_embed = global_cond + # Iterate over the transformer layers + for layer_ix, layer in enumerate(self.layers): + if use_checkpointing: + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) + else: + x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond_embed, self_attention_flash_sliding_window = self.sliding_window, sync_cond=sync_cond, prepend_length=prepend_length, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + if exit_layer_ix is not None and layer_ix == exit_layer_ix: + x = x[:, self.num_memory_tokens:, :] + + if return_info: + return x, info + + return x + + x = x[:, self.num_memory_tokens:, :] + if global_cond is not None and self.adaLN_modulation is not None: + if len(global_cond.shape) == 2: + global_cond = global_cond.unsqueeze(1) + shift, scale = self.adaLN_modulation(global_cond).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/ThinkSound/models/utils.py b/ThinkSound/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a90f1e273dd239b1dc43b80f838782bde7411c3b --- /dev/null +++ b/ThinkSound/models/utils.py @@ -0,0 +1,177 @@ +import torch +from safetensors.torch import load_file +from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor +#from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline +from torch.nn.utils import remove_weight_norm + +def load_ckpt_state_dict(ckpt_path, prefix=None): + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + # 过滤特定前缀的state_dict + filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict + + return filtered_state_dict + +def remove_weight_norm_from_model(model): + for module in model.modules(): + if hasattr(module, "weight"): + print(f"Removing weight norm from {module}") + remove_weight_norm(module) + + return model + +# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + + if num_samples == 1: + q = torch.empty_like(input).exponential_(1, generator=generator) + return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) + + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def next_power_of_two(n): + return 2 ** (n - 1).bit_length() + +def next_multiple_of_64(n): + return ((n + 63) // 64) * 64 + + +# mask construction helpers + +def mask_from_start_end_indices( + seq_len: int, + start: Tensor, + end: Tensor +): + assert start.shape == end.shape + device = start.device + + seq = torch.arange(seq_len, device = device, dtype = torch.long) + seq = seq.reshape(*((-1,) * start.ndim), seq_len) + seq = seq.expand(*start.shape, seq_len) + + mask = seq >= start[..., None].long() + mask &= seq < end[..., None].long() + return mask + +def mask_from_frac_lengths( + seq_len: int, + frac_lengths: Tensor +): + device = frac_lengths.device + + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) + start = (max_start * rand).clamp(min = 0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + +def _build_spline(video_feat, video_t, target_t): + # 三次样条插值核心实现 + coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1)) + spline = NaturalCubicSpline(coeffs) + return spline.evaluate(target_t).permute(0,2,1) + +def resample(video_feat, audio_latent): + """ + 9s + video_feat: [B, 72, D] + audio_latent: [B, D', 194] or int + """ + B, Tv, D = video_feat.shape + + if isinstance(audio_latent, torch.Tensor): + # audio_latent is a tensor + if audio_latent.shape[1] != 64: + Ta = audio_latent.shape[1] + else: + Ta = audio_latent.shape[2] + elif isinstance(audio_latent, int): + # audio_latent is an int + Ta = audio_latent + else: + raise TypeError("audio_latent must be either a tensor or an int") + + # 构建时间戳 (关键改进点) + video_time = torch.linspace(0, 9, Tv, device=video_feat.device) + audio_time = torch.linspace(0, 9, Ta, device=video_feat.device) + + # 三维化处理 (Batch, Feature, Time) + video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv] + + # 三次样条插值 + aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] + return aligned_video.permute(0, 2, 1) # [B, Ta, D] + +import os +enable_torch_compile = os.environ.get("ENABLE_TORCH_COMPILE", "0") == "1" + +def compile(function, *args, **kwargs): + + if enable_torch_compile: + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + return function \ No newline at end of file diff --git a/ThinkSound/models/wavelets.py b/ThinkSound/models/wavelets.py new file mode 100644 index 0000000000000000000000000000000000000000..a359e39110c168aab960d3f79262b464a660e55e --- /dev/null +++ b/ThinkSound/models/wavelets.py @@ -0,0 +1,82 @@ +"""The 1D discrete wavelet transform for PyTorch.""" + +from einops import rearrange +import pywt +import torch +from torch import nn +from torch.nn import functional as F +from typing import Literal + + +def get_filter_bank(wavelet): + filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) + if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): + filt = filt[:, 1:] + return filt + +class WaveletEncode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[:2, None] + kernel = torch.flip(kernel, dims=(-1,)) + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels], x[:, self.channels :] + pad = self.kernel.shape[-1] // 2 + low = F.pad(low, (pad, pad), "reflect") + low = F.conv1d(low, self.kernel, stride=2) + rest = rearrange( + rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x + + +class WaveletDecode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[2:, None] + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] + pad = self.kernel.shape[-1] // 2 + 2 + low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) + low = F.pad(low, (pad, pad), "reflect") + low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) + low = F.conv_transpose1d( + low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 + ) + low = low[..., pad - 1 : -pad] + rest = rearrange( + rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x \ No newline at end of file diff --git a/ThinkSound/training/__init__.py b/ThinkSound/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f77486b07a478bc88359bf2ece8b9c860df1b054 --- /dev/null +++ b/ThinkSound/training/__init__.py @@ -0,0 +1 @@ +from .factory import create_training_wrapper_from_config, create_demo_callback_from_config diff --git a/ThinkSound/training/autoencoders.py b/ThinkSound/training/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd10dbdf6feadfa83e5563121af7986fb7f56aa --- /dev/null +++ b/ThinkSound/training/autoencoders.py @@ -0,0 +1,502 @@ +import torch +import torchaudio +import wandb +from einops import rearrange +from safetensors.torch import save_file, save_model +from ema_pytorch import EMA +from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, SpatialSTFTLoss +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Callback +from ..models.autoencoders import AudioAutoencoder +from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss +from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck +from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss +from .utils import create_optimizer_from_config, create_scheduler_from_config, log_audio, log_image, log_metric, log_point_cloud + + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image + +class AutoencoderTrainingWrapper(L.LightningModule): + def __init__( + self, + autoencoder: AudioAutoencoder, + lr: float = 1e-4, + warmup_steps: int = 0, + encoder_freeze_on_warmup: bool = False, + sample_rate=48000, + loss_config: dict = None, + optimizer_configs: dict = None, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: AudioAutoencoder = None + ): + super().__init__() + + self.automatic_optimization = False + + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + + self.force_input_mono = force_input_mono + + self.teacher_model = teacher_model + + if optimizer_configs is None: + optimizer_configs ={ + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + } + + } + + self.optimizer_configs = optimizer_configs + + if loss_config is None: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + loss_config = { + "discriminator": { + "type": "encodec", + "config": { + "n_ffts": scales, + "hop_lengths": hop_sizes, + "win_lengths": win_lengths, + "filters": 32 + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0, + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + }, + "weights": { + "mrstft": 1.0, + } + }, + "time": { + "type": "l1", + "config": {}, + "weights": { + "l1": 0.0, + } + } + } + + self.loss_config = loss_config + + # Spectral reconstruction loss + + stft_loss_args = loss_config['spectral']['config'] + + if self.autoencoder.out_channels == 2: + self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + elif self.autoencoder.out_channels == 4: + # self.sdstft = SpatialSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Discriminator + + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + + self.gen_loss_modules = [] + + # Adversarial and feature matching losses + self.gen_loss_modules += [ + ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), + ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), + ] + + if self.teacher_model is not None: + # Distillation losses + + stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss + AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder + ] + + else: + + # Reconstruction loss + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.autoencoder.out_channels == 2: + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), + AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), + ] + elif self.autoencoder.out_channels == 4: + # self.gen_loss_modules += [ + # AuralossLoss(self.lrstft, 'reals', 'decoded', name='stft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + # ] + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals_w', 'decoded_w', name='stft_loss_w', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_x', 'decoded_x', name='stft_loss_x', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_y', 'decoded_y', name='stft_loss_y', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_z', 'decoded_z', name='stft_loss_z', weight=self.loss_config['spectral']['weights']['mrstft']/4), + ] + + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.loss_config['time']['weights']['l1'] > 0.0: + self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) + + if self.autoencoder.bottleneck is not None: + self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) + + self.losses_gen = MultiLoss(self.gen_loss_modules) + + self.disc_loss_modules = [ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ] + + self.losses_disc = MultiLoss(self.disc_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + + self.use_ema = use_ema + + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + def configure_optimizers(self): + + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) + opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) + + if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + + return [opt_gen, opt_disc] + + def training_step(self, batch, batch_idx): + reals, _ = batch + + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + # Optionally mask out some latents for noise resistance + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + decoded = self.autoencoder.decode(latents) + + loss_info["decoded"] = decoded + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + elif self.autoencoder.out_channels == 4: + loss_info["decoded_w"] = decoded[:, 0:1, :] + loss_info["decoded_x"] = decoded[:, 1:2, :] + loss_info["decoded_y"] = decoded[:, 2:3, :] + loss_info["decoded_z"] = decoded[:, 3:4, :] + loss_info["reals_w"] = reals[:, 0:1, :] + loss_info["reals_x"] = reals[:, 1:2, :] + loss_info["reals_y"] = reals[:, 2:3, :] + loss_info["reals_z"] = reals[:, 3:4, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) + else: + loss_dis = torch.tensor(0.).to(reals) + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen, opt_disc = self.optimizers() + + lr_schedulers = self.lr_schedulers() + + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + sched_gen, sched_disc = lr_schedulers + + # Train the discriminator + if self.global_step % 2 and self.warmed_up: + loss, losses = self.losses_disc(loss_info) + + log_dict = { + 'train/disc_lr': opt_disc.param_groups[0]['lr'] + } + + opt_disc.zero_grad() + self.manual_backward(loss) + + + opt_disc.step() + + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + + # import ipdb + # ipdb.set_trace() + loss, losses = self.losses_gen(loss_info) + + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + opt_gen.step() + + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict = { + 'train/loss': loss.detach(), + 'train/latent_std': latents.std().detach(), + 'train/data_std': data_std.detach(), + 'train/gen_lr': opt_gen.param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AutoencoderDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + module.eval() + + try: + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + + fakes = module.autoencoder.decode(latents) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demos/recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_audio(trainer.logger, 'recon', filename, self.sample_rate) + + log_point_cloud(trainer.logger, 'embeddings_3dpca', latents) + log_image(trainer.logger, 'embeddings_spec', tokens_spectrogram_image(latents)) + log_image(trainer.logger, 'recon_melspec_left', audio_spectrogram_image(reals_fakes)) + + # trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses \ No newline at end of file diff --git a/ThinkSound/training/autoencoders_1.py b/ThinkSound/training/autoencoders_1.py new file mode 100644 index 0000000000000000000000000000000000000000..8f8381011ac16b6d862a0d432e540b1f66dc4757 --- /dev/null +++ b/ThinkSound/training/autoencoders_1.py @@ -0,0 +1,671 @@ +import os +import torch +import torchaudio +import wandb +import pytorch_lightning as pl + +from copy import deepcopy +from typing import Optional, Literal + +from ..models.autoencoders import AudioAutoencoder, fold_channels_into_batch, unfold_channels_from_batch +from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss, BigVGANDiscriminator +from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck +from .losses import MelSpectrogramLoss, MultiLoss, AuralossLoss, ValueLoss, TargetValueLoss, L1Loss, LossWithTarget, MSELoss, HubertLoss, PESQMetric +from .losses import auraloss as auraloss +from .utils import create_optimizer_from_config, create_scheduler_from_config, log_audio, log_image, log_metric, log_point_cloud, logger_project_name + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from ..interface.aeiou import audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_model + +def trim_to_shortest(a, b): + """Trim the longer of two tensors to the length of the shorter one.""" + if a.shape[-1] > b.shape[-1]: + return a[:,:,:b.shape[-1]], b + elif b.shape[-1] > a.shape[-1]: + return a, b[:,:,:a.shape[-1]] + return a, b + +class AutoencoderTrainingWrapper(pl.LightningModule): + def __init__( + self, + autoencoder: AudioAutoencoder, + sample_rate=48000, + loss_config: Optional[dict] = None, + eval_loss_config: Optional[dict] = None, + optimizer_configs: Optional[dict] = None, + lr: float = 1e-4, + warmup_steps: int = 0, + warmup_mode: Literal["adv", "full"] = "adv", + encoder_freeze_on_warmup: bool = False, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: Optional[AudioAutoencoder] = None, + clip_grad_norm = 0.0 + ): + super().__init__() + + self.automatic_optimization = False + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.warmup_mode = warmup_mode + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + self.clip_grad_norm = clip_grad_norm + + self.force_input_mono = force_input_mono + + self.teacher_model = teacher_model + + if optimizer_configs is None: + optimizer_configs ={ + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + } + + } + + self.optimizer_configs = optimizer_configs + + if loss_config is None: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + loss_config = { + "discriminator": { + "type": "encodec", + "config": { + "n_ffts": scales, + "hop_lengths": hop_sizes, + "win_lengths": win_lengths, + "filters": 32 + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0, + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + }, + "weights": { + "mrstft": 1.0, + } + }, + "time": { + "type": "l1", + "config": {}, + "weights": { + "l1": 0.0, + } + } + } + + self.loss_config = loss_config + + # Spectral reconstruction loss + stft_loss_args = loss_config['spectral']['config'] + + self.use_disc = 'discriminator' in loss_config + + if self.autoencoder.out_channels == 2: + self.sdstft = auraloss.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Discriminator + if self.use_disc: + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'big_vgan': + self.discriminator = BigVGANDiscriminator(channels=self.autoencoder.out_channels, sample_rate=sample_rate,**loss_config['discriminator']['config']) + + else: + self.discriminator = None + + self.gen_loss_modules = [] + + # Adversarial and feature matching losses + if self.use_disc: + self.gen_loss_modules += [ + ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), + ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching_loss'), + ] + stft_loss_decay = self.loss_config['spectral'].get('decay', 1.0) + if self.teacher_model is not None: + # Distillation losses + stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + self.gen_loss_modules += [ + MSELoss(key_a='teacher_latents', key_b='latents', weight=stft_loss_weight, name='latent_distill_loss', decay = stft_loss_decay), # Latent space distillation + AuralossLoss(self.sdstft, target_key = 'reals', input_key = 'decoded', name='mrstft_loss', weight=stft_loss_weight, decay = stft_loss_decay), # Reconstruction loss + AuralossLoss(self.sdstft, input_key = 'decoded', target_key = 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight, decay = stft_loss_decay), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, target_key = 'reals', input_key = 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight, decay = stft_loss_decay), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, target_key = 'reals', input_key = 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight, decay = stft_loss_decay) # Teacher's encoder is compatible with distilled model's decoder + ] + + else: + + # Reconstruction loss + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, target_key = 'reals', input_key = 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft'], decay = stft_loss_decay), + ] + + if self.autoencoder.out_channels == 2: + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.lrstft, target_key = 'reals_left', input_key = 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2, decay = stft_loss_decay), + AuralossLoss(self.lrstft, target_key = 'reals_right', input_key = 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2, decay = stft_loss_decay), + ] + + if "mrmel" in loss_config: + mrmel_weight = loss_config["mrmel"]["weights"]["mrmel"] + if mrmel_weight > 0: + mrmel_config = loss_config["mrmel"]["config"] + self.mrmel = MelSpectrogramLoss(sample_rate, + n_mels=mrmel_config["n_mels"], + window_lengths=mrmel_config["window_lengths"], + pow=mrmel_config["pow"], + log_weight=mrmel_config["log_weight"], + mag_weight=mrmel_config["mag_weight"], + ) + self.gen_loss_modules.append(LossWithTarget( + self.mrmel, "reals", "decoded", + name="mrmel_loss", weight=mrmel_weight, + )) + + if "hubert" in loss_config: + hubert_weight = loss_config["hubert"]["weights"]["hubert"] + if hubert_weight > 0: + hubert_cfg = ( + loss_config["hubert"]["config"] + if "config" in loss_config["hubert"] else dict()) + self.hubert = HubertLoss(weight=1.0, **hubert_cfg) + + self.gen_loss_modules.append(LossWithTarget( + self.hubert, target_key = "reals", input_key = "decoded", + name="hubert_loss", weight=hubert_weight, + decay = loss_config["hubert"].get("decay", 1.0) + )) + + if "l1" in loss_config["time"]["weights"]: + if self.loss_config['time']['weights']['l1'] > 0.0: + self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', + weight=self.loss_config['time']['weights']['l1'], + name='l1_time_loss', + decay = self.loss_config['time'].get('decay', 1.0))) + + if "l2" in loss_config["time"]["weights"]: + if self.loss_config['time']['weights']['l2'] > 0.0: + self.gen_loss_modules.append(MSELoss(key_a='reals', key_b='decoded', + weight=self.loss_config['time']['weights']['l2'], + name='l2_time_loss', + decay = self.loss_config['time'].get('decay', 1.0))) + + if self.autoencoder.bottleneck is not None: + self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) + + self.losses_gen = MultiLoss(self.gen_loss_modules) + + if self.use_disc: + self.disc_loss_modules = [ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ] + + self.losses_disc = MultiLoss(self.disc_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + + self.use_ema = use_ema + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + # evaluation losses & metrics + self.eval_losses = torch.nn.ModuleDict() + if eval_loss_config is not None: + if "pesq" in eval_loss_config: + self.eval_losses["pesq"] = PESQMetric(sample_rate) + if "stft"in eval_loss_config: + self.eval_losses["stft"] = auraloss.STFTLoss(**eval_loss_config["stft"]) + if "sisdr" in eval_loss_config: + self.eval_losses["sisdr"] = auraloss.SISDRLoss(**eval_loss_config["sisdr"]) + if "mel" in eval_loss_config: + self.eval_losses["mel"] = auraloss.MelSTFTLoss( + sample_rate, **eval_loss_config["mel"]) + + self.validation_step_outputs = [] + + + def configure_optimizers(self): + gen_params = list(self.autoencoder.parameters()) + + if self.use_disc: + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], gen_params) + opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) + if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + return [opt_gen, opt_disc] + else: + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], gen_params) + if "scheduler" in self.optimizer_configs['autoencoder']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + return [opt_gen], [sched_gen] + return [opt_gen] + + def forward(self, reals): + latents, encoder_info = self.autoencoder.encode(reals, return_info=True) + decoded = self.autoencoder.decode(latents) + return decoded + + def validation_step(self, batch, batch_idx): + reals, _ = batch + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if len(reals.shape) == 2: + reals = reals.unsqueeze(1) + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + loss_info["latents"] = latents + loss_info.update(encoder_info) + + decoded = self.autoencoder.decode(latents) + #Trim output to remove post-padding. + decoded, reals = trim_to_shortest(decoded, reals) + + # Run evaluation metrics. + val_loss_dict = {} + for eval_key, eval_fn in self.eval_losses.items(): + loss_value = eval_fn(decoded, reals) + if eval_key == "sisdr": loss_value = -loss_value + if isinstance(loss_value, torch.Tensor): + loss_value = loss_value.item() + + val_loss_dict[eval_key] = loss_value + + self.validation_step_outputs.append(val_loss_dict) + return val_loss_dict + + def on_validation_epoch_end(self): + sum_loss_dict = {} + for loss_dict in self.validation_step_outputs: + for key, value in loss_dict.items(): + if key not in sum_loss_dict: + sum_loss_dict[key] = value + else: + sum_loss_dict[key] += value + + for key, value in sum_loss_dict.items(): + val_loss = value / len(self.validation_step_outputs) + val_loss = self.all_gather(val_loss).mean().item() + log_metric(self.logger, f"val/{key}", val_loss) + + self.validation_step_outputs.clear() # free memory + + def training_step(self, batch, batch_idx): + reals, _ = batch + + log_dict = {} + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if len(reals.shape) == 2: + reals = reals.unsqueeze(1) + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + # Optionally mask out some latents for noise resistance + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + + decoded = self.autoencoder.decode(latents) + + #Trim output to remove post-padding + decoded, reals = trim_to_shortest(decoded, reals) + + loss_info["decoded"] = decoded + loss_info["reals"] = reals + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + if self.use_disc: + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals=reals, fakes=decoded) + else: + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + if self.warmup_mode == "adv": + loss_dis, _, _ = self.discriminator.loss(reals=reals, fakes=decoded) + else: + loss_dis = torch.tensor(0.0).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen = None + opt_disc = None + + if self.use_disc: + opt_gen, opt_disc = self.optimizers() + else: + opt_gen = self.optimizers() + + lr_schedulers = self.lr_schedulers() + + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + if self.use_disc: + sched_gen, sched_disc = lr_schedulers + else: + sched_gen = lr_schedulers + + # Train the discriminator + use_disc = ( + self.use_disc + and self.global_step % 2 + # Check warmup mode and if it is time to use discriminator. + and ( + (self.warmup_mode == "full" and self.warmed_up) + or self.warmup_mode == "adv") + ) + if use_disc: + loss, losses = self.losses_disc(loss_info) + + log_dict['train/disc_lr'] = opt_disc.param_groups[0]['lr'] + + opt_disc.zero_grad() + self.manual_backward(loss) + if self.clip_grad_norm > 0.0: + torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.clip_grad_norm) + opt_disc.step() + + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + + loss, losses = self.losses_gen(loss_info) + + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + if self.clip_grad_norm > 0.0: + torch.nn.utils.clip_grad_norm_(self.autoencoder.parameters(), self.clip_grad_norm) + opt_gen.step() + + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict['train/loss'] = loss.detach().item() + log_dict['train/latent_std'] = latents.std().detach().item() + log_dict['train/data_std'] = data_std.detach().item() + log_dict['train/gen_lr'] = opt_gen.param_groups[0]['lr'] + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach().item() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class AutoencoderDemoCallback(pl.Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + sample_size=65536, + sample_rate=44100, + max_demos = 8 + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = iter(deepcopy(demo_dl)) + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.max_demos = max_demos + + + @rank_zero_only + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + module.eval() + + try: + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + # Limit the number of demo samples + if demo_reals.shape[0] > self.max_demos: + demo_reals = demo_reals[:self.max_demos,...] + + encoder_input = demo_reals + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + fakes = module.autoencoder.decode(latents) + + #Trim output to remove post-padding. + fakes, demo_reals = trim_to_shortest(fakes.detach(), demo_reals) + log_dict = {} + + if module.discriminator is not None: + window = torch.kaiser_window(512).to(fakes.device) + fakes_stft = torch.stft(fold_channels_into_batch(fakes), n_fft=512, hop_length=128, win_length=512, window = window, center=True, return_complex=True) + fakes_stft.requires_grad = True + fakes_signal = unfold_channels_from_batch(torch.istft(fakes_stft, n_fft=512, hop_length=128, win_length=512, window = window, center=True), fakes.shape[1]) + real_stft = torch.stft(fold_channels_into_batch(demo_reals), n_fft=512, hop_length=128, win_length=512, window = window, center=True, return_complex=True) + reals_signal = unfold_channels_from_batch(torch.istft(real_stft, n_fft=512, hop_length=128, win_length=512, window = window, center=True), demo_reals.shape[1]) + _, loss, _ = module.discriminator.loss(reals_signal,fakes_signal) + fakes_stft.retain_grad() + loss.backward() + grads = unfold_channels_from_batch(fakes_stft.grad.detach().abs(),fakes.shape[1]) + log_dict[f'disciminator_sensitivity'] = wandb.Image(tokens_spectrogram_image(grads.mean(dim=1).log10(), title = 'Discriminator Sensitivity', symmetric = False)) + opts = module.optimizers() + opts[0].zero_grad() + opts[1].zero_grad() + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + try: + data_dir = os.path.join( + trainer.logger.save_dir, logger_project_name(trainer.logger), + trainer.logger.experiment.id, "media") + os.makedirs(data_dir, exist_ok=True) + filename = os.path.join(data_dir, f'recon_{trainer.global_step:08}.wav') + except: + filename = f'recon_{trainer.global_step:08}.wav' + + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_audio(trainer.logger, 'recon', filename, self.sample_rate) + log_point_cloud(trainer.logger, 'embeddings_3dpca', latents) + log_image(trainer.logger, 'embeddings_spec', tokens_spectrogram_image(latents)) + log_image(trainer.logger, 'recon_melspec_left', audio_spectrogram_image(reals_fakes)) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses diff --git a/ThinkSound/training/diffusion.py b/ThinkSound/training/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ad9e1cb5bb526eac6a4f22725f1ade1614fe7 --- /dev/null +++ b/ThinkSound/training/diffusion.py @@ -0,0 +1,2188 @@ +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Callback +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb +# from beartype.typing import Tuple +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +import auraloss +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler, truncated_logistic_normal_rescaled, sample_timesteps_logsnr +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper +from ..models.autoencoders import DiffusionAutoencoder +from ..models.diffusion_prior import PriorType +from .autoencoders import create_loss_modules_from_bottleneck +from .losses import AuralossLoss, MSELoss, MultiLoss +from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask, log_audio, log_image, log_metric, log_point_cloud +import os +from pathlib import Path +from time import time +import numpy as np + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionUncondTrainingWrapper(L.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: DiffusionModelWrapper, + lr: float = 1e-4, + pre_encoded: bool = False + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + diffusion_input = reals + + loss_info = {} + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + with torch.amp.autocast('cuda'): + v = self.diffusion(noised_inputs, t) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionUncondDemoCallback(Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_samples = module.diffusion.sample_size + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + with torch.amp.autocast('cuda'): + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionInfillTrainingWrapper(L.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + optimizer_configs: dict = None, + pre_encoded: bool = False, + frac_lengths_mask = (0.7, 1.), + min_span_len = 10, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + diffusion_objective = 'rectified_flow', + ctx_drop: float = 0.1, + r_drop: float = 0.0, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + self.frac_lengths_mask = frac_lengths_mask + self.min_span_len = min_span_len + self.timestep_sampler = timestep_sampler + self.ctx_drop = ctx_drop + self.r_drop = r_drop + self.diffusion_objective = diffusion_objective + print(f'Training in the {diffusion_objective} formulation') + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss", + mask_key="mask" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + reals, metadata = batch + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + # import ipdb + # ipdb.set_trace() + p_drop = torch.rand(1).item() + # r_drop = torch.rand(1).item() + # if p_drop >= self.ctx_drop and self.r_drop > 0.0 and r_drop < self.r_drop: + # generate_channel_mask(reals) + + diffusion_input = reals + assert torch.all(torch.isfinite(diffusion_input)), "Non-finite values detected in diffusion_input" + p = Profiler() + loss_info = {} + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + conditioning = {} + + p.tick("conditioning") + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + elif self.timestep_sampler == "trunc_logit_normal": + # Draw from logistic truncated normal distribution + t = truncated_logistic_normal_rescaled(reals.shape[0]).to(self.device) + + # Flip the distribution + t = 1 - t + + # # Calculate the noise schedule parameters for those timesteps + # alphas, sigmas = get_alphas_sigmas(t) + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + # x_ctx = diffusion_input.detach().clone().transpose(1,2) + bsz, dim, seq_len = diffusion_input.shape + + + if p_drop < self.ctx_drop: + ctx_mask = torch.ones((bsz, seq_len), device = diffusion_input.device, dtype = torch.bool) + # elif self.r_drop > 0.0 and r_drop < self.r_drop: + # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool) + else: + # 计算 frac_lengths 提前使用 + frac_lengths = torch.zeros((bsz,), device=diffusion_input.device).uniform_(*self.frac_lengths_mask) + # if self.r_drop > 0.0 and r_drop < self.r_drop: + # import ipdb + # ipdb.set_trace() + + # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool) + # else: + ctx_mask = generate_mask(bsz, seq_len, frac_lengths, self.min_span_len) + + if ctx_mask.dim() == 2: + ctx_mask = ctx_mask.unsqueeze(1) + masked_sequence = diffusion_input * ~ctx_mask + conditioning['x_ctx'] = [masked_sequence] + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + with torch.amp.autocast('cuda'): + p.tick("amp") + v = self.diffusion(noised_inputs, t, cond=conditioning) + p.tick("diffusion") + loss_info.update({ + "v": v, + "targets": targets, + "mask": ctx_mask.squeeze(-1) + }) + # import ipdb + # ipdb.set_trace() + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionInfillDemoCallback(Callback): + def __init__(self, + demo_dl, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + self.demo_dl = iter(demo_dl) + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + + try: + demo_reals, _ = next(self.demo_dl) + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + reals = demo_reals + log_dict = {} + + if not module.pre_encoded: + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.amp.autocast('cuda'): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = {} + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + frac_lengths = torch.zeros((demo_reals.shape[0],), device = module.device).uniform_(*(0.3,0.5)) + ctx_mask = generate_mask(demo_reals.shape[0],demo_reals.shape[2], frac_lengths, module.min_span_len) + # x_ctx = (demo_reals * ~ctx_mask.unsqueeze(1)).transpose(1,2) + x_ctx = demo_reals * ~ctx_mask.unsqueeze(1) + + conditioning['x_ctx'] = [x_ctx] + # x_ctx_mask = x_ctx * ~ctx_mask.unsqueeze(-1) + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(x_ctx.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(x_ctx, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + with torch.amp.autocast('cuda'): + if module.diffusion_objective == "v": + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(module.diffusion_ema, noise, self.demo_steps, **cond_inputs) + # fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # #Interleave reals and fakes + # reals_fakes = rearrange([reals, fakes], 'i b d n -> (b i) d n') + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + + filename = f'results/audio_ssl/demo_ssl_{trainer.global_step:08}.wav' + os.makedirs(Path(filename).parent,exist_ok=True) + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + log_audio(trainer.logger, f'demo_ssl_{trainer.global_step:08}', filename, self.sample_rate) + log_image(trainer.logger, f'demo_melspec_left', audio_spectrogram_image(fakes)) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionCondTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = None, + mask_padding: bool = False, + mask_padding_dropout: float = 0.0, + use_ema: bool = True, + log_loss_info: bool = False, + optimizer_configs: dict = None, + diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + video_dropout_prob = 0.2, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + validation_timesteps = [0.1, 0.3, 0.5, 0.7, 0.9], + max_mask_segments = 0, + ): + super().__init__() + + self.diffusion = model + + if use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.mask_padding = mask_padding + self.mask_padding_dropout = mask_padding_dropout + + self.cfg_dropout_prob = cfg_dropout_prob + self.video_dropout_prob = video_dropout_prob + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + if self.timestep_sampler == "log_snr": + self.mean_logsnr = -1.2 + self.std_logsnr = 2.0 + self.diffusion_objective = model.diffusion_objective + print(f'Training in the {self.diffusion_objective} formulation with timestep sampler: {timestep_sampler}') + + self.max_mask_segments = max_mask_segments + + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + mask_key="padding_mask" if self.mask_padding else None, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + # Validation + self.validation_timesteps = validation_timesteps + + self.validation_step_outputs = {} + + for validation_timestep in self.validation_timesteps: + self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = [] + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + reals, metadata = batch + # import ipdb + # ipdb.set_trace() + p = Profiler() + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + if not self.pre_encoded: + if reals.shape[1] != 64: + diffusion_input = reals + else: + waveform = torch.stack([item['waveform'] for item in metadata],dim=0) + diffusion_input = waveform + loss_info["audio_reals"] = diffusion_input + else: + diffusion_input = reals + #print(diffusion_input,flush=True) + p.tick("setup") + + conditioning = self.diffusion.conditioner(metadata, self.device) + + # video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + # if isinstance(conditioning['metaclip_features'], list): + # empty_clip_feat = self.diffusion.model.model.empty_clip_feat.to(conditioning['metaclip_features'][0].dtype) + # conditioning['metaclip_features'][0][~video_exist] = empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # empty_sync_feat = self.diffusion.model.model.empty_sync_feat.to(conditioning['sync_features'][0].dtype) + # conditioning['sync_features'][0][~video_exist] = empty_sync_feat + # else: + # conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + + # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding + use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout + + # Create batch tensor of attention masks from the "mask" field of the metadata array + if use_padding_mask: + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + if use_padding_mask: + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + if self.max_mask_segments > 0: + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = masked_input + elif self.max_mask_segments == -1: + source_latent = torch.stack([item['source_latent'] for item in metadata],dim=0) + conditioning['inpaint_masked_input'] = source_latent + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + elif self.timestep_sampler == "trunc_logit_normal": + # Draw from logistic truncated normal distribution + t = truncated_logistic_normal_rescaled(reals.shape[0]).to(self.device) + + # Flip the distribution + t = 1 - t + elif self.timestep_sampler == "log_snr": + t = sample_timesteps_logsnr(reals.shape[0], mean_logsnr=self.mean_logsnr, std_logsnr=self.std_logsnr).to(self.device) + else: + raise ValueError(f"Invalid timestep_sampler: {self.timestep_sampler}") + # import ipdb + # ipdb.set_trace() + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + if use_padding_mask: + extra_args["mask"] = padding_masks + + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, video_dropout_prob=self.video_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + }) + + loss, losses = self.losses(loss_info) + + p.tick("loss") + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def validation_step(self, batch, batch_idx): + reals, metadata = batch + # breakpoint() + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + if reals.shape[1]!=64: + diffusion_input = reals + else: + waveform = torch.stack([item['waveform'] for item in metadata],dim=0) + diffusion_input = waveform + loss_info["audio_reals"] = diffusion_input + else: + diffusion_input = reals + + conditioning = self.diffusion.conditioner(metadata, self.device) + + # video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + # if isinstance(conditioning['metaclip_features'], list): + # empty_clip_feat = self.diffusion.model.model.empty_clip_feat.to(conditioning['metaclip_features'][0].dtype) + # conditioning['metaclip_features'][0][~video_exist] = empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # empty_sync_feat = self.diffusion.model.model.empty_sync_feat.to(conditioning['sync_features'][0].dtype) + # conditioning['sync_features'][0][~video_exist] = empty_sync_feat + # else: + # conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + + if self.diffusion.pretransform is not None: + + if not self.pre_encoded: + self.diffusion.pretransform.to(self.device) + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + if self.max_mask_segments > 0: + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = masked_input + # elif 'source_latent' in conditioning.keys(): + # conditioning['inpaint_masked_input'] = conditioning['source_latent'] + for validation_timestep in self.validation_timesteps: + + t = torch.full((reals.shape[0],), validation_timestep, device=self.device) + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + #print("diffusion_input",diffusion_input.dtype) + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + + with torch.no_grad(): + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.0, video_dropout_prob=0.0) + + val_loss = F.mse_loss(output, targets) + + self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'].append(val_loss.item()) + + def on_validation_epoch_end(self): + log_dict = {} + for validation_timestep in self.validation_timesteps: + + outputs_key = f'val/loss_{validation_timestep:.1f}' + val_loss = sum(self.validation_step_outputs[outputs_key]) / len(self.validation_step_outputs[outputs_key]) + + # Gather losses across all GPUs + val_loss = self.all_gather(val_loss).mean().item() + log_dict[outputs_key] = val_loss + # log_metric(self.logger, outputs_key, val_loss, step=self.global_step) + + # Get average over all timesteps + val_loss = torch.tensor([val for val in self.validation_step_outputs.values()]).mean() + + # Gather losses across all GPUs + val_loss = self.all_gather(val_loss).mean().item() + log_dict['val/avg_loss'] = val_loss + # log_metric(self.logger, 'val/avg_loss', val_loss, step=self.global_step) + self.log_dict(log_dict, prog_bar=True, on_epoch=True, sync_dist=True) + # Reset validation losses + for validation_timestep in self.validation_timesteps: + self.validation_step_outputs[f'val/loss_{validation_timestep:.1f}'] = [] + + def predict_step(self, batch, batch_idx): + reals, metadata = batch + + ids = [item['id'] for item in metadata] + # batch_size, length = reals.shape[0], reals.shape[2] + batch_size, length = reals.shape[0], 194 + conditioning = self.diffusion.conditioner(metadata, self.device) + + # video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + # if isinstance(conditioning['metaclip_features'], list): + # empty_clip_feat = self.diffusion.model.model.empty_clip_feat.to(conditioning['metaclip_features'][0].dtype) + # conditioning['metaclip_features'][0][~video_exist] = empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # empty_sync_feat = self.diffusion.model.model.empty_sync_feat.to(conditioning['sync_features'][0].dtype) + # conditioning['sync_features'][0][~video_exist] = empty_sync_feat + # else: + # conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + # if 'sync_features' in conditioning.keys(): + # conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + + # if self.max_mask_segments > 0: + # bsz, dim, seq_len = reals.shape + # mask_start = torch.randint(0, seq_len-43+1,(1,)).item() + # inpaint_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=self.device) + # inpaint_mask[:, mask_start:mask_start+43] = 0 + # inpaint_mask = inpaint_mask.unsqueeze(1).expand(-1, dim, -1) + # masked_input = reals * inpaint_mask + # conditioning['inpaint_masked_input'] = masked_input + if self.max_mask_segments > 0: + # 读取音频文件 + # import ipdb + # ipdb.set_trace() + # audio_path = 'interactive/iter1/_8JGpCc5RI8_000160_stage2.mp4' + # reals, sample_rate = torchaudio.load(audio_path) + # reals = reals.unsqueeze(0).to(self.device) + # reals = self.diffusion.pretransform.encode(reals) + bsz, dim, seq_len = reals.shape + inpaint_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=self.device) + inpaint_mask[:, 100:140] = 0 + inpaint_mask = inpaint_mask.unsqueeze(1).expand(-1, dim, -1) + masked_input = reals * inpaint_mask + conditioning['inpaint_masked_input'] = masked_input + elif self.max_mask_segments ==-1: + source_latent = torch.stack([item['source_latent'] for item in metadata],dim=0) + conditioning['inpaint_masked_input'] = source_latent + + cond_inputs = self.diffusion.get_conditioning_inputs(conditioning) + if batch_size > 1: + noise_list = [] + for _ in range(batch_size): + noise_1 = torch.randn([1, self.diffusion.io_channels, length]).to(self.device) # 每次生成推进RNG状态 + noise_list.append(noise_1) + noise = torch.cat(noise_list, dim=0) + else: + noise = torch.randn([batch_size, self.diffusion.io_channels, length]).to(self.device) + + model = self.diffusion.model + if self.diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif self.diffusion_objective == "rectified_flow": + import time + start_time = time.time() + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + # import ipdb + # ipdb.set_trace() + # fakes_list = [] + # for i in range(5): + # noise = torch.randn([batch_size, self.diffusion.io_channels, length]).to(self.device) + # fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + # if self.max_mask_segments > 0: + # masked_input = torch.zeros_like(fakes, dtype=fakes.dtype, device=self.device) + # masked_input[:,:,:100] = fakes[:,:,-100:] + # conditioning['inpaint_masked_input'] = masked_input + + # if i == 0: + # # 第一次迭代,保存完整的音频 + # fakes_list.append(fakes) + # else: + # # 后续迭代,只保存最后94长度的音频 + # fakes_list.append(fakes[:, :, -94:]) + + # # 组合所有音频片段 + # fakes = torch.cat(fakes_list, dim=2) + end_time = time.time() + execution_time = end_time - start_time + #print(f"执行时间: {execution_time:.2f} 秒") + if self.diffusion.pretransform is not None: + fakes = self.diffusion.pretransform.decode(fakes) + if self.max_mask_segments > 0: + masked_input = self.diffusion.pretransform.decode(masked_input) + elif self.max_mask_segments == -1: + masked_input = self.diffusion.pretransform.decode(source_latent) + + + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + if self.max_mask_segments > 0: + masked_audios = masked_input.to(torch.float32).div(torch.max(torch.abs(masked_input))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + return (masked_audios, audios) + elif self.max_mask_segments==-1: + masked_audios = masked_input.to(torch.float32).div(torch.max(torch.abs(masked_input))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + return (masked_audios, audios) + return audios + + def predict_step_inpaint(self, batch, batch_idx): + reals, metadata = batch + # import ipdb + # ipdb.set_trace() + ids = [item['id'] for item in metadata] + batch_size, length = reals.shape[0], reals.shape[2] + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + max_mask_length = reals.shape[2] + # Create a mask of random length for a random slice of the input + bsz, dim, seq_len = reals.shape + + cond_inputs = self.diffusion.get_conditioning_inputs(conditioning) + if batch_size > 1: + noise_list = [] + for _ in range(batch_size): + noise_1 = torch.randn([1, self.diffusion.io_channels, length]).to(self.device) # 每次生成推进RNG状态 + noise_list.append(noise_1) + noise = torch.cat(noise_list, dim=0) + else: + noise = torch.randn([batch_size, self.diffusion.io_channels, length]).to(self.device) + with torch.amp.autocast('cuda'): + + model = self.diffusion.model + if self.diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif self.diffusion_objective == "rectified_flow": + import time + start_time = time.time() + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + end_time = time.time() + execution_time = end_time - start_time + print(f"执行时间: {execution_time:.2f} 秒") + if self.diffusion.pretransform is not None: + fakes = self.diffusion.pretransform.decode(fakes) + masked_input = self.diffusion.pretransform.decode(masked_input) + + masked_audios = masked_input.to(torch.float32).div(torch.max(torch.abs(masked_input))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + # return audios + return (masked_audios, audios) + + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + try: + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + except Exception as e: + print(f"failed to export model to: {path}. {e} occured") + + +class DiffusionCondDemoCallback(Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + demo_steps=250, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + demo_cond_from_batch: bool = False, + display_audio_cond: bool = False + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + # If true, the callback will use the metadata from the batch to generate the demo conditioning + self.demo_cond_from_batch = demo_cond_from_batch + + # If true, the callback will display the audio conditioning + self.display_audio_cond = display_audio_cond + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_samples = self.demo_samples + + demo_cond = self.demo_conditioning + + if self.demo_cond_from_batch: + # Get metadata from the batch + demo_cond = batch[1][:self.num_demos] + + if '.pth' in demo_cond[0]: + demo_cond_data = [] + for path in demo_cond: + # info = {} + data = torch.load(path, weights_only=True) + if 'caption_t5' not in data.keys(): + data['caption_t5'] = data['caption'] + data['seconds_start'] = 0 + data['seconds_total'] = 10 + demo_cond_data.append(data) + demo_cond = demo_cond_data + elif '.npz' in demo_cond[0]: + demo_cond_data = [] + for path in demo_cond: + # info = {} + npz_data = np.load(path,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + for key in data.keys(): + # print(key) + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + + demo_cond_data.append(data) + demo_cond = demo_cond_data + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + print("Getting conditioning") + conditioning = module.diffusion.conditioner(demo_cond, module.device) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + log_dict = {} + + if self.display_audio_cond: + audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) + audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') + + filename = f'demo_audio_cond_{trainer.global_step:08}.wav' + audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, audio_inputs, self.sample_rate) + log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") + log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + + print(f"Generating demo for cfg scale {cfg_scale}") + + # model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + model = module.diffusion.model + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demos/demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + log_audio(trainer.logger, f'demo_cfg_{cfg_scale}', filename, self.sample_rate) + log_image(trainer.logger, f'demo_melspec_left_cfg_{cfg_scale}', audio_spectrogram_image(fakes)) + + del fakes + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() + +class DiffusionCondInpaintTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + max_mask_segments = 10, + log_loss_info: bool = False, + optimizer_configs: dict = None, + use_ema: bool = True, + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + ): + super().__init__() + + self.diffusion = model + + self.use_ema = use_ema + + if self.use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.cfg_dropout_prob = cfg_dropout_prob + + self.lr = lr + self.max_mask_segments = max_mask_segments + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + # if use_padding_mask: + # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + # conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + with torch.amp.autocast('cuda'): + p.tick("amp") + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + }) + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondInpaintDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.demo_cfg_scales = demo_cfg_scales + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + try: + log_dict = {} + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + if not module.pre_encoded: + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.amp.autocast('cuda'): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = module.diffusion.conditioner(metadata, module.device) + + masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + print(f"Generating demo for cfg scale {cfg_scale}") + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + with torch.amp.autocast('cuda'): + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + +class DiffusionAutoencoderTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a diffusion autoencoder + ''' + def __init__( + self, + model: DiffusionAutoencoder, + lr: float = 1e-4, + ema_copy = None, + use_reconstruction_loss: bool = False + ): + super().__init__() + + self.diffae = model + + self.diffae_ema = EMA( + self.diffae, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + if model.bottleneck is not None: + # TODO: Use loss config for configurable bottleneck weights and reconstruction losses + loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffae.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.diffae.pretransform is not None: + with torch.no_grad(): + reals = self.diffae.pretransform.encode(reals) + + loss_info["reals"] = reals + + #Encode reals, skipping the pretransform since it was already applied + latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + if self.diffae.decoder is not None: + latents = self.diffae.decoder(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != reals.shape[2]: + latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') + + loss_info["latents_upsampled"] = latents + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.amp.autocast('cuda'): + v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffae.pretransform is not None: + pred = self.diffae.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std(), + 'train/latent_std': latents.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffae_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.diffae_ema.ema_model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionAutoencoderDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad() and torch.amp.autocast('cuda'): + latents = module.diffae_ema.ema_model.encode(encoder_input).float() + fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + if module.diffae_ema.ema_model.pretransform is not None: + with torch.no_grad() and torch.amp.autocast('cuda'): + initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) + first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) + first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') + first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() + first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' + torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) + + log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) + + log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, + sample_rate=self.sample_rate, + caption=f'First Stage Reconstructed') + + log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) + + + trainer.logger.experiment.log(log_dict) + +def create_source_mixture(reals, num_sources=2): + # Create a fake mixture source by mixing elements from the training batch together with random offsets + source = torch.zeros_like(reals) + for i in range(reals.shape[0]): + sources_added = 0 + + js = list(range(reals.shape[0])) + random.shuffle(js) + for j in js: + if i == j or (i != j and sources_added < num_sources): + # Randomly offset the mixed element between 0 and the length of the source + seq_len = reals.shape[2] + offset = random.randint(0, seq_len-1) + source[i, :, offset:] += reals[j, :, :-offset] + if i == j: + # If this is the real one, shift the reals as well to ensure alignment + new_reals = torch.zeros_like(reals[i]) + new_reals[:, offset:] = reals[i, :, :-offset] + reals[i] = new_reals + sources_added += 1 + + return source + +class DiffusionPriorTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a diffusion prior for inverse problems + Prior types: + mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + ema_copy = None, + prior_type: PriorType = PriorType.MonoToStereo, + use_reconstruction_loss: bool = False, + log_loss_info: bool = False, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.log_loss_info = log_loss_info + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.io_channels + + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + self.audio_out_channels = out_channels + + if self.audio_out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Add left and right channel reconstruction losses in addition to the sum and difference + loss_modules += [ + AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), + AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + ] + + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + self.prior_type = prior_type + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.prior_type == PriorType.MonoToStereo: + source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) + loss_info["audio_reals_mono"] = source + else: + raise ValueError(f"Unknown prior type {self.prior_type}") + + if self.diffusion.pretransform is not None: + with torch.no_grad(): + reals = self.diffusion.pretransform.encode(reals) + + if self.prior_type in [PriorType.MonoToStereo]: + source = self.diffusion.pretransform.encode(source) + + if self.diffusion.conditioner is not None: + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + else: + conditioning = {} + + loss_info["reals"] = reals + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.amp.autocast('cuda'): + + conditioning['source'] = [source] + + v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffusion.pretransform is not None: + pred = self.diffusion.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + if self.audio_out_channels == 2: + loss_info["pred_left"] = pred[:, 0:1, :] + loss_info["pred_right"] = pred[:, 1:2, :] + loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] + loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(v, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std() + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + #model = self.diffusion_ema.ema_model + model = self.diffusion + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionPriorDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, metadata = next(self.demo_dl) + # import ipdb + # ipdb.set_trace() + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + encoder_input = demo_reals + + if module.diffusion.conditioner is not None: + with torch.amp.autocast('cuda'): + conditioning_tensors = module.diffusion.conditioner(metadata, module.device) + + else: + conditioning_tensors = {} + + + with torch.no_grad() and torch.amp.autocast('cuda'): + if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: + source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) + + if module.diffusion.pretransform is not None: + encoder_input = module.diffusion.pretransform.encode(encoder_input) + source_input = module.diffusion.pretransform.encode(source) + else: + source_input = source + + conditioning_tensors['source'] = [source_input] + + fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_mono_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + #Log the source + filename = f'source_{trainer.global_step:08}.wav' + source = rearrange(source, 'b d n -> d (b n)') + source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, source, self.sample_rate) + + log_dict[f'source'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Source') + + log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) + + trainer.logger.experiment.log(log_dict) \ No newline at end of file diff --git a/ThinkSound/training/factory.py b/ThinkSound/training/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc2ac52221c87c17eca5f17ff9bd744cda43cc2 --- /dev/null +++ b/ThinkSound/training/factory.py @@ -0,0 +1,263 @@ +import torch +from torch.nn import Parameter +from ..models.factory import create_model_from_config + +def create_training_wrapper_from_config(model_config, model): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + if model_type == 'autoencoder': + from .autoencoders import AutoencoderTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + ema_copy = create_model_from_config(model_config) + ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + + latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) + + teacher_model = training_config.get("teacher_model", None) + if teacher_model is not None: + teacher_model = create_model_from_config(teacher_model) + teacher_model = teacher_model.eval().requires_grad_(False) + + teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) + if teacher_model_ckpt is not None: + teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) + else: + raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") + + return AutoencoderTrainingWrapper( + model, + lr=training_config["learning_rate"], + warmup_steps=training_config.get("warmup_steps", 0), + encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), + sample_rate=model_config["sample_rate"], + loss_config=training_config.get("loss_configs", None), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=use_ema, + ema_copy=ema_copy if use_ema else None, + force_input_mono=training_config.get("force_input_mono", False), + latent_mask_ratio=latent_mask_ratio, + teacher_model=teacher_model + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondTrainingWrapper + return DiffusionUncondTrainingWrapper( + model, + lr=training_config["learning_rate"], + pre_encoded=training_config.get("pre_encoded", False), + ) + elif model_type == 'diffusion_infill': + from .diffusion import DiffusionInfillTrainingWrapper + return DiffusionInfillTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + frac_lengths_mask=training_config.get("frac_lengths_mask", (0.7, 1.)), + min_span_len=training_config.get("min_span_len", 10), + timestep_sampler = training_config.get("timestep_sampler", "uniform"), + ctx_drop = training_config.get("ctx_drop", 0.1), + r_drop = training_config.get("r_drop", 0.0) + ) + elif model_type == 'diffusion_cond' or model_type == 'mm_diffusion_cond': + from .diffusion import DiffusionCondTrainingWrapper + return DiffusionCondTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + mask_padding=training_config.get("mask_padding", False), + mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), + use_ema = training_config.get("use_ema", True), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + diffusion_objective=training_config.get("diffusion_objective","v"), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + video_dropout_prob = training_config.get("video_dropout_prob", 0.2), + timestep_sampler = training_config.get("timestep_sampler", "uniform"), + max_mask_segments = training_config.get("max_mask_segments", 0) + ) + elif model_type == 'diffusion_prior': + from .diffusion import DiffusionPriorTrainingWrapper + from ..models.diffusion_prior import PriorType + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + prior_type = training_config.get("prior_type", "mono_stereo") + + if prior_type == "mono_stereo": + prior_type_enum = PriorType.MonoToStereo + else: + raise ValueError(f"Unknown prior type: {prior_type}") + + return DiffusionPriorTrainingWrapper( + model, + lr=training_config["learning_rate"], + ema_copy=ema_copy, + prior_type=prior_type_enum, + log_loss_info=training_config.get("log_loss_info", False), + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), + ) + elif model_type == 'diffusion_cond_inpaint': + from .diffusion import DiffusionCondInpaintTrainingWrapper + return DiffusionCondInpaintTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + max_mask_segments = training_config.get("max_mask_segments", 10), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=training_config.get("use_ema", True), + pre_encoded=training_config.get("pre_encoded", False), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler = training_config.get("timestep_sampler", "uniform") + ) + elif model_type == 'diffusion_autoencoder' : + from .diffusion import DiffusionAutoencoderTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return DiffusionAutoencoderTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config["learning_rate"], + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) + ) + elif model_type == 'lm': + from .lm import AudioLanguageModelTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return AudioLanguageModelTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config.get("learning_rate", None), + use_ema=training_config.get("use_ema", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_demo_callback_from_config(model_config, **kwargs): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + demo_config = training_config.get("demo", {}) + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderDemoCallback + return AutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondDemoCallback + return DiffusionUncondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"] + ) + elif model_type == 'diffusion_infill': + from .diffusion import DiffusionInfillDemoCallback + return DiffusionInfillDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_autoencoder": + from .diffusion import DiffusionAutoencoderDemoCallback + return DiffusionAutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_prior": + from .diffusion import DiffusionPriorDemoCallback + return DiffusionPriorDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_cond" or model_type == 'mm_diffusion_cond': + from .diffusion import DiffusionCondDemoCallback + + return DiffusionCondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + num_demos=demo_config["num_demos"], + demo_cfg_scales=demo_config["demo_cfg_scales"], + demo_conditioning=demo_config.get("demo_cond", {}), + demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), + display_audio_cond=demo_config.get("display_audio_cond", False), + ) + elif model_type == "diffusion_cond_inpaint": + from .diffusion import DiffusionCondInpaintDemoCallback + + return DiffusionCondInpaintDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + demo_cfg_scales=demo_config["demo_cfg_scales"], + **kwargs + ) + + elif model_type == "lm": + from .lm import AudioLanguageModelDemoCallback + + return AudioLanguageModelDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), + demo_conditioning=demo_config.get("demo_cond", None), + num_demos=demo_config.get("num_demos", 8), + **kwargs + ) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') \ No newline at end of file diff --git a/ThinkSound/training/lm.py b/ThinkSound/training/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fa9f71c805f8d4083919d5c46422c5b7eeb4a8 --- /dev/null +++ b/ThinkSound/training/lm.py @@ -0,0 +1,267 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..models.lm import AudioLanguageModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLanguageModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + pre_encoded=False + ): + super().__init__() + + self.model = model + + self.model.pretransform.requires_grad_(False) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if not self.pre_encoded: + codes = self.model.pretransform.tokenize(reals) + else: + codes = reals + + padding_masks = [] + for md in metadata: + if md["padding_mask"].ndim == 1: + padding_masks.append(md["padding_mask"]) + else: + padding_masks.append(md["padding_mask"][0]) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + # Interpolate padding masks to the same length as the codes + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + condition_tensors = self.model.conditioner(metadata, self.device) + + lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] + + logits_mask = logits_mask & padding_masks + + cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) + + loss = cross_entropy + + log_dict = { + 'train/loss': loss.detach(), + 'train/cross_entropy': cross_entropy.detach(), + 'train/perplexity': torch.exp(cross_entropy).detach(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for k, ce_q in enumerate(cross_entropy_per_codebook): + log_dict[f'cross_entropy_q{k + 1}'] = ce_q + log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + #demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + ##Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes / fakes.abs().max() + fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/ThinkSound/training/lm_continuous.py b/ThinkSound/training/lm_continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..0ecc1a92336a0623f3f9b1c455a1f8198e4cacb8 --- /dev/null +++ b/ThinkSound/training/lm_continuous.py @@ -0,0 +1,294 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper + +from ..models.lm import AudioLMContinuousModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLMContinuousModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + diffusion_batch_mul=4, + pre_encoded=False + ): + super().__init__() + + self.model = model + self.diffusion = diffusion + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.model.pretransform.requires_grad_(False) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + + self.optimizer_configs = optimizer_configs + + self.diffusion_batch_mul = diffusion_batch_mul + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + diffusion_input = reals + + loss_info = {} + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + padding_masks = [] + for md in metadata: + if md["padding_mask"].ndim == 1: + padding_masks.append(md["padding_mask"]) + else: + padding_masks.append(md["padding_mask"][0]) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + with torch.cuda.amp.autocast(): + condition_tensors = self.model.conditioner(metadata, self.device) + + z = self.model.compute_logits(diffusion_input, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + bsz, seq_len, _ = z.shape + gt_inputs = diffusion_input.clone().detach() + gt_inputs = gt_inputs.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) + z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) + mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(z.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(z.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None] + sigmas = sigmas[:, None] + + noise = torch.randn_like(gt_inputs) + noised_inputs = gt_inputs * alphas + noise * sigmas + if self.diffusion_objective == "v": + targets = noise * alphas - gt_inputs * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - gt_inputs + cond = {} + cond['z'] = z + with torch.cuda.amp.autocast(): + v = self.diffusion(noised_inputs, t, cond=cond) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses() + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback):loss_info + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + #demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + ##Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes / fakes.abs().max() + fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/ThinkSound/training/losses/__init__.py b/ThinkSound/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37fdea0eb6c3190e7001567cfe17dc296bf811e8 --- /dev/null +++ b/ThinkSound/training/losses/__init__.py @@ -0,0 +1 @@ +from .losses import * \ No newline at end of file diff --git a/ThinkSound/training/losses/auraloss.py b/ThinkSound/training/losses/auraloss.py new file mode 100644 index 0000000000000000000000000000000000000000..933ec0ab66cb6f8fbb26e249f96b09fb1982a132 --- /dev/null +++ b/ThinkSound/training/losses/auraloss.py @@ -0,0 +1,691 @@ +# Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0 +# You can find the license at LICENSES/LICENSE_AURALOSS.txt + +import torch +import numpy as np +from typing import List, Any +import scipy.signal + +def apply_reduction(losses, reduction="none"): + """Apply reduction to collection of losses.""" + if reduction == "mean": + losses = losses.mean() + elif reduction == "sum": + losses = losses.sum() + return losses + +def compute_direction(w, x, y, z): + # 计算各个声道的权重 + phi = torch.atan2(y, x) + theta = torch.atan2(torch.sqrt(x**2 + y**2), z) + return phi.unsqueeze(1), theta.unsqueeze(1) + +def get_window(win_type: str, win_length: int): + """Return a window function. + + Args: + win_type (str): Window type. Can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + win_length (int): Window length + + Returns: + win: The window as a 1D torch tensor + """ + + try: + win = getattr(torch, win_type)(win_length) + except: + win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length)) + + return win + +class SumAndDifference(torch.nn.Module): + """Sum and difference signal extraction module.""" + + def __init__(self): + """Initialize sum and difference extraction module.""" + super(SumAndDifference, self).__init__() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, #channels, #samples). + Returns: + Tensor: Sum signal. + Tensor: Difference signal. + """ + if not (x.size(1) == 2): # inputs must be stereo + raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).") + + sum_sig = self.sum(x).unsqueeze(1) + diff_sig = self.diff(x).unsqueeze(1) + + return sum_sig, diff_sig + + @staticmethod + def sum(x): + return x[:, 0, :] + x[:, 1, :] + + @staticmethod + def diff(x): + return x[:, 0, :] - x[:, 1, :] + + +class FIRFilter(torch.nn.Module): + """FIR pre-emphasis filtering module. + + Args: + filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" + coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 + ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 + plot (bool): Plot the magnitude respond of the filter. Default: False + + Based upon the perceptual loss pre-empahsis filters proposed by + [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). + + A-weighting filter - "aw" + First-order highpass - "hp" + Folded differentiator - "fd" + + Note that the default coefficeint value of 0.85 is optimized for + a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. + """ + + def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): + """Initilize FIR pre-emphasis filtering module.""" + super(FIRFilter, self).__init__() + self.filter_type = filter_type + self.coef = coef + self.fs = fs + self.ntaps = ntaps + self.plot = plot + + import scipy.signal + + if ntaps % 2 == 0: + raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") + + if filter_type == "hp": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) + elif filter_type == "fd": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) + elif filter_type == "aw": + # Definition of analog A-weighting filter according to IEC/CD 1672. + f1 = 20.598997 + f2 = 107.65265 + f3 = 737.86223 + f4 = 12194.217 + A1000 = 1.9997 + + NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] + DENs = np.polymul( + [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], + [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], + ) + DENs = np.polymul( + np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] + ) + + # convert analog filter to digital filter + b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) + + # compute the digital filter frequency response + w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) + + # then we fit to 101 tap FIR filter with least squares + taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) + + # now implement this digital FIR filter as a Conv1d layer + self.fir = torch.nn.Conv1d( + 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 + ) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) + + if plot: + from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) + + def forward(self, input, target): + """Calculate forward propagation. + Args: + input (Tensor): Predicted signal (B, #channels, #samples). + target (Tensor): Groundtruth signal (B, #channels, #samples). + Returns: + Tensor: Filtered signal. + """ + input = torch.nn.functional.conv1d( + input, self.fir.weight.data, padding=self.ntaps // 2 + ) + target = torch.nn.functional.conv1d( + target, self.fir.weight.data, padding=self.ntaps // 2 + ) + return input, target + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719). + """ + + def __init__(self): + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean() + +class STFTMagnitudeLoss(torch.nn.Module): + """STFT magnitude loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719) + and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1) + + Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the + compression strength (larger value results in more compression), and `log_eps` can be used + to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive + output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression. + + Args: + log (bool, optional): Log-scale the STFT magnitudes, + or use linear scale. Default: True + log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm. + Default: 0.0 + log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm. + Default: 1.0 + distance (str, optional): Distance function ["L1", "L2"]. Default: "L1" + reduction (str, optional): Reduction of the loss elements. Default: "mean" + """ + + def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"): + super(STFTMagnitudeLoss, self).__init__() + + self.log = log + self.log_eps = log_eps + self.log_fac = log_fac + + if distance == "L1": + self.distance = torch.nn.L1Loss(reduction=reduction) + elif distance == "L2": + self.distance = torch.nn.MSELoss(reduction=reduction) + else: + raise ValueError(f"Invalid distance: '{distance}'.") + + def forward(self, x_mag, y_mag): + if self.log: + x_mag = torch.log(self.log_fac * x_mag + self.log_eps) + y_mag = torch.log(self.log_fac * y_mag + self.log_eps) + return self.distance(x_mag, y_mag) + + +class STFTLoss(torch.nn.Module): + """STFT loss module. + + See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472). + + Args: + fft_size (int, optional): FFT size in samples. Default: 1024 + hop_size (int, optional): Hop size of the FFT in samples. Default: 256 + win_length (int, optional): Length of the FFT analysis window. Default: 1024 + window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of scaling frequency bins. Default: None. + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + eps (float, optional): Small epsilon value for stablity. Default: 1e-8 + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + reduction (str, optional): Specifies the reduction to apply to the output: + 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of elements in the output, + 'sum': the output will be summed. + Default: 'mean' + mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. + device (str, optional): Place the filterbanks on specified device. Default: None + + Returns: + loss: + Aggreate loss term. Only returned if output='loss'. By default. + loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss: + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + + def __init__( + self, + fft_size: int = 1024, + hop_size: int = 256, + win_length: int = 1024, + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + eps: float = 1e-8, + output: str = "loss", + reduction: str = "mean", + mag_distance: str = "L1", + device: Any = None, + **kwargs + ): + super().__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.win_length = win_length + self.window = get_window(window, win_length) + self.w_sc = w_sc + self.w_log_mag = w_log_mag + self.w_lin_mag = w_lin_mag + self.w_phs = w_phs + self.sample_rate = sample_rate + self.scale = scale + self.n_bins = n_bins + self.perceptual_weighting = perceptual_weighting + self.scale_invariance = scale_invariance + self.eps = eps + self.output = output + self.reduction = reduction + self.mag_distance = mag_distance + self.device = device + + self.phs_used = bool(self.w_phs) + + self.spectralconv = SpectralConvergenceLoss() + self.logstft = STFTMagnitudeLoss( + log=True, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + self.linstft = STFTMagnitudeLoss( + log=False, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + + # setup mel filterbank + if scale is not None: + try: + import librosa.filters + except Exception as e: + print(e) + print("Try `pip install auraloss[all]`.") + + if self.scale == "mel": + assert sample_rate != None # Must set sample rate to use mel scale + assert n_bins <= fft_size # Must be more FFT bins than Mel bins + fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) + fb = torch.tensor(fb).unsqueeze(0) + + elif self.scale == "chroma": + assert sample_rate != None # Must set sample rate to use chroma scale + assert n_bins <= fft_size # Must be more FFT bins than chroma bins + fb = librosa.filters.chroma( + sr=sample_rate, n_fft=fft_size, n_chroma=n_bins + ) + + else: + raise ValueError( + f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." + ) + + self.register_buffer("fb", fb) + + if scale is not None and device is not None: + self.fb = self.fb.to(self.device) # move filterbank to device + + if self.perceptual_weighting: + if sample_rate is None: + raise ValueError( + f"`sample_rate` must be supplied when `perceptual_weighting = True`." + ) + self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + + def stft(self, x): + """Perform STFT. + Args: + x (Tensor): Input signal tensor (B, T). + + Returns: + Tensor: x_mag, x_phs + Magnitude and phase spectra (B, fft_size // 2 + 1, frames). + """ + x_stft = torch.stft( + x, + self.fft_size, + self.hop_size, + self.win_length, + self.window, + return_complex=True, + ) + x_mag = torch.sqrt( + torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) + ) + + # torch.angle is expensive, so it is only evaluated if the values are used in the loss + if self.phs_used: + x_phs = torch.angle(x_stft) + else: + x_phs = None + + return x_mag, x_phs + + def forward(self, input: torch.Tensor, target: torch.Tensor): + bs, chs, seq_len = input.size() + + if self.perceptual_weighting: # apply optional A-weighting via FIR filter + # since FIRFilter only support mono audio we will move channels to batch dim + input = input.view(bs * chs, 1, -1) + target = target.view(bs * chs, 1, -1) + + # now apply the filter to both + self.prefilter.to(input.device) + input, target = self.prefilter(input, target) + + # now move the channels back + input = input.view(bs, chs, -1) + target = target.view(bs, chs, -1) + + # compute the magnitude and phase spectra of input and target + self.window = self.window.to(input.device) + + x_mag, x_phs = self.stft(input.view(-1, input.size(-1))) + y_mag, y_phs = self.stft(target.view(-1, target.size(-1))) + + # apply relevant transforms + if self.scale is not None: + self.fb = self.fb.to(input.device) + x_mag = torch.matmul(self.fb, x_mag) + y_mag = torch.matmul(self.fb, y_mag) + + # normalize scales + if self.scale_invariance: + alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1])) + y_mag = y_mag * alpha.unsqueeze(-1) + + # compute loss terms + sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0 + log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0 + lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0 + phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0 + + # combine loss terms + loss = ( + (self.w_sc * sc_mag_loss) + + (self.w_log_mag * log_mag_loss) + + (self.w_lin_mag * lin_mag_loss) + + (self.w_phs * phs_loss) + ) + + loss = apply_reduction(loss, reduction=self.reduction) + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module. + + See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480) + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str, optional): Window to apply before FFT, options include: + 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None. + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + """ + + def __init__( + self, + fft_sizes: List[int] = [1024, 2048, 512], + hop_sizes: List[int] = [120, 240, 50], + win_lengths: List[int] = [600, 1200, 240], + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + **kwargs, + ): + super().__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all + self.fft_sizes = fft_sizes + self.hop_sizes = hop_sizes + self.win_lengths = win_lengths + + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [ + STFTLoss( + fs, + ss, + wl, + window, + w_sc, + w_log_mag, + w_lin_mag, + w_phs, + sample_rate, + scale, + n_bins, + perceptual_weighting, + scale_invariance, + **kwargs, + ) + ] + + def forward(self, x, y): + mrstft_loss = 0.0 + sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], [] + # import ipdb + # ipdb.set_trace() + for f in self.stft_losses: + if f.output == "full": # extract just first term + tmp_loss = f(x, y) + mrstft_loss += tmp_loss[0] + sc_mag_loss.append(tmp_loss[1]) + log_mag_loss.append(tmp_loss[2]) + lin_mag_loss.append(tmp_loss[3]) + phs_loss.append(tmp_loss[4]) + else: + mrstft_loss += f(x, y) + + mrstft_loss /= len(self.stft_losses) + + if f.output == "loss": + return mrstft_loss + else: + return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + + +class SumAndDifferenceSTFTLoss(torch.nn.Module): + """Sum and difference sttereo STFT loss module. + + See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) + + Args: + fft_sizes (List[int]): List of FFT sizes. + hop_sizes (List[int]): List of hop sizes. + win_lengths (List[int]): List of window lengths. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False + n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + """ + + def __init__( + self, + fft_sizes: List[int], + hop_sizes: List[int], + win_lengths: List[int], + window: str = "hann_window", + w_sum: float = 1.0, + w_diff: float = 1.0, + output: str = "loss", + **kwargs, + ): + super().__init__() + self.sd = SumAndDifference() + self.w_sum = w_sum + self.w_diff = w_diff + self.output = output + self.mrstft = MultiResolutionSTFTLoss( + fft_sizes, + hop_sizes, + win_lengths, + window, + **kwargs, + ) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + + # compute sum and difference signals for both + input_sum, input_diff = self.sd(input) + target_sum, target_diff = self.sd(target) + + # compute error in STFT domain + sum_loss = self.mrstft(input_sum, target_sum) + diff_loss = self.mrstft(input_diff, target_diff) + loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sum_loss, diff_loss + +class SpatialSTFTLoss(torch.nn.Module): + """Sum and difference sttereo STFT loss module. + + See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) + + Args: + fft_sizes (List[int]): List of FFT sizes. + hop_sizes (List[int]): List of hop sizes. + win_lengths (List[int]): List of window lengths. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False + n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + """ + + def __init__( + self, + fft_sizes: List[int], + hop_sizes: List[int], + win_lengths: List[int], + window: str = "hann_window", + w_phi: float = 1.0, + w_theta: float = 1.0, + output: str = "loss", + **kwargs, + ): + super().__init__() + self.w_phi = w_phi + self.w_theta = w_theta + self.output = output + self.mrstft = MultiResolutionSTFTLoss( + fft_sizes, + hop_sizes, + win_lengths, + window, + **kwargs, + ) + + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + + w_o, x_o, y_o, z_o = input[:, 0], input[:, 1], input[:, 2], input[:, 3] + w_r, x_r, y_r, z_r = target[:, 0], target[:, 1], target[:, 2], target[:, 3] + + phi_o, theta_o = compute_direction(w_o, x_o, y_o, z_o) + phi_r, theta_r = compute_direction(w_r, x_r, y_r, z_r) + + # compute error in STFT domain + phi_loss = self.mrstft(phi_o, phi_r) + theta_loss = self.mrstft(theta_o, theta_r) + loss = ((self.w_phi * phi_loss) + (self.w_theta * theta_loss)) / 2 + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sum_loss, diff_loss \ No newline at end of file diff --git a/ThinkSound/training/losses/losses.py b/ThinkSound/training/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7285850c3ff873e0dda6a83265536dcb0bcb5b4f --- /dev/null +++ b/ThinkSound/training/losses/losses.py @@ -0,0 +1,100 @@ +import typing as tp + +from torch.nn import functional as F +from torch import nn + +class LossModule(nn.Module): + def __init__(self, name: str, weight: float = 1.0): + super().__init__() + + self.name = name + self.weight = weight + + def forward(self, info, *args, **kwargs): + raise NotImplementedError + +class ValueLoss(LossModule): + def __init__(self, key: str, name, weight: float = 1.0): + super().__init__(name=name, weight=weight) + + self.key = key + + def forward(self, info): + return self.weight * info[self.key] + +class L1Loss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info: + mse_loss = mse_loss[info[self.mask_key]] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class MSELoss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') + if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: + mask = info[self.mask_key] + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss[mask] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class AuralossLoss(LossModule): + def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): + super().__init__(name, weight) + + self.auraloss_module = auraloss_module + + self.input_key = input_key + self.target_key = target_key + + def forward(self, info): + loss = self.auraloss_module(info[self.input_key], info[self.target_key]) + + return self.weight * loss + +class MultiLoss(nn.Module): + def __init__(self, losses: tp.List[LossModule]): + super().__init__() + + self.losses = nn.ModuleList(losses) + + def forward(self, info): + total_loss = 0 + + losses = {} + + for loss_module in self.losses: + module_loss = loss_module(info) + total_loss += module_loss + losses[loss_module.name] = module_loss + + return total_loss, losses \ No newline at end of file diff --git a/ThinkSound/training/utils.py b/ThinkSound/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..16d5f5616be13407d99424956328dbebc5ff29bd --- /dev/null +++ b/ThinkSound/training/utils.py @@ -0,0 +1,232 @@ +import torch +import os +from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor +import random +import wandb +from lightning.pytorch.loggers import WandbLogger, CometLogger, TensorBoardLogger +from ..interface.aeiou import pca_point_cloud +def get_rank(): + """Get rank of current process.""" + + print(os.environ.keys()) + + if "SLURM_PROCID" in os.environ: + return int(os.environ["SLURM_PROCID"]) + + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + return torch.distributed.get_rank() + +class InverseLR(torch.optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + final_lr (float): The final learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.final_lr = final_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + import warnings + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.final_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + +def copy_state_dict(model, state_dict): + """Load state_dict to model, but only for keys that match exactly. + + Args: + model (nn.Module): model to load state_dict. + state_dict (OrderedDict): state_dict to load. + """ + model_state_dict = model.state_dict() + # 创建一个列表存储不匹配的参数 + missing_keys = [] + unexpected_keys = [] + # 手动加载并检查不匹配的参数 + for key in state_dict: + if key not in model_state_dict: + unexpected_keys.append(key) + elif state_dict[key].shape != model_state_dict[key].shape: + unexpected_keys.append(key) + + for key in model_state_dict: + if key not in state_dict: + missing_keys.append(key) + + # 打印不匹配的参数 + print("Missing keys in state_dict:", missing_keys) + print("Unexpected keys in state_dict:", unexpected_keys) + for key in state_dict: + if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: + if isinstance(state_dict[key], torch.nn.Parameter): + # backwards compatibility for serialized parameters + state_dict[key] = state_dict[key].data + model_state_dict[key] = state_dict[key] + + model.load_state_dict(model_state_dict, strict=False) + +def create_optimizer_from_config(optimizer_config, parameters): + """Create optimizer from config. + + Args: + parameters (iterable): parameters to optimize. + optimizer_config (dict): optimizer config. + + Returns: + torch.optim.Optimizer: optimizer. + """ + + optimizer_type = optimizer_config["type"] + + if optimizer_type == "FusedAdam": + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) + else: + optimizer_fn = getattr(torch.optim, optimizer_type) + optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + return optimizer + +def create_scheduler_from_config(scheduler_config, optimizer): + """Create scheduler from config. + + Args: + scheduler_config (dict): scheduler config. + optimizer (torch.optim.Optimizer): optimizer. + + Returns: + torch.optim.lr_scheduler._LRScheduler: scheduler. + """ + if scheduler_config["type"] == "InverseLR": + scheduler_fn = InverseLR + else: + scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) + scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) + return scheduler + +# mask construction helpers + +def mask_from_start_end_indices( + seq_len: int, + start: Tensor, + end: Tensor +): + assert start.shape == end.shape + device = start.device + + seq = torch.arange(seq_len, device = device, dtype = torch.long) + seq = seq.reshape(*((-1,) * start.ndim), seq_len) + seq = seq.expand(*start.shape, seq_len) + + mask = seq >= start[..., None].long() + mask &= seq < end[..., None].long() + return mask + +def mask_from_frac_lengths( + seq_len: int, + frac_lengths: Tensor +): + device = frac_lengths.device + + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) + start = (max_start * rand).clamp(min = 0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + +def generate_mask(batch_size, seq_len, frac_lengths, min_span_len): + # 计算需要掩盖的起始数量 + n_mask = (frac_lengths * seq_len // min_span_len).long() # 每个 span 为 10 + # 初始化掩码张量,初始为全 0(未掩盖) + mask_tensor = torch.zeros((batch_size, seq_len), device=frac_lengths.device, dtype=torch.bool) + + for b in range(batch_size): + # 随机挑选起始帧 + start_frames = random.sample(range(0, seq_len - min_span_len + 1), n_mask[b]) # 0 到 seq_len-10 的范围 + + for start in start_frames: + # 将 span 为 10 的区域标记为 1(掩盖) + mask_tensor[b, start:start + 10] = 1.0 + + return mask_tensor + +def generate_channel_mask(diffusion_input): + + # 如果 r_drop 小于 threshold,则对每个样本选择一个随机声道进行完全 mask + batchsize, num_channels, dim = diffusion_input.shape + for i in range(batchsize): + channel_means = torch.mean(torch.abs(diffusion_input[i]), dim=1) # Mean of the absolute values for each channel + # Determine if any channel is 'small enough' + if torch.all(channel_means > 0.01): + # If all channels are not 'small enough', apply the mask + channel = torch.randint(num_channels, (1,)).item() + diffusion_input[i, channel, :] = 1e-8 # Mask the channel by setting its values + else: + # Optionally log that at least one channel is 'small enough' and no mask is applied + print(f"Sample {i}: At least one channel is 'small enough', skipping masking.") + + return diffusion_input + +def logger_project_name(logger) -> str: + if isinstance(logger, WandbLogger): + return logger.experiment.project + elif isinstance(logger, CometLogger): + return logger.name + +def log_metric(logger, key, value, step=None): + from pytorch_lightning.loggers import WandbLogger, CometLogger + if isinstance(logger, WandbLogger): + logger.experiment.log({key: value}) + elif isinstance(logger, CometLogger): + logger.experiment.log_metric({key: value}, step=step) + +def log_audio(logger, key, audio_path, sample_rate, caption=None): + if isinstance(logger, WandbLogger): + logger.experiment.log({key: wandb.Audio(audio_path, sample_rate=sample_rate, caption=caption)}) + elif isinstance(logger, CometLogger): + logger.experiment.log_audio(audio_path, file_name=key, sample_rate=sample_rate) + +def log_image(logger, key, img_data): + if isinstance(logger, WandbLogger): + logger.experiment.log({key: wandb.Image(img_data)}) + elif isinstance(logger, CometLogger): + logger.experiment.log_image(img_data, name=key) + +def log_point_cloud(logger, key, tokens, caption=None): + if isinstance(logger, WandbLogger): + point_cloud = pca_point_cloud(tokens) + logger.experiment.log({key: point_cloud}) + elif isinstance(logger, CometLogger): + point_cloud = pca_point_cloud(tokens, rgb_float=True, output_type="points") + #logger.experiment.log_points_3d(scene_name=key, points=point_cloud) + diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..81b9011b066a970b993bc713d45565e3a18ee503 --- /dev/null +++ b/app.py @@ -0,0 +1,642 @@ +import os +# ⭐ Must be set before importing gradio +import subprocess +import sys + +subprocess.run(["bash", "setup.sh"], check=True) + +os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".gradio_tmp") +os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True) +os.environ["JAX_PLATFORMS"] = "cpu" +import gradio as gr +import logging +import sys +import json +import torch +import torchaudio +import numpy as np +import tempfile +import shutil +import subprocess +from pathlib import Path +import torch.nn.functional as F +import mediapy +from torio.io import StreamingMediaDecoder +from torchvision.transforms import v2 +import time +import random +seed=42 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + + +try: + from moviepy import VideoFileClip +except ImportError: + from moviepy.editor import VideoFileClip + +# ==================== Logging ==================== +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +log = logging.getLogger() + +# ==================== Constants ==================== +_CLIP_FPS = 4 +_CLIP_SIZE = 288 +_SYNC_FPS = 25 +_SYNC_SIZE = 224 +SAMPLE_RATE = 44100 + +# ==================== Model Path Configuration ==================== +from huggingface_hub import snapshot_download +snapshot_download(repo_id="FunAudioLLM/PrismAudio", local_dir="./ckpts") + +MODEL_CONFIG_PATH = "ThinkSound/configs/model_configs/prismaudio.json" +CKPT_PATH = "ckpts/prismaudio.ckpt" +VAE_CKPT_PATH = "ckpts/vae.ckpt" +VAE_CONFIG_PATH = "ThinkSound/configs/model_configs/stable_audio_2_0_vae.json" +SYNCHFORMER_CKPT_PATH = "ckpts/synchformer_state_dict.pth" +DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' + +# ==================== Global Model Registry ==================== +_MODELS = { + "feature_extractor": None, + "diffusion": None, + "model_config": None, + "sync_transform": None, +} + + +def load_all_models(): + """Load all models once at application startup.""" + global _MODELS + + log.info("=" * 50) + log.info("Loading all models...") + + # ---- 1. Sync video transform ---- + _MODELS["sync_transform"] = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + log.info("✅ sync_transform ready") + + # ---- 2. FeaturesUtils ---- + from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils + + feature_extractor = FeaturesUtils( + vae_ckpt=None, + vae_config=VAE_CONFIG_PATH, + enable_conditions=True, + synchformer_ckpt=SYNCHFORMER_CKPT_PATH, + ) + feature_extractor = feature_extractor.eval().to(DEVICE) + _MODELS["feature_extractor"] = feature_extractor + log.info("✅ FeaturesUtils loaded") + + # ---- 3. Diffusion model ---- + from ThinkSound.models import create_model_from_config + from ThinkSound.models.utils import load_ckpt_state_dict + + with open(MODEL_CONFIG_PATH) as f: + model_config = json.load(f) + _MODELS["model_config"] = model_config + + diffusion = create_model_from_config(model_config) + diffusion.load_state_dict(torch.load(CKPT_PATH, map_location='cpu')) + + vae_state = load_ckpt_state_dict(VAE_CKPT_PATH, prefix='autoencoder.') + diffusion.pretransform.load_state_dict(vae_state) + + diffusion = diffusion.eval().to(DEVICE) + _MODELS["diffusion"] = diffusion + log.info("✅ Diffusion model loaded") + + log.info("=" * 50) + log.info("All models ready. Waiting for inference requests.") + + +# ==================== Video Utilities ==================== + +def get_video_duration(video_path: str) -> float: + video = VideoFileClip(str(video_path)) + duration = video.duration + video.close() + return duration + + +def convert_to_mp4(src: str, dst: str) -> tuple[bool, str]: + """Re-encode any video format to h264/aac mp4 via ffmpeg.""" + result = subprocess.run( + [ + "ffmpeg", "-y", "-i", src, + "-c:v", "libx264", "-preset", "fast", + "-c:a", "aac", "-strict", "experimental", + dst, + ], + capture_output=True, + text=True, + ) + return result.returncode == 0, result.stderr + + +def combine_audio_video(video_path: str, audio_path: str, output_path: str) -> tuple[bool, str]: + """Mux generated audio into the original silent video via ffmpeg.""" + result = subprocess.run( + [ + "ffmpeg", "-y", + "-i", video_path, + "-i", audio_path, + "-c:v", "copy", + "-c:a", "aac", "-strict", "experimental", + "-map", "0:v:0", + "-map", "1:a:0", + "-shortest", + output_path, + ], + capture_output=True, + text=True, + ) + return result.returncode == 0, result.stderr + + +def pad_to_square(video_tensor: torch.Tensor) -> torch.Tensor: + """(L, C, H, W) -> (L, C, _CLIP_SIZE, _CLIP_SIZE)""" + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (L, C, H, W)") + l, c, h, w = video_tensor.shape + max_side = max(h, w) + pad_h = max_side - h + pad_w = max_side - w + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + return F.interpolate( + video_padded, size=(_CLIP_SIZE, _CLIP_SIZE), + mode='bilinear', align_corners=False, + ) + + +def extract_video_frames(video_path: str): + """ + Decode clip_chunk and sync_chunk from video entirely in memory. + + Returns: + clip_chunk : (L, H, W, C) float32 [0, 1] + sync_chunk : (L, C, H, W) float32 normalized + duration : float (seconds) + """ + sync_transform = _MODELS["sync_transform"] + assert sync_transform is not None, "Call load_all_models() first." + + duration_sec = get_video_duration(video_path) + + reader = StreamingMediaDecoder(video_path) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + + if clip_chunk is None: + raise RuntimeError("CLIP video stream returned None") + if sync_chunk is None: + raise RuntimeError("Sync video stream returned None") + + # ---- clip_chunk ---- + clip_expected = int(_CLIP_FPS * duration_sec) + clip_chunk = clip_chunk[:clip_expected] + if clip_chunk.shape[0] < clip_expected: + pad_n = clip_expected - clip_chunk.shape[0] + clip_chunk = torch.cat( + [clip_chunk, clip_chunk[-1:].repeat(pad_n, 1, 1, 1)], dim=0 + ) + clip_chunk = pad_to_square(clip_chunk) + clip_chunk = clip_chunk.permute(0, 2, 3, 1) + clip_chunk = mediapy.to_float01(clip_chunk) + + # ---- sync_chunk ---- + sync_expected = int(_SYNC_FPS * duration_sec) + sync_chunk = sync_chunk[:sync_expected] + if sync_chunk.shape[0] < sync_expected: + pad_n = sync_expected - sync_chunk.shape[0] + sync_chunk = torch.cat( + [sync_chunk, sync_chunk[-1:].repeat(pad_n, 1, 1, 1)], dim=0 + ) + sync_chunk = sync_transform(sync_chunk) + + log.info(f"clip_chunk: {clip_chunk.shape}, sync_chunk: {sync_chunk.shape}") + return clip_chunk, sync_chunk, duration_sec + + +# ==================== Feature Extraction ==================== + +def extract_features(clip_chunk: torch.Tensor, sync_chunk: torch.Tensor, caption: str) -> dict: + """Reuses globally loaded FeaturesUtils — no reload per call.""" + model = _MODELS["feature_extractor"] + assert model is not None, "FeaturesUtils not initialized." + + info = {} + with torch.no_grad(): + text_features = model.encode_t5_text([caption]) + info['text_features'] = text_features[0].cpu() + + clip_input = torch.from_numpy(clip_chunk).unsqueeze(0) + video_feat, frame_embed, _, text_feat = \ + model.encode_video_and_text_with_videoprism(clip_input, [caption]) + + info['global_video_features'] = torch.tensor(np.array(video_feat)).squeeze(0).cpu() + info['video_features'] = torch.tensor(np.array(frame_embed)).squeeze(0).cpu() + info['global_text_features'] = torch.tensor(np.array(text_feat)).squeeze(0).cpu() + + sync_input = sync_chunk.unsqueeze(0).to(DEVICE) + info['sync_features'] = model.encode_video_with_sync(sync_input)[0].cpu() + + return info + + +# ==================== Build Meta ==================== + +def build_meta(info: dict, duration: float, caption: str): + latent_length = round(SAMPLE_RATE * duration / 2048) + audio_latent = torch.zeros((1, 64, latent_length), dtype=torch.float32) + + meta = dict(info) + meta['id'] = 'demo' + meta['relpath'] = 'demo.npz' + meta['path'] = 'demo.npz' + meta['caption_cot'] = caption + meta['video_exist'] = torch.tensor(True) + + return audio_latent, meta + + +# ==================== Diffusion Sampling ==================== + +def run_diffusion(audio_latent: torch.Tensor, meta: dict, duration: float) -> torch.Tensor: + """Reuses globally loaded diffusion model — no reload per call.""" + from ThinkSound.inference.sampling import sample, sample_discrete_euler + import time + + diffusion = _MODELS["diffusion"] + model_config = _MODELS["model_config"] + assert diffusion is not None, "Diffusion model not initialized." + + diffusion_objective = model_config["model"]["diffusion"]["diffusion_objective"] + latent_length = round(SAMPLE_RATE * duration / 2048) + + meta_on_device = { + k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v + for k, v in meta.items() + } + metadata = (meta_on_device,) + + with torch.no_grad(): + with torch.amp.autocast('cuda'): + conditioning = diffusion.conditioner(metadata, DEVICE) + + video_exist = torch.stack([item['video_exist'] for item in metadata], dim=0) + if 'metaclip_features' in conditioning: + conditioning['metaclip_features'][~video_exist] = \ + diffusion.model.model.empty_clip_feat + if 'sync_features' in conditioning: + conditioning['sync_features'][~video_exist] = \ + diffusion.model.model.empty_sync_feat + + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + noise = torch.randn([1, diffusion.io_channels, latent_length]).to(DEVICE) + + with torch.amp.autocast('cuda'): + if diffusion_objective == "v": + fakes = sample( + diffusion.model, noise, 24, 0, + **cond_inputs, cfg_scale=5, batch_cfg=True, + ) + elif diffusion_objective == "rectified_flow": + t0 = time.time() + fakes = sample_discrete_euler( + diffusion.model, noise, 24, + **cond_inputs, cfg_scale=5, batch_cfg=True, + ) + log.info(f"Sampling time: {time.time() - t0:.2f}s") + + if diffusion.pretransform is not None: + fakes = diffusion.pretransform.decode(fakes) + + return ( + fakes.to(torch.float32) + .div(torch.max(torch.abs(fakes))) + .clamp(-1, 1) + .mul(32767) + .to(torch.int16) + .cpu() + ) + + +# ==================== Full Inference Pipeline ==================== + +def generate_audio(video_file, caption: str): + start_time =time.time() + + """ + Gradio generator function (yields status + result progressively). + + Yields: + (status_str, combined_video_path_or_None) + """ + # ---- Basic validation ---- + if video_file is None: + yield "❌ Please upload a video file first.", None + return + if not caption or caption.strip() == "": + yield "❌ Please enter a caption / prompt.", None + return + + caption = caption.strip() + logs = [] + + def log_step(msg: str): + log.info(msg) + logs.append(msg) + return "\n".join(logs) + + # ---- Working directory (auto-cleaned on exit) ---- + work_dir = tempfile.mkdtemp(dir=os.environ["GRADIO_TEMP_DIR"], prefix="thinksound_") + + try: + # ---- Step 1: Convert / copy to mp4 ---- + status = log_step("📹 Step 1: Preparing video...") + + yield status, None + + src_ext = os.path.splitext(video_file)[1].lower() + mp4_path = os.path.join(work_dir, "input.mp4") + + if src_ext != ".mp4": + log_step(" Converting to mp4...") + ok, err = convert_to_mp4(video_file, mp4_path) + if not ok: + yield log_step(f"❌ Video conversion failed:\n{err}"), None + return + else: + shutil.copy(video_file, mp4_path) + log_step(" Video ready.") + + # ---- Step 2: Validate duration ---- + status = log_step("📹 Step 2: Checking video duration...") + yield status, None + + duration = get_video_duration(mp4_path) + log_step(f" Duration: {duration:.2f}s") + + # ---- Step 3: Extract video frames ---- + status = log_step("🎞️ Step 3: Extracting video frames (clip & sync)...") + yield status, None + + clip_chunk, sync_chunk, duration = extract_video_frames(mp4_path) + log_step(f" clip_chunk : {tuple(clip_chunk.shape)}") + log_step(f" sync_chunk : {tuple(sync_chunk.shape)}") + + # ---- Step 4: Extract model features ---- + status = log_step("🧠 Step 4: Extracting text / video / sync features...") + yield status, None + + info = extract_features(clip_chunk, sync_chunk, caption) + log_step(f" text_features : {tuple(info['text_features'].shape)}") + log_step(f" global_video_features : {tuple(info['global_video_features'].shape)}") + log_step(f" video_features : {tuple(info['video_features'].shape)}") + log_step(f" global_text_features : {tuple(info['global_text_features'].shape)}") + log_step(f" sync_features : {tuple(info['sync_features'].shape)}") + + # ---- Step 5: Build inference batch ---- + status = log_step("📦 Step 5: Building inference batch...") + yield status, None + + audio_latent, meta = build_meta(info, duration, caption) + log_step(f" audio_latent : {tuple(audio_latent.shape)}") + + # ---- Step 6: Diffusion sampling ---- + status = log_step("🎵 Step 6: Running diffusion sampling...") + yield status, None + + generated_audio = run_diffusion(audio_latent, meta, duration) + log_step(f" Generated audio shape : {tuple(generated_audio.shape)}") + + # ---- Step 7: Save generated audio (temp) ---- + status = log_step("💾 Step 7: Saving generated audio...") + yield status, None + + audio_path = os.path.join(work_dir, "generated_audio.wav") + torchaudio.save( + audio_path, + generated_audio[0], # (1, T) + SAMPLE_RATE, + ) + log_step(f" Audio saved: {audio_path}") + + # ---- Step 8: Mux audio into original video ---- + status = log_step("🎬 Step 8: Merging audio into video...") + yield status, None + + combined_path = os.path.join(work_dir, "output_with_audio.mp4") + ok, err = combine_audio_video(mp4_path, audio_path, combined_path) + if not ok: + yield log_step(f"❌ Failed to combine audio and video:\n{err}"), None + return + + log_step("✅ Done! Audio and video merged successfully.") + yield "\n".join(logs), combined_path + + except Exception as e: + log_step(f"❌ Unexpected error: {str(e)}") + log.exception(e) + yield "\n".join(logs), None + + end_time =time.time() + print("cost: ",end_time-start_time) + + # Note: work_dir is NOT deleted here so Gradio can serve the output file. + # Gradio manages its own GRADIO_TEMP_DIR cleanup on restart. + + +# ==================== Gradio UI ==================== + +def build_ui() -> gr.Blocks: + with gr.Blocks( + title="ThinkSound - Video to Audio Generation", + theme=gr.themes.Soft(), + css=""" + .title { text-align:center; font-size:2em; font-weight:bold; margin-bottom:.2em; } + .sub { text-align:center; color:#666; margin-bottom:1.5em; } + .mono { font-family:monospace; font-size:.85em; } + """, + ) as demo: + + gr.HTML('
🎵 ThinkSound
') + gr.HTML( + '
' + 'Upload a video and a text prompt — ' + 'the generated audio will be merged back into your video.' + '
' + ) + + # ====================================================== + # Row 1 — Inputs + # ====================================================== + with gr.Row(): + + # ---------- Left: inputs ---------- + with gr.Column(scale=1): + gr.Markdown("### 📥 Input") + + video_input = gr.Video( + label="Upload Video", + sources=["upload"], + height=300, + ) + caption_input = gr.Textbox( + label="Caption / Prompt", + placeholder=( + "Describe the audio you want to generate, e.g.:\n" + "A dog barking in the park with wind blowing" + ), + lines=4, + max_lines=8, + ) + with gr.Row(): + clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1) + submit_btn = gr.Button("🚀 Generate Audio", variant="primary", scale=2) + + # ---------- Right: live log ---------- + with gr.Column(scale=1): + gr.Markdown("### 📋 Run Log") + log_output = gr.Textbox( + label="", + lines=18, + max_lines=30, + interactive=False, + elem_classes=["mono"], + ) + + # ====================================================== + # Row 2 — Output video (original video + generated audio) + # ====================================================== + gr.Markdown("---") + gr.Markdown("### 📤 Output — Original video with generated audio") + + video_output = gr.Video( + label="Video + Generated Audio", + interactive=False, + height=400, + ) + + # ====================================================== + # Example prompts + # ====================================================== + gr.Markdown("---") + gr.Markdown("### 💡 Example Prompts (click to fill)") + + # ====================================================== + # Instructions + # ====================================================== + with gr.Accordion("📖 Instructions", open=False): + gr.Markdown(f""" +**Steps** +1. Upload a video file (mp4 / avi / mov / etc.). +2. Enter a text prompt describing the desired audio content. +3. Click **🚀 Generate Audio** and watch the log on the right for progress. +4. The output video (original visuals + generated audio) appears below when done. + +**Notes** +- All models are pre-loaded at startup — no warm-up delay on the first request. +- Everything stays in memory; only the final wav and merged mp4 are written to disk. +- A CUDA GPU is strongly recommended; CPU inference will be very slow. +- Queue depth is limited to 3 concurrent requests to avoid OOM. + +**Current model paths** +``` +MODEL_CONFIG_PATH = {MODEL_CONFIG_PATH} +CKPT_PATH = {CKPT_PATH} +VAE_CKPT_PATH = {VAE_CKPT_PATH} +SYNCHFORMER_CKPT_PATH = {SYNCHFORMER_CKPT_PATH} +``` + """) + + # ====================================================== + # Event bindings + # ====================================================== + submit_btn.click( + fn=generate_audio, + inputs=[video_input, caption_input], + outputs=[log_output, video_output], + show_progress=True, + ) + + def clear_all(): + return None, "", "", None + + clear_btn.click( + fn=clear_all, + inputs=[], + outputs=[video_input, caption_input, log_output, video_output], + ) + + return demo + + +# ==================== Entry Point ==================== + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="ThinkSound Gradio App") + parser.add_argument("--server_name", type=str, default="0.0.0.0", + help="Gradio server host") + parser.add_argument("--server_port", type=int, default=7860, + help="Gradio server port") + parser.add_argument("--share", action="store_true", + help="Create a public Gradio share link") + args = parser.parse_args() + + # ---- Check model files ---- + missing = [] + for name, path in [ + ("Model Config", MODEL_CONFIG_PATH), + ("Checkpoint", CKPT_PATH), + ("VAE Checkpoint", VAE_CKPT_PATH), + ("Synchformer", SYNCHFORMER_CKPT_PATH), + ]: + if not os.path.exists(path): + missing.append(f" ⚠️ {name}: {path}") + + if missing: + log.warning("The following model files were not found — please check your paths:") + for m in missing: + log.warning(m) + else: + log.info("✅ All model files found.") + + # ⭐ Load all models once at startup + load_all_models() + + demo = build_ui() + demo.queue(max_size=3) + demo.launch( + server_name=args.server_name, + server_port=args.server_port, + share=args.share, + show_error=True, + ) diff --git a/data_utils/__init__.py b/data_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/ext/synchformer/LICENSE b/data_utils/ext/synchformer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2f70bf24b6f45f458998bdf5746376c4832352ea --- /dev/null +++ b/data_utils/ext/synchformer/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Vladimir Iashin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/data_utils/ext/synchformer/__init__.py b/data_utils/ext/synchformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9eff0160aa046d712d9330c4201b0ccd4c0c51b0 --- /dev/null +++ b/data_utils/ext/synchformer/__init__.py @@ -0,0 +1 @@ +from data_utils.ext.synchformer.synchformer import Synchformer diff --git a/data_utils/ext/synchformer/divided_224_16x4.yaml b/data_utils/ext/synchformer/divided_224_16x4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9d20b76302a8af7928391643bd4b2d184e970aa --- /dev/null +++ b/data_utils/ext/synchformer/divided_224_16x4.yaml @@ -0,0 +1,84 @@ +TRAIN: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 32 + EVAL_PERIOD: 5 + CHECKPOINT_PERIOD: 5 + AUTO_RESUME: True + CHECKPOINT_EPOCH_RESET: True + CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 4 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 224 + INPUT_CHANNEL_NUM: [3] + MEAN: [0.5, 0.5, 0.5] + STD: [0.5, 0.5, 0.5] + PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 + PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames + INV_UNIFORM_SAMPLE: True + RANDOM_FLIP: False + REVERSE_INPUT_CHANNEL: True + USE_RAND_AUGMENT: True + RE_PROB: 0.0 + USE_REPEATED_AUG: False + USE_RANDOM_RESIZE_CROPS: False + COLORJITTER: False + GRAYSCALE: False + GAUSSIAN: False +SOLVER: + BASE_LR: 1e-4 + LR_POLICY: steps_with_relative_lrs + LRS: [1, 0.1, 0.01] + STEPS: [0, 20, 30] + MAX_EPOCH: 35 + MOMENTUM: 0.9 + WEIGHT_DECAY: 5e-2 + WARMUP_EPOCHS: 0.0 + OPTIMIZING_METHOD: adamw + USE_MIXED_PRECISION: True + SMOOTHING: 0.2 +SLOWFAST: + ALPHA: 8 +VIT: + PATCH_SIZE: 16 + PATCH_SIZE_TEMP: 2 + CHANNELS: 3 + EMBED_DIM: 768 + DEPTH: 12 + NUM_HEADS: 12 + MLP_RATIO: 4 + QKV_BIAS: True + VIDEO_INPUT: True + TEMPORAL_RESOLUTION: 8 + USE_MLP: True + DROP: 0.0 + POS_DROPOUT: 0.0 + DROP_PATH: 0.2 + IM_PRETRAINED: True + HEAD_DROPOUT: 0.0 + HEAD_ACT: tanh + PRETRAINED_WEIGHTS: vit_1k + ATTN_LAYER: divided +MODEL: + NUM_CLASSES: 174 + ARCH: slow + MODEL_NAME: VisionTransformer + LOSS_FUNC: cross_entropy +TEST: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 64 + NUM_ENSEMBLE_VIEWS: 1 + NUM_SPATIAL_CROPS: 3 +DATA_LOADER: + NUM_WORKERS: 4 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 4 +RNG_SEED: 0 +OUTPUT_DIR: . +TENSORBOARD: + ENABLE: True diff --git a/data_utils/ext/synchformer/motionformer.py b/data_utils/ext/synchformer/motionformer.py new file mode 100644 index 0000000000000000000000000000000000000000..148b5d3c7f021a8dfe38f7134a919c21d35e6bab --- /dev/null +++ b/data_utils/ext/synchformer/motionformer.py @@ -0,0 +1,400 @@ +import logging +from pathlib import Path + +import einops +import torch +from omegaconf import OmegaConf +from timm.layers import trunc_normal_ +from torch import nn + +from data_utils.ext.synchformer.utils import check_if_file_exists_else_download +from data_utils.ext.synchformer.video_model_builder import VisionTransformer + +FILE2URL = { + # cfg + 'motionformer_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml', + 'joint_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml', + 'divided_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml', + # ckpt + 'ssv2_motionformer_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth', + 'ssv2_joint_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth', + 'ssv2_divided_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth', +} + + +class MotionFormer(VisionTransformer): + ''' This class serves three puposes: + 1. Renames the class to MotionFormer. + 2. Downloads the cfg from the original repo and patches it if needed. + 3. Takes care of feature extraction by redefining .forward() + - if `extract_features=True` and `factorize_space_time=False`, + the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D) + and spatial and temporal transformer encoder layers are used. + - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True` + the output is of shape (B, D) and spatial and temporal transformer encoder layers + are used as well as the global representation is extracted from segments (extra pos emb + is added). + ''' + + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + factorize_space_time: bool = None, + agg_space_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ): + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.factorize_space_time = factorize_space_time + + if self.ckpt_path is not None: + check_if_file_exists_else_download(self.ckpt_path, FILE2URL) + ckpt = torch.load(self.ckpt_path, map_location='cpu') + mformer_ckpt2cfg = { + 'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml', + 'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml', + 'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml', + } + # init from motionformer ckpt or from our Stage I ckpt + # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to + # load the state dict differently + was_pt_on_avclip = self.ckpt_path.endswith( + '.pt') # checks if it is a stage I ckpt (FIXME: a bit generic) + if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())): + cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name] + elif was_pt_on_avclip: + # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it) + s1_cfg = ckpt.get('args', None) # Stage I cfg + if s1_cfg is not None: + s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path + # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch + if s1_vfeat_extractor_ckpt_path is not None: + cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name] + else: + cfg_fname = 'divided_224_16x4.yaml' + else: + cfg_fname = 'divided_224_16x4.yaml' + else: + raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.') + else: + was_pt_on_avclip = False + cfg_fname = 'divided_224_16x4.yaml' + # logging.info(f'No ckpt_path provided, using {cfg_fname} config.') + + if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']: + pos_emb_type = 'separate' + elif cfg_fname == 'joint_224_16x4.yaml': + pos_emb_type = 'joint' + + self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname + + check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL) + mformer_cfg = OmegaConf.load(self.mformer_cfg_path) + logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}') + + # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`) + mformer_cfg.VIT.ATTN_DROPOUT = 0.0 + mformer_cfg.VIT.POS_EMBED = pos_emb_type + mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True + mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing + mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg'] + + # finally init VisionTransformer with the cfg + super().__init__(mformer_cfg) + + # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt + if (self.ckpt_path is not None) and (not was_pt_on_avclip): + _ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False) + if len(_ckpt_load_status.missing_keys) > 0 or len( + _ckpt_load_status.unexpected_keys) > 0: + logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \ + f'Missing keys: {_ckpt_load_status.missing_keys}, ' \ + f'Unexpected keys: {_ckpt_load_status.unexpected_keys}') + else: + logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.') + + if self.extract_features: + assert isinstance(self.norm, + nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights' + # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger + self.pre_logits = nn.Identity() + # we don't need the classification head (saving memory) + self.head = nn.Identity() + self.head_drop = nn.Identity() + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.embed_dim, + nhead=self.num_heads, + activation=nn.GELU(), + batch_first=True, + dim_feedforward=self.mlp_ratio * self.embed_dim, + dropout=self.drop_rate, + layer_norm_eps=1e-6, + norm_first=True, + ) + # define adapters if needed + if self.factorize_space_time: + if agg_space_module == 'TransformerEncoderLayer': + self.spatial_attn_agg = SpatialTransformerEncoderLayer( + **transf_enc_layer_kwargs) + elif agg_space_module == 'AveragePooling': + self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t', + then_permute_pattern='BS D t -> BS t D') + if agg_time_module == 'TransformerEncoderLayer': + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == 'AveragePooling': + self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D') + elif 'Identity' in agg_time_module: + self.temp_attn_agg = nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == 'TransformerEncoderLayer': + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs) + elif agg_segments_module == 'AveragePooling': + self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D') + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + ckpt_weights = dict() + for k, v in ckpt['state_dict'].items(): + if k.startswith(('module.v_encoder.', 'v_encoder.')): + k = k.replace('module.', '').replace('v_encoder.', '') + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \ + f'Missing keys ({len(_load_status.missing_keys)}): ' \ + f'{_load_status.missing_keys}, \n' \ + f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \ + f'{_load_status.unexpected_keys} \n' \ + f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.') + else: + logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.') + + # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1 + # but it used to calculate the number of patches, so we need to set keep it + self.patch_embed.requires_grad_(False) + + def forward(self, x): + ''' + x is of shape (B, S, C, T, H, W) where S is the number of segments. + ''' + # Batch, Segments, Channels, T=frames, Height, Width + B, S, C, T, H, W = x.shape + # Motionformer expects a tensor of shape (1, B, C, T, H, W). + # The first dimension (1) is a dummy dimension to make the input tensor and won't be used: + # see `video_model_builder.video_input`. + # x = x.unsqueeze(0) # (1, B, S, C, T, H, W) + + orig_shape = (B, S, C, T, H, W) + x = x.view(B * S, C, T, H, W) # flatten batch and segments + x = self.forward_segments(x, orig_shape=orig_shape) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity` + + return x # x is (B, S, ...) + + def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor: + '''x is of shape (1, BS, C, T, H, W) where S is the number of segments.''' + x, x_mask = self.forward_features(x) + + assert self.extract_features + + # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + x = x[:, + 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC) + x = self.norm(x) + x = self.pre_logits(x) + if self.factorize_space_time: + x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D) + + x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D) + x = self.temp_attn_agg( + x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity` + + return x + + def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + ''' + feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions. + From `self.patch_embed_3d`, it follows that we could reshape feats with: + `feats.transpose(1, 2).view(B*S, D, t, h, w)` + ''' + B, S, C, T, H, W = orig_shape + D = self.embed_dim + + # num patches in each dimension + t = T // self.patch_embed_3d.z_block_size + h = self.patch_embed_3d.height + w = self.patch_embed_3d.width + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w) + + return feats + + +class BaseEncoderLayer(nn.TransformerEncoderLayer): + ''' + This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token + to the sequence and outputs the CLS token's representation. + This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream + and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream. + We also, optionally, add a positional embedding to the input sequence which + allows to reuse it for global aggregation (of segments) for both streams. + ''' + + def __init__(self, + add_pos_emb: bool = False, + pos_emb_drop: float = None, + pos_max_len: int = None, + *args_transformer_enc, + **kwargs_transformer_enc): + super().__init__(*args_transformer_enc, **kwargs_transformer_enc) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim)) + trunc_normal_(self.cls_token, std=.02) + + # add positional embedding + self.add_pos_emb = add_pos_emb + if add_pos_emb: + self.pos_max_len = 1 + pos_max_len # +1 (for CLS) + self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim)) + self.pos_drop = nn.Dropout(pos_emb_drop) + trunc_normal_(self.pos_emb, std=.02) + + self.apply(self._init_weights) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + ''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)''' + batch_dim = x.shape[0] + + # add CLS token + cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension + x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D) + if x_mask is not None: + cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, + device=x_mask.device) # 1=keep; 0=mask + x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len) + B, N = x_mask_w_cls.shape + # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks + x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\ + .expand(-1, self.self_attn.num_heads, N, -1)\ + .reshape(B * self.self_attn.num_heads, N, N) + assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool' + x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask) + else: + x_mask_w_cls = None + + # add positional embedding + if self.add_pos_emb: + seq_len = x.shape[ + 1] # (don't even think about moving it before the CLS token concatenation) + assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})' + x = x + self.pos_emb[:, :seq_len, :] + x = self.pos_drop(x) + + # apply encoder layer (calls nn.TransformerEncoderLayer.forward); + x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D) + + # CLS token is expected to hold spatial information for each frame + x = x[:, 0, :] # (batch_dim, D) + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token', 'pos_emb'} + + +class SpatialTransformerEncoderLayer(BaseEncoderLayer): + ''' Aggregates spatial dimensions by applying attention individually to each frame. ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + ''' x is of shape (B*S, D, t, h, w) where S is the number of segments. + if specified x_mask (B*S, t, h, w), 0=masked, 1=kept + Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. ''' + BS, D, t, h, w = x.shape + + # time as a batch dimension and flatten spatial dimensions as sequence + x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D') + # similar to mask + if x_mask is not None: + x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)') + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t) + + # (B*S, t, D) + return x + + +class TemporalTransformerEncoderLayer(BaseEncoderLayer): + ''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation + in both streams. ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + ''' x is of shape (B*S, t, D) where S is the number of segments. + Returns a tensor of shape (B*S, D) pooling temporal information. ''' + BS, t, D = x.shape + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x) # (B*S, D) + + return x # (B*S, D) + + +class AveragePooling(nn.Module): + + def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None: + ''' patterns are e.g. "bs t d -> bs d" ''' + super().__init__() + # TODO: need to register them as buffers (but fails because these are strings) + self.reduce_fn = 'mean' + self.avg_pattern = avg_pattern + self.then_permute_pattern = then_permute_pattern + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + x = einops.reduce(x, self.avg_pattern, self.reduce_fn) + if self.then_permute_pattern is not None: + x = einops.rearrange(x, self.then_permute_pattern) + return x diff --git a/data_utils/ext/synchformer/synchformer.py b/data_utils/ext/synchformer/synchformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd580fa6cc1701eedeeebc5fcc3951755207df96 --- /dev/null +++ b/data_utils/ext/synchformer/synchformer.py @@ -0,0 +1,55 @@ +import logging +from typing import Any, Mapping + +import torch +from torch import nn + +from data_utils.ext.synchformer.motionformer import MotionFormer + + +class Synchformer(nn.Module): + + def __init__(self): + super().__init__() + + self.vfeat_extractor = MotionFormer(extract_features=True, + factorize_space_time=True, + agg_space_module='TransformerEncoderLayer', + agg_time_module='torch.nn.Identity', + add_global_repr=False) + + # self.vfeat_extractor = instantiate_from_config(vfeat_extractor) + # self.afeat_extractor = instantiate_from_config(afeat_extractor) + # # bridging the s3d latent dim (1024) into what is specified in the config + # # to match e.g. the transformer dim + # self.vproj = instantiate_from_config(vproj) + # self.aproj = instantiate_from_config(aproj) + # self.transformer = instantiate_from_config(transformer) + + def forward(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): + # discard all entries except vfeat_extractor + sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} + + return super().load_state_dict(sd, strict) + + +if __name__ == "__main__": + model = Synchformer().cuda().eval() + sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) + model.load_state_dict(sd) + + vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() + features = model.extract_vfeats(vid, for_loop=False).detach().cpu() + print(features.shape) + + # extract and save the state dict only + # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model'] + # torch.save(sd, './ext_weights/synchformer_state_dict.pth') diff --git a/data_utils/ext/synchformer/utils.py b/data_utils/ext/synchformer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a797eb9c66f04b7c29934bfc384c935cdf441a62 --- /dev/null +++ b/data_utils/ext/synchformer/utils.py @@ -0,0 +1,92 @@ +from hashlib import md5 +from pathlib import Path + +import requests +from tqdm import tqdm + +PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a' +FNAME2LINK = { + # S3: Synchability: AudioSet (run 2) + '24-01-22T20-34-52.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt', + 'cfg-24-01-22T20-34-52.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml', + # S2: Synchformer: AudioSet (run 2) + '24-01-04T16-39-21.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt', + 'cfg-24-01-04T16-39-21.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml', + # S2: Synchformer: AudioSet (run 1) + '23-08-28T11-23-23.pt': + f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt', + 'cfg-23-08-28T11-23-23.yaml': + f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml', + # S2: Synchformer: LRS3 (run 2) + '23-12-23T18-33-57.pt': + f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt', + 'cfg-23-12-23T18-33-57.yaml': + f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml', + # S2: Synchformer: VGS (run 2) + '24-01-02T10-00-53.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt', + 'cfg-24-01-02T10-00-53.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml', + # SparseSync: ft VGGSound-Full + '22-09-21T21-00-52.pt': + f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt', + 'cfg-22-09-21T21-00-52.yaml': + f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml', + # SparseSync: ft VGGSound-Sparse + '22-07-28T15-49-45.pt': + f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt', + 'cfg-22-07-28T15-49-45.yaml': + f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml', + # SparseSync: only pt on LRS3 + '22-07-13T22-25-49.pt': + f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt', + 'cfg-22-07-13T22-25-49.yaml': + f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml', + # SparseSync: feature extractors + 'ResNetAudio-22-08-04T09-51-04.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s + 'ResNetAudio-22-08-03T23-14-49.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s + 'ResNetAudio-22-08-03T23-14-28.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s + 'ResNetAudio-22-06-24T08-10-33.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s + 'ResNetAudio-22-06-24T17-31-07.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s + 'ResNetAudio-22-06-24T23-57-11.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s + 'ResNetAudio-22-06-25T04-35-42.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s +} + + +def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): + '''Checks if file exists, if not downloads it from the link to the path''' + path = Path(path) + if not path.exists(): + path.parent.mkdir(exist_ok=True, parents=True) + link = fname2link.get(path.name, None) + if link is None: + raise ValueError(f'Cant find the checkpoint file: {path}.', + f'Please download it manually and ensure the path exists.') + with requests.get(fname2link[path.name], stream=True) as r: + total_size = int(r.headers.get('content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=True) as pbar: + with open(path, 'wb') as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def get_md5sum(path): + hash_md5 = md5() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(4096 * 8), b''): + hash_md5.update(chunk) + md5sum = hash_md5.hexdigest() + return md5sum diff --git a/data_utils/ext/synchformer/video_model_builder.py b/data_utils/ext/synchformer/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..da6decd3ab8a2f938e7df66f046451fff6413b5f --- /dev/null +++ b/data_utils/ext/synchformer/video_model_builder.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition + +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +from timm.layers import trunc_normal_ + +from data_utils.ext.synchformer import vit_helper + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage """ + + def __init__(self, cfg): + super().__init__() + self.img_size = cfg.DATA.TRAIN_CROP_SIZE + self.patch_size = cfg.VIT.PATCH_SIZE + self.in_chans = cfg.VIT.CHANNELS + if cfg.TRAIN.DATASET == "Epickitchens": + self.num_classes = [97, 300] + else: + self.num_classes = cfg.MODEL.NUM_CLASSES + self.embed_dim = cfg.VIT.EMBED_DIM + self.depth = cfg.VIT.DEPTH + self.num_heads = cfg.VIT.NUM_HEADS + self.mlp_ratio = cfg.VIT.MLP_RATIO + self.qkv_bias = cfg.VIT.QKV_BIAS + self.drop_rate = cfg.VIT.DROP + self.drop_path_rate = cfg.VIT.DROP_PATH + self.head_dropout = cfg.VIT.HEAD_DROPOUT + self.video_input = cfg.VIT.VIDEO_INPUT + self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION + self.use_mlp = cfg.VIT.USE_MLP + self.num_features = self.embed_dim + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT + self.head_act = cfg.VIT.HEAD_ACT + self.cfg = cfg + + # Patch Embedding + self.patch_embed = vit_helper.PatchEmbed(img_size=224, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim) + + # 3D Patch Embedding + self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size, + temporal_resolution=self.temporal_resolution, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP) + self.patch_embed_3d.proj.weight.data = torch.zeros_like( + self.patch_embed_3d.proj.weight.data) + + # Number of patches + if self.video_input: + num_patches = self.patch_embed.num_patches * self.temporal_resolution + else: + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + # CLS token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.cls_token, std=.02) + + # Positional embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) + self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) + trunc_normal_(self.pos_embed, std=.02) + + if self.cfg.VIT.POS_EMBED == "joint": + self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) + trunc_normal_(self.st_embed, std=.02) + elif self.cfg.VIT.POS_EMBED == "separate": + self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) + + # Layer Blocks + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] + if self.cfg.VIT.ATTN_LAYER == "divided": + self.blocks = nn.ModuleList([ + vit_helper.DividedSpaceTimeBlock( + attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) for i in range(self.depth) + ]) + else: + self.blocks = nn.ModuleList([ + vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE) + for i in range(self.depth) + ]) + self.norm = norm_layer(self.embed_dim) + + # MLP head + if self.use_mlp: + hidden_dim = self.embed_dim + if self.head_act == 'tanh': + # logging.info("Using TanH activation in MLP") + act = nn.Tanh() + elif self.head_act == 'gelu': + # logging.info("Using GELU activation in MLP") + act = nn.GELU() + else: + # logging.info("Using ReLU activation in MLP") + act = nn.ReLU() + self.pre_logits = nn.Sequential( + OrderedDict([ + ('fc', nn.Linear(self.embed_dim, hidden_dim)), + ('act', act), + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier Head + self.head_drop = nn.Dropout(p=self.head_dropout) + if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + for a, i in enumerate(range(len(self.num_classes))): + setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) + else: + self.head = nn.Linear(self.embed_dim, + self.num_classes) if self.num_classes > 0 else nn.Identity() + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.cfg.VIT.POS_EMBED == "joint": + return {'pos_embed', 'cls_token', 'st_embed'} + else: + return {'pos_embed', 'cls_token', 'temp_embed'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()) + + def forward_features(self, x): + # if self.video_input: + # x = x[0] + B = x.shape[0] + + # Tokenize input + # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: + # for simplicity of mapping between content dimensions (input x) and token dims (after patching) + # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): + + # apply patching on input + x = self.patch_embed_3d(x) + tok_mask = None + + # else: + # tok_mask = None + # # 2D tokenization + # if self.video_input: + # x = x.permute(0, 2, 1, 3, 4) + # (B, T, C, H, W) = x.shape + # x = x.reshape(B * T, C, H, W) + + # x = self.patch_embed(x) + + # if self.video_input: + # (B2, T2, D2) = x.shape + # x = x.reshape(B, T * T2, D2) + + # Append CLS token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + # if tok_mask is not None: + # # prepend 1(=keep) to the mask to account for the CLS token as well + # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) + + # Interpolate positinoal embeddings + # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: + # pos_embed = self.pos_embed + # N = pos_embed.shape[1] - 1 + # npatch = int((x.size(1) - 1) / self.temporal_resolution) + # class_emb = pos_embed[:, 0] + # pos_embed = pos_embed[:, 1:] + # dim = x.shape[-1] + # pos_embed = torch.nn.functional.interpolate( + # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + # scale_factor=math.sqrt(npatch / N), + # mode='bicubic', + # ) + # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + # else: + new_pos_embed = self.pos_embed + npatch = self.patch_embed.num_patches + + # Add positional embeddings to input + if self.video_input: + if self.cfg.VIT.POS_EMBED == "separate": + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) + tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + x = x + total_pos_embed + elif self.cfg.VIT.POS_EMBED == "joint": + x = x + self.st_embed + else: + # image input + x = x + new_pos_embed + + # Apply positional dropout + x = self.pos_drop(x) + + # Encoding using transformer layers + for i, blk in enumerate(self.blocks): + x = blk(x, + seq_len=npatch, + num_frames=self.temporal_resolution, + approx=self.cfg.VIT.APPROX_ATTN_TYPE, + num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, + tok_mask=tok_mask) + + ### v-iashin: I moved it to the forward pass + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + ### + return x, tok_mask + + # def forward(self, x): + # x = self.forward_features(x) + # ### v-iashin: here. This should leave the same forward output as before + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + # ### + # x = self.head_drop(x) + # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + # output = [] + # for head in range(len(self.num_classes)): + # x_out = getattr(self, "head%d" % head)(x) + # if not self.training: + # x_out = torch.nn.functional.softmax(x_out, dim=-1) + # output.append(x_out) + # return output + # else: + # x = self.head(x) + # if not self.training: + # x = torch.nn.functional.softmax(x, dim=-1) + # return x diff --git a/data_utils/ext/synchformer/vit_helper.py b/data_utils/ext/synchformer/vit_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6af730a135bf49240ec439c81c9ad0aa5c9a505e --- /dev/null +++ b/data_utils/ext/synchformer/vit_helper.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition +"""Video models.""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from timm.layers import to_2tuple +from torch import einsum +from torch.nn import functional as F + +default_cfgs = { + 'vit_1k': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_1k_large': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', +} + + +def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): + sim = einsum('b i d, b j d -> b i j', q, k) + # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) + if tok_mask is not None: + BSH, N = tok_mask.shape + sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, + float('-inf')) # 1 - broadcasts across N + attn = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', attn, v) + return out + + +class DividedAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # init to zeros + self.qkv.weight.data.fill_(0) + self.qkv.bias.data.fill_(0) + self.proj.weight.data.fill_(1) + self.proj.bias.data.fill_(0) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): + # num of heads variable + h = self.num_heads + + # project x to q, k, v vaalues + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + if tok_mask is not None: + # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d + assert len(tok_mask.shape) == 2 + tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) + + # Scale q + q *= self.scale + + # Take out cls_q, cls_k, cls_v + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + # the same for masking + if tok_mask is not None: + cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] + else: + cls_mask, mask_ = None, None + + # let CLS token attend to key / values of all patches across time and space + cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) + + # rearrange across time or space + q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), + (q_, k_, v_)) + + # expand CLS token keys and values across time or space and concat + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + # the same for masking (if provided) + if tok_mask is not None: + # since mask does not have the latent dim (d), we need to remove it from einops dims + mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''), + **einops_dims) + cls_mask = repeat(cls_mask, 'b () -> (b r) ()', + r=r) # expand cls_mask across time or space + mask_ = torch.cat((cls_mask, mask_), dim=1) + + # attention + out = qkv_attn(q_, k_, v_, tok_mask=mask_) + + # merge back time or space + out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) + + # concat back the cls token + out = torch.cat((cls_out, out), dim=1) + + # merge back the heads + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + ## to out + x = self.proj(out) + x = self.proj_drop(x) + return x + + +class DividedSpaceTimeBlock(nn.Module): + + def __init__(self, + dim=768, + num_heads=12, + attn_type='divided', + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + + self.einops_from_space = 'b (f n) d' + self.einops_to_space = '(b f) n d' + self.einops_from_time = 'b (f n) d' + self.einops_to_time = '(b n) f d' + + self.norm1 = norm_layer(dim) + + self.attn = DividedAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + + self.timeattn = DividedAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.norm3 = norm_layer(dim) + + def forward(self, + x, + seq_len=196, + num_frames=8, + approx='none', + num_landmarks=128, + tok_mask: torch.Tensor = None): + time_output = self.timeattn(self.norm3(x), + self.einops_from_time, + self.einops_to_time, + n=seq_len, + tok_mask=tok_mask) + time_residual = x + time_output + + space_output = self.attn(self.norm1(time_residual), + self.einops_from_space, + self.einops_to_space, + f=num_frames, + tok_mask=tok_mask) + space_residual = time_residual + self.drop_path(space_output) + + x = space_residual + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) + patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ Image to Patch Embedding """ + + def __init__(self, + img_size=224, + temporal_resolution=4, + in_chans=3, + patch_size=16, + z_block_size=2, + embed_dim=768, + flatten=True): + super().__init__() + self.height = (img_size // patch_size) + self.width = (img_size // patch_size) + ### v-iashin: these two are incorrect + # self.frames = (temporal_resolution // z_block_size) + # self.num_patches = self.height * self.width * self.frames + self.z_block_size = z_block_size + ### + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=(z_block_size, patch_size, patch_size), + stride=(z_block_size, patch_size, patch_size)) + self.flatten = flatten + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + return x + + +class HeadMLP(nn.Module): + + def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): + super(HeadMLP, self).__init__() + self.n_input = n_input + self.n_classes = n_classes + self.n_hidden = n_hidden + if n_hidden is None: + # use linear classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), + nn.Linear(n_input, n_classes, bias=True)) + else: + # use simple MLP classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), + nn.Linear(n_input, n_hidden, bias=True), + nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(n_hidden, n_classes, bias=True)) + print(f"Dropout-NLP: {p}") + + def forward(self, x): + return self.block_forward(x) + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def adapt_input_conv(in_chans, conv_weight, agg='sum'): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + if agg == 'sum': + print("Summing conv1 weights") + conv_weight = conv_weight.sum(dim=1, keepdim=True) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + if agg == 'sum': + print("Summing conv1 weights") + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + conv_weight = conv_weight.repeat(1, in_chans, 1, 1) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, + cfg=None, + num_classes=1000, + in_chans=3, + filter_fn=None, + strict=True, + progress=False): + # Load state dict + assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]") + state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + input_convs = 'patch_embed.proj' + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs, ) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, + state_dict[weight_name], + agg='avg') + print( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)' + ) + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + print( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.' + ) + + classifier_name = 'head' + label_offset = cfg.get('label_offset', 0) + pretrain_classes = 1000 + if num_classes != pretrain_classes: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + loaded_state = state_dict + self_state = model.state_dict() + all_names = set(self_state.keys()) + saved_names = set([]) + for name, param in loaded_state.items(): + param = param + if 'module.' in name: + name = name.replace('module.', '') + if name in self_state.keys() and param.shape == self_state[name].shape: + saved_names.add(name) + self_state[name].copy_(param) + else: + print(f"didnt load: {name} of shape: {param.shape}") + print("Missing Keys:") + print(all_names - saved_names) diff --git a/data_utils/extract_training_audio.py b/data_utils/extract_training_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6e4a23fe64b844b5022bda921621e88fd91a66 --- /dev/null +++ b/data_utils/extract_training_audio.py @@ -0,0 +1,137 @@ +import argparse +import os +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from tqdm import tqdm # 导入 tqdm +import logging # 导入 logging +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from data_utils.v2a_utils.audio_text_dataset import Audio_Text +from data_utils.v2a_utils.feature_utils_224_audio import FeaturesUtils +import torchaudio +from einops import rearrange +from torch.utils.data.dataloader import default_collate + + +# 设置日志配置 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def setup(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +def cleanup(): + dist.destroy_process_group() + +def error_avoidance_collate(batch): + batch = list(filter(lambda x: x is not None, batch)) + if len(batch) == 0: + return None # 或 return {} + return default_collate(batch) +def main(args): + rank = int(os.environ["RANK"]) # 从环境变量中获取 rank + world_size = int(os.environ["WORLD_SIZE"]) # 从环境变量中获取 world size + setup(rank, world_size) + tsv_path = args.tsv_path + save_dir = args.save_dir + root = args.root + dataset = Audio_Text( + root=root, + tsv_path=tsv_path, + sample_rate=args.sample_rate, + duration_sec=args.duration_sec, + audio_samples=args.audio_samples, + start_row=args.start_row, + end_row=args.end_row, + save_dir=save_dir + ) + os.makedirs(save_dir, exist_ok=True) + # 使用 DataLoader 加载数据集,增加 batch_size 和 num_workers + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=1, sampler=train_sampler, num_workers=16, drop_last=False,collate_fn=error_avoidance_collate) + + feature_extractor = FeaturesUtils( + vae_ckpt=args.vae_ckpt, + vae_config=args.vae_config, + enable_conditions=True, + synchformer_ckpt=args.synchformer_ckpt + ).eval().cuda(rank) + + # 使用 DistributedDataParallel 支持多显卡 + feature_extractor = torch.nn.parallel.DistributedDataParallel(feature_extractor, device_ids=[rank]) + + # 使用 tqdm 显示进度条 + for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")): + # 使用 torch.no_grad() 来加快推理速度 + if data is None: + continue + ids = data['id'] # 获取当前批次的所有 ID + with torch.no_grad(): + audio = data['audio'].cuda(rank, non_blocking=True) + if audio.size(0) == 0: + continue + output = { + 'caption': data['caption'], + 'caption_cot': data['caption_cot'] + } + # logging.info(f'Processing batch {i} with IDs: {ids}') # 添加日志记录 + + # latent = feature_extractor.module.encode_audio(audio) + # output['latent'] = latent.detach().cpu() + + caption = data['caption'] + # print(caption,'debug!!!!!!!!!') + # metaclip_global_text_features, metaclip_text_features = feature_extractor.module.encode_text(caption) + # output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu() + # output['metaclip_text_features'] = metaclip_text_features.detach().cpu() + + caption_cot = data['caption_cot'] + t5_features = feature_extractor.module.encode_t5_text(caption_cot) + output['t5_features'] = t5_features.detach().cpu() + + # 保存每个样本的输出 + for j in range(audio.size(0)): # 遍历当前批次的每个样本 + sample_output = { + 'id': ids[j], + 'caption': output['caption'][j], + 'caption_cot': output['caption_cot'][j], + # 'latent': output['latent'][j], + # 'metaclip_global_text_features': output['metaclip_global_text_features'][j], + # 'metaclip_text_features': output['metaclip_text_features'][j], + 't5_features': output['t5_features'][j] + } + torch.save(sample_output, f'{save_dir}/{ids[j]}.pth') + + ## test the sync between videos and audios + # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100) + # recon_audio = feature_extractor.decode_audio(latent) + # recon_audio = rearrange(recon_audio, "b d n -> d (b n)") + # id = data['id'] + # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100) + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4') + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4') + + cleanup() + +if __name__ == '__main__': + # print('i am rank',os.environ["RANK"]) + parser = argparse.ArgumentParser(description='Extract Audio Training Latents') + parser.add_argument('--root', type=str, default='dataset/vggsound/raw_audios/test', help='Root directory of the audio dataset') + parser.add_argument('--tsv_path', type=str, default='cot_vgg_test_mix_coarse.csv', help='Path to the TSV file') + parser.add_argument('--save-dir', type=str, default='vgg_cot_extra/cot_coarse', help='Save Directory') + parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio') + parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds') + parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples') + parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') + parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') + parser.add_argument('--start-row', type=int, default=0, help='start row') + parser.add_argument('--end-row', type=int, default=None, help='end row') + + args = parser.parse_args() + + # 直接使用 torch.distributed.launch 启动 + main(args=args) # 这里的 rank 需要在命令行中指定 + + diff --git a/data_utils/extract_training_video.py b/data_utils/extract_training_video.py new file mode 100644 index 0000000000000000000000000000000000000000..92ea01c61e08d49bd4572d28ce5d30d4e5283b19 --- /dev/null +++ b/data_utils/extract_training_video.py @@ -0,0 +1,141 @@ +import argparse +import os +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from tqdm import tqdm # 导入 tqdm +import logging # 导入 logging +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from data_utils.v2a_utils.vggsound_224 import VGGSound +from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils +import torchaudio +from einops import rearrange +from torch.utils.data.dataloader import default_collate + + +# 设置日志配置 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + +def setup(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + +def cleanup(): + dist.destroy_process_group() + +def error_avoidance_collate(batch): + batch = list(filter(lambda x: x is not None, batch)) + return default_collate(batch) + +def main(args): + rank = int(os.environ["RANK"]) # 从环境变量中获取 rank + world_size = int(os.environ["WORLD_SIZE"]) # 从环境变量中获取 world size + setup(rank, world_size) + + dataset = VGGSound( + root=args.root, + tsv_path=args.tsv_path, + sample_rate=args.sample_rate, + duration_sec=args.duration_sec, + audio_samples=args.audio_samples, + start_row=args.start_row, + end_row=args.end_row, + save_dir=args.save_dir + ) + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + # 使用 DataLoader 加载数据集,增加 batch_size 和 num_workers + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=1, sampler=train_sampler, num_workers=8, drop_last=False,collate_fn=error_avoidance_collate) + + feature_extractor = FeaturesUtils( + vae_ckpt=args.vae_ckpt, + vae_config=args.vae_config, + enable_conditions=True, + synchformer_ckpt=args.synchformer_ckpt + ).eval().cuda(rank) + + # 使用 DistributedDataParallel 支持多显卡 + feature_extractor = torch.nn.parallel.DistributedDataParallel(feature_extractor, device_ids=[rank]) + + # 使用 tqdm 显示进度条 + for i, data in enumerate(tqdm(dataloader, desc="Processing", unit="batch")): + # 使用 torch.no_grad() 来加快推理速度 + ids = data['id'] # 获取当前批次的所有 ID + with torch.no_grad(): + audio = data['audio'].cuda(rank, non_blocking=True) + + output = { + 'caption': data['caption'], + 'caption_cot': data['caption_cot'] + } + # logging.info(f'Processing batch {i} with IDs: {ids}') # 添加日志记录 + + latent = feature_extractor.module.encode_audio(audio) + output['latent'] = latent.detach().cpu() + + clip_video = data['clip_video'].cuda(rank, non_blocking=True) + # logging.info(f'Processing batch {i} with shape: {clip_video.shape}') # 添加日志记录 + clip_features = feature_extractor.module.encode_video_with_clip(clip_video) + output['metaclip_features'] = clip_features.detach().cpu() + + sync_video = data['sync_video'].cuda(rank, non_blocking=True) + sync_features = feature_extractor.module.encode_video_with_sync(sync_video) + output['sync_features'] = sync_features.detach().cpu() + + caption = data['caption'] + metaclip_global_text_features, metaclip_text_features = feature_extractor.module.encode_text(caption) + output['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu() + output['metaclip_text_features'] = metaclip_text_features.detach().cpu() + + caption_cot = data['caption_cot'] + t5_features = feature_extractor.module.encode_t5_text(caption_cot) + output['t5_features'] = t5_features.detach().cpu() + + + # 保存每个样本的输出 + for j in range(audio.size(0)): # 遍历当前批次的每个样本 + sample_output = { + 'id': ids[j], + 'caption': output['caption'][j], + 'caption_cot': output['caption_cot'][j], + 'latent': output['latent'][j], + 'metaclip_features': output['metaclip_features'][j], + 'sync_features': output['sync_features'][j], + 'metaclip_global_text_features': output['metaclip_global_text_features'][j], + 'metaclip_text_features': output['metaclip_text_features'][j], + 't5_features': output['t5_features'][j], + } + torch.save(sample_output, f'{save_dir}/{ids[j]}.pth') + + ## test the sync between videos and audios + # torchaudio.save(f'input_{i}.wav',data['audio'],sample_rate=44100) + # recon_audio = feature_extractor.decode_audio(latent) + # recon_audio = rearrange(recon_audio, "b d n -> d (b n)") + # id = data['id'] + # torchaudio.save(f'recon_{i}.wav',recon_audio.cpu(),sample_rate=44100) + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i recon_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest out_{i}.mp4') + # os.system(f'ffmpeg -y -i dataset/vggsound/video/train/{id}.mp4 -i input_{i}.wav -t 9 -map 0:v -map 1:a -c:v copy -c:a aac -strict experimental -shortest input_{i}.mp4') + cleanup() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Extract Video Training Latents') + parser.add_argument('--root', type=str, default='dataset/vggsound/video/test', help='Root directory of the video dataset') + parser.add_argument('--tsv_path', type=str, default='dataset/vggsound/split_txt/test_caption.csv', help='Path to the TSV file') + parser.add_argument('--save-dir', type=str, default='dataset/vggsound/video_text_latents/test', help='Save Directory') + parser.add_argument('--sample_rate', type=int, default=44100, help='Sample rate of the audio') + parser.add_argument('--duration_sec', type=float, default=9.0, help='Duration of the audio in seconds') + parser.add_argument('--audio_samples', type=int, default=397312, help='Number of audio samples') + parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') + parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') + parser.add_argument('--start-row', type=int, default=0, help='start row') + parser.add_argument('--end-row', type=int, default=None, help='end row') + + args = parser.parse_args() + + # 直接使用 torch.distributed.launch 启动 + main(args=args) # 这里的 rank 需要在命令行中指定 + diff --git a/data_utils/prismaudio_data_process.py b/data_utils/prismaudio_data_process.py new file mode 100644 index 0000000000000000000000000000000000000000..c4491e3620bac5801883049f34ccacb43a028ab3 --- /dev/null +++ b/data_utils/prismaudio_data_process.py @@ -0,0 +1,151 @@ +import argparse +import os +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" +import sys +import torch +import logging +import torch.distributed as dist +from torch.utils.data import DataLoader, distributed +from tqdm import tqdm +import time +import numpy as np +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from data_utils.v2a_utils.feature_utils_288 import FeaturesUtils +from data_utils.v2a_utils.thinksound_288_al import VGGSound +from torch.utils.data.dataloader import default_collate + +def setup_distributed(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +def cleanup_distributed(): + dist.destroy_process_group() + + +def error_avoidance_collate(batch): + batch = list(filter(lambda x: x is not None, batch)) + if len(batch) == 0: + return None # 或 return {} + return default_collate(batch) + +def process_batch(data, model, rank, inference_mode=False): + output = { + 'caption_cot': data['caption_cot'], + 'latent': [], + 'global_video_features': [], + 'video_features': [], + 'global_text_features': [], + 'text_features': [], + 'sync_features': [], + } + #start_time = time.time() + + with torch.no_grad(): + + text_features = model.module.encode_t5_text(data['caption_cot']) + output['text_features'] = text_features.detach().cpu().numpy() + if not inference_mode: + latent = model.module.encode_audio(data['audio'].cuda(rank, non_blocking=True)) + output['latent'] = latent.detach().cpu().numpy() + else: + output['latent'] = [None] * len(text_features) + + video_feat,frame_embed,_,text_feat= model.module.encode_video_and_text_with_videoprism(data['clip_video'], data['caption_cot']) + + output['global_video_features'].append(np.array(video_feat)) + output['video_features'].append(np.array(frame_embed)) + output['global_text_features'].append(np.array(text_feat)) + + sync_video = data['sync_video'].cuda(rank, non_blocking=True) + sync_features = model.module.encode_video_with_sync(sync_video) + output['sync_features'] = sync_features.detach().cpu().numpy() + + + + return output + + +def save_outputs(output, ids, save_dir, add_audio_path=None, add_video_path=None): + for i, sample_id in enumerate(ids): + np.savez( + os.path.join(save_dir, f"{sample_id}.npz"), + id=sample_id, + audio_path=os.path.join(add_audio_path,f"{sample_id}.wav") if add_audio_path is not None else None, + video_path=os.path.join(add_video_path,f"{sample_id}.mp4") if add_video_path is not None else None, + caption_cot=output['caption_cot'][i], + latent=output['latent'][i], + global_video_features=output['global_video_features'][i], + video_features=output['video_features'][i], + global_text_features=output['global_text_features'][i], + text_features=output['text_features'][i], + sync_features=output['sync_features'][i], + ) + + +def main(args): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + setup_distributed(rank, world_size) + + dataset = VGGSound( + root=args.root, + tsv_path=args.tsv_path, + sample_rate=args.sample_rate, + start_row=args.start_row, + end_row=args.end_row, + save_dir=args.save_dir, + inference_mode = args.inference_mode + ) + + os.makedirs(args.save_dir, exist_ok=True) + + sampler = distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32, + drop_last=False, collate_fn=error_avoidance_collate, pin_memory=True) + + model = FeaturesUtils( + vae_ckpt=args.vae_ckpt if not args.inference_mode else None, + vae_config=args.vae_config, + enable_conditions=True, + synchformer_ckpt=args.synchformer_ckpt + ) + model = model.eval().cuda(rank) + + torch.compile(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + + for data in tqdm(dataloader, desc="Processing", unit="batch"): + if data is None: + continue + ids = data['id'] + try: + output = process_batch(data, model, rank, args.inference_mode) + save_outputs(output, ids, args.save_dir, args.add_audio_path, args.add_video_path) + except Exception as e: + logging.error(f"Error processing sample IDs {ids}: {e}") + + cleanup_distributed() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + parser = argparse.ArgumentParser(description='Extract Video Training Latents') + parser.add_argument('--root', default='videos') + parser.add_argument('--tsv_path', default='cot_coarse/cot.csv') + parser.add_argument('--save-dir', default='results') + parser.add_argument('--sample_rate', type=int, default=44100, help='Audio sample rate') + parser.add_argument('--vae_ckpt', type=str, default='ckpts/vae.ckpt', help='Path to the VAE checkpoint') + parser.add_argument('--vae_config', type=str, default='ThinkSound/configs/model_configs/stable_audio_2_0_vae.json', help='Path to the VAE configuration file') + parser.add_argument('--synchformer_ckpt', type=str, default='ckpts/synchformer_state_dict.pth', help='Path to the Synchformer checkpoint') + parser.add_argument('--start-row', '-s', type=int, default=0, help='Start row index') + parser.add_argument('--end-row', '-e', type=int, default=None, help='End row index') + parser.add_argument('--add_audio_path', default=None, + help='Provide the original audio file path required for ITD reward in GRPO') + parser.add_argument('--add_video_path', default=None, + help='Provide the video path file required for Synchformer reward in GRPO') + parser.add_argument('--inference_mode', default=False, help='inference_mode') + + args = parser.parse_args() + main(args) diff --git a/data_utils/v2a_utils/__init__.py b/data_utils/v2a_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/v2a_utils/audio_text_dataset.py b/data_utils/v2a_utils/audio_text_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5a692380ef635f1ea44902b8f13a9eac5a7b7de4 --- /dev/null +++ b/data_utils/v2a_utils/audio_text_dataset.py @@ -0,0 +1,173 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class Audio_Text(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.cots = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + # if id in videos: + self.labels.append(label) + # print(label,'debug1!!!!!!!!!') + self.cots.append(record['caption_cot']) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.resampler = {} + + def sample(self, idx: int): + video_id = self.videos[idx] + label = self.labels[idx] + cot = self.cots[idx] + audio_path = os.path.join(self.root, f'{video_id}.wav') + if not os.path.exists(audio_path): + audio_path = os.path.join(self.root, f'{video_id}.flac') + if not os.path.exists(audio_path): + raise RuntimeError(f'Audio is not exist {audio_path}') + audio_chunk, sample_rate = torchaudio.load(audio_path) + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + + abs_max = audio_chunk[0].abs().max() + + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + elif audio_chunk.shape[0] > 2: + audio_chunk = audio_chunk[:2] + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + assert audio_chunk.shape == (2, 397312), f'error shape:{video_id},{audio_chunk.shape}' + # print(label,'debug2!!!!!!!!!') + data = { + 'id': video_id, + 'caption': label, + 'caption_cot': cot, + 'audio': audio_chunk, + } + + return data + + def __getitem__(self, idx: int): + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/audioset_224.py b/data_utils/v2a_utils/audioset_224.py new file mode 100644 index 0000000000000000000000000000000000000000..f363d157ebc72d3b929937f6de031eec63315cf4 --- /dev/null +++ b/data_utils/v2a_utils/audioset_224.py @@ -0,0 +1,315 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class Audioset(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.caption_t5s = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['label'] + caption_t5 = record['caption_t5'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + self.caption_t5s.append(caption_t5) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + caption_t5 = self.caption_t5s[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + # reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_path = os.path.join("dataset/3_Audioset/audios/sound",video_id+'.wav') + assert os.path.exists(audio_path), f'{audio_path} not exists' + audio_chunk, sr = torchaudio.load(audio_path) + # audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + sample_rate = int(sr) + # audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'caption_t5': caption_t5, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = Audioset( +# root="dataset/3_Audioset/video/sound", +# tsv_path="dataset/3_Audioset/split_txt/unbalanced_sound_filtered_aligned_novgg_noout.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="dataset/3_Audioset/video_text_latents/" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/audioset_video_224.py b/data_utils/v2a_utils/audioset_video_224.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc9f94332c74b5bd03790954f539e6ab570deb9 --- /dev/null +++ b/data_utils/v2a_utils/audioset_video_224.py @@ -0,0 +1,268 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class Audioset(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + duration_sec: float = 10.0, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.captions = [] + self.videos = [] + self.caption_t5s = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + with open(tsv_path.replace('.csv','.txt')) as file: + paths = file.readlines() + for record, path in zip(df_list,paths): + id = Path(record['id']).stem + # if os.path.exists(f'{save_dir}/{id}.pth'): continue + caption = record['caption'] + caption_t5 = record['caption_t5'] + path = path.strip() + part = Path(path).parent + video_id = Path(path).stem[1:] + video_path = os.path.join('dataset/3_Audioset/video',part,f'{video_id}.mp4') + assert os.path.exists(video_path), 'video must exist' + # if id in videos: + self.captions.append(caption) + self.caption_t5s.append(caption_t5) + # self.labels[id] = label + self.videos.append(video_path) + # else: + # missing_videos.append(id) + assert len(self.captions) == len(self.caption_t5s) and len(self.captions) == len(self.videos), 'error length' + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.duration_sec = duration_sec + + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_path = self.videos[idx] + video_id = 'Y'+str(Path(video_path).stem) + caption = self.captions[idx] + caption_t5 = self.caption_t5s[idx] + + reader = StreamingMediaDecoder(video_path) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(clip_chunk.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert clip_chunk.shape[0] == self.clip_expected_length and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': caption, + 'caption_t5': caption_t5, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.captions) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/feature_utils_224.py b/data_utils/v2a_utils/feature_utils_224.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcb012470a9ebbdbd34e8e37a920aa359d44b40 --- /dev/null +++ b/data_utils/v2a_utils/feature_utils_224.py @@ -0,0 +1,182 @@ +from typing import Literal, Optional +import json +import open_clip +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from open_clip import create_model_from_pretrained +from torchvision.transforms import Normalize +from ThinkSound.models.factory import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict +from ThinkSound.models.utils import copy_state_dict +from transformers import AutoModel +from transformers import AutoProcessor +from transformers import T5EncoderModel, AutoTokenizer +import logging +from data_utils.ext.synchformer import Synchformer + +log = logging.getLogger() + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features, last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + + +class FeaturesUtils(nn.Module): + + def __init__( + self, + *, + vae_ckpt: Optional[str] = None, + vae_config: Optional[str] = None, + synchformer_ckpt: Optional[str] = None, + enable_conditions: bool = True, + need_vae_encoder: bool = True, + ): + super().__init__() + + if enable_conditions: + self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.clip_model = patch_clip(self.clip_model) + self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl") + self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl") + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + # std=[0.26862954, 0.26130258, 0.27577711]) + self.synchformer = Synchformer() + self.synchformer.load_state_dict( + torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) + + # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + else: + self.clip_model = None + self.synchformer = None + self.tokenizer = None + + if vae_ckpt is not None: + with open(vae_config) as f: + vae_config = json.load(f) + self.vae = create_model_from_config(vae_config) + print(f"Loading model checkpoint from {vae_ckpt}") + # Load checkpoint + copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.' + else: + self.tod = None + + def compile(self): + if self.clip_model is not None: + self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) + self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) + if self.synchformer is not None: + self.synchformer = torch.compile(self.synchformer) + + + def train(self, mode: bool) -> None: + return super().train(False) + + @torch.inference_mode() + def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: + assert self.clip_model is not None, 'CLIP is not loaded' + # x: (B, T, C, H, W) H/W: 384 + b, t, c, h, w = x.shape + + assert c == 3 and h == 224 and w == 224 + # x = self.clip_preprocess(x) + x = rearrange(x, 'b t c h w -> (b t) c h w') + outputs = [] + if batch_size < 0: + batch_size = b * t + for i in range(0, b * t, batch_size): + outputs.append(self.clip_model.get_image_features(x[i:i + batch_size])) + x = torch.cat(outputs, dim=0) + # x = self.clip_model.encode_image(x, normalize=True) + x = rearrange(x, '(b t) d -> b t d', b=b) + return x + + @torch.inference_mode() + def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: + assert self.synchformer is not None, 'Synchformer is not loaded' + # x: (B, T, C, H, W) H/W: 384 + b, t, c, h, w = x.shape + # import ipdb + # ipdb.set_trace() + assert c == 3 and h == 224 and w == 224 + + # partition the video + segment_size = 16 + step_size = 8 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size:i * step_size + segment_size]) + x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) + + outputs = [] + if batch_size < 0: + batch_size = b + x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w') + for i in range(0, b * num_segments, batch_size): + outputs.append(self.synchformer(x[i:i + batch_size])) + x = torch.cat(outputs, dim=0) + x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b) + return x + + @torch.inference_mode() + def encode_text(self, text: list[str]) -> torch.Tensor: + assert self.clip_model is not None, 'CLIP is not loaded' + # assert self.tokenizer is not None, 'Tokenizer is not loaded' + # x: (B, L) + tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device) + return self.clip_model.get_text_features(**tokens) + + @torch.inference_mode() + def encode_t5_text(self, text: list[str]) -> torch.Tensor: + assert self.t5_model is not None, 'T5 model is not loaded' + assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded' + # x: (B, L) + inputs = self.t5_tokenizer(text, + truncation=True, + max_length=77, + padding="max_length", + return_tensors="pt").to(self.device) + return self.t5_model(**inputs).last_hidden_state + + @torch.inference_mode() + def encode_audio(self, x) -> torch.Tensor: + x = self.vae.encode(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/data_utils/v2a_utils/feature_utils_224_audio.py b/data_utils/v2a_utils/feature_utils_224_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..b869edf4620f2ae6c1b928d3bf844292b3dcba45 --- /dev/null +++ b/data_utils/v2a_utils/feature_utils_224_audio.py @@ -0,0 +1,133 @@ +from typing import Literal, Optional +import json +# import open_clip +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +# from open_clip import create_model_from_pretrained +from torchvision.transforms import Normalize +from ThinkSound.models.factory import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict +from ThinkSound.training.utils import copy_state_dict +from transformers import AutoModel +from transformers import AutoProcessor +from transformers import T5EncoderModel, AutoTokenizer +import logging + + +from data_utils.ext.synchformer import Synchformer + +log = logging.getLogger() + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features, last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + + +class FeaturesUtils(nn.Module): + + def __init__( + self, + *, + vae_ckpt: Optional[str] = None, + vae_config: Optional[str] = None, + synchformer_ckpt: Optional[str] = None, + enable_conditions: bool = True, + need_vae_encoder: bool = True, + ): + super().__init__() + + if enable_conditions: + self.clip_model = AutoModel.from_pretrained("metaclip-h14-fullcc2.5b") + self.clip_model = patch_clip(self.clip_model) + self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl") + self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl") + self.clip_processor = AutoProcessor.from_pretrained("metaclip-h14-fullcc2.5b") + # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + # std=[0.26862954, 0.26130258, 0.27577711]) + # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + else: + self.clip_model = None + self.synchformer = None + self.tokenizer = None + + if vae_ckpt is not None: + with open(vae_config) as f: + vae_config = json.load(f) + self.vae = create_model_from_config(vae_config) + print(f"Loading model checkpoint from {vae_ckpt}") + # Load checkpoint + copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.' + else: + self.tod = None + + def compile(self): + if self.clip_model is not None: + self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) + self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) + if self.synchformer is not None: + self.synchformer = torch.compile(self.synchformer) + + + def train(self, mode: bool) -> None: + return super().train(False) + + @torch.inference_mode() + def encode_text(self, text: list[str]) -> torch.Tensor: + assert self.clip_model is not None, 'CLIP is not loaded' + # assert self.tokenizer is not None, 'Tokenizer is not loaded' + # x: (B, L) + tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device) + return self.clip_model.get_text_features(**tokens) + + @torch.inference_mode() + def encode_t5_text(self, text: list[str]) -> torch.Tensor: + assert self.t5_model is not None, 'T5 model is not loaded' + assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded' + # x: (B, L) + inputs = self.t5_tokenizer(text, + truncation=True, + max_length=77, + padding="max_length", + return_tensors="pt").to(self.device) + return self.t5_model(**inputs).last_hidden_state + + @torch.inference_mode() + def encode_audio(self, x) -> torch.Tensor: + x = self.vae.encode(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/data_utils/v2a_utils/feature_utils_288.py b/data_utils/v2a_utils/feature_utils_288.py new file mode 100644 index 0000000000000000000000000000000000000000..9b12e224087cc0afad7dccae78bc9c96fbcf5260 --- /dev/null +++ b/data_utils/v2a_utils/feature_utils_288.py @@ -0,0 +1,201 @@ +from typing import Literal, Optional +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torchvision.transforms import Normalize +from ThinkSound.models.factory import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict +import einshape +import sys +import os +from transformers import AutoTokenizer,AutoModelForSeq2SeqLM,AutoModel,T5EncoderModel +import logging +import os +import numpy as np +log = logging.getLogger() + +import jax +import jax.numpy as jnp +from videoprism import models as vp + +from data_utils.ext.synchformer import Synchformer + + +def copy_state_dict(model, state_dict): + """Load state_dict to model, but only for keys that match exactly. + + Args: + model (nn.Module): model to load state_dict. + state_dict (OrderedDict): state_dict to load. + """ + model_state_dict = model.state_dict() + missing_keys = [] + unexpected_keys = [] + for key in state_dict: + if key not in model_state_dict: + unexpected_keys.append(key) + elif state_dict[key].shape != model_state_dict[key].shape: + unexpected_keys.append(key) + + for key in model_state_dict: + if key not in state_dict: + missing_keys.append(key) + + print("Missing keys in state_dict:", missing_keys) + print("Unexpected keys in state_dict:", unexpected_keys) + for key in state_dict: + if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: + if isinstance(state_dict[key], torch.nn.Parameter): + # backwards compatibility for serialized parameters + state_dict[key] = state_dict[key].data + model_state_dict[key] = state_dict[key] + + model.load_state_dict(model_state_dict, strict=False) + + +class FeaturesUtils(nn.Module): + + def __init__( + self, + *, + vae_ckpt: Optional[str] = None, + vae_config: Optional[str] = None, + synchformer_ckpt: Optional[str] = None, + enable_conditions: bool = True, + need_vae_encoder: bool = True, + ): + super().__init__() + + if enable_conditions: + + self.t5 = AutoModelForSeq2SeqLM.from_pretrained("google/t5gemma-l-l-ul2-it").get_encoder() + self.t5tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-l-l-ul2-it") + + self.synchformer = Synchformer() + self.synchformer.load_state_dict( + torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) + + + else: + + self.synchformer = None + self.tokenizer = None + + if vae_ckpt is not None: + with open(vae_config) as f: + vae_config = json.load(f) + self.vae = create_model_from_config(vae_config) + print(f"Loading model checkpoint from {vae_ckpt}") + # Load checkpoint + copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.' + + + def _init_jax(self): + if hasattr(self, "flax_model"): + return # already init + backend = jax.default_backend() + if backend != 'gpu': + log.warning( + f"JAX is running on {backend.upper()} instead of GPU! " + f"Performance will be significantly degraded." + ) + self.jax_dev = jax.devices()[0] # CPU只有一个设备 + else: + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + devices = jax.devices() + device_idx = local_rank % len(devices) + self.jax_dev = devices[device_idx] + + model_name = 'videoprism_lvt_public_v1_large' + self.flax_model = vp.get_model(model_name) + state = vp.load_pretrained_weights(model_name) + self.loaded_state = jax.device_put(state, device=self.jax_dev) + self.text_tokenizer = vp.load_text_tokenizer('c4_en') + + self.apply_jit = jax.jit(lambda x, y, z: self.flax_model.apply( + self.loaded_state, x, y, z, train=False, return_intermediate=True + ), device=self.jax_dev) + + # def train(self, mode: bool) -> None: + # return super().train(False) + + def encode_video_and_text_with_videoprism(self, x: torch.Tensor, cot: str, batch_size: int = -1) -> torch.Tensor: + self._init_jax() + + b, t, h, w, c = x.shape + assert c == 3 and h == 288 and w == 288 + text_ids, text_paddings = vp.tokenize_texts(self.text_tokenizer, cot) + + x = jax.device_put(x.cpu().numpy(), device=self.jax_dev) + + text_ids = jax.device_put(text_ids, device=self.jax_dev) + text_paddings = jax.device_put(text_paddings, device=self.jax_dev) + + video_embeddings, text_embeddings, outputs = self.apply_jit( + x, text_ids, text_paddings + ) + + frame_embed = outputs['frame_embeddings'] + spatialtemporal_embed = einshape.jax_einshape( + 'b(ts)d->btsd', outputs['spatiotemporal_features'], t=frame_embed.shape[0] + ) + + return video_embeddings[0],frame_embed[0],spatialtemporal_embed[0][0],text_embeddings + + @torch.inference_mode() + def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: + assert self.synchformer is not None, 'Synchformer is not loaded' + + b, t, c, h, w = x.shape + + + assert c == 3 and h == 224 and w == 224 + + + segment_size = 16 + step_size = 8 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size:i * step_size + segment_size]) + x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) + + outputs = [] + if batch_size < 0: + batch_size = b + x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w') + for i in range(0, b * num_segments, batch_size): + outputs.append(self.synchformer(x[i:i + batch_size])) + x = torch.cat(outputs, dim=0) + x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b) + return x + + @torch.inference_mode() + def encode_t5_text(self, text: list[str]) -> torch.Tensor: + assert self.t5 is not None, 'T5 model is not loaded' + assert self.t5tokenizer is not None, 'T5 Tokenizer is not loaded' + # x: (B, L) + inputs = self.t5tokenizer(text, + padding=True, + truncation=False, + return_tensors="pt").to(self.device) + + text_features = self.t5(**inputs).last_hidden_state + return text_features + + @torch.inference_mode() + def encode_audio(self, x) -> torch.Tensor: + x = self.vae.encode(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + diff --git a/data_utils/v2a_utils/thinksound_288_al.py b/data_utils/v2a_utils/thinksound_288_al.py new file mode 100644 index 0000000000000000000000000000000000000000..f99991c7d03aeafb2e3ba38821c5d8555d54dfa3 --- /dev/null +++ b/data_utils/v2a_utils/thinksound_288_al.py @@ -0,0 +1,221 @@ +import logging +log = logging.getLogger() +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image +from transformers import AutoProcessor +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +import mediapy +import torch.nn.functional as F +import numpy as np +import subprocess +from torchvision.utils import save_image +try: + from moviepy import VideoFileClip +except ImportError: + from moviepy.editor import VideoFileClip + +_CLIP_FPS = 4 +_CLIP_SIZE = 288 +_SYNC_FPS = 25 +_SYNC_SIZE = 224 + +def pad_to_square(video_tensor): + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + pad_h = max_side - h + pad_w = max_side - w + + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + video_tensor = F.interpolate(video_padded, size=(_CLIP_SIZE, _CLIP_SIZE), mode='bilinear', align_corners=False) + return video_tensor + + +def get_video_duration(video_path): + video = VideoFileClip(str(video_path)) + return video.duration + + +class VGGSound(Dataset): + def __init__( + self, + root: Path, + *, + tsv_path: Path, + sample_rate: int = 44100, + normalize_audio: bool = False, + start_row: int = None, + end_row: int = None, + save_dir: str = '', + use_variable_length: bool = False, + video_encoder: str = 'videoprism', + video_resolution: int = _CLIP_SIZE, + inference_mode: bool = False, + video_fps: int = _CLIP_FPS + ): + self.inference_mode = inference_mode + self.sample_rate=sample_rate + self.root = Path(root) + self.normalize_audio = normalize_audio + self.use_variable_length = use_variable_length + self.video_encoder = video_encoder + self.video_resolution = video_resolution + self.video_fps = video_fps + + + self.videos = [] + self.caption_cot = [] + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.npz'): continue + + caption_cot = record['caption_cot'] + + if not os.path.exists(os.path.join(self.root, id)+".mp4"): + continue + + self.videos.append(id) + self.caption_cot.append(caption_cot) + + + log.info(f'processing {len(self.videos)} videos') + + + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + + video_id = self.videos[idx] + caption_cot = self.caption_cot[idx] + duration_sec= get_video_duration(self.root / (video_id + '.mp4')) + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + if not self.inference_mode: + reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + + sync_chunk = data_chunk[1] + if not self.inference_mode: + audio_chunk = data_chunk[2] + audio_chunk = audio_chunk.transpose(0, 1) + else: + num_samples = int(self.sample_rate * duration_sec) + audio_chunk = torch.randn((2, num_samples)) + + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + if not self.inference_mode: + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + else: + sample_rate = self.sample_rate + + + abs_max = audio_chunk[0].abs().max() + + if self.normalize_audio: + abs_max = audio_chunk.abs().max() + audio_chunk = audio_chunk / abs_max * 0.95 + + clip_expected_length = int(_CLIP_FPS * duration_sec) + sync_expected_length = int(_SYNC_FPS * duration_sec) + + clip_chunk = clip_chunk[:clip_expected_length] + + if clip_chunk.shape[0] != clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = clip_expected_length - current_length + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + + + + clip_chunk = pad_to_square(clip_chunk) + clip_chunk = clip_chunk.permute(0, 2, 3, 1) + + clip_chunk = mediapy.to_float01(clip_chunk) + + sync_chunk = sync_chunk[:sync_expected_length] + if sync_chunk.shape[0] != sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + + padding = last_frame.repeat(sync_expected_length - current_length, 1, 1, 1) + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + + + sync_chunk = self.sync_transform(sync_chunk) + + data = { + 'id': video_id, + 'caption_cot': caption_cot, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + logging.error(f'Error loading {self.videos[idx]}: {e}') + return None + + def __len__(self) -> int: + return len(self.videos) + + + diff --git a/data_utils/v2a_utils/vggsound.py b/data_utils/v2a_utils/vggsound.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ae03f7af3eb263c649f7d8855b9069db988b2b --- /dev/null +++ b/data_utils/v2a_utils/vggsound.py @@ -0,0 +1,259 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + videos = sorted(os.listdir(self.root)) + videos = set([Path(v).stem for v in videos]) # remove extensions + # videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + if id in videos: + # self.labels.append(label) + self.labels[id] = label + self.videos.append(id) + else: + missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + # import ipdb + # ipdb.set_trace() + # process audio + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + + # if abs_max <= 1e-6: + # raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = self.clip_transform(clip_chunk) + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/test", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_latents_text/test" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224.py b/data_utils/v2a_utils/vggsound_224.py new file mode 100644 index 0000000000000000000000000000000000000000..ff971b2ba4b9ea3251eb41ad1d48a0e956bf924c --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224.py @@ -0,0 +1,325 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.caption_cot = [] + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + caption_cot = record['caption_cot'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + self.caption_cot.append(caption_cot) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + caption_cot = self.caption_cot[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_chunk = data_chunk[2] + if len(audio_chunk.shape) != 2: + raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + # if sync_chunk.shape[0] < self.sync_expected_length: + # raise RuntimeError( + # f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + # ) + # import ipdb + # ipdb.set_trace() + # process audio + # import ipdb + # ipdb.set_trace() + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + audio_chunk = audio_chunk.transpose(0, 1) + abs_max = audio_chunk[0].abs().max() + # audio_chunk = audio_chunk.mean(dim=0) # mono + # if self.normalize_audio: + # abs_max = audio_chunk.abs().max() + # audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + if audio_chunk.shape[0] > 1 and audio_chunk[1].abs().max() > 1e-6: + audio_chunk = audio_chunk[1:2] + else: + raise RuntimeError(f'Audio is silent {video_id}') + + # ensure the stereo audio + if audio_chunk.shape[0] < 2: + audio_chunk = audio_chunk.repeat(2, 1) + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[1] < self.expected_audio_length: + # zero-padding audio + padding_length = self.expected_audio_length - audio_chunk.shape[1] + # 创建 padding 张量,大小为 [batch_size, padding_length],值为0 + padding = torch.zeros(audio_chunk.shape[0], padding_length) + # 将原始音频和 padding 沿第 1 维度拼接在一起 + audio_chunk = torch.cat((audio_chunk, padding), dim=1) + # raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:,:self.expected_audio_length] + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + 'caption_cot': caption_cot, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224_no_audio.py b/data_utils/v2a_utils/vggsound_224_no_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..51b6ad69f3afca5d10764cc9eebc0daf848f810a --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224_no_audio.py @@ -0,0 +1,275 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + self.caption_cot = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['caption'] + caption_cot = record['caption_cot'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + self.caption_cot.append(caption_cot) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + caption_cot = self.caption_cot[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + # reader.add_basic_audio_stream(frames_per_chunk=2**30,) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + # audio_chunk = data_chunk[2] + # if len(audio_chunk.shape) != 2: + # raise RuntimeError(f'error audio shape {video_id}') + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + # if clip_chunk.shape[0] < self.clip_expected_length: + # raise RuntimeError( + # f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + # ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + # assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + # log.info(clip_chunk.shape) + # save_tensor_as_image(clip_chunk[0].numpy(),'scale.png') + # log.info(clip_chunk[0]) + # clip_chunk = outputs + # text_ids = outputs["input_ids"] + # temp_img = clip_chunk[0].permute(1, 2, 0) * 255 + # save_image(clip_chunk[0],'scale.png') + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + # assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + # 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + 'caption_cot': caption_cot, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_224_no_sync.py b/data_utils/v2a_utils/vggsound_224_no_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..87b301e8e1f4e09ced92a01614c78b93435cc327 --- /dev/null +++ b/data_utils/v2a_utils/vggsound_224_no_sync.py @@ -0,0 +1,223 @@ +import os +from pathlib import Path +from typing import Optional, Union +from PIL import Image + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import numpy as np + +import logging +log = logging.getLogger() + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def save_tensor_as_image(tensor, save_path): + """ + 将形状为 (1, 3, H, W) 的 RGB 图像数组保存为图片文件。 + + :param tensor: 输入的 NumPy 数组 (1, 3, H, W)。 + :param save_path: 图片保存路径。 + """ + # # 移除批次维度,变成 (3, H, W) + # tensor = tensor.squeeze(0) + + # 交换轴顺序,变为 (H, W, 3) + image_array = np.transpose(tensor, (1, 2, 0)) + + # 检查数组是否为合适的数据类型 + if image_array.dtype != np.uint8: + # 如果不是 uint8,首先标准化,然后转换 + image_array = (image_array - image_array.min()) / (image_array.max() - image_array.min()) * 255 + image_array = image_array.astype(np.uint8) + + # 创建图像对象 + image = Image.fromarray(image_array) + + # 保存图片 + image.save(save_path) + print(f"Image saved to {save_path}") + +def pad_to_square(video_tensor): + # 验证输入的形状 + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + # 计算每一维度需要的填充量:(left, right, top, bottom) + pad_h = max_side - h + pad_w = max_side - w + + # 创建padding tuple (left, right, top, bottom) + # 因为图像的填充是作用在最后两个维度 h 和 w 上,所以我们需要指定这两个维度的填充 + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + # 使用F.pad对视频张量进行填充操作 + # 填充参数为 (left, right, top, bottom) + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + if os.path.exists(f'{save_dir}/{id}.pth'): continue + label = record['label'] + # if id in videos: + self.labels.append(label) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + # save_image(clip_chunk[0] / 255.0,'square.png') + # clip_chunk = self.clip_transform(clip_chunk) + # import ipdb + # ipdb.set_trace() + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + + data = { + 'id': video_id, + 'caption': label, + 'clip_video': clip_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/train", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_224_latents_text/train" +# ) +# dataset[0] \ No newline at end of file diff --git a/data_utils/v2a_utils/vggsound_text.py b/data_utils/v2a_utils/vggsound_text.py new file mode 100644 index 0000000000000000000000000000000000000000..8f097f19136b3084df1434d4f2b2c6e5ddec88bf --- /dev/null +++ b/data_utils/v2a_utils/vggsound_text.py @@ -0,0 +1,109 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'dataset/vggsound/split_txt/train_caption.csv', + start_row: Optional[int] = None, + end_row: Optional[int] = None, + save_dir: str = 'data/vggsound/video_latents_text/train' + ): + self.root = Path(root) + + # videos = sorted(os.listdir(self.root)) + # videos = set([Path(v).stem for v in videos]) # remove extensions + videos = [] + self.labels = [] + self.cots = [] + self.videos = [] + missing_videos = [] + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep=',', dtype={'id': str}).to_dict('records') + + # 控制处理的行范围 + if start_row is not None and end_row is not None: + df_list = df_list[start_row:end_row] + + for record in df_list: + id = record['id'] + # if os.path.exists(f'{save_dir}/{id}.pth'): + # continue + # try: + # torch.load(f'{save_dir}/{id}.pth') + # continue + # except: + # print(f'error load file: {save_dir}/{id}.pth') + # os.system(f'rm -f {save_dir}/{id}.pth') + label = record['caption'] + # if id in videos: + self.labels.append(label) + self.cots.append(record['caption_cot']) + # self.labels[id] = label + self.videos.append(id) + # else: + # missing_videos.append(id) + + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + + + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[idx] + cot = self.cots[idx] + data = { + 'id': video_id, + 'caption': label, + 'caption_cot': cot + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) + + +# dataset = VGGSound( +# root="data/vggsound/video/test", +# tsv_path="data/vggsound/split_txt/temp.csv", +# sample_rate=44100, +# duration_sec=9.0, +# audio_samples=397312, +# start_row=0, +# end_row=None, +# save_dir="data/vggsound/video_latents_text/test" +# ) +# dataset[0] \ No newline at end of file diff --git a/defaults.ini b/defaults.ini new file mode 100644 index 0000000000000000000000000000000000000000..46d5673cc164cf9ca46ed3c12bb4921a3c5a232b --- /dev/null +++ b/defaults.ini @@ -0,0 +1,71 @@ + +[DEFAULTS] + +#name of the run +name = ThinkSound + +# the batch size +batch_size = 1 +test_batch_size = 1 + +# predict ckpt directory +ckpt_dir = "ckpts/thinksound_light.ckpt" + +# number of GPUs to use for training +num_gpus = 1 + +# number of nodes to use for training +num_nodes = 1 + +# Multi-GPU strategy for PyTorch Lightning +strategy = "" + +# Precision to use for training +precision = "bf16-mixed" + +# number of CPU workers for the DataLoader +num_workers = 32 + +# the random seed +seed = 42 + +# Batches for gradient accumulation +accum_batches = 1 + +# Number of steps between checkpoints +checkpoint_every = 2000 + +# trainer checkpoint file to restart training from +ckpt_path = '' + +# model checkpoint file to start a new training run from +pretrained_ckpt_path = '' + +# Checkpoint path for the pretransform model if needed +pretransform_ckpt_path = 'ckpts/vae.ckpt' + +# configuration model specifying model hyperparameters +model_config = '' + +# configuration for datasets +dataset_config = '' + +# directory to save the checkpoints in +save_dir = '' + +# gradient_clip_val passed into PyTorch Lightning Trainer +gradient_clip_val = 0.0 + +# remove the weight norm from the pretransform model +remove_pretransform_weight_norm = '' + +compile = False + +repeat_num = 5 + +duration_sec = '9' + +results_dir = 'results' + + + diff --git a/docs/PrismAudio/Dataset.md b/docs/PrismAudio/Dataset.md new file mode 100644 index 0000000000000000000000000000000000000000..5821829685fc0950b0b4d49e69263eaaa4a27140 --- /dev/null +++ b/docs/PrismAudio/Dataset.md @@ -0,0 +1,104 @@ +# Dataset Preparation Guide + +This guide provides step-by-step instructions for preparing datasets to train models in this repository. + +## 0. Pre-requisites + +Ensure the following checkpoint files exist in the `ckpts/` directory before continuing: + +* `ckpts/vae.ckpt` +* `ckpts/synchformer_state_dict.pth` + +## 1. Preparing Video-Text Datasets + +To convert raw videos and CoT annotations into training features, use the following command: + +```bash +torchrun --nproc_per_node=8 data_utils/extract_training_video.py \ + --root \ + --tsv_path \ + --save-dir \ + --add_video_path \ + --add_audio_path +``` + +### Arguments + +* `--root `: Path to the root directory containing all `.mp4` videos to be processed. +* `--tsv_path `: Path to the TSV/CSV file containing `id` and `caption_cot` columns. +* `--save-dir `: Directory where extracted feature `.npz` files will be saved. +* `--add_video_path ` *(optional)*: Reference video directory required for enabling the **Synchformer reward** in GRPO. Typically the same as `--root`. +* `--add_audio_path ` *(optional)*: Reference audio directory required for enabling the **ITD reward** in GRPO. + +> **Note:** `--add_video_path` and `--add_audio_path` are optional. Only provide them if you intend to use the corresponding reward functions during GRPO training. + +--- + +## 2. Organizing Feature Files + +After extraction, create a `.txt` file listing all generated feature file names (one per line), for example: + +``` +item1.npz +item2.npz +item3.npz +... +``` + +This file acts as the dataset split index and will be referenced in the dataset configuration. + +--- + +## 3. Creating the Dataset Configuration JSON + +Create a JSON file following the structure below (adapted from `ThinkSound/configs/multimodal_dataset_demo_prismaudio.json`): + +```json +{ + "dataset_type": "video_dataset", + "datasets": [ + { + "id": "your_dataset_id", + "path": "path_to_feature_dir", + "split_path": "path_to_train_split_txt" + } + ], + "val_datasets": [ + { + "id": "your_val_dataset_id", + "path": "path_to_val_feature_dir", + "split_path": "path_to_val_split_txt" + } + ], + "test_datasets": [ + { + "id": "your_test_dataset_id", + "path": "path_to_test_feature_dir", + "split_path": "path_to_test_split_txt" + } + ], + "random_crop": false, + "input_type": "video", + "fps": 8 +} +``` + +### Field Descriptions + +| Field | Description | +|-------|-------------| +| `dataset_type` | Fixed as `"video_dataset"` | +| `datasets` | List of training feature directories with their split `.txt` files | +| `val_datasets` | Validation set, same structure as `datasets` | +| `test_datasets` | Test set, same structure as `datasets` | +| `random_crop` | Whether to apply random cropping, typically `false` | +| `input_type` | Fixed as `"video"` | + + +You can include multiple datasets under `datasets` by appending additional dictionary blocks to the list. + +--- + +## 4. Proceed to Training + +Refer to [`Training.md`](./Training.md) for detailed training instructions once the dataset configuration is complete. \ No newline at end of file diff --git a/docs/PrismAudio/Training.md b/docs/PrismAudio/Training.md new file mode 100644 index 0000000000000000000000000000000000000000..ce1b99d00cb9354fc53afac005708c96047ac2ff --- /dev/null +++ b/docs/PrismAudio/Training.md @@ -0,0 +1,71 @@ +# PrismAudio Training Guide + +This guide walks you through data preparation, configuration, and launching GRPO training for PrismAudio. We recommend reading through all steps before starting. + +--- + +## Step 1: Prepare the Dataset + +Before training, you must preprocess the dataset following the instructions in [Dataset.md](./Dataset.md). This includes: + +1. Converting raw videos and CoT annotations into structured feature `.npz` files. +2. Constructing a valid dataset configuration JSON that points to all precomputed features. + +Make sure your extracted dataset includes all required modalities and is organized correctly. + +--- + +## Step 2: Configure Training Script + +Open `grpo/config/grpo.py` and select or customize the appropriate config function: + +- `general_thinksound_8gpus()` — for training **PrismAudio** + + +The key parameters to modify are: + +| Parameter | Description | +|-----------|-------------| +| `model_config` | Path to model architecture config, e.g. `ThinkSound/configs/model_configs/prismaudio.json` | +| `pretransform_ckpt_path` | Path to VAE checkpoint, e.g. `ckpts/vae.ckpt` | +| `ref_model` | Path to reference model checkpoint, e.g. `ckpts/prismaudio.ckpt` | +| `ckpt_dir` | Path to the checkpoint to fine-tune from | +| `dataset_config` | Path to your dataset configuration JSON prepared in Step 1 | + +Reward function weights can be adjusted under `reward_fn`: + +| Reward | Key | Default Weight | Description | +|--------|-----|---------------|-------------| +| **Temporal** | `synch_reward` | `0.6` | Audio-visual synchronization via Synchformer | +| **Semantic** | `ms_clap` | `1.0` | Audio-text alignment via MS-CLAP | +| **Spatial** | `itd_reward` | `0.4` | Inter-track distance reward using reference audio | +| **Aesthetic** | `meta_reward` | `0.1` | Aesthetic quality via Meta Audiobox Aesthetics | + +GRPO-specific parameters: + +| Parameter | Default | Description | +|-----------|---------|-------------| +| `train.beta` | `0.04` | KL penalty weight to mitigate reward hacking | +| `sample.num_audio_per_prompt` | `16` | Number of audio candidates per prompt for group-relative advantage computation | +| `sample.num_steps` | `24` | Number of denoising steps during training sampling | +| `sample.train_num_steps` | `2` | Number of SDE steps in Fast-GRPO hybrid sampling | +| `train.learning_rate` | `1e-5` | Learning rate for GRPO optimization | +| `train.ema` | `True` | Whether to use exponential moving average | + + + + +## Step 5: Launch Training + +Make the script executable (if not already) and start training: + +```bash +chmod +x scripts/PrismAudio/grpo_1node8gpus.sh +./scripts/PrismAudio/grpo_1node8gpus.sh +``` + + +--- + +Happy training! 🚀 +If you run into any issues, consider opening an issue or checking the documentation for detailed help. diff --git a/docs/ThinkSound/Dataset.md b/docs/ThinkSound/Dataset.md new file mode 100644 index 0000000000000000000000000000000000000000..e41d8a5f49e333150b6e8827b08c3fd82f7397d8 --- /dev/null +++ b/docs/ThinkSound/Dataset.md @@ -0,0 +1,100 @@ +# Dataset Preparation Guide + +This guide provides step-by-step instructions for preparing datasets to train models in this repository. + +## 0. Pre-requisites + +Ensure the following checkpoint files exist in the `ckpts/` directory before continuing: + +* `ckpts/vae.ckpt` +* `ckpts/synchformer_state_dict.pth` + +## 1. Preparing Video-Text Datasets + +To convert raw videos and CoT into training features, use the following command: + +```bash +torchrun --nproc_per_node=8 data_utils/extract_training_video.py \ + --root \ + --tsv_path \ + --save-dir \ + --duration_sec \ + --audio_samples duration_sec*44100 +``` + +* ``: Path to the root directory containing all .mp4 videos to be processed (all videos must be of equal duration). +* ``: Path to the TSV/CSV file that lists video-text pairs.(see `demo_test.csv` for format). +* ``: Directory where extracted video features will be saved. +* ``: Duration to which all videos will be uniformly trimmed or padded. + +## 2. Preparing Audio-Text Datasets + +You can also include audio-text pairs for training. Use the following command to extract features: + +```bash +torchrun --nproc_per_node=8 data_utils/extract_training_audio.py \ + --root \ + --tsv_path \ + --save-dir \ + --duration_sec \ + --audio_samples duration_sec*44100 +``` + +* ``: Path to the raw audio files. +* ``: Path to the TSV/CSV file that lists audio-text pairs. +* ``: Directory where extracted audio features will be saved. +* ``: Duration to which all audios will be uniformly trimmed or padded. +* Note that the audio input for feature extraction must be trimmed to match the duration of the video-text datasets. + +## 3. Organizing Feature Files + +For each dataset (video or audio), create a `.txt` file listing all feature file names (one per line), for example: + +``` +item1.pth +item2.pth +item3.pth +... +``` + +This file acts as the training split and will be referenced in the dataset config. + +## 4. Creating the Dataset Configuration JSON + +Create a JSON file following the structure below (adapted from `ThinkSound/configs/multimodal_dataset_demo.json`): + +```json +{ + "dataset_type": "multimodal_dir", + "video_datasets": [ + { + "id": "video_dataset_id", + "path": "path_to_video_feature_dir", + "split_path": "path_to_video_split_txt" + } + ], + "audio_datasets": [ + { + "id": "audio_dataset_id", + "path": "path_to_audio_feature_dir", + "split_path": "path_to_audio_split_txt" + } + ], + "val_datasets": [ + { + "id": "val_dataset_id", + "path": "path_to_val_feature_dir", + "split_path": "path_to_val_split_txt" + } + ], + "random_crop": true, + "input_type": "prompt" +} +``` + +You can include multiple datasets under `video_datasets` and `audio_datasets` by appending additional dictionary blocks to each list. The `val_datasets` is encouraged and must be a video-text dataset. + +## 5. Proceed to Training + +Refer to [`docs/ThinkSound/Training.md`](./Training.md) for detailed training instructions once the dataset configuration is complete. + diff --git a/docs/ThinkSound/Training.md b/docs/ThinkSound/Training.md new file mode 100644 index 0000000000000000000000000000000000000000..69f5d6ee4f40abf14d157c5afe9560f690fe405a --- /dev/null +++ b/docs/ThinkSound/Training.md @@ -0,0 +1,84 @@ +# Training Guide + +This guide will walk you through the process of preparing data, configuring your training setup, and launching GRPO training for the ThinkSound model. For best results, we recommend reading through all steps before starting. + +--- + +## Step 1: Prepare the Dataset + +Before training, you must preprocess the dataset following the instructions in [Dataset.md](./Dataset.md). This includes: + +1. Converting raw videos and CoT annotations into structured feature `.npz` files. +2. Constructing a valid dataset configuration JSON that points to all precomputed features. + +Make sure your extracted dataset includes all required modalities and is organized correctly. + +--- + +## Step 2: Configure Training Script + +Open `scripts/PrismAudio/grpo_1node8gpus.sh` and customize the following items: + +Under the `grpo/config` section, set the paths to your model and configuration files: + +* `model_config`: Path to the model architecture config (e.g., `ThinkSound/configs/model_configs/prismaudio.json`) +* `pretransform_ckpt_path`: Path to the pretrained model checkpoint (e.g., `ckpts/prismaudio.ckpt`) +* `dataset_config`: Path to your dataset configuration JSON prepared in Step 1 + +Also modify distributed training settings as needed: + +* `num_gpus`, `num_nodes`, `node_rank`, `MASTER_PORT`, etc. + +* (Optional) Enable debug mode by adding the `--debug` flag when running the script. + +### 🔍 Tip + +If you're using a multi-GPU setup, ensure the `WORLD_SIZE`, `NODE_RANK`, and `MASTER_PORT` are correctly set for your environment. These are critical for DistributedDataParallel (DDP) training. + +--- + +## Step 3: Configure Reward Functions *(Optional)* + +ThinkSound supports two optional reward functions during GRPO training. To enable them, provide the corresponding reference paths when extracting features (see [Dataset.md](./Dataset.md)): + +| Reward | Required Argument | Description | +|--------|------------------|-------------| +| **Synchformer** | `--add_video_path` | Enables audio-visual synchronization reward | +| **ITD** | `--add_audio_path` | Enables inter-track distance reward using reference audio | + +These paths are embedded into the extracted `.pth` feature files during dataset preparation and will be automatically used during GRPO training if present. + +--- + +## Step 4: Launch Training + +Make the script executable (if not already) and start training: + +```bash +chmod +x scripts/PrismAudio/grpo_1node8gpus.sh +./scripts/PrismAudio/grpo_1node8gpus.sh +``` + +Logs will be written to the specified log directory (`log_dir`). + +--- + +## Step 5: Customize Model and Training Parameters + +To modify model architecture or training strategy, open the model config file specified in `grpo/config`. +You can adjust a wide range of parameters, such as: + +* Number of model parameters +* Optimizer type +* Learning rate +* Latent dimension +* GRPO-specific reward weights + +Be sure to keep a backup of your config for reproducibility. + +--- + + + +Happy training! 🚀 +If you run into any issues, consider opening an issue or checking the documentation for detailed help. \ No newline at end of file diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbe7276fd289c69f403bab1270adbdad0d1443 --- /dev/null +++ b/predict.py @@ -0,0 +1,173 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +import re +import torch +import torchaudio +from lightning.pytorch import seed_everything +import random +from datetime import datetime +import numpy as np +from ThinkSound.models import create_model_from_config +from ThinkSound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from ThinkSound.inference.sampling import sample, sample_discrete_euler +from pathlib import Path + + + +def predict_step(diffusion, batch, diffusion_objective, device='cuda:0'): + diffusion = diffusion.to(device) + + reals, metadata = batch + ids = [item['id'] for item in metadata] + batch_size, length = reals.shape[0], reals.shape[2] + print(f"Predicting {batch_size} samples with length {length} for ids: {ids}") + with torch.amp.autocast('cuda'): + conditioning = diffusion.conditioner(metadata, device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + if 'metaclip_features' in conditioning: + conditioning['metaclip_features'][~video_exist] = diffusion.model.model.empty_clip_feat + + if 'sync_features' in conditioning: + conditioning['sync_features'][~video_exist] = diffusion.model.model.empty_sync_feat + + + cond_inputs = diffusion.get_conditioning_inputs(conditioning) + if batch_size > 1: + noise_list = [] + for _ in range(batch_size): + noise_1 = torch.randn([1, diffusion.io_channels, length]).to(device) # 每次生成推进RNG状态 + noise_list.append(noise_1) + noise = torch.cat(noise_list, dim=0) + else: + noise = torch.randn([batch_size, diffusion.io_channels, length]).to(device) + + with torch.amp.autocast('cuda'): + + model = diffusion.model + if diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif diffusion_objective == "rectified_flow": + import time + start_time = time.time() + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + end_time = time.time() + execution_time = end_time - start_time + print(f"执行时间: {execution_time:.2f} 秒") + if diffusion.pretransform is not None: + fakes = diffusion.pretransform.decode(fakes) + + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + return audios + + + +def load_file(filename, info, latent_length): + # try: + npz_file = filename + if os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + audio = torch.zeros((1, 64, latent_length), dtype=torch.float32) + info['video_exist'] = torch.tensor(True) + # except: + # print(f'error load file: {filename}') + return audio, info + +def load(filename,duration): + assert os.path.exists(filename) + info = {} + audio, info = load_file(filename, info, round(44100/64/32*duration)) + info["path"] = filename + + info['id'] = Path(filename).stem + info["relpath"] = 'demo.npz' + + return (audio, info) + +def main(): + + args = get_all_args() + + if (args.save_dir == ''): + args.save_dir=args.results_dir + + + seed = args.seed + + # Set a different seed for each process if using SLURM + if os.environ.get("SLURM_PROCID") is not None: + seed += int(os.environ.get("SLURM_PROCID")) + + # random.seed(seed) + # torch.manual_seed(seed) + seed_everything(seed, workers=True) + + #Get JSON config from args.model_config + if args.model_config == '': + args.model_config = "ThinkSound/configs/model_configs/thinksound.json" + with open(args.model_config) as f: + model_config = json.load(f) + + + duration=(float)(args.duration_sec) + + model_config["sample_size"] = duration * model_config["sample_rate"] + if "sync_seq_len" in model_config["model"]["diffusion"]["config"]: + model_config["model"]["diffusion"]["config"]["sync_seq_len"] = 24 * int(duration) + + if "clip_seq_len" in model_config["model"]["diffusion"]["config"]: + model_config["model"]["diffusion"]["config"]["clip_seq_len"] = 8 * int(duration) + + if "latent_seq_len" in model_config["model"]["diffusion"]["config"]: + model_config["model"]["diffusion"]["config"]["latent_seq_len"] = round(44100 / 64 / 32 * duration) + + + model = create_model_from_config(model_config) + + ## speed by torch.compile + if args.compile: + model = torch.compile(model) + + + model.load_state_dict(torch.load(args.ckpt_dir)) + + + load_vae_state = load_ckpt_state_dict(args.pretransform_ckpt_path, prefix='autoencoder.') + model.pretransform.load_state_dict(load_vae_state) + + audio,meta=load(os.path.join(args.results_dir, "demo.npz") , duration) + + for k, v in meta.items(): + if isinstance(v, torch.Tensor): + meta[k] = v.to('cuda:0') + + audio=predict_step(model, + batch=[audio,(meta,)], + diffusion_objective=model_config["model"]["diffusion"]["diffusion_objective"], + device='cuda:0' + ) + + current_date = datetime.now() + formatted_date = current_date.strftime('%m%d') + + audio_dir = os.path.join(args.save_dir,f'{formatted_date}_batch_size'+str(args.test_batch_size)) + os.makedirs(audio_dir,exist_ok=True) + torchaudio.save(os.path.join(audio_dir,"demo.wav"), audio[0], 44100) + + + #trainer.predict(training_wrapper, dm, return_predictions=False) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/PrismAudio/demo.sh b/scripts/PrismAudio/demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..9465ddfeae45651b37d67def2def6aad6446091c --- /dev/null +++ b/scripts/PrismAudio/demo.sh @@ -0,0 +1,97 @@ + +PYTHON=$(which python) +TORCHRUN=$(which torchrun) +export CUDA_LAUNCH_BLOCKING=1 +export TF_GPU_ALLOCATOR=cuda_malloc_async +export PATH=$CUDA_PATH/bin:$PATH +export LD_LIBRARY_PATH=$CUDA_PATH/lib64:$LD_LIBRARY_PATH +export TF_CPP_MIN_LOG_LEVEL=2 + +# Check number of arguments +if [ "$#" -lt 2 ]; then + echo "Usage: $0 " + exit 1 +fi + +VIDEO_PATH="$1" +DESCRIPTION="$2" + + +model_config="ThinkSound/configs/model_configs/prismaudio.json" +ckpt_dir="ckpts/prismaudio.ckpt" +# Generate unique ID +UNIQUE_ID=$(uuidgen | cut -c 1-8) + +# Create necessary directories +mkdir -p videos cot_coarse results + +# Get video filename and extension +VIDEO_FILE=$(basename "$VIDEO_PATH") +VIDEO_EXT="${VIDEO_FILE##*.}" +VIDEO_ID="${VIDEO_FILE%.*}" +TEMP_VIDEO_PATH="videos/demo.mp4" + +# Convert video to MP4 format if needed +if [ "${VIDEO_EXT,,}" != "mp4" ]; then + echo "⏳ Converting video to MP4 format..." + ffmpeg -y -i "$VIDEO_PATH" -c:v libx264 -preset fast -c:a aac -strict experimental "$TEMP_VIDEO_PATH" >/dev/null 2>&1 + if [ $? -ne 0 ]; then + echo "❌ Video conversion failed" + exit 2 + fi +else + cp "$VIDEO_PATH" "$TEMP_VIDEO_PATH" +fi + +# Calculate video duration +DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$TEMP_VIDEO_PATH") +DURATION_SEC=${DURATION%.*} +echo "Duration is: $DURATION_SEC" + +# Create cot.csv file +CAPTION_COT=$(echo "$DESCRIPTION" | tr '"' "'") +CSV_PATH="cot_coarse/cot.csv" +echo "id,caption_cot" > "$CSV_PATH" +echo "demo,\"$CAPTION_COT\"" >> "$CSV_PATH" + +# Run feature extraction +echo "⏳ Extracting features..." +$TORCHRUN --nproc_per_node=1 data_utils/prismaudio_data_process.py --inference_mode True + +if [ $? -ne 0 ]; then + echo "❌ Feature extraction failed" + rm -f "$TEMP_VIDEO_PATH" + exit 3 +fi + +# Run inference +echo "⏳ Running model inference..." +$PYTHON predict.py \ + --model-config "$model_config" \ + --duration-sec "$DURATION" \ + --ckpt-dir "$ckpt_dir" \ + --results-dir "results"\ + +if [ $? -ne 0 ]; then + echo "❌ Inference failed" + rm -f "$TEMP_VIDEO_PATH" + exit 4 +fi + +# Get generated audio file +CURRENT_DATE=$(date +"%m%d") +AUDIO_PATH="results/${CURRENT_DATE}_batch_size1/demo.wav" + +# Check if audio file exists +if [ ! -f "$AUDIO_PATH" ]; then + echo "❌ Generated audio file not found" + rm -f "$TEMP_VIDEO_PATH" + exit 5 +fi + +# Clean up temporary video file +rm -f "$TEMP_VIDEO_PATH" + + +echo "✅ Audio generated successfully!" +echo "Audio file path: $AUDIO_PATH" \ No newline at end of file diff --git a/scripts/PrismAudio/grpo_1node4gpus.sh b/scripts/PrismAudio/grpo_1node4gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..e360e3c449bd3c413bd04ad54d411378731a282d --- /dev/null +++ b/scripts/PrismAudio/grpo_1node4gpus.sh @@ -0,0 +1,16 @@ + +export NCCL_IB_DISABLE=1 +export NCCL_IB_HCA=mlx5 +export NCCL_DEBUG=WARN +export NCCL_IB_GID_INDEX=3 +export HF_ENDPOINT=https://hf-mirror.com + +MASTER_PORT=29501 + +# Launch command (parameters automatically read from accelerate_multi_node.yaml) +PYTHONPATH=. accelerate launch \ + --config_file grpo/scripts/accelerate_configs/multi_gpu.yaml \ + --num_machines 1 --num_processes 4 \ + --main_process_port ${MASTER_PORT} \ + grpo/scripts/train_audio_fast.py \ + --config grpo/config/grpo.py:general_thinksound_4gpus diff --git a/scripts/PrismAudio/grpo_1node8gpus.sh b/scripts/PrismAudio/grpo_1node8gpus.sh new file mode 100644 index 0000000000000000000000000000000000000000..d7644705f385d601806d15385172d303edca1a09 --- /dev/null +++ b/scripts/PrismAudio/grpo_1node8gpus.sh @@ -0,0 +1,16 @@ + +export NCCL_IB_DISABLE=1 +export NCCL_IB_HCA=mlx5 +export NCCL_DEBUG=WARN +export NCCL_IB_GID_INDEX=3 +export HF_ENDPOINT=https://hf-mirror.com + +MASTER_PORT=29501 + +# Launch command (parameters automatically read from accelerate_multi_node.yaml) +PYTHONPATH=. accelerate launch \ + --config_file grpo/scripts/accelerate_configs/multi_gpu.yaml \ + --num_machines 1 --num_processes 8 \ + --main_process_port ${MASTER_PORT} \ + grpo/scripts/train_audio_fast.py \ + --config grpo/config/grpo.py:general_thinksound_8gpus diff --git a/scripts/PrismAudio/infer.sh b/scripts/PrismAudio/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..0361f1b16281079820d1e02c5def2e6a889c777f --- /dev/null +++ b/scripts/PrismAudio/infer.sh @@ -0,0 +1,75 @@ + +# 变量定义 +ckpt_dir="ckpts/prismaudio.ckpt" +test_batch_size=1 +dataset_config="ThinkSound/configs/multimodal_dataset_demo_prismaudio.json" +model_config="ThinkSound/configs/model_configs/prismaudio.json" +pretransform_ckpt_path="ckpts/vae.ckpt" +# 默认值 +debug_mode="true" +node_rank=0 + +result_path="results" + +while [[ $# -gt 0 ]]; do + case "$1" in + --duration-sec) + if [[ -n "$2" && "$2" != --* ]]; then + duration_sec="$2" + shift 2 + else + echo "❌ Argument --duration-sec requires a value" + exit 1 + fi + ;; + --result-path) + if [[ -n "$2" && "$2" != --* ]]; then + result_path="$2" + shift 2 + else + echo "❌ Argument --result-path requires a path" + exit 1 + fi + ;; + *) + echo "❌ Unknown argument: $1" + exit 1 + ;; + esac +done + +export NODE_RANK=$node_rank +export RANK=$node_rank + +num_gpus=1 +num_nodes=1 + +export WORLD_SIZE=$((num_gpus * num_nodes)) +# 打印配置信息 +echo "Training Configuration:" +echo "Checkpoint Directory: $ckpt_dir" +echo "Dataset Config: $dataset_config" +echo "Model Config: $model_config" +echo "Pretransform Checkpoint Path: $pretransform_ckpt_path" +echo "Num GPUs: $num_gpus" +echo "Num Nodes: $num_nodes" +echo "Test Batch Size: $test_batch_size" +echo "Num Workers: 20" +echo "Node Rank: $node_rank" +echo "WORLD SIZE: $WORLD_SIZE" + + +python predict.py \ + --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --ckpt-dir "$ckpt_dir" \ + --pretransform-ckpt-path "$pretransform_ckpt_path" \ + --checkpoint-every 2000 \ + --num-gpus "$num_gpus" \ + --num-nodes "$num_nodes" \ + --batch-size 1 \ + --test-batch-size $test_batch_size \ + --num-workers 32 \ + --duration-sec $duration_sec \ + --results-dir $result_path \ + diff --git a/scripts/PrismAudio/preprocess_videos.sh b/scripts/PrismAudio/preprocess_videos.sh new file mode 100644 index 0000000000000000000000000000000000000000..a90f586fe070a6c3250f62c9ec11f471dca29e33 --- /dev/null +++ b/scripts/PrismAudio/preprocess_videos.sh @@ -0,0 +1,17 @@ + + +NPROC_PER_NODE=4 # Number of GPUs per node +MASTER_PORT=29605 # Communication port for distributed training + +ROOT="path_to_your_videos" # Root directory of input videos +SAVE_DIR="path_to_your_dataset" # Output directory for processed results +TSV_PATH="demo_test.csv" # Path to video metadata CSV file (must contain video id and caption_cot) + +# ===== Run ===== +torchrun \ + --nproc_per_node=$NPROC_PER_NODE \ + --master-port=$MASTER_PORT \ + data_utils/prismaudio_data_process.py \ + --root "$ROOT" \ + --save-dir "$SAVE_DIR" \ + --tsv_path "$TSV_PATH" diff --git a/scripts/PrismAudio/setup/requirements.txt b/scripts/PrismAudio/setup/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f56524687ce8ce880c9ccff52f1595645e613fac --- /dev/null +++ b/scripts/PrismAudio/setup/requirements.txt @@ -0,0 +1,108 @@ +# Core ML Framework +torch==2.6.0 +torchaudio==2.6.0 +torchvision==0.21.0 +triton==3.2.0 + +# Training & Optimization +accelerate==1.4.0 +deepspeed==0.16.4 +peft==0.10.0 +diffusers==0.33.1 +transformers + +safetensors==0.6.2 +trl==0.23.1 + +# Attention +xformers==0.0.29.post3 + +# Audio +descript-audio-codec==1.0.0 + +soundfile==0.13.1 +librosa==0.10.2.post1 +torchaudio==2.6.0 +pyloudnorm==0.1.1 +torch-stoi==0.2.3 +pystoi==0.4.1 +pedalboard==0.9.17 +laion_clap==1.1.7 +msclap==1.3.4 + +# Video +av==15.0.0 +k-diffusion==0.1.1 + +# Vision +timm==1.0.19 +clip-anytorch==2.6.0 +open_clip_torch==2.31.0 + +kornia==0.8.1 +albumentations==1.4.10 + +# Data +datasets==3.3.2 +webdataset==1.0.2 +pandas==2.2.3 +numpy==1.26.4 +scipy==1.15.2 +scikit-learn==1.6.1 +scikit-image==0.25.2 +h5py==3.14.0 +pyarrow==21.0.0 + +# Experiment Tracking +wandb==0.18.7 + + +# Config & Utils +omegaconf==2.3.0 +hydra-core==1.3.2 +einops==0.8.1 +einops-exts==0.0.4 +huggingface-hub==0.34.4 +lightning==2.5.1.post0 +torchmetrics==1.8.1 +torchdiffeq==0.2.5 +torchsde==0.2.6 +vector-quantize-pytorch==1.9.14 +alias-free-torch==0.0.6 +prefigure==0.0.10 +randomname==0.2.1 +aeiou==0.0.21 +argbind==0.3.9 +gin-config==0.5.0 +ml_collections==1.1.0 +jsonmerge==1.9.2 +flatten-dict==0.4.2 + +# API & Serving +openai==1.99.9 +gradio==5.42.0 +fastapi==0.115.11 +uvicorn==0.34.0 + +# General +tqdm==4.67.1 +requests==2.32.3 +pydantic==2.10.6 +pillow +matplotlib==3.10.0 +rich==14.1.0 +click==8.2.1 +PyYAML==6.0.2 +ftfy==6.3.1 +tiktoken==0.11.0 +sentencepiece==0.2.1 +regex==2025.7.34 +numba==0.61.2 +mediapy +moviepy +fire==0.7.0 +psutil==7.0.0 +jax[cuda12]>=0.5 +tensorflow +setuptools<=81 +seaborn==0.13.2 \ No newline at end of file diff --git a/scripts/ThinkSound/demo.bat b/scripts/ThinkSound/demo.bat new file mode 100644 index 0000000000000000000000000000000000000000..e4c49d0e9dd69589838beb715d9fcacf59acd545 --- /dev/null +++ b/scripts/ThinkSound/demo.bat @@ -0,0 +1,95 @@ +@echo off +setlocal enabledelayedexpansion + +:: Check number of arguments +if "%~3"=="" ( + echo Usage: %~nx0 ^ ^ ^ [use-half] + exit /b 1 +) + +set "VIDEO_PATH=%~1" +set "TITLE=%~2" +set "DESCRIPTION=%~3" +set "USE_HALF_FLAG=%~4" + +set "MODEL_CONFIG=ThinkSound\configs\model_configs\thinksound.json" + +:: Generate unique ID +for /f %%i in ('powershell -Command "[guid]::NewGuid().ToString().Substring(0,8)"') do set "UNIQUE_ID=%%i" + +:: Create necessary directories +if not exist videos mkdir videos +if not exist cot_coarse mkdir cot_coarse +if not exist results mkdir results + +:: Extract file info +for %%f in ("%VIDEO_PATH%") do ( + set "VIDEO_FILE=%%~nxf" + set "VIDEO_ID=%%~nf" + set "VIDEO_EXT=%%~xf" +) + +:: Normalize extension +set "VIDEO_EXT=!VIDEO_EXT:.=!" +set "TEMP_VIDEO_PATH=videos\demo.mp4" + +:: Convert to mp4 if needed +echo VIDEO_EXT is: !VIDEO_EXT! + +if /i not "!VIDEO_EXT!"=="mp4" ( + echo Converting to mp4... + ffmpeg -y -i "%VIDEO_PATH%" -c:v libx264 -preset fast -c:a aac "%TEMP_VIDEO_PATH%" >nul 2>&1 + if errorlevel 1 ( + echo Video conversion failed. + exit /b 2 + ) +) else ( + echo Copying "%VIDEO_PATH%" to "%TEMP_VIDEO_PATH%" + copy "%VIDEO_PATH%" "%TEMP_VIDEO_PATH%" +) + +:: Get duration (in seconds) +for /f %%i in ('ffprobe -v error -show_entries format^=duration -of default^=noprint_wrappers^=1:nokey^=1 "%TEMP_VIDEO_PATH%"') do set "DURATION=%%i" +for /f "tokens=1 delims=." %%a in ("%DURATION%") do set "DURATION_SEC=%%a" +echo Duration is: %DURATION_SEC% + +:: Create cot.csv +set "CSV_PATH=cot_coarse\cot.csv" +echo id,caption,caption_cot> "%CSV_PATH%" +echo demo,"%TITLE%","%DESCRIPTION:"='%" >> "%CSV_PATH%" + +:: Run feature extraction +echo Extracting features... +set "CMD=python extract_latents.py --duration_sec %DURATION_SEC%" +if "%USE_HALF_FLAG%"=="use-half" ( + set "CMD=%CMD% --use_half" +) +call %CMD% +if errorlevel 1 ( + echo Feature extraction failed. + del /f "%TEMP_VIDEO_PATH%" + exit /b 3 +) + +:: Run inference +echo Running inference... +python predict.py --model-config "%MODEL_CONFIG%" --duration-sec %DURATION_SEC% --results-dir "results" +if errorlevel 1 ( + echo Inference failed. + del /f "%TEMP_VIDEO_PATH%" + exit /b 4 +) + +:: Locate audio output +for /f %%i in ('powershell -Command "(Get-Date).ToString('MMdd')"') do set "CURRENT_DATE=%%i" +set "AUDIO_PATH=results\%CURRENT_DATE%_batch_size1\demo.wav" + +if not exist "%AUDIO_PATH%" ( + echo Audio file not found. + del /f "%TEMP_VIDEO_PATH%" + exit /b 5 +) + +del /f "%TEMP_VIDEO_PATH%" +echo Audio successfully generated: %AUDIO_PATH% +exit /b 0 diff --git a/scripts/ThinkSound/demo.sh b/scripts/ThinkSound/demo.sh new file mode 100755 index 0000000000000000000000000000000000000000..cbc681b98a866d69d34db5c227c406dbd80b721f --- /dev/null +++ b/scripts/ThinkSound/demo.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# Check number of arguments +if [ "$#" -lt 3 ] || [ "$#" -gt 4 ]; then + echo "Usage: $0 <description> [use-half]" + exit 1 +fi + +VIDEO_PATH="$1" +TITLE="$2" +DESCRIPTION="$3" +USE_HALF_FLAG="$4" + +model_config="ThinkSound/configs/model_configs/thinksound.json" + +# Generate unique ID +UNIQUE_ID=$(uuidgen | cut -c 1-8) + +# Create necessary directories +mkdir -p videos cot_coarse results + +# Get video filename and extension +VIDEO_FILE=$(basename "$VIDEO_PATH") +VIDEO_EXT="${VIDEO_FILE##*.}" +VIDEO_ID="${VIDEO_FILE%.*}" +TEMP_VIDEO_PATH="videos/demo.mp4" + +# Convert video to MP4 format if needed +if [ "${VIDEO_EXT,,}" != "mp4" ]; then + echo "⏳ Converting video to MP4 format..." + ffmpeg -y -i "$VIDEO_PATH" -c:v libx264 -preset fast -c:a aac -strict experimental "$TEMP_VIDEO_PATH" >/dev/null 2>&1 + if [ $? -ne 0 ]; then + echo "❌ Video conversion failed" + exit 2 + fi +else + cp "$VIDEO_PATH" "$TEMP_VIDEO_PATH" +fi + +# Calculate video duration +DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$TEMP_VIDEO_PATH") +DURATION_SEC=${DURATION%.*} +echo "Duration is: $DURATION_SEC" + +# Create cot.csv file +CAPTION_COT=$(echo "$DESCRIPTION" | tr '"' "'") +CSV_PATH="cot_coarse/cot.csv" +echo "id,caption,caption_cot" > "$CSV_PATH" +echo "demo,$TITLE,\"$CAPTION_COT\"" >> "$CSV_PATH" + +# Run feature extraction +echo "⏳ Extracting features..." +EXTRACT_CMD=("python" "extract_latents.py" "--duration_sec" "$DURATION_SEC") +if [ "$USE_HALF_FLAG" = "use-half" ]; then + EXTRACT_CMD+=("--use_half") +fi + +"${EXTRACT_CMD[@]}" 2>&1 +if [ $? -ne 0 ]; then + echo "❌ Feature extraction failed" + rm -f "$TEMP_VIDEO_PATH" + exit 3 +fi + +# Run inference +echo "⏳ Running model inference..." +python predict.py \ + --model-config "$model_config" \ + --duration-sec "$DURATION_SEC" \ + --results-dir "results"\ + +if [ $? -ne 0 ]; then + echo "❌ Inference failed" + rm -f "$TEMP_VIDEO_PATH" + exit 4 +fi + +# Get generated audio file +CURRENT_DATE=$(date +"%m%d") +AUDIO_PATH="results/${CURRENT_DATE}_batch_size1/demo.wav" + +# Check if audio file exists +if [ ! -f "$AUDIO_PATH" ]; then + echo "❌ Generated audio file not found" + rm -f "$TEMP_VIDEO_PATH" + exit 5 +fi + +# Clean up temporary video file +rm -f "$TEMP_VIDEO_PATH" + + +echo "✅ Audio generated successfully!" +echo "Audio file path: $AUDIO_PATH" \ No newline at end of file diff --git a/scripts/ThinkSound/eval_batch.bat b/scripts/ThinkSound/eval_batch.bat new file mode 100644 index 0000000000000000000000000000000000000000..cc787cf04a40b5ed9a53c0668be041d95e943d77 --- /dev/null +++ b/scripts/ThinkSound/eval_batch.bat @@ -0,0 +1,74 @@ +@echo off +setlocal enabledelayedexpansion + +set ARG_COUNT=0 +if not "%~1"=="" set /a ARG_COUNT+=1 +if not "%~2"=="" set /a ARG_COUNT+=1 +if not "%~3"=="" set /a ARG_COUNT+=1 +if not "%~4"=="" set /a ARG_COUNT+=1 + +if !ARG_COUNT! LSS 2 ( + echo Usage: %~nx0 ^<video_folder_path^> ^<csv_path^> [save_path] [use-half] + exit /b 1 +) + +set "VIDEO_PATH=%~1" +set "CSV_PATH=%~2" +set "SAVE_PATH=%~3" +set "USE_HALF_FLAG=%~4" + +if "!SAVE_PATH!"=="" ( + set "SAVE_PATH=results\features" +) + +set "DATASET_CONFIG=ThinkSound\configs\multimodal_dataset_demo.json" +set "MODEL_CONFIG=ThinkSound\configs\model_configs\thinksound.json" + +if not exist results mkdir results +if not exist results\features mkdir results\features +if not exist "!SAVE_PATH!" mkdir "!SAVE_PATH!" + +set "FIRST_VIDEO=" +for %%f in ("!VIDEO_PATH!\*.mp4") do ( + if not defined FIRST_VIDEO set "FIRST_VIDEO=%%~ff" +) + +if not defined FIRST_VIDEO ( + echo ❌ No .mp4 video file found in folder "!VIDEO_PATH!" + exit /b 1 +) + +echo First video found: !FIRST_VIDEO! + +for /f %%i in ('ffprobe -v error -show_entries format^=duration -of default^=noprint_wrappers^=1:nokey^=1 "!FIRST_VIDEO!"') do set "DURATION=%%i" +for /f "tokens=1 delims=." %%a in ("!DURATION!") do set "DURATION_SEC=%%a" +echo Video duration: !DURATION_SEC! seconds + +echo ⏳ Extracting features... +set "CMD=python extract_latents.py --root !VIDEO_PATH! --tsv_path !CSV_PATH! --save-dir results\features --duration_sec !DURATION_SEC!" +if /i "!USE_HALF_FLAG!"=="use-half" ( + set "CMD=!CMD! --use_half" +) +echo Running: !CMD! +call !CMD! +if errorlevel 1 ( + echo ❌ Feature extraction failed. + exit /b 3 +) + +echo ⏳ Running model inference... +set "CMD=python eval_batch.py --dataset-config !DATASET_CONFIG! --model-config !MODEL_CONFIG! --duration-sec !DURATION_SEC! --results-dir results\features --save-dir !SAVE_PATH!" +echo Running: !CMD! +call !CMD! +if errorlevel 1 ( + echo ❌ Inference failed. + exit /b 4 +) + +for /f %%i in ('powershell -Command "Get-Date -Format MMdd"') do set "CURRENT_DATE=%%i" +set "AUDIO_PATH=!SAVE_PATH!\!CURRENT_DATE!_batch_size1" + +echo ✅ Audio files saved in: !AUDIO_PATH! + +endlocal +exit /b 0 diff --git a/scripts/ThinkSound/eval_batch.sh b/scripts/ThinkSound/eval_batch.sh new file mode 100644 index 0000000000000000000000000000000000000000..74317102c1dd5fea9c60199992804c9ff6cadc02 --- /dev/null +++ b/scripts/ThinkSound/eval_batch.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Check number of arguments +if [ "$#" -lt 2 ] || [ "$#" -gt 3 ] || [ "$#" -gt 4 ]; then + echo "Usage: $0 <video_path> <csv_path> <save_path (optional)> [use-half]" + exit 1 +fi + +VIDEO_PATH="$1" +CSV_PATH="$2" +SAVE_PATH="$3" +USE_HALF_FLAG="$4" + +dataset_config="ThinkSound/configs/multimodal_dataset_demo.json" +model_config="ThinkSound/configs/model_configs/thinksound.json" + +# Create necessary directories +mkdir -p results results/features + +SAVE_PATH=${SAVE_PATH:-"results/features"} + + +FIRST_VIDEO=$(find "$VIDEO_PATH" -type f \( -iname "*.mp4" \) | head -n 1) + +if [ -z "$FIRST_VIDEO" ]; then + echo "❌ No .mp4 video file found in $VIDEO_PATH" + exit 1 +fi + +DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$FIRST_VIDEO") +DURATION_SEC=${DURATION%.*} + + +# Run feature extraction +echo "⏳ Extracting features..." +EXTRACT_CMD=("python" "extract_latents.py" "--root" "$VIDEO_PATH" "--tsv_path" "$CSV_PATH" "--save-dir" "results/features" "--duration_sec" "$DURATION_SEC") +if [ "$USE_HALF_FLAG" = "use-half" ]; then + EXTRACT_CMD+=("--use_half") +fi + +"${EXTRACT_CMD[@]}" 2>&1 + +if [ $? -ne 0 ]; then + echo "❌ Feature extraction failed" + exit 3 +fi + +# Run inference +echo "⏳ Running model inference..." +python eval_batch.py --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --duration-sec "$DURATION_SEC" \ + --results-dir "results/features"\ + --save-dir "$SAVE_PATH" 2>&1 \ + +if [ $? -ne 0 ]; then + echo "❌ Inference failed" + exit 4 +fi + +# Get generated audio file +CURRENT_DATE=$(date +"%m%d") +AUDIO_PATH=$SAVE_PATH"/${CURRENT_DATE}_batch_size1" + + +echo "Audio files path: $AUDIO_PATH" \ No newline at end of file diff --git a/scripts/ThinkSound/infer.sh b/scripts/ThinkSound/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..46b1ecf3f095a50da4705aa0a041579bd4dc36db --- /dev/null +++ b/scripts/ThinkSound/infer.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# 变量定义 +ckpt_dir="ckpts/thinksound.ckpt" +test_batch_size=1 +dataset_config="ThinkSound/configs/multimodal_dataset_demo.json" +model_config="ThinkSound/configs/model_configs/thinksound.json" +pretransform_ckpt_path="ckpts/vae.ckpt" +# 默认值 +debug_mode="true" +node_rank=0 + +result_path="results" + +while [[ $# -gt 0 ]]; do + case "$1" in + --duration-sec) + if [[ -n "$2" && "$2" != --* ]]; then + duration_sec="$2" + shift 2 + else + echo "❌ Argument --duration-sec requires a value" + exit 1 + fi + ;; + --result-path) + if [[ -n "$2" && "$2" != --* ]]; then + result_path="$2" + shift 2 + else + echo "❌ Argument --result-path requires a path" + exit 1 + fi + ;; + *) + echo "❌ Unknown argument: $1" + exit 1 + ;; + esac +done + +export NODE_RANK=$node_rank +export RANK=$node_rank + +num_gpus=1 +num_nodes=1 + +export WORLD_SIZE=$((num_gpus * num_nodes)) +# 打印配置信息 +echo "Training Configuration:" +echo "Checkpoint Directory: $ckpt_dir" +echo "Dataset Config: $dataset_config" +echo "Model Config: $model_config" +echo "Pretransform Checkpoint Path: $pretransform_ckpt_path" +echo "Num GPUs: $num_gpus" +echo "Num Nodes: $num_nodes" +echo "Test Batch Size: $test_batch_size" +echo "Num Workers: 20" +echo "Node Rank: $node_rank" +echo "WORLD SIZE: $WORLD_SIZE" + + +python predict.py \ + --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --ckpt-dir "$ckpt_dir" \ + --pretransform-ckpt-path "$pretransform_ckpt_path" \ + --checkpoint-every 2000 \ + --num-gpus "$num_gpus" \ + --num-nodes "$num_nodes" \ + --batch-size 1 \ + --test-batch-size $test_batch_size \ + --num-workers 32 \ + --duration-sec $duration_sec \ + --results-dir $result_path \ + diff --git a/scripts/ThinkSound/train.sh b/scripts/ThinkSound/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..f9151f4c251b4b9078751e0d92108d8ffe87b1d6 --- /dev/null +++ b/scripts/ThinkSound/train.sh @@ -0,0 +1,89 @@ +#!/bin/bash + + +ckpt_dir="your_project_name" +log_dir="logs/$ckpt_dir" +dataset_config="ThinkSound/configs/multimodal_dataset_demo.json" +model_config="ThinkSound/configs/model_configs/thinksound.json" +pretransform_ckpt_path="ckpts/vae.ckpt" +#export MASTER_ADDR="10.32.3.240" +export MASTER_PORT="9511" +# pip install git+https://github.com/patrick-kidger/torchcubicspline.git + +debug_mode="false" +node_rank=0 + + +while [[ "$#" -gt 0 ]]; do + case $1 in + --debug) debug_mode="true"; shift ;; + --node-rank) node_rank="$2"; shift; shift ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac +done + +export NODE_RANK=$node_rank +export WORLD_SIZE=8 + +mkdir demos + +if [ "$debug_mode" != "true" ]; then + mkdir -p "$log_dir" + + cp "$dataset_config" "$log_dir/" + cp "$model_config" "$log_dir/" + cp "$0" "$log_dir/" +fi + + + +if [ "$debug_mode" == "true" ]; then + num_gpus=1 + num_nodes=1 +else + num_gpus=8 + num_nodes=1 +fi + + +echo "Training Configuration:" +echo "Checkpoint Directory: $ckpt_dir" +echo "Log Directory: $log_dir" +echo "Dataset Config: $dataset_config" +echo "Model Config: $model_config" +echo "Pretransform Checkpoint Path: $pretransform_ckpt_path" +echo "Num GPUs: $num_gpus" +echo "Num Nodes: $num_nodes" +echo "Batch Size: 32" +echo "Num Workers: 24" +echo "Node Rank: $node_rank" + + +if [ "$debug_mode" == "true" ]; then + nohup python train.py \ + --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --name "$ckpt_dir" \ + --save-dir "logs/" \ + --pretransform-ckpt-path "$pretransform_ckpt_path" \ + --checkpoint-every 2000 \ + --num-gpus "$num_gpus" \ + --num-nodes "$num_nodes" \ + --batch-size 32 \ + --num-workers 24 +else + nohup python train.py \ + --dataset-config "$dataset_config" \ + --model-config "$model_config" \ + --name "$ckpt_dir" \ + --save-dir "logs/" \ + --pretransform-ckpt-path "$pretransform_ckpt_path" \ + --checkpoint-every 4000 \ + --num-gpus "$num_gpus" \ + --num-nodes "$num_nodes" \ + --batch-size 32 \ + --num-workers 24 \ + > "$log_dir/train.log" 2>&1 & + + echo "Training started. Logs can be found in $log_dir/train.log" +fi diff --git a/set_up.sh b/set_up.sh new file mode 100644 index 0000000000000000000000000000000000000000..d328f1a822a81e305ba1647909a0774c93ee3118 --- /dev/null +++ b/set_up.sh @@ -0,0 +1,9 @@ +git clone https://github.com/google-deepmind/videoprism.git +cd videoprism +pip install . +cd .. +pip install -r scripts/PrismAudio/setup/requirements.txt +pip install tensorflow-cpu==2.15.0 +pip install facenet_pytorch==2.6.0 --no-deps + +conda install -y -c conda-forge 'ffmpeg<7' diff --git a/third_party/LICENSE_StabilityAI.md b/third_party/LICENSE_StabilityAI.md new file mode 100644 index 0000000000000000000000000000000000000000..1d9ce2ee1067327543544de197291726e4fc57a4 --- /dev/null +++ b/third_party/LICENSE_StabilityAI.md @@ -0,0 +1,58 @@ +STABILITY AI COMMUNITY LICENSE AGREEMENT + +Last Updated: July 5, 2024 + +1. INTRODUCTION + +This Agreement applies to any individual person or entity (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or Derivative Works thereof for any Research & Non-Commercial or Commercial purpose. Capitalized terms not otherwise defined herein are defined in Section V below. + +This Agreement is intended to allow research, non-commercial, and limited commercial uses of the Models free of charge. In order to ensure that certain limited commercial uses of the Models continue to be allowed, this Agreement preserves free access to the Models for people or organizations generating annual revenue of less than US $1,000,000 (or local currency equivalent). + +By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement. If You are acting on behalf of a company, organization or other entity, then “You” includes you and that entity, and You agree that You: (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and (ii) You agree to the terms of this Agreement on that entity’s behalf. + +2. RESEARCH & NON-COMMERCIAL USE LICENSE + +Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose. “Research Purpose” means academic or scientific advancement, and in each case, is not primarily intended for commercial advantage or monetary compensation to You or others. “Non-Commercial Purpose” means any purpose other than a Research Purpose that is not primarily intended for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist) or evaluation and testing. + +3. COMMERCIAL USE LICENSE + +Subject to the terms of this Agreement (including the remainder of this Section III), Stability AI grants You a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under Stability AI’s intellectual property or other rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the Stability AI Materials for any Commercial Purpose. “Commercial Purpose” means any purpose other than a Research Purpose or Non-Commercial Purpose that is primarily intended for commercial advantage or monetary compensation to You or others, including but not limited to, (i) creating, modifying, or distributing Your product or service, including via a hosted service or application programming interface, and (ii) for Your business’s or organization’s internal operations. +If You are using or distributing the Stability AI Materials for a Commercial Purpose, You must register with Stability AI at (https://stability.ai/community-license). If at any time You or Your Affiliate(s), either individually or in aggregate, generate more than USD $1,000,000 in annual revenue (or the equivalent thereof in Your local currency), regardless of whether that revenue is generated directly or indirectly from the Stability AI Materials or Derivative Works, any licenses granted to You under this Agreement shall terminate as of such date. You must request a license from Stability AI at (https://stability.ai/enterprise) , which Stability AI may grant to You in its sole discretion. If you receive Stability AI Materials, or any Derivative Works thereof, from a Licensee as part of an integrated end user product, then Section III of this Agreement will not apply to you. + +4. GENERAL TERMS + +Your Research, Non-Commercial, and Commercial License(s) under this Agreement are subject to the following terms. +a. Distribution & Attribution. If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product or service that uses any portion of them, You shall: (i) provide a copy of this Agreement to that third party, (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved”, and (iii) prominently display “Powered by Stability AI” on a related website, user interface, blogpost, about page, or product documentation. If You create a Derivative Work, You may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that You clearly indicate which attributions apply to the Stability AI Materials and state in the “Notice” text file that You changed the Stability AI Materials and how it was modified. +b. Use Restrictions. Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby incorporated by reference. Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model (excluding the Models or Derivative Works). +c. Intellectual Property. +(i) Trademark License. No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of its Affiliates, except as required under Section IV(a) herein. +(ii) Ownership of Derivative Works. As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI. +(iii) Ownership of Outputs. As between You and Stability AI, You own any outputs generated from the Models or Derivative Works to the extent permitted by applicable law. +(iv) Disputes. If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by You, then any licenses granted to You under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of this Agreement. +(v) Feedback. From time to time, You may provide Stability AI with verbal and/or written suggestions, comments or other feedback related to Stability AI’s existing or prospective technology, products or services (collectively, “Feedback”). You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant Stability AI a perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive, worldwide right and license to exploit the Feedback in any manner without restriction. Your Feedback is provided “AS IS” and You make no warranties whatsoever about any Feedback. +d. Disclaimer Of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE STABILITY AI MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OR LAWFULNESS OF USING OR REDISTRIBUTING THE STABILITY AI MATERIALS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE STABILITY AI MATERIALS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. +e. Limitation Of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +f. Term And Termination. The term of this Agreement will commence upon Your acceptance of this Agreement or access to the Stability AI Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You shall delete and cease use of any Stability AI Materials or Derivative Works. Section IV(d), (e), and (g) shall survive the termination of this Agreement. +g. Governing Law. This Agreement will be governed by and constructed in accordance with the laws of the United States and the State of California without regard to choice of law principles, and the UN Convention on Contracts for International Sale of Goods does not apply to this Agreement. + +5. DEFINITIONS + +“Affiliate(s)” means any entity that directly or indirectly controls, is controlled by, or is under common control with the subject entity; for purposes of this definition, “control” means direct or indirect ownership or control of more than 50% of the voting interests of the subject entity. + +"Agreement" means this Stability AI Community License Agreement. + +“AUP” means the Stability AI Acceptable Use Policy available at (https://stability.ai/use-policy), as may be updated from time to time. + +"Derivative Work(s)” means (a) any derivative work of the Stability AI Materials as recognized by U.S. copyright laws and (b) any modifications to a Model, and any other model created which is based on or derived from the Model or the Model’s output, including “fine tune” and “low-rank adaptation” models derived from a Model or a Model’s output, but do not include the output of any Model. + +“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software or Models. + +“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing listed on Stability’s Core Models Webpage available at (https://stability.ai/core-models), as may be updated from time to time. + +"Stability AI" or "we" means Stability AI Ltd. and its Affiliates. + +"Software" means Stability AI’s proprietary software made available under this Agreement now or in the future. + +“Stability AI Materials” means, collectively, Stability’s proprietary Models, Software and Documentation (and any portion or combination thereof) made available under this Agreement. + +“Trade Control Laws” means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations. \ No newline at end of file