Spaces:
Sleeping
Sleeping
orrp commited on
Commit ·
9583919
1
Parent(s): e823eac
Refactoring, linting, switching to pyproject.toml and Docker
Browse files- .dockerignore +6 -0
- vampnet/.pre-commit-config.yaml → .pre-commit-config.yaml +0 -0
- Dockerfile +17 -0
- README.md +1 -3
- pyproject.toml +8 -10
- uv.lock +0 -0
- vampnet/app.py +130 -183
- vampnet/scripts/exp/eval.py +34 -21
- vampnet/scripts/exp/experiment.py +82 -79
- vampnet/scripts/exp/fine_tune.py +13 -16
- vampnet/scripts/exp/train.py +55 -62
- vampnet/scripts/utils/data/augment.py +25 -23
- vampnet/scripts/utils/data/maestro-reorg.py +3 -3
- vampnet/scripts/utils/plots.py +27 -12
- vampnet/scripts/utils/remove_quiet_files.py +7 -4
- vampnet/scripts/utils/split.py +14 -20
- vampnet/scripts/utils/split_long_audio_file.py +14 -11
- vampnet/scripts/utils/stage.py +0 -1
- vampnet/scripts/utils/visualize_embeddings.py +22 -16
- vampnet/scripts/utils/xeno-canto-dl.py +15 -7
- vampnet/setup.py +2 -5
- vampnet/vampnet/__init__.py +1 -3
- vampnet/vampnet/beats.py +6 -10
- vampnet/vampnet/interface.py +94 -97
- vampnet/vampnet/mask.py +54 -55
- vampnet/vampnet/modules/__init__.py +1 -1
- vampnet/vampnet/modules/activations.py +5 -4
- vampnet/vampnet/modules/layers.py +5 -8
- vampnet/vampnet/modules/transformer.py +86 -96
- vampnet/vampnet/scheduler.py +1 -4
- vampnet/vampnet/util.py +14 -21
- wham.egg-info/PKG-INFO +248 -0
- wham.egg-info/SOURCES.txt +33 -0
- wham.egg-info/dependency_links.txt +1 -0
- wham.egg-info/requires.txt +27 -0
- wham.egg-info/top_level.txt +1 -0
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv
|
| 2 |
+
.git
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
.ruff_cache
|
| 6 |
+
wham/vampnet/models/*
|
vampnet/.pre-commit-config.yaml → .pre-commit-config.yaml
RENAMED
|
File without changes
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
|
| 3 |
+
|
| 4 |
+
# Install ffmpeg for audio processing
|
| 5 |
+
RUN apt-get update && apt-get install -y ffmpeg git build-essential && rm -rf /var/lib/apt/lists/*
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install dependencies using uv
|
| 10 |
+
COPY pyproject.toml .
|
| 11 |
+
RUN uv pip install --system .
|
| 12 |
+
|
| 13 |
+
# Copy your code and run the app
|
| 14 |
+
COPY . .
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 17 |
+
CMD ["python", "vampnet/app.py"]
|
README.md
CHANGED
|
@@ -3,11 +3,9 @@ title: WhAM
|
|
| 3 |
emoji: 🐋
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: indigo
|
| 6 |
-
sdk:
|
| 7 |
-
app_file: vampnet/app.py
|
| 8 |
pinned: false
|
| 9 |
hardware: a10g-small
|
| 10 |
-
python_version: "3.10"
|
| 11 |
---
|
| 12 |
|
| 13 |
# WhAM: a Whale Acoustics Model
|
|
|
|
| 3 |
emoji: 🐋
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: indigo
|
| 6 |
+
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
hardware: a10g-small
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
# WhAM: a Whale Acoustics Model
|
pyproject.toml
CHANGED
|
@@ -11,7 +11,7 @@ authors = [
|
|
| 11 |
{ name = "Project CETI" }
|
| 12 |
]
|
| 13 |
license = { text = "MIT" }
|
| 14 |
-
requires-python = ">=3.
|
| 15 |
dependencies = [
|
| 16 |
"torch",
|
| 17 |
"gradio",
|
|
@@ -33,10 +33,10 @@ dependencies = [
|
|
| 33 |
"gdown",
|
| 34 |
"transformers",
|
| 35 |
"fadtk",
|
| 36 |
-
"urllib3=
|
| 37 |
"plotly",
|
| 38 |
"pyharp",
|
| 39 |
-
|
| 40 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git",
|
| 41 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 42 |
"descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git"
|
|
@@ -47,15 +47,13 @@ where = ["."]
|
|
| 47 |
include = ["wham*", "vampnet*"]
|
| 48 |
|
| 49 |
[tool.ruff]
|
| 50 |
-
|
| 51 |
-
target-version = "py39"
|
| 52 |
line-length = 88
|
| 53 |
|
| 54 |
[tool.ruff.lint]
|
| 55 |
-
# Enable Pyflakes (F), pycodestyle (E, W), and isort (I)
|
| 56 |
select = ["E", "F", "W", "I"]
|
| 57 |
-
|
| 58 |
|
| 59 |
-
[tool.ruff.
|
| 60 |
-
|
| 61 |
-
|
|
|
|
| 11 |
{ name = "Project CETI" }
|
| 12 |
]
|
| 13 |
license = { text = "MIT" }
|
| 14 |
+
requires-python = ">=3.10,<3.11"
|
| 15 |
dependencies = [
|
| 16 |
"torch",
|
| 17 |
"gradio",
|
|
|
|
| 33 |
"gdown",
|
| 34 |
"transformers",
|
| 35 |
"fadtk",
|
| 36 |
+
"urllib3>=2.0.2",
|
| 37 |
"plotly",
|
| 38 |
"pyharp",
|
| 39 |
+
"ruff",
|
| 40 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git",
|
| 41 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 42 |
"descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git"
|
|
|
|
| 47 |
include = ["wham*", "vampnet*"]
|
| 48 |
|
| 49 |
[tool.ruff]
|
| 50 |
+
target-version = "py310"
|
|
|
|
| 51 |
line-length = 88
|
| 52 |
|
| 53 |
[tool.ruff.lint]
|
|
|
|
| 54 |
select = ["E", "F", "W", "I"]
|
| 55 |
+
fixable = ["ALL"]
|
| 56 |
|
| 57 |
+
[tool.ruff.lint.isort]
|
| 58 |
+
known-first-party = ["wham", "vampnet"]
|
| 59 |
+
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vampnet/app.py
CHANGED
|
@@ -1,43 +1,40 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
os.chdir(SCRIPT_DIR)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
device = "cuda" if torch.cuda.is_available()
|
| 10 |
sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
|
| 11 |
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
from typing import Tuple
|
| 14 |
-
import yaml
|
| 15 |
-
import tempfile
|
| 16 |
-
import uuid
|
| 17 |
-
from dataclasses import dataclass, asdict
|
| 18 |
-
|
| 19 |
-
import numpy as np
|
| 20 |
-
import audiotools as at
|
| 21 |
-
import argbind
|
| 22 |
-
|
| 23 |
-
import gradio as gr
|
| 24 |
-
from vampnet.interface import Interface
|
| 25 |
-
from vampnet import mask as pmask
|
| 26 |
|
| 27 |
Interface = argbind.bind(Interface)
|
| 28 |
|
| 29 |
conf = argbind.parse_args()
|
| 30 |
|
| 31 |
|
| 32 |
-
from torch_pitch_shift import pitch_shift
|
|
|
|
|
|
|
| 33 |
def shift_pitch(signal, interval: int):
|
| 34 |
signal.samples = pitch_shift(
|
| 35 |
-
signal.samples,
|
| 36 |
-
shift=interval,
|
| 37 |
-
sample_rate=signal.sample_rate
|
| 38 |
)
|
| 39 |
return signal
|
| 40 |
|
|
|
|
| 41 |
def load_interface():
|
| 42 |
with argbind.scope(conf):
|
| 43 |
interface = Interface()
|
|
@@ -46,8 +43,6 @@ def load_interface():
|
|
| 46 |
return interface
|
| 47 |
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
interface = load_interface()
|
| 52 |
|
| 53 |
|
|
@@ -59,8 +54,7 @@ def load_audio(file):
|
|
| 59 |
print(file)
|
| 60 |
filepath = file.name
|
| 61 |
sig = at.AudioSignal.salient_excerpt(
|
| 62 |
-
filepath,
|
| 63 |
-
duration=interface.coarse.chunk_size_s
|
| 64 |
)
|
| 65 |
sig = interface.preprocess(sig)
|
| 66 |
|
|
@@ -121,19 +115,10 @@ def _vamp(
|
|
| 121 |
# build the mask
|
| 122 |
mask = pmask.linear_random(z, _rand_mask_intensity)
|
| 123 |
mask = pmask.mask_and(
|
| 124 |
-
mask, pmask.inpaint(
|
| 125 |
-
z,
|
| 126 |
-
interface.s2t(_prefix_s),
|
| 127 |
-
interface.s2t(_suffix_s)
|
| 128 |
-
)
|
| 129 |
)
|
| 130 |
mask = pmask.mask_and(
|
| 131 |
-
mask, pmask.periodic_mask(
|
| 132 |
-
z,
|
| 133 |
-
_periodic_p,
|
| 134 |
-
_periodic_w,
|
| 135 |
-
random_roll=True
|
| 136 |
-
)
|
| 137 |
)
|
| 138 |
if _onset_mask_width > 0:
|
| 139 |
mask = pmask.mask_or(
|
|
@@ -142,7 +127,7 @@ def _vamp(
|
|
| 142 |
if _beat_mask_width > 0:
|
| 143 |
beat_mask = interface.make_beat_mask(
|
| 144 |
sig,
|
| 145 |
-
after_beat_s=(_beat_mask_width/1000),
|
| 146 |
mask_upbeats=not _beat_mask_downbeats,
|
| 147 |
)
|
| 148 |
mask = pmask.mask_and(mask, beat_mask)
|
|
@@ -174,14 +159,14 @@ def _vamp(
|
|
| 174 |
|
| 175 |
_top_p_val = _top_p if _top_p > 0 else None
|
| 176 |
# save the mask as a txt file
|
| 177 |
-
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 178 |
|
| 179 |
_seed_val = _seed if _seed > 0 else None
|
| 180 |
zv, mask_z = interface.coarse_vamp(
|
| 181 |
z,
|
| 182 |
mask=mask,
|
| 183 |
sampling_steps=_num_steps,
|
| 184 |
-
mask_temperature=_masktemp*10,
|
| 185 |
sampling_temperature=_sampletemp,
|
| 186 |
return_mask=True,
|
| 187 |
typical_filtering=_typical_filtering,
|
|
@@ -196,7 +181,7 @@ def _vamp(
|
|
| 196 |
if _use_coarse2fine:
|
| 197 |
zv = interface.coarse_to_fine(
|
| 198 |
zv,
|
| 199 |
-
mask_temperature=_masktemp*10,
|
| 200 |
sampling_temperature=_sampletemp,
|
| 201 |
mask=mask,
|
| 202 |
sampling_steps=_num_steps,
|
|
@@ -220,6 +205,7 @@ def _vamp(
|
|
| 220 |
else:
|
| 221 |
return sig.path_to_file
|
| 222 |
|
|
|
|
| 223 |
def _extract_and_call_vamp(data, return_mask):
|
| 224 |
"""Extract plain values from Gradio data dict so only picklable args cross the ZeroGPU boundary."""
|
| 225 |
return _vamp(
|
|
@@ -250,12 +236,15 @@ def _extract_and_call_vamp(data, return_mask):
|
|
| 250 |
return_mask=return_mask,
|
| 251 |
)
|
| 252 |
|
|
|
|
| 253 |
def vamp(data):
|
| 254 |
return _extract_and_call_vamp(data, return_mask=True)
|
| 255 |
|
|
|
|
| 256 |
def api_vamp(data):
|
| 257 |
return _extract_and_call_vamp(data, return_mask=False)
|
| 258 |
|
|
|
|
| 259 |
def save_vamp(data):
|
| 260 |
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
| 261 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -289,6 +278,7 @@ def save_vamp(data):
|
|
| 289 |
yaml.dump(_data, f)
|
| 290 |
|
| 291 |
import zipfile
|
|
|
|
| 292 |
zip_path = str(out_dir.with_suffix(".zip"))
|
| 293 |
with zipfile.ZipFile(zip_path, "w") as zf:
|
| 294 |
for file in out_dir.iterdir():
|
|
@@ -312,7 +302,7 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
|
| 312 |
if _beat_mask_width > 0:
|
| 313 |
beat_mask = interface.make_beat_mask(
|
| 314 |
sig,
|
| 315 |
-
after_beat_s=(_beat_mask_width/1000),
|
| 316 |
)
|
| 317 |
mask = pmask.mask_and(mask, beat_mask)
|
| 318 |
|
|
@@ -325,7 +315,6 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
|
| 325 |
gen_fn=interface.coarse.generate,
|
| 326 |
)
|
| 327 |
|
| 328 |
-
|
| 329 |
zv = interface.coarse_to_fine(
|
| 330 |
zv,
|
| 331 |
sampling_temperature=_sampletemp,
|
|
@@ -339,8 +328,8 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
|
| 339 |
|
| 340 |
return sig.path_to_file
|
| 341 |
|
| 342 |
-
with gr.Blocks() as demo:
|
| 343 |
|
|
|
|
| 344 |
with gr.Row():
|
| 345 |
with gr.Column():
|
| 346 |
gr.Markdown("# VampNet Audio Vamping")
|
|
@@ -360,11 +349,9 @@ with gr.Blocks() as demo:
|
|
| 360 |
""")
|
| 361 |
with gr.Row():
|
| 362 |
with gr.Column():
|
| 363 |
-
|
| 364 |
-
|
| 365 |
manual_audio_upload = gr.File(
|
| 366 |
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
| 367 |
-
file_types=["audio"]
|
| 368 |
)
|
| 369 |
load_example_audio_button = gr.Button("or load example audio")
|
| 370 |
|
|
@@ -382,71 +369,65 @@ with gr.Blocks() as demo:
|
|
| 382 |
|
| 383 |
# connect widgets
|
| 384 |
load_example_audio_button.click(
|
| 385 |
-
fn=load_example_audio,
|
| 386 |
-
inputs=[],
|
| 387 |
-
outputs=[ input_audio]
|
| 388 |
)
|
| 389 |
|
| 390 |
manual_audio_upload.change(
|
| 391 |
-
fn=load_audio,
|
| 392 |
-
inputs=[manual_audio_upload],
|
| 393 |
-
outputs=[ input_audio]
|
| 394 |
)
|
| 395 |
|
| 396 |
# mask settings
|
| 397 |
with gr.Column():
|
| 398 |
-
|
| 399 |
-
|
| 400 |
presets = {
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
|
| 451 |
preset = gr.Dropdown(
|
| 452 |
label="preset",
|
|
@@ -464,7 +445,6 @@ with gr.Blocks() as demo:
|
|
| 464 |
value=3,
|
| 465 |
)
|
| 466 |
|
| 467 |
-
|
| 468 |
onset_mask_width = gr.Slider(
|
| 469 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
| 470 |
minimum=0,
|
|
@@ -480,8 +460,7 @@ with gr.Blocks() as demo:
|
|
| 480 |
value=0,
|
| 481 |
)
|
| 482 |
beat_mask_downbeats = gr.Checkbox(
|
| 483 |
-
label="beat mask downbeats only?",
|
| 484 |
-
value=False
|
| 485 |
)
|
| 486 |
|
| 487 |
n_mask_codebooks = gr.Number(
|
|
@@ -489,7 +468,6 @@ with gr.Blocks() as demo:
|
|
| 489 |
value=9,
|
| 490 |
)
|
| 491 |
|
| 492 |
-
|
| 493 |
with gr.Accordion("extras ", open=False):
|
| 494 |
pitch_shift_amt = gr.Slider(
|
| 495 |
label="pitch shift amount (semitones)",
|
|
@@ -503,7 +481,7 @@ with gr.Blocks() as demo:
|
|
| 503 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
| 504 |
minimum=0.0,
|
| 505 |
maximum=1.0,
|
| 506 |
-
value=1.0
|
| 507 |
)
|
| 508 |
|
| 509 |
periodic_w = gr.Slider(
|
|
@@ -538,78 +516,62 @@ with gr.Blocks() as demo:
|
|
| 538 |
return tuple(presets[_preset].values())
|
| 539 |
|
| 540 |
load_preset_button.click(
|
| 541 |
-
fn=load_preset,
|
| 542 |
-
inputs=[preset],
|
| 543 |
-
outputs=preset_outputs
|
| 544 |
)
|
| 545 |
|
| 546 |
-
|
| 547 |
with gr.Accordion("prefix/suffix prompts", open=False):
|
| 548 |
prefix_s = gr.Slider(
|
| 549 |
label="prefix hint length (seconds)",
|
| 550 |
minimum=0.0,
|
| 551 |
maximum=10.0,
|
| 552 |
-
value=0.0
|
| 553 |
)
|
| 554 |
suffix_s = gr.Slider(
|
| 555 |
label="suffix hint length (seconds)",
|
| 556 |
minimum=0.0,
|
| 557 |
maximum=10.0,
|
| 558 |
-
value=0.0
|
| 559 |
)
|
| 560 |
|
| 561 |
masktemp = gr.Slider(
|
| 562 |
-
label="mask temperature",
|
| 563 |
-
minimum=0.0,
|
| 564 |
-
maximum=100.0,
|
| 565 |
-
value=1.5
|
| 566 |
)
|
| 567 |
sampletemp = gr.Slider(
|
| 568 |
label="sample temperature",
|
| 569 |
minimum=0.1,
|
| 570 |
maximum=10.0,
|
| 571 |
value=1.0,
|
| 572 |
-
step=0.001
|
| 573 |
)
|
| 574 |
|
| 575 |
-
|
| 576 |
-
|
| 577 |
with gr.Accordion("sampling settings", open=False):
|
| 578 |
top_p = gr.Slider(
|
| 579 |
-
label="top p (0.0 = off)",
|
| 580 |
-
minimum=0.0,
|
| 581 |
-
maximum=1.0,
|
| 582 |
-
value=0.0
|
| 583 |
-
)
|
| 584 |
-
typical_filtering = gr.Checkbox(
|
| 585 |
-
label="typical filtering ",
|
| 586 |
-
value=False
|
| 587 |
)
|
|
|
|
| 588 |
typical_mass = gr.Slider(
|
| 589 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
| 590 |
minimum=0.01,
|
| 591 |
maximum=0.99,
|
| 592 |
-
value=0.15
|
| 593 |
)
|
| 594 |
typical_min_tokens = gr.Slider(
|
| 595 |
label="typical min tokens (should probably stay between 1 and 256)",
|
| 596 |
minimum=1,
|
| 597 |
maximum=256,
|
| 598 |
step=1,
|
| 599 |
-
value=64
|
| 600 |
)
|
| 601 |
sample_cutoff = gr.Slider(
|
| 602 |
label="sample cutoff",
|
| 603 |
minimum=0.0,
|
| 604 |
maximum=1.0,
|
| 605 |
value=0.5,
|
| 606 |
-
step=0.01
|
| 607 |
)
|
| 608 |
|
| 609 |
use_coarse2fine = gr.Checkbox(
|
| 610 |
-
label="use coarse2fine",
|
| 611 |
-
value=True,
|
| 612 |
-
visible=False
|
| 613 |
)
|
| 614 |
|
| 615 |
num_steps = gr.Slider(
|
|
@@ -617,29 +579,21 @@ with gr.Blocks() as demo:
|
|
| 617 |
minimum=1,
|
| 618 |
maximum=128,
|
| 619 |
step=1,
|
| 620 |
-
value=36
|
| 621 |
)
|
| 622 |
|
| 623 |
dropout = gr.Slider(
|
| 624 |
-
label="mask dropout",
|
| 625 |
-
minimum=0.0,
|
| 626 |
-
maximum=1.0,
|
| 627 |
-
step=0.01,
|
| 628 |
-
value=0.0
|
| 629 |
)
|
| 630 |
|
| 631 |
-
|
| 632 |
seed = gr.Number(
|
| 633 |
label="seed (0 for random)",
|
| 634 |
value=0,
|
| 635 |
precision=0,
|
| 636 |
)
|
| 637 |
|
| 638 |
-
|
| 639 |
-
|
| 640 |
# mask settings
|
| 641 |
with gr.Column():
|
| 642 |
-
|
| 643 |
# lora_choice = gr.Dropdown(
|
| 644 |
# label="lora choice",
|
| 645 |
# choices=list(loras.keys()),
|
|
@@ -649,51 +603,49 @@ with gr.Blocks() as demo:
|
|
| 649 |
|
| 650 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 651 |
output_audio = gr.Audio(
|
| 652 |
-
label="output audio",
|
| 653 |
-
interactive=False,
|
| 654 |
-
type="filepath"
|
| 655 |
)
|
| 656 |
|
| 657 |
notes_text = gr.Textbox(
|
| 658 |
label="type any notes about the generated audio here",
|
| 659 |
value="",
|
| 660 |
-
interactive=True
|
| 661 |
)
|
| 662 |
save_button = gr.Button("save vamp")
|
| 663 |
download_file = gr.File(
|
| 664 |
-
label="vamp to download will appear here",
|
| 665 |
-
interactive=False
|
| 666 |
)
|
| 667 |
use_as_input_button = gr.Button("use output as input")
|
| 668 |
|
| 669 |
thank_you = gr.Markdown("")
|
| 670 |
|
| 671 |
-
|
| 672 |
_inputs = {
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
|
|
|
|
|
|
| 697 |
|
| 698 |
# connect widgets
|
| 699 |
vamp_button.click(
|
|
@@ -704,22 +656,17 @@ with gr.Blocks() as demo:
|
|
| 704 |
|
| 705 |
api_vamp_button = gr.Button("api vamp", visible=False)
|
| 706 |
api_vamp_button.click(
|
| 707 |
-
fn=api_vamp,
|
| 708 |
-
inputs=_inputs,
|
| 709 |
-
outputs=[output_audio],
|
| 710 |
-
api_name="vamp"
|
| 711 |
)
|
| 712 |
|
| 713 |
use_as_input_button.click(
|
| 714 |
-
fn=lambda x: x,
|
| 715 |
-
inputs=[output_audio],
|
| 716 |
-
outputs=[input_audio]
|
| 717 |
)
|
| 718 |
|
| 719 |
save_button.click(
|
| 720 |
fn=save_vamp,
|
| 721 |
inputs=_inputs | {notes_text, output_audio},
|
| 722 |
-
outputs=[thank_you, download_file]
|
| 723 |
)
|
| 724 |
|
| 725 |
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import uuid
|
| 4 |
+
from pathlib import Path
|
| 5 |
|
| 6 |
+
import argbind
|
| 7 |
+
import audiotools as at
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
|
| 13 |
+
from vampnet import mask as pmask
|
| 14 |
+
from vampnet.interface import Interface
|
| 15 |
|
| 16 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 17 |
os.chdir(SCRIPT_DIR)
|
| 18 |
|
| 19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 20 |
sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
Interface = argbind.bind(Interface)
|
| 24 |
|
| 25 |
conf = argbind.parse_args()
|
| 26 |
|
| 27 |
|
| 28 |
+
from torch_pitch_shift import pitch_shift
|
| 29 |
+
|
| 30 |
+
|
| 31 |
def shift_pitch(signal, interval: int):
|
| 32 |
signal.samples = pitch_shift(
|
| 33 |
+
signal.samples, shift=interval, sample_rate=signal.sample_rate
|
|
|
|
|
|
|
| 34 |
)
|
| 35 |
return signal
|
| 36 |
|
| 37 |
+
|
| 38 |
def load_interface():
|
| 39 |
with argbind.scope(conf):
|
| 40 |
interface = Interface()
|
|
|
|
| 43 |
return interface
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
| 46 |
interface = load_interface()
|
| 47 |
|
| 48 |
|
|
|
|
| 54 |
print(file)
|
| 55 |
filepath = file.name
|
| 56 |
sig = at.AudioSignal.salient_excerpt(
|
| 57 |
+
filepath, duration=interface.coarse.chunk_size_s
|
|
|
|
| 58 |
)
|
| 59 |
sig = interface.preprocess(sig)
|
| 60 |
|
|
|
|
| 115 |
# build the mask
|
| 116 |
mask = pmask.linear_random(z, _rand_mask_intensity)
|
| 117 |
mask = pmask.mask_and(
|
| 118 |
+
mask, pmask.inpaint(z, interface.s2t(_prefix_s), interface.s2t(_suffix_s))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
)
|
| 120 |
mask = pmask.mask_and(
|
| 121 |
+
mask, pmask.periodic_mask(z, _periodic_p, _periodic_w, random_roll=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
if _onset_mask_width > 0:
|
| 124 |
mask = pmask.mask_or(
|
|
|
|
| 127 |
if _beat_mask_width > 0:
|
| 128 |
beat_mask = interface.make_beat_mask(
|
| 129 |
sig,
|
| 130 |
+
after_beat_s=(_beat_mask_width / 1000),
|
| 131 |
mask_upbeats=not _beat_mask_downbeats,
|
| 132 |
)
|
| 133 |
mask = pmask.mask_and(mask, beat_mask)
|
|
|
|
| 159 |
|
| 160 |
_top_p_val = _top_p if _top_p > 0 else None
|
| 161 |
# save the mask as a txt file
|
| 162 |
+
np.savetxt(out_dir / "mask.txt", mask[:, 0, :].long().cpu().numpy())
|
| 163 |
|
| 164 |
_seed_val = _seed if _seed > 0 else None
|
| 165 |
zv, mask_z = interface.coarse_vamp(
|
| 166 |
z,
|
| 167 |
mask=mask,
|
| 168 |
sampling_steps=_num_steps,
|
| 169 |
+
mask_temperature=_masktemp * 10,
|
| 170 |
sampling_temperature=_sampletemp,
|
| 171 |
return_mask=True,
|
| 172 |
typical_filtering=_typical_filtering,
|
|
|
|
| 181 |
if _use_coarse2fine:
|
| 182 |
zv = interface.coarse_to_fine(
|
| 183 |
zv,
|
| 184 |
+
mask_temperature=_masktemp * 10,
|
| 185 |
sampling_temperature=_sampletemp,
|
| 186 |
mask=mask,
|
| 187 |
sampling_steps=_num_steps,
|
|
|
|
| 205 |
else:
|
| 206 |
return sig.path_to_file
|
| 207 |
|
| 208 |
+
|
| 209 |
def _extract_and_call_vamp(data, return_mask):
|
| 210 |
"""Extract plain values from Gradio data dict so only picklable args cross the ZeroGPU boundary."""
|
| 211 |
return _vamp(
|
|
|
|
| 236 |
return_mask=return_mask,
|
| 237 |
)
|
| 238 |
|
| 239 |
+
|
| 240 |
def vamp(data):
|
| 241 |
return _extract_and_call_vamp(data, return_mask=True)
|
| 242 |
|
| 243 |
+
|
| 244 |
def api_vamp(data):
|
| 245 |
return _extract_and_call_vamp(data, return_mask=False)
|
| 246 |
|
| 247 |
+
|
| 248 |
def save_vamp(data):
|
| 249 |
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
| 250 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 278 |
yaml.dump(_data, f)
|
| 279 |
|
| 280 |
import zipfile
|
| 281 |
+
|
| 282 |
zip_path = str(out_dir.with_suffix(".zip"))
|
| 283 |
with zipfile.ZipFile(zip_path, "w") as zf:
|
| 284 |
for file in out_dir.iterdir():
|
|
|
|
| 302 |
if _beat_mask_width > 0:
|
| 303 |
beat_mask = interface.make_beat_mask(
|
| 304 |
sig,
|
| 305 |
+
after_beat_s=(_beat_mask_width / 1000),
|
| 306 |
)
|
| 307 |
mask = pmask.mask_and(mask, beat_mask)
|
| 308 |
|
|
|
|
| 315 |
gen_fn=interface.coarse.generate,
|
| 316 |
)
|
| 317 |
|
|
|
|
| 318 |
zv = interface.coarse_to_fine(
|
| 319 |
zv,
|
| 320 |
sampling_temperature=_sampletemp,
|
|
|
|
| 328 |
|
| 329 |
return sig.path_to_file
|
| 330 |
|
|
|
|
| 331 |
|
| 332 |
+
with gr.Blocks() as demo:
|
| 333 |
with gr.Row():
|
| 334 |
with gr.Column():
|
| 335 |
gr.Markdown("# VampNet Audio Vamping")
|
|
|
|
| 349 |
""")
|
| 350 |
with gr.Row():
|
| 351 |
with gr.Column():
|
|
|
|
|
|
|
| 352 |
manual_audio_upload = gr.File(
|
| 353 |
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
| 354 |
+
file_types=["audio"],
|
| 355 |
)
|
| 356 |
load_example_audio_button = gr.Button("or load example audio")
|
| 357 |
|
|
|
|
| 369 |
|
| 370 |
# connect widgets
|
| 371 |
load_example_audio_button.click(
|
| 372 |
+
fn=load_example_audio, inputs=[], outputs=[input_audio]
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
manual_audio_upload.change(
|
| 376 |
+
fn=load_audio, inputs=[manual_audio_upload], outputs=[input_audio]
|
|
|
|
|
|
|
| 377 |
)
|
| 378 |
|
| 379 |
# mask settings
|
| 380 |
with gr.Column():
|
|
|
|
|
|
|
| 381 |
presets = {
|
| 382 |
+
"unconditional": {
|
| 383 |
+
"periodic_p": 0,
|
| 384 |
+
"onset_mask_width": 0,
|
| 385 |
+
"beat_mask_width": 0,
|
| 386 |
+
"beat_mask_downbeats": False,
|
| 387 |
+
},
|
| 388 |
+
"slight periodic variation": {
|
| 389 |
+
"periodic_p": 5,
|
| 390 |
+
"onset_mask_width": 5,
|
| 391 |
+
"beat_mask_width": 0,
|
| 392 |
+
"beat_mask_downbeats": False,
|
| 393 |
+
},
|
| 394 |
+
"moderate periodic variation": {
|
| 395 |
+
"periodic_p": 13,
|
| 396 |
+
"onset_mask_width": 5,
|
| 397 |
+
"beat_mask_width": 0,
|
| 398 |
+
"beat_mask_downbeats": False,
|
| 399 |
+
},
|
| 400 |
+
"strong periodic variation": {
|
| 401 |
+
"periodic_p": 17,
|
| 402 |
+
"onset_mask_width": 5,
|
| 403 |
+
"beat_mask_width": 0,
|
| 404 |
+
"beat_mask_downbeats": False,
|
| 405 |
+
},
|
| 406 |
+
"very strong periodic variation": {
|
| 407 |
+
"periodic_p": 21,
|
| 408 |
+
"onset_mask_width": 5,
|
| 409 |
+
"beat_mask_width": 0,
|
| 410 |
+
"beat_mask_downbeats": False,
|
| 411 |
+
},
|
| 412 |
+
"beat-driven variation": {
|
| 413 |
+
"periodic_p": 0,
|
| 414 |
+
"onset_mask_width": 0,
|
| 415 |
+
"beat_mask_width": 50,
|
| 416 |
+
"beat_mask_downbeats": False,
|
| 417 |
+
},
|
| 418 |
+
"beat-driven variation (downbeats only)": {
|
| 419 |
+
"periodic_p": 0,
|
| 420 |
+
"onset_mask_width": 0,
|
| 421 |
+
"beat_mask_width": 50,
|
| 422 |
+
"beat_mask_downbeats": True,
|
| 423 |
+
},
|
| 424 |
+
"beat-driven variation (downbeats only, strong)": {
|
| 425 |
+
"periodic_p": 0,
|
| 426 |
+
"onset_mask_width": 0,
|
| 427 |
+
"beat_mask_width": 20,
|
| 428 |
+
"beat_mask_downbeats": True,
|
| 429 |
+
},
|
| 430 |
+
}
|
| 431 |
|
| 432 |
preset = gr.Dropdown(
|
| 433 |
label="preset",
|
|
|
|
| 445 |
value=3,
|
| 446 |
)
|
| 447 |
|
|
|
|
| 448 |
onset_mask_width = gr.Slider(
|
| 449 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
| 450 |
minimum=0,
|
|
|
|
| 460 |
value=0,
|
| 461 |
)
|
| 462 |
beat_mask_downbeats = gr.Checkbox(
|
| 463 |
+
label="beat mask downbeats only?", value=False
|
|
|
|
| 464 |
)
|
| 465 |
|
| 466 |
n_mask_codebooks = gr.Number(
|
|
|
|
| 468 |
value=9,
|
| 469 |
)
|
| 470 |
|
|
|
|
| 471 |
with gr.Accordion("extras ", open=False):
|
| 472 |
pitch_shift_amt = gr.Slider(
|
| 473 |
label="pitch shift amount (semitones)",
|
|
|
|
| 481 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
| 482 |
minimum=0.0,
|
| 483 |
maximum=1.0,
|
| 484 |
+
value=1.0,
|
| 485 |
)
|
| 486 |
|
| 487 |
periodic_w = gr.Slider(
|
|
|
|
| 516 |
return tuple(presets[_preset].values())
|
| 517 |
|
| 518 |
load_preset_button.click(
|
| 519 |
+
fn=load_preset, inputs=[preset], outputs=preset_outputs
|
|
|
|
|
|
|
| 520 |
)
|
| 521 |
|
|
|
|
| 522 |
with gr.Accordion("prefix/suffix prompts", open=False):
|
| 523 |
prefix_s = gr.Slider(
|
| 524 |
label="prefix hint length (seconds)",
|
| 525 |
minimum=0.0,
|
| 526 |
maximum=10.0,
|
| 527 |
+
value=0.0,
|
| 528 |
)
|
| 529 |
suffix_s = gr.Slider(
|
| 530 |
label="suffix hint length (seconds)",
|
| 531 |
minimum=0.0,
|
| 532 |
maximum=10.0,
|
| 533 |
+
value=0.0,
|
| 534 |
)
|
| 535 |
|
| 536 |
masktemp = gr.Slider(
|
| 537 |
+
label="mask temperature", minimum=0.0, maximum=100.0, value=1.5
|
|
|
|
|
|
|
|
|
|
| 538 |
)
|
| 539 |
sampletemp = gr.Slider(
|
| 540 |
label="sample temperature",
|
| 541 |
minimum=0.1,
|
| 542 |
maximum=10.0,
|
| 543 |
value=1.0,
|
| 544 |
+
step=0.001,
|
| 545 |
)
|
| 546 |
|
|
|
|
|
|
|
| 547 |
with gr.Accordion("sampling settings", open=False):
|
| 548 |
top_p = gr.Slider(
|
| 549 |
+
label="top p (0.0 = off)", minimum=0.0, maximum=1.0, value=0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
)
|
| 551 |
+
typical_filtering = gr.Checkbox(label="typical filtering ", value=False)
|
| 552 |
typical_mass = gr.Slider(
|
| 553 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
| 554 |
minimum=0.01,
|
| 555 |
maximum=0.99,
|
| 556 |
+
value=0.15,
|
| 557 |
)
|
| 558 |
typical_min_tokens = gr.Slider(
|
| 559 |
label="typical min tokens (should probably stay between 1 and 256)",
|
| 560 |
minimum=1,
|
| 561 |
maximum=256,
|
| 562 |
step=1,
|
| 563 |
+
value=64,
|
| 564 |
)
|
| 565 |
sample_cutoff = gr.Slider(
|
| 566 |
label="sample cutoff",
|
| 567 |
minimum=0.0,
|
| 568 |
maximum=1.0,
|
| 569 |
value=0.5,
|
| 570 |
+
step=0.01,
|
| 571 |
)
|
| 572 |
|
| 573 |
use_coarse2fine = gr.Checkbox(
|
| 574 |
+
label="use coarse2fine", value=True, visible=False
|
|
|
|
|
|
|
| 575 |
)
|
| 576 |
|
| 577 |
num_steps = gr.Slider(
|
|
|
|
| 579 |
minimum=1,
|
| 580 |
maximum=128,
|
| 581 |
step=1,
|
| 582 |
+
value=36,
|
| 583 |
)
|
| 584 |
|
| 585 |
dropout = gr.Slider(
|
| 586 |
+
label="mask dropout", minimum=0.0, maximum=1.0, step=0.01, value=0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
)
|
| 588 |
|
|
|
|
| 589 |
seed = gr.Number(
|
| 590 |
label="seed (0 for random)",
|
| 591 |
value=0,
|
| 592 |
precision=0,
|
| 593 |
)
|
| 594 |
|
|
|
|
|
|
|
| 595 |
# mask settings
|
| 596 |
with gr.Column():
|
|
|
|
| 597 |
# lora_choice = gr.Dropdown(
|
| 598 |
# label="lora choice",
|
| 599 |
# choices=list(loras.keys()),
|
|
|
|
| 603 |
|
| 604 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 605 |
output_audio = gr.Audio(
|
| 606 |
+
label="output audio", interactive=False, type="filepath"
|
|
|
|
|
|
|
| 607 |
)
|
| 608 |
|
| 609 |
notes_text = gr.Textbox(
|
| 610 |
label="type any notes about the generated audio here",
|
| 611 |
value="",
|
| 612 |
+
interactive=True,
|
| 613 |
)
|
| 614 |
save_button = gr.Button("save vamp")
|
| 615 |
download_file = gr.File(
|
| 616 |
+
label="vamp to download will appear here", interactive=False
|
|
|
|
| 617 |
)
|
| 618 |
use_as_input_button = gr.Button("use output as input")
|
| 619 |
|
| 620 |
thank_you = gr.Markdown("")
|
| 621 |
|
|
|
|
| 622 |
_inputs = {
|
| 623 |
+
input_audio,
|
| 624 |
+
num_steps,
|
| 625 |
+
masktemp,
|
| 626 |
+
sampletemp,
|
| 627 |
+
top_p,
|
| 628 |
+
prefix_s,
|
| 629 |
+
suffix_s,
|
| 630 |
+
rand_mask_intensity,
|
| 631 |
+
periodic_p,
|
| 632 |
+
periodic_w,
|
| 633 |
+
n_conditioning_codebooks,
|
| 634 |
+
dropout,
|
| 635 |
+
use_coarse2fine,
|
| 636 |
+
stretch_factor,
|
| 637 |
+
onset_mask_width,
|
| 638 |
+
typical_filtering,
|
| 639 |
+
typical_mass,
|
| 640 |
+
typical_min_tokens,
|
| 641 |
+
beat_mask_width,
|
| 642 |
+
beat_mask_downbeats,
|
| 643 |
+
seed,
|
| 644 |
+
# lora_choice,
|
| 645 |
+
n_mask_codebooks,
|
| 646 |
+
pitch_shift_amt,
|
| 647 |
+
sample_cutoff,
|
| 648 |
+
}
|
| 649 |
|
| 650 |
# connect widgets
|
| 651 |
vamp_button.click(
|
|
|
|
| 656 |
|
| 657 |
api_vamp_button = gr.Button("api vamp", visible=False)
|
| 658 |
api_vamp_button.click(
|
| 659 |
+
fn=api_vamp, inputs=_inputs, outputs=[output_audio], api_name="vamp"
|
|
|
|
|
|
|
|
|
|
| 660 |
)
|
| 661 |
|
| 662 |
use_as_input_button.click(
|
| 663 |
+
fn=lambda x: x, inputs=[output_audio], outputs=[input_audio]
|
|
|
|
|
|
|
| 664 |
)
|
| 665 |
|
| 666 |
save_button.click(
|
| 667 |
fn=save_vamp,
|
| 668 |
inputs=_inputs | {notes_text, output_audio},
|
| 669 |
+
outputs=[thank_you, download_file],
|
| 670 |
)
|
| 671 |
|
| 672 |
|
vampnet/scripts/exp/eval.py
CHANGED
|
@@ -1,20 +1,18 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
import os
|
| 3 |
-
from functools import partial
|
| 4 |
|
| 5 |
-
from frechet_audio_distance import FrechetAudioDistance
|
| 6 |
-
import pandas
|
| 7 |
import argbind
|
|
|
|
|
|
|
| 8 |
import torch
|
|
|
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
-
import audiotools
|
| 12 |
-
from audiotools import AudioSignal
|
| 13 |
|
| 14 |
@argbind.bind(without_prefix=True)
|
| 15 |
def eval(
|
| 16 |
exp_dir: str = None,
|
| 17 |
-
baseline_key: str = "baseline",
|
| 18 |
audio_ext: str = ".wav",
|
| 19 |
):
|
| 20 |
assert exp_dir is not None
|
|
@@ -26,9 +24,9 @@ def eval(
|
|
| 26 |
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 27 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 28 |
frechet = FrechetAudioDistance(
|
| 29 |
-
use_pca=False,
|
| 30 |
use_activation=False,
|
| 31 |
-
verbose=True,
|
| 32 |
audio_load_worker=4,
|
| 33 |
)
|
| 34 |
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -36,19 +34,25 @@ def eval(
|
|
| 36 |
# figure out what conditions we have
|
| 37 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
| 38 |
|
| 39 |
-
assert baseline_key in conditions,
|
|
|
|
|
|
|
| 40 |
conditions.remove(baseline_key)
|
| 41 |
|
| 42 |
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
| 43 |
print(f"conditions: {conditions}")
|
| 44 |
|
| 45 |
-
baseline_dir = exp_dir / baseline_key
|
| 46 |
-
baseline_files = sorted(
|
|
|
|
|
|
|
| 47 |
|
| 48 |
metrics = []
|
| 49 |
for condition in tqdm(conditions):
|
| 50 |
cond_dir = exp_dir / condition
|
| 51 |
-
cond_files = sorted(
|
|
|
|
|
|
|
| 52 |
|
| 53 |
print(f"computing fad for {baseline_dir} and {cond_dir}")
|
| 54 |
frechet_score = frechet.score(baseline_dir, cond_dir)
|
|
@@ -57,11 +61,15 @@ def eval(
|
|
| 57 |
num_files = min(len(baseline_files), len(cond_files))
|
| 58 |
baseline_files = baseline_files[:num_files]
|
| 59 |
cond_files = cond_files[:num_files]
|
| 60 |
-
assert len(list(baseline_files)) == len(list(cond_files)),
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def process(baseline_file, cond_file):
|
| 63 |
# make sure the files match (same name)
|
| 64 |
-
assert baseline_file.stem == cond_file.stem,
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# load the files
|
| 67 |
baseline_sig = AudioSignal(str(baseline_file))
|
|
@@ -74,7 +82,9 @@ def eval(
|
|
| 74 |
if "inpaint" in condition:
|
| 75 |
ctx_amt = float(condition.split("_")[-1])
|
| 76 |
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
| 77 |
-
print(
|
|
|
|
|
|
|
| 78 |
cond_sig.trim(ctx_samples, ctx_samples)
|
| 79 |
baseline_sig.trim(ctx_samples, ctx_samples)
|
| 80 |
|
|
@@ -88,15 +98,18 @@ def eval(
|
|
| 88 |
"file": baseline_file.stem,
|
| 89 |
}
|
| 90 |
|
| 91 |
-
print(
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
| 95 |
|
| 96 |
-
|
| 97 |
for mk in metric_keys:
|
| 98 |
stat = pandas.DataFrame(metrics)
|
| 99 |
-
stat = stat.groupby([
|
| 100 |
stat.to_csv(exp_dir / f"stats-{mk}.csv")
|
| 101 |
|
| 102 |
df = pandas.DataFrame(metrics)
|
|
@@ -107,4 +120,4 @@ if __name__ == "__main__":
|
|
| 107 |
args = argbind.parse_args()
|
| 108 |
|
| 109 |
with argbind.scope(args):
|
| 110 |
-
eval()
|
|
|
|
| 1 |
from pathlib import Path
|
|
|
|
|
|
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import argbind
|
| 4 |
+
import audiotools
|
| 5 |
+
import pandas
|
| 6 |
import torch
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
from frechet_audio_distance import FrechetAudioDistance
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
@argbind.bind(without_prefix=True)
|
| 13 |
def eval(
|
| 14 |
exp_dir: str = None,
|
| 15 |
+
baseline_key: str = "baseline",
|
| 16 |
audio_ext: str = ".wav",
|
| 17 |
):
|
| 18 |
assert exp_dir is not None
|
|
|
|
| 24 |
# stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
|
| 25 |
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
|
| 26 |
frechet = FrechetAudioDistance(
|
| 27 |
+
use_pca=False,
|
| 28 |
use_activation=False,
|
| 29 |
+
verbose=True,
|
| 30 |
audio_load_worker=4,
|
| 31 |
)
|
| 32 |
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 34 |
# figure out what conditions we have
|
| 35 |
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
|
| 36 |
|
| 37 |
+
assert baseline_key in conditions, (
|
| 38 |
+
f"baseline_key {baseline_key} not found in {exp_dir}"
|
| 39 |
+
)
|
| 40 |
conditions.remove(baseline_key)
|
| 41 |
|
| 42 |
print(f"Found {len(conditions)} conditions in {exp_dir}")
|
| 43 |
print(f"conditions: {conditions}")
|
| 44 |
|
| 45 |
+
baseline_dir = exp_dir / baseline_key
|
| 46 |
+
baseline_files = sorted(
|
| 47 |
+
list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)
|
| 48 |
+
)
|
| 49 |
|
| 50 |
metrics = []
|
| 51 |
for condition in tqdm(conditions):
|
| 52 |
cond_dir = exp_dir / condition
|
| 53 |
+
cond_files = sorted(
|
| 54 |
+
list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)
|
| 55 |
+
)
|
| 56 |
|
| 57 |
print(f"computing fad for {baseline_dir} and {cond_dir}")
|
| 58 |
frechet_score = frechet.score(baseline_dir, cond_dir)
|
|
|
|
| 61 |
num_files = min(len(baseline_files), len(cond_files))
|
| 62 |
baseline_files = baseline_files[:num_files]
|
| 63 |
cond_files = cond_files[:num_files]
|
| 64 |
+
assert len(list(baseline_files)) == len(list(cond_files)), (
|
| 65 |
+
f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
|
| 66 |
+
)
|
| 67 |
|
| 68 |
def process(baseline_file, cond_file):
|
| 69 |
# make sure the files match (same name)
|
| 70 |
+
assert baseline_file.stem == cond_file.stem, (
|
| 71 |
+
f"baseline file {baseline_file} and cond file {cond_file} do not match"
|
| 72 |
+
)
|
| 73 |
|
| 74 |
# load the files
|
| 75 |
baseline_sig = AudioSignal(str(baseline_file))
|
|
|
|
| 82 |
if "inpaint" in condition:
|
| 83 |
ctx_amt = float(condition.split("_")[-1])
|
| 84 |
ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
|
| 85 |
+
print(
|
| 86 |
+
f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}"
|
| 87 |
+
)
|
| 88 |
cond_sig.trim(ctx_samples, ctx_samples)
|
| 89 |
baseline_sig.trim(ctx_samples, ctx_samples)
|
| 90 |
|
|
|
|
| 98 |
"file": baseline_file.stem,
|
| 99 |
}
|
| 100 |
|
| 101 |
+
print(
|
| 102 |
+
f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}"
|
| 103 |
+
)
|
| 104 |
+
metrics.extend(
|
| 105 |
+
tqdm(map(process, baseline_files, cond_files), total=len(baseline_files))
|
| 106 |
+
)
|
| 107 |
|
| 108 |
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
|
| 109 |
|
|
|
|
| 110 |
for mk in metric_keys:
|
| 111 |
stat = pandas.DataFrame(metrics)
|
| 112 |
+
stat = stat.groupby(["condition"])[mk].agg(["mean", "count", "std"])
|
| 113 |
stat.to_csv(exp_dir / f"stats-{mk}.csv")
|
| 114 |
|
| 115 |
df = pandas.DataFrame(metrics)
|
|
|
|
| 120 |
args = argbind.parse_args()
|
| 121 |
|
| 122 |
with argbind.scope(args):
|
| 123 |
+
eval()
|
vampnet/scripts/exp/experiment.py
CHANGED
|
@@ -1,48 +1,44 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
import random
|
| 3 |
-
from typing import List
|
| 4 |
-
import tempfile
|
| 5 |
import subprocess
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import argbind
|
| 8 |
-
|
| 9 |
import torch
|
|
|
|
| 10 |
|
| 11 |
-
from vampnet.interface import Interface
|
| 12 |
from vampnet import mask as pmask
|
| 13 |
-
|
| 14 |
|
| 15 |
Interface: Interface = argbind.bind(Interface)
|
| 16 |
|
| 17 |
|
| 18 |
-
|
| 19 |
-
def calculate_bitrate(
|
| 20 |
-
interface, num_codebooks,
|
| 21 |
-
downsample_factor
|
| 22 |
-
):
|
| 23 |
bit_width = 10
|
| 24 |
sr = interface.codec.sample_rate
|
| 25 |
hop = interface.codec.hop_size
|
| 26 |
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
|
| 27 |
return rate
|
| 28 |
|
|
|
|
| 29 |
def baseline(sig, interface):
|
| 30 |
return interface.preprocess(sig)
|
| 31 |
|
|
|
|
| 32 |
def reconstructed(sig, interface):
|
| 33 |
-
return interface.to_signal(
|
| 34 |
-
|
| 35 |
-
)
|
| 36 |
|
| 37 |
def coarse2fine(sig, interface):
|
| 38 |
z = interface.encode(sig)
|
| 39 |
-
z = z[:, :interface.c2f.n_conditioning_codebooks, :]
|
| 40 |
|
| 41 |
z = interface.coarse_to_fine(z)
|
| 42 |
return interface.to_signal(z)
|
| 43 |
|
| 44 |
-
class CoarseCond:
|
| 45 |
|
|
|
|
| 46 |
def __init__(self, num_conditioning_codebooks, downsample_factor):
|
| 47 |
self.num_conditioning_codebooks = num_conditioning_codebooks
|
| 48 |
self.downsample_factor = downsample_factor
|
|
@@ -57,83 +53,85 @@ class CoarseCond:
|
|
| 57 |
zv = interface.coarse_to_fine(zv)
|
| 58 |
return interface.to_signal(zv)
|
| 59 |
|
|
|
|
| 60 |
def opus(sig, interface, bitrate=128):
|
| 61 |
sig = interface.preprocess(sig)
|
| 62 |
-
|
| 63 |
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
| 64 |
sig.write(f.name)
|
| 65 |
|
| 66 |
opus_name = Path(f.name).with_suffix(".opus")
|
| 67 |
# convert to opus
|
| 68 |
cmd = [
|
| 69 |
-
"ffmpeg",
|
| 70 |
-
"-
|
| 71 |
-
"-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
]
|
| 74 |
subprocess.run(cmd, check=True)
|
| 75 |
|
| 76 |
# convert back to wav
|
| 77 |
output_name = Path(f"{f.name}-opus").with_suffix(".wav")
|
| 78 |
-
cmd = [
|
| 79 |
-
"ffmpeg", "-y", "-i", opus_name,
|
| 80 |
-
output_name
|
| 81 |
-
]
|
| 82 |
|
| 83 |
subprocess.run(cmd, check=True)
|
| 84 |
|
| 85 |
-
sig = at.AudioSignal(
|
| 86 |
-
output_name,
|
| 87 |
-
sample_rate=sig.sample_rate
|
| 88 |
-
)
|
| 89 |
return sig
|
| 90 |
|
|
|
|
| 91 |
def mask_ratio_1_step(ratio=1.0):
|
| 92 |
def wrapper(sig, interface):
|
| 93 |
z = interface.encode(sig)
|
| 94 |
mask = pmask.linear_random(z, ratio)
|
| 95 |
zv = interface.coarse_vamp(
|
| 96 |
-
z,
|
| 97 |
mask,
|
| 98 |
-
sampling_steps=1,
|
| 99 |
)
|
| 100 |
|
| 101 |
return interface.to_signal(zv)
|
|
|
|
| 102 |
return wrapper
|
| 103 |
|
|
|
|
| 104 |
def num_sampling_steps(num_steps=1):
|
| 105 |
def wrapper(sig, interface: Interface):
|
| 106 |
z = interface.encode(sig)
|
| 107 |
mask = pmask.periodic_mask(z, 16)
|
| 108 |
zv = interface.coarse_vamp(
|
| 109 |
-
z,
|
| 110 |
mask,
|
| 111 |
-
sampling_steps=num_steps,
|
| 112 |
)
|
| 113 |
|
| 114 |
zv = interface.coarse_to_fine(zv)
|
| 115 |
return interface.to_signal(zv)
|
|
|
|
| 116 |
return wrapper
|
| 117 |
|
|
|
|
| 118 |
def beat_mask(ctx_time):
|
| 119 |
def wrapper(sig, interface):
|
| 120 |
beat_mask = interface.make_beat_mask(
|
| 121 |
-
sig,
|
| 122 |
-
before_beat_s=ctx_time/2,
|
| 123 |
-
after_beat_s=ctx_time/2,
|
| 124 |
-
invert=True
|
| 125 |
)
|
| 126 |
|
| 127 |
z = interface.encode(sig)
|
| 128 |
|
| 129 |
-
zv = interface.coarse_vamp(
|
| 130 |
-
z, beat_mask
|
| 131 |
-
)
|
| 132 |
|
| 133 |
zv = interface.coarse_to_fine(zv)
|
| 134 |
return interface.to_signal(zv)
|
|
|
|
| 135 |
return wrapper
|
| 136 |
|
|
|
|
| 137 |
def inpaint(ctx_time):
|
| 138 |
def wrapper(sig, interface: Interface):
|
| 139 |
z = interface.encode(sig)
|
|
@@ -141,22 +139,22 @@ def inpaint(ctx_time):
|
|
| 141 |
|
| 142 |
zv = interface.coarse_vamp(z, mask)
|
| 143 |
zv = interface.coarse_to_fine(zv)
|
| 144 |
-
|
| 145 |
return interface.to_signal(zv)
|
|
|
|
| 146 |
return wrapper
|
| 147 |
|
|
|
|
| 148 |
def token_noise(noise_amt):
|
| 149 |
def wrapper(sig, interface: Interface):
|
| 150 |
z = interface.encode(sig)
|
| 151 |
mask = pmask.random(z, noise_amt)
|
| 152 |
-
z = torch.where(
|
| 153 |
-
mask,
|
| 154 |
-
torch.randint_like(z, 0, interface.coarse.vocab_size),
|
| 155 |
-
z
|
| 156 |
-
)
|
| 157 |
return interface.to_signal(z)
|
|
|
|
| 158 |
return wrapper
|
| 159 |
|
|
|
|
| 160 |
EXP_REGISTRY = {}
|
| 161 |
|
| 162 |
EXP_REGISTRY["gen-compression"] = {
|
|
@@ -164,57 +162,63 @@ EXP_REGISTRY["gen-compression"] = {
|
|
| 164 |
"reconstructed": reconstructed,
|
| 165 |
"coarse2fine": coarse2fine,
|
| 166 |
**{
|
| 167 |
-
f"{n}_codebooks_downsampled_{x}x": CoarseCond(
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
f"token_noise_{x}": mask_ratio_1_step(ratio=x)
|
| 177 |
-
for x in [0.25, 0.5, 0.75]
|
| 178 |
},
|
| 179 |
-
|
| 180 |
}
|
| 181 |
|
| 182 |
|
| 183 |
EXP_REGISTRY["sampling-steps"] = {
|
| 184 |
# "codec": reconstructed,
|
| 185 |
-
**{f"steps_{n}": num_sampling_steps(n)
|
| 186 |
}
|
| 187 |
|
| 188 |
|
| 189 |
EXP_REGISTRY["musical-sampling"] = {
|
| 190 |
-
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 191 |
-
**{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
}
|
| 193 |
|
|
|
|
| 194 |
@argbind.bind(without_prefix=True)
|
| 195 |
def main(
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
at.util.seed(seed)
|
| 206 |
interface = Interface()
|
| 207 |
|
| 208 |
-
output_dir = Path(output_dir)
|
| 209 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 210 |
|
| 211 |
-
from audiotools.data.datasets import
|
| 212 |
|
| 213 |
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
|
| 214 |
-
dataset = AudioDataset(
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
without_replacement=True,
|
| 219 |
)
|
| 220 |
|
|
@@ -223,7 +227,6 @@ def main(
|
|
| 223 |
else:
|
| 224 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
| 225 |
|
| 226 |
-
|
| 227 |
indices = list(range(max_excerpts))
|
| 228 |
random.shuffle(indices)
|
| 229 |
for i in tqdm(indices):
|
|
@@ -237,8 +240,7 @@ def main(
|
|
| 237 |
|
| 238 |
sig = dataset[i]["signal"]
|
| 239 |
results = {
|
| 240 |
-
name: cond(sig, interface).cpu()
|
| 241 |
-
for name, cond in SAMPLE_CONDS.items()
|
| 242 |
}
|
| 243 |
|
| 244 |
for name, sig in results.items():
|
|
@@ -247,6 +249,7 @@ def main(
|
|
| 247 |
|
| 248 |
sig.write(o_dir / f"{i}.wav")
|
| 249 |
|
|
|
|
| 250 |
if __name__ == "__main__":
|
| 251 |
args = argbind.parse_args()
|
| 252 |
|
|
|
|
|
|
|
| 1 |
import random
|
|
|
|
|
|
|
| 2 |
import subprocess
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
|
| 6 |
import argbind
|
| 7 |
+
import audiotools as at
|
| 8 |
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
|
|
|
|
| 11 |
from vampnet import mask as pmask
|
| 12 |
+
from vampnet.interface import Interface
|
| 13 |
|
| 14 |
Interface: Interface = argbind.bind(Interface)
|
| 15 |
|
| 16 |
|
| 17 |
+
def calculate_bitrate(interface, num_codebooks, downsample_factor):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
bit_width = 10
|
| 19 |
sr = interface.codec.sample_rate
|
| 20 |
hop = interface.codec.hop_size
|
| 21 |
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
|
| 22 |
return rate
|
| 23 |
|
| 24 |
+
|
| 25 |
def baseline(sig, interface):
|
| 26 |
return interface.preprocess(sig)
|
| 27 |
|
| 28 |
+
|
| 29 |
def reconstructed(sig, interface):
|
| 30 |
+
return interface.to_signal(interface.encode(sig))
|
| 31 |
+
|
|
|
|
| 32 |
|
| 33 |
def coarse2fine(sig, interface):
|
| 34 |
z = interface.encode(sig)
|
| 35 |
+
z = z[:, : interface.c2f.n_conditioning_codebooks, :]
|
| 36 |
|
| 37 |
z = interface.coarse_to_fine(z)
|
| 38 |
return interface.to_signal(z)
|
| 39 |
|
|
|
|
| 40 |
|
| 41 |
+
class CoarseCond:
|
| 42 |
def __init__(self, num_conditioning_codebooks, downsample_factor):
|
| 43 |
self.num_conditioning_codebooks = num_conditioning_codebooks
|
| 44 |
self.downsample_factor = downsample_factor
|
|
|
|
| 53 |
zv = interface.coarse_to_fine(zv)
|
| 54 |
return interface.to_signal(zv)
|
| 55 |
|
| 56 |
+
|
| 57 |
def opus(sig, interface, bitrate=128):
|
| 58 |
sig = interface.preprocess(sig)
|
| 59 |
+
|
| 60 |
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
| 61 |
sig.write(f.name)
|
| 62 |
|
| 63 |
opus_name = Path(f.name).with_suffix(".opus")
|
| 64 |
# convert to opus
|
| 65 |
cmd = [
|
| 66 |
+
"ffmpeg",
|
| 67 |
+
"-y",
|
| 68 |
+
"-i",
|
| 69 |
+
f.name,
|
| 70 |
+
"-c:a",
|
| 71 |
+
"libopus",
|
| 72 |
+
"-b:a",
|
| 73 |
+
f"{bitrate}",
|
| 74 |
+
opus_name,
|
| 75 |
]
|
| 76 |
subprocess.run(cmd, check=True)
|
| 77 |
|
| 78 |
# convert back to wav
|
| 79 |
output_name = Path(f"{f.name}-opus").with_suffix(".wav")
|
| 80 |
+
cmd = ["ffmpeg", "-y", "-i", opus_name, output_name]
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
subprocess.run(cmd, check=True)
|
| 83 |
|
| 84 |
+
sig = at.AudioSignal(output_name, sample_rate=sig.sample_rate)
|
|
|
|
|
|
|
|
|
|
| 85 |
return sig
|
| 86 |
|
| 87 |
+
|
| 88 |
def mask_ratio_1_step(ratio=1.0):
|
| 89 |
def wrapper(sig, interface):
|
| 90 |
z = interface.encode(sig)
|
| 91 |
mask = pmask.linear_random(z, ratio)
|
| 92 |
zv = interface.coarse_vamp(
|
| 93 |
+
z,
|
| 94 |
mask,
|
| 95 |
+
sampling_steps=1,
|
| 96 |
)
|
| 97 |
|
| 98 |
return interface.to_signal(zv)
|
| 99 |
+
|
| 100 |
return wrapper
|
| 101 |
|
| 102 |
+
|
| 103 |
def num_sampling_steps(num_steps=1):
|
| 104 |
def wrapper(sig, interface: Interface):
|
| 105 |
z = interface.encode(sig)
|
| 106 |
mask = pmask.periodic_mask(z, 16)
|
| 107 |
zv = interface.coarse_vamp(
|
| 108 |
+
z,
|
| 109 |
mask,
|
| 110 |
+
sampling_steps=num_steps,
|
| 111 |
)
|
| 112 |
|
| 113 |
zv = interface.coarse_to_fine(zv)
|
| 114 |
return interface.to_signal(zv)
|
| 115 |
+
|
| 116 |
return wrapper
|
| 117 |
|
| 118 |
+
|
| 119 |
def beat_mask(ctx_time):
|
| 120 |
def wrapper(sig, interface):
|
| 121 |
beat_mask = interface.make_beat_mask(
|
| 122 |
+
sig, before_beat_s=ctx_time / 2, after_beat_s=ctx_time / 2, invert=True
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
z = interface.encode(sig)
|
| 126 |
|
| 127 |
+
zv = interface.coarse_vamp(z, beat_mask)
|
|
|
|
|
|
|
| 128 |
|
| 129 |
zv = interface.coarse_to_fine(zv)
|
| 130 |
return interface.to_signal(zv)
|
| 131 |
+
|
| 132 |
return wrapper
|
| 133 |
|
| 134 |
+
|
| 135 |
def inpaint(ctx_time):
|
| 136 |
def wrapper(sig, interface: Interface):
|
| 137 |
z = interface.encode(sig)
|
|
|
|
| 139 |
|
| 140 |
zv = interface.coarse_vamp(z, mask)
|
| 141 |
zv = interface.coarse_to_fine(zv)
|
| 142 |
+
|
| 143 |
return interface.to_signal(zv)
|
| 144 |
+
|
| 145 |
return wrapper
|
| 146 |
|
| 147 |
+
|
| 148 |
def token_noise(noise_amt):
|
| 149 |
def wrapper(sig, interface: Interface):
|
| 150 |
z = interface.encode(sig)
|
| 151 |
mask = pmask.random(z, noise_amt)
|
| 152 |
+
z = torch.where(mask, torch.randint_like(z, 0, interface.coarse.vocab_size), z)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return interface.to_signal(z)
|
| 154 |
+
|
| 155 |
return wrapper
|
| 156 |
|
| 157 |
+
|
| 158 |
EXP_REGISTRY = {}
|
| 159 |
|
| 160 |
EXP_REGISTRY["gen-compression"] = {
|
|
|
|
| 162 |
"reconstructed": reconstructed,
|
| 163 |
"coarse2fine": coarse2fine,
|
| 164 |
**{
|
| 165 |
+
f"{n}_codebooks_downsampled_{x}x": CoarseCond(
|
| 166 |
+
num_conditioning_codebooks=n, downsample_factor=x
|
| 167 |
+
)
|
| 168 |
+
for (n, x) in (
|
| 169 |
+
(1, 1), # 1 codebook, no downsampling
|
| 170 |
+
(4, 4), # 4 codebooks, downsampled 4x
|
| 171 |
+
(4, 16), # 4 codebooks, downsampled 16x
|
| 172 |
+
(4, 32), # 4 codebooks, downsampled 16x
|
| 173 |
+
)
|
|
|
|
|
|
|
| 174 |
},
|
| 175 |
+
**{f"token_noise_{x}": mask_ratio_1_step(ratio=x) for x in [0.25, 0.5, 0.75]},
|
| 176 |
}
|
| 177 |
|
| 178 |
|
| 179 |
EXP_REGISTRY["sampling-steps"] = {
|
| 180 |
# "codec": reconstructed,
|
| 181 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
|
| 182 |
}
|
| 183 |
|
| 184 |
|
| 185 |
EXP_REGISTRY["musical-sampling"] = {
|
| 186 |
+
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
| 187 |
+
**{
|
| 188 |
+
f"inpaint_{t}": inpaint(t)
|
| 189 |
+
for t in [
|
| 190 |
+
0.5,
|
| 191 |
+
1.0,
|
| 192 |
+
]
|
| 193 |
+
}, # multiply these by 2 (they go left and right)
|
| 194 |
}
|
| 195 |
|
| 196 |
+
|
| 197 |
@argbind.bind(without_prefix=True)
|
| 198 |
def main(
|
| 199 |
+
sources=[
|
| 200 |
+
"/media/CHONK/hugo/spotdl/val",
|
| 201 |
+
],
|
| 202 |
+
output_dir: str = "./samples",
|
| 203 |
+
max_excerpts: int = 2000,
|
| 204 |
+
exp_type: str = "gen-compression",
|
| 205 |
+
seed: int = 0,
|
| 206 |
+
ext: str = [".mp3"],
|
| 207 |
+
):
|
| 208 |
at.util.seed(seed)
|
| 209 |
interface = Interface()
|
| 210 |
|
| 211 |
+
output_dir = Path(output_dir)
|
| 212 |
output_dir.mkdir(exist_ok=True, parents=True)
|
| 213 |
|
| 214 |
+
from audiotools.data.datasets import AudioDataset, AudioLoader
|
| 215 |
|
| 216 |
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
|
| 217 |
+
dataset = AudioDataset(
|
| 218 |
+
loader,
|
| 219 |
+
sample_rate=interface.codec.sample_rate,
|
| 220 |
+
duration=interface.coarse.chunk_size_s,
|
| 221 |
+
n_examples=max_excerpts,
|
| 222 |
without_replacement=True,
|
| 223 |
)
|
| 224 |
|
|
|
|
| 227 |
else:
|
| 228 |
raise ValueError(f"Unknown exp_type {exp_type}")
|
| 229 |
|
|
|
|
| 230 |
indices = list(range(max_excerpts))
|
| 231 |
random.shuffle(indices)
|
| 232 |
for i in tqdm(indices):
|
|
|
|
| 240 |
|
| 241 |
sig = dataset[i]["signal"]
|
| 242 |
results = {
|
| 243 |
+
name: cond(sig, interface).cpu() for name, cond in SAMPLE_CONDS.items()
|
|
|
|
| 244 |
}
|
| 245 |
|
| 246 |
for name, sig in results.items():
|
|
|
|
| 249 |
|
| 250 |
sig.write(o_dir / f"{i}.wav")
|
| 251 |
|
| 252 |
+
|
| 253 |
if __name__ == "__main__":
|
| 254 |
args = argbind.parse_args()
|
| 255 |
|
vampnet/scripts/exp/fine_tune.py
CHANGED
|
@@ -1,20 +1,21 @@
|
|
| 1 |
-
import argbind
|
| 2 |
from pathlib import Path
|
| 3 |
-
import yaml
|
| 4 |
from typing import List
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
|
| 9 |
"""example output: (yaml)
|
| 10 |
|
| 11 |
"""
|
| 12 |
|
|
|
|
| 13 |
@argbind.bind(without_prefix=True, positional=True)
|
| 14 |
def fine_tune(audio_files_or_folders: List[str], name: str):
|
| 15 |
|
| 16 |
conf_dir = Path("conf")
|
| 17 |
-
assert conf_dir.exists(),
|
|
|
|
|
|
|
| 18 |
|
| 19 |
conf_dir = conf_dir / "generated"
|
| 20 |
conf_dir.mkdir(exist_ok=True)
|
|
@@ -35,7 +36,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 35 |
"AudioDataset.duration": 3.0,
|
| 36 |
"AudioDataset.loudness_cutoff": -40.0,
|
| 37 |
"save_path": f"./runs/{name}/c2f",
|
| 38 |
-
"fine_tune_checkpoint": "./models/vampnet/c2f.pth"
|
| 39 |
}
|
| 40 |
|
| 41 |
finetune_coarse_conf = {
|
|
@@ -44,15 +45,13 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 44 |
"train/AudioLoader.sources": audio_files_or_folders,
|
| 45 |
"val/AudioLoader.sources": audio_files_or_folders,
|
| 46 |
"save_path": f"./runs/{name}/coarse",
|
| 47 |
-
"fine_tune_checkpoint": "./models/vampnet/coarse.pth"
|
| 48 |
}
|
| 49 |
|
| 50 |
interface_conf = {
|
| 51 |
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
| 52 |
-
|
| 53 |
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
| 54 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
| 55 |
-
|
| 56 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 57 |
"AudioLoader.sources": [audio_files_or_folders],
|
| 58 |
}
|
|
@@ -63,19 +62,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 63 |
|
| 64 |
with open(finetune_dir / "coarse.yml", "w") as f:
|
| 65 |
yaml.dump(finetune_coarse_conf, f)
|
| 66 |
-
|
| 67 |
-
with open(finetune_dir / "interface.yml", "w") as f:
|
| 68 |
yaml.dump(interface_conf, f)
|
| 69 |
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
|
| 72 |
|
| 73 |
if __name__ == "__main__":
|
| 74 |
args = argbind.parse_args()
|
| 75 |
|
| 76 |
with argbind.scope(args):
|
| 77 |
fine_tune()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
|
|
|
| 2 |
from typing import List
|
| 3 |
|
| 4 |
+
import argbind
|
| 5 |
+
import yaml
|
| 6 |
|
| 7 |
"""example output: (yaml)
|
| 8 |
|
| 9 |
"""
|
| 10 |
|
| 11 |
+
|
| 12 |
@argbind.bind(without_prefix=True, positional=True)
|
| 13 |
def fine_tune(audio_files_or_folders: List[str], name: str):
|
| 14 |
|
| 15 |
conf_dir = Path("conf")
|
| 16 |
+
assert conf_dir.exists(), (
|
| 17 |
+
"conf directory not found. are you in the vampnet directory?"
|
| 18 |
+
)
|
| 19 |
|
| 20 |
conf_dir = conf_dir / "generated"
|
| 21 |
conf_dir.mkdir(exist_ok=True)
|
|
|
|
| 36 |
"AudioDataset.duration": 3.0,
|
| 37 |
"AudioDataset.loudness_cutoff": -40.0,
|
| 38 |
"save_path": f"./runs/{name}/c2f",
|
| 39 |
+
"fine_tune_checkpoint": "./models/vampnet/c2f.pth",
|
| 40 |
}
|
| 41 |
|
| 42 |
finetune_coarse_conf = {
|
|
|
|
| 45 |
"train/AudioLoader.sources": audio_files_or_folders,
|
| 46 |
"val/AudioLoader.sources": audio_files_or_folders,
|
| 47 |
"save_path": f"./runs/{name}/coarse",
|
| 48 |
+
"fine_tune_checkpoint": "./models/vampnet/coarse.pth",
|
| 49 |
}
|
| 50 |
|
| 51 |
interface_conf = {
|
| 52 |
"Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
|
|
|
|
| 53 |
"Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
|
| 54 |
"Interface.wavebeat_ckpt": "./models/wavebeat.pth",
|
|
|
|
| 55 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 56 |
"AudioLoader.sources": [audio_files_or_folders],
|
| 57 |
}
|
|
|
|
| 62 |
|
| 63 |
with open(finetune_dir / "coarse.yml", "w") as f:
|
| 64 |
yaml.dump(finetune_coarse_conf, f)
|
| 65 |
+
|
| 66 |
+
with open(finetune_dir / "interface.yml", "w") as f:
|
| 67 |
yaml.dump(interface_conf, f)
|
| 68 |
|
| 69 |
+
print(
|
| 70 |
+
f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` "
|
| 71 |
+
)
|
| 72 |
|
|
|
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
| 75 |
args = argbind.parse_args()
|
| 76 |
|
| 77 |
with argbind.scope(args):
|
| 78 |
fine_tune()
|
|
|
|
|
|
|
|
|
|
|
|
vampnet/scripts/exp/train.py
CHANGED
|
@@ -1,36 +1,33 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import warnings
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Optional
|
| 6 |
-
from dataclasses import dataclass
|
| 7 |
|
| 8 |
import argbind
|
| 9 |
import audiotools as at
|
|
|
|
| 10 |
import torch
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
from audiotools import AudioSignal
|
| 13 |
from audiotools.data import transforms
|
|
|
|
| 14 |
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
| 15 |
from rich import pretty
|
| 16 |
from rich.traceback import install
|
| 17 |
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
|
| 19 |
import vampnet
|
| 20 |
-
from vampnet.modules.transformer import VampNet
|
| 21 |
-
from vampnet.util import codebook_unflatten, codebook_flatten
|
| 22 |
from vampnet import mask as pmask
|
| 23 |
-
|
| 24 |
-
from
|
| 25 |
-
|
| 26 |
-
from audiotools.ml.decorators import (
|
| 27 |
-
timer, Tracker, when
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
import loralib as lora
|
| 31 |
|
| 32 |
-
|
| 33 |
-
torch._dynamo.config.verbose=True
|
| 34 |
|
| 35 |
|
| 36 |
# Enable cudnn autotuner to speed up training
|
|
@@ -50,11 +47,15 @@ AdamW = argbind.bind(torch.optim.AdamW)
|
|
| 50 |
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
| 51 |
|
| 52 |
# transforms
|
| 53 |
-
filter_fn = lambda fn:
|
| 54 |
-
"
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
| 59 |
|
| 60 |
# model
|
|
@@ -106,13 +107,14 @@ def flip_coin(shape, p, rng):
|
|
| 106 |
|
| 107 |
|
| 108 |
def num_params_hook(o, p):
|
| 109 |
-
return o + f" {p/1e6:<.3f}M params."
|
| 110 |
|
| 111 |
|
| 112 |
def add_num_params_repr_hook(model):
|
| 113 |
-
import numpy as np
|
| 114 |
from functools import partial
|
| 115 |
|
|
|
|
|
|
|
| 116 |
for n, m in model.named_modules():
|
| 117 |
o = m.extra_repr()
|
| 118 |
p = sum([np.prod(p.size()) for p in m.parameters()])
|
|
@@ -149,6 +151,7 @@ def accuracy(
|
|
| 149 |
|
| 150 |
return accuracy
|
| 151 |
|
|
|
|
| 152 |
def _metrics(z_hat, r, target, flat_mask, output):
|
| 153 |
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
| 154 |
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
|
@@ -219,7 +222,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 219 |
mask = pmask.random(z, r)
|
| 220 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 221 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 222 |
-
|
| 223 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 224 |
|
| 225 |
dtype = torch.bfloat16 if accel.amp else None
|
|
@@ -246,13 +249,11 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 246 |
output=output,
|
| 247 |
)
|
| 248 |
|
| 249 |
-
|
| 250 |
accel.backward(output["loss"])
|
| 251 |
|
| 252 |
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
| 253 |
output["other/batch_size"] = z.shape[0]
|
| 254 |
|
| 255 |
-
|
| 256 |
accel.scaler.unscale_(state.optimizer)
|
| 257 |
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
| 258 |
state.model.parameters(), state.grad_clip_val
|
|
@@ -264,7 +265,6 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 264 |
state.scheduler.step()
|
| 265 |
accel.update()
|
| 266 |
|
| 267 |
-
|
| 268 |
return {k: v for k, v in sorted(output.items())}
|
| 269 |
|
| 270 |
|
|
@@ -295,9 +295,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 295 |
z[:, vn.n_conditioning_codebooks :, :],
|
| 296 |
)
|
| 297 |
|
| 298 |
-
flat_mask = codebook_flatten(
|
| 299 |
-
mask[:, vn.n_conditioning_codebooks :, :]
|
| 300 |
-
)
|
| 301 |
|
| 302 |
output = {}
|
| 303 |
# replace target with ignore index for masked tokens
|
|
@@ -338,16 +336,16 @@ def checkpoint(state, save_iters, save_path, fine_tune):
|
|
| 338 |
tags.append(f"{state.tracker.step // 1000}k")
|
| 339 |
|
| 340 |
if state.tracker.is_best("val", "loss"):
|
| 341 |
-
state.tracker.print(
|
| 342 |
tags.append("best")
|
| 343 |
|
| 344 |
if fine_tune:
|
| 345 |
-
for tag in tags:
|
| 346 |
-
# save the lora model
|
| 347 |
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
| 348 |
torch.save(
|
| 349 |
-
lora.lora_state_dict(accel.unwrap(state.model)),
|
| 350 |
-
f"{save_path}/{tag}/lora.pth"
|
| 351 |
)
|
| 352 |
|
| 353 |
for tag in tags:
|
|
@@ -383,7 +381,7 @@ def save_sampled(state, z, writer):
|
|
| 383 |
|
| 384 |
def save_imputation(state, z, val_idx, writer):
|
| 385 |
n_prefix = int(z.shape[-1] * 0.25)
|
| 386 |
-
n_suffix = int(z.shape[-1] *
|
| 387 |
|
| 388 |
vn = accel.unwrap(state.model)
|
| 389 |
|
|
@@ -402,8 +400,8 @@ def save_imputation(state, z, val_idx, writer):
|
|
| 402 |
time_steps=z.shape[-1],
|
| 403 |
start_tokens=z[i][None, ...],
|
| 404 |
mask=mask[i][None, ...],
|
| 405 |
-
)
|
| 406 |
-
)
|
| 407 |
imputed = AudioSignal.batch(imputed)
|
| 408 |
|
| 409 |
for i in range(len(val_idx)):
|
|
@@ -443,7 +441,6 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
| 443 |
|
| 444 |
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
| 445 |
|
| 446 |
-
|
| 447 |
mask = pmask.random(z, r)
|
| 448 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 449 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
|
@@ -479,7 +476,6 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
| 479 |
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
| 480 |
|
| 481 |
|
| 482 |
-
|
| 483 |
@argbind.bind(without_prefix=True)
|
| 484 |
def load(
|
| 485 |
args,
|
|
@@ -499,11 +495,12 @@ def load(
|
|
| 499 |
if args["fine_tune"]:
|
| 500 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 501 |
model = torch.compile(
|
| 502 |
-
VampNet.load(
|
| 503 |
-
|
|
|
|
| 504 |
)
|
| 505 |
)
|
| 506 |
-
|
| 507 |
if resume:
|
| 508 |
kwargs = {
|
| 509 |
"folder": f"{save_path}/{tag}",
|
|
@@ -518,16 +515,11 @@ def load(
|
|
| 518 |
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
| 519 |
)
|
| 520 |
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
model = torch.compile(VampNet()) if model is None else model
|
| 525 |
model = accel.prepare_model(model)
|
| 526 |
|
| 527 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
| 528 |
-
assert (
|
| 529 |
-
accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
| 530 |
-
)
|
| 531 |
|
| 532 |
optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
|
| 533 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
|
@@ -538,13 +530,13 @@ def load(
|
|
| 538 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
| 539 |
if "tracker.pth" in v_extra:
|
| 540 |
tracker.load_state_dict(v_extra["tracker.pth"])
|
| 541 |
-
|
| 542 |
criterion = CrossEntropyLoss()
|
| 543 |
|
| 544 |
sample_rate = codec.sample_rate
|
| 545 |
|
| 546 |
# a better rng for sampling from our schedule
|
| 547 |
-
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
| 548 |
|
| 549 |
# log a model summary w/ num params
|
| 550 |
if accel.local_rank == 0:
|
|
@@ -577,13 +569,19 @@ def train(
|
|
| 577 |
codec_ckpt: str = None,
|
| 578 |
save_path: str = "ckpt",
|
| 579 |
num_iters: int = int(1000e6),
|
| 580 |
-
save_iters: list = [
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
val_freq: int = 1000,
|
| 583 |
batch_size: int = 12,
|
| 584 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
| 585 |
num_workers: int = 10,
|
| 586 |
-
fine_tune: bool = False,
|
| 587 |
):
|
| 588 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
| 589 |
|
|
@@ -600,11 +598,7 @@ def train(
|
|
| 600 |
)
|
| 601 |
|
| 602 |
# load the codec model
|
| 603 |
-
state: State = load(
|
| 604 |
-
args=args,
|
| 605 |
-
accel=accel,
|
| 606 |
-
tracker=tracker,
|
| 607 |
-
save_path=save_path)
|
| 608 |
print("initialized state.")
|
| 609 |
|
| 610 |
train_dataloader = accel.prepare_dataloader(
|
|
@@ -624,8 +618,6 @@ def train(
|
|
| 624 |
)
|
| 625 |
print("initialized dataloader.")
|
| 626 |
|
| 627 |
-
|
| 628 |
-
|
| 629 |
if fine_tune:
|
| 630 |
lora.mark_only_lora_as_trainable(state.model)
|
| 631 |
print("marked only lora as trainable.")
|
|
@@ -658,10 +650,11 @@ def train(
|
|
| 658 |
if tracker.step % val_freq == 0 or last_iter:
|
| 659 |
validate(state, val_dataloader, accel)
|
| 660 |
checkpoint(
|
| 661 |
-
state=state,
|
| 662 |
-
save_iters=save_iters,
|
| 663 |
-
save_path=save_path,
|
| 664 |
-
fine_tune=fine_tune
|
|
|
|
| 665 |
|
| 666 |
# Reset validation progress bar, print summary since last validation.
|
| 667 |
tracker.done("val", f"Iteration {tracker.step}")
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
import warnings
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Optional
|
|
|
|
| 7 |
|
| 8 |
import argbind
|
| 9 |
import audiotools as at
|
| 10 |
+
import loralib as lora
|
| 11 |
import torch
|
| 12 |
+
import torch._dynamo
|
| 13 |
import torch.nn as nn
|
| 14 |
from audiotools import AudioSignal
|
| 15 |
from audiotools.data import transforms
|
| 16 |
+
from audiotools.ml.decorators import Tracker, timer, when
|
| 17 |
from einops import rearrange
|
| 18 |
+
|
| 19 |
+
# from dac.model.dac import DAC
|
| 20 |
+
from lac.model.lac import LAC as DAC
|
| 21 |
from rich import pretty
|
| 22 |
from rich.traceback import install
|
| 23 |
from torch.utils.tensorboard import SummaryWriter
|
| 24 |
|
| 25 |
import vampnet
|
|
|
|
|
|
|
| 26 |
from vampnet import mask as pmask
|
| 27 |
+
from vampnet.modules.transformer import VampNet
|
| 28 |
+
from vampnet.util import codebook_flatten, codebook_unflatten
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
torch._dynamo.config.verbose = True
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# Enable cudnn autotuner to speed up training
|
|
|
|
| 47 |
NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
|
| 48 |
|
| 49 |
# transforms
|
| 50 |
+
filter_fn = lambda fn: (
|
| 51 |
+
hasattr(fn, "transform")
|
| 52 |
+
and fn.__qualname__
|
| 53 |
+
not in [
|
| 54 |
+
"BaseTransform",
|
| 55 |
+
"Compose",
|
| 56 |
+
"Choose",
|
| 57 |
+
]
|
| 58 |
+
)
|
| 59 |
tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
|
| 60 |
|
| 61 |
# model
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def num_params_hook(o, p):
|
| 110 |
+
return o + f" {p / 1e6:<.3f}M params."
|
| 111 |
|
| 112 |
|
| 113 |
def add_num_params_repr_hook(model):
|
|
|
|
| 114 |
from functools import partial
|
| 115 |
|
| 116 |
+
import numpy as np
|
| 117 |
+
|
| 118 |
for n, m in model.named_modules():
|
| 119 |
o = m.extra_repr()
|
| 120 |
p = sum([np.prod(p.size()) for p in m.parameters()])
|
|
|
|
| 151 |
|
| 152 |
return accuracy
|
| 153 |
|
| 154 |
+
|
| 155 |
def _metrics(z_hat, r, target, flat_mask, output):
|
| 156 |
for r_range in [(0, 0.5), (0.5, 1.0)]:
|
| 157 |
unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
|
|
|
|
| 222 |
mask = pmask.random(z, r)
|
| 223 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 224 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
| 225 |
+
|
| 226 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 227 |
|
| 228 |
dtype = torch.bfloat16 if accel.amp else None
|
|
|
|
| 249 |
output=output,
|
| 250 |
)
|
| 251 |
|
|
|
|
| 252 |
accel.backward(output["loss"])
|
| 253 |
|
| 254 |
output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
|
| 255 |
output["other/batch_size"] = z.shape[0]
|
| 256 |
|
|
|
|
| 257 |
accel.scaler.unscale_(state.optimizer)
|
| 258 |
output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
| 259 |
state.model.parameters(), state.grad_clip_val
|
|
|
|
| 265 |
state.scheduler.step()
|
| 266 |
accel.update()
|
| 267 |
|
|
|
|
| 268 |
return {k: v for k, v in sorted(output.items())}
|
| 269 |
|
| 270 |
|
|
|
|
| 295 |
z[:, vn.n_conditioning_codebooks :, :],
|
| 296 |
)
|
| 297 |
|
| 298 |
+
flat_mask = codebook_flatten(mask[:, vn.n_conditioning_codebooks :, :])
|
|
|
|
|
|
|
| 299 |
|
| 300 |
output = {}
|
| 301 |
# replace target with ignore index for masked tokens
|
|
|
|
| 336 |
tags.append(f"{state.tracker.step // 1000}k")
|
| 337 |
|
| 338 |
if state.tracker.is_best("val", "loss"):
|
| 339 |
+
state.tracker.print("Best model so far")
|
| 340 |
tags.append("best")
|
| 341 |
|
| 342 |
if fine_tune:
|
| 343 |
+
for tag in tags:
|
| 344 |
+
# save the lora model
|
| 345 |
(Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
|
| 346 |
torch.save(
|
| 347 |
+
lora.lora_state_dict(accel.unwrap(state.model)),
|
| 348 |
+
f"{save_path}/{tag}/lora.pth",
|
| 349 |
)
|
| 350 |
|
| 351 |
for tag in tags:
|
|
|
|
| 381 |
|
| 382 |
def save_imputation(state, z, val_idx, writer):
|
| 383 |
n_prefix = int(z.shape[-1] * 0.25)
|
| 384 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
| 385 |
|
| 386 |
vn = accel.unwrap(state.model)
|
| 387 |
|
|
|
|
| 400 |
time_steps=z.shape[-1],
|
| 401 |
start_tokens=z[i][None, ...],
|
| 402 |
mask=mask[i][None, ...],
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
imputed = AudioSignal.batch(imputed)
|
| 406 |
|
| 407 |
for i in range(len(val_idx)):
|
|
|
|
| 441 |
|
| 442 |
r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
|
| 443 |
|
|
|
|
| 444 |
mask = pmask.random(z, r)
|
| 445 |
mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
|
| 446 |
z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
|
|
|
|
| 476 |
save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
|
| 477 |
|
| 478 |
|
|
|
|
| 479 |
@argbind.bind(without_prefix=True)
|
| 480 |
def load(
|
| 481 |
args,
|
|
|
|
| 495 |
if args["fine_tune"]:
|
| 496 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 497 |
model = torch.compile(
|
| 498 |
+
VampNet.load(
|
| 499 |
+
location=Path(fine_tune_checkpoint),
|
| 500 |
+
map_location="cpu",
|
| 501 |
)
|
| 502 |
)
|
| 503 |
+
|
| 504 |
if resume:
|
| 505 |
kwargs = {
|
| 506 |
"folder": f"{save_path}/{tag}",
|
|
|
|
| 515 |
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
| 516 |
)
|
| 517 |
|
|
|
|
|
|
|
|
|
|
| 518 |
model = torch.compile(VampNet()) if model is None else model
|
| 519 |
model = accel.prepare_model(model)
|
| 520 |
|
| 521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
| 522 |
+
assert accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
|
|
|
|
|
|
|
| 523 |
|
| 524 |
optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
|
| 525 |
scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
|
|
|
|
| 530 |
scheduler.load_state_dict(v_extra["scheduler.pth"])
|
| 531 |
if "tracker.pth" in v_extra:
|
| 532 |
tracker.load_state_dict(v_extra["tracker.pth"])
|
| 533 |
+
|
| 534 |
criterion = CrossEntropyLoss()
|
| 535 |
|
| 536 |
sample_rate = codec.sample_rate
|
| 537 |
|
| 538 |
# a better rng for sampling from our schedule
|
| 539 |
+
rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
|
| 540 |
|
| 541 |
# log a model summary w/ num params
|
| 542 |
if accel.local_rank == 0:
|
|
|
|
| 569 |
codec_ckpt: str = None,
|
| 570 |
save_path: str = "ckpt",
|
| 571 |
num_iters: int = int(1000e6),
|
| 572 |
+
save_iters: list = [
|
| 573 |
+
10000,
|
| 574 |
+
50000,
|
| 575 |
+
100000,
|
| 576 |
+
300000,
|
| 577 |
+
500000,
|
| 578 |
+
],
|
| 579 |
+
sample_freq: int = 10000,
|
| 580 |
val_freq: int = 1000,
|
| 581 |
batch_size: int = 12,
|
| 582 |
val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
| 583 |
num_workers: int = 10,
|
| 584 |
+
fine_tune: bool = False,
|
| 585 |
):
|
| 586 |
assert codec_ckpt is not None, "codec_ckpt is required"
|
| 587 |
|
|
|
|
| 598 |
)
|
| 599 |
|
| 600 |
# load the codec model
|
| 601 |
+
state: State = load(args=args, accel=accel, tracker=tracker, save_path=save_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
print("initialized state.")
|
| 603 |
|
| 604 |
train_dataloader = accel.prepare_dataloader(
|
|
|
|
| 618 |
)
|
| 619 |
print("initialized dataloader.")
|
| 620 |
|
|
|
|
|
|
|
| 621 |
if fine_tune:
|
| 622 |
lora.mark_only_lora_as_trainable(state.model)
|
| 623 |
print("marked only lora as trainable.")
|
|
|
|
| 650 |
if tracker.step % val_freq == 0 or last_iter:
|
| 651 |
validate(state, val_dataloader, accel)
|
| 652 |
checkpoint(
|
| 653 |
+
state=state,
|
| 654 |
+
save_iters=save_iters,
|
| 655 |
+
save_path=save_path,
|
| 656 |
+
fine_tune=fine_tune,
|
| 657 |
+
)
|
| 658 |
|
| 659 |
# Reset validation progress bar, print summary since last validation.
|
| 660 |
tracker.done("val", f"Iteration {tracker.step}")
|
vampnet/scripts/utils/data/augment.py
CHANGED
|
@@ -1,17 +1,12 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
|
| 3 |
-
import audiotools as at
|
| 4 |
-
from audiotools import AudioSignal
|
| 5 |
-
|
| 6 |
import argbind
|
| 7 |
-
import
|
| 8 |
import torch
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from torch_pitch_shift import
|
| 12 |
-
from torch_time_stretch import
|
| 13 |
-
|
| 14 |
-
from audiotools.core.util import sample_from_dist
|
| 15 |
|
| 16 |
|
| 17 |
@argbind.bind(without_prefix=True)
|
|
@@ -20,11 +15,11 @@ def augment(
|
|
| 20 |
dest_folder: Path = None,
|
| 21 |
n_augmentations: int = 10,
|
| 22 |
):
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
"""
|
| 29 |
assert audio_folder is not None
|
| 30 |
assert dest_folder is not None
|
|
@@ -37,24 +32,31 @@ def augment(
|
|
| 37 |
|
| 38 |
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
|
| 40 |
-
|
| 41 |
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
| 42 |
# apply pedalboard transforms
|
| 43 |
for j in range(n_augmentations):
|
| 44 |
# pitch shift between -7 and 7 semitones
|
| 45 |
import random
|
|
|
|
| 46 |
dst = chunk.clone()
|
| 47 |
dst.samples = pitch_shift(
|
| 48 |
-
dst.samples,
|
| 49 |
-
shift=random.choice(
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
dst.samples = time_stretch(
|
| 54 |
dst.samples,
|
| 55 |
-
stretch=random.choice(
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
)
|
| 59 |
|
| 60 |
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
|
@@ -64,4 +66,4 @@ if __name__ == "__main__":
|
|
| 64 |
args = argbind.parse_args()
|
| 65 |
|
| 66 |
with argbind.scope(args):
|
| 67 |
-
augment()
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
import argbind
|
| 4 |
+
import audiotools as at
|
| 5 |
import torch
|
| 6 |
+
import tqdm
|
| 7 |
+
from audiotools import AudioSignal
|
| 8 |
+
from torch_pitch_shift import get_fast_shifts, pitch_shift
|
| 9 |
+
from torch_time_stretch import get_fast_stretches, time_stretch
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
@argbind.bind(without_prefix=True)
|
|
|
|
| 15 |
dest_folder: Path = None,
|
| 16 |
n_augmentations: int = 10,
|
| 17 |
):
|
| 18 |
+
"""
|
| 19 |
+
Augment a folder of audio files by applying audiotools and pedalboard transforms.
|
| 20 |
|
| 21 |
+
The dest foler will contain a folder for each of the clean dataset's files.
|
| 22 |
+
Under each of these folders, there will be a clean file and many augmented files.
|
| 23 |
"""
|
| 24 |
assert audio_folder is not None
|
| 25 |
assert dest_folder is not None
|
|
|
|
| 32 |
|
| 33 |
src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 34 |
|
|
|
|
| 35 |
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
| 36 |
# apply pedalboard transforms
|
| 37 |
for j in range(n_augmentations):
|
| 38 |
# pitch shift between -7 and 7 semitones
|
| 39 |
import random
|
| 40 |
+
|
| 41 |
dst = chunk.clone()
|
| 42 |
dst.samples = pitch_shift(
|
| 43 |
+
dst.samples,
|
| 44 |
+
shift=random.choice(
|
| 45 |
+
get_fast_shifts(
|
| 46 |
+
src.sample_rate, condition=lambda x: x >= 0.25 and x <= 1.0
|
| 47 |
+
)
|
| 48 |
+
),
|
| 49 |
+
sample_rate=src.sample_rate,
|
| 50 |
)
|
| 51 |
dst.samples = time_stretch(
|
| 52 |
dst.samples,
|
| 53 |
+
stretch=random.choice(
|
| 54 |
+
get_fast_stretches(
|
| 55 |
+
src.sample_rate,
|
| 56 |
+
condition=lambda x: x >= 0.667 and x <= 1.5,
|
| 57 |
+
)
|
| 58 |
+
),
|
| 59 |
+
sample_rate=src.sample_rate,
|
| 60 |
)
|
| 61 |
|
| 62 |
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
|
|
|
| 66 |
args = argbind.parse_args()
|
| 67 |
|
| 68 |
with argbind.scope(args):
|
| 69 |
+
augment()
|
vampnet/scripts/utils/data/maestro-reorg.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
import json
|
| 3 |
import os
|
|
|
|
| 4 |
|
| 5 |
maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
|
| 6 |
output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
|
|
@@ -14,7 +14,7 @@ train = []
|
|
| 14 |
validation = []
|
| 15 |
test = []
|
| 16 |
for key, split in maestro["split"].items():
|
| 17 |
-
audio_filename = maestro[
|
| 18 |
if split == "train":
|
| 19 |
train.append(audio_filename)
|
| 20 |
elif split == "test":
|
|
@@ -36,4 +36,4 @@ for audio_filename in validation:
|
|
| 36 |
for audio_filename in test:
|
| 37 |
p = output_path / "test" / audio_filename
|
| 38 |
p.parent.mkdir(parents=True, exist_ok=True)
|
| 39 |
-
os.symlink(maestro_path / audio_filename, p)
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
|
| 5 |
maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
|
| 6 |
output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
|
|
|
|
| 14 |
validation = []
|
| 15 |
test = []
|
| 16 |
for key, split in maestro["split"].items():
|
| 17 |
+
audio_filename = maestro["audio_filename"][key]
|
| 18 |
if split == "train":
|
| 19 |
train.append(audio_filename)
|
| 20 |
elif split == "test":
|
|
|
|
| 36 |
for audio_filename in test:
|
| 37 |
p = output_path / "test" / audio_filename
|
| 38 |
p.parent.mkdir(parents=True, exist_ok=True)
|
| 39 |
+
os.symlink(maestro_path / audio_filename, p)
|
vampnet/scripts/utils/plots.py
CHANGED
|
@@ -2,16 +2,19 @@ import matplotlib.pyplot as plt
|
|
| 2 |
import seaborn as sns
|
| 3 |
from pandas.api.types import CategoricalDtype
|
| 4 |
|
|
|
|
| 5 |
def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
| 6 |
# Add a new column to your dataframe with the latex representation
|
| 7 |
-
metrics[
|
| 8 |
|
| 9 |
# Order condition_latex as per the condition_to_latex dictionary
|
| 10 |
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
|
| 11 |
-
metrics[
|
| 12 |
|
| 13 |
# Compute mean and std for each condition for each metric
|
| 14 |
-
grouped = metrics.groupby(
|
|
|
|
|
|
|
| 15 |
|
| 16 |
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
|
| 17 |
|
|
@@ -22,16 +25,28 @@ def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
|
| 22 |
bar_colors = [color_palette[condition] for condition in grouped.index]
|
| 23 |
|
| 24 |
# Plot mel
|
| 25 |
-
sns.boxplot(
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Plot frechet
|
| 31 |
-
axs[1].bar(
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Adjust the space between plots
|
| 37 |
plt.subplots_adjust(hspace=0.1)
|
|
@@ -40,4 +55,4 @@ def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
|
| 40 |
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 41 |
|
| 42 |
# Reduce the space between suptitle and the plot
|
| 43 |
-
plt.subplots_adjust(top=0.92)
|
|
|
|
| 2 |
import seaborn as sns
|
| 3 |
from pandas.api.types import CategoricalDtype
|
| 4 |
|
| 5 |
+
|
| 6 |
def plot_metrics(metrics, condition_to_latex, title, color_palette):
|
| 7 |
# Add a new column to your dataframe with the latex representation
|
| 8 |
+
metrics["condition_latex"] = metrics["condition"].map(condition_to_latex)
|
| 9 |
|
| 10 |
# Order condition_latex as per the condition_to_latex dictionary
|
| 11 |
cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
|
| 12 |
+
metrics["condition_latex"] = metrics["condition_latex"].astype(cat_type)
|
| 13 |
|
| 14 |
# Compute mean and std for each condition for each metric
|
| 15 |
+
grouped = metrics.groupby("condition_latex")[["mel", "frechet"]].agg(
|
| 16 |
+
["mean", "std"]
|
| 17 |
+
)
|
| 18 |
|
| 19 |
fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
|
| 20 |
|
|
|
|
| 25 |
bar_colors = [color_palette[condition] for condition in grouped.index]
|
| 26 |
|
| 27 |
# Plot mel
|
| 28 |
+
sns.boxplot(
|
| 29 |
+
x="condition_latex",
|
| 30 |
+
y="mel",
|
| 31 |
+
data=metrics,
|
| 32 |
+
ax=axs[0],
|
| 33 |
+
palette=color_palette,
|
| 34 |
+
showfliers=False,
|
| 35 |
+
)
|
| 36 |
+
axs[0].set_ylabel("Mel Spectrogram Loss \u2190")
|
| 37 |
+
axs[0].set_xlabel("") # Remove x-axis label
|
| 38 |
+
axs[0].set_xticklabels(grouped.index, rotation=0, ha="center")
|
| 39 |
|
| 40 |
# Plot frechet
|
| 41 |
+
axs[1].bar(
|
| 42 |
+
grouped.index,
|
| 43 |
+
grouped["frechet"]["mean"],
|
| 44 |
+
yerr=grouped["frechet"]["std"],
|
| 45 |
+
color=bar_colors,
|
| 46 |
+
)
|
| 47 |
+
axs[1].set_ylabel("FAD \u2190")
|
| 48 |
+
axs[1].set_xlabel("") # Remove x-axis label
|
| 49 |
+
axs[1].set_xticklabels(grouped.index, rotation=0, ha="center")
|
| 50 |
|
| 51 |
# Adjust the space between plots
|
| 52 |
plt.subplots_adjust(hspace=0.1)
|
|
|
|
| 55 |
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
| 56 |
|
| 57 |
# Reduce the space between suptitle and the plot
|
| 58 |
+
plt.subplots_adjust(top=0.92)
|
vampnet/scripts/utils/remove_quiet_files.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
# removes files with loudness below 24db
|
| 2 |
|
| 3 |
-
from pathlib import Path
|
| 4 |
import shutil
|
| 5 |
-
|
|
|
|
| 6 |
import argbind
|
|
|
|
|
|
|
| 7 |
|
| 8 |
@argbind.bind(without_prefix=True)
|
| 9 |
def remove_quiet_files(
|
|
@@ -14,7 +16,7 @@ def remove_quiet_files(
|
|
| 14 |
# copy src to dest
|
| 15 |
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 16 |
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
| 17 |
-
|
| 18 |
audio_files = at.util.find_audio(dest_dir)
|
| 19 |
for audio_file in audio_files:
|
| 20 |
sig = at.AudioSignal(audio_file)
|
|
@@ -22,8 +24,9 @@ def remove_quiet_files(
|
|
| 22 |
audio_file.unlink()
|
| 23 |
print(f"removed {audio_file}")
|
| 24 |
|
|
|
|
| 25 |
if __name__ == "__main__":
|
| 26 |
args = argbind.parse_args()
|
| 27 |
|
| 28 |
with argbind.scope(args):
|
| 29 |
-
remove_quiet_files()
|
|
|
|
| 1 |
# removes files with loudness below 24db
|
| 2 |
|
|
|
|
| 3 |
import shutil
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
import argbind
|
| 7 |
+
import audiotools as at
|
| 8 |
+
|
| 9 |
|
| 10 |
@argbind.bind(without_prefix=True)
|
| 11 |
def remove_quiet_files(
|
|
|
|
| 16 |
# copy src to dest
|
| 17 |
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
| 19 |
+
|
| 20 |
audio_files = at.util.find_audio(dest_dir)
|
| 21 |
for audio_file in audio_files:
|
| 22 |
sig = at.AudioSignal(audio_file)
|
|
|
|
| 24 |
audio_file.unlink()
|
| 25 |
print(f"removed {audio_file}")
|
| 26 |
|
| 27 |
+
|
| 28 |
if __name__ == "__main__":
|
| 29 |
args = argbind.parse_args()
|
| 30 |
|
| 31 |
with argbind.scope(args):
|
| 32 |
+
remove_quiet_files()
|
vampnet/scripts/utils/split.py
CHANGED
|
@@ -1,29 +1,25 @@
|
|
| 1 |
-
|
| 2 |
-
import random
|
| 3 |
-
import shutil
|
| 4 |
import os
|
| 5 |
-
import
|
|
|
|
| 6 |
|
| 7 |
import argbind
|
| 8 |
from tqdm import tqdm
|
| 9 |
-
from tqdm.contrib.concurrent import thread_map
|
| 10 |
-
|
| 11 |
-
from audiotools.core import util
|
| 12 |
|
| 13 |
|
| 14 |
@argbind.bind(without_prefix=True)
|
| 15 |
def train_test_split(
|
| 16 |
-
audio_folder: str = ".",
|
| 17 |
test_size: float = 0.2,
|
| 18 |
seed: int = 42,
|
| 19 |
pattern: str = "**/*.mp3",
|
| 20 |
):
|
| 21 |
-
print(
|
| 22 |
|
| 23 |
audio_folder = Path(audio_folder)
|
| 24 |
audio_files = list(tqdm(audio_folder.glob(pattern)))
|
| 25 |
print(f"found {len(audio_files)} audio files")
|
| 26 |
-
|
| 27 |
# split according to test_size
|
| 28 |
n_test = int(len(audio_files) * test_size)
|
| 29 |
n_train = len(audio_files) - n_test
|
|
@@ -35,30 +31,28 @@ def train_test_split(
|
|
| 35 |
train_files = audio_files[:n_train]
|
| 36 |
test_files = audio_files[n_train:]
|
| 37 |
|
| 38 |
-
|
| 39 |
print(f"Train files: {len(train_files)}")
|
| 40 |
print(f"Test files: {len(test_files)}")
|
| 41 |
continue_ = input("Continue [yn]? ") or "n"
|
| 42 |
|
| 43 |
if continue_ != "y":
|
| 44 |
return
|
| 45 |
-
|
| 46 |
-
for split, files in (
|
| 47 |
-
("train", train_files), ("test", test_files)
|
| 48 |
-
):
|
| 49 |
for file in tqdm(files):
|
| 50 |
-
out_file =
|
|
|
|
|
|
|
| 51 |
out_file.parent.mkdir(exist_ok=True, parents=True)
|
| 52 |
os.symlink(file, out_file)
|
| 53 |
|
| 54 |
# save split as json
|
| 55 |
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
| 56 |
json.dump([str(f) for f in files], f)
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
if __name__ == "__main__":
|
| 61 |
-
args
|
| 62 |
|
| 63 |
with argbind.scope(args):
|
| 64 |
-
train_test_split()
|
|
|
|
| 1 |
+
import json
|
|
|
|
|
|
|
| 2 |
import os
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
|
| 6 |
import argbind
|
| 7 |
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
@argbind.bind(without_prefix=True)
|
| 11 |
def train_test_split(
|
| 12 |
+
audio_folder: str = ".",
|
| 13 |
test_size: float = 0.2,
|
| 14 |
seed: int = 42,
|
| 15 |
pattern: str = "**/*.mp3",
|
| 16 |
):
|
| 17 |
+
print("finding audio")
|
| 18 |
|
| 19 |
audio_folder = Path(audio_folder)
|
| 20 |
audio_files = list(tqdm(audio_folder.glob(pattern)))
|
| 21 |
print(f"found {len(audio_files)} audio files")
|
| 22 |
+
|
| 23 |
# split according to test_size
|
| 24 |
n_test = int(len(audio_files) * test_size)
|
| 25 |
n_train = len(audio_files) - n_test
|
|
|
|
| 31 |
train_files = audio_files[:n_train]
|
| 32 |
test_files = audio_files[n_train:]
|
| 33 |
|
|
|
|
| 34 |
print(f"Train files: {len(train_files)}")
|
| 35 |
print(f"Test files: {len(test_files)}")
|
| 36 |
continue_ = input("Continue [yn]? ") or "n"
|
| 37 |
|
| 38 |
if continue_ != "y":
|
| 39 |
return
|
| 40 |
+
|
| 41 |
+
for split, files in (("train", train_files), ("test", test_files)):
|
|
|
|
|
|
|
| 42 |
for file in tqdm(files):
|
| 43 |
+
out_file = (
|
| 44 |
+
audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
| 45 |
+
)
|
| 46 |
out_file.parent.mkdir(exist_ok=True, parents=True)
|
| 47 |
os.symlink(file, out_file)
|
| 48 |
|
| 49 |
# save split as json
|
| 50 |
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
| 51 |
json.dump([str(f) for f in files], f)
|
|
|
|
| 52 |
|
| 53 |
+
|
| 54 |
if __name__ == "__main__":
|
| 55 |
+
args = argbind.parse_args()
|
| 56 |
|
| 57 |
with argbind.scope(args):
|
| 58 |
+
train_test_split()
|
vampnet/scripts/utils/split_long_audio_file.py
CHANGED
|
@@ -1,34 +1,37 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
import argbind
|
| 3 |
|
|
|
|
| 4 |
import audiotools as at
|
| 5 |
import tqdm
|
| 6 |
|
| 7 |
|
| 8 |
@argbind.bind(without_prefix=True)
|
| 9 |
-
def split_long_audio_file(
|
| 10 |
-
file: str = None,
|
| 11 |
-
max_chunk_size_s: int = 60*10
|
| 12 |
-
):
|
| 13 |
file = Path(file)
|
| 14 |
output_dir = file.parent / file.stem
|
| 15 |
output_dir.mkdir()
|
| 16 |
-
|
| 17 |
sig = at.AudioSignal(file)
|
| 18 |
|
| 19 |
# split into chunks
|
| 20 |
-
for i, sig in tqdm.tqdm(
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
):
|
| 24 |
sig.write(output_dir / f"{i}.wav")
|
| 25 |
|
| 26 |
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
| 27 |
-
|
| 28 |
return output_dir
|
| 29 |
|
|
|
|
| 30 |
if __name__ == "__main__":
|
| 31 |
args = argbind.parse_args()
|
| 32 |
|
| 33 |
with argbind.scope(args):
|
| 34 |
-
split_long_audio_file()
|
|
|
|
| 1 |
from pathlib import Path
|
|
|
|
| 2 |
|
| 3 |
+
import argbind
|
| 4 |
import audiotools as at
|
| 5 |
import tqdm
|
| 6 |
|
| 7 |
|
| 8 |
@argbind.bind(without_prefix=True)
|
| 9 |
+
def split_long_audio_file(file: str = None, max_chunk_size_s: int = 60 * 10):
|
|
|
|
|
|
|
|
|
|
| 10 |
file = Path(file)
|
| 11 |
output_dir = file.parent / file.stem
|
| 12 |
output_dir.mkdir()
|
| 13 |
+
|
| 14 |
sig = at.AudioSignal(file)
|
| 15 |
|
| 16 |
# split into chunks
|
| 17 |
+
for i, sig in tqdm.tqdm(
|
| 18 |
+
enumerate(
|
| 19 |
+
sig.windows(
|
| 20 |
+
window_duration=max_chunk_size_s,
|
| 21 |
+
hop_duration=max_chunk_size_s / 2,
|
| 22 |
+
preprocess=True,
|
| 23 |
+
)
|
| 24 |
+
)
|
| 25 |
):
|
| 26 |
sig.write(output_dir / f"{i}.wav")
|
| 27 |
|
| 28 |
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
| 29 |
+
|
| 30 |
return output_dir
|
| 31 |
|
| 32 |
+
|
| 33 |
if __name__ == "__main__":
|
| 34 |
args = argbind.parse_args()
|
| 35 |
|
| 36 |
with argbind.scope(args):
|
| 37 |
+
split_long_audio_file()
|
vampnet/scripts/utils/stage.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import subprocess
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
import argbind
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import argbind
|
vampnet/scripts/utils/visualize_embeddings.py
CHANGED
|
@@ -3,19 +3,20 @@ TODO: train a linear probe
|
|
| 3 |
usage:
|
| 4 |
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
|
| 5 |
"""
|
|
|
|
|
|
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import List
|
| 8 |
|
| 9 |
-
import audiotools as at
|
| 10 |
-
from audiotools import AudioSignal
|
| 11 |
import argbind
|
| 12 |
-
import
|
| 13 |
import numpy as np
|
| 14 |
-
import
|
| 15 |
-
import
|
|
|
|
| 16 |
|
| 17 |
from vampnet.interface import Interface
|
| 18 |
-
import tqdm
|
| 19 |
|
| 20 |
# bind the Interface to argbind
|
| 21 |
Interface = argbind.bind(Interface)
|
|
@@ -34,6 +35,7 @@ def smart_plotly_export(fig, save_path: Path):
|
|
| 34 |
# TODO: come back and make this prettier
|
| 35 |
elif img_format == "numpy":
|
| 36 |
import io
|
|
|
|
| 37 |
from PIL import Image
|
| 38 |
|
| 39 |
def plotly_fig2array(fig):
|
|
@@ -72,6 +74,7 @@ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="
|
|
| 72 |
|
| 73 |
if method == "umap":
|
| 74 |
from umap import UMAP
|
|
|
|
| 75 |
reducer = UMAP(n_components=n_components)
|
| 76 |
elif method == "tsne":
|
| 77 |
from sklearn.manifold import TSNE
|
|
@@ -100,11 +103,16 @@ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="
|
|
| 100 |
)
|
| 101 |
if n_components == 2:
|
| 102 |
fig = px.scatter(
|
| 103 |
-
df,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
|
| 106 |
elif n_components == 3:
|
| 107 |
-
df[
|
| 108 |
fig = px.scatter_3d(
|
| 109 |
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
|
| 110 |
)
|
|
@@ -139,15 +147,15 @@ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
|
| 139 |
# [20, 1, 600ish, 768]
|
| 140 |
|
| 141 |
# squeeze batch dim (1 bc layer should be dim 0)
|
| 142 |
-
assert (
|
| 143 |
-
embeddings.shape[
|
| 144 |
-
)
|
| 145 |
embeddings = embeddings.squeeze(1)
|
| 146 |
|
| 147 |
num_layers = embeddings.shape[0]
|
| 148 |
-
assert (
|
| 149 |
-
layer
|
| 150 |
-
)
|
| 151 |
|
| 152 |
# do meanpooling over the time dimension
|
| 153 |
embeddings = embeddings.mean(dim=-2)
|
|
@@ -169,7 +177,6 @@ class AnnotatedEmbedding:
|
|
| 169 |
def save(self, path):
|
| 170 |
"""Save the Embedding object to a given path as a zip file."""
|
| 171 |
with zipfile.ZipFile(path, "w") as archive:
|
| 172 |
-
|
| 173 |
# Save numpy array
|
| 174 |
with archive.open("embedding.npy", "w") as f:
|
| 175 |
np.save(f, self.embedding)
|
|
@@ -187,7 +194,6 @@ class AnnotatedEmbedding:
|
|
| 187 |
def load(cls, path):
|
| 188 |
"""Load the Embedding object from a given zip path."""
|
| 189 |
with zipfile.ZipFile(path, "r") as archive:
|
| 190 |
-
|
| 191 |
# Load numpy array
|
| 192 |
with archive.open("embedding.npy") as f:
|
| 193 |
embedding = np.load(f)
|
|
|
|
| 3 |
usage:
|
| 4 |
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
|
| 5 |
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import zipfile
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import List
|
| 11 |
|
|
|
|
|
|
|
| 12 |
import argbind
|
| 13 |
+
import audiotools as at
|
| 14 |
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import tqdm
|
| 17 |
+
from audiotools import AudioSignal
|
| 18 |
|
| 19 |
from vampnet.interface import Interface
|
|
|
|
| 20 |
|
| 21 |
# bind the Interface to argbind
|
| 22 |
Interface = argbind.bind(Interface)
|
|
|
|
| 35 |
# TODO: come back and make this prettier
|
| 36 |
elif img_format == "numpy":
|
| 37 |
import io
|
| 38 |
+
|
| 39 |
from PIL import Image
|
| 40 |
|
| 41 |
def plotly_fig2array(fig):
|
|
|
|
| 74 |
|
| 75 |
if method == "umap":
|
| 76 |
from umap import UMAP
|
| 77 |
+
|
| 78 |
reducer = UMAP(n_components=n_components)
|
| 79 |
elif method == "tsne":
|
| 80 |
from sklearn.manifold import TSNE
|
|
|
|
| 103 |
)
|
| 104 |
if n_components == 2:
|
| 105 |
fig = px.scatter(
|
| 106 |
+
df,
|
| 107 |
+
x="x",
|
| 108 |
+
y="y",
|
| 109 |
+
color="label",
|
| 110 |
+
hover_name="name",
|
| 111 |
+
title=fig_title,
|
| 112 |
)
|
| 113 |
|
| 114 |
elif n_components == 3:
|
| 115 |
+
df["z"] = projs[:, 2]
|
| 116 |
fig = px.scatter_3d(
|
| 117 |
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
|
| 118 |
)
|
|
|
|
| 147 |
# [20, 1, 600ish, 768]
|
| 148 |
|
| 149 |
# squeeze batch dim (1 bc layer should be dim 0)
|
| 150 |
+
assert embeddings.shape[1] == 1, (
|
| 151 |
+
f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
| 152 |
+
)
|
| 153 |
embeddings = embeddings.squeeze(1)
|
| 154 |
|
| 155 |
num_layers = embeddings.shape[0]
|
| 156 |
+
assert layer < num_layers, (
|
| 157 |
+
f"layer {layer} is out of bounds for model with {num_layers} layers"
|
| 158 |
+
)
|
| 159 |
|
| 160 |
# do meanpooling over the time dimension
|
| 161 |
embeddings = embeddings.mean(dim=-2)
|
|
|
|
| 177 |
def save(self, path):
|
| 178 |
"""Save the Embedding object to a given path as a zip file."""
|
| 179 |
with zipfile.ZipFile(path, "w") as archive:
|
|
|
|
| 180 |
# Save numpy array
|
| 181 |
with archive.open("embedding.npy", "w") as f:
|
| 182 |
np.save(f, self.embedding)
|
|
|
|
| 194 |
def load(cls, path):
|
| 195 |
"""Load the Embedding object from a given zip path."""
|
| 196 |
with zipfile.ZipFile(path, "r") as archive:
|
|
|
|
| 197 |
# Load numpy array
|
| 198 |
with archive.open("embedding.npy") as f:
|
| 199 |
embedding = np.load(f)
|
vampnet/scripts/utils/xeno-canto-dl.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from xenopy import Query
|
| 2 |
|
| 3 |
-
|
| 4 |
SPECIES = [
|
| 5 |
"American Robin",
|
| 6 |
"Northern Cardinal",
|
|
@@ -208,27 +207,36 @@ SPECIES = [
|
|
| 208 |
"American Woodcock",
|
| 209 |
"Wilson's Phalarope",
|
| 210 |
"Red-necked Phalarope",
|
| 211 |
-
"Red Phalarope"
|
| 212 |
]
|
| 213 |
|
| 214 |
from pathlib import Path
|
| 215 |
|
|
|
|
| 216 |
def remove_spaces(s):
|
| 217 |
return s.replace(" ", "")
|
| 218 |
|
| 219 |
-
|
|
|
|
| 220 |
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
| 221 |
continue
|
| 222 |
try:
|
| 223 |
q = Query(
|
| 224 |
-
name=species,
|
| 225 |
-
|
|
|
|
|
|
|
| 226 |
|
| 227 |
# retrieve metadata
|
| 228 |
metafiles = q.retrieve_meta(verbose=True)
|
| 229 |
# retrieve recordings
|
| 230 |
-
q.retrieve_recordings(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
except:
|
| 233 |
print("Failed to download " + species)
|
| 234 |
-
continue
|
|
|
|
| 1 |
from xenopy import Query
|
| 2 |
|
|
|
|
| 3 |
SPECIES = [
|
| 4 |
"American Robin",
|
| 5 |
"Northern Cardinal",
|
|
|
|
| 207 |
"American Woodcock",
|
| 208 |
"Wilson's Phalarope",
|
| 209 |
"Red-necked Phalarope",
|
| 210 |
+
"Red Phalarope",
|
| 211 |
]
|
| 212 |
|
| 213 |
from pathlib import Path
|
| 214 |
|
| 215 |
+
|
| 216 |
def remove_spaces(s):
|
| 217 |
return s.replace(" ", "")
|
| 218 |
|
| 219 |
+
|
| 220 |
+
for species in SPECIES:
|
| 221 |
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
| 222 |
continue
|
| 223 |
try:
|
| 224 |
q = Query(
|
| 225 |
+
name=species,
|
| 226 |
+
q="A",
|
| 227 |
+
length="10-30",
|
| 228 |
+
)
|
| 229 |
|
| 230 |
# retrieve metadata
|
| 231 |
metafiles = q.retrieve_meta(verbose=True)
|
| 232 |
# retrieve recordings
|
| 233 |
+
q.retrieve_recordings(
|
| 234 |
+
multiprocess=True,
|
| 235 |
+
nproc=10,
|
| 236 |
+
attempts=10,
|
| 237 |
+
outdir="/media/CHONK/hugo/xeno-canto-full/",
|
| 238 |
+
)
|
| 239 |
|
| 240 |
except:
|
| 241 |
print("Failed to download " + species)
|
| 242 |
+
continue
|
vampnet/setup.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
from setuptools import find_packages
|
| 2 |
-
from setuptools import setup
|
| 3 |
|
| 4 |
with open("README.md") as f:
|
| 5 |
long_description = f.read()
|
|
@@ -29,7 +28,7 @@ setup(
|
|
| 29 |
"Cython",
|
| 30 |
],
|
| 31 |
install_requires=[
|
| 32 |
-
"Cython", # Added by
|
| 33 |
"torch",
|
| 34 |
"pydantic==2.10.6",
|
| 35 |
"argbind>=0.3.2",
|
|
@@ -40,8 +39,6 @@ setup(
|
|
| 40 |
"gradio",
|
| 41 |
"loralib",
|
| 42 |
"torch_pitch_shift",
|
| 43 |
-
"plotly", # Added by WAM for clustering (see https://github.com/hugofloresgarcia/vampnet/issues/20)
|
| 44 |
"pyharp",
|
| 45 |
-
|
| 46 |
],
|
| 47 |
)
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
|
|
|
| 2 |
|
| 3 |
with open("README.md") as f:
|
| 4 |
long_description = f.read()
|
|
|
|
| 28 |
"Cython",
|
| 29 |
],
|
| 30 |
install_requires=[
|
| 31 |
+
"Cython", # Added by WhAM because it seems to be needed by this repo?
|
| 32 |
"torch",
|
| 33 |
"pydantic==2.10.6",
|
| 34 |
"argbind>=0.3.2",
|
|
|
|
| 39 |
"gradio",
|
| 40 |
"loralib",
|
| 41 |
"torch_pitch_shift",
|
|
|
|
| 42 |
"pyharp",
|
|
|
|
| 43 |
],
|
| 44 |
)
|
vampnet/vampnet/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
from . import modules
|
| 3 |
-
from . import scheduler
|
| 4 |
from .interface import Interface
|
| 5 |
|
| 6 |
__version__ = "0.0.1"
|
|
|
|
| 1 |
+
from . import modules, scheduler
|
|
|
|
|
|
|
| 2 |
from .interface import Interface
|
| 3 |
|
| 4 |
__version__ = "0.0.1"
|
vampnet/vampnet/beats.py
CHANGED
|
@@ -1,19 +1,14 @@
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
-
import warnings
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import
|
| 7 |
-
from typing import List
|
| 8 |
-
from typing import Tuple
|
| 9 |
-
from typing import Union
|
| 10 |
|
| 11 |
import librosa
|
| 12 |
-
import torch
|
| 13 |
import numpy as np
|
|
|
|
| 14 |
from audiotools import AudioSignal
|
| 15 |
|
| 16 |
-
|
| 17 |
logging.basicConfig(level=logging.INFO)
|
| 18 |
|
| 19 |
###################
|
|
@@ -60,7 +55,6 @@ def mkdir(path: Union[Path, str]) -> Path:
|
|
| 60 |
return p
|
| 61 |
|
| 62 |
|
| 63 |
-
|
| 64 |
###################
|
| 65 |
# beat data #
|
| 66 |
###################
|
|
@@ -204,7 +198,9 @@ class WaveBeat(BeatTracker):
|
|
| 204 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 205 |
from wavebeat.dstcn import dsTCNModel
|
| 206 |
|
| 207 |
-
model = dsTCNModel.load_from_checkpoint(
|
|
|
|
|
|
|
| 208 |
model.eval()
|
| 209 |
|
| 210 |
self.device = device
|
|
@@ -247,4 +243,4 @@ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
|
|
| 247 |
f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
|
| 248 |
)
|
| 249 |
|
| 250 |
-
return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
|
|
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from pathlib import Path
|
| 5 |
+
from typing import List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import librosa
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
from audiotools import AudioSignal
|
| 11 |
|
|
|
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
###################
|
|
|
|
| 55 |
return p
|
| 56 |
|
| 57 |
|
|
|
|
| 58 |
###################
|
| 59 |
# beat data #
|
| 60 |
###################
|
|
|
|
| 198 |
def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
|
| 199 |
from wavebeat.dstcn import dsTCNModel
|
| 200 |
|
| 201 |
+
model = dsTCNModel.load_from_checkpoint(
|
| 202 |
+
ckpt_path, map_location=torch.device(device), weights_only=False
|
| 203 |
+
)
|
| 204 |
model.eval()
|
| 205 |
|
| 206 |
self.device = device
|
|
|
|
| 243 |
f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
|
| 244 |
)
|
| 245 |
|
| 246 |
+
return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
|
vampnet/vampnet/interface.py
CHANGED
|
@@ -1,19 +1,17 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from pathlib import Path
|
| 3 |
import math
|
|
|
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
from audiotools import AudioSignal
|
| 8 |
-
import tqdm
|
| 9 |
-
|
| 10 |
-
from .modules.transformer import VampNet
|
| 11 |
-
from .beats import WaveBeat
|
| 12 |
-
from .mask import *
|
| 13 |
|
| 14 |
# from dac.model.dac import DAC
|
| 15 |
from lac.model.lac import LAC as DAC
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def signal_concat(
|
| 19 |
audio_signals: list,
|
|
@@ -24,7 +22,7 @@ def signal_concat(
|
|
| 24 |
|
| 25 |
|
| 26 |
def _load_model(
|
| 27 |
-
ckpt: str,
|
| 28 |
lora_ckpt: str = None,
|
| 29 |
device: str = "cpu",
|
| 30 |
chunk_size_s: int = 10,
|
|
@@ -41,7 +39,9 @@ def _load_model(
|
|
| 41 |
if should_cont != "y":
|
| 42 |
raise Exception("aborting")
|
| 43 |
else:
|
| 44 |
-
model.load_state_dict(
|
|
|
|
|
|
|
| 45 |
|
| 46 |
model.to(device)
|
| 47 |
model.eval()
|
|
@@ -49,7 +49,6 @@ def _load_model(
|
|
| 49 |
return model
|
| 50 |
|
| 51 |
|
| 52 |
-
|
| 53 |
class Interface(torch.nn.Module):
|
| 54 |
def __init__(
|
| 55 |
self,
|
|
@@ -60,8 +59,8 @@ class Interface(torch.nn.Module):
|
|
| 60 |
codec_ckpt: str = None,
|
| 61 |
wavebeat_ckpt: str = None,
|
| 62 |
device: str = "cpu",
|
| 63 |
-
coarse_chunk_size_s: int =
|
| 64 |
-
coarse2fine_chunk_size_s: int =
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
|
@@ -98,7 +97,7 @@ class Interface(torch.nn.Module):
|
|
| 98 |
self.device = device
|
| 99 |
|
| 100 |
def lora_load(
|
| 101 |
-
self,
|
| 102 |
coarse_ckpt: str = None,
|
| 103 |
c2f_ckpt: str = None,
|
| 104 |
full_ckpts: bool = False,
|
|
@@ -106,7 +105,7 @@ class Interface(torch.nn.Module):
|
|
| 106 |
if full_ckpts:
|
| 107 |
if coarse_ckpt is not None:
|
| 108 |
self.coarse = _load_model(
|
| 109 |
-
ckpt=coarse_ckpt,
|
| 110 |
device=self.device,
|
| 111 |
chunk_size_s=self.coarse.chunk_size_s,
|
| 112 |
)
|
|
@@ -129,7 +128,7 @@ class Interface(torch.nn.Module):
|
|
| 129 |
print(f"loading c2f from {c2f_ckpt}")
|
| 130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
self.c2f.to(self.device)
|
| 132 |
-
|
| 133 |
def s2t(self, seconds: float):
|
| 134 |
"""seconds to tokens"""
|
| 135 |
if isinstance(seconds, np.ndarray):
|
|
@@ -140,7 +139,7 @@ class Interface(torch.nn.Module):
|
|
| 140 |
def s2t2s(self, seconds: float):
|
| 141 |
"""seconds to tokens to seconds"""
|
| 142 |
return self.t2s(self.s2t(seconds))
|
| 143 |
-
|
| 144 |
def t2s(self, tokens: int):
|
| 145 |
"""tokens to seconds"""
|
| 146 |
return tokens * self.codec.hop_length / self.codec.sample_rate
|
|
@@ -159,7 +158,7 @@ class Interface(torch.nn.Module):
|
|
| 159 |
|
| 160 |
def to_signal(self, z: torch.Tensor):
|
| 161 |
return self.coarse.to_signal(z, self.codec)
|
| 162 |
-
|
| 163 |
def preprocess(self, signal: AudioSignal):
|
| 164 |
signal = (
|
| 165 |
signal.clone()
|
|
@@ -169,41 +168,39 @@ class Interface(torch.nn.Module):
|
|
| 169 |
.ensure_max_of_audio(1.0)
|
| 170 |
)
|
| 171 |
return signal
|
| 172 |
-
|
| 173 |
@torch.inference_mode()
|
| 174 |
def encode(self, signal: AudioSignal):
|
| 175 |
signal = self.preprocess(signal).to(self.device)
|
| 176 |
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 177 |
return z
|
| 178 |
|
| 179 |
-
def snap_to_beats(
|
| 180 |
-
self,
|
| 181 |
-
signal: AudioSignal
|
| 182 |
-
):
|
| 183 |
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
|
| 184 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
| 185 |
-
|
| 186 |
# trim the signa around the first beat time
|
| 187 |
-
samples_begin = int(beats[0] * signal.sample_rate
|
| 188 |
samples_end = int(beats[-1] * signal.sample_rate)
|
| 189 |
print(beats[0])
|
| 190 |
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
|
| 191 |
|
| 192 |
return signal
|
| 193 |
|
| 194 |
-
def make_beat_mask(
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
| 204 |
):
|
| 205 |
-
"""make a beat synced mask. that is, make a mask that
|
| 206 |
-
places 1s at and around the beat, and 0s everywhere else.
|
| 207 |
"""
|
| 208 |
assert self.beat_tracker is not None, "No beat tracker loaded"
|
| 209 |
|
|
@@ -214,14 +211,16 @@ class Interface(torch.nn.Module):
|
|
| 214 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
| 215 |
|
| 216 |
# remove downbeats from beats
|
| 217 |
-
beats_z = torch.tensor(beats_z)[
|
|
|
|
|
|
|
| 218 |
beats_z = beats_z.tolist()
|
| 219 |
downbeats_z = downbeats_z.tolist()
|
| 220 |
|
| 221 |
-
# make the mask
|
| 222 |
seq_len = self.s2t(signal.duration)
|
| 223 |
mask = torch.zeros(seq_len, device=self.device)
|
| 224 |
-
|
| 225 |
mask_b4 = self.s2t(before_beat_s)
|
| 226 |
mask_after = self.s2t(after_beat_s)
|
| 227 |
|
|
@@ -241,44 +240,39 @@ class Interface(torch.nn.Module):
|
|
| 241 |
downbeats_z = downbeats_z[::downbeat_downsample_factor]
|
| 242 |
print(f"beats_z: {len(beats_z)}")
|
| 243 |
print(f"downbeats_z: {len(downbeats_z)}")
|
| 244 |
-
|
| 245 |
if mask_upbeats:
|
| 246 |
for beat_idx in beats_z:
|
| 247 |
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
| 248 |
-
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 249 |
_m = torch.ones(num_steps, device=self.device)
|
| 250 |
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 251 |
_m = _m * _m_mask.long()
|
| 252 |
-
|
| 253 |
-
mask[_slice[0]:_slice[1]] = _m
|
| 254 |
|
| 255 |
if mask_downbeats:
|
| 256 |
for downbeat_idx in downbeats_z:
|
| 257 |
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
| 258 |
-
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
| 259 |
_m = torch.ones(num_steps, device=self.device)
|
| 260 |
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 261 |
_m = _m * _m_mask.long()
|
| 262 |
-
|
| 263 |
-
mask[_slice[0]:_slice[1]] = _m
|
| 264 |
-
|
| 265 |
mask = mask.clamp(0, 1)
|
| 266 |
if invert:
|
| 267 |
mask = 1 - mask
|
| 268 |
-
|
| 269 |
mask = mask[None, None, :].bool().long()
|
| 270 |
if self.c2f is not None:
|
| 271 |
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
|
| 272 |
else:
|
| 273 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
| 274 |
return mask
|
| 275 |
-
|
| 276 |
-
def coarse_to_fine(
|
| 277 |
-
self,
|
| 278 |
-
z: torch.Tensor,
|
| 279 |
-
mask: torch.Tensor = None,
|
| 280 |
-
**kwargs
|
| 281 |
-
):
|
| 282 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 283 |
length = z.shape[-1]
|
| 284 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
|
@@ -288,49 +282,57 @@ class Interface(torch.nn.Module):
|
|
| 288 |
if length % chunk_len != 0:
|
| 289 |
pad_len = chunk_len - (length % chunk_len)
|
| 290 |
z = torch.nn.functional.pad(z, (0, pad_len))
|
| 291 |
-
mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
| 294 |
if n_codebooks_to_append > 0:
|
| 295 |
-
z = torch.cat(
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
# set the mask to 0 for all conditioning codebooks
|
| 301 |
if mask is not None:
|
| 302 |
mask = mask.clone()
|
| 303 |
-
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
| 304 |
|
| 305 |
fine_z = []
|
| 306 |
for i in range(n_chunks):
|
| 307 |
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 308 |
-
mask_chunk =
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
chunk = self.c2f.generate(
|
| 311 |
codec=self.codec,
|
| 312 |
time_steps=chunk_len,
|
| 313 |
start_tokens=chunk,
|
| 314 |
return_signal=False,
|
| 315 |
mask=mask_chunk,
|
| 316 |
-
**kwargs
|
| 317 |
)
|
| 318 |
fine_z.append(chunk)
|
| 319 |
|
| 320 |
fine_z = torch.cat(fine_z, dim=-1)
|
| 321 |
return fine_z[:, :, :length].clone()
|
| 322 |
-
|
| 323 |
-
def coarse_vamp(
|
| 324 |
-
self,
|
| 325 |
-
z,
|
| 326 |
-
mask,
|
| 327 |
-
return_mask=False,
|
| 328 |
-
gen_fn=None,
|
| 329 |
-
**kwargs
|
| 330 |
-
):
|
| 331 |
# coarse z
|
| 332 |
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
| 333 |
-
assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s),
|
|
|
|
|
|
|
| 334 |
|
| 335 |
mask = mask[:, : self.coarse.n_codebooks, :]
|
| 336 |
|
|
@@ -342,41 +344,39 @@ class Interface(torch.nn.Module):
|
|
| 342 |
codec=self.codec,
|
| 343 |
time_steps=cz.shape[-1],
|
| 344 |
start_tokens=cz,
|
| 345 |
-
mask=mask,
|
| 346 |
return_signal=False,
|
| 347 |
-
**kwargs
|
| 348 |
)
|
| 349 |
|
| 350 |
# add the fine codes back in
|
| 351 |
-
c_vamp = torch.cat(
|
| 352 |
-
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
| 353 |
-
dim=1
|
| 354 |
-
)
|
| 355 |
|
| 356 |
if return_mask:
|
| 357 |
return c_vamp, cz_masked
|
| 358 |
-
|
| 359 |
return c_vamp
|
| 360 |
|
| 361 |
|
| 362 |
if __name__ == "__main__":
|
| 363 |
-
import audiotools as at
|
| 364 |
import logging
|
|
|
|
|
|
|
|
|
|
| 365 |
logger = logging.getLogger()
|
| 366 |
logger.setLevel(logging.INFO)
|
| 367 |
torch.set_printoptions(threshold=10000)
|
| 368 |
at.util.seed(42)
|
| 369 |
|
| 370 |
interface = Interface(
|
| 371 |
-
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 372 |
-
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 373 |
codec_ckpt="./models/vampnet/codec.pth",
|
| 374 |
-
device="cuda",
|
| 375 |
-
wavebeat_ckpt="./models/wavebeat.pth"
|
| 376 |
)
|
| 377 |
|
| 378 |
-
|
| 379 |
-
sig = at.AudioSignal('assets/example.wav')
|
| 380 |
|
| 381 |
z = interface.encode(sig)
|
| 382 |
breakpoint()
|
|
@@ -398,19 +398,18 @@ if __name__ == "__main__":
|
|
| 398 |
# mask = codebook_unmask(mask, 0)
|
| 399 |
|
| 400 |
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
| 401 |
-
|
| 402 |
zv, mask_z = interface.coarse_vamp(
|
| 403 |
-
z,
|
| 404 |
mask=mask,
|
| 405 |
sampling_steps=36,
|
| 406 |
temperature=8.0,
|
| 407 |
-
return_mask=True,
|
| 408 |
-
gen_fn=interface.coarse.generate
|
| 409 |
)
|
| 410 |
-
|
| 411 |
|
| 412 |
use_coarse2fine = True
|
| 413 |
-
if use_coarse2fine:
|
| 414 |
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
| 415 |
breakpoint()
|
| 416 |
|
|
@@ -418,5 +417,3 @@ if __name__ == "__main__":
|
|
| 418 |
|
| 419 |
sig = interface.to_signal(zv).cpu()
|
| 420 |
print("done")
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
+
from pathlib import Path
|
| 3 |
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
from audiotools import AudioSignal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# from dac.model.dac import DAC
|
| 9 |
from lac.model.lac import LAC as DAC
|
| 10 |
|
| 11 |
+
from .beats import WaveBeat
|
| 12 |
+
from .mask import *
|
| 13 |
+
from .modules.transformer import VampNet
|
| 14 |
+
|
| 15 |
|
| 16 |
def signal_concat(
|
| 17 |
audio_signals: list,
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def _load_model(
|
| 25 |
+
ckpt: str,
|
| 26 |
lora_ckpt: str = None,
|
| 27 |
device: str = "cpu",
|
| 28 |
chunk_size_s: int = 10,
|
|
|
|
| 39 |
if should_cont != "y":
|
| 40 |
raise Exception("aborting")
|
| 41 |
else:
|
| 42 |
+
model.load_state_dict(
|
| 43 |
+
torch.load(lora_ckpt, map_location="cpu"), strict=False
|
| 44 |
+
)
|
| 45 |
|
| 46 |
model.to(device)
|
| 47 |
model.eval()
|
|
|
|
| 49 |
return model
|
| 50 |
|
| 51 |
|
|
|
|
| 52 |
class Interface(torch.nn.Module):
|
| 53 |
def __init__(
|
| 54 |
self,
|
|
|
|
| 59 |
codec_ckpt: str = None,
|
| 60 |
wavebeat_ckpt: str = None,
|
| 61 |
device: str = "cpu",
|
| 62 |
+
coarse_chunk_size_s: int = 10,
|
| 63 |
+
coarse2fine_chunk_size_s: int = 3,
|
| 64 |
):
|
| 65 |
super().__init__()
|
| 66 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
|
|
|
| 97 |
self.device = device
|
| 98 |
|
| 99 |
def lora_load(
|
| 100 |
+
self,
|
| 101 |
coarse_ckpt: str = None,
|
| 102 |
c2f_ckpt: str = None,
|
| 103 |
full_ckpts: bool = False,
|
|
|
|
| 105 |
if full_ckpts:
|
| 106 |
if coarse_ckpt is not None:
|
| 107 |
self.coarse = _load_model(
|
| 108 |
+
ckpt=coarse_ckpt,
|
| 109 |
device=self.device,
|
| 110 |
chunk_size_s=self.coarse.chunk_size_s,
|
| 111 |
)
|
|
|
|
| 128 |
print(f"loading c2f from {c2f_ckpt}")
|
| 129 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 130 |
self.c2f.to(self.device)
|
| 131 |
+
|
| 132 |
def s2t(self, seconds: float):
|
| 133 |
"""seconds to tokens"""
|
| 134 |
if isinstance(seconds, np.ndarray):
|
|
|
|
| 139 |
def s2t2s(self, seconds: float):
|
| 140 |
"""seconds to tokens to seconds"""
|
| 141 |
return self.t2s(self.s2t(seconds))
|
| 142 |
+
|
| 143 |
def t2s(self, tokens: int):
|
| 144 |
"""tokens to seconds"""
|
| 145 |
return tokens * self.codec.hop_length / self.codec.sample_rate
|
|
|
|
| 158 |
|
| 159 |
def to_signal(self, z: torch.Tensor):
|
| 160 |
return self.coarse.to_signal(z, self.codec)
|
| 161 |
+
|
| 162 |
def preprocess(self, signal: AudioSignal):
|
| 163 |
signal = (
|
| 164 |
signal.clone()
|
|
|
|
| 168 |
.ensure_max_of_audio(1.0)
|
| 169 |
)
|
| 170 |
return signal
|
| 171 |
+
|
| 172 |
@torch.inference_mode()
|
| 173 |
def encode(self, signal: AudioSignal):
|
| 174 |
signal = self.preprocess(signal).to(self.device)
|
| 175 |
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
|
| 176 |
return z
|
| 177 |
|
| 178 |
+
def snap_to_beats(self, signal: AudioSignal):
|
|
|
|
|
|
|
|
|
|
| 179 |
assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
|
| 180 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
| 181 |
+
|
| 182 |
# trim the signa around the first beat time
|
| 183 |
+
samples_begin = int(beats[0] * signal.sample_rate)
|
| 184 |
samples_end = int(beats[-1] * signal.sample_rate)
|
| 185 |
print(beats[0])
|
| 186 |
signal = signal.clone().trim(samples_begin, signal.length - samples_end)
|
| 187 |
|
| 188 |
return signal
|
| 189 |
|
| 190 |
+
def make_beat_mask(
|
| 191 |
+
self,
|
| 192 |
+
signal: AudioSignal,
|
| 193 |
+
before_beat_s: float = 0.0,
|
| 194 |
+
after_beat_s: float = 0.02,
|
| 195 |
+
mask_downbeats: bool = True,
|
| 196 |
+
mask_upbeats: bool = True,
|
| 197 |
+
downbeat_downsample_factor: int = None,
|
| 198 |
+
beat_downsample_factor: int = None,
|
| 199 |
+
dropout: float = 0.0,
|
| 200 |
+
invert: bool = True,
|
| 201 |
):
|
| 202 |
+
"""make a beat synced mask. that is, make a mask that
|
| 203 |
+
places 1s at and around the beat, and 0s everywhere else.
|
| 204 |
"""
|
| 205 |
assert self.beat_tracker is not None, "No beat tracker loaded"
|
| 206 |
|
|
|
|
| 211 |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
|
| 212 |
|
| 213 |
# remove downbeats from beats
|
| 214 |
+
beats_z = torch.tensor(beats_z)[
|
| 215 |
+
~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))
|
| 216 |
+
]
|
| 217 |
beats_z = beats_z.tolist()
|
| 218 |
downbeats_z = downbeats_z.tolist()
|
| 219 |
|
| 220 |
+
# make the mask
|
| 221 |
seq_len = self.s2t(signal.duration)
|
| 222 |
mask = torch.zeros(seq_len, device=self.device)
|
| 223 |
+
|
| 224 |
mask_b4 = self.s2t(before_beat_s)
|
| 225 |
mask_after = self.s2t(after_beat_s)
|
| 226 |
|
|
|
|
| 240 |
downbeats_z = downbeats_z[::downbeat_downsample_factor]
|
| 241 |
print(f"beats_z: {len(beats_z)}")
|
| 242 |
print(f"downbeats_z: {len(downbeats_z)}")
|
| 243 |
+
|
| 244 |
if mask_upbeats:
|
| 245 |
for beat_idx in beats_z:
|
| 246 |
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
| 247 |
+
num_steps = mask[_slice[0] : _slice[1]].shape[0]
|
| 248 |
_m = torch.ones(num_steps, device=self.device)
|
| 249 |
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 250 |
_m = _m * _m_mask.long()
|
| 251 |
+
|
| 252 |
+
mask[_slice[0] : _slice[1]] = _m
|
| 253 |
|
| 254 |
if mask_downbeats:
|
| 255 |
for downbeat_idx in downbeats_z:
|
| 256 |
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
| 257 |
+
num_steps = mask[_slice[0] : _slice[1]].shape[0]
|
| 258 |
_m = torch.ones(num_steps, device=self.device)
|
| 259 |
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
| 260 |
_m = _m * _m_mask.long()
|
| 261 |
+
|
| 262 |
+
mask[_slice[0] : _slice[1]] = _m
|
| 263 |
+
|
| 264 |
mask = mask.clamp(0, 1)
|
| 265 |
if invert:
|
| 266 |
mask = 1 - mask
|
| 267 |
+
|
| 268 |
mask = mask[None, None, :].bool().long()
|
| 269 |
if self.c2f is not None:
|
| 270 |
mask = mask.repeat(1, self.c2f.n_codebooks, 1)
|
| 271 |
else:
|
| 272 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
| 273 |
return mask
|
| 274 |
+
|
| 275 |
+
def coarse_to_fine(self, z: torch.Tensor, mask: torch.Tensor = None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
| 277 |
length = z.shape[-1]
|
| 278 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
|
|
|
| 282 |
if length % chunk_len != 0:
|
| 283 |
pad_len = chunk_len - (length % chunk_len)
|
| 284 |
z = torch.nn.functional.pad(z, (0, pad_len))
|
| 285 |
+
mask = (
|
| 286 |
+
torch.nn.functional.pad(mask, (0, pad_len))
|
| 287 |
+
if mask is not None
|
| 288 |
+
else None
|
| 289 |
+
)
|
| 290 |
|
| 291 |
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
| 292 |
if n_codebooks_to_append > 0:
|
| 293 |
+
z = torch.cat(
|
| 294 |
+
[
|
| 295 |
+
z,
|
| 296 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1])
|
| 297 |
+
.long()
|
| 298 |
+
.to(self.device),
|
| 299 |
+
],
|
| 300 |
+
dim=1,
|
| 301 |
+
)
|
| 302 |
|
| 303 |
# set the mask to 0 for all conditioning codebooks
|
| 304 |
if mask is not None:
|
| 305 |
mask = mask.clone()
|
| 306 |
+
mask[:, : self.c2f.n_conditioning_codebooks, :] = 0
|
| 307 |
|
| 308 |
fine_z = []
|
| 309 |
for i in range(n_chunks):
|
| 310 |
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 311 |
+
mask_chunk = (
|
| 312 |
+
mask[:, :, i * chunk_len : (i + 1) * chunk_len]
|
| 313 |
+
if mask is not None
|
| 314 |
+
else None
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
chunk = self.c2f.generate(
|
| 318 |
codec=self.codec,
|
| 319 |
time_steps=chunk_len,
|
| 320 |
start_tokens=chunk,
|
| 321 |
return_signal=False,
|
| 322 |
mask=mask_chunk,
|
| 323 |
+
**kwargs,
|
| 324 |
)
|
| 325 |
fine_z.append(chunk)
|
| 326 |
|
| 327 |
fine_z = torch.cat(fine_z, dim=-1)
|
| 328 |
return fine_z[:, :, :length].clone()
|
| 329 |
+
|
| 330 |
+
def coarse_vamp(self, z, mask, return_mask=False, gen_fn=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
# coarse z
|
| 332 |
cz = z[:, : self.coarse.n_codebooks, :].clone()
|
| 333 |
+
assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), (
|
| 334 |
+
f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
|
| 335 |
+
)
|
| 336 |
|
| 337 |
mask = mask[:, : self.coarse.n_codebooks, :]
|
| 338 |
|
|
|
|
| 344 |
codec=self.codec,
|
| 345 |
time_steps=cz.shape[-1],
|
| 346 |
start_tokens=cz,
|
| 347 |
+
mask=mask,
|
| 348 |
return_signal=False,
|
| 349 |
+
**kwargs,
|
| 350 |
)
|
| 351 |
|
| 352 |
# add the fine codes back in
|
| 353 |
+
c_vamp = torch.cat([c_vamp, z[:, self.coarse.n_codebooks :, :]], dim=1)
|
|
|
|
|
|
|
|
|
|
| 354 |
|
| 355 |
if return_mask:
|
| 356 |
return c_vamp, cz_masked
|
| 357 |
+
|
| 358 |
return c_vamp
|
| 359 |
|
| 360 |
|
| 361 |
if __name__ == "__main__":
|
|
|
|
| 362 |
import logging
|
| 363 |
+
|
| 364 |
+
import audiotools as at
|
| 365 |
+
|
| 366 |
logger = logging.getLogger()
|
| 367 |
logger.setLevel(logging.INFO)
|
| 368 |
torch.set_printoptions(threshold=10000)
|
| 369 |
at.util.seed(42)
|
| 370 |
|
| 371 |
interface = Interface(
|
| 372 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 373 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 374 |
codec_ckpt="./models/vampnet/codec.pth",
|
| 375 |
+
device="cuda",
|
| 376 |
+
wavebeat_ckpt="./models/wavebeat.pth",
|
| 377 |
)
|
| 378 |
|
| 379 |
+
sig = at.AudioSignal("assets/example.wav")
|
|
|
|
| 380 |
|
| 381 |
z = interface.encode(sig)
|
| 382 |
breakpoint()
|
|
|
|
| 398 |
# mask = codebook_unmask(mask, 0)
|
| 399 |
|
| 400 |
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
| 401 |
+
|
| 402 |
zv, mask_z = interface.coarse_vamp(
|
| 403 |
+
z,
|
| 404 |
mask=mask,
|
| 405 |
sampling_steps=36,
|
| 406 |
temperature=8.0,
|
| 407 |
+
return_mask=True,
|
| 408 |
+
gen_fn=interface.coarse.generate,
|
| 409 |
)
|
|
|
|
| 410 |
|
| 411 |
use_coarse2fine = True
|
| 412 |
+
if use_coarse2fine:
|
| 413 |
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
| 414 |
breakpoint()
|
| 415 |
|
|
|
|
| 417 |
|
| 418 |
sig = interface.to_signal(zv).cpu()
|
| 419 |
print("done")
|
|
|
|
|
|
vampnet/vampnet/mask.py
CHANGED
|
@@ -1,33 +1,34 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
from audiotools import AudioSignal
|
| 5 |
|
| 6 |
from .util import scalar_to_batch_tensor
|
| 7 |
|
|
|
|
| 8 |
def _gamma(r):
|
| 9 |
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
| 10 |
|
|
|
|
| 11 |
def _invgamma(y):
|
| 12 |
if not torch.is_tensor(y):
|
| 13 |
y = torch.tensor(y)[None]
|
| 14 |
return 2 * y.acos() / torch.pi
|
| 15 |
|
|
|
|
| 16 |
def full_mask(x: torch.Tensor):
|
| 17 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 18 |
return torch.ones_like(x).long()
|
| 19 |
|
|
|
|
| 20 |
def empty_mask(x: torch.Tensor):
|
| 21 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 22 |
return torch.zeros_like(x).long()
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
mask: torch.Tensor,
|
| 27 |
-
mask_token: int
|
| 28 |
-
):
|
| 29 |
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
| 30 |
-
assert mask.shape == x.shape,
|
|
|
|
|
|
|
| 31 |
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
|
| 32 |
assert ~torch.any(mask > 1), "mask must be binary"
|
| 33 |
assert ~torch.any(mask < 0), "mask must be binary"
|
|
@@ -37,10 +38,8 @@ def apply_mask(
|
|
| 37 |
|
| 38 |
return x, mask
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
r: torch.Tensor
|
| 43 |
-
):
|
| 44 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 45 |
if not isinstance(r, torch.Tensor):
|
| 46 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
|
@@ -53,6 +52,7 @@ def random(
|
|
| 53 |
|
| 54 |
return mask
|
| 55 |
|
|
|
|
| 56 |
def linear_random(
|
| 57 |
x: torch.Tensor,
|
| 58 |
r: torch.Tensor,
|
|
@@ -71,19 +71,21 @@ def linear_random(
|
|
| 71 |
|
| 72 |
return mask
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
n_prefix,
|
| 76 |
n_suffix,
|
| 77 |
):
|
| 78 |
assert n_prefix is not None
|
| 79 |
assert n_suffix is not None
|
| 80 |
-
|
| 81 |
mask = full_mask(x)
|
| 82 |
|
| 83 |
# if we have a prefix or suffix, set their mask prob to 0
|
| 84 |
if n_prefix > 0:
|
| 85 |
if not isinstance(n_prefix, torch.Tensor):
|
| 86 |
-
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
| 87 |
for i, n in enumerate(n_prefix):
|
| 88 |
if n > 0:
|
| 89 |
mask[i, :, :n] = 0.0
|
|
@@ -94,13 +96,15 @@ def inpaint(x: torch.Tensor,
|
|
| 94 |
if n > 0:
|
| 95 |
mask[i, :, -n:] = 0.0
|
| 96 |
|
| 97 |
-
|
| 98 |
return mask
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
mask = full_mask(x)
|
| 105 |
if period == 0:
|
| 106 |
return mask
|
|
@@ -113,8 +117,8 @@ def periodic_mask(x: torch.Tensor,
|
|
| 113 |
for j in range(mask.shape[-1]):
|
| 114 |
if j % factor == 0:
|
| 115 |
# figure out how wide the mask should be
|
| 116 |
-
j_start = max(0, j - width // 2
|
| 117 |
-
j_end = min(mask.shape[-1] - 1, j + width // 2
|
| 118 |
# flip a coin for each position in the mask
|
| 119 |
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
| 120 |
assert torch.all(j_mask == 1)
|
|
@@ -129,10 +133,8 @@ def periodic_mask(x: torch.Tensor,
|
|
| 129 |
|
| 130 |
return mask
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
n_conditioning_codebooks: int
|
| 135 |
-
):
|
| 136 |
if n_conditioning_codebooks == None:
|
| 137 |
return mask
|
| 138 |
# if we have any conditioning codebooks, set their mask to 0
|
|
@@ -140,18 +142,18 @@ def codebook_unmask(
|
|
| 140 |
mask[:, :n_conditioning_codebooks, :] = 0
|
| 141 |
return mask
|
| 142 |
|
|
|
|
| 143 |
def codebook_mask(mask: torch.Tensor, start: int):
|
| 144 |
mask = mask.clone()
|
| 145 |
mask[:, start:, :] = 1
|
| 146 |
return mask
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
mask2: torch.Tensor
|
| 151 |
-
):
|
| 152 |
assert mask1.shape == mask2.shape, "masks must be same shape"
|
| 153 |
return torch.min(mask1, mask2)
|
| 154 |
|
|
|
|
| 155 |
def dropout(
|
| 156 |
mask: torch.Tensor,
|
| 157 |
p: float,
|
|
@@ -164,19 +166,20 @@ def dropout(
|
|
| 164 |
mask = ~mask.round().bool()
|
| 165 |
return mask.long()
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
assert mask1.max() <= 1, "mask1 must be binary"
|
| 173 |
assert mask2.max() <= 1, "mask2 must be binary"
|
| 174 |
assert mask1.min() >= 0, "mask1 must be binary"
|
| 175 |
assert mask2.min() >= 0, "mask2 must be binary"
|
| 176 |
return (mask1 + mask2).clamp(0, 1)
|
| 177 |
|
|
|
|
| 178 |
def time_stretch_mask(
|
| 179 |
-
x: torch.Tensor,
|
| 180 |
stretch_factor: int,
|
| 181 |
):
|
| 182 |
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
|
@@ -189,18 +192,19 @@ def time_stretch_mask(
|
|
| 189 |
mask = periodic_mask(x, stretch_factor, width=1)
|
| 190 |
return mask
|
| 191 |
|
|
|
|
| 192 |
def _onset_times_madmom(wav_path, sample_rate, hop_length):
|
| 193 |
-
from madmom.features.onsets import
|
|
|
|
| 194 |
proc = RNNOnsetProcessor(online=False)
|
| 195 |
-
onsetproc = OnsetPeakPickingProcessor(
|
| 196 |
-
threshold=0.3, fps=sample_rate / hop_length
|
| 197 |
-
)
|
| 198 |
act = proc(wav_path)
|
| 199 |
return onsetproc(act)
|
| 200 |
|
| 201 |
|
| 202 |
def _onset_times_librosa(wav_path, sample_rate, hop_length):
|
| 203 |
import librosa
|
|
|
|
| 204 |
y, sr = librosa.load(wav_path, sr=sample_rate)
|
| 205 |
onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
|
| 206 |
onset_frames = librosa.onset.onset_detect(
|
|
@@ -209,18 +213,14 @@ def _onset_times_librosa(wav_path, sample_rate, hop_length):
|
|
| 209 |
return librosa.frames_to_time(onset_frames, sr=sr, hop_length=hop_length)
|
| 210 |
|
| 211 |
|
| 212 |
-
def onset_mask(
|
| 213 |
-
sig: AudioSignal,
|
| 214 |
-
z: torch.Tensor,
|
| 215 |
-
interface,
|
| 216 |
-
width: int = 1
|
| 217 |
-
):
|
| 218 |
-
import librosa
|
| 219 |
import tempfile
|
| 220 |
-
|
|
|
|
| 221 |
|
| 222 |
try:
|
| 223 |
import madmom # noqa: F401
|
|
|
|
| 224 |
_get_onset_times = _onset_times_madmom
|
| 225 |
except ImportError:
|
| 226 |
print("madmom not installed, falling back to librosa for onset detection")
|
|
@@ -228,7 +228,7 @@ def onset_mask(
|
|
| 228 |
|
| 229 |
hop_length = interface.codec.hop_length
|
| 230 |
|
| 231 |
-
with tempfile.NamedTemporaryFile(suffix=
|
| 232 |
sig = sig.clone()
|
| 233 |
sig.write(f.name)
|
| 234 |
|
|
@@ -238,9 +238,9 @@ def onset_mask(
|
|
| 238 |
)
|
| 239 |
|
| 240 |
if onset_indices.shape[0] == 0:
|
| 241 |
-
mask = empty_mask(z)
|
| 242 |
-
print(
|
| 243 |
-
else:
|
| 244 |
torch.set_printoptions(threshold=1000)
|
| 245 |
print("onset indices: ", onset_indices)
|
| 246 |
print("onset times: ", onset_times)
|
|
@@ -251,12 +251,11 @@ def onset_mask(
|
|
| 251 |
for onset_index in onset_indices:
|
| 252 |
onset_index = min(onset_index, n_timesteps - 1)
|
| 253 |
onset_index = max(onset_index, 0)
|
| 254 |
-
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
| 255 |
|
| 256 |
print(mask)
|
| 257 |
-
|
| 258 |
-
return mask
|
| 259 |
|
|
|
|
| 260 |
|
| 261 |
|
| 262 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from audiotools import AudioSignal
|
| 3 |
|
| 4 |
from .util import scalar_to_batch_tensor
|
| 5 |
|
| 6 |
+
|
| 7 |
def _gamma(r):
|
| 8 |
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
| 9 |
|
| 10 |
+
|
| 11 |
def _invgamma(y):
|
| 12 |
if not torch.is_tensor(y):
|
| 13 |
y = torch.tensor(y)[None]
|
| 14 |
return 2 * y.acos() / torch.pi
|
| 15 |
|
| 16 |
+
|
| 17 |
def full_mask(x: torch.Tensor):
|
| 18 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 19 |
return torch.ones_like(x).long()
|
| 20 |
|
| 21 |
+
|
| 22 |
def empty_mask(x: torch.Tensor):
|
| 23 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 24 |
return torch.zeros_like(x).long()
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def apply_mask(x: torch.Tensor, mask: torch.Tensor, mask_token: int):
|
|
|
|
|
|
|
|
|
|
| 28 |
assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
|
| 29 |
+
assert mask.shape == x.shape, (
|
| 30 |
+
f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
|
| 31 |
+
)
|
| 32 |
assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
|
| 33 |
assert ~torch.any(mask > 1), "mask must be binary"
|
| 34 |
assert ~torch.any(mask < 0), "mask must be binary"
|
|
|
|
| 38 |
|
| 39 |
return x, mask
|
| 40 |
|
| 41 |
+
|
| 42 |
+
def random(x: torch.Tensor, r: torch.Tensor):
|
|
|
|
|
|
|
| 43 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
| 44 |
if not isinstance(r, torch.Tensor):
|
| 45 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
|
|
|
|
| 52 |
|
| 53 |
return mask
|
| 54 |
|
| 55 |
+
|
| 56 |
def linear_random(
|
| 57 |
x: torch.Tensor,
|
| 58 |
r: torch.Tensor,
|
|
|
|
| 71 |
|
| 72 |
return mask
|
| 73 |
|
| 74 |
+
|
| 75 |
+
def inpaint(
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
n_prefix,
|
| 78 |
n_suffix,
|
| 79 |
):
|
| 80 |
assert n_prefix is not None
|
| 81 |
assert n_suffix is not None
|
| 82 |
+
|
| 83 |
mask = full_mask(x)
|
| 84 |
|
| 85 |
# if we have a prefix or suffix, set their mask prob to 0
|
| 86 |
if n_prefix > 0:
|
| 87 |
if not isinstance(n_prefix, torch.Tensor):
|
| 88 |
+
n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
|
| 89 |
for i, n in enumerate(n_prefix):
|
| 90 |
if n > 0:
|
| 91 |
mask[i, :, :n] = 0.0
|
|
|
|
| 96 |
if n > 0:
|
| 97 |
mask[i, :, -n:] = 0.0
|
| 98 |
|
|
|
|
| 99 |
return mask
|
| 100 |
|
| 101 |
+
|
| 102 |
+
def periodic_mask(
|
| 103 |
+
x: torch.Tensor,
|
| 104 |
+
period: int,
|
| 105 |
+
width: int = 1,
|
| 106 |
+
random_roll=False,
|
| 107 |
+
):
|
| 108 |
mask = full_mask(x)
|
| 109 |
if period == 0:
|
| 110 |
return mask
|
|
|
|
| 117 |
for j in range(mask.shape[-1]):
|
| 118 |
if j % factor == 0:
|
| 119 |
# figure out how wide the mask should be
|
| 120 |
+
j_start = max(0, j - width // 2)
|
| 121 |
+
j_end = min(mask.shape[-1] - 1, j + width // 2) + 1
|
| 122 |
# flip a coin for each position in the mask
|
| 123 |
j_mask = torch.bernoulli(torch.ones(j_end - j_start))
|
| 124 |
assert torch.all(j_mask == 1)
|
|
|
|
| 133 |
|
| 134 |
return mask
|
| 135 |
|
| 136 |
+
|
| 137 |
+
def codebook_unmask(mask: torch.Tensor, n_conditioning_codebooks: int):
|
|
|
|
|
|
|
| 138 |
if n_conditioning_codebooks == None:
|
| 139 |
return mask
|
| 140 |
# if we have any conditioning codebooks, set their mask to 0
|
|
|
|
| 142 |
mask[:, :n_conditioning_codebooks, :] = 0
|
| 143 |
return mask
|
| 144 |
|
| 145 |
+
|
| 146 |
def codebook_mask(mask: torch.Tensor, start: int):
|
| 147 |
mask = mask.clone()
|
| 148 |
mask[:, start:, :] = 1
|
| 149 |
return mask
|
| 150 |
|
| 151 |
+
|
| 152 |
+
def mask_and(mask1: torch.Tensor, mask2: torch.Tensor):
|
|
|
|
|
|
|
| 153 |
assert mask1.shape == mask2.shape, "masks must be same shape"
|
| 154 |
return torch.min(mask1, mask2)
|
| 155 |
|
| 156 |
+
|
| 157 |
def dropout(
|
| 158 |
mask: torch.Tensor,
|
| 159 |
p: float,
|
|
|
|
| 166 |
mask = ~mask.round().bool()
|
| 167 |
return mask.long()
|
| 168 |
|
| 169 |
+
|
| 170 |
+
def mask_or(mask1: torch.Tensor, mask2: torch.Tensor):
|
| 171 |
+
assert mask1.shape == mask2.shape, (
|
| 172 |
+
f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
|
| 173 |
+
)
|
| 174 |
assert mask1.max() <= 1, "mask1 must be binary"
|
| 175 |
assert mask2.max() <= 1, "mask2 must be binary"
|
| 176 |
assert mask1.min() >= 0, "mask1 must be binary"
|
| 177 |
assert mask2.min() >= 0, "mask2 must be binary"
|
| 178 |
return (mask1 + mask2).clamp(0, 1)
|
| 179 |
|
| 180 |
+
|
| 181 |
def time_stretch_mask(
|
| 182 |
+
x: torch.Tensor,
|
| 183 |
stretch_factor: int,
|
| 184 |
):
|
| 185 |
assert stretch_factor >= 1, "stretch factor must be >= 1"
|
|
|
|
| 192 |
mask = periodic_mask(x, stretch_factor, width=1)
|
| 193 |
return mask
|
| 194 |
|
| 195 |
+
|
| 196 |
def _onset_times_madmom(wav_path, sample_rate, hop_length):
|
| 197 |
+
from madmom.features.onsets import OnsetPeakPickingProcessor, RNNOnsetProcessor
|
| 198 |
+
|
| 199 |
proc = RNNOnsetProcessor(online=False)
|
| 200 |
+
onsetproc = OnsetPeakPickingProcessor(threshold=0.3, fps=sample_rate / hop_length)
|
|
|
|
|
|
|
| 201 |
act = proc(wav_path)
|
| 202 |
return onsetproc(act)
|
| 203 |
|
| 204 |
|
| 205 |
def _onset_times_librosa(wav_path, sample_rate, hop_length):
|
| 206 |
import librosa
|
| 207 |
+
|
| 208 |
y, sr = librosa.load(wav_path, sr=sample_rate)
|
| 209 |
onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
|
| 210 |
onset_frames = librosa.onset.onset_detect(
|
|
|
|
| 213 |
return librosa.frames_to_time(onset_frames, sr=sr, hop_length=hop_length)
|
| 214 |
|
| 215 |
|
| 216 |
+
def onset_mask(sig: AudioSignal, z: torch.Tensor, interface, width: int = 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
import tempfile
|
| 218 |
+
|
| 219 |
+
import librosa
|
| 220 |
|
| 221 |
try:
|
| 222 |
import madmom # noqa: F401
|
| 223 |
+
|
| 224 |
_get_onset_times = _onset_times_madmom
|
| 225 |
except ImportError:
|
| 226 |
print("madmom not installed, falling back to librosa for onset detection")
|
|
|
|
| 228 |
|
| 229 |
hop_length = interface.codec.hop_length
|
| 230 |
|
| 231 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
| 232 |
sig = sig.clone()
|
| 233 |
sig.write(f.name)
|
| 234 |
|
|
|
|
| 238 |
)
|
| 239 |
|
| 240 |
if onset_indices.shape[0] == 0:
|
| 241 |
+
mask = empty_mask(z)
|
| 242 |
+
print("no onsets found, returning empty mask")
|
| 243 |
+
else:
|
| 244 |
torch.set_printoptions(threshold=1000)
|
| 245 |
print("onset indices: ", onset_indices)
|
| 246 |
print("onset times: ", onset_times)
|
|
|
|
| 251 |
for onset_index in onset_indices:
|
| 252 |
onset_index = min(onset_index, n_timesteps - 1)
|
| 253 |
onset_index = max(onset_index, 0)
|
| 254 |
+
mask[:, :, onset_index - width : onset_index + width] = 0.0
|
| 255 |
|
| 256 |
print(mask)
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
return mask
|
| 259 |
|
| 260 |
|
| 261 |
if __name__ == "__main__":
|
vampnet/vampnet/modules/__init__.py
CHANGED
|
@@ -3,4 +3,4 @@ import audiotools
|
|
| 3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
| 4 |
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
| 5 |
|
| 6 |
-
from .transformer import VampNet
|
|
|
|
| 3 |
audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
|
| 4 |
audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
|
| 5 |
|
| 6 |
+
from .transformer import VampNet
|
vampnet/vampnet/modules/activations.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
import math
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
-
from einops import rearrange
|
| 7 |
|
| 8 |
|
| 9 |
class NewGELU(nn.Module):
|
|
@@ -25,6 +23,7 @@ class NewGELU(nn.Module):
|
|
| 25 |
)
|
| 26 |
)
|
| 27 |
|
|
|
|
| 28 |
class GatedGELU(nn.Module):
|
| 29 |
def __init__(self):
|
| 30 |
super().__init__()
|
|
@@ -34,6 +33,7 @@ class GatedGELU(nn.Module):
|
|
| 34 |
p1, p2 = x.chunk(2, dim=dim)
|
| 35 |
return p1 * self.gelu(p2)
|
| 36 |
|
|
|
|
| 37 |
class Snake1d(nn.Module):
|
| 38 |
def __init__(self, channels):
|
| 39 |
super().__init__()
|
|
@@ -42,6 +42,7 @@ class Snake1d(nn.Module):
|
|
| 42 |
def forward(self, x):
|
| 43 |
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
| 44 |
|
|
|
|
| 45 |
def get_activation(name: str = "relu"):
|
| 46 |
if name == "relu":
|
| 47 |
return nn.ReLU
|
|
@@ -52,4 +53,4 @@ def get_activation(name: str = "relu"):
|
|
| 52 |
elif name == "snake":
|
| 53 |
return Snake1d
|
| 54 |
else:
|
| 55 |
-
raise ValueError(f"Unrecognized activation {name}")
|
|
|
|
| 1 |
import math
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class NewGELU(nn.Module):
|
|
|
|
| 23 |
)
|
| 24 |
)
|
| 25 |
|
| 26 |
+
|
| 27 |
class GatedGELU(nn.Module):
|
| 28 |
def __init__(self):
|
| 29 |
super().__init__()
|
|
|
|
| 33 |
p1, p2 = x.chunk(2, dim=dim)
|
| 34 |
return p1 * self.gelu(p2)
|
| 35 |
|
| 36 |
+
|
| 37 |
class Snake1d(nn.Module):
|
| 38 |
def __init__(self, channels):
|
| 39 |
super().__init__()
|
|
|
|
| 42 |
def forward(self, x):
|
| 43 |
return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
|
| 44 |
|
| 45 |
+
|
| 46 |
def get_activation(name: str = "relu"):
|
| 47 |
if name == "relu":
|
| 48 |
return nn.ReLU
|
|
|
|
| 53 |
elif name == "snake":
|
| 54 |
return Snake1d
|
| 55 |
else:
|
| 56 |
+
raise ValueError(f"Unrecognized activation {name}")
|
vampnet/vampnet/modules/layers.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
-
import
|
| 2 |
-
from typing import Optional
|
| 3 |
-
from typing import Tuple
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
-
from einops import rearrange
|
| 9 |
from torch.nn.utils import weight_norm
|
| 10 |
|
|
|
|
| 11 |
# Scripting this brings model speed up 1.4x
|
| 12 |
@torch.jit.script
|
| 13 |
def snake(x, alpha):
|
|
@@ -132,10 +130,10 @@ class CodebookEmbedding(nn.Module):
|
|
| 132 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
| 133 |
|
| 134 |
def from_codes(self, codes: torch.Tensor, codec):
|
| 135 |
-
"""
|
| 136 |
-
get a sequence of continuous embeddings from a sequence of discrete codes.
|
| 137 |
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
| 138 |
-
necessary for the language model, like <MASK>.
|
| 139 |
"""
|
| 140 |
n_codebooks = codes.shape[1]
|
| 141 |
latent = []
|
|
@@ -161,4 +159,3 @@ class CodebookEmbedding(nn.Module):
|
|
| 161 |
"""
|
| 162 |
x = self.out_proj(latents)
|
| 163 |
return x
|
| 164 |
-
|
|
|
|
| 1 |
+
from typing import Optional, Tuple
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 6 |
from torch.nn.utils import weight_norm
|
| 7 |
|
| 8 |
+
|
| 9 |
# Scripting this brings model speed up 1.4x
|
| 10 |
@torch.jit.script
|
| 11 |
def snake(x, alpha):
|
|
|
|
| 130 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
| 131 |
|
| 132 |
def from_codes(self, codes: torch.Tensor, codec):
|
| 133 |
+
"""
|
| 134 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
| 135 |
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
| 136 |
+
necessary for the language model, like <MASK>.
|
| 137 |
"""
|
| 138 |
n_codebooks = codes.shape[1]
|
| 139 |
latent = []
|
|
|
|
| 159 |
"""
|
| 160 |
x = self.out_proj(latents)
|
| 161 |
return x
|
|
|
vampnet/vampnet/modules/transformer.py
CHANGED
|
@@ -1,22 +1,19 @@
|
|
| 1 |
-
import math
|
| 2 |
import logging
|
| 3 |
-
|
|
|
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from einops import rearrange
|
| 10 |
-
import loralib as lora
|
| 11 |
-
import audiotools as at
|
| 12 |
|
| 13 |
-
from .activations import get_activation
|
| 14 |
-
from .layers import CodebookEmbedding
|
| 15 |
-
from .layers import FiLM
|
| 16 |
-
from .layers import SequentialWithFiLM
|
| 17 |
-
from .layers import WNConv1d
|
| 18 |
-
from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
|
| 19 |
from ..mask import _gamma
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
LORA_R = 8
|
| 22 |
|
|
@@ -279,6 +276,7 @@ class TransformerLayer(nn.Module):
|
|
| 279 |
|
| 280 |
if flash_attn:
|
| 281 |
from flash_attn.flash_attention import FlashMHA
|
|
|
|
| 282 |
self.self_attn = FlashMHA(
|
| 283 |
embed_dim=d_model,
|
| 284 |
num_heads=n_heads,
|
|
@@ -410,9 +408,15 @@ class TransformerStack(nn.Module):
|
|
| 410 |
def subsequent_mask(self, size):
|
| 411 |
return torch.ones(1, size, size).tril().bool()
|
| 412 |
|
| 413 |
-
def forward(
|
| 414 |
-
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
"""Computes a full transformer stack
|
| 417 |
Parameters
|
| 418 |
----------
|
|
@@ -454,7 +458,6 @@ class TransformerStack(nn.Module):
|
|
| 454 |
if return_activations:
|
| 455 |
activations.append(x.detach())
|
| 456 |
|
| 457 |
-
|
| 458 |
out = self.norm(x) if self.norm is not None else x
|
| 459 |
if return_activations:
|
| 460 |
return out, torch.stack(activations)
|
|
@@ -475,10 +478,12 @@ class VampNet(at.ml.BaseModel):
|
|
| 475 |
vocab_size: int = 1024,
|
| 476 |
flash_attn: bool = True,
|
| 477 |
noise_mode: str = "mask",
|
| 478 |
-
dropout: float = 0.1
|
| 479 |
):
|
| 480 |
super().__init__()
|
| 481 |
-
assert r_cond_dim == 0,
|
|
|
|
|
|
|
| 482 |
self.n_heads = n_heads
|
| 483 |
self.n_layers = n_layers
|
| 484 |
self.r_cond_dim = r_cond_dim
|
|
@@ -530,13 +535,15 @@ class VampNet(at.ml.BaseModel):
|
|
| 530 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
| 531 |
|
| 532 |
x = rearrange(x, "b d n -> b n d")
|
| 533 |
-
out = self.transformer(
|
|
|
|
|
|
|
| 534 |
if return_activations:
|
| 535 |
out, activations = out
|
| 536 |
|
| 537 |
out = rearrange(out, "b n d -> b d n")
|
| 538 |
|
| 539 |
-
out = self.classifier(out, None)
|
| 540 |
|
| 541 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
| 542 |
|
|
@@ -544,7 +551,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 544 |
return out, activations
|
| 545 |
else:
|
| 546 |
return out
|
| 547 |
-
|
| 548 |
def r_embed(self, r, max_positions=10000):
|
| 549 |
if self.r_cond_dim > 0:
|
| 550 |
dtype = r.dtype
|
|
@@ -564,11 +571,11 @@ class VampNet(at.ml.BaseModel):
|
|
| 564 |
return emb.to(dtype)
|
| 565 |
else:
|
| 566 |
return r
|
| 567 |
-
|
| 568 |
@torch.no_grad()
|
| 569 |
def to_signal(self, z, codec):
|
| 570 |
"""
|
| 571 |
-
convert a sequence of latents to a signal.
|
| 572 |
"""
|
| 573 |
assert z.ndim == 3
|
| 574 |
|
|
@@ -588,7 +595,6 @@ class VampNet(at.ml.BaseModel):
|
|
| 588 |
|
| 589 |
return signal
|
| 590 |
|
| 591 |
-
|
| 592 |
@torch.no_grad()
|
| 593 |
def generate(
|
| 594 |
self,
|
|
@@ -604,16 +610,14 @@ class VampNet(at.ml.BaseModel):
|
|
| 604 |
typical_min_tokens=1,
|
| 605 |
top_p=None,
|
| 606 |
return_signal=True,
|
| 607 |
-
seed: int = None,
|
| 608 |
sample_cutoff: float = 1.0,
|
| 609 |
):
|
| 610 |
if seed is not None:
|
| 611 |
at.util.seed(seed)
|
| 612 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 613 |
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
#####################
|
| 617 |
# resolve initial z #
|
| 618 |
#####################
|
| 619 |
z = start_tokens
|
|
@@ -625,7 +629,6 @@ class VampNet(at.ml.BaseModel):
|
|
| 625 |
|
| 626 |
logging.debug(f"created z with shape {z.shape}")
|
| 627 |
|
| 628 |
-
|
| 629 |
#################
|
| 630 |
# resolve mask #
|
| 631 |
#################
|
|
@@ -636,9 +639,8 @@ class VampNet(at.ml.BaseModel):
|
|
| 636 |
if mask.ndim == 2:
|
| 637 |
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
| 638 |
# init_mask = mask.clone()
|
| 639 |
-
|
| 640 |
-
logging.debug(f"created mask with shape {mask.shape}")
|
| 641 |
|
|
|
|
| 642 |
|
| 643 |
###########
|
| 644 |
# set up #
|
|
@@ -663,33 +665,33 @@ class VampNet(at.ml.BaseModel):
|
|
| 663 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 664 |
|
| 665 |
# our current schedule step
|
| 666 |
-
r = scalar_to_batch_tensor(
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
).to(z.device)
|
| 670 |
logging.debug(f"r: {r}")
|
| 671 |
|
| 672 |
# get latents
|
| 673 |
latents = self.embedding.from_codes(z_masked, codec)
|
| 674 |
logging.debug(f"computed latents with shape: {latents.shape}")
|
| 675 |
|
| 676 |
-
|
| 677 |
# infer from latents
|
| 678 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 679 |
-
logits = self.forward(latents)
|
| 680 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 681 |
b = logits.shape[0]
|
| 682 |
|
| 683 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 684 |
|
| 685 |
sampled_z, selected_probs = sample_from_logits(
|
| 686 |
-
logits,
|
| 687 |
-
|
| 688 |
-
),
|
| 689 |
temperature=sampling_temperature,
|
| 690 |
-
typical_filtering=typical_filtering,
|
|
|
|
| 691 |
typical_min_tokens=typical_min_tokens,
|
| 692 |
-
top_k=None,
|
|
|
|
|
|
|
| 693 |
)
|
| 694 |
|
| 695 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
|
@@ -697,46 +699,38 @@ class VampNet(at.ml.BaseModel):
|
|
| 697 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 698 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 699 |
# remove conditioning codebooks, we'll add them back at the end
|
| 700 |
-
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
| 701 |
|
| 702 |
mask = (z_masked == self.mask_token).int()
|
| 703 |
-
|
| 704 |
# update the mask, remove conditioning codebooks from the mask
|
| 705 |
logging.debug(f"updated mask with shape: {mask.shape}")
|
| 706 |
# add z back into sampled z where the mask was false
|
| 707 |
-
sampled_z = torch.where(
|
| 708 |
-
mask.bool(), sampled_z, z_masked
|
| 709 |
-
)
|
| 710 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
| 711 |
|
| 712 |
# ignore any tokens that weren't masked
|
| 713 |
-
selected_probs = torch.where(
|
| 714 |
-
mask.bool(), selected_probs, torch.inf
|
| 715 |
-
)
|
| 716 |
|
| 717 |
# get the num tokens to mask, according to the schedule
|
| 718 |
-
num_to_mask =
|
|
|
|
|
|
|
| 719 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 720 |
|
| 721 |
if i != (sampling_steps - 1):
|
| 722 |
num_to_mask = torch.maximum(
|
| 723 |
torch.tensor(1),
|
| 724 |
-
torch.minimum(
|
| 725 |
-
mask.sum(dim=-1, keepdim=True) - 1,
|
| 726 |
-
num_to_mask
|
| 727 |
-
)
|
| 728 |
)
|
| 729 |
|
| 730 |
-
|
| 731 |
# get our new mask
|
| 732 |
mask = mask_by_random_topk(
|
| 733 |
-
num_to_mask, selected_probs, mask_temperature * (1-r)
|
| 734 |
-
)
|
| 735 |
|
| 736 |
# update the mask
|
| 737 |
-
z_masked = torch.where(
|
| 738 |
-
mask.bool(), self.mask_token, sampled_z
|
| 739 |
-
)
|
| 740 |
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
| 741 |
|
| 742 |
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
|
@@ -745,35 +739,37 @@ class VampNet(at.ml.BaseModel):
|
|
| 745 |
|
| 746 |
# add conditioning codebooks back to z_masked
|
| 747 |
z_masked = torch.cat(
|
| 748 |
-
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
|
|
|
|
|
|
|
|
|
| 749 |
)
|
| 750 |
-
logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
|
| 751 |
-
|
| 752 |
|
| 753 |
# add conditioning codebooks back to sampled_z
|
| 754 |
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
| 755 |
sampled_z = torch.cat(
|
| 756 |
-
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
| 757 |
)
|
| 758 |
|
| 759 |
-
logging.debug(
|
| 760 |
|
| 761 |
if return_signal:
|
| 762 |
return self.to_signal(sampled_z, codec)
|
| 763 |
else:
|
| 764 |
return sampled_z
|
| 765 |
|
|
|
|
| 766 |
def sample_from_logits(
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
"""Convenience function to sample from a categorial distribution with input as
|
| 778 |
unnormalized logits.
|
| 779 |
|
|
@@ -801,9 +797,8 @@ def sample_from_logits(
|
|
| 801 |
shp = logits.shape[:-1]
|
| 802 |
|
| 803 |
if typical_filtering:
|
| 804 |
-
typical_filter(
|
| 805 |
-
|
| 806 |
-
typical_min_tokens=typical_min_tokens
|
| 807 |
)
|
| 808 |
|
| 809 |
# Apply top_k sampling
|
|
@@ -846,21 +841,20 @@ def sample_from_logits(
|
|
| 846 |
return token, token_probs
|
| 847 |
else:
|
| 848 |
return token
|
| 849 |
-
|
| 850 |
|
| 851 |
|
| 852 |
def mask_by_random_topk(
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
"""
|
| 858 |
Args:
|
| 859 |
num_to_mask (int): number of tokens to mask
|
| 860 |
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
| 861 |
temperature (float, optional): temperature. Defaults to 1.0.
|
| 862 |
"""
|
| 863 |
-
logging.debug(
|
| 864 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 865 |
logging.debug(f"probs shape: {probs.shape}")
|
| 866 |
logging.debug(f"temperature: {temperature}")
|
|
@@ -875,9 +869,7 @@ def mask_by_random_topk(
|
|
| 875 |
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
|
| 876 |
|
| 877 |
# get the cut off threshold, given the mask length
|
| 878 |
-
cut_off = torch.take_along_dim(
|
| 879 |
-
sorted_confidence, num_to_mask, axis=-1
|
| 880 |
-
)
|
| 881 |
logging.debug(f"cut off shape: {cut_off.shape}")
|
| 882 |
|
| 883 |
# mask out the tokens
|
|
@@ -886,10 +878,12 @@ def mask_by_random_topk(
|
|
| 886 |
|
| 887 |
return mask
|
| 888 |
|
|
|
|
| 889 |
def typical_filter(
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
|
|
|
| 893 |
nb, nt, _ = logits.shape
|
| 894 |
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
| 895 |
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
|
@@ -898,9 +892,7 @@ def typical_filter(
|
|
| 898 |
|
| 899 |
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
| 900 |
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
| 901 |
-
x_flat_cumsum = (
|
| 902 |
-
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
| 903 |
-
)
|
| 904 |
|
| 905 |
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
| 906 |
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
|
@@ -933,7 +925,7 @@ if __name__ == "__main__":
|
|
| 933 |
).to(device)
|
| 934 |
|
| 935 |
r = torch.zeros(batch_size).to(device)
|
| 936 |
-
|
| 937 |
z_mask_latent = torch.rand(
|
| 938 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
| 939 |
).to(device)
|
|
@@ -942,12 +934,10 @@ if __name__ == "__main__":
|
|
| 942 |
pred = z_hat.argmax(dim=1)
|
| 943 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
| 944 |
|
| 945 |
-
print(f"model has {num_params(model)/1e6:<.3f}M parameters")
|
| 946 |
print(f"prediction has shape {pred.shape}")
|
| 947 |
breakpoint()
|
| 948 |
|
| 949 |
args = argbind.parse_args()
|
| 950 |
with argbind.scope(args):
|
| 951 |
try_model()
|
| 952 |
-
|
| 953 |
-
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import math
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
+
import audiotools as at
|
| 6 |
+
import loralib as lora
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from einops import rearrange
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from ..mask import _gamma
|
| 14 |
+
from ..util import codebook_flatten, codebook_unflatten, scalar_to_batch_tensor
|
| 15 |
+
from .activations import get_activation
|
| 16 |
+
from .layers import CodebookEmbedding, FiLM, SequentialWithFiLM, WNConv1d
|
| 17 |
|
| 18 |
LORA_R = 8
|
| 19 |
|
|
|
|
| 276 |
|
| 277 |
if flash_attn:
|
| 278 |
from flash_attn.flash_attention import FlashMHA
|
| 279 |
+
|
| 280 |
self.self_attn = FlashMHA(
|
| 281 |
embed_dim=d_model,
|
| 282 |
num_heads=n_heads,
|
|
|
|
| 408 |
def subsequent_mask(self, size):
|
| 409 |
return torch.ones(1, size, size).tril().bool()
|
| 410 |
|
| 411 |
+
def forward(
|
| 412 |
+
self,
|
| 413 |
+
x,
|
| 414 |
+
x_mask,
|
| 415 |
+
cond=None,
|
| 416 |
+
src=None,
|
| 417 |
+
src_mask=None,
|
| 418 |
+
return_activations: bool = False,
|
| 419 |
+
):
|
| 420 |
"""Computes a full transformer stack
|
| 421 |
Parameters
|
| 422 |
----------
|
|
|
|
| 458 |
if return_activations:
|
| 459 |
activations.append(x.detach())
|
| 460 |
|
|
|
|
| 461 |
out = self.norm(x) if self.norm is not None else x
|
| 462 |
if return_activations:
|
| 463 |
return out, torch.stack(activations)
|
|
|
|
| 478 |
vocab_size: int = 1024,
|
| 479 |
flash_attn: bool = True,
|
| 480 |
noise_mode: str = "mask",
|
| 481 |
+
dropout: float = 0.1,
|
| 482 |
):
|
| 483 |
super().__init__()
|
| 484 |
+
assert r_cond_dim == 0, (
|
| 485 |
+
f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
| 486 |
+
)
|
| 487 |
self.n_heads = n_heads
|
| 488 |
self.n_layers = n_layers
|
| 489 |
self.r_cond_dim = r_cond_dim
|
|
|
|
| 535 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
| 536 |
|
| 537 |
x = rearrange(x, "b d n -> b n d")
|
| 538 |
+
out = self.transformer(
|
| 539 |
+
x=x, x_mask=x_mask, return_activations=return_activations
|
| 540 |
+
)
|
| 541 |
if return_activations:
|
| 542 |
out, activations = out
|
| 543 |
|
| 544 |
out = rearrange(out, "b n d -> b d n")
|
| 545 |
|
| 546 |
+
out = self.classifier(out, None) # no cond here!
|
| 547 |
|
| 548 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
| 549 |
|
|
|
|
| 551 |
return out, activations
|
| 552 |
else:
|
| 553 |
return out
|
| 554 |
+
|
| 555 |
def r_embed(self, r, max_positions=10000):
|
| 556 |
if self.r_cond_dim > 0:
|
| 557 |
dtype = r.dtype
|
|
|
|
| 571 |
return emb.to(dtype)
|
| 572 |
else:
|
| 573 |
return r
|
| 574 |
+
|
| 575 |
@torch.no_grad()
|
| 576 |
def to_signal(self, z, codec):
|
| 577 |
"""
|
| 578 |
+
convert a sequence of latents to a signal.
|
| 579 |
"""
|
| 580 |
assert z.ndim == 3
|
| 581 |
|
|
|
|
| 595 |
|
| 596 |
return signal
|
| 597 |
|
|
|
|
| 598 |
@torch.no_grad()
|
| 599 |
def generate(
|
| 600 |
self,
|
|
|
|
| 610 |
typical_min_tokens=1,
|
| 611 |
top_p=None,
|
| 612 |
return_signal=True,
|
| 613 |
+
seed: int = None,
|
| 614 |
sample_cutoff: float = 1.0,
|
| 615 |
):
|
| 616 |
if seed is not None:
|
| 617 |
at.util.seed(seed)
|
| 618 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 619 |
|
| 620 |
+
#####################
|
|
|
|
|
|
|
| 621 |
# resolve initial z #
|
| 622 |
#####################
|
| 623 |
z = start_tokens
|
|
|
|
| 629 |
|
| 630 |
logging.debug(f"created z with shape {z.shape}")
|
| 631 |
|
|
|
|
| 632 |
#################
|
| 633 |
# resolve mask #
|
| 634 |
#################
|
|
|
|
| 639 |
if mask.ndim == 2:
|
| 640 |
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
| 641 |
# init_mask = mask.clone()
|
|
|
|
|
|
|
| 642 |
|
| 643 |
+
logging.debug(f"created mask with shape {mask.shape}")
|
| 644 |
|
| 645 |
###########
|
| 646 |
# set up #
|
|
|
|
| 665 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 666 |
|
| 667 |
# our current schedule step
|
| 668 |
+
r = scalar_to_batch_tensor((i + 1) / sampling_steps, z.shape[0]).to(
|
| 669 |
+
z.device
|
| 670 |
+
)
|
|
|
|
| 671 |
logging.debug(f"r: {r}")
|
| 672 |
|
| 673 |
# get latents
|
| 674 |
latents = self.embedding.from_codes(z_masked, codec)
|
| 675 |
logging.debug(f"computed latents with shape: {latents.shape}")
|
| 676 |
|
|
|
|
| 677 |
# infer from latents
|
| 678 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 679 |
+
logits = self.forward(latents) # b, prob, seq
|
| 680 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 681 |
b = logits.shape[0]
|
| 682 |
|
| 683 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 684 |
|
| 685 |
sampled_z, selected_probs = sample_from_logits(
|
| 686 |
+
logits,
|
| 687 |
+
sample=((i / sampling_steps) <= sample_cutoff),
|
|
|
|
| 688 |
temperature=sampling_temperature,
|
| 689 |
+
typical_filtering=typical_filtering,
|
| 690 |
+
typical_mass=typical_mass,
|
| 691 |
typical_min_tokens=typical_min_tokens,
|
| 692 |
+
top_k=None,
|
| 693 |
+
top_p=top_p,
|
| 694 |
+
return_probs=True,
|
| 695 |
)
|
| 696 |
|
| 697 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
|
|
|
| 699 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 700 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 701 |
# remove conditioning codebooks, we'll add them back at the end
|
| 702 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks :, :])
|
| 703 |
|
| 704 |
mask = (z_masked == self.mask_token).int()
|
| 705 |
+
|
| 706 |
# update the mask, remove conditioning codebooks from the mask
|
| 707 |
logging.debug(f"updated mask with shape: {mask.shape}")
|
| 708 |
# add z back into sampled z where the mask was false
|
| 709 |
+
sampled_z = torch.where(mask.bool(), sampled_z, z_masked)
|
|
|
|
|
|
|
| 710 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
| 711 |
|
| 712 |
# ignore any tokens that weren't masked
|
| 713 |
+
selected_probs = torch.where(mask.bool(), selected_probs, torch.inf)
|
|
|
|
|
|
|
| 714 |
|
| 715 |
# get the num tokens to mask, according to the schedule
|
| 716 |
+
num_to_mask = (
|
| 717 |
+
torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
| 718 |
+
)
|
| 719 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 720 |
|
| 721 |
if i != (sampling_steps - 1):
|
| 722 |
num_to_mask = torch.maximum(
|
| 723 |
torch.tensor(1),
|
| 724 |
+
torch.minimum(mask.sum(dim=-1, keepdim=True) - 1, num_to_mask),
|
|
|
|
|
|
|
|
|
|
| 725 |
)
|
| 726 |
|
|
|
|
| 727 |
# get our new mask
|
| 728 |
mask = mask_by_random_topk(
|
| 729 |
+
num_to_mask, selected_probs, mask_temperature * (1 - r)
|
| 730 |
+
)
|
| 731 |
|
| 732 |
# update the mask
|
| 733 |
+
z_masked = torch.where(mask.bool(), self.mask_token, sampled_z)
|
|
|
|
|
|
|
| 734 |
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
| 735 |
|
| 736 |
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
|
|
|
| 739 |
|
| 740 |
# add conditioning codebooks back to z_masked
|
| 741 |
z_masked = torch.cat(
|
| 742 |
+
(z[:, : self.n_conditioning_codebooks, :], z_masked), dim=1
|
| 743 |
+
)
|
| 744 |
+
logging.debug(
|
| 745 |
+
f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}"
|
| 746 |
)
|
|
|
|
|
|
|
| 747 |
|
| 748 |
# add conditioning codebooks back to sampled_z
|
| 749 |
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
| 750 |
sampled_z = torch.cat(
|
| 751 |
+
(z[:, : self.n_conditioning_codebooks, :], sampled_z), dim=1
|
| 752 |
)
|
| 753 |
|
| 754 |
+
logging.debug("finished sampling")
|
| 755 |
|
| 756 |
if return_signal:
|
| 757 |
return self.to_signal(sampled_z, codec)
|
| 758 |
else:
|
| 759 |
return sampled_z
|
| 760 |
|
| 761 |
+
|
| 762 |
def sample_from_logits(
|
| 763 |
+
logits,
|
| 764 |
+
sample: bool = True,
|
| 765 |
+
temperature: float = 1.0,
|
| 766 |
+
top_k: int = None,
|
| 767 |
+
top_p: float = None,
|
| 768 |
+
typical_filtering: bool = False,
|
| 769 |
+
typical_mass: float = 0.2,
|
| 770 |
+
typical_min_tokens: int = 1,
|
| 771 |
+
return_probs: bool = False,
|
| 772 |
+
):
|
| 773 |
"""Convenience function to sample from a categorial distribution with input as
|
| 774 |
unnormalized logits.
|
| 775 |
|
|
|
|
| 797 |
shp = logits.shape[:-1]
|
| 798 |
|
| 799 |
if typical_filtering:
|
| 800 |
+
typical_filter(
|
| 801 |
+
logits, typical_mass=typical_mass, typical_min_tokens=typical_min_tokens
|
|
|
|
| 802 |
)
|
| 803 |
|
| 804 |
# Apply top_k sampling
|
|
|
|
| 841 |
return token, token_probs
|
| 842 |
else:
|
| 843 |
return token
|
|
|
|
| 844 |
|
| 845 |
|
| 846 |
def mask_by_random_topk(
|
| 847 |
+
num_to_mask: int,
|
| 848 |
+
probs: torch.Tensor,
|
| 849 |
+
temperature: float = 1.0,
|
| 850 |
+
):
|
| 851 |
"""
|
| 852 |
Args:
|
| 853 |
num_to_mask (int): number of tokens to mask
|
| 854 |
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
| 855 |
temperature (float, optional): temperature. Defaults to 1.0.
|
| 856 |
"""
|
| 857 |
+
logging.debug("masking by random topk")
|
| 858 |
logging.debug(f"num to mask: {num_to_mask}")
|
| 859 |
logging.debug(f"probs shape: {probs.shape}")
|
| 860 |
logging.debug(f"temperature: {temperature}")
|
|
|
|
| 869 |
logging.debug(f"sorted idx shape: {sorted_idx.shape}")
|
| 870 |
|
| 871 |
# get the cut off threshold, given the mask length
|
| 872 |
+
cut_off = torch.take_along_dim(sorted_confidence, num_to_mask, axis=-1)
|
|
|
|
|
|
|
| 873 |
logging.debug(f"cut off shape: {cut_off.shape}")
|
| 874 |
|
| 875 |
# mask out the tokens
|
|
|
|
| 878 |
|
| 879 |
return mask
|
| 880 |
|
| 881 |
+
|
| 882 |
def typical_filter(
|
| 883 |
+
logits,
|
| 884 |
+
typical_mass: float = 0.95,
|
| 885 |
+
typical_min_tokens: int = 1,
|
| 886 |
+
):
|
| 887 |
nb, nt, _ = logits.shape
|
| 888 |
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
| 889 |
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
|
|
|
| 892 |
|
| 893 |
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
| 894 |
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
| 895 |
+
x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
|
|
|
|
|
|
| 896 |
|
| 897 |
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
| 898 |
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
|
|
|
| 925 |
).to(device)
|
| 926 |
|
| 927 |
r = torch.zeros(batch_size).to(device)
|
| 928 |
+
|
| 929 |
z_mask_latent = torch.rand(
|
| 930 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
| 931 |
).to(device)
|
|
|
|
| 934 |
pred = z_hat.argmax(dim=1)
|
| 935 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
| 936 |
|
| 937 |
+
print(f"model has {num_params(model) / 1e6:<.3f}M parameters")
|
| 938 |
print(f"prediction has shape {pred.shape}")
|
| 939 |
breakpoint()
|
| 940 |
|
| 941 |
args = argbind.parse_args()
|
| 942 |
with argbind.scope(args):
|
| 943 |
try_model()
|
|
|
|
|
|
vampnet/vampnet/scheduler.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
import torch
|
| 5 |
|
|
|
|
| 6 |
class NoamScheduler:
|
| 7 |
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
| 8 |
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
|
@@ -44,4 +42,3 @@ class NoamScheduler:
|
|
| 44 |
|
| 45 |
for p in self.optimizer.param_groups:
|
| 46 |
p["lr"] = self.lr
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
+
|
| 4 |
class NoamScheduler:
|
| 5 |
"""OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
|
| 6 |
Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
|
|
|
|
| 42 |
|
| 43 |
for p in self.optimizer.param_groups:
|
| 44 |
p["lr"] = self.lr
|
|
|
vampnet/vampnet/util.py
CHANGED
|
@@ -1,43 +1,36 @@
|
|
| 1 |
-
import tqdm
|
| 2 |
-
|
| 3 |
import torch
|
|
|
|
| 4 |
from einops import rearrange
|
| 5 |
|
|
|
|
| 6 |
def scalar_to_batch_tensor(x, batch_size):
|
| 7 |
return torch.tensor(x).repeat(batch_size)
|
| 8 |
|
| 9 |
|
| 10 |
-
def parallelize(
|
| 11 |
-
fn,
|
| 12 |
-
*iterables,
|
| 13 |
-
parallel: str = "thread_map",
|
| 14 |
-
**kwargs
|
| 15 |
-
):
|
| 16 |
if parallel == "thread_map":
|
| 17 |
from tqdm.contrib.concurrent import thread_map
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
*iterables,
|
| 21 |
-
**kwargs
|
| 22 |
-
)
|
| 23 |
elif parallel == "process_map":
|
| 24 |
from tqdm.contrib.concurrent import process_map
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
*iterables,
|
| 28 |
-
**kwargs
|
| 29 |
-
)
|
| 30 |
elif parallel == "single":
|
| 31 |
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
| 32 |
else:
|
| 33 |
-
raise ValueError(
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
def codebook_flatten(tokens: torch.Tensor):
|
| 36 |
-
"""
|
| 37 |
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
| 38 |
"""
|
| 39 |
return rearrange(tokens, "b c t -> b (t c)")
|
| 40 |
|
|
|
|
| 41 |
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
|
| 42 |
"""
|
| 43 |
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import tqdm
|
| 3 |
from einops import rearrange
|
| 4 |
|
| 5 |
+
|
| 6 |
def scalar_to_batch_tensor(x, batch_size):
|
| 7 |
return torch.tensor(x).repeat(batch_size)
|
| 8 |
|
| 9 |
|
| 10 |
+
def parallelize(fn, *iterables, parallel: str = "thread_map", **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
if parallel == "thread_map":
|
| 12 |
from tqdm.contrib.concurrent import thread_map
|
| 13 |
+
|
| 14 |
+
return thread_map(fn, *iterables, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 15 |
elif parallel == "process_map":
|
| 16 |
from tqdm.contrib.concurrent import process_map
|
| 17 |
+
|
| 18 |
+
return process_map(fn, *iterables, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 19 |
elif parallel == "single":
|
| 20 |
return [fn(x) for x in tqdm.tqdm(*iterables)]
|
| 21 |
else:
|
| 22 |
+
raise ValueError(
|
| 23 |
+
f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
def codebook_flatten(tokens: torch.Tensor):
|
| 28 |
+
"""
|
| 29 |
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
| 30 |
"""
|
| 31 |
return rearrange(tokens, "b c t -> b (t c)")
|
| 32 |
|
| 33 |
+
|
| 34 |
def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
|
| 35 |
"""
|
| 36 |
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
wham.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: wham
|
| 3 |
+
Version: 0.0.1
|
| 4 |
+
Summary: Towards A Translative Model of Sperm Whale Vocalization
|
| 5 |
+
Author: Project CETI
|
| 6 |
+
License: MIT
|
| 7 |
+
Requires-Python: <3.11,>=3.10
|
| 8 |
+
Description-Content-Type: text/markdown
|
| 9 |
+
License-File: LICENSE
|
| 10 |
+
Requires-Dist: torch
|
| 11 |
+
Requires-Dist: gradio
|
| 12 |
+
Requires-Dist: argbind>=0.3.2
|
| 13 |
+
Requires-Dist: numpy<1.24
|
| 14 |
+
Requires-Dist: pydantic<3,>=2.0
|
| 15 |
+
Requires-Dist: huggingface_hub
|
| 16 |
+
Requires-Dist: loralib
|
| 17 |
+
Requires-Dist: torch_pitch_shift
|
| 18 |
+
Requires-Dist: soundfile
|
| 19 |
+
Requires-Dist: pydub
|
| 20 |
+
Requires-Dist: tqdm
|
| 21 |
+
Requires-Dist: Cython
|
| 22 |
+
Requires-Dist: pandas
|
| 23 |
+
Requires-Dist: pathlib
|
| 24 |
+
Requires-Dist: ffmpeg-python
|
| 25 |
+
Requires-Dist: scikit-learn
|
| 26 |
+
Requires-Dist: wandb
|
| 27 |
+
Requires-Dist: gdown
|
| 28 |
+
Requires-Dist: transformers
|
| 29 |
+
Requires-Dist: fadtk
|
| 30 |
+
Requires-Dist: urllib3>=2.0.2
|
| 31 |
+
Requires-Dist: plotly
|
| 32 |
+
Requires-Dist: pyharp
|
| 33 |
+
Requires-Dist: ruff
|
| 34 |
+
Requires-Dist: wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
|
| 35 |
+
Requires-Dist: lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 36 |
+
Requires-Dist: descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
| 37 |
+
Dynamic: license-file
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
title: WhAM
|
| 41 |
+
emoji: 🐋
|
| 42 |
+
colorFrom: blue
|
| 43 |
+
colorTo: indigo
|
| 44 |
+
sdk: docker
|
| 45 |
+
pinned: false
|
| 46 |
+
hardware: a10g-small
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
# WhAM: a Whale Acoustics Model
|
| 50 |
+
[](https://arxiv.org/abs/2512.02206)
|
| 51 |
+
[](https://doi.org/10.5281/zenodo.17633708)
|
| 52 |
+
[](https://huggingface.co/datasets/orrp/DSWP)
|
| 53 |
+

|
| 54 |
+
WhAM is a transformer-based audio-to-audio model designed to synthesize and analyze sperm whale codas. Based on [VampNet](https://github.com/hugofloresgarcia/vampnet), WhAM uses masked acoustic token modeling to capture temporal and spectral features of whale communication. WhAM generates codas from a given audio context, enabling three core capabilities:
|
| 55 |
+
|
| 56 |
+
- Acoustic Translation: The ability to style-transfer arbitrary audio prompts (e.g., human speech, noise) into the acoustic texture of sperm whale codas.
|
| 57 |
+
|
| 58 |
+
- Synthesizing novel "pseudocodas".
|
| 59 |
+
|
| 60 |
+
- Providing audio embeddings for downstream tasks such as social unit and spectral feature ("vowel") classification.
|
| 61 |
+
|
| 62 |
+
See our [NeurIPS 2025](https://openreview.net/pdf?id=IL1wvzOgqD) publication for more details.
|
| 63 |
+
|
| 64 |
+
## Installation
|
| 65 |
+
|
| 66 |
+
1. **Clone the repository:**
|
| 67 |
+
```bash
|
| 68 |
+
git clone https://github.com/Project-CETI/wham.git
|
| 69 |
+
cd wham
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
2. **Set up the environment:**
|
| 73 |
+
```bash
|
| 74 |
+
conda create -n wham python=3.9
|
| 75 |
+
conda activate wham
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
3. **Install dependencies:**
|
| 79 |
+
```bash
|
| 80 |
+
# Install the wham package
|
| 81 |
+
pip install -e .
|
| 82 |
+
|
| 83 |
+
# Install VampNet
|
| 84 |
+
pip install -e ./vampnet
|
| 85 |
+
|
| 86 |
+
# Install madmom
|
| 87 |
+
pip install --no-build-isolation madmom
|
| 88 |
+
|
| 89 |
+
# Install ffmpeg
|
| 90 |
+
conda install -c conda-forge ffmpeg
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
4. **Download model weights:**
|
| 94 |
+
Download the [weights](https://zenodo.org/records/17633708) and extract to `vampnet/models/`.
|
| 95 |
+
|
| 96 |
+
## Generation
|
| 97 |
+
|
| 98 |
+
To run WhAM locally and prompt it in your browser:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python vampnet/app.py --args.load conf/interface.yml --Interface.device cuda
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
This will provide you with a Gradio link to test WhAM on inputs of your choice.
|
| 105 |
+
|
| 106 |
+
## Training Data
|
| 107 |
+
|
| 108 |
+

|
| 109 |
+
|
| 110 |
+
You only need to follow these to fine-tune your own version of WhAM. First, obtain the original VampNet weights by following the instructions in the . Download
|
| 111 |
+
c2f.pth and codec.pth and replace the weights you previously downloaded in `vampnet/models`.
|
| 112 |
+
|
| 113 |
+
Second, obtain data:
|
| 114 |
+
|
| 115 |
+
1. **Domain adaptation data:**
|
| 116 |
+
|
| 117 |
+
- Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `vampnet/training_data/domain_adaptation`.
|
| 118 |
+
|
| 119 |
+
- Download audio samples from the [BirdSet Dataset](https://huggingface.co/datasets/DBD-research-group/BirdSet). Save these under the same directory
|
| 120 |
+
|
| 121 |
+
- Finally, download all samples from the [AudioSet Dataset](https://research.google.com/audioset/ontology/index.html) with the label `Animal` and once again save these into the directory
|
| 122 |
+
|
| 123 |
+
3. **Species-specific finetuning:** Finetuning can be performed on the openly available **[Dominica Sperm Whale Project (DSWP)](https://huggingface.co/datasets/orrp/DSWP)** dataset, available on Hugging Face.
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
With data in hand, navigate into `vampnet` and perform Domain Adaptation:
|
| 127 |
+
```bash
|
| 128 |
+
python vampnet/scripts/exp/fine_tune.py "training_data/domain_adaptation" domain_adapted && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/c2f.yml
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
Then fine-tune the domain-adapted model. Create the config file with the command:
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
python vampnet/scripts/exp/fine_tune.py "training_data/species_specific_finetuning" fine-tuned
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
To select which weights you want to use as a checkpoint, change `fine_tune_checkpoint` in `conf/generated/fine-tuned/[c2f/coarse].yml` to `./runs/domain_adaptation/[coarse/c2f]/[checkpoint]/vampnets/weights.pth`. `[checkpoint]` can be `latest` in order to use the last saved checkpoint from the previous run, though it is recommended to manually verify the quality of generations over various checkpoints as overtraining can often cause degradation in audio quality, especially with smaller datasets. After making that change, run the command:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/c2f.yml
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
After following these steps, you should be able to generate audio via the browser by running:
|
| 144 |
+
```bash
|
| 145 |
+
python app.py --args.load vampnet/conf/generated/fine-tuned/interface.yml
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
**Note**: The coarse and fine weights can be trained separately if compute allows. In this case, you would call the two scripts:
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/coarse.yml
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/c2f.yml
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
After both are finished running, ensure that both resulting weights are copied into the same copy of WhAM.
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
## Testing Data
|
| 163 |
+
|
| 164 |
+
1. **Marine Mammel Data:**
|
| 165 |
+
Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `data/testing_data/marine_mammals/data/[SPECIES_NAME]`.
|
| 166 |
+
* `[SPECIES_NAME]` must match the species names found in `wham/generation/prompt_configs.py`.
|
| 167 |
+
|
| 168 |
+
2. **Sperm Whale Codas:**
|
| 169 |
+
To evaluate on sperm whale codas, you can use the openly available [DSWP](https://huggingface.co/datasets/orrp/DSWP) dataset.
|
| 170 |
+
|
| 171 |
+
3. Generate artifical beeps for experiments. `data/generate_beeps.sh`
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
## Reproducing Paper Results
|
| 175 |
+
Note: Access to the DSWP+CETI annotated is required to reproduce all results; as of time of publication, only part of this data is publicly available. Still, we include the following code as it may be useful for researchers who may benefit from our evaluation pipeline.
|
| 176 |
+
|
| 177 |
+
### 1. Downstream Classification Tasks
|
| 178 |
+
To reproduce **Table 1** (Classification Accuracies) and **Figure 7** (Ablation Study):
|
| 179 |
+
|
| 180 |
+
**Table 1 Results:**
|
| 181 |
+
```bash
|
| 182 |
+
cd wham/embedding
|
| 183 |
+
./downstream_tasks.sh
|
| 184 |
+
```
|
| 185 |
+
* Runs all downstream classification tasks.
|
| 186 |
+
* **Baselines:** Run once.
|
| 187 |
+
* **Models (AVES, VampNet):** Run over 3 random seeds; reports mean and standard deviation.
|
| 188 |
+
|
| 189 |
+
**Figure 7 Results (Ablation):**
|
| 190 |
+
```bash
|
| 191 |
+
cd wham/embedding
|
| 192 |
+
./downstream_ablation.sh
|
| 193 |
+
```
|
| 194 |
+
* Outputs accuracy scores for ablation variants (averaged across 3 seeds with error bars).
|
| 195 |
+
|
| 196 |
+
### 2. Generative Metrics
|
| 197 |
+
|
| 198 |
+
**Figure 12: Frechet Audio Distance (FAD) Scores**
|
| 199 |
+
Calculate the distance between WhAM's generated results and real codas:
|
| 200 |
+
```bash
|
| 201 |
+
# Calculate for all species
|
| 202 |
+
bash wham/generation/eval/calculate_FAD.sh
|
| 203 |
+
|
| 204 |
+
# Calculate for a single species
|
| 205 |
+
bash wham/generation/eval/calculate_FAD.sh [species_name]
|
| 206 |
+
```
|
| 207 |
+
* *Runtime:* ~3 hours on an NVIDIA A10 GPU.
|
| 208 |
+
|
| 209 |
+
**Figure 3: FAD with Custom/BirdNET Embeddings**
|
| 210 |
+
To compare against other embeddings:
|
| 211 |
+
1. Convert your `.wav` files to `.npy` embeddings.
|
| 212 |
+
2. Place raw coda embeddings in: `data/testing_data/coda_embeddings`
|
| 213 |
+
3. Place comparison embeddings in subfolders within: `data/testing_data/comparison_embeddings`
|
| 214 |
+
4. Run:
|
| 215 |
+
```bash
|
| 216 |
+
python wham/generation/eval/calculate_custom_fad.py
|
| 217 |
+
```
|
| 218 |
+
*For BirdNET embeddings, refer to the [official repo](https://github.com/BirdNET-Team/BirdNET-Analyzer).*
|
| 219 |
+
|
| 220 |
+
**Table 2: Embedding Type Ablation**
|
| 221 |
+
Calculate distances between raw codas, denoised versions, and noise profiles:
|
| 222 |
+
```bash
|
| 223 |
+
bash wham/generation/eval/FAD_ablation.sh
|
| 224 |
+
```
|
| 225 |
+
* *Prerequisites:* Ensure `data/testing_data/ablation/noise` and `data/testing_data/ablation/denoised` are populated.
|
| 226 |
+
* *Runtime:* ~1.5 hours on an NVIDIA A10 GPU.
|
| 227 |
+
|
| 228 |
+
**Figure 13: Tokenizer Reconstruction**
|
| 229 |
+
Test the mean squared reconstruction error:
|
| 230 |
+
```bash
|
| 231 |
+
bash wham/generation/eval/evaluate_tokenizer.sh
|
| 232 |
+
```
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## Citation
|
| 237 |
+
|
| 238 |
+
Please use the following citation if you use this code, model or data.
|
| 239 |
+
|
| 240 |
+
```bibtex
|
| 241 |
+
@inproceedings{wham2025,
|
| 242 |
+
title={Towards A Translative Model of Sperm Whale Vocalization},
|
| 243 |
+
author={Orr Paradise, Pranav Muralikrishnan, Liangyuan Chen, Hugo Flores Garcia, Bryan Pardo, Roee Diamant, David F. Gruber, Shane Gero, Shafi Goldwasser},
|
| 244 |
+
booktitle={Advances in Neural Information Processing Systems 39: Annual Conference
|
| 245 |
+
on Neural Information Processing Systems 2025, NeurIPS 2025, San Diego, CA, USA},
|
| 246 |
+
year={2025}
|
| 247 |
+
}
|
| 248 |
+
```
|
wham.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
vampnet/app.py
|
| 5 |
+
vampnet/setup.py
|
| 6 |
+
vampnet/scripts/exp/eval.py
|
| 7 |
+
vampnet/scripts/exp/experiment.py
|
| 8 |
+
vampnet/scripts/exp/fine_tune.py
|
| 9 |
+
vampnet/scripts/exp/train.py
|
| 10 |
+
vampnet/scripts/utils/plots.py
|
| 11 |
+
vampnet/scripts/utils/remove_quiet_files.py
|
| 12 |
+
vampnet/scripts/utils/split.py
|
| 13 |
+
vampnet/scripts/utils/split_long_audio_file.py
|
| 14 |
+
vampnet/scripts/utils/stage.py
|
| 15 |
+
vampnet/scripts/utils/visualize_embeddings.py
|
| 16 |
+
vampnet/scripts/utils/xeno-canto-dl.py
|
| 17 |
+
vampnet/scripts/utils/data/augment.py
|
| 18 |
+
vampnet/scripts/utils/data/maestro-reorg.py
|
| 19 |
+
vampnet/vampnet/__init__.py
|
| 20 |
+
vampnet/vampnet/beats.py
|
| 21 |
+
vampnet/vampnet/interface.py
|
| 22 |
+
vampnet/vampnet/mask.py
|
| 23 |
+
vampnet/vampnet/scheduler.py
|
| 24 |
+
vampnet/vampnet/util.py
|
| 25 |
+
vampnet/vampnet/modules/__init__.py
|
| 26 |
+
vampnet/vampnet/modules/activations.py
|
| 27 |
+
vampnet/vampnet/modules/layers.py
|
| 28 |
+
vampnet/vampnet/modules/transformer.py
|
| 29 |
+
wham.egg-info/PKG-INFO
|
| 30 |
+
wham.egg-info/SOURCES.txt
|
| 31 |
+
wham.egg-info/dependency_links.txt
|
| 32 |
+
wham.egg-info/requires.txt
|
| 33 |
+
wham.egg-info/top_level.txt
|
wham.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
wham.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
gradio
|
| 3 |
+
argbind>=0.3.2
|
| 4 |
+
numpy<1.24
|
| 5 |
+
pydantic<3,>=2.0
|
| 6 |
+
huggingface_hub
|
| 7 |
+
loralib
|
| 8 |
+
torch_pitch_shift
|
| 9 |
+
soundfile
|
| 10 |
+
pydub
|
| 11 |
+
tqdm
|
| 12 |
+
Cython
|
| 13 |
+
pandas
|
| 14 |
+
pathlib
|
| 15 |
+
ffmpeg-python
|
| 16 |
+
scikit-learn
|
| 17 |
+
wandb
|
| 18 |
+
gdown
|
| 19 |
+
transformers
|
| 20 |
+
fadtk
|
| 21 |
+
urllib3>=2.0.2
|
| 22 |
+
plotly
|
| 23 |
+
pyharp
|
| 24 |
+
ruff
|
| 25 |
+
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
|
| 26 |
+
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 27 |
+
descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
wham.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
vampnet
|