Anvita Pandit commited on
Commit
5f6b40b
·
1 Parent(s): 5da7938

Add WhAM Gradio app with ZeroGPU support

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Project CETI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,211 @@
1
  ---
2
- title: Wham
3
- emoji: 📊
4
- colorFrom: red
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.8.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WhAM
3
+ emoji: 🐋
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
 
 
7
  app_file: app.py
8
  pinned: false
9
+ hardware: zero-a10g
10
  ---
11
 
12
+ # WhAM: a Whale Acoustics Model
13
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.02206-b31b1b.svg)](https://arxiv.org/abs/2512.02206)
14
+ [![Model Weights](https://img.shields.io/badge/Zenodo-Model%20Weights-blue.svg)](https://doi.org/10.5281/zenodo.17633708)
15
+ [![Hugging Face Dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DSWP%20Dataset-yellow)](https://huggingface.co/datasets/orrp/DSWP)
16
+ ![WhAM](assets/inference.png "WhAM")
17
+ WhAM is a transformer-based audio-to-audio model designed to synthesize and analyze sperm whale codas. Based on [VampNet](https://github.com/hugofloresgarcia/vampnet), WhAM uses masked acoustic token modeling to capture temporal and spectral features of whale communication. WhAM generates codas from a given audio context, enabling three core capabilities:
18
+
19
+ - Acoustic Translation: The ability to style-transfer arbitrary audio prompts (e.g., human speech, noise) into the acoustic texture of sperm whale codas.
20
+
21
+ - Synthesizing novel "pseudocodas".
22
+
23
+ - Providing audio embeddings for downstream tasks such as social unit and spectral feature ("vowel") classification.
24
+
25
+ See our [NeurIPS 2025](https://openreview.net/pdf?id=IL1wvzOgqD) publication for more details.
26
+
27
+ ## Installation
28
+
29
+ 1. **Clone the repository:**
30
+ ```bash
31
+ git clone https://github.com/Project-CETI/wham.git
32
+ cd wham
33
+ ```
34
+
35
+ 2. **Set up the environment:**
36
+ ```bash
37
+ conda create -n wham python=3.9
38
+ conda activate wham
39
+ ```
40
+
41
+ 3. **Install dependencies:**
42
+ ```bash
43
+ # Install the wham package
44
+ pip install -e .
45
+
46
+ # Install VampNet
47
+ pip install -e ./vampnet
48
+
49
+ # Install madmom
50
+ pip install --no-build-isolation madmom
51
+
52
+ # Install ffmpeg
53
+ conda install -c conda-forge ffmpeg
54
+ ```
55
+
56
+ 4. **Download model weights:**
57
+ Download the [weights](https://zenodo.org/records/17633708) and extract to `vampnet/models/`.
58
+
59
+ ## Generation
60
+
61
+ To run WhAM locally and prompt it in your browser:
62
+
63
+ ```bash
64
+ python vampnet/app.py --args.load conf/interface.yml --Interface.device cuda
65
+ ```
66
+
67
+ This will provide you with a Gradio link to test WhAM on inputs of your choice.
68
+
69
+ ## Training Data
70
+
71
+ ![Training](assets/training.png "Training")
72
+
73
+ You only need to follow these to fine-tune your own version of WhAM. First, obtain the original VampNet weights by following the instructions in the ![original repo](https://github.com/hugofloresgarcia/vampnet/tree/ismir-2023). Download
74
+ c2f.pth and codec.pth and replace the weights you previously downloaded in `vampnet/models`.
75
+
76
+ Second, obtain data:
77
+
78
+ 1. **Domain adaptation data:**
79
+
80
+ - Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `vampnet/training_data/domain_adaptation`.
81
+
82
+ - Download audio samples from the [BirdSet Dataset](https://huggingface.co/datasets/DBD-research-group/BirdSet). Save these under the same directory
83
+
84
+ - Finally, download all samples from the [AudioSet Dataset](https://research.google.com/audioset/ontology/index.html) with the label `Animal` and once again save these into the directory
85
+
86
+ 3. **Species-specific finetuning:** Finetuning can be performed on the openly available **[Dominica Sperm Whale Project (DSWP)](https://huggingface.co/datasets/orrp/DSWP)** dataset, available on Hugging Face.
87
+
88
+
89
+ With data in hand, navigate into `vampnet` and perform Domain Adaptation:
90
+ ```bash
91
+ python vampnet/scripts/exp/fine_tune.py "training_data/domain_adaptation" domain_adapted && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/c2f.yml
92
+ ```
93
+
94
+ Then fine-tune the domain-adapted model. Create the config file with the command:
95
+
96
+ ```bash
97
+ python vampnet/scripts/exp/fine_tune.py "training_data/species_specific_finetuning" fine-tuned
98
+ ```
99
+
100
+ To select which weights you want to use as a checkpoint, change `fine_tune_checkpoint` in `conf/generated/fine-tuned/[c2f/coarse].yml` to `./runs/domain_adaptation/[coarse/c2f]/[checkpoint]/vampnets/weights.pth`. `[checkpoint]` can be `latest` in order to use the last saved checkpoint from the previous run, though it is recommended to manually verify the quality of generations over various checkpoints as overtraining can often cause degradation in audio quality, especially with smaller datasets. After making that change, run the command:
101
+
102
+ ```bash
103
+ python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/c2f.yml
104
+ ```
105
+
106
+ After following these steps, you should be able to generate audio via the browser by running:
107
+ ```bash
108
+ python app.py --args.load vampnet/conf/generated/fine-tuned/interface.yml
109
+ ```
110
+
111
+ **Note**: The coarse and fine weights can be trained separately if compute allows. In this case, you would call the two scripts:
112
+
113
+ ```bash
114
+ python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/coarse.yml
115
+ ```
116
+
117
+ ```bash
118
+ python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/c2f.yml
119
+ ```
120
+
121
+ After both are finished running, ensure that both resulting weights are copied into the same copy of WhAM.
122
+
123
+
124
+
125
+ ## Testing Data
126
+
127
+ 1. **Marine Mammel Data:**
128
+ Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `data/testing_data/marine_mammals/data/[SPECIES_NAME]`.
129
+ * `[SPECIES_NAME]` must match the species names found in `wham/generation/prompt_configs.py`.
130
+
131
+ 2. **Sperm Whale Codas:**
132
+ To evaluate on sperm whale codas, you can use the openly available [DSWP](https://huggingface.co/datasets/orrp/DSWP) dataset.
133
+
134
+ 3. Generate artifical beeps for experiments. `data/generate_beeps.sh`
135
+
136
+
137
+ ## Reproducing Paper Results
138
+ Note: Access to the DSWP+CETI annotated is required to reproduce all results; as of time of publication, only part of this data is publicly available. Still, we include the following code as it may be useful for researchers who may benefit from our evaluation pipeline.
139
+
140
+ ### 1. Downstream Classification Tasks
141
+ To reproduce **Table 1** (Classification Accuracies) and **Figure 7** (Ablation Study):
142
+
143
+ **Table 1 Results:**
144
+ ```bash
145
+ cd wham/embedding
146
+ ./downstream_tasks.sh
147
+ ```
148
+ * Runs all downstream classification tasks.
149
+ * **Baselines:** Run once.
150
+ * **Models (AVES, VampNet):** Run over 3 random seeds; reports mean and standard deviation.
151
+
152
+ **Figure 7 Results (Ablation):**
153
+ ```bash
154
+ cd wham/embedding
155
+ ./downstream_ablation.sh
156
+ ```
157
+ * Outputs accuracy scores for ablation variants (averaged across 3 seeds with error bars).
158
+
159
+ ### 2. Generative Metrics
160
+
161
+ **Figure 12: Frechet Audio Distance (FAD) Scores**
162
+ Calculate the distance between WhAM's generated results and real codas:
163
+ ```bash
164
+ # Calculate for all species
165
+ bash wham/generation/eval/calculate_FAD.sh
166
+
167
+ # Calculate for a single species
168
+ bash wham/generation/eval/calculate_FAD.sh [species_name]
169
+ ```
170
+ * *Runtime:* ~3 hours on an NVIDIA A10 GPU.
171
+
172
+ **Figure 3: FAD with Custom/BirdNET Embeddings**
173
+ To compare against other embeddings:
174
+ 1. Convert your `.wav` files to `.npy` embeddings.
175
+ 2. Place raw coda embeddings in: `data/testing_data/coda_embeddings`
176
+ 3. Place comparison embeddings in subfolders within: `data/testing_data/comparison_embeddings`
177
+ 4. Run:
178
+ ```bash
179
+ python wham/generation/eval/calculate_custom_fad.py
180
+ ```
181
+ *For BirdNET embeddings, refer to the [official repo](https://github.com/BirdNET-Team/BirdNET-Analyzer).*
182
+
183
+ **Table 2: Embedding Type Ablation**
184
+ Calculate distances between raw codas, denoised versions, and noise profiles:
185
+ ```bash
186
+ bash wham/generation/eval/FAD_ablation.sh
187
+ ```
188
+ * *Prerequisites:* Ensure `data/testing_data/ablation/noise` and `data/testing_data/ablation/denoised` are populated.
189
+ * *Runtime:* ~1.5 hours on an NVIDIA A10 GPU.
190
+
191
+ **Figure 13: Tokenizer Reconstruction**
192
+ Test the mean squared reconstruction error:
193
+ ```bash
194
+ bash wham/generation/eval/evaluate_tokenizer.sh
195
+ ```
196
+
197
+ ---
198
+
199
+ ## Citation
200
+
201
+ Please use the following citation if you use this code, model or data.
202
+
203
+ ```bibtex
204
+ @inproceedings{wham2025,
205
+ title={Towards A Translative Model of Sperm Whale Vocalization},
206
+ author={Orr Paradise, Pranav Muralikrishnan, Liangyuan Chen, Hugo Flores Garcia, Bryan Pardo, Roee Diamant, David F. Gruber, Shane Gero, Shafi Goldwasser},
207
+ booktitle={Advances in Neural Information Processing Systems 39: Annual Conference
208
+ on Neural Information Processing Systems 2025, NeurIPS 2025, San Diego, CA, USA},
209
+ year={2025}
210
+ }
211
+ ```
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
7
+ MODEL_DIR = os.path.join(REPO_ROOT, "vampnet", "models")
8
+ os.makedirs(MODEL_DIR, exist_ok=True)
9
+
10
+ MODEL_REPO = "anvitax/wham-weights"
11
+ WEIGHT_FILES = ["coarse.pth", "c2f.pth", "codec.pth", "wavebeat.pth"]
12
+
13
+ for fname in WEIGHT_FILES:
14
+ target = os.path.join(MODEL_DIR, fname)
15
+ if not os.path.exists(target):
16
+ print(f"Downloading {fname} from {MODEL_REPO}...")
17
+ hf_hub_download(repo_id=MODEL_REPO, filename=fname, local_dir=MODEL_DIR)
18
+ else:
19
+ print(f"Found {fname}")
20
+
21
+ sys.path.insert(0, os.path.join(REPO_ROOT, "vampnet"))
22
+
23
+ os.chdir(os.path.join(REPO_ROOT, "vampnet"))
24
+
25
+ try:
26
+ import spaces
27
+ device = "cpu"
28
+ except ImportError:
29
+ import torch
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ sys.argv = [
33
+ "app.py",
34
+ "--args.load", "conf/interface.yml",
35
+ "--Interface.device", device,
36
+ ]
37
+
38
+ exec(open("app.py").read())
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ argbind>=0.3.2
4
+ numpy<1.24
5
+ pydantic==2.10.6
6
+ huggingface_hub
7
+ loralib
8
+ torch_pitch_shift
9
+ soundfile
10
+ pydub
11
+ tqdm
12
+ Cython
13
+ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
14
+ lac @ git+https://github.com/hugofloresgarcia/lac.git
15
+ descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
16
+ pyharp
setup.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ with open("README.md") as f:
4
+ long_description = f.read()
5
+
6
+ setup(
7
+ name="wham",
8
+ version="0.0.1",
9
+ long_description=long_description,
10
+ long_description_content_type="text/markdown",
11
+ url="https://github.com/orrp/wam",
12
+ license="MIT",
13
+ packages=find_packages(),
14
+ package_dir={},
15
+ install_requires=[
16
+ "descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
17
+ "argbind",
18
+ "pandas",
19
+ "pathlib",
20
+ "pydub",
21
+ "ffmpeg-python",
22
+ "tqdm",
23
+ "scikit-learn",
24
+ "wandb",
25
+ "gdown", # For fetching large files from Google Drive
26
+ "soundfile",
27
+ "transformers",
28
+ "torch",
29
+ "Cython",
30
+ "fadtk",
31
+ "urllib3==2.0"
32
+ ],
33
+ )
vampnet/.gitignore ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/env.sh
108
+ venv/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ # Files created by experiments
131
+ output/
132
+ snapshot/
133
+ *.m4a
134
+ notebooks/scratch.ipynb
135
+ notebooks/inspect.ipynb
136
+ notebooks/effects.ipynb
137
+ notebooks/*.ipynb
138
+ notebooks/*.gif
139
+ notebooks/*.wav
140
+ notebooks/*.mp4
141
+ *runs/
142
+ boards/
143
+ samples/
144
+ *.ipynb
145
+
146
+ results.json
147
+ metrics.csv
148
+ mprofile_*
149
+ mem.png
150
+
151
+ results/
152
+ mprofile*
153
+ *.png
154
+ # do not ignore the test wav file
155
+ !tests/audio/short_test_audio.wav
156
+ !tests/audio/output.wav
157
+ */.DS_Store
158
+ .DS_Store
159
+ env.sh
160
+ _codebraid/
161
+ **/*.html
162
+ **/*.exec.md
163
+ flagged/
164
+ log.txt
165
+ ckpt/
166
+ .syncthing*
167
+ tests/assets/
168
+ archived/
169
+
170
+ scratch/
171
+
172
+ runs-archive
173
+ lyrebird-audiotools
174
+ lyrebird-audio-codec
175
+ samples-*/**
176
+
177
+ gradio-outputs/
178
+ models/
179
+ samples*/
180
+ models-all/
181
+ models.zip
182
+ .git-old
183
+ conf/generated/*
184
+ runs*/
185
+
186
+
187
+ gtzan.zip
188
+ .gtzan_emb_cache
189
+ runs
vampnet/.pre-commit-config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/asottile/reorder_python_imports
3
+ rev: v2.5.0
4
+ hooks:
5
+ - id: reorder-python-imports
6
+ - repo: https://github.com/psf/black
7
+ rev: 23.1.0
8
+ hooks:
9
+ - id: black
10
+ language_version: python3
11
+ - repo: https://github.com/pre-commit/pre-commit-hooks
12
+ rev: v4.0.1
13
+ hooks:
14
+ - id: end-of-file-fixer
15
+ - id: trailing-whitespace
vampnet/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
vampnet/README.md ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VampNet
2
+
3
+ This repository contains recipes for training generative music models on top of the Descript Audio Codec.
4
+
5
+ ## try `unloop`
6
+ you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
7
+
8
+ # Setting up
9
+
10
+ **Requires Python 3.9**.
11
+
12
+ you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
13
+
14
+ (for example, using conda)
15
+ ```bash
16
+ conda create -n vampnet python=3.9
17
+ conda activate vampnet
18
+ ```
19
+
20
+
21
+ install VampNet
22
+
23
+ ```bash
24
+ git clone https://github.com/hugofloresgarcia/vampnet.git
25
+ pip install -e ./vampnet
26
+ ```
27
+
28
+ ## A note on argbind
29
+ This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
30
+ Config files are stored in the `conf/` folder.
31
+
32
+ ## Getting the Pretrained Models
33
+
34
+ ### Licensing for Pretrained Models:
35
+ The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
36
+
37
+ Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
38
+
39
+
40
+ # Usage
41
+
42
+ ## Launching the Gradio Interface
43
+ You can launch a gradio UI to play with vampnet.
44
+
45
+ ```bash
46
+ python app.py --args.load conf/interface.yml --Interface.device cuda
47
+ ```
48
+
49
+ # Training / Fine-tuning
50
+
51
+ ## Training a model
52
+
53
+ To train a model, run the following script:
54
+
55
+ ```bash
56
+ python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
57
+ ```
58
+
59
+ You can edit `conf/vampnet.yml` to change the dataset paths or any training hyperparameters.
60
+
61
+ For coarse2fine models, you can use `conf/c2f.yml` as a starting configuration.
62
+
63
+ See `python scripts/exp/train.py -h` for a list of options.
64
+
65
+ ## Fine-tuning
66
+ To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
67
+ The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to launch the gradio interface.
68
+
69
+ ```bash
70
+ python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
71
+ ```
72
+
73
+ This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
74
+
75
+ The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
76
+
77
+ launch the coarse job:
78
+ ```bash
79
+ python scripts/exp/train.py --args.load conf/<fine_tune_name>/coarse.yml
80
+ ```
81
+
82
+ this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
83
+
84
+ launch the c2f job:
85
+ ```bash
86
+ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
87
+ ```
88
+
89
+ launch the interface:
90
+ ```bash
91
+ python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
92
+ ```
93
+
94
+
vampnet/app.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+ import yaml
4
+ import tempfile
5
+ import uuid
6
+ from dataclasses import dataclass, asdict
7
+
8
+ import numpy as np
9
+ import audiotools as at
10
+ import argbind
11
+
12
+ import gradio as gr
13
+ from vampnet.interface import Interface
14
+ from vampnet import mask as pmask
15
+
16
+ try:
17
+ import spaces
18
+ ZERO_GPU = True
19
+ except ImportError:
20
+ ZERO_GPU = False
21
+
22
+ def gpu(fn):
23
+ if ZERO_GPU:
24
+ return spaces.GPU(fn)
25
+ return fn
26
+
27
+ Interface = argbind.bind(Interface)
28
+ # AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
29
+
30
+ conf = argbind.parse_args()
31
+
32
+
33
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
34
+ def shift_pitch(signal, interval: int):
35
+ signal.samples = pitch_shift(
36
+ signal.samples,
37
+ shift=interval,
38
+ sample_rate=signal.sample_rate
39
+ )
40
+ return signal
41
+
42
+ def load_interface():
43
+ with argbind.scope(conf):
44
+ interface = Interface()
45
+ # loader = AudioLoader()
46
+ print(f"interface device is {interface.device}")
47
+ return interface
48
+
49
+
50
+
51
+
52
+ interface = load_interface()
53
+
54
+
55
+ OUT_DIR = Path("gradio-outputs")
56
+ OUT_DIR.mkdir(exist_ok=True, parents=True)
57
+
58
+
59
+ def load_audio(file):
60
+ print(file)
61
+ filepath = file.name
62
+ sig = at.AudioSignal.salient_excerpt(
63
+ filepath,
64
+ duration=interface.coarse.chunk_size_s
65
+ )
66
+ sig = interface.preprocess(sig)
67
+
68
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
69
+ out_dir.mkdir(parents=True, exist_ok=True)
70
+ sig.write(out_dir / "input.wav")
71
+ return sig.path_to_file
72
+
73
+
74
+ def load_example_audio():
75
+ return "./assets/example.wav"
76
+
77
+
78
+ @gpu
79
+ def _vamp(data, return_mask=False):
80
+ interface.to("cuda")
81
+
82
+ out_dir = OUT_DIR / str(uuid.uuid4())
83
+ out_dir.mkdir()
84
+ sig = at.AudioSignal(data[input_audio])
85
+ sig = interface.preprocess(sig)
86
+
87
+ loudness = sig.loudness()
88
+ print(f"input loudness is {loudness}")
89
+
90
+ if data[pitch_shift_amt] != 0:
91
+ sig = shift_pitch(sig, data[pitch_shift_amt])
92
+
93
+ z = interface.encode(sig)
94
+
95
+ ncc = data[n_conditioning_codebooks]
96
+
97
+ # build the mask
98
+ mask = pmask.linear_random(z, data[rand_mask_intensity])
99
+ mask = pmask.mask_and(
100
+ mask, pmask.inpaint(
101
+ z,
102
+ interface.s2t(data[prefix_s]),
103
+ interface.s2t(data[suffix_s])
104
+ )
105
+ )
106
+ mask = pmask.mask_and(
107
+ mask, pmask.periodic_mask(
108
+ z,
109
+ data[periodic_p],
110
+ data[periodic_w],
111
+ random_roll=True
112
+ )
113
+ )
114
+ if data[onset_mask_width] > 0:
115
+ mask = pmask.mask_or(
116
+ mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
117
+ )
118
+ if data[beat_mask_width] > 0:
119
+ beat_mask = interface.make_beat_mask(
120
+ sig,
121
+ after_beat_s=(data[beat_mask_width]/1000),
122
+ mask_upbeats=not data[beat_mask_downbeats],
123
+ )
124
+ mask = pmask.mask_and(mask, beat_mask)
125
+
126
+ # these should be the last two mask ops
127
+ mask = pmask.dropout(mask, data[dropout])
128
+ mask = pmask.codebook_unmask(mask, ncc)
129
+ mask = pmask.codebook_mask(mask, int(data[n_mask_codebooks]))
130
+
131
+
132
+
133
+ print(f"dropout {data[dropout]}")
134
+ print(f"masktemp {data[masktemp]}")
135
+ print(f"sampletemp {data[sampletemp]}")
136
+ print(f"top_p {data[top_p]}")
137
+ print(f"prefix_s {data[prefix_s]}")
138
+ print(f"suffix_s {data[suffix_s]}")
139
+ print(f"rand_mask_intensity {data[rand_mask_intensity]}")
140
+ print(f"num_steps {data[num_steps]}")
141
+ print(f"periodic_p {data[periodic_p]}")
142
+ print(f"periodic_w {data[periodic_w]}")
143
+ print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
144
+ print(f"use_coarse2fine {data[use_coarse2fine]}")
145
+ print(f"onset_mask_width {data[onset_mask_width]}")
146
+ print(f"beat_mask_width {data[beat_mask_width]}")
147
+ print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
148
+ print(f"stretch_factor {data[stretch_factor]}")
149
+ print(f"seed {data[seed]}")
150
+ print(f"pitch_shift_amt {data[pitch_shift_amt]}")
151
+ print(f"sample_cutoff {data[sample_cutoff]}")
152
+
153
+
154
+ _top_p = data[top_p] if data[top_p] > 0 else None
155
+ # save the mask as a txt file
156
+ np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
157
+
158
+ _seed = data[seed] if data[seed] > 0 else None
159
+ zv, mask_z = interface.coarse_vamp(
160
+ z,
161
+ mask=mask,
162
+ sampling_steps=data[num_steps],
163
+ mask_temperature=data[masktemp]*10,
164
+ sampling_temperature=data[sampletemp],
165
+ return_mask=True,
166
+ typical_filtering=data[typical_filtering],
167
+ typical_mass=data[typical_mass],
168
+ typical_min_tokens=data[typical_min_tokens],
169
+ top_p=_top_p,
170
+ gen_fn=interface.coarse.generate,
171
+ seed=_seed,
172
+ sample_cutoff=data[sample_cutoff],
173
+ )
174
+
175
+ if use_coarse2fine:
176
+ zv = interface.coarse_to_fine(
177
+ zv,
178
+ mask_temperature=data[masktemp]*10,
179
+ sampling_temperature=data[sampletemp],
180
+ mask=mask,
181
+ sampling_steps=data[num_steps],
182
+ sample_cutoff=data[sample_cutoff],
183
+ seed=_seed,
184
+ )
185
+
186
+ sig = interface.to_signal(zv).cpu()
187
+ print("done")
188
+
189
+ print(f"output loudness is {sig.loudness()}")
190
+ sig = sig.normalize(loudness)
191
+ print(f"normalized loudness is {sig.loudness()}")
192
+
193
+ sig.write(out_dir / "output.wav")
194
+
195
+ if return_mask:
196
+ mask = interface.to_signal(mask_z).cpu()
197
+ mask.write(out_dir / "mask.wav")
198
+ return sig.path_to_file, mask.path_to_file
199
+ else:
200
+ return sig.path_to_file
201
+
202
+ def vamp(data):
203
+ return _vamp(data, return_mask=True)
204
+
205
+ def api_vamp(data):
206
+ return _vamp(data, return_mask=False)
207
+
208
+ def save_vamp(data):
209
+ out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
210
+ out_dir.mkdir(parents=True, exist_ok=True)
211
+
212
+ sig_in = at.AudioSignal(data[input_audio])
213
+ sig_out = at.AudioSignal(data[output_audio])
214
+
215
+ sig_in.write(out_dir / "input.wav")
216
+ sig_out.write(out_dir / "output.wav")
217
+
218
+ _data = {
219
+ "masktemp": data[masktemp],
220
+ "sampletemp": data[sampletemp],
221
+ "top_p": data[top_p],
222
+ "prefix_s": data[prefix_s],
223
+ "suffix_s": data[suffix_s],
224
+ "rand_mask_intensity": data[rand_mask_intensity],
225
+ "num_steps": data[num_steps],
226
+ "notes": data[notes_text],
227
+ "periodic_period": data[periodic_p],
228
+ "periodic_width": data[periodic_w],
229
+ "n_conditioning_codebooks": data[n_conditioning_codebooks],
230
+ "use_coarse2fine": data[use_coarse2fine],
231
+ "stretch_factor": data[stretch_factor],
232
+ "seed": data[seed],
233
+ "samplecutoff": data[sample_cutoff],
234
+ }
235
+
236
+ # save with yaml
237
+ with open(out_dir / "data.yaml", "w") as f:
238
+ yaml.dump(_data, f)
239
+
240
+ import zipfile
241
+ zip_path = str(out_dir.with_suffix(".zip"))
242
+ with zipfile.ZipFile(zip_path, "w") as zf:
243
+ for file in out_dir.iterdir():
244
+ zf.write(file, file.name)
245
+
246
+ return f"saved! your save code is {out_dir.stem}", zip_path
247
+
248
+
249
+ @gpu
250
+ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
251
+ interface.to("cuda")
252
+
253
+ out_dir = OUT_DIR / str(uuid.uuid4())
254
+ out_dir.mkdir()
255
+ sig = at.AudioSignal(_input_audio)
256
+ sig = interface.preprocess(sig)
257
+
258
+ z = interface.encode(sig)
259
+
260
+ # build the mask
261
+ mask = pmask.linear_random(z, 1.0)
262
+ if _beat_mask_width > 0:
263
+ beat_mask = interface.make_beat_mask(
264
+ sig,
265
+ after_beat_s=(_beat_mask_width/1000),
266
+ )
267
+ mask = pmask.mask_and(mask, beat_mask)
268
+
269
+ # save the mask as a txt file
270
+ zv, mask_z = interface.coarse_vamp(
271
+ z,
272
+ mask=mask,
273
+ sampling_temperature=_sampletemp,
274
+ return_mask=True,
275
+ gen_fn=interface.coarse.generate,
276
+ )
277
+
278
+
279
+ zv = interface.coarse_to_fine(
280
+ zv,
281
+ sampling_temperature=_sampletemp,
282
+ mask=mask,
283
+ )
284
+
285
+ sig = interface.to_signal(zv).cpu()
286
+ print("done")
287
+
288
+ sig.write(out_dir / "output.wav")
289
+
290
+ return sig.path_to_file
291
+
292
+ with gr.Blocks() as demo:
293
+
294
+ with gr.Row():
295
+ with gr.Column():
296
+ gr.Markdown("# VampNet Audio Vamping")
297
+ gr.Markdown("""## Description:
298
+ This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
299
+ You can control the extent and nature of variation with a set of manual controls and presets.
300
+ Use this interface to experiment with different mask settings and explore the audio outputs.
301
+ """)
302
+
303
+ gr.Markdown("""
304
+ ## Instructions:
305
+ 1. You can start by uploading some audio, or by loading the example audio.
306
+ 2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
307
+ 3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
308
+ 4. Optionally, you can add some notes and save the result.
309
+ 5. You can also use the output as the new input and continue experimenting!
310
+ """)
311
+ with gr.Row():
312
+ with gr.Column():
313
+
314
+
315
+ manual_audio_upload = gr.File(
316
+ label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
317
+ file_types=["audio"]
318
+ )
319
+ load_example_audio_button = gr.Button("or load example audio")
320
+
321
+ input_audio = gr.Audio(
322
+ label="input audio",
323
+ interactive=False,
324
+ type="filepath",
325
+ )
326
+
327
+ audio_mask = gr.Audio(
328
+ label="audio mask (listen to this to hear the mask hints)",
329
+ interactive=False,
330
+ type="filepath",
331
+ )
332
+
333
+ # connect widgets
334
+ load_example_audio_button.click(
335
+ fn=load_example_audio,
336
+ inputs=[],
337
+ outputs=[ input_audio]
338
+ )
339
+
340
+ manual_audio_upload.change(
341
+ fn=load_audio,
342
+ inputs=[manual_audio_upload],
343
+ outputs=[ input_audio]
344
+ )
345
+
346
+ # mask settings
347
+ with gr.Column():
348
+
349
+
350
+ presets = {
351
+ "unconditional": {
352
+ "periodic_p": 0,
353
+ "onset_mask_width": 0,
354
+ "beat_mask_width": 0,
355
+ "beat_mask_downbeats": False,
356
+ },
357
+ "slight periodic variation": {
358
+ "periodic_p": 5,
359
+ "onset_mask_width": 5,
360
+ "beat_mask_width": 0,
361
+ "beat_mask_downbeats": False,
362
+ },
363
+ "moderate periodic variation": {
364
+ "periodic_p": 13,
365
+ "onset_mask_width": 5,
366
+ "beat_mask_width": 0,
367
+ "beat_mask_downbeats": False,
368
+ },
369
+ "strong periodic variation": {
370
+ "periodic_p": 17,
371
+ "onset_mask_width": 5,
372
+ "beat_mask_width": 0,
373
+ "beat_mask_downbeats": False,
374
+ },
375
+ "very strong periodic variation": {
376
+ "periodic_p": 21,
377
+ "onset_mask_width": 5,
378
+ "beat_mask_width": 0,
379
+ "beat_mask_downbeats": False,
380
+ },
381
+ "beat-driven variation": {
382
+ "periodic_p": 0,
383
+ "onset_mask_width": 0,
384
+ "beat_mask_width": 50,
385
+ "beat_mask_downbeats": False,
386
+ },
387
+ "beat-driven variation (downbeats only)": {
388
+ "periodic_p": 0,
389
+ "onset_mask_width": 0,
390
+ "beat_mask_width": 50,
391
+ "beat_mask_downbeats": True,
392
+ },
393
+ "beat-driven variation (downbeats only, strong)": {
394
+ "periodic_p": 0,
395
+ "onset_mask_width": 0,
396
+ "beat_mask_width": 20,
397
+ "beat_mask_downbeats": True,
398
+ },
399
+ }
400
+
401
+ preset = gr.Dropdown(
402
+ label="preset",
403
+ choices=list(presets.keys()),
404
+ value="strong periodic variation",
405
+ )
406
+ load_preset_button = gr.Button("load_preset")
407
+
408
+ with gr.Accordion("manual controls", open=True):
409
+ periodic_p = gr.Slider(
410
+ label="periodic prompt (0 - unconditional, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
411
+ minimum=0,
412
+ maximum=128,
413
+ step=1,
414
+ value=3,
415
+ )
416
+
417
+
418
+ onset_mask_width = gr.Slider(
419
+ label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
420
+ minimum=0,
421
+ maximum=100,
422
+ step=1,
423
+ value=5,
424
+ )
425
+
426
+ beat_mask_width = gr.Slider(
427
+ label="beat prompt (ms)",
428
+ minimum=0,
429
+ maximum=200,
430
+ value=0,
431
+ )
432
+ beat_mask_downbeats = gr.Checkbox(
433
+ label="beat mask downbeats only?",
434
+ value=False
435
+ )
436
+
437
+ n_mask_codebooks = gr.Number(
438
+ label="first upper codebook level to mask",
439
+ value=9,
440
+ )
441
+
442
+
443
+ with gr.Accordion("extras ", open=False):
444
+ pitch_shift_amt = gr.Slider(
445
+ label="pitch shift amount (semitones)",
446
+ minimum=-12,
447
+ maximum=12,
448
+ step=1,
449
+ value=0,
450
+ )
451
+
452
+ rand_mask_intensity = gr.Slider(
453
+ label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
454
+ minimum=0.0,
455
+ maximum=1.0,
456
+ value=1.0
457
+ )
458
+
459
+ periodic_w = gr.Slider(
460
+ label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
461
+ minimum=1,
462
+ maximum=20,
463
+ step=1,
464
+ value=1,
465
+ )
466
+ n_conditioning_codebooks = gr.Number(
467
+ label="number of conditioning codebooks. probably 0",
468
+ value=0,
469
+ precision=0,
470
+ )
471
+
472
+ stretch_factor = gr.Slider(
473
+ label="time stretch factor",
474
+ minimum=0,
475
+ maximum=64,
476
+ step=1,
477
+ value=1,
478
+ )
479
+
480
+ preset_outputs = {
481
+ periodic_p,
482
+ onset_mask_width,
483
+ beat_mask_width,
484
+ beat_mask_downbeats,
485
+ }
486
+
487
+ def load_preset(_preset):
488
+ return tuple(presets[_preset].values())
489
+
490
+ load_preset_button.click(
491
+ fn=load_preset,
492
+ inputs=[preset],
493
+ outputs=preset_outputs
494
+ )
495
+
496
+
497
+ with gr.Accordion("prefix/suffix prompts", open=False):
498
+ prefix_s = gr.Slider(
499
+ label="prefix hint length (seconds)",
500
+ minimum=0.0,
501
+ maximum=10.0,
502
+ value=0.0
503
+ )
504
+ suffix_s = gr.Slider(
505
+ label="suffix hint length (seconds)",
506
+ minimum=0.0,
507
+ maximum=10.0,
508
+ value=0.0
509
+ )
510
+
511
+ masktemp = gr.Slider(
512
+ label="mask temperature",
513
+ minimum=0.0,
514
+ maximum=100.0,
515
+ value=1.5
516
+ )
517
+ sampletemp = gr.Slider(
518
+ label="sample temperature",
519
+ minimum=0.1,
520
+ maximum=10.0,
521
+ value=1.0,
522
+ step=0.001
523
+ )
524
+
525
+
526
+
527
+ with gr.Accordion("sampling settings", open=False):
528
+ top_p = gr.Slider(
529
+ label="top p (0.0 = off)",
530
+ minimum=0.0,
531
+ maximum=1.0,
532
+ value=0.0
533
+ )
534
+ typical_filtering = gr.Checkbox(
535
+ label="typical filtering ",
536
+ value=False
537
+ )
538
+ typical_mass = gr.Slider(
539
+ label="typical mass (should probably stay between 0.1 and 0.5)",
540
+ minimum=0.01,
541
+ maximum=0.99,
542
+ value=0.15
543
+ )
544
+ typical_min_tokens = gr.Slider(
545
+ label="typical min tokens (should probably stay between 1 and 256)",
546
+ minimum=1,
547
+ maximum=256,
548
+ step=1,
549
+ value=64
550
+ )
551
+ sample_cutoff = gr.Slider(
552
+ label="sample cutoff",
553
+ minimum=0.0,
554
+ maximum=1.0,
555
+ value=0.5,
556
+ step=0.01
557
+ )
558
+
559
+ use_coarse2fine = gr.Checkbox(
560
+ label="use coarse2fine",
561
+ value=True,
562
+ visible=False
563
+ )
564
+
565
+ num_steps = gr.Slider(
566
+ label="number of steps (should normally be between 12 and 36)",
567
+ minimum=1,
568
+ maximum=128,
569
+ step=1,
570
+ value=36
571
+ )
572
+
573
+ dropout = gr.Slider(
574
+ label="mask dropout",
575
+ minimum=0.0,
576
+ maximum=1.0,
577
+ step=0.01,
578
+ value=0.0
579
+ )
580
+
581
+
582
+ seed = gr.Number(
583
+ label="seed (0 for random)",
584
+ value=0,
585
+ precision=0,
586
+ )
587
+
588
+
589
+
590
+ # mask settings
591
+ with gr.Column():
592
+
593
+ # lora_choice = gr.Dropdown(
594
+ # label="lora choice",
595
+ # choices=list(loras.keys()),
596
+ # value=LORA_NONE,
597
+ # visible=False
598
+ # )
599
+
600
+ vamp_button = gr.Button("generate (vamp)!!!")
601
+ output_audio = gr.Audio(
602
+ label="output audio",
603
+ interactive=False,
604
+ type="filepath"
605
+ )
606
+
607
+ notes_text = gr.Textbox(
608
+ label="type any notes about the generated audio here",
609
+ value="",
610
+ interactive=True
611
+ )
612
+ save_button = gr.Button("save vamp")
613
+ download_file = gr.File(
614
+ label="vamp to download will appear here",
615
+ interactive=False
616
+ )
617
+ use_as_input_button = gr.Button("use output as input")
618
+
619
+ thank_you = gr.Markdown("")
620
+
621
+
622
+ _inputs = {
623
+ input_audio,
624
+ num_steps,
625
+ masktemp,
626
+ sampletemp,
627
+ top_p,
628
+ prefix_s, suffix_s,
629
+ rand_mask_intensity,
630
+ periodic_p, periodic_w,
631
+ n_conditioning_codebooks,
632
+ dropout,
633
+ use_coarse2fine,
634
+ stretch_factor,
635
+ onset_mask_width,
636
+ typical_filtering,
637
+ typical_mass,
638
+ typical_min_tokens,
639
+ beat_mask_width,
640
+ beat_mask_downbeats,
641
+ seed,
642
+ # lora_choice,
643
+ n_mask_codebooks,
644
+ pitch_shift_amt,
645
+ sample_cutoff
646
+ }
647
+
648
+ # connect widgets
649
+ vamp_button.click(
650
+ fn=vamp,
651
+ inputs=_inputs,
652
+ outputs=[output_audio, audio_mask],
653
+ )
654
+
655
+ api_vamp_button = gr.Button("api vamp", visible=False)
656
+ api_vamp_button.click(
657
+ fn=api_vamp,
658
+ inputs=_inputs,
659
+ outputs=[output_audio],
660
+ api_name="vamp"
661
+ )
662
+
663
+ use_as_input_button.click(
664
+ fn=lambda x: x,
665
+ inputs=[output_audio],
666
+ outputs=[input_audio]
667
+ )
668
+
669
+ save_button.click(
670
+ fn=save_vamp,
671
+ inputs=_inputs | {notes_text, output_audio},
672
+ outputs=[thank_you, download_file]
673
+ )
674
+
675
+
676
+ # demo.launch(share=True, enable_queue=True, debug=True)
677
+ demo.launch(share=True, debug=True) # from wam: because enable_queue seems to not be supported?
vampnet/conf/c2f.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ VampNet.n_codebooks: 14
5
+ VampNet.n_conditioning_codebooks: 4
6
+
7
+ VampNet.embedding_dim: 1280
8
+ VampNet.n_layers: 16
9
+ VampNet.n_heads: 20
10
+
11
+ AudioDataset.duration: 3.0
12
+
13
+
14
+ AudioDataset.loudness_cutoff: -40.0
vampnet/conf/interface.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./models/coarse.pth
2
+ Interface.coarse2fine_ckpt: ./models/c2f.pth
3
+ Interface.codec_ckpt: ./models/codec.pth
4
+ Interface.coarse_chunk_size_s: 10
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+ # AudioLoader.sources:
9
+ # - /media/CHONK/null
10
+
vampnet/conf/lora/lora.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/vampnet.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioDataset.n_examples: 100000000
7
+ val/AudioDataset.n_examples: 500
8
+
9
+
10
+ NoamScheduler.warmup: 500
11
+
12
+ batch_size: 6
13
+ num_workers: 7
14
+ save_iters: [2000, 4000, 10000,20000, 40000, 100000]
15
+ sample_freq: 2000
16
+ val_freq: 1000
17
+
18
+ AdamW.lr: 0.0001
19
+
20
+ # let's us organize sound classes into folders and choose from those sound classes uniformly
21
+ AudioDataset.without_replacement: False
22
+ num_iters: 500000
vampnet/conf/vampnet.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ codec_ckpt: ./models/vampnet/codec.pth
3
+ save_path: ckpt
4
+
5
+ num_iters: 1000000000
6
+ save_iters: [10000, 50000, 100000, 300000, 500000]
7
+ val_idx: [0,1,2,3,4,5,6,7,8,9]
8
+ sample_freq: 10000
9
+ val_freq: 1000
10
+
11
+ batch_size: 8
12
+ num_workers: 10
13
+
14
+ # Optimization
15
+ amp: false
16
+
17
+ CrossEntropyLoss.label_smoothing: 0.1
18
+
19
+ AdamW.lr: 0.001
20
+
21
+ NoamScheduler.factor: 2.0
22
+ NoamScheduler.warmup: 10000
23
+
24
+ VampNet.vocab_size: 1024
25
+ VampNet.n_codebooks: 4
26
+ VampNet.n_conditioning_codebooks: 0
27
+ VampNet.r_cond_dim: 0
28
+ VampNet.noise_mode: mask
29
+ VampNet.embedding_dim: 1280
30
+ VampNet.n_layers: 20
31
+ VampNet.n_heads: 20
32
+ VampNet.flash_attn: false
33
+ VampNet.dropout: 0.1
34
+
35
+ AudioLoader.relative_path: ""
36
+ AudioDataset.loudness_cutoff: -30.0
37
+ AudioDataset.without_replacement: true
38
+ AudioLoader.shuffle: true
39
+
40
+ AudioDataset.duration: 10.0
41
+
42
+ train/AudioDataset.n_examples: 10000000
43
+ train/AudioLoader.sources:
44
+ - /media/CHONK/hugo/spotdl/audio-train
45
+
46
+ val/AudioDataset.n_examples: 2000
47
+ val/AudioLoader.sources:
48
+ - /media/CHONK/hugo/spotdl/audio-val
49
+
vampnet/scripts/exp/eval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import os
3
+ from functools import partial
4
+
5
+ from frechet_audio_distance import FrechetAudioDistance
6
+ import pandas
7
+ import argbind
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ import audiotools
12
+ from audiotools import AudioSignal
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def eval(
16
+ exp_dir: str = None,
17
+ baseline_key: str = "baseline",
18
+ audio_ext: str = ".wav",
19
+ ):
20
+ assert exp_dir is not None
21
+ exp_dir = Path(exp_dir)
22
+ assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
23
+
24
+ # set up our metrics
25
+ # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
26
+ # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
27
+ mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
28
+ frechet = FrechetAudioDistance(
29
+ use_pca=False,
30
+ use_activation=False,
31
+ verbose=True,
32
+ audio_load_worker=4,
33
+ )
34
+ frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # figure out what conditions we have
37
+ conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
38
+
39
+ assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
40
+ conditions.remove(baseline_key)
41
+
42
+ print(f"Found {len(conditions)} conditions in {exp_dir}")
43
+ print(f"conditions: {conditions}")
44
+
45
+ baseline_dir = exp_dir / baseline_key
46
+ baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
47
+
48
+ metrics = []
49
+ for condition in tqdm(conditions):
50
+ cond_dir = exp_dir / condition
51
+ cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
52
+
53
+ print(f"computing fad for {baseline_dir} and {cond_dir}")
54
+ frechet_score = frechet.score(baseline_dir, cond_dir)
55
+
56
+ # make sure we have the same number of files
57
+ num_files = min(len(baseline_files), len(cond_files))
58
+ baseline_files = baseline_files[:num_files]
59
+ cond_files = cond_files[:num_files]
60
+ assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
61
+
62
+ def process(baseline_file, cond_file):
63
+ # make sure the files match (same name)
64
+ assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
65
+
66
+ # load the files
67
+ baseline_sig = AudioSignal(str(baseline_file))
68
+ cond_sig = AudioSignal(str(cond_file))
69
+
70
+ cond_sig.resample(baseline_sig.sample_rate)
71
+ cond_sig.truncate_samples(baseline_sig.length)
72
+
73
+ # if our condition is inpainting, we need to trim the conditioning off
74
+ if "inpaint" in condition:
75
+ ctx_amt = float(condition.split("_")[-1])
76
+ ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
77
+ print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
78
+ cond_sig.trim(ctx_samples, ctx_samples)
79
+ baseline_sig.trim(ctx_samples, ctx_samples)
80
+
81
+ return {
82
+ # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
83
+ # "stft": stft_loss(baseline_sig, cond_sig).item(),
84
+ "mel": mel_loss(baseline_sig, cond_sig).item(),
85
+ "frechet": frechet_score,
86
+ # "visqol": vsq,
87
+ "condition": condition,
88
+ "file": baseline_file.stem,
89
+ }
90
+
91
+ print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
92
+ metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
93
+
94
+ metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
95
+
96
+
97
+ for mk in metric_keys:
98
+ stat = pandas.DataFrame(metrics)
99
+ stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
100
+ stat.to_csv(exp_dir / f"stats-{mk}.csv")
101
+
102
+ df = pandas.DataFrame(metrics)
103
+ df.to_csv(exp_dir / "metrics-all.csv", index=False)
104
+
105
+
106
+ if __name__ == "__main__":
107
+ args = argbind.parse_args()
108
+
109
+ with argbind.scope(args):
110
+ eval()
vampnet/scripts/exp/experiment.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ from typing import List
4
+ import tempfile
5
+ import subprocess
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ import torch
10
+
11
+ from vampnet.interface import Interface
12
+ from vampnet import mask as pmask
13
+ import audiotools as at
14
+
15
+ Interface: Interface = argbind.bind(Interface)
16
+
17
+
18
+
19
+ def calculate_bitrate(
20
+ interface, num_codebooks,
21
+ downsample_factor
22
+ ):
23
+ bit_width = 10
24
+ sr = interface.codec.sample_rate
25
+ hop = interface.codec.hop_size
26
+ rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
27
+ return rate
28
+
29
+ def baseline(sig, interface):
30
+ return interface.preprocess(sig)
31
+
32
+ def reconstructed(sig, interface):
33
+ return interface.to_signal(
34
+ interface.encode(sig)
35
+ )
36
+
37
+ def coarse2fine(sig, interface):
38
+ z = interface.encode(sig)
39
+ z = z[:, :interface.c2f.n_conditioning_codebooks, :]
40
+
41
+ z = interface.coarse_to_fine(z)
42
+ return interface.to_signal(z)
43
+
44
+ class CoarseCond:
45
+
46
+ def __init__(self, num_conditioning_codebooks, downsample_factor):
47
+ self.num_conditioning_codebooks = num_conditioning_codebooks
48
+ self.downsample_factor = downsample_factor
49
+
50
+ def __call__(self, sig, interface):
51
+ z = interface.encode(sig)
52
+ mask = pmask.full_mask(z)
53
+ mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
54
+ mask = pmask.periodic_mask(mask, self.downsample_factor)
55
+
56
+ zv = interface.coarse_vamp(z, mask)
57
+ zv = interface.coarse_to_fine(zv)
58
+ return interface.to_signal(zv)
59
+
60
+ def opus(sig, interface, bitrate=128):
61
+ sig = interface.preprocess(sig)
62
+
63
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
64
+ sig.write(f.name)
65
+
66
+ opus_name = Path(f.name).with_suffix(".opus")
67
+ # convert to opus
68
+ cmd = [
69
+ "ffmpeg", "-y", "-i", f.name,
70
+ "-c:a", "libopus",
71
+ "-b:a", f"{bitrate}",
72
+ opus_name
73
+ ]
74
+ subprocess.run(cmd, check=True)
75
+
76
+ # convert back to wav
77
+ output_name = Path(f"{f.name}-opus").with_suffix(".wav")
78
+ cmd = [
79
+ "ffmpeg", "-y", "-i", opus_name,
80
+ output_name
81
+ ]
82
+
83
+ subprocess.run(cmd, check=True)
84
+
85
+ sig = at.AudioSignal(
86
+ output_name,
87
+ sample_rate=sig.sample_rate
88
+ )
89
+ return sig
90
+
91
+ def mask_ratio_1_step(ratio=1.0):
92
+ def wrapper(sig, interface):
93
+ z = interface.encode(sig)
94
+ mask = pmask.linear_random(z, ratio)
95
+ zv = interface.coarse_vamp(
96
+ z,
97
+ mask,
98
+ sampling_steps=1,
99
+ )
100
+
101
+ return interface.to_signal(zv)
102
+ return wrapper
103
+
104
+ def num_sampling_steps(num_steps=1):
105
+ def wrapper(sig, interface: Interface):
106
+ z = interface.encode(sig)
107
+ mask = pmask.periodic_mask(z, 16)
108
+ zv = interface.coarse_vamp(
109
+ z,
110
+ mask,
111
+ sampling_steps=num_steps,
112
+ )
113
+
114
+ zv = interface.coarse_to_fine(zv)
115
+ return interface.to_signal(zv)
116
+ return wrapper
117
+
118
+ def beat_mask(ctx_time):
119
+ def wrapper(sig, interface):
120
+ beat_mask = interface.make_beat_mask(
121
+ sig,
122
+ before_beat_s=ctx_time/2,
123
+ after_beat_s=ctx_time/2,
124
+ invert=True
125
+ )
126
+
127
+ z = interface.encode(sig)
128
+
129
+ zv = interface.coarse_vamp(
130
+ z, beat_mask
131
+ )
132
+
133
+ zv = interface.coarse_to_fine(zv)
134
+ return interface.to_signal(zv)
135
+ return wrapper
136
+
137
+ def inpaint(ctx_time):
138
+ def wrapper(sig, interface: Interface):
139
+ z = interface.encode(sig)
140
+ mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
141
+
142
+ zv = interface.coarse_vamp(z, mask)
143
+ zv = interface.coarse_to_fine(zv)
144
+
145
+ return interface.to_signal(zv)
146
+ return wrapper
147
+
148
+ def token_noise(noise_amt):
149
+ def wrapper(sig, interface: Interface):
150
+ z = interface.encode(sig)
151
+ mask = pmask.random(z, noise_amt)
152
+ z = torch.where(
153
+ mask,
154
+ torch.randint_like(z, 0, interface.coarse.vocab_size),
155
+ z
156
+ )
157
+ return interface.to_signal(z)
158
+ return wrapper
159
+
160
+ EXP_REGISTRY = {}
161
+
162
+ EXP_REGISTRY["gen-compression"] = {
163
+ "baseline": baseline,
164
+ "reconstructed": reconstructed,
165
+ "coarse2fine": coarse2fine,
166
+ **{
167
+ f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
168
+ for (n, x) in (
169
+ (1, 1), # 1 codebook, no downsampling
170
+ (4, 4), # 4 codebooks, downsampled 4x
171
+ (4, 16), # 4 codebooks, downsampled 16x
172
+ (4, 32), # 4 codebooks, downsampled 16x
173
+ )
174
+ },
175
+ **{
176
+ f"token_noise_{x}": mask_ratio_1_step(ratio=x)
177
+ for x in [0.25, 0.5, 0.75]
178
+ },
179
+
180
+ }
181
+
182
+
183
+ EXP_REGISTRY["sampling-steps"] = {
184
+ # "codec": reconstructed,
185
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
186
+ }
187
+
188
+
189
+ EXP_REGISTRY["musical-sampling"] = {
190
+ **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
191
+ **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
192
+ }
193
+
194
+ @argbind.bind(without_prefix=True)
195
+ def main(
196
+ sources=[
197
+ "/media/CHONK/hugo/spotdl/val",
198
+ ],
199
+ output_dir: str = "./samples",
200
+ max_excerpts: int = 2000,
201
+ exp_type: str = "gen-compression",
202
+ seed: int = 0,
203
+ ext: str = [".mp3"],
204
+ ):
205
+ at.util.seed(seed)
206
+ interface = Interface()
207
+
208
+ output_dir = Path(output_dir)
209
+ output_dir.mkdir(exist_ok=True, parents=True)
210
+
211
+ from audiotools.data.datasets import AudioLoader, AudioDataset
212
+
213
+ loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
214
+ dataset = AudioDataset(loader,
215
+ sample_rate=interface.codec.sample_rate,
216
+ duration=interface.coarse.chunk_size_s,
217
+ n_examples=max_excerpts,
218
+ without_replacement=True,
219
+ )
220
+
221
+ if exp_type in EXP_REGISTRY:
222
+ SAMPLE_CONDS = EXP_REGISTRY[exp_type]
223
+ else:
224
+ raise ValueError(f"Unknown exp_type {exp_type}")
225
+
226
+
227
+ indices = list(range(max_excerpts))
228
+ random.shuffle(indices)
229
+ for i in tqdm(indices):
230
+ # if all our files are already there, skip
231
+ done = []
232
+ for name in SAMPLE_CONDS:
233
+ o_dir = Path(output_dir) / name
234
+ done.append((o_dir / f"{i}.wav").exists())
235
+ if all(done):
236
+ continue
237
+
238
+ sig = dataset[i]["signal"]
239
+ results = {
240
+ name: cond(sig, interface).cpu()
241
+ for name, cond in SAMPLE_CONDS.items()
242
+ }
243
+
244
+ for name, sig in results.items():
245
+ o_dir = Path(output_dir) / name
246
+ o_dir.mkdir(exist_ok=True, parents=True)
247
+
248
+ sig.write(o_dir / f"{i}.wav")
249
+
250
+ if __name__ == "__main__":
251
+ args = argbind.parse_args()
252
+
253
+ with argbind.scope(args):
254
+ main()
vampnet/scripts/exp/fine_tune.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argbind
2
+ from pathlib import Path
3
+ import yaml
4
+ from typing import List
5
+
6
+
7
+
8
+
9
+ """example output: (yaml)
10
+
11
+ """
12
+
13
+ @argbind.bind(without_prefix=True, positional=True)
14
+ def fine_tune(audio_files_or_folders: List[str], name: str):
15
+
16
+ conf_dir = Path("conf")
17
+ assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
18
+
19
+ conf_dir = conf_dir / "generated"
20
+ conf_dir.mkdir(exist_ok=True)
21
+
22
+ finetune_dir = conf_dir / name
23
+ finetune_dir.mkdir(exist_ok=True)
24
+
25
+ finetune_c2f_conf = {
26
+ "$include": ["conf/lora/lora.yml"],
27
+ "fine_tune": True,
28
+ "train/AudioLoader.sources": audio_files_or_folders,
29
+ "val/AudioLoader.sources": audio_files_or_folders,
30
+ "VampNet.n_codebooks": 14,
31
+ "VampNet.n_conditioning_codebooks": 4,
32
+ "VampNet.embedding_dim": 1280,
33
+ "VampNet.n_layers": 16,
34
+ "VampNet.n_heads": 20,
35
+ "AudioDataset.duration": 3.0,
36
+ "AudioDataset.loudness_cutoff": -40.0,
37
+ "save_path": f"./runs/{name}/c2f",
38
+ "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
+ }
40
+
41
+ finetune_coarse_conf = {
42
+ "$include": ["conf/lora/lora.yml"],
43
+ "fine_tune": True,
44
+ "train/AudioLoader.sources": audio_files_or_folders,
45
+ "val/AudioLoader.sources": audio_files_or_folders,
46
+ "save_path": f"./runs/{name}/coarse",
47
+ "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
+ }
49
+
50
+ interface_conf = {
51
+ "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
52
+
53
+ "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
+ "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
+
56
+ "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
+ "AudioLoader.sources": [audio_files_or_folders],
58
+ }
59
+
60
+ # save the confs
61
+ with open(finetune_dir / "c2f.yml", "w") as f:
62
+ yaml.dump(finetune_c2f_conf, f)
63
+
64
+ with open(finetune_dir / "coarse.yml", "w") as f:
65
+ yaml.dump(finetune_coarse_conf, f)
66
+
67
+ with open(finetune_dir / "interface.yml", "w") as f:
68
+ yaml.dump(interface_conf, f)
69
+
70
+
71
+ print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
72
+
73
+ if __name__ == "__main__":
74
+ args = argbind.parse_args()
75
+
76
+ with argbind.scope(args):
77
+ fine_tune()
78
+
79
+
80
+
81
+
vampnet/scripts/exp/train.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from dataclasses import dataclass
7
+
8
+ import argbind
9
+ import audiotools as at
10
+ import torch
11
+ import torch.nn as nn
12
+ from audiotools import AudioSignal
13
+ from audiotools.data import transforms
14
+ from einops import rearrange
15
+ from rich import pretty
16
+ from rich.traceback import install
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import vampnet
20
+ from vampnet.modules.transformer import VampNet
21
+ from vampnet.util import codebook_unflatten, codebook_flatten
22
+ from vampnet import mask as pmask
23
+ # from dac.model.dac import DAC
24
+ from lac.model.lac import LAC as DAC
25
+
26
+ from audiotools.ml.decorators import (
27
+ timer, Tracker, when
28
+ )
29
+
30
+ import loralib as lora
31
+
32
+ import torch._dynamo
33
+ torch._dynamo.config.verbose=True
34
+
35
+
36
+ # Enable cudnn autotuner to speed up training
37
+ # (can be altered by the funcs.seed function)
38
+ torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
39
+ # Uncomment to trade memory for speed.
40
+
41
+ # Install to make things look nice
42
+ warnings.filterwarnings("ignore", category=UserWarning)
43
+ pretty.install()
44
+ install()
45
+
46
+ # optim
47
+ Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
48
+ CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
49
+ AdamW = argbind.bind(torch.optim.AdamW)
50
+ NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
51
+
52
+ # transforms
53
+ filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
54
+ "BaseTransform",
55
+ "Compose",
56
+ "Choose",
57
+ ]
58
+ tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
59
+
60
+ # model
61
+ VampNet = argbind.bind(VampNet)
62
+
63
+
64
+ # data
65
+ AudioLoader = argbind.bind(at.datasets.AudioLoader)
66
+ AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
67
+
68
+ IGNORE_INDEX = -100
69
+
70
+
71
+ @argbind.bind("train", "val", without_prefix=True)
72
+ def build_transform():
73
+ transform = transforms.Compose(
74
+ tfm.VolumeNorm(("const", -24)),
75
+ # tfm.PitchShift(),
76
+ tfm.RescaleAudio(),
77
+ )
78
+ return transform
79
+
80
+
81
+ @torch.no_grad()
82
+ def apply_transform(transform_fn, batch):
83
+ sig: AudioSignal = batch["signal"]
84
+ kwargs = batch["transform_args"]
85
+
86
+ sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
87
+ return sig
88
+
89
+
90
+ def build_datasets(args, sample_rate: int):
91
+ with argbind.scope(args, "train"):
92
+ train_data = AudioDataset(
93
+ AudioLoader(), sample_rate, transform=build_transform()
94
+ )
95
+ with argbind.scope(args, "val"):
96
+ val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
97
+ return train_data, val_data
98
+
99
+
100
+ def rand_float(shape, low, high, rng):
101
+ return rng.draw(shape)[:, 0] * (high - low) + low
102
+
103
+
104
+ def flip_coin(shape, p, rng):
105
+ return rng.draw(shape)[:, 0] < p
106
+
107
+
108
+ def num_params_hook(o, p):
109
+ return o + f" {p/1e6:<.3f}M params."
110
+
111
+
112
+ def add_num_params_repr_hook(model):
113
+ import numpy as np
114
+ from functools import partial
115
+
116
+ for n, m in model.named_modules():
117
+ o = m.extra_repr()
118
+ p = sum([np.prod(p.size()) for p in m.parameters()])
119
+
120
+ setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
121
+
122
+
123
+ def accuracy(
124
+ preds: torch.Tensor,
125
+ target: torch.Tensor,
126
+ top_k: int = 1,
127
+ ignore_index: Optional[int] = None,
128
+ ) -> torch.Tensor:
129
+ # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
130
+ preds = rearrange(preds, "b p s -> (b s) p")
131
+ target = rearrange(target, "b s -> (b s)")
132
+
133
+ # return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
134
+ if ignore_index is not None:
135
+ # Create a mask for the ignored index
136
+ mask = target != ignore_index
137
+ # Apply the mask to the target and predictions
138
+ preds = preds[mask]
139
+ target = target[mask]
140
+
141
+ # Get the top-k predicted classes and their indices
142
+ _, pred_indices = torch.topk(preds, k=top_k, dim=-1)
143
+
144
+ # Determine if the true target is in the top-k predicted classes
145
+ correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
146
+
147
+ # Calculate the accuracy
148
+ accuracy = torch.mean(correct.float())
149
+
150
+ return accuracy
151
+
152
+ def _metrics(z_hat, r, target, flat_mask, output):
153
+ for r_range in [(0, 0.5), (0.5, 1.0)]:
154
+ unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
155
+ masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
156
+
157
+ assert target.shape[0] == r.shape[0]
158
+ # grab the indices of the r values that are in the range
159
+ r_idx = (r >= r_range[0]) & (r < r_range[1])
160
+
161
+ # grab the target and z_hat values that are in the range
162
+ r_unmasked_target = unmasked_target[r_idx]
163
+ r_masked_target = masked_target[r_idx]
164
+ r_z_hat = z_hat[r_idx]
165
+
166
+ for topk in (1, 25):
167
+ s, e = r_range
168
+ tag = f"accuracy-{s}-{e}/top{topk}"
169
+
170
+ output[f"{tag}/unmasked"] = accuracy(
171
+ preds=r_z_hat,
172
+ target=r_unmasked_target,
173
+ ignore_index=IGNORE_INDEX,
174
+ top_k=topk,
175
+ )
176
+ output[f"{tag}/masked"] = accuracy(
177
+ preds=r_z_hat,
178
+ target=r_masked_target,
179
+ ignore_index=IGNORE_INDEX,
180
+ top_k=topk,
181
+ )
182
+
183
+
184
+ @dataclass
185
+ class State:
186
+ model: VampNet
187
+ codec: DAC
188
+
189
+ optimizer: AdamW
190
+ scheduler: NoamScheduler
191
+ criterion: CrossEntropyLoss
192
+ grad_clip_val: float
193
+
194
+ rng: torch.quasirandom.SobolEngine
195
+
196
+ train_data: AudioDataset
197
+ val_data: AudioDataset
198
+
199
+ tracker: Tracker
200
+
201
+
202
+ @timer()
203
+ def train_loop(state: State, batch: dict, accel: Accelerator):
204
+ state.model.train()
205
+ batch = at.util.prepare_batch(batch, accel.device)
206
+ signal = apply_transform(state.train_data.transform, batch)
207
+
208
+ output = {}
209
+ vn = accel.unwrap(state.model)
210
+ with accel.autocast():
211
+ with torch.inference_mode():
212
+ state.codec.to(accel.device)
213
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
214
+ z = z[:, : vn.n_codebooks, :]
215
+
216
+ n_batch = z.shape[0]
217
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
218
+
219
+ mask = pmask.random(z, r)
220
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
221
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
222
+
223
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
224
+
225
+ dtype = torch.bfloat16 if accel.amp else None
226
+ with accel.autocast(dtype=dtype):
227
+ z_hat = state.model(z_mask_latent)
228
+
229
+ target = codebook_flatten(
230
+ z[:, vn.n_conditioning_codebooks :, :],
231
+ )
232
+
233
+ flat_mask = codebook_flatten(
234
+ mask[:, vn.n_conditioning_codebooks :, :],
235
+ )
236
+
237
+ # replace target with ignore index for masked tokens
238
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
239
+ output["loss"] = state.criterion(z_hat, t_masked)
240
+
241
+ _metrics(
242
+ r=r,
243
+ z_hat=z_hat,
244
+ target=target,
245
+ flat_mask=flat_mask,
246
+ output=output,
247
+ )
248
+
249
+
250
+ accel.backward(output["loss"])
251
+
252
+ output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
253
+ output["other/batch_size"] = z.shape[0]
254
+
255
+
256
+ accel.scaler.unscale_(state.optimizer)
257
+ output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
258
+ state.model.parameters(), state.grad_clip_val
259
+ )
260
+
261
+ accel.step(state.optimizer)
262
+ state.optimizer.zero_grad()
263
+
264
+ state.scheduler.step()
265
+ accel.update()
266
+
267
+
268
+ return {k: v for k, v in sorted(output.items())}
269
+
270
+
271
+ @timer()
272
+ @torch.no_grad()
273
+ def val_loop(state: State, batch: dict, accel: Accelerator):
274
+ state.model.eval()
275
+ state.codec.eval()
276
+ batch = at.util.prepare_batch(batch, accel.device)
277
+ signal = apply_transform(state.val_data.transform, batch)
278
+
279
+ vn = accel.unwrap(state.model)
280
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
281
+ z = z[:, : vn.n_codebooks, :]
282
+
283
+ n_batch = z.shape[0]
284
+ r = state.rng.draw(n_batch)[:, 0].to(accel.device)
285
+
286
+ mask = pmask.random(z, r)
287
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
288
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
289
+
290
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
291
+
292
+ z_hat = state.model(z_mask_latent)
293
+
294
+ target = codebook_flatten(
295
+ z[:, vn.n_conditioning_codebooks :, :],
296
+ )
297
+
298
+ flat_mask = codebook_flatten(
299
+ mask[:, vn.n_conditioning_codebooks :, :]
300
+ )
301
+
302
+ output = {}
303
+ # replace target with ignore index for masked tokens
304
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
305
+ output["loss"] = state.criterion(z_hat, t_masked)
306
+
307
+ _metrics(
308
+ r=r,
309
+ z_hat=z_hat,
310
+ target=target,
311
+ flat_mask=flat_mask,
312
+ output=output,
313
+ )
314
+
315
+ return output
316
+
317
+
318
+ def validate(state, val_dataloader, accel):
319
+ for batch in val_dataloader:
320
+ output = val_loop(state, batch, accel)
321
+ # Consolidate state dicts if using ZeroRedundancyOptimizer
322
+ if hasattr(state.optimizer, "consolidate_state_dict"):
323
+ state.optimizer.consolidate_state_dict()
324
+ return output
325
+
326
+
327
+ def checkpoint(state, save_iters, save_path, fine_tune):
328
+ if accel.local_rank != 0:
329
+ state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
330
+ return
331
+
332
+ metadata = {"logs": dict(state.tracker.history)}
333
+
334
+ tags = ["latest"]
335
+ state.tracker.print(f"Saving to {str(Path('.').absolute())}")
336
+
337
+ if state.tracker.step in save_iters:
338
+ tags.append(f"{state.tracker.step // 1000}k")
339
+
340
+ if state.tracker.is_best("val", "loss"):
341
+ state.tracker.print(f"Best model so far")
342
+ tags.append("best")
343
+
344
+ if fine_tune:
345
+ for tag in tags:
346
+ # save the lora model
347
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
348
+ torch.save(
349
+ lora.lora_state_dict(accel.unwrap(state.model)),
350
+ f"{save_path}/{tag}/lora.pth"
351
+ )
352
+
353
+ for tag in tags:
354
+ model_extra = {
355
+ "optimizer.pth": state.optimizer.state_dict(),
356
+ "scheduler.pth": state.scheduler.state_dict(),
357
+ "tracker.pth": state.tracker.state_dict(),
358
+ "metadata.pth": metadata,
359
+ }
360
+
361
+ accel.unwrap(state.model).metadata = metadata
362
+ accel.unwrap(state.model).save_to_folder(
363
+ f"{save_path}/{tag}", model_extra, package=False
364
+ )
365
+
366
+
367
+ def save_sampled(state, z, writer):
368
+ num_samples = z.shape[0]
369
+
370
+ for i in range(num_samples):
371
+ sampled = accel.unwrap(state.model).generate(
372
+ codec=state.codec,
373
+ time_steps=z.shape[-1],
374
+ start_tokens=z[i : i + 1],
375
+ )
376
+ sampled.cpu().write_audio_to_tb(
377
+ f"sampled/{i}",
378
+ writer,
379
+ step=state.tracker.step,
380
+ plot_fn=None,
381
+ )
382
+
383
+
384
+ def save_imputation(state, z, val_idx, writer):
385
+ n_prefix = int(z.shape[-1] * 0.25)
386
+ n_suffix = int(z.shape[-1] * 0.25)
387
+
388
+ vn = accel.unwrap(state.model)
389
+
390
+ mask = pmask.inpaint(z, n_prefix, n_suffix)
391
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
392
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
393
+
394
+ imputed_noisy = vn.to_signal(z_mask, state.codec)
395
+ imputed_true = vn.to_signal(z, state.codec)
396
+
397
+ imputed = []
398
+ for i in range(len(z)):
399
+ imputed.append(
400
+ vn.generate(
401
+ codec=state.codec,
402
+ time_steps=z.shape[-1],
403
+ start_tokens=z[i][None, ...],
404
+ mask=mask[i][None, ...],
405
+ )
406
+ )
407
+ imputed = AudioSignal.batch(imputed)
408
+
409
+ for i in range(len(val_idx)):
410
+ imputed_noisy[i].cpu().write_audio_to_tb(
411
+ f"inpainted_prompt/{i}",
412
+ writer,
413
+ step=state.tracker.step,
414
+ plot_fn=None,
415
+ )
416
+ imputed[i].cpu().write_audio_to_tb(
417
+ f"inpainted_middle/{i}",
418
+ writer,
419
+ step=state.tracker.step,
420
+ plot_fn=None,
421
+ )
422
+ imputed_true[i].cpu().write_audio_to_tb(
423
+ f"reconstructed/{i}",
424
+ writer,
425
+ step=state.tracker.step,
426
+ plot_fn=None,
427
+ )
428
+
429
+
430
+ @torch.no_grad()
431
+ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
432
+ state.model.eval()
433
+ state.codec.eval()
434
+ vn = accel.unwrap(state.model)
435
+
436
+ batch = [state.val_data[i] for i in val_idx]
437
+ batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
438
+
439
+ signal = apply_transform(state.val_data.transform, batch)
440
+
441
+ z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
442
+ z = z[:, : vn.n_codebooks, :]
443
+
444
+ r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
445
+
446
+
447
+ mask = pmask.random(z, r)
448
+ mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
449
+ z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
450
+
451
+ z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
452
+
453
+ z_hat = state.model(z_mask_latent)
454
+
455
+ z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
456
+ z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
457
+ z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
458
+
459
+ generated = vn.to_signal(z_pred, state.codec)
460
+ reconstructed = vn.to_signal(z, state.codec)
461
+ masked = vn.to_signal(z_mask.squeeze(1), state.codec)
462
+
463
+ for i in range(generated.batch_size):
464
+ audio_dict = {
465
+ "original": signal[i],
466
+ "masked": masked[i],
467
+ "generated": generated[i],
468
+ "reconstructed": reconstructed[i],
469
+ }
470
+ for k, v in audio_dict.items():
471
+ v.cpu().write_audio_to_tb(
472
+ f"onestep/_{i}.r={r[i]:0.2f}/{k}",
473
+ writer,
474
+ step=state.tracker.step,
475
+ plot_fn=None,
476
+ )
477
+
478
+ save_sampled(state=state, z=z, writer=writer)
479
+ save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
480
+
481
+
482
+
483
+ @argbind.bind(without_prefix=True)
484
+ def load(
485
+ args,
486
+ accel: at.ml.Accelerator,
487
+ tracker: Tracker,
488
+ save_path: str,
489
+ resume: bool = False,
490
+ tag: str = "latest",
491
+ fine_tune_checkpoint: Optional[str] = None,
492
+ grad_clip_val: float = 5.0,
493
+ ) -> State:
494
+ codec = DAC.load(args["codec_ckpt"], map_location="cpu")
495
+ codec.eval()
496
+
497
+ model, v_extra = None, {}
498
+
499
+ if args["fine_tune"]:
500
+ assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
501
+ model = torch.compile(
502
+ VampNet.load(location=Path(fine_tune_checkpoint),
503
+ map_location="cpu",
504
+ )
505
+ )
506
+
507
+ if resume:
508
+ kwargs = {
509
+ "folder": f"{save_path}/{tag}",
510
+ "map_location": "cpu",
511
+ "package": False,
512
+ }
513
+ tracker.print(f"Loading checkpoint from {kwargs['folder']}")
514
+ if (Path(kwargs["folder"]) / "vampnet").exists():
515
+ model, v_extra = VampNet.load_from_folder(**kwargs)
516
+ else:
517
+ raise ValueError(
518
+ f"Could not find a VampNet checkpoint in {kwargs['folder']}"
519
+ )
520
+
521
+
522
+
523
+
524
+ model = torch.compile(VampNet()) if model is None else model
525
+ model = accel.prepare_model(model)
526
+
527
+ # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
528
+ assert (
529
+ accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
530
+ )
531
+
532
+ optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
533
+ scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
534
+ scheduler.step()
535
+
536
+ if "optimizer.pth" in v_extra:
537
+ optimizer.load_state_dict(v_extra["optimizer.pth"])
538
+ scheduler.load_state_dict(v_extra["scheduler.pth"])
539
+ if "tracker.pth" in v_extra:
540
+ tracker.load_state_dict(v_extra["tracker.pth"])
541
+
542
+ criterion = CrossEntropyLoss()
543
+
544
+ sample_rate = codec.sample_rate
545
+
546
+ # a better rng for sampling from our schedule
547
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
548
+
549
+ # log a model summary w/ num params
550
+ if accel.local_rank == 0:
551
+ add_num_params_repr_hook(accel.unwrap(model))
552
+ with open(f"{save_path}/model.txt", "w") as f:
553
+ f.write(repr(accel.unwrap(model)))
554
+
555
+ # load the datasets
556
+ train_data, val_data = build_datasets(args, sample_rate)
557
+
558
+ return State(
559
+ tracker=tracker,
560
+ model=model,
561
+ codec=codec,
562
+ optimizer=optimizer,
563
+ scheduler=scheduler,
564
+ criterion=criterion,
565
+ rng=rng,
566
+ train_data=train_data,
567
+ val_data=val_data,
568
+ grad_clip_val=grad_clip_val,
569
+ )
570
+
571
+
572
+ @argbind.bind(without_prefix=True)
573
+ def train(
574
+ args,
575
+ accel: at.ml.Accelerator,
576
+ seed: int = 0,
577
+ codec_ckpt: str = None,
578
+ save_path: str = "ckpt",
579
+ num_iters: int = int(1000e6),
580
+ save_iters: list = [10000, 50000, 100000, 300000, 500000,],
581
+ sample_freq: int = 10000,
582
+ val_freq: int = 1000,
583
+ batch_size: int = 12,
584
+ val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
585
+ num_workers: int = 10,
586
+ fine_tune: bool = False,
587
+ ):
588
+ assert codec_ckpt is not None, "codec_ckpt is required"
589
+
590
+ seed = seed + accel.local_rank
591
+ at.util.seed(seed)
592
+ writer = None
593
+
594
+ if accel.local_rank == 0:
595
+ writer = SummaryWriter(log_dir=f"{save_path}/logs/")
596
+ argbind.dump_args(args, f"{save_path}/args.yml")
597
+
598
+ tracker = Tracker(
599
+ writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
600
+ )
601
+
602
+ # load the codec model
603
+ state: State = load(
604
+ args=args,
605
+ accel=accel,
606
+ tracker=tracker,
607
+ save_path=save_path)
608
+ print("initialized state.")
609
+
610
+ train_dataloader = accel.prepare_dataloader(
611
+ state.train_data,
612
+ start_idx=state.tracker.step * batch_size,
613
+ num_workers=num_workers,
614
+ batch_size=batch_size,
615
+ collate_fn=state.train_data.collate,
616
+ )
617
+ val_dataloader = accel.prepare_dataloader(
618
+ state.val_data,
619
+ start_idx=0,
620
+ num_workers=num_workers,
621
+ batch_size=batch_size,
622
+ collate_fn=state.val_data.collate,
623
+ persistent_workers=num_workers > 0,
624
+ )
625
+ print("initialized dataloader.")
626
+
627
+
628
+
629
+ if fine_tune:
630
+ lora.mark_only_lora_as_trainable(state.model)
631
+ print("marked only lora as trainable.")
632
+
633
+ # Wrap the functions so that they neatly track in TensorBoard + progress bars
634
+ # and only run when specific conditions are met.
635
+ global train_loop, val_loop, validate, save_samples, checkpoint
636
+
637
+ train_loop = tracker.log("train", "value", history=False)(
638
+ tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
639
+ )
640
+ val_loop = tracker.track("val", len(val_dataloader))(val_loop)
641
+ validate = tracker.log("val", "mean")(validate)
642
+
643
+ save_samples = when(lambda: accel.local_rank == 0)(save_samples)
644
+ checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
645
+
646
+ print("starting training loop.")
647
+ with tracker.live:
648
+ for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
649
+ train_loop(state, batch, accel)
650
+
651
+ last_iter = (
652
+ tracker.step == num_iters - 1 if num_iters is not None else False
653
+ )
654
+
655
+ if tracker.step % sample_freq == 0 or last_iter:
656
+ save_samples(state, val_idx, writer)
657
+
658
+ if tracker.step % val_freq == 0 or last_iter:
659
+ validate(state, val_dataloader, accel)
660
+ checkpoint(
661
+ state=state,
662
+ save_iters=save_iters,
663
+ save_path=save_path,
664
+ fine_tune=fine_tune)
665
+
666
+ # Reset validation progress bar, print summary since last validation.
667
+ tracker.done("val", f"Iteration {tracker.step}")
668
+
669
+ if last_iter:
670
+ break
671
+
672
+
673
+ if __name__ == "__main__":
674
+ args = argbind.parse_args()
675
+ args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
676
+ with argbind.scope(args):
677
+ with Accelerator() as accel:
678
+ if accel.local_rank != 0:
679
+ sys.tracebacklimit = 0
680
+ train(args, accel)
vampnet/scripts/utils/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Scripts
2
+
3
+ ## process_zip.py
4
+
5
+ Some requirements that may not be installed in the docker image:
6
+ * argbind
7
+ * wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
8
+
9
+ ### zip folder structure
10
+
11
+ The zip folder should have the following internal structure:
12
+
13
+ ```
14
+ base_folder/
15
+ test_case_1/
16
+ before.wav
17
+ test_case_2/
18
+ before.wav
19
+ ...
20
+ test_case_n/
21
+ before.wav
22
+ ```
23
+
24
+ Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
25
+ https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
26
+
27
+ ### Execution
28
+ `python process_zip.py <path/to/zip> -tag <string>`
vampnet/scripts/utils/data/augment.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import audiotools as at
4
+ from audiotools import AudioSignal
5
+
6
+ import argbind
7
+ import tqdm
8
+ import torch
9
+
10
+
11
+ from torch_pitch_shift import pitch_shift, get_fast_shifts
12
+ from torch_time_stretch import time_stretch, get_fast_stretches
13
+
14
+ from audiotools.core.util import sample_from_dist
15
+
16
+
17
+ @argbind.bind(without_prefix=True)
18
+ def augment(
19
+ audio_folder: Path = None,
20
+ dest_folder: Path = None,
21
+ n_augmentations: int = 10,
22
+ ):
23
+ """
24
+ Augment a folder of audio files by applying audiotools and pedalboard transforms.
25
+
26
+ The dest foler will contain a folder for each of the clean dataset's files.
27
+ Under each of these folders, there will be a clean file and many augmented files.
28
+ """
29
+ assert audio_folder is not None
30
+ assert dest_folder is not None
31
+ audio_files = at.util.find_audio(audio_folder)
32
+
33
+ for audio_file in tqdm.tqdm(audio_files):
34
+ subtree = dest_folder / audio_file.relative_to(audio_folder).parent
35
+ subdir = subtree / audio_file.stem
36
+ subdir.mkdir(parents=True, exist_ok=True)
37
+
38
+ src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+
41
+ for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
+ # apply pedalboard transforms
43
+ for j in range(n_augmentations):
44
+ # pitch shift between -7 and 7 semitones
45
+ import random
46
+ dst = chunk.clone()
47
+ dst.samples = pitch_shift(
48
+ dst.samples,
49
+ shift=random.choice(get_fast_shifts(src.sample_rate,
50
+ condition=lambda x: x >= 0.25 and x <= 1.0)),
51
+ sample_rate=src.sample_rate
52
+ )
53
+ dst.samples = time_stretch(
54
+ dst.samples,
55
+ stretch=random.choice(get_fast_stretches(src.sample_rate,
56
+ condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
+ sample_rate=src.sample_rate,
58
+ )
59
+
60
+ dst.cpu().write(subdir / f"{i}-{j}.wav")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ args = argbind.parse_args()
65
+
66
+ with argbind.scope(args):
67
+ augment()
vampnet/scripts/utils/data/maestro-reorg.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import os
4
+
5
+ maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
6
+ output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
7
+
8
+ # split
9
+ with open(maestro_path / "maestro-v3.0.0.json") as f:
10
+ maestro = json.load(f)
11
+
12
+ breakpoint()
13
+ train = []
14
+ validation = []
15
+ test = []
16
+ for key, split in maestro["split"].items():
17
+ audio_filename = maestro['audio_filename'][key]
18
+ if split == "train":
19
+ train.append(audio_filename)
20
+ elif split == "test":
21
+ test.append(audio_filename)
22
+ elif split == "validation":
23
+ validation.append(audio_filename)
24
+ else:
25
+ raise ValueError(f"Unknown split {split}")
26
+
27
+ # symlink all files
28
+ for audio_filename in train:
29
+ p = output_path / "train" / audio_filename
30
+ p.parent.mkdir(parents=True, exist_ok=True)
31
+ os.symlink(maestro_path / audio_filename, p)
32
+ for audio_filename in validation:
33
+ p = output_path / "validation" / audio_filename
34
+ p.parent.mkdir(parents=True, exist_ok=True)
35
+ os.symlink(maestro_path / audio_filename, p)
36
+ for audio_filename in test:
37
+ p = output_path / "test" / audio_filename
38
+ p.parent.mkdir(parents=True, exist_ok=True)
39
+ os.symlink(maestro_path / audio_filename, p)
vampnet/scripts/utils/plots.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ from pandas.api.types import CategoricalDtype
4
+
5
+ def plot_metrics(metrics, condition_to_latex, title, color_palette):
6
+ # Add a new column to your dataframe with the latex representation
7
+ metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
8
+
9
+ # Order condition_latex as per the condition_to_latex dictionary
10
+ cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
11
+ metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
12
+
13
+ # Compute mean and std for each condition for each metric
14
+ grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
15
+
16
+ fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
17
+
18
+ # Set the main title for the figure
19
+ fig.suptitle(title, fontsize=16)
20
+
21
+ # Get color for each bar in the plot
22
+ bar_colors = [color_palette[condition] for condition in grouped.index]
23
+
24
+ # Plot mel
25
+ sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
26
+ axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
27
+ axs[0].set_xlabel('') # Remove x-axis label
28
+ axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
29
+
30
+ # Plot frechet
31
+ axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
32
+ axs[1].set_ylabel('FAD \u2190')
33
+ axs[1].set_xlabel('') # Remove x-axis label
34
+ axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
35
+
36
+ # Adjust the space between plots
37
+ plt.subplots_adjust(hspace=0.1)
38
+
39
+ # Remove any unnecessary space around the plot
40
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
41
+
42
+ # Reduce the space between suptitle and the plot
43
+ plt.subplots_adjust(top=0.92)
vampnet/scripts/utils/remove_quiet_files.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # removes files with loudness below 24db
2
+
3
+ from pathlib import Path
4
+ import shutil
5
+ import audiotools as at
6
+ import argbind
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def remove_quiet_files(
10
+ src_dir: Path = None,
11
+ dest_dir: Path = None,
12
+ min_loudness: float = -30,
13
+ ):
14
+ # copy src to dest
15
+ dest_dir.mkdir(parents=True, exist_ok=True)
16
+ shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
+
18
+ audio_files = at.util.find_audio(dest_dir)
19
+ for audio_file in audio_files:
20
+ sig = at.AudioSignal(audio_file)
21
+ if sig.loudness() < min_loudness:
22
+ audio_file.unlink()
23
+ print(f"removed {audio_file}")
24
+
25
+ if __name__ == "__main__":
26
+ args = argbind.parse_args()
27
+
28
+ with argbind.scope(args):
29
+ remove_quiet_files()
vampnet/scripts/utils/split.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ import shutil
4
+ import os
5
+ import json
6
+
7
+ import argbind
8
+ from tqdm import tqdm
9
+ from tqdm.contrib.concurrent import thread_map
10
+
11
+ from audiotools.core import util
12
+
13
+
14
+ @argbind.bind(without_prefix=True)
15
+ def train_test_split(
16
+ audio_folder: str = ".",
17
+ test_size: float = 0.2,
18
+ seed: int = 42,
19
+ pattern: str = "**/*.mp3",
20
+ ):
21
+ print(f"finding audio")
22
+
23
+ audio_folder = Path(audio_folder)
24
+ audio_files = list(tqdm(audio_folder.glob(pattern)))
25
+ print(f"found {len(audio_files)} audio files")
26
+
27
+ # split according to test_size
28
+ n_test = int(len(audio_files) * test_size)
29
+ n_train = len(audio_files) - n_test
30
+
31
+ # shuffle
32
+ random.seed(seed)
33
+ random.shuffle(audio_files)
34
+
35
+ train_files = audio_files[:n_train]
36
+ test_files = audio_files[n_train:]
37
+
38
+
39
+ print(f"Train files: {len(train_files)}")
40
+ print(f"Test files: {len(test_files)}")
41
+ continue_ = input("Continue [yn]? ") or "n"
42
+
43
+ if continue_ != "y":
44
+ return
45
+
46
+ for split, files in (
47
+ ("train", train_files), ("test", test_files)
48
+ ):
49
+ for file in tqdm(files):
50
+ out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
51
+ out_file.parent.mkdir(exist_ok=True, parents=True)
52
+ os.symlink(file, out_file)
53
+
54
+ # save split as json
55
+ with open(Path(audio_folder) / f"{split}.json", "w") as f:
56
+ json.dump([str(f) for f in files], f)
57
+
58
+
59
+
60
+ if __name__ == "__main__":
61
+ args = argbind.parse_args()
62
+
63
+ with argbind.scope(args):
64
+ train_test_split()
vampnet/scripts/utils/split_long_audio_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argbind
3
+
4
+ import audiotools as at
5
+ import tqdm
6
+
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(
10
+ file: str = None,
11
+ max_chunk_size_s: int = 60*10
12
+ ):
13
+ file = Path(file)
14
+ output_dir = file.parent / file.stem
15
+ output_dir.mkdir()
16
+
17
+ sig = at.AudioSignal(file)
18
+
19
+ # split into chunks
20
+ for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
+ window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
+ preprocess=True))
23
+ ):
24
+ sig.write(output_dir / f"{i}.wav")
25
+
26
+ print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
+
28
+ return output_dir
29
+
30
+ if __name__ == "__main__":
31
+ args = argbind.parse_args()
32
+
33
+ with argbind.scope(args):
34
+ split_long_audio_file()
vampnet/scripts/utils/stage.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import argbind
6
+ import rich
7
+ from audiotools.ml import Experiment
8
+
9
+
10
+ @argbind.bind(without_prefix=True)
11
+ def run(
12
+ run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
13
+ name: str = None,
14
+ recent: bool = False,
15
+ ):
16
+ if recent:
17
+ paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
18
+ paths = [p.name for p in paths if p.is_dir()]
19
+ if paths:
20
+ name = paths[-1]
21
+
22
+ with Experiment(run_dir, name) as exp:
23
+ exp.snapshot()
24
+ rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
25
+
26
+
27
+ if __name__ == "__main__":
28
+ args = argbind.parse_args()
29
+ with argbind.scope(args):
30
+ run()
vampnet/scripts/utils/visualize_embeddings.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: train a linear probe
3
+ usage:
4
+ python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
5
+ """
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ import audiotools as at
10
+ from audiotools import AudioSignal
11
+ import argbind
12
+ import torch
13
+ import numpy as np
14
+ import zipfile
15
+ import json
16
+
17
+ from vampnet.interface import Interface
18
+ import tqdm
19
+
20
+ # bind the Interface to argbind
21
+ Interface = argbind.bind(Interface)
22
+
23
+ PREFIX = "vampnet-embedding-"
24
+
25
+ DEBUG = False
26
+
27
+
28
+ def smart_plotly_export(fig, save_path: Path):
29
+ img_format = save_path.suffix[1:]
30
+ if img_format == "html":
31
+ fig.write_html(save_path)
32
+ elif img_format == "bytes":
33
+ return fig.to_image(format="png")
34
+ # TODO: come back and make this prettier
35
+ elif img_format == "numpy":
36
+ import io
37
+ from PIL import Image
38
+
39
+ def plotly_fig2array(fig):
40
+ # convert Plotly fig to an array
41
+ fig_bytes = fig.to_image(format="png", width=1200, height=700)
42
+ buf = io.BytesIO(fig_bytes)
43
+ img = Image.open(buf)
44
+ return np.asarray(img)
45
+
46
+ return plotly_fig2array(fig)
47
+ elif img_format == "jpeg" or "png" or "webp":
48
+ fig.write_image(save_path)
49
+ else:
50
+ raise ValueError("invalid image format")
51
+
52
+
53
+ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
54
+ """
55
+ dimensionality reduction for visualization!
56
+ saves an html plotly figure to save_path
57
+ parameters:
58
+ emb (np.ndarray): the samples to be reduced with shape (samples, features)
59
+ labels (list): list of labels for embedding
60
+ save_path (str): path where u wanna save ur figure
61
+ method (str): umap, tsne, or pca
62
+ title (str): title for ur figure
63
+ returns:
64
+ proj (np.ndarray): projection vector with shape (samples, dimensions)
65
+ """
66
+ import pandas as pd
67
+ import plotly.express as px
68
+
69
+ fig_name = f"vampnet-embeddings-layer={layer}"
70
+ fig_title = f"{fig_name}_{method}"
71
+ save_path = (output_dir / fig_name).with_suffix(".html")
72
+
73
+ if method == "umap":
74
+ from umap import UMAP
75
+ reducer = UMAP(n_components=n_components)
76
+ elif method == "tsne":
77
+ from sklearn.manifold import TSNE
78
+
79
+ reducer = TSNE(n_components=n_components)
80
+ elif method == "pca":
81
+ from sklearn.decomposition import PCA
82
+
83
+ reducer = PCA(n_components=n_components)
84
+ else:
85
+ raise ValueError(f"invalid method: {method}")
86
+
87
+ labels = [emb.label for emb in annotated_embeddings]
88
+ names = [emb.filename for emb in annotated_embeddings]
89
+ embs = [emb.embedding for emb in annotated_embeddings]
90
+ embs_at_layer = np.stack(embs)[:, layer, :]
91
+ projs = reducer.fit_transform(embs_at_layer)
92
+
93
+ df = pd.DataFrame(
94
+ {
95
+ "label": labels,
96
+ "name": names,
97
+ "x": projs[:, 0],
98
+ "y": projs[:, 1],
99
+ }
100
+ )
101
+ if n_components == 2:
102
+ fig = px.scatter(
103
+ df, x="x", y="y", color="label", hover_name="name", title=fig_title,
104
+ )
105
+
106
+ elif n_components == 3:
107
+ df['z'] = projs[:, 2]
108
+ fig = px.scatter_3d(
109
+ df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
110
+ )
111
+ else:
112
+ raise ValueError(f"can't plot {n_components} components")
113
+
114
+ fig.update_traces(
115
+ marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
116
+ selector=dict(mode="markers"),
117
+ )
118
+
119
+ return smart_plotly_export(fig, save_path)
120
+
121
+
122
+ # per JukeMIR, we want the emebddings from the middle layer?
123
+ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
124
+ with torch.inference_mode():
125
+ # preprocess the signal
126
+ sig = interface.preprocess(sig)
127
+
128
+ # get the coarse vampnet model
129
+ vampnet = interface.coarse
130
+
131
+ # get the tokens
132
+ z = interface.encode(sig)[:, : vampnet.n_codebooks, :]
133
+ z_latents = vampnet.embedding.from_codes(z, interface.codec)
134
+
135
+ # do a forward pass through the model, get the embeddings
136
+ _z, embeddings = vampnet(z_latents, return_activations=True)
137
+ # print(f"got embeddings with shape {embeddings.shape}")
138
+ # [layer, batch, time, n_dims]
139
+ # [20, 1, 600ish, 768]
140
+
141
+ # squeeze batch dim (1 bc layer should be dim 0)
142
+ assert (
143
+ embeddings.shape[1] == 1
144
+ ), f"expected batch dim to be 1, got {embeddings.shape[0]}"
145
+ embeddings = embeddings.squeeze(1)
146
+
147
+ num_layers = embeddings.shape[0]
148
+ assert (
149
+ layer < num_layers
150
+ ), f"layer {layer} is out of bounds for model with {num_layers} layers"
151
+
152
+ # do meanpooling over the time dimension
153
+ embeddings = embeddings.mean(dim=-2)
154
+ # [20, 768]
155
+
156
+ # return the embeddings
157
+ return embeddings
158
+
159
+
160
+ from dataclasses import dataclass, fields
161
+
162
+
163
+ @dataclass
164
+ class AnnotatedEmbedding:
165
+ label: str
166
+ filename: str
167
+ embedding: np.ndarray
168
+
169
+ def save(self, path):
170
+ """Save the Embedding object to a given path as a zip file."""
171
+ with zipfile.ZipFile(path, "w") as archive:
172
+
173
+ # Save numpy array
174
+ with archive.open("embedding.npy", "w") as f:
175
+ np.save(f, self.embedding)
176
+
177
+ # Save non-numpy data as json
178
+ non_numpy_data = {
179
+ f.name: getattr(self, f.name)
180
+ for f in fields(self)
181
+ if f.name != "embedding"
182
+ }
183
+ with archive.open("data.json", "w") as f:
184
+ f.write(json.dumps(non_numpy_data).encode("utf-8"))
185
+
186
+ @classmethod
187
+ def load(cls, path):
188
+ """Load the Embedding object from a given zip path."""
189
+ with zipfile.ZipFile(path, "r") as archive:
190
+
191
+ # Load numpy array
192
+ with archive.open("embedding.npy") as f:
193
+ embedding = np.load(f)
194
+
195
+ # Load non-numpy data from json
196
+ with archive.open("data.json") as f:
197
+ data = json.loads(f.read().decode("utf-8"))
198
+
199
+ return cls(embedding=embedding, **data)
200
+
201
+
202
+ @argbind.bind(without_prefix=True)
203
+ def main(
204
+ path_to_audio: str = None,
205
+ cache_dir: str = "./.emb_cache",
206
+ output_dir: str = "./vampnet_embeddings",
207
+ layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
208
+ method: str = "tsne",
209
+ n_components: int = 2,
210
+ ):
211
+ path_to_audio = Path(path_to_audio)
212
+ assert path_to_audio.exists(), f"{path_to_audio} does not exist"
213
+
214
+ cache_dir = Path(cache_dir)
215
+ output_dir = Path(output_dir)
216
+ output_dir.mkdir(exist_ok=True, parents=True)
217
+
218
+ # load our interface
219
+ # argbind will automatically load the default config,
220
+ interface = Interface()
221
+
222
+ # we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
223
+ labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
224
+ print(f"Found {len(labels)} labels")
225
+ print(f"labels: {labels}")
226
+
227
+ # collect audio files, labels, and embeddings
228
+ annotated_embeddings = []
229
+ for label in labels:
230
+ audio_files = list(at.util.find_audio(path_to_audio / label))
231
+ print(f"Found {len(audio_files)} audio files for label {label}")
232
+
233
+ for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
234
+ # check if we have a cached embedding for this file
235
+ cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
236
+ if cached_path.exists():
237
+ # if so, load it
238
+ if DEBUG:
239
+ print(f"loading cached embedding for {cached_path.stem}")
240
+ embedding = AnnotatedEmbedding.load(cached_path)
241
+ else:
242
+ try:
243
+ sig = AudioSignal(audio_file)
244
+ except Exception as e:
245
+ print(f"failed to load {audio_file.name} with error {e}")
246
+ print(f"skipping {audio_file.name}")
247
+ continue
248
+
249
+ # gets the embedding
250
+ emb = vampnet_embed(sig, interface).cpu().numpy()
251
+
252
+ # create an embedding we can save/load
253
+ embedding = AnnotatedEmbedding(
254
+ label=label, filename=audio_file.name, embedding=emb
255
+ )
256
+
257
+ # cache the embeddings
258
+ cached_path.parent.mkdir(exist_ok=True, parents=True)
259
+ embedding.save(cached_path)
260
+ annotated_embeddings.append(embedding)
261
+
262
+ # now, let's do a dim reduction on the embeddings and visualize them.
263
+ for layer in tqdm.tqdm(layers, desc="dim reduction"):
264
+ dim_reduce(
265
+ annotated_embeddings,
266
+ layer,
267
+ output_dir=output_dir,
268
+ n_components=n_components,
269
+ method=method,
270
+ )
271
+
272
+
273
+ if __name__ == "__main__":
274
+ args = argbind.parse_args()
275
+ with argbind.scope(args):
276
+ main()
vampnet/scripts/utils/xeno-canto-dl.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xenopy import Query
2
+
3
+
4
+ SPECIES = [
5
+ "American Robin",
6
+ "Northern Cardinal",
7
+ "Mourning Dove",
8
+ "American Crow",
9
+ "Baltimore Oriole",
10
+ "Blue Jay",
11
+ "Eastern Bluebird",
12
+ "House Finch",
13
+ "American Goldfinch",
14
+ "House Sparrow",
15
+ "Song Sparrow",
16
+ "Tufted Titmouse",
17
+ "White-breasted Nuthatch",
18
+ "European Starling",
19
+ "American Redstart",
20
+ "Red-winged Blackbird",
21
+ "Brown-headed Cowbird",
22
+ "Common Grackle",
23
+ "Boat-tailed Grackle",
24
+ "Common Yellowthroat",
25
+ "Northern Mockingbird",
26
+ "Carolina Wren",
27
+ "Eastern Meadowlark",
28
+ "Chipping Sparrow",
29
+ "Tree Swallow",
30
+ "Barn Swallow",
31
+ "Cliff Swallow",
32
+ "Pine Siskin",
33
+ "Indigo Bunting",
34
+ "Eastern Towhee",
35
+ "Carolina Chickadee",
36
+ "Great Crested Flycatcher",
37
+ "Eastern Wood-Pewee",
38
+ "Ovenbird",
39
+ "Northern Flicker",
40
+ "Red-eyed Vireo",
41
+ "American Woodcock",
42
+ "Eastern Phoebe",
43
+ "Downy Woodpecker",
44
+ "Scarlet Tanager",
45
+ "Yellow Warbler",
46
+ "White-eyed Vireo",
47
+ "Common Loon",
48
+ "White-throated Sparrow",
49
+ "Yellow-throated Vireo",
50
+ "Great Blue Heron",
51
+ "Belted Kingfisher",
52
+ "Pied-billed Grebe",
53
+ "Wild Turkey",
54
+ "Wood Thrush",
55
+ "Rose-breasted Grosbeak",
56
+ "Field Sparrow",
57
+ "Hooded Warbler",
58
+ "Northern Parula",
59
+ "Chestnut-sided Warbler",
60
+ "Blue-winged Warbler",
61
+ "Red-bellied Woodpecker",
62
+ "Yellow-billed Cuckoo",
63
+ "Gray Catbird",
64
+ "Northern Saw-whet Owl",
65
+ "Osprey",
66
+ "Common Nighthawk",
67
+ "Broad-winged Hawk",
68
+ "Black-throated Green Warbler",
69
+ "Great Horned Owl",
70
+ "Common Raven",
71
+ "Barred Owl",
72
+ "Canada Warbler",
73
+ "Magnolia Warbler",
74
+ "Black-and-white Warbler",
75
+ "Eastern Kingbird",
76
+ "Swainson's Thrush",
77
+ "Worm-eating Warbler",
78
+ "Prairie Warbler",
79
+ "Baltimore Oriole",
80
+ "Black-throated Blue Warbler",
81
+ "Louisiana Waterthrush",
82
+ "Blackburnian Warbler",
83
+ "Black-capped Chickadee",
84
+ "Cerulean Warbler",
85
+ "Red-shouldered Hawk",
86
+ "Cooper's Hawk",
87
+ "Yellow-throated Warbler",
88
+ "Blue-headed Vireo",
89
+ "Blackpoll Warbler",
90
+ "Ruffed Grouse",
91
+ "Kentucky Warbler",
92
+ "Hermit Thrush",
93
+ "Cedar Waxwing",
94
+ "Eastern Screech-Owl",
95
+ "Northern Goshawk",
96
+ "Green Heron",
97
+ "Red-tailed Hawk",
98
+ "Black Vulture",
99
+ "Hairy Woodpecker",
100
+ "Golden-crowned Kinglet",
101
+ "Ruby-crowned Kinglet",
102
+ "Bicknell's Thrush",
103
+ "Blue-gray Gnatcatcher",
104
+ "Veery",
105
+ "Pileated Woodpecker",
106
+ "Purple Finch",
107
+ "White-crowned Sparrow",
108
+ "Snow Bunting",
109
+ "Pine Grosbeak",
110
+ "American Tree Sparrow",
111
+ "Dark-eyed Junco",
112
+ "Snowy Owl",
113
+ "White-winged Crossbill",
114
+ "Red Crossbill",
115
+ "Common Redpoll",
116
+ "Northern Shrike",
117
+ "Northern Harrier",
118
+ "Rough-legged Hawk",
119
+ "Long-eared Owl",
120
+ "Evening Grosbeak",
121
+ "Northern Pintail",
122
+ "American Black Duck",
123
+ "Mallard",
124
+ "Canvasback",
125
+ "Redhead",
126
+ "Ring-necked Duck",
127
+ "Greater Scaup",
128
+ "Lesser Scaup",
129
+ "Bufflehead",
130
+ "Common Goldeneye",
131
+ "Hooded Merganser",
132
+ "Common Merganser",
133
+ "Red-breasted Merganser",
134
+ "Ruddy Duck",
135
+ "Wood Duck",
136
+ "Gadwall",
137
+ "American Wigeon",
138
+ "Northern Shoveler",
139
+ "Green-winged Teal",
140
+ "Blue-winged Teal",
141
+ "Cinnamon Teal",
142
+ "Ringed Teal",
143
+ "Cape Teal",
144
+ "Northern Fulmar",
145
+ "Yellow-billed Loon",
146
+ "Red-throated Loon",
147
+ "Arctic Loon",
148
+ "Pacific Loon",
149
+ "Horned Grebe",
150
+ "Red-necked Grebe",
151
+ "Eared Grebe",
152
+ "Western Grebe",
153
+ "Clark's Grebe",
154
+ "Double-crested Cormorant",
155
+ "Pelagic Cormorant",
156
+ "Great Cormorant",
157
+ "American White Pelican",
158
+ "Brown Pelican",
159
+ "Brandt's Cormorant",
160
+ "Least Bittern",
161
+ "Great Egret",
162
+ "Snowy Egret",
163
+ "Little Blue Heron",
164
+ "Tricolored Heron",
165
+ "Reddish Egret",
166
+ "Black-crowned Night-Heron",
167
+ "Yellow-crowned Night-Heron",
168
+ "White Ibis",
169
+ "Glossy Ibis",
170
+ "Roseate Spoonbill",
171
+ "Wood Stork",
172
+ "Black-bellied Whistling-Duck",
173
+ "Fulvous Whistling-Duck",
174
+ "Greater White-fronted Goose",
175
+ "Snow Goose",
176
+ "Ross's Goose",
177
+ "Canada Goose",
178
+ "Brant",
179
+ "Mute Swan",
180
+ "Tundra Swan",
181
+ "Whooper Swan",
182
+ "Sandhill Crane",
183
+ "Black-necked Stilt",
184
+ "American Avocet",
185
+ "Northern Jacana",
186
+ "Greater Yellowlegs",
187
+ "Lesser Yellowlegs",
188
+ "Willet",
189
+ "Spotted Sandpiper",
190
+ "Upland Sandpiper",
191
+ "Whimbrel",
192
+ "Long-billed Curlew",
193
+ "Marbled Godwit",
194
+ "Ruddy Turnstone",
195
+ "Red Knot",
196
+ "Sanderling",
197
+ "Semipalmated Sandpiper",
198
+ "Western Sandpiper",
199
+ "Least Sandpiper",
200
+ "White-rumped Sandpiper",
201
+ "Baird's Sandpiper",
202
+ "Pectoral Sandpiper",
203
+ "Dunlin",
204
+ "Buff-breasted Sandpiper",
205
+ "Short-billed Dowitcher",
206
+ "Long-billed Dowitcher",
207
+ "Common Snipe",
208
+ "American Woodcock",
209
+ "Wilson's Phalarope",
210
+ "Red-necked Phalarope",
211
+ "Red Phalarope"
212
+ ]
213
+
214
+ from pathlib import Path
215
+
216
+ def remove_spaces(s):
217
+ return s.replace(" ", "")
218
+
219
+ for species in SPECIES:
220
+ if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
+ continue
222
+ try:
223
+ q = Query(
224
+ name=species, q="A", length="10-30",
225
+ )
226
+
227
+ # retrieve metadata
228
+ metafiles = q.retrieve_meta(verbose=True)
229
+ # retrieve recordings
230
+ q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
231
+
232
+ except:
233
+ print("Failed to download " + species)
234
+ continue
vampnet/setup.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages
2
+ from setuptools import setup
3
+
4
+ with open("README.md") as f:
5
+ long_description = f.read()
6
+
7
+ setup(
8
+ name="vampnet",
9
+ version="0.0.1",
10
+ classifiers=[
11
+ "Intended Audience :: Developers",
12
+ "Natural Language :: English",
13
+ "Programming Language :: Python :: 3.7",
14
+ "Topic :: Artistic Software",
15
+ "Topic :: Multimedia",
16
+ "Topic :: Multimedia :: Sound/Audio",
17
+ "Topic :: Multimedia :: Sound/Audio :: Editors",
18
+ "Topic :: Software Development :: Libraries",
19
+ ],
20
+ description="Generative Music Modeling.",
21
+ long_description=long_description,
22
+ long_description_content_type="text/markdown",
23
+ author="Hugo Flores García, Prem Seetharaman",
24
+ author_email="hfgacrcia@descript.com",
25
+ url="https://github.com/hugofloresgarcia/vampnet",
26
+ license="MIT",
27
+ packages=find_packages(),
28
+ setup_requires=[
29
+ "Cython",
30
+ ],
31
+ install_requires=[
32
+ "Cython", # Added by WAM because it seems to be needed by this repo?
33
+ "torch",
34
+ "pydantic==2.10.6",
35
+ "argbind>=0.3.2",
36
+ "numpy<1.24",
37
+ "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
38
+ "lac @ git+https://github.com/hugofloresgarcia/lac.git",
39
+ "descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
40
+ "gradio",
41
+ "loralib",
42
+ "torch_pitch_shift",
43
+ "plotly", # Added by WAM for clustering (see https://github.com/hugofloresgarcia/vampnet/issues/20)
44
+ "pyharp",
45
+
46
+ ],
47
+ )
vampnet/vampnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from . import modules
3
+ from . import scheduler
4
+ from .interface import Interface
5
+
6
+ __version__ = "0.0.1"
vampnet/vampnet/beats.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any
7
+ from typing import List
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ import librosa
12
+ import torch
13
+ import numpy as np
14
+ from audiotools import AudioSignal
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+ ###################
20
+ # beat sync utils #
21
+ ###################
22
+
23
+ AGGREGATOR_REGISTRY = {
24
+ "mean": np.mean,
25
+ "median": np.median,
26
+ "max": np.max,
27
+ "min": np.min,
28
+ }
29
+
30
+
31
+ def list_aggregators() -> list:
32
+ return list(AGGREGATOR_REGISTRY.keys())
33
+
34
+
35
+ @dataclass
36
+ class TimeSegment:
37
+ start: float
38
+ end: float
39
+
40
+ @property
41
+ def duration(self):
42
+ return self.end - self.start
43
+
44
+ def __str__(self) -> str:
45
+ return f"{self.start} - {self.end}"
46
+
47
+ def find_overlapping_segment(
48
+ self, segments: List["TimeSegment"]
49
+ ) -> Union["TimeSegment", None]:
50
+ """Find the first segment that overlaps with this segment, or None if no segment overlaps"""
51
+ for s in segments:
52
+ if s.start <= self.start and s.end >= self.end:
53
+ return s
54
+ return None
55
+
56
+
57
+ def mkdir(path: Union[Path, str]) -> Path:
58
+ p = Path(path)
59
+ p.mkdir(parents=True, exist_ok=True)
60
+ return p
61
+
62
+
63
+
64
+ ###################
65
+ # beat data #
66
+ ###################
67
+ @dataclass
68
+ class BeatSegment(TimeSegment):
69
+ downbeat: bool = False # if there's a downbeat on the start_time
70
+
71
+
72
+ class Beats:
73
+ def __init__(self, beat_times, downbeat_times):
74
+ if isinstance(beat_times, np.ndarray):
75
+ beat_times = beat_times.tolist()
76
+ if isinstance(downbeat_times, np.ndarray):
77
+ downbeat_times = downbeat_times.tolist()
78
+ self._beat_times = beat_times
79
+ self._downbeat_times = downbeat_times
80
+ self._use_downbeats = False
81
+
82
+ def use_downbeats(self, use_downbeats: bool = True):
83
+ """use downbeats instead of beats when calling beat_times"""
84
+ self._use_downbeats = use_downbeats
85
+
86
+ def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
87
+ """
88
+ segments a song into time segments corresponding to beats.
89
+ the first segment starts at 0 and ends at the first beat time.
90
+ the last segment starts at the last beat time and ends at the end of the song.
91
+ """
92
+ beat_times = self._beat_times.copy()
93
+ downbeat_times = self._downbeat_times
94
+ beat_times.insert(0, 0)
95
+ beat_times.append(signal.signal_duration)
96
+
97
+ downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
98
+ 1
99
+ ]
100
+ is_downbeat = [
101
+ True if i in downbeat_ids else False for i in range(len(beat_times))
102
+ ]
103
+ segments = [
104
+ BeatSegment(start_time, end_time, downbeat)
105
+ for start_time, end_time, downbeat in zip(
106
+ beat_times[:-1], beat_times[1:], is_downbeat
107
+ )
108
+ ]
109
+ return segments
110
+
111
+ def get_beats(self) -> np.ndarray:
112
+ """returns an array of beat times, in seconds
113
+ if downbeats is True, returns an array of downbeat times, in seconds
114
+ """
115
+ return np.array(
116
+ self._downbeat_times if self._use_downbeats else self._beat_times
117
+ )
118
+
119
+ @property
120
+ def beat_times(self) -> np.ndarray:
121
+ """return beat times"""
122
+ return np.array(self._beat_times)
123
+
124
+ @property
125
+ def downbeat_times(self) -> np.ndarray:
126
+ """return downbeat times"""
127
+ return np.array(self._downbeat_times)
128
+
129
+ def beat_times_to_feature_frames(
130
+ self, signal: AudioSignal, features: np.ndarray
131
+ ) -> np.ndarray:
132
+ """convert beat times to frames, given an array of time-varying features"""
133
+ beat_times = self.get_beats()
134
+ beat_frames = (
135
+ beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
136
+ ).astype(np.int64)
137
+ return beat_frames
138
+
139
+ def sync_features(
140
+ self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
141
+ ) -> np.ndarray:
142
+ """sync features to beats"""
143
+ if aggregate not in AGGREGATOR_REGISTRY:
144
+ raise ValueError(f"unknown aggregation method {aggregate}")
145
+
146
+ return librosa.util.sync(
147
+ features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
148
+ )
149
+
150
+ def to_json(self) -> dict:
151
+ """return beats and downbeats as json"""
152
+ return {
153
+ "beats": self._beat_times,
154
+ "downbeats": self._downbeat_times,
155
+ "use_downbeats": self._use_downbeats,
156
+ }
157
+
158
+ @classmethod
159
+ def from_dict(cls, data: dict):
160
+ """load beats and downbeats from json"""
161
+ inst = cls(data["beats"], data["downbeats"])
162
+ inst.use_downbeats(data["use_downbeats"])
163
+ return inst
164
+
165
+ def save(self, output_dir: Path):
166
+ """save beats and downbeats to json"""
167
+ mkdir(output_dir)
168
+ with open(output_dir / "beats.json", "w") as f:
169
+ json.dump(self.to_json(), f)
170
+
171
+ @classmethod
172
+ def load(cls, input_dir: Path):
173
+ """load beats and downbeats from json"""
174
+ beats_file = Path(input_dir) / "beats.json"
175
+ with open(beats_file, "r") as f:
176
+ data = json.load(f)
177
+ return cls.from_dict(data)
178
+
179
+
180
+ ###################
181
+ # beat tracking #
182
+ ###################
183
+
184
+
185
+ class BeatTracker:
186
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
187
+ """extract beats from an audio signal"""
188
+ raise NotImplementedError
189
+
190
+ def __call__(self, signal: AudioSignal) -> Beats:
191
+ """extract beats from an audio signal
192
+ NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
193
+ it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
194
+ Args:
195
+ signal (AudioSignal): signal to beat track
196
+ Returns:
197
+ Tuple[np.ndarray, np.ndarray]: beats and downbeats
198
+ """
199
+ beats, downbeats = self.extract_beats(signal)
200
+ return Beats(beats, downbeats)
201
+
202
+
203
+ class WaveBeat(BeatTracker):
204
+ def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
205
+ from wavebeat.dstcn import dsTCNModel
206
+
207
+ model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
208
+ model.eval()
209
+
210
+ self.device = device
211
+ self.model = model
212
+
213
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
214
+ """returns beat and downbeat times, in seconds"""
215
+ # extract beats
216
+ beats, downbeats = self.model.predict_beats_from_array(
217
+ audio=signal.audio_data.squeeze(0),
218
+ sr=signal.sample_rate,
219
+ use_gpu=self.device != "cpu",
220
+ )
221
+
222
+ return beats, downbeats
223
+
224
+
225
+ class MadmomBeats(BeatTracker):
226
+ def __init__(self):
227
+ raise NotImplementedError
228
+
229
+ def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
230
+ """returns beat and downbeat times, in seconds"""
231
+ pass
232
+
233
+
234
+ BEAT_TRACKER_REGISTRY = {
235
+ "wavebeat": WaveBeat,
236
+ "madmom": MadmomBeats,
237
+ }
238
+
239
+
240
+ def list_beat_trackers() -> list:
241
+ return list(BEAT_TRACKER_REGISTRY.keys())
242
+
243
+
244
+ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
245
+ if beat_tracker not in BEAT_TRACKER_REGISTRY:
246
+ raise ValueError(
247
+ f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
248
+ )
249
+
250
+ return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
vampnet/vampnet/interface.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import math
4
+
5
+ import torch
6
+ import numpy as np
7
+ from audiotools import AudioSignal
8
+ import tqdm
9
+
10
+ from .modules.transformer import VampNet
11
+ from .beats import WaveBeat
12
+ from .mask import *
13
+
14
+ # from dac.model.dac import DAC
15
+ from lac.model.lac import LAC as DAC
16
+
17
+
18
+ def signal_concat(
19
+ audio_signals: list,
20
+ ):
21
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
22
+
23
+ return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
24
+
25
+
26
+ def _load_model(
27
+ ckpt: str,
28
+ lora_ckpt: str = None,
29
+ device: str = "cpu",
30
+ chunk_size_s: int = 10,
31
+ ):
32
+ # we need to set strict to False if the model has lora weights to add later
33
+ model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
34
+
35
+ # load lora weights if needed
36
+ if lora_ckpt is not None:
37
+ if not Path(lora_ckpt).exists():
38
+ should_cont = input(
39
+ f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
40
+ )
41
+ if should_cont != "y":
42
+ raise Exception("aborting")
43
+ else:
44
+ model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
45
+
46
+ model.to(device)
47
+ model.eval()
48
+ model.chunk_size_s = chunk_size_s
49
+ return model
50
+
51
+
52
+
53
+ class Interface(torch.nn.Module):
54
+ def __init__(
55
+ self,
56
+ coarse_ckpt: str = None,
57
+ coarse_lora_ckpt: str = None,
58
+ coarse2fine_ckpt: str = None,
59
+ coarse2fine_lora_ckpt: str = None,
60
+ codec_ckpt: str = None,
61
+ wavebeat_ckpt: str = None,
62
+ device: str = "cpu",
63
+ coarse_chunk_size_s: int = 10,
64
+ coarse2fine_chunk_size_s: int = 3,
65
+ ):
66
+ super().__init__()
67
+ assert codec_ckpt is not None, "must provide a codec checkpoint"
68
+ self.codec = DAC.load(Path(codec_ckpt))
69
+ self.codec.eval()
70
+ self.codec.to(device)
71
+
72
+ assert coarse_ckpt is not None, "must provide a coarse checkpoint"
73
+ self.coarse = _load_model(
74
+ ckpt=coarse_ckpt,
75
+ lora_ckpt=coarse_lora_ckpt,
76
+ device=device,
77
+ chunk_size_s=coarse_chunk_size_s,
78
+ )
79
+
80
+ # check if we have a coarse2fine ckpt
81
+ if coarse2fine_ckpt is not None:
82
+ self.c2f = _load_model(
83
+ ckpt=coarse2fine_ckpt,
84
+ lora_ckpt=coarse2fine_lora_ckpt,
85
+ device=device,
86
+ chunk_size_s=coarse2fine_chunk_size_s,
87
+ )
88
+ else:
89
+ self.c2f = None
90
+
91
+ if wavebeat_ckpt is not None:
92
+ print(f"loading wavebeat from {wavebeat_ckpt}")
93
+ self.beat_tracker = WaveBeat(wavebeat_ckpt)
94
+ self.beat_tracker.model.to(device)
95
+ else:
96
+ self.beat_tracker = None
97
+
98
+ self.device = device
99
+
100
+ def lora_load(
101
+ self,
102
+ coarse_ckpt: str = None,
103
+ c2f_ckpt: str = None,
104
+ full_ckpts: bool = False,
105
+ ):
106
+ if full_ckpts:
107
+ if coarse_ckpt is not None:
108
+ self.coarse = _load_model(
109
+ ckpt=coarse_ckpt,
110
+ device=self.device,
111
+ chunk_size_s=self.coarse.chunk_size_s,
112
+ )
113
+ if c2f_ckpt is not None:
114
+ self.c2f = _load_model(
115
+ ckpt=c2f_ckpt,
116
+ device=self.device,
117
+ chunk_size_s=self.c2f.chunk_size_s,
118
+ )
119
+ else:
120
+ if coarse_ckpt is not None:
121
+ self.coarse.to("cpu")
122
+ state_dict = torch.load(coarse_ckpt, map_location="cpu")
123
+ print(f"loading coarse from {coarse_ckpt}")
124
+ self.coarse.load_state_dict(state_dict, strict=False)
125
+ self.coarse.to(self.device)
126
+ if c2f_ckpt is not None:
127
+ self.c2f.to("cpu")
128
+ state_dict = torch.load(c2f_ckpt, map_location="cpu")
129
+ print(f"loading c2f from {c2f_ckpt}")
130
+ self.c2f.load_state_dict(state_dict, strict=False)
131
+ self.c2f.to(self.device)
132
+
133
+ def s2t(self, seconds: float):
134
+ """seconds to tokens"""
135
+ if isinstance(seconds, np.ndarray):
136
+ return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
137
+ else:
138
+ return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
139
+
140
+ def s2t2s(self, seconds: float):
141
+ """seconds to tokens to seconds"""
142
+ return self.t2s(self.s2t(seconds))
143
+
144
+ def t2s(self, tokens: int):
145
+ """tokens to seconds"""
146
+ return tokens * self.codec.hop_length / self.codec.sample_rate
147
+
148
+ def to(self, device):
149
+ self.device = device
150
+ self.coarse.to(device)
151
+ self.codec.to(device)
152
+
153
+ if self.c2f is not None:
154
+ self.c2f.to(device)
155
+
156
+ if self.beat_tracker is not None:
157
+ self.beat_tracker.model.to(device)
158
+ return self
159
+
160
+ def to_signal(self, z: torch.Tensor):
161
+ return self.coarse.to_signal(z, self.codec)
162
+
163
+ def preprocess(self, signal: AudioSignal):
164
+ signal = (
165
+ signal.clone()
166
+ .resample(self.codec.sample_rate)
167
+ .to_mono()
168
+ .normalize(-24)
169
+ .ensure_max_of_audio(1.0)
170
+ )
171
+ return signal
172
+
173
+ @torch.inference_mode()
174
+ def encode(self, signal: AudioSignal):
175
+ signal = self.preprocess(signal).to(self.device)
176
+ z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
177
+ return z
178
+
179
+ def snap_to_beats(
180
+ self,
181
+ signal: AudioSignal
182
+ ):
183
+ assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
184
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
185
+
186
+ # trim the signa around the first beat time
187
+ samples_begin = int(beats[0] * signal.sample_rate )
188
+ samples_end = int(beats[-1] * signal.sample_rate)
189
+ print(beats[0])
190
+ signal = signal.clone().trim(samples_begin, signal.length - samples_end)
191
+
192
+ return signal
193
+
194
+ def make_beat_mask(self,
195
+ signal: AudioSignal,
196
+ before_beat_s: float = 0.0,
197
+ after_beat_s: float = 0.02,
198
+ mask_downbeats: bool = True,
199
+ mask_upbeats: bool = True,
200
+ downbeat_downsample_factor: int = None,
201
+ beat_downsample_factor: int = None,
202
+ dropout: float = 0.0,
203
+ invert: bool = True,
204
+ ):
205
+ """make a beat synced mask. that is, make a mask that
206
+ places 1s at and around the beat, and 0s everywhere else.
207
+ """
208
+ assert self.beat_tracker is not None, "No beat tracker loaded"
209
+
210
+ # get the beat times
211
+ beats, downbeats = self.beat_tracker.extract_beats(signal)
212
+
213
+ # get the beat indices in z
214
+ beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
215
+
216
+ # remove downbeats from beats
217
+ beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
218
+ beats_z = beats_z.tolist()
219
+ downbeats_z = downbeats_z.tolist()
220
+
221
+ # make the mask
222
+ seq_len = self.s2t(signal.duration)
223
+ mask = torch.zeros(seq_len, device=self.device)
224
+
225
+ mask_b4 = self.s2t(before_beat_s)
226
+ mask_after = self.s2t(after_beat_s)
227
+
228
+ if beat_downsample_factor is not None:
229
+ if beat_downsample_factor < 1:
230
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
231
+ else:
232
+ beat_downsample_factor = 1
233
+
234
+ if downbeat_downsample_factor is not None:
235
+ if downbeat_downsample_factor < 1:
236
+ raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
237
+ else:
238
+ downbeat_downsample_factor = 1
239
+
240
+ beats_z = beats_z[::beat_downsample_factor]
241
+ downbeats_z = downbeats_z[::downbeat_downsample_factor]
242
+ print(f"beats_z: {len(beats_z)}")
243
+ print(f"downbeats_z: {len(downbeats_z)}")
244
+
245
+ if mask_upbeats:
246
+ for beat_idx in beats_z:
247
+ _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
248
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
249
+ _m = torch.ones(num_steps, device=self.device)
250
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
251
+ _m = _m * _m_mask.long()
252
+
253
+ mask[_slice[0]:_slice[1]] = _m
254
+
255
+ if mask_downbeats:
256
+ for downbeat_idx in downbeats_z:
257
+ _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
258
+ num_steps = mask[_slice[0]:_slice[1]].shape[0]
259
+ _m = torch.ones(num_steps, device=self.device)
260
+ _m_mask = torch.bernoulli(_m * (1 - dropout))
261
+ _m = _m * _m_mask.long()
262
+
263
+ mask[_slice[0]:_slice[1]] = _m
264
+
265
+ mask = mask.clamp(0, 1)
266
+ if invert:
267
+ mask = 1 - mask
268
+
269
+ mask = mask[None, None, :].bool().long()
270
+ if self.c2f is not None:
271
+ mask = mask.repeat(1, self.c2f.n_codebooks, 1)
272
+ else:
273
+ mask = mask.repeat(1, self.coarse.n_codebooks, 1)
274
+ return mask
275
+
276
+ def coarse_to_fine(
277
+ self,
278
+ z: torch.Tensor,
279
+ mask: torch.Tensor = None,
280
+ **kwargs
281
+ ):
282
+ assert self.c2f is not None, "No coarse2fine model loaded"
283
+ length = z.shape[-1]
284
+ chunk_len = self.s2t(self.c2f.chunk_size_s)
285
+ n_chunks = math.ceil(z.shape[-1] / chunk_len)
286
+
287
+ # zero pad to chunk_len
288
+ if length % chunk_len != 0:
289
+ pad_len = chunk_len - (length % chunk_len)
290
+ z = torch.nn.functional.pad(z, (0, pad_len))
291
+ mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
292
+
293
+ n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
294
+ if n_codebooks_to_append > 0:
295
+ z = torch.cat([
296
+ z,
297
+ torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
298
+ ], dim=1)
299
+
300
+ # set the mask to 0 for all conditioning codebooks
301
+ if mask is not None:
302
+ mask = mask.clone()
303
+ mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
304
+
305
+ fine_z = []
306
+ for i in range(n_chunks):
307
+ chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
308
+ mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
309
+
310
+ chunk = self.c2f.generate(
311
+ codec=self.codec,
312
+ time_steps=chunk_len,
313
+ start_tokens=chunk,
314
+ return_signal=False,
315
+ mask=mask_chunk,
316
+ **kwargs
317
+ )
318
+ fine_z.append(chunk)
319
+
320
+ fine_z = torch.cat(fine_z, dim=-1)
321
+ return fine_z[:, :, :length].clone()
322
+
323
+ def coarse_vamp(
324
+ self,
325
+ z,
326
+ mask,
327
+ return_mask=False,
328
+ gen_fn=None,
329
+ **kwargs
330
+ ):
331
+ # coarse z
332
+ cz = z[:, : self.coarse.n_codebooks, :].clone()
333
+ assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
334
+
335
+ mask = mask[:, : self.coarse.n_codebooks, :]
336
+
337
+ cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
338
+ cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
339
+
340
+ gen_fn = gen_fn or self.coarse.generate
341
+ c_vamp = gen_fn(
342
+ codec=self.codec,
343
+ time_steps=cz.shape[-1],
344
+ start_tokens=cz,
345
+ mask=mask,
346
+ return_signal=False,
347
+ **kwargs
348
+ )
349
+
350
+ # add the fine codes back in
351
+ c_vamp = torch.cat(
352
+ [c_vamp, z[:, self.coarse.n_codebooks :, :]],
353
+ dim=1
354
+ )
355
+
356
+ if return_mask:
357
+ return c_vamp, cz_masked
358
+
359
+ return c_vamp
360
+
361
+
362
+ if __name__ == "__main__":
363
+ import audiotools as at
364
+ import logging
365
+ logger = logging.getLogger()
366
+ logger.setLevel(logging.INFO)
367
+ torch.set_printoptions(threshold=10000)
368
+ at.util.seed(42)
369
+
370
+ interface = Interface(
371
+ coarse_ckpt="./models/vampnet/coarse.pth",
372
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
373
+ codec_ckpt="./models/vampnet/codec.pth",
374
+ device="cuda",
375
+ wavebeat_ckpt="./models/wavebeat.pth"
376
+ )
377
+
378
+
379
+ sig = at.AudioSignal('assets/example.wav')
380
+
381
+ z = interface.encode(sig)
382
+ breakpoint()
383
+
384
+ # mask = linear_random(z, 1.0)
385
+ # mask = mask_and(
386
+ # mask, periodic_mask(
387
+ # z,
388
+ # 32,
389
+ # 1,
390
+ # random_roll=True
391
+ # )
392
+ # )
393
+
394
+ # mask = interface.make_beat_mask(
395
+ # sig, 0.0, 0.075
396
+ # )
397
+ # mask = dropout(mask, 0.0)
398
+ # mask = codebook_unmask(mask, 0)
399
+
400
+ mask = inpaint(z, n_prefix=100, n_suffix=100)
401
+
402
+ zv, mask_z = interface.coarse_vamp(
403
+ z,
404
+ mask=mask,
405
+ sampling_steps=36,
406
+ temperature=8.0,
407
+ return_mask=True,
408
+ gen_fn=interface.coarse.generate
409
+ )
410
+
411
+
412
+ use_coarse2fine = True
413
+ if use_coarse2fine:
414
+ zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
415
+ breakpoint()
416
+
417
+ mask = interface.to_signal(mask_z).cpu()
418
+
419
+ sig = interface.to_signal(zv).cpu()
420
+ print("done")
421
+
422
+
vampnet/vampnet/mask.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from audiotools import AudioSignal
5
+
6
+ from .util import scalar_to_batch_tensor
7
+
8
+ def _gamma(r):
9
+ return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
10
+
11
+ def _invgamma(y):
12
+ if not torch.is_tensor(y):
13
+ y = torch.tensor(y)[None]
14
+ return 2 * y.acos() / torch.pi
15
+
16
+ def full_mask(x: torch.Tensor):
17
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
18
+ return torch.ones_like(x).long()
19
+
20
+ def empty_mask(x: torch.Tensor):
21
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
22
+ return torch.zeros_like(x).long()
23
+
24
+ def apply_mask(
25
+ x: torch.Tensor,
26
+ mask: torch.Tensor,
27
+ mask_token: int
28
+ ):
29
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
30
+ assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
31
+ assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
32
+ assert ~torch.any(mask > 1), "mask must be binary"
33
+ assert ~torch.any(mask < 0), "mask must be binary"
34
+
35
+ fill_x = torch.full_like(x, mask_token)
36
+ x = x * (1 - mask) + fill_x * mask
37
+
38
+ return x, mask
39
+
40
+ def random(
41
+ x: torch.Tensor,
42
+ r: torch.Tensor
43
+ ):
44
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
+ if not isinstance(r, torch.Tensor):
46
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
47
+
48
+ r = _gamma(r)[:, None, None]
49
+ probs = torch.ones_like(x) * r
50
+
51
+ mask = torch.bernoulli(probs)
52
+ mask = mask.round().long()
53
+
54
+ return mask
55
+
56
+ def linear_random(
57
+ x: torch.Tensor,
58
+ r: torch.Tensor,
59
+ ):
60
+ assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
61
+ if not isinstance(r, torch.Tensor):
62
+ r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
63
+
64
+ probs = torch.ones_like(x).to(x.device).float()
65
+ # expand to batch and codebook dims
66
+ probs = probs.expand(x.shape[0], x.shape[1], -1)
67
+ probs = probs * r
68
+
69
+ mask = torch.bernoulli(probs)
70
+ mask = mask.round().long()
71
+
72
+ return mask
73
+
74
+ def inpaint(x: torch.Tensor,
75
+ n_prefix,
76
+ n_suffix,
77
+ ):
78
+ assert n_prefix is not None
79
+ assert n_suffix is not None
80
+
81
+ mask = full_mask(x)
82
+
83
+ # if we have a prefix or suffix, set their mask prob to 0
84
+ if n_prefix > 0:
85
+ if not isinstance(n_prefix, torch.Tensor):
86
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
87
+ for i, n in enumerate(n_prefix):
88
+ if n > 0:
89
+ mask[i, :, :n] = 0.0
90
+ if n_suffix > 0:
91
+ if not isinstance(n_suffix, torch.Tensor):
92
+ n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
93
+ for i, n in enumerate(n_suffix):
94
+ if n > 0:
95
+ mask[i, :, -n:] = 0.0
96
+
97
+
98
+ return mask
99
+
100
+ def periodic_mask(x: torch.Tensor,
101
+ period: int, width: int = 1,
102
+ random_roll=False,
103
+ ):
104
+ mask = full_mask(x)
105
+ if period == 0:
106
+ return mask
107
+
108
+ if not isinstance(period, torch.Tensor):
109
+ period = scalar_to_batch_tensor(period, x.shape[0])
110
+ for i, factor in enumerate(period):
111
+ if factor == 0:
112
+ continue
113
+ for j in range(mask.shape[-1]):
114
+ if j % factor == 0:
115
+ # figure out how wide the mask should be
116
+ j_start = max(0, j - width // 2 )
117
+ j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
118
+ # flip a coin for each position in the mask
119
+ j_mask = torch.bernoulli(torch.ones(j_end - j_start))
120
+ assert torch.all(j_mask == 1)
121
+ j_fill = torch.ones_like(j_mask) * (1 - j_mask)
122
+ assert torch.all(j_fill == 0)
123
+ # fill
124
+ mask[i, :, j_start:j_end] = j_fill
125
+ if random_roll:
126
+ # add a random offset to the mask
127
+ offset = torch.randint(0, period[0], (1,))
128
+ mask = torch.roll(mask, offset.item(), dims=-1)
129
+
130
+ return mask
131
+
132
+ def codebook_unmask(
133
+ mask: torch.Tensor,
134
+ n_conditioning_codebooks: int
135
+ ):
136
+ if n_conditioning_codebooks == None:
137
+ return mask
138
+ # if we have any conditioning codebooks, set their mask to 0
139
+ mask = mask.clone()
140
+ mask[:, :n_conditioning_codebooks, :] = 0
141
+ return mask
142
+
143
+ def codebook_mask(mask: torch.Tensor, start: int):
144
+ mask = mask.clone()
145
+ mask[:, start:, :] = 1
146
+ return mask
147
+
148
+ def mask_and(
149
+ mask1: torch.Tensor,
150
+ mask2: torch.Tensor
151
+ ):
152
+ assert mask1.shape == mask2.shape, "masks must be same shape"
153
+ return torch.min(mask1, mask2)
154
+
155
+ def dropout(
156
+ mask: torch.Tensor,
157
+ p: float,
158
+ ):
159
+ assert 0 <= p <= 1, "p must be between 0 and 1"
160
+ assert mask.max() <= 1, "mask must be binary"
161
+ assert mask.min() >= 0, "mask must be binary"
162
+ mask = (~mask.bool()).float()
163
+ mask = torch.bernoulli(mask * (1 - p))
164
+ mask = ~mask.round().bool()
165
+ return mask.long()
166
+
167
+ def mask_or(
168
+ mask1: torch.Tensor,
169
+ mask2: torch.Tensor
170
+ ):
171
+ assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
172
+ assert mask1.max() <= 1, "mask1 must be binary"
173
+ assert mask2.max() <= 1, "mask2 must be binary"
174
+ assert mask1.min() >= 0, "mask1 must be binary"
175
+ assert mask2.min() >= 0, "mask2 must be binary"
176
+ return (mask1 + mask2).clamp(0, 1)
177
+
178
+ def time_stretch_mask(
179
+ x: torch.Tensor,
180
+ stretch_factor: int,
181
+ ):
182
+ assert stretch_factor >= 1, "stretch factor must be >= 1"
183
+ c_seq_len = x.shape[-1]
184
+ x = x.repeat_interleave(stretch_factor, dim=-1)
185
+
186
+ # trim cz to the original length
187
+ x = x[:, :, :c_seq_len]
188
+
189
+ mask = periodic_mask(x, stretch_factor, width=1)
190
+ return mask
191
+
192
+ def onset_mask(
193
+ sig: AudioSignal,
194
+ z: torch.Tensor,
195
+ interface,
196
+ width: int = 1
197
+ ):
198
+ import librosa
199
+ import madmom
200
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
201
+ import tempfile
202
+ import numpy as np
203
+
204
+ with tempfile.NamedTemporaryFile(suffix='.wav') as f:
205
+ sig = sig.clone()
206
+ sig.write(f.name)
207
+
208
+ proc = RNNOnsetProcessor(online=False)
209
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
210
+ fps=sig.sample_rate/interface.codec.hop_length)
211
+
212
+ act = proc(f.name)
213
+ onset_times = onsetproc(act)
214
+
215
+ # convert to indices for z array
216
+ onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
217
+
218
+ if onset_indices.shape[0] == 0:
219
+ mask = empty_mask(z)
220
+ print(f"no onsets found, returning empty mask")
221
+ else:
222
+ torch.set_printoptions(threshold=1000)
223
+ print("onset indices: ", onset_indices)
224
+ print("onset times: ", onset_times)
225
+
226
+ # create a mask, set onset
227
+ mask = torch.ones_like(z)
228
+ n_timesteps = z.shape[-1]
229
+
230
+ for onset_index in onset_indices:
231
+ onset_index = min(onset_index, n_timesteps - 1)
232
+ onset_index = max(onset_index, 0)
233
+ mask[:, :, onset_index - width:onset_index + width] = 0.0
234
+
235
+ print(mask)
236
+
237
+ return mask
238
+
239
+
240
+
241
+ if __name__ == "__main__":
242
+ pass
vampnet/vampnet/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import audiotools
2
+
3
+ audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
+ audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
+
6
+ from .transformer import VampNet
vampnet/vampnet/modules/activations.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class NewGELU(nn.Module):
10
+ """
11
+ Implementation of the GELU activation function currently in Google BERT repo
12
+ (identical to OpenAI GPT). Also see the Gaussian Error Linear Units
13
+ paper: https://arxiv.org/abs/1606.08415
14
+ """
15
+
16
+ def forward(self, x):
17
+ return (
18
+ 0.5
19
+ * x
20
+ * (
21
+ 1.0
22
+ + torch.tanh(
23
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
24
+ )
25
+ )
26
+ )
27
+
28
+ class GatedGELU(nn.Module):
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.gelu = NewGELU()
32
+
33
+ def forward(self, x, dim: int = -1):
34
+ p1, p2 = x.chunk(2, dim=dim)
35
+ return p1 * self.gelu(p2)
36
+
37
+ class Snake1d(nn.Module):
38
+ def __init__(self, channels):
39
+ super().__init__()
40
+ self.alpha = nn.Parameter(torch.ones(channels))
41
+
42
+ def forward(self, x):
43
+ return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
44
+
45
+ def get_activation(name: str = "relu"):
46
+ if name == "relu":
47
+ return nn.ReLU
48
+ elif name == "gelu":
49
+ return NewGELU
50
+ elif name == "geglu":
51
+ return GatedGELU
52
+ elif name == "snake":
53
+ return Snake1d
54
+ else:
55
+ raise ValueError(f"Unrecognized activation {name}")
vampnet/vampnet/modules/layers.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Optional
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch.nn.utils import weight_norm
10
+
11
+ # Scripting this brings model speed up 1.4x
12
+ @torch.jit.script
13
+ def snake(x, alpha):
14
+ shape = x.shape
15
+ x = x.reshape(shape[0], shape[1], -1)
16
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
17
+ x = x.reshape(shape)
18
+ return x
19
+
20
+
21
+ class Snake1d(nn.Module):
22
+ def __init__(self, channels):
23
+ super().__init__()
24
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
25
+
26
+ def forward(self, x):
27
+ return snake(x, self.alpha)
28
+
29
+
30
+ def num_params(model):
31
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
32
+
33
+
34
+ def recurse_children(module, fn):
35
+ for child in module.children():
36
+ if isinstance(child, nn.ModuleList):
37
+ for c in child:
38
+ yield recurse_children(c, fn)
39
+ if isinstance(child, nn.ModuleDict):
40
+ for c in child.values():
41
+ yield recurse_children(c, fn)
42
+
43
+ yield recurse_children(child, fn)
44
+ yield fn(child)
45
+
46
+
47
+ def WNConv1d(*args, **kwargs):
48
+ return weight_norm(nn.Conv1d(*args, **kwargs))
49
+
50
+
51
+ def WNConvTranspose1d(*args, **kwargs):
52
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
53
+
54
+
55
+ class SequentialWithFiLM(nn.Module):
56
+ """
57
+ handy wrapper for nn.Sequential that allows FiLM layers to be
58
+ inserted in between other layers.
59
+ """
60
+
61
+ def __init__(self, *layers):
62
+ super().__init__()
63
+ self.layers = nn.ModuleList(layers)
64
+
65
+ @staticmethod
66
+ def has_film(module):
67
+ mod_has_film = any(
68
+ [res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
69
+ )
70
+ return mod_has_film
71
+
72
+ def forward(self, x, cond):
73
+ for layer in self.layers:
74
+ if self.has_film(layer):
75
+ x = layer(x, cond)
76
+ else:
77
+ x = layer(x)
78
+ return x
79
+
80
+
81
+ class FiLM(nn.Module):
82
+ def __init__(self, input_dim: int, output_dim: int):
83
+ super().__init__()
84
+
85
+ self.input_dim = input_dim
86
+ self.output_dim = output_dim
87
+
88
+ if input_dim > 0:
89
+ self.beta = nn.Linear(input_dim, output_dim)
90
+ self.gamma = nn.Linear(input_dim, output_dim)
91
+
92
+ def forward(self, x, r):
93
+ if self.input_dim == 0:
94
+ return x
95
+ else:
96
+ beta, gamma = self.beta(r), self.gamma(r)
97
+ beta, gamma = (
98
+ beta.view(x.size(0), self.output_dim, 1),
99
+ gamma.view(x.size(0), self.output_dim, 1),
100
+ )
101
+ x = x * (gamma + 1) + beta
102
+ return x
103
+
104
+
105
+ class CodebookEmbedding(nn.Module):
106
+ def __init__(
107
+ self,
108
+ vocab_size: int,
109
+ latent_dim: int,
110
+ n_codebooks: int,
111
+ emb_dim: int,
112
+ special_tokens: Optional[Tuple[str]] = None,
113
+ ):
114
+ super().__init__()
115
+ self.n_codebooks = n_codebooks
116
+ self.emb_dim = emb_dim
117
+ self.latent_dim = latent_dim
118
+ self.vocab_size = vocab_size
119
+
120
+ if special_tokens is not None:
121
+ for tkn in special_tokens:
122
+ self.special = nn.ParameterDict(
123
+ {
124
+ tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
125
+ for tkn in special_tokens
126
+ }
127
+ )
128
+ self.special_idxs = {
129
+ tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
130
+ }
131
+
132
+ self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
+
134
+ def from_codes(self, codes: torch.Tensor, codec):
135
+ """
136
+ get a sequence of continuous embeddings from a sequence of discrete codes.
137
+ unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
138
+ necessary for the language model, like <MASK>.
139
+ """
140
+ n_codebooks = codes.shape[1]
141
+ latent = []
142
+ for i in range(n_codebooks):
143
+ c = codes[:, i, :]
144
+
145
+ lookup_table = codec.quantizer.quantizers[i].codebook.weight
146
+ if hasattr(self, "special"):
147
+ special_lookup = torch.cat(
148
+ [self.special[tkn][i : i + 1] for tkn in self.special], dim=0
149
+ )
150
+ lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
151
+
152
+ l = F.embedding(c, lookup_table).transpose(1, 2)
153
+ latent.append(l)
154
+
155
+ latent = torch.cat(latent, dim=1)
156
+ return latent
157
+
158
+ def forward(self, latents: torch.Tensor):
159
+ """
160
+ project a sequence of latents to a sequence of embeddings
161
+ """
162
+ x = self.out_proj(latents)
163
+ return x
164
+
vampnet/vampnet/modules/transformer.py ADDED
@@ -0,0 +1,953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import logging
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ import loralib as lora
11
+ import audiotools as at
12
+
13
+ from .activations import get_activation
14
+ from .layers import CodebookEmbedding
15
+ from .layers import FiLM
16
+ from .layers import SequentialWithFiLM
17
+ from .layers import WNConv1d
18
+ from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
19
+ from ..mask import _gamma
20
+
21
+ LORA_R = 8
22
+
23
+ # def log(t, eps=1e-20):
24
+ # return torch.log(t + eps)
25
+
26
+
27
+ def gumbel_noise_like(t):
28
+ noise = torch.zeros_like(t).uniform_(1e-20, 1)
29
+ return -torch.log(-torch.log(noise))
30
+
31
+
32
+ def gumbel_sample(t, temperature=1.0, dim=-1):
33
+ return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
34
+
35
+
36
+ class RMSNorm(nn.Module):
37
+ def __init__(self, hidden_size: int, eps=1e-6):
38
+ super().__init__()
39
+ self.weight = nn.Parameter(torch.ones(hidden_size))
40
+ self.var_eps = eps
41
+
42
+ def forward(self, x):
43
+ """Returns root mean square normalized version of input `x`
44
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known
45
+ # as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
46
+ # thus varience is calculated w/o mean and there is no bias
47
+ Parameters
48
+ ----------
49
+ x : Tensor[B x T x D]
50
+ Returns
51
+ -------
52
+ Tensor[B x T x D]
53
+ """
54
+ var = x.pow(2).mean(-1, keepdim=True)
55
+ x = x * torch.rsqrt(var + self.var_eps)
56
+
57
+ return self.weight * x
58
+
59
+
60
+ class FeedForward(nn.Module):
61
+ def __init__(
62
+ self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
63
+ ):
64
+ super().__init__()
65
+ factor = 2 if activation == "geglu" else 1
66
+ self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
67
+ self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
68
+ self.drop = nn.Dropout(dropout)
69
+ self.act = get_activation(activation)()
70
+
71
+ def forward(self, x):
72
+ """Computes position-wise feed-forward layer
73
+ Parameters
74
+ ----------
75
+ x : Tensor[B x T x D]
76
+ Returns
77
+ -------
78
+ Tensor[B x T x D]
79
+ """
80
+ x = self.w_1(x)
81
+ x = self.act(x)
82
+ x = self.drop(x)
83
+ x = self.w_2(x)
84
+ return x
85
+
86
+
87
+ class MultiHeadRelativeAttention(nn.Module):
88
+ def __init__(
89
+ self,
90
+ n_head: int = 8,
91
+ d_model: int = 512,
92
+ dropout: float = 0.1,
93
+ bidirectional: bool = True,
94
+ has_relative_attention_bias: bool = True,
95
+ attention_num_buckets: int = 32,
96
+ attention_max_distance: int = 128,
97
+ ):
98
+ super().__init__()
99
+ d_head = d_model // n_head
100
+ self.n_head = n_head
101
+ self.d_head = d_head
102
+ self.bidirectional = bidirectional
103
+ self.has_relative_attention_bias = has_relative_attention_bias
104
+ self.attention_num_buckets = attention_num_buckets
105
+ self.attention_max_distance = attention_max_distance
106
+
107
+ # Create linear query, key, value projections
108
+ self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
109
+ self.w_ks = nn.Linear(d_model, d_model, bias=False)
110
+ self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
111
+
112
+ # Create linear final output projection
113
+ self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
114
+
115
+ # Dropout for attention output weights
116
+ self.dropout = nn.Dropout(dropout)
117
+
118
+ # Create relative positional embeddings (if turned on)
119
+ if has_relative_attention_bias:
120
+ self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
121
+
122
+ def _relative_position_bucket(self, relative_position):
123
+ """Converts unbounded relative position into bounded set of buckets
124
+ with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
125
+ buckets
126
+ Parameters
127
+ ----------
128
+ relative_position : Tensor[T_q x T_kv]
129
+ Relative positions between queries and key_value items
130
+ Returns
131
+ -------
132
+ Tensor[T_q x T_kv]
133
+ Input relative positions converted into buckets
134
+ """
135
+ relative_buckets = 0
136
+ num_buckets = self.attention_num_buckets
137
+ max_distance = self.attention_max_distance
138
+
139
+ # Convert relative position for (-inf, inf) to [0, inf]
140
+ # Negative relative positions correspond to past
141
+ # Positive relative positions correspond to future
142
+ if self.bidirectional:
143
+ # use half buckets for each side (past / future)
144
+ num_buckets //= 2
145
+
146
+ # Shift the position positions by `num_buckets` to wrap around
147
+ # negative positions
148
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
149
+ relative_position = torch.abs(relative_position)
150
+ else:
151
+ # If not bidirectional, ignore positive positions and wrap
152
+ # negative positions to positive
153
+ relative_position = -torch.min(
154
+ relative_position, torch.zeros_like(relative_position)
155
+ )
156
+
157
+ # Allocate half of the buckets are for exact increments in positions
158
+ max_exact = num_buckets // 2
159
+ is_small = relative_position < max_exact
160
+
161
+ # The other half of the buckets are for logarithmically bigger bins in
162
+ # positions up to `max_distance`
163
+ relative_postion_if_large = max_exact + (
164
+ torch.log(relative_position.float() / max_exact)
165
+ / math.log(max_distance / max_exact)
166
+ * (num_buckets - max_exact)
167
+ ).to(torch.long)
168
+
169
+ # Clip the max relative position to `num_buckets - 1`
170
+ relative_postion_if_large = torch.min(
171
+ relative_postion_if_large,
172
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
173
+ )
174
+
175
+ # Choose relative buckets based on small or large positions
176
+ relative_buckets += torch.where(
177
+ is_small, relative_position, relative_postion_if_large
178
+ )
179
+
180
+ return relative_buckets
181
+
182
+ def compute_bias(self, query_length, key_length):
183
+ """Computes a position bias scalar for each index in query_length x key_length
184
+ Parameters
185
+ ----------
186
+ query_length : int
187
+ key_length : int
188
+ Returns
189
+ -------
190
+ Tensor[heads x 1 x T_q x T_kv]
191
+ Position bias to be applied on attention logits
192
+ """
193
+
194
+ query_position = torch.arange(query_length, dtype=torch.long)[:, None]
195
+ key_position = torch.arange(key_length, dtype=torch.long)[None, :]
196
+ relative_position = key_position - query_position
197
+
198
+ # Convert relative position to buckets
199
+ relative_position_bucket = self._relative_position_bucket(relative_position)
200
+ relative_position_bucket = relative_position_bucket.to(
201
+ self.relative_attention_bias.weight.device
202
+ )
203
+
204
+ # Index attention bias values
205
+ values = self.relative_attention_bias(relative_position_bucket)
206
+ values = rearrange(values, "q k h -> h 1 q k")
207
+
208
+ return values
209
+
210
+ def forward(self, q, k, v, mask=None, position_bias=None):
211
+ """Computes attention over (keys, values) for every timestep in query
212
+ Parameters
213
+ ----------
214
+ q : Tensor[B x T_q x d_model]
215
+ Query vectors
216
+ k : Tensor[B x T_kv x d_model]
217
+ Key vectors to compute attention over
218
+ v : Tensor[B x T_kv x d_model]
219
+ Value vectors corresponding to the keys
220
+ mask : Tensor[B x T_q x T_kv], optional
221
+ position_bias: Tensor[head x 1 x T_q x T_kv]
222
+ Returns
223
+ -------
224
+ Tensor[B x T_q x d_model]
225
+ Outputs after attending (key, value) using queries
226
+ """
227
+ # Compute query, key, value projections
228
+ q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
229
+ k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
230
+ v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
231
+
232
+ # Compute attention matrix
233
+ attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
234
+
235
+ # Add relative position bias to attention scores
236
+ if position_bias is None:
237
+ if self.has_relative_attention_bias:
238
+ position_bias = self.compute_bias(q.size(-2), k.size(-2))
239
+ else:
240
+ position_bias = torch.zeros_like(attn)
241
+ attn += position_bias
242
+
243
+ # Apply mask to attention scores to prevent looking up invalid locations
244
+ if mask is not None:
245
+ attn = attn.masked_fill(mask[None] == 0, -1e9)
246
+
247
+ # Normalize attention scores and add dropout
248
+ attn = torch.softmax(attn, dim=3)
249
+ attn = self.dropout(attn)
250
+
251
+ # Compute attended outputs (product of attention matrix and values)
252
+ output = torch.einsum("hblt,hbtv->hblv", [attn, v])
253
+ output = rearrange(output, "head b l v -> b l (head v)")
254
+ output = self.fc(output)
255
+
256
+ return output, position_bias
257
+
258
+
259
+ class TransformerLayer(nn.Module):
260
+ def __init__(
261
+ self,
262
+ d_model: int = 512,
263
+ d_cond: int = 64,
264
+ n_heads: int = 8,
265
+ bidirectional: bool = True,
266
+ is_decoder: bool = False,
267
+ has_relative_attention_bias: bool = False,
268
+ flash_attn: bool = False,
269
+ dropout: float = 0.1,
270
+ ):
271
+ super().__init__()
272
+ # Store args
273
+ self.is_decoder = is_decoder
274
+
275
+ # Create self-attention layer
276
+ self.norm_1 = RMSNorm(d_model)
277
+ self.film_1 = FiLM(d_cond, d_model)
278
+ self.flash_attn = flash_attn
279
+
280
+ if flash_attn:
281
+ from flash_attn.flash_attention import FlashMHA
282
+ self.self_attn = FlashMHA(
283
+ embed_dim=d_model,
284
+ num_heads=n_heads,
285
+ attention_dropout=dropout,
286
+ causal=False,
287
+ )
288
+ else:
289
+ self.self_attn = MultiHeadRelativeAttention(
290
+ n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
291
+ )
292
+
293
+ # (Optional) Create cross-attention layer
294
+ if is_decoder:
295
+ self.norm_2 = RMSNorm(d_model)
296
+ self.film_2 = FiLM(d_cond, d_model)
297
+ self.cross_attn = MultiHeadRelativeAttention(
298
+ n_heads,
299
+ d_model,
300
+ dropout,
301
+ bidirectional=True,
302
+ has_relative_attention_bias=False,
303
+ )
304
+
305
+ # Create last feed-forward layer
306
+ self.norm_3 = RMSNorm(d_model)
307
+ self.film_3 = FiLM(d_cond, d_model)
308
+ self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
309
+
310
+ # Create dropout
311
+ self.dropout = nn.Dropout(dropout)
312
+
313
+ def forward(
314
+ self,
315
+ x,
316
+ x_mask,
317
+ cond,
318
+ src=None,
319
+ src_mask=None,
320
+ position_bias=None,
321
+ encoder_decoder_position_bias=None,
322
+ ):
323
+ """Computes one transformer layer consisting of self attention, (op) cross attention
324
+ and feedforward layer
325
+ Parameters
326
+ ----------
327
+ x : Tensor[B x T_q x D]
328
+ x_mask : Tensor[B x T_q]
329
+ src : Tensor[B x T_kv x D], optional
330
+ src_mask : Tensor[B x T_kv x D], optional
331
+ position_bias : Tensor[heads x B x T_q x T_q], optional
332
+ Relative position bias for self attention layer
333
+ encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
334
+ Relative position bias for cross attention layer
335
+ Returns
336
+ -------
337
+ Tensor[B x T_q x D]
338
+ """
339
+ y = self.norm_1(x)
340
+ y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
341
+ if self.flash_attn:
342
+ with torch.autocast(y.device.type, dtype=torch.bfloat16):
343
+ y = self.self_attn(y)[0]
344
+ else:
345
+ y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
346
+ x = x + self.dropout(y)
347
+
348
+ if self.is_decoder:
349
+ y = self.norm_2(x)
350
+ y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
351
+ y, encoder_decoder_position_bias = self.cross_attn(
352
+ y, src, src, src_mask, encoder_decoder_position_bias
353
+ )
354
+ x = x + self.dropout(y)
355
+
356
+ y = self.norm_3(x)
357
+ y = self.film_3(
358
+ y.permute(
359
+ 0,
360
+ 2,
361
+ 1,
362
+ ),
363
+ cond,
364
+ ).permute(0, 2, 1)
365
+ y = self.feed_forward(y)
366
+ x = x + self.dropout(y)
367
+
368
+ return x, position_bias, encoder_decoder_position_bias
369
+
370
+
371
+ class TransformerStack(nn.Module):
372
+ def __init__(
373
+ self,
374
+ d_model: int = 512,
375
+ d_cond: int = 64,
376
+ n_heads: int = 8,
377
+ n_layers: int = 8,
378
+ last_layer: bool = True,
379
+ bidirectional: bool = True,
380
+ flash_attn: bool = False,
381
+ is_decoder: bool = False,
382
+ dropout: float = 0.1,
383
+ ):
384
+ super().__init__()
385
+ # Store args
386
+ self.bidirectional = bidirectional
387
+ self.is_decoder = is_decoder
388
+
389
+ # Create transformer layers
390
+ # In T5, relative attention bias is shared by all layers in the stack
391
+ self.layers = nn.ModuleList(
392
+ [
393
+ TransformerLayer(
394
+ d_model,
395
+ d_cond,
396
+ n_heads,
397
+ bidirectional,
398
+ is_decoder,
399
+ has_relative_attention_bias=True if (i == 0) else False,
400
+ flash_attn=flash_attn,
401
+ dropout=dropout,
402
+ )
403
+ for i in range(n_layers)
404
+ ]
405
+ )
406
+
407
+ # Perform last normalization
408
+ self.norm = RMSNorm(d_model) if last_layer else None
409
+
410
+ def subsequent_mask(self, size):
411
+ return torch.ones(1, size, size).tril().bool()
412
+
413
+ def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
+ return_activations: bool = False
415
+ ):
416
+ """Computes a full transformer stack
417
+ Parameters
418
+ ----------
419
+ x : Tensor[B x T_q x D]
420
+ x_mask : Tensor[B x T_q]
421
+ src : Tensor[B x T_kv x D], optional
422
+ src_mask : Tensor[B x T_kv], optional
423
+ Returns
424
+ -------
425
+ Tensor[B x T_q x D]
426
+ """
427
+
428
+ # Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
429
+ if self.is_decoder:
430
+ src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
431
+
432
+ # Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
433
+ x_mask = x_mask.unsqueeze(-2)
434
+ if not self.bidirectional:
435
+ x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
436
+
437
+ # Initialize position biases
438
+ position_bias = None
439
+ encoder_decoder_position_bias = None
440
+
441
+ # Compute transformer layers
442
+ if return_activations:
443
+ activations = []
444
+ for layer in self.layers:
445
+ x, position_bias, encoder_decoder_position_bias = layer(
446
+ x=x,
447
+ x_mask=x_mask,
448
+ cond=cond,
449
+ src=src,
450
+ src_mask=src_mask,
451
+ position_bias=position_bias,
452
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
453
+ )
454
+ if return_activations:
455
+ activations.append(x.detach())
456
+
457
+
458
+ out = self.norm(x) if self.norm is not None else x
459
+ if return_activations:
460
+ return out, torch.stack(activations)
461
+ else:
462
+ return out
463
+
464
+
465
+ class VampNet(at.ml.BaseModel):
466
+ def __init__(
467
+ self,
468
+ n_heads: int = 20,
469
+ n_layers: int = 16,
470
+ r_cond_dim: int = 0,
471
+ n_codebooks: int = 9,
472
+ n_conditioning_codebooks: int = 0,
473
+ latent_dim: int = 8,
474
+ embedding_dim: int = 1280,
475
+ vocab_size: int = 1024,
476
+ flash_attn: bool = True,
477
+ noise_mode: str = "mask",
478
+ dropout: float = 0.1
479
+ ):
480
+ super().__init__()
481
+ assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
482
+ self.n_heads = n_heads
483
+ self.n_layers = n_layers
484
+ self.r_cond_dim = r_cond_dim
485
+ self.n_codebooks = n_codebooks
486
+ self.n_conditioning_codebooks = n_conditioning_codebooks
487
+ self.embedding_dim = embedding_dim
488
+ self.vocab_size = vocab_size
489
+ self.latent_dim = latent_dim
490
+ self.flash_attn = flash_attn
491
+ self.noise_mode = noise_mode
492
+
493
+ assert self.noise_mode == "mask", "deprecated"
494
+
495
+ self.embedding = CodebookEmbedding(
496
+ latent_dim=latent_dim,
497
+ n_codebooks=n_codebooks,
498
+ vocab_size=vocab_size,
499
+ emb_dim=embedding_dim,
500
+ special_tokens=["MASK"],
501
+ )
502
+ self.mask_token = self.embedding.special_idxs["MASK"]
503
+
504
+ self.transformer = TransformerStack(
505
+ d_model=embedding_dim,
506
+ d_cond=r_cond_dim,
507
+ n_heads=n_heads,
508
+ n_layers=n_layers,
509
+ last_layer=True,
510
+ bidirectional=True,
511
+ flash_attn=flash_attn,
512
+ is_decoder=False,
513
+ dropout=dropout,
514
+ )
515
+
516
+ # Add final conv layer
517
+ self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
518
+ self.classifier = SequentialWithFiLM(
519
+ WNConv1d(
520
+ embedding_dim,
521
+ vocab_size * self.n_predict_codebooks,
522
+ kernel_size=1,
523
+ padding="same",
524
+ # groups=self.n_predict_codebooks,
525
+ ),
526
+ )
527
+
528
+ def forward(self, x, return_activations: bool = False):
529
+ x = self.embedding(x)
530
+ x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
+
532
+ x = rearrange(x, "b d n -> b n d")
533
+ out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
534
+ if return_activations:
535
+ out, activations = out
536
+
537
+ out = rearrange(out, "b n d -> b d n")
538
+
539
+ out = self.classifier(out, None) # no cond here!
540
+
541
+ out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
+
543
+ if return_activations:
544
+ return out, activations
545
+ else:
546
+ return out
547
+
548
+ def r_embed(self, r, max_positions=10000):
549
+ if self.r_cond_dim > 0:
550
+ dtype = r.dtype
551
+
552
+ r = _gamma(r) * max_positions
553
+ half_dim = self.r_cond_dim // 2
554
+
555
+ emb = math.log(max_positions) / (half_dim - 1)
556
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
557
+
558
+ emb = r[:, None] * emb[None, :]
559
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
560
+
561
+ if self.r_cond_dim % 2 == 1: # zero pad
562
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
563
+
564
+ return emb.to(dtype)
565
+ else:
566
+ return r
567
+
568
+ @torch.no_grad()
569
+ def to_signal(self, z, codec):
570
+ """
571
+ convert a sequence of latents to a signal.
572
+ """
573
+ assert z.ndim == 3
574
+
575
+ signal = at.AudioSignal(
576
+ codec.decode(
577
+ codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
578
+ )["audio"],
579
+ codec.sample_rate,
580
+ )
581
+
582
+ # find where the mask token is and replace it with silence in the audio
583
+ for tstep in range(z.shape[-1]):
584
+ if torch.any(z[:, :, tstep] == self.mask_token):
585
+ sample_idx_0 = tstep * codec.hop_length
586
+ sample_idx_1 = sample_idx_0 + codec.hop_length
587
+ signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
588
+
589
+ return signal
590
+
591
+
592
+ @torch.no_grad()
593
+ def generate(
594
+ self,
595
+ codec,
596
+ time_steps: int = 300,
597
+ sampling_steps: int = 36,
598
+ start_tokens: Optional[torch.Tensor] = None,
599
+ sampling_temperature: float = 1.0,
600
+ mask: Optional[torch.Tensor] = None,
601
+ mask_temperature: float = 10.5,
602
+ typical_filtering=False,
603
+ typical_mass=0.2,
604
+ typical_min_tokens=1,
605
+ top_p=None,
606
+ return_signal=True,
607
+ seed: int = None,
608
+ sample_cutoff: float = 1.0,
609
+ ):
610
+ if seed is not None:
611
+ at.util.seed(seed)
612
+ logging.debug(f"beginning generation with {sampling_steps} steps")
613
+
614
+
615
+
616
+ #####################
617
+ # resolve initial z #
618
+ #####################
619
+ z = start_tokens
620
+
621
+ if z is None:
622
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
623
+ self.device
624
+ )
625
+
626
+ logging.debug(f"created z with shape {z.shape}")
627
+
628
+
629
+ #################
630
+ # resolve mask #
631
+ #################
632
+
633
+ if mask is None:
634
+ mask = torch.ones_like(z).to(self.device).int()
635
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
636
+ if mask.ndim == 2:
637
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
638
+ # init_mask = mask.clone()
639
+
640
+ logging.debug(f"created mask with shape {mask.shape}")
641
+
642
+
643
+ ###########
644
+ # set up #
645
+ ##########
646
+ # apply the mask to z
647
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
648
+ # logging.debug(f"z_masked: {z_masked}")
649
+
650
+ # how many mask tokens to begin with?
651
+ num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
652
+ logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
653
+
654
+ # how many codebooks are we inferring vs conditioning on?
655
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
656
+ logging.debug(f"n infer codebooks: {n_infer_codebooks}")
657
+
658
+ #################
659
+ # begin sampling #
660
+ #################
661
+
662
+ for i in range(sampling_steps):
663
+ logging.debug(f"step {i} of {sampling_steps}")
664
+
665
+ # our current schedule step
666
+ r = scalar_to_batch_tensor(
667
+ (i + 1) / sampling_steps,
668
+ z.shape[0]
669
+ ).to(z.device)
670
+ logging.debug(f"r: {r}")
671
+
672
+ # get latents
673
+ latents = self.embedding.from_codes(z_masked, codec)
674
+ logging.debug(f"computed latents with shape: {latents.shape}")
675
+
676
+
677
+ # infer from latents
678
+ # NOTE: this collapses the codebook dimension into the sequence dimension
679
+ logits = self.forward(latents) # b, prob, seq
680
+ logits = logits.permute(0, 2, 1) # b, seq, prob
681
+ b = logits.shape[0]
682
+
683
+ logging.debug(f"permuted logits with shape: {logits.shape}")
684
+
685
+ sampled_z, selected_probs = sample_from_logits(
686
+ logits, sample=(
687
+ (i / sampling_steps) <= sample_cutoff
688
+ ),
689
+ temperature=sampling_temperature,
690
+ typical_filtering=typical_filtering, typical_mass=typical_mass,
691
+ typical_min_tokens=typical_min_tokens,
692
+ top_k=None, top_p=top_p, return_probs=True,
693
+ )
694
+
695
+ logging.debug(f"sampled z with shape: {sampled_z.shape}")
696
+
697
+ # flatten z_masked and mask, so we can deal with the sampling logic
698
+ # we'll unflatten them at the end of the loop for the next forward pass
699
+ # remove conditioning codebooks, we'll add them back at the end
700
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
701
+
702
+ mask = (z_masked == self.mask_token).int()
703
+
704
+ # update the mask, remove conditioning codebooks from the mask
705
+ logging.debug(f"updated mask with shape: {mask.shape}")
706
+ # add z back into sampled z where the mask was false
707
+ sampled_z = torch.where(
708
+ mask.bool(), sampled_z, z_masked
709
+ )
710
+ logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
711
+
712
+ # ignore any tokens that weren't masked
713
+ selected_probs = torch.where(
714
+ mask.bool(), selected_probs, torch.inf
715
+ )
716
+
717
+ # get the num tokens to mask, according to the schedule
718
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
719
+ logging.debug(f"num to mask: {num_to_mask}")
720
+
721
+ if i != (sampling_steps - 1):
722
+ num_to_mask = torch.maximum(
723
+ torch.tensor(1),
724
+ torch.minimum(
725
+ mask.sum(dim=-1, keepdim=True) - 1,
726
+ num_to_mask
727
+ )
728
+ )
729
+
730
+
731
+ # get our new mask
732
+ mask = mask_by_random_topk(
733
+ num_to_mask, selected_probs, mask_temperature * (1-r)
734
+ )
735
+
736
+ # update the mask
737
+ z_masked = torch.where(
738
+ mask.bool(), self.mask_token, sampled_z
739
+ )
740
+ logging.debug(f"updated z_masked with shape: {z_masked.shape}")
741
+
742
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
743
+ mask = codebook_unflatten(mask, n_infer_codebooks)
744
+ logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
745
+
746
+ # add conditioning codebooks back to z_masked
747
+ z_masked = torch.cat(
748
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
749
+ )
750
+ logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
751
+
752
+
753
+ # add conditioning codebooks back to sampled_z
754
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
755
+ sampled_z = torch.cat(
756
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
757
+ )
758
+
759
+ logging.debug(f"finished sampling")
760
+
761
+ if return_signal:
762
+ return self.to_signal(sampled_z, codec)
763
+ else:
764
+ return sampled_z
765
+
766
+ def sample_from_logits(
767
+ logits,
768
+ sample: bool = True,
769
+ temperature: float = 1.0,
770
+ top_k: int = None,
771
+ top_p: float = None,
772
+ typical_filtering: bool = False,
773
+ typical_mass: float = 0.2,
774
+ typical_min_tokens: int = 1,
775
+ return_probs: bool = False
776
+ ):
777
+ """Convenience function to sample from a categorial distribution with input as
778
+ unnormalized logits.
779
+
780
+ Parameters
781
+ ----------
782
+ logits : Tensor[..., vocab_size]
783
+ config: SamplingConfig
784
+ The set of hyperparameters to be used for sampling
785
+ sample : bool, optional
786
+ Whether to perform multinomial sampling, by default True
787
+ temperature : float, optional
788
+ Scaling parameter when multinomial samping, by default 1.0
789
+ top_k : int, optional
790
+ Restricts sampling to only `top_k` values acc. to probability,
791
+ by default None
792
+ top_p : float, optional
793
+ Restricts sampling to only those values with cumulative
794
+ probability = `top_p`, by default None
795
+
796
+ Returns
797
+ -------
798
+ Tensor[...]
799
+ Sampled tokens
800
+ """
801
+ shp = logits.shape[:-1]
802
+
803
+ if typical_filtering:
804
+ typical_filter(logits,
805
+ typical_mass=typical_mass,
806
+ typical_min_tokens=typical_min_tokens
807
+ )
808
+
809
+ # Apply top_k sampling
810
+ if top_k is not None:
811
+ v, _ = logits.topk(top_k)
812
+ logits[logits < v[..., [-1]]] = -float("inf")
813
+
814
+ # Apply top_p (nucleus) sampling
815
+ if top_p is not None and top_p < 1.0:
816
+ v, sorted_indices = logits.sort(descending=True)
817
+ cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
818
+
819
+ sorted_indices_to_remove = cumulative_probs > top_p
820
+ # Right shift indices_to_remove to keep 1st token over threshold
821
+ sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
822
+ ..., :-1
823
+ ]
824
+
825
+ # Compute indices_to_remove in unsorted array
826
+ indices_to_remove = sorted_indices_to_remove.scatter(
827
+ -1, sorted_indices, sorted_indices_to_remove
828
+ )
829
+
830
+ logits[indices_to_remove] = -float("inf")
831
+
832
+ # Perform multinomial sampling after normalizing logits
833
+ probs = (
834
+ F.softmax(logits / temperature, dim=-1)
835
+ if temperature > 0
836
+ else logits.softmax(dim=-1)
837
+ )
838
+ token = (
839
+ probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
840
+ if sample
841
+ else logits.argmax(-1)
842
+ )
843
+
844
+ if return_probs:
845
+ token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
846
+ return token, token_probs
847
+ else:
848
+ return token
849
+
850
+
851
+
852
+ def mask_by_random_topk(
853
+ num_to_mask: int,
854
+ probs: torch.Tensor,
855
+ temperature: float = 1.0,
856
+ ):
857
+ """
858
+ Args:
859
+ num_to_mask (int): number of tokens to mask
860
+ probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
861
+ temperature (float, optional): temperature. Defaults to 1.0.
862
+ """
863
+ logging.debug(f"masking by random topk")
864
+ logging.debug(f"num to mask: {num_to_mask}")
865
+ logging.debug(f"probs shape: {probs.shape}")
866
+ logging.debug(f"temperature: {temperature}")
867
+ logging.debug("")
868
+
869
+ noise = gumbel_noise_like(probs)
870
+ confidence = torch.log(probs) + temperature * noise
871
+ logging.debug(f"confidence shape: {confidence.shape}")
872
+
873
+ sorted_confidence, sorted_idx = confidence.sort(dim=-1)
874
+ logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
875
+ logging.debug(f"sorted idx shape: {sorted_idx.shape}")
876
+
877
+ # get the cut off threshold, given the mask length
878
+ cut_off = torch.take_along_dim(
879
+ sorted_confidence, num_to_mask, axis=-1
880
+ )
881
+ logging.debug(f"cut off shape: {cut_off.shape}")
882
+
883
+ # mask out the tokens
884
+ mask = confidence < cut_off
885
+ logging.debug(f"mask shape: {mask.shape}")
886
+
887
+ return mask
888
+
889
+ def typical_filter(
890
+ logits,
891
+ typical_mass: float = 0.95,
892
+ typical_min_tokens: int = 1,):
893
+ nb, nt, _ = logits.shape
894
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
895
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
896
+ x_flat_norm_p = torch.exp(x_flat_norm)
897
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
898
+
899
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
900
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
901
+ x_flat_cumsum = (
902
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
903
+ )
904
+
905
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
906
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
907
+ 1, last_ind.view(-1, 1)
908
+ )
909
+ if typical_min_tokens > 1:
910
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
911
+ indices_to_remove = sorted_indices_to_remove.scatter(
912
+ 1, x_flat_indices, sorted_indices_to_remove
913
+ )
914
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
915
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
916
+ return logits
917
+
918
+
919
+ if __name__ == "__main__":
920
+ # import argbind
921
+ from .layers import num_params
922
+
923
+ VampNet = argbind.bind(VampNet)
924
+
925
+ @argbind.bind(without_prefix=True)
926
+ def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
927
+ seq_len = int(32000 / 512 * seq_len_s)
928
+
929
+ model = VampNet().to(device)
930
+
931
+ z = torch.randint(
932
+ 0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
933
+ ).to(device)
934
+
935
+ r = torch.zeros(batch_size).to(device)
936
+
937
+ z_mask_latent = torch.rand(
938
+ batch_size, model.latent_dim * model.n_codebooks, seq_len
939
+ ).to(device)
940
+ z_hat = model(z_mask_latent)
941
+
942
+ pred = z_hat.argmax(dim=1)
943
+ pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
944
+
945
+ print(f"model has {num_params(model)/1e6:<.3f}M parameters")
946
+ print(f"prediction has shape {pred.shape}")
947
+ breakpoint()
948
+
949
+ args = argbind.parse_args()
950
+ with argbind.scope(args):
951
+ try_model()
952
+
953
+
vampnet/vampnet/scheduler.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ class NoamScheduler:
7
+ """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
8
+ Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ optimizer: torch.optim.Optimizer,
14
+ d_model: int = 512,
15
+ factor: float = 1.0,
16
+ warmup: int = 4000,
17
+ ):
18
+ # Store hparams
19
+ self.warmup = warmup
20
+ self.factor = factor
21
+ self.d_model = d_model
22
+
23
+ # Initialize variables `lr` and `steps`
24
+ self.lr = None
25
+ self.steps = 0
26
+
27
+ # Store the optimizer
28
+ self.optimizer = optimizer
29
+
30
+ def state_dict(self):
31
+ return {
32
+ key: value for key, value in self.__dict__.items() if key != "optimizer"
33
+ }
34
+
35
+ def load_state_dict(self, state_dict):
36
+ self.__dict__.update(state_dict)
37
+
38
+ def step(self):
39
+ self.steps += 1
40
+ self.lr = self.factor * (
41
+ self.d_model ** (-0.5)
42
+ * min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
43
+ )
44
+
45
+ for p in self.optimizer.param_groups:
46
+ p["lr"] = self.lr
47
+
vampnet/vampnet/util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+
3
+ import torch
4
+ from einops import rearrange
5
+
6
+ def scalar_to_batch_tensor(x, batch_size):
7
+ return torch.tensor(x).repeat(batch_size)
8
+
9
+
10
+ def parallelize(
11
+ fn,
12
+ *iterables,
13
+ parallel: str = "thread_map",
14
+ **kwargs
15
+ ):
16
+ if parallel == "thread_map":
17
+ from tqdm.contrib.concurrent import thread_map
18
+ return thread_map(
19
+ fn,
20
+ *iterables,
21
+ **kwargs
22
+ )
23
+ elif parallel == "process_map":
24
+ from tqdm.contrib.concurrent import process_map
25
+ return process_map(
26
+ fn,
27
+ *iterables,
28
+ **kwargs
29
+ )
30
+ elif parallel == "single":
31
+ return [fn(x) for x in tqdm.tqdm(*iterables)]
32
+ else:
33
+ raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
34
+
35
+ def codebook_flatten(tokens: torch.Tensor):
36
+ """
37
+ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
38
+ """
39
+ return rearrange(tokens, "b c t -> b (t c)")
40
+
41
+ def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
42
+ """
43
+ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
44
+ """
45
+ tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
46
+ return tokens