Audio-Flamingo (code, models, paper)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +23 -0
- Audio Flamingo 2. An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities.pdf +3 -0
- Audio Flamingo 3. Advancing Audio Intelligence with Fully Open Large Audio Language Models.pdf +3 -0
- Audio Flamingo Sound-CoT Technical Report. Improving Chain-of-Thought Reasoning in Sound Understanding.pdf +3 -0
- Audio Flamingo. A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities.pdf +3 -0
- NVIDIA представила модель, которая анализирует звук, речь и музыку.pdf +3 -0
- code/Audio-Flamingo-3-Pinokio.zip +3 -0
- code/Audio-Flamingo-3.zip +3 -0
- code/AudioFlamingo.zip +3 -0
- code/audio-flamingo-3-chat-hf.zip +3 -0
- code/audio-flamingo-3-hf.zip +3 -0
- code/audio-flamingo-audio_flamingo_2.zip +3 -0
- code/audio-flamingo-audio_flamingo_3.zip +3 -0
- code/audio-flamingo-soundCoT.zip +3 -0
- code/audio-flamingo.zip +3 -0
- code/audio_flamingo.zip +3 -0
- code/cog-nvidia-audio-flamingo-3.zip +3 -0
- models/audio-flamingo-1/.gitattributes +2 -0
- models/audio-flamingo-1/.gitignore +5 -0
- models/audio-flamingo-1/LICENSE +21 -0
- models/audio-flamingo-1/README.md +64 -0
- models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf +3 -0
- models/audio-flamingo-1/assets/audio_flamingo_arch.png +3 -0
- models/audio-flamingo-1/audio flamingo model card.md +115 -0
- models/audio-flamingo-1/chat/README.md +65 -0
- models/audio-flamingo-1/chat/clap_modified_code/CLAPWrapper.py +463 -0
- models/audio-flamingo-1/chat/configs/chat.yaml +80 -0
- models/audio-flamingo-1/chat/data/README.md +19 -0
- models/audio-flamingo-1/chat/data/data.py +481 -0
- models/audio-flamingo-1/chat/data/prepare_each_dataset.py +253 -0
- models/audio-flamingo-1/chat/src/__init__.py +2 -0
- models/audio-flamingo-1/chat/src/factory.py +219 -0
- models/audio-flamingo-1/chat/src/flamingo.py +260 -0
- models/audio-flamingo-1/chat/src/flamingo_lm.py +177 -0
- models/audio-flamingo-1/chat/src/helpers.py +380 -0
- models/audio-flamingo-1/chat/src/utils.py +54 -0
- models/audio-flamingo-1/chat/train/distributed.py +150 -0
- models/audio-flamingo-1/chat/train/train.py +376 -0
- models/audio-flamingo-1/chat/train/train_utils.py +351 -0
- models/audio-flamingo-1/checkpoints/chat_part1.pt +3 -0
- models/audio-flamingo-1/checkpoints/chat_part2.pt +3 -0
- models/audio-flamingo-1/checkpoints/chat_part3.pt +3 -0
- models/audio-flamingo-1/checkpoints/chat_part4.pt +3 -0
- models/audio-flamingo-1/checkpoints/chat_part5.pt +3 -0
- models/audio-flamingo-1/checkpoints/checkpoint_utils.py +19 -0
- models/audio-flamingo-1/checkpoints/foundation_part1.pt +3 -0
- models/audio-flamingo-1/checkpoints/foundation_part2.pt +3 -0
- models/audio-flamingo-1/checkpoints/foundation_part3.pt +3 -0
- models/audio-flamingo-1/checkpoints/foundation_part4.pt +3 -0
- models/audio-flamingo-1/checkpoints/foundation_part5.pt +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,26 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Audio[[:space:]]Flamingo[[:space:]]2.[[:space:]]An[[:space:]]Audio-Language[[:space:]]Model[[:space:]]with[[:space:]]Long-Audio[[:space:]]Understanding[[:space:]]and[[:space:]]Expert[[:space:]]Reasoning[[:space:]]Abilities.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Audio[[:space:]]Flamingo[[:space:]]3.[[:space:]]Advancing[[:space:]]Audio[[:space:]]Intelligence[[:space:]]with[[:space:]]Fully[[:space:]]Open[[:space:]]Large[[:space:]]Audio[[:space:]]Language[[:space:]]Models.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Audio[[:space:]]Flamingo[[:space:]]Sound-CoT[[:space:]]Technical[[:space:]]Report.[[:space:]]Improving[[:space:]]Chain-of-Thought[[:space:]]Reasoning[[:space:]]in[[:space:]]Sound[[:space:]]Understanding.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Audio[[:space:]]Flamingo.[[:space:]]A[[:space:]]Novel[[:space:]]Audio[[:space:]]Language[[:space:]]Model[[:space:]]with[[:space:]]Few-Shot[[:space:]]Learning[[:space:]]and[[:space:]]Dialogue[[:space:]]Abilities.pdf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/audio-flamingo-1/assets/audio_flamingo_arch.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
models/audio-flamingo-1/labeling_machine/AF-AudioSet.json filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
models/audio-flamingo-3-chat/llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
models/audio-flamingo-3-chat/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
models/audio-flamingo-3-chat/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/audio-flamingo-3-chat/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/audio-flamingo-3-chat/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
models/audio-flamingo-3-hf/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
models/audio-flamingo-3-hf/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
models/audio-flamingo-3-hf/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
models/audio-flamingo-3-hf/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
models/audio-flamingo-3-hf/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
models/audio-flamingo-3/llm/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
models/audio-flamingo-3/static/af3_main_diagram-1.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
models/audio-flamingo-3/static/af3_radial-1.png filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
models/audio-flamingo-3/static/af3_sota.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
models/audio-flamingo-3/static/logo-no-bg.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
NVIDIA[[:space:]]представила[[:space:]]модель,[[:space:]]которая[[:space:]]анализирует[[:space:]]звук,[[:space:]]речь[[:space:]]и[[:space:]]музыку.pdf filter=lfs diff=lfs merge=lfs -text
|
Audio Flamingo 2. An Audio-Language Model with Long-Audio Understanding and Expert Reasoning Abilities.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:08d544cf57f324020ee5d9ff916c17d53aced283c09d38be09f9bc020a9ba171
|
| 3 |
+
size 10739247
|
Audio Flamingo 3. Advancing Audio Intelligence with Fully Open Large Audio Language Models.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f966955628f247ed76c7207b7b86048a1790794cc3b5cea47287ec14417f3508
|
| 3 |
+
size 6985793
|
Audio Flamingo Sound-CoT Technical Report. Improving Chain-of-Thought Reasoning in Sound Understanding.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20cf54c8128ca96298a342f066239370a122e42642ae3ded1f67b76cd8f80a4d
|
| 3 |
+
size 585189
|
Audio Flamingo. A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31451a317ebb7d0f4445500134cf5178b63c617d7f6583e0eac4f3a4c3d0000d
|
| 3 |
+
size 1444685
|
NVIDIA представила модель, которая анализирует звук, речь и музыку.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e96f68efb1fade752565268797ac90417bd142450eb8c8245c89f94994c22d09
|
| 3 |
+
size 2983908
|
code/Audio-Flamingo-3-Pinokio.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d740f122a84c86b6e55574c8f6ce7145a8518ccb00a874661810124d5bf1f71
|
| 3 |
+
size 1365689
|
code/Audio-Flamingo-3.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb1f6c63f18a25cc0db01c146783c651286fc53bb064d003b11518b62e7f59c2
|
| 3 |
+
size 6721741
|
code/AudioFlamingo.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:40d2bac296affd39f8f63066c9bc4a2ac0f6ec982c542c3f3c8a961e1ef68ca3
|
| 3 |
+
size 2578624
|
code/audio-flamingo-3-chat-hf.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:afd4af51c0c1a6e2e3b54323dea6b34872c3221826ea969ff40a6e055e3de0e4
|
| 3 |
+
size 1395827
|
code/audio-flamingo-3-hf.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3157b02422c02ccf87f00b99d0db9ad6ba78101fe43985461db101f418a4e1b4
|
| 3 |
+
size 1443418
|
code/audio-flamingo-audio_flamingo_2.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d71ee0ac516346df1cfc497da306b729cbe52c1f88c327a0d32ae36f22111450
|
| 3 |
+
size 5672326
|
code/audio-flamingo-audio_flamingo_3.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:db2c7f1847f5f2380f58bd78aa93326a1262b98bc2ff179206a26f67d7c2b371
|
| 3 |
+
size 3445237
|
code/audio-flamingo-soundCoT.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7e130b11f7d96aca17b9d6feda693d4b9e65ad7b2a35d374451eb24875ac820
|
| 3 |
+
size 12563876
|
code/audio-flamingo.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8cc93bf1b574c278112642af6e930485d065b3818e40392fc08b6cbd621f6f1
|
| 3 |
+
size 2484492
|
code/audio_flamingo.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:971305b4acb2b932be39abe6b376a6d3c52dece06fa6873220631c48a486ba81
|
| 3 |
+
size 11632722
|
code/cog-nvidia-audio-flamingo-3.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5921a9348dea2c34db7c1542e66461e392dfe92410ad4072001b675ebb87e2eb
|
| 3 |
+
size 17389826
|
models/audio-flamingo-1/.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
AF-AudioSet.json filter=lfs diff=lfs merge=lfs -text
|
models/audio-flamingo-1/.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.DS_Store
|
| 4 |
+
foundation.pt
|
| 5 |
+
chat.pt
|
models/audio-flamingo-1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 NVIDIA CORPORATION.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/audio-flamingo-1/README.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch Implementation of Audio Flamingo
|
| 2 |
+
|
| 3 |
+
**Zhifeng Kong, Arushi Goel, Rohan Badlani, Wei Ping, Rafael Valle, Bryan Catanzaro**
|
| 4 |
+
|
| 5 |
+
[[Demo website]](https://audioflamingo.github.io/) [[Demo video]](https://www.youtube.com/watch?v=ucttuS28RVE) [[ICML poster]](assets/AudioFlamingo_ICML2024_poster.pdf)
|
| 6 |
+
|
| 7 |
+
This repo contains the PyTorch implementation of [Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities](https://arxiv.org/abs/2402.01831) (ICML 2024). Audio Flamingo is a novel audio-understanding language model with
|
| 8 |
+
- strong audio understanding abilities,
|
| 9 |
+
- the ability to quickly adapt to unseen tasks via in-context learning and retrieval, and
|
| 10 |
+
- strong multi-turn dialogue abilities.
|
| 11 |
+
|
| 12 |
+
We introduce a series of training techniques, architecture design, and data strategies to enhance our model with these abilities. Extensive evaluations across various audio understanding tasks confirm the efficacy of our method, setting new state-of-the-art benchmarks.
|
| 13 |
+
|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
## Code Structure
|
| 17 |
+
|
| 18 |
+
- The folder ```foundation/``` contains training code for the foundation model.
|
| 19 |
+
- The folder ```chat/``` contains training code for the chat model, which can perform multi-turn dialogues.
|
| 20 |
+
- The folder ```inference/``` contains inference code for both the foundation and chat models.
|
| 21 |
+
|
| 22 |
+
Within each folder, the structure is highly based on the [Open Flamingo](https://github.com/mlfoundations/open_flamingo) repo (commit ```a05dcba```). Each folder is self-contained and we expect no cross dependencies between these folders.
|
| 23 |
+
|
| 24 |
+
## Preparation
|
| 25 |
+
|
| 26 |
+
- Download source code of Laion-CLAP from their [official repo](https://github.com/LAION-AI/CLAP). Rename the folder to ```my_laion_clap/``` and copy the folder to under each of ```foundation/, chat/, inference/```. Download their pretrained checkpoints to ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/laion-clap-pretrained/laion_clap/```.
|
| 27 |
+
- Download source code of Microsoft-CLAP from their [official repo](https://github.com/microsoft/CLAP). Rename the folder to ```my_ms_clap/``` and copy the folder to under each of ```foundation/, chat/, inference/```. In each of these, replace the ```my_ms_clap/msclap/CLAPWrapper.py``` with ```clap_modified_code/CLAPWrapper.py```, which adds some processing functions and removes some bugs for clapcap. Download their pretrained checkpoints to ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/clap/```.
|
| 28 |
+
- Download raw training and evaluation datasets from their original sources. Refer to ```foundation/data/README.md``` and ```chat/data/README.md``` for specific instructions to prepare data.
|
| 29 |
+
|
| 30 |
+
## Running the Code
|
| 31 |
+
|
| 32 |
+
We refer to ```foundation/README.md```, ```chat/README.md```, and ```inference/README.md``` for the specific instructions for training the foundation model, training the chat model, and inferencing, as they require different setups. We used 8 A100 GPUs to train our models.
|
| 33 |
+
|
| 34 |
+
## Checkpoints
|
| 35 |
+
- The folder ```checkpoints/``` contains foundation and chat model checkpoints.
|
| 36 |
+
- Each model is about 17GB. Due to ```git lfs``` constraints we split each model into 5 parts. After downloading, go to ```checkpoints/``` and ```python checkpoint_utils.py``` to merge the parts.
|
| 37 |
+
- Alternatively, the model checkpoints are also on HuggingFace (which is easier to download): [https://huggingface.co/nvidia/audio-flamingo](https://huggingface.co/nvidia/audio-flamingo). One can either ```git clone``` this project or use the ```huggingface_hub.hf_hub_download``` function to download: ```checkpoint_path = hf_hub_download(repo_id="nvidia/audio-flamingo", filename="foundation(or chat).pt")```.
|
| 38 |
+
- If you would like to run inference with these checkpoints, remember to modify the absolute paths in ```inference/configs/*.yaml``` and ```inference/inference_examples.py``` to properly load model checkpoints and data (see ```inference/README.md```).
|
| 39 |
+
- The foundation model is pretrained with ```foundation/configs/foundation_pretrain.yaml``` and then finetuned with ```foundation/configs/foundation_sft_8_shot.yaml```.
|
| 40 |
+
- The chat model is pretrained with ```foundation/configs/foundation_pretrain.yaml```, then finetuned with ```foundation/configs/foundation_sft_4_shot.yaml```, and finally finetuned with ```chat/configs/chat.yaml```.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## Downstream applications
|
| 44 |
+
- We use Audio Flamingo as a data labeling machine for synthetic captions. See ```labeling_machine/``` for details of the synthetic dataset and license descriptions.
|
| 45 |
+
|
| 46 |
+
## References
|
| 47 |
+
|
| 48 |
+
The main training and inferencing code within each folder (```foundation/```, ```chat/```, ```inference/```), including ```train/```, ```src/```, ```data/```, and ```configs/```, are modified from [Open Flamingo](https://github.com/mlfoundations/open_flamingo) (commit ```a05dcba```) (MIT license), which borrows from [flamingo-pytorch](https://github.com/lucidrains/flamingo-pytorch) (MIT license), [flamingo-mini](https://github.com/dhansmair/flamingo-mini) (MIT license), and [open_clip](https://github.com/mlfoundations/open_clip) (MIT license). ```src/helpers.py``` also includes self-attention implementations based on [attention-is-all-you-need-pytorch](https://github.com/jadore801120/attention-is-all-you-need-pytorch) (MIT license), which borrows from [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) (MIT license). Our code also relies on [LAION-AI/CLAP](https://github.com/LAION-AI/CLAP) (CC0-1.0 license) and [microsoft/CLAP](https://github.com/microsoft/CLAP) (MIT license). In ```chat/data/prepare_each_dataset.py```, the filtering keywords are based on the [LLARK](https://arxiv.org/abs/2310.07160) paper (CC-BY-4.0 license) and the [LTU](https://arxiv.org/abs/2305.10790) paper (CC-BY-4.0 license).
|
| 49 |
+
|
| 50 |
+
## License
|
| 51 |
+
|
| 52 |
+
- The code in this repo is under MIT license (see ```LICENSE```).
|
| 53 |
+
- The checkpoints in this repo (```checkpoints/*.pt```) are for non-commercial use only. They are subject to the [OPT-IML](https://huggingface.co/facebook/opt-iml-1.3b/blob/main/LICENSE.md) license, the [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and the original licenses accompanying each training dataset.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
## Citation
|
| 57 |
+
```
|
| 58 |
+
@article{kong2024audio,
|
| 59 |
+
title={Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities},
|
| 60 |
+
author={Kong, Zhifeng and Goel, Arushi and Badlani, Rohan and Ping, Wei and Valle, Rafael and Catanzaro, Bryan},
|
| 61 |
+
journal={arXiv preprint arXiv:2402.01831},
|
| 62 |
+
year={2024}
|
| 63 |
+
}
|
| 64 |
+
```
|
models/audio-flamingo-1/assets/AudioFlamingo_ICML2024_poster.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f63fccc267408123e6119d0293ff81a6dfbe6979293d451ac14a2a3cc9abe98
|
| 3 |
+
size 1170996
|
models/audio-flamingo-1/assets/audio_flamingo_arch.png
ADDED
|
Git LFS Details
|
models/audio-flamingo-1/audio flamingo model card.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model Overview
|
| 2 |
+
|
| 3 |
+
## Description:
|
| 4 |
+
Audio Flamingo is a novel audio-understanding language model for
|
| 5 |
+
|
| 6 |
+
- understanding audio,
|
| 7 |
+
- quickly adapting to unseen tasks via in-context learning and retrieval, and
|
| 8 |
+
- understanding and responding to multi-turn dialogues
|
| 9 |
+
|
| 10 |
+
We introduce a series of training techniques, architecture design, and data strategies to enhance our model with these abilities. Extensive evaluations across various audio understanding tasks confirm the efficacy of our method, setting new state-of-the-art benchmarks.
|
| 11 |
+
|
| 12 |
+
<center><img src="https://github.com/NVIDIA/audio-flamingo/raw/main/assets/audio_flamingo_arch.png" width="800"></center>
|
| 13 |
+
|
| 14 |
+
**This model is ready for non-commercial research-only.**
|
| 15 |
+
<br>
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
## References(s):
|
| 19 |
+
* [Audio Flamingo: A Novel Audio Language Model with Few-Shot Learning and Dialogue Abilities](https://arxiv.org/abs/2402.01831) <br>
|
| 20 |
+
* [Project Page](https://github.com/NVIDIA/audio-flamingo) <br>
|
| 21 |
+
* [Demo Website](https://audioflamingo.github.io/) <br>
|
| 22 |
+
|
| 23 |
+
## Model Architecture:
|
| 24 |
+
**Architecture Type:** Transformer <br>
|
| 25 |
+
**Network Architecture:** Audio Flamingo
|
| 26 |
+
|
| 27 |
+
Audio Flamingo is a Flamingo-style architecture with frozen audio feature extractor, trainable transformation layers and xattn-dense layers, and language model layers.
|
| 28 |
+
|
| 29 |
+
## Input:
|
| 30 |
+
**Input Types:** Audio, Text <br>
|
| 31 |
+
**Input Format:** Wav/MP3/Flac, String <br>
|
| 32 |
+
**Input Parameters:** None <br>
|
| 33 |
+
**Maximum Audio Input Lengths:** 33.25 seconds <br>
|
| 34 |
+
**Maximum Text Input Lengths:** 512 tokens <br>
|
| 35 |
+
|
| 36 |
+
## Output:
|
| 37 |
+
**Output Type:** Text <br>
|
| 38 |
+
**Output Format:** String <br>
|
| 39 |
+
**Output Parameters:** None <br>
|
| 40 |
+
|
| 41 |
+
## Software Integration:
|
| 42 |
+
**Runtime Engine(s):** PyTorch
|
| 43 |
+
|
| 44 |
+
**Supported Hardware Microarchitecture Compatibility:**
|
| 45 |
+
* NVIDIA Ampere <br>
|
| 46 |
+
* NVIDIA Hopper <br>
|
| 47 |
+
|
| 48 |
+
## Preferred/Supported Operating System(s):
|
| 49 |
+
* Linux
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
## Model Version(s):
|
| 53 |
+
* v1.0
|
| 54 |
+
|
| 55 |
+
## Training, Testing, and Evaluation Datasets:
|
| 56 |
+
|
| 57 |
+
### Training Dataset:
|
| 58 |
+
Audio Flamingo is trained with **publicly available** datasets under various licenses, with the most restricted ones being non-commercial/research-only. The dataset contains diverse audio types including speech, environmental sounds, and music.
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
* [OpenAQA ](https://github.com/YuanGongND/ltu?tab=readme-ov-file): Data collection method - [Human]; Labeling method - [Synthetic]
|
| 62 |
+
* [Laion630K ](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md)
|
| 63 |
+
* [LP-MusicCaps ](https://github.com/seungheondoh/lp-music-caps)
|
| 64 |
+
* [SoundDescs ](https://github.com/akoepke/audio-retrieval-benchmark)
|
| 65 |
+
* [WavCaps](https://github.com/XinhaoMei/WavCaps)
|
| 66 |
+
* [AudioSet ](https://research.google.com/audioset/download.html)
|
| 67 |
+
* [AudioSet Strong Labeled ](https://research.google.com/audioset/download_strong.html)
|
| 68 |
+
* [WavText5K ](https://github.com/microsoft/WavText5K)
|
| 69 |
+
* [MSP-Podcast ](https://ecs.utdallas.edu/research/researchlabs/msp-lab/MSP-Podcast.html)
|
| 70 |
+
* [ClothoAQA ](https://zenodo.org/records/6473207)
|
| 71 |
+
* [Clotho-v2 ](https://github.com/audio-captioning/clotho-dataset/tree/master)
|
| 72 |
+
* [MACS ](https://zenodo.org/records/5114771)
|
| 73 |
+
* [FSD50k ](https://zenodo.org/records/4060432)
|
| 74 |
+
* [CochlScene ](https://github.com/cochlearai/cochlscene)
|
| 75 |
+
* [NonSpeech 7k ](https://zenodo.org/records/6967442)
|
| 76 |
+
* [Chime-home ](https://code.soundsoftware.ac.uk/projects/chime-home-dataset-annotation-and-baseline-evaluation-code)
|
| 77 |
+
* [Sonyc-UST ](https://zenodo.org/records/3966543)
|
| 78 |
+
* [Emov-DB ](https://github.com/numediart/EmoV-DB)
|
| 79 |
+
* [JL-Corpus ](https://github.com/tli725/JL-Corpus)
|
| 80 |
+
* [Tess ](https://www.kaggle.com/datasets/ejlok1/toronto-emotional-speech-set-tess)
|
| 81 |
+
* [OMGEmotion ](https://github.com/knowledgetechnologyuhh/OMGEmotionChallenge)
|
| 82 |
+
* [MELD ](https://github.com/declare-lab/MELD)
|
| 83 |
+
* [MusicAVQA ](https://gewu-lab.github.io/MUSIC-AVQA/)
|
| 84 |
+
* [MusicQA ](https://github.com/shansongliu/MU-LLaMA?tab=readme-ov-file)
|
| 85 |
+
* [MusicCaps ](https://www.kaggle.com/datasets/googleai/musiccaps)
|
| 86 |
+
* [NSynth ](https://magenta.tensorflow.org/datasets/nsynth)
|
| 87 |
+
* [MTG-Jamendo ](https://github.com/MTG/mtg-jamendo-dataset)
|
| 88 |
+
* [MusDB-HQ ](https://zenodo.org/records/3338373)
|
| 89 |
+
* [FMA ](https://github.com/mdeff/fma)
|
| 90 |
+
|
| 91 |
+
For all of these datasets, the data collection method is [human]. For OpenAQA, Laion630k, LP-MusicCaps, WavCaps, MusicQA, the data labeling method is [synthetic]. For the rest, the data labeling method is [human].
|
| 92 |
+
|
| 93 |
+
### Evaluating Dataset:
|
| 94 |
+
Audio Flamingo is evaluated on the test split of the following datasets.
|
| 95 |
+
|
| 96 |
+
* [ClothoAQA ](https://zenodo.org/records/6473207)
|
| 97 |
+
* [MusicAVQA ](https://gewu-lab.github.io/MUSIC-AVQA/)
|
| 98 |
+
* [Clotho-v2 ](https://github.com/audio-captioning/clotho-dataset/tree/master)
|
| 99 |
+
* [FSD50k ](https://zenodo.org/records/4060432)
|
| 100 |
+
* [CochlScene ](https://github.com/cochlearai/cochlscene)
|
| 101 |
+
* [NonSpeech 7k ](https://zenodo.org/records/6967442)
|
| 102 |
+
* [NSynth ](https://magenta.tensorflow.org/datasets/nsynth)
|
| 103 |
+
* [AudioCaps ](https://github.com/cdjkim/audiocaps)
|
| 104 |
+
* [CREMA-D ](https://github.com/CheyneyComputerScience/CREMA-D)
|
| 105 |
+
* [Ravdess ](https://zenodo.org/records/1188976)
|
| 106 |
+
* [US8K ](https://urbansounddataset.weebly.com/urbansound8k.html)
|
| 107 |
+
* [GTZAN ](https://www.tensorflow.org/datasets/catalog/gtzan)
|
| 108 |
+
* [Medley-solos-DB ](https://zenodo.org/records/3464194)
|
| 109 |
+
|
| 110 |
+
For all of these datasets, the data collection method is [human] and the data labeling method is [human].
|
| 111 |
+
|
| 112 |
+
## Inference
|
| 113 |
+
|
| 114 |
+
**Engine:** HuggingFace Transformers <br>
|
| 115 |
+
**Test Hardware [Name the specific test hardware model]:** A100 80GB <br>
|
models/audio-flamingo-1/chat/README.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Audio Flamingo Training (Chat Model)
|
| 2 |
+
|
| 3 |
+
## Get data ready
|
| 4 |
+
|
| 5 |
+
Please read ```data/README.md``` for instructions on data preparation.
|
| 6 |
+
|
| 7 |
+
## Get paths ready
|
| 8 |
+
|
| 9 |
+
Let ```YOUR_REPO_ROOT_DIR``` be the absolute path to this repo. We use the following structure
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
YOUR_REPO_ROOT_DIR/
|
| 13 |
+
- foundation/
|
| 14 |
+
- chat/ # you are here
|
| 15 |
+
- inference/
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Replace ```YOUR_REPO_ROOT_DIR``` to your absolute path in the following places:
|
| 19 |
+
- ```configs/*.yaml --> clap_config --> config_root```
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Let ```YOUR_DATA_ROOT_DIR``` be the absolute path to store all data, checkpoints, etc. We use the following structure
|
| 23 |
+
```
|
| 24 |
+
YOUR_DATA_ROOT_DIR/
|
| 25 |
+
- datasets/
|
| 26 |
+
- <dataset_name_i>/
|
| 27 |
+
- files: raw data of this dataset, including raw waveforms, metadata, etc.
|
| 28 |
+
|
| 29 |
+
- audio-flamingo-data/
|
| 30 |
+
- dataset_files/
|
| 31 |
+
- <dataset_name_i>-<flamingo_task_i>/
|
| 32 |
+
- files: dataset manifests, precomputed embeddings, etc.
|
| 33 |
+
|
| 34 |
+
- checkpoint/
|
| 35 |
+
- <experiment_name>/ # same as the config file name, and train_config --> run_name in each config
|
| 36 |
+
- tensorboard/
|
| 37 |
+
- checkpoint_xxx.pt
|
| 38 |
+
- other cached files
|
| 39 |
+
|
| 40 |
+
- clap/
|
| 41 |
+
- files: pretrained Microsoft-CLAP checkpoints
|
| 42 |
+
|
| 43 |
+
- laion-clap-pretrained/laion_clap
|
| 44 |
+
- files: pretrained Laion-CLAP checkpoints
|
| 45 |
+
|
| 46 |
+
- LLM_pretrained/.cache/ # place to store HuggingFace cache instead of the default ~/.cache
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Replace ```YOUR_DATA_ROOT_DIR``` to your absolute path in the following places:
|
| 50 |
+
- ```configs/*.yaml```
|
| 51 |
+
- ```prepare_each_dataset.py --> __main__```
|
| 52 |
+
|
| 53 |
+
## Training
|
| 54 |
+
|
| 55 |
+
The following code is tested on 1 node (8 GPUs per node) of A100 (80G) GPUs.
|
| 56 |
+
|
| 57 |
+
Set ```configs/chat.yaml --> sft_config --> pretrained_path``` and ```pretrained_ckpt``` to be the checkpoint of the pretrained model.
|
| 58 |
+
```
|
| 59 |
+
export NCCL_IB_SL=1
|
| 60 |
+
export CUDA_DEVICE_MAX_CONNECTIONS=1
|
| 61 |
+
cd train/
|
| 62 |
+
torchrun --nproc_per_node 8 train.py -c ../configs/chat.yaml
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
|
models/audio-flamingo-1/chat/clap_modified_code/CLAPWrapper.py
ADDED
|
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/microsoft/CLAP under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
warnings.filterwarnings("ignore")
|
| 9 |
+
import random
|
| 10 |
+
import torchaudio
|
| 11 |
+
# from torch._six import string_classes
|
| 12 |
+
import collections
|
| 13 |
+
import re
|
| 14 |
+
import numpy as np
|
| 15 |
+
from transformers import AutoTokenizer, logging
|
| 16 |
+
try:
|
| 17 |
+
from models.clap import CLAP
|
| 18 |
+
from models.mapper import get_clapcap
|
| 19 |
+
except:
|
| 20 |
+
from .models.clap import CLAP
|
| 21 |
+
from .models.mapper import get_clapcap
|
| 22 |
+
import math
|
| 23 |
+
import torchaudio.transforms as T
|
| 24 |
+
import os
|
| 25 |
+
import torch
|
| 26 |
+
from importlib_resources import files
|
| 27 |
+
import argparse
|
| 28 |
+
import yaml
|
| 29 |
+
import sys
|
| 30 |
+
logging.set_verbosity_error()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class CLAPWrapper():
|
| 34 |
+
"""
|
| 35 |
+
A class for interfacing CLAP model.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, model_fp, config_root, version, use_cuda=False):
|
| 39 |
+
self.supported_versions = ['2022', '2023', 'clapcap']
|
| 40 |
+
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
| 41 |
+
self.file_path = os.path.realpath(__file__)
|
| 42 |
+
self.default_collate_err_msg_format = (
|
| 43 |
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
| 44 |
+
"dicts or lists; found {}")
|
| 45 |
+
self.config_root = config_root
|
| 46 |
+
self.config_as_str = self.get_config_path(version)
|
| 47 |
+
self.model_fp = model_fp
|
| 48 |
+
self.use_cuda = use_cuda
|
| 49 |
+
self.version = version
|
| 50 |
+
if 'clapcap' in self.version:
|
| 51 |
+
self.clapcap, self.tokenizer, self.args = self.load_clapcap()
|
| 52 |
+
else:
|
| 53 |
+
self.clap, self.tokenizer, self.args = self.load_clap()
|
| 54 |
+
|
| 55 |
+
def get_config_path(self, version):
|
| 56 |
+
if version in self.supported_versions:
|
| 57 |
+
return f"{self.config_root}/config_{version}.yml"
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"The specific version is not supported. The supported versions are {str(self.supported_versions)}")
|
| 60 |
+
|
| 61 |
+
def read_config_as_args(self,config_path,args=None,is_config_str=False):
|
| 62 |
+
return_dict = {}
|
| 63 |
+
|
| 64 |
+
if config_path is not None:
|
| 65 |
+
if is_config_str:
|
| 66 |
+
yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
|
| 67 |
+
else:
|
| 68 |
+
with open(config_path, "r") as f:
|
| 69 |
+
yml_config = yaml.load(f, Loader=yaml.FullLoader)
|
| 70 |
+
|
| 71 |
+
if args != None:
|
| 72 |
+
for k, v in yml_config.items():
|
| 73 |
+
if k in args.__dict__:
|
| 74 |
+
args.__dict__[k] = v
|
| 75 |
+
else:
|
| 76 |
+
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
|
| 77 |
+
else:
|
| 78 |
+
for k, v in yml_config.items():
|
| 79 |
+
return_dict[k] = v
|
| 80 |
+
|
| 81 |
+
args = args if args != None else return_dict
|
| 82 |
+
return argparse.Namespace(**args)
|
| 83 |
+
|
| 84 |
+
def load_clap(self):
|
| 85 |
+
r"""Load CLAP model with args from config file"""
|
| 86 |
+
|
| 87 |
+
args = self.read_config_as_args(self.config_as_str, is_config_str=False)
|
| 88 |
+
|
| 89 |
+
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
|
| 90 |
+
self.token_keys = ['input_ids', 'attention_mask']
|
| 91 |
+
elif 'bert' in args.text_model:
|
| 92 |
+
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
| 93 |
+
|
| 94 |
+
clap = CLAP(
|
| 95 |
+
audioenc_name=args.audioenc_name,
|
| 96 |
+
sample_rate=args.sampling_rate,
|
| 97 |
+
window_size=args.window_size,
|
| 98 |
+
hop_size=args.hop_size,
|
| 99 |
+
mel_bins=args.mel_bins,
|
| 100 |
+
fmin=args.fmin,
|
| 101 |
+
fmax=args.fmax,
|
| 102 |
+
classes_num=args.num_classes,
|
| 103 |
+
out_emb=args.out_emb,
|
| 104 |
+
text_model=args.text_model,
|
| 105 |
+
transformer_embed_dim=args.transformer_embed_dim,
|
| 106 |
+
d_proj=args.d_proj
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Load pretrained weights for model
|
| 110 |
+
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
| 111 |
+
|
| 112 |
+
# We unwrap the DDP model and save. If the model is not unwrapped and saved, then the model needs to unwrapped before `load_state_dict`:
|
| 113 |
+
# Reference link: https://discuss.pytorch.org/t/how-to-load-dataparallel-model-which-trained-using-multiple-gpus/146005
|
| 114 |
+
clap.load_state_dict(model_state_dict)
|
| 115 |
+
|
| 116 |
+
clap.eval() # set clap in eval mode
|
| 117 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
| 118 |
+
if 'gpt' in args.text_model:
|
| 119 |
+
tokenizer.add_special_tokens({'pad_token': '!'})
|
| 120 |
+
|
| 121 |
+
if self.use_cuda and torch.cuda.is_available():
|
| 122 |
+
clap = clap.cuda()
|
| 123 |
+
|
| 124 |
+
return clap, tokenizer, args
|
| 125 |
+
|
| 126 |
+
def load_clapcap(self):
|
| 127 |
+
r"""Load CLAP model with args from config file"""
|
| 128 |
+
|
| 129 |
+
args = self.read_config_as_args(self.config_as_str, is_config_str=False)
|
| 130 |
+
args.prefix_dim = args.d_proj
|
| 131 |
+
text_model = args.text_model
|
| 132 |
+
args.text_model = args.text_decoder
|
| 133 |
+
args.cross_attention = True if 'cross' in args.clapcap_model.lower() else False
|
| 134 |
+
|
| 135 |
+
if 'roberta' in args.text_model or 'clip' in args.text_model or 'gpt' in args.text_model:
|
| 136 |
+
self.token_keys = ['input_ids', 'attention_mask']
|
| 137 |
+
elif 'bert' in args.text_model:
|
| 138 |
+
self.token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
|
| 139 |
+
|
| 140 |
+
clap = CLAP(
|
| 141 |
+
audioenc_name=args.audioenc_name,
|
| 142 |
+
sample_rate=args.sampling_rate,
|
| 143 |
+
window_size=args.window_size,
|
| 144 |
+
hop_size=args.hop_size,
|
| 145 |
+
mel_bins=args.mel_bins,
|
| 146 |
+
fmin=args.fmin,
|
| 147 |
+
fmax=args.fmax,
|
| 148 |
+
classes_num=args.num_classes,
|
| 149 |
+
out_emb=args.out_emb,
|
| 150 |
+
text_model=text_model,
|
| 151 |
+
transformer_embed_dim=args.transformer_embed_dim,
|
| 152 |
+
d_proj=args.d_proj
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
clapcap = get_clapcap(args.clapcap_model)(clap, args.text_decoder, args.prefix_length, args.prefix_length_clip, args.prefix_dim,
|
| 156 |
+
args.num_layers, args.normalize_prefix, args.mapping_type, True, True)
|
| 157 |
+
|
| 158 |
+
model_state_dict = torch.load(self.model_fp, map_location=torch.device('cpu'))['model']
|
| 159 |
+
clapcap.load_state_dict(model_state_dict)
|
| 160 |
+
|
| 161 |
+
clapcap.eval() # set clap in eval mode
|
| 162 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_model)
|
| 163 |
+
if 'gpt' in args.text_model:
|
| 164 |
+
tokenizer.add_special_tokens({'pad_token': '!'})
|
| 165 |
+
|
| 166 |
+
if self.use_cuda and torch.cuda.is_available():
|
| 167 |
+
clapcap = clapcap.cuda()
|
| 168 |
+
|
| 169 |
+
return clapcap, tokenizer, args
|
| 170 |
+
|
| 171 |
+
def default_collate(self, batch):
|
| 172 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
| 173 |
+
elem = batch[0]
|
| 174 |
+
elem_type = type(elem)
|
| 175 |
+
if isinstance(elem, torch.Tensor):
|
| 176 |
+
out = None
|
| 177 |
+
if torch.utils.data.get_worker_info() is not None:
|
| 178 |
+
# If we're in a background process, concatenate directly into a
|
| 179 |
+
# shared memory tensor to avoid an extra copy
|
| 180 |
+
numel = sum([x.numel() for x in batch])
|
| 181 |
+
storage = elem.storage()._new_shared(numel)
|
| 182 |
+
out = elem.new(storage)
|
| 183 |
+
return torch.stack(batch, 0, out=out)
|
| 184 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
| 185 |
+
and elem_type.__name__ != 'string_':
|
| 186 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
| 187 |
+
# array of string classes and object
|
| 188 |
+
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
| 189 |
+
raise TypeError(
|
| 190 |
+
self.default_collate_err_msg_format.format(elem.dtype))
|
| 191 |
+
|
| 192 |
+
return self.default_collate([torch.as_tensor(b) for b in batch])
|
| 193 |
+
elif elem.shape == (): # scalars
|
| 194 |
+
return torch.as_tensor(batch)
|
| 195 |
+
elif isinstance(elem, float):
|
| 196 |
+
return torch.tensor(batch, dtype=torch.float64)
|
| 197 |
+
elif isinstance(elem, int):
|
| 198 |
+
return torch.tensor(batch)
|
| 199 |
+
# elif isinstance(elem, string_classes):
|
| 200 |
+
# return batch
|
| 201 |
+
elif isinstance(elem, collections.abc.Mapping):
|
| 202 |
+
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
|
| 203 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
| 204 |
+
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
|
| 205 |
+
elif isinstance(elem, collections.abc.Sequence):
|
| 206 |
+
# check to make sure that the elements in batch have consistent size
|
| 207 |
+
it = iter(batch)
|
| 208 |
+
elem_size = len(next(it))
|
| 209 |
+
if not all(len(elem) == elem_size for elem in it):
|
| 210 |
+
raise RuntimeError(
|
| 211 |
+
'each element in list of batch should be of equal size')
|
| 212 |
+
transposed = zip(*batch)
|
| 213 |
+
return [self.default_collate(samples) for samples in transposed]
|
| 214 |
+
|
| 215 |
+
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
|
| 216 |
+
|
| 217 |
+
def read_audio(self, audio_path, resample=False):
|
| 218 |
+
r"""Loads audio file or array and returns a torch tensor"""
|
| 219 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
| 220 |
+
audio_time_series, sample_rate = torchaudio.load(audio_path)
|
| 221 |
+
|
| 222 |
+
resample_rate = self.args.sampling_rate
|
| 223 |
+
if resample:
|
| 224 |
+
resampler = T.Resample(sample_rate, resample_rate)
|
| 225 |
+
audio_time_series = resampler(audio_time_series)
|
| 226 |
+
return audio_time_series, sample_rate
|
| 227 |
+
|
| 228 |
+
def load_audio_into_tensor(self, audio_path, audio_duration, resample=False):
|
| 229 |
+
r"""Loads audio file and returns raw audio."""
|
| 230 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
| 231 |
+
audio_time_series, sample_rate = self.read_audio(audio_path, resample=False)
|
| 232 |
+
audio_time_series = audio_time_series.reshape(-1)
|
| 233 |
+
|
| 234 |
+
# audio_time_series is shorter than predefined audio duration,
|
| 235 |
+
# so audio_time_series is extended
|
| 236 |
+
if audio_duration*sample_rate >= audio_time_series.shape[0]:
|
| 237 |
+
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
|
| 238 |
+
audio_time_series.shape[0]))
|
| 239 |
+
# Repeat audio_time_series by repeat_factor to match audio_duration
|
| 240 |
+
audio_time_series = audio_time_series.repeat(repeat_factor)
|
| 241 |
+
# remove excess part of audio_time_series
|
| 242 |
+
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
|
| 243 |
+
else:
|
| 244 |
+
# audio_time_series is longer than predefined audio duration,
|
| 245 |
+
# so audio_time_series is trimmed
|
| 246 |
+
start_index = random.randrange(
|
| 247 |
+
audio_time_series.shape[0] - audio_duration*sample_rate)
|
| 248 |
+
audio_time_series = audio_time_series[start_index:start_index +
|
| 249 |
+
audio_duration*sample_rate]
|
| 250 |
+
return torch.FloatTensor(audio_time_series)
|
| 251 |
+
|
| 252 |
+
# Modified
|
| 253 |
+
def load_audio_clip_into_tensor(self, audio_clip, audio_duration, resample=False):
|
| 254 |
+
r"""Loads audio clip and returns raw audio."""
|
| 255 |
+
# Randomly sample a segment of audio_duration from the clip or pad to match duration
|
| 256 |
+
sample_rate = 44100
|
| 257 |
+
audio_time_series = audio_clip.reshape(-1)
|
| 258 |
+
|
| 259 |
+
# audio_time_series is shorter than predefined audio duration,
|
| 260 |
+
# so audio_time_series is extended
|
| 261 |
+
assert audio_duration * sample_rate >= audio_time_series.shape[0], \
|
| 262 |
+
'dur * sr = {} should be larger than len = {}'.format(audio_duration * sample_rate, audio_time_series.shape[0])
|
| 263 |
+
repeat_factor = int(np.ceil((audio_duration*sample_rate) /
|
| 264 |
+
audio_time_series.shape[0]))
|
| 265 |
+
# Repeat audio_time_series by repeat_factor to match audio_duration
|
| 266 |
+
audio_time_series = audio_time_series.repeat(repeat_factor)
|
| 267 |
+
# remove excess part of audio_time_series
|
| 268 |
+
audio_time_series = audio_time_series[0:audio_duration*sample_rate]
|
| 269 |
+
|
| 270 |
+
# return torch.FloatTensor(audio_time_series)
|
| 271 |
+
return audio_time_series # already on cuda device
|
| 272 |
+
|
| 273 |
+
def preprocess_audio(self, audio_files, resample):
|
| 274 |
+
r"""Load list of audio files and return raw audio"""
|
| 275 |
+
audio_tensors = []
|
| 276 |
+
for audio_file in audio_files:
|
| 277 |
+
audio_tensor = self.load_audio_into_tensor(
|
| 278 |
+
audio_file, self.args.duration, resample)
|
| 279 |
+
audio_tensor = audio_tensor.reshape(
|
| 280 |
+
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
|
| 281 |
+
audio_tensors.append(audio_tensor)
|
| 282 |
+
return self.default_collate(audio_tensors)
|
| 283 |
+
|
| 284 |
+
# Modified
|
| 285 |
+
def preprocess_audio_clips(self, audio_clips, resample=False):
|
| 286 |
+
r"""Load list of audio clips and return raw audio"""
|
| 287 |
+
audio_tensors = []
|
| 288 |
+
for audio_clip in audio_clips:
|
| 289 |
+
audio_tensor = self.load_audio_clip_into_tensor(
|
| 290 |
+
audio_clip, self.args.duration, resample=False)
|
| 291 |
+
audio_tensor = audio_tensor.reshape(
|
| 292 |
+
1, -1).cuda() if self.use_cuda and torch.cuda.is_available() else audio_tensor.reshape(1, -1)
|
| 293 |
+
audio_tensors.append(audio_tensor)
|
| 294 |
+
return self.default_collate(audio_tensors)
|
| 295 |
+
|
| 296 |
+
def preprocess_text(self, text_queries):
|
| 297 |
+
r"""Load list of class labels and return tokenized text"""
|
| 298 |
+
tokenized_texts = []
|
| 299 |
+
for ttext in text_queries:
|
| 300 |
+
if 'gpt' in self.args.text_model:
|
| 301 |
+
ttext = ttext + ' <|endoftext|>'
|
| 302 |
+
tok = self.tokenizer.encode_plus(
|
| 303 |
+
text=ttext, add_special_tokens=True, max_length=self.args.text_len, padding='max_length', return_tensors="pt")
|
| 304 |
+
for key in self.token_keys:
|
| 305 |
+
tok[key] = tok[key].reshape(-1).cuda() if self.use_cuda and torch.cuda.is_available() else tok[key].reshape(-1)
|
| 306 |
+
tokenized_texts.append(tok)
|
| 307 |
+
return self.default_collate(tokenized_texts)
|
| 308 |
+
|
| 309 |
+
def get_text_embeddings(self, class_labels):
|
| 310 |
+
r"""Load list of class labels and return text embeddings"""
|
| 311 |
+
preprocessed_text = self.preprocess_text(class_labels)
|
| 312 |
+
return self._get_text_embeddings(preprocessed_text)
|
| 313 |
+
|
| 314 |
+
def get_audio_embeddings(self, audio_files, resample):
|
| 315 |
+
r"""Load list of audio files and return a audio embeddings"""
|
| 316 |
+
preprocessed_audio = self.preprocess_audio(audio_files, resample)
|
| 317 |
+
return self._get_audio_embeddings(preprocessed_audio)
|
| 318 |
+
|
| 319 |
+
# Modified
|
| 320 |
+
def get_audio_embeddings_from_clips(self, audio_clips, resample=False):
|
| 321 |
+
r"""Load list of audio files and return a audio embeddings"""
|
| 322 |
+
preprocessed_audio = self.preprocess_audio_clips(audio_clips, resample)
|
| 323 |
+
return self._get_audio_embeddings(preprocessed_audio)
|
| 324 |
+
|
| 325 |
+
def _get_text_embeddings(self, preprocessed_text):
|
| 326 |
+
r"""Load preprocessed text and return text embeddings"""
|
| 327 |
+
with torch.no_grad():
|
| 328 |
+
return self.clap.caption_encoder(preprocessed_text)
|
| 329 |
+
|
| 330 |
+
# Modified
|
| 331 |
+
def _get_audio_embeddings(self, preprocessed_audio):
|
| 332 |
+
r"""Load preprocessed audio and return a audio embeddings"""
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
preprocessed_audio = preprocessed_audio.reshape(
|
| 335 |
+
preprocessed_audio.shape[0], preprocessed_audio.shape[2])
|
| 336 |
+
#Append [0] the audio emebdding, [1] has output class probabilities
|
| 337 |
+
if 'clapcap' in self.version:
|
| 338 |
+
return self.clapcap.clap(preprocessed_audio)[0]
|
| 339 |
+
else:
|
| 340 |
+
return self.clap.audio_encoder(preprocessed_audio)[0]
|
| 341 |
+
|
| 342 |
+
def _generic_batch_inference(self, func, *args):
|
| 343 |
+
r"""Process audio and/or text per batch"""
|
| 344 |
+
input_tmp = args[0]
|
| 345 |
+
batch_size = args[-1]
|
| 346 |
+
# args[0] has audio_files, args[1] has class_labels
|
| 347 |
+
inputs = [args[0], args[1]] if len(args) == 3 else [args[0]]
|
| 348 |
+
args0_len = len(args[0])
|
| 349 |
+
# compute text_embeddings once for all the audio_files batches
|
| 350 |
+
if len(inputs) == 2:
|
| 351 |
+
text_embeddings = self.get_text_embeddings(args[1])
|
| 352 |
+
inputs = [args[0], args[1], text_embeddings]
|
| 353 |
+
dataset_idx = 0
|
| 354 |
+
for _ in range(math.ceil(args0_len/batch_size)):
|
| 355 |
+
next_batch_idx = dataset_idx + batch_size
|
| 356 |
+
# batch size is bigger than available audio/text items
|
| 357 |
+
if next_batch_idx >= args0_len:
|
| 358 |
+
inputs[0] = input_tmp[dataset_idx:]
|
| 359 |
+
return func(*tuple(inputs))
|
| 360 |
+
else:
|
| 361 |
+
inputs[0] = input_tmp[dataset_idx:next_batch_idx]
|
| 362 |
+
yield func(*tuple(inputs))
|
| 363 |
+
dataset_idx = next_batch_idx
|
| 364 |
+
|
| 365 |
+
def get_audio_embeddings_per_batch(self, audio_files, batch_size):
|
| 366 |
+
r"""Load preprocessed audio and return a audio embeddings per batch"""
|
| 367 |
+
return self._generic_batch_inference(self.get_audio_embeddings, audio_files, batch_size)
|
| 368 |
+
|
| 369 |
+
def get_text_embeddings_per_batch(self, class_labels, batch_size):
|
| 370 |
+
r"""Load preprocessed text and return text embeddings per batch"""
|
| 371 |
+
return self._generic_batch_inference(self.get_text_embeddings, class_labels, batch_size)
|
| 372 |
+
|
| 373 |
+
def compute_similarity(self, audio_embeddings, text_embeddings):
|
| 374 |
+
r"""Compute similarity between text and audio embeddings"""
|
| 375 |
+
audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
|
| 376 |
+
text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)
|
| 377 |
+
|
| 378 |
+
logit_scale = self.clap.logit_scale.exp()
|
| 379 |
+
similarity = logit_scale*text_embeddings @ audio_embeddings.T
|
| 380 |
+
return similarity.T
|
| 381 |
+
|
| 382 |
+
def classify_audio_files_per_batch(self, audio_files, class_labels, batch_size):
|
| 383 |
+
r"""Compute classification probabilities for each audio recording in a batch and each class label"""
|
| 384 |
+
return self._generic_batch_inference(self.classify_audio_files, audio_files, class_labels, batch_size)
|
| 385 |
+
|
| 386 |
+
def generate_caption(self, audio_files, resample=True, beam_size: int = 5, entry_length=67, temperature=1.):
|
| 387 |
+
r"""Generate audio captions for each audio recording in a batch"""
|
| 388 |
+
captions = []
|
| 389 |
+
audio_tensors = self.preprocess_audio(audio_files, resample)
|
| 390 |
+
|
| 391 |
+
with torch.no_grad():
|
| 392 |
+
prefix = self.clapcap.clap(audio_tensors.squeeze(1))[0]
|
| 393 |
+
if self.args.normalize_prefix:
|
| 394 |
+
prefix = prefix / prefix.norm(2, -1).reshape(-1,1)
|
| 395 |
+
prefix_embed = self.clapcap.clap_project(prefix).view(-1, self.args.prefix_length, self.clapcap.gpt.transformer.wte.weight.shape[1])
|
| 396 |
+
|
| 397 |
+
for i in range(len(audio_tensors)):
|
| 398 |
+
gen_caption = self._generate_beam(embed=prefix_embed[i].unsqueeze(0),\
|
| 399 |
+
beam_size=beam_size,\
|
| 400 |
+
entry_length=entry_length,\
|
| 401 |
+
temperature=temperature)[0]
|
| 402 |
+
captions.append(gen_caption.capitalize())
|
| 403 |
+
return captions
|
| 404 |
+
|
| 405 |
+
def _generate_beam(self, beam_size: int = 5, prompt=None, embed=None,
|
| 406 |
+
entry_length=67, temperature=1., stop_token: str = ' <|endoftext|>'):
|
| 407 |
+
r"""Generate captions by beam search decoding"""
|
| 408 |
+
self.clapcap.eval()
|
| 409 |
+
stop_token_index = self.tokenizer.encode(stop_token)[0]
|
| 410 |
+
tokens = None
|
| 411 |
+
scores = None
|
| 412 |
+
device = next(self.clapcap.parameters()).device
|
| 413 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
| 414 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
| 415 |
+
with torch.no_grad():
|
| 416 |
+
if embed is not None:
|
| 417 |
+
generated = embed
|
| 418 |
+
else:
|
| 419 |
+
if tokens is None:
|
| 420 |
+
tokens = torch.tensor(self.tokenizer.encode(prompt))
|
| 421 |
+
tokens = tokens.unsqueeze(0).to(device)
|
| 422 |
+
generated = self.clapcap.gpt.transformer.wte(tokens)
|
| 423 |
+
for i in range(entry_length):
|
| 424 |
+
outputs = self.clapcap.gpt(inputs_embeds=generated)
|
| 425 |
+
logits = outputs.logits
|
| 426 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 427 |
+
logits = logits.softmax(-1).log()
|
| 428 |
+
if scores is None:
|
| 429 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
| 430 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
| 431 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
| 432 |
+
if tokens is None:
|
| 433 |
+
tokens = next_tokens
|
| 434 |
+
else:
|
| 435 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
| 436 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 437 |
+
else:
|
| 438 |
+
logits[is_stopped] = -float(np.inf)
|
| 439 |
+
logits[is_stopped, 0] = 0
|
| 440 |
+
scores_sum = scores[:, None] + logits
|
| 441 |
+
seq_lengths[~is_stopped] += 1
|
| 442 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
| 443 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
|
| 444 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
| 445 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
| 446 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
| 447 |
+
next_tokens = next_tokens.unsqueeze(1)
|
| 448 |
+
tokens = tokens[next_tokens_source]
|
| 449 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 450 |
+
generated = generated[next_tokens_source]
|
| 451 |
+
scores = scores_sum_average * seq_lengths
|
| 452 |
+
is_stopped = is_stopped[next_tokens_source]
|
| 453 |
+
next_token_embed = self.clapcap.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
|
| 454 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 455 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
| 456 |
+
if is_stopped.all():
|
| 457 |
+
break
|
| 458 |
+
scores = scores / seq_lengths
|
| 459 |
+
output_list = tokens.cpu().numpy()
|
| 460 |
+
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
|
| 461 |
+
order = scores.argsort(descending=True)
|
| 462 |
+
output_texts = [output_texts[i] for i in order]
|
| 463 |
+
return output_texts
|
models/audio-flamingo-1/chat/configs/chat.yaml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train_config:
|
| 2 |
+
expdir: YOUR_DATA_ROOT_DIR/audio-flamingo-data/checkpoint
|
| 3 |
+
run_name: chat
|
| 4 |
+
delete_previous_checkpoint: false
|
| 5 |
+
batch_size: 4
|
| 6 |
+
gradient_accumulation_steps: 4 # global batchsize = 128
|
| 7 |
+
seed: 42
|
| 8 |
+
learning_rate: 0.00002
|
| 9 |
+
lr_scheduler: constant
|
| 10 |
+
loss_multiplier: 1.0
|
| 11 |
+
warmup_steps: 1875
|
| 12 |
+
weight_decay: 0.1
|
| 13 |
+
precision: fp32
|
| 14 |
+
gradient_checkpointing: False
|
| 15 |
+
num_epochs: 1
|
| 16 |
+
offline: false
|
| 17 |
+
freeze_lm_embeddings: false
|
| 18 |
+
logging_steps: 10
|
| 19 |
+
dist_backend: nccl
|
| 20 |
+
dist_url: env://
|
| 21 |
+
no_set_device_rank: false
|
| 22 |
+
fsdp: true
|
| 23 |
+
fsdp_use_orig_params: false # Passed into the FSDP constructor. Enables param_groups and gradient masking for weight_decay. Does not work with OPT.
|
| 24 |
+
fsdp_sharding_strategy: full # full, hybrid
|
| 25 |
+
horovod: false
|
| 26 |
+
|
| 27 |
+
# Chat SFT hparams
|
| 28 |
+
sft_config:
|
| 29 |
+
pretrained_path: YOUR_DATA_ROOT_DIR/audio-flamingo-data/checkpoint/foundation_sft_4_shot/
|
| 30 |
+
pretrained_ckpt: checkpoint_99.pt
|
| 31 |
+
unfreeze_full_lm: true
|
| 32 |
+
|
| 33 |
+
data_config:
|
| 34 |
+
dataset_blending_global_weight: 1.0
|
| 35 |
+
|
| 36 |
+
dataset_blending_config:
|
| 37 |
+
dialog_AudioSetSL-Dialog/train:
|
| 38 |
+
weight: 1.0
|
| 39 |
+
prefix_prob: 1.0
|
| 40 |
+
|
| 41 |
+
dialog_MusicCaps-Dialog/train:
|
| 42 |
+
weight: 5.0
|
| 43 |
+
prefix_prob: 1.0
|
| 44 |
+
|
| 45 |
+
dataset_file_root: YOUR_DATA_ROOT_DIR/audio-flamingo-data/dataset_files
|
| 46 |
+
data_root: YOUR_DATA_ROOT_DIR/datasets
|
| 47 |
+
dataset_blending_output: dataset_blending.json
|
| 48 |
+
max_tokens: 512
|
| 49 |
+
num_workers: 4
|
| 50 |
+
|
| 51 |
+
clap_config:
|
| 52 |
+
# method: laion-clap
|
| 53 |
+
# audio_embed_dim: 512
|
| 54 |
+
# model_name: 630k-fusion-best
|
| 55 |
+
# checkpoint: YOUR_DATA_ROOT_DIR/audio-flamingo-data/laion-clap-pretrained/laion_clap/630k-fusion-best.pt
|
| 56 |
+
|
| 57 |
+
method: microsoft-clap
|
| 58 |
+
audio_embed_dim: 1024
|
| 59 |
+
config_root: YOUR_REPO_ROOT_DIR/chat/my_ms_clap/src/configs
|
| 60 |
+
model_name: 'clapcap'
|
| 61 |
+
checkpoint: YOUR_DATA_ROOT_DIR/audio-flamingo-data/clap/clapcap_weights_2023.pth
|
| 62 |
+
|
| 63 |
+
window_length: 7.0 # seconds
|
| 64 |
+
window_overlap: 5.25 # seconds
|
| 65 |
+
max_num_window: 16 # total = 33.25 seconds
|
| 66 |
+
max_num_fewshot: 4 # number of fewshot samples
|
| 67 |
+
|
| 68 |
+
model_config:
|
| 69 |
+
cache_dir: YOUR_DATA_ROOT_DIR/audio-flamingo-data/LLM_pretrained/.cache
|
| 70 |
+
|
| 71 |
+
lang_encoder_path: facebook/opt-iml-max-1.3b
|
| 72 |
+
tokenizer_path: facebook/opt-iml-max-1.3b
|
| 73 |
+
cross_attn_every_n_layers: 1
|
| 74 |
+
audio_transformer_kwargs: {
|
| 75 |
+
n_head: 8,
|
| 76 |
+
n_layers: 3,
|
| 77 |
+
d_inner: 2048,
|
| 78 |
+
max_num_media: 128, # must be >= max_num_window * num_fewshot_samples (4)
|
| 79 |
+
max_window_per_audio: 16, # must = max_num_window
|
| 80 |
+
}
|
models/audio-flamingo-1/chat/data/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Preparation
|
| 2 |
+
|
| 3 |
+
Data preparation and loading is a challenging part in this codebase as complex formats are used. Below are the instructions to prepare dataset manifests.
|
| 4 |
+
|
| 5 |
+
## Step 1: Download raw datasets
|
| 6 |
+
|
| 7 |
+
Download datasets from their original sources, or prepare your own datasets. For simplicity, in this repo, we assume datasets are stored under ```YOUR_DATA_ROOT_DIR/datasets/<dataset_name>```.
|
| 8 |
+
|
| 9 |
+
## Step 2: Prepare dialogues
|
| 10 |
+
|
| 11 |
+
Follow the instructions in Appendix B in our paper to generate dialogues from rich metadata and filter for quality.
|
| 12 |
+
|
| 13 |
+
## Step 3: Prepare raw datasets into manifests
|
| 14 |
+
|
| 15 |
+
- Modify the ```prepare_files()``` function in ```prepare_each_dataset.py``` based on your raw dataset files.
|
| 16 |
+
- For each dataset, this function generates manifests for each split (train/val/test). The manifest is stored under ```YOUR_DATA_ROOT_DIR/audio-flamingo-data/dataset_files/```. The filenames are in the format of ```<dataset_name>-Dialog/train.json```.
|
| 17 |
+
- The ```<dataset_name>``` used in Audio Flamingo can be found in ```configs/*.yaml``` --> data_config --> dataset_blending_config.
|
| 18 |
+
- The structure of manifests can be found within the ```prepare_files()``` function.
|
| 19 |
+
- Usage: ```python prepare_each_dataset.py -d <dataset_name>```.
|
models/audio-flamingo-1/chat/data/data.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import functools
|
| 8 |
+
import io
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
| 13 |
+
import random
|
| 14 |
+
import re
|
| 15 |
+
import string
|
| 16 |
+
import subprocess
|
| 17 |
+
import sys
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from copy import deepcopy
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from functools import partial
|
| 26 |
+
from pydub import AudioSegment
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torchvision
|
| 31 |
+
from torch.utils.data import DataLoader, Dataset, get_worker_info
|
| 32 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
from transformers import AutoTokenizer
|
| 36 |
+
|
| 37 |
+
import librosa
|
| 38 |
+
import soundfile as sf
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def int16_to_float32(x):
|
| 42 |
+
return (x / 32767.0).astype(np.float32)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def float32_to_int16(x):
|
| 46 |
+
x = np.clip(x, a_min=-1., a_max=1.)
|
| 47 |
+
return (x * 32767.).astype(np.int16)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DataCollator:
|
| 51 |
+
def __init__(self, tokenizer):
|
| 52 |
+
self.tokenizer = tokenizer
|
| 53 |
+
|
| 54 |
+
def __call__(self, batch):
|
| 55 |
+
filenames, audio_clips, audio_embed_mask, input_ids, attention_masks = zip(*batch)
|
| 56 |
+
|
| 57 |
+
audio_clips = torch.cat([x.unsqueeze(0) for x in audio_clips], dim=0)
|
| 58 |
+
audio_embed_mask = torch.cat([x.unsqueeze(0) for x in audio_embed_mask], dim=0)
|
| 59 |
+
|
| 60 |
+
max_length = max([ids.shape[1] for ids in input_ids])
|
| 61 |
+
|
| 62 |
+
padded_input_ids = []
|
| 63 |
+
padded_attention_masks = []
|
| 64 |
+
for ids, mask in zip(input_ids, attention_masks):
|
| 65 |
+
if ids.shape[1] < max_length:
|
| 66 |
+
padded_input_ids.append(
|
| 67 |
+
torch.cat([ids, torch.LongTensor([self.tokenizer.pad_token_id] * (max_length - ids.shape[1])).unsqueeze(0)], dim=1)
|
| 68 |
+
)
|
| 69 |
+
padded_attention_masks.append(
|
| 70 |
+
torch.cat([mask, torch.LongTensor([0] * (max_length - mask.shape[1])).unsqueeze(0)], dim=1)
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
padded_input_ids.append(ids)
|
| 74 |
+
padded_attention_masks.append(mask)
|
| 75 |
+
|
| 76 |
+
padded_input_ids = torch.cat(padded_input_ids, dim=0)
|
| 77 |
+
padded_attention_masks = torch.cat(padded_attention_masks, dim=0).bool()
|
| 78 |
+
|
| 79 |
+
out_dict = dict(
|
| 80 |
+
filenames=filenames,
|
| 81 |
+
audio_clips=audio_clips,
|
| 82 |
+
audio_embed_mask=audio_embed_mask,
|
| 83 |
+
input_ids=padded_input_ids,
|
| 84 |
+
attention_mask=padded_attention_masks
|
| 85 |
+
)
|
| 86 |
+
return out_dict
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class AudioTextData(torch.utils.data.Dataset):
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
dataset_file_root: str,
|
| 93 |
+
data_root: str,
|
| 94 |
+
clap_config: dict,
|
| 95 |
+
dataset_blending_global_weight: float,
|
| 96 |
+
dataset_blending_config: dict,
|
| 97 |
+
dataset_blending_output: str,
|
| 98 |
+
tokenizer,
|
| 99 |
+
max_tokens: int,
|
| 100 |
+
split: str = 'train',
|
| 101 |
+
epoch: int = 0,
|
| 102 |
+
force_reblend: bool = False,
|
| 103 |
+
**kwargs
|
| 104 |
+
):
|
| 105 |
+
self.dataset_file_root = dataset_file_root
|
| 106 |
+
self.data_root = data_root
|
| 107 |
+
self.clap_config = clap_config
|
| 108 |
+
self.dataset_blending_global_weight = dataset_blending_global_weight
|
| 109 |
+
self.dataset_blending_config = dataset_blending_config
|
| 110 |
+
|
| 111 |
+
self.split = split
|
| 112 |
+
self.epoch = epoch
|
| 113 |
+
self.force_reblend = force_reblend
|
| 114 |
+
|
| 115 |
+
assert self.split == 'train'
|
| 116 |
+
self.data = self.blend_dataset(dataset_blending_config, dataset_blending_output)
|
| 117 |
+
|
| 118 |
+
self.tokenizer = tokenizer
|
| 119 |
+
self.tokenizer.padding_side = "right"
|
| 120 |
+
self.max_tokens = max_tokens
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def shuffle_dict_fixed_rand(dic, seed=0):
|
| 124 |
+
print('randomly shuffling key-value pairs')
|
| 125 |
+
|
| 126 |
+
local_random = np.random.default_rng(seed)
|
| 127 |
+
original_keys = list(dic.keys())
|
| 128 |
+
shuffled_keys = deepcopy(original_keys)
|
| 129 |
+
local_random.shuffle(shuffled_keys)
|
| 130 |
+
shuffling_mapping = {x: y for (x, y) in zip(original_keys, shuffled_keys)}
|
| 131 |
+
|
| 132 |
+
shuffled_dic = {}
|
| 133 |
+
for idx in original_keys:
|
| 134 |
+
shuffled_idx = shuffling_mapping[idx]
|
| 135 |
+
shuffled_dic[idx] = dic[shuffled_idx]
|
| 136 |
+
return shuffled_dic
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def is_broken_file(audiopath):
|
| 140 |
+
# write your broken file paths here
|
| 141 |
+
BROKEN_FILES = []
|
| 142 |
+
return audiopath in BROKEN_FILES
|
| 143 |
+
|
| 144 |
+
def _read_dataset_file(self, dataset_file):
|
| 145 |
+
print("reading", dataset_file)
|
| 146 |
+
with open(dataset_file) as f:
|
| 147 |
+
contents = f.read()
|
| 148 |
+
contents = json.loads(contents)
|
| 149 |
+
|
| 150 |
+
assert contents["dataset_path"].startswith(self.data_root)
|
| 151 |
+
rel_path = contents["dataset_path"][len(self.data_root):]
|
| 152 |
+
if rel_path.startswith('/'):
|
| 153 |
+
rel_path = rel_path[1:]
|
| 154 |
+
if contents['split_path'] is not None:
|
| 155 |
+
rel_path = os.path.join(rel_path, contents['split_path'])
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
contents["data"] = {
|
| 159 |
+
"0": {'name': name (xxx.wav), 'dialogue': [
|
| 160 |
+
{"user": question 1, "assistant": answer 1},
|
| 161 |
+
...
|
| 162 |
+
{"user": question k, "assistant": answer k}
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
"1": {'name': name (xxx.wav), 'dialogue': [
|
| 166 |
+
{"user": question 1, "assistant": answer 1},
|
| 167 |
+
...
|
| 168 |
+
{"user": question k, "assistant": answer k}
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
...
|
| 172 |
+
"total_num-1": {'name': name (xxx.wav), 'dialogue': [
|
| 173 |
+
{"user": question 1, "assistant": answer 1},
|
| 174 |
+
...
|
| 175 |
+
{"user": question k, "assistant": answer k}
|
| 176 |
+
]
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
for idx in contents["data"]:
|
| 182 |
+
contents["data"][idx]['task'] = contents["flamingo_task"]
|
| 183 |
+
contents["data"][idx]['name'] = os.path.join(
|
| 184 |
+
rel_path, contents["data"][idx]['name']
|
| 185 |
+
)
|
| 186 |
+
return contents
|
| 187 |
+
|
| 188 |
+
def blend_dataset(self, dataset_blending_config, dataset_blending_output):
|
| 189 |
+
if os.path.exists(dataset_blending_output) and not self.force_reblend:
|
| 190 |
+
print("loading blended dataset file from:", dataset_blending_output)
|
| 191 |
+
with open(dataset_blending_output) as f:
|
| 192 |
+
contents = f.read()
|
| 193 |
+
self_data = json.loads(contents)
|
| 194 |
+
|
| 195 |
+
else:
|
| 196 |
+
if not self.force_reblend:
|
| 197 |
+
print("no blended dataset file found; reading all dataset files")
|
| 198 |
+
else:
|
| 199 |
+
print("force reblending dataset at epoch {}; reading all dataset files".format(self.epoch))
|
| 200 |
+
|
| 201 |
+
all_data = {}
|
| 202 |
+
for dataset_name in dataset_blending_config:
|
| 203 |
+
dataset_file = os.path.join(self.dataset_file_root, '{}.json'.format(dataset_name))
|
| 204 |
+
contents = self._read_dataset_file(dataset_file)
|
| 205 |
+
contents['data'] = self.shuffle_dict_fixed_rand(
|
| 206 |
+
contents['data'],
|
| 207 |
+
seed=sum(list(map(ord, dataset_name)))
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
weight_global = float(self.dataset_blending_global_weight)
|
| 211 |
+
weight_dataset = float(dataset_blending_config[dataset_name]["weight"])
|
| 212 |
+
weight = weight_global * weight_dataset
|
| 213 |
+
|
| 214 |
+
all_data[dataset_name] = {
|
| 215 |
+
"contents": contents,
|
| 216 |
+
"weight": weight
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
self_data = {
|
| 220 |
+
"dataset_path": self.data_root,
|
| 221 |
+
"split_path": None,
|
| 222 |
+
"total_num": 0,
|
| 223 |
+
"data": {}
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
for dataset_name in all_data:
|
| 227 |
+
print('blending {}'.format(dataset_name))
|
| 228 |
+
|
| 229 |
+
contents = all_data[dataset_name]["contents"]
|
| 230 |
+
shuffled_contents_data = contents['data']
|
| 231 |
+
weight = all_data[dataset_name]["weight"]
|
| 232 |
+
assert type(weight) == float and weight > 0.0
|
| 233 |
+
|
| 234 |
+
dataset_total_num = contents['total_num']
|
| 235 |
+
start_idx = int(self.epoch * dataset_total_num * weight)
|
| 236 |
+
end_idx = int((self.epoch + 1) * dataset_total_num * weight)
|
| 237 |
+
|
| 238 |
+
for idx in range(start_idx, end_idx):
|
| 239 |
+
if idx > 0 and idx % dataset_total_num == 0:
|
| 240 |
+
print('force shuffling at new epoch {} for dataset {}'.format(idx // dataset_total_num, dataset_name))
|
| 241 |
+
shuffled_contents_data = self.shuffle_dict_fixed_rand(
|
| 242 |
+
contents['data'],
|
| 243 |
+
seed=sum(list(map(ord, '{}-epoch-{}'.format(dataset_name, idx // dataset_total_num))))
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
key = str(idx % dataset_total_num)
|
| 247 |
+
item = shuffled_contents_data[key]
|
| 248 |
+
|
| 249 |
+
found_broken = False
|
| 250 |
+
assert type(item['name']) is str
|
| 251 |
+
audiopath = os.path.join(self.data_root, item['name'])
|
| 252 |
+
if self.is_broken_file(audiopath):
|
| 253 |
+
print('cannot read {}'.format(audiopath))
|
| 254 |
+
found_broken = True
|
| 255 |
+
|
| 256 |
+
if found_broken:
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
self_data['data'][self_data['total_num']] = item
|
| 260 |
+
self_data['total_num'] += 1
|
| 261 |
+
|
| 262 |
+
if not self.force_reblend:
|
| 263 |
+
print('writing blended dataset file to:', dataset_blending_output)
|
| 264 |
+
with open(dataset_blending_output, 'w') as json_file:
|
| 265 |
+
json.dump(self_data, json_file)
|
| 266 |
+
else:
|
| 267 |
+
print('writing reblended dataset file to:', dataset_blending_output.replace('.json', '-reblended.json'))
|
| 268 |
+
with open(dataset_blending_output.replace('.json', '-reblended.json'), 'w') as json_file:
|
| 269 |
+
json.dump(self_data, json_file)
|
| 270 |
+
|
| 271 |
+
return self_data
|
| 272 |
+
|
| 273 |
+
def get_num_windows(self, T, sr):
|
| 274 |
+
clap_config = self.clap_config
|
| 275 |
+
window_length = int(float(clap_config["window_length"]) * sr)
|
| 276 |
+
window_overlap = int(float(clap_config["window_overlap"]) * sr)
|
| 277 |
+
max_num_window = int(clap_config["max_num_window"])
|
| 278 |
+
|
| 279 |
+
num_windows = 1
|
| 280 |
+
if T <= window_length:
|
| 281 |
+
num_windows = 1
|
| 282 |
+
full_length = window_length
|
| 283 |
+
elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap):
|
| 284 |
+
num_windows = max_num_window
|
| 285 |
+
full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap)
|
| 286 |
+
else:
|
| 287 |
+
num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap)))
|
| 288 |
+
full_length = num_windows * window_length - (num_windows - 1) * window_overlap
|
| 289 |
+
|
| 290 |
+
return num_windows, full_length
|
| 291 |
+
|
| 292 |
+
def load_audio(self, file_path, target_sr=44100, duration=30.0, start=0.0):
|
| 293 |
+
if file_path.endswith('.mp3'):
|
| 294 |
+
audio = AudioSegment.from_file(file_path)
|
| 295 |
+
if len(audio) > (start + duration) * 1000:
|
| 296 |
+
audio = audio[start * 1000:(start + duration) * 1000]
|
| 297 |
+
|
| 298 |
+
if audio.frame_rate != target_sr:
|
| 299 |
+
audio = audio.set_frame_rate(target_sr)
|
| 300 |
+
|
| 301 |
+
if audio.channels > 1:
|
| 302 |
+
audio = audio.set_channels(1)
|
| 303 |
+
|
| 304 |
+
data = np.array(audio.get_array_of_samples())
|
| 305 |
+
if audio.sample_width == 2:
|
| 306 |
+
data = data.astype(np.float32) / np.iinfo(np.int16).max
|
| 307 |
+
elif audio.sample_width == 4:
|
| 308 |
+
data = data.astype(np.float32) / np.iinfo(np.int32).max
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError("Unsupported bit depth: {}".format(audio.sample_width))
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
with sf.SoundFile(file_path) as audio:
|
| 314 |
+
original_sr = audio.samplerate
|
| 315 |
+
channels = audio.channels
|
| 316 |
+
|
| 317 |
+
max_frames = int((start + duration) * original_sr)
|
| 318 |
+
|
| 319 |
+
audio.seek(int(start * original_sr))
|
| 320 |
+
frames_to_read = min(max_frames, len(audio))
|
| 321 |
+
data = audio.read(frames_to_read)
|
| 322 |
+
|
| 323 |
+
if data.max() > 1 or data.min() < -1:
|
| 324 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
| 325 |
+
|
| 326 |
+
if original_sr != target_sr:
|
| 327 |
+
if channels == 1:
|
| 328 |
+
data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr)
|
| 329 |
+
else:
|
| 330 |
+
data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0]
|
| 331 |
+
else:
|
| 332 |
+
if channels != 1:
|
| 333 |
+
data = data.T[0]
|
| 334 |
+
|
| 335 |
+
if data.min() >= 0:
|
| 336 |
+
data = 2 * data / abs(data.max()) - 1.0
|
| 337 |
+
else:
|
| 338 |
+
data = data / max(abs(data.max()), abs(data.min()))
|
| 339 |
+
|
| 340 |
+
assert len(data.shape) == 1, data.shape
|
| 341 |
+
return data
|
| 342 |
+
|
| 343 |
+
def compute_sliding_window(self, audio_file, audio_start=0.0):
|
| 344 |
+
if type(audio_start) == str:
|
| 345 |
+
audio_start = float(audio_start)
|
| 346 |
+
|
| 347 |
+
clap_config = self.clap_config
|
| 348 |
+
|
| 349 |
+
if clap_config["method"] == 'laion-clap':
|
| 350 |
+
sr = 48000
|
| 351 |
+
elif clap_config["method"] == 'microsoft-clap':
|
| 352 |
+
sr = 44100
|
| 353 |
+
else:
|
| 354 |
+
raise NotImplementedError
|
| 355 |
+
|
| 356 |
+
window_length = int(float(clap_config["window_length"]) * sr)
|
| 357 |
+
window_overlap = int(float(clap_config["window_overlap"]) * sr)
|
| 358 |
+
max_num_window = int(clap_config["max_num_window"])
|
| 359 |
+
duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"]
|
| 360 |
+
|
| 361 |
+
audio_data = self.load_audio(os.path.join(self.data_root, audio_file), sr, duration, audio_start)
|
| 362 |
+
T = len(audio_data)
|
| 363 |
+
num_windows, full_length = self.get_num_windows(T, sr)
|
| 364 |
+
|
| 365 |
+
if full_length > T:
|
| 366 |
+
audio_data = np.append(audio_data, np.zeros(full_length - T))
|
| 367 |
+
audio_data = audio_data.reshape(1, -1)
|
| 368 |
+
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float()
|
| 369 |
+
|
| 370 |
+
audio_clips = []
|
| 371 |
+
audio_embed_mask = torch.zeros(max_num_window)
|
| 372 |
+
for i in range(num_windows):
|
| 373 |
+
start = i * (window_length - window_overlap)
|
| 374 |
+
audio_clips.append(audio_data_tensor[:, start:start+window_length])
|
| 375 |
+
audio_embed_mask[i] = 1
|
| 376 |
+
|
| 377 |
+
assert sum(audio_embed_mask) == num_windows
|
| 378 |
+
|
| 379 |
+
if num_windows < max_num_window:
|
| 380 |
+
for _ in range(max_num_window - num_windows):
|
| 381 |
+
audio_clips.append(torch.zeros_like(audio_clips[-1]))
|
| 382 |
+
|
| 383 |
+
audio_clips = torch.cat(audio_clips) # (max_num_window, window_length * sr) cuda tensor
|
| 384 |
+
|
| 385 |
+
return audio_clips, audio_embed_mask
|
| 386 |
+
|
| 387 |
+
def preprocess_string_for_eval(self, x):
|
| 388 |
+
x = x.rstrip().lstrip()
|
| 389 |
+
x = x.lower()
|
| 390 |
+
return x
|
| 391 |
+
|
| 392 |
+
def __getitem__(self, i):
|
| 393 |
+
try:
|
| 394 |
+
item = self.data['data'][str(i)]
|
| 395 |
+
except:
|
| 396 |
+
item = self.data['data'][i]
|
| 397 |
+
|
| 398 |
+
assert type(item['name']) is str
|
| 399 |
+
audio_files = [os.path.join(self.data_root, item['name'])]
|
| 400 |
+
audio_starts = [0 if 'audio_start' not in item else float(item['audio_start'])]
|
| 401 |
+
|
| 402 |
+
audio_clips, audio_embed_mask = [], []
|
| 403 |
+
for audio_file, audio_start in zip(audio_files, audio_starts):
|
| 404 |
+
this_audio_clips, this_audio_embed_mask = self.compute_sliding_window(audio_file, audio_start)
|
| 405 |
+
audio_clips.append(this_audio_clips)
|
| 406 |
+
audio_embed_mask.append(this_audio_embed_mask)
|
| 407 |
+
|
| 408 |
+
audio_clips = torch.cat(audio_clips)
|
| 409 |
+
audio_embed_mask = torch.cat(audio_embed_mask)
|
| 410 |
+
|
| 411 |
+
correct_num_windows = int(self.clap_config["max_num_window"]) * int(self.clap_config["max_num_fewshot"])
|
| 412 |
+
if len(audio_clips) < correct_num_windows:
|
| 413 |
+
audio_clips = torch.cat([
|
| 414 |
+
audio_clips,
|
| 415 |
+
torch.zeros(correct_num_windows - len(audio_clips), audio_clips.shape[1])
|
| 416 |
+
])
|
| 417 |
+
audio_embed_mask = torch.cat([
|
| 418 |
+
audio_embed_mask,
|
| 419 |
+
torch.zeros(correct_num_windows - len(audio_embed_mask))
|
| 420 |
+
])
|
| 421 |
+
|
| 422 |
+
audio_clips.requires_grad = False
|
| 423 |
+
audio_embed_mask.requires_grad = False
|
| 424 |
+
|
| 425 |
+
assert 'dialogue' in item
|
| 426 |
+
dialogue = item['dialogue']
|
| 427 |
+
prefix = 'The task is dialog. '
|
| 428 |
+
sample = f"{self.tokenizer.bos_token}{prefix}<audio>"
|
| 429 |
+
for each_round in dialogue:
|
| 430 |
+
user_content, assistant_content = each_round['user'], each_round['assistant']
|
| 431 |
+
sample = sample + f"user: {user_content} \nassistant: {self.tokenizer.sep_token}{assistant_content}<|endofchunk|>{self.tokenizer.eos_token}\n"
|
| 432 |
+
|
| 433 |
+
text = self.tokenizer(
|
| 434 |
+
sample,
|
| 435 |
+
max_length=self.max_tokens,
|
| 436 |
+
padding="longest",
|
| 437 |
+
truncation="only_first",
|
| 438 |
+
return_tensors="pt"
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
return (item['name'], audio_clips, audio_embed_mask, text["input_ids"], text["attention_mask"])
|
| 442 |
+
|
| 443 |
+
def __len__(self):
|
| 444 |
+
return len(list(self.data['data'].keys()))
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@dataclass
|
| 448 |
+
class DataInfo:
|
| 449 |
+
dataset: Dataset
|
| 450 |
+
dataloader: DataLoader
|
| 451 |
+
sampler: DistributedSampler = None
|
| 452 |
+
|
| 453 |
+
def set_epoch(self, epoch):
|
| 454 |
+
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
|
| 455 |
+
self.sampler.set_epoch(epoch)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def get_audiotext_dataloader(data_config, clap_config, text_tokenizer, batch_size, split='train', epoch=0, force_reblend=False):
|
| 459 |
+
assert split == 'train'
|
| 460 |
+
|
| 461 |
+
data_collator = DataCollator(text_tokenizer)
|
| 462 |
+
dataloader_shuffle = False
|
| 463 |
+
|
| 464 |
+
trainset = AudioTextData(
|
| 465 |
+
**data_config,
|
| 466 |
+
clap_config=clap_config,
|
| 467 |
+
tokenizer=text_tokenizer,
|
| 468 |
+
split=split,
|
| 469 |
+
epoch=epoch,
|
| 470 |
+
force_reblend=force_reblend
|
| 471 |
+
)
|
| 472 |
+
sampler = DistributedSampler(trainset, shuffle=True)
|
| 473 |
+
trainloader = DataLoader(
|
| 474 |
+
trainset,
|
| 475 |
+
sampler=sampler,
|
| 476 |
+
batch_size=batch_size,
|
| 477 |
+
shuffle=dataloader_shuffle,
|
| 478 |
+
collate_fn=data_collator,
|
| 479 |
+
num_workers=data_config["num_workers"]
|
| 480 |
+
)
|
| 481 |
+
return DataInfo(dataset=trainset, dataloader=trainloader, sampler=sampler)
|
models/audio-flamingo-1/chat/data/prepare_each_dataset.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import csv
|
| 7 |
+
import yaml
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import pickle
|
| 10 |
+
import glob
|
| 11 |
+
import math
|
| 12 |
+
from functools import partial
|
| 13 |
+
import sys
|
| 14 |
+
import io
|
| 15 |
+
import warnings
|
| 16 |
+
import random
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
import librosa
|
| 22 |
+
from pydub import AudioSegment
|
| 23 |
+
import soundfile as sf
|
| 24 |
+
|
| 25 |
+
import faiss
|
| 26 |
+
|
| 27 |
+
import multiprocessing
|
| 28 |
+
multiprocessing.set_start_method('spawn', force=True)
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
except:
|
| 33 |
+
tqdm = lambda x: x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def filter_file(file_path, file_list, filename):
|
| 37 |
+
if file_list is not None:
|
| 38 |
+
if filename not in file_list:
|
| 39 |
+
print(filename, 'not exist')
|
| 40 |
+
return True
|
| 41 |
+
else:
|
| 42 |
+
if not os.path.exists(os.path.join(file_path, filename)):
|
| 43 |
+
print(filename, 'not exist')
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
if os.path.getsize(os.path.join(file_path, filename)) < 16000:
|
| 47 |
+
print(filename, 'less than 0.5 to 1 second')
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def filter_response(response):
|
| 54 |
+
filter_phrases_LLARK = [
|
| 55 |
+
'metadata', 'is not provided', 'based on theprovided metadata',
|
| 56 |
+
'based on the providedbeat', 'based on the provided chord',
|
| 57 |
+
'basedon the provided information', 'based on theprovided annotations',
|
| 58 |
+
'no specific mood,there is no mention of',
|
| 59 |
+
'there is no specificmention of any', 'as an ai assistant',
|
| 60 |
+
'iam unable to', 'as an ai assistant', 'i donot',
|
| 61 |
+
'it is difficult to determine', 'it isnot possible to determine',
|
| 62 |
+
'no informationis available about the album', 'cannotdetermine',
|
| 63 |
+
'violin 1', 'violin 2', 'violin 3,viola 1', 'viola 2', 'viola 3', 'pack'
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
filter_phrases_LTU = [
|
| 67 |
+
'cannot determine', 'not provided', 'cannot be determined', 'sorry', 'i cannot',
|
| 68 |
+
'without more information', 'enough information',
|
| 69 |
+
'not possible', 'more context', 'enough', 'impossible', 'cannot be determined',
|
| 70 |
+
'without additional information',
|
| 71 |
+
'unclear', 'cannot', 'not clear', 'do not provide sufficient', 'does not provide',
|
| 72 |
+
'difficult to determine', 'no information provided',
|
| 73 |
+
"can't infer", "difficult to infer", "not specified", "no specific", "no information",
|
| 74 |
+
"without additional", 'it is difficult to',
|
| 75 |
+
"no indication"
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
filter_phrases_ours = ["doesn't provide", "doesn't specify", "doesn't indicate", "based on"]
|
| 79 |
+
|
| 80 |
+
for phrase in filter_phrases_LLARK + filter_phrases_LTU + filter_phrases_ours:
|
| 81 |
+
if phrase in response.lower():
|
| 82 |
+
return True
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# !!!Important!!! please write your own code to create dataset manifests based on your stored datasets
|
| 87 |
+
# The list of dataset_name and flamingo_task can be found in configs/*.yaml --> data_config --> dataset_blending_config
|
| 88 |
+
def prepare_files(dataset_name, dataset_path, split, flamingo_task, output_file):
|
| 89 |
+
|
| 90 |
+
assert not os.path.exists(output_file)
|
| 91 |
+
dataset_dic = {
|
| 92 |
+
"dataset_path": dataset_path,
|
| 93 |
+
"split": split,
|
| 94 |
+
"split_path": None,
|
| 95 |
+
"flamingo_task": "{}-{}".format(dataset_name, flamingo_task),
|
| 96 |
+
"total_num": 0,
|
| 97 |
+
"data": {}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
"""
|
| 101 |
+
dataset_dic has the format
|
| 102 |
+
{
|
| 103 |
+
"dataset_path": YOUR_DATA_ROOT_DIR/datasets/dataset_name/,
|
| 104 |
+
"split": "train" or "test",
|
| 105 |
+
"split_path": ./,
|
| 106 |
+
"flamingo_task": <dataset_name>-Dialog,
|
| 107 |
+
"total_num": total number of samples,
|
| 108 |
+
"data": a dictionary of data manifest (see below)
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
dataset_dic["data"] has the format
|
| 112 |
+
{
|
| 113 |
+
"0": {'name': name (xxx.wav), 'dialogue': [
|
| 114 |
+
{"user": question 1, "assistant": answer 1},
|
| 115 |
+
...
|
| 116 |
+
{"user": question k, "assistant": answer k}
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
"1": {'name': name (xxx.wav), 'dialogue': [
|
| 120 |
+
{"user": question 1, "assistant": answer 1},
|
| 121 |
+
...
|
| 122 |
+
{"user": question k, "assistant": answer k}
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
...
|
| 126 |
+
"total_num-1": {'name': name (xxx.wav), 'dialogue': [
|
| 127 |
+
{"user": question 1, "assistant": answer 1},
|
| 128 |
+
...
|
| 129 |
+
{"user": question k, "assistant": answer k}
|
| 130 |
+
]
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
Note that os.path.join(dataset_path, split_path, name) is the absolute path to the audio file.
|
| 135 |
+
Note that audio files are not restricted to wav. However, mp3 is not recommended due to a different seeking mechanism.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
if dataset_name == 'dialog_AudioSetSL':
|
| 139 |
+
assert flamingo_task == "Dialog"
|
| 140 |
+
assert split == 'train'
|
| 141 |
+
map_split = lambda split: './'
|
| 142 |
+
file_path = os.path.join(
|
| 143 |
+
dataset_path,
|
| 144 |
+
map_split(split)
|
| 145 |
+
)
|
| 146 |
+
assert os.path.exists(file_path), '{} not exist'.format(file_path)
|
| 147 |
+
|
| 148 |
+
dataset_dic["split_path"] = map_split(split)
|
| 149 |
+
file_list = None
|
| 150 |
+
|
| 151 |
+
json_filename = 'dialogues_audioset_thresholded.json'
|
| 152 |
+
with open(os.path.join(dataset_path, json_filename)) as f:
|
| 153 |
+
data_list = f.read()
|
| 154 |
+
data_list = json.loads(data_list)
|
| 155 |
+
|
| 156 |
+
for data in tqdm(data_list):
|
| 157 |
+
filename = data["audio_id"]
|
| 158 |
+
if filter_file(file_path, file_list, filename):
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
dialogue = data['dialogue']
|
| 162 |
+
|
| 163 |
+
# filter bad dialog
|
| 164 |
+
discard = False
|
| 165 |
+
for each_round in dialogue:
|
| 166 |
+
if filter_response(each_round['assistant']):
|
| 167 |
+
discard = True
|
| 168 |
+
break
|
| 169 |
+
|
| 170 |
+
if not discard:
|
| 171 |
+
dataset_dic["data"][dataset_dic["total_num"]] = {
|
| 172 |
+
"name": filename,
|
| 173 |
+
"dialogue": dialogue
|
| 174 |
+
}
|
| 175 |
+
dataset_dic["total_num"] += 1
|
| 176 |
+
|
| 177 |
+
elif dataset_name == 'dialog_MusicCaps':
|
| 178 |
+
assert flamingo_task == "Dialog"
|
| 179 |
+
assert split == 'train'
|
| 180 |
+
map_split = lambda split: './'
|
| 181 |
+
file_path = os.path.join(
|
| 182 |
+
dataset_path,
|
| 183 |
+
map_split(split)
|
| 184 |
+
)
|
| 185 |
+
assert os.path.exists(file_path), '{} not exist'.format(file_path)
|
| 186 |
+
|
| 187 |
+
dataset_dic["split_path"] = map_split(split)
|
| 188 |
+
file_list = None
|
| 189 |
+
|
| 190 |
+
json_filename = 'dialogues_musiccaps_thresholded.json'
|
| 191 |
+
with open(os.path.join(dataset_path, json_filename)) as f:
|
| 192 |
+
data_list = f.read()
|
| 193 |
+
data_list = json.loads(data_list)
|
| 194 |
+
|
| 195 |
+
for data in tqdm(data_list):
|
| 196 |
+
filename = data["audio_id"]
|
| 197 |
+
if filter_file(file_path, file_list, filename):
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
dialogue = data['dialogue']
|
| 201 |
+
|
| 202 |
+
# filter bad dialog
|
| 203 |
+
discard = False
|
| 204 |
+
for each_round in dialogue:
|
| 205 |
+
if filter_response(each_round['assistant']):
|
| 206 |
+
discard = True
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
if not discard:
|
| 210 |
+
dataset_dic["data"][dataset_dic["total_num"]] = {
|
| 211 |
+
"name": filename,
|
| 212 |
+
"dialogue": dialogue
|
| 213 |
+
}
|
| 214 |
+
dataset_dic["total_num"] += 1
|
| 215 |
+
|
| 216 |
+
with open(output_file, 'w') as json_file:
|
| 217 |
+
json.dump(dataset_dic, json_file)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
if __name__ == '__main__':
|
| 221 |
+
import argparse
|
| 222 |
+
|
| 223 |
+
parser = argparse.ArgumentParser()
|
| 224 |
+
parser.add_argument('-d', '--dataset_name', type=str, help='dataset name')
|
| 225 |
+
parser.add_argument('-f', '--flamingo_task', type=str, default='Dialog', help='flamingo task')
|
| 226 |
+
args = parser.parse_args()
|
| 227 |
+
|
| 228 |
+
global DATA_ROOT_DIR
|
| 229 |
+
DATA_ROOT_DIR = "YOUR_DATA_ROOT_DIR"
|
| 230 |
+
dataset_root = os.path.join(DATA_ROOT_DIR, "datasets")
|
| 231 |
+
output_root = os.path.join(DATA_ROOT_DIR, "audio-flamingo-data/dataset_files")
|
| 232 |
+
os.makedirs(output_root, exist_ok=True)
|
| 233 |
+
|
| 234 |
+
dataset_name = args.dataset_name # dialog_AudioSetSL, dialog_MusicCaps
|
| 235 |
+
flamingo_task = args.flamingo_task # Dialog
|
| 236 |
+
|
| 237 |
+
split = 'train'
|
| 238 |
+
dataset_path = os.path.join(dataset_root, dataset_name)
|
| 239 |
+
|
| 240 |
+
output_folder = '{}-{}'.format(dataset_name, flamingo_task)
|
| 241 |
+
os.makedirs(os.path.join(output_root, output_folder), exist_ok=True)
|
| 242 |
+
|
| 243 |
+
dataset_file = os.path.join(output_root, output_folder, '{}.json'.format(split))
|
| 244 |
+
if not os.path.exists(dataset_file):
|
| 245 |
+
try:
|
| 246 |
+
prepare_files(dataset_name, dataset_path, split, flamingo_task, dataset_file)
|
| 247 |
+
except AssertionError as e:
|
| 248 |
+
print('split {} not exist for {}: {}'.format(split, dataset_name, e))
|
| 249 |
+
continue
|
| 250 |
+
else:
|
| 251 |
+
print('{} exists; exiting'.format(dataset_file))
|
| 252 |
+
|
| 253 |
+
|
models/audio-flamingo-1/chat/src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
models/audio-flamingo-1/chat/src/factory.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import sys
|
| 8 |
+
sys.path.append('../')
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
from my_laion_clap.CLAP.src.laion_clap.hook import CLAP_Module
|
| 15 |
+
from my_ms_clap.src.CLAPWrapper import CLAPWrapper
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from .flamingo import Flamingo
|
| 22 |
+
from .flamingo_lm import FlamingoLMMixin
|
| 23 |
+
from .utils import extend_instance
|
| 24 |
+
except:
|
| 25 |
+
from flamingo import Flamingo
|
| 26 |
+
from flamingo_lm import FlamingoLMMixin
|
| 27 |
+
from utils import extend_instance
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CLAP(nn.Module):
|
| 31 |
+
def __init__(self, clap_config):
|
| 32 |
+
super(CLAP, self).__init__()
|
| 33 |
+
self.method = clap_config["method"]
|
| 34 |
+
device_id = f'cuda:{torch.cuda.current_device()}'
|
| 35 |
+
|
| 36 |
+
if ('finetune' in clap_config) and clap_config['finetune']:
|
| 37 |
+
self.finetune = True
|
| 38 |
+
print('Finetuning CLAP encoder as well!')
|
| 39 |
+
else:
|
| 40 |
+
self.finetune = False
|
| 41 |
+
|
| 42 |
+
if self.method == 'laion-clap':
|
| 43 |
+
# https://github.com/LAION-AI/CLAP
|
| 44 |
+
if clap_config["model_name"] in ['630k-audioset-best', '630k-best', '630k-audioset-fusion-best', '630k-fusion-best']:
|
| 45 |
+
amodel = 'HTSAT-tiny'
|
| 46 |
+
elif clap_config["model_name"] in ['music_speech_audioset_epoch_15_esc_89.98']:
|
| 47 |
+
amodel = 'HTSAT-base'
|
| 48 |
+
else:
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
enable_fusion = 'fusion' in clap_config["model_name"].lower()
|
| 52 |
+
self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id)
|
| 53 |
+
self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
for param in self.laion_clap.parameters():
|
| 57 |
+
param.requires_grad = self.finetune
|
| 58 |
+
|
| 59 |
+
if self.finetune:
|
| 60 |
+
self.laion_clap.train()
|
| 61 |
+
else:
|
| 62 |
+
self.laion_clap.eval()
|
| 63 |
+
|
| 64 |
+
print('loaded laion-clap model: {}'.format(clap_config["checkpoint"]))
|
| 65 |
+
|
| 66 |
+
elif self.method == 'microsoft-clap':
|
| 67 |
+
# https://github.com/microsoft/CLAP
|
| 68 |
+
self.ms_clap = CLAPWrapper(
|
| 69 |
+
clap_config["checkpoint"],
|
| 70 |
+
config_root=clap_config["config_root"],
|
| 71 |
+
version=clap_config['model_name'],
|
| 72 |
+
use_cuda=True
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if clap_config['model_name'] in ['2022', '2023']:
|
| 76 |
+
for param in self.ms_clap.clap.parameters():
|
| 77 |
+
param.requires_grad = self.finetune
|
| 78 |
+
if self.finetune:
|
| 79 |
+
self.ms_clap.clap.train()
|
| 80 |
+
else:
|
| 81 |
+
self.ms_clap.clap.eval()
|
| 82 |
+
else:
|
| 83 |
+
for param in self.ms_clap.clapcap.parameters():
|
| 84 |
+
param.requires_grad = self.finetune
|
| 85 |
+
if self.finetune:
|
| 86 |
+
self.ms_clap.clapcap.train()
|
| 87 |
+
else:
|
| 88 |
+
self.ms_clap.clapcap.eval()
|
| 89 |
+
|
| 90 |
+
print('loaded microsoft-clap model: {}'.format(clap_config["checkpoint"]))
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
def forward(self, audio_clips):
|
| 96 |
+
|
| 97 |
+
if len(audio_clips.shape) == 2:
|
| 98 |
+
audio_clips = audio_clips.unsqueeze(0)
|
| 99 |
+
assert len(audio_clips.shape) == 3
|
| 100 |
+
|
| 101 |
+
audio_embeds = []
|
| 102 |
+
for x in audio_clips:
|
| 103 |
+
if self.method == 'laion-clap':
|
| 104 |
+
audio_embed = self.laion_clap.get_audio_embedding_from_data(x=x, use_tensor=True)
|
| 105 |
+
elif self.method == 'microsoft-clap':
|
| 106 |
+
audio_embed = self.ms_clap.get_audio_embeddings_from_clips(x)
|
| 107 |
+
|
| 108 |
+
audio_embeds.append(audio_embed)
|
| 109 |
+
|
| 110 |
+
audio_embeds = torch.stack(audio_embeds, dim=0)
|
| 111 |
+
audio_embeds.requires_grad = self.finetune
|
| 112 |
+
|
| 113 |
+
return audio_embeds
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_model_and_transforms(
|
| 117 |
+
clap_config: dict,
|
| 118 |
+
lang_encoder_path: str,
|
| 119 |
+
tokenizer_path: str,
|
| 120 |
+
audio_transformer_kwargs: dict,
|
| 121 |
+
cross_attn_every_n_layers: int = 1,
|
| 122 |
+
use_local_files: bool = False,
|
| 123 |
+
decoder_layers_attr_name: str = None,
|
| 124 |
+
freeze_lm_embeddings: bool = False,
|
| 125 |
+
unfreeze_full_lm: bool = False,
|
| 126 |
+
cache_dir: Optional[str] = None,
|
| 127 |
+
**flamingo_kwargs,
|
| 128 |
+
):
|
| 129 |
+
clap = CLAP(clap_config)
|
| 130 |
+
|
| 131 |
+
text_tokenizer = AutoTokenizer.from_pretrained(
|
| 132 |
+
tokenizer_path,
|
| 133 |
+
local_files_only=use_local_files,
|
| 134 |
+
trust_remote_code=True,
|
| 135 |
+
cache_dir=cache_dir,
|
| 136 |
+
)
|
| 137 |
+
text_tokenizer.add_special_tokens(
|
| 138 |
+
{"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
|
| 139 |
+
)
|
| 140 |
+
if text_tokenizer.pad_token is None:
|
| 141 |
+
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
|
| 142 |
+
if text_tokenizer.sep_token is None:
|
| 143 |
+
text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
|
| 144 |
+
|
| 145 |
+
lang_encoder = AutoModelForCausalLM.from_pretrained(
|
| 146 |
+
lang_encoder_path,
|
| 147 |
+
local_files_only=use_local_files,
|
| 148 |
+
trust_remote_code=True,
|
| 149 |
+
cache_dir=cache_dir,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
extend_instance(lang_encoder, FlamingoLMMixin)
|
| 153 |
+
|
| 154 |
+
if decoder_layers_attr_name is None:
|
| 155 |
+
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
|
| 156 |
+
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
|
| 157 |
+
lang_encoder.resize_token_embeddings(len(text_tokenizer))
|
| 158 |
+
|
| 159 |
+
if ('finetune' in clap_config) and clap_config['finetune']:
|
| 160 |
+
unfreeze_clap = True
|
| 161 |
+
else:
|
| 162 |
+
unfreeze_clap = False
|
| 163 |
+
|
| 164 |
+
model = Flamingo(
|
| 165 |
+
clap,
|
| 166 |
+
unfreeze_clap,
|
| 167 |
+
lang_encoder,
|
| 168 |
+
text_tokenizer.encode("<|endofchunk|>")[-1],
|
| 169 |
+
text_tokenizer.encode("<audio>")[-1],
|
| 170 |
+
text_tokenizer.sep_token_id,
|
| 171 |
+
audio_embed_dim=clap_config["audio_embed_dim"],
|
| 172 |
+
audio_transformer_kwargs=audio_transformer_kwargs,
|
| 173 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
| 174 |
+
**flamingo_kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
model.requires_grad_(False)
|
| 178 |
+
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
|
| 179 |
+
|
| 180 |
+
model.audio_transformer.requires_grad_(True)
|
| 181 |
+
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
|
| 182 |
+
if not freeze_lm_embeddings:
|
| 183 |
+
model.lang_encoder.get_input_embeddings().requires_grad_(True)
|
| 184 |
+
|
| 185 |
+
if unfreeze_full_lm:
|
| 186 |
+
model.lang_encoder.requires_grad_(True)
|
| 187 |
+
|
| 188 |
+
if unfreeze_clap:
|
| 189 |
+
model.clap.requires_grad_(True)
|
| 190 |
+
|
| 191 |
+
print("Flamingo model initialized with {:,} trainable parameters (audio transformer has {:,}, LM has {:,})".format(
|
| 192 |
+
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
| 193 |
+
sum(p.numel() for p in model.audio_transformer.parameters() if p.requires_grad),
|
| 194 |
+
sum(p.numel() for p in model.lang_encoder.parameters() if p.requires_grad)
|
| 195 |
+
))
|
| 196 |
+
|
| 197 |
+
return model, text_tokenizer
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _infer_decoder_layers_attr_name(model):
|
| 201 |
+
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
|
| 202 |
+
if k.lower() in model.__class__.__name__.lower():
|
| 203 |
+
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
|
| 204 |
+
|
| 205 |
+
raise ValueError(
|
| 206 |
+
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
|
| 211 |
+
"opt": "model.decoder.layers",
|
| 212 |
+
"gptj": "transformer.h",
|
| 213 |
+
"gpt-j": "transformer.h",
|
| 214 |
+
"pythia": "gpt_neox.layers",
|
| 215 |
+
"llama": "model.layers",
|
| 216 |
+
"gptneoxforcausallm": "gpt_neox.layers",
|
| 217 |
+
"mpt": "transformer.blocks",
|
| 218 |
+
"mosaicgpt": "transformer.blocks",
|
| 219 |
+
}
|
models/audio-flamingo-1/chat/src/flamingo.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from torch.distributed.fsdp.wrap import (
|
| 12 |
+
enable_wrap,
|
| 13 |
+
wrap,
|
| 14 |
+
)
|
| 15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
+
from torch.distributed.fsdp import (
|
| 17 |
+
FullyShardedDataParallel as FSDP,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from .helpers import TransformerEncoder
|
| 22 |
+
from .utils import apply_with_stopping_condition
|
| 23 |
+
except:
|
| 24 |
+
from helpers import TransformerEncoder
|
| 25 |
+
from utils import apply_with_stopping_condition
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Flamingo(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
clap: nn.Module,
|
| 32 |
+
unfreeze_clap: bool,
|
| 33 |
+
lang_encoder: nn.Module,
|
| 34 |
+
eoc_token_id: int,
|
| 35 |
+
media_token_id: int,
|
| 36 |
+
sep_token_id: int,
|
| 37 |
+
audio_embed_dim: int,
|
| 38 |
+
audio_transformer_kwargs: dict,
|
| 39 |
+
cross_attn_every_n_layers: int = 1,
|
| 40 |
+
gradient_checkpointing: bool = False,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.eoc_token_id = eoc_token_id
|
| 45 |
+
self.media_token_id = media_token_id
|
| 46 |
+
self.sep_token_id = sep_token_id
|
| 47 |
+
self.audio_embed_dim = audio_embed_dim
|
| 48 |
+
self.clap = clap # .to(torch.cuda.current_device())
|
| 49 |
+
self.unfreeze_clap = unfreeze_clap
|
| 50 |
+
self.clap.requires_grad_(unfreeze_clap)
|
| 51 |
+
|
| 52 |
+
if hasattr(lang_encoder.config, "d_model"):
|
| 53 |
+
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
|
| 54 |
+
else:
|
| 55 |
+
self.lang_dim = lang_encoder.config.hidden_size
|
| 56 |
+
|
| 57 |
+
n_head = audio_transformer_kwargs["n_head"]
|
| 58 |
+
n_layers = audio_transformer_kwargs["n_layers"]
|
| 59 |
+
d_inner = audio_transformer_kwargs["d_inner"]
|
| 60 |
+
max_num_media = audio_transformer_kwargs["max_num_media"]
|
| 61 |
+
max_window_per_audio = audio_transformer_kwargs["max_window_per_audio"]
|
| 62 |
+
assert audio_embed_dim % n_head == 0
|
| 63 |
+
|
| 64 |
+
self.audio_transformer = TransformerEncoder(
|
| 65 |
+
d_word_vec=audio_embed_dim,
|
| 66 |
+
n_layers=n_layers,
|
| 67 |
+
n_head=n_head,
|
| 68 |
+
d_k=audio_embed_dim // n_head,
|
| 69 |
+
d_v=audio_embed_dim // n_head,
|
| 70 |
+
d_model=audio_embed_dim,
|
| 71 |
+
d_inner=d_inner,
|
| 72 |
+
dropout=0.0,
|
| 73 |
+
n_position=max_num_media,
|
| 74 |
+
scale_emb=True
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.lang_encoder = lang_encoder
|
| 78 |
+
self.lang_encoder.init_flamingo(
|
| 79 |
+
media_token_id=media_token_id,
|
| 80 |
+
lang_hidden_size=self.lang_dim,
|
| 81 |
+
audio_hidden_size=self.audio_embed_dim,
|
| 82 |
+
max_window_per_audio=max_window_per_audio,
|
| 83 |
+
cross_attn_every_n_layers=cross_attn_every_n_layers,
|
| 84 |
+
gradient_checkpointing=gradient_checkpointing,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self._use_gradient_checkpointing = gradient_checkpointing
|
| 88 |
+
self.audio_transformer._use_gradient_checkpointing = gradient_checkpointing
|
| 89 |
+
self.clap._use_gradient_checkpointing = gradient_checkpointing
|
| 90 |
+
|
| 91 |
+
def forward(
|
| 92 |
+
self,
|
| 93 |
+
audio_x: torch.Tensor,
|
| 94 |
+
audio_x_mask: torch.Tensor,
|
| 95 |
+
lang_x: torch.Tensor,
|
| 96 |
+
attention_mask: torch.Tensor = None,
|
| 97 |
+
labels: torch.Tensor = None,
|
| 98 |
+
clear_conditioned_layers: bool = True,
|
| 99 |
+
past_key_values=None,
|
| 100 |
+
use_cache: bool = False,
|
| 101 |
+
):
|
| 102 |
+
assert (
|
| 103 |
+
self.lang_encoder.initialized_flamingo
|
| 104 |
+
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
|
| 105 |
+
|
| 106 |
+
assert (
|
| 107 |
+
self.lang_encoder._use_cached_audio_x or audio_x is not None
|
| 108 |
+
), "Must provide either audio_x or have precached media using cache_media()."
|
| 109 |
+
|
| 110 |
+
if self.lang_encoder._use_cached_audio_x:
|
| 111 |
+
assert (
|
| 112 |
+
audio_x is None
|
| 113 |
+
), "Expect audio_x to be None when media has been cached using cache_media(). Try uncache_media() first."
|
| 114 |
+
assert self.lang_encoder.is_conditioned()
|
| 115 |
+
|
| 116 |
+
else:
|
| 117 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
| 118 |
+
self._condition_media_locations(input_ids=lang_x)
|
| 119 |
+
|
| 120 |
+
output = self.lang_encoder(
|
| 121 |
+
input_ids=lang_x,
|
| 122 |
+
attention_mask=attention_mask,
|
| 123 |
+
labels=labels,
|
| 124 |
+
past_key_values=past_key_values,
|
| 125 |
+
use_cache=use_cache,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if clear_conditioned_layers:
|
| 129 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 130 |
+
|
| 131 |
+
return output
|
| 132 |
+
|
| 133 |
+
def generate(
|
| 134 |
+
self,
|
| 135 |
+
audio_x: torch.Tensor,
|
| 136 |
+
audio_x_mask: torch.Tensor,
|
| 137 |
+
lang_x: torch.Tensor,
|
| 138 |
+
attention_mask: torch.Tensor = None,
|
| 139 |
+
**kwargs,
|
| 140 |
+
):
|
| 141 |
+
num_beams = kwargs.pop("num_beams", 1)
|
| 142 |
+
if num_beams > 1:
|
| 143 |
+
audio_x = audio_x.repeat_interleave(num_beams, dim=0)
|
| 144 |
+
|
| 145 |
+
self.lang_encoder._use_cached_audio_x = True
|
| 146 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
| 147 |
+
|
| 148 |
+
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
|
| 149 |
+
output = self.lang_encoder.generate(
|
| 150 |
+
input_ids=lang_x,
|
| 151 |
+
attention_mask=attention_mask,
|
| 152 |
+
eos_token_id=eos_token_id,
|
| 153 |
+
num_beams=num_beams,
|
| 154 |
+
**kwargs,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 158 |
+
self.lang_encoder._use_cached_audio_x = False
|
| 159 |
+
return output
|
| 160 |
+
|
| 161 |
+
def _encode_audio_x(self, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
|
| 162 |
+
"""
|
| 163 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
assert audio_x.ndim == 3, "audio_x should be of shape (B, num_window, window_length)"
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
audio_embeds = self.clap(audio_x)
|
| 170 |
+
B, L, D = audio_embeds.shape # L is number of windows, D is feature dim
|
| 171 |
+
assert D == self.audio_embed_dim
|
| 172 |
+
|
| 173 |
+
assert audio_x_mask.ndim == 2, "audio_x_mask should be of shape (B, L)"
|
| 174 |
+
|
| 175 |
+
if B > 1 and audio_x_mask.shape[0] == 1:
|
| 176 |
+
audio_x_mask = audio_x_mask.repeat(B, 1)
|
| 177 |
+
|
| 178 |
+
assert audio_x_mask.shape[0] == B and audio_x_mask.shape[1] == L, "{} != ({},{})".format(audio_x_mask.shape, B, L)
|
| 179 |
+
|
| 180 |
+
audio_x_out = self.audio_transformer(audio_embeds) # B, L, D
|
| 181 |
+
audio_x_out = audio_x_out.unsqueeze(2) # B, L, n=1, D
|
| 182 |
+
audio_x_mask = audio_x_mask.unsqueeze(2) # B, L, n=1
|
| 183 |
+
|
| 184 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 185 |
+
layer.condition_audio_x(audio_x_out, audio_x_mask)
|
| 186 |
+
|
| 187 |
+
def wrap_fsdp(self, wrapper_kwargs, device_id):
|
| 188 |
+
# unfreeze the decoder layers
|
| 189 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
| 190 |
+
block.requires_grad_(True)
|
| 191 |
+
|
| 192 |
+
# wrap in FSDP
|
| 193 |
+
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
|
| 194 |
+
self.audio_transformer = wrap(wrap(self.audio_transformer))
|
| 195 |
+
self.lang_encoder.old_decoder_blocks = nn.ModuleList(
|
| 196 |
+
wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks
|
| 197 |
+
)
|
| 198 |
+
self.lang_encoder.gated_cross_attn_layers = nn.ModuleList(
|
| 199 |
+
wrap(wrap(layer)) if layer is not None else None
|
| 200 |
+
for layer in self.lang_encoder.gated_cross_attn_layers
|
| 201 |
+
)
|
| 202 |
+
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
|
| 203 |
+
self.lang_encoder.set_input_embeddings(
|
| 204 |
+
wrap(wrap(self.lang_encoder.get_input_embeddings()))
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if hasattr(self.lang_encoder, 'set_output_embeddings'):
|
| 208 |
+
self.lang_encoder.set_output_embeddings(
|
| 209 |
+
wrap(wrap(self.lang_encoder.get_output_embeddings()))
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
print('skip wrapping output embeddings')
|
| 213 |
+
|
| 214 |
+
# manually move non-FSDP managed parameters to device_id
|
| 215 |
+
# these are all in lang_encoder
|
| 216 |
+
apply_with_stopping_condition(
|
| 217 |
+
module=self.lang_encoder,
|
| 218 |
+
apply_fn=lambda m: m.to(device_id),
|
| 219 |
+
apply_condition=lambda m: len(list(m.children())) == 0,
|
| 220 |
+
stopping_condition=lambda m: isinstance(m, FSDP),
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# clap shouldn't be wrapped; should be on each gpu
|
| 224 |
+
if self.unfreeze_clap:
|
| 225 |
+
apply_with_stopping_condition(
|
| 226 |
+
module=self.clap,
|
| 227 |
+
apply_fn=lambda m: m.to(device_id),
|
| 228 |
+
apply_condition=lambda m: len(list(m.children())) == 0,
|
| 229 |
+
stopping_condition=lambda m: isinstance(m, FSDP),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# exclude the original decoder layers from the optimizer
|
| 233 |
+
for block in self.lang_encoder.old_decoder_blocks:
|
| 234 |
+
for p in block.parameters():
|
| 235 |
+
p.exclude_from_optimizer = True
|
| 236 |
+
|
| 237 |
+
# set up clip_grad_norm_ function
|
| 238 |
+
def clip_grad_norm_(max_norm):
|
| 239 |
+
self.audio_transformer.clip_grad_norm_(max_norm)
|
| 240 |
+
for layer in self.lang_encoder.gated_cross_attn_layers:
|
| 241 |
+
if layer is not None:
|
| 242 |
+
layer.clip_grad_norm_(max_norm)
|
| 243 |
+
self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm)
|
| 244 |
+
|
| 245 |
+
self.clip_grad_norm_ = clip_grad_norm_
|
| 246 |
+
|
| 247 |
+
def _condition_media_locations(self, input_ids: torch.Tensor):
|
| 248 |
+
media_locations = (input_ids == self.media_token_id)
|
| 249 |
+
|
| 250 |
+
for layer in self.lang_encoder._get_decoder_layers():
|
| 251 |
+
layer.condition_media_locations(media_locations)
|
| 252 |
+
|
| 253 |
+
def cache_media(self, input_ids: torch.Tensor, audio_x: torch.Tensor, audio_x_mask: torch.Tensor):
|
| 254 |
+
self._encode_audio_x(audio_x=audio_x, audio_x_mask=audio_x_mask)
|
| 255 |
+
self._condition_media_locations(input_ids=input_ids)
|
| 256 |
+
self.lang_encoder._use_cached_audio_x = True
|
| 257 |
+
|
| 258 |
+
def uncache_media(self):
|
| 259 |
+
self.lang_encoder.clear_conditioned_layers()
|
| 260 |
+
self.lang_encoder._use_cached_audio_x = False
|
models/audio-flamingo-1/chat/src/flamingo_lm.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from .helpers import GatedCrossAttentionBlock
|
| 11 |
+
from .utils import getattr_recursive, setattr_recursive
|
| 12 |
+
except:
|
| 13 |
+
from helpers import GatedCrossAttentionBlock
|
| 14 |
+
from utils import getattr_recursive, setattr_recursive
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FlamingoLayer(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.gated_cross_attn_layer = gated_cross_attn_layer
|
| 27 |
+
self.decoder_layer = decoder_layer
|
| 28 |
+
self.audio_x = None
|
| 29 |
+
self.audio_x_mask = None
|
| 30 |
+
self.few_shot_mask = None
|
| 31 |
+
self.media_locations = None
|
| 32 |
+
if self.gated_cross_attn_layer is not None:
|
| 33 |
+
self.gated_cross_attn_layer._use_gradient_checkpointing = (
|
| 34 |
+
gradient_checkpointing
|
| 35 |
+
)
|
| 36 |
+
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
|
| 37 |
+
|
| 38 |
+
def is_conditioned(self) -> bool:
|
| 39 |
+
"""Check whether the layer is conditioned."""
|
| 40 |
+
return (self.audio_x is not None) and (self.audio_x_mask is not None) and (self.media_locations is not None)
|
| 41 |
+
|
| 42 |
+
def condition_audio_x(self, audio_x, audio_x_mask):
|
| 43 |
+
self.audio_x = audio_x
|
| 44 |
+
self.audio_x_mask = audio_x_mask
|
| 45 |
+
|
| 46 |
+
def condition_media_locations(self, media_locations):
|
| 47 |
+
self.media_locations = media_locations
|
| 48 |
+
|
| 49 |
+
def condition_use_cached_media(self, use_cached_media):
|
| 50 |
+
self.use_cached_media = use_cached_media
|
| 51 |
+
|
| 52 |
+
def forward(
|
| 53 |
+
self,
|
| 54 |
+
lang_x,
|
| 55 |
+
attention_mask=None,
|
| 56 |
+
**decoder_layer_kwargs,
|
| 57 |
+
):
|
| 58 |
+
if self.gated_cross_attn_layer is not None:
|
| 59 |
+
if self.audio_x is None:
|
| 60 |
+
raise ValueError("audio_x must be conditioned before forward pass")
|
| 61 |
+
|
| 62 |
+
if self.media_locations is None:
|
| 63 |
+
raise ValueError(
|
| 64 |
+
"media_locations must be conditioned before forward pass"
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
lang_x = self.gated_cross_attn_layer(
|
| 68 |
+
lang_x,
|
| 69 |
+
self.audio_x,
|
| 70 |
+
self.audio_x_mask,
|
| 71 |
+
media_locations=self.media_locations,
|
| 72 |
+
use_cached_media=self.use_cached_media,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Normal decoder layer
|
| 76 |
+
lang_x = self.decoder_layer(
|
| 77 |
+
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
|
| 78 |
+
)
|
| 79 |
+
return lang_x
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FlamingoLMMixin(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
Mixin to add cross-attention layers to a language model.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
|
| 88 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
| 89 |
+
|
| 90 |
+
def _get_decoder_layers(self):
|
| 91 |
+
return getattr_recursive(self, self.decoder_layers_attr_name)
|
| 92 |
+
|
| 93 |
+
def _set_decoder_layers(self, value):
|
| 94 |
+
setattr_recursive(self, self.decoder_layers_attr_name, value)
|
| 95 |
+
|
| 96 |
+
def init_flamingo(
|
| 97 |
+
self,
|
| 98 |
+
media_token_id,
|
| 99 |
+
lang_hidden_size,
|
| 100 |
+
audio_hidden_size,
|
| 101 |
+
max_window_per_audio,
|
| 102 |
+
cross_attn_every_n_layers,
|
| 103 |
+
gradient_checkpointing,
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
|
| 107 |
+
"""
|
| 108 |
+
self.old_decoder_blocks = self._get_decoder_layers()
|
| 109 |
+
self.gated_cross_attn_layers = nn.ModuleList(
|
| 110 |
+
[
|
| 111 |
+
GatedCrossAttentionBlock(
|
| 112 |
+
dim=lang_hidden_size,
|
| 113 |
+
dim_audio=audio_hidden_size,
|
| 114 |
+
max_window_per_audio=max_window_per_audio,
|
| 115 |
+
only_attend_immediate_media=False,
|
| 116 |
+
)
|
| 117 |
+
if (layer_idx + 1) % cross_attn_every_n_layers == 0
|
| 118 |
+
else None
|
| 119 |
+
for layer_idx, _ in enumerate(self._get_decoder_layers())
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
self.init_flamingo_layers(gradient_checkpointing)
|
| 123 |
+
self.media_token_id = media_token_id
|
| 124 |
+
self.initialized_flamingo = True
|
| 125 |
+
self._use_cached_audio_x = False
|
| 126 |
+
|
| 127 |
+
def init_flamingo_layers(self, gradient_checkpointing):
|
| 128 |
+
"""
|
| 129 |
+
Re initializes the FlamingoLayers.
|
| 130 |
+
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
|
| 131 |
+
"""
|
| 132 |
+
self._set_decoder_layers(
|
| 133 |
+
nn.ModuleList(
|
| 134 |
+
[
|
| 135 |
+
FlamingoLayer(
|
| 136 |
+
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
|
| 137 |
+
)
|
| 138 |
+
for gated_cross_attn_layer, decoder_layer in zip(
|
| 139 |
+
self.gated_cross_attn_layers, self.old_decoder_blocks
|
| 140 |
+
)
|
| 141 |
+
]
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def forward(self, input_ids, attention_mask, **kwargs):
|
| 146 |
+
"""Condition the Flamingo layers on the media locations before forward()"""
|
| 147 |
+
if not self.initialized_flamingo:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
"Flamingo layers are not initialized. Please call `init_flamingo` first."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
media_locations = input_ids == self.media_token_id
|
| 153 |
+
|
| 154 |
+
use_cached_media_locations = (
|
| 155 |
+
self._use_cached_audio_x
|
| 156 |
+
and self.is_conditioned()
|
| 157 |
+
and not media_locations.any()
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
for layer in self._get_decoder_layers():
|
| 161 |
+
if not use_cached_media_locations:
|
| 162 |
+
layer.condition_media_locations(media_locations)
|
| 163 |
+
layer.condition_use_cached_media(use_cached_media_locations)
|
| 164 |
+
|
| 165 |
+
kwargs["input_ids"] = input_ids
|
| 166 |
+
kwargs["attention_mask"] = attention_mask
|
| 167 |
+
return super().forward(**kwargs)
|
| 168 |
+
|
| 169 |
+
def is_conditioned(self) -> bool:
|
| 170 |
+
"""Check whether all decoder layers are already conditioned."""
|
| 171 |
+
return all(l.is_conditioned() for l in self._get_decoder_layers())
|
| 172 |
+
|
| 173 |
+
def clear_conditioned_layers(self):
|
| 174 |
+
for layer in self._get_decoder_layers():
|
| 175 |
+
layer.condition_audio_x(None, None)
|
| 176 |
+
layer.condition_media_locations(None)
|
| 177 |
+
layer.condition_use_cached_media(None)
|
models/audio-flamingo-1/chat/src/helpers.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license.
|
| 8 |
+
# LICENSE is in incl_licenses directory.
|
| 9 |
+
|
| 10 |
+
# Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license.
|
| 11 |
+
# LICENSE is in incl_licenses directory.
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
from einops_exts import rearrange_many
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import einsum, nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
def exists(val):
|
| 23 |
+
return val is not None
|
| 24 |
+
|
| 25 |
+
def FeedForward(dim, mult=4):
|
| 26 |
+
inner_dim = int(dim * mult)
|
| 27 |
+
return nn.Sequential(
|
| 28 |
+
nn.LayerNorm(dim),
|
| 29 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch
|
| 35 |
+
# Original Copyright 2017 Victor Huang
|
| 36 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 37 |
+
|
| 38 |
+
class ScaledDotProductAttention(nn.Module):
|
| 39 |
+
''' Scaled Dot-Product Attention '''
|
| 40 |
+
|
| 41 |
+
def __init__(self, temperature, attn_dropout=0.1):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.temperature = temperature
|
| 44 |
+
self.dropout = nn.Dropout(attn_dropout)
|
| 45 |
+
|
| 46 |
+
def forward(self, q, k, v, mask=None):
|
| 47 |
+
|
| 48 |
+
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
|
| 49 |
+
|
| 50 |
+
if mask is not None:
|
| 51 |
+
attn = attn.masked_fill(mask == 0, -1e9)
|
| 52 |
+
|
| 53 |
+
attn = self.dropout(F.softmax(attn, dim=-1))
|
| 54 |
+
output = torch.matmul(attn, v)
|
| 55 |
+
|
| 56 |
+
return output, attn
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MultiHeadAttention(nn.Module):
|
| 60 |
+
''' Multi-Head Attention module '''
|
| 61 |
+
|
| 62 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
self.n_head = n_head
|
| 66 |
+
self.d_k = d_k
|
| 67 |
+
self.d_v = d_v
|
| 68 |
+
|
| 69 |
+
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
|
| 70 |
+
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
|
| 71 |
+
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
|
| 72 |
+
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
|
| 73 |
+
|
| 74 |
+
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
|
| 75 |
+
|
| 76 |
+
self.dropout = nn.Dropout(dropout)
|
| 77 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def forward(self, q, k, v, mask=None):
|
| 81 |
+
|
| 82 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
| 83 |
+
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
|
| 84 |
+
|
| 85 |
+
residual = q
|
| 86 |
+
|
| 87 |
+
# Pass through the pre-attention projection: b x lq x (n*dv)
|
| 88 |
+
# Separate different heads: b x lq x n x dv
|
| 89 |
+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
| 90 |
+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
| 91 |
+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
| 92 |
+
|
| 93 |
+
# Transpose for attention dot product: b x n x lq x dv
|
| 94 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 95 |
+
|
| 96 |
+
if mask is not None:
|
| 97 |
+
mask = mask.unsqueeze(1) # For head axis broadcasting.
|
| 98 |
+
|
| 99 |
+
q, attn = self.attention(q, k, v, mask=mask)
|
| 100 |
+
|
| 101 |
+
# Transpose to move the head dimension back: b x lq x n x dv
|
| 102 |
+
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
|
| 103 |
+
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
|
| 104 |
+
q = self.dropout(self.fc(q))
|
| 105 |
+
q += residual
|
| 106 |
+
|
| 107 |
+
q = self.layer_norm(q)
|
| 108 |
+
|
| 109 |
+
return q, attn
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class PositionwiseFeedForward(nn.Module):
|
| 113 |
+
''' A two-feed-forward-layer module '''
|
| 114 |
+
|
| 115 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
|
| 118 |
+
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
|
| 119 |
+
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
|
| 120 |
+
self.dropout = nn.Dropout(dropout)
|
| 121 |
+
|
| 122 |
+
def forward(self, x):
|
| 123 |
+
|
| 124 |
+
residual = x
|
| 125 |
+
|
| 126 |
+
x = self.w_2(F.relu(self.w_1(x)))
|
| 127 |
+
x = self.dropout(x)
|
| 128 |
+
x += residual
|
| 129 |
+
|
| 130 |
+
x = self.layer_norm(x)
|
| 131 |
+
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class PositionalEncoding(nn.Module):
|
| 136 |
+
|
| 137 |
+
def __init__(self, d_hid, n_position=200):
|
| 138 |
+
super(PositionalEncoding, self).__init__()
|
| 139 |
+
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
|
| 140 |
+
|
| 141 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
| 142 |
+
|
| 143 |
+
def get_position_angle_vec(position):
|
| 144 |
+
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
| 145 |
+
|
| 146 |
+
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
| 147 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
| 148 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
| 149 |
+
|
| 150 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
return x + self.pos_table[:, :x.size(1)].clone().detach()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class EncoderLayer(nn.Module):
|
| 157 |
+
''' Compose with two layers '''
|
| 158 |
+
|
| 159 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
|
| 160 |
+
super(EncoderLayer, self).__init__()
|
| 161 |
+
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
| 162 |
+
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
| 163 |
+
|
| 164 |
+
def forward(self, enc_input, slf_attn_mask=None):
|
| 165 |
+
enc_output, enc_slf_attn = self.slf_attn(
|
| 166 |
+
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
| 167 |
+
enc_output = self.pos_ffn(enc_output)
|
| 168 |
+
return enc_output, enc_slf_attn
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class TransformerEncoder(nn.Module):
|
| 172 |
+
''' A encoder model with self attention mechanism. '''
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64,
|
| 176 |
+
d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True):
|
| 177 |
+
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
if n_position > 0:
|
| 181 |
+
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
|
| 182 |
+
else:
|
| 183 |
+
self.position_enc = lambda x: x
|
| 184 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 185 |
+
self.layer_stack = nn.ModuleList([
|
| 186 |
+
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
|
| 187 |
+
for _ in range(n_layers)])
|
| 188 |
+
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
| 189 |
+
self.scale_emb = scale_emb
|
| 190 |
+
self.d_model = d_model
|
| 191 |
+
|
| 192 |
+
def forward(self, src_seq, return_attns=False):
|
| 193 |
+
if len(src_seq.shape) == 2:
|
| 194 |
+
src_seq = src_seq.unsqueeze(1)
|
| 195 |
+
B, L, D = src_seq.shape
|
| 196 |
+
|
| 197 |
+
enc_slf_attn_list = []
|
| 198 |
+
|
| 199 |
+
causal_mask = None
|
| 200 |
+
|
| 201 |
+
enc_output = src_seq
|
| 202 |
+
if self.scale_emb:
|
| 203 |
+
enc_output = enc_output * self.d_model ** 0.5
|
| 204 |
+
enc_output = self.dropout(self.position_enc(enc_output))
|
| 205 |
+
enc_output = self.layer_norm(enc_output)
|
| 206 |
+
|
| 207 |
+
for enc_layer in self.layer_stack:
|
| 208 |
+
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask)
|
| 209 |
+
enc_slf_attn_list += [enc_slf_attn] if return_attns else []
|
| 210 |
+
|
| 211 |
+
if return_attns:
|
| 212 |
+
return enc_output, enc_slf_attn_list
|
| 213 |
+
return enc_output
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# gated cross attention
|
| 217 |
+
class MaskedCrossAttention(nn.Module):
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
*,
|
| 221 |
+
dim,
|
| 222 |
+
dim_audio,
|
| 223 |
+
max_window_per_audio,
|
| 224 |
+
dim_head=64,
|
| 225 |
+
heads=8,
|
| 226 |
+
only_attend_immediate_media=True,
|
| 227 |
+
):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.max_window_per_audio = max_window_per_audio
|
| 230 |
+
self.scale = dim_head**-0.5
|
| 231 |
+
self.heads = heads
|
| 232 |
+
inner_dim = dim_head * heads
|
| 233 |
+
|
| 234 |
+
self.norm = nn.LayerNorm(dim)
|
| 235 |
+
|
| 236 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 237 |
+
self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False)
|
| 238 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 239 |
+
|
| 240 |
+
self.only_attend_immediate_media = only_attend_immediate_media
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
x,
|
| 245 |
+
media, media_mask,
|
| 246 |
+
media_locations=None,
|
| 247 |
+
use_cached_media=False
|
| 248 |
+
):
|
| 249 |
+
|
| 250 |
+
if not use_cached_media:
|
| 251 |
+
assert (
|
| 252 |
+
media_locations.shape[1] == x.shape[1]
|
| 253 |
+
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
|
| 254 |
+
|
| 255 |
+
T_txt = x.shape[1]
|
| 256 |
+
B, L = media.shape[:2]
|
| 257 |
+
assert media.shape[2] == 1 # extra dim
|
| 258 |
+
assert L % self.max_window_per_audio == 0 # should be 4 or 8 times
|
| 259 |
+
h = self.heads
|
| 260 |
+
|
| 261 |
+
x = self.norm(x)
|
| 262 |
+
|
| 263 |
+
q = self.to_q(x)
|
| 264 |
+
media = rearrange(media, "b t n d -> b (t n) d")
|
| 265 |
+
|
| 266 |
+
k, v = self.to_kv(media).chunk(2, dim=-1)
|
| 267 |
+
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
|
| 268 |
+
|
| 269 |
+
q = q * self.scale
|
| 270 |
+
|
| 271 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
| 272 |
+
|
| 273 |
+
# mask padded audio embeddings
|
| 274 |
+
media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim
|
| 275 |
+
sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max)
|
| 276 |
+
|
| 277 |
+
assert self.only_attend_immediate_media is False
|
| 278 |
+
|
| 279 |
+
# mask media locations
|
| 280 |
+
if exists(media_locations):
|
| 281 |
+
few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device)
|
| 282 |
+
for batch_idx in range(B):
|
| 283 |
+
media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio>
|
| 284 |
+
if len(media_locations_b.shape) > 1:
|
| 285 |
+
media_locations_b = media_locations_b.squeeze(-1)
|
| 286 |
+
|
| 287 |
+
for i in range(-1, len(media_locations_b)):
|
| 288 |
+
if i == -1:
|
| 289 |
+
if len(media_locations_b) == 1:
|
| 290 |
+
text_start, text_end = 0, T_txt
|
| 291 |
+
else:
|
| 292 |
+
text_start, text_end = 0, media_locations_b[i+1]
|
| 293 |
+
|
| 294 |
+
elif i == len(media_locations_b) - 1:
|
| 295 |
+
text_start, text_end = media_locations_b[i], T_txt
|
| 296 |
+
|
| 297 |
+
else:
|
| 298 |
+
text_start, text_end = media_locations_b[i], media_locations_b[i+1]
|
| 299 |
+
|
| 300 |
+
if self.only_attend_immediate_media:
|
| 301 |
+
look_at_window_start = max(i,0) * self.max_window_per_audio
|
| 302 |
+
else:
|
| 303 |
+
look_at_window_start = 0
|
| 304 |
+
look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio
|
| 305 |
+
|
| 306 |
+
few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True
|
| 307 |
+
|
| 308 |
+
sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max)
|
| 309 |
+
|
| 310 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
| 311 |
+
attn = sim.softmax(dim=-1)
|
| 312 |
+
|
| 313 |
+
if exists(media_locations) and self.only_attend_immediate_media:
|
| 314 |
+
text_without_media_mask = text_time == 0
|
| 315 |
+
text_without_media_mask = rearrange(
|
| 316 |
+
text_without_media_mask, "b i -> b 1 i 1"
|
| 317 |
+
)
|
| 318 |
+
attn = attn.masked_fill(text_without_media_mask, 0.0)
|
| 319 |
+
|
| 320 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
| 321 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 322 |
+
return self.to_out(out)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class GatedCrossAttentionBlock(nn.Module):
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
*,
|
| 329 |
+
dim,
|
| 330 |
+
dim_audio,
|
| 331 |
+
max_window_per_audio,
|
| 332 |
+
dim_head=64,
|
| 333 |
+
heads=8,
|
| 334 |
+
ff_mult=4,
|
| 335 |
+
only_attend_immediate_media=True,
|
| 336 |
+
):
|
| 337 |
+
super().__init__()
|
| 338 |
+
self.attn = MaskedCrossAttention(
|
| 339 |
+
dim=dim,
|
| 340 |
+
dim_audio=dim_audio,
|
| 341 |
+
max_window_per_audio=max_window_per_audio,
|
| 342 |
+
dim_head=dim_head,
|
| 343 |
+
heads=heads,
|
| 344 |
+
only_attend_immediate_media=only_attend_immediate_media,
|
| 345 |
+
)
|
| 346 |
+
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
|
| 347 |
+
|
| 348 |
+
self.ff = FeedForward(dim, mult=ff_mult)
|
| 349 |
+
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
|
| 350 |
+
|
| 351 |
+
def forward(
|
| 352 |
+
self,
|
| 353 |
+
x,
|
| 354 |
+
media,
|
| 355 |
+
media_mask,
|
| 356 |
+
media_locations=None,
|
| 357 |
+
use_cached_media=False,
|
| 358 |
+
):
|
| 359 |
+
x = (
|
| 360 |
+
self.attn(
|
| 361 |
+
x,
|
| 362 |
+
media,
|
| 363 |
+
media_mask,
|
| 364 |
+
media_locations=media_locations,
|
| 365 |
+
use_cached_media=use_cached_media,
|
| 366 |
+
)
|
| 367 |
+
* self.attn_gate.tanh()
|
| 368 |
+
+ x
|
| 369 |
+
)
|
| 370 |
+
x = self.ff(x) * self.ff_gate.tanh() + x
|
| 371 |
+
|
| 372 |
+
return x
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == '__main__':
|
| 376 |
+
enc = TransformerEncoder().cuda()
|
| 377 |
+
x = torch.randn(4, 512).cuda()
|
| 378 |
+
output = enc(x)
|
| 379 |
+
enc._use_gradient_checkpointing = True
|
| 380 |
+
print(output.shape)
|
models/audio-flamingo-1/chat/src/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
def extend_instance(obj, mixin):
|
| 8 |
+
"""Apply mixins to a class instance after creation"""
|
| 9 |
+
base_cls = obj.__class__
|
| 10 |
+
base_cls_name = obj.__class__.__name__
|
| 11 |
+
obj.__class__ = type(
|
| 12 |
+
base_cls_name, (mixin, base_cls), {}
|
| 13 |
+
) # mixin needs to go first for our forward() logic to work
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def getattr_recursive(obj, att):
|
| 17 |
+
"""
|
| 18 |
+
Return nested attribute of obj
|
| 19 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
| 20 |
+
"""
|
| 21 |
+
if att == "":
|
| 22 |
+
return obj
|
| 23 |
+
i = att.find(".")
|
| 24 |
+
if i < 0:
|
| 25 |
+
return getattr(obj, att)
|
| 26 |
+
else:
|
| 27 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def setattr_recursive(obj, att, val):
|
| 31 |
+
"""
|
| 32 |
+
Set nested attribute of obj
|
| 33 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
| 34 |
+
"""
|
| 35 |
+
if "." in att:
|
| 36 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
| 37 |
+
setattr(obj, att.split(".")[-1], val)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def apply_with_stopping_condition(
|
| 41 |
+
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
|
| 42 |
+
):
|
| 43 |
+
if stopping_condition(module):
|
| 44 |
+
return
|
| 45 |
+
if apply_condition(module):
|
| 46 |
+
apply_fn(module, **other_args)
|
| 47 |
+
for child in module.children():
|
| 48 |
+
apply_with_stopping_condition(
|
| 49 |
+
child,
|
| 50 |
+
apply_fn,
|
| 51 |
+
apply_condition=apply_condition,
|
| 52 |
+
stopping_condition=stopping_condition,
|
| 53 |
+
**other_args
|
| 54 |
+
)
|
models/audio-flamingo-1/chat/train/distributed.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/mlfoundations/open_clip under the MIT license.
|
| 8 |
+
# LICENSE is in incl_licenses directory.
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import horovod.torch as hvd
|
| 15 |
+
except ImportError:
|
| 16 |
+
hvd = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_global_master(args):
|
| 20 |
+
return args.rank == 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def is_local_master(args):
|
| 24 |
+
return args.local_rank == 0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_master(args, local=False):
|
| 28 |
+
return is_local_master(args) if local else is_global_master(args)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def is_using_horovod():
|
| 32 |
+
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
|
| 33 |
+
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
|
| 34 |
+
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
|
| 35 |
+
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
|
| 36 |
+
if all([var in os.environ for var in ompi_vars]) or all(
|
| 37 |
+
[var in os.environ for var in pmi_vars]
|
| 38 |
+
):
|
| 39 |
+
return True
|
| 40 |
+
else:
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def is_using_distributed():
|
| 45 |
+
if "WORLD_SIZE" in os.environ:
|
| 46 |
+
return int(os.environ["WORLD_SIZE"]) > 1
|
| 47 |
+
if "SLURM_NTASKS" in os.environ:
|
| 48 |
+
return int(os.environ["SLURM_NTASKS"]) > 1
|
| 49 |
+
return False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def world_info_from_env():
|
| 53 |
+
local_rank = 0
|
| 54 |
+
for v in (
|
| 55 |
+
"LOCAL_RANK",
|
| 56 |
+
"MPI_LOCALRANKID",
|
| 57 |
+
"SLURM_LOCALID",
|
| 58 |
+
"OMPI_COMM_WORLD_LOCAL_RANK",
|
| 59 |
+
):
|
| 60 |
+
if v in os.environ:
|
| 61 |
+
local_rank = int(os.environ[v])
|
| 62 |
+
break
|
| 63 |
+
global_rank = 0
|
| 64 |
+
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
|
| 65 |
+
if v in os.environ:
|
| 66 |
+
global_rank = int(os.environ[v])
|
| 67 |
+
break
|
| 68 |
+
world_size = 1
|
| 69 |
+
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
|
| 70 |
+
if v in os.environ:
|
| 71 |
+
world_size = int(os.environ[v])
|
| 72 |
+
break
|
| 73 |
+
|
| 74 |
+
return local_rank, global_rank, world_size
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def init_distributed_device(args):
|
| 78 |
+
# Distributed training = training on more than one GPU.
|
| 79 |
+
# Works in both single and multi-node scenarios.
|
| 80 |
+
args.distributed = False
|
| 81 |
+
args.world_size = 1
|
| 82 |
+
args.rank = 0 # global rank
|
| 83 |
+
args.local_rank = 0
|
| 84 |
+
|
| 85 |
+
if args.horovod:
|
| 86 |
+
assert hvd is not None, "Horovod is not installed"
|
| 87 |
+
print('using horovod')
|
| 88 |
+
hvd.init()
|
| 89 |
+
args.local_rank = int(hvd.local_rank())
|
| 90 |
+
args.rank = hvd.rank()
|
| 91 |
+
args.world_size = hvd.size()
|
| 92 |
+
args.distributed = True
|
| 93 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
| 94 |
+
os.environ["RANK"] = str(args.rank)
|
| 95 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 96 |
+
|
| 97 |
+
elif is_using_distributed():
|
| 98 |
+
if "SLURM_PROCID" in os.environ:
|
| 99 |
+
print('DDP via SLURM')
|
| 100 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
| 101 |
+
|
| 102 |
+
# SLURM var -> torch.distributed vars in case needed
|
| 103 |
+
os.environ["LOCAL_RANK"] = str(args.local_rank)
|
| 104 |
+
os.environ["RANK"] = str(args.rank)
|
| 105 |
+
os.environ["WORLD_SIZE"] = str(args.world_size)
|
| 106 |
+
|
| 107 |
+
init_method = args.dist_url
|
| 108 |
+
|
| 109 |
+
# # master_ip = os.getenv('MASTER_ADDR', 'localhost')
|
| 110 |
+
# # master_port = os.getenv('MASTER_PORT', '7000')
|
| 111 |
+
# print("DDP RANK %d WORLD_SIZE %d" % (args.rank, args.world_size))
|
| 112 |
+
# # init_method = f'tcp://{master_ip}:{master_port}'
|
| 113 |
+
# init_method = 'tcp://localhost:54323'
|
| 114 |
+
# print("Init method: %s" % (init_method))
|
| 115 |
+
|
| 116 |
+
torch.distributed.init_process_group(
|
| 117 |
+
backend=args.dist_backend,
|
| 118 |
+
init_method=init_method,
|
| 119 |
+
world_size=args.world_size,
|
| 120 |
+
rank=args.rank,
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
print('DDP via torchrun, torch.distributed.launch')
|
| 124 |
+
args.local_rank, _, _ = world_info_from_env()
|
| 125 |
+
torch.distributed.init_process_group(
|
| 126 |
+
backend=args.dist_backend, init_method=args.dist_url
|
| 127 |
+
)
|
| 128 |
+
args.world_size = torch.distributed.get_world_size()
|
| 129 |
+
args.rank = torch.distributed.get_rank()
|
| 130 |
+
args.distributed = True
|
| 131 |
+
else:
|
| 132 |
+
print('needed to run on single gpu')
|
| 133 |
+
torch.distributed.init_process_group(
|
| 134 |
+
backend=args.dist_backend,
|
| 135 |
+
init_method=args.dist_url,
|
| 136 |
+
world_size=1,
|
| 137 |
+
rank=0,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if torch.cuda.is_available():
|
| 141 |
+
if args.distributed and not args.no_set_device_rank:
|
| 142 |
+
device = "cuda:%d" % args.local_rank
|
| 143 |
+
else:
|
| 144 |
+
device = "cuda:0"
|
| 145 |
+
torch.cuda.set_device(device)
|
| 146 |
+
else:
|
| 147 |
+
device = "cpu"
|
| 148 |
+
args.device = device
|
| 149 |
+
device = torch.device(device)
|
| 150 |
+
return device
|
models/audio-flamingo-1/chat/train/train.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import functools
|
| 9 |
+
import glob
|
| 10 |
+
import os
|
| 11 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
|
| 12 |
+
import random
|
| 13 |
+
import shutil
|
| 14 |
+
import sys
|
| 15 |
+
sys.path.append('../')
|
| 16 |
+
import yaml
|
| 17 |
+
import time
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 22 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 23 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 24 |
+
from torch.distributed.fsdp import (
|
| 25 |
+
CPUOffload,
|
| 26 |
+
MixedPrecision,
|
| 27 |
+
ShardingStrategy,
|
| 28 |
+
BackwardPrefetch,
|
| 29 |
+
)
|
| 30 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
| 31 |
+
checkpoint_wrapper,
|
| 32 |
+
CheckpointWrapper,
|
| 33 |
+
CheckpointImpl,
|
| 34 |
+
apply_activation_checkpointing,
|
| 35 |
+
)
|
| 36 |
+
from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups
|
| 37 |
+
from torch.distributed.distributed_c10d import _get_default_group
|
| 38 |
+
torch.cuda.empty_cache()
|
| 39 |
+
|
| 40 |
+
from transformers import (
|
| 41 |
+
get_constant_schedule_with_warmup,
|
| 42 |
+
get_cosine_schedule_with_warmup,
|
| 43 |
+
get_linear_schedule_with_warmup,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
from data.data import get_audiotext_dataloader # AudioTextData, DataCollator
|
| 47 |
+
from distributed import init_distributed_device, world_info_from_env
|
| 48 |
+
from train_utils import (
|
| 49 |
+
train_one_epoch,
|
| 50 |
+
get_mp_policy_dtype,
|
| 51 |
+
save_checkpoint,
|
| 52 |
+
Dict2Class,
|
| 53 |
+
get_autocast,
|
| 54 |
+
get_cast_dtype
|
| 55 |
+
)
|
| 56 |
+
from src.factory import create_model_and_transforms
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def random_seed(seed=42, rank=0):
|
| 60 |
+
torch.manual_seed(seed + rank)
|
| 61 |
+
np.random.seed(seed + rank)
|
| 62 |
+
random.seed(seed + rank)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def main():
|
| 66 |
+
parser = argparse.ArgumentParser()
|
| 67 |
+
parser.add_argument('-c', '--config', type=str, default='../config/config.yaml', help='yaml config path')
|
| 68 |
+
parsed_args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
config = yaml.load(open(parsed_args.config), Loader=yaml.FullLoader)
|
| 71 |
+
data_config = config['data_config']
|
| 72 |
+
model_config = config['model_config']
|
| 73 |
+
clap_config = config["clap_config"]
|
| 74 |
+
args = Dict2Class(config['train_config'])
|
| 75 |
+
|
| 76 |
+
if 'sft_config' in config:
|
| 77 |
+
sft_config = config['sft_config']
|
| 78 |
+
unfreeze_full_lm = sft_config['unfreeze_full_lm']
|
| 79 |
+
else:
|
| 80 |
+
sft_config = None
|
| 81 |
+
unfreeze_full_lm = False
|
| 82 |
+
|
| 83 |
+
# get paths done
|
| 84 |
+
exp_path = os.path.join(args.expdir, args.run_name)
|
| 85 |
+
os.makedirs(exp_path, exist_ok=True)
|
| 86 |
+
print('exp_path:', exp_path)
|
| 87 |
+
shutil.copy(parsed_args.config, os.path.join(exp_path, 'config.yaml'))
|
| 88 |
+
data_config["dataset_blending_output"] = os.path.join(exp_path, data_config["dataset_blending_output"])
|
| 89 |
+
|
| 90 |
+
# Validate args
|
| 91 |
+
if args.fsdp and not args.fsdp_use_orig_params:
|
| 92 |
+
print(
|
| 93 |
+
"Warning: FSDP is running without fsdp_use_orig_params flag. "
|
| 94 |
+
+ "This is not recommended because it means we will use uniform weight decay"
|
| 95 |
+
+ " and train all embeddings, not just the newly added ones. "
|
| 96 |
+
+ "Note: OPT models are not compatible with fsdp_use_orig_params flag."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if args.fsdp and args.fsdp_sharding_strategy == "hybrid":
|
| 100 |
+
print(
|
| 101 |
+
"Warning: As of torch=2.0.1, the FSDP logic for optim_state_dict() is broken for hybrid sharding."
|
| 102 |
+
+ "To make this method work, we need to modify torch.distributed.fsdp._optim_utils.py"
|
| 103 |
+
+ "Copy and paste the code from the _optim_utils.py in this repo into the torch file."
|
| 104 |
+
+ "The main issue was the missing group kwarg on line 1596 in _all_gather_optim_state."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Set up distributed training
|
| 108 |
+
print('initializing distributed environment')
|
| 109 |
+
if args.offline:
|
| 110 |
+
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
| 111 |
+
args.local_rank, args.rank, args.world_size = world_info_from_env()
|
| 112 |
+
device_id = init_distributed_device(args)
|
| 113 |
+
random_seed(args.seed)
|
| 114 |
+
|
| 115 |
+
# Initialize model
|
| 116 |
+
print('creating model')
|
| 117 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
| 118 |
+
model, tokenizer = create_model_and_transforms(
|
| 119 |
+
**model_config,
|
| 120 |
+
clap_config=clap_config,
|
| 121 |
+
use_local_files=args.offline,
|
| 122 |
+
gradient_checkpointing=args.gradient_checkpointing,
|
| 123 |
+
freeze_lm_embeddings=args.freeze_lm_embeddings,
|
| 124 |
+
unfreeze_full_lm=unfreeze_full_lm
|
| 125 |
+
)
|
| 126 |
+
random_seed(args.seed, args.rank)
|
| 127 |
+
|
| 128 |
+
# Initialize logging
|
| 129 |
+
print(f"Start running training on rank {args.rank}.")
|
| 130 |
+
|
| 131 |
+
# Load model checkpoint on CPU
|
| 132 |
+
checkpoint_list = glob.glob(f"{args.expdir}/{args.run_name}/checkpoint_*.pt")
|
| 133 |
+
if len(checkpoint_list) == 0:
|
| 134 |
+
print(f"Found no checkpoints for run {args.run_name}.")
|
| 135 |
+
resume_from_checkpoint = None
|
| 136 |
+
else:
|
| 137 |
+
resume_from_checkpoint = sorted(
|
| 138 |
+
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
|
| 139 |
+
)[-1]
|
| 140 |
+
print(
|
| 141 |
+
f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# load pretrained model
|
| 145 |
+
resume_from_epoch = 0
|
| 146 |
+
if (resume_from_checkpoint is None) and (sft_config is not None):
|
| 147 |
+
# just started SFT
|
| 148 |
+
pretrained_path = os.path.join(
|
| 149 |
+
sft_config['pretrained_path'],
|
| 150 |
+
sft_config['pretrained_ckpt']
|
| 151 |
+
)
|
| 152 |
+
if args.rank == 0:
|
| 153 |
+
print(f"Loading checkpoint from {pretrained_path}")
|
| 154 |
+
checkpoint = torch.load(pretrained_path, map_location="cpu")
|
| 155 |
+
msd = checkpoint["model_state_dict"]
|
| 156 |
+
msd = {k.replace("module.", ""): v for k, v in msd.items()}
|
| 157 |
+
|
| 158 |
+
# for fsdp, only one rank needs to load the state dict
|
| 159 |
+
if not args.fsdp or args.rank == 0:
|
| 160 |
+
model.load_state_dict(msd, False)
|
| 161 |
+
del checkpoint["model_state_dict"]
|
| 162 |
+
del msd
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
elif resume_from_checkpoint is not None:
|
| 166 |
+
# continue training (either pretraining or STF)
|
| 167 |
+
if args.rank == 0:
|
| 168 |
+
print(f"Loading checkpoint from {resume_from_checkpoint}")
|
| 169 |
+
checkpoint = torch.load(resume_from_checkpoint, map_location="cpu")
|
| 170 |
+
msd = checkpoint["model_state_dict"]
|
| 171 |
+
msd = {k.replace("module.", ""): v for k, v in msd.items()}
|
| 172 |
+
resume_from_epoch = checkpoint["epoch"] + 1
|
| 173 |
+
|
| 174 |
+
# for fsdp, only one rank needs to load the state dict
|
| 175 |
+
if not args.fsdp or args.rank == 0:
|
| 176 |
+
model.load_state_dict(msd, False)
|
| 177 |
+
del checkpoint["model_state_dict"]
|
| 178 |
+
del msd
|
| 179 |
+
|
| 180 |
+
else:
|
| 181 |
+
pass
|
| 182 |
+
|
| 183 |
+
# Initialize FSDP / DDP, and ensure the model is on GPU
|
| 184 |
+
print(f"Initializing distributed training with {args.world_size} GPUs.")
|
| 185 |
+
if args.fsdp:
|
| 186 |
+
print(
|
| 187 |
+
f"Before FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# init MixedPrecision
|
| 191 |
+
if args.precision != "fp32":
|
| 192 |
+
cast_dtype = get_mp_policy_dtype(args.precision)
|
| 193 |
+
mp_policy = MixedPrecision(
|
| 194 |
+
param_dtype=torch.float32,
|
| 195 |
+
reduce_dtype=cast_dtype, # gradient communication
|
| 196 |
+
buffer_dtype=cast_dtype,
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
mp_policy = None
|
| 200 |
+
|
| 201 |
+
# init process groups
|
| 202 |
+
if args.fsdp_sharding_strategy == "hybrid":
|
| 203 |
+
intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
|
| 204 |
+
_get_default_group()
|
| 205 |
+
)
|
| 206 |
+
args.my_group = intra_node_group # for optimizer saving
|
| 207 |
+
process_group = (intra_node_group, inter_node_group) # for FSDP init
|
| 208 |
+
else:
|
| 209 |
+
args.my_group = None # for optimizer saving
|
| 210 |
+
process_group = None # for FSDP init
|
| 211 |
+
|
| 212 |
+
# init FSDP
|
| 213 |
+
wrapper_kwargs = dict(
|
| 214 |
+
process_group=process_group,
|
| 215 |
+
cpu_offload=CPUOffload(offload_params=False),
|
| 216 |
+
device_id=device_id,
|
| 217 |
+
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
|
| 218 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD
|
| 219 |
+
if args.fsdp_sharding_strategy == "full"
|
| 220 |
+
else ShardingStrategy.HYBRID_SHARD,
|
| 221 |
+
use_orig_params=args.fsdp_use_orig_params,
|
| 222 |
+
mixed_precision=mp_policy,
|
| 223 |
+
forward_prefetch=True,
|
| 224 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
| 225 |
+
limit_all_gathers=True,
|
| 226 |
+
)
|
| 227 |
+
model.wrap_fsdp(wrapper_kwargs, device_id)
|
| 228 |
+
ddp_model = model
|
| 229 |
+
|
| 230 |
+
print(
|
| 231 |
+
f"After FSDP parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
|
| 232 |
+
)
|
| 233 |
+
print(
|
| 234 |
+
f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
else:
|
| 238 |
+
model = model.to(device_id)
|
| 239 |
+
ddp_model = DDP(model, device_ids=[device_id])
|
| 240 |
+
|
| 241 |
+
# Initialize gradient checkpointing
|
| 242 |
+
if args.gradient_checkpointing:
|
| 243 |
+
non_reentrant_wrapper = functools.partial(
|
| 244 |
+
checkpoint_wrapper,
|
| 245 |
+
offload_to_cpu=True,
|
| 246 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
| 247 |
+
)
|
| 248 |
+
apply_activation_checkpointing(
|
| 249 |
+
ddp_model,
|
| 250 |
+
checkpoint_wrapper_fn=non_reentrant_wrapper,
|
| 251 |
+
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
|
| 252 |
+
and not isinstance(m, FSDP)
|
| 253 |
+
and not isinstance(m, CheckpointWrapper),
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Initialize optimizer
|
| 257 |
+
params_to_optimize = ddp_model.named_parameters()
|
| 258 |
+
params_to_optimize = list(
|
| 259 |
+
filter(
|
| 260 |
+
lambda x: x[1].requires_grad
|
| 261 |
+
and not getattr(x[1], "exclude_from_optimizer", False),
|
| 262 |
+
params_to_optimize,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
if not args.fsdp or args.fsdp_use_orig_params:
|
| 266 |
+
# apply weight decay only to params in the xattn layers
|
| 267 |
+
def get_grouped_params(model):
|
| 268 |
+
params_with_wd, params_without_wd = [], []
|
| 269 |
+
for n, p in params_to_optimize:
|
| 270 |
+
if "gated_cross_attn" in n:
|
| 271 |
+
params_with_wd.append(p)
|
| 272 |
+
else:
|
| 273 |
+
params_without_wd.append(p)
|
| 274 |
+
return [
|
| 275 |
+
{"params": params_with_wd, "weight_decay": args.weight_decay},
|
| 276 |
+
{"params": params_without_wd, "weight_decay": 0.0},
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
optimizer = torch.optim.AdamW(
|
| 280 |
+
get_grouped_params(params_to_optimize), lr=args.learning_rate
|
| 281 |
+
)
|
| 282 |
+
else:
|
| 283 |
+
# unclear if we should be using no weight decay or small weight decay for all parameters
|
| 284 |
+
optimizer = torch.optim.AdamW(
|
| 285 |
+
(p for _, p in params_to_optimize),
|
| 286 |
+
lr=args.learning_rate,
|
| 287 |
+
weight_decay=args.weight_decay,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# load optimizer checkpoint
|
| 291 |
+
if resume_from_checkpoint is not None:
|
| 292 |
+
osd = checkpoint["optimizer_state_dict"]
|
| 293 |
+
if args.fsdp:
|
| 294 |
+
osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer)
|
| 295 |
+
optimizer.load_state_dict(osd)
|
| 296 |
+
del checkpoint["optimizer_state_dict"]
|
| 297 |
+
del osd
|
| 298 |
+
|
| 299 |
+
# Initialize data loaders
|
| 300 |
+
AudioTextDataInfo = get_audiotext_dataloader(
|
| 301 |
+
data_config, clap_config, tokenizer, args.batch_size, split='train',
|
| 302 |
+
epoch=0, force_reblend=True
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
total_training_steps = (
|
| 306 |
+
len(AudioTextDataInfo.dataset) // (args.batch_size * args.world_size)
|
| 307 |
+
) * args.num_epochs
|
| 308 |
+
|
| 309 |
+
if args.rank == 0:
|
| 310 |
+
print(f"Total training steps: {total_training_steps}")
|
| 311 |
+
tb = SummaryWriter(os.path.join(exp_path, 'tensorboard'))
|
| 312 |
+
else:
|
| 313 |
+
tb = None
|
| 314 |
+
|
| 315 |
+
# Initialize lr scheduler
|
| 316 |
+
if args.lr_scheduler == "linear":
|
| 317 |
+
lr_scheduler = get_linear_schedule_with_warmup(
|
| 318 |
+
optimizer,
|
| 319 |
+
num_warmup_steps=args.warmup_steps,
|
| 320 |
+
num_training_steps=total_training_steps,
|
| 321 |
+
)
|
| 322 |
+
elif args.lr_scheduler == "cosine":
|
| 323 |
+
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 324 |
+
optimizer,
|
| 325 |
+
num_warmup_steps=args.warmup_steps,
|
| 326 |
+
num_training_steps=total_training_steps,
|
| 327 |
+
)
|
| 328 |
+
else:
|
| 329 |
+
lr_scheduler = get_constant_schedule_with_warmup(
|
| 330 |
+
optimizer, num_warmup_steps=args.warmup_steps
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# load lr scheduler checkpoint
|
| 334 |
+
if resume_from_checkpoint is not None:
|
| 335 |
+
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
|
| 336 |
+
del checkpoint["lr_scheduler_state_dict"]
|
| 337 |
+
|
| 338 |
+
# Start training!
|
| 339 |
+
ddp_model.train()
|
| 340 |
+
|
| 341 |
+
print('start training from epoch {}'.format(resume_from_epoch))
|
| 342 |
+
for epoch in range(resume_from_epoch, args.num_epochs):
|
| 343 |
+
# force reblending dataset for every epoch
|
| 344 |
+
if epoch > 0:
|
| 345 |
+
AudioTextDataInfo = get_audiotext_dataloader(
|
| 346 |
+
data_config, clap_config, tokenizer, args.batch_size, split='train',
|
| 347 |
+
epoch=epoch, force_reblend=True
|
| 348 |
+
)
|
| 349 |
+
AudioTextDataInfo.set_epoch(epoch)
|
| 350 |
+
trainloader = AudioTextDataInfo.dataloader
|
| 351 |
+
|
| 352 |
+
# train one epoch
|
| 353 |
+
train_one_epoch(
|
| 354 |
+
args=args,
|
| 355 |
+
model=ddp_model,
|
| 356 |
+
epoch=epoch,
|
| 357 |
+
tokenizer=tokenizer,
|
| 358 |
+
optimizer=optimizer,
|
| 359 |
+
lr_scheduler=lr_scheduler,
|
| 360 |
+
trainloader=trainloader,
|
| 361 |
+
device_id=device_id,
|
| 362 |
+
tb=tb
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# save checkpoint
|
| 366 |
+
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
|
| 367 |
+
time.sleep(1.0)
|
| 368 |
+
|
| 369 |
+
# save final checkpoint
|
| 370 |
+
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
|
| 371 |
+
if args.rank == 0:
|
| 372 |
+
tb.close()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
main()
|
models/audio-flamingo-1/chat/train/train_utils.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import os
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import sys
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
|
| 13 |
+
from contextlib import suppress
|
| 14 |
+
import torch
|
| 15 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 16 |
+
from torch.distributed.fsdp import (
|
| 17 |
+
FullStateDictConfig,
|
| 18 |
+
StateDictType,
|
| 19 |
+
)
|
| 20 |
+
from torch.distributed.fsdp.api import FullOptimStateDictConfig
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Dict2Class:
|
| 25 |
+
def __init__(self, data_dict):
|
| 26 |
+
for key, value in data_dict.items():
|
| 27 |
+
setattr(self, key, value)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SysLogger(object):
|
| 31 |
+
def __init__(self, filename="../log/log.log"):
|
| 32 |
+
self.terminal = sys.stdout
|
| 33 |
+
self.log = open(filename, "a")
|
| 34 |
+
|
| 35 |
+
def write(self, message):
|
| 36 |
+
self.terminal.write(message+'\n')
|
| 37 |
+
self.log.write(message)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cast_dtype(precision: str):
|
| 41 |
+
cast_dtype = None
|
| 42 |
+
if precision == "bf16":
|
| 43 |
+
cast_dtype = torch.bfloat16
|
| 44 |
+
elif precision == "fp16":
|
| 45 |
+
cast_dtype = torch.float16
|
| 46 |
+
return cast_dtype
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_mp_policy_dtype(precision: str):
|
| 50 |
+
if "bfloat16" in precision or "bf16" in precision:
|
| 51 |
+
return torch.bfloat16
|
| 52 |
+
elif precision == "fp16":
|
| 53 |
+
return torch.float16
|
| 54 |
+
else:
|
| 55 |
+
return torch.float32
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_autocast(precision, cache_enabled=True):
|
| 59 |
+
if precision == "amp":
|
| 60 |
+
return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
|
| 61 |
+
elif precision == "amp_bfloat16" or precision == "amp_bf16":
|
| 62 |
+
return lambda: torch.cuda.amp.autocast(
|
| 63 |
+
dtype=torch.bfloat16, cache_enabled=cache_enabled
|
| 64 |
+
)
|
| 65 |
+
else:
|
| 66 |
+
return suppress
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def train_one_epoch(
|
| 70 |
+
args,
|
| 71 |
+
model,
|
| 72 |
+
epoch,
|
| 73 |
+
trainloader,
|
| 74 |
+
tokenizer,
|
| 75 |
+
optimizer,
|
| 76 |
+
lr_scheduler,
|
| 77 |
+
device_id,
|
| 78 |
+
tb
|
| 79 |
+
):
|
| 80 |
+
# setup loaders
|
| 81 |
+
num_batches_per_epoch = len(trainloader)
|
| 82 |
+
total_training_steps = num_batches_per_epoch * args.num_epochs
|
| 83 |
+
print('num_batches_per_epoch={}, total_training_steps={}'.format(num_batches_per_epoch, total_training_steps))
|
| 84 |
+
|
| 85 |
+
autocast = get_autocast(
|
| 86 |
+
args.precision, cache_enabled=(not args.fsdp)
|
| 87 |
+
) # if fsdp, disable cache to save memory
|
| 88 |
+
cast_dtype = get_cast_dtype(args.precision)
|
| 89 |
+
|
| 90 |
+
# setup model
|
| 91 |
+
media_token_id = tokenizer("<audio>", add_special_tokens=False)["input_ids"][-1]
|
| 92 |
+
assert media_token_id == tokenizer.encode("<audio>")[-1]
|
| 93 |
+
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
|
| 94 |
+
model.train()
|
| 95 |
+
|
| 96 |
+
# setup logging
|
| 97 |
+
step_time_m = AverageMeter()
|
| 98 |
+
data_time_m = AverageMeter()
|
| 99 |
+
end = time.time()
|
| 100 |
+
|
| 101 |
+
# loop through dataloader
|
| 102 |
+
for num_steps, batch in tqdm(
|
| 103 |
+
enumerate(trainloader),
|
| 104 |
+
disable=args.rank != 0,
|
| 105 |
+
total=total_training_steps,
|
| 106 |
+
initial=(epoch * num_batches_per_epoch)
|
| 107 |
+
):
|
| 108 |
+
|
| 109 |
+
data_time_m.update(time.time() - end)
|
| 110 |
+
global_step = num_steps + epoch * num_batches_per_epoch
|
| 111 |
+
|
| 112 |
+
#### FORWARD PASS ####
|
| 113 |
+
audio_clips = batch["audio_clips"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS, WINDOW_LENGTH)
|
| 114 |
+
audio_embed_mask = batch["audio_embed_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_WINDOWS)
|
| 115 |
+
input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS)
|
| 116 |
+
attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) # (B, N_TOKENS)
|
| 117 |
+
|
| 118 |
+
# set up labels; language model is expected to handle shifting
|
| 119 |
+
labels = input_ids.clone()
|
| 120 |
+
labels[labels == tokenizer.pad_token_id] = -100
|
| 121 |
+
labels[:, :2] = -100
|
| 122 |
+
labels[labels == tokenizer.encode("<audio>")[-1]] = -100
|
| 123 |
+
|
| 124 |
+
# mask all prompts except for between <SEP> and <|endofchunk|>
|
| 125 |
+
sep_locations = labels == tokenizer.sep_token_id
|
| 126 |
+
eoc_locations = labels == endofchunk_token_id
|
| 127 |
+
|
| 128 |
+
if not all(sep_locations.sum(dim=1) == eoc_locations.sum(dim=1)):
|
| 129 |
+
print("Warning: <SEP>-<EoC> pairing mismatch at step {} due to max_token limit.".format(num_steps))
|
| 130 |
+
|
| 131 |
+
for i in range(labels.shape[0]):
|
| 132 |
+
shouldmask = True
|
| 133 |
+
for j in range(labels.shape[1]):
|
| 134 |
+
if shouldmask and (labels[i][j] != tokenizer.eos_token_id):
|
| 135 |
+
masked_value = -100
|
| 136 |
+
else:
|
| 137 |
+
masked_value = labels[i][j]
|
| 138 |
+
|
| 139 |
+
if labels[i][j] == tokenizer.sep_token_id:
|
| 140 |
+
shouldmask = False
|
| 141 |
+
elif labels[i][j] == endofchunk_token_id:
|
| 142 |
+
shouldmask = True
|
| 143 |
+
|
| 144 |
+
labels[i][j] = masked_value
|
| 145 |
+
|
| 146 |
+
if labels[i][-1] not in [-100, tokenizer.eos_token_id, tokenizer.pad_token_id, endofchunk_token_id]:
|
| 147 |
+
for j in range(labels.shape[1]-1, -1, -1):
|
| 148 |
+
if labels[i][j] not in [-100, tokenizer.eos_token_id, endofchunk_token_id]:
|
| 149 |
+
labels[i][j] = -100
|
| 150 |
+
else:
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
labels = labels.to(device_id)
|
| 154 |
+
|
| 155 |
+
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
|
| 156 |
+
with autocast():
|
| 157 |
+
output = model(
|
| 158 |
+
audio_x=audio_clips,
|
| 159 |
+
audio_x_mask=audio_embed_mask,
|
| 160 |
+
lang_x=input_ids,
|
| 161 |
+
attention_mask=attention_mask,
|
| 162 |
+
labels=labels
|
| 163 |
+
)
|
| 164 |
+
loss = output.loss
|
| 165 |
+
|
| 166 |
+
divided_loss = loss / args.gradient_accumulation_steps
|
| 167 |
+
train_loss = divided_loss * args.loss_multiplier
|
| 168 |
+
train_loss.backward()
|
| 169 |
+
|
| 170 |
+
if (not args.freeze_lm_embeddings) and (
|
| 171 |
+
not args.fsdp or args.fsdp_use_orig_params
|
| 172 |
+
):
|
| 173 |
+
# Mask gradients for input embeddings s.t. we only update the added tokens <audio> and <|endofchunk|>
|
| 174 |
+
if args.fsdp:
|
| 175 |
+
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
|
| 176 |
+
else:
|
| 177 |
+
embed_grad = (
|
| 178 |
+
model.module.lang_encoder.get_input_embeddings().weight.grad
|
| 179 |
+
)
|
| 180 |
+
zero_mask = torch.zeros_like(embed_grad)
|
| 181 |
+
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
|
| 182 |
+
zero_mask[endofchunk_token_id] = torch.ones_like(
|
| 183 |
+
zero_mask[endofchunk_token_id]
|
| 184 |
+
)
|
| 185 |
+
if args.fsdp:
|
| 186 |
+
model.lang_encoder.get_input_embeddings().weight.grad = (
|
| 187 |
+
embed_grad * zero_mask
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
model.module.lang_encoder.get_input_embeddings().weight.grad = (
|
| 191 |
+
embed_grad * zero_mask
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# clip gradient norm
|
| 195 |
+
if args.fsdp:
|
| 196 |
+
"""
|
| 197 |
+
The way we clip gradients with FSDP is different than the non-FSDP case,
|
| 198 |
+
because during FSDP, gradient norms are computed over certain submodules,
|
| 199 |
+
rather than the entire model.
|
| 200 |
+
At least for OPT-125M, this didn't seem to make a difference in performance.
|
| 201 |
+
"""
|
| 202 |
+
model.clip_grad_norm_(1.0)
|
| 203 |
+
else:
|
| 204 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 205 |
+
|
| 206 |
+
# step optimizer and log
|
| 207 |
+
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (
|
| 208 |
+
num_steps == num_batches_per_epoch - 1
|
| 209 |
+
):
|
| 210 |
+
optimizer.step()
|
| 211 |
+
lr_scheduler.step()
|
| 212 |
+
optimizer.zero_grad(set_to_none=True)
|
| 213 |
+
|
| 214 |
+
# step time and reset end outside of rank 0
|
| 215 |
+
step_time_m.update(time.time() - end)
|
| 216 |
+
end = time.time()
|
| 217 |
+
|
| 218 |
+
# rank 0 logging
|
| 219 |
+
if args.rank == 0:
|
| 220 |
+
samples_per_second = (
|
| 221 |
+
args.gradient_accumulation_steps
|
| 222 |
+
* args.batch_size
|
| 223 |
+
* args.world_size
|
| 224 |
+
/ step_time_m.val
|
| 225 |
+
)
|
| 226 |
+
samples_per_second_per_gpu = (
|
| 227 |
+
args.gradient_accumulation_steps
|
| 228 |
+
* args.batch_size
|
| 229 |
+
/ step_time_m.val
|
| 230 |
+
)
|
| 231 |
+
log_dict = {
|
| 232 |
+
"data_time": data_time_m.avg,
|
| 233 |
+
"step_time": step_time_m.avg,
|
| 234 |
+
"samples_per_second": samples_per_second,
|
| 235 |
+
"samples_per_second_per_gpu": samples_per_second_per_gpu,
|
| 236 |
+
"lr": optimizer.param_groups[0]["lr"],
|
| 237 |
+
"loss": loss.item()
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
if ((num_steps + 1) % args.logging_steps == 0):
|
| 241 |
+
for key in log_dict:
|
| 242 |
+
tb.add_scalar("Train/{}".format(key), log_dict[key], global_step)
|
| 243 |
+
|
| 244 |
+
step_time_m.reset()
|
| 245 |
+
data_time_m.reset()
|
| 246 |
+
|
| 247 |
+
# Log loss to console
|
| 248 |
+
if ((num_steps + 1) % args.logging_steps == 0):
|
| 249 |
+
print(
|
| 250 |
+
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}\n"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class AverageMeter(object):
|
| 255 |
+
"""Computes and stores the average and current value"""
|
| 256 |
+
|
| 257 |
+
def __init__(self):
|
| 258 |
+
self.reset()
|
| 259 |
+
|
| 260 |
+
def reset(self):
|
| 261 |
+
self.val = 0
|
| 262 |
+
self.avg = 0
|
| 263 |
+
self.sum = 0
|
| 264 |
+
self.count = 0
|
| 265 |
+
|
| 266 |
+
def update(self, val, n=1):
|
| 267 |
+
self.val = val
|
| 268 |
+
self.sum += val * n
|
| 269 |
+
self.count += n
|
| 270 |
+
self.avg = self.sum / self.count
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def filter_state_dict_to_trainable(model, state_dict):
|
| 274 |
+
"""
|
| 275 |
+
Remove non-trainable parameters from model state dict.
|
| 276 |
+
Exception: Embeddings will not be removed, even if frozen.
|
| 277 |
+
This is because we need the new <audio> <|endofchunk|> tokens to
|
| 278 |
+
be consistent across initializations.
|
| 279 |
+
"""
|
| 280 |
+
for (
|
| 281 |
+
name,
|
| 282 |
+
p,
|
| 283 |
+
) in model.named_parameters(): # won't work for fsdp + use_orig_params=False
|
| 284 |
+
if "fsdp" in name:
|
| 285 |
+
continue
|
| 286 |
+
if "embed" in name or isinstance(p, torch.nn.Embedding):
|
| 287 |
+
continue
|
| 288 |
+
if not p.requires_grad:
|
| 289 |
+
name = name.replace("._checkpoint_wrapped_module", "")
|
| 290 |
+
if name in state_dict:
|
| 291 |
+
del state_dict[name]
|
| 292 |
+
else:
|
| 293 |
+
print(f"WARNING: filtering but {name} not in state_dict")
|
| 294 |
+
|
| 295 |
+
# also remove the keys in state_dict generated from
|
| 296 |
+
# lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers
|
| 297 |
+
# because these are already saved in lang_encoder.model...
|
| 298 |
+
to_delete = [
|
| 299 |
+
n
|
| 300 |
+
for n in state_dict.keys()
|
| 301 |
+
if ("lang_encoder.old_decoder_blocks" in n)
|
| 302 |
+
or ("lang_encoder.gated_cross_attn_layers" in n)
|
| 303 |
+
or ("vision_encoder" in n)
|
| 304 |
+
]
|
| 305 |
+
for name in to_delete:
|
| 306 |
+
del state_dict[name]
|
| 307 |
+
return state_dict
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
|
| 311 |
+
"""
|
| 312 |
+
Save training checkpoint with model, optimizer, and lr_scheduler state.
|
| 313 |
+
"""
|
| 314 |
+
if args.fsdp:
|
| 315 |
+
FSDP.set_state_dict_type(
|
| 316 |
+
model,
|
| 317 |
+
StateDictType.FULL_STATE_DICT,
|
| 318 |
+
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
|
| 319 |
+
FullOptimStateDictConfig(rank0_only=True),
|
| 320 |
+
)
|
| 321 |
+
model_state = model.state_dict()
|
| 322 |
+
optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group)
|
| 323 |
+
|
| 324 |
+
else:
|
| 325 |
+
model_state = model.state_dict()
|
| 326 |
+
optim_state = optimizer.state_dict()
|
| 327 |
+
|
| 328 |
+
if args.rank == 0:
|
| 329 |
+
if not (args.fsdp and not args.fsdp_use_orig_params):
|
| 330 |
+
model_state = filter_state_dict_to_trainable(model, model_state)
|
| 331 |
+
|
| 332 |
+
checkpoint_dir = os.path.join(args.expdir, args.run_name)
|
| 333 |
+
if not os.path.exists(checkpoint_dir):
|
| 334 |
+
os.makedirs(checkpoint_dir)
|
| 335 |
+
|
| 336 |
+
checkpoint_dict = {
|
| 337 |
+
"epoch": epoch,
|
| 338 |
+
"model_state_dict": model_state,
|
| 339 |
+
"optimizer_state_dict": optim_state,
|
| 340 |
+
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
print(f"Saving checkpoint to {checkpoint_dir}/checkpoint_{epoch}.pt")
|
| 344 |
+
torch.save(checkpoint_dict, f"{checkpoint_dir}/checkpoint_{epoch}.pt")
|
| 345 |
+
|
| 346 |
+
if args.delete_previous_checkpoint:
|
| 347 |
+
if epoch > 0 and epoch % 20 != 0:
|
| 348 |
+
try:
|
| 349 |
+
os.remove(f"{checkpoint_dir}/checkpoint_{epoch-1}.pt")
|
| 350 |
+
except:
|
| 351 |
+
pass
|
models/audio-flamingo-1/checkpoints/chat_part1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5673d1541cd5764d6dcc89b3bdc331b768c1159ef685a373c7f4deb9e1ddaef
|
| 3 |
+
size 3328734458
|
models/audio-flamingo-1/checkpoints/chat_part2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea792e60d95deacc75244af0b23c5b75b8de3aef617b392e633cd67a5f20c5aa
|
| 3 |
+
size 3482749306
|
models/audio-flamingo-1/checkpoints/chat_part3.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8828e7e5db7014259d746025274dd752fe39f959eb6d7e1380796a838c2983c
|
| 3 |
+
size 3898925434
|
models/audio-flamingo-1/checkpoints/chat_part4.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d0445a8f663774e16df616f8e20abb5ceff61a40b2ee5a8b12a83a641971f8e1
|
| 3 |
+
size 3357325242
|
models/audio-flamingo-1/checkpoints/chat_part5.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e54e84938e4143a3d2fc8b090a1ab238287a0a41236efcc4a5df5d476291d96f
|
| 3 |
+
size 3591230906
|
models/audio-flamingo-1/checkpoints/checkpoint_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def merge_checkpoints(checkpoint_path, num_parts=5):
|
| 7 |
+
combined_state_dict = {}
|
| 8 |
+
for i in range(1, num_parts + 1):
|
| 9 |
+
part_path = checkpoint_path.replace('.pt', '_part{}.pt'.format(i))
|
| 10 |
+
part_checkpoint = torch.load(part_path)
|
| 11 |
+
part_state_dict = part_checkpoint['model_state_dict']
|
| 12 |
+
combined_state_dict.update(part_state_dict)
|
| 13 |
+
|
| 14 |
+
full_checkpoint = {'model_state_dict': combined_state_dict}
|
| 15 |
+
torch.save(full_checkpoint, checkpoint_path)
|
| 16 |
+
print('merging {}: finished'.format(checkpoint_path))
|
| 17 |
+
|
| 18 |
+
merge_checkpoints('foundation.pt', num_parts=5)
|
| 19 |
+
merge_checkpoints('chat.pt', num_parts=5)
|
models/audio-flamingo-1/checkpoints/foundation_part1.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5921b6167a0e0d27a4732dc77899505db454bcb769942da3079c8b821e54711
|
| 3 |
+
size 3328736090
|
models/audio-flamingo-1/checkpoints/foundation_part2.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fa325ee55d333ff4155679af58e901096b89ffd8097092b602cdf54f8b989791
|
| 3 |
+
size 3482750938
|
models/audio-flamingo-1/checkpoints/foundation_part3.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dfec3fcfe53582e71c8ae1bc8832bbce5b15168878cf5e85a32f1c717bacdc78
|
| 3 |
+
size 3898927066
|
models/audio-flamingo-1/checkpoints/foundation_part4.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b4b3e3011ea05da7ef29e8b906141af05a7eb2fa9604b65f21638db142a192d
|
| 3 |
+
size 3357326874
|
models/audio-flamingo-1/checkpoints/foundation_part5.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e52ea8848b540816e4b7e46ed705ec84965b3718d77b3e8a4ded0b64668fd168
|
| 3 |
+
size 3591232538
|