niobures commited on
Commit
49f2b3f
·
verified ·
1 Parent(s): db916fe

Audio-Flamingo (code, models, paper)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +23 -0
  2. Audio Flamingo 2. An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities.pdf +3 -0
  3. Audio Flamingo 3. Advancing Audio Intelligence with Fully Open Large Audio Language Models.pdf +3 -0
  4. Audio Flamingo Sound-CoT Technical Report. Improving Chain-of-Thought Reasoning in Sound Understanding.pdf +3 -0
  5. Audio Flamingo. A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities.pdf +3 -0
  6. NVIDIA представила модель, которая анализирует звук, речь и музыку.pdf +3 -0
  7. code/Audio-Flamingo-3-Pinokio.zip +3 -0
  8. code/Audio-Flamingo-3.zip +3 -0
  9. code/AudioFlamingo.zip +3 -0
  10. code/audio-flamingo-3-chat-hf.zip +3 -0
  11. code/audio-flamingo-3-hf.zip +3 -0
  12. code/audio-flamingo-audio_flamingo_2.zip +3 -0
  13. code/audio-flamingo-audio_flamingo_3.zip +3 -0
  14. code/audio-flamingo-soundCoT.zip +3 -0
  15. code/audio-flamingo.zip +3 -0
  16. code/audio_flamingo.zip +3 -0
  17. code/cog-nvidia-audio-flamingo-3.zip +3 -0
  18. models/audio-flamingo-1/.gitattributes +2 -0
  19. models/audio-flamingo-1/.gitignore +5 -0
  20. models/audio-flamingo-1/LICENSE +21 -0
  21. models/audio-flamingo-1/README.md +64 -0
  22. models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf +3 -0
  23. models/audio-flamingo-1/assets/audio_flamingo_arch.png +3 -0
  24. models/audio-flamingo-1/audio flamingo model card.md +115 -0
  25. models/audio-flamingo-1/chat/README.md +65 -0
  26. models/audio-flamingo-1/chat/clap_modified_code/CLAPWrapper.py +463 -0
  27. models/audio-flamingo-1/chat/configs/chat.yaml +80 -0
  28. models/audio-flamingo-1/chat/data/README.md +19 -0
  29. models/audio-flamingo-1/chat/data/data.py +481 -0
  30. models/audio-flamingo-1/chat/data/prepare_each_dataset.py +253 -0
  31. models/audio-flamingo-1/chat/src/__init__.py +2 -0
  32. models/audio-flamingo-1/chat/src/factory.py +219 -0
  33. models/audio-flamingo-1/chat/src/flamingo.py +260 -0
  34. models/audio-flamingo-1/chat/src/flamingo_lm.py +177 -0
  35. models/audio-flamingo-1/chat/src/helpers.py +380 -0
  36. models/audio-flamingo-1/chat/src/utils.py +54 -0
  37. models/audio-flamingo-1/chat/train/distributed.py +150 -0
  38. models/audio-flamingo-1/chat/train/train.py +376 -0
  39. models/audio-flamingo-1/chat/train/train_utils.py +351 -0
  40. models/audio-flamingo-1/checkpoints/chat_part1.pt +3 -0
  41. models/audio-flamingo-1/checkpoints/chat_part2.pt +3 -0
  42. models/audio-flamingo-1/checkpoints/chat_part3.pt +3 -0
  43. models/audio-flamingo-1/checkpoints/chat_part4.pt +3 -0
  44. models/audio-flamingo-1/checkpoints/chat_part5.pt +3 -0
  45. models/audio-flamingo-1/checkpoints/checkpoint_utils.py +19 -0
  46. models/audio-flamingo-1/checkpoints/foundation_part1.pt +3 -0
  47. models/audio-flamingo-1/checkpoints/foundation_part2.pt +3 -0
  48. models/audio-flamingo-1/checkpoints/foundation_part3.pt +3 -0
  49. models/audio-flamingo-1/checkpoints/foundation_part4.pt +3 -0
  50. models/audio-flamingo-1/checkpoints/foundation_part5.pt +3 -0
.gitattributes CHANGED
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Audio[[:space:]]Flamingo[[:space:]]2.[[:space:]]An[[:space:]]Audio-Language[[:space:]]Model[[:space:]]with[[:space:]]Long-Audio[[:space:]]Understanding[[:space:]]and[[:space:]]Expert[[:space:]]Reasoning[[:space:]]Abilities.pdf filter=lfs diff=lfs merge=lfs -text
37
+ Audio[[:space:]]Flamingo[[:space:]]3.[[:space:]]Advancing[[:space:]]Audio[[:space:]]Intelligence[[:space:]]with[[:space:]]Fully[[:space:]]Open[[:space:]]Large[[:space:]]Audio[[:space:]]Language[[:space:]]Models.pdf filter=lfs diff=lfs merge=lfs -text
38
+ Audio[[:space:]]Flamingo[[:space:]]Sound-CoT[[:space:]]Technical[[:space:]]Report.[[:space:]]Improving[[:space:]]Chain-of-Thought[[:space:]]Reasoning[[:space:]]in[[:space:]]Sound[[:space:]]Understanding.pdf filter=lfs diff=lfs merge=lfs -text
39
+ Audio[[:space:]]Flamingo.[[:space:]]A[[:space:]]Novel[[:space:]]Audio[[:space:]]Language[[:space:]]Model[[:space:]]with[[:space:]]Few-Shot[[:space:]]Learning[[:space:]]and[[:space:]]Dialogue[[:space:]]Abilities.pdf filter=lfs diff=lfs merge=lfs -text
40
+ models/audio-flamingo-1/assets/audio_flamingo_arch.png filter=lfs diff=lfs merge=lfs -text
41
+ models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf filter=lfs diff=lfs merge=lfs -text
42
+ models/audio-flamingo-1/labeling_machine/AF-AudioSet.json filter=lfs diff=lfs merge=lfs -text
43
+ models/audio-flamingo-3-chat/llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
44
+ models/audio-flamingo-3-chat/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
45
+ models/audio-flamingo-3-chat/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
46
+ models/audio-flamingo-3-chat/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
47
+ models/audio-flamingo-3-chat/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
48
+ models/audio-flamingo-3-hf/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
49
+ models/audio-flamingo-3-hf/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
50
+ models/audio-flamingo-3-hf/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
51
+ models/audio-flamingo-3-hf/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
52
+ models/audio-flamingo-3-hf/tokenizer.json filter=lfs diff=lfs merge=lfs -text
53
+ models/audio-flamingo-3/llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
54
+ models/audio-flamingo-3/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
55
+ models/audio-flamingo-3/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
56
+ models/audio-flamingo-3/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
57
+ models/audio-flamingo-3/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
58
+ NVIDIA[[:space:]]представила[[:space:]]модель,[[:space:]]которая[[:space:]]анализирует[[:space:]]звук,[[:space:]]речь[[:space:]]и[[:space:]]музыку.pdf filter=lfs diff=lfs merge=lfs -text
Audio Flamingo 2. An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d544cf57f324020ee5d9ff916c17d53aced283c09d38be09f9bc020a9ba171
3
+ size 10739247
Audio Flamingo 3. Advancing Audio Intelligence with Fully Open Large Audio Language Models.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f966955628f247ed76c7207b7b86048a1790794cc3b5cea47287ec14417f3508
3
+ size 6985793
Audio Flamingo Sound-CoT Technical Report. Improving Chain-of-Thought Reasoning in Sound Understanding.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20cf54c8128ca96298a342f066239370a122e42642ae3ded1f67b76cd8f80a4d
3
+ size 585189
Audio Flamingo. A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31451a317ebb7d0f4445500134cf5178b63c617d7f6583e0eac4f3a4c3d0000d
3
+ size 1444685
NVIDIA представила модель, которая анализирует звук, речь и музыку.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e96f68efb1fade752565268797ac90417bd142450eb8c8245c89f94994c22d09
3
+ size 2983908
code/Audio-Flamingo-3-Pinokio.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d740f122a84c86b6e55574c8f6ce7145a8518ccb00a874661810124d5bf1f71
3
+ size 1365689
code/Audio-Flamingo-3.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb1f6c63f18a25cc0db01c146783c651286fc53bb064d003b11518b62e7f59c2
3
+ size 6721741
code/AudioFlamingo.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40d2bac296affd39f8f63066c9bc4a2ac0f6ec982c542c3f3c8a961e1ef68ca3
3
+ size 2578624
code/audio-flamingo-3-chat-hf.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afd4af51c0c1a6e2e3b54323dea6b34872c3221826ea969ff40a6e055e3de0e4
3
+ size 1395827
code/audio-flamingo-3-hf.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3157b02422c02ccf87f00b99d0db9ad6ba78101fe43985461db101f418a4e1b4
3
+ size 1443418
code/audio-flamingo-audio_flamingo_2.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d71ee0ac516346df1cfc497da306b729cbe52c1f88c327a0d32ae36f22111450
3
+ size 5672326
code/audio-flamingo-audio_flamingo_3.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db2c7f1847f5f2380f58bd78aa93326a1262b98bc2ff179206a26f67d7c2b371
3
+ size 3445237
code/audio-flamingo-soundCoT.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7e130b11f7d96aca17b9d6feda693d4b9e65ad7b2a35d374451eb24875ac820
3
+ size 12563876
code/audio-flamingo.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8cc93bf1b574c278112642af6e930485d065b3818e40392fc08b6cbd621f6f1
3
+ size 2484492
code/audio_flamingo.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:971305b4acb2b932be39abe6b376a6d3c52dece06fa6873220631c48a486ba81
3
+ size 11632722
code/cog-nvidia-audio-flamingo-3.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5921a9348dea2c34db7c1542e66461e392dfe92410ad4072001b675ebb87e2eb
3
+ size 17389826
models/audio-flamingo-1/.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ AF-AudioSet.json filter=lfs diff=lfs merge=lfs -text
models/audio-flamingo-1/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .DS_Store
4
+ foundation.pt
5
+ chat.pt
models/audio-flamingo-1/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/audio-flamingo-1/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Implementation of Audio Flamingo
2
+
3
+ **Zhifeng Kong, Arushi Goel, Rohan Badlani, Wei Ping, Rafael Valle, Bryan Catanzaro**
4
+
5
+ [[Demo website]](https://audioflamingo.github.io/) [[Demo video]](https://www.youtube.com/watch?v=ucttuS28RVE) [[ICML poster]](assets/AudioFlamingo_ICML2024_poster.pdf)
6
+
7
+ This repo contains the PyTorch implementation of [Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities](https://arxiv.org/abs/2402.01831) (ICML 2024). Audio Flamingo is a novel audio-understanding language model with
8
+ - strong audio understanding abilities,
9
+ - the ability to quickly adapt to unseen tasks via in-context learning and retrieval, and
10
+ - strong multi-turn dialogue abilities.
11
+
12
+ We introduce a series of training techniques, architecture design, and data strategies to enhance our model with these abilities. Extensive evaluations across various audio understanding tasks confirm the efficacy of our method, setting new state-of-the-art benchmarks.
13
+
14
+ ![](assets/audio_flamingo_arch.png)
15
+
16
+ ## Code Structure
17
+
18
+ - The folder ```foundation/``` contains training code for the foundation model.
19
+ - The folder ```chat/``` contains training code for the chat model, which can perform multi-turn dialogues.
20
+ - The folder ```inference/``` contains inference code for both the foundation and chat models.
21
+
22
+ Within each folder, the structure is highly based on the [Open Flamingo](https://github.com/mlfoundations/open_flamingo) repo (commit ```a05dcba```). Each folder is self-contained and we expect no cross dependencies between these folders.
23
+
24
+ ## Preparation
25
+
26
+ - Download source code of Laion-CLAP from their [official repo](https://github.com/LAION-AI/CLAP). Rename the folder to ```my_laion_clap/``` and copy the folder to under each of ```foundation/, chat/, inference/```. Download their pretrained checkpoints to ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/laion-clap-pretrained/laion_clap/```.
27
+ - Download source code of Microsoft-CLAP from their [official repo](https://github.com/microsoft/CLAP). Rename the folder to ```my_ms_clap/``` and copy the folder to under each of ```foundation/, chat/, inference/```. In each of these, replace the ```my_ms_clap/msclap/CLAPWrapper.py``` with ```clap_modified_code/CLAPWrapper.py```, which adds some processing functions and removes some bugs for clapcap. Download their pretrained checkpoints to ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/clap/```.
28
+ - Download raw training and evaluation datasets from their original sources. Refer to ```foundation/data/README.md``` and ```chat/data/README.md``` for specific instructions to prepare data.
29
+
30
+ ## Running the Code
31
+
32
+ We refer to ```foundation/README.md```, ```chat/README.md```, and ```inference/README.md``` for the specific instructions for training the foundation model, training the chat model, and inferencing, as they require different setups. We used 8 A100 GPUs to train our models.
33
+
34
+ ## Checkpoints
35
+ - The folder ```checkpoints/``` contains foundation and chat model checkpoints.
36
+ - Each model is about 17GB. Due to ```git lfs``` constraints we split each model into 5 parts. After downloading, go to ```checkpoints/``` and ```python checkpoint_utils.py``` to merge the parts.
37
+ - Alternatively, the model checkpoints are also on HuggingFace (which is easier to download): [https://huggingface.co/nvidia/audio-flamingo](https://huggingface.co/nvidia/audio-flamingo). One can either ```git clone``` this project or use the ```huggingface_hub.hf_hub_download``` function to download: ```checkpoint_path = hf_hub_download(repo_id="nvidia/audio-flamingo", filename="foundation(or chat).pt")```.
38
+ - If you would like to run inference with these checkpoints, remember to modify the absolute paths in ```inference/configs/*.yaml``` and ```inference/inference_examples.py``` to properly load model checkpoints and data (see ```inference/README.md```).
39
+ - The foundation model is pretrained with ```foundation/configs/foundation_pretrain.yaml``` and then finetuned with ```foundation/configs/foundation_sft_8_shot.yaml```.
40
+ - The chat model is pretrained with ```foundation/configs/foundation_pretrain.yaml```, then finetuned with ```foundation/configs/foundation_sft_4_shot.yaml```, and finally finetuned with ```chat/configs/chat.yaml```.
41
+
42
+
43
+ ## Downstream applications
44
+ - We use Audio Flamingo as a data labeling machine for synthetic captions. See ```labeling_machine/``` for details of the synthetic dataset and license descriptions.
45
+
46
+ ## References
47
+
48
+ The main training and inferencing code within each folder (```foundation/```, ```chat/```, ```inference/```), including ```train/```, ```src/```, ```data/```, and ```configs/```, are modified from [Open Flamingo](https://github.com/mlfoundations/open_flamingo) (commit ```a05dcba```) (MIT license), which borrows from [flamingo-pytorch](https://github.com/lucidrains/flamingo-pytorch) (MIT license), [flamingo-mini](https://github.com/dhansmair/flamingo-mini) (MIT license), and [open_clip](https://github.com/mlfoundations/open_clip) (MIT license). ```src/helpers.py``` also includes self-attention implementations based on [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch) (MIT license), which borrows from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) (MIT license). Our code also relies on [LAION-AI/CLAP](https://github.com/LAION-AI/CLAP) (CC0-1.0 license) and [microsoft/CLAP](https://github.com/microsoft/CLAP) (MIT license). In ```chat/data/prepare_each_dataset.py```, the filtering keywords are based on the [LLARK](https://arxiv.org/abs/2310.07160) paper (CC-BY-4.0 license) and the [LTU](https://arxiv.org/abs/2305.10790) paper (CC-BY-4.0 license).
49
+
50
+ ## License
51
+
52
+ - The code in this repo is under MIT license (see ```LICENSE```).
53
+ - The checkpoints in this repo (```checkpoints/*.pt```) are for non-commercial use only. They are subject to the [OPT-IML](https://huggingface.co/facebook/opt-iml-1.3b/blob/main/LICENSE.md) license, the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset.
54
+
55
+
56
+ ## Citation
57
+ ```
58
+ @article{kong2024audio,
59
+ title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities},
60
+ author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan},
61
+ journal={arXiv preprint arXiv:2402.01831},
62
+ year={2024}
63
+ }
64
+ ```
models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f63fccc267408123e6119d0293ff81a6dfbe6979293d451ac14a2a3cc9abe98
3
+ size 1170996
models/audio-flamingo-1/assets/audio_flamingo_arch.png ADDED

Git LFS Details

  • SHA256: 12e09cd22361ec76fb00a23da064d6961da4271cc1673046068101d5054db7fc
  • Pointer size: 131 Bytes
  • Size of remote file: 492 kB
models/audio-flamingo-1/audio flamingo model card.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Overview
2
+
3
+ ## Description:
4
+ Audio Flamingo is a novel audio-understanding language model for
5
+
6
+ - understanding audio,
7
+ - quickly adapting to unseen tasks via in-context learning and retrieval, and
8
+ - understanding and responding to multi-turn dialogues
9
+
10
+ We introduce a series of training techniques, architecture design, and data strategies to enhance our model with these abilities. Extensive evaluations across various audio understanding tasks confirm the efficacy of our method, setting new state-of-the-art benchmarks.
11
+
12
+ <center><img src="https://github.com/NVIDIA/audio-flamingo/raw/main/assets/audio_flamingo_arch.png" width="800"></center>
13
+
14
+ **This model is ready for non-commercial research-only.**
15
+ <br>
16
+
17
+
18
+ ## References(s):
19
+ * [Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities](https://arxiv.org/abs/2402.01831) <br>
20
+ * [Project Page](https://github.com/NVIDIA/audio-flamingo) <br>
21
+ * [Demo Website](https://audioflamingo.github.io/) <br>
22
+
23
+ ## Model Architecture:
24
+ **Architecture Type:** Transformer <br>
25
+ **Network Architecture:** Audio Flamingo
26
+
27
+ Audio Flamingo is a Flamingo-style architecture with frozen audio feature extractor, trainable transformation layers and xattn-dense layers, and language model layers.
28
+
29
+ ## Input:
30
+ **Input Types:** Audio, Text <br>
31
+ **Input Format:** Wav/MP3/Flac, String <br>
32
+ **Input Parameters:** None <br>
33
+ **Maximum Audio Input Lengths:** 33.25 seconds <br>
34
+ **Maximum Text Input Lengths:** 512 tokens <br>
35
+
36
+ ## Output:
37
+ **Output Type:** Text <br>
38
+ **Output Format:** String <br>
39
+ **Output Parameters:** None <br>
40
+
41
+ ## Software Integration:
42
+ **Runtime Engine(s):** PyTorch
43
+
44
+ **Supported Hardware Microarchitecture Compatibility:**
45
+ * NVIDIA Ampere <br>
46
+ * NVIDIA Hopper <br>
47
+
48
+ ## Preferred/Supported Operating System(s):
49
+ * Linux
50
+
51
+
52
+ ## Model Version(s):
53
+ * v1.0
54
+
55
+ ## Training, Testing, and Evaluation Datasets:
56
+
57
+ ### Training Dataset:
58
+ Audio Flamingo is trained with **publicly available** datasets under various licenses, with the most restricted ones being non-commercial/research-only. The dataset contains diverse audio types including speech, environmental sounds, and music.
59
+
60
+
61
+ * [OpenAQA ](https://github.com/YuanGongND/ltu?tab=readme-ov-file): Data collection method - [Human]; Labeling method - [Synthetic]
62
+ * [Laion630K ](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md)
63
+ * [LP-MusicCaps ](https://github.com/seungheondoh/lp-music-caps)
64
+ * [SoundDescs ](https://github.com/akoepke/audio-retrieval-benchmark)
65
+ * [WavCaps](https://github.com/XinhaoMei/WavCaps)
66
+ * [AudioSet ](https://research.google.com/audioset/download.html)
67
+ * [AudioSet Strong Labeled ](https://research.google.com/audioset/download_strong.html)
68
+ * [WavText5K ](https://github.com/microsoft/WavText5K)
69
+ * [MSP-Podcast ](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html)
70
+ * [ClothoAQA ](https://zenodo.org/records/6473207)
71
+ * [Clotho-v2 ](https://github.com/audio-captioning/clotho-dataset/tree/master)
72
+ * [MACS ](https://zenodo.org/records/5114771)
73
+ * [FSD50k ](https://zenodo.org/records/4060432)
74
+ * [CochlScene ](https://github.com/cochlearai/cochlscene)
75
+ * [NonSpeech 7k ](https://zenodo.org/records/6967442)
76
+ * [Chime-home ](https://code.soundsoftware.ac.uk/projects/chime-home-dataset-annotation-and-baseline-evaluation-code)
77
+ * [Sonyc-UST ](https://zenodo.org/records/3966543)
78
+ * [Emov-DB ](https://github.com/numediart/EmoV-DB)
79
+ * [JL-Corpus ](https://github.com/tli725/JL-Corpus)
80
+ * [Tess ](https://www.kaggle.com/datasets/ejlok1/toronto-emotional-speech-set-tess)
81
+ * [OMGEmotion ](https://github.com/knowledgetechnologyuhh/OMGEmotionChallenge)
82
+ * [MELD ](https://github.com/declare-lab/MELD)
83
+ * [MusicAVQA ](https://gewu-lab.github.io/MUSIC-AVQA/)
84
+ * [MusicQA ](https://github.com/shansongliu/MU-LLaMA?tab=readme-ov-file)
85
+ * [MusicCaps ](https://www.kaggle.com/datasets/googleai/musiccaps)
86
+ * [NSynth ](https://magenta.tensorflow.org/datasets/nsynth)
87
+ * [MTG-Jamendo ](https://github.com/MTG/mtg-jamendo-dataset)
88
+ * [MusDB-HQ ](https://zenodo.org/records/3338373)
89
+ * [FMA ](https://github.com/mdeff/fma)
90
+
91
+ For all of these datasets, the data collection method is [human]. For OpenAQA, Laion630k, LP-MusicCaps, WavCaps, MusicQA, the data labeling method is [synthetic]. For the rest, the data labeling method is [human].
92
+
93
+ ### Evaluating Dataset:
94
+ Audio Flamingo is evaluated on the test split of the following datasets.
95
+
96
+ * [ClothoAQA ](https://zenodo.org/records/6473207)
97
+ * [MusicAVQA ](https://gewu-lab.github.io/MUSIC-AVQA/)
98
+ * [Clotho-v2 ](https://github.com/audio-captioning/clotho-dataset/tree/master)
99
+ * [FSD50k ](https://zenodo.org/records/4060432)
100
+ * [CochlScene ](https://github.com/cochlearai/cochlscene)
101
+ * [NonSpeech 7k ](https://zenodo.org/records/6967442)
102
+ * [NSynth ](https://magenta.tensorflow.org/datasets/nsynth)
103
+ * [AudioCaps ](https://github.com/cdjkim/audiocaps)
104
+ * [CREMA-D ](https://github.com/CheyneyComputerScience/CREMA-D)
105
+ * [Ravdess ](https://zenodo.org/records/1188976)
106
+ * [US8K ](https://urbansounddataset.weebly.com/urbansound8k.html)
107
+ * [GTZAN ](https://www.tensorflow.org/datasets/catalog/gtzan)
108
+ * [Medley-solos-DB ](https://zenodo.org/records/3464194)
109
+
110
+ For all of these datasets, the data collection method is [human] and the data labeling method is [human].
111
+
112
+ ## Inference
113
+
114
+ **Engine:** HuggingFace Transformers <br>
115
+ **Test Hardware [Name the specific test hardware model]:** A100 80GB <br>
models/audio-flamingo-1/chat/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio Flamingo Training (Chat Model)
2
+
3
+ ## Get data ready
4
+
5
+ Please read ```data/README.md``` for instructions on data preparation.
6
+
7
+ ## Get paths ready
8
+
9
+ Let ```YOUR_REPO_ROOT_DIR``` be the absolute path to this repo. We use the following structure
10
+
11
+ ```
12
+ YOUR_REPO_ROOT_DIR/
13
+ - foundation/
14
+ - chat/ # you are here
15
+ - inference/
16
+ ```
17
+
18
+ Replace ```YOUR_REPO_ROOT_DIR``` to your absolute path in the following places:
19
+ - ```configs/*.yaml --> clap_config --> config_root```
20
+
21
+
22
+ Let ```YOUR_DATA_ROOT_DIR``` be the absolute path to store all data, checkpoints, etc. We use the following structure
23
+ ```
24
+ YOUR_DATA_ROOT_DIR/
25
+ - datasets/
26
+ - <dataset_name_i>/
27
+ - files: raw data of this dataset, including raw waveforms, metadata, etc.
28
+
29
+ - audio-flamingo-data/
30
+ - dataset_files/
31
+ - <dataset_name_i>-<flamingo_task_i>/
32
+ - files: dataset manifests, precomputed embeddings, etc.
33
+
34
+ - checkpoint/
35
+ - <experiment_name>/ # same as the config file name, and train_config --> run_name in each config
36
+ - tensorboard/
37
+ - checkpoint_xxx.pt
38
+ - other cached files
39
+
40
+ - clap/
41
+ - files: pretrained Microsoft-CLAP checkpoints
42
+
43
+ - laion-clap-pretrained/laion_clap
44
+ - files: pretrained Laion-CLAP checkpoints
45
+
46
+ - LLM_pretrained/.cache/ # place to store HuggingFace cache instead of the default ~/.cache
47
+ ```
48
+
49
+ Replace ```YOUR_DATA_ROOT_DIR``` to your absolute path in the following places:
50
+ - ```configs/*.yaml```
51
+ - ```prepare_each_dataset.py --> __main__```
52
+
53
+ ## Training
54
+
55
+ The following code is tested on 1 node (8 GPUs per node) of A100 (80G) GPUs.
56
+
57
+ Set ```configs/chat.yaml --> sft_config --> pretrained_path``` and ```pretrained_ckpt``` to be the checkpoint of the pretrained model.
58
+ ```
59
+ export NCCL_IB_SL=1
60
+ export CUDA_DEVICE_MAX_CONNECTIONS=1
61
+ cd train/
62
+ torchrun --nproc_per_node 8 train.py -c ../configs/chat.yaml
63
+ ```
64
+
65
+
models/audio-flamingo-1/chat/clap_modified_code/CLAPWrapper.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/microsoft/CLAP under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import warnings
8
+ warnings.filterwarnings("ignore")
9
+ import random
10
+ import torchaudio
11
+ # from torch._six import string_classes
12
+ import collections
13
+ import re
14
+ import numpy as np
15
+ from transformers import AutoTokenizer, logging
16
+ try:
17
+ from models.clap import CLAP
18
+ from models.mapper import get_clapcap
19
+ except:
20
+ from .models.clap import CLAP
21
+ from .models.mapper import get_clapcap
22
+ import math
23
+ import torchaudio.transforms as T
24
+ import os
25
+ import torch
26
+ from importlib_resources import files
27
+ import argparse
28
+ import yaml
29
+ import sys
30
+ logging.set_verbosity_error()
31
+
32
+
33
+ class CLAPWrapper():
34
+ """
35
+ A class for interfacing CLAP model.
36
+ """
37
+
38
+ def __init__(self, model_fp, config_root, version, use_cuda=False):
39
+ self.supported_versions = ['2022', '2023', 'clapcap']
40
+ self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
41
+ self.file_path = os.path.realpath(__file__)
42
+ self.default_collate_err_msg_format = (
43
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
44
+ "dicts or lists; found {}")
45
+ self.config_root = config_root
46
+ self.config_as_str = self.get_config_path(version)
47
+ self.model_fp = model_fp
48
+ self.use_cuda = use_cuda
49
+ self.version = version
50
+ if 'clapcap' in self.version:
51
+ self.clapcap, self.tokenizer, self.args = self.load_clapcap()
52
+ else:
53
+ self.clap, self.tokenizer, self.args = self.load_clap()
54
+
55
+ def get_config_path(self, version):
56
+ if version in self.supported_versions:
57
+ return f"{self.config_root}/config_{version}.yml"
58
+ else:
59
+ raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
60
+
61
+ def read_config_as_args(self,config_path,args=None,is_config_str=False):
62
+ return_dict = {}
63
+
64
+ if config_path is not None:
65
+ if is_config_str:
66
+ yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
67
+ else:
68
+ with open(config_path, "r") as f:
69
+ yml_config = yaml.load(f, Loader=yaml.FullLoader)
70
+
71
+ if args != None:
72
+ for k, v in yml_config.items():
73
+ if k in args.__dict__:
74
+ args.__dict__[k] = v
75
+ else:
76
+ sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
77
+ else:
78
+ for k, v in yml_config.items():
79
+ return_dict[k] = v
80
+
81
+ args = args if args != None else return_dict
82
+ return argparse.Namespace(**args)
83
+
84
+ def load_clap(self):
85
+ r"""Load CLAP model with args from config file"""
86
+
87
+ args = self.read_config_as_args(self.config_as_str, is_config_str=False)
88
+
89
+ if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
90
+ self.token_keys = ['input_ids', 'attention_mask']
91
+ elif 'bert' in args.text_model:
92
+ self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
93
+
94
+ clap = CLAP(
95
+ audioenc_name=args.audioenc_name,
96
+ sample_rate=args.sampling_rate,
97
+ window_size=args.window_size,
98
+ hop_size=args.hop_size,
99
+ mel_bins=args.mel_bins,
100
+ fmin=args.fmin,
101
+ fmax=args.fmax,
102
+ classes_num=args.num_classes,
103
+ out_emb=args.out_emb,
104
+ text_model=args.text_model,
105
+ transformer_embed_dim=args.transformer_embed_dim,
106
+ d_proj=args.d_proj
107
+ )
108
+
109
+ # Load pretrained weights for model
110
+ model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
111
+
112
+ # We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
113
+ # Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
114
+ clap.load_state_dict(model_state_dict)
115
+
116
+ clap.eval() # set clap in eval mode
117
+ tokenizer = AutoTokenizer.from_pretrained(args.text_model)
118
+ if 'gpt' in args.text_model:
119
+ tokenizer.add_special_tokens({'pad_token': '!'})
120
+
121
+ if self.use_cuda and torch.cuda.is_available():
122
+ clap = clap.cuda()
123
+
124
+ return clap, tokenizer, args
125
+
126
+ def load_clapcap(self):
127
+ r"""Load CLAP model with args from config file"""
128
+
129
+ args = self.read_config_as_args(self.config_as_str, is_config_str=False)
130
+ args.prefix_dim = args.d_proj
131
+ text_model = args.text_model
132
+ args.text_model = args.text_decoder
133
+ args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
134
+
135
+ if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
136
+ self.token_keys = ['input_ids', 'attention_mask']
137
+ elif 'bert' in args.text_model:
138
+ self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
139
+
140
+ clap = CLAP(
141
+ audioenc_name=args.audioenc_name,
142
+ sample_rate=args.sampling_rate,
143
+ window_size=args.window_size,
144
+ hop_size=args.hop_size,
145
+ mel_bins=args.mel_bins,
146
+ fmin=args.fmin,
147
+ fmax=args.fmax,
148
+ classes_num=args.num_classes,
149
+ out_emb=args.out_emb,
150
+ text_model=text_model,
151
+ transformer_embed_dim=args.transformer_embed_dim,
152
+ d_proj=args.d_proj
153
+ )
154
+
155
+ clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
156
+ args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
157
+
158
+ model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
159
+ clapcap.load_state_dict(model_state_dict)
160
+
161
+ clapcap.eval() # set clap in eval mode
162
+ tokenizer = AutoTokenizer.from_pretrained(args.text_model)
163
+ if 'gpt' in args.text_model:
164
+ tokenizer.add_special_tokens({'pad_token': '!'})
165
+
166
+ if self.use_cuda and torch.cuda.is_available():
167
+ clapcap = clapcap.cuda()
168
+
169
+ return clapcap, tokenizer, args
170
+
171
+ def default_collate(self, batch):
172
+ r"""Puts each data field into a tensor with outer dimension batch size"""
173
+ elem = batch[0]
174
+ elem_type = type(elem)
175
+ if isinstance(elem, torch.Tensor):
176
+ out = None
177
+ if torch.utils.data.get_worker_info() is not None:
178
+ # If we're in a background process, concatenate directly into a
179
+ # shared memory tensor to avoid an extra copy
180
+ numel = sum([x.numel() for x in batch])
181
+ storage = elem.storage()._new_shared(numel)
182
+ out = elem.new(storage)
183
+ return torch.stack(batch, 0, out=out)
184
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
185
+ and elem_type.__name__ != 'string_':
186
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
187
+ # array of string classes and object
188
+ if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
189
+ raise TypeError(
190
+ self.default_collate_err_msg_format.format(elem.dtype))
191
+
192
+ return self.default_collate([torch.as_tensor(b) for b in batch])
193
+ elif elem.shape == (): # scalars
194
+ return torch.as_tensor(batch)
195
+ elif isinstance(elem, float):
196
+ return torch.tensor(batch, dtype=torch.float64)
197
+ elif isinstance(elem, int):
198
+ return torch.tensor(batch)
199
+ # elif isinstance(elem, string_classes):
200
+ # return batch
201
+ elif isinstance(elem, collections.abc.Mapping):
202
+ return {key: self.default_collate([d[key] for d in batch]) for key in elem}
203
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
204
+ return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
205
+ elif isinstance(elem, collections.abc.Sequence):
206
+ # check to make sure that the elements in batch have consistent size
207
+ it = iter(batch)
208
+ elem_size = len(next(it))
209
+ if not all(len(elem) == elem_size for elem in it):
210
+ raise RuntimeError(
211
+ 'each element in list of batch should be of equal size')
212
+ transposed = zip(*batch)
213
+ return [self.default_collate(samples) for samples in transposed]
214
+
215
+ raise TypeError(self.default_collate_err_msg_format.format(elem_type))
216
+
217
+ def read_audio(self, audio_path, resample=False):
218
+ r"""Loads audio file or array and returns a torch tensor"""
219
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
220
+ audio_time_series, sample_rate = torchaudio.load(audio_path)
221
+
222
+ resample_rate = self.args.sampling_rate
223
+ if resample:
224
+ resampler = T.Resample(sample_rate, resample_rate)
225
+ audio_time_series = resampler(audio_time_series)
226
+ return audio_time_series, sample_rate
227
+
228
+ def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
229
+ r"""Loads audio file and returns raw audio."""
230
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
231
+ audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
232
+ audio_time_series = audio_time_series.reshape(-1)
233
+
234
+ # audio_time_series is shorter than predefined audio duration,
235
+ # so audio_time_series is extended
236
+ if audio_duration*sample_rate >= audio_time_series.shape[0]:
237
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
238
+ audio_time_series.shape[0]))
239
+ # Repeat audio_time_series by repeat_factor to match audio_duration
240
+ audio_time_series = audio_time_series.repeat(repeat_factor)
241
+ # remove excess part of audio_time_series
242
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
243
+ else:
244
+ # audio_time_series is longer than predefined audio duration,
245
+ # so audio_time_series is trimmed
246
+ start_index = random.randrange(
247
+ audio_time_series.shape[0] - audio_duration*sample_rate)
248
+ audio_time_series = audio_time_series[start_index:start_index +
249
+ audio_duration*sample_rate]
250
+ return torch.FloatTensor(audio_time_series)
251
+
252
+ # Modified
253
+ def load_audio_clip_into_tensor(self, audio_clip, audio_duration, resample=False):
254
+ r"""Loads audio clip and returns raw audio."""
255
+ # Randomly sample a segment of audio_duration from the clip or pad to match duration
256
+ sample_rate = 44100
257
+ audio_time_series = audio_clip.reshape(-1)
258
+
259
+ # audio_time_series is shorter than predefined audio duration,
260
+ # so audio_time_series is extended
261
+ assert audio_duration * sample_rate >= audio_time_series.shape[0], \
262
+ 'dur * sr = {} should be larger than len = {}'.format(audio_duration * sample_rate, audio_time_series.shape[0])
263
+ repeat_factor = int(np.ceil((audio_duration*sample_rate) /
264
+ audio_time_series.shape[0]))
265
+ # Repeat audio_time_series by repeat_factor to match audio_duration
266
+ audio_time_series = audio_time_series.repeat(repeat_factor)
267
+ # remove excess part of audio_time_series
268
+ audio_time_series = audio_time_series[0:audio_duration*sample_rate]
269
+
270
+ # return torch.FloatTensor(audio_time_series)
271
+ return audio_time_series # already on cuda device
272
+
273
+ def preprocess_audio(self, audio_files, resample):
274
+ r"""Load list of audio files and return raw audio"""
275
+ audio_tensors = []
276
+ for audio_file in audio_files:
277
+ audio_tensor = self.load_audio_into_tensor(
278
+ audio_file, self.args.duration, resample)
279
+ audio_tensor = audio_tensor.reshape(
280
+ 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
281
+ audio_tensors.append(audio_tensor)
282
+ return self.default_collate(audio_tensors)
283
+
284
+ # Modified
285
+ def preprocess_audio_clips(self, audio_clips, resample=False):
286
+ r"""Load list of audio clips and return raw audio"""
287
+ audio_tensors = []
288
+ for audio_clip in audio_clips:
289
+ audio_tensor = self.load_audio_clip_into_tensor(
290
+ audio_clip, self.args.duration, resample=False)
291
+ audio_tensor = audio_tensor.reshape(
292
+ 1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
293
+ audio_tensors.append(audio_tensor)
294
+ return self.default_collate(audio_tensors)
295
+
296
+ def preprocess_text(self, text_queries):
297
+ r"""Load list of class labels and return tokenized text"""
298
+ tokenized_texts = []
299
+ for ttext in text_queries:
300
+ if 'gpt' in self.args.text_model:
301
+ ttext = ttext + ' <|endoftext|>'
302
+ tok = self.tokenizer.encode_plus(
303
+ text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
304
+ for key in self.token_keys:
305
+ tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
306
+ tokenized_texts.append(tok)
307
+ return self.default_collate(tokenized_texts)
308
+
309
+ def get_text_embeddings(self, class_labels):
310
+ r"""Load list of class labels and return text embeddings"""
311
+ preprocessed_text = self.preprocess_text(class_labels)
312
+ return self._get_text_embeddings(preprocessed_text)
313
+
314
+ def get_audio_embeddings(self, audio_files, resample):
315
+ r"""Load list of audio files and return a audio embeddings"""
316
+ preprocessed_audio = self.preprocess_audio(audio_files, resample)
317
+ return self._get_audio_embeddings(preprocessed_audio)
318
+
319
+ # Modified
320
+ def get_audio_embeddings_from_clips(self, audio_clips, resample=False):
321
+ r"""Load list of audio files and return a audio embeddings"""
322
+ preprocessed_audio = self.preprocess_audio_clips(audio_clips, resample)
323
+ return self._get_audio_embeddings(preprocessed_audio)
324
+
325
+ def _get_text_embeddings(self, preprocessed_text):
326
+ r"""Load preprocessed text and return text embeddings"""
327
+ with torch.no_grad():
328
+ return self.clap.caption_encoder(preprocessed_text)
329
+
330
+ # Modified
331
+ def _get_audio_embeddings(self, preprocessed_audio):
332
+ r"""Load preprocessed audio and return a audio embeddings"""
333
+ with torch.no_grad():
334
+ preprocessed_audio = preprocessed_audio.reshape(
335
+ preprocessed_audio.shape[0], preprocessed_audio.shape[2])
336
+ #Append [0] the audio emebdding, [1] has output class probabilities
337
+ if 'clapcap' in self.version:
338
+ return self.clapcap.clap(preprocessed_audio)[0]
339
+ else:
340
+ return self.clap.audio_encoder(preprocessed_audio)[0]
341
+
342
+ def _generic_batch_inference(self, func, *args):
343
+ r"""Process audio and/or text per batch"""
344
+ input_tmp = args[0]
345
+ batch_size = args[-1]
346
+ # args[0] has audio_files, args[1] has class_labels
347
+ inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
348
+ args0_len = len(args[0])
349
+ # compute text_embeddings once for all the audio_files batches
350
+ if len(inputs) == 2:
351
+ text_embeddings = self.get_text_embeddings(args[1])
352
+ inputs = [args[0], args[1], text_embeddings]
353
+ dataset_idx = 0
354
+ for _ in range(math.ceil(args0_len/batch_size)):
355
+ next_batch_idx = dataset_idx + batch_size
356
+ # batch size is bigger than available audio/text items
357
+ if next_batch_idx >= args0_len:
358
+ inputs[0] = input_tmp[dataset_idx:]
359
+ return func(*tuple(inputs))
360
+ else:
361
+ inputs[0] = input_tmp[dataset_idx:next_batch_idx]
362
+ yield func(*tuple(inputs))
363
+ dataset_idx = next_batch_idx
364
+
365
+ def get_audio_embeddings_per_batch(self, audio_files, batch_size):
366
+ r"""Load preprocessed audio and return a audio embeddings per batch"""
367
+ return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
368
+
369
+ def get_text_embeddings_per_batch(self, class_labels, batch_size):
370
+ r"""Load preprocessed text and return text embeddings per batch"""
371
+ return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
372
+
373
+ def compute_similarity(self, audio_embeddings, text_embeddings):
374
+ r"""Compute similarity between text and audio embeddings"""
375
+ audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
376
+ text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
377
+
378
+ logit_scale = self.clap.logit_scale.exp()
379
+ similarity = logit_scale*text_embeddings @ audio_embeddings.T
380
+ return similarity.T
381
+
382
+ def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
383
+ r"""Compute classification probabilities for each audio recording in a batch and each class label"""
384
+ return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
385
+
386
+ def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
387
+ r"""Generate audio captions for each audio recording in a batch"""
388
+ captions = []
389
+ audio_tensors = self.preprocess_audio(audio_files, resample)
390
+
391
+ with torch.no_grad():
392
+ prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
393
+ if self.args.normalize_prefix:
394
+ prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
395
+ prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
396
+
397
+ for i in range(len(audio_tensors)):
398
+ gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
399
+ beam_size=beam_size,\
400
+ entry_length=entry_length,\
401
+ temperature=temperature)[0]
402
+ captions.append(gen_caption.capitalize())
403
+ return captions
404
+
405
+ def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
406
+ entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
407
+ r"""Generate captions by beam search decoding"""
408
+ self.clapcap.eval()
409
+ stop_token_index = self.tokenizer.encode(stop_token)[0]
410
+ tokens = None
411
+ scores = None
412
+ device = next(self.clapcap.parameters()).device
413
+ seq_lengths = torch.ones(beam_size, device=device)
414
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
415
+ with torch.no_grad():
416
+ if embed is not None:
417
+ generated = embed
418
+ else:
419
+ if tokens is None:
420
+ tokens = torch.tensor(self.tokenizer.encode(prompt))
421
+ tokens = tokens.unsqueeze(0).to(device)
422
+ generated = self.clapcap.gpt.transformer.wte(tokens)
423
+ for i in range(entry_length):
424
+ outputs = self.clapcap.gpt(inputs_embeds=generated)
425
+ logits = outputs.logits
426
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
427
+ logits = logits.softmax(-1).log()
428
+ if scores is None:
429
+ scores, next_tokens = logits.topk(beam_size, -1)
430
+ generated = generated.expand(beam_size, *generated.shape[1:])
431
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
432
+ if tokens is None:
433
+ tokens = next_tokens
434
+ else:
435
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
436
+ tokens = torch.cat((tokens, next_tokens), dim=1)
437
+ else:
438
+ logits[is_stopped] = -float(np.inf)
439
+ logits[is_stopped, 0] = 0
440
+ scores_sum = scores[:, None] + logits
441
+ seq_lengths[~is_stopped] += 1
442
+ scores_sum_average = scores_sum / seq_lengths[:, None]
443
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
444
+ next_tokens_source = next_tokens // scores_sum.shape[1]
445
+ seq_lengths = seq_lengths[next_tokens_source]
446
+ next_tokens = next_tokens % scores_sum.shape[1]
447
+ next_tokens = next_tokens.unsqueeze(1)
448
+ tokens = tokens[next_tokens_source]
449
+ tokens = torch.cat((tokens, next_tokens), dim=1)
450
+ generated = generated[next_tokens_source]
451
+ scores = scores_sum_average * seq_lengths
452
+ is_stopped = is_stopped[next_tokens_source]
453
+ next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
454
+ generated = torch.cat((generated, next_token_embed), dim=1)
455
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
456
+ if is_stopped.all():
457
+ break
458
+ scores = scores / seq_lengths
459
+ output_list = tokens.cpu().numpy()
460
+ output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
461
+ order = scores.argsort(descending=True)
462
+ output_texts = [output_texts[i] for i in order]
463
+ return output_texts
models/audio-flamingo-1/chat/configs/chat.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_config:
2
+ expdir: YOUR_DATA_ROOT_DIR/audio-flamingo-data/checkpoint
3
+ run_name: chat
4
+ delete_previous_checkpoint: false
5
+ batch_size: 4
6
+ gradient_accumulation_steps: 4 # global batchsize = 128
7
+ seed: 42
8
+ learning_rate: 0.00002
9
+ lr_scheduler: constant
10
+ loss_multiplier: 1.0
11
+ warmup_steps: 1875
12
+ weight_decay: 0.1
13
+ precision: fp32
14
+ gradient_checkpointing: False
15
+ num_epochs: 1
16
+ offline: false
17
+ freeze_lm_embeddings: false
18
+ logging_steps: 10
19
+ dist_backend: nccl
20
+ dist_url: env://
21
+ no_set_device_rank: false
22
+ fsdp: true
23
+ fsdp_use_orig_params: false # Passed into the FSDP constructor. Enables param_groups and gradient masking for weight_decay. Does not work with OPT.
24
+ fsdp_sharding_strategy: full # full, hybrid
25
+ horovod: false
26
+
27
+ # Chat SFT hparams
28
+ sft_config:
29
+ pretrained_path: YOUR_DATA_ROOT_DIR/audio-flamingo-data/checkpoint/foundation_sft_4_shot/
30
+ pretrained_ckpt: checkpoint_99.pt
31
+ unfreeze_full_lm: true
32
+
33
+ data_config:
34
+ dataset_blending_global_weight: 1.0
35
+
36
+ dataset_blending_config:
37
+ dialog_AudioSetSL-Dialog/train:
38
+ weight: 1.0
39
+ prefix_prob: 1.0
40
+
41
+ dialog_MusicCaps-Dialog/train:
42
+ weight: 5.0
43
+ prefix_prob: 1.0
44
+
45
+ dataset_file_root: YOUR_DATA_ROOT_DIR/audio-flamingo-data/dataset_files
46
+ data_root: YOUR_DATA_ROOT_DIR/datasets
47
+ dataset_blending_output: dataset_blending.json
48
+ max_tokens: 512
49
+ num_workers: 4
50
+
51
+ clap_config:
52
+ # method: laion-clap
53
+ # audio_embed_dim: 512
54
+ # model_name: 630k-fusion-best
55
+ # checkpoint: YOUR_DATA_ROOT_DIR/audio-flamingo-data/laion-clap-pretrained/laion_clap/630k-fusion-best.pt
56
+
57
+ method: microsoft-clap
58
+ audio_embed_dim: 1024
59
+ config_root: YOUR_REPO_ROOT_DIR/chat/my_ms_clap/src/configs
60
+ model_name: 'clapcap'
61
+ checkpoint: YOUR_DATA_ROOT_DIR/audio-flamingo-data/clap/clapcap_weights_2023.pth
62
+
63
+ window_length: 7.0 # seconds
64
+ window_overlap: 5.25 # seconds
65
+ max_num_window: 16 # total = 33.25 seconds
66
+ max_num_fewshot: 4 # number of fewshot samples
67
+
68
+ model_config:
69
+ cache_dir: YOUR_DATA_ROOT_DIR/audio-flamingo-data/LLM_pretrained/.cache
70
+
71
+ lang_encoder_path: facebook/opt-iml-max-1.3b
72
+ tokenizer_path: facebook/opt-iml-max-1.3b
73
+ cross_attn_every_n_layers: 1
74
+ audio_transformer_kwargs: {
75
+ n_head: 8,
76
+ n_layers: 3,
77
+ d_inner: 2048,
78
+ max_num_media: 128, # must be >= max_num_window * num_fewshot_samples (4)
79
+ max_window_per_audio: 16, # must = max_num_window
80
+ }
models/audio-flamingo-1/chat/data/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Preparation
2
+
3
+ Data preparation and loading is a challenging part in this codebase as complex formats are used. Below are the instructions to prepare dataset manifests.
4
+
5
+ ## Step 1: Download raw datasets
6
+
7
+ Download datasets from their original sources, or prepare your own datasets. For simplicity, in this repo, we assume datasets are stored under ```YOUR_DATA_ROOT_DIR/datasets/<dataset_name>```.
8
+
9
+ ## Step 2: Prepare dialogues
10
+
11
+ Follow the instructions in Appendix B in our paper to generate dialogues from rich metadata and filter for quality.
12
+
13
+ ## Step 3: Prepare raw datasets into manifests
14
+
15
+ - Modify the ```prepare_files()``` function in ```prepare_each_dataset.py``` based on your raw dataset files.
16
+ - For each dataset, this function generates manifests for each split (train/val/test). The manifest is stored under ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/dataset_files/```. The filenames are in the format of ```<dataset_name>-Dialog/train.json```.
17
+ - The ```<dataset_name>``` used in Audio Flamingo can be found in ```configs/*.yaml``` --> data_config --> dataset_blending_config.
18
+ - The structure of manifests can be found within the ```prepare_files()``` function.
19
+ - Usage: ```python prepare_each_dataset.py -d <dataset_name>```.
models/audio-flamingo-1/chat/data/data.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import functools
8
+ import io
9
+ import json
10
+ import math
11
+ import os
12
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
13
+ import random
14
+ import re
15
+ import string
16
+ import subprocess
17
+ import sys
18
+ import yaml
19
+
20
+ import numpy as np
21
+
22
+ from collections import defaultdict
23
+ from copy import deepcopy
24
+ from dataclasses import dataclass
25
+ from functools import partial
26
+ from pydub import AudioSegment
27
+ from tqdm import tqdm
28
+
29
+ import torch
30
+ import torchvision
31
+ from torch.utils.data import DataLoader, Dataset, get_worker_info
32
+ from torch.utils.data.distributed import DistributedSampler
33
+
34
+
35
+ from transformers import AutoTokenizer
36
+
37
+ import librosa
38
+ import soundfile as sf
39
+
40
+
41
+ def int16_to_float32(x):
42
+ return (x / 32767.0).astype(np.float32)
43
+
44
+
45
+ def float32_to_int16(x):
46
+ x = np.clip(x, a_min=-1., a_max=1.)
47
+ return (x * 32767.).astype(np.int16)
48
+
49
+
50
+ class DataCollator:
51
+ def __init__(self, tokenizer):
52
+ self.tokenizer = tokenizer
53
+
54
+ def __call__(self, batch):
55
+ filenames, audio_clips, audio_embed_mask, input_ids, attention_masks = zip(*batch)
56
+
57
+ audio_clips = torch.cat([x.unsqueeze(0) for x in audio_clips], dim=0)
58
+ audio_embed_mask = torch.cat([x.unsqueeze(0) for x in audio_embed_mask], dim=0)
59
+
60
+ max_length = max([ids.shape[1] for ids in input_ids])
61
+
62
+ padded_input_ids = []
63
+ padded_attention_masks = []
64
+ for ids, mask in zip(input_ids, attention_masks):
65
+ if ids.shape[1] < max_length:
66
+ padded_input_ids.append(
67
+ torch.cat([ids, torch.LongTensor([self.tokenizer.pad_token_id] * (max_length - ids.shape[1])).unsqueeze(0)], dim=1)
68
+ )
69
+ padded_attention_masks.append(
70
+ torch.cat([mask, torch.LongTensor([0] * (max_length - mask.shape[1])).unsqueeze(0)], dim=1)
71
+ )
72
+ else:
73
+ padded_input_ids.append(ids)
74
+ padded_attention_masks.append(mask)
75
+
76
+ padded_input_ids = torch.cat(padded_input_ids, dim=0)
77
+ padded_attention_masks = torch.cat(padded_attention_masks, dim=0).bool()
78
+
79
+ out_dict = dict(
80
+ filenames=filenames,
81
+ audio_clips=audio_clips,
82
+ audio_embed_mask=audio_embed_mask,
83
+ input_ids=padded_input_ids,
84
+ attention_mask=padded_attention_masks
85
+ )
86
+ return out_dict
87
+
88
+
89
+ class AudioTextData(torch.utils.data.Dataset):
90
+ def __init__(
91
+ self,
92
+ dataset_file_root: str,
93
+ data_root: str,
94
+ clap_config: dict,
95
+ dataset_blending_global_weight: float,
96
+ dataset_blending_config: dict,
97
+ dataset_blending_output: str,
98
+ tokenizer,
99
+ max_tokens: int,
100
+ split: str = 'train',
101
+ epoch: int = 0,
102
+ force_reblend: bool = False,
103
+ **kwargs
104
+ ):
105
+ self.dataset_file_root = dataset_file_root
106
+ self.data_root = data_root
107
+ self.clap_config = clap_config
108
+ self.dataset_blending_global_weight = dataset_blending_global_weight
109
+ self.dataset_blending_config = dataset_blending_config
110
+
111
+ self.split = split
112
+ self.epoch = epoch
113
+ self.force_reblend = force_reblend
114
+
115
+ assert self.split == 'train'
116
+ self.data = self.blend_dataset(dataset_blending_config, dataset_blending_output)
117
+
118
+ self.tokenizer = tokenizer
119
+ self.tokenizer.padding_side = "right"
120
+ self.max_tokens = max_tokens
121
+
122
+ @staticmethod
123
+ def shuffle_dict_fixed_rand(dic, seed=0):
124
+ print('randomly shuffling key-value pairs')
125
+
126
+ local_random = np.random.default_rng(seed)
127
+ original_keys = list(dic.keys())
128
+ shuffled_keys = deepcopy(original_keys)
129
+ local_random.shuffle(shuffled_keys)
130
+ shuffling_mapping = {x: y for (x, y) in zip(original_keys, shuffled_keys)}
131
+
132
+ shuffled_dic = {}
133
+ for idx in original_keys:
134
+ shuffled_idx = shuffling_mapping[idx]
135
+ shuffled_dic[idx] = dic[shuffled_idx]
136
+ return shuffled_dic
137
+
138
+ @staticmethod
139
+ def is_broken_file(audiopath):
140
+ # write your broken file paths here
141
+ BROKEN_FILES = []
142
+ return audiopath in BROKEN_FILES
143
+
144
+ def _read_dataset_file(self, dataset_file):
145
+ print("reading", dataset_file)
146
+ with open(dataset_file) as f:
147
+ contents = f.read()
148
+ contents = json.loads(contents)
149
+
150
+ assert contents["dataset_path"].startswith(self.data_root)
151
+ rel_path = contents["dataset_path"][len(self.data_root):]
152
+ if rel_path.startswith('/'):
153
+ rel_path = rel_path[1:]
154
+ if contents['split_path'] is not None:
155
+ rel_path = os.path.join(rel_path, contents['split_path'])
156
+
157
+ """
158
+ contents["data"] = {
159
+ "0": {'name': name (xxx.wav), 'dialogue': [
160
+ {"user": question 1, "assistant": answer 1},
161
+ ...
162
+ {"user": question k, "assistant": answer k}
163
+ ]
164
+ },
165
+ "1": {'name': name (xxx.wav), 'dialogue': [
166
+ {"user": question 1, "assistant": answer 1},
167
+ ...
168
+ {"user": question k, "assistant": answer k}
169
+ ]
170
+ },
171
+ ...
172
+ "total_num-1": {'name': name (xxx.wav), 'dialogue': [
173
+ {"user": question 1, "assistant": answer 1},
174
+ ...
175
+ {"user": question k, "assistant": answer k}
176
+ ]
177
+ }
178
+ }
179
+ """
180
+
181
+ for idx in contents["data"]:
182
+ contents["data"][idx]['task'] = contents["flamingo_task"]
183
+ contents["data"][idx]['name'] = os.path.join(
184
+ rel_path, contents["data"][idx]['name']
185
+ )
186
+ return contents
187
+
188
+ def blend_dataset(self, dataset_blending_config, dataset_blending_output):
189
+ if os.path.exists(dataset_blending_output) and not self.force_reblend:
190
+ print("loading blended dataset file from:", dataset_blending_output)
191
+ with open(dataset_blending_output) as f:
192
+ contents = f.read()
193
+ self_data = json.loads(contents)
194
+
195
+ else:
196
+ if not self.force_reblend:
197
+ print("no blended dataset file found; reading all dataset files")
198
+ else:
199
+ print("force reblending dataset at epoch {}; reading all dataset files".format(self.epoch))
200
+
201
+ all_data = {}
202
+ for dataset_name in dataset_blending_config:
203
+ dataset_file = os.path.join(self.dataset_file_root, '{}.json'.format(dataset_name))
204
+ contents = self._read_dataset_file(dataset_file)
205
+ contents['data'] = self.shuffle_dict_fixed_rand(
206
+ contents['data'],
207
+ seed=sum(list(map(ord, dataset_name)))
208
+ )
209
+
210
+ weight_global = float(self.dataset_blending_global_weight)
211
+ weight_dataset = float(dataset_blending_config[dataset_name]["weight"])
212
+ weight = weight_global * weight_dataset
213
+
214
+ all_data[dataset_name] = {
215
+ "contents": contents,
216
+ "weight": weight
217
+ }
218
+
219
+ self_data = {
220
+ "dataset_path": self.data_root,
221
+ "split_path": None,
222
+ "total_num": 0,
223
+ "data": {}
224
+ }
225
+
226
+ for dataset_name in all_data:
227
+ print('blending {}'.format(dataset_name))
228
+
229
+ contents = all_data[dataset_name]["contents"]
230
+ shuffled_contents_data = contents['data']
231
+ weight = all_data[dataset_name]["weight"]
232
+ assert type(weight) == float and weight > 0.0
233
+
234
+ dataset_total_num = contents['total_num']
235
+ start_idx = int(self.epoch * dataset_total_num * weight)
236
+ end_idx = int((self.epoch + 1) * dataset_total_num * weight)
237
+
238
+ for idx in range(start_idx, end_idx):
239
+ if idx > 0 and idx % dataset_total_num == 0:
240
+ print('force shuffling at new epoch {} for dataset {}'.format(idx // dataset_total_num, dataset_name))
241
+ shuffled_contents_data = self.shuffle_dict_fixed_rand(
242
+ contents['data'],
243
+ seed=sum(list(map(ord, '{}-epoch-{}'.format(dataset_name, idx // dataset_total_num))))
244
+ )
245
+
246
+ key = str(idx % dataset_total_num)
247
+ item = shuffled_contents_data[key]
248
+
249
+ found_broken = False
250
+ assert type(item['name']) is str
251
+ audiopath = os.path.join(self.data_root, item['name'])
252
+ if self.is_broken_file(audiopath):
253
+ print('cannot read {}'.format(audiopath))
254
+ found_broken = True
255
+
256
+ if found_broken:
257
+ continue
258
+
259
+ self_data['data'][self_data['total_num']] = item
260
+ self_data['total_num'] += 1
261
+
262
+ if not self.force_reblend:
263
+ print('writing blended dataset file to:', dataset_blending_output)
264
+ with open(dataset_blending_output, 'w') as json_file:
265
+ json.dump(self_data, json_file)
266
+ else:
267
+ print('writing reblended dataset file to:', dataset_blending_output.replace('.json', '-reblended.json'))
268
+ with open(dataset_blending_output.replace('.json', '-reblended.json'), 'w') as json_file:
269
+ json.dump(self_data, json_file)
270
+
271
+ return self_data
272
+
273
+ def get_num_windows(self, T, sr):
274
+ clap_config = self.clap_config
275
+ window_length = int(float(clap_config["window_length"]) * sr)
276
+ window_overlap = int(float(clap_config["window_overlap"]) * sr)
277
+ max_num_window = int(clap_config["max_num_window"])
278
+
279
+ num_windows = 1
280
+ if T <= window_length:
281
+ num_windows = 1
282
+ full_length = window_length
283
+ elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
284
+ num_windows = max_num_window
285
+ full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
286
+ else:
287
+ num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
288
+ full_length = num_windows * window_length - (num_windows - 1) * window_overlap
289
+
290
+ return num_windows, full_length
291
+
292
+ def load_audio(self, file_path, target_sr=44100, duration=30.0, start=0.0):
293
+ if file_path.endswith('.mp3'):
294
+ audio = AudioSegment.from_file(file_path)
295
+ if len(audio) > (start + duration) * 1000:
296
+ audio = audio[start * 1000:(start + duration) * 1000]
297
+
298
+ if audio.frame_rate != target_sr:
299
+ audio = audio.set_frame_rate(target_sr)
300
+
301
+ if audio.channels > 1:
302
+ audio = audio.set_channels(1)
303
+
304
+ data = np.array(audio.get_array_of_samples())
305
+ if audio.sample_width == 2:
306
+ data = data.astype(np.float32) / np.iinfo(np.int16).max
307
+ elif audio.sample_width == 4:
308
+ data = data.astype(np.float32) / np.iinfo(np.int32).max
309
+ else:
310
+ raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
311
+
312
+ else:
313
+ with sf.SoundFile(file_path) as audio:
314
+ original_sr = audio.samplerate
315
+ channels = audio.channels
316
+
317
+ max_frames = int((start + duration) * original_sr)
318
+
319
+ audio.seek(int(start * original_sr))
320
+ frames_to_read = min(max_frames, len(audio))
321
+ data = audio.read(frames_to_read)
322
+
323
+ if data.max() > 1 or data.min() < -1:
324
+ data = data / max(abs(data.max()), abs(data.min()))
325
+
326
+ if original_sr != target_sr:
327
+ if channels == 1:
328
+ data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
329
+ else:
330
+ data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
331
+ else:
332
+ if channels != 1:
333
+ data = data.T[0]
334
+
335
+ if data.min() >= 0:
336
+ data = 2 * data / abs(data.max()) - 1.0
337
+ else:
338
+ data = data / max(abs(data.max()), abs(data.min()))
339
+
340
+ assert len(data.shape) == 1, data.shape
341
+ return data
342
+
343
+ def compute_sliding_window(self, audio_file, audio_start=0.0):
344
+ if type(audio_start) == str:
345
+ audio_start = float(audio_start)
346
+
347
+ clap_config = self.clap_config
348
+
349
+ if clap_config["method"] == 'laion-clap':
350
+ sr = 48000
351
+ elif clap_config["method"] == 'microsoft-clap':
352
+ sr = 44100
353
+ else:
354
+ raise NotImplementedError
355
+
356
+ window_length = int(float(clap_config["window_length"]) * sr)
357
+ window_overlap = int(float(clap_config["window_overlap"]) * sr)
358
+ max_num_window = int(clap_config["max_num_window"])
359
+ duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"]
360
+
361
+ audio_data = self.load_audio(os.path.join(self.data_root, audio_file), sr, duration, audio_start)
362
+ T = len(audio_data)
363
+ num_windows, full_length = self.get_num_windows(T, sr)
364
+
365
+ if full_length > T:
366
+ audio_data = np.append(audio_data, np.zeros(full_length - T))
367
+ audio_data = audio_data.reshape(1, -1)
368
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
369
+
370
+ audio_clips = []
371
+ audio_embed_mask = torch.zeros(max_num_window)
372
+ for i in range(num_windows):
373
+ start = i * (window_length - window_overlap)
374
+ audio_clips.append(audio_data_tensor[:, start:start+window_length])
375
+ audio_embed_mask[i] = 1
376
+
377
+ assert sum(audio_embed_mask) == num_windows
378
+
379
+ if num_windows < max_num_window:
380
+ for _ in range(max_num_window - num_windows):
381
+ audio_clips.append(torch.zeros_like(audio_clips[-1]))
382
+
383
+ audio_clips = torch.cat(audio_clips) # (max_num_window, window_length * sr) cuda tensor
384
+
385
+ return audio_clips, audio_embed_mask
386
+
387
+ def preprocess_string_for_eval(self, x):
388
+ x = x.rstrip().lstrip()
389
+ x = x.lower()
390
+ return x
391
+
392
+ def __getitem__(self, i):
393
+ try:
394
+ item = self.data['data'][str(i)]
395
+ except:
396
+ item = self.data['data'][i]
397
+
398
+ assert type(item['name']) is str
399
+ audio_files = [os.path.join(self.data_root, item['name'])]
400
+ audio_starts = [0 if 'audio_start' not in item else float(item['audio_start'])]
401
+
402
+ audio_clips, audio_embed_mask = [], []
403
+ for audio_file, audio_start in zip(audio_files, audio_starts):
404
+ this_audio_clips, this_audio_embed_mask = self.compute_sliding_window(audio_file, audio_start)
405
+ audio_clips.append(this_audio_clips)
406
+ audio_embed_mask.append(this_audio_embed_mask)
407
+
408
+ audio_clips = torch.cat(audio_clips)
409
+ audio_embed_mask = torch.cat(audio_embed_mask)
410
+
411
+ correct_num_windows = int(self.clap_config["max_num_window"]) * int(self.clap_config["max_num_fewshot"])
412
+ if len(audio_clips) < correct_num_windows:
413
+ audio_clips = torch.cat([
414
+ audio_clips,
415
+ torch.zeros(correct_num_windows - len(audio_clips), audio_clips.shape[1])
416
+ ])
417
+ audio_embed_mask = torch.cat([
418
+ audio_embed_mask,
419
+ torch.zeros(correct_num_windows - len(audio_embed_mask))
420
+ ])
421
+
422
+ audio_clips.requires_grad = False
423
+ audio_embed_mask.requires_grad = False
424
+
425
+ assert 'dialogue' in item
426
+ dialogue = item['dialogue']
427
+ prefix = 'The task is dialog. '
428
+ sample = f"{self.tokenizer.bos_token}{prefix}<audio>"
429
+ for each_round in dialogue:
430
+ user_content, assistant_content = each_round['user'], each_round['assistant']
431
+ sample = sample + f"user: {user_content} \nassistant: {self.tokenizer.sep_token}{assistant_content}<|endofchunk|>{self.tokenizer.eos_token}\n"
432
+
433
+ text = self.tokenizer(
434
+ sample,
435
+ max_length=self.max_tokens,
436
+ padding="longest",
437
+ truncation="only_first",
438
+ return_tensors="pt"
439
+ )
440
+
441
+ return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"])
442
+
443
+ def __len__(self):
444
+ return len(list(self.data['data'].keys()))
445
+
446
+
447
+ @dataclass
448
+ class DataInfo:
449
+ dataset: Dataset
450
+ dataloader: DataLoader
451
+ sampler: DistributedSampler = None
452
+
453
+ def set_epoch(self, epoch):
454
+ if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
455
+ self.sampler.set_epoch(epoch)
456
+
457
+
458
+ def get_audiotext_dataloader(data_config, clap_config, text_tokenizer, batch_size, split='train', epoch=0, force_reblend=False):
459
+ assert split == 'train'
460
+
461
+ data_collator = DataCollator(text_tokenizer)
462
+ dataloader_shuffle = False
463
+
464
+ trainset = AudioTextData(
465
+ **data_config,
466
+ clap_config=clap_config,
467
+ tokenizer=text_tokenizer,
468
+ split=split,
469
+ epoch=epoch,
470
+ force_reblend=force_reblend
471
+ )
472
+ sampler = DistributedSampler(trainset, shuffle=True)
473
+ trainloader = DataLoader(
474
+ trainset,
475
+ sampler=sampler,
476
+ batch_size=batch_size,
477
+ shuffle=dataloader_shuffle,
478
+ collate_fn=data_collator,
479
+ num_workers=data_config["num_workers"]
480
+ )
481
+ return DataInfo(dataset=trainset, dataloader=trainloader, sampler=sampler)
models/audio-flamingo-1/chat/data/prepare_each_dataset.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import json
6
+ import csv
7
+ import yaml
8
+ from collections import defaultdict
9
+ import pickle
10
+ import glob
11
+ import math
12
+ from functools import partial
13
+ import sys
14
+ import io
15
+ import warnings
16
+ import random
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ import librosa
22
+ from pydub import AudioSegment
23
+ import soundfile as sf
24
+
25
+ import faiss
26
+
27
+ import multiprocessing
28
+ multiprocessing.set_start_method('spawn', force=True)
29
+
30
+ try:
31
+ from tqdm import tqdm
32
+ except:
33
+ tqdm = lambda x: x
34
+
35
+
36
+ def filter_file(file_path, file_list, filename):
37
+ if file_list is not None:
38
+ if filename not in file_list:
39
+ print(filename, 'not exist')
40
+ return True
41
+ else:
42
+ if not os.path.exists(os.path.join(file_path, filename)):
43
+ print(filename, 'not exist')
44
+ return True
45
+
46
+ if os.path.getsize(os.path.join(file_path, filename)) < 16000:
47
+ print(filename, 'less than 0.5 to 1 second')
48
+ return True
49
+
50
+ return False
51
+
52
+
53
+ def filter_response(response):
54
+ filter_phrases_LLARK = [
55
+ 'metadata', 'is not provided', 'based on theprovided metadata',
56
+ 'based on the providedbeat', 'based on the provided chord',
57
+ 'basedon the provided information', 'based on theprovided annotations',
58
+ 'no specific mood,there is no mention of',
59
+ 'there is no specificmention of any', 'as an ai assistant',
60
+ 'iam unable to', 'as an ai assistant', 'i donot',
61
+ 'it is difficult to determine', 'it isnot possible to determine',
62
+ 'no informationis available about the album', 'cannotdetermine',
63
+ 'violin 1', 'violin 2', 'violin 3,viola 1', 'viola 2', 'viola 3', 'pack'
64
+ ]
65
+
66
+ filter_phrases_LTU = [
67
+ 'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot',
68
+ 'without more information', 'enough information',
69
+ 'not possible', 'more context', 'enough', 'impossible', 'cannot be determined',
70
+ 'without additional information',
71
+ 'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide',
72
+ 'difficult to determine', 'no information provided',
73
+ "can't infer", "difficult to infer", "not specified", "no specific", "no information",
74
+ "without additional", 'it is difficult to',
75
+ "no indication"
76
+ ]
77
+
78
+ filter_phrases_ours = ["doesn't provide", "doesn't specify", "doesn't indicate", "based on"]
79
+
80
+ for phrase in filter_phrases_LLARK + filter_phrases_LTU + filter_phrases_ours:
81
+ if phrase in response.lower():
82
+ return True
83
+ return False
84
+
85
+
86
+ # !!!Important!!! please write your own code to create dataset manifests based on your stored datasets
87
+ # The list of dataset_name and flamingo_task can be found in configs/*.yaml --> data_config --> dataset_blending_config
88
+ def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file):
89
+
90
+ assert not os.path.exists(output_file)
91
+ dataset_dic = {
92
+ "dataset_path": dataset_path,
93
+ "split": split,
94
+ "split_path": None,
95
+ "flamingo_task": "{}-{}".format(dataset_name, flamingo_task),
96
+ "total_num": 0,
97
+ "data": {}
98
+ }
99
+
100
+ """
101
+ dataset_dic has the format
102
+ {
103
+ "dataset_path": YOUR_DATA_ROOT_DIR/datasets/dataset_name/,
104
+ "split": "train" or "test",
105
+ "split_path": ./,
106
+ "flamingo_task": <dataset_name>-Dialog,
107
+ "total_num": total number of samples,
108
+ "data": a dictionary of data manifest (see below)
109
+ }
110
+
111
+ dataset_dic["data"] has the format
112
+ {
113
+ "0": {'name': name (xxx.wav), 'dialogue': [
114
+ {"user": question 1, "assistant": answer 1},
115
+ ...
116
+ {"user": question k, "assistant": answer k}
117
+ ]
118
+ },
119
+ "1": {'name': name (xxx.wav), 'dialogue': [
120
+ {"user": question 1, "assistant": answer 1},
121
+ ...
122
+ {"user": question k, "assistant": answer k}
123
+ ]
124
+ },
125
+ ...
126
+ "total_num-1": {'name': name (xxx.wav), 'dialogue': [
127
+ {"user": question 1, "assistant": answer 1},
128
+ ...
129
+ {"user": question k, "assistant": answer k}
130
+ ]
131
+ }
132
+ }
133
+
134
+ Note that os.path.join(dataset_path, split_path, name) is the absolute path to the audio file.
135
+ Note that audio files are not restricted to wav. However, mp3 is not recommended due to a different seeking mechanism.
136
+ """
137
+
138
+ if dataset_name == 'dialog_AudioSetSL':
139
+ assert flamingo_task == "Dialog"
140
+ assert split == 'train'
141
+ map_split = lambda split: './'
142
+ file_path = os.path.join(
143
+ dataset_path,
144
+ map_split(split)
145
+ )
146
+ assert os.path.exists(file_path), '{} not exist'.format(file_path)
147
+
148
+ dataset_dic["split_path"] = map_split(split)
149
+ file_list = None
150
+
151
+ json_filename = 'dialogues_audioset_thresholded.json'
152
+ with open(os.path.join(dataset_path, json_filename)) as f:
153
+ data_list = f.read()
154
+ data_list = json.loads(data_list)
155
+
156
+ for data in tqdm(data_list):
157
+ filename = data["audio_id"]
158
+ if filter_file(file_path, file_list, filename):
159
+ continue
160
+
161
+ dialogue = data['dialogue']
162
+
163
+ # filter bad dialog
164
+ discard = False
165
+ for each_round in dialogue:
166
+ if filter_response(each_round['assistant']):
167
+ discard = True
168
+ break
169
+
170
+ if not discard:
171
+ dataset_dic["data"][dataset_dic["total_num"]] = {
172
+ "name": filename,
173
+ "dialogue": dialogue
174
+ }
175
+ dataset_dic["total_num"] += 1
176
+
177
+ elif dataset_name == 'dialog_MusicCaps':
178
+ assert flamingo_task == "Dialog"
179
+ assert split == 'train'
180
+ map_split = lambda split: './'
181
+ file_path = os.path.join(
182
+ dataset_path,
183
+ map_split(split)
184
+ )
185
+ assert os.path.exists(file_path), '{} not exist'.format(file_path)
186
+
187
+ dataset_dic["split_path"] = map_split(split)
188
+ file_list = None
189
+
190
+ json_filename = 'dialogues_musiccaps_thresholded.json'
191
+ with open(os.path.join(dataset_path, json_filename)) as f:
192
+ data_list = f.read()
193
+ data_list = json.loads(data_list)
194
+
195
+ for data in tqdm(data_list):
196
+ filename = data["audio_id"]
197
+ if filter_file(file_path, file_list, filename):
198
+ continue
199
+
200
+ dialogue = data['dialogue']
201
+
202
+ # filter bad dialog
203
+ discard = False
204
+ for each_round in dialogue:
205
+ if filter_response(each_round['assistant']):
206
+ discard = True
207
+ break
208
+
209
+ if not discard:
210
+ dataset_dic["data"][dataset_dic["total_num"]] = {
211
+ "name": filename,
212
+ "dialogue": dialogue
213
+ }
214
+ dataset_dic["total_num"] += 1
215
+
216
+ with open(output_file, 'w') as json_file:
217
+ json.dump(dataset_dic, json_file)
218
+
219
+
220
+ if __name__ == '__main__':
221
+ import argparse
222
+
223
+ parser = argparse.ArgumentParser()
224
+ parser.add_argument('-d', '--dataset_name', type=str, help='dataset name')
225
+ parser.add_argument('-f', '--flamingo_task', type=str, default='Dialog', help='flamingo task')
226
+ args = parser.parse_args()
227
+
228
+ global DATA_ROOT_DIR
229
+ DATA_ROOT_DIR = "YOUR_DATA_ROOT_DIR"
230
+ dataset_root = os.path.join(DATA_ROOT_DIR, "datasets")
231
+ output_root = os.path.join(DATA_ROOT_DIR, "audio-flamingo-data/dataset_files")
232
+ os.makedirs(output_root, exist_ok=True)
233
+
234
+ dataset_name = args.dataset_name # dialog_AudioSetSL, dialog_MusicCaps
235
+ flamingo_task = args.flamingo_task # Dialog
236
+
237
+ split = 'train'
238
+ dataset_path = os.path.join(dataset_root, dataset_name)
239
+
240
+ output_folder = '{}-{}'.format(dataset_name, flamingo_task)
241
+ os.makedirs(os.path.join(output_root, output_folder), exist_ok=True)
242
+
243
+ dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split))
244
+ if not os.path.exists(dataset_file):
245
+ try:
246
+ prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file)
247
+ except AssertionError as e:
248
+ print('split {} not exist for {}: {}'.format(split, dataset_name, e))
249
+ continue
250
+ else:
251
+ print('{} exists; exiting'.format(dataset_file))
252
+
253
+
models/audio-flamingo-1/chat/src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
models/audio-flamingo-1/chat/src/factory.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import sys
8
+ sys.path.append('../')
9
+
10
+ from typing import Optional
11
+ from copy import deepcopy
12
+
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from my_laion_clap.CLAP.src.laion_clap.hook import CLAP_Module
15
+ from my_ms_clap.src.CLAPWrapper import CLAPWrapper
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ try:
21
+ from .flamingo import Flamingo
22
+ from .flamingo_lm import FlamingoLMMixin
23
+ from .utils import extend_instance
24
+ except:
25
+ from flamingo import Flamingo
26
+ from flamingo_lm import FlamingoLMMixin
27
+ from utils import extend_instance
28
+
29
+
30
+ class CLAP(nn.Module):
31
+ def __init__(self, clap_config):
32
+ super(CLAP, self).__init__()
33
+ self.method = clap_config["method"]
34
+ device_id = f'cuda:{torch.cuda.current_device()}'
35
+
36
+ if ('finetune' in clap_config) and clap_config['finetune']:
37
+ self.finetune = True
38
+ print('Finetuning CLAP encoder as well!')
39
+ else:
40
+ self.finetune = False
41
+
42
+ if self.method == 'laion-clap':
43
+ # https://github.com/LAION-AI/CLAP
44
+ if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']:
45
+ amodel = 'HTSAT-tiny'
46
+ elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']:
47
+ amodel = 'HTSAT-base'
48
+ else:
49
+ raise NotImplementedError
50
+
51
+ enable_fusion = 'fusion' in clap_config["model_name"].lower()
52
+ self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id)
53
+ self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
54
+
55
+
56
+ for param in self.laion_clap.parameters():
57
+ param.requires_grad = self.finetune
58
+
59
+ if self.finetune:
60
+ self.laion_clap.train()
61
+ else:
62
+ self.laion_clap.eval()
63
+
64
+ print('loaded laion-clap model: {}'.format(clap_config["checkpoint"]))
65
+
66
+ elif self.method == 'microsoft-clap':
67
+ # https://github.com/microsoft/CLAP
68
+ self.ms_clap = CLAPWrapper(
69
+ clap_config["checkpoint"],
70
+ config_root=clap_config["config_root"],
71
+ version=clap_config['model_name'],
72
+ use_cuda=True
73
+ )
74
+
75
+ if clap_config['model_name'] in ['2022', '2023']:
76
+ for param in self.ms_clap.clap.parameters():
77
+ param.requires_grad = self.finetune
78
+ if self.finetune:
79
+ self.ms_clap.clap.train()
80
+ else:
81
+ self.ms_clap.clap.eval()
82
+ else:
83
+ for param in self.ms_clap.clapcap.parameters():
84
+ param.requires_grad = self.finetune
85
+ if self.finetune:
86
+ self.ms_clap.clapcap.train()
87
+ else:
88
+ self.ms_clap.clapcap.eval()
89
+
90
+ print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"]))
91
+
92
+ else:
93
+ raise NotImplementedError
94
+
95
+ def forward(self, audio_clips):
96
+
97
+ if len(audio_clips.shape) == 2:
98
+ audio_clips = audio_clips.unsqueeze(0)
99
+ assert len(audio_clips.shape) == 3
100
+
101
+ audio_embeds = []
102
+ for x in audio_clips:
103
+ if self.method == 'laion-clap':
104
+ audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True)
105
+ elif self.method == 'microsoft-clap':
106
+ audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x)
107
+
108
+ audio_embeds.append(audio_embed)
109
+
110
+ audio_embeds = torch.stack(audio_embeds, dim=0)
111
+ audio_embeds.requires_grad = self.finetune
112
+
113
+ return audio_embeds
114
+
115
+
116
+ def create_model_and_transforms(
117
+ clap_config: dict,
118
+ lang_encoder_path: str,
119
+ tokenizer_path: str,
120
+ audio_transformer_kwargs: dict,
121
+ cross_attn_every_n_layers: int = 1,
122
+ use_local_files: bool = False,
123
+ decoder_layers_attr_name: str = None,
124
+ freeze_lm_embeddings: bool = False,
125
+ unfreeze_full_lm: bool = False,
126
+ cache_dir: Optional[str] = None,
127
+ **flamingo_kwargs,
128
+ ):
129
+ clap = CLAP(clap_config)
130
+
131
+ text_tokenizer = AutoTokenizer.from_pretrained(
132
+ tokenizer_path,
133
+ local_files_only=use_local_files,
134
+ trust_remote_code=True,
135
+ cache_dir=cache_dir,
136
+ )
137
+ text_tokenizer.add_special_tokens(
138
+ {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
139
+ )
140
+ if text_tokenizer.pad_token is None:
141
+ text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
142
+ if text_tokenizer.sep_token is None:
143
+ text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
144
+
145
+ lang_encoder = AutoModelForCausalLM.from_pretrained(
146
+ lang_encoder_path,
147
+ local_files_only=use_local_files,
148
+ trust_remote_code=True,
149
+ cache_dir=cache_dir,
150
+ )
151
+
152
+ extend_instance(lang_encoder, FlamingoLMMixin)
153
+
154
+ if decoder_layers_attr_name is None:
155
+ decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
156
+ lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
157
+ lang_encoder.resize_token_embeddings(len(text_tokenizer))
158
+
159
+ if ('finetune' in clap_config) and clap_config['finetune']:
160
+ unfreeze_clap = True
161
+ else:
162
+ unfreeze_clap = False
163
+
164
+ model = Flamingo(
165
+ clap,
166
+ unfreeze_clap,
167
+ lang_encoder,
168
+ text_tokenizer.encode("<|endofchunk|>")[-1],
169
+ text_tokenizer.encode("<audio>")[-1],
170
+ text_tokenizer.sep_token_id,
171
+ audio_embed_dim=clap_config["audio_embed_dim"],
172
+ audio_transformer_kwargs=audio_transformer_kwargs,
173
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
174
+ **flamingo_kwargs,
175
+ )
176
+
177
+ model.requires_grad_(False)
178
+ assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
179
+
180
+ model.audio_transformer.requires_grad_(True)
181
+ model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
182
+ if not freeze_lm_embeddings:
183
+ model.lang_encoder.get_input_embeddings().requires_grad_(True)
184
+
185
+ if unfreeze_full_lm:
186
+ model.lang_encoder.requires_grad_(True)
187
+
188
+ if unfreeze_clap:
189
+ model.clap.requires_grad_(True)
190
+
191
+ print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format(
192
+ sum(p.numel() for p in model.parameters() if p.requires_grad),
193
+ sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad),
194
+ sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad)
195
+ ))
196
+
197
+ return model, text_tokenizer
198
+
199
+
200
+ def _infer_decoder_layers_attr_name(model):
201
+ for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
202
+ if k.lower() in model.__class__.__name__.lower():
203
+ return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
204
+
205
+ raise ValueError(
206
+ f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
207
+ )
208
+
209
+
210
+ __KNOWN_DECODER_LAYERS_ATTR_NAMES = {
211
+ "opt": "model.decoder.layers",
212
+ "gptj": "transformer.h",
213
+ "gpt-j": "transformer.h",
214
+ "pythia": "gpt_neox.layers",
215
+ "llama": "model.layers",
216
+ "gptneoxforcausallm": "gpt_neox.layers",
217
+ "mpt": "transformer.blocks",
218
+ "mosaicgpt": "transformer.blocks",
219
+ }
models/audio-flamingo-1/chat/src/flamingo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from torch.distributed.fsdp.wrap import (
12
+ enable_wrap,
13
+ wrap,
14
+ )
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+ from torch.distributed.fsdp import (
17
+ FullyShardedDataParallel as FSDP,
18
+ )
19
+
20
+ try:
21
+ from .helpers import TransformerEncoder
22
+ from .utils import apply_with_stopping_condition
23
+ except:
24
+ from helpers import TransformerEncoder
25
+ from utils import apply_with_stopping_condition
26
+
27
+
28
+ class Flamingo(nn.Module):
29
+ def __init__(
30
+ self,
31
+ clap: nn.Module,
32
+ unfreeze_clap: bool,
33
+ lang_encoder: nn.Module,
34
+ eoc_token_id: int,
35
+ media_token_id: int,
36
+ sep_token_id: int,
37
+ audio_embed_dim: int,
38
+ audio_transformer_kwargs: dict,
39
+ cross_attn_every_n_layers: int = 1,
40
+ gradient_checkpointing: bool = False,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.eoc_token_id = eoc_token_id
45
+ self.media_token_id = media_token_id
46
+ self.sep_token_id = sep_token_id
47
+ self.audio_embed_dim = audio_embed_dim
48
+ self.clap = clap # .to(torch.cuda.current_device())
49
+ self.unfreeze_clap = unfreeze_clap
50
+ self.clap.requires_grad_(unfreeze_clap)
51
+
52
+ if hasattr(lang_encoder.config, "d_model"):
53
+ self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
54
+ else:
55
+ self.lang_dim = lang_encoder.config.hidden_size
56
+
57
+ n_head = audio_transformer_kwargs["n_head"]
58
+ n_layers = audio_transformer_kwargs["n_layers"]
59
+ d_inner = audio_transformer_kwargs["d_inner"]
60
+ max_num_media = audio_transformer_kwargs["max_num_media"]
61
+ max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"]
62
+ assert audio_embed_dim % n_head == 0
63
+
64
+ self.audio_transformer = TransformerEncoder(
65
+ d_word_vec=audio_embed_dim,
66
+ n_layers=n_layers,
67
+ n_head=n_head,
68
+ d_k=audio_embed_dim // n_head,
69
+ d_v=audio_embed_dim // n_head,
70
+ d_model=audio_embed_dim,
71
+ d_inner=d_inner,
72
+ dropout=0.0,
73
+ n_position=max_num_media,
74
+ scale_emb=True
75
+ )
76
+
77
+ self.lang_encoder = lang_encoder
78
+ self.lang_encoder.init_flamingo(
79
+ media_token_id=media_token_id,
80
+ lang_hidden_size=self.lang_dim,
81
+ audio_hidden_size=self.audio_embed_dim,
82
+ max_window_per_audio=max_window_per_audio,
83
+ cross_attn_every_n_layers=cross_attn_every_n_layers,
84
+ gradient_checkpointing=gradient_checkpointing,
85
+ )
86
+
87
+ self._use_gradient_checkpointing = gradient_checkpointing
88
+ self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing
89
+ self.clap._use_gradient_checkpointing = gradient_checkpointing
90
+
91
+ def forward(
92
+ self,
93
+ audio_x: torch.Tensor,
94
+ audio_x_mask: torch.Tensor,
95
+ lang_x: torch.Tensor,
96
+ attention_mask: torch.Tensor = None,
97
+ labels: torch.Tensor = None,
98
+ clear_conditioned_layers: bool = True,
99
+ past_key_values=None,
100
+ use_cache: bool = False,
101
+ ):
102
+ assert (
103
+ self.lang_encoder.initialized_flamingo
104
+ ), "Flamingo layers are not initialized. Please call `init_flamingo` first."
105
+
106
+ assert (
107
+ self.lang_encoder._use_cached_audio_x or audio_x is not None
108
+ ), "Must provide either audio_x or have precached media using cache_media()."
109
+
110
+ if self.lang_encoder._use_cached_audio_x:
111
+ assert (
112
+ audio_x is None
113
+ ), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first."
114
+ assert self.lang_encoder.is_conditioned()
115
+
116
+ else:
117
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
118
+ self._condition_media_locations(input_ids=lang_x)
119
+
120
+ output = self.lang_encoder(
121
+ input_ids=lang_x,
122
+ attention_mask=attention_mask,
123
+ labels=labels,
124
+ past_key_values=past_key_values,
125
+ use_cache=use_cache,
126
+ )
127
+
128
+ if clear_conditioned_layers:
129
+ self.lang_encoder.clear_conditioned_layers()
130
+
131
+ return output
132
+
133
+ def generate(
134
+ self,
135
+ audio_x: torch.Tensor,
136
+ audio_x_mask: torch.Tensor,
137
+ lang_x: torch.Tensor,
138
+ attention_mask: torch.Tensor = None,
139
+ **kwargs,
140
+ ):
141
+ num_beams = kwargs.pop("num_beams", 1)
142
+ if num_beams > 1:
143
+ audio_x = audio_x.repeat_interleave(num_beams, dim=0)
144
+
145
+ self.lang_encoder._use_cached_audio_x = True
146
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
147
+
148
+ eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
149
+ output = self.lang_encoder.generate(
150
+ input_ids=lang_x,
151
+ attention_mask=attention_mask,
152
+ eos_token_id=eos_token_id,
153
+ num_beams=num_beams,
154
+ **kwargs,
155
+ )
156
+
157
+ self.lang_encoder.clear_conditioned_layers()
158
+ self.lang_encoder._use_cached_audio_x = False
159
+ return output
160
+
161
+ def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
162
+ """
163
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
164
+ """
165
+
166
+ assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)"
167
+
168
+ with torch.no_grad():
169
+ audio_embeds = self.clap(audio_x)
170
+ B, L, D = audio_embeds.shape # L is number of windows, D is feature dim
171
+ assert D == self.audio_embed_dim
172
+
173
+ assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)"
174
+
175
+ if B > 1 and audio_x_mask.shape[0] == 1:
176
+ audio_x_mask = audio_x_mask.repeat(B, 1)
177
+
178
+ assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L)
179
+
180
+ audio_x_out = self.audio_transformer(audio_embeds) # B, L, D
181
+ audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D
182
+ audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1
183
+
184
+ for layer in self.lang_encoder._get_decoder_layers():
185
+ layer.condition_audio_x(audio_x_out, audio_x_mask)
186
+
187
+ def wrap_fsdp(self, wrapper_kwargs, device_id):
188
+ # unfreeze the decoder layers
189
+ for block in self.lang_encoder.old_decoder_blocks:
190
+ block.requires_grad_(True)
191
+
192
+ # wrap in FSDP
193
+ with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
194
+ self.audio_transformer = wrap(wrap(self.audio_transformer))
195
+ self.lang_encoder.old_decoder_blocks = nn.ModuleList(
196
+ wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
197
+ )
198
+ self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
199
+ wrap(wrap(layer)) if layer is not None else None
200
+ for layer in self.lang_encoder.gated_cross_attn_layers
201
+ )
202
+ self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
203
+ self.lang_encoder.set_input_embeddings(
204
+ wrap(wrap(self.lang_encoder.get_input_embeddings()))
205
+ )
206
+
207
+ if hasattr(self.lang_encoder, 'set_output_embeddings'):
208
+ self.lang_encoder.set_output_embeddings(
209
+ wrap(wrap(self.lang_encoder.get_output_embeddings()))
210
+ )
211
+ else:
212
+ print('skip wrapping output embeddings')
213
+
214
+ # manually move non-FSDP managed parameters to device_id
215
+ # these are all in lang_encoder
216
+ apply_with_stopping_condition(
217
+ module=self.lang_encoder,
218
+ apply_fn=lambda m: m.to(device_id),
219
+ apply_condition=lambda m: len(list(m.children())) == 0,
220
+ stopping_condition=lambda m: isinstance(m, FSDP),
221
+ )
222
+
223
+ # clap shouldn't be wrapped; should be on each gpu
224
+ if self.unfreeze_clap:
225
+ apply_with_stopping_condition(
226
+ module=self.clap,
227
+ apply_fn=lambda m: m.to(device_id),
228
+ apply_condition=lambda m: len(list(m.children())) == 0,
229
+ stopping_condition=lambda m: isinstance(m, FSDP),
230
+ )
231
+
232
+ # exclude the original decoder layers from the optimizer
233
+ for block in self.lang_encoder.old_decoder_blocks:
234
+ for p in block.parameters():
235
+ p.exclude_from_optimizer = True
236
+
237
+ # set up clip_grad_norm_ function
238
+ def clip_grad_norm_(max_norm):
239
+ self.audio_transformer.clip_grad_norm_(max_norm)
240
+ for layer in self.lang_encoder.gated_cross_attn_layers:
241
+ if layer is not None:
242
+ layer.clip_grad_norm_(max_norm)
243
+ self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
244
+
245
+ self.clip_grad_norm_ = clip_grad_norm_
246
+
247
+ def _condition_media_locations(self, input_ids: torch.Tensor):
248
+ media_locations = (input_ids == self.media_token_id)
249
+
250
+ for layer in self.lang_encoder._get_decoder_layers():
251
+ layer.condition_media_locations(media_locations)
252
+
253
+ def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
254
+ self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
255
+ self._condition_media_locations(input_ids=input_ids)
256
+ self.lang_encoder._use_cached_audio_x = True
257
+
258
+ def uncache_media(self):
259
+ self.lang_encoder.clear_conditioned_layers()
260
+ self.lang_encoder._use_cached_audio_x = False
models/audio-flamingo-1/chat/src/flamingo_lm.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import torch.nn as nn
8
+
9
+ try:
10
+ from .helpers import GatedCrossAttentionBlock
11
+ from .utils import getattr_recursive, setattr_recursive
12
+ except:
13
+ from helpers import GatedCrossAttentionBlock
14
+ from utils import getattr_recursive, setattr_recursive
15
+
16
+
17
+ class FlamingoLayer(nn.Module):
18
+ """
19
+ FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
20
+ """
21
+
22
+ def __init__(
23
+ self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
24
+ ):
25
+ super().__init__()
26
+ self.gated_cross_attn_layer = gated_cross_attn_layer
27
+ self.decoder_layer = decoder_layer
28
+ self.audio_x = None
29
+ self.audio_x_mask = None
30
+ self.few_shot_mask = None
31
+ self.media_locations = None
32
+ if self.gated_cross_attn_layer is not None:
33
+ self.gated_cross_attn_layer._use_gradient_checkpointing = (
34
+ gradient_checkpointing
35
+ )
36
+ self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
37
+
38
+ def is_conditioned(self) -> bool:
39
+ """Check whether the layer is conditioned."""
40
+ return (self.audio_x is not None) and (self.audio_x_mask is not None) and (self.media_locations is not None)
41
+
42
+ def condition_audio_x(self, audio_x, audio_x_mask):
43
+ self.audio_x = audio_x
44
+ self.audio_x_mask = audio_x_mask
45
+
46
+ def condition_media_locations(self, media_locations):
47
+ self.media_locations = media_locations
48
+
49
+ def condition_use_cached_media(self, use_cached_media):
50
+ self.use_cached_media = use_cached_media
51
+
52
+ def forward(
53
+ self,
54
+ lang_x,
55
+ attention_mask=None,
56
+ **decoder_layer_kwargs,
57
+ ):
58
+ if self.gated_cross_attn_layer is not None:
59
+ if self.audio_x is None:
60
+ raise ValueError("audio_x must be conditioned before forward pass")
61
+
62
+ if self.media_locations is None:
63
+ raise ValueError(
64
+ "media_locations must be conditioned before forward pass"
65
+ )
66
+
67
+ lang_x = self.gated_cross_attn_layer(
68
+ lang_x,
69
+ self.audio_x,
70
+ self.audio_x_mask,
71
+ media_locations=self.media_locations,
72
+ use_cached_media=self.use_cached_media,
73
+ )
74
+
75
+ # Normal decoder layer
76
+ lang_x = self.decoder_layer(
77
+ lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
78
+ )
79
+ return lang_x
80
+
81
+
82
+ class FlamingoLMMixin(nn.Module):
83
+ """
84
+ Mixin to add cross-attention layers to a language model.
85
+ """
86
+
87
+ def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
88
+ self.decoder_layers_attr_name = decoder_layers_attr_name
89
+
90
+ def _get_decoder_layers(self):
91
+ return getattr_recursive(self, self.decoder_layers_attr_name)
92
+
93
+ def _set_decoder_layers(self, value):
94
+ setattr_recursive(self, self.decoder_layers_attr_name, value)
95
+
96
+ def init_flamingo(
97
+ self,
98
+ media_token_id,
99
+ lang_hidden_size,
100
+ audio_hidden_size,
101
+ max_window_per_audio,
102
+ cross_attn_every_n_layers,
103
+ gradient_checkpointing,
104
+ ):
105
+ """
106
+ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
107
+ """
108
+ self.old_decoder_blocks = self._get_decoder_layers()
109
+ self.gated_cross_attn_layers = nn.ModuleList(
110
+ [
111
+ GatedCrossAttentionBlock(
112
+ dim=lang_hidden_size,
113
+ dim_audio=audio_hidden_size,
114
+ max_window_per_audio=max_window_per_audio,
115
+ only_attend_immediate_media=False,
116
+ )
117
+ if (layer_idx + 1) % cross_attn_every_n_layers == 0
118
+ else None
119
+ for layer_idx, _ in enumerate(self._get_decoder_layers())
120
+ ]
121
+ )
122
+ self.init_flamingo_layers(gradient_checkpointing)
123
+ self.media_token_id = media_token_id
124
+ self.initialized_flamingo = True
125
+ self._use_cached_audio_x = False
126
+
127
+ def init_flamingo_layers(self, gradient_checkpointing):
128
+ """
129
+ Re initializes the FlamingoLayers.
130
+ Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
131
+ """
132
+ self._set_decoder_layers(
133
+ nn.ModuleList(
134
+ [
135
+ FlamingoLayer(
136
+ gated_cross_attn_layer, decoder_layer, gradient_checkpointing
137
+ )
138
+ for gated_cross_attn_layer, decoder_layer in zip(
139
+ self.gated_cross_attn_layers, self.old_decoder_blocks
140
+ )
141
+ ]
142
+ )
143
+ )
144
+
145
+ def forward(self, input_ids, attention_mask, **kwargs):
146
+ """Condition the Flamingo layers on the media locations before forward()"""
147
+ if not self.initialized_flamingo:
148
+ raise ValueError(
149
+ "Flamingo layers are not initialized. Please call `init_flamingo` first."
150
+ )
151
+
152
+ media_locations = input_ids == self.media_token_id
153
+
154
+ use_cached_media_locations = (
155
+ self._use_cached_audio_x
156
+ and self.is_conditioned()
157
+ and not media_locations.any()
158
+ )
159
+
160
+ for layer in self._get_decoder_layers():
161
+ if not use_cached_media_locations:
162
+ layer.condition_media_locations(media_locations)
163
+ layer.condition_use_cached_media(use_cached_media_locations)
164
+
165
+ kwargs["input_ids"] = input_ids
166
+ kwargs["attention_mask"] = attention_mask
167
+ return super().forward(**kwargs)
168
+
169
+ def is_conditioned(self) -> bool:
170
+ """Check whether all decoder layers are already conditioned."""
171
+ return all(l.is_conditioned() for l in self._get_decoder_layers())
172
+
173
+ def clear_conditioned_layers(self):
174
+ for layer in self._get_decoder_layers():
175
+ layer.condition_audio_x(None, None)
176
+ layer.condition_media_locations(None)
177
+ layer.condition_use_cached_media(None)
models/audio-flamingo-1/chat/src/helpers.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license.
8
+ # LICENSE is in incl_licenses directory.
9
+
10
+ # Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license.
11
+ # LICENSE is in incl_licenses directory.
12
+
13
+ from einops import rearrange, repeat
14
+ from einops_exts import rearrange_many
15
+
16
+ import numpy as np
17
+
18
+ import torch
19
+ from torch import einsum, nn
20
+ import torch.nn.functional as F
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+ def FeedForward(dim, mult=4):
26
+ inner_dim = int(dim * mult)
27
+ return nn.Sequential(
28
+ nn.LayerNorm(dim),
29
+ nn.Linear(dim, inner_dim, bias=False),
30
+ nn.GELU(),
31
+ nn.Linear(inner_dim, dim, bias=False),
32
+ )
33
+
34
+ # Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch
35
+ # Original Copyright 2017 Victor Huang
36
+ # MIT License (https://opensource.org/licenses/MIT)
37
+
38
+ class ScaledDotProductAttention(nn.Module):
39
+ ''' Scaled Dot-Product Attention '''
40
+
41
+ def __init__(self, temperature, attn_dropout=0.1):
42
+ super().__init__()
43
+ self.temperature = temperature
44
+ self.dropout = nn.Dropout(attn_dropout)
45
+
46
+ def forward(self, q, k, v, mask=None):
47
+
48
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
49
+
50
+ if mask is not None:
51
+ attn = attn.masked_fill(mask == 0, -1e9)
52
+
53
+ attn = self.dropout(F.softmax(attn, dim=-1))
54
+ output = torch.matmul(attn, v)
55
+
56
+ return output, attn
57
+
58
+
59
+ class MultiHeadAttention(nn.Module):
60
+ ''' Multi-Head Attention module '''
61
+
62
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
63
+ super().__init__()
64
+
65
+ self.n_head = n_head
66
+ self.d_k = d_k
67
+ self.d_v = d_v
68
+
69
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
70
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
71
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
72
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
73
+
74
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
75
+
76
+ self.dropout = nn.Dropout(dropout)
77
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
78
+
79
+
80
+ def forward(self, q, k, v, mask=None):
81
+
82
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
83
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
84
+
85
+ residual = q
86
+
87
+ # Pass through the pre-attention projection: b x lq x (n*dv)
88
+ # Separate different heads: b x lq x n x dv
89
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
90
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
91
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
92
+
93
+ # Transpose for attention dot product: b x n x lq x dv
94
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
95
+
96
+ if mask is not None:
97
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
98
+
99
+ q, attn = self.attention(q, k, v, mask=mask)
100
+
101
+ # Transpose to move the head dimension back: b x lq x n x dv
102
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
103
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
104
+ q = self.dropout(self.fc(q))
105
+ q += residual
106
+
107
+ q = self.layer_norm(q)
108
+
109
+ return q, attn
110
+
111
+
112
+ class PositionwiseFeedForward(nn.Module):
113
+ ''' A two-feed-forward-layer module '''
114
+
115
+ def __init__(self, d_in, d_hid, dropout=0.1):
116
+ super().__init__()
117
+ self.w_1 = nn.Linear(d_in, d_hid) # position-wise
118
+ self.w_2 = nn.Linear(d_hid, d_in) # position-wise
119
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
120
+ self.dropout = nn.Dropout(dropout)
121
+
122
+ def forward(self, x):
123
+
124
+ residual = x
125
+
126
+ x = self.w_2(F.relu(self.w_1(x)))
127
+ x = self.dropout(x)
128
+ x += residual
129
+
130
+ x = self.layer_norm(x)
131
+
132
+ return x
133
+
134
+
135
+ class PositionalEncoding(nn.Module):
136
+
137
+ def __init__(self, d_hid, n_position=200):
138
+ super(PositionalEncoding, self).__init__()
139
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
140
+
141
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
142
+
143
+ def get_position_angle_vec(position):
144
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
145
+
146
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
147
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
148
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
149
+
150
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
151
+
152
+ def forward(self, x):
153
+ return x + self.pos_table[:, :x.size(1)].clone().detach()
154
+
155
+
156
+ class EncoderLayer(nn.Module):
157
+ ''' Compose with two layers '''
158
+
159
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
160
+ super(EncoderLayer, self).__init__()
161
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
162
+ self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
163
+
164
+ def forward(self, enc_input, slf_attn_mask=None):
165
+ enc_output, enc_slf_attn = self.slf_attn(
166
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
167
+ enc_output = self.pos_ffn(enc_output)
168
+ return enc_output, enc_slf_attn
169
+
170
+
171
+ class TransformerEncoder(nn.Module):
172
+ ''' A encoder model with self attention mechanism. '''
173
+
174
+ def __init__(
175
+ self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64,
176
+ d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True):
177
+
178
+ super().__init__()
179
+
180
+ if n_position > 0:
181
+ self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
182
+ else:
183
+ self.position_enc = lambda x: x
184
+ self.dropout = nn.Dropout(p=dropout)
185
+ self.layer_stack = nn.ModuleList([
186
+ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
187
+ for _ in range(n_layers)])
188
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
189
+ self.scale_emb = scale_emb
190
+ self.d_model = d_model
191
+
192
+ def forward(self, src_seq, return_attns=False):
193
+ if len(src_seq.shape) == 2:
194
+ src_seq = src_seq.unsqueeze(1)
195
+ B, L, D = src_seq.shape
196
+
197
+ enc_slf_attn_list = []
198
+
199
+ causal_mask = None
200
+
201
+ enc_output = src_seq
202
+ if self.scale_emb:
203
+ enc_output = enc_output * self.d_model ** 0.5
204
+ enc_output = self.dropout(self.position_enc(enc_output))
205
+ enc_output = self.layer_norm(enc_output)
206
+
207
+ for enc_layer in self.layer_stack:
208
+ enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask)
209
+ enc_slf_attn_list += [enc_slf_attn] if return_attns else []
210
+
211
+ if return_attns:
212
+ return enc_output, enc_slf_attn_list
213
+ return enc_output
214
+
215
+
216
+ # gated cross attention
217
+ class MaskedCrossAttention(nn.Module):
218
+ def __init__(
219
+ self,
220
+ *,
221
+ dim,
222
+ dim_audio,
223
+ max_window_per_audio,
224
+ dim_head=64,
225
+ heads=8,
226
+ only_attend_immediate_media=True,
227
+ ):
228
+ super().__init__()
229
+ self.max_window_per_audio = max_window_per_audio
230
+ self.scale = dim_head**-0.5
231
+ self.heads = heads
232
+ inner_dim = dim_head * heads
233
+
234
+ self.norm = nn.LayerNorm(dim)
235
+
236
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
237
+ self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False)
238
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
239
+
240
+ self.only_attend_immediate_media = only_attend_immediate_media
241
+
242
+ def forward(
243
+ self,
244
+ x,
245
+ media, media_mask,
246
+ media_locations=None,
247
+ use_cached_media=False
248
+ ):
249
+
250
+ if not use_cached_media:
251
+ assert (
252
+ media_locations.shape[1] == x.shape[1]
253
+ ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
254
+
255
+ T_txt = x.shape[1]
256
+ B, L = media.shape[:2]
257
+ assert media.shape[2] == 1 # extra dim
258
+ assert L % self.max_window_per_audio == 0 # should be 4 or 8 times
259
+ h = self.heads
260
+
261
+ x = self.norm(x)
262
+
263
+ q = self.to_q(x)
264
+ media = rearrange(media, "b t n d -> b (t n) d")
265
+
266
+ k, v = self.to_kv(media).chunk(2, dim=-1)
267
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
268
+
269
+ q = q * self.scale
270
+
271
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
272
+
273
+ # mask padded audio embeddings
274
+ media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim
275
+ sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max)
276
+
277
+ assert self.only_attend_immediate_media is False
278
+
279
+ # mask media locations
280
+ if exists(media_locations):
281
+ few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device)
282
+ for batch_idx in range(B):
283
+ media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio>
284
+ if len(media_locations_b.shape) > 1:
285
+ media_locations_b = media_locations_b.squeeze(-1)
286
+
287
+ for i in range(-1, len(media_locations_b)):
288
+ if i == -1:
289
+ if len(media_locations_b) == 1:
290
+ text_start, text_end = 0, T_txt
291
+ else:
292
+ text_start, text_end = 0, media_locations_b[i+1]
293
+
294
+ elif i == len(media_locations_b) - 1:
295
+ text_start, text_end = media_locations_b[i], T_txt
296
+
297
+ else:
298
+ text_start, text_end = media_locations_b[i], media_locations_b[i+1]
299
+
300
+ if self.only_attend_immediate_media:
301
+ look_at_window_start = max(i,0) * self.max_window_per_audio
302
+ else:
303
+ look_at_window_start = 0
304
+ look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio
305
+
306
+ few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True
307
+
308
+ sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max)
309
+
310
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
311
+ attn = sim.softmax(dim=-1)
312
+
313
+ if exists(media_locations) and self.only_attend_immediate_media:
314
+ text_without_media_mask = text_time == 0
315
+ text_without_media_mask = rearrange(
316
+ text_without_media_mask, "b i -> b 1 i 1"
317
+ )
318
+ attn = attn.masked_fill(text_without_media_mask, 0.0)
319
+
320
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
321
+ out = rearrange(out, "b h n d -> b n (h d)")
322
+ return self.to_out(out)
323
+
324
+
325
+ class GatedCrossAttentionBlock(nn.Module):
326
+ def __init__(
327
+ self,
328
+ *,
329
+ dim,
330
+ dim_audio,
331
+ max_window_per_audio,
332
+ dim_head=64,
333
+ heads=8,
334
+ ff_mult=4,
335
+ only_attend_immediate_media=True,
336
+ ):
337
+ super().__init__()
338
+ self.attn = MaskedCrossAttention(
339
+ dim=dim,
340
+ dim_audio=dim_audio,
341
+ max_window_per_audio=max_window_per_audio,
342
+ dim_head=dim_head,
343
+ heads=heads,
344
+ only_attend_immediate_media=only_attend_immediate_media,
345
+ )
346
+ self.attn_gate = nn.Parameter(torch.tensor([0.0]))
347
+
348
+ self.ff = FeedForward(dim, mult=ff_mult)
349
+ self.ff_gate = nn.Parameter(torch.tensor([0.0]))
350
+
351
+ def forward(
352
+ self,
353
+ x,
354
+ media,
355
+ media_mask,
356
+ media_locations=None,
357
+ use_cached_media=False,
358
+ ):
359
+ x = (
360
+ self.attn(
361
+ x,
362
+ media,
363
+ media_mask,
364
+ media_locations=media_locations,
365
+ use_cached_media=use_cached_media,
366
+ )
367
+ * self.attn_gate.tanh()
368
+ + x
369
+ )
370
+ x = self.ff(x) * self.ff_gate.tanh() + x
371
+
372
+ return x
373
+
374
+
375
+ if __name__ == '__main__':
376
+ enc = TransformerEncoder().cuda()
377
+ x = torch.randn(4, 512).cuda()
378
+ output = enc(x)
379
+ enc._use_gradient_checkpointing = True
380
+ print(output.shape)
models/audio-flamingo-1/chat/src/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ def extend_instance(obj, mixin):
8
+ """Apply mixins to a class instance after creation"""
9
+ base_cls = obj.__class__
10
+ base_cls_name = obj.__class__.__name__
11
+ obj.__class__ = type(
12
+ base_cls_name, (mixin, base_cls), {}
13
+ ) # mixin needs to go first for our forward() logic to work
14
+
15
+
16
+ def getattr_recursive(obj, att):
17
+ """
18
+ Return nested attribute of obj
19
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
20
+ """
21
+ if att == "":
22
+ return obj
23
+ i = att.find(".")
24
+ if i < 0:
25
+ return getattr(obj, att)
26
+ else:
27
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
28
+
29
+
30
+ def setattr_recursive(obj, att, val):
31
+ """
32
+ Set nested attribute of obj
33
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
34
+ """
35
+ if "." in att:
36
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
37
+ setattr(obj, att.split(".")[-1], val)
38
+
39
+
40
+ def apply_with_stopping_condition(
41
+ module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
42
+ ):
43
+ if stopping_condition(module):
44
+ return
45
+ if apply_condition(module):
46
+ apply_fn(module, **other_args)
47
+ for child in module.children():
48
+ apply_with_stopping_condition(
49
+ child,
50
+ apply_fn,
51
+ apply_condition=apply_condition,
52
+ stopping_condition=stopping_condition,
53
+ **other_args
54
+ )
models/audio-flamingo-1/chat/train/distributed.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ # Adapted from https://github.com/mlfoundations/open_clip under the MIT license.
8
+ # LICENSE is in incl_licenses directory.
9
+
10
+ import os
11
+ import torch
12
+
13
+ try:
14
+ import horovod.torch as hvd
15
+ except ImportError:
16
+ hvd = None
17
+
18
+
19
+ def is_global_master(args):
20
+ return args.rank == 0
21
+
22
+
23
+ def is_local_master(args):
24
+ return args.local_rank == 0
25
+
26
+
27
+ def is_master(args, local=False):
28
+ return is_local_master(args) if local else is_global_master(args)
29
+
30
+
31
+ def is_using_horovod():
32
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
33
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
34
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
35
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
36
+ if all([var in os.environ for var in ompi_vars]) or all(
37
+ [var in os.environ for var in pmi_vars]
38
+ ):
39
+ return True
40
+ else:
41
+ return False
42
+
43
+
44
+ def is_using_distributed():
45
+ if "WORLD_SIZE" in os.environ:
46
+ return int(os.environ["WORLD_SIZE"]) > 1
47
+ if "SLURM_NTASKS" in os.environ:
48
+ return int(os.environ["SLURM_NTASKS"]) > 1
49
+ return False
50
+
51
+
52
+ def world_info_from_env():
53
+ local_rank = 0
54
+ for v in (
55
+ "LOCAL_RANK",
56
+ "MPI_LOCALRANKID",
57
+ "SLURM_LOCALID",
58
+ "OMPI_COMM_WORLD_LOCAL_RANK",
59
+ ):
60
+ if v in os.environ:
61
+ local_rank = int(os.environ[v])
62
+ break
63
+ global_rank = 0
64
+ for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
65
+ if v in os.environ:
66
+ global_rank = int(os.environ[v])
67
+ break
68
+ world_size = 1
69
+ for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
70
+ if v in os.environ:
71
+ world_size = int(os.environ[v])
72
+ break
73
+
74
+ return local_rank, global_rank, world_size
75
+
76
+
77
+ def init_distributed_device(args):
78
+ # Distributed training = training on more than one GPU.
79
+ # Works in both single and multi-node scenarios.
80
+ args.distributed = False
81
+ args.world_size = 1
82
+ args.rank = 0 # global rank
83
+ args.local_rank = 0
84
+
85
+ if args.horovod:
86
+ assert hvd is not None, "Horovod is not installed"
87
+ print('using horovod')
88
+ hvd.init()
89
+ args.local_rank = int(hvd.local_rank())
90
+ args.rank = hvd.rank()
91
+ args.world_size = hvd.size()
92
+ args.distributed = True
93
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
94
+ os.environ["RANK"] = str(args.rank)
95
+ os.environ["WORLD_SIZE"] = str(args.world_size)
96
+
97
+ elif is_using_distributed():
98
+ if "SLURM_PROCID" in os.environ:
99
+ print('DDP via SLURM')
100
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
101
+
102
+ # SLURM var -> torch.distributed vars in case needed
103
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
104
+ os.environ["RANK"] = str(args.rank)
105
+ os.environ["WORLD_SIZE"] = str(args.world_size)
106
+
107
+ init_method = args.dist_url
108
+
109
+ # # master_ip = os.getenv('MASTER_ADDR', 'localhost')
110
+ # # master_port = os.getenv('MASTER_PORT', '7000')
111
+ # print("DDP RANK %d WORLD_SIZE %d" % (args.rank, args.world_size))
112
+ # # init_method = f'tcp://{master_ip}:{master_port}'
113
+ # init_method = 'tcp://localhost:54323'
114
+ # print("Init method: %s" % (init_method))
115
+
116
+ torch.distributed.init_process_group(
117
+ backend=args.dist_backend,
118
+ init_method=init_method,
119
+ world_size=args.world_size,
120
+ rank=args.rank,
121
+ )
122
+ else:
123
+ print('DDP via torchrun, torch.distributed.launch')
124
+ args.local_rank, _, _ = world_info_from_env()
125
+ torch.distributed.init_process_group(
126
+ backend=args.dist_backend, init_method=args.dist_url
127
+ )
128
+ args.world_size = torch.distributed.get_world_size()
129
+ args.rank = torch.distributed.get_rank()
130
+ args.distributed = True
131
+ else:
132
+ print('needed to run on single gpu')
133
+ torch.distributed.init_process_group(
134
+ backend=args.dist_backend,
135
+ init_method=args.dist_url,
136
+ world_size=1,
137
+ rank=0,
138
+ )
139
+
140
+ if torch.cuda.is_available():
141
+ if args.distributed and not args.no_set_device_rank:
142
+ device = "cuda:%d" % args.local_rank
143
+ else:
144
+ device = "cuda:0"
145
+ torch.cuda.set_device(device)
146
+ else:
147
+ device = "cpu"
148
+ args.device = device
149
+ device = torch.device(device)
150
+ return device
models/audio-flamingo-1/chat/train/train.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import argparse
8
+ import functools
9
+ import glob
10
+ import os
11
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
12
+ import random
13
+ import shutil
14
+ import sys
15
+ sys.path.append('../')
16
+ import yaml
17
+ import time
18
+
19
+ import numpy as np
20
+ import torch
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
24
+ from torch.distributed.fsdp import (
25
+ CPUOffload,
26
+ MixedPrecision,
27
+ ShardingStrategy,
28
+ BackwardPrefetch,
29
+ )
30
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
31
+ checkpoint_wrapper,
32
+ CheckpointWrapper,
33
+ CheckpointImpl,
34
+ apply_activation_checkpointing,
35
+ )
36
+ from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups
37
+ from torch.distributed.distributed_c10d import _get_default_group
38
+ torch.cuda.empty_cache()
39
+
40
+ from transformers import (
41
+ get_constant_schedule_with_warmup,
42
+ get_cosine_schedule_with_warmup,
43
+ get_linear_schedule_with_warmup,
44
+ )
45
+
46
+ from data.data import get_audiotext_dataloader # AudioTextData, DataCollator
47
+ from distributed import init_distributed_device, world_info_from_env
48
+ from train_utils import (
49
+ train_one_epoch,
50
+ get_mp_policy_dtype,
51
+ save_checkpoint,
52
+ Dict2Class,
53
+ get_autocast,
54
+ get_cast_dtype
55
+ )
56
+ from src.factory import create_model_and_transforms
57
+
58
+
59
+ def random_seed(seed=42, rank=0):
60
+ torch.manual_seed(seed + rank)
61
+ np.random.seed(seed + rank)
62
+ random.seed(seed + rank)
63
+
64
+
65
+ def main():
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path')
68
+ parsed_args = parser.parse_args()
69
+
70
+ config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader)
71
+ data_config = config['data_config']
72
+ model_config = config['model_config']
73
+ clap_config = config["clap_config"]
74
+ args = Dict2Class(config['train_config'])
75
+
76
+ if 'sft_config' in config:
77
+ sft_config = config['sft_config']
78
+ unfreeze_full_lm = sft_config['unfreeze_full_lm']
79
+ else:
80
+ sft_config = None
81
+ unfreeze_full_lm = False
82
+
83
+ # get paths done
84
+ exp_path = os.path.join(args.expdir, args.run_name)
85
+ os.makedirs(exp_path, exist_ok=True)
86
+ print('exp_path:', exp_path)
87
+ shutil.copy(parsed_args.config, os.path.join(exp_path, 'config.yaml'))
88
+ data_config["dataset_blending_output"] = os.path.join(exp_path, data_config["dataset_blending_output"])
89
+
90
+ # Validate args
91
+ if args.fsdp and not args.fsdp_use_orig_params:
92
+ print(
93
+ "Warning: FSDP is running without fsdp_use_orig_params flag. "
94
+ + "This is not recommended because it means we will use uniform weight decay"
95
+ + " and train all embeddings, not just the newly added ones. "
96
+ + "Note: OPT models are not compatible with fsdp_use_orig_params flag."
97
+ )
98
+
99
+ if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
100
+ print(
101
+ "Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding."
102
+ + "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py"
103
+ + "Copy and paste the code from the _optim_utils.py in this repo into the torch file."
104
+ + "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state."
105
+ )
106
+
107
+ # Set up distributed training
108
+ print('initializing distributed environment')
109
+ if args.offline:
110
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
111
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
112
+ device_id = init_distributed_device(args)
113
+ random_seed(args.seed)
114
+
115
+ # Initialize model
116
+ print('creating model')
117
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
118
+ model, tokenizer = create_model_and_transforms(
119
+ **model_config,
120
+ clap_config=clap_config,
121
+ use_local_files=args.offline,
122
+ gradient_checkpointing=args.gradient_checkpointing,
123
+ freeze_lm_embeddings=args.freeze_lm_embeddings,
124
+ unfreeze_full_lm=unfreeze_full_lm
125
+ )
126
+ random_seed(args.seed, args.rank)
127
+
128
+ # Initialize logging
129
+ print(f"Start running training on rank {args.rank}.")
130
+
131
+ # Load model checkpoint on CPU
132
+ checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt")
133
+ if len(checkpoint_list) == 0:
134
+ print(f"Found no checkpoints for run {args.run_name}.")
135
+ resume_from_checkpoint = None
136
+ else:
137
+ resume_from_checkpoint = sorted(
138
+ checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
139
+ )[-1]
140
+ print(
141
+ f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}."
142
+ )
143
+
144
+ # load pretrained model
145
+ resume_from_epoch = 0
146
+ if (resume_from_checkpoint is None) and (sft_config is not None):
147
+ # just started SFT
148
+ pretrained_path = os.path.join(
149
+ sft_config['pretrained_path'],
150
+ sft_config['pretrained_ckpt']
151
+ )
152
+ if args.rank == 0:
153
+ print(f"Loading checkpoint from {pretrained_path}")
154
+ checkpoint = torch.load(pretrained_path, map_location="cpu")
155
+ msd = checkpoint["model_state_dict"]
156
+ msd = {k.replace("module.", ""): v for k, v in msd.items()}
157
+
158
+ # for fsdp, only one rank needs to load the state dict
159
+ if not args.fsdp or args.rank == 0:
160
+ model.load_state_dict(msd, False)
161
+ del checkpoint["model_state_dict"]
162
+ del msd
163
+
164
+
165
+ elif resume_from_checkpoint is not None:
166
+ # continue training (either pretraining or STF)
167
+ if args.rank == 0:
168
+ print(f"Loading checkpoint from {resume_from_checkpoint}")
169
+ checkpoint = torch.load(resume_from_checkpoint, map_location="cpu")
170
+ msd = checkpoint["model_state_dict"]
171
+ msd = {k.replace("module.", ""): v for k, v in msd.items()}
172
+ resume_from_epoch = checkpoint["epoch"] + 1
173
+
174
+ # for fsdp, only one rank needs to load the state dict
175
+ if not args.fsdp or args.rank == 0:
176
+ model.load_state_dict(msd, False)
177
+ del checkpoint["model_state_dict"]
178
+ del msd
179
+
180
+ else:
181
+ pass
182
+
183
+ # Initialize FSDP / DDP, and ensure the model is on GPU
184
+ print(f"Initializing distributed training with {args.world_size} GPUs.")
185
+ if args.fsdp:
186
+ print(
187
+ f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
188
+ )
189
+
190
+ # init MixedPrecision
191
+ if args.precision != "fp32":
192
+ cast_dtype = get_mp_policy_dtype(args.precision)
193
+ mp_policy = MixedPrecision(
194
+ param_dtype=torch.float32,
195
+ reduce_dtype=cast_dtype, # gradient communication
196
+ buffer_dtype=cast_dtype,
197
+ )
198
+ else:
199
+ mp_policy = None
200
+
201
+ # init process groups
202
+ if args.fsdp_sharding_strategy == "hybrid":
203
+ intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
204
+ _get_default_group()
205
+ )
206
+ args.my_group = intra_node_group # for optimizer saving
207
+ process_group = (intra_node_group, inter_node_group) # for FSDP init
208
+ else:
209
+ args.my_group = None # for optimizer saving
210
+ process_group = None # for FSDP init
211
+
212
+ # init FSDP
213
+ wrapper_kwargs = dict(
214
+ process_group=process_group,
215
+ cpu_offload=CPUOffload(offload_params=False),
216
+ device_id=device_id,
217
+ sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
218
+ sharding_strategy=ShardingStrategy.FULL_SHARD
219
+ if args.fsdp_sharding_strategy == "full"
220
+ else ShardingStrategy.HYBRID_SHARD,
221
+ use_orig_params=args.fsdp_use_orig_params,
222
+ mixed_precision=mp_policy,
223
+ forward_prefetch=True,
224
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
225
+ limit_all_gathers=True,
226
+ )
227
+ model.wrap_fsdp(wrapper_kwargs, device_id)
228
+ ddp_model = model
229
+
230
+ print(
231
+ f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
232
+ )
233
+ print(
234
+ f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
235
+ )
236
+
237
+ else:
238
+ model = model.to(device_id)
239
+ ddp_model = DDP(model, device_ids=[device_id])
240
+
241
+ # Initialize gradient checkpointing
242
+ if args.gradient_checkpointing:
243
+ non_reentrant_wrapper = functools.partial(
244
+ checkpoint_wrapper,
245
+ offload_to_cpu=True,
246
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
247
+ )
248
+ apply_activation_checkpointing(
249
+ ddp_model,
250
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
251
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
252
+ and not isinstance(m, FSDP)
253
+ and not isinstance(m, CheckpointWrapper),
254
+ )
255
+
256
+ # Initialize optimizer
257
+ params_to_optimize = ddp_model.named_parameters()
258
+ params_to_optimize = list(
259
+ filter(
260
+ lambda x: x[1].requires_grad
261
+ and not getattr(x[1], "exclude_from_optimizer", False),
262
+ params_to_optimize,
263
+ )
264
+ )
265
+ if not args.fsdp or args.fsdp_use_orig_params:
266
+ # apply weight decay only to params in the xattn layers
267
+ def get_grouped_params(model):
268
+ params_with_wd, params_without_wd = [], []
269
+ for n, p in params_to_optimize:
270
+ if "gated_cross_attn" in n:
271
+ params_with_wd.append(p)
272
+ else:
273
+ params_without_wd.append(p)
274
+ return [
275
+ {"params": params_with_wd, "weight_decay": args.weight_decay},
276
+ {"params": params_without_wd, "weight_decay": 0.0},
277
+ ]
278
+
279
+ optimizer = torch.optim.AdamW(
280
+ get_grouped_params(params_to_optimize), lr=args.learning_rate
281
+ )
282
+ else:
283
+ # unclear if we should be using no weight decay or small weight decay for all parameters
284
+ optimizer = torch.optim.AdamW(
285
+ (p for _, p in params_to_optimize),
286
+ lr=args.learning_rate,
287
+ weight_decay=args.weight_decay,
288
+ )
289
+
290
+ # load optimizer checkpoint
291
+ if resume_from_checkpoint is not None:
292
+ osd = checkpoint["optimizer_state_dict"]
293
+ if args.fsdp:
294
+ osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer)
295
+ optimizer.load_state_dict(osd)
296
+ del checkpoint["optimizer_state_dict"]
297
+ del osd
298
+
299
+ # Initialize data loaders
300
+ AudioTextDataInfo = get_audiotext_dataloader(
301
+ data_config, clap_config, tokenizer, args.batch_size, split='train',
302
+ epoch=0, force_reblend=True
303
+ )
304
+
305
+ total_training_steps = (
306
+ len(AudioTextDataInfo.dataset) // (args.batch_size * args.world_size)
307
+ ) * args.num_epochs
308
+
309
+ if args.rank == 0:
310
+ print(f"Total training steps: {total_training_steps}")
311
+ tb = SummaryWriter(os.path.join(exp_path, 'tensorboard'))
312
+ else:
313
+ tb = None
314
+
315
+ # Initialize lr scheduler
316
+ if args.lr_scheduler == "linear":
317
+ lr_scheduler = get_linear_schedule_with_warmup(
318
+ optimizer,
319
+ num_warmup_steps=args.warmup_steps,
320
+ num_training_steps=total_training_steps,
321
+ )
322
+ elif args.lr_scheduler == "cosine":
323
+ lr_scheduler = get_cosine_schedule_with_warmup(
324
+ optimizer,
325
+ num_warmup_steps=args.warmup_steps,
326
+ num_training_steps=total_training_steps,
327
+ )
328
+ else:
329
+ lr_scheduler = get_constant_schedule_with_warmup(
330
+ optimizer, num_warmup_steps=args.warmup_steps
331
+ )
332
+
333
+ # load lr scheduler checkpoint
334
+ if resume_from_checkpoint is not None:
335
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
336
+ del checkpoint["lr_scheduler_state_dict"]
337
+
338
+ # Start training!
339
+ ddp_model.train()
340
+
341
+ print('start training from epoch {}'.format(resume_from_epoch))
342
+ for epoch in range(resume_from_epoch, args.num_epochs):
343
+ # force reblending dataset for every epoch
344
+ if epoch > 0:
345
+ AudioTextDataInfo = get_audiotext_dataloader(
346
+ data_config, clap_config, tokenizer, args.batch_size, split='train',
347
+ epoch=epoch, force_reblend=True
348
+ )
349
+ AudioTextDataInfo.set_epoch(epoch)
350
+ trainloader = AudioTextDataInfo.dataloader
351
+
352
+ # train one epoch
353
+ train_one_epoch(
354
+ args=args,
355
+ model=ddp_model,
356
+ epoch=epoch,
357
+ tokenizer=tokenizer,
358
+ optimizer=optimizer,
359
+ lr_scheduler=lr_scheduler,
360
+ trainloader=trainloader,
361
+ device_id=device_id,
362
+ tb=tb
363
+ )
364
+
365
+ # save checkpoint
366
+ save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
367
+ time.sleep(1.0)
368
+
369
+ # save final checkpoint
370
+ save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
371
+ if args.rank == 0:
372
+ tb.close()
373
+
374
+
375
+ if __name__ == "__main__":
376
+ main()
models/audio-flamingo-1/chat/train/train_utils.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import time
8
+ import os
9
+ from tqdm import tqdm
10
+ import sys
11
+ from copy import deepcopy
12
+
13
+ from contextlib import suppress
14
+ import torch
15
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
16
+ from torch.distributed.fsdp import (
17
+ FullStateDictConfig,
18
+ StateDictType,
19
+ )
20
+ from torch.distributed.fsdp.api import FullOptimStateDictConfig
21
+ from einops import rearrange
22
+
23
+
24
+ class Dict2Class:
25
+ def __init__(self, data_dict):
26
+ for key, value in data_dict.items():
27
+ setattr(self, key, value)
28
+
29
+
30
+ class SysLogger(object):
31
+ def __init__(self, filename="../log/log.log"):
32
+ self.terminal = sys.stdout
33
+ self.log = open(filename, "a")
34
+
35
+ def write(self, message):
36
+ self.terminal.write(message+'\n')
37
+ self.log.write(message)
38
+
39
+
40
+ def get_cast_dtype(precision: str):
41
+ cast_dtype = None
42
+ if precision == "bf16":
43
+ cast_dtype = torch.bfloat16
44
+ elif precision == "fp16":
45
+ cast_dtype = torch.float16
46
+ return cast_dtype
47
+
48
+
49
+ def get_mp_policy_dtype(precision: str):
50
+ if "bfloat16" in precision or "bf16" in precision:
51
+ return torch.bfloat16
52
+ elif precision == "fp16":
53
+ return torch.float16
54
+ else:
55
+ return torch.float32
56
+
57
+
58
+ def get_autocast(precision, cache_enabled=True):
59
+ if precision == "amp":
60
+ return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
61
+ elif precision == "amp_bfloat16" or precision == "amp_bf16":
62
+ return lambda: torch.cuda.amp.autocast(
63
+ dtype=torch.bfloat16, cache_enabled=cache_enabled
64
+ )
65
+ else:
66
+ return suppress
67
+
68
+
69
+ def train_one_epoch(
70
+ args,
71
+ model,
72
+ epoch,
73
+ trainloader,
74
+ tokenizer,
75
+ optimizer,
76
+ lr_scheduler,
77
+ device_id,
78
+ tb
79
+ ):
80
+ # setup loaders
81
+ num_batches_per_epoch = len(trainloader)
82
+ total_training_steps = num_batches_per_epoch * args.num_epochs
83
+ print('num_batches_per_epoch={}, total_training_steps={}'.format(num_batches_per_epoch, total_training_steps))
84
+
85
+ autocast = get_autocast(
86
+ args.precision, cache_enabled=(not args.fsdp)
87
+ ) # if fsdp, disable cache to save memory
88
+ cast_dtype = get_cast_dtype(args.precision)
89
+
90
+ # setup model
91
+ media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1]
92
+ assert media_token_id == tokenizer.encode("<audio>")[-1]
93
+ endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
94
+ model.train()
95
+
96
+ # setup logging
97
+ step_time_m = AverageMeter()
98
+ data_time_m = AverageMeter()
99
+ end = time.time()
100
+
101
+ # loop through dataloader
102
+ for num_steps, batch in tqdm(
103
+ enumerate(trainloader),
104
+ disable=args.rank != 0,
105
+ total=total_training_steps,
106
+ initial=(epoch * num_batches_per_epoch)
107
+ ):
108
+
109
+ data_time_m.update(time.time() - end)
110
+ global_step = num_steps + epoch * num_batches_per_epoch
111
+
112
+ #### FORWARD PASS ####
113
+ audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS, WINDOW_LENGTH)
114
+ audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS)
115
+ input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS)
116
+ attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS)
117
+
118
+ # set up labels; language model is expected to handle shifting
119
+ labels = input_ids.clone()
120
+ labels[labels == tokenizer.pad_token_id] = -100
121
+ labels[:, :2] = -100
122
+ labels[labels == tokenizer.encode("<audio>")[-1]] = -100
123
+
124
+ # mask all prompts except for between <SEP> and <|endofchunk|>
125
+ sep_locations = labels == tokenizer.sep_token_id
126
+ eoc_locations = labels == endofchunk_token_id
127
+
128
+ if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)):
129
+ print("Warning: <SEP>-<EoC> pairing mismatch at step {} due to max_token limit.".format(num_steps))
130
+
131
+ for i in range(labels.shape[0]):
132
+ shouldmask = True
133
+ for j in range(labels.shape[1]):
134
+ if shouldmask and (labels[i][j] != tokenizer.eos_token_id):
135
+ masked_value = -100
136
+ else:
137
+ masked_value = labels[i][j]
138
+
139
+ if labels[i][j] == tokenizer.sep_token_id:
140
+ shouldmask = False
141
+ elif labels[i][j] == endofchunk_token_id:
142
+ shouldmask = True
143
+
144
+ labels[i][j] = masked_value
145
+
146
+ if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]:
147
+ for j in range(labels.shape[1]-1, -1, -1):
148
+ if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]:
149
+ labels[i][j] = -100
150
+ else:
151
+ break
152
+
153
+ labels = labels.to(device_id)
154
+
155
+ # gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
156
+ with autocast():
157
+ output = model(
158
+ audio_x=audio_clips,
159
+ audio_x_mask=audio_embed_mask,
160
+ lang_x=input_ids,
161
+ attention_mask=attention_mask,
162
+ labels=labels
163
+ )
164
+ loss = output.loss
165
+
166
+ divided_loss = loss / args.gradient_accumulation_steps
167
+ train_loss = divided_loss * args.loss_multiplier
168
+ train_loss.backward()
169
+
170
+ if (not args.freeze_lm_embeddings) and (
171
+ not args.fsdp or args.fsdp_use_orig_params
172
+ ):
173
+ # Mask gradients for input embeddings s.t. we only update the added tokens <audio> and <|endofchunk|>
174
+ if args.fsdp:
175
+ embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
176
+ else:
177
+ embed_grad = (
178
+ model.module.lang_encoder.get_input_embeddings().weight.grad
179
+ )
180
+ zero_mask = torch.zeros_like(embed_grad)
181
+ zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
182
+ zero_mask[endofchunk_token_id] = torch.ones_like(
183
+ zero_mask[endofchunk_token_id]
184
+ )
185
+ if args.fsdp:
186
+ model.lang_encoder.get_input_embeddings().weight.grad = (
187
+ embed_grad * zero_mask
188
+ )
189
+ else:
190
+ model.module.lang_encoder.get_input_embeddings().weight.grad = (
191
+ embed_grad * zero_mask
192
+ )
193
+
194
+ # clip gradient norm
195
+ if args.fsdp:
196
+ """
197
+ The way we clip gradients with FSDP is different than the non-FSDP case,
198
+ because during FSDP, gradient norms are computed over certain submodules,
199
+ rather than the entire model.
200
+ At least for OPT-125M, this didn't seem to make a difference in performance.
201
+ """
202
+ model.clip_grad_norm_(1.0)
203
+ else:
204
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
205
+
206
+ # step optimizer and log
207
+ if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (
208
+ num_steps == num_batches_per_epoch - 1
209
+ ):
210
+ optimizer.step()
211
+ lr_scheduler.step()
212
+ optimizer.zero_grad(set_to_none=True)
213
+
214
+ # step time and reset end outside of rank 0
215
+ step_time_m.update(time.time() - end)
216
+ end = time.time()
217
+
218
+ # rank 0 logging
219
+ if args.rank == 0:
220
+ samples_per_second = (
221
+ args.gradient_accumulation_steps
222
+ * args.batch_size
223
+ * args.world_size
224
+ / step_time_m.val
225
+ )
226
+ samples_per_second_per_gpu = (
227
+ args.gradient_accumulation_steps
228
+ * args.batch_size
229
+ / step_time_m.val
230
+ )
231
+ log_dict = {
232
+ "data_time": data_time_m.avg,
233
+ "step_time": step_time_m.avg,
234
+ "samples_per_second": samples_per_second,
235
+ "samples_per_second_per_gpu": samples_per_second_per_gpu,
236
+ "lr": optimizer.param_groups[0]["lr"],
237
+ "loss": loss.item()
238
+ }
239
+
240
+ if ((num_steps + 1) % args.logging_steps == 0):
241
+ for key in log_dict:
242
+ tb.add_scalar("Train/{}".format(key), log_dict[key], global_step)
243
+
244
+ step_time_m.reset()
245
+ data_time_m.reset()
246
+
247
+ # Log loss to console
248
+ if ((num_steps + 1) % args.logging_steps == 0):
249
+ print(
250
+ f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}\n"
251
+ )
252
+
253
+
254
+ class AverageMeter(object):
255
+ """Computes and stores the average and current value"""
256
+
257
+ def __init__(self):
258
+ self.reset()
259
+
260
+ def reset(self):
261
+ self.val = 0
262
+ self.avg = 0
263
+ self.sum = 0
264
+ self.count = 0
265
+
266
+ def update(self, val, n=1):
267
+ self.val = val
268
+ self.sum += val * n
269
+ self.count += n
270
+ self.avg = self.sum / self.count
271
+
272
+
273
+ def filter_state_dict_to_trainable(model, state_dict):
274
+ """
275
+ Remove non-trainable parameters from model state dict.
276
+ Exception: Embeddings will not be removed, even if frozen.
277
+ This is because we need the new <audio> <|endofchunk|> tokens to
278
+ be consistent across initializations.
279
+ """
280
+ for (
281
+ name,
282
+ p,
283
+ ) in model.named_parameters(): # won't work for fsdp + use_orig_params=False
284
+ if "fsdp" in name:
285
+ continue
286
+ if "embed" in name or isinstance(p, torch.nn.Embedding):
287
+ continue
288
+ if not p.requires_grad:
289
+ name = name.replace("._checkpoint_wrapped_module", "")
290
+ if name in state_dict:
291
+ del state_dict[name]
292
+ else:
293
+ print(f"WARNING: filtering but {name} not in state_dict")
294
+
295
+ # also remove the keys in state_dict generated from
296
+ # lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers
297
+ # because these are already saved in lang_encoder.model...
298
+ to_delete = [
299
+ n
300
+ for n in state_dict.keys()
301
+ if ("lang_encoder.old_decoder_blocks" in n)
302
+ or ("lang_encoder.gated_cross_attn_layers" in n)
303
+ or ("vision_encoder" in n)
304
+ ]
305
+ for name in to_delete:
306
+ del state_dict[name]
307
+ return state_dict
308
+
309
+
310
+ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
311
+ """
312
+ Save training checkpoint with model, optimizer, and lr_scheduler state.
313
+ """
314
+ if args.fsdp:
315
+ FSDP.set_state_dict_type(
316
+ model,
317
+ StateDictType.FULL_STATE_DICT,
318
+ FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
319
+ FullOptimStateDictConfig(rank0_only=True),
320
+ )
321
+ model_state = model.state_dict()
322
+ optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group)
323
+
324
+ else:
325
+ model_state = model.state_dict()
326
+ optim_state = optimizer.state_dict()
327
+
328
+ if args.rank == 0:
329
+ if not (args.fsdp and not args.fsdp_use_orig_params):
330
+ model_state = filter_state_dict_to_trainable(model, model_state)
331
+
332
+ checkpoint_dir = os.path.join(args.expdir, args.run_name)
333
+ if not os.path.exists(checkpoint_dir):
334
+ os.makedirs(checkpoint_dir)
335
+
336
+ checkpoint_dict = {
337
+ "epoch": epoch,
338
+ "model_state_dict": model_state,
339
+ "optimizer_state_dict": optim_state,
340
+ "lr_scheduler_state_dict": lr_scheduler.state_dict(),
341
+ }
342
+
343
+ print(f"Saving checkpoint to {checkpoint_dir}/checkpoint_{epoch}.pt")
344
+ torch.save(checkpoint_dict, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
345
+
346
+ if args.delete_previous_checkpoint:
347
+ if epoch > 0 and epoch % 20 != 0:
348
+ try:
349
+ os.remove(f"{checkpoint_dir}/checkpoint_{epoch-1}.pt")
350
+ except:
351
+ pass
models/audio-flamingo-1/checkpoints/chat_part1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5673d1541cd5764d6dcc89b3bdc331b768c1159ef685a373c7f4deb9e1ddaef
3
+ size 3328734458
models/audio-flamingo-1/checkpoints/chat_part2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea792e60d95deacc75244af0b23c5b75b8de3aef617b392e633cd67a5f20c5aa
3
+ size 3482749306
models/audio-flamingo-1/checkpoints/chat_part3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8828e7e5db7014259d746025274dd752fe39f959eb6d7e1380796a838c2983c
3
+ size 3898925434
models/audio-flamingo-1/checkpoints/chat_part4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0445a8f663774e16df616f8e20abb5ceff61a40b2ee5a8b12a83a641971f8e1
3
+ size 3357325242
models/audio-flamingo-1/checkpoints/chat_part5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e54e84938e4143a3d2fc8b090a1ab238287a0a41236efcc4a5df5d476291d96f
3
+ size 3591230906
models/audio-flamingo-1/checkpoints/checkpoint_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+
6
+ def merge_checkpoints(checkpoint_path, num_parts=5):
7
+ combined_state_dict = {}
8
+ for i in range(1, num_parts + 1):
9
+ part_path = checkpoint_path.replace('.pt', '_part{}.pt'.format(i))
10
+ part_checkpoint = torch.load(part_path)
11
+ part_state_dict = part_checkpoint['model_state_dict']
12
+ combined_state_dict.update(part_state_dict)
13
+
14
+ full_checkpoint = {'model_state_dict': combined_state_dict}
15
+ torch.save(full_checkpoint, checkpoint_path)
16
+ print('merging {}: finished'.format(checkpoint_path))
17
+
18
+ merge_checkpoints('foundation.pt', num_parts=5)
19
+ merge_checkpoints('chat.pt', num_parts=5)
models/audio-flamingo-1/checkpoints/foundation_part1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5921b6167a0e0d27a4732dc77899505db454bcb769942da3079c8b821e54711
3
+ size 3328736090
models/audio-flamingo-1/checkpoints/foundation_part2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa325ee55d333ff4155679af58e901096b89ffd8097092b602cdf54f8b989791
3
+ size 3482750938
models/audio-flamingo-1/checkpoints/foundation_part3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfec3fcfe53582e71c8ae1bc8832bbce5b15168878cf5e85a32f1c717bacdc78
3
+ size 3898927066
models/audio-flamingo-1/checkpoints/foundation_part4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b4b3e3011ea05da7ef29e8b906141af05a7eb2fa9604b65f21638db142a192d
3
+ size 3357326874
models/audio-flamingo-1/checkpoints/foundation_part5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e52ea8848b540816e4b7e46ed705ec84965b3718d77b3e8a4ded0b64668fd168
3
+ size 3591232538