Spaces:
Sleeping
Sleeping
Anvita Pandit commited on
Commit ·
5f6b40b
1
Parent(s): 5da7938
Add WhAM Gradio app with ZeroGPU support
Browse files- LICENSE +21 -0
- README.md +205 -7
- app.py +38 -0
- packages.txt +1 -0
- requirements.txt +16 -0
- setup.py +33 -0
- vampnet/.gitignore +189 -0
- vampnet/.pre-commit-config.yaml +15 -0
- vampnet/LICENSE +21 -0
- vampnet/README.md +94 -0
- vampnet/app.py +677 -0
- vampnet/conf/c2f.yml +14 -0
- vampnet/conf/interface.yml +10 -0
- vampnet/conf/lora/lora.yml +22 -0
- vampnet/conf/vampnet.yml +49 -0
- vampnet/scripts/exp/eval.py +110 -0
- vampnet/scripts/exp/experiment.py +254 -0
- vampnet/scripts/exp/fine_tune.py +81 -0
- vampnet/scripts/exp/train.py +680 -0
- vampnet/scripts/utils/README.md +28 -0
- vampnet/scripts/utils/data/augment.py +67 -0
- vampnet/scripts/utils/data/maestro-reorg.py +39 -0
- vampnet/scripts/utils/plots.py +43 -0
- vampnet/scripts/utils/remove_quiet_files.py +29 -0
- vampnet/scripts/utils/split.py +64 -0
- vampnet/scripts/utils/split_long_audio_file.py +34 -0
- vampnet/scripts/utils/stage.py +30 -0
- vampnet/scripts/utils/visualize_embeddings.py +276 -0
- vampnet/scripts/utils/xeno-canto-dl.py +234 -0
- vampnet/setup.py +47 -0
- vampnet/vampnet/__init__.py +6 -0
- vampnet/vampnet/beats.py +250 -0
- vampnet/vampnet/interface.py +422 -0
- vampnet/vampnet/mask.py +242 -0
- vampnet/vampnet/modules/__init__.py +6 -0
- vampnet/vampnet/modules/activations.py +55 -0
- vampnet/vampnet/modules/layers.py +164 -0
- vampnet/vampnet/modules/transformer.py +953 -0
- vampnet/vampnet/scheduler.py +47 -0
- vampnet/vampnet/util.py +46 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Project CETI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,13 +1,211 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.8.0
|
| 8 |
-
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: WhAM
|
| 3 |
+
emoji: 🐋
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
|
|
|
|
|
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
| 9 |
+
hardware: zero-a10g
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# WhAM: a Whale Acoustics Model
|
| 13 |
+
[](https://arxiv.org/abs/2512.02206)
|
| 14 |
+
[](https://doi.org/10.5281/zenodo.17633708)
|
| 15 |
+
[](https://huggingface.co/datasets/orrp/DSWP)
|
| 16 |
+

|
| 17 |
+
WhAM is a transformer-based audio-to-audio model designed to synthesize and analyze sperm whale codas. Based on [VampNet](https://github.com/hugofloresgarcia/vampnet), WhAM uses masked acoustic token modeling to capture temporal and spectral features of whale communication. WhAM generates codas from a given audio context, enabling three core capabilities:
|
| 18 |
+
|
| 19 |
+
- Acoustic Translation: The ability to style-transfer arbitrary audio prompts (e.g., human speech, noise) into the acoustic texture of sperm whale codas.
|
| 20 |
+
|
| 21 |
+
- Synthesizing novel "pseudocodas".
|
| 22 |
+
|
| 23 |
+
- Providing audio embeddings for downstream tasks such as social unit and spectral feature ("vowel") classification.
|
| 24 |
+
|
| 25 |
+
See our [NeurIPS 2025](https://openreview.net/pdf?id=IL1wvzOgqD) publication for more details.
|
| 26 |
+
|
| 27 |
+
## Installation
|
| 28 |
+
|
| 29 |
+
1. **Clone the repository:**
|
| 30 |
+
```bash
|
| 31 |
+
git clone https://github.com/Project-CETI/wham.git
|
| 32 |
+
cd wham
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
2. **Set up the environment:**
|
| 36 |
+
```bash
|
| 37 |
+
conda create -n wham python=3.9
|
| 38 |
+
conda activate wham
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
3. **Install dependencies:**
|
| 42 |
+
```bash
|
| 43 |
+
# Install the wham package
|
| 44 |
+
pip install -e .
|
| 45 |
+
|
| 46 |
+
# Install VampNet
|
| 47 |
+
pip install -e ./vampnet
|
| 48 |
+
|
| 49 |
+
# Install madmom
|
| 50 |
+
pip install --no-build-isolation madmom
|
| 51 |
+
|
| 52 |
+
# Install ffmpeg
|
| 53 |
+
conda install -c conda-forge ffmpeg
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
4. **Download model weights:**
|
| 57 |
+
Download the [weights](https://zenodo.org/records/17633708) and extract to `vampnet/models/`.
|
| 58 |
+
|
| 59 |
+
## Generation
|
| 60 |
+
|
| 61 |
+
To run WhAM locally and prompt it in your browser:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python vampnet/app.py --args.load conf/interface.yml --Interface.device cuda
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
This will provide you with a Gradio link to test WhAM on inputs of your choice.
|
| 68 |
+
|
| 69 |
+
## Training Data
|
| 70 |
+
|
| 71 |
+

|
| 72 |
+
|
| 73 |
+
You only need to follow these to fine-tune your own version of WhAM. First, obtain the original VampNet weights by following the instructions in the . Download
|
| 74 |
+
c2f.pth and codec.pth and replace the weights you previously downloaded in `vampnet/models`.
|
| 75 |
+
|
| 76 |
+
Second, obtain data:
|
| 77 |
+
|
| 78 |
+
1. **Domain adaptation data:**
|
| 79 |
+
|
| 80 |
+
- Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `vampnet/training_data/domain_adaptation`.
|
| 81 |
+
|
| 82 |
+
- Download audio samples from the [BirdSet Dataset](https://huggingface.co/datasets/DBD-research-group/BirdSet). Save these under the same directory
|
| 83 |
+
|
| 84 |
+
- Finally, download all samples from the [AudioSet Dataset](https://research.google.com/audioset/ontology/index.html) with the label `Animal` and once again save these into the directory
|
| 85 |
+
|
| 86 |
+
3. **Species-specific finetuning:** Finetuning can be performed on the openly available **[Dominica Sperm Whale Project (DSWP)](https://huggingface.co/datasets/orrp/DSWP)** dataset, available on Hugging Face.
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
With data in hand, navigate into `vampnet` and perform Domain Adaptation:
|
| 90 |
+
```bash
|
| 91 |
+
python vampnet/scripts/exp/fine_tune.py "training_data/domain_adaptation" domain_adapted && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/c2f.yml
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
Then fine-tune the domain-adapted model. Create the config file with the command:
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
python vampnet/scripts/exp/fine_tune.py "training_data/species_specific_finetuning" fine-tuned
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
To select which weights you want to use as a checkpoint, change `fine_tune_checkpoint` in `conf/generated/fine-tuned/[c2f/coarse].yml` to `./runs/domain_adaptation/[coarse/c2f]/[checkpoint]/vampnets/weights.pth`. `[checkpoint]` can be `latest` in order to use the last saved checkpoint from the previous run, though it is recommended to manually verify the quality of generations over various checkpoints as overtraining can often cause degradation in audio quality, especially with smaller datasets. After making that change, run the command:
|
| 101 |
+
|
| 102 |
+
```bash
|
| 103 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/c2f.yml
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
After following these steps, you should be able to generate audio via the browser by running:
|
| 107 |
+
```bash
|
| 108 |
+
python app.py --args.load vampnet/conf/generated/fine-tuned/interface.yml
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
**Note**: The coarse and fine weights can be trained separately if compute allows. In this case, you would call the two scripts:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/coarse.yml
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/c2f.yml
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
After both are finished running, ensure that both resulting weights are copied into the same copy of WhAM.
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
## Testing Data
|
| 126 |
+
|
| 127 |
+
1. **Marine Mammel Data:**
|
| 128 |
+
Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `data/testing_data/marine_mammals/data/[SPECIES_NAME]`.
|
| 129 |
+
* `[SPECIES_NAME]` must match the species names found in `wham/generation/prompt_configs.py`.
|
| 130 |
+
|
| 131 |
+
2. **Sperm Whale Codas:**
|
| 132 |
+
To evaluate on sperm whale codas, you can use the openly available [DSWP](https://huggingface.co/datasets/orrp/DSWP) dataset.
|
| 133 |
+
|
| 134 |
+
3. Generate artifical beeps for experiments. `data/generate_beeps.sh`
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
## Reproducing Paper Results
|
| 138 |
+
Note: Access to the DSWP+CETI annotated is required to reproduce all results; as of time of publication, only part of this data is publicly available. Still, we include the following code as it may be useful for researchers who may benefit from our evaluation pipeline.
|
| 139 |
+
|
| 140 |
+
### 1. Downstream Classification Tasks
|
| 141 |
+
To reproduce **Table 1** (Classification Accuracies) and **Figure 7** (Ablation Study):
|
| 142 |
+
|
| 143 |
+
**Table 1 Results:**
|
| 144 |
+
```bash
|
| 145 |
+
cd wham/embedding
|
| 146 |
+
./downstream_tasks.sh
|
| 147 |
+
```
|
| 148 |
+
* Runs all downstream classification tasks.
|
| 149 |
+
* **Baselines:** Run once.
|
| 150 |
+
* **Models (AVES, VampNet):** Run over 3 random seeds; reports mean and standard deviation.
|
| 151 |
+
|
| 152 |
+
**Figure 7 Results (Ablation):**
|
| 153 |
+
```bash
|
| 154 |
+
cd wham/embedding
|
| 155 |
+
./downstream_ablation.sh
|
| 156 |
+
```
|
| 157 |
+
* Outputs accuracy scores for ablation variants (averaged across 3 seeds with error bars).
|
| 158 |
+
|
| 159 |
+
### 2. Generative Metrics
|
| 160 |
+
|
| 161 |
+
**Figure 12: Frechet Audio Distance (FAD) Scores**
|
| 162 |
+
Calculate the distance between WhAM's generated results and real codas:
|
| 163 |
+
```bash
|
| 164 |
+
# Calculate for all species
|
| 165 |
+
bash wham/generation/eval/calculate_FAD.sh
|
| 166 |
+
|
| 167 |
+
# Calculate for a single species
|
| 168 |
+
bash wham/generation/eval/calculate_FAD.sh [species_name]
|
| 169 |
+
```
|
| 170 |
+
* *Runtime:* ~3 hours on an NVIDIA A10 GPU.
|
| 171 |
+
|
| 172 |
+
**Figure 3: FAD with Custom/BirdNET Embeddings**
|
| 173 |
+
To compare against other embeddings:
|
| 174 |
+
1. Convert your `.wav` files to `.npy` embeddings.
|
| 175 |
+
2. Place raw coda embeddings in: `data/testing_data/coda_embeddings`
|
| 176 |
+
3. Place comparison embeddings in subfolders within: `data/testing_data/comparison_embeddings`
|
| 177 |
+
4. Run:
|
| 178 |
+
```bash
|
| 179 |
+
python wham/generation/eval/calculate_custom_fad.py
|
| 180 |
+
```
|
| 181 |
+
*For BirdNET embeddings, refer to the [official repo](https://github.com/BirdNET-Team/BirdNET-Analyzer).*
|
| 182 |
+
|
| 183 |
+
**Table 2: Embedding Type Ablation**
|
| 184 |
+
Calculate distances between raw codas, denoised versions, and noise profiles:
|
| 185 |
+
```bash
|
| 186 |
+
bash wham/generation/eval/FAD_ablation.sh
|
| 187 |
+
```
|
| 188 |
+
* *Prerequisites:* Ensure `data/testing_data/ablation/noise` and `data/testing_data/ablation/denoised` are populated.
|
| 189 |
+
* *Runtime:* ~1.5 hours on an NVIDIA A10 GPU.
|
| 190 |
+
|
| 191 |
+
**Figure 13: Tokenizer Reconstruction**
|
| 192 |
+
Test the mean squared reconstruction error:
|
| 193 |
+
```bash
|
| 194 |
+
bash wham/generation/eval/evaluate_tokenizer.sh
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## Citation
|
| 200 |
+
|
| 201 |
+
Please use the following citation if you use this code, model or data.
|
| 202 |
+
|
| 203 |
+
```bibtex
|
| 204 |
+
@inproceedings{wham2025,
|
| 205 |
+
title={Towards A Translative Model of Sperm Whale Vocalization},
|
| 206 |
+
author={Orr Paradise, Pranav Muralikrishnan, Liangyuan Chen, Hugo Flores Garcia, Bryan Pardo, Roee Diamant, David F. Gruber, Shane Gero, Shafi Goldwasser},
|
| 207 |
+
booktitle={Advances in Neural Information Processing Systems 39: Annual Conference
|
| 208 |
+
on Neural Information Processing Systems 2025, NeurIPS 2025, San Diego, CA, USA},
|
| 209 |
+
year={2025}
|
| 210 |
+
}
|
| 211 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
|
| 6 |
+
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
|
| 7 |
+
MODEL_DIR = os.path.join(REPO_ROOT, "vampnet", "models")
|
| 8 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 9 |
+
|
| 10 |
+
MODEL_REPO = "anvitax/wham-weights"
|
| 11 |
+
WEIGHT_FILES = ["coarse.pth", "c2f.pth", "codec.pth", "wavebeat.pth"]
|
| 12 |
+
|
| 13 |
+
for fname in WEIGHT_FILES:
|
| 14 |
+
target = os.path.join(MODEL_DIR, fname)
|
| 15 |
+
if not os.path.exists(target):
|
| 16 |
+
print(f"Downloading {fname} from {MODEL_REPO}...")
|
| 17 |
+
hf_hub_download(repo_id=MODEL_REPO, filename=fname, local_dir=MODEL_DIR)
|
| 18 |
+
else:
|
| 19 |
+
print(f"Found {fname}")
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, os.path.join(REPO_ROOT, "vampnet"))
|
| 22 |
+
|
| 23 |
+
os.chdir(os.path.join(REPO_ROOT, "vampnet"))
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import spaces
|
| 27 |
+
device = "cpu"
|
| 28 |
+
except ImportError:
|
| 29 |
+
import torch
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
|
| 32 |
+
sys.argv = [
|
| 33 |
+
"app.py",
|
| 34 |
+
"--args.load", "conf/interface.yml",
|
| 35 |
+
"--Interface.device", device,
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
exec(open("app.py").read())
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
gradio
|
| 3 |
+
argbind>=0.3.2
|
| 4 |
+
numpy<1.24
|
| 5 |
+
pydantic==2.10.6
|
| 6 |
+
huggingface_hub
|
| 7 |
+
loralib
|
| 8 |
+
torch_pitch_shift
|
| 9 |
+
soundfile
|
| 10 |
+
pydub
|
| 11 |
+
tqdm
|
| 12 |
+
Cython
|
| 13 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
| 14 |
+
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 15 |
+
descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
| 16 |
+
pyharp
|
setup.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
with open("README.md") as f:
|
| 4 |
+
long_description = f.read()
|
| 5 |
+
|
| 6 |
+
setup(
|
| 7 |
+
name="wham",
|
| 8 |
+
version="0.0.1",
|
| 9 |
+
long_description=long_description,
|
| 10 |
+
long_description_content_type="text/markdown",
|
| 11 |
+
url="https://github.com/orrp/wam",
|
| 12 |
+
license="MIT",
|
| 13 |
+
packages=find_packages(),
|
| 14 |
+
package_dir={},
|
| 15 |
+
install_requires=[
|
| 16 |
+
"descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
|
| 17 |
+
"argbind",
|
| 18 |
+
"pandas",
|
| 19 |
+
"pathlib",
|
| 20 |
+
"pydub",
|
| 21 |
+
"ffmpeg-python",
|
| 22 |
+
"tqdm",
|
| 23 |
+
"scikit-learn",
|
| 24 |
+
"wandb",
|
| 25 |
+
"gdown", # For fetching large files from Google Drive
|
| 26 |
+
"soundfile",
|
| 27 |
+
"transformers",
|
| 28 |
+
"torch",
|
| 29 |
+
"Cython",
|
| 30 |
+
"fadtk",
|
| 31 |
+
"urllib3==2.0"
|
| 32 |
+
],
|
| 33 |
+
)
|
vampnet/.gitignore
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
# Usually these files are written by a python script from a template
|
| 32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 33 |
+
*.manifest
|
| 34 |
+
*.spec
|
| 35 |
+
|
| 36 |
+
# Installer logs
|
| 37 |
+
pip-log.txt
|
| 38 |
+
pip-delete-this-directory.txt
|
| 39 |
+
|
| 40 |
+
# Unit test / coverage reports
|
| 41 |
+
htmlcov/
|
| 42 |
+
.tox/
|
| 43 |
+
.nox/
|
| 44 |
+
.coverage
|
| 45 |
+
.coverage.*
|
| 46 |
+
.cache
|
| 47 |
+
nosetests.xml
|
| 48 |
+
coverage.xml
|
| 49 |
+
*.cover
|
| 50 |
+
*.py,cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
target/
|
| 76 |
+
|
| 77 |
+
# Jupyter Notebook
|
| 78 |
+
.ipynb_checkpoints
|
| 79 |
+
|
| 80 |
+
# IPython
|
| 81 |
+
profile_default/
|
| 82 |
+
ipython_config.py
|
| 83 |
+
|
| 84 |
+
# pyenv
|
| 85 |
+
.python-version
|
| 86 |
+
|
| 87 |
+
# pipenv
|
| 88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 91 |
+
# install all needed dependencies.
|
| 92 |
+
#Pipfile.lock
|
| 93 |
+
|
| 94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 95 |
+
__pypackages__/
|
| 96 |
+
|
| 97 |
+
# Celery stuff
|
| 98 |
+
celerybeat-schedule
|
| 99 |
+
celerybeat.pid
|
| 100 |
+
|
| 101 |
+
# SageMath parsed files
|
| 102 |
+
*.sage.py
|
| 103 |
+
|
| 104 |
+
# Environments
|
| 105 |
+
.env
|
| 106 |
+
.venv
|
| 107 |
+
env/env.sh
|
| 108 |
+
venv/
|
| 109 |
+
env.bak/
|
| 110 |
+
venv.bak/
|
| 111 |
+
|
| 112 |
+
# Spyder project settings
|
| 113 |
+
.spyderproject
|
| 114 |
+
.spyproject
|
| 115 |
+
|
| 116 |
+
# Rope project settings
|
| 117 |
+
.ropeproject
|
| 118 |
+
|
| 119 |
+
# mkdocs documentation
|
| 120 |
+
/site
|
| 121 |
+
|
| 122 |
+
# mypy
|
| 123 |
+
.mypy_cache/
|
| 124 |
+
.dmypy.json
|
| 125 |
+
dmypy.json
|
| 126 |
+
|
| 127 |
+
# Pyre type checker
|
| 128 |
+
.pyre/
|
| 129 |
+
|
| 130 |
+
# Files created by experiments
|
| 131 |
+
output/
|
| 132 |
+
snapshot/
|
| 133 |
+
*.m4a
|
| 134 |
+
notebooks/scratch.ipynb
|
| 135 |
+
notebooks/inspect.ipynb
|
| 136 |
+
notebooks/effects.ipynb
|
| 137 |
+
notebooks/*.ipynb
|
| 138 |
+
notebooks/*.gif
|
| 139 |
+
notebooks/*.wav
|
| 140 |
+
notebooks/*.mp4
|
| 141 |
+
*runs/
|
| 142 |
+
boards/
|
| 143 |
+
samples/
|
| 144 |
+
*.ipynb
|
| 145 |
+
|
| 146 |
+
results.json
|
| 147 |
+
metrics.csv
|
| 148 |
+
mprofile_*
|
| 149 |
+
mem.png
|
| 150 |
+
|
| 151 |
+
results/
|
| 152 |
+
mprofile*
|
| 153 |
+
*.png
|
| 154 |
+
# do not ignore the test wav file
|
| 155 |
+
!tests/audio/short_test_audio.wav
|
| 156 |
+
!tests/audio/output.wav
|
| 157 |
+
*/.DS_Store
|
| 158 |
+
.DS_Store
|
| 159 |
+
env.sh
|
| 160 |
+
_codebraid/
|
| 161 |
+
**/*.html
|
| 162 |
+
**/*.exec.md
|
| 163 |
+
flagged/
|
| 164 |
+
log.txt
|
| 165 |
+
ckpt/
|
| 166 |
+
.syncthing*
|
| 167 |
+
tests/assets/
|
| 168 |
+
archived/
|
| 169 |
+
|
| 170 |
+
scratch/
|
| 171 |
+
|
| 172 |
+
runs-archive
|
| 173 |
+
lyrebird-audiotools
|
| 174 |
+
lyrebird-audio-codec
|
| 175 |
+
samples-*/**
|
| 176 |
+
|
| 177 |
+
gradio-outputs/
|
| 178 |
+
models/
|
| 179 |
+
samples*/
|
| 180 |
+
models-all/
|
| 181 |
+
models.zip
|
| 182 |
+
.git-old
|
| 183 |
+
conf/generated/*
|
| 184 |
+
runs*/
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
gtzan.zip
|
| 188 |
+
.gtzan_emb_cache
|
| 189 |
+
runs
|
vampnet/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/asottile/reorder_python_imports
|
| 3 |
+
rev: v2.5.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: reorder-python-imports
|
| 6 |
+
- repo: https://github.com/psf/black
|
| 7 |
+
rev: 23.1.0
|
| 8 |
+
hooks:
|
| 9 |
+
- id: black
|
| 10 |
+
language_version: python3
|
| 11 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 12 |
+
rev: v4.0.1
|
| 13 |
+
hooks:
|
| 14 |
+
- id: end-of-file-fixer
|
| 15 |
+
- id: trailing-whitespace
|
vampnet/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Hugo Flores García and Prem Seetharaman
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
vampnet/README.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VampNet
|
| 2 |
+
|
| 3 |
+
This repository contains recipes for training generative music models on top of the Descript Audio Codec.
|
| 4 |
+
|
| 5 |
+
## try `unloop`
|
| 6 |
+
you can try vampnet in a co-creative looper called unloop. see this link: https://github.com/hugofloresgarcia/unloop
|
| 7 |
+
|
| 8 |
+
# Setting up
|
| 9 |
+
|
| 10 |
+
**Requires Python 3.9**.
|
| 11 |
+
|
| 12 |
+
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
| 13 |
+
|
| 14 |
+
(for example, using conda)
|
| 15 |
+
```bash
|
| 16 |
+
conda create -n vampnet python=3.9
|
| 17 |
+
conda activate vampnet
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
install VampNet
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
git clone https://github.com/hugofloresgarcia/vampnet.git
|
| 25 |
+
pip install -e ./vampnet
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## A note on argbind
|
| 29 |
+
This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
|
| 30 |
+
Config files are stored in the `conf/` folder.
|
| 31 |
+
|
| 32 |
+
## Getting the Pretrained Models
|
| 33 |
+
|
| 34 |
+
### Licensing for Pretrained Models:
|
| 35 |
+
The weights for the models are licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml). Likewise, any VampNet models fine-tuned on the pretrained models are also licensed [`CC BY-NC-SA 4.0`](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.ml).
|
| 36 |
+
|
| 37 |
+
Download the pretrained models from [this link](https://zenodo.org/record/8136629). Then, extract the models to the `models/` folder.
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Usage
|
| 41 |
+
|
| 42 |
+
## Launching the Gradio Interface
|
| 43 |
+
You can launch a gradio UI to play with vampnet.
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
python app.py --args.load conf/interface.yml --Interface.device cuda
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
# Training / Fine-tuning
|
| 50 |
+
|
| 51 |
+
## Training a model
|
| 52 |
+
|
| 53 |
+
To train a model, run the following script:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
You can edit `conf/vampnet.yml` to change the dataset paths or any training hyperparameters.
|
| 60 |
+
|
| 61 |
+
For coarse2fine models, you can use `conf/c2f.yml` as a starting configuration.
|
| 62 |
+
|
| 63 |
+
See `python scripts/exp/train.py -h` for a list of options.
|
| 64 |
+
|
| 65 |
+
## Fine-tuning
|
| 66 |
+
To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
|
| 67 |
+
The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to launch the gradio interface.
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
|
| 74 |
+
|
| 75 |
+
The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
|
| 76 |
+
|
| 77 |
+
launch the coarse job:
|
| 78 |
+
```bash
|
| 79 |
+
python scripts/exp/train.py --args.load conf/<fine_tune_name>/coarse.yml
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
|
| 83 |
+
|
| 84 |
+
launch the c2f job:
|
| 85 |
+
```bash
|
| 86 |
+
python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
launch the interface:
|
| 90 |
+
```bash
|
| 91 |
+
python app.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
|
vampnet/app.py
ADDED
|
@@ -0,0 +1,677 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
import yaml
|
| 4 |
+
import tempfile
|
| 5 |
+
import uuid
|
| 6 |
+
from dataclasses import dataclass, asdict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import audiotools as at
|
| 10 |
+
import argbind
|
| 11 |
+
|
| 12 |
+
import gradio as gr
|
| 13 |
+
from vampnet.interface import Interface
|
| 14 |
+
from vampnet import mask as pmask
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import spaces
|
| 18 |
+
ZERO_GPU = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
ZERO_GPU = False
|
| 21 |
+
|
| 22 |
+
def gpu(fn):
|
| 23 |
+
if ZERO_GPU:
|
| 24 |
+
return spaces.GPU(fn)
|
| 25 |
+
return fn
|
| 26 |
+
|
| 27 |
+
Interface = argbind.bind(Interface)
|
| 28 |
+
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
| 29 |
+
|
| 30 |
+
conf = argbind.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
| 34 |
+
def shift_pitch(signal, interval: int):
|
| 35 |
+
signal.samples = pitch_shift(
|
| 36 |
+
signal.samples,
|
| 37 |
+
shift=interval,
|
| 38 |
+
sample_rate=signal.sample_rate
|
| 39 |
+
)
|
| 40 |
+
return signal
|
| 41 |
+
|
| 42 |
+
def load_interface():
|
| 43 |
+
with argbind.scope(conf):
|
| 44 |
+
interface = Interface()
|
| 45 |
+
# loader = AudioLoader()
|
| 46 |
+
print(f"interface device is {interface.device}")
|
| 47 |
+
return interface
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
interface = load_interface()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
OUT_DIR = Path("gradio-outputs")
|
| 56 |
+
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_audio(file):
|
| 60 |
+
print(file)
|
| 61 |
+
filepath = file.name
|
| 62 |
+
sig = at.AudioSignal.salient_excerpt(
|
| 63 |
+
filepath,
|
| 64 |
+
duration=interface.coarse.chunk_size_s
|
| 65 |
+
)
|
| 66 |
+
sig = interface.preprocess(sig)
|
| 67 |
+
|
| 68 |
+
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
|
| 69 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
sig.write(out_dir / "input.wav")
|
| 71 |
+
return sig.path_to_file
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_example_audio():
|
| 75 |
+
return "./assets/example.wav"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@gpu
|
| 79 |
+
def _vamp(data, return_mask=False):
|
| 80 |
+
interface.to("cuda")
|
| 81 |
+
|
| 82 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 83 |
+
out_dir.mkdir()
|
| 84 |
+
sig = at.AudioSignal(data[input_audio])
|
| 85 |
+
sig = interface.preprocess(sig)
|
| 86 |
+
|
| 87 |
+
loudness = sig.loudness()
|
| 88 |
+
print(f"input loudness is {loudness}")
|
| 89 |
+
|
| 90 |
+
if data[pitch_shift_amt] != 0:
|
| 91 |
+
sig = shift_pitch(sig, data[pitch_shift_amt])
|
| 92 |
+
|
| 93 |
+
z = interface.encode(sig)
|
| 94 |
+
|
| 95 |
+
ncc = data[n_conditioning_codebooks]
|
| 96 |
+
|
| 97 |
+
# build the mask
|
| 98 |
+
mask = pmask.linear_random(z, data[rand_mask_intensity])
|
| 99 |
+
mask = pmask.mask_and(
|
| 100 |
+
mask, pmask.inpaint(
|
| 101 |
+
z,
|
| 102 |
+
interface.s2t(data[prefix_s]),
|
| 103 |
+
interface.s2t(data[suffix_s])
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
mask = pmask.mask_and(
|
| 107 |
+
mask, pmask.periodic_mask(
|
| 108 |
+
z,
|
| 109 |
+
data[periodic_p],
|
| 110 |
+
data[periodic_w],
|
| 111 |
+
random_roll=True
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
if data[onset_mask_width] > 0:
|
| 115 |
+
mask = pmask.mask_or(
|
| 116 |
+
mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
|
| 117 |
+
)
|
| 118 |
+
if data[beat_mask_width] > 0:
|
| 119 |
+
beat_mask = interface.make_beat_mask(
|
| 120 |
+
sig,
|
| 121 |
+
after_beat_s=(data[beat_mask_width]/1000),
|
| 122 |
+
mask_upbeats=not data[beat_mask_downbeats],
|
| 123 |
+
)
|
| 124 |
+
mask = pmask.mask_and(mask, beat_mask)
|
| 125 |
+
|
| 126 |
+
# these should be the last two mask ops
|
| 127 |
+
mask = pmask.dropout(mask, data[dropout])
|
| 128 |
+
mask = pmask.codebook_unmask(mask, ncc)
|
| 129 |
+
mask = pmask.codebook_mask(mask, int(data[n_mask_codebooks]))
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
print(f"dropout {data[dropout]}")
|
| 134 |
+
print(f"masktemp {data[masktemp]}")
|
| 135 |
+
print(f"sampletemp {data[sampletemp]}")
|
| 136 |
+
print(f"top_p {data[top_p]}")
|
| 137 |
+
print(f"prefix_s {data[prefix_s]}")
|
| 138 |
+
print(f"suffix_s {data[suffix_s]}")
|
| 139 |
+
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
| 140 |
+
print(f"num_steps {data[num_steps]}")
|
| 141 |
+
print(f"periodic_p {data[periodic_p]}")
|
| 142 |
+
print(f"periodic_w {data[periodic_w]}")
|
| 143 |
+
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
| 144 |
+
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
| 145 |
+
print(f"onset_mask_width {data[onset_mask_width]}")
|
| 146 |
+
print(f"beat_mask_width {data[beat_mask_width]}")
|
| 147 |
+
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
| 148 |
+
print(f"stretch_factor {data[stretch_factor]}")
|
| 149 |
+
print(f"seed {data[seed]}")
|
| 150 |
+
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
| 151 |
+
print(f"sample_cutoff {data[sample_cutoff]}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
_top_p = data[top_p] if data[top_p] > 0 else None
|
| 155 |
+
# save the mask as a txt file
|
| 156 |
+
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 157 |
+
|
| 158 |
+
_seed = data[seed] if data[seed] > 0 else None
|
| 159 |
+
zv, mask_z = interface.coarse_vamp(
|
| 160 |
+
z,
|
| 161 |
+
mask=mask,
|
| 162 |
+
sampling_steps=data[num_steps],
|
| 163 |
+
mask_temperature=data[masktemp]*10,
|
| 164 |
+
sampling_temperature=data[sampletemp],
|
| 165 |
+
return_mask=True,
|
| 166 |
+
typical_filtering=data[typical_filtering],
|
| 167 |
+
typical_mass=data[typical_mass],
|
| 168 |
+
typical_min_tokens=data[typical_min_tokens],
|
| 169 |
+
top_p=_top_p,
|
| 170 |
+
gen_fn=interface.coarse.generate,
|
| 171 |
+
seed=_seed,
|
| 172 |
+
sample_cutoff=data[sample_cutoff],
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if use_coarse2fine:
|
| 176 |
+
zv = interface.coarse_to_fine(
|
| 177 |
+
zv,
|
| 178 |
+
mask_temperature=data[masktemp]*10,
|
| 179 |
+
sampling_temperature=data[sampletemp],
|
| 180 |
+
mask=mask,
|
| 181 |
+
sampling_steps=data[num_steps],
|
| 182 |
+
sample_cutoff=data[sample_cutoff],
|
| 183 |
+
seed=_seed,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
sig = interface.to_signal(zv).cpu()
|
| 187 |
+
print("done")
|
| 188 |
+
|
| 189 |
+
print(f"output loudness is {sig.loudness()}")
|
| 190 |
+
sig = sig.normalize(loudness)
|
| 191 |
+
print(f"normalized loudness is {sig.loudness()}")
|
| 192 |
+
|
| 193 |
+
sig.write(out_dir / "output.wav")
|
| 194 |
+
|
| 195 |
+
if return_mask:
|
| 196 |
+
mask = interface.to_signal(mask_z).cpu()
|
| 197 |
+
mask.write(out_dir / "mask.wav")
|
| 198 |
+
return sig.path_to_file, mask.path_to_file
|
| 199 |
+
else:
|
| 200 |
+
return sig.path_to_file
|
| 201 |
+
|
| 202 |
+
def vamp(data):
|
| 203 |
+
return _vamp(data, return_mask=True)
|
| 204 |
+
|
| 205 |
+
def api_vamp(data):
|
| 206 |
+
return _vamp(data, return_mask=False)
|
| 207 |
+
|
| 208 |
+
def save_vamp(data):
|
| 209 |
+
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
| 210 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
sig_in = at.AudioSignal(data[input_audio])
|
| 213 |
+
sig_out = at.AudioSignal(data[output_audio])
|
| 214 |
+
|
| 215 |
+
sig_in.write(out_dir / "input.wav")
|
| 216 |
+
sig_out.write(out_dir / "output.wav")
|
| 217 |
+
|
| 218 |
+
_data = {
|
| 219 |
+
"masktemp": data[masktemp],
|
| 220 |
+
"sampletemp": data[sampletemp],
|
| 221 |
+
"top_p": data[top_p],
|
| 222 |
+
"prefix_s": data[prefix_s],
|
| 223 |
+
"suffix_s": data[suffix_s],
|
| 224 |
+
"rand_mask_intensity": data[rand_mask_intensity],
|
| 225 |
+
"num_steps": data[num_steps],
|
| 226 |
+
"notes": data[notes_text],
|
| 227 |
+
"periodic_period": data[periodic_p],
|
| 228 |
+
"periodic_width": data[periodic_w],
|
| 229 |
+
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 230 |
+
"use_coarse2fine": data[use_coarse2fine],
|
| 231 |
+
"stretch_factor": data[stretch_factor],
|
| 232 |
+
"seed": data[seed],
|
| 233 |
+
"samplecutoff": data[sample_cutoff],
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# save with yaml
|
| 237 |
+
with open(out_dir / "data.yaml", "w") as f:
|
| 238 |
+
yaml.dump(_data, f)
|
| 239 |
+
|
| 240 |
+
import zipfile
|
| 241 |
+
zip_path = str(out_dir.with_suffix(".zip"))
|
| 242 |
+
with zipfile.ZipFile(zip_path, "w") as zf:
|
| 243 |
+
for file in out_dir.iterdir():
|
| 244 |
+
zf.write(file, file.name)
|
| 245 |
+
|
| 246 |
+
return f"saved! your save code is {out_dir.stem}", zip_path
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@gpu
|
| 250 |
+
def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
| 251 |
+
interface.to("cuda")
|
| 252 |
+
|
| 253 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 254 |
+
out_dir.mkdir()
|
| 255 |
+
sig = at.AudioSignal(_input_audio)
|
| 256 |
+
sig = interface.preprocess(sig)
|
| 257 |
+
|
| 258 |
+
z = interface.encode(sig)
|
| 259 |
+
|
| 260 |
+
# build the mask
|
| 261 |
+
mask = pmask.linear_random(z, 1.0)
|
| 262 |
+
if _beat_mask_width > 0:
|
| 263 |
+
beat_mask = interface.make_beat_mask(
|
| 264 |
+
sig,
|
| 265 |
+
after_beat_s=(_beat_mask_width/1000),
|
| 266 |
+
)
|
| 267 |
+
mask = pmask.mask_and(mask, beat_mask)
|
| 268 |
+
|
| 269 |
+
# save the mask as a txt file
|
| 270 |
+
zv, mask_z = interface.coarse_vamp(
|
| 271 |
+
z,
|
| 272 |
+
mask=mask,
|
| 273 |
+
sampling_temperature=_sampletemp,
|
| 274 |
+
return_mask=True,
|
| 275 |
+
gen_fn=interface.coarse.generate,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
zv = interface.coarse_to_fine(
|
| 280 |
+
zv,
|
| 281 |
+
sampling_temperature=_sampletemp,
|
| 282 |
+
mask=mask,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
sig = interface.to_signal(zv).cpu()
|
| 286 |
+
print("done")
|
| 287 |
+
|
| 288 |
+
sig.write(out_dir / "output.wav")
|
| 289 |
+
|
| 290 |
+
return sig.path_to_file
|
| 291 |
+
|
| 292 |
+
with gr.Blocks() as demo:
|
| 293 |
+
|
| 294 |
+
with gr.Row():
|
| 295 |
+
with gr.Column():
|
| 296 |
+
gr.Markdown("# VampNet Audio Vamping")
|
| 297 |
+
gr.Markdown("""## Description:
|
| 298 |
+
This is a demo of the VampNet, a generative audio model that transforms the input audio based on the chosen settings.
|
| 299 |
+
You can control the extent and nature of variation with a set of manual controls and presets.
|
| 300 |
+
Use this interface to experiment with different mask settings and explore the audio outputs.
|
| 301 |
+
""")
|
| 302 |
+
|
| 303 |
+
gr.Markdown("""
|
| 304 |
+
## Instructions:
|
| 305 |
+
1. You can start by uploading some audio, or by loading the example audio.
|
| 306 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
| 307 |
+
3. Click the "generate (vamp)!!!" button to apply the vamp operation. Listen to the output audio.
|
| 308 |
+
4. Optionally, you can add some notes and save the result.
|
| 309 |
+
5. You can also use the output as the new input and continue experimenting!
|
| 310 |
+
""")
|
| 311 |
+
with gr.Row():
|
| 312 |
+
with gr.Column():
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
manual_audio_upload = gr.File(
|
| 316 |
+
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
| 317 |
+
file_types=["audio"]
|
| 318 |
+
)
|
| 319 |
+
load_example_audio_button = gr.Button("or load example audio")
|
| 320 |
+
|
| 321 |
+
input_audio = gr.Audio(
|
| 322 |
+
label="input audio",
|
| 323 |
+
interactive=False,
|
| 324 |
+
type="filepath",
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
audio_mask = gr.Audio(
|
| 328 |
+
label="audio mask (listen to this to hear the mask hints)",
|
| 329 |
+
interactive=False,
|
| 330 |
+
type="filepath",
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# connect widgets
|
| 334 |
+
load_example_audio_button.click(
|
| 335 |
+
fn=load_example_audio,
|
| 336 |
+
inputs=[],
|
| 337 |
+
outputs=[ input_audio]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
manual_audio_upload.change(
|
| 341 |
+
fn=load_audio,
|
| 342 |
+
inputs=[manual_audio_upload],
|
| 343 |
+
outputs=[ input_audio]
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# mask settings
|
| 347 |
+
with gr.Column():
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
presets = {
|
| 351 |
+
"unconditional": {
|
| 352 |
+
"periodic_p": 0,
|
| 353 |
+
"onset_mask_width": 0,
|
| 354 |
+
"beat_mask_width": 0,
|
| 355 |
+
"beat_mask_downbeats": False,
|
| 356 |
+
},
|
| 357 |
+
"slight periodic variation": {
|
| 358 |
+
"periodic_p": 5,
|
| 359 |
+
"onset_mask_width": 5,
|
| 360 |
+
"beat_mask_width": 0,
|
| 361 |
+
"beat_mask_downbeats": False,
|
| 362 |
+
},
|
| 363 |
+
"moderate periodic variation": {
|
| 364 |
+
"periodic_p": 13,
|
| 365 |
+
"onset_mask_width": 5,
|
| 366 |
+
"beat_mask_width": 0,
|
| 367 |
+
"beat_mask_downbeats": False,
|
| 368 |
+
},
|
| 369 |
+
"strong periodic variation": {
|
| 370 |
+
"periodic_p": 17,
|
| 371 |
+
"onset_mask_width": 5,
|
| 372 |
+
"beat_mask_width": 0,
|
| 373 |
+
"beat_mask_downbeats": False,
|
| 374 |
+
},
|
| 375 |
+
"very strong periodic variation": {
|
| 376 |
+
"periodic_p": 21,
|
| 377 |
+
"onset_mask_width": 5,
|
| 378 |
+
"beat_mask_width": 0,
|
| 379 |
+
"beat_mask_downbeats": False,
|
| 380 |
+
},
|
| 381 |
+
"beat-driven variation": {
|
| 382 |
+
"periodic_p": 0,
|
| 383 |
+
"onset_mask_width": 0,
|
| 384 |
+
"beat_mask_width": 50,
|
| 385 |
+
"beat_mask_downbeats": False,
|
| 386 |
+
},
|
| 387 |
+
"beat-driven variation (downbeats only)": {
|
| 388 |
+
"periodic_p": 0,
|
| 389 |
+
"onset_mask_width": 0,
|
| 390 |
+
"beat_mask_width": 50,
|
| 391 |
+
"beat_mask_downbeats": True,
|
| 392 |
+
},
|
| 393 |
+
"beat-driven variation (downbeats only, strong)": {
|
| 394 |
+
"periodic_p": 0,
|
| 395 |
+
"onset_mask_width": 0,
|
| 396 |
+
"beat_mask_width": 20,
|
| 397 |
+
"beat_mask_downbeats": True,
|
| 398 |
+
},
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
preset = gr.Dropdown(
|
| 402 |
+
label="preset",
|
| 403 |
+
choices=list(presets.keys()),
|
| 404 |
+
value="strong periodic variation",
|
| 405 |
+
)
|
| 406 |
+
load_preset_button = gr.Button("load_preset")
|
| 407 |
+
|
| 408 |
+
with gr.Accordion("manual controls", open=True):
|
| 409 |
+
periodic_p = gr.Slider(
|
| 410 |
+
label="periodic prompt (0 - unconditional, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
| 411 |
+
minimum=0,
|
| 412 |
+
maximum=128,
|
| 413 |
+
step=1,
|
| 414 |
+
value=3,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
onset_mask_width = gr.Slider(
|
| 419 |
+
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
| 420 |
+
minimum=0,
|
| 421 |
+
maximum=100,
|
| 422 |
+
step=1,
|
| 423 |
+
value=5,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
beat_mask_width = gr.Slider(
|
| 427 |
+
label="beat prompt (ms)",
|
| 428 |
+
minimum=0,
|
| 429 |
+
maximum=200,
|
| 430 |
+
value=0,
|
| 431 |
+
)
|
| 432 |
+
beat_mask_downbeats = gr.Checkbox(
|
| 433 |
+
label="beat mask downbeats only?",
|
| 434 |
+
value=False
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
n_mask_codebooks = gr.Number(
|
| 438 |
+
label="first upper codebook level to mask",
|
| 439 |
+
value=9,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
with gr.Accordion("extras ", open=False):
|
| 444 |
+
pitch_shift_amt = gr.Slider(
|
| 445 |
+
label="pitch shift amount (semitones)",
|
| 446 |
+
minimum=-12,
|
| 447 |
+
maximum=12,
|
| 448 |
+
step=1,
|
| 449 |
+
value=0,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
rand_mask_intensity = gr.Slider(
|
| 453 |
+
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
| 454 |
+
minimum=0.0,
|
| 455 |
+
maximum=1.0,
|
| 456 |
+
value=1.0
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
periodic_w = gr.Slider(
|
| 460 |
+
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
| 461 |
+
minimum=1,
|
| 462 |
+
maximum=20,
|
| 463 |
+
step=1,
|
| 464 |
+
value=1,
|
| 465 |
+
)
|
| 466 |
+
n_conditioning_codebooks = gr.Number(
|
| 467 |
+
label="number of conditioning codebooks. probably 0",
|
| 468 |
+
value=0,
|
| 469 |
+
precision=0,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
stretch_factor = gr.Slider(
|
| 473 |
+
label="time stretch factor",
|
| 474 |
+
minimum=0,
|
| 475 |
+
maximum=64,
|
| 476 |
+
step=1,
|
| 477 |
+
value=1,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
preset_outputs = {
|
| 481 |
+
periodic_p,
|
| 482 |
+
onset_mask_width,
|
| 483 |
+
beat_mask_width,
|
| 484 |
+
beat_mask_downbeats,
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
def load_preset(_preset):
|
| 488 |
+
return tuple(presets[_preset].values())
|
| 489 |
+
|
| 490 |
+
load_preset_button.click(
|
| 491 |
+
fn=load_preset,
|
| 492 |
+
inputs=[preset],
|
| 493 |
+
outputs=preset_outputs
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
with gr.Accordion("prefix/suffix prompts", open=False):
|
| 498 |
+
prefix_s = gr.Slider(
|
| 499 |
+
label="prefix hint length (seconds)",
|
| 500 |
+
minimum=0.0,
|
| 501 |
+
maximum=10.0,
|
| 502 |
+
value=0.0
|
| 503 |
+
)
|
| 504 |
+
suffix_s = gr.Slider(
|
| 505 |
+
label="suffix hint length (seconds)",
|
| 506 |
+
minimum=0.0,
|
| 507 |
+
maximum=10.0,
|
| 508 |
+
value=0.0
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
masktemp = gr.Slider(
|
| 512 |
+
label="mask temperature",
|
| 513 |
+
minimum=0.0,
|
| 514 |
+
maximum=100.0,
|
| 515 |
+
value=1.5
|
| 516 |
+
)
|
| 517 |
+
sampletemp = gr.Slider(
|
| 518 |
+
label="sample temperature",
|
| 519 |
+
minimum=0.1,
|
| 520 |
+
maximum=10.0,
|
| 521 |
+
value=1.0,
|
| 522 |
+
step=0.001
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
with gr.Accordion("sampling settings", open=False):
|
| 528 |
+
top_p = gr.Slider(
|
| 529 |
+
label="top p (0.0 = off)",
|
| 530 |
+
minimum=0.0,
|
| 531 |
+
maximum=1.0,
|
| 532 |
+
value=0.0
|
| 533 |
+
)
|
| 534 |
+
typical_filtering = gr.Checkbox(
|
| 535 |
+
label="typical filtering ",
|
| 536 |
+
value=False
|
| 537 |
+
)
|
| 538 |
+
typical_mass = gr.Slider(
|
| 539 |
+
label="typical mass (should probably stay between 0.1 and 0.5)",
|
| 540 |
+
minimum=0.01,
|
| 541 |
+
maximum=0.99,
|
| 542 |
+
value=0.15
|
| 543 |
+
)
|
| 544 |
+
typical_min_tokens = gr.Slider(
|
| 545 |
+
label="typical min tokens (should probably stay between 1 and 256)",
|
| 546 |
+
minimum=1,
|
| 547 |
+
maximum=256,
|
| 548 |
+
step=1,
|
| 549 |
+
value=64
|
| 550 |
+
)
|
| 551 |
+
sample_cutoff = gr.Slider(
|
| 552 |
+
label="sample cutoff",
|
| 553 |
+
minimum=0.0,
|
| 554 |
+
maximum=1.0,
|
| 555 |
+
value=0.5,
|
| 556 |
+
step=0.01
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
use_coarse2fine = gr.Checkbox(
|
| 560 |
+
label="use coarse2fine",
|
| 561 |
+
value=True,
|
| 562 |
+
visible=False
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
num_steps = gr.Slider(
|
| 566 |
+
label="number of steps (should normally be between 12 and 36)",
|
| 567 |
+
minimum=1,
|
| 568 |
+
maximum=128,
|
| 569 |
+
step=1,
|
| 570 |
+
value=36
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
dropout = gr.Slider(
|
| 574 |
+
label="mask dropout",
|
| 575 |
+
minimum=0.0,
|
| 576 |
+
maximum=1.0,
|
| 577 |
+
step=0.01,
|
| 578 |
+
value=0.0
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
seed = gr.Number(
|
| 583 |
+
label="seed (0 for random)",
|
| 584 |
+
value=0,
|
| 585 |
+
precision=0,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# mask settings
|
| 591 |
+
with gr.Column():
|
| 592 |
+
|
| 593 |
+
# lora_choice = gr.Dropdown(
|
| 594 |
+
# label="lora choice",
|
| 595 |
+
# choices=list(loras.keys()),
|
| 596 |
+
# value=LORA_NONE,
|
| 597 |
+
# visible=False
|
| 598 |
+
# )
|
| 599 |
+
|
| 600 |
+
vamp_button = gr.Button("generate (vamp)!!!")
|
| 601 |
+
output_audio = gr.Audio(
|
| 602 |
+
label="output audio",
|
| 603 |
+
interactive=False,
|
| 604 |
+
type="filepath"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
notes_text = gr.Textbox(
|
| 608 |
+
label="type any notes about the generated audio here",
|
| 609 |
+
value="",
|
| 610 |
+
interactive=True
|
| 611 |
+
)
|
| 612 |
+
save_button = gr.Button("save vamp")
|
| 613 |
+
download_file = gr.File(
|
| 614 |
+
label="vamp to download will appear here",
|
| 615 |
+
interactive=False
|
| 616 |
+
)
|
| 617 |
+
use_as_input_button = gr.Button("use output as input")
|
| 618 |
+
|
| 619 |
+
thank_you = gr.Markdown("")
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
_inputs = {
|
| 623 |
+
input_audio,
|
| 624 |
+
num_steps,
|
| 625 |
+
masktemp,
|
| 626 |
+
sampletemp,
|
| 627 |
+
top_p,
|
| 628 |
+
prefix_s, suffix_s,
|
| 629 |
+
rand_mask_intensity,
|
| 630 |
+
periodic_p, periodic_w,
|
| 631 |
+
n_conditioning_codebooks,
|
| 632 |
+
dropout,
|
| 633 |
+
use_coarse2fine,
|
| 634 |
+
stretch_factor,
|
| 635 |
+
onset_mask_width,
|
| 636 |
+
typical_filtering,
|
| 637 |
+
typical_mass,
|
| 638 |
+
typical_min_tokens,
|
| 639 |
+
beat_mask_width,
|
| 640 |
+
beat_mask_downbeats,
|
| 641 |
+
seed,
|
| 642 |
+
# lora_choice,
|
| 643 |
+
n_mask_codebooks,
|
| 644 |
+
pitch_shift_amt,
|
| 645 |
+
sample_cutoff
|
| 646 |
+
}
|
| 647 |
+
|
| 648 |
+
# connect widgets
|
| 649 |
+
vamp_button.click(
|
| 650 |
+
fn=vamp,
|
| 651 |
+
inputs=_inputs,
|
| 652 |
+
outputs=[output_audio, audio_mask],
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
api_vamp_button = gr.Button("api vamp", visible=False)
|
| 656 |
+
api_vamp_button.click(
|
| 657 |
+
fn=api_vamp,
|
| 658 |
+
inputs=_inputs,
|
| 659 |
+
outputs=[output_audio],
|
| 660 |
+
api_name="vamp"
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
use_as_input_button.click(
|
| 664 |
+
fn=lambda x: x,
|
| 665 |
+
inputs=[output_audio],
|
| 666 |
+
outputs=[input_audio]
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
save_button.click(
|
| 670 |
+
fn=save_vamp,
|
| 671 |
+
inputs=_inputs | {notes_text, output_audio},
|
| 672 |
+
outputs=[thank_you, download_file]
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
# demo.launch(share=True, enable_queue=True, debug=True)
|
| 677 |
+
demo.launch(share=True, debug=True) # from wam: because enable_queue seems to not be supported?
|
vampnet/conf/c2f.yml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
$include:
|
| 2 |
+
- conf/vampnet.yml
|
| 3 |
+
|
| 4 |
+
VampNet.n_codebooks: 14
|
| 5 |
+
VampNet.n_conditioning_codebooks: 4
|
| 6 |
+
|
| 7 |
+
VampNet.embedding_dim: 1280
|
| 8 |
+
VampNet.n_layers: 16
|
| 9 |
+
VampNet.n_heads: 20
|
| 10 |
+
|
| 11 |
+
AudioDataset.duration: 3.0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
AudioDataset.loudness_cutoff: -40.0
|
vampnet/conf/interface.yml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Interface.coarse_ckpt: ./models/coarse.pth
|
| 2 |
+
Interface.coarse2fine_ckpt: ./models/c2f.pth
|
| 3 |
+
Interface.codec_ckpt: ./models/codec.pth
|
| 4 |
+
Interface.coarse_chunk_size_s: 10
|
| 5 |
+
Interface.coarse2fine_chunk_size_s: 3
|
| 6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
| 7 |
+
|
| 8 |
+
# AudioLoader.sources:
|
| 9 |
+
# - /media/CHONK/null
|
| 10 |
+
|
vampnet/conf/lora/lora.yml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
$include:
|
| 2 |
+
- conf/vampnet.yml
|
| 3 |
+
|
| 4 |
+
fine_tune: True
|
| 5 |
+
|
| 6 |
+
train/AudioDataset.n_examples: 100000000
|
| 7 |
+
val/AudioDataset.n_examples: 500
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
NoamScheduler.warmup: 500
|
| 11 |
+
|
| 12 |
+
batch_size: 6
|
| 13 |
+
num_workers: 7
|
| 14 |
+
save_iters: [2000, 4000, 10000,20000, 40000, 100000]
|
| 15 |
+
sample_freq: 2000
|
| 16 |
+
val_freq: 1000
|
| 17 |
+
|
| 18 |
+
AdamW.lr: 0.0001
|
| 19 |
+
|
| 20 |
+
# let's us organize sound classes into folders and choose from those sound classes uniformly
|
| 21 |
+
AudioDataset.without_replacement: False
|
| 22 |
+
num_iters: 500000
|
vampnet/conf/vampnet.yml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
codec_ckpt: ./models/vampnet/codec.pth
|
| 3 |
+
save_path: ckpt
|
| 4 |
+
|
| 5 |
+
num_iters: 1000000000
|
| 6 |
+
save_iters: [10000, 50000, 100000, 300000, 500000]
|
| 7 |
+
val_idx: [0,1,2,3,4,5,6,7,8,9]
|
| 8 |
+
sample_freq: 10000
|
| 9 |
+
val_freq: 1000
|
| 10 |
+
|
| 11 |
+
batch_size: 8
|
| 12 |
+
num_workers: 10
|
| 13 |
+
|
| 14 |
+
# Optimization
|
| 15 |
+
amp: false
|
| 16 |
+
|
| 17 |
+
CrossEntropyLoss.label_smoothing: 0.1
|
| 18 |
+
|
| 19 |
+
AdamW.lr: 0.001
|
| 20 |
+
|
| 21 |
+
NoamScheduler.factor: 2.0
|
| 22 |
+
NoamScheduler.warmup: 10000
|
| 23 |
+
|
| 24 |
+
VampNet.vocab_size: 1024
|
| 25 |
+
VampNet.n_codebooks: 4
|
| 26 |
+
VampNet.n_conditioning_codebooks: 0
|
| 27 |
+
VampNet.r_cond_dim: 0
|
| 28 |
+
VampNet.noise_mode: mask
|
| 29 |
+
VampNet.embedding_dim: 1280
|
| 30 |
+
VampNet.n_layers: 20
|
| 31 |
+
VampNet.n_heads: 20
|
| 32 |
+
VampNet.flash_attn: false
|
| 33 |
+
VampNet.dropout: 0.1
|
| 34 |
+
|
| 35 |
+
AudioLoader.relative_path: ""
|
| 36 |
+
AudioDataset.loudness_cutoff: -30.0
|
| 37 |
+
AudioDataset.without_replacement: true
|
| 38 |
+
AudioLoader.shuffle: true
|
| 39 |
+
|
| 40 |
+
AudioDataset.duration: 10.0
|
| 41 |
+
|
| 42 |
+
train/AudioDataset.n_examples: 10000000
|
| 43 |
+
train/AudioLoader.sources:
|
| 44 |
+
- /media/CHONK/hugo/spotdl/audio-train
|
| 45 |
+
|
| 46 |
+
val/AudioDataset.n_examples: 2000
|
| 47 |
+
val/AudioLoader.sources:
|
| 48 |
+
- /media/CHONK/hugo/spotdl/audio-val
|
| 49 |
+
|
vampnet/scripts/exp/eval.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
from frechet_audio_distance import FrechetAudioDistance
|
| 6 |
+
import pandas
|
| 7 |
+
import argbind
|
| 8 |
+
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
import audiotools
|
| 12 |
+
from audiotools import AudioSignal
|
| 13 |
+
|
| 14 |
+
@argbind.bind(without_prefix=True)
|
| 15 |
+
def eval(
|
| 16 |
+
exp_dir: str = None,
|
| 17 |
+
baseline_key: str = "baseline",
|
| 18 |
+
audio_ext: str = ".wav",
|
| 19 |
+
):
|
| 20 |
+
assert exp_dir is not None
|
| 21 |
+
exp_dir = Path(exp_dir)
|
| 22 |
+
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
|
| 23 |
+
|
| 24 |
+
# set up our metrics
|
| 25 |
+
# sisdr_loss = audiotools.metrics.distance.SISDRLoss()
|
| 26 |
+
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 27 |
+
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 28 |
+
frechet = FrechetAudioDistance(
|
| 29 |
+
use_pca=False,
|
| 30 |
+
use_activation=False,
|
| 31 |
+
verbose=True,
|
| 32 |
+
audio_load_worker=4,
|
| 33 |
+
)
|
| 34 |
+
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
|
| 36 |
+
# figure out what conditions we have
|
| 37 |
+
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
| 38 |
+
|
| 39 |
+
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
|
| 40 |
+
conditions.remove(baseline_key)
|
| 41 |
+
|
| 42 |
+
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
| 43 |
+
print(f"conditions: {conditions}")
|
| 44 |
+
|
| 45 |
+
baseline_dir = exp_dir / baseline_key
|
| 46 |
+
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 47 |
+
|
| 48 |
+
metrics = []
|
| 49 |
+
for condition in tqdm(conditions):
|
| 50 |
+
cond_dir = exp_dir / condition
|
| 51 |
+
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
|
| 52 |
+
|
| 53 |
+
print(f"computing fad for {baseline_dir} and {cond_dir}")
|
| 54 |
+
frechet_score = frechet.score(baseline_dir, cond_dir)
|
| 55 |
+
|
| 56 |
+
# make sure we have the same number of files
|
| 57 |
+
num_files = min(len(baseline_files), len(cond_files))
|
| 58 |
+
baseline_files = baseline_files[:num_files]
|
| 59 |
+
cond_files = cond_files[:num_files]
|
| 60 |
+
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
| 61 |
+
|
| 62 |
+
def process(baseline_file, cond_file):
|
| 63 |
+
# make sure the files match (same name)
|
| 64 |
+
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
| 65 |
+
|
| 66 |
+
# load the files
|
| 67 |
+
baseline_sig = AudioSignal(str(baseline_file))
|
| 68 |
+
cond_sig = AudioSignal(str(cond_file))
|
| 69 |
+
|
| 70 |
+
cond_sig.resample(baseline_sig.sample_rate)
|
| 71 |
+
cond_sig.truncate_samples(baseline_sig.length)
|
| 72 |
+
|
| 73 |
+
# if our condition is inpainting, we need to trim the conditioning off
|
| 74 |
+
if "inpaint" in condition:
|
| 75 |
+
ctx_amt = float(condition.split("_")[-1])
|
| 76 |
+
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
| 77 |
+
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
|
| 78 |
+
cond_sig.trim(ctx_samples, ctx_samples)
|
| 79 |
+
baseline_sig.trim(ctx_samples, ctx_samples)
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
# "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
|
| 83 |
+
# "stft": stft_loss(baseline_sig, cond_sig).item(),
|
| 84 |
+
"mel": mel_loss(baseline_sig, cond_sig).item(),
|
| 85 |
+
"frechet": frechet_score,
|
| 86 |
+
# "visqol": vsq,
|
| 87 |
+
"condition": condition,
|
| 88 |
+
"file": baseline_file.stem,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
|
| 92 |
+
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
|
| 93 |
+
|
| 94 |
+
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
for mk in metric_keys:
|
| 98 |
+
stat = pandas.DataFrame(metrics)
|
| 99 |
+
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
|
| 100 |
+
stat.to_csv(exp_dir / f"stats-{mk}.csv")
|
| 101 |
+
|
| 102 |
+
df = pandas.DataFrame(metrics)
|
| 103 |
+
df.to_csv(exp_dir / "metrics-all.csv", index=False)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
args = argbind.parse_args()
|
| 108 |
+
|
| 109 |
+
with argbind.scope(args):
|
| 110 |
+
eval()
|
vampnet/scripts/exp/experiment.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import random
|
| 3 |
+
from typing import List
|
| 4 |
+
import tempfile
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
import argbind
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from vampnet.interface import Interface
|
| 12 |
+
from vampnet import mask as pmask
|
| 13 |
+
import audiotools as at
|
| 14 |
+
|
| 15 |
+
Interface: Interface = argbind.bind(Interface)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def calculate_bitrate(
|
| 20 |
+
interface, num_codebooks,
|
| 21 |
+
downsample_factor
|
| 22 |
+
):
|
| 23 |
+
bit_width = 10
|
| 24 |
+
sr = interface.codec.sample_rate
|
| 25 |
+
hop = interface.codec.hop_size
|
| 26 |
+
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
|
| 27 |
+
return rate
|
| 28 |
+
|
| 29 |
+
def baseline(sig, interface):
|
| 30 |
+
return interface.preprocess(sig)
|
| 31 |
+
|
| 32 |
+
def reconstructed(sig, interface):
|
| 33 |
+
return interface.to_signal(
|
| 34 |
+
interface.encode(sig)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def coarse2fine(sig, interface):
|
| 38 |
+
z = interface.encode(sig)
|
| 39 |
+
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
|
| 40 |
+
|
| 41 |
+
z = interface.coarse_to_fine(z)
|
| 42 |
+
return interface.to_signal(z)
|
| 43 |
+
|
| 44 |
+
class CoarseCond:
|
| 45 |
+
|
| 46 |
+
def __init__(self, num_conditioning_codebooks, downsample_factor):
|
| 47 |
+
self.num_conditioning_codebooks = num_conditioning_codebooks
|
| 48 |
+
self.downsample_factor = downsample_factor
|
| 49 |
+
|
| 50 |
+
def __call__(self, sig, interface):
|
| 51 |
+
z = interface.encode(sig)
|
| 52 |
+
mask = pmask.full_mask(z)
|
| 53 |
+
mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks)
|
| 54 |
+
mask = pmask.periodic_mask(mask, self.downsample_factor)
|
| 55 |
+
|
| 56 |
+
zv = interface.coarse_vamp(z, mask)
|
| 57 |
+
zv = interface.coarse_to_fine(zv)
|
| 58 |
+
return interface.to_signal(zv)
|
| 59 |
+
|
| 60 |
+
def opus(sig, interface, bitrate=128):
|
| 61 |
+
sig = interface.preprocess(sig)
|
| 62 |
+
|
| 63 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
| 64 |
+
sig.write(f.name)
|
| 65 |
+
|
| 66 |
+
opus_name = Path(f.name).with_suffix(".opus")
|
| 67 |
+
# convert to opus
|
| 68 |
+
cmd = [
|
| 69 |
+
"ffmpeg", "-y", "-i", f.name,
|
| 70 |
+
"-c:a", "libopus",
|
| 71 |
+
"-b:a", f"{bitrate}",
|
| 72 |
+
opus_name
|
| 73 |
+
]
|
| 74 |
+
subprocess.run(cmd, check=True)
|
| 75 |
+
|
| 76 |
+
# convert back to wav
|
| 77 |
+
output_name = Path(f"{f.name}-opus").with_suffix(".wav")
|
| 78 |
+
cmd = [
|
| 79 |
+
"ffmpeg", "-y", "-i", opus_name,
|
| 80 |
+
output_name
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
subprocess.run(cmd, check=True)
|
| 84 |
+
|
| 85 |
+
sig = at.AudioSignal(
|
| 86 |
+
output_name,
|
| 87 |
+
sample_rate=sig.sample_rate
|
| 88 |
+
)
|
| 89 |
+
return sig
|
| 90 |
+
|
| 91 |
+
def mask_ratio_1_step(ratio=1.0):
|
| 92 |
+
def wrapper(sig, interface):
|
| 93 |
+
z = interface.encode(sig)
|
| 94 |
+
mask = pmask.linear_random(z, ratio)
|
| 95 |
+
zv = interface.coarse_vamp(
|
| 96 |
+
z,
|
| 97 |
+
mask,
|
| 98 |
+
sampling_steps=1,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
return interface.to_signal(zv)
|
| 102 |
+
return wrapper
|
| 103 |
+
|
| 104 |
+
def num_sampling_steps(num_steps=1):
|
| 105 |
+
def wrapper(sig, interface: Interface):
|
| 106 |
+
z = interface.encode(sig)
|
| 107 |
+
mask = pmask.periodic_mask(z, 16)
|
| 108 |
+
zv = interface.coarse_vamp(
|
| 109 |
+
z,
|
| 110 |
+
mask,
|
| 111 |
+
sampling_steps=num_steps,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
zv = interface.coarse_to_fine(zv)
|
| 115 |
+
return interface.to_signal(zv)
|
| 116 |
+
return wrapper
|
| 117 |
+
|
| 118 |
+
def beat_mask(ctx_time):
|
| 119 |
+
def wrapper(sig, interface):
|
| 120 |
+
beat_mask = interface.make_beat_mask(
|
| 121 |
+
sig,
|
| 122 |
+
before_beat_s=ctx_time/2,
|
| 123 |
+
after_beat_s=ctx_time/2,
|
| 124 |
+
invert=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
z = interface.encode(sig)
|
| 128 |
+
|
| 129 |
+
zv = interface.coarse_vamp(
|
| 130 |
+
z, beat_mask
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
zv = interface.coarse_to_fine(zv)
|
| 134 |
+
return interface.to_signal(zv)
|
| 135 |
+
return wrapper
|
| 136 |
+
|
| 137 |
+
def inpaint(ctx_time):
|
| 138 |
+
def wrapper(sig, interface: Interface):
|
| 139 |
+
z = interface.encode(sig)
|
| 140 |
+
mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time))
|
| 141 |
+
|
| 142 |
+
zv = interface.coarse_vamp(z, mask)
|
| 143 |
+
zv = interface.coarse_to_fine(zv)
|
| 144 |
+
|
| 145 |
+
return interface.to_signal(zv)
|
| 146 |
+
return wrapper
|
| 147 |
+
|
| 148 |
+
def token_noise(noise_amt):
|
| 149 |
+
def wrapper(sig, interface: Interface):
|
| 150 |
+
z = interface.encode(sig)
|
| 151 |
+
mask = pmask.random(z, noise_amt)
|
| 152 |
+
z = torch.where(
|
| 153 |
+
mask,
|
| 154 |
+
torch.randint_like(z, 0, interface.coarse.vocab_size),
|
| 155 |
+
z
|
| 156 |
+
)
|
| 157 |
+
return interface.to_signal(z)
|
| 158 |
+
return wrapper
|
| 159 |
+
|
| 160 |
+
EXP_REGISTRY = {}
|
| 161 |
+
|
| 162 |
+
EXP_REGISTRY["gen-compression"] = {
|
| 163 |
+
"baseline": baseline,
|
| 164 |
+
"reconstructed": reconstructed,
|
| 165 |
+
"coarse2fine": coarse2fine,
|
| 166 |
+
**{
|
| 167 |
+
f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
|
| 168 |
+
for (n, x) in (
|
| 169 |
+
(1, 1), # 1 codebook, no downsampling
|
| 170 |
+
(4, 4), # 4 codebooks, downsampled 4x
|
| 171 |
+
(4, 16), # 4 codebooks, downsampled 16x
|
| 172 |
+
(4, 32), # 4 codebooks, downsampled 16x
|
| 173 |
+
)
|
| 174 |
+
},
|
| 175 |
+
**{
|
| 176 |
+
f"token_noise_{x}": mask_ratio_1_step(ratio=x)
|
| 177 |
+
for x in [0.25, 0.5, 0.75]
|
| 178 |
+
},
|
| 179 |
+
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
EXP_REGISTRY["sampling-steps"] = {
|
| 184 |
+
# "codec": reconstructed,
|
| 185 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
EXP_REGISTRY["musical-sampling"] = {
|
| 190 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 191 |
+
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
@argbind.bind(without_prefix=True)
|
| 195 |
+
def main(
|
| 196 |
+
sources=[
|
| 197 |
+
"/media/CHONK/hugo/spotdl/val",
|
| 198 |
+
],
|
| 199 |
+
output_dir: str = "./samples",
|
| 200 |
+
max_excerpts: int = 2000,
|
| 201 |
+
exp_type: str = "gen-compression",
|
| 202 |
+
seed: int = 0,
|
| 203 |
+
ext: str = [".mp3"],
|
| 204 |
+
):
|
| 205 |
+
at.util.seed(seed)
|
| 206 |
+
interface = Interface()
|
| 207 |
+
|
| 208 |
+
output_dir = Path(output_dir)
|
| 209 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 210 |
+
|
| 211 |
+
from audiotools.data.datasets import AudioLoader, AudioDataset
|
| 212 |
+
|
| 213 |
+
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
|
| 214 |
+
dataset = AudioDataset(loader,
|
| 215 |
+
sample_rate=interface.codec.sample_rate,
|
| 216 |
+
duration=interface.coarse.chunk_size_s,
|
| 217 |
+
n_examples=max_excerpts,
|
| 218 |
+
without_replacement=True,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if exp_type in EXP_REGISTRY:
|
| 222 |
+
SAMPLE_CONDS = EXP_REGISTRY[exp_type]
|
| 223 |
+
else:
|
| 224 |
+
raise ValueError(f"Unknown exp_type {exp_type}")
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
indices = list(range(max_excerpts))
|
| 228 |
+
random.shuffle(indices)
|
| 229 |
+
for i in tqdm(indices):
|
| 230 |
+
# if all our files are already there, skip
|
| 231 |
+
done = []
|
| 232 |
+
for name in SAMPLE_CONDS:
|
| 233 |
+
o_dir = Path(output_dir) / name
|
| 234 |
+
done.append((o_dir / f"{i}.wav").exists())
|
| 235 |
+
if all(done):
|
| 236 |
+
continue
|
| 237 |
+
|
| 238 |
+
sig = dataset[i]["signal"]
|
| 239 |
+
results = {
|
| 240 |
+
name: cond(sig, interface).cpu()
|
| 241 |
+
for name, cond in SAMPLE_CONDS.items()
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
for name, sig in results.items():
|
| 245 |
+
o_dir = Path(output_dir) / name
|
| 246 |
+
o_dir.mkdir(exist_ok=True, parents=True)
|
| 247 |
+
|
| 248 |
+
sig.write(o_dir / f"{i}.wav")
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
args = argbind.parse_args()
|
| 252 |
+
|
| 253 |
+
with argbind.scope(args):
|
| 254 |
+
main()
|
vampnet/scripts/exp/fine_tune.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argbind
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import yaml
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
"""example output: (yaml)
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
@argbind.bind(without_prefix=True, positional=True)
|
| 14 |
+
def fine_tune(audio_files_or_folders: List[str], name: str):
|
| 15 |
+
|
| 16 |
+
conf_dir = Path("conf")
|
| 17 |
+
assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
|
| 18 |
+
|
| 19 |
+
conf_dir = conf_dir / "generated"
|
| 20 |
+
conf_dir.mkdir(exist_ok=True)
|
| 21 |
+
|
| 22 |
+
finetune_dir = conf_dir / name
|
| 23 |
+
finetune_dir.mkdir(exist_ok=True)
|
| 24 |
+
|
| 25 |
+
finetune_c2f_conf = {
|
| 26 |
+
"$include": ["conf/lora/lora.yml"],
|
| 27 |
+
"fine_tune": True,
|
| 28 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
| 29 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
| 30 |
+
"VampNet.n_codebooks": 14,
|
| 31 |
+
"VampNet.n_conditioning_codebooks": 4,
|
| 32 |
+
"VampNet.embedding_dim": 1280,
|
| 33 |
+
"VampNet.n_layers": 16,
|
| 34 |
+
"VampNet.n_heads": 20,
|
| 35 |
+
"AudioDataset.duration": 3.0,
|
| 36 |
+
"AudioDataset.loudness_cutoff": -40.0,
|
| 37 |
+
"save_path": f"./runs/{name}/c2f",
|
| 38 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
finetune_coarse_conf = {
|
| 42 |
+
"$include": ["conf/lora/lora.yml"],
|
| 43 |
+
"fine_tune": True,
|
| 44 |
+
"train/AudioLoader.sources": audio_files_or_folders,
|
| 45 |
+
"val/AudioLoader.sources": audio_files_or_folders,
|
| 46 |
+
"save_path": f"./runs/{name}/coarse",
|
| 47 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
interface_conf = {
|
| 51 |
+
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
| 52 |
+
|
| 53 |
+
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
| 54 |
+
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
| 55 |
+
|
| 56 |
+
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 57 |
+
"AudioLoader.sources": [audio_files_or_folders],
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
# save the confs
|
| 61 |
+
with open(finetune_dir / "c2f.yml", "w") as f:
|
| 62 |
+
yaml.dump(finetune_c2f_conf, f)
|
| 63 |
+
|
| 64 |
+
with open(finetune_dir / "coarse.yml", "w") as f:
|
| 65 |
+
yaml.dump(finetune_coarse_conf, f)
|
| 66 |
+
|
| 67 |
+
with open(finetune_dir / "interface.yml", "w") as f:
|
| 68 |
+
yaml.dump(interface_conf, f)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
args = argbind.parse_args()
|
| 75 |
+
|
| 76 |
+
with argbind.scope(args):
|
| 77 |
+
fine_tune()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
vampnet/scripts/exp/train.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import warnings
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
|
| 8 |
+
import argbind
|
| 9 |
+
import audiotools as at
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from audiotools import AudioSignal
|
| 13 |
+
from audiotools.data import transforms
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from rich import pretty
|
| 16 |
+
from rich.traceback import install
|
| 17 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
+
|
| 19 |
+
import vampnet
|
| 20 |
+
from vampnet.modules.transformer import VampNet
|
| 21 |
+
from vampnet.util import codebook_unflatten, codebook_flatten
|
| 22 |
+
from vampnet import mask as pmask
|
| 23 |
+
# from dac.model.dac import DAC
|
| 24 |
+
from lac.model.lac import LAC as DAC
|
| 25 |
+
|
| 26 |
+
from audiotools.ml.decorators import (
|
| 27 |
+
timer, Tracker, when
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
import loralib as lora
|
| 31 |
+
|
| 32 |
+
import torch._dynamo
|
| 33 |
+
torch._dynamo.config.verbose=True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Enable cudnn autotuner to speed up training
|
| 37 |
+
# (can be altered by the funcs.seed function)
|
| 38 |
+
torch.backends.cudnn.benchmark = bool(int(os.getenv("CUDNN_BENCHMARK", 1)))
|
| 39 |
+
# Uncomment to trade memory for speed.
|
| 40 |
+
|
| 41 |
+
# Install to make things look nice
|
| 42 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 43 |
+
pretty.install()
|
| 44 |
+
install()
|
| 45 |
+
|
| 46 |
+
# optim
|
| 47 |
+
Accelerator = argbind.bind(at.ml.Accelerator, without_prefix=True)
|
| 48 |
+
CrossEntropyLoss = argbind.bind(nn.CrossEntropyLoss)
|
| 49 |
+
AdamW = argbind.bind(torch.optim.AdamW)
|
| 50 |
+
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
| 51 |
+
|
| 52 |
+
# transforms
|
| 53 |
+
filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
|
| 54 |
+
"BaseTransform",
|
| 55 |
+
"Compose",
|
| 56 |
+
"Choose",
|
| 57 |
+
]
|
| 58 |
+
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
| 59 |
+
|
| 60 |
+
# model
|
| 61 |
+
VampNet = argbind.bind(VampNet)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# data
|
| 65 |
+
AudioLoader = argbind.bind(at.datasets.AudioLoader)
|
| 66 |
+
AudioDataset = argbind.bind(at.datasets.AudioDataset, "train", "val")
|
| 67 |
+
|
| 68 |
+
IGNORE_INDEX = -100
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@argbind.bind("train", "val", without_prefix=True)
|
| 72 |
+
def build_transform():
|
| 73 |
+
transform = transforms.Compose(
|
| 74 |
+
tfm.VolumeNorm(("const", -24)),
|
| 75 |
+
# tfm.PitchShift(),
|
| 76 |
+
tfm.RescaleAudio(),
|
| 77 |
+
)
|
| 78 |
+
return transform
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@torch.no_grad()
|
| 82 |
+
def apply_transform(transform_fn, batch):
|
| 83 |
+
sig: AudioSignal = batch["signal"]
|
| 84 |
+
kwargs = batch["transform_args"]
|
| 85 |
+
|
| 86 |
+
sig: AudioSignal = transform_fn(sig.clone(), **kwargs)
|
| 87 |
+
return sig
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def build_datasets(args, sample_rate: int):
|
| 91 |
+
with argbind.scope(args, "train"):
|
| 92 |
+
train_data = AudioDataset(
|
| 93 |
+
AudioLoader(), sample_rate, transform=build_transform()
|
| 94 |
+
)
|
| 95 |
+
with argbind.scope(args, "val"):
|
| 96 |
+
val_data = AudioDataset(AudioLoader(), sample_rate, transform=build_transform())
|
| 97 |
+
return train_data, val_data
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def rand_float(shape, low, high, rng):
|
| 101 |
+
return rng.draw(shape)[:, 0] * (high - low) + low
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def flip_coin(shape, p, rng):
|
| 105 |
+
return rng.draw(shape)[:, 0] < p
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def num_params_hook(o, p):
|
| 109 |
+
return o + f" {p/1e6:<.3f}M params."
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def add_num_params_repr_hook(model):
|
| 113 |
+
import numpy as np
|
| 114 |
+
from functools import partial
|
| 115 |
+
|
| 116 |
+
for n, m in model.named_modules():
|
| 117 |
+
o = m.extra_repr()
|
| 118 |
+
p = sum([np.prod(p.size()) for p in m.parameters()])
|
| 119 |
+
|
| 120 |
+
setattr(m, "extra_repr", partial(num_params_hook, o=o, p=p))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def accuracy(
|
| 124 |
+
preds: torch.Tensor,
|
| 125 |
+
target: torch.Tensor,
|
| 126 |
+
top_k: int = 1,
|
| 127 |
+
ignore_index: Optional[int] = None,
|
| 128 |
+
) -> torch.Tensor:
|
| 129 |
+
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
| 130 |
+
preds = rearrange(preds, "b p s -> (b s) p")
|
| 131 |
+
target = rearrange(target, "b s -> (b s)")
|
| 132 |
+
|
| 133 |
+
# return torchmetrics.functional.accuracy(preds, target, task='multiclass', top_k=topk, num_classes=preds.shape[-1], ignore_index=ignore_index)
|
| 134 |
+
if ignore_index is not None:
|
| 135 |
+
# Create a mask for the ignored index
|
| 136 |
+
mask = target != ignore_index
|
| 137 |
+
# Apply the mask to the target and predictions
|
| 138 |
+
preds = preds[mask]
|
| 139 |
+
target = target[mask]
|
| 140 |
+
|
| 141 |
+
# Get the top-k predicted classes and their indices
|
| 142 |
+
_, pred_indices = torch.topk(preds, k=top_k, dim=-1)
|
| 143 |
+
|
| 144 |
+
# Determine if the true target is in the top-k predicted classes
|
| 145 |
+
correct = torch.sum(torch.eq(pred_indices, target.unsqueeze(1)), dim=1)
|
| 146 |
+
|
| 147 |
+
# Calculate the accuracy
|
| 148 |
+
accuracy = torch.mean(correct.float())
|
| 149 |
+
|
| 150 |
+
return accuracy
|
| 151 |
+
|
| 152 |
+
def _metrics(z_hat, r, target, flat_mask, output):
|
| 153 |
+
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
| 154 |
+
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
| 155 |
+
masked_target = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 156 |
+
|
| 157 |
+
assert target.shape[0] == r.shape[0]
|
| 158 |
+
# grab the indices of the r values that are in the range
|
| 159 |
+
r_idx = (r >= r_range[0]) & (r < r_range[1])
|
| 160 |
+
|
| 161 |
+
# grab the target and z_hat values that are in the range
|
| 162 |
+
r_unmasked_target = unmasked_target[r_idx]
|
| 163 |
+
r_masked_target = masked_target[r_idx]
|
| 164 |
+
r_z_hat = z_hat[r_idx]
|
| 165 |
+
|
| 166 |
+
for topk in (1, 25):
|
| 167 |
+
s, e = r_range
|
| 168 |
+
tag = f"accuracy-{s}-{e}/top{topk}"
|
| 169 |
+
|
| 170 |
+
output[f"{tag}/unmasked"] = accuracy(
|
| 171 |
+
preds=r_z_hat,
|
| 172 |
+
target=r_unmasked_target,
|
| 173 |
+
ignore_index=IGNORE_INDEX,
|
| 174 |
+
top_k=topk,
|
| 175 |
+
)
|
| 176 |
+
output[f"{tag}/masked"] = accuracy(
|
| 177 |
+
preds=r_z_hat,
|
| 178 |
+
target=r_masked_target,
|
| 179 |
+
ignore_index=IGNORE_INDEX,
|
| 180 |
+
top_k=topk,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@dataclass
|
| 185 |
+
class State:
|
| 186 |
+
model: VampNet
|
| 187 |
+
codec: DAC
|
| 188 |
+
|
| 189 |
+
optimizer: AdamW
|
| 190 |
+
scheduler: NoamScheduler
|
| 191 |
+
criterion: CrossEntropyLoss
|
| 192 |
+
grad_clip_val: float
|
| 193 |
+
|
| 194 |
+
rng: torch.quasirandom.SobolEngine
|
| 195 |
+
|
| 196 |
+
train_data: AudioDataset
|
| 197 |
+
val_data: AudioDataset
|
| 198 |
+
|
| 199 |
+
tracker: Tracker
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@timer()
|
| 203 |
+
def train_loop(state: State, batch: dict, accel: Accelerator):
|
| 204 |
+
state.model.train()
|
| 205 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
| 206 |
+
signal = apply_transform(state.train_data.transform, batch)
|
| 207 |
+
|
| 208 |
+
output = {}
|
| 209 |
+
vn = accel.unwrap(state.model)
|
| 210 |
+
with accel.autocast():
|
| 211 |
+
with torch.inference_mode():
|
| 212 |
+
state.codec.to(accel.device)
|
| 213 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 214 |
+
z = z[:, : vn.n_codebooks, :]
|
| 215 |
+
|
| 216 |
+
n_batch = z.shape[0]
|
| 217 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
| 218 |
+
|
| 219 |
+
mask = pmask.random(z, r)
|
| 220 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 221 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 222 |
+
|
| 223 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 224 |
+
|
| 225 |
+
dtype = torch.bfloat16 if accel.amp else None
|
| 226 |
+
with accel.autocast(dtype=dtype):
|
| 227 |
+
z_hat = state.model(z_mask_latent)
|
| 228 |
+
|
| 229 |
+
target = codebook_flatten(
|
| 230 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
flat_mask = codebook_flatten(
|
| 234 |
+
mask[:, vn.n_conditioning_codebooks :, :],
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# replace target with ignore index for masked tokens
|
| 238 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 239 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
| 240 |
+
|
| 241 |
+
_metrics(
|
| 242 |
+
r=r,
|
| 243 |
+
z_hat=z_hat,
|
| 244 |
+
target=target,
|
| 245 |
+
flat_mask=flat_mask,
|
| 246 |
+
output=output,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
accel.backward(output["loss"])
|
| 251 |
+
|
| 252 |
+
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
| 253 |
+
output["other/batch_size"] = z.shape[0]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
accel.scaler.unscale_(state.optimizer)
|
| 257 |
+
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
| 258 |
+
state.model.parameters(), state.grad_clip_val
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
accel.step(state.optimizer)
|
| 262 |
+
state.optimizer.zero_grad()
|
| 263 |
+
|
| 264 |
+
state.scheduler.step()
|
| 265 |
+
accel.update()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
return {k: v for k, v in sorted(output.items())}
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@timer()
|
| 272 |
+
@torch.no_grad()
|
| 273 |
+
def val_loop(state: State, batch: dict, accel: Accelerator):
|
| 274 |
+
state.model.eval()
|
| 275 |
+
state.codec.eval()
|
| 276 |
+
batch = at.util.prepare_batch(batch, accel.device)
|
| 277 |
+
signal = apply_transform(state.val_data.transform, batch)
|
| 278 |
+
|
| 279 |
+
vn = accel.unwrap(state.model)
|
| 280 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 281 |
+
z = z[:, : vn.n_codebooks, :]
|
| 282 |
+
|
| 283 |
+
n_batch = z.shape[0]
|
| 284 |
+
r = state.rng.draw(n_batch)[:, 0].to(accel.device)
|
| 285 |
+
|
| 286 |
+
mask = pmask.random(z, r)
|
| 287 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 288 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 289 |
+
|
| 290 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 291 |
+
|
| 292 |
+
z_hat = state.model(z_mask_latent)
|
| 293 |
+
|
| 294 |
+
target = codebook_flatten(
|
| 295 |
+
z[:, vn.n_conditioning_codebooks :, :],
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
flat_mask = codebook_flatten(
|
| 299 |
+
mask[:, vn.n_conditioning_codebooks :, :]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
output = {}
|
| 303 |
+
# replace target with ignore index for masked tokens
|
| 304 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
| 305 |
+
output["loss"] = state.criterion(z_hat, t_masked)
|
| 306 |
+
|
| 307 |
+
_metrics(
|
| 308 |
+
r=r,
|
| 309 |
+
z_hat=z_hat,
|
| 310 |
+
target=target,
|
| 311 |
+
flat_mask=flat_mask,
|
| 312 |
+
output=output,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return output
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def validate(state, val_dataloader, accel):
|
| 319 |
+
for batch in val_dataloader:
|
| 320 |
+
output = val_loop(state, batch, accel)
|
| 321 |
+
# Consolidate state dicts if using ZeroRedundancyOptimizer
|
| 322 |
+
if hasattr(state.optimizer, "consolidate_state_dict"):
|
| 323 |
+
state.optimizer.consolidate_state_dict()
|
| 324 |
+
return output
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def checkpoint(state, save_iters, save_path, fine_tune):
|
| 328 |
+
if accel.local_rank != 0:
|
| 329 |
+
state.tracker.print(f"ERROR:Skipping checkpoint on rank {accel.local_rank}")
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
+
metadata = {"logs": dict(state.tracker.history)}
|
| 333 |
+
|
| 334 |
+
tags = ["latest"]
|
| 335 |
+
state.tracker.print(f"Saving to {str(Path('.').absolute())}")
|
| 336 |
+
|
| 337 |
+
if state.tracker.step in save_iters:
|
| 338 |
+
tags.append(f"{state.tracker.step // 1000}k")
|
| 339 |
+
|
| 340 |
+
if state.tracker.is_best("val", "loss"):
|
| 341 |
+
state.tracker.print(f"Best model so far")
|
| 342 |
+
tags.append("best")
|
| 343 |
+
|
| 344 |
+
if fine_tune:
|
| 345 |
+
for tag in tags:
|
| 346 |
+
# save the lora model
|
| 347 |
+
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
| 348 |
+
torch.save(
|
| 349 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
| 350 |
+
f"{save_path}/{tag}/lora.pth"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
for tag in tags:
|
| 354 |
+
model_extra = {
|
| 355 |
+
"optimizer.pth": state.optimizer.state_dict(),
|
| 356 |
+
"scheduler.pth": state.scheduler.state_dict(),
|
| 357 |
+
"tracker.pth": state.tracker.state_dict(),
|
| 358 |
+
"metadata.pth": metadata,
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
accel.unwrap(state.model).metadata = metadata
|
| 362 |
+
accel.unwrap(state.model).save_to_folder(
|
| 363 |
+
f"{save_path}/{tag}", model_extra, package=False
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def save_sampled(state, z, writer):
|
| 368 |
+
num_samples = z.shape[0]
|
| 369 |
+
|
| 370 |
+
for i in range(num_samples):
|
| 371 |
+
sampled = accel.unwrap(state.model).generate(
|
| 372 |
+
codec=state.codec,
|
| 373 |
+
time_steps=z.shape[-1],
|
| 374 |
+
start_tokens=z[i : i + 1],
|
| 375 |
+
)
|
| 376 |
+
sampled.cpu().write_audio_to_tb(
|
| 377 |
+
f"sampled/{i}",
|
| 378 |
+
writer,
|
| 379 |
+
step=state.tracker.step,
|
| 380 |
+
plot_fn=None,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def save_imputation(state, z, val_idx, writer):
|
| 385 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
| 386 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
| 387 |
+
|
| 388 |
+
vn = accel.unwrap(state.model)
|
| 389 |
+
|
| 390 |
+
mask = pmask.inpaint(z, n_prefix, n_suffix)
|
| 391 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 392 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 393 |
+
|
| 394 |
+
imputed_noisy = vn.to_signal(z_mask, state.codec)
|
| 395 |
+
imputed_true = vn.to_signal(z, state.codec)
|
| 396 |
+
|
| 397 |
+
imputed = []
|
| 398 |
+
for i in range(len(z)):
|
| 399 |
+
imputed.append(
|
| 400 |
+
vn.generate(
|
| 401 |
+
codec=state.codec,
|
| 402 |
+
time_steps=z.shape[-1],
|
| 403 |
+
start_tokens=z[i][None, ...],
|
| 404 |
+
mask=mask[i][None, ...],
|
| 405 |
+
)
|
| 406 |
+
)
|
| 407 |
+
imputed = AudioSignal.batch(imputed)
|
| 408 |
+
|
| 409 |
+
for i in range(len(val_idx)):
|
| 410 |
+
imputed_noisy[i].cpu().write_audio_to_tb(
|
| 411 |
+
f"inpainted_prompt/{i}",
|
| 412 |
+
writer,
|
| 413 |
+
step=state.tracker.step,
|
| 414 |
+
plot_fn=None,
|
| 415 |
+
)
|
| 416 |
+
imputed[i].cpu().write_audio_to_tb(
|
| 417 |
+
f"inpainted_middle/{i}",
|
| 418 |
+
writer,
|
| 419 |
+
step=state.tracker.step,
|
| 420 |
+
plot_fn=None,
|
| 421 |
+
)
|
| 422 |
+
imputed_true[i].cpu().write_audio_to_tb(
|
| 423 |
+
f"reconstructed/{i}",
|
| 424 |
+
writer,
|
| 425 |
+
step=state.tracker.step,
|
| 426 |
+
plot_fn=None,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@torch.no_grad()
|
| 431 |
+
def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
| 432 |
+
state.model.eval()
|
| 433 |
+
state.codec.eval()
|
| 434 |
+
vn = accel.unwrap(state.model)
|
| 435 |
+
|
| 436 |
+
batch = [state.val_data[i] for i in val_idx]
|
| 437 |
+
batch = at.util.prepare_batch(state.val_data.collate(batch), accel.device)
|
| 438 |
+
|
| 439 |
+
signal = apply_transform(state.val_data.transform, batch)
|
| 440 |
+
|
| 441 |
+
z = state.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 442 |
+
z = z[:, : vn.n_codebooks, :]
|
| 443 |
+
|
| 444 |
+
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
mask = pmask.random(z, r)
|
| 448 |
+
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 449 |
+
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 450 |
+
|
| 451 |
+
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 452 |
+
|
| 453 |
+
z_hat = state.model(z_mask_latent)
|
| 454 |
+
|
| 455 |
+
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
| 456 |
+
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
| 457 |
+
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
| 458 |
+
|
| 459 |
+
generated = vn.to_signal(z_pred, state.codec)
|
| 460 |
+
reconstructed = vn.to_signal(z, state.codec)
|
| 461 |
+
masked = vn.to_signal(z_mask.squeeze(1), state.codec)
|
| 462 |
+
|
| 463 |
+
for i in range(generated.batch_size):
|
| 464 |
+
audio_dict = {
|
| 465 |
+
"original": signal[i],
|
| 466 |
+
"masked": masked[i],
|
| 467 |
+
"generated": generated[i],
|
| 468 |
+
"reconstructed": reconstructed[i],
|
| 469 |
+
}
|
| 470 |
+
for k, v in audio_dict.items():
|
| 471 |
+
v.cpu().write_audio_to_tb(
|
| 472 |
+
f"onestep/_{i}.r={r[i]:0.2f}/{k}",
|
| 473 |
+
writer,
|
| 474 |
+
step=state.tracker.step,
|
| 475 |
+
plot_fn=None,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
save_sampled(state=state, z=z, writer=writer)
|
| 479 |
+
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
@argbind.bind(without_prefix=True)
|
| 484 |
+
def load(
|
| 485 |
+
args,
|
| 486 |
+
accel: at.ml.Accelerator,
|
| 487 |
+
tracker: Tracker,
|
| 488 |
+
save_path: str,
|
| 489 |
+
resume: bool = False,
|
| 490 |
+
tag: str = "latest",
|
| 491 |
+
fine_tune_checkpoint: Optional[str] = None,
|
| 492 |
+
grad_clip_val: float = 5.0,
|
| 493 |
+
) -> State:
|
| 494 |
+
codec = DAC.load(args["codec_ckpt"], map_location="cpu")
|
| 495 |
+
codec.eval()
|
| 496 |
+
|
| 497 |
+
model, v_extra = None, {}
|
| 498 |
+
|
| 499 |
+
if args["fine_tune"]:
|
| 500 |
+
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 501 |
+
model = torch.compile(
|
| 502 |
+
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 503 |
+
map_location="cpu",
|
| 504 |
+
)
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
if resume:
|
| 508 |
+
kwargs = {
|
| 509 |
+
"folder": f"{save_path}/{tag}",
|
| 510 |
+
"map_location": "cpu",
|
| 511 |
+
"package": False,
|
| 512 |
+
}
|
| 513 |
+
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 514 |
+
if (Path(kwargs["folder"]) / "vampnet").exists():
|
| 515 |
+
model, v_extra = VampNet.load_from_folder(**kwargs)
|
| 516 |
+
else:
|
| 517 |
+
raise ValueError(
|
| 518 |
+
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
model = torch.compile(VampNet()) if model is None else model
|
| 525 |
+
model = accel.prepare_model(model)
|
| 526 |
+
|
| 527 |
+
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
| 528 |
+
assert (
|
| 529 |
+
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
|
| 533 |
+
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
| 534 |
+
scheduler.step()
|
| 535 |
+
|
| 536 |
+
if "optimizer.pth" in v_extra:
|
| 537 |
+
optimizer.load_state_dict(v_extra["optimizer.pth"])
|
| 538 |
+
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
| 539 |
+
if "tracker.pth" in v_extra:
|
| 540 |
+
tracker.load_state_dict(v_extra["tracker.pth"])
|
| 541 |
+
|
| 542 |
+
criterion = CrossEntropyLoss()
|
| 543 |
+
|
| 544 |
+
sample_rate = codec.sample_rate
|
| 545 |
+
|
| 546 |
+
# a better rng for sampling from our schedule
|
| 547 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
| 548 |
+
|
| 549 |
+
# log a model summary w/ num params
|
| 550 |
+
if accel.local_rank == 0:
|
| 551 |
+
add_num_params_repr_hook(accel.unwrap(model))
|
| 552 |
+
with open(f"{save_path}/model.txt", "w") as f:
|
| 553 |
+
f.write(repr(accel.unwrap(model)))
|
| 554 |
+
|
| 555 |
+
# load the datasets
|
| 556 |
+
train_data, val_data = build_datasets(args, sample_rate)
|
| 557 |
+
|
| 558 |
+
return State(
|
| 559 |
+
tracker=tracker,
|
| 560 |
+
model=model,
|
| 561 |
+
codec=codec,
|
| 562 |
+
optimizer=optimizer,
|
| 563 |
+
scheduler=scheduler,
|
| 564 |
+
criterion=criterion,
|
| 565 |
+
rng=rng,
|
| 566 |
+
train_data=train_data,
|
| 567 |
+
val_data=val_data,
|
| 568 |
+
grad_clip_val=grad_clip_val,
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
@argbind.bind(without_prefix=True)
|
| 573 |
+
def train(
|
| 574 |
+
args,
|
| 575 |
+
accel: at.ml.Accelerator,
|
| 576 |
+
seed: int = 0,
|
| 577 |
+
codec_ckpt: str = None,
|
| 578 |
+
save_path: str = "ckpt",
|
| 579 |
+
num_iters: int = int(1000e6),
|
| 580 |
+
save_iters: list = [10000, 50000, 100000, 300000, 500000,],
|
| 581 |
+
sample_freq: int = 10000,
|
| 582 |
+
val_freq: int = 1000,
|
| 583 |
+
batch_size: int = 12,
|
| 584 |
+
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
| 585 |
+
num_workers: int = 10,
|
| 586 |
+
fine_tune: bool = False,
|
| 587 |
+
):
|
| 588 |
+
assert codec_ckpt is not None, "codec_ckpt is required"
|
| 589 |
+
|
| 590 |
+
seed = seed + accel.local_rank
|
| 591 |
+
at.util.seed(seed)
|
| 592 |
+
writer = None
|
| 593 |
+
|
| 594 |
+
if accel.local_rank == 0:
|
| 595 |
+
writer = SummaryWriter(log_dir=f"{save_path}/logs/")
|
| 596 |
+
argbind.dump_args(args, f"{save_path}/args.yml")
|
| 597 |
+
|
| 598 |
+
tracker = Tracker(
|
| 599 |
+
writer=writer, log_file=f"{save_path}/log.txt", rank=accel.local_rank
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
# load the codec model
|
| 603 |
+
state: State = load(
|
| 604 |
+
args=args,
|
| 605 |
+
accel=accel,
|
| 606 |
+
tracker=tracker,
|
| 607 |
+
save_path=save_path)
|
| 608 |
+
print("initialized state.")
|
| 609 |
+
|
| 610 |
+
train_dataloader = accel.prepare_dataloader(
|
| 611 |
+
state.train_data,
|
| 612 |
+
start_idx=state.tracker.step * batch_size,
|
| 613 |
+
num_workers=num_workers,
|
| 614 |
+
batch_size=batch_size,
|
| 615 |
+
collate_fn=state.train_data.collate,
|
| 616 |
+
)
|
| 617 |
+
val_dataloader = accel.prepare_dataloader(
|
| 618 |
+
state.val_data,
|
| 619 |
+
start_idx=0,
|
| 620 |
+
num_workers=num_workers,
|
| 621 |
+
batch_size=batch_size,
|
| 622 |
+
collate_fn=state.val_data.collate,
|
| 623 |
+
persistent_workers=num_workers > 0,
|
| 624 |
+
)
|
| 625 |
+
print("initialized dataloader.")
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
if fine_tune:
|
| 630 |
+
lora.mark_only_lora_as_trainable(state.model)
|
| 631 |
+
print("marked only lora as trainable.")
|
| 632 |
+
|
| 633 |
+
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 634 |
+
# and only run when specific conditions are met.
|
| 635 |
+
global train_loop, val_loop, validate, save_samples, checkpoint
|
| 636 |
+
|
| 637 |
+
train_loop = tracker.log("train", "value", history=False)(
|
| 638 |
+
tracker.track("train", num_iters, completed=state.tracker.step)(train_loop)
|
| 639 |
+
)
|
| 640 |
+
val_loop = tracker.track("val", len(val_dataloader))(val_loop)
|
| 641 |
+
validate = tracker.log("val", "mean")(validate)
|
| 642 |
+
|
| 643 |
+
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 644 |
+
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
| 645 |
+
|
| 646 |
+
print("starting training loop.")
|
| 647 |
+
with tracker.live:
|
| 648 |
+
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 649 |
+
train_loop(state, batch, accel)
|
| 650 |
+
|
| 651 |
+
last_iter = (
|
| 652 |
+
tracker.step == num_iters - 1 if num_iters is not None else False
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if tracker.step % sample_freq == 0 or last_iter:
|
| 656 |
+
save_samples(state, val_idx, writer)
|
| 657 |
+
|
| 658 |
+
if tracker.step % val_freq == 0 or last_iter:
|
| 659 |
+
validate(state, val_dataloader, accel)
|
| 660 |
+
checkpoint(
|
| 661 |
+
state=state,
|
| 662 |
+
save_iters=save_iters,
|
| 663 |
+
save_path=save_path,
|
| 664 |
+
fine_tune=fine_tune)
|
| 665 |
+
|
| 666 |
+
# Reset validation progress bar, print summary since last validation.
|
| 667 |
+
tracker.done("val", f"Iteration {tracker.step}")
|
| 668 |
+
|
| 669 |
+
if last_iter:
|
| 670 |
+
break
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
if __name__ == "__main__":
|
| 674 |
+
args = argbind.parse_args()
|
| 675 |
+
args["args.debug"] = int(os.getenv("LOCAL_RANK", 0)) == 0
|
| 676 |
+
with argbind.scope(args):
|
| 677 |
+
with Accelerator() as accel:
|
| 678 |
+
if accel.local_rank != 0:
|
| 679 |
+
sys.tracebacklimit = 0
|
| 680 |
+
train(args, accel)
|
vampnet/scripts/utils/README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scripts
|
| 2 |
+
|
| 3 |
+
## process_zip.py
|
| 4 |
+
|
| 5 |
+
Some requirements that may not be installed in the docker image:
|
| 6 |
+
* argbind
|
| 7 |
+
* wav2wav (pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git or `pip install git+https://github.com/descriptinc/lyrebird-wav2wav.git@<branchname>`)
|
| 8 |
+
|
| 9 |
+
### zip folder structure
|
| 10 |
+
|
| 11 |
+
The zip folder should have the following internal structure:
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
base_folder/
|
| 15 |
+
test_case_1/
|
| 16 |
+
before.wav
|
| 17 |
+
test_case_2/
|
| 18 |
+
before.wav
|
| 19 |
+
...
|
| 20 |
+
test_case_n/
|
| 21 |
+
before.wav
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
Note: There can be issues with the output zip if the input zip folder structure is too deep or too shallow. IF you want/need to use a zip file with a different folder structure, adjust this:
|
| 25 |
+
https://github.com/descriptinc/lyrebird-wav2wav/blob/136c923ce19df03876a515ca0ed83854710cfa30/scripts/utils/process_zip.py#L28
|
| 26 |
+
|
| 27 |
+
### Execution
|
| 28 |
+
`python process_zip.py <path/to/zip> -tag <string>`
|
vampnet/scripts/utils/data/augment.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import audiotools as at
|
| 4 |
+
from audiotools import AudioSignal
|
| 5 |
+
|
| 6 |
+
import argbind
|
| 7 |
+
import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch_pitch_shift import pitch_shift, get_fast_shifts
|
| 12 |
+
from torch_time_stretch import time_stretch, get_fast_stretches
|
| 13 |
+
|
| 14 |
+
from audiotools.core.util import sample_from_dist
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@argbind.bind(without_prefix=True)
|
| 18 |
+
def augment(
|
| 19 |
+
audio_folder: Path = None,
|
| 20 |
+
dest_folder: Path = None,
|
| 21 |
+
n_augmentations: int = 10,
|
| 22 |
+
):
|
| 23 |
+
"""
|
| 24 |
+
Augment a folder of audio files by applying audiotools and pedalboard transforms.
|
| 25 |
+
|
| 26 |
+
The dest foler will contain a folder for each of the clean dataset's files.
|
| 27 |
+
Under each of these folders, there will be a clean file and many augmented files.
|
| 28 |
+
"""
|
| 29 |
+
assert audio_folder is not None
|
| 30 |
+
assert dest_folder is not None
|
| 31 |
+
audio_files = at.util.find_audio(audio_folder)
|
| 32 |
+
|
| 33 |
+
for audio_file in tqdm.tqdm(audio_files):
|
| 34 |
+
subtree = dest_folder / audio_file.relative_to(audio_folder).parent
|
| 35 |
+
subdir = subtree / audio_file.stem
|
| 36 |
+
subdir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
| 42 |
+
# apply pedalboard transforms
|
| 43 |
+
for j in range(n_augmentations):
|
| 44 |
+
# pitch shift between -7 and 7 semitones
|
| 45 |
+
import random
|
| 46 |
+
dst = chunk.clone()
|
| 47 |
+
dst.samples = pitch_shift(
|
| 48 |
+
dst.samples,
|
| 49 |
+
shift=random.choice(get_fast_shifts(src.sample_rate,
|
| 50 |
+
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
| 51 |
+
sample_rate=src.sample_rate
|
| 52 |
+
)
|
| 53 |
+
dst.samples = time_stretch(
|
| 54 |
+
dst.samples,
|
| 55 |
+
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
| 56 |
+
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
| 57 |
+
sample_rate=src.sample_rate,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
args = argbind.parse_args()
|
| 65 |
+
|
| 66 |
+
with argbind.scope(args):
|
| 67 |
+
augment()
|
vampnet/scripts/utils/data/maestro-reorg.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
|
| 6 |
+
output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
|
| 7 |
+
|
| 8 |
+
# split
|
| 9 |
+
with open(maestro_path / "maestro-v3.0.0.json") as f:
|
| 10 |
+
maestro = json.load(f)
|
| 11 |
+
|
| 12 |
+
breakpoint()
|
| 13 |
+
train = []
|
| 14 |
+
validation = []
|
| 15 |
+
test = []
|
| 16 |
+
for key, split in maestro["split"].items():
|
| 17 |
+
audio_filename = maestro['audio_filename'][key]
|
| 18 |
+
if split == "train":
|
| 19 |
+
train.append(audio_filename)
|
| 20 |
+
elif split == "test":
|
| 21 |
+
test.append(audio_filename)
|
| 22 |
+
elif split == "validation":
|
| 23 |
+
validation.append(audio_filename)
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f"Unknown split {split}")
|
| 26 |
+
|
| 27 |
+
# symlink all files
|
| 28 |
+
for audio_filename in train:
|
| 29 |
+
p = output_path / "train" / audio_filename
|
| 30 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
os.symlink(maestro_path / audio_filename, p)
|
| 32 |
+
for audio_filename in validation:
|
| 33 |
+
p = output_path / "validation" / audio_filename
|
| 34 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
os.symlink(maestro_path / audio_filename, p)
|
| 36 |
+
for audio_filename in test:
|
| 37 |
+
p = output_path / "test" / audio_filename
|
| 38 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
os.symlink(maestro_path / audio_filename, p)
|
vampnet/scripts/utils/plots.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
from pandas.api.types import CategoricalDtype
|
| 4 |
+
|
| 5 |
+
def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
| 6 |
+
# Add a new column to your dataframe with the latex representation
|
| 7 |
+
metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
|
| 8 |
+
|
| 9 |
+
# Order condition_latex as per the condition_to_latex dictionary
|
| 10 |
+
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
|
| 11 |
+
metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
|
| 12 |
+
|
| 13 |
+
# Compute mean and std for each condition for each metric
|
| 14 |
+
grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
|
| 15 |
+
|
| 16 |
+
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
|
| 17 |
+
|
| 18 |
+
# Set the main title for the figure
|
| 19 |
+
fig.suptitle(title, fontsize=16)
|
| 20 |
+
|
| 21 |
+
# Get color for each bar in the plot
|
| 22 |
+
bar_colors = [color_palette[condition] for condition in grouped.index]
|
| 23 |
+
|
| 24 |
+
# Plot mel
|
| 25 |
+
sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
|
| 26 |
+
axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
|
| 27 |
+
axs[0].set_xlabel('') # Remove x-axis label
|
| 28 |
+
axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
|
| 29 |
+
|
| 30 |
+
# Plot frechet
|
| 31 |
+
axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
|
| 32 |
+
axs[1].set_ylabel('FAD \u2190')
|
| 33 |
+
axs[1].set_xlabel('') # Remove x-axis label
|
| 34 |
+
axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
|
| 35 |
+
|
| 36 |
+
# Adjust the space between plots
|
| 37 |
+
plt.subplots_adjust(hspace=0.1)
|
| 38 |
+
|
| 39 |
+
# Remove any unnecessary space around the plot
|
| 40 |
+
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 41 |
+
|
| 42 |
+
# Reduce the space between suptitle and the plot
|
| 43 |
+
plt.subplots_adjust(top=0.92)
|
vampnet/scripts/utils/remove_quiet_files.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# removes files with loudness below 24db
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import shutil
|
| 5 |
+
import audiotools as at
|
| 6 |
+
import argbind
|
| 7 |
+
|
| 8 |
+
@argbind.bind(without_prefix=True)
|
| 9 |
+
def remove_quiet_files(
|
| 10 |
+
src_dir: Path = None,
|
| 11 |
+
dest_dir: Path = None,
|
| 12 |
+
min_loudness: float = -30,
|
| 13 |
+
):
|
| 14 |
+
# copy src to dest
|
| 15 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
| 17 |
+
|
| 18 |
+
audio_files = at.util.find_audio(dest_dir)
|
| 19 |
+
for audio_file in audio_files:
|
| 20 |
+
sig = at.AudioSignal(audio_file)
|
| 21 |
+
if sig.loudness() < min_loudness:
|
| 22 |
+
audio_file.unlink()
|
| 23 |
+
print(f"removed {audio_file}")
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
args = argbind.parse_args()
|
| 27 |
+
|
| 28 |
+
with argbind.scope(args):
|
| 29 |
+
remove_quiet_files()
|
vampnet/scripts/utils/split.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import random
|
| 3 |
+
import shutil
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
import argbind
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from tqdm.contrib.concurrent import thread_map
|
| 10 |
+
|
| 11 |
+
from audiotools.core import util
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@argbind.bind(without_prefix=True)
|
| 15 |
+
def train_test_split(
|
| 16 |
+
audio_folder: str = ".",
|
| 17 |
+
test_size: float = 0.2,
|
| 18 |
+
seed: int = 42,
|
| 19 |
+
pattern: str = "**/*.mp3",
|
| 20 |
+
):
|
| 21 |
+
print(f"finding audio")
|
| 22 |
+
|
| 23 |
+
audio_folder = Path(audio_folder)
|
| 24 |
+
audio_files = list(tqdm(audio_folder.glob(pattern)))
|
| 25 |
+
print(f"found {len(audio_files)} audio files")
|
| 26 |
+
|
| 27 |
+
# split according to test_size
|
| 28 |
+
n_test = int(len(audio_files) * test_size)
|
| 29 |
+
n_train = len(audio_files) - n_test
|
| 30 |
+
|
| 31 |
+
# shuffle
|
| 32 |
+
random.seed(seed)
|
| 33 |
+
random.shuffle(audio_files)
|
| 34 |
+
|
| 35 |
+
train_files = audio_files[:n_train]
|
| 36 |
+
test_files = audio_files[n_train:]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
print(f"Train files: {len(train_files)}")
|
| 40 |
+
print(f"Test files: {len(test_files)}")
|
| 41 |
+
continue_ = input("Continue [yn]? ") or "n"
|
| 42 |
+
|
| 43 |
+
if continue_ != "y":
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
for split, files in (
|
| 47 |
+
("train", train_files), ("test", test_files)
|
| 48 |
+
):
|
| 49 |
+
for file in tqdm(files):
|
| 50 |
+
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
| 51 |
+
out_file.parent.mkdir(exist_ok=True, parents=True)
|
| 52 |
+
os.symlink(file, out_file)
|
| 53 |
+
|
| 54 |
+
# save split as json
|
| 55 |
+
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
| 56 |
+
json.dump([str(f) for f in files], f)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
args = argbind.parse_args()
|
| 62 |
+
|
| 63 |
+
with argbind.scope(args):
|
| 64 |
+
train_test_split()
|
vampnet/scripts/utils/split_long_audio_file.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import argbind
|
| 3 |
+
|
| 4 |
+
import audiotools as at
|
| 5 |
+
import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@argbind.bind(without_prefix=True)
|
| 9 |
+
def split_long_audio_file(
|
| 10 |
+
file: str = None,
|
| 11 |
+
max_chunk_size_s: int = 60*10
|
| 12 |
+
):
|
| 13 |
+
file = Path(file)
|
| 14 |
+
output_dir = file.parent / file.stem
|
| 15 |
+
output_dir.mkdir()
|
| 16 |
+
|
| 17 |
+
sig = at.AudioSignal(file)
|
| 18 |
+
|
| 19 |
+
# split into chunks
|
| 20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
| 21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
| 22 |
+
preprocess=True))
|
| 23 |
+
):
|
| 24 |
+
sig.write(output_dir / f"{i}.wav")
|
| 25 |
+
|
| 26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
| 27 |
+
|
| 28 |
+
return output_dir
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
args = argbind.parse_args()
|
| 32 |
+
|
| 33 |
+
with argbind.scope(args):
|
| 34 |
+
split_long_audio_file()
|
vampnet/scripts/utils/stage.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import argbind
|
| 6 |
+
import rich
|
| 7 |
+
from audiotools.ml import Experiment
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@argbind.bind(without_prefix=True)
|
| 11 |
+
def run(
|
| 12 |
+
run_dir: str = os.getenv("PATH_TO_RUNS", "runs"),
|
| 13 |
+
name: str = None,
|
| 14 |
+
recent: bool = False,
|
| 15 |
+
):
|
| 16 |
+
if recent:
|
| 17 |
+
paths = sorted(Path(run_dir).iterdir(), key=os.path.getmtime)
|
| 18 |
+
paths = [p.name for p in paths if p.is_dir()]
|
| 19 |
+
if paths:
|
| 20 |
+
name = paths[-1]
|
| 21 |
+
|
| 22 |
+
with Experiment(run_dir, name) as exp:
|
| 23 |
+
exp.snapshot()
|
| 24 |
+
rich.print(f"Created a snapshot of {exp.parent_directory} at {exp.exp_dir}")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
args = argbind.parse_args()
|
| 29 |
+
with argbind.scope(args):
|
| 30 |
+
run()
|
vampnet/scripts/utils/visualize_embeddings.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO: train a linear probe
|
| 3 |
+
usage:
|
| 4 |
+
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
|
| 5 |
+
"""
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List
|
| 8 |
+
|
| 9 |
+
import audiotools as at
|
| 10 |
+
from audiotools import AudioSignal
|
| 11 |
+
import argbind
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
import zipfile
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
from vampnet.interface import Interface
|
| 18 |
+
import tqdm
|
| 19 |
+
|
| 20 |
+
# bind the Interface to argbind
|
| 21 |
+
Interface = argbind.bind(Interface)
|
| 22 |
+
|
| 23 |
+
PREFIX = "vampnet-embedding-"
|
| 24 |
+
|
| 25 |
+
DEBUG = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def smart_plotly_export(fig, save_path: Path):
|
| 29 |
+
img_format = save_path.suffix[1:]
|
| 30 |
+
if img_format == "html":
|
| 31 |
+
fig.write_html(save_path)
|
| 32 |
+
elif img_format == "bytes":
|
| 33 |
+
return fig.to_image(format="png")
|
| 34 |
+
# TODO: come back and make this prettier
|
| 35 |
+
elif img_format == "numpy":
|
| 36 |
+
import io
|
| 37 |
+
from PIL import Image
|
| 38 |
+
|
| 39 |
+
def plotly_fig2array(fig):
|
| 40 |
+
# convert Plotly fig to an array
|
| 41 |
+
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
| 42 |
+
buf = io.BytesIO(fig_bytes)
|
| 43 |
+
img = Image.open(buf)
|
| 44 |
+
return np.asarray(img)
|
| 45 |
+
|
| 46 |
+
return plotly_fig2array(fig)
|
| 47 |
+
elif img_format == "jpeg" or "png" or "webp":
|
| 48 |
+
fig.write_image(save_path)
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError("invalid image format")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"):
|
| 54 |
+
"""
|
| 55 |
+
dimensionality reduction for visualization!
|
| 56 |
+
saves an html plotly figure to save_path
|
| 57 |
+
parameters:
|
| 58 |
+
emb (np.ndarray): the samples to be reduced with shape (samples, features)
|
| 59 |
+
labels (list): list of labels for embedding
|
| 60 |
+
save_path (str): path where u wanna save ur figure
|
| 61 |
+
method (str): umap, tsne, or pca
|
| 62 |
+
title (str): title for ur figure
|
| 63 |
+
returns:
|
| 64 |
+
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
| 65 |
+
"""
|
| 66 |
+
import pandas as pd
|
| 67 |
+
import plotly.express as px
|
| 68 |
+
|
| 69 |
+
fig_name = f"vampnet-embeddings-layer={layer}"
|
| 70 |
+
fig_title = f"{fig_name}_{method}"
|
| 71 |
+
save_path = (output_dir / fig_name).with_suffix(".html")
|
| 72 |
+
|
| 73 |
+
if method == "umap":
|
| 74 |
+
from umap import UMAP
|
| 75 |
+
reducer = UMAP(n_components=n_components)
|
| 76 |
+
elif method == "tsne":
|
| 77 |
+
from sklearn.manifold import TSNE
|
| 78 |
+
|
| 79 |
+
reducer = TSNE(n_components=n_components)
|
| 80 |
+
elif method == "pca":
|
| 81 |
+
from sklearn.decomposition import PCA
|
| 82 |
+
|
| 83 |
+
reducer = PCA(n_components=n_components)
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"invalid method: {method}")
|
| 86 |
+
|
| 87 |
+
labels = [emb.label for emb in annotated_embeddings]
|
| 88 |
+
names = [emb.filename for emb in annotated_embeddings]
|
| 89 |
+
embs = [emb.embedding for emb in annotated_embeddings]
|
| 90 |
+
embs_at_layer = np.stack(embs)[:, layer, :]
|
| 91 |
+
projs = reducer.fit_transform(embs_at_layer)
|
| 92 |
+
|
| 93 |
+
df = pd.DataFrame(
|
| 94 |
+
{
|
| 95 |
+
"label": labels,
|
| 96 |
+
"name": names,
|
| 97 |
+
"x": projs[:, 0],
|
| 98 |
+
"y": projs[:, 1],
|
| 99 |
+
}
|
| 100 |
+
)
|
| 101 |
+
if n_components == 2:
|
| 102 |
+
fig = px.scatter(
|
| 103 |
+
df, x="x", y="y", color="label", hover_name="name", title=fig_title,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
elif n_components == 3:
|
| 107 |
+
df['z'] = projs[:, 2]
|
| 108 |
+
fig = px.scatter_3d(
|
| 109 |
+
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"can't plot {n_components} components")
|
| 113 |
+
|
| 114 |
+
fig.update_traces(
|
| 115 |
+
marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")),
|
| 116 |
+
selector=dict(mode="markers"),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return smart_plotly_export(fig, save_path)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# per JukeMIR, we want the emebddings from the middle layer?
|
| 123 |
+
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
| 124 |
+
with torch.inference_mode():
|
| 125 |
+
# preprocess the signal
|
| 126 |
+
sig = interface.preprocess(sig)
|
| 127 |
+
|
| 128 |
+
# get the coarse vampnet model
|
| 129 |
+
vampnet = interface.coarse
|
| 130 |
+
|
| 131 |
+
# get the tokens
|
| 132 |
+
z = interface.encode(sig)[:, : vampnet.n_codebooks, :]
|
| 133 |
+
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
| 134 |
+
|
| 135 |
+
# do a forward pass through the model, get the embeddings
|
| 136 |
+
_z, embeddings = vampnet(z_latents, return_activations=True)
|
| 137 |
+
# print(f"got embeddings with shape {embeddings.shape}")
|
| 138 |
+
# [layer, batch, time, n_dims]
|
| 139 |
+
# [20, 1, 600ish, 768]
|
| 140 |
+
|
| 141 |
+
# squeeze batch dim (1 bc layer should be dim 0)
|
| 142 |
+
assert (
|
| 143 |
+
embeddings.shape[1] == 1
|
| 144 |
+
), f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
| 145 |
+
embeddings = embeddings.squeeze(1)
|
| 146 |
+
|
| 147 |
+
num_layers = embeddings.shape[0]
|
| 148 |
+
assert (
|
| 149 |
+
layer < num_layers
|
| 150 |
+
), f"layer {layer} is out of bounds for model with {num_layers} layers"
|
| 151 |
+
|
| 152 |
+
# do meanpooling over the time dimension
|
| 153 |
+
embeddings = embeddings.mean(dim=-2)
|
| 154 |
+
# [20, 768]
|
| 155 |
+
|
| 156 |
+
# return the embeddings
|
| 157 |
+
return embeddings
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
from dataclasses import dataclass, fields
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class AnnotatedEmbedding:
|
| 165 |
+
label: str
|
| 166 |
+
filename: str
|
| 167 |
+
embedding: np.ndarray
|
| 168 |
+
|
| 169 |
+
def save(self, path):
|
| 170 |
+
"""Save the Embedding object to a given path as a zip file."""
|
| 171 |
+
with zipfile.ZipFile(path, "w") as archive:
|
| 172 |
+
|
| 173 |
+
# Save numpy array
|
| 174 |
+
with archive.open("embedding.npy", "w") as f:
|
| 175 |
+
np.save(f, self.embedding)
|
| 176 |
+
|
| 177 |
+
# Save non-numpy data as json
|
| 178 |
+
non_numpy_data = {
|
| 179 |
+
f.name: getattr(self, f.name)
|
| 180 |
+
for f in fields(self)
|
| 181 |
+
if f.name != "embedding"
|
| 182 |
+
}
|
| 183 |
+
with archive.open("data.json", "w") as f:
|
| 184 |
+
f.write(json.dumps(non_numpy_data).encode("utf-8"))
|
| 185 |
+
|
| 186 |
+
@classmethod
|
| 187 |
+
def load(cls, path):
|
| 188 |
+
"""Load the Embedding object from a given zip path."""
|
| 189 |
+
with zipfile.ZipFile(path, "r") as archive:
|
| 190 |
+
|
| 191 |
+
# Load numpy array
|
| 192 |
+
with archive.open("embedding.npy") as f:
|
| 193 |
+
embedding = np.load(f)
|
| 194 |
+
|
| 195 |
+
# Load non-numpy data from json
|
| 196 |
+
with archive.open("data.json") as f:
|
| 197 |
+
data = json.loads(f.read().decode("utf-8"))
|
| 198 |
+
|
| 199 |
+
return cls(embedding=embedding, **data)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@argbind.bind(without_prefix=True)
|
| 203 |
+
def main(
|
| 204 |
+
path_to_audio: str = None,
|
| 205 |
+
cache_dir: str = "./.emb_cache",
|
| 206 |
+
output_dir: str = "./vampnet_embeddings",
|
| 207 |
+
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19],
|
| 208 |
+
method: str = "tsne",
|
| 209 |
+
n_components: int = 2,
|
| 210 |
+
):
|
| 211 |
+
path_to_audio = Path(path_to_audio)
|
| 212 |
+
assert path_to_audio.exists(), f"{path_to_audio} does not exist"
|
| 213 |
+
|
| 214 |
+
cache_dir = Path(cache_dir)
|
| 215 |
+
output_dir = Path(output_dir)
|
| 216 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
| 217 |
+
|
| 218 |
+
# load our interface
|
| 219 |
+
# argbind will automatically load the default config,
|
| 220 |
+
interface = Interface()
|
| 221 |
+
|
| 222 |
+
# we expect path_to_audio to consist of a folder for each label, so let's get the list of labels
|
| 223 |
+
labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()]
|
| 224 |
+
print(f"Found {len(labels)} labels")
|
| 225 |
+
print(f"labels: {labels}")
|
| 226 |
+
|
| 227 |
+
# collect audio files, labels, and embeddings
|
| 228 |
+
annotated_embeddings = []
|
| 229 |
+
for label in labels:
|
| 230 |
+
audio_files = list(at.util.find_audio(path_to_audio / label))
|
| 231 |
+
print(f"Found {len(audio_files)} audio files for label {label}")
|
| 232 |
+
|
| 233 |
+
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"):
|
| 234 |
+
# check if we have a cached embedding for this file
|
| 235 |
+
cached_path = cache_dir / f"{label}_{audio_file.stem}.emb"
|
| 236 |
+
if cached_path.exists():
|
| 237 |
+
# if so, load it
|
| 238 |
+
if DEBUG:
|
| 239 |
+
print(f"loading cached embedding for {cached_path.stem}")
|
| 240 |
+
embedding = AnnotatedEmbedding.load(cached_path)
|
| 241 |
+
else:
|
| 242 |
+
try:
|
| 243 |
+
sig = AudioSignal(audio_file)
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"failed to load {audio_file.name} with error {e}")
|
| 246 |
+
print(f"skipping {audio_file.name}")
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# gets the embedding
|
| 250 |
+
emb = vampnet_embed(sig, interface).cpu().numpy()
|
| 251 |
+
|
| 252 |
+
# create an embedding we can save/load
|
| 253 |
+
embedding = AnnotatedEmbedding(
|
| 254 |
+
label=label, filename=audio_file.name, embedding=emb
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# cache the embeddings
|
| 258 |
+
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
| 259 |
+
embedding.save(cached_path)
|
| 260 |
+
annotated_embeddings.append(embedding)
|
| 261 |
+
|
| 262 |
+
# now, let's do a dim reduction on the embeddings and visualize them.
|
| 263 |
+
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
| 264 |
+
dim_reduce(
|
| 265 |
+
annotated_embeddings,
|
| 266 |
+
layer,
|
| 267 |
+
output_dir=output_dir,
|
| 268 |
+
n_components=n_components,
|
| 269 |
+
method=method,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == "__main__":
|
| 274 |
+
args = argbind.parse_args()
|
| 275 |
+
with argbind.scope(args):
|
| 276 |
+
main()
|
vampnet/scripts/utils/xeno-canto-dl.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from xenopy import Query
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
SPECIES = [
|
| 5 |
+
"American Robin",
|
| 6 |
+
"Northern Cardinal",
|
| 7 |
+
"Mourning Dove",
|
| 8 |
+
"American Crow",
|
| 9 |
+
"Baltimore Oriole",
|
| 10 |
+
"Blue Jay",
|
| 11 |
+
"Eastern Bluebird",
|
| 12 |
+
"House Finch",
|
| 13 |
+
"American Goldfinch",
|
| 14 |
+
"House Sparrow",
|
| 15 |
+
"Song Sparrow",
|
| 16 |
+
"Tufted Titmouse",
|
| 17 |
+
"White-breasted Nuthatch",
|
| 18 |
+
"European Starling",
|
| 19 |
+
"American Redstart",
|
| 20 |
+
"Red-winged Blackbird",
|
| 21 |
+
"Brown-headed Cowbird",
|
| 22 |
+
"Common Grackle",
|
| 23 |
+
"Boat-tailed Grackle",
|
| 24 |
+
"Common Yellowthroat",
|
| 25 |
+
"Northern Mockingbird",
|
| 26 |
+
"Carolina Wren",
|
| 27 |
+
"Eastern Meadowlark",
|
| 28 |
+
"Chipping Sparrow",
|
| 29 |
+
"Tree Swallow",
|
| 30 |
+
"Barn Swallow",
|
| 31 |
+
"Cliff Swallow",
|
| 32 |
+
"Pine Siskin",
|
| 33 |
+
"Indigo Bunting",
|
| 34 |
+
"Eastern Towhee",
|
| 35 |
+
"Carolina Chickadee",
|
| 36 |
+
"Great Crested Flycatcher",
|
| 37 |
+
"Eastern Wood-Pewee",
|
| 38 |
+
"Ovenbird",
|
| 39 |
+
"Northern Flicker",
|
| 40 |
+
"Red-eyed Vireo",
|
| 41 |
+
"American Woodcock",
|
| 42 |
+
"Eastern Phoebe",
|
| 43 |
+
"Downy Woodpecker",
|
| 44 |
+
"Scarlet Tanager",
|
| 45 |
+
"Yellow Warbler",
|
| 46 |
+
"White-eyed Vireo",
|
| 47 |
+
"Common Loon",
|
| 48 |
+
"White-throated Sparrow",
|
| 49 |
+
"Yellow-throated Vireo",
|
| 50 |
+
"Great Blue Heron",
|
| 51 |
+
"Belted Kingfisher",
|
| 52 |
+
"Pied-billed Grebe",
|
| 53 |
+
"Wild Turkey",
|
| 54 |
+
"Wood Thrush",
|
| 55 |
+
"Rose-breasted Grosbeak",
|
| 56 |
+
"Field Sparrow",
|
| 57 |
+
"Hooded Warbler",
|
| 58 |
+
"Northern Parula",
|
| 59 |
+
"Chestnut-sided Warbler",
|
| 60 |
+
"Blue-winged Warbler",
|
| 61 |
+
"Red-bellied Woodpecker",
|
| 62 |
+
"Yellow-billed Cuckoo",
|
| 63 |
+
"Gray Catbird",
|
| 64 |
+
"Northern Saw-whet Owl",
|
| 65 |
+
"Osprey",
|
| 66 |
+
"Common Nighthawk",
|
| 67 |
+
"Broad-winged Hawk",
|
| 68 |
+
"Black-throated Green Warbler",
|
| 69 |
+
"Great Horned Owl",
|
| 70 |
+
"Common Raven",
|
| 71 |
+
"Barred Owl",
|
| 72 |
+
"Canada Warbler",
|
| 73 |
+
"Magnolia Warbler",
|
| 74 |
+
"Black-and-white Warbler",
|
| 75 |
+
"Eastern Kingbird",
|
| 76 |
+
"Swainson's Thrush",
|
| 77 |
+
"Worm-eating Warbler",
|
| 78 |
+
"Prairie Warbler",
|
| 79 |
+
"Baltimore Oriole",
|
| 80 |
+
"Black-throated Blue Warbler",
|
| 81 |
+
"Louisiana Waterthrush",
|
| 82 |
+
"Blackburnian Warbler",
|
| 83 |
+
"Black-capped Chickadee",
|
| 84 |
+
"Cerulean Warbler",
|
| 85 |
+
"Red-shouldered Hawk",
|
| 86 |
+
"Cooper's Hawk",
|
| 87 |
+
"Yellow-throated Warbler",
|
| 88 |
+
"Blue-headed Vireo",
|
| 89 |
+
"Blackpoll Warbler",
|
| 90 |
+
"Ruffed Grouse",
|
| 91 |
+
"Kentucky Warbler",
|
| 92 |
+
"Hermit Thrush",
|
| 93 |
+
"Cedar Waxwing",
|
| 94 |
+
"Eastern Screech-Owl",
|
| 95 |
+
"Northern Goshawk",
|
| 96 |
+
"Green Heron",
|
| 97 |
+
"Red-tailed Hawk",
|
| 98 |
+
"Black Vulture",
|
| 99 |
+
"Hairy Woodpecker",
|
| 100 |
+
"Golden-crowned Kinglet",
|
| 101 |
+
"Ruby-crowned Kinglet",
|
| 102 |
+
"Bicknell's Thrush",
|
| 103 |
+
"Blue-gray Gnatcatcher",
|
| 104 |
+
"Veery",
|
| 105 |
+
"Pileated Woodpecker",
|
| 106 |
+
"Purple Finch",
|
| 107 |
+
"White-crowned Sparrow",
|
| 108 |
+
"Snow Bunting",
|
| 109 |
+
"Pine Grosbeak",
|
| 110 |
+
"American Tree Sparrow",
|
| 111 |
+
"Dark-eyed Junco",
|
| 112 |
+
"Snowy Owl",
|
| 113 |
+
"White-winged Crossbill",
|
| 114 |
+
"Red Crossbill",
|
| 115 |
+
"Common Redpoll",
|
| 116 |
+
"Northern Shrike",
|
| 117 |
+
"Northern Harrier",
|
| 118 |
+
"Rough-legged Hawk",
|
| 119 |
+
"Long-eared Owl",
|
| 120 |
+
"Evening Grosbeak",
|
| 121 |
+
"Northern Pintail",
|
| 122 |
+
"American Black Duck",
|
| 123 |
+
"Mallard",
|
| 124 |
+
"Canvasback",
|
| 125 |
+
"Redhead",
|
| 126 |
+
"Ring-necked Duck",
|
| 127 |
+
"Greater Scaup",
|
| 128 |
+
"Lesser Scaup",
|
| 129 |
+
"Bufflehead",
|
| 130 |
+
"Common Goldeneye",
|
| 131 |
+
"Hooded Merganser",
|
| 132 |
+
"Common Merganser",
|
| 133 |
+
"Red-breasted Merganser",
|
| 134 |
+
"Ruddy Duck",
|
| 135 |
+
"Wood Duck",
|
| 136 |
+
"Gadwall",
|
| 137 |
+
"American Wigeon",
|
| 138 |
+
"Northern Shoveler",
|
| 139 |
+
"Green-winged Teal",
|
| 140 |
+
"Blue-winged Teal",
|
| 141 |
+
"Cinnamon Teal",
|
| 142 |
+
"Ringed Teal",
|
| 143 |
+
"Cape Teal",
|
| 144 |
+
"Northern Fulmar",
|
| 145 |
+
"Yellow-billed Loon",
|
| 146 |
+
"Red-throated Loon",
|
| 147 |
+
"Arctic Loon",
|
| 148 |
+
"Pacific Loon",
|
| 149 |
+
"Horned Grebe",
|
| 150 |
+
"Red-necked Grebe",
|
| 151 |
+
"Eared Grebe",
|
| 152 |
+
"Western Grebe",
|
| 153 |
+
"Clark's Grebe",
|
| 154 |
+
"Double-crested Cormorant",
|
| 155 |
+
"Pelagic Cormorant",
|
| 156 |
+
"Great Cormorant",
|
| 157 |
+
"American White Pelican",
|
| 158 |
+
"Brown Pelican",
|
| 159 |
+
"Brandt's Cormorant",
|
| 160 |
+
"Least Bittern",
|
| 161 |
+
"Great Egret",
|
| 162 |
+
"Snowy Egret",
|
| 163 |
+
"Little Blue Heron",
|
| 164 |
+
"Tricolored Heron",
|
| 165 |
+
"Reddish Egret",
|
| 166 |
+
"Black-crowned Night-Heron",
|
| 167 |
+
"Yellow-crowned Night-Heron",
|
| 168 |
+
"White Ibis",
|
| 169 |
+
"Glossy Ibis",
|
| 170 |
+
"Roseate Spoonbill",
|
| 171 |
+
"Wood Stork",
|
| 172 |
+
"Black-bellied Whistling-Duck",
|
| 173 |
+
"Fulvous Whistling-Duck",
|
| 174 |
+
"Greater White-fronted Goose",
|
| 175 |
+
"Snow Goose",
|
| 176 |
+
"Ross's Goose",
|
| 177 |
+
"Canada Goose",
|
| 178 |
+
"Brant",
|
| 179 |
+
"Mute Swan",
|
| 180 |
+
"Tundra Swan",
|
| 181 |
+
"Whooper Swan",
|
| 182 |
+
"Sandhill Crane",
|
| 183 |
+
"Black-necked Stilt",
|
| 184 |
+
"American Avocet",
|
| 185 |
+
"Northern Jacana",
|
| 186 |
+
"Greater Yellowlegs",
|
| 187 |
+
"Lesser Yellowlegs",
|
| 188 |
+
"Willet",
|
| 189 |
+
"Spotted Sandpiper",
|
| 190 |
+
"Upland Sandpiper",
|
| 191 |
+
"Whimbrel",
|
| 192 |
+
"Long-billed Curlew",
|
| 193 |
+
"Marbled Godwit",
|
| 194 |
+
"Ruddy Turnstone",
|
| 195 |
+
"Red Knot",
|
| 196 |
+
"Sanderling",
|
| 197 |
+
"Semipalmated Sandpiper",
|
| 198 |
+
"Western Sandpiper",
|
| 199 |
+
"Least Sandpiper",
|
| 200 |
+
"White-rumped Sandpiper",
|
| 201 |
+
"Baird's Sandpiper",
|
| 202 |
+
"Pectoral Sandpiper",
|
| 203 |
+
"Dunlin",
|
| 204 |
+
"Buff-breasted Sandpiper",
|
| 205 |
+
"Short-billed Dowitcher",
|
| 206 |
+
"Long-billed Dowitcher",
|
| 207 |
+
"Common Snipe",
|
| 208 |
+
"American Woodcock",
|
| 209 |
+
"Wilson's Phalarope",
|
| 210 |
+
"Red-necked Phalarope",
|
| 211 |
+
"Red Phalarope"
|
| 212 |
+
]
|
| 213 |
+
|
| 214 |
+
from pathlib import Path
|
| 215 |
+
|
| 216 |
+
def remove_spaces(s):
|
| 217 |
+
return s.replace(" ", "")
|
| 218 |
+
|
| 219 |
+
for species in SPECIES:
|
| 220 |
+
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
| 221 |
+
continue
|
| 222 |
+
try:
|
| 223 |
+
q = Query(
|
| 224 |
+
name=species, q="A", length="10-30",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# retrieve metadata
|
| 228 |
+
metafiles = q.retrieve_meta(verbose=True)
|
| 229 |
+
# retrieve recordings
|
| 230 |
+
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
| 231 |
+
|
| 232 |
+
except:
|
| 233 |
+
print("Failed to download " + species)
|
| 234 |
+
continue
|
vampnet/setup.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages
|
| 2 |
+
from setuptools import setup
|
| 3 |
+
|
| 4 |
+
with open("README.md") as f:
|
| 5 |
+
long_description = f.read()
|
| 6 |
+
|
| 7 |
+
setup(
|
| 8 |
+
name="vampnet",
|
| 9 |
+
version="0.0.1",
|
| 10 |
+
classifiers=[
|
| 11 |
+
"Intended Audience :: Developers",
|
| 12 |
+
"Natural Language :: English",
|
| 13 |
+
"Programming Language :: Python :: 3.7",
|
| 14 |
+
"Topic :: Artistic Software",
|
| 15 |
+
"Topic :: Multimedia",
|
| 16 |
+
"Topic :: Multimedia :: Sound/Audio",
|
| 17 |
+
"Topic :: Multimedia :: Sound/Audio :: Editors",
|
| 18 |
+
"Topic :: Software Development :: Libraries",
|
| 19 |
+
],
|
| 20 |
+
description="Generative Music Modeling.",
|
| 21 |
+
long_description=long_description,
|
| 22 |
+
long_description_content_type="text/markdown",
|
| 23 |
+
author="Hugo Flores García, Prem Seetharaman",
|
| 24 |
+
author_email="hfgacrcia@descript.com",
|
| 25 |
+
url="https://github.com/hugofloresgarcia/vampnet",
|
| 26 |
+
license="MIT",
|
| 27 |
+
packages=find_packages(),
|
| 28 |
+
setup_requires=[
|
| 29 |
+
"Cython",
|
| 30 |
+
],
|
| 31 |
+
install_requires=[
|
| 32 |
+
"Cython", # Added by WAM because it seems to be needed by this repo?
|
| 33 |
+
"torch",
|
| 34 |
+
"pydantic==2.10.6",
|
| 35 |
+
"argbind>=0.3.2",
|
| 36 |
+
"numpy<1.24",
|
| 37 |
+
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
| 38 |
+
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 39 |
+
"descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git",
|
| 40 |
+
"gradio",
|
| 41 |
+
"loralib",
|
| 42 |
+
"torch_pitch_shift",
|
| 43 |
+
"plotly", # Added by WAM for clustering (see https://github.com/hugofloresgarcia/vampnet/issues/20)
|
| 44 |
+
"pyharp",
|
| 45 |
+
|
| 46 |
+
],
|
| 47 |
+
)
|
vampnet/vampnet/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from . import modules
|
| 3 |
+
from . import scheduler
|
| 4 |
+
from .interface import Interface
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
vampnet/vampnet/beats.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
from typing import List
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import librosa
|
| 12 |
+
import torch
|
| 13 |
+
import numpy as np
|
| 14 |
+
from audiotools import AudioSignal
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
|
| 19 |
+
###################
|
| 20 |
+
# beat sync utils #
|
| 21 |
+
###################
|
| 22 |
+
|
| 23 |
+
AGGREGATOR_REGISTRY = {
|
| 24 |
+
"mean": np.mean,
|
| 25 |
+
"median": np.median,
|
| 26 |
+
"max": np.max,
|
| 27 |
+
"min": np.min,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def list_aggregators() -> list:
|
| 32 |
+
return list(AGGREGATOR_REGISTRY.keys())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class TimeSegment:
|
| 37 |
+
start: float
|
| 38 |
+
end: float
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def duration(self):
|
| 42 |
+
return self.end - self.start
|
| 43 |
+
|
| 44 |
+
def __str__(self) -> str:
|
| 45 |
+
return f"{self.start} - {self.end}"
|
| 46 |
+
|
| 47 |
+
def find_overlapping_segment(
|
| 48 |
+
self, segments: List["TimeSegment"]
|
| 49 |
+
) -> Union["TimeSegment", None]:
|
| 50 |
+
"""Find the first segment that overlaps with this segment, or None if no segment overlaps"""
|
| 51 |
+
for s in segments:
|
| 52 |
+
if s.start <= self.start and s.end >= self.end:
|
| 53 |
+
return s
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mkdir(path: Union[Path, str]) -> Path:
|
| 58 |
+
p = Path(path)
|
| 59 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 60 |
+
return p
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
###################
|
| 65 |
+
# beat data #
|
| 66 |
+
###################
|
| 67 |
+
@dataclass
|
| 68 |
+
class BeatSegment(TimeSegment):
|
| 69 |
+
downbeat: bool = False # if there's a downbeat on the start_time
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Beats:
|
| 73 |
+
def __init__(self, beat_times, downbeat_times):
|
| 74 |
+
if isinstance(beat_times, np.ndarray):
|
| 75 |
+
beat_times = beat_times.tolist()
|
| 76 |
+
if isinstance(downbeat_times, np.ndarray):
|
| 77 |
+
downbeat_times = downbeat_times.tolist()
|
| 78 |
+
self._beat_times = beat_times
|
| 79 |
+
self._downbeat_times = downbeat_times
|
| 80 |
+
self._use_downbeats = False
|
| 81 |
+
|
| 82 |
+
def use_downbeats(self, use_downbeats: bool = True):
|
| 83 |
+
"""use downbeats instead of beats when calling beat_times"""
|
| 84 |
+
self._use_downbeats = use_downbeats
|
| 85 |
+
|
| 86 |
+
def beat_segments(self, signal: AudioSignal) -> List[BeatSegment]:
|
| 87 |
+
"""
|
| 88 |
+
segments a song into time segments corresponding to beats.
|
| 89 |
+
the first segment starts at 0 and ends at the first beat time.
|
| 90 |
+
the last segment starts at the last beat time and ends at the end of the song.
|
| 91 |
+
"""
|
| 92 |
+
beat_times = self._beat_times.copy()
|
| 93 |
+
downbeat_times = self._downbeat_times
|
| 94 |
+
beat_times.insert(0, 0)
|
| 95 |
+
beat_times.append(signal.signal_duration)
|
| 96 |
+
|
| 97 |
+
downbeat_ids = np.intersect1d(beat_times, downbeat_times, return_indices=True)[
|
| 98 |
+
1
|
| 99 |
+
]
|
| 100 |
+
is_downbeat = [
|
| 101 |
+
True if i in downbeat_ids else False for i in range(len(beat_times))
|
| 102 |
+
]
|
| 103 |
+
segments = [
|
| 104 |
+
BeatSegment(start_time, end_time, downbeat)
|
| 105 |
+
for start_time, end_time, downbeat in zip(
|
| 106 |
+
beat_times[:-1], beat_times[1:], is_downbeat
|
| 107 |
+
)
|
| 108 |
+
]
|
| 109 |
+
return segments
|
| 110 |
+
|
| 111 |
+
def get_beats(self) -> np.ndarray:
|
| 112 |
+
"""returns an array of beat times, in seconds
|
| 113 |
+
if downbeats is True, returns an array of downbeat times, in seconds
|
| 114 |
+
"""
|
| 115 |
+
return np.array(
|
| 116 |
+
self._downbeat_times if self._use_downbeats else self._beat_times
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def beat_times(self) -> np.ndarray:
|
| 121 |
+
"""return beat times"""
|
| 122 |
+
return np.array(self._beat_times)
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def downbeat_times(self) -> np.ndarray:
|
| 126 |
+
"""return downbeat times"""
|
| 127 |
+
return np.array(self._downbeat_times)
|
| 128 |
+
|
| 129 |
+
def beat_times_to_feature_frames(
|
| 130 |
+
self, signal: AudioSignal, features: np.ndarray
|
| 131 |
+
) -> np.ndarray:
|
| 132 |
+
"""convert beat times to frames, given an array of time-varying features"""
|
| 133 |
+
beat_times = self.get_beats()
|
| 134 |
+
beat_frames = (
|
| 135 |
+
beat_times * signal.sample_rate / signal.signal_length * features.shape[-1]
|
| 136 |
+
).astype(np.int64)
|
| 137 |
+
return beat_frames
|
| 138 |
+
|
| 139 |
+
def sync_features(
|
| 140 |
+
self, feature_frames: np.ndarray, features: np.ndarray, aggregate="median"
|
| 141 |
+
) -> np.ndarray:
|
| 142 |
+
"""sync features to beats"""
|
| 143 |
+
if aggregate not in AGGREGATOR_REGISTRY:
|
| 144 |
+
raise ValueError(f"unknown aggregation method {aggregate}")
|
| 145 |
+
|
| 146 |
+
return librosa.util.sync(
|
| 147 |
+
features, feature_frames, aggregate=AGGREGATOR_REGISTRY[aggregate]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def to_json(self) -> dict:
|
| 151 |
+
"""return beats and downbeats as json"""
|
| 152 |
+
return {
|
| 153 |
+
"beats": self._beat_times,
|
| 154 |
+
"downbeats": self._downbeat_times,
|
| 155 |
+
"use_downbeats": self._use_downbeats,
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
@classmethod
|
| 159 |
+
def from_dict(cls, data: dict):
|
| 160 |
+
"""load beats and downbeats from json"""
|
| 161 |
+
inst = cls(data["beats"], data["downbeats"])
|
| 162 |
+
inst.use_downbeats(data["use_downbeats"])
|
| 163 |
+
return inst
|
| 164 |
+
|
| 165 |
+
def save(self, output_dir: Path):
|
| 166 |
+
"""save beats and downbeats to json"""
|
| 167 |
+
mkdir(output_dir)
|
| 168 |
+
with open(output_dir / "beats.json", "w") as f:
|
| 169 |
+
json.dump(self.to_json(), f)
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def load(cls, input_dir: Path):
|
| 173 |
+
"""load beats and downbeats from json"""
|
| 174 |
+
beats_file = Path(input_dir) / "beats.json"
|
| 175 |
+
with open(beats_file, "r") as f:
|
| 176 |
+
data = json.load(f)
|
| 177 |
+
return cls.from_dict(data)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
###################
|
| 181 |
+
# beat tracking #
|
| 182 |
+
###################
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class BeatTracker:
|
| 186 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
| 187 |
+
"""extract beats from an audio signal"""
|
| 188 |
+
raise NotImplementedError
|
| 189 |
+
|
| 190 |
+
def __call__(self, signal: AudioSignal) -> Beats:
|
| 191 |
+
"""extract beats from an audio signal
|
| 192 |
+
NOTE: if the first beat (and/or downbeat) is detected within the first 100ms of the audio,
|
| 193 |
+
it is discarded. This is to avoid empty bins with no beat synced features in the first beat.
|
| 194 |
+
Args:
|
| 195 |
+
signal (AudioSignal): signal to beat track
|
| 196 |
+
Returns:
|
| 197 |
+
Tuple[np.ndarray, np.ndarray]: beats and downbeats
|
| 198 |
+
"""
|
| 199 |
+
beats, downbeats = self.extract_beats(signal)
|
| 200 |
+
return Beats(beats, downbeats)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class WaveBeat(BeatTracker):
|
| 204 |
+
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 205 |
+
from wavebeat.dstcn import dsTCNModel
|
| 206 |
+
|
| 207 |
+
model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device))
|
| 208 |
+
model.eval()
|
| 209 |
+
|
| 210 |
+
self.device = device
|
| 211 |
+
self.model = model
|
| 212 |
+
|
| 213 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
| 214 |
+
"""returns beat and downbeat times, in seconds"""
|
| 215 |
+
# extract beats
|
| 216 |
+
beats, downbeats = self.model.predict_beats_from_array(
|
| 217 |
+
audio=signal.audio_data.squeeze(0),
|
| 218 |
+
sr=signal.sample_rate,
|
| 219 |
+
use_gpu=self.device != "cpu",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return beats, downbeats
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class MadmomBeats(BeatTracker):
|
| 226 |
+
def __init__(self):
|
| 227 |
+
raise NotImplementedError
|
| 228 |
+
|
| 229 |
+
def extract_beats(self, signal: AudioSignal) -> Tuple[np.ndarray, np.ndarray]:
|
| 230 |
+
"""returns beat and downbeat times, in seconds"""
|
| 231 |
+
pass
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
BEAT_TRACKER_REGISTRY = {
|
| 235 |
+
"wavebeat": WaveBeat,
|
| 236 |
+
"madmom": MadmomBeats,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def list_beat_trackers() -> list:
|
| 241 |
+
return list(BEAT_TRACKER_REGISTRY.keys())
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
|
| 245 |
+
if beat_tracker not in BEAT_TRACKER_REGISTRY:
|
| 246 |
+
raise ValueError(
|
| 247 |
+
f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
|
vampnet/vampnet/interface.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
import tqdm
|
| 9 |
+
|
| 10 |
+
from .modules.transformer import VampNet
|
| 11 |
+
from .beats import WaveBeat
|
| 12 |
+
from .mask import *
|
| 13 |
+
|
| 14 |
+
# from dac.model.dac import DAC
|
| 15 |
+
from lac.model.lac import LAC as DAC
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def signal_concat(
|
| 19 |
+
audio_signals: list,
|
| 20 |
+
):
|
| 21 |
+
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1)
|
| 22 |
+
|
| 23 |
+
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _load_model(
|
| 27 |
+
ckpt: str,
|
| 28 |
+
lora_ckpt: str = None,
|
| 29 |
+
device: str = "cpu",
|
| 30 |
+
chunk_size_s: int = 10,
|
| 31 |
+
):
|
| 32 |
+
# we need to set strict to False if the model has lora weights to add later
|
| 33 |
+
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
|
| 34 |
+
|
| 35 |
+
# load lora weights if needed
|
| 36 |
+
if lora_ckpt is not None:
|
| 37 |
+
if not Path(lora_ckpt).exists():
|
| 38 |
+
should_cont = input(
|
| 39 |
+
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
|
| 40 |
+
)
|
| 41 |
+
if should_cont != "y":
|
| 42 |
+
raise Exception("aborting")
|
| 43 |
+
else:
|
| 44 |
+
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
|
| 45 |
+
|
| 46 |
+
model.to(device)
|
| 47 |
+
model.eval()
|
| 48 |
+
model.chunk_size_s = chunk_size_s
|
| 49 |
+
return model
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Interface(torch.nn.Module):
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
coarse_ckpt: str = None,
|
| 57 |
+
coarse_lora_ckpt: str = None,
|
| 58 |
+
coarse2fine_ckpt: str = None,
|
| 59 |
+
coarse2fine_lora_ckpt: str = None,
|
| 60 |
+
codec_ckpt: str = None,
|
| 61 |
+
wavebeat_ckpt: str = None,
|
| 62 |
+
device: str = "cpu",
|
| 63 |
+
coarse_chunk_size_s: int = 10,
|
| 64 |
+
coarse2fine_chunk_size_s: int = 3,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
+
self.codec = DAC.load(Path(codec_ckpt))
|
| 69 |
+
self.codec.eval()
|
| 70 |
+
self.codec.to(device)
|
| 71 |
+
|
| 72 |
+
assert coarse_ckpt is not None, "must provide a coarse checkpoint"
|
| 73 |
+
self.coarse = _load_model(
|
| 74 |
+
ckpt=coarse_ckpt,
|
| 75 |
+
lora_ckpt=coarse_lora_ckpt,
|
| 76 |
+
device=device,
|
| 77 |
+
chunk_size_s=coarse_chunk_size_s,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# check if we have a coarse2fine ckpt
|
| 81 |
+
if coarse2fine_ckpt is not None:
|
| 82 |
+
self.c2f = _load_model(
|
| 83 |
+
ckpt=coarse2fine_ckpt,
|
| 84 |
+
lora_ckpt=coarse2fine_lora_ckpt,
|
| 85 |
+
device=device,
|
| 86 |
+
chunk_size_s=coarse2fine_chunk_size_s,
|
| 87 |
+
)
|
| 88 |
+
else:
|
| 89 |
+
self.c2f = None
|
| 90 |
+
|
| 91 |
+
if wavebeat_ckpt is not None:
|
| 92 |
+
print(f"loading wavebeat from {wavebeat_ckpt}")
|
| 93 |
+
self.beat_tracker = WaveBeat(wavebeat_ckpt)
|
| 94 |
+
self.beat_tracker.model.to(device)
|
| 95 |
+
else:
|
| 96 |
+
self.beat_tracker = None
|
| 97 |
+
|
| 98 |
+
self.device = device
|
| 99 |
+
|
| 100 |
+
def lora_load(
|
| 101 |
+
self,
|
| 102 |
+
coarse_ckpt: str = None,
|
| 103 |
+
c2f_ckpt: str = None,
|
| 104 |
+
full_ckpts: bool = False,
|
| 105 |
+
):
|
| 106 |
+
if full_ckpts:
|
| 107 |
+
if coarse_ckpt is not None:
|
| 108 |
+
self.coarse = _load_model(
|
| 109 |
+
ckpt=coarse_ckpt,
|
| 110 |
+
device=self.device,
|
| 111 |
+
chunk_size_s=self.coarse.chunk_size_s,
|
| 112 |
+
)
|
| 113 |
+
if c2f_ckpt is not None:
|
| 114 |
+
self.c2f = _load_model(
|
| 115 |
+
ckpt=c2f_ckpt,
|
| 116 |
+
device=self.device,
|
| 117 |
+
chunk_size_s=self.c2f.chunk_size_s,
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
if coarse_ckpt is not None:
|
| 121 |
+
self.coarse.to("cpu")
|
| 122 |
+
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
| 123 |
+
print(f"loading coarse from {coarse_ckpt}")
|
| 124 |
+
self.coarse.load_state_dict(state_dict, strict=False)
|
| 125 |
+
self.coarse.to(self.device)
|
| 126 |
+
if c2f_ckpt is not None:
|
| 127 |
+
self.c2f.to("cpu")
|
| 128 |
+
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
| 129 |
+
print(f"loading c2f from {c2f_ckpt}")
|
| 130 |
+
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
+
self.c2f.to(self.device)
|
| 132 |
+
|
| 133 |
+
def s2t(self, seconds: float):
|
| 134 |
+
"""seconds to tokens"""
|
| 135 |
+
if isinstance(seconds, np.ndarray):
|
| 136 |
+
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
| 137 |
+
else:
|
| 138 |
+
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length)
|
| 139 |
+
|
| 140 |
+
def s2t2s(self, seconds: float):
|
| 141 |
+
"""seconds to tokens to seconds"""
|
| 142 |
+
return self.t2s(self.s2t(seconds))
|
| 143 |
+
|
| 144 |
+
def t2s(self, tokens: int):
|
| 145 |
+
"""tokens to seconds"""
|
| 146 |
+
return tokens * self.codec.hop_length / self.codec.sample_rate
|
| 147 |
+
|
| 148 |
+
def to(self, device):
|
| 149 |
+
self.device = device
|
| 150 |
+
self.coarse.to(device)
|
| 151 |
+
self.codec.to(device)
|
| 152 |
+
|
| 153 |
+
if self.c2f is not None:
|
| 154 |
+
self.c2f.to(device)
|
| 155 |
+
|
| 156 |
+
if self.beat_tracker is not None:
|
| 157 |
+
self.beat_tracker.model.to(device)
|
| 158 |
+
return self
|
| 159 |
+
|
| 160 |
+
def to_signal(self, z: torch.Tensor):
|
| 161 |
+
return self.coarse.to_signal(z, self.codec)
|
| 162 |
+
|
| 163 |
+
def preprocess(self, signal: AudioSignal):
|
| 164 |
+
signal = (
|
| 165 |
+
signal.clone()
|
| 166 |
+
.resample(self.codec.sample_rate)
|
| 167 |
+
.to_mono()
|
| 168 |
+
.normalize(-24)
|
| 169 |
+
.ensure_max_of_audio(1.0)
|
| 170 |
+
)
|
| 171 |
+
return signal
|
| 172 |
+
|
| 173 |
+
@torch.inference_mode()
|
| 174 |
+
def encode(self, signal: AudioSignal):
|
| 175 |
+
signal = self.preprocess(signal).to(self.device)
|
| 176 |
+
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 177 |
+
return z
|
| 178 |
+
|
| 179 |
+
def snap_to_beats(
|
| 180 |
+
self,
|
| 181 |
+
signal: AudioSignal
|
| 182 |
+
):
|
| 183 |
+
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
|
| 184 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
| 185 |
+
|
| 186 |
+
# trim the signa around the first beat time
|
| 187 |
+
samples_begin = int(beats[0] * signal.sample_rate )
|
| 188 |
+
samples_end = int(beats[-1] * signal.sample_rate)
|
| 189 |
+
print(beats[0])
|
| 190 |
+
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
|
| 191 |
+
|
| 192 |
+
return signal
|
| 193 |
+
|
| 194 |
+
def make_beat_mask(self,
|
| 195 |
+
signal: AudioSignal,
|
| 196 |
+
before_beat_s: float = 0.0,
|
| 197 |
+
after_beat_s: float = 0.02,
|
| 198 |
+
mask_downbeats: bool = True,
|
| 199 |
+
mask_upbeats: bool = True,
|
| 200 |
+
downbeat_downsample_factor: int = None,
|
| 201 |
+
beat_downsample_factor: int = None,
|
| 202 |
+
dropout: float = 0.0,
|
| 203 |
+
invert: bool = True,
|
| 204 |
+
):
|
| 205 |
+
"""make a beat synced mask. that is, make a mask that
|
| 206 |
+
places 1s at and around the beat, and 0s everywhere else.
|
| 207 |
+
"""
|
| 208 |
+
assert self.beat_tracker is not None, "No beat tracker loaded"
|
| 209 |
+
|
| 210 |
+
# get the beat times
|
| 211 |
+
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
| 212 |
+
|
| 213 |
+
# get the beat indices in z
|
| 214 |
+
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
| 215 |
+
|
| 216 |
+
# remove downbeats from beats
|
| 217 |
+
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
|
| 218 |
+
beats_z = beats_z.tolist()
|
| 219 |
+
downbeats_z = downbeats_z.tolist()
|
| 220 |
+
|
| 221 |
+
# make the mask
|
| 222 |
+
seq_len = self.s2t(signal.duration)
|
| 223 |
+
mask = torch.zeros(seq_len, device=self.device)
|
| 224 |
+
|
| 225 |
+
mask_b4 = self.s2t(before_beat_s)
|
| 226 |
+
mask_after = self.s2t(after_beat_s)
|
| 227 |
+
|
| 228 |
+
if beat_downsample_factor is not None:
|
| 229 |
+
if beat_downsample_factor < 1:
|
| 230 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
| 231 |
+
else:
|
| 232 |
+
beat_downsample_factor = 1
|
| 233 |
+
|
| 234 |
+
if downbeat_downsample_factor is not None:
|
| 235 |
+
if downbeat_downsample_factor < 1:
|
| 236 |
+
raise ValueError("mask_beat_downsample_factor must be >= 1 or None")
|
| 237 |
+
else:
|
| 238 |
+
downbeat_downsample_factor = 1
|
| 239 |
+
|
| 240 |
+
beats_z = beats_z[::beat_downsample_factor]
|
| 241 |
+
downbeats_z = downbeats_z[::downbeat_downsample_factor]
|
| 242 |
+
print(f"beats_z: {len(beats_z)}")
|
| 243 |
+
print(f"downbeats_z: {len(downbeats_z)}")
|
| 244 |
+
|
| 245 |
+
if mask_upbeats:
|
| 246 |
+
for beat_idx in beats_z:
|
| 247 |
+
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
| 248 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 249 |
+
_m = torch.ones(num_steps, device=self.device)
|
| 250 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 251 |
+
_m = _m * _m_mask.long()
|
| 252 |
+
|
| 253 |
+
mask[_slice[0]:_slice[1]] = _m
|
| 254 |
+
|
| 255 |
+
if mask_downbeats:
|
| 256 |
+
for downbeat_idx in downbeats_z:
|
| 257 |
+
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
| 258 |
+
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 259 |
+
_m = torch.ones(num_steps, device=self.device)
|
| 260 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 261 |
+
_m = _m * _m_mask.long()
|
| 262 |
+
|
| 263 |
+
mask[_slice[0]:_slice[1]] = _m
|
| 264 |
+
|
| 265 |
+
mask = mask.clamp(0, 1)
|
| 266 |
+
if invert:
|
| 267 |
+
mask = 1 - mask
|
| 268 |
+
|
| 269 |
+
mask = mask[None, None, :].bool().long()
|
| 270 |
+
if self.c2f is not None:
|
| 271 |
+
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
|
| 272 |
+
else:
|
| 273 |
+
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
| 274 |
+
return mask
|
| 275 |
+
|
| 276 |
+
def coarse_to_fine(
|
| 277 |
+
self,
|
| 278 |
+
z: torch.Tensor,
|
| 279 |
+
mask: torch.Tensor = None,
|
| 280 |
+
**kwargs
|
| 281 |
+
):
|
| 282 |
+
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 283 |
+
length = z.shape[-1]
|
| 284 |
+
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
| 285 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
| 286 |
+
|
| 287 |
+
# zero pad to chunk_len
|
| 288 |
+
if length % chunk_len != 0:
|
| 289 |
+
pad_len = chunk_len - (length % chunk_len)
|
| 290 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
| 291 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
| 292 |
+
|
| 293 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
| 294 |
+
if n_codebooks_to_append > 0:
|
| 295 |
+
z = torch.cat([
|
| 296 |
+
z,
|
| 297 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
| 298 |
+
], dim=1)
|
| 299 |
+
|
| 300 |
+
# set the mask to 0 for all conditioning codebooks
|
| 301 |
+
if mask is not None:
|
| 302 |
+
mask = mask.clone()
|
| 303 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
| 304 |
+
|
| 305 |
+
fine_z = []
|
| 306 |
+
for i in range(n_chunks):
|
| 307 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 308 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
| 309 |
+
|
| 310 |
+
chunk = self.c2f.generate(
|
| 311 |
+
codec=self.codec,
|
| 312 |
+
time_steps=chunk_len,
|
| 313 |
+
start_tokens=chunk,
|
| 314 |
+
return_signal=False,
|
| 315 |
+
mask=mask_chunk,
|
| 316 |
+
**kwargs
|
| 317 |
+
)
|
| 318 |
+
fine_z.append(chunk)
|
| 319 |
+
|
| 320 |
+
fine_z = torch.cat(fine_z, dim=-1)
|
| 321 |
+
return fine_z[:, :, :length].clone()
|
| 322 |
+
|
| 323 |
+
def coarse_vamp(
|
| 324 |
+
self,
|
| 325 |
+
z,
|
| 326 |
+
mask,
|
| 327 |
+
return_mask=False,
|
| 328 |
+
gen_fn=None,
|
| 329 |
+
**kwargs
|
| 330 |
+
):
|
| 331 |
+
# coarse z
|
| 332 |
+
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
| 333 |
+
assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
| 334 |
+
|
| 335 |
+
mask = mask[:, : self.coarse.n_codebooks, :]
|
| 336 |
+
|
| 337 |
+
cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
|
| 338 |
+
cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
|
| 339 |
+
|
| 340 |
+
gen_fn = gen_fn or self.coarse.generate
|
| 341 |
+
c_vamp = gen_fn(
|
| 342 |
+
codec=self.codec,
|
| 343 |
+
time_steps=cz.shape[-1],
|
| 344 |
+
start_tokens=cz,
|
| 345 |
+
mask=mask,
|
| 346 |
+
return_signal=False,
|
| 347 |
+
**kwargs
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# add the fine codes back in
|
| 351 |
+
c_vamp = torch.cat(
|
| 352 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
| 353 |
+
dim=1
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if return_mask:
|
| 357 |
+
return c_vamp, cz_masked
|
| 358 |
+
|
| 359 |
+
return c_vamp
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
import audiotools as at
|
| 364 |
+
import logging
|
| 365 |
+
logger = logging.getLogger()
|
| 366 |
+
logger.setLevel(logging.INFO)
|
| 367 |
+
torch.set_printoptions(threshold=10000)
|
| 368 |
+
at.util.seed(42)
|
| 369 |
+
|
| 370 |
+
interface = Interface(
|
| 371 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 372 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 373 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
| 374 |
+
device="cuda",
|
| 375 |
+
wavebeat_ckpt="./models/wavebeat.pth"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
sig = at.AudioSignal('assets/example.wav')
|
| 380 |
+
|
| 381 |
+
z = interface.encode(sig)
|
| 382 |
+
breakpoint()
|
| 383 |
+
|
| 384 |
+
# mask = linear_random(z, 1.0)
|
| 385 |
+
# mask = mask_and(
|
| 386 |
+
# mask, periodic_mask(
|
| 387 |
+
# z,
|
| 388 |
+
# 32,
|
| 389 |
+
# 1,
|
| 390 |
+
# random_roll=True
|
| 391 |
+
# )
|
| 392 |
+
# )
|
| 393 |
+
|
| 394 |
+
# mask = interface.make_beat_mask(
|
| 395 |
+
# sig, 0.0, 0.075
|
| 396 |
+
# )
|
| 397 |
+
# mask = dropout(mask, 0.0)
|
| 398 |
+
# mask = codebook_unmask(mask, 0)
|
| 399 |
+
|
| 400 |
+
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
| 401 |
+
|
| 402 |
+
zv, mask_z = interface.coarse_vamp(
|
| 403 |
+
z,
|
| 404 |
+
mask=mask,
|
| 405 |
+
sampling_steps=36,
|
| 406 |
+
temperature=8.0,
|
| 407 |
+
return_mask=True,
|
| 408 |
+
gen_fn=interface.coarse.generate
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
use_coarse2fine = True
|
| 413 |
+
if use_coarse2fine:
|
| 414 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
| 415 |
+
breakpoint()
|
| 416 |
+
|
| 417 |
+
mask = interface.to_signal(mask_z).cpu()
|
| 418 |
+
|
| 419 |
+
sig = interface.to_signal(zv).cpu()
|
| 420 |
+
print("done")
|
| 421 |
+
|
| 422 |
+
|
vampnet/vampnet/mask.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from audiotools import AudioSignal
|
| 5 |
+
|
| 6 |
+
from .util import scalar_to_batch_tensor
|
| 7 |
+
|
| 8 |
+
def _gamma(r):
|
| 9 |
+
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
| 10 |
+
|
| 11 |
+
def _invgamma(y):
|
| 12 |
+
if not torch.is_tensor(y):
|
| 13 |
+
y = torch.tensor(y)[None]
|
| 14 |
+
return 2 * y.acos() / torch.pi
|
| 15 |
+
|
| 16 |
+
def full_mask(x: torch.Tensor):
|
| 17 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 18 |
+
return torch.ones_like(x).long()
|
| 19 |
+
|
| 20 |
+
def empty_mask(x: torch.Tensor):
|
| 21 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 22 |
+
return torch.zeros_like(x).long()
|
| 23 |
+
|
| 24 |
+
def apply_mask(
|
| 25 |
+
x: torch.Tensor,
|
| 26 |
+
mask: torch.Tensor,
|
| 27 |
+
mask_token: int
|
| 28 |
+
):
|
| 29 |
+
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
| 30 |
+
assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
|
| 31 |
+
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
|
| 32 |
+
assert ~torch.any(mask > 1), "mask must be binary"
|
| 33 |
+
assert ~torch.any(mask < 0), "mask must be binary"
|
| 34 |
+
|
| 35 |
+
fill_x = torch.full_like(x, mask_token)
|
| 36 |
+
x = x * (1 - mask) + fill_x * mask
|
| 37 |
+
|
| 38 |
+
return x, mask
|
| 39 |
+
|
| 40 |
+
def random(
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
r: torch.Tensor
|
| 43 |
+
):
|
| 44 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 45 |
+
if not isinstance(r, torch.Tensor):
|
| 46 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
| 47 |
+
|
| 48 |
+
r = _gamma(r)[:, None, None]
|
| 49 |
+
probs = torch.ones_like(x) * r
|
| 50 |
+
|
| 51 |
+
mask = torch.bernoulli(probs)
|
| 52 |
+
mask = mask.round().long()
|
| 53 |
+
|
| 54 |
+
return mask
|
| 55 |
+
|
| 56 |
+
def linear_random(
|
| 57 |
+
x: torch.Tensor,
|
| 58 |
+
r: torch.Tensor,
|
| 59 |
+
):
|
| 60 |
+
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 61 |
+
if not isinstance(r, torch.Tensor):
|
| 62 |
+
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
| 63 |
+
|
| 64 |
+
probs = torch.ones_like(x).to(x.device).float()
|
| 65 |
+
# expand to batch and codebook dims
|
| 66 |
+
probs = probs.expand(x.shape[0], x.shape[1], -1)
|
| 67 |
+
probs = probs * r
|
| 68 |
+
|
| 69 |
+
mask = torch.bernoulli(probs)
|
| 70 |
+
mask = mask.round().long()
|
| 71 |
+
|
| 72 |
+
return mask
|
| 73 |
+
|
| 74 |
+
def inpaint(x: torch.Tensor,
|
| 75 |
+
n_prefix,
|
| 76 |
+
n_suffix,
|
| 77 |
+
):
|
| 78 |
+
assert n_prefix is not None
|
| 79 |
+
assert n_suffix is not None
|
| 80 |
+
|
| 81 |
+
mask = full_mask(x)
|
| 82 |
+
|
| 83 |
+
# if we have a prefix or suffix, set their mask prob to 0
|
| 84 |
+
if n_prefix > 0:
|
| 85 |
+
if not isinstance(n_prefix, torch.Tensor):
|
| 86 |
+
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
| 87 |
+
for i, n in enumerate(n_prefix):
|
| 88 |
+
if n > 0:
|
| 89 |
+
mask[i, :, :n] = 0.0
|
| 90 |
+
if n_suffix > 0:
|
| 91 |
+
if not isinstance(n_suffix, torch.Tensor):
|
| 92 |
+
n_suffix = scalar_to_batch_tensor(n_suffix, x.shape[0]).to(x.device)
|
| 93 |
+
for i, n in enumerate(n_suffix):
|
| 94 |
+
if n > 0:
|
| 95 |
+
mask[i, :, -n:] = 0.0
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
return mask
|
| 99 |
+
|
| 100 |
+
def periodic_mask(x: torch.Tensor,
|
| 101 |
+
period: int, width: int = 1,
|
| 102 |
+
random_roll=False,
|
| 103 |
+
):
|
| 104 |
+
mask = full_mask(x)
|
| 105 |
+
if period == 0:
|
| 106 |
+
return mask
|
| 107 |
+
|
| 108 |
+
if not isinstance(period, torch.Tensor):
|
| 109 |
+
period = scalar_to_batch_tensor(period, x.shape[0])
|
| 110 |
+
for i, factor in enumerate(period):
|
| 111 |
+
if factor == 0:
|
| 112 |
+
continue
|
| 113 |
+
for j in range(mask.shape[-1]):
|
| 114 |
+
if j % factor == 0:
|
| 115 |
+
# figure out how wide the mask should be
|
| 116 |
+
j_start = max(0, j - width // 2 )
|
| 117 |
+
j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
|
| 118 |
+
# flip a coin for each position in the mask
|
| 119 |
+
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
| 120 |
+
assert torch.all(j_mask == 1)
|
| 121 |
+
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
|
| 122 |
+
assert torch.all(j_fill == 0)
|
| 123 |
+
# fill
|
| 124 |
+
mask[i, :, j_start:j_end] = j_fill
|
| 125 |
+
if random_roll:
|
| 126 |
+
# add a random offset to the mask
|
| 127 |
+
offset = torch.randint(0, period[0], (1,))
|
| 128 |
+
mask = torch.roll(mask, offset.item(), dims=-1)
|
| 129 |
+
|
| 130 |
+
return mask
|
| 131 |
+
|
| 132 |
+
def codebook_unmask(
|
| 133 |
+
mask: torch.Tensor,
|
| 134 |
+
n_conditioning_codebooks: int
|
| 135 |
+
):
|
| 136 |
+
if n_conditioning_codebooks == None:
|
| 137 |
+
return mask
|
| 138 |
+
# if we have any conditioning codebooks, set their mask to 0
|
| 139 |
+
mask = mask.clone()
|
| 140 |
+
mask[:, :n_conditioning_codebooks, :] = 0
|
| 141 |
+
return mask
|
| 142 |
+
|
| 143 |
+
def codebook_mask(mask: torch.Tensor, start: int):
|
| 144 |
+
mask = mask.clone()
|
| 145 |
+
mask[:, start:, :] = 1
|
| 146 |
+
return mask
|
| 147 |
+
|
| 148 |
+
def mask_and(
|
| 149 |
+
mask1: torch.Tensor,
|
| 150 |
+
mask2: torch.Tensor
|
| 151 |
+
):
|
| 152 |
+
assert mask1.shape == mask2.shape, "masks must be same shape"
|
| 153 |
+
return torch.min(mask1, mask2)
|
| 154 |
+
|
| 155 |
+
def dropout(
|
| 156 |
+
mask: torch.Tensor,
|
| 157 |
+
p: float,
|
| 158 |
+
):
|
| 159 |
+
assert 0 <= p <= 1, "p must be between 0 and 1"
|
| 160 |
+
assert mask.max() <= 1, "mask must be binary"
|
| 161 |
+
assert mask.min() >= 0, "mask must be binary"
|
| 162 |
+
mask = (~mask.bool()).float()
|
| 163 |
+
mask = torch.bernoulli(mask * (1 - p))
|
| 164 |
+
mask = ~mask.round().bool()
|
| 165 |
+
return mask.long()
|
| 166 |
+
|
| 167 |
+
def mask_or(
|
| 168 |
+
mask1: torch.Tensor,
|
| 169 |
+
mask2: torch.Tensor
|
| 170 |
+
):
|
| 171 |
+
assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
|
| 172 |
+
assert mask1.max() <= 1, "mask1 must be binary"
|
| 173 |
+
assert mask2.max() <= 1, "mask2 must be binary"
|
| 174 |
+
assert mask1.min() >= 0, "mask1 must be binary"
|
| 175 |
+
assert mask2.min() >= 0, "mask2 must be binary"
|
| 176 |
+
return (mask1 + mask2).clamp(0, 1)
|
| 177 |
+
|
| 178 |
+
def time_stretch_mask(
|
| 179 |
+
x: torch.Tensor,
|
| 180 |
+
stretch_factor: int,
|
| 181 |
+
):
|
| 182 |
+
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
| 183 |
+
c_seq_len = x.shape[-1]
|
| 184 |
+
x = x.repeat_interleave(stretch_factor, dim=-1)
|
| 185 |
+
|
| 186 |
+
# trim cz to the original length
|
| 187 |
+
x = x[:, :, :c_seq_len]
|
| 188 |
+
|
| 189 |
+
mask = periodic_mask(x, stretch_factor, width=1)
|
| 190 |
+
return mask
|
| 191 |
+
|
| 192 |
+
def onset_mask(
|
| 193 |
+
sig: AudioSignal,
|
| 194 |
+
z: torch.Tensor,
|
| 195 |
+
interface,
|
| 196 |
+
width: int = 1
|
| 197 |
+
):
|
| 198 |
+
import librosa
|
| 199 |
+
import madmom
|
| 200 |
+
from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
|
| 201 |
+
import tempfile
|
| 202 |
+
import numpy as np
|
| 203 |
+
|
| 204 |
+
with tempfile.NamedTemporaryFile(suffix='.wav') as f:
|
| 205 |
+
sig = sig.clone()
|
| 206 |
+
sig.write(f.name)
|
| 207 |
+
|
| 208 |
+
proc = RNNOnsetProcessor(online=False)
|
| 209 |
+
onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
|
| 210 |
+
fps=sig.sample_rate/interface.codec.hop_length)
|
| 211 |
+
|
| 212 |
+
act = proc(f.name)
|
| 213 |
+
onset_times = onsetproc(act)
|
| 214 |
+
|
| 215 |
+
# convert to indices for z array
|
| 216 |
+
onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
|
| 217 |
+
|
| 218 |
+
if onset_indices.shape[0] == 0:
|
| 219 |
+
mask = empty_mask(z)
|
| 220 |
+
print(f"no onsets found, returning empty mask")
|
| 221 |
+
else:
|
| 222 |
+
torch.set_printoptions(threshold=1000)
|
| 223 |
+
print("onset indices: ", onset_indices)
|
| 224 |
+
print("onset times: ", onset_times)
|
| 225 |
+
|
| 226 |
+
# create a mask, set onset
|
| 227 |
+
mask = torch.ones_like(z)
|
| 228 |
+
n_timesteps = z.shape[-1]
|
| 229 |
+
|
| 230 |
+
for onset_index in onset_indices:
|
| 231 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
| 232 |
+
onset_index = max(onset_index, 0)
|
| 233 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
| 234 |
+
|
| 235 |
+
print(mask)
|
| 236 |
+
|
| 237 |
+
return mask
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
if __name__ == "__main__":
|
| 242 |
+
pass
|
vampnet/vampnet/modules/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import audiotools
|
| 2 |
+
|
| 3 |
+
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
| 4 |
+
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
| 5 |
+
|
| 6 |
+
from .transformer import VampNet
|
vampnet/vampnet/modules/activations.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class NewGELU(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Implementation of the GELU activation function currently in Google BERT repo
|
| 12 |
+
(identical to OpenAI GPT). Also see the Gaussian Error Linear Units
|
| 13 |
+
paper: https://arxiv.org/abs/1606.08415
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
return (
|
| 18 |
+
0.5
|
| 19 |
+
* x
|
| 20 |
+
* (
|
| 21 |
+
1.0
|
| 22 |
+
+ torch.tanh(
|
| 23 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
|
| 24 |
+
)
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
class GatedGELU(nn.Module):
|
| 29 |
+
def __init__(self):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.gelu = NewGELU()
|
| 32 |
+
|
| 33 |
+
def forward(self, x, dim: int = -1):
|
| 34 |
+
p1, p2 = x.chunk(2, dim=dim)
|
| 35 |
+
return p1 * self.gelu(p2)
|
| 36 |
+
|
| 37 |
+
class Snake1d(nn.Module):
|
| 38 |
+
def __init__(self, channels):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.alpha = nn.Parameter(torch.ones(channels))
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
| 44 |
+
|
| 45 |
+
def get_activation(name: str = "relu"):
|
| 46 |
+
if name == "relu":
|
| 47 |
+
return nn.ReLU
|
| 48 |
+
elif name == "gelu":
|
| 49 |
+
return NewGELU
|
| 50 |
+
elif name == "geglu":
|
| 51 |
+
return GatedGELU
|
| 52 |
+
elif name == "snake":
|
| 53 |
+
return Snake1d
|
| 54 |
+
else:
|
| 55 |
+
raise ValueError(f"Unrecognized activation {name}")
|
vampnet/vampnet/modules/layers.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
# Scripting this brings model speed up 1.4x
|
| 12 |
+
@torch.jit.script
|
| 13 |
+
def snake(x, alpha):
|
| 14 |
+
shape = x.shape
|
| 15 |
+
x = x.reshape(shape[0], shape[1], -1)
|
| 16 |
+
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
| 17 |
+
x = x.reshape(shape)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Snake1d(nn.Module):
|
| 22 |
+
def __init__(self, channels):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return snake(x, self.alpha)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def num_params(model):
|
| 31 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def recurse_children(module, fn):
|
| 35 |
+
for child in module.children():
|
| 36 |
+
if isinstance(child, nn.ModuleList):
|
| 37 |
+
for c in child:
|
| 38 |
+
yield recurse_children(c, fn)
|
| 39 |
+
if isinstance(child, nn.ModuleDict):
|
| 40 |
+
for c in child.values():
|
| 41 |
+
yield recurse_children(c, fn)
|
| 42 |
+
|
| 43 |
+
yield recurse_children(child, fn)
|
| 44 |
+
yield fn(child)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def WNConv1d(*args, **kwargs):
|
| 48 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 52 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SequentialWithFiLM(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
handy wrapper for nn.Sequential that allows FiLM layers to be
|
| 58 |
+
inserted in between other layers.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, *layers):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.layers = nn.ModuleList(layers)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def has_film(module):
|
| 67 |
+
mod_has_film = any(
|
| 68 |
+
[res for res in recurse_children(module, lambda c: isinstance(c, FiLM))]
|
| 69 |
+
)
|
| 70 |
+
return mod_has_film
|
| 71 |
+
|
| 72 |
+
def forward(self, x, cond):
|
| 73 |
+
for layer in self.layers:
|
| 74 |
+
if self.has_film(layer):
|
| 75 |
+
x = layer(x, cond)
|
| 76 |
+
else:
|
| 77 |
+
x = layer(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class FiLM(nn.Module):
|
| 82 |
+
def __init__(self, input_dim: int, output_dim: int):
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
self.input_dim = input_dim
|
| 86 |
+
self.output_dim = output_dim
|
| 87 |
+
|
| 88 |
+
if input_dim > 0:
|
| 89 |
+
self.beta = nn.Linear(input_dim, output_dim)
|
| 90 |
+
self.gamma = nn.Linear(input_dim, output_dim)
|
| 91 |
+
|
| 92 |
+
def forward(self, x, r):
|
| 93 |
+
if self.input_dim == 0:
|
| 94 |
+
return x
|
| 95 |
+
else:
|
| 96 |
+
beta, gamma = self.beta(r), self.gamma(r)
|
| 97 |
+
beta, gamma = (
|
| 98 |
+
beta.view(x.size(0), self.output_dim, 1),
|
| 99 |
+
gamma.view(x.size(0), self.output_dim, 1),
|
| 100 |
+
)
|
| 101 |
+
x = x * (gamma + 1) + beta
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class CodebookEmbedding(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
vocab_size: int,
|
| 109 |
+
latent_dim: int,
|
| 110 |
+
n_codebooks: int,
|
| 111 |
+
emb_dim: int,
|
| 112 |
+
special_tokens: Optional[Tuple[str]] = None,
|
| 113 |
+
):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.n_codebooks = n_codebooks
|
| 116 |
+
self.emb_dim = emb_dim
|
| 117 |
+
self.latent_dim = latent_dim
|
| 118 |
+
self.vocab_size = vocab_size
|
| 119 |
+
|
| 120 |
+
if special_tokens is not None:
|
| 121 |
+
for tkn in special_tokens:
|
| 122 |
+
self.special = nn.ParameterDict(
|
| 123 |
+
{
|
| 124 |
+
tkn: nn.Parameter(torch.randn(n_codebooks, self.latent_dim))
|
| 125 |
+
for tkn in special_tokens
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
self.special_idxs = {
|
| 129 |
+
tkn: i + vocab_size for i, tkn in enumerate(special_tokens)
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
| 133 |
+
|
| 134 |
+
def from_codes(self, codes: torch.Tensor, codec):
|
| 135 |
+
"""
|
| 136 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
| 137 |
+
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
| 138 |
+
necessary for the language model, like <MASK>.
|
| 139 |
+
"""
|
| 140 |
+
n_codebooks = codes.shape[1]
|
| 141 |
+
latent = []
|
| 142 |
+
for i in range(n_codebooks):
|
| 143 |
+
c = codes[:, i, :]
|
| 144 |
+
|
| 145 |
+
lookup_table = codec.quantizer.quantizers[i].codebook.weight
|
| 146 |
+
if hasattr(self, "special"):
|
| 147 |
+
special_lookup = torch.cat(
|
| 148 |
+
[self.special[tkn][i : i + 1] for tkn in self.special], dim=0
|
| 149 |
+
)
|
| 150 |
+
lookup_table = torch.cat([lookup_table, special_lookup], dim=0)
|
| 151 |
+
|
| 152 |
+
l = F.embedding(c, lookup_table).transpose(1, 2)
|
| 153 |
+
latent.append(l)
|
| 154 |
+
|
| 155 |
+
latent = torch.cat(latent, dim=1)
|
| 156 |
+
return latent
|
| 157 |
+
|
| 158 |
+
def forward(self, latents: torch.Tensor):
|
| 159 |
+
"""
|
| 160 |
+
project a sequence of latents to a sequence of embeddings
|
| 161 |
+
"""
|
| 162 |
+
x = self.out_proj(latents)
|
| 163 |
+
return x
|
| 164 |
+
|
vampnet/vampnet/modules/transformer.py
ADDED
|
@@ -0,0 +1,953 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import loralib as lora
|
| 11 |
+
import audiotools as at
|
| 12 |
+
|
| 13 |
+
from .activations import get_activation
|
| 14 |
+
from .layers import CodebookEmbedding
|
| 15 |
+
from .layers import FiLM
|
| 16 |
+
from .layers import SequentialWithFiLM
|
| 17 |
+
from .layers import WNConv1d
|
| 18 |
+
from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
|
| 19 |
+
from ..mask import _gamma
|
| 20 |
+
|
| 21 |
+
LORA_R = 8
|
| 22 |
+
|
| 23 |
+
# def log(t, eps=1e-20):
|
| 24 |
+
# return torch.log(t + eps)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def gumbel_noise_like(t):
|
| 28 |
+
noise = torch.zeros_like(t).uniform_(1e-20, 1)
|
| 29 |
+
return -torch.log(-torch.log(noise))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gumbel_sample(t, temperature=1.0, dim=-1):
|
| 33 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class RMSNorm(nn.Module):
|
| 37 |
+
def __init__(self, hidden_size: int, eps=1e-6):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 40 |
+
self.var_eps = eps
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
"""Returns root mean square normalized version of input `x`
|
| 44 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known
|
| 45 |
+
# as Root Mean Square Layer Normalization https://arxiv.org/abs/1910.07467
|
| 46 |
+
# thus varience is calculated w/o mean and there is no bias
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
x : Tensor[B x T x D]
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
Tensor[B x T x D]
|
| 53 |
+
"""
|
| 54 |
+
var = x.pow(2).mean(-1, keepdim=True)
|
| 55 |
+
x = x * torch.rsqrt(var + self.var_eps)
|
| 56 |
+
|
| 57 |
+
return self.weight * x
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class FeedForward(nn.Module):
|
| 61 |
+
def __init__(
|
| 62 |
+
self, d_model: int = 512, dropout: float = 0.1, activation: str = "geglu"
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
factor = 2 if activation == "geglu" else 1
|
| 66 |
+
self.w_1 = lora.Linear(d_model, d_model * 4, bias=False, r=LORA_R)
|
| 67 |
+
self.w_2 = lora.Linear(d_model * 4 // factor, d_model, bias=False, r=LORA_R)
|
| 68 |
+
self.drop = nn.Dropout(dropout)
|
| 69 |
+
self.act = get_activation(activation)()
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
"""Computes position-wise feed-forward layer
|
| 73 |
+
Parameters
|
| 74 |
+
----------
|
| 75 |
+
x : Tensor[B x T x D]
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
Tensor[B x T x D]
|
| 79 |
+
"""
|
| 80 |
+
x = self.w_1(x)
|
| 81 |
+
x = self.act(x)
|
| 82 |
+
x = self.drop(x)
|
| 83 |
+
x = self.w_2(x)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MultiHeadRelativeAttention(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
n_head: int = 8,
|
| 91 |
+
d_model: int = 512,
|
| 92 |
+
dropout: float = 0.1,
|
| 93 |
+
bidirectional: bool = True,
|
| 94 |
+
has_relative_attention_bias: bool = True,
|
| 95 |
+
attention_num_buckets: int = 32,
|
| 96 |
+
attention_max_distance: int = 128,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
d_head = d_model // n_head
|
| 100 |
+
self.n_head = n_head
|
| 101 |
+
self.d_head = d_head
|
| 102 |
+
self.bidirectional = bidirectional
|
| 103 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
| 104 |
+
self.attention_num_buckets = attention_num_buckets
|
| 105 |
+
self.attention_max_distance = attention_max_distance
|
| 106 |
+
|
| 107 |
+
# Create linear query, key, value projections
|
| 108 |
+
self.w_qs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
| 109 |
+
self.w_ks = nn.Linear(d_model, d_model, bias=False)
|
| 110 |
+
self.w_vs = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
| 111 |
+
|
| 112 |
+
# Create linear final output projection
|
| 113 |
+
self.fc = lora.Linear(d_model, d_model, bias=False, r=LORA_R)
|
| 114 |
+
|
| 115 |
+
# Dropout for attention output weights
|
| 116 |
+
self.dropout = nn.Dropout(dropout)
|
| 117 |
+
|
| 118 |
+
# Create relative positional embeddings (if turned on)
|
| 119 |
+
if has_relative_attention_bias:
|
| 120 |
+
self.relative_attention_bias = nn.Embedding(attention_num_buckets, n_head)
|
| 121 |
+
|
| 122 |
+
def _relative_position_bucket(self, relative_position):
|
| 123 |
+
"""Converts unbounded relative position into bounded set of buckets
|
| 124 |
+
with half "exact" buckets (1 position = 1 bucket) and half "log-spaced"
|
| 125 |
+
buckets
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
relative_position : Tensor[T_q x T_kv]
|
| 129 |
+
Relative positions between queries and key_value items
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
Tensor[T_q x T_kv]
|
| 133 |
+
Input relative positions converted into buckets
|
| 134 |
+
"""
|
| 135 |
+
relative_buckets = 0
|
| 136 |
+
num_buckets = self.attention_num_buckets
|
| 137 |
+
max_distance = self.attention_max_distance
|
| 138 |
+
|
| 139 |
+
# Convert relative position for (-inf, inf) to [0, inf]
|
| 140 |
+
# Negative relative positions correspond to past
|
| 141 |
+
# Positive relative positions correspond to future
|
| 142 |
+
if self.bidirectional:
|
| 143 |
+
# use half buckets for each side (past / future)
|
| 144 |
+
num_buckets //= 2
|
| 145 |
+
|
| 146 |
+
# Shift the position positions by `num_buckets` to wrap around
|
| 147 |
+
# negative positions
|
| 148 |
+
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
|
| 149 |
+
relative_position = torch.abs(relative_position)
|
| 150 |
+
else:
|
| 151 |
+
# If not bidirectional, ignore positive positions and wrap
|
| 152 |
+
# negative positions to positive
|
| 153 |
+
relative_position = -torch.min(
|
| 154 |
+
relative_position, torch.zeros_like(relative_position)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Allocate half of the buckets are for exact increments in positions
|
| 158 |
+
max_exact = num_buckets // 2
|
| 159 |
+
is_small = relative_position < max_exact
|
| 160 |
+
|
| 161 |
+
# The other half of the buckets are for logarithmically bigger bins in
|
| 162 |
+
# positions up to `max_distance`
|
| 163 |
+
relative_postion_if_large = max_exact + (
|
| 164 |
+
torch.log(relative_position.float() / max_exact)
|
| 165 |
+
/ math.log(max_distance / max_exact)
|
| 166 |
+
* (num_buckets - max_exact)
|
| 167 |
+
).to(torch.long)
|
| 168 |
+
|
| 169 |
+
# Clip the max relative position to `num_buckets - 1`
|
| 170 |
+
relative_postion_if_large = torch.min(
|
| 171 |
+
relative_postion_if_large,
|
| 172 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Choose relative buckets based on small or large positions
|
| 176 |
+
relative_buckets += torch.where(
|
| 177 |
+
is_small, relative_position, relative_postion_if_large
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return relative_buckets
|
| 181 |
+
|
| 182 |
+
def compute_bias(self, query_length, key_length):
|
| 183 |
+
"""Computes a position bias scalar for each index in query_length x key_length
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
query_length : int
|
| 187 |
+
key_length : int
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
Tensor[heads x 1 x T_q x T_kv]
|
| 191 |
+
Position bias to be applied on attention logits
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
query_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
| 195 |
+
key_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
| 196 |
+
relative_position = key_position - query_position
|
| 197 |
+
|
| 198 |
+
# Convert relative position to buckets
|
| 199 |
+
relative_position_bucket = self._relative_position_bucket(relative_position)
|
| 200 |
+
relative_position_bucket = relative_position_bucket.to(
|
| 201 |
+
self.relative_attention_bias.weight.device
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Index attention bias values
|
| 205 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
| 206 |
+
values = rearrange(values, "q k h -> h 1 q k")
|
| 207 |
+
|
| 208 |
+
return values
|
| 209 |
+
|
| 210 |
+
def forward(self, q, k, v, mask=None, position_bias=None):
|
| 211 |
+
"""Computes attention over (keys, values) for every timestep in query
|
| 212 |
+
Parameters
|
| 213 |
+
----------
|
| 214 |
+
q : Tensor[B x T_q x d_model]
|
| 215 |
+
Query vectors
|
| 216 |
+
k : Tensor[B x T_kv x d_model]
|
| 217 |
+
Key vectors to compute attention over
|
| 218 |
+
v : Tensor[B x T_kv x d_model]
|
| 219 |
+
Value vectors corresponding to the keys
|
| 220 |
+
mask : Tensor[B x T_q x T_kv], optional
|
| 221 |
+
position_bias: Tensor[head x 1 x T_q x T_kv]
|
| 222 |
+
Returns
|
| 223 |
+
-------
|
| 224 |
+
Tensor[B x T_q x d_model]
|
| 225 |
+
Outputs after attending (key, value) using queries
|
| 226 |
+
"""
|
| 227 |
+
# Compute query, key, value projections
|
| 228 |
+
q = rearrange(self.w_qs(q), "b l (head k) -> head b l k", head=self.n_head)
|
| 229 |
+
k = rearrange(self.w_ks(k), "b t (head k) -> head b t k", head=self.n_head)
|
| 230 |
+
v = rearrange(self.w_vs(v), "b t (head k) -> head b t k", head=self.n_head)
|
| 231 |
+
|
| 232 |
+
# Compute attention matrix
|
| 233 |
+
attn = torch.einsum("hblk,hbtk->hblt", [q, k]) / np.sqrt(q.shape[-1])
|
| 234 |
+
|
| 235 |
+
# Add relative position bias to attention scores
|
| 236 |
+
if position_bias is None:
|
| 237 |
+
if self.has_relative_attention_bias:
|
| 238 |
+
position_bias = self.compute_bias(q.size(-2), k.size(-2))
|
| 239 |
+
else:
|
| 240 |
+
position_bias = torch.zeros_like(attn)
|
| 241 |
+
attn += position_bias
|
| 242 |
+
|
| 243 |
+
# Apply mask to attention scores to prevent looking up invalid locations
|
| 244 |
+
if mask is not None:
|
| 245 |
+
attn = attn.masked_fill(mask[None] == 0, -1e9)
|
| 246 |
+
|
| 247 |
+
# Normalize attention scores and add dropout
|
| 248 |
+
attn = torch.softmax(attn, dim=3)
|
| 249 |
+
attn = self.dropout(attn)
|
| 250 |
+
|
| 251 |
+
# Compute attended outputs (product of attention matrix and values)
|
| 252 |
+
output = torch.einsum("hblt,hbtv->hblv", [attn, v])
|
| 253 |
+
output = rearrange(output, "head b l v -> b l (head v)")
|
| 254 |
+
output = self.fc(output)
|
| 255 |
+
|
| 256 |
+
return output, position_bias
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class TransformerLayer(nn.Module):
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
d_model: int = 512,
|
| 263 |
+
d_cond: int = 64,
|
| 264 |
+
n_heads: int = 8,
|
| 265 |
+
bidirectional: bool = True,
|
| 266 |
+
is_decoder: bool = False,
|
| 267 |
+
has_relative_attention_bias: bool = False,
|
| 268 |
+
flash_attn: bool = False,
|
| 269 |
+
dropout: float = 0.1,
|
| 270 |
+
):
|
| 271 |
+
super().__init__()
|
| 272 |
+
# Store args
|
| 273 |
+
self.is_decoder = is_decoder
|
| 274 |
+
|
| 275 |
+
# Create self-attention layer
|
| 276 |
+
self.norm_1 = RMSNorm(d_model)
|
| 277 |
+
self.film_1 = FiLM(d_cond, d_model)
|
| 278 |
+
self.flash_attn = flash_attn
|
| 279 |
+
|
| 280 |
+
if flash_attn:
|
| 281 |
+
from flash_attn.flash_attention import FlashMHA
|
| 282 |
+
self.self_attn = FlashMHA(
|
| 283 |
+
embed_dim=d_model,
|
| 284 |
+
num_heads=n_heads,
|
| 285 |
+
attention_dropout=dropout,
|
| 286 |
+
causal=False,
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
self.self_attn = MultiHeadRelativeAttention(
|
| 290 |
+
n_heads, d_model, dropout, bidirectional, has_relative_attention_bias
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# (Optional) Create cross-attention layer
|
| 294 |
+
if is_decoder:
|
| 295 |
+
self.norm_2 = RMSNorm(d_model)
|
| 296 |
+
self.film_2 = FiLM(d_cond, d_model)
|
| 297 |
+
self.cross_attn = MultiHeadRelativeAttention(
|
| 298 |
+
n_heads,
|
| 299 |
+
d_model,
|
| 300 |
+
dropout,
|
| 301 |
+
bidirectional=True,
|
| 302 |
+
has_relative_attention_bias=False,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Create last feed-forward layer
|
| 306 |
+
self.norm_3 = RMSNorm(d_model)
|
| 307 |
+
self.film_3 = FiLM(d_cond, d_model)
|
| 308 |
+
self.feed_forward = FeedForward(d_model=d_model, dropout=dropout)
|
| 309 |
+
|
| 310 |
+
# Create dropout
|
| 311 |
+
self.dropout = nn.Dropout(dropout)
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self,
|
| 315 |
+
x,
|
| 316 |
+
x_mask,
|
| 317 |
+
cond,
|
| 318 |
+
src=None,
|
| 319 |
+
src_mask=None,
|
| 320 |
+
position_bias=None,
|
| 321 |
+
encoder_decoder_position_bias=None,
|
| 322 |
+
):
|
| 323 |
+
"""Computes one transformer layer consisting of self attention, (op) cross attention
|
| 324 |
+
and feedforward layer
|
| 325 |
+
Parameters
|
| 326 |
+
----------
|
| 327 |
+
x : Tensor[B x T_q x D]
|
| 328 |
+
x_mask : Tensor[B x T_q]
|
| 329 |
+
src : Tensor[B x T_kv x D], optional
|
| 330 |
+
src_mask : Tensor[B x T_kv x D], optional
|
| 331 |
+
position_bias : Tensor[heads x B x T_q x T_q], optional
|
| 332 |
+
Relative position bias for self attention layer
|
| 333 |
+
encoder_decoder_position_bias : Tensor[heads x B x T_q x T_kv], optional
|
| 334 |
+
Relative position bias for cross attention layer
|
| 335 |
+
Returns
|
| 336 |
+
-------
|
| 337 |
+
Tensor[B x T_q x D]
|
| 338 |
+
"""
|
| 339 |
+
y = self.norm_1(x)
|
| 340 |
+
y = self.film_1(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
| 341 |
+
if self.flash_attn:
|
| 342 |
+
with torch.autocast(y.device.type, dtype=torch.bfloat16):
|
| 343 |
+
y = self.self_attn(y)[0]
|
| 344 |
+
else:
|
| 345 |
+
y, position_bias = self.self_attn(y, y, y, x_mask, position_bias)
|
| 346 |
+
x = x + self.dropout(y)
|
| 347 |
+
|
| 348 |
+
if self.is_decoder:
|
| 349 |
+
y = self.norm_2(x)
|
| 350 |
+
y = self.film_2(y.permute(0, 2, 1), cond).permute(0, 2, 1)
|
| 351 |
+
y, encoder_decoder_position_bias = self.cross_attn(
|
| 352 |
+
y, src, src, src_mask, encoder_decoder_position_bias
|
| 353 |
+
)
|
| 354 |
+
x = x + self.dropout(y)
|
| 355 |
+
|
| 356 |
+
y = self.norm_3(x)
|
| 357 |
+
y = self.film_3(
|
| 358 |
+
y.permute(
|
| 359 |
+
0,
|
| 360 |
+
2,
|
| 361 |
+
1,
|
| 362 |
+
),
|
| 363 |
+
cond,
|
| 364 |
+
).permute(0, 2, 1)
|
| 365 |
+
y = self.feed_forward(y)
|
| 366 |
+
x = x + self.dropout(y)
|
| 367 |
+
|
| 368 |
+
return x, position_bias, encoder_decoder_position_bias
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
class TransformerStack(nn.Module):
|
| 372 |
+
def __init__(
|
| 373 |
+
self,
|
| 374 |
+
d_model: int = 512,
|
| 375 |
+
d_cond: int = 64,
|
| 376 |
+
n_heads: int = 8,
|
| 377 |
+
n_layers: int = 8,
|
| 378 |
+
last_layer: bool = True,
|
| 379 |
+
bidirectional: bool = True,
|
| 380 |
+
flash_attn: bool = False,
|
| 381 |
+
is_decoder: bool = False,
|
| 382 |
+
dropout: float = 0.1,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
# Store args
|
| 386 |
+
self.bidirectional = bidirectional
|
| 387 |
+
self.is_decoder = is_decoder
|
| 388 |
+
|
| 389 |
+
# Create transformer layers
|
| 390 |
+
# In T5, relative attention bias is shared by all layers in the stack
|
| 391 |
+
self.layers = nn.ModuleList(
|
| 392 |
+
[
|
| 393 |
+
TransformerLayer(
|
| 394 |
+
d_model,
|
| 395 |
+
d_cond,
|
| 396 |
+
n_heads,
|
| 397 |
+
bidirectional,
|
| 398 |
+
is_decoder,
|
| 399 |
+
has_relative_attention_bias=True if (i == 0) else False,
|
| 400 |
+
flash_attn=flash_attn,
|
| 401 |
+
dropout=dropout,
|
| 402 |
+
)
|
| 403 |
+
for i in range(n_layers)
|
| 404 |
+
]
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Perform last normalization
|
| 408 |
+
self.norm = RMSNorm(d_model) if last_layer else None
|
| 409 |
+
|
| 410 |
+
def subsequent_mask(self, size):
|
| 411 |
+
return torch.ones(1, size, size).tril().bool()
|
| 412 |
+
|
| 413 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
|
| 414 |
+
return_activations: bool = False
|
| 415 |
+
):
|
| 416 |
+
"""Computes a full transformer stack
|
| 417 |
+
Parameters
|
| 418 |
+
----------
|
| 419 |
+
x : Tensor[B x T_q x D]
|
| 420 |
+
x_mask : Tensor[B x T_q]
|
| 421 |
+
src : Tensor[B x T_kv x D], optional
|
| 422 |
+
src_mask : Tensor[B x T_kv], optional
|
| 423 |
+
Returns
|
| 424 |
+
-------
|
| 425 |
+
Tensor[B x T_q x D]
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
# Convert `src_mask` to (B x T_q x T_kv) shape for cross attention masking
|
| 429 |
+
if self.is_decoder:
|
| 430 |
+
src_mask = x_mask.unsqueeze(-1) * src_mask.unsqueeze(-2)
|
| 431 |
+
|
| 432 |
+
# Convert `x_mask` to (B x T_q x T_q) shape for self attention masking
|
| 433 |
+
x_mask = x_mask.unsqueeze(-2)
|
| 434 |
+
if not self.bidirectional:
|
| 435 |
+
x_mask = x_mask * self.subsequent_mask(x.size(1)).to(x_mask.device)
|
| 436 |
+
|
| 437 |
+
# Initialize position biases
|
| 438 |
+
position_bias = None
|
| 439 |
+
encoder_decoder_position_bias = None
|
| 440 |
+
|
| 441 |
+
# Compute transformer layers
|
| 442 |
+
if return_activations:
|
| 443 |
+
activations = []
|
| 444 |
+
for layer in self.layers:
|
| 445 |
+
x, position_bias, encoder_decoder_position_bias = layer(
|
| 446 |
+
x=x,
|
| 447 |
+
x_mask=x_mask,
|
| 448 |
+
cond=cond,
|
| 449 |
+
src=src,
|
| 450 |
+
src_mask=src_mask,
|
| 451 |
+
position_bias=position_bias,
|
| 452 |
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
| 453 |
+
)
|
| 454 |
+
if return_activations:
|
| 455 |
+
activations.append(x.detach())
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
out = self.norm(x) if self.norm is not None else x
|
| 459 |
+
if return_activations:
|
| 460 |
+
return out, torch.stack(activations)
|
| 461 |
+
else:
|
| 462 |
+
return out
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class VampNet(at.ml.BaseModel):
|
| 466 |
+
def __init__(
|
| 467 |
+
self,
|
| 468 |
+
n_heads: int = 20,
|
| 469 |
+
n_layers: int = 16,
|
| 470 |
+
r_cond_dim: int = 0,
|
| 471 |
+
n_codebooks: int = 9,
|
| 472 |
+
n_conditioning_codebooks: int = 0,
|
| 473 |
+
latent_dim: int = 8,
|
| 474 |
+
embedding_dim: int = 1280,
|
| 475 |
+
vocab_size: int = 1024,
|
| 476 |
+
flash_attn: bool = True,
|
| 477 |
+
noise_mode: str = "mask",
|
| 478 |
+
dropout: float = 0.1
|
| 479 |
+
):
|
| 480 |
+
super().__init__()
|
| 481 |
+
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
| 482 |
+
self.n_heads = n_heads
|
| 483 |
+
self.n_layers = n_layers
|
| 484 |
+
self.r_cond_dim = r_cond_dim
|
| 485 |
+
self.n_codebooks = n_codebooks
|
| 486 |
+
self.n_conditioning_codebooks = n_conditioning_codebooks
|
| 487 |
+
self.embedding_dim = embedding_dim
|
| 488 |
+
self.vocab_size = vocab_size
|
| 489 |
+
self.latent_dim = latent_dim
|
| 490 |
+
self.flash_attn = flash_attn
|
| 491 |
+
self.noise_mode = noise_mode
|
| 492 |
+
|
| 493 |
+
assert self.noise_mode == "mask", "deprecated"
|
| 494 |
+
|
| 495 |
+
self.embedding = CodebookEmbedding(
|
| 496 |
+
latent_dim=latent_dim,
|
| 497 |
+
n_codebooks=n_codebooks,
|
| 498 |
+
vocab_size=vocab_size,
|
| 499 |
+
emb_dim=embedding_dim,
|
| 500 |
+
special_tokens=["MASK"],
|
| 501 |
+
)
|
| 502 |
+
self.mask_token = self.embedding.special_idxs["MASK"]
|
| 503 |
+
|
| 504 |
+
self.transformer = TransformerStack(
|
| 505 |
+
d_model=embedding_dim,
|
| 506 |
+
d_cond=r_cond_dim,
|
| 507 |
+
n_heads=n_heads,
|
| 508 |
+
n_layers=n_layers,
|
| 509 |
+
last_layer=True,
|
| 510 |
+
bidirectional=True,
|
| 511 |
+
flash_attn=flash_attn,
|
| 512 |
+
is_decoder=False,
|
| 513 |
+
dropout=dropout,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Add final conv layer
|
| 517 |
+
self.n_predict_codebooks = n_codebooks - n_conditioning_codebooks
|
| 518 |
+
self.classifier = SequentialWithFiLM(
|
| 519 |
+
WNConv1d(
|
| 520 |
+
embedding_dim,
|
| 521 |
+
vocab_size * self.n_predict_codebooks,
|
| 522 |
+
kernel_size=1,
|
| 523 |
+
padding="same",
|
| 524 |
+
# groups=self.n_predict_codebooks,
|
| 525 |
+
),
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
def forward(self, x, return_activations: bool = False):
|
| 529 |
+
x = self.embedding(x)
|
| 530 |
+
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
| 531 |
+
|
| 532 |
+
x = rearrange(x, "b d n -> b n d")
|
| 533 |
+
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
| 534 |
+
if return_activations:
|
| 535 |
+
out, activations = out
|
| 536 |
+
|
| 537 |
+
out = rearrange(out, "b n d -> b d n")
|
| 538 |
+
|
| 539 |
+
out = self.classifier(out, None) # no cond here!
|
| 540 |
+
|
| 541 |
+
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
| 542 |
+
|
| 543 |
+
if return_activations:
|
| 544 |
+
return out, activations
|
| 545 |
+
else:
|
| 546 |
+
return out
|
| 547 |
+
|
| 548 |
+
def r_embed(self, r, max_positions=10000):
|
| 549 |
+
if self.r_cond_dim > 0:
|
| 550 |
+
dtype = r.dtype
|
| 551 |
+
|
| 552 |
+
r = _gamma(r) * max_positions
|
| 553 |
+
half_dim = self.r_cond_dim // 2
|
| 554 |
+
|
| 555 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
| 556 |
+
emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
|
| 557 |
+
|
| 558 |
+
emb = r[:, None] * emb[None, :]
|
| 559 |
+
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
|
| 560 |
+
|
| 561 |
+
if self.r_cond_dim % 2 == 1: # zero pad
|
| 562 |
+
emb = nn.functional.pad(emb, (0, 1), mode="constant")
|
| 563 |
+
|
| 564 |
+
return emb.to(dtype)
|
| 565 |
+
else:
|
| 566 |
+
return r
|
| 567 |
+
|
| 568 |
+
@torch.no_grad()
|
| 569 |
+
def to_signal(self, z, codec):
|
| 570 |
+
"""
|
| 571 |
+
convert a sequence of latents to a signal.
|
| 572 |
+
"""
|
| 573 |
+
assert z.ndim == 3
|
| 574 |
+
|
| 575 |
+
signal = at.AudioSignal(
|
| 576 |
+
codec.decode(
|
| 577 |
+
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
| 578 |
+
)["audio"],
|
| 579 |
+
codec.sample_rate,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# find where the mask token is and replace it with silence in the audio
|
| 583 |
+
for tstep in range(z.shape[-1]):
|
| 584 |
+
if torch.any(z[:, :, tstep] == self.mask_token):
|
| 585 |
+
sample_idx_0 = tstep * codec.hop_length
|
| 586 |
+
sample_idx_1 = sample_idx_0 + codec.hop_length
|
| 587 |
+
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
| 588 |
+
|
| 589 |
+
return signal
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@torch.no_grad()
|
| 593 |
+
def generate(
|
| 594 |
+
self,
|
| 595 |
+
codec,
|
| 596 |
+
time_steps: int = 300,
|
| 597 |
+
sampling_steps: int = 36,
|
| 598 |
+
start_tokens: Optional[torch.Tensor] = None,
|
| 599 |
+
sampling_temperature: float = 1.0,
|
| 600 |
+
mask: Optional[torch.Tensor] = None,
|
| 601 |
+
mask_temperature: float = 10.5,
|
| 602 |
+
typical_filtering=False,
|
| 603 |
+
typical_mass=0.2,
|
| 604 |
+
typical_min_tokens=1,
|
| 605 |
+
top_p=None,
|
| 606 |
+
return_signal=True,
|
| 607 |
+
seed: int = None,
|
| 608 |
+
sample_cutoff: float = 1.0,
|
| 609 |
+
):
|
| 610 |
+
if seed is not None:
|
| 611 |
+
at.util.seed(seed)
|
| 612 |
+
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
#####################
|
| 617 |
+
# resolve initial z #
|
| 618 |
+
#####################
|
| 619 |
+
z = start_tokens
|
| 620 |
+
|
| 621 |
+
if z is None:
|
| 622 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
| 623 |
+
self.device
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
logging.debug(f"created z with shape {z.shape}")
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
#################
|
| 630 |
+
# resolve mask #
|
| 631 |
+
#################
|
| 632 |
+
|
| 633 |
+
if mask is None:
|
| 634 |
+
mask = torch.ones_like(z).to(self.device).int()
|
| 635 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
| 636 |
+
if mask.ndim == 2:
|
| 637 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
| 638 |
+
# init_mask = mask.clone()
|
| 639 |
+
|
| 640 |
+
logging.debug(f"created mask with shape {mask.shape}")
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
###########
|
| 644 |
+
# set up #
|
| 645 |
+
##########
|
| 646 |
+
# apply the mask to z
|
| 647 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
| 648 |
+
# logging.debug(f"z_masked: {z_masked}")
|
| 649 |
+
|
| 650 |
+
# how many mask tokens to begin with?
|
| 651 |
+
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
| 652 |
+
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
| 653 |
+
|
| 654 |
+
# how many codebooks are we inferring vs conditioning on?
|
| 655 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
| 656 |
+
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
| 657 |
+
|
| 658 |
+
#################
|
| 659 |
+
# begin sampling #
|
| 660 |
+
#################
|
| 661 |
+
|
| 662 |
+
for i in range(sampling_steps):
|
| 663 |
+
logging.debug(f"step {i} of {sampling_steps}")
|
| 664 |
+
|
| 665 |
+
# our current schedule step
|
| 666 |
+
r = scalar_to_batch_tensor(
|
| 667 |
+
(i + 1) / sampling_steps,
|
| 668 |
+
z.shape[0]
|
| 669 |
+
).to(z.device)
|
| 670 |
+
logging.debug(f"r: {r}")
|
| 671 |
+
|
| 672 |
+
# get latents
|
| 673 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
| 674 |
+
logging.debug(f"computed latents with shape: {latents.shape}")
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
# infer from latents
|
| 678 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 679 |
+
logits = self.forward(latents) # b, prob, seq
|
| 680 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 681 |
+
b = logits.shape[0]
|
| 682 |
+
|
| 683 |
+
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 684 |
+
|
| 685 |
+
sampled_z, selected_probs = sample_from_logits(
|
| 686 |
+
logits, sample=(
|
| 687 |
+
(i / sampling_steps) <= sample_cutoff
|
| 688 |
+
),
|
| 689 |
+
temperature=sampling_temperature,
|
| 690 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
| 691 |
+
typical_min_tokens=typical_min_tokens,
|
| 692 |
+
top_k=None, top_p=top_p, return_probs=True,
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 696 |
+
|
| 697 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 698 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
| 699 |
+
# remove conditioning codebooks, we'll add them back at the end
|
| 700 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
| 701 |
+
|
| 702 |
+
mask = (z_masked == self.mask_token).int()
|
| 703 |
+
|
| 704 |
+
# update the mask, remove conditioning codebooks from the mask
|
| 705 |
+
logging.debug(f"updated mask with shape: {mask.shape}")
|
| 706 |
+
# add z back into sampled z where the mask was false
|
| 707 |
+
sampled_z = torch.where(
|
| 708 |
+
mask.bool(), sampled_z, z_masked
|
| 709 |
+
)
|
| 710 |
+
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
| 711 |
+
|
| 712 |
+
# ignore any tokens that weren't masked
|
| 713 |
+
selected_probs = torch.where(
|
| 714 |
+
mask.bool(), selected_probs, torch.inf
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# get the num tokens to mask, according to the schedule
|
| 718 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
| 719 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
| 720 |
+
|
| 721 |
+
if i != (sampling_steps - 1):
|
| 722 |
+
num_to_mask = torch.maximum(
|
| 723 |
+
torch.tensor(1),
|
| 724 |
+
torch.minimum(
|
| 725 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
| 726 |
+
num_to_mask
|
| 727 |
+
)
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# get our new mask
|
| 732 |
+
mask = mask_by_random_topk(
|
| 733 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
# update the mask
|
| 737 |
+
z_masked = torch.where(
|
| 738 |
+
mask.bool(), self.mask_token, sampled_z
|
| 739 |
+
)
|
| 740 |
+
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
| 741 |
+
|
| 742 |
+
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
| 743 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
| 744 |
+
logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
|
| 745 |
+
|
| 746 |
+
# add conditioning codebooks back to z_masked
|
| 747 |
+
z_masked = torch.cat(
|
| 748 |
+
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
| 749 |
+
)
|
| 750 |
+
logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# add conditioning codebooks back to sampled_z
|
| 754 |
+
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
| 755 |
+
sampled_z = torch.cat(
|
| 756 |
+
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
logging.debug(f"finished sampling")
|
| 760 |
+
|
| 761 |
+
if return_signal:
|
| 762 |
+
return self.to_signal(sampled_z, codec)
|
| 763 |
+
else:
|
| 764 |
+
return sampled_z
|
| 765 |
+
|
| 766 |
+
def sample_from_logits(
|
| 767 |
+
logits,
|
| 768 |
+
sample: bool = True,
|
| 769 |
+
temperature: float = 1.0,
|
| 770 |
+
top_k: int = None,
|
| 771 |
+
top_p: float = None,
|
| 772 |
+
typical_filtering: bool = False,
|
| 773 |
+
typical_mass: float = 0.2,
|
| 774 |
+
typical_min_tokens: int = 1,
|
| 775 |
+
return_probs: bool = False
|
| 776 |
+
):
|
| 777 |
+
"""Convenience function to sample from a categorial distribution with input as
|
| 778 |
+
unnormalized logits.
|
| 779 |
+
|
| 780 |
+
Parameters
|
| 781 |
+
----------
|
| 782 |
+
logits : Tensor[..., vocab_size]
|
| 783 |
+
config: SamplingConfig
|
| 784 |
+
The set of hyperparameters to be used for sampling
|
| 785 |
+
sample : bool, optional
|
| 786 |
+
Whether to perform multinomial sampling, by default True
|
| 787 |
+
temperature : float, optional
|
| 788 |
+
Scaling parameter when multinomial samping, by default 1.0
|
| 789 |
+
top_k : int, optional
|
| 790 |
+
Restricts sampling to only `top_k` values acc. to probability,
|
| 791 |
+
by default None
|
| 792 |
+
top_p : float, optional
|
| 793 |
+
Restricts sampling to only those values with cumulative
|
| 794 |
+
probability = `top_p`, by default None
|
| 795 |
+
|
| 796 |
+
Returns
|
| 797 |
+
-------
|
| 798 |
+
Tensor[...]
|
| 799 |
+
Sampled tokens
|
| 800 |
+
"""
|
| 801 |
+
shp = logits.shape[:-1]
|
| 802 |
+
|
| 803 |
+
if typical_filtering:
|
| 804 |
+
typical_filter(logits,
|
| 805 |
+
typical_mass=typical_mass,
|
| 806 |
+
typical_min_tokens=typical_min_tokens
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Apply top_k sampling
|
| 810 |
+
if top_k is not None:
|
| 811 |
+
v, _ = logits.topk(top_k)
|
| 812 |
+
logits[logits < v[..., [-1]]] = -float("inf")
|
| 813 |
+
|
| 814 |
+
# Apply top_p (nucleus) sampling
|
| 815 |
+
if top_p is not None and top_p < 1.0:
|
| 816 |
+
v, sorted_indices = logits.sort(descending=True)
|
| 817 |
+
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
| 818 |
+
|
| 819 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 820 |
+
# Right shift indices_to_remove to keep 1st token over threshold
|
| 821 |
+
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
| 822 |
+
..., :-1
|
| 823 |
+
]
|
| 824 |
+
|
| 825 |
+
# Compute indices_to_remove in unsorted array
|
| 826 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 827 |
+
-1, sorted_indices, sorted_indices_to_remove
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
logits[indices_to_remove] = -float("inf")
|
| 831 |
+
|
| 832 |
+
# Perform multinomial sampling after normalizing logits
|
| 833 |
+
probs = (
|
| 834 |
+
F.softmax(logits / temperature, dim=-1)
|
| 835 |
+
if temperature > 0
|
| 836 |
+
else logits.softmax(dim=-1)
|
| 837 |
+
)
|
| 838 |
+
token = (
|
| 839 |
+
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
| 840 |
+
if sample
|
| 841 |
+
else logits.argmax(-1)
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
if return_probs:
|
| 845 |
+
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 846 |
+
return token, token_probs
|
| 847 |
+
else:
|
| 848 |
+
return token
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def mask_by_random_topk(
|
| 853 |
+
num_to_mask: int,
|
| 854 |
+
probs: torch.Tensor,
|
| 855 |
+
temperature: float = 1.0,
|
| 856 |
+
):
|
| 857 |
+
"""
|
| 858 |
+
Args:
|
| 859 |
+
num_to_mask (int): number of tokens to mask
|
| 860 |
+
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
| 861 |
+
temperature (float, optional): temperature. Defaults to 1.0.
|
| 862 |
+
"""
|
| 863 |
+
logging.debug(f"masking by random topk")
|
| 864 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
| 865 |
+
logging.debug(f"probs shape: {probs.shape}")
|
| 866 |
+
logging.debug(f"temperature: {temperature}")
|
| 867 |
+
logging.debug("")
|
| 868 |
+
|
| 869 |
+
noise = gumbel_noise_like(probs)
|
| 870 |
+
confidence = torch.log(probs) + temperature * noise
|
| 871 |
+
logging.debug(f"confidence shape: {confidence.shape}")
|
| 872 |
+
|
| 873 |
+
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
| 874 |
+
logging.debug(f"sorted confidence shape: {sorted_confidence.shape}")
|
| 875 |
+
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
|
| 876 |
+
|
| 877 |
+
# get the cut off threshold, given the mask length
|
| 878 |
+
cut_off = torch.take_along_dim(
|
| 879 |
+
sorted_confidence, num_to_mask, axis=-1
|
| 880 |
+
)
|
| 881 |
+
logging.debug(f"cut off shape: {cut_off.shape}")
|
| 882 |
+
|
| 883 |
+
# mask out the tokens
|
| 884 |
+
mask = confidence < cut_off
|
| 885 |
+
logging.debug(f"mask shape: {mask.shape}")
|
| 886 |
+
|
| 887 |
+
return mask
|
| 888 |
+
|
| 889 |
+
def typical_filter(
|
| 890 |
+
logits,
|
| 891 |
+
typical_mass: float = 0.95,
|
| 892 |
+
typical_min_tokens: int = 1,):
|
| 893 |
+
nb, nt, _ = logits.shape
|
| 894 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
| 895 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
| 896 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
| 897 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
| 898 |
+
|
| 899 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
| 900 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
| 901 |
+
x_flat_cumsum = (
|
| 902 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
| 906 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
| 907 |
+
1, last_ind.view(-1, 1)
|
| 908 |
+
)
|
| 909 |
+
if typical_min_tokens > 1:
|
| 910 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
| 911 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 912 |
+
1, x_flat_indices, sorted_indices_to_remove
|
| 913 |
+
)
|
| 914 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
| 915 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
| 916 |
+
return logits
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
if __name__ == "__main__":
|
| 920 |
+
# import argbind
|
| 921 |
+
from .layers import num_params
|
| 922 |
+
|
| 923 |
+
VampNet = argbind.bind(VampNet)
|
| 924 |
+
|
| 925 |
+
@argbind.bind(without_prefix=True)
|
| 926 |
+
def try_model(device: str = "cuda", batch_size: int = 2, seq_len_s: float = 10.0):
|
| 927 |
+
seq_len = int(32000 / 512 * seq_len_s)
|
| 928 |
+
|
| 929 |
+
model = VampNet().to(device)
|
| 930 |
+
|
| 931 |
+
z = torch.randint(
|
| 932 |
+
0, model.vocab_size, size=(batch_size, model.n_codebooks, seq_len)
|
| 933 |
+
).to(device)
|
| 934 |
+
|
| 935 |
+
r = torch.zeros(batch_size).to(device)
|
| 936 |
+
|
| 937 |
+
z_mask_latent = torch.rand(
|
| 938 |
+
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
| 939 |
+
).to(device)
|
| 940 |
+
z_hat = model(z_mask_latent)
|
| 941 |
+
|
| 942 |
+
pred = z_hat.argmax(dim=1)
|
| 943 |
+
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
| 944 |
+
|
| 945 |
+
print(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
| 946 |
+
print(f"prediction has shape {pred.shape}")
|
| 947 |
+
breakpoint()
|
| 948 |
+
|
| 949 |
+
args = argbind.parse_args()
|
| 950 |
+
with argbind.scope(args):
|
| 951 |
+
try_model()
|
| 952 |
+
|
| 953 |
+
|
vampnet/vampnet/scheduler.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class NoamScheduler:
|
| 7 |
+
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
| 8 |
+
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
optimizer: torch.optim.Optimizer,
|
| 14 |
+
d_model: int = 512,
|
| 15 |
+
factor: float = 1.0,
|
| 16 |
+
warmup: int = 4000,
|
| 17 |
+
):
|
| 18 |
+
# Store hparams
|
| 19 |
+
self.warmup = warmup
|
| 20 |
+
self.factor = factor
|
| 21 |
+
self.d_model = d_model
|
| 22 |
+
|
| 23 |
+
# Initialize variables `lr` and `steps`
|
| 24 |
+
self.lr = None
|
| 25 |
+
self.steps = 0
|
| 26 |
+
|
| 27 |
+
# Store the optimizer
|
| 28 |
+
self.optimizer = optimizer
|
| 29 |
+
|
| 30 |
+
def state_dict(self):
|
| 31 |
+
return {
|
| 32 |
+
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
def load_state_dict(self, state_dict):
|
| 36 |
+
self.__dict__.update(state_dict)
|
| 37 |
+
|
| 38 |
+
def step(self):
|
| 39 |
+
self.steps += 1
|
| 40 |
+
self.lr = self.factor * (
|
| 41 |
+
self.d_model ** (-0.5)
|
| 42 |
+
* min(self.steps ** (-0.5), self.steps * self.warmup ** (-1.5))
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
for p in self.optimizer.param_groups:
|
| 46 |
+
p["lr"] = self.lr
|
| 47 |
+
|
vampnet/vampnet/util.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
def scalar_to_batch_tensor(x, batch_size):
|
| 7 |
+
return torch.tensor(x).repeat(batch_size)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def parallelize(
|
| 11 |
+
fn,
|
| 12 |
+
*iterables,
|
| 13 |
+
parallel: str = "thread_map",
|
| 14 |
+
**kwargs
|
| 15 |
+
):
|
| 16 |
+
if parallel == "thread_map":
|
| 17 |
+
from tqdm.contrib.concurrent import thread_map
|
| 18 |
+
return thread_map(
|
| 19 |
+
fn,
|
| 20 |
+
*iterables,
|
| 21 |
+
**kwargs
|
| 22 |
+
)
|
| 23 |
+
elif parallel == "process_map":
|
| 24 |
+
from tqdm.contrib.concurrent import process_map
|
| 25 |
+
return process_map(
|
| 26 |
+
fn,
|
| 27 |
+
*iterables,
|
| 28 |
+
**kwargs
|
| 29 |
+
)
|
| 30 |
+
elif parallel == "single":
|
| 31 |
+
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
|
| 34 |
+
|
| 35 |
+
def codebook_flatten(tokens: torch.Tensor):
|
| 36 |
+
"""
|
| 37 |
+
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
| 38 |
+
"""
|
| 39 |
+
return rearrange(tokens, "b c t -> b (t c)")
|
| 40 |
+
|
| 41 |
+
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
|
| 42 |
+
"""
|
| 43 |
+
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
| 44 |
+
"""
|
| 45 |
+
tokens = rearrange(flat_tokens, "b (t c) -> b c t", c=n_c)
|
| 46 |
+
return tokens
|