Spaces:
Runtime error
Runtime error
c2f prompts
#4
by
hugggof
- opened
- .gitignore +3 -6
- README.md +2 -11
- app.py +42 -227
- conf/lora/lora.yml +3 -5
- conf/vampnet.yml +1 -1
- requirements.txt +2 -4
- scripts/exp/fine_tune.py +4 -3
- scripts/exp/train.py +15 -23
- scripts/utils/{data/augment.py → augment.py} +24 -38
- scripts/utils/gtzan_embeddings.py +0 -263
- scripts/utils/{data/maestro-reorg.py → maestro-reorg.py} +0 -0
- scripts/utils/remove_quiet_files.py +0 -29
- scripts/utils/split_long_audio_file.py +0 -34
- scripts/utils/xeno-canto-dl.py +0 -234
- setup.py +2 -3
- vampnet/interface.py +6 -5
- vampnet/mask.py +20 -38
- vampnet/modules/transformer.py +51 -138
.gitignore
CHANGED
|
@@ -175,14 +175,11 @@ lyrebird-audio-codec
|
|
| 175 |
samples-*/**
|
| 176 |
|
| 177 |
gradio-outputs/
|
| 178 |
-
models/
|
| 179 |
samples*/
|
| 180 |
models-all/
|
| 181 |
models.zip
|
|
|
|
|
|
|
|
|
|
| 182 |
.git-old
|
| 183 |
conf/generated/*
|
| 184 |
-
runs*/
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
gtzan.zip
|
| 188 |
-
.gtzan_emb_cache
|
|
|
|
| 175 |
samples-*/**
|
| 176 |
|
| 177 |
gradio-outputs/
|
|
|
|
| 178 |
samples*/
|
| 179 |
models-all/
|
| 180 |
models.zip
|
| 181 |
+
audiotools/
|
| 182 |
+
descript-audio-codec/
|
| 183 |
+
# *.pth
|
| 184 |
.git-old
|
| 185 |
conf/generated/*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -7,7 +7,6 @@ sdk: gradio
|
|
| 7 |
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
python_version: 3.9
|
| 11 |
---
|
| 12 |
|
| 13 |
# VampNet
|
|
@@ -19,15 +18,7 @@ you can try vampnet in a co-creative looper called unloop. see this link: https:
|
|
| 19 |
|
| 20 |
# Setting up
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
you'll need a Python 3.9 environment to run VampNet. This is due to a [known issue with madmom](https://github.com/hugofloresgarcia/vampnet/issues/15).
|
| 25 |
-
|
| 26 |
-
(for example, using conda)
|
| 27 |
-
```bash
|
| 28 |
-
conda create -n vampnet python=3.9
|
| 29 |
-
conda activate vampnet
|
| 30 |
-
```
|
| 31 |
|
| 32 |
|
| 33 |
install VampNet
|
|
@@ -100,7 +91,7 @@ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
|
|
| 100 |
|
| 101 |
launch the interface:
|
| 102 |
```bash
|
| 103 |
-
python
|
| 104 |
```
|
| 105 |
|
| 106 |
|
|
|
|
| 7 |
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
# VampNet
|
|
|
|
| 18 |
|
| 19 |
# Setting up
|
| 20 |
|
| 21 |
+
Requires Python 3.9 or later.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
install VampNet
|
|
|
|
| 91 |
|
| 92 |
launch the interface:
|
| 93 |
```bash
|
| 94 |
+
python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
|
| 95 |
```
|
| 96 |
|
| 97 |
|
app.py
CHANGED
|
@@ -1,12 +1,3 @@
|
|
| 1 |
-
# huggingface space exclusive
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
# print("installing pyharp")
|
| 5 |
-
# os.system('pip install "pyharp@git+https://github.com/audacitorch/pyharp.git"')
|
| 6 |
-
# print("installing madmom")
|
| 7 |
-
os.system('pip install cython')
|
| 8 |
-
os.system('pip install madmom')
|
| 9 |
-
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Tuple
|
| 12 |
import yaml
|
|
@@ -24,38 +15,27 @@ import gradio as gr
|
|
| 24 |
from vampnet.interface import Interface
|
| 25 |
from vampnet import mask as pmask
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
# loader = AudioLoader()
|
| 32 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
shift=interval,
|
| 42 |
-
sample_rate=signal.sample_rate
|
| 43 |
-
)
|
| 44 |
-
return signal
|
| 45 |
-
|
| 46 |
-
def load_interface():
|
| 47 |
-
interface = Interface(
|
| 48 |
-
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 49 |
-
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 50 |
-
codec_ckpt="./models/vampnet/codec.pth",
|
| 51 |
-
wavebeat_ckpt="./models/wavebeat.pth",
|
| 52 |
-
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 53 |
-
)
|
| 54 |
-
return interface
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
|
|
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
OUT_DIR = Path("gradio-outputs")
|
| 61 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
|
@@ -70,7 +50,7 @@ def load_audio(file):
|
|
| 70 |
)
|
| 71 |
sig = interface.preprocess(sig)
|
| 72 |
|
| 73 |
-
out_dir = OUT_DIR /
|
| 74 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 75 |
sig.write(out_dir / "input.wav")
|
| 76 |
return sig.path_to_file
|
|
@@ -88,10 +68,6 @@ def _vamp(data, return_mask=False):
|
|
| 88 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 89 |
out_dir.mkdir()
|
| 90 |
sig = at.AudioSignal(data[input_audio])
|
| 91 |
-
sig = interface.preprocess(sig)
|
| 92 |
-
|
| 93 |
-
if data[pitch_shift_amt] != 0:
|
| 94 |
-
sig = shift_pitch(sig, data[pitch_shift_amt])
|
| 95 |
|
| 96 |
z = interface.encode(sig)
|
| 97 |
|
|
@@ -131,58 +107,24 @@ def _vamp(data, return_mask=False):
|
|
| 131 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 132 |
|
| 133 |
|
| 134 |
-
print(f"dropout {data[dropout]}")
|
| 135 |
-
print(f"masktemp {data[masktemp]}")
|
| 136 |
-
print(f"sampletemp {data[sampletemp]}")
|
| 137 |
-
print(f"top_p {data[top_p]}")
|
| 138 |
-
print(f"prefix_s {data[prefix_s]}")
|
| 139 |
-
print(f"suffix_s {data[suffix_s]}")
|
| 140 |
-
print(f"rand_mask_intensity {data[rand_mask_intensity]}")
|
| 141 |
-
print(f"num_steps {data[num_steps]}")
|
| 142 |
-
print(f"periodic_p {data[periodic_p]}")
|
| 143 |
-
print(f"periodic_w {data[periodic_w]}")
|
| 144 |
-
print(f"n_conditioning_codebooks {data[n_conditioning_codebooks]}")
|
| 145 |
-
print(f"use_coarse2fine {data[use_coarse2fine]}")
|
| 146 |
-
print(f"onset_mask_width {data[onset_mask_width]}")
|
| 147 |
-
print(f"beat_mask_width {data[beat_mask_width]}")
|
| 148 |
-
print(f"beat_mask_downbeats {data[beat_mask_downbeats]}")
|
| 149 |
-
print(f"stretch_factor {data[stretch_factor]}")
|
| 150 |
-
print(f"seed {data[seed]}")
|
| 151 |
-
print(f"pitch_shift_amt {data[pitch_shift_amt]}")
|
| 152 |
-
print(f"sample_cutoff {data[sample_cutoff]}")
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
_top_p = data[top_p] if data[top_p] > 0 else None
|
| 156 |
# save the mask as a txt file
|
| 157 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 158 |
|
| 159 |
-
_seed = data[seed] if data[seed] > 0 else None
|
| 160 |
zv, mask_z = interface.coarse_vamp(
|
| 161 |
z,
|
| 162 |
mask=mask,
|
| 163 |
sampling_steps=data[num_steps],
|
| 164 |
-
|
| 165 |
-
sampling_temperature=data[sampletemp],
|
| 166 |
return_mask=True,
|
| 167 |
typical_filtering=data[typical_filtering],
|
| 168 |
typical_mass=data[typical_mass],
|
| 169 |
typical_min_tokens=data[typical_min_tokens],
|
| 170 |
-
top_p=_top_p,
|
| 171 |
gen_fn=interface.coarse.generate,
|
| 172 |
-
seed=_seed,
|
| 173 |
-
sample_cutoff=data[sample_cutoff],
|
| 174 |
)
|
| 175 |
|
| 176 |
if use_coarse2fine:
|
| 177 |
-
zv = interface.coarse_to_fine(
|
| 178 |
-
zv,
|
| 179 |
-
mask_temperature=data[masktemp]*10,
|
| 180 |
-
sampling_temperature=data[sampletemp],
|
| 181 |
-
mask=mask,
|
| 182 |
-
sampling_steps=data[num_steps] // 2,
|
| 183 |
-
sample_cutoff=data[sample_cutoff],
|
| 184 |
-
seed=_seed,
|
| 185 |
-
)
|
| 186 |
|
| 187 |
sig = interface.to_signal(zv).cpu()
|
| 188 |
print("done")
|
|
@@ -215,9 +157,7 @@ def save_vamp(data):
|
|
| 215 |
sig_out.write(out_dir / "output.wav")
|
| 216 |
|
| 217 |
_data = {
|
| 218 |
-
"
|
| 219 |
-
"sampletemp": data[sampletemp],
|
| 220 |
-
"top_p": data[top_p],
|
| 221 |
"prefix_s": data[prefix_s],
|
| 222 |
"suffix_s": data[suffix_s],
|
| 223 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
@@ -228,8 +168,6 @@ def save_vamp(data):
|
|
| 228 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 229 |
"use_coarse2fine": data[use_coarse2fine],
|
| 230 |
"stretch_factor": data[stretch_factor],
|
| 231 |
-
"seed": data[seed],
|
| 232 |
-
"samplecutoff": data[sample_cutoff],
|
| 233 |
}
|
| 234 |
|
| 235 |
# save with yaml
|
|
@@ -245,54 +183,13 @@ def save_vamp(data):
|
|
| 245 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
| 246 |
|
| 247 |
|
| 248 |
-
def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
|
| 249 |
-
|
| 250 |
-
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 251 |
-
out_dir.mkdir()
|
| 252 |
-
sig = at.AudioSignal(_input_audio)
|
| 253 |
-
sig = interface.preprocess(sig)
|
| 254 |
-
|
| 255 |
-
z = interface.encode(sig)
|
| 256 |
-
|
| 257 |
-
# build the mask
|
| 258 |
-
mask = pmask.linear_random(z, 1.0)
|
| 259 |
-
if _beat_mask_width > 0:
|
| 260 |
-
beat_mask = interface.make_beat_mask(
|
| 261 |
-
sig,
|
| 262 |
-
after_beat_s=(_beat_mask_width/1000),
|
| 263 |
-
)
|
| 264 |
-
mask = pmask.mask_and(mask, beat_mask)
|
| 265 |
-
|
| 266 |
-
# save the mask as a txt file
|
| 267 |
-
zv, mask_z = interface.coarse_vamp(
|
| 268 |
-
z,
|
| 269 |
-
mask=mask,
|
| 270 |
-
sampling_temperature=_sampletemp,
|
| 271 |
-
return_mask=True,
|
| 272 |
-
gen_fn=interface.coarse.generate,
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
zv = interface.coarse_to_fine(
|
| 277 |
-
zv,
|
| 278 |
-
sampling_temperature=_sampletemp,
|
| 279 |
-
mask=mask,
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
sig = interface.to_signal(zv).cpu()
|
| 283 |
-
print("done")
|
| 284 |
-
|
| 285 |
-
sig.write(out_dir / "output.wav")
|
| 286 |
-
|
| 287 |
-
return sig.path_to_file
|
| 288 |
-
|
| 289 |
with gr.Blocks() as demo:
|
| 290 |
|
| 291 |
with gr.Row():
|
| 292 |
with gr.Column():
|
| 293 |
-
gr.Markdown("# VampNet
|
| 294 |
gr.Markdown("""## Description:
|
| 295 |
-
This is a demo of
|
| 296 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
| 297 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
| 298 |
""")
|
|
@@ -300,8 +197,8 @@ with gr.Blocks() as demo:
|
|
| 300 |
gr.Markdown("""
|
| 301 |
## Instructions:
|
| 302 |
1. You can start by uploading some audio, or by loading the example audio.
|
| 303 |
-
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings.
|
| 304 |
-
3. Click the "generate (vamp)!!!" button to
|
| 305 |
4. Optionally, you can add some notes and save the result.
|
| 306 |
5. You can also use the output as the new input and continue experimenting!
|
| 307 |
""")
|
|
@@ -352,25 +249,19 @@ with gr.Blocks() as demo:
|
|
| 352 |
"beat_mask_downbeats": False,
|
| 353 |
},
|
| 354 |
"slight periodic variation": {
|
| 355 |
-
"periodic_p":
|
| 356 |
-
"onset_mask_width":
|
| 357 |
-
"beat_mask_width": 0,
|
| 358 |
-
"beat_mask_downbeats": False,
|
| 359 |
-
},
|
| 360 |
-
"moderate periodic variation": {
|
| 361 |
-
"periodic_p": 13,
|
| 362 |
-
"onset_mask_width": 5,
|
| 363 |
"beat_mask_width": 0,
|
| 364 |
"beat_mask_downbeats": False,
|
| 365 |
},
|
| 366 |
"strong periodic variation": {
|
| 367 |
-
"periodic_p":
|
| 368 |
"onset_mask_width": 5,
|
| 369 |
"beat_mask_width": 0,
|
| 370 |
"beat_mask_downbeats": False,
|
| 371 |
},
|
| 372 |
"very strong periodic variation": {
|
| 373 |
-
"periodic_p":
|
| 374 |
"onset_mask_width": 5,
|
| 375 |
"beat_mask_width": 0,
|
| 376 |
"beat_mask_downbeats": False,
|
|
@@ -378,15 +269,9 @@ with gr.Blocks() as demo:
|
|
| 378 |
"beat-driven variation": {
|
| 379 |
"periodic_p": 0,
|
| 380 |
"onset_mask_width": 0,
|
| 381 |
-
"beat_mask_width":
|
| 382 |
"beat_mask_downbeats": False,
|
| 383 |
},
|
| 384 |
-
"beat-driven variation (downbeats only)": {
|
| 385 |
-
"periodic_p": 0,
|
| 386 |
-
"onset_mask_width": 0,
|
| 387 |
-
"beat_mask_width": 50,
|
| 388 |
-
"beat_mask_downbeats": True,
|
| 389 |
-
},
|
| 390 |
"beat-driven variation (downbeats only, strong)": {
|
| 391 |
"periodic_p": 0,
|
| 392 |
"onset_mask_width": 0,
|
|
@@ -408,20 +293,20 @@ with gr.Blocks() as demo:
|
|
| 408 |
minimum=0,
|
| 409 |
maximum=128,
|
| 410 |
step=1,
|
| 411 |
-
value=
|
| 412 |
)
|
| 413 |
|
| 414 |
|
| 415 |
onset_mask_width = gr.Slider(
|
| 416 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
| 417 |
minimum=0,
|
| 418 |
-
maximum=
|
| 419 |
step=1,
|
| 420 |
value=5,
|
| 421 |
)
|
| 422 |
|
| 423 |
beat_mask_width = gr.Slider(
|
| 424 |
-
label="beat
|
| 425 |
minimum=0,
|
| 426 |
maximum=200,
|
| 427 |
value=0,
|
|
@@ -433,14 +318,6 @@ with gr.Blocks() as demo:
|
|
| 433 |
|
| 434 |
|
| 435 |
with gr.Accordion("extras ", open=False):
|
| 436 |
-
pitch_shift_amt = gr.Slider(
|
| 437 |
-
label="pitch shift amount (semitones)",
|
| 438 |
-
minimum=-12,
|
| 439 |
-
maximum=12,
|
| 440 |
-
step=1,
|
| 441 |
-
value=0,
|
| 442 |
-
)
|
| 443 |
-
|
| 444 |
rand_mask_intensity = gr.Slider(
|
| 445 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
| 446 |
minimum=0.0,
|
|
@@ -500,34 +377,21 @@ with gr.Blocks() as demo:
|
|
| 500 |
value=0.0
|
| 501 |
)
|
| 502 |
|
| 503 |
-
|
| 504 |
-
label="
|
| 505 |
minimum=0.0,
|
| 506 |
-
maximum=100.0,
|
| 507 |
-
value=1.5
|
| 508 |
-
)
|
| 509 |
-
sampletemp = gr.Slider(
|
| 510 |
-
label="sample temperature",
|
| 511 |
-
minimum=0.1,
|
| 512 |
maximum=10.0,
|
| 513 |
-
value=1.
|
| 514 |
-
step=0.001
|
| 515 |
)
|
| 516 |
-
|
| 517 |
|
| 518 |
|
| 519 |
with gr.Accordion("sampling settings", open=False):
|
| 520 |
-
top_p = gr.Slider(
|
| 521 |
-
label="top p (0.0 = off)",
|
| 522 |
-
minimum=0.0,
|
| 523 |
-
maximum=1.0,
|
| 524 |
-
value=0.9
|
| 525 |
-
)
|
| 526 |
typical_filtering = gr.Checkbox(
|
| 527 |
label="typical filtering ",
|
| 528 |
value=False
|
| 529 |
)
|
| 530 |
-
typical_mass = gr.Slider(
|
| 531 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
| 532 |
minimum=0.01,
|
| 533 |
maximum=0.99,
|
|
@@ -540,13 +404,6 @@ with gr.Blocks() as demo:
|
|
| 540 |
step=1,
|
| 541 |
value=64
|
| 542 |
)
|
| 543 |
-
sample_cutoff = gr.Slider(
|
| 544 |
-
label="sample cutoff",
|
| 545 |
-
minimum=0.0,
|
| 546 |
-
maximum=1.0,
|
| 547 |
-
value=0.5,
|
| 548 |
-
step=0.01
|
| 549 |
-
)
|
| 550 |
|
| 551 |
use_coarse2fine = gr.Checkbox(
|
| 552 |
label="use coarse2fine",
|
|
@@ -571,24 +428,8 @@ with gr.Blocks() as demo:
|
|
| 571 |
)
|
| 572 |
|
| 573 |
|
| 574 |
-
seed = gr.Number(
|
| 575 |
-
label="seed (0 for random)",
|
| 576 |
-
value=0,
|
| 577 |
-
precision=0,
|
| 578 |
-
)
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
# mask settings
|
| 583 |
with gr.Column():
|
| 584 |
-
|
| 585 |
-
# lora_choice = gr.Dropdown(
|
| 586 |
-
# label="lora choice",
|
| 587 |
-
# choices=list(loras.keys()),
|
| 588 |
-
# value=LORA_NONE,
|
| 589 |
-
# visible=False
|
| 590 |
-
# )
|
| 591 |
-
|
| 592 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 593 |
output_audio = gr.Audio(
|
| 594 |
label="output audio",
|
|
@@ -614,9 +455,7 @@ with gr.Blocks() as demo:
|
|
| 614 |
_inputs = {
|
| 615 |
input_audio,
|
| 616 |
num_steps,
|
| 617 |
-
|
| 618 |
-
sampletemp,
|
| 619 |
-
top_p,
|
| 620 |
prefix_s, suffix_s,
|
| 621 |
rand_mask_intensity,
|
| 622 |
periodic_p, periodic_w,
|
|
@@ -629,11 +468,7 @@ with gr.Blocks() as demo:
|
|
| 629 |
typical_mass,
|
| 630 |
typical_min_tokens,
|
| 631 |
beat_mask_width,
|
| 632 |
-
beat_mask_downbeats
|
| 633 |
-
seed,
|
| 634 |
-
# lora_choice,
|
| 635 |
-
pitch_shift_amt,
|
| 636 |
-
sample_cutoff
|
| 637 |
}
|
| 638 |
|
| 639 |
# connect widgets
|
|
@@ -663,24 +498,4 @@ with gr.Blocks() as demo:
|
|
| 663 |
outputs=[thank_you, download_file]
|
| 664 |
)
|
| 665 |
|
| 666 |
-
|
| 667 |
-
harp_inputs = [
|
| 668 |
-
input_audio,
|
| 669 |
-
beat_mask_width,
|
| 670 |
-
sampletemp,
|
| 671 |
-
]
|
| 672 |
-
|
| 673 |
-
build_endpoint(
|
| 674 |
-
inputs=harp_inputs,
|
| 675 |
-
output=output_audio,
|
| 676 |
-
process_fn=harp_vamp,
|
| 677 |
-
card=ModelCard(
|
| 678 |
-
name="vampnet",
|
| 679 |
-
description="Generate variations on music input, based on small prompts around the beat. NOTE: vampnet's has a maximum context length of 10 seconds. Please split all audio clips into 10 second chunks, or processing will result in an error. ",
|
| 680 |
-
author="Hugo Flores García",
|
| 681 |
-
tags=["music", "generative"]
|
| 682 |
-
),
|
| 683 |
-
visible=False
|
| 684 |
-
)
|
| 685 |
-
|
| 686 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from typing import Tuple
|
| 3 |
import yaml
|
|
|
|
| 15 |
from vampnet.interface import Interface
|
| 16 |
from vampnet import mask as pmask
|
| 17 |
|
| 18 |
+
# Interface = argbind.bind(Interface)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
| 20 |
|
| 21 |
+
interface = Interface(
|
| 22 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
| 23 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
| 24 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
| 25 |
+
wavebeat_ckpt="./models/wavebeat.pth",
|
| 26 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 27 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# loader = AudioLoader()
|
| 30 |
+
print(f"interface device is {interface.device}")
|
| 31 |
|
| 32 |
+
# dataset = at.data.datasets.AudioDataset(
|
| 33 |
+
# loader,
|
| 34 |
+
# sample_rate=interface.codec.sample_rate,
|
| 35 |
+
# duration=interface.coarse.chunk_size_s,
|
| 36 |
+
# n_examples=5000,
|
| 37 |
+
# without_replacement=True,
|
| 38 |
+
# )
|
| 39 |
|
| 40 |
OUT_DIR = Path("gradio-outputs")
|
| 41 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
|
|
|
| 50 |
)
|
| 51 |
sig = interface.preprocess(sig)
|
| 52 |
|
| 53 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 54 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
sig.write(out_dir / "input.wav")
|
| 56 |
return sig.path_to_file
|
|
|
|
| 68 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
| 69 |
out_dir.mkdir()
|
| 70 |
sig = at.AudioSignal(data[input_audio])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
z = interface.encode(sig)
|
| 73 |
|
|
|
|
| 107 |
mask = pmask.codebook_unmask(mask, ncc)
|
| 108 |
|
| 109 |
|
| 110 |
+
print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# save the mask as a txt file
|
| 112 |
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
| 113 |
|
|
|
|
| 114 |
zv, mask_z = interface.coarse_vamp(
|
| 115 |
z,
|
| 116 |
mask=mask,
|
| 117 |
sampling_steps=data[num_steps],
|
| 118 |
+
temperature=float(data[temp]*10),
|
|
|
|
| 119 |
return_mask=True,
|
| 120 |
typical_filtering=data[typical_filtering],
|
| 121 |
typical_mass=data[typical_mass],
|
| 122 |
typical_min_tokens=data[typical_min_tokens],
|
|
|
|
| 123 |
gen_fn=interface.coarse.generate,
|
|
|
|
|
|
|
| 124 |
)
|
| 125 |
|
| 126 |
if use_coarse2fine:
|
| 127 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
sig = interface.to_signal(zv).cpu()
|
| 130 |
print("done")
|
|
|
|
| 157 |
sig_out.write(out_dir / "output.wav")
|
| 158 |
|
| 159 |
_data = {
|
| 160 |
+
"temp": data[temp],
|
|
|
|
|
|
|
| 161 |
"prefix_s": data[prefix_s],
|
| 162 |
"suffix_s": data[suffix_s],
|
| 163 |
"rand_mask_intensity": data[rand_mask_intensity],
|
|
|
|
| 168 |
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
| 169 |
"use_coarse2fine": data[use_coarse2fine],
|
| 170 |
"stretch_factor": data[stretch_factor],
|
|
|
|
|
|
|
| 171 |
}
|
| 172 |
|
| 173 |
# save with yaml
|
|
|
|
| 183 |
return f"saved! your save code is {out_dir.stem}", zip_path
|
| 184 |
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
with gr.Blocks() as demo:
|
| 187 |
|
| 188 |
with gr.Row():
|
| 189 |
with gr.Column():
|
| 190 |
+
gr.Markdown("# VampNet")
|
| 191 |
gr.Markdown("""## Description:
|
| 192 |
+
This is a demo of VampNet, a masked generative music model capable of doing music variations.
|
| 193 |
You can control the extent and nature of variation with a set of manual controls and presets.
|
| 194 |
Use this interface to experiment with different mask settings and explore the audio outputs.
|
| 195 |
""")
|
|
|
|
| 197 |
gr.Markdown("""
|
| 198 |
## Instructions:
|
| 199 |
1. You can start by uploading some audio, or by loading the example audio.
|
| 200 |
+
2. Choose a preset for the vamp operation, or manually adjust the controls to customize the mask settings. Click the load preset button.
|
| 201 |
+
3. Click the "generate (vamp)!!!" button to generate audio. Listen to the output audio, and the masked audio to hear the mask hints.
|
| 202 |
4. Optionally, you can add some notes and save the result.
|
| 203 |
5. You can also use the output as the new input and continue experimenting!
|
| 204 |
""")
|
|
|
|
| 249 |
"beat_mask_downbeats": False,
|
| 250 |
},
|
| 251 |
"slight periodic variation": {
|
| 252 |
+
"periodic_p": 7,
|
| 253 |
+
"onset_mask_width": 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
"beat_mask_width": 0,
|
| 255 |
"beat_mask_downbeats": False,
|
| 256 |
},
|
| 257 |
"strong periodic variation": {
|
| 258 |
+
"periodic_p": 13,
|
| 259 |
"onset_mask_width": 5,
|
| 260 |
"beat_mask_width": 0,
|
| 261 |
"beat_mask_downbeats": False,
|
| 262 |
},
|
| 263 |
"very strong periodic variation": {
|
| 264 |
+
"periodic_p": 17,
|
| 265 |
"onset_mask_width": 5,
|
| 266 |
"beat_mask_width": 0,
|
| 267 |
"beat_mask_downbeats": False,
|
|
|
|
| 269 |
"beat-driven variation": {
|
| 270 |
"periodic_p": 0,
|
| 271 |
"onset_mask_width": 0,
|
| 272 |
+
"beat_mask_width": 20,
|
| 273 |
"beat_mask_downbeats": False,
|
| 274 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
"beat-driven variation (downbeats only, strong)": {
|
| 276 |
"periodic_p": 0,
|
| 277 |
"onset_mask_width": 0,
|
|
|
|
| 293 |
minimum=0,
|
| 294 |
maximum=128,
|
| 295 |
step=1,
|
| 296 |
+
value=13,
|
| 297 |
)
|
| 298 |
|
| 299 |
|
| 300 |
onset_mask_width = gr.Slider(
|
| 301 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
| 302 |
minimum=0,
|
| 303 |
+
maximum=20,
|
| 304 |
step=1,
|
| 305 |
value=5,
|
| 306 |
)
|
| 307 |
|
| 308 |
beat_mask_width = gr.Slider(
|
| 309 |
+
label="beat mask width (in milliseconds)",
|
| 310 |
minimum=0,
|
| 311 |
maximum=200,
|
| 312 |
value=0,
|
|
|
|
| 318 |
|
| 319 |
|
| 320 |
with gr.Accordion("extras ", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
rand_mask_intensity = gr.Slider(
|
| 322 |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
| 323 |
minimum=0.0,
|
|
|
|
| 377 |
value=0.0
|
| 378 |
)
|
| 379 |
|
| 380 |
+
temp = gr.Slider(
|
| 381 |
+
label="temperature",
|
| 382 |
minimum=0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
maximum=10.0,
|
| 384 |
+
value=1.8
|
|
|
|
| 385 |
)
|
| 386 |
+
|
| 387 |
|
| 388 |
|
| 389 |
with gr.Accordion("sampling settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
typical_filtering = gr.Checkbox(
|
| 391 |
label="typical filtering ",
|
| 392 |
value=False
|
| 393 |
)
|
| 394 |
+
typical_mass = gr.Slider(
|
| 395 |
label="typical mass (should probably stay between 0.1 and 0.5)",
|
| 396 |
minimum=0.01,
|
| 397 |
maximum=0.99,
|
|
|
|
| 404 |
step=1,
|
| 405 |
value=64
|
| 406 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
use_coarse2fine = gr.Checkbox(
|
| 409 |
label="use coarse2fine",
|
|
|
|
| 428 |
)
|
| 429 |
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
# mask settings
|
| 432 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
vamp_button = gr.Button("generate (vamp)!!!")
|
| 434 |
output_audio = gr.Audio(
|
| 435 |
label="output audio",
|
|
|
|
| 455 |
_inputs = {
|
| 456 |
input_audio,
|
| 457 |
num_steps,
|
| 458 |
+
temp,
|
|
|
|
|
|
|
| 459 |
prefix_s, suffix_s,
|
| 460 |
rand_mask_intensity,
|
| 461 |
periodic_p, periodic_w,
|
|
|
|
| 468 |
typical_mass,
|
| 469 |
typical_min_tokens,
|
| 470 |
beat_mask_width,
|
| 471 |
+
beat_mask_downbeats
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
}
|
| 473 |
|
| 474 |
# connect widgets
|
|
|
|
| 498 |
outputs=[thank_you, download_file]
|
| 499 |
)
|
| 500 |
|
| 501 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf/lora/lora.yml
CHANGED
|
@@ -4,16 +4,14 @@ $include:
|
|
| 4 |
fine_tune: True
|
| 5 |
|
| 6 |
train/AudioDataset.n_examples: 100000000
|
| 7 |
-
val/AudioDataset.n_examples:
|
| 8 |
|
| 9 |
|
| 10 |
NoamScheduler.warmup: 500
|
| 11 |
|
| 12 |
-
batch_size:
|
| 13 |
num_workers: 7
|
| 14 |
-
save_iters: [
|
| 15 |
-
sample_freq: 1000
|
| 16 |
-
val_freq: 500
|
| 17 |
|
| 18 |
AdamW.lr: 0.0001
|
| 19 |
|
|
|
|
| 4 |
fine_tune: True
|
| 5 |
|
| 6 |
train/AudioDataset.n_examples: 100000000
|
| 7 |
+
val/AudioDataset.n_examples: 100
|
| 8 |
|
| 9 |
|
| 10 |
NoamScheduler.warmup: 500
|
| 11 |
|
| 12 |
+
batch_size: 7
|
| 13 |
num_workers: 7
|
| 14 |
+
save_iters: [100000, 200000, 300000, 4000000, 500000]
|
|
|
|
|
|
|
| 15 |
|
| 16 |
AdamW.lr: 0.0001
|
| 17 |
|
conf/vampnet.yml
CHANGED
|
@@ -32,7 +32,7 @@ VampNet.n_heads: 20
|
|
| 32 |
VampNet.flash_attn: false
|
| 33 |
VampNet.dropout: 0.1
|
| 34 |
|
| 35 |
-
AudioLoader.relative_path:
|
| 36 |
AudioDataset.loudness_cutoff: -30.0
|
| 37 |
AudioDataset.without_replacement: true
|
| 38 |
AudioLoader.shuffle: true
|
|
|
|
| 32 |
VampNet.flash_attn: false
|
| 33 |
VampNet.dropout: 0.1
|
| 34 |
|
| 35 |
+
AudioLoader.relative_path: /data/
|
| 36 |
AudioDataset.loudness_cutoff: -30.0
|
| 37 |
AudioDataset.without_replacement: true
|
| 38 |
AudioLoader.shuffle: true
|
requirements.txt
CHANGED
|
@@ -1,10 +1,8 @@
|
|
| 1 |
torch
|
| 2 |
argbind>=0.3.2
|
| 3 |
-
numpy==1.
|
| 4 |
gradio
|
| 5 |
loralib
|
| 6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
| 7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 8 |
-
|
| 9 |
-
-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp
|
| 10 |
-
torch_pitch_shift
|
|
|
|
| 1 |
torch
|
| 2 |
argbind>=0.3.2
|
| 3 |
+
numpy==1.22
|
| 4 |
gradio
|
| 5 |
loralib
|
| 6 |
wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
|
| 7 |
lac @ git+https://github.com/hugofloresgarcia/lac.git
|
| 8 |
+
audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
|
|
|
|
|
|
scripts/exp/fine_tune.py
CHANGED
|
@@ -48,10 +48,11 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
|
|
| 48 |
}
|
| 49 |
|
| 50 |
interface_conf = {
|
| 51 |
-
"Interface.coarse_ckpt": f"./
|
|
|
|
| 52 |
|
| 53 |
-
"Interface.coarse2fine_ckpt": f"./
|
| 54 |
-
"Interface.
|
| 55 |
|
| 56 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 57 |
"AudioLoader.sources": [audio_files_or_folders],
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
interface_conf = {
|
| 51 |
+
"Interface.coarse_ckpt": f"./models/vampnet/coarse.pth",
|
| 52 |
+
"Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
|
| 53 |
|
| 54 |
+
"Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
|
| 55 |
+
"Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
|
| 56 |
|
| 57 |
"Interface.codec_ckpt": "./models/vampnet/codec.pth",
|
| 58 |
"AudioLoader.sources": [audio_files_or_folders],
|
scripts/exp/train.py
CHANGED
|
@@ -14,7 +14,7 @@ from audiotools.data import transforms
|
|
| 14 |
from einops import rearrange
|
| 15 |
from rich import pretty
|
| 16 |
from rich.traceback import install
|
| 17 |
-
from
|
| 18 |
|
| 19 |
import vampnet
|
| 20 |
from vampnet.modules.transformer import VampNet
|
|
@@ -29,9 +29,6 @@ from audiotools.ml.decorators import (
|
|
| 29 |
|
| 30 |
import loralib as lora
|
| 31 |
|
| 32 |
-
import torch._dynamo
|
| 33 |
-
torch._dynamo.config.verbose=True
|
| 34 |
-
|
| 35 |
|
| 36 |
# Enable cudnn autotuner to speed up training
|
| 37 |
# (can be altered by the funcs.seed function)
|
|
@@ -224,7 +221,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 224 |
|
| 225 |
dtype = torch.bfloat16 if accel.amp else None
|
| 226 |
with accel.autocast(dtype=dtype):
|
| 227 |
-
z_hat = state.model(z_mask_latent)
|
| 228 |
|
| 229 |
target = codebook_flatten(
|
| 230 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
@@ -289,7 +286,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
|
|
| 289 |
|
| 290 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 291 |
|
| 292 |
-
z_hat = state.model(z_mask_latent)
|
| 293 |
|
| 294 |
target = codebook_flatten(
|
| 295 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
@@ -408,19 +405,19 @@ def save_imputation(state, z, val_idx, writer):
|
|
| 408 |
|
| 409 |
for i in range(len(val_idx)):
|
| 410 |
imputed_noisy[i].cpu().write_audio_to_tb(
|
| 411 |
-
f"
|
| 412 |
writer,
|
| 413 |
step=state.tracker.step,
|
| 414 |
plot_fn=None,
|
| 415 |
)
|
| 416 |
imputed[i].cpu().write_audio_to_tb(
|
| 417 |
-
f"
|
| 418 |
writer,
|
| 419 |
step=state.tracker.step,
|
| 420 |
plot_fn=None,
|
| 421 |
)
|
| 422 |
imputed_true[i].cpu().write_audio_to_tb(
|
| 423 |
-
f"
|
| 424 |
writer,
|
| 425 |
step=state.tracker.step,
|
| 426 |
plot_fn=None,
|
|
@@ -450,7 +447,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
| 450 |
|
| 451 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 452 |
|
| 453 |
-
z_hat = state.model(z_mask_latent)
|
| 454 |
|
| 455 |
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
| 456 |
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
|
@@ -469,7 +466,7 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
|
|
| 469 |
}
|
| 470 |
for k, v in audio_dict.items():
|
| 471 |
v.cpu().write_audio_to_tb(
|
| 472 |
-
f"
|
| 473 |
writer,
|
| 474 |
step=state.tracker.step,
|
| 475 |
plot_fn=None,
|
|
@@ -488,6 +485,7 @@ def load(
|
|
| 488 |
save_path: str,
|
| 489 |
resume: bool = False,
|
| 490 |
tag: str = "latest",
|
|
|
|
| 491 |
fine_tune_checkpoint: Optional[str] = None,
|
| 492 |
grad_clip_val: float = 5.0,
|
| 493 |
) -> State:
|
|
@@ -500,7 +498,7 @@ def load(
|
|
| 500 |
kwargs = {
|
| 501 |
"folder": f"{save_path}/{tag}",
|
| 502 |
"map_location": "cpu",
|
| 503 |
-
"package":
|
| 504 |
}
|
| 505 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 506 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
@@ -513,14 +511,11 @@ def load(
|
|
| 513 |
|
| 514 |
if args["fine_tune"]:
|
| 515 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 516 |
-
model =
|
| 517 |
-
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 518 |
-
map_location="cpu",
|
| 519 |
-
)
|
| 520 |
-
)
|
| 521 |
|
| 522 |
|
| 523 |
-
model =
|
|
|
|
| 524 |
model = accel.prepare_model(model)
|
| 525 |
|
| 526 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
@@ -604,7 +599,7 @@ def train(
|
|
| 604 |
accel=accel,
|
| 605 |
tracker=tracker,
|
| 606 |
save_path=save_path)
|
| 607 |
-
|
| 608 |
|
| 609 |
train_dataloader = accel.prepare_dataloader(
|
| 610 |
state.train_data,
|
|
@@ -619,15 +614,13 @@ def train(
|
|
| 619 |
num_workers=num_workers,
|
| 620 |
batch_size=batch_size,
|
| 621 |
collate_fn=state.val_data.collate,
|
| 622 |
-
persistent_workers=
|
| 623 |
)
|
| 624 |
-
print("initialized dataloader.")
|
| 625 |
|
| 626 |
|
| 627 |
|
| 628 |
if fine_tune:
|
| 629 |
lora.mark_only_lora_as_trainable(state.model)
|
| 630 |
-
print("marked only lora as trainable.")
|
| 631 |
|
| 632 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 633 |
# and only run when specific conditions are met.
|
|
@@ -642,7 +635,6 @@ def train(
|
|
| 642 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 643 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
| 644 |
|
| 645 |
-
print("starting training loop.")
|
| 646 |
with tracker.live:
|
| 647 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 648 |
train_loop(state, batch, accel)
|
|
|
|
| 14 |
from einops import rearrange
|
| 15 |
from rich import pretty
|
| 16 |
from rich.traceback import install
|
| 17 |
+
from tensorboardX import SummaryWriter
|
| 18 |
|
| 19 |
import vampnet
|
| 20 |
from vampnet.modules.transformer import VampNet
|
|
|
|
| 29 |
|
| 30 |
import loralib as lora
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Enable cudnn autotuner to speed up training
|
| 34 |
# (can be altered by the funcs.seed function)
|
|
|
|
| 221 |
|
| 222 |
dtype = torch.bfloat16 if accel.amp else None
|
| 223 |
with accel.autocast(dtype=dtype):
|
| 224 |
+
z_hat = state.model(z_mask_latent, r)
|
| 225 |
|
| 226 |
target = codebook_flatten(
|
| 227 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
|
|
| 286 |
|
| 287 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 288 |
|
| 289 |
+
z_hat = state.model(z_mask_latent, r)
|
| 290 |
|
| 291 |
target = codebook_flatten(
|
| 292 |
z[:, vn.n_conditioning_codebooks :, :],
|
|
|
|
| 405 |
|
| 406 |
for i in range(len(val_idx)):
|
| 407 |
imputed_noisy[i].cpu().write_audio_to_tb(
|
| 408 |
+
f"imputed_noisy/{i}",
|
| 409 |
writer,
|
| 410 |
step=state.tracker.step,
|
| 411 |
plot_fn=None,
|
| 412 |
)
|
| 413 |
imputed[i].cpu().write_audio_to_tb(
|
| 414 |
+
f"imputed/{i}",
|
| 415 |
writer,
|
| 416 |
step=state.tracker.step,
|
| 417 |
plot_fn=None,
|
| 418 |
)
|
| 419 |
imputed_true[i].cpu().write_audio_to_tb(
|
| 420 |
+
f"imputed_true/{i}",
|
| 421 |
writer,
|
| 422 |
step=state.tracker.step,
|
| 423 |
plot_fn=None,
|
|
|
|
| 447 |
|
| 448 |
z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
|
| 449 |
|
| 450 |
+
z_hat = state.model(z_mask_latent, r)
|
| 451 |
|
| 452 |
z_pred = torch.softmax(z_hat, dim=1).argmax(dim=1)
|
| 453 |
z_pred = codebook_unflatten(z_pred, n_c=vn.n_predict_codebooks)
|
|
|
|
| 466 |
}
|
| 467 |
for k, v in audio_dict.items():
|
| 468 |
v.cpu().write_audio_to_tb(
|
| 469 |
+
f"samples/_{i}.r={r[i]:0.2f}/{k}",
|
| 470 |
writer,
|
| 471 |
step=state.tracker.step,
|
| 472 |
plot_fn=None,
|
|
|
|
| 485 |
save_path: str,
|
| 486 |
resume: bool = False,
|
| 487 |
tag: str = "latest",
|
| 488 |
+
load_weights: bool = False,
|
| 489 |
fine_tune_checkpoint: Optional[str] = None,
|
| 490 |
grad_clip_val: float = 5.0,
|
| 491 |
) -> State:
|
|
|
|
| 498 |
kwargs = {
|
| 499 |
"folder": f"{save_path}/{tag}",
|
| 500 |
"map_location": "cpu",
|
| 501 |
+
"package": not load_weights,
|
| 502 |
}
|
| 503 |
tracker.print(f"Loading checkpoint from {kwargs['folder']}")
|
| 504 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
|
|
|
| 511 |
|
| 512 |
if args["fine_tune"]:
|
| 513 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 514 |
+
model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
|
| 517 |
+
model = VampNet() if model is None else model
|
| 518 |
+
|
| 519 |
model = accel.prepare_model(model)
|
| 520 |
|
| 521 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
|
| 599 |
accel=accel,
|
| 600 |
tracker=tracker,
|
| 601 |
save_path=save_path)
|
| 602 |
+
|
| 603 |
|
| 604 |
train_dataloader = accel.prepare_dataloader(
|
| 605 |
state.train_data,
|
|
|
|
| 614 |
num_workers=num_workers,
|
| 615 |
batch_size=batch_size,
|
| 616 |
collate_fn=state.val_data.collate,
|
| 617 |
+
persistent_workers=True,
|
| 618 |
)
|
|
|
|
| 619 |
|
| 620 |
|
| 621 |
|
| 622 |
if fine_tune:
|
| 623 |
lora.mark_only_lora_as_trainable(state.model)
|
|
|
|
| 624 |
|
| 625 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 626 |
# and only run when specific conditions are met.
|
|
|
|
| 635 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 636 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
| 637 |
|
|
|
|
| 638 |
with tracker.live:
|
| 639 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 640 |
train_loop(state, batch, accel)
|
scripts/utils/{data/augment.py → augment.py}
RENAMED
|
@@ -5,19 +5,34 @@ from audiotools import AudioSignal
|
|
| 5 |
|
| 6 |
import argbind
|
| 7 |
import tqdm
|
| 8 |
-
import torch
|
| 9 |
|
| 10 |
|
| 11 |
-
from
|
| 12 |
-
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@argbind.bind(without_prefix=True)
|
| 18 |
def augment(
|
| 19 |
-
audio_folder: Path
|
| 20 |
-
dest_folder: Path
|
| 21 |
n_augmentations: int = 10,
|
| 22 |
):
|
| 23 |
"""
|
|
@@ -26,8 +41,7 @@ def augment(
|
|
| 26 |
The dest foler will contain a folder for each of the clean dataset's files.
|
| 27 |
Under each of these folders, there will be a clean file and many augmented files.
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
-
assert dest_folder is not None
|
| 31 |
audio_files = at.util.find_audio(audio_folder)
|
| 32 |
|
| 33 |
for audio_file in tqdm.tqdm(audio_files):
|
|
@@ -35,33 +49,5 @@ def augment(
|
|
| 35 |
subdir = subtree / audio_file.stem
|
| 36 |
subdir.mkdir(parents=True, exist_ok=True)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
|
| 42 |
-
# apply pedalboard transforms
|
| 43 |
-
for j in range(n_augmentations):
|
| 44 |
-
# pitch shift between -7 and 7 semitones
|
| 45 |
-
import random
|
| 46 |
-
dst = chunk.clone()
|
| 47 |
-
dst.samples = pitch_shift(
|
| 48 |
-
dst.samples,
|
| 49 |
-
shift=random.choice(get_fast_shifts(src.sample_rate,
|
| 50 |
-
condition=lambda x: x >= 0.25 and x <= 1.0)),
|
| 51 |
-
sample_rate=src.sample_rate
|
| 52 |
-
)
|
| 53 |
-
dst.samples = time_stretch(
|
| 54 |
-
dst.samples,
|
| 55 |
-
stretch=random.choice(get_fast_stretches(src.sample_rate,
|
| 56 |
-
condition=lambda x: x >= 0.667 and x <= 1.5, )),
|
| 57 |
-
sample_rate=src.sample_rate,
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
dst.cpu().write(subdir / f"{i}-{j}.wav")
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
if __name__ == "__main__":
|
| 64 |
-
args = argbind.parse_args()
|
| 65 |
-
|
| 66 |
-
with argbind.scope(args):
|
| 67 |
-
augment()
|
|
|
|
| 5 |
|
| 6 |
import argbind
|
| 7 |
import tqdm
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
+
from pedalboard import (
|
| 11 |
+
Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
|
| 12 |
+
)
|
| 13 |
+
from pedalboard.io import AudioFile
|
| 14 |
|
| 15 |
+
# Read in a whole file, resampling to our desired sample rate:
|
| 16 |
+
samplerate = 44100.0
|
| 17 |
+
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
| 18 |
+
audio = f.read(f.frames)
|
| 19 |
+
|
| 20 |
+
# Make a pretty interesting sounding guitar pedalboard:
|
| 21 |
+
board = Pedalboard([
|
| 22 |
+
Compressor(threshold_db=-50, ratio=25),
|
| 23 |
+
Gain(gain_db=30),
|
| 24 |
+
Chorus(),
|
| 25 |
+
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
| 26 |
+
Phaser(),
|
| 27 |
+
Convolution("./guitar_amp.wav", 1.0),
|
| 28 |
+
Reverb(room_size=0.25),
|
| 29 |
+
])
|
| 30 |
|
| 31 |
|
| 32 |
@argbind.bind(without_prefix=True)
|
| 33 |
def augment(
|
| 34 |
+
audio_folder: Path,
|
| 35 |
+
dest_folder: Path,
|
| 36 |
n_augmentations: int = 10,
|
| 37 |
):
|
| 38 |
"""
|
|
|
|
| 41 |
The dest foler will contain a folder for each of the clean dataset's files.
|
| 42 |
Under each of these folders, there will be a clean file and many augmented files.
|
| 43 |
"""
|
| 44 |
+
|
|
|
|
| 45 |
audio_files = at.util.find_audio(audio_folder)
|
| 46 |
|
| 47 |
for audio_file in tqdm.tqdm(audio_files):
|
|
|
|
| 49 |
subdir = subtree / audio_file.stem
|
| 50 |
subdir.mkdir(parents=True, exist_ok=True)
|
| 51 |
|
| 52 |
+
# apply pedalboard transforms
|
| 53 |
+
for i in range(n_augmentations):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/gtzan_embeddings.py
DELETED
|
@@ -1,263 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
TODO: train a linear probe
|
| 3 |
-
usage:
|
| 4 |
-
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output
|
| 5 |
-
"""
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import List
|
| 8 |
-
|
| 9 |
-
import audiotools as at
|
| 10 |
-
from audiotools import AudioSignal
|
| 11 |
-
import argbind
|
| 12 |
-
import torch
|
| 13 |
-
import numpy as np
|
| 14 |
-
import zipfile
|
| 15 |
-
import json
|
| 16 |
-
|
| 17 |
-
from vampnet.interface import Interface
|
| 18 |
-
import tqdm
|
| 19 |
-
|
| 20 |
-
# bind the Interface to argbind
|
| 21 |
-
Interface = argbind.bind(Interface)
|
| 22 |
-
|
| 23 |
-
DEBUG = False
|
| 24 |
-
|
| 25 |
-
def smart_plotly_export(fig, save_path):
|
| 26 |
-
img_format = save_path.split('.')[-1]
|
| 27 |
-
if img_format == 'html':
|
| 28 |
-
fig.write_html(save_path)
|
| 29 |
-
elif img_format == 'bytes':
|
| 30 |
-
return fig.to_image(format='png')
|
| 31 |
-
#TODO: come back and make this prettier
|
| 32 |
-
elif img_format == 'numpy':
|
| 33 |
-
import io
|
| 34 |
-
from PIL import Image
|
| 35 |
-
|
| 36 |
-
def plotly_fig2array(fig):
|
| 37 |
-
#convert Plotly fig to an array
|
| 38 |
-
fig_bytes = fig.to_image(format="png", width=1200, height=700)
|
| 39 |
-
buf = io.BytesIO(fig_bytes)
|
| 40 |
-
img = Image.open(buf)
|
| 41 |
-
return np.asarray(img)
|
| 42 |
-
|
| 43 |
-
return plotly_fig2array(fig)
|
| 44 |
-
elif img_format == 'jpeg' or 'png' or 'webp':
|
| 45 |
-
fig.write_image(save_path)
|
| 46 |
-
else:
|
| 47 |
-
raise ValueError("invalid image format")
|
| 48 |
-
|
| 49 |
-
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''):
|
| 50 |
-
"""
|
| 51 |
-
dimensionality reduction for visualization!
|
| 52 |
-
saves an html plotly figure to save_path
|
| 53 |
-
parameters:
|
| 54 |
-
emb (np.ndarray): the samples to be reduces with shape (samples, features)
|
| 55 |
-
labels (list): list of labels for embedding
|
| 56 |
-
save_path (str): path where u wanna save ur figure
|
| 57 |
-
method (str): umap, tsne, or pca
|
| 58 |
-
title (str): title for ur figure
|
| 59 |
-
returns:
|
| 60 |
-
proj (np.ndarray): projection vector with shape (samples, dimensions)
|
| 61 |
-
"""
|
| 62 |
-
import pandas as pd
|
| 63 |
-
import plotly.express as px
|
| 64 |
-
if method == 'umap':
|
| 65 |
-
reducer = umap.UMAP(n_components=n_components)
|
| 66 |
-
elif method == 'tsne':
|
| 67 |
-
from sklearn.manifold import TSNE
|
| 68 |
-
reducer = TSNE(n_components=n_components)
|
| 69 |
-
elif method == 'pca':
|
| 70 |
-
from sklearn.decomposition import PCA
|
| 71 |
-
reducer = PCA(n_components=n_components)
|
| 72 |
-
else:
|
| 73 |
-
raise ValueError
|
| 74 |
-
|
| 75 |
-
proj = reducer.fit_transform(emb)
|
| 76 |
-
|
| 77 |
-
if n_components == 2:
|
| 78 |
-
df = pd.DataFrame(dict(
|
| 79 |
-
x=proj[:, 0],
|
| 80 |
-
y=proj[:, 1],
|
| 81 |
-
instrument=labels
|
| 82 |
-
))
|
| 83 |
-
fig = px.scatter(df, x='x', y='y', color='instrument',
|
| 84 |
-
title=title+f"_{method}")
|
| 85 |
-
|
| 86 |
-
elif n_components == 3:
|
| 87 |
-
df = pd.DataFrame(dict(
|
| 88 |
-
x=proj[:, 0],
|
| 89 |
-
y=proj[:, 1],
|
| 90 |
-
z=proj[:, 2],
|
| 91 |
-
instrument=labels
|
| 92 |
-
))
|
| 93 |
-
fig = px.scatter_3d(df, x='x', y='y', z='z',
|
| 94 |
-
color='instrument',
|
| 95 |
-
title=title)
|
| 96 |
-
else:
|
| 97 |
-
raise ValueError("cant plot more than 3 components")
|
| 98 |
-
|
| 99 |
-
fig.update_traces(marker=dict(size=6,
|
| 100 |
-
line=dict(width=1,
|
| 101 |
-
color='DarkSlateGrey')),
|
| 102 |
-
selector=dict(mode='markers'))
|
| 103 |
-
|
| 104 |
-
return smart_plotly_export(fig, save_path)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# per JukeMIR, we want the emebddings from the middle layer?
|
| 109 |
-
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
|
| 110 |
-
with torch.inference_mode():
|
| 111 |
-
# preprocess the signal
|
| 112 |
-
sig = interface.preprocess(sig)
|
| 113 |
-
|
| 114 |
-
# get the coarse vampnet model
|
| 115 |
-
vampnet = interface.coarse
|
| 116 |
-
|
| 117 |
-
# get the tokens
|
| 118 |
-
z = interface.encode(sig)[:, :vampnet.n_codebooks, :]
|
| 119 |
-
z_latents = vampnet.embedding.from_codes(z, interface.codec)
|
| 120 |
-
|
| 121 |
-
# do a forward pass through the model, get the embeddings
|
| 122 |
-
_z, embeddings = vampnet(z_latents, return_activations=True)
|
| 123 |
-
# print(f"got embeddings with shape {embeddings.shape}")
|
| 124 |
-
# [layer, batch, time, n_dims]
|
| 125 |
-
# [20, 1, 600ish, 768]
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
# squeeze batch dim (1 bc layer should be dim 0)
|
| 129 |
-
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}"
|
| 130 |
-
embeddings = embeddings.squeeze(1)
|
| 131 |
-
|
| 132 |
-
num_layers = embeddings.shape[0]
|
| 133 |
-
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers"
|
| 134 |
-
|
| 135 |
-
# do meanpooling over the time dimension
|
| 136 |
-
embeddings = embeddings.mean(dim=-2)
|
| 137 |
-
# [20, 768]
|
| 138 |
-
|
| 139 |
-
# return the embeddings
|
| 140 |
-
return embeddings
|
| 141 |
-
|
| 142 |
-
from dataclasses import dataclass, fields
|
| 143 |
-
@dataclass
|
| 144 |
-
class Embedding:
|
| 145 |
-
genre: str
|
| 146 |
-
filename: str
|
| 147 |
-
embedding: np.ndarray
|
| 148 |
-
|
| 149 |
-
def save(self, path):
|
| 150 |
-
"""Save the Embedding object to a given path as a zip file."""
|
| 151 |
-
with zipfile.ZipFile(path, 'w') as archive:
|
| 152 |
-
|
| 153 |
-
# Save numpy array
|
| 154 |
-
with archive.open('embedding.npy', 'w') as f:
|
| 155 |
-
np.save(f, self.embedding)
|
| 156 |
-
|
| 157 |
-
# Save non-numpy data as json
|
| 158 |
-
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'}
|
| 159 |
-
with archive.open('data.json', 'w') as f:
|
| 160 |
-
f.write(json.dumps(non_numpy_data).encode('utf-8'))
|
| 161 |
-
|
| 162 |
-
@classmethod
|
| 163 |
-
def load(cls, path):
|
| 164 |
-
"""Load the Embedding object from a given zip path."""
|
| 165 |
-
with zipfile.ZipFile(path, 'r') as archive:
|
| 166 |
-
|
| 167 |
-
# Load numpy array
|
| 168 |
-
with archive.open('embedding.npy') as f:
|
| 169 |
-
embedding = np.load(f)
|
| 170 |
-
|
| 171 |
-
# Load non-numpy data from json
|
| 172 |
-
with archive.open('data.json') as f:
|
| 173 |
-
data = json.loads(f.read().decode('utf-8'))
|
| 174 |
-
|
| 175 |
-
return cls(embedding=embedding, **data)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
@argbind.bind(without_prefix=True)
|
| 179 |
-
def main(
|
| 180 |
-
path_to_gtzan: str = None,
|
| 181 |
-
cache_dir: str = "./.gtzan_emb_cache",
|
| 182 |
-
output_dir: str = "./gtzan_vampnet_embeddings",
|
| 183 |
-
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
|
| 184 |
-
):
|
| 185 |
-
path_to_gtzan = Path(path_to_gtzan)
|
| 186 |
-
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist"
|
| 187 |
-
|
| 188 |
-
cache_dir = Path(cache_dir)
|
| 189 |
-
output_dir = Path(output_dir)
|
| 190 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
| 191 |
-
|
| 192 |
-
# load our interface
|
| 193 |
-
# argbind will automatically load the default config,
|
| 194 |
-
interface = Interface()
|
| 195 |
-
|
| 196 |
-
# gtzan should have a folder for each genre, so let's get the list of genres
|
| 197 |
-
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()]
|
| 198 |
-
print(f"Found {len(genres)} genres")
|
| 199 |
-
print(f"genres: {genres}")
|
| 200 |
-
|
| 201 |
-
# collect audio files, genres, and embeddings
|
| 202 |
-
data = []
|
| 203 |
-
for genre in genres:
|
| 204 |
-
audio_files = list(at.util.find_audio(path_to_gtzan / genre))
|
| 205 |
-
print(f"Found {len(audio_files)} audio files for genre {genre}")
|
| 206 |
-
|
| 207 |
-
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"):
|
| 208 |
-
# check if we have a cached embedding for this file
|
| 209 |
-
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb")
|
| 210 |
-
if cached_path.exists():
|
| 211 |
-
# if so, load it
|
| 212 |
-
if DEBUG:
|
| 213 |
-
print(f"loading cached embedding for {cached_path.stem}")
|
| 214 |
-
embedding = Embedding.load(cached_path)
|
| 215 |
-
data.append(embedding)
|
| 216 |
-
else:
|
| 217 |
-
try:
|
| 218 |
-
sig = AudioSignal(audio_file)
|
| 219 |
-
except Exception as e:
|
| 220 |
-
print(f"failed to load {audio_file.name} with error {e}")
|
| 221 |
-
print(f"skipping {audio_file.name}")
|
| 222 |
-
continue
|
| 223 |
-
|
| 224 |
-
# gets the embedding
|
| 225 |
-
emb = vampnet_embed(sig, interface).cpu().numpy()
|
| 226 |
-
|
| 227 |
-
# create an embedding we can save/load
|
| 228 |
-
embedding = Embedding(
|
| 229 |
-
genre=genre,
|
| 230 |
-
filename=audio_file.name,
|
| 231 |
-
embedding=emb
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
# cache the embeddings
|
| 235 |
-
cached_path.parent.mkdir(exist_ok=True, parents=True)
|
| 236 |
-
embedding.save(cached_path)
|
| 237 |
-
|
| 238 |
-
# now, let's do a dim reduction on the embeddings
|
| 239 |
-
# and visualize them.
|
| 240 |
-
|
| 241 |
-
# collect a list of embeddings and labels
|
| 242 |
-
embeddings = [d.embedding for d in data]
|
| 243 |
-
labels = [d.genre for d in data]
|
| 244 |
-
|
| 245 |
-
# convert the embeddings to a numpy array
|
| 246 |
-
embeddings = np.stack(embeddings)
|
| 247 |
-
|
| 248 |
-
# do dimensionality reduction for each layer we're given
|
| 249 |
-
for layer in tqdm.tqdm(layers, desc="dim reduction"):
|
| 250 |
-
dim_reduce(
|
| 251 |
-
embeddings[:, layer, :], labels,
|
| 252 |
-
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'),
|
| 253 |
-
n_components=2, method='tsne',
|
| 254 |
-
title=f'vampnet-gtzan-layer={layer}'
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
if __name__ == "__main__":
|
| 261 |
-
args = argbind.parse_args()
|
| 262 |
-
with argbind.scope(args):
|
| 263 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/{data/maestro-reorg.py → maestro-reorg.py}
RENAMED
|
File without changes
|
scripts/utils/remove_quiet_files.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
# removes files with loudness below 24db
|
| 2 |
-
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import shutil
|
| 5 |
-
import audiotools as at
|
| 6 |
-
import argbind
|
| 7 |
-
|
| 8 |
-
@argbind.bind(without_prefix=True)
|
| 9 |
-
def remove_quiet_files(
|
| 10 |
-
src_dir: Path = None,
|
| 11 |
-
dest_dir: Path = None,
|
| 12 |
-
min_loudness: float = -30,
|
| 13 |
-
):
|
| 14 |
-
# copy src to dest
|
| 15 |
-
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 16 |
-
shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
|
| 17 |
-
|
| 18 |
-
audio_files = at.util.find_audio(dest_dir)
|
| 19 |
-
for audio_file in audio_files:
|
| 20 |
-
sig = at.AudioSignal(audio_file)
|
| 21 |
-
if sig.loudness() < min_loudness:
|
| 22 |
-
audio_file.unlink()
|
| 23 |
-
print(f"removed {audio_file}")
|
| 24 |
-
|
| 25 |
-
if __name__ == "__main__":
|
| 26 |
-
args = argbind.parse_args()
|
| 27 |
-
|
| 28 |
-
with argbind.scope(args):
|
| 29 |
-
remove_quiet_files()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/split_long_audio_file.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
import argbind
|
| 3 |
-
|
| 4 |
-
import audiotools as at
|
| 5 |
-
import tqdm
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@argbind.bind(without_prefix=True)
|
| 9 |
-
def split_long_audio_file(
|
| 10 |
-
file: str = None,
|
| 11 |
-
max_chunk_size_s: int = 60*10
|
| 12 |
-
):
|
| 13 |
-
file = Path(file)
|
| 14 |
-
output_dir = file.parent / file.stem
|
| 15 |
-
output_dir.mkdir()
|
| 16 |
-
|
| 17 |
-
sig = at.AudioSignal(file)
|
| 18 |
-
|
| 19 |
-
# split into chunks
|
| 20 |
-
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
| 21 |
-
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
| 22 |
-
preprocess=True))
|
| 23 |
-
):
|
| 24 |
-
sig.write(output_dir / f"{i}.wav")
|
| 25 |
-
|
| 26 |
-
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
| 27 |
-
|
| 28 |
-
return output_dir
|
| 29 |
-
|
| 30 |
-
if __name__ == "__main__":
|
| 31 |
-
args = argbind.parse_args()
|
| 32 |
-
|
| 33 |
-
with argbind.scope(args):
|
| 34 |
-
split_long_audio_file()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/utils/xeno-canto-dl.py
DELETED
|
@@ -1,234 +0,0 @@
|
|
| 1 |
-
from xenopy import Query
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
SPECIES = [
|
| 5 |
-
"American Robin",
|
| 6 |
-
"Northern Cardinal",
|
| 7 |
-
"Mourning Dove",
|
| 8 |
-
"American Crow",
|
| 9 |
-
"Baltimore Oriole",
|
| 10 |
-
"Blue Jay",
|
| 11 |
-
"Eastern Bluebird",
|
| 12 |
-
"House Finch",
|
| 13 |
-
"American Goldfinch",
|
| 14 |
-
"House Sparrow",
|
| 15 |
-
"Song Sparrow",
|
| 16 |
-
"Tufted Titmouse",
|
| 17 |
-
"White-breasted Nuthatch",
|
| 18 |
-
"European Starling",
|
| 19 |
-
"American Redstart",
|
| 20 |
-
"Red-winged Blackbird",
|
| 21 |
-
"Brown-headed Cowbird",
|
| 22 |
-
"Common Grackle",
|
| 23 |
-
"Boat-tailed Grackle",
|
| 24 |
-
"Common Yellowthroat",
|
| 25 |
-
"Northern Mockingbird",
|
| 26 |
-
"Carolina Wren",
|
| 27 |
-
"Eastern Meadowlark",
|
| 28 |
-
"Chipping Sparrow",
|
| 29 |
-
"Tree Swallow",
|
| 30 |
-
"Barn Swallow",
|
| 31 |
-
"Cliff Swallow",
|
| 32 |
-
"Pine Siskin",
|
| 33 |
-
"Indigo Bunting",
|
| 34 |
-
"Eastern Towhee",
|
| 35 |
-
"Carolina Chickadee",
|
| 36 |
-
"Great Crested Flycatcher",
|
| 37 |
-
"Eastern Wood-Pewee",
|
| 38 |
-
"Ovenbird",
|
| 39 |
-
"Northern Flicker",
|
| 40 |
-
"Red-eyed Vireo",
|
| 41 |
-
"American Woodcock",
|
| 42 |
-
"Eastern Phoebe",
|
| 43 |
-
"Downy Woodpecker",
|
| 44 |
-
"Scarlet Tanager",
|
| 45 |
-
"Yellow Warbler",
|
| 46 |
-
"White-eyed Vireo",
|
| 47 |
-
"Common Loon",
|
| 48 |
-
"White-throated Sparrow",
|
| 49 |
-
"Yellow-throated Vireo",
|
| 50 |
-
"Great Blue Heron",
|
| 51 |
-
"Belted Kingfisher",
|
| 52 |
-
"Pied-billed Grebe",
|
| 53 |
-
"Wild Turkey",
|
| 54 |
-
"Wood Thrush",
|
| 55 |
-
"Rose-breasted Grosbeak",
|
| 56 |
-
"Field Sparrow",
|
| 57 |
-
"Hooded Warbler",
|
| 58 |
-
"Northern Parula",
|
| 59 |
-
"Chestnut-sided Warbler",
|
| 60 |
-
"Blue-winged Warbler",
|
| 61 |
-
"Red-bellied Woodpecker",
|
| 62 |
-
"Yellow-billed Cuckoo",
|
| 63 |
-
"Gray Catbird",
|
| 64 |
-
"Northern Saw-whet Owl",
|
| 65 |
-
"Osprey",
|
| 66 |
-
"Common Nighthawk",
|
| 67 |
-
"Broad-winged Hawk",
|
| 68 |
-
"Black-throated Green Warbler",
|
| 69 |
-
"Great Horned Owl",
|
| 70 |
-
"Common Raven",
|
| 71 |
-
"Barred Owl",
|
| 72 |
-
"Canada Warbler",
|
| 73 |
-
"Magnolia Warbler",
|
| 74 |
-
"Black-and-white Warbler",
|
| 75 |
-
"Eastern Kingbird",
|
| 76 |
-
"Swainson's Thrush",
|
| 77 |
-
"Worm-eating Warbler",
|
| 78 |
-
"Prairie Warbler",
|
| 79 |
-
"Baltimore Oriole",
|
| 80 |
-
"Black-throated Blue Warbler",
|
| 81 |
-
"Louisiana Waterthrush",
|
| 82 |
-
"Blackburnian Warbler",
|
| 83 |
-
"Black-capped Chickadee",
|
| 84 |
-
"Cerulean Warbler",
|
| 85 |
-
"Red-shouldered Hawk",
|
| 86 |
-
"Cooper's Hawk",
|
| 87 |
-
"Yellow-throated Warbler",
|
| 88 |
-
"Blue-headed Vireo",
|
| 89 |
-
"Blackpoll Warbler",
|
| 90 |
-
"Ruffed Grouse",
|
| 91 |
-
"Kentucky Warbler",
|
| 92 |
-
"Hermit Thrush",
|
| 93 |
-
"Cedar Waxwing",
|
| 94 |
-
"Eastern Screech-Owl",
|
| 95 |
-
"Northern Goshawk",
|
| 96 |
-
"Green Heron",
|
| 97 |
-
"Red-tailed Hawk",
|
| 98 |
-
"Black Vulture",
|
| 99 |
-
"Hairy Woodpecker",
|
| 100 |
-
"Golden-crowned Kinglet",
|
| 101 |
-
"Ruby-crowned Kinglet",
|
| 102 |
-
"Bicknell's Thrush",
|
| 103 |
-
"Blue-gray Gnatcatcher",
|
| 104 |
-
"Veery",
|
| 105 |
-
"Pileated Woodpecker",
|
| 106 |
-
"Purple Finch",
|
| 107 |
-
"White-crowned Sparrow",
|
| 108 |
-
"Snow Bunting",
|
| 109 |
-
"Pine Grosbeak",
|
| 110 |
-
"American Tree Sparrow",
|
| 111 |
-
"Dark-eyed Junco",
|
| 112 |
-
"Snowy Owl",
|
| 113 |
-
"White-winged Crossbill",
|
| 114 |
-
"Red Crossbill",
|
| 115 |
-
"Common Redpoll",
|
| 116 |
-
"Northern Shrike",
|
| 117 |
-
"Northern Harrier",
|
| 118 |
-
"Rough-legged Hawk",
|
| 119 |
-
"Long-eared Owl",
|
| 120 |
-
"Evening Grosbeak",
|
| 121 |
-
"Northern Pintail",
|
| 122 |
-
"American Black Duck",
|
| 123 |
-
"Mallard",
|
| 124 |
-
"Canvasback",
|
| 125 |
-
"Redhead",
|
| 126 |
-
"Ring-necked Duck",
|
| 127 |
-
"Greater Scaup",
|
| 128 |
-
"Lesser Scaup",
|
| 129 |
-
"Bufflehead",
|
| 130 |
-
"Common Goldeneye",
|
| 131 |
-
"Hooded Merganser",
|
| 132 |
-
"Common Merganser",
|
| 133 |
-
"Red-breasted Merganser",
|
| 134 |
-
"Ruddy Duck",
|
| 135 |
-
"Wood Duck",
|
| 136 |
-
"Gadwall",
|
| 137 |
-
"American Wigeon",
|
| 138 |
-
"Northern Shoveler",
|
| 139 |
-
"Green-winged Teal",
|
| 140 |
-
"Blue-winged Teal",
|
| 141 |
-
"Cinnamon Teal",
|
| 142 |
-
"Ringed Teal",
|
| 143 |
-
"Cape Teal",
|
| 144 |
-
"Northern Fulmar",
|
| 145 |
-
"Yellow-billed Loon",
|
| 146 |
-
"Red-throated Loon",
|
| 147 |
-
"Arctic Loon",
|
| 148 |
-
"Pacific Loon",
|
| 149 |
-
"Horned Grebe",
|
| 150 |
-
"Red-necked Grebe",
|
| 151 |
-
"Eared Grebe",
|
| 152 |
-
"Western Grebe",
|
| 153 |
-
"Clark's Grebe",
|
| 154 |
-
"Double-crested Cormorant",
|
| 155 |
-
"Pelagic Cormorant",
|
| 156 |
-
"Great Cormorant",
|
| 157 |
-
"American White Pelican",
|
| 158 |
-
"Brown Pelican",
|
| 159 |
-
"Brandt's Cormorant",
|
| 160 |
-
"Least Bittern",
|
| 161 |
-
"Great Egret",
|
| 162 |
-
"Snowy Egret",
|
| 163 |
-
"Little Blue Heron",
|
| 164 |
-
"Tricolored Heron",
|
| 165 |
-
"Reddish Egret",
|
| 166 |
-
"Black-crowned Night-Heron",
|
| 167 |
-
"Yellow-crowned Night-Heron",
|
| 168 |
-
"White Ibis",
|
| 169 |
-
"Glossy Ibis",
|
| 170 |
-
"Roseate Spoonbill",
|
| 171 |
-
"Wood Stork",
|
| 172 |
-
"Black-bellied Whistling-Duck",
|
| 173 |
-
"Fulvous Whistling-Duck",
|
| 174 |
-
"Greater White-fronted Goose",
|
| 175 |
-
"Snow Goose",
|
| 176 |
-
"Ross's Goose",
|
| 177 |
-
"Canada Goose",
|
| 178 |
-
"Brant",
|
| 179 |
-
"Mute Swan",
|
| 180 |
-
"Tundra Swan",
|
| 181 |
-
"Whooper Swan",
|
| 182 |
-
"Sandhill Crane",
|
| 183 |
-
"Black-necked Stilt",
|
| 184 |
-
"American Avocet",
|
| 185 |
-
"Northern Jacana",
|
| 186 |
-
"Greater Yellowlegs",
|
| 187 |
-
"Lesser Yellowlegs",
|
| 188 |
-
"Willet",
|
| 189 |
-
"Spotted Sandpiper",
|
| 190 |
-
"Upland Sandpiper",
|
| 191 |
-
"Whimbrel",
|
| 192 |
-
"Long-billed Curlew",
|
| 193 |
-
"Marbled Godwit",
|
| 194 |
-
"Ruddy Turnstone",
|
| 195 |
-
"Red Knot",
|
| 196 |
-
"Sanderling",
|
| 197 |
-
"Semipalmated Sandpiper",
|
| 198 |
-
"Western Sandpiper",
|
| 199 |
-
"Least Sandpiper",
|
| 200 |
-
"White-rumped Sandpiper",
|
| 201 |
-
"Baird's Sandpiper",
|
| 202 |
-
"Pectoral Sandpiper",
|
| 203 |
-
"Dunlin",
|
| 204 |
-
"Buff-breasted Sandpiper",
|
| 205 |
-
"Short-billed Dowitcher",
|
| 206 |
-
"Long-billed Dowitcher",
|
| 207 |
-
"Common Snipe",
|
| 208 |
-
"American Woodcock",
|
| 209 |
-
"Wilson's Phalarope",
|
| 210 |
-
"Red-necked Phalarope",
|
| 211 |
-
"Red Phalarope"
|
| 212 |
-
]
|
| 213 |
-
|
| 214 |
-
from pathlib import Path
|
| 215 |
-
|
| 216 |
-
def remove_spaces(s):
|
| 217 |
-
return s.replace(" ", "")
|
| 218 |
-
|
| 219 |
-
for species in SPECIES:
|
| 220 |
-
if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
|
| 221 |
-
continue
|
| 222 |
-
try:
|
| 223 |
-
q = Query(
|
| 224 |
-
name=species, q="A", length="10-30",
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# retrieve metadata
|
| 228 |
-
metafiles = q.retrieve_meta(verbose=True)
|
| 229 |
-
# retrieve recordings
|
| 230 |
-
q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
|
| 231 |
-
|
| 232 |
-
except:
|
| 233 |
-
print("Failed to download " + species)
|
| 234 |
-
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
CHANGED
|
@@ -28,13 +28,12 @@ setup(
|
|
| 28 |
install_requires=[
|
| 29 |
"torch",
|
| 30 |
"argbind>=0.3.2",
|
| 31 |
-
"numpy==1.
|
| 32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
| 33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 34 |
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
| 35 |
"gradio",
|
|
|
|
| 36 |
"loralib",
|
| 37 |
-
"torch_pitch_shift",
|
| 38 |
-
"madmom",
|
| 39 |
],
|
| 40 |
)
|
|
|
|
| 28 |
install_requires=[
|
| 29 |
"torch",
|
| 30 |
"argbind>=0.3.2",
|
| 31 |
+
"numpy==1.22",
|
| 32 |
"wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat",
|
| 33 |
"lac @ git+https://github.com/hugofloresgarcia/lac.git",
|
| 34 |
"descript-audiotools @ git+https://github.com/descriptinc/audiotools.git@0.7.2",
|
| 35 |
"gradio",
|
| 36 |
+
"tensorboardX",
|
| 37 |
"loralib",
|
|
|
|
|
|
|
| 38 |
],
|
| 39 |
)
|
vampnet/interface.py
CHANGED
|
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
|
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
-
self.codec = DAC.load(
|
| 69 |
self.codec.eval()
|
| 70 |
self.codec.to(device)
|
| 71 |
|
|
@@ -120,16 +120,17 @@ class Interface(torch.nn.Module):
|
|
| 120 |
if coarse_ckpt is not None:
|
| 121 |
self.coarse.to("cpu")
|
| 122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
| 123 |
-
|
| 124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
| 125 |
self.coarse.to(self.device)
|
| 126 |
if c2f_ckpt is not None:
|
| 127 |
self.c2f.to("cpu")
|
| 128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
| 129 |
-
|
| 130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
self.c2f.to(self.device)
|
| 132 |
|
|
|
|
| 133 |
def s2t(self, seconds: float):
|
| 134 |
"""seconds to tokens"""
|
| 135 |
if isinstance(seconds, np.ndarray):
|
|
@@ -193,8 +194,8 @@ class Interface(torch.nn.Module):
|
|
| 193 |
|
| 194 |
def make_beat_mask(self,
|
| 195 |
signal: AudioSignal,
|
| 196 |
-
before_beat_s: float = 0.
|
| 197 |
-
after_beat_s: float = 0.
|
| 198 |
mask_downbeats: bool = True,
|
| 199 |
mask_upbeats: bool = True,
|
| 200 |
downbeat_downsample_factor: int = None,
|
|
|
|
| 65 |
):
|
| 66 |
super().__init__()
|
| 67 |
assert codec_ckpt is not None, "must provide a codec checkpoint"
|
| 68 |
+
self.codec = DAC.load(codec_ckpt)
|
| 69 |
self.codec.eval()
|
| 70 |
self.codec.to(device)
|
| 71 |
|
|
|
|
| 120 |
if coarse_ckpt is not None:
|
| 121 |
self.coarse.to("cpu")
|
| 122 |
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
| 123 |
+
|
| 124 |
self.coarse.load_state_dict(state_dict, strict=False)
|
| 125 |
self.coarse.to(self.device)
|
| 126 |
if c2f_ckpt is not None:
|
| 127 |
self.c2f.to("cpu")
|
| 128 |
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
| 129 |
+
|
| 130 |
self.c2f.load_state_dict(state_dict, strict=False)
|
| 131 |
self.c2f.to(self.device)
|
| 132 |
|
| 133 |
+
|
| 134 |
def s2t(self, seconds: float):
|
| 135 |
"""seconds to tokens"""
|
| 136 |
if isinstance(seconds, np.ndarray):
|
|
|
|
| 194 |
|
| 195 |
def make_beat_mask(self,
|
| 196 |
signal: AudioSignal,
|
| 197 |
+
before_beat_s: float = 0.1,
|
| 198 |
+
after_beat_s: float = 0.1,
|
| 199 |
mask_downbeats: bool = True,
|
| 200 |
mask_upbeats: bool = True,
|
| 201 |
downbeat_downsample_factor: int = None,
|
vampnet/mask.py
CHANGED
|
@@ -191,47 +191,29 @@ def onset_mask(
|
|
| 191 |
width: int = 1
|
| 192 |
):
|
| 193 |
import librosa
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
if onset_indices.shape[0] == 0:
|
| 214 |
-
mask = empty_mask(z)
|
| 215 |
-
print(f"no onsets found, returning empty mask")
|
| 216 |
-
else:
|
| 217 |
-
torch.set_printoptions(threshold=1000)
|
| 218 |
-
print("onset indices: ", onset_indices)
|
| 219 |
-
print("onset times: ", onset_times)
|
| 220 |
-
|
| 221 |
-
# create a mask, set onset
|
| 222 |
-
mask = torch.ones_like(z)
|
| 223 |
-
n_timesteps = z.shape[-1]
|
| 224 |
-
|
| 225 |
-
for onset_index in onset_indices:
|
| 226 |
-
onset_index = min(onset_index, n_timesteps - 1)
|
| 227 |
-
onset_index = max(onset_index, 0)
|
| 228 |
-
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
| 229 |
-
|
| 230 |
-
print(mask)
|
| 231 |
|
| 232 |
return mask
|
| 233 |
|
| 234 |
|
| 235 |
|
| 236 |
if __name__ == "__main__":
|
| 237 |
-
|
|
|
|
|
|
| 191 |
width: int = 1
|
| 192 |
):
|
| 193 |
import librosa
|
| 194 |
+
|
| 195 |
+
onset_indices = librosa.onset.onset_detect(
|
| 196 |
+
y=sig.clone().to_mono().samples.cpu().numpy()[0, 0],
|
| 197 |
+
sr=sig.sample_rate,
|
| 198 |
+
hop_length=interface.codec.hop_length,
|
| 199 |
+
backtrack=True,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# create a mask, set onset
|
| 203 |
+
mask = torch.ones_like(z)
|
| 204 |
+
n_timesteps = z.shape[-1]
|
| 205 |
+
|
| 206 |
+
for onset_index in onset_indices:
|
| 207 |
+
onset_index = min(onset_index, n_timesteps - 1)
|
| 208 |
+
onset_index = max(onset_index, 0)
|
| 209 |
+
mask[:, :, onset_index - width:onset_index + width] = 0.0
|
| 210 |
+
|
| 211 |
+
print(mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
return mask
|
| 214 |
|
| 215 |
|
| 216 |
|
| 217 |
if __name__ == "__main__":
|
| 218 |
+
torch.set_printoptions(threshold=10000)
|
| 219 |
+
|
vampnet/modules/transformer.py
CHANGED
|
@@ -410,9 +410,7 @@ class TransformerStack(nn.Module):
|
|
| 410 |
def subsequent_mask(self, size):
|
| 411 |
return torch.ones(1, size, size).tril().bool()
|
| 412 |
|
| 413 |
-
def forward(self, x, x_mask, cond=None, src=None, src_mask=None
|
| 414 |
-
return_activations: bool = False
|
| 415 |
-
):
|
| 416 |
"""Computes a full transformer stack
|
| 417 |
Parameters
|
| 418 |
----------
|
|
@@ -439,8 +437,6 @@ class TransformerStack(nn.Module):
|
|
| 439 |
encoder_decoder_position_bias = None
|
| 440 |
|
| 441 |
# Compute transformer layers
|
| 442 |
-
if return_activations:
|
| 443 |
-
activations = []
|
| 444 |
for layer in self.layers:
|
| 445 |
x, position_bias, encoder_decoder_position_bias = layer(
|
| 446 |
x=x,
|
|
@@ -451,15 +447,8 @@ class TransformerStack(nn.Module):
|
|
| 451 |
position_bias=position_bias,
|
| 452 |
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
| 453 |
)
|
| 454 |
-
if return_activations:
|
| 455 |
-
activations.append(x.detach())
|
| 456 |
|
| 457 |
-
|
| 458 |
-
out = self.norm(x) if self.norm is not None else x
|
| 459 |
-
if return_activations:
|
| 460 |
-
return out, torch.stack(activations)
|
| 461 |
-
else:
|
| 462 |
-
return out
|
| 463 |
|
| 464 |
|
| 465 |
class VampNet(at.ml.BaseModel):
|
|
@@ -467,7 +456,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 467 |
self,
|
| 468 |
n_heads: int = 20,
|
| 469 |
n_layers: int = 16,
|
| 470 |
-
r_cond_dim: int =
|
| 471 |
n_codebooks: int = 9,
|
| 472 |
n_conditioning_codebooks: int = 0,
|
| 473 |
latent_dim: int = 8,
|
|
@@ -478,7 +467,6 @@ class VampNet(at.ml.BaseModel):
|
|
| 478 |
dropout: float = 0.1
|
| 479 |
):
|
| 480 |
super().__init__()
|
| 481 |
-
assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
|
| 482 |
self.n_heads = n_heads
|
| 483 |
self.n_layers = n_layers
|
| 484 |
self.r_cond_dim = r_cond_dim
|
|
@@ -525,25 +513,21 @@ class VampNet(at.ml.BaseModel):
|
|
| 525 |
),
|
| 526 |
)
|
| 527 |
|
| 528 |
-
def forward(self, x,
|
| 529 |
x = self.embedding(x)
|
| 530 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
| 531 |
|
| 532 |
-
|
| 533 |
-
out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
|
| 534 |
-
if return_activations:
|
| 535 |
-
out, activations = out
|
| 536 |
|
|
|
|
|
|
|
| 537 |
out = rearrange(out, "b n d -> b d n")
|
| 538 |
|
| 539 |
-
out = self.classifier(out,
|
| 540 |
|
| 541 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
| 542 |
|
| 543 |
-
|
| 544 |
-
return out, activations
|
| 545 |
-
else:
|
| 546 |
-
return out
|
| 547 |
|
| 548 |
def r_embed(self, r, max_positions=10000):
|
| 549 |
if self.r_cond_dim > 0:
|
|
@@ -594,23 +578,22 @@ class VampNet(at.ml.BaseModel):
|
|
| 594 |
self,
|
| 595 |
codec,
|
| 596 |
time_steps: int = 300,
|
| 597 |
-
sampling_steps: int =
|
| 598 |
start_tokens: Optional[torch.Tensor] = None,
|
| 599 |
-
sampling_temperature: float = 1.0,
|
| 600 |
mask: Optional[torch.Tensor] = None,
|
| 601 |
-
|
| 602 |
typical_filtering=False,
|
| 603 |
typical_mass=0.2,
|
| 604 |
typical_min_tokens=1,
|
| 605 |
-
top_p=None,
|
| 606 |
return_signal=True,
|
| 607 |
-
seed: int = None,
|
| 608 |
-
sample_cutoff: float = 1.0,
|
| 609 |
):
|
| 610 |
-
if seed is not None:
|
| 611 |
-
at.util.seed(seed)
|
| 612 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
|
| 616 |
#####################
|
|
@@ -662,6 +645,9 @@ class VampNet(at.ml.BaseModel):
|
|
| 662 |
for i in range(sampling_steps):
|
| 663 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 664 |
|
|
|
|
|
|
|
|
|
|
| 665 |
# our current schedule step
|
| 666 |
r = scalar_to_batch_tensor(
|
| 667 |
(i + 1) / sampling_steps,
|
|
@@ -676,24 +662,41 @@ class VampNet(at.ml.BaseModel):
|
|
| 676 |
|
| 677 |
# infer from latents
|
| 678 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 679 |
-
logits = self.forward(latents) # b, prob, seq
|
| 680 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 681 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 684 |
|
| 685 |
-
sampled_z, selected_probs = sample_from_logits(
|
| 686 |
-
logits, sample=(
|
| 687 |
-
(i / sampling_steps) <= sample_cutoff
|
| 688 |
-
),
|
| 689 |
-
temperature=sampling_temperature,
|
| 690 |
-
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
| 691 |
-
typical_min_tokens=typical_min_tokens,
|
| 692 |
-
top_k=None, top_p=top_p, return_probs=True,
|
| 693 |
-
)
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 696 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 698 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 699 |
# remove conditioning codebooks, we'll add them back at the end
|
|
@@ -730,7 +733,7 @@ class VampNet(at.ml.BaseModel):
|
|
| 730 |
|
| 731 |
# get our new mask
|
| 732 |
mask = mask_by_random_topk(
|
| 733 |
-
num_to_mask, selected_probs,
|
| 734 |
)
|
| 735 |
|
| 736 |
# update the mask
|
|
@@ -763,97 +766,8 @@ class VampNet(at.ml.BaseModel):
|
|
| 763 |
else:
|
| 764 |
return sampled_z
|
| 765 |
|
| 766 |
-
def sample_from_logits(
|
| 767 |
-
logits,
|
| 768 |
-
sample: bool = True,
|
| 769 |
-
temperature: float = 1.0,
|
| 770 |
-
top_k: int = None,
|
| 771 |
-
top_p: float = None,
|
| 772 |
-
typical_filtering: bool = False,
|
| 773 |
-
typical_mass: float = 0.2,
|
| 774 |
-
typical_min_tokens: int = 1,
|
| 775 |
-
return_probs: bool = False
|
| 776 |
-
):
|
| 777 |
-
"""Convenience function to sample from a categorial distribution with input as
|
| 778 |
-
unnormalized logits.
|
| 779 |
-
|
| 780 |
-
Parameters
|
| 781 |
-
----------
|
| 782 |
-
logits : Tensor[..., vocab_size]
|
| 783 |
-
config: SamplingConfig
|
| 784 |
-
The set of hyperparameters to be used for sampling
|
| 785 |
-
sample : bool, optional
|
| 786 |
-
Whether to perform multinomial sampling, by default True
|
| 787 |
-
temperature : float, optional
|
| 788 |
-
Scaling parameter when multinomial samping, by default 1.0
|
| 789 |
-
top_k : int, optional
|
| 790 |
-
Restricts sampling to only `top_k` values acc. to probability,
|
| 791 |
-
by default None
|
| 792 |
-
top_p : float, optional
|
| 793 |
-
Restricts sampling to only those values with cumulative
|
| 794 |
-
probability = `top_p`, by default None
|
| 795 |
-
|
| 796 |
-
Returns
|
| 797 |
-
-------
|
| 798 |
-
Tensor[...]
|
| 799 |
-
Sampled tokens
|
| 800 |
-
"""
|
| 801 |
-
shp = logits.shape[:-1]
|
| 802 |
-
|
| 803 |
-
if typical_filtering:
|
| 804 |
-
typical_filter(logits,
|
| 805 |
-
typical_mass=typical_mass,
|
| 806 |
-
typical_min_tokens=typical_min_tokens
|
| 807 |
-
)
|
| 808 |
-
|
| 809 |
-
# Apply top_k sampling
|
| 810 |
-
if top_k is not None:
|
| 811 |
-
v, _ = logits.topk(top_k)
|
| 812 |
-
logits[logits < v[..., [-1]]] = -float("inf")
|
| 813 |
-
|
| 814 |
-
# Apply top_p (nucleus) sampling
|
| 815 |
-
if top_p is not None and top_p < 1.0:
|
| 816 |
-
v, sorted_indices = logits.sort(descending=True)
|
| 817 |
-
cumulative_probs = v.softmax(dim=-1).cumsum(dim=-1)
|
| 818 |
-
|
| 819 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 820 |
-
# Right shift indices_to_remove to keep 1st token over threshold
|
| 821 |
-
sorted_indices_to_remove = F.pad(sorted_indices_to_remove, (1, 0), value=False)[
|
| 822 |
-
..., :-1
|
| 823 |
-
]
|
| 824 |
-
|
| 825 |
-
# Compute indices_to_remove in unsorted array
|
| 826 |
-
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 827 |
-
-1, sorted_indices, sorted_indices_to_remove
|
| 828 |
-
)
|
| 829 |
-
|
| 830 |
-
logits[indices_to_remove] = -float("inf")
|
| 831 |
-
|
| 832 |
-
# Perform multinomial sampling after normalizing logits
|
| 833 |
-
probs = (
|
| 834 |
-
F.softmax(logits / temperature, dim=-1)
|
| 835 |
-
if temperature > 0
|
| 836 |
-
else logits.softmax(dim=-1)
|
| 837 |
-
)
|
| 838 |
-
token = (
|
| 839 |
-
probs.view(-1, probs.size(-1)).multinomial(1).squeeze(1).view(*shp)
|
| 840 |
-
if sample
|
| 841 |
-
else logits.argmax(-1)
|
| 842 |
-
)
|
| 843 |
-
|
| 844 |
-
if return_probs:
|
| 845 |
-
token_probs = probs.take_along_dim(token.unsqueeze(-1), dim=-1).squeeze(-1)
|
| 846 |
-
return token, token_probs
|
| 847 |
-
else:
|
| 848 |
-
return token
|
| 849 |
-
|
| 850 |
-
|
| 851 |
|
| 852 |
-
def mask_by_random_topk(
|
| 853 |
-
num_to_mask: int,
|
| 854 |
-
probs: torch.Tensor,
|
| 855 |
-
temperature: float = 1.0,
|
| 856 |
-
):
|
| 857 |
"""
|
| 858 |
Args:
|
| 859 |
num_to_mask (int): number of tokens to mask
|
|
@@ -866,8 +780,7 @@ def mask_by_random_topk(
|
|
| 866 |
logging.debug(f"temperature: {temperature}")
|
| 867 |
logging.debug("")
|
| 868 |
|
| 869 |
-
|
| 870 |
-
confidence = torch.log(probs) + temperature * noise
|
| 871 |
logging.debug(f"confidence shape: {confidence.shape}")
|
| 872 |
|
| 873 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
|
@@ -937,7 +850,7 @@ if __name__ == "__main__":
|
|
| 937 |
z_mask_latent = torch.rand(
|
| 938 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
| 939 |
).to(device)
|
| 940 |
-
z_hat = model(z_mask_latent)
|
| 941 |
|
| 942 |
pred = z_hat.argmax(dim=1)
|
| 943 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|
|
|
|
| 410 |
def subsequent_mask(self, size):
|
| 411 |
return torch.ones(1, size, size).tril().bool()
|
| 412 |
|
| 413 |
+
def forward(self, x, x_mask, cond=None, src=None, src_mask=None):
|
|
|
|
|
|
|
| 414 |
"""Computes a full transformer stack
|
| 415 |
Parameters
|
| 416 |
----------
|
|
|
|
| 437 |
encoder_decoder_position_bias = None
|
| 438 |
|
| 439 |
# Compute transformer layers
|
|
|
|
|
|
|
| 440 |
for layer in self.layers:
|
| 441 |
x, position_bias, encoder_decoder_position_bias = layer(
|
| 442 |
x=x,
|
|
|
|
| 447 |
position_bias=position_bias,
|
| 448 |
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
| 449 |
)
|
|
|
|
|
|
|
| 450 |
|
| 451 |
+
return self.norm(x) if self.norm is not None else x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
|
| 454 |
class VampNet(at.ml.BaseModel):
|
|
|
|
| 456 |
self,
|
| 457 |
n_heads: int = 20,
|
| 458 |
n_layers: int = 16,
|
| 459 |
+
r_cond_dim: int = 64,
|
| 460 |
n_codebooks: int = 9,
|
| 461 |
n_conditioning_codebooks: int = 0,
|
| 462 |
latent_dim: int = 8,
|
|
|
|
| 467 |
dropout: float = 0.1
|
| 468 |
):
|
| 469 |
super().__init__()
|
|
|
|
| 470 |
self.n_heads = n_heads
|
| 471 |
self.n_layers = n_layers
|
| 472 |
self.r_cond_dim = r_cond_dim
|
|
|
|
| 513 |
),
|
| 514 |
)
|
| 515 |
|
| 516 |
+
def forward(self, x, cond):
|
| 517 |
x = self.embedding(x)
|
| 518 |
x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
|
| 519 |
|
| 520 |
+
cond = self.r_embed(cond)
|
|
|
|
|
|
|
|
|
|
| 521 |
|
| 522 |
+
x = rearrange(x, "b d n -> b n d")
|
| 523 |
+
out = self.transformer(x=x, x_mask=x_mask, cond=cond)
|
| 524 |
out = rearrange(out, "b n d -> b d n")
|
| 525 |
|
| 526 |
+
out = self.classifier(out, cond)
|
| 527 |
|
| 528 |
out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
|
| 529 |
|
| 530 |
+
return out
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
def r_embed(self, r, max_positions=10000):
|
| 533 |
if self.r_cond_dim > 0:
|
|
|
|
| 578 |
self,
|
| 579 |
codec,
|
| 580 |
time_steps: int = 300,
|
| 581 |
+
sampling_steps: int = 24,
|
| 582 |
start_tokens: Optional[torch.Tensor] = None,
|
|
|
|
| 583 |
mask: Optional[torch.Tensor] = None,
|
| 584 |
+
temperature: Union[float, Tuple[float, float]] = 2.5,
|
| 585 |
typical_filtering=False,
|
| 586 |
typical_mass=0.2,
|
| 587 |
typical_min_tokens=1,
|
|
|
|
| 588 |
return_signal=True,
|
|
|
|
|
|
|
| 589 |
):
|
|
|
|
|
|
|
| 590 |
logging.debug(f"beginning generation with {sampling_steps} steps")
|
| 591 |
|
| 592 |
+
#####################
|
| 593 |
+
# resolve temperature #
|
| 594 |
+
#####################
|
| 595 |
+
assert isinstance(temperature, float)
|
| 596 |
+
logging.debug(f"temperature: {temperature}")
|
| 597 |
|
| 598 |
|
| 599 |
#####################
|
|
|
|
| 645 |
for i in range(sampling_steps):
|
| 646 |
logging.debug(f"step {i} of {sampling_steps}")
|
| 647 |
|
| 648 |
+
# our current temperature
|
| 649 |
+
logging.debug(f"temperature: {temperature}")
|
| 650 |
+
|
| 651 |
# our current schedule step
|
| 652 |
r = scalar_to_batch_tensor(
|
| 653 |
(i + 1) / sampling_steps,
|
|
|
|
| 662 |
|
| 663 |
# infer from latents
|
| 664 |
# NOTE: this collapses the codebook dimension into the sequence dimension
|
| 665 |
+
logits = self.forward(latents, r) # b, prob, seq
|
| 666 |
logits = logits.permute(0, 2, 1) # b, seq, prob
|
| 667 |
+
if typical_filtering:
|
| 668 |
+
typical_filter(logits,
|
| 669 |
+
typical_mass=typical_mass,
|
| 670 |
+
typical_min_tokens=typical_min_tokens
|
| 671 |
+
)
|
| 672 |
+
|
| 673 |
|
| 674 |
logging.debug(f"permuted logits with shape: {logits.shape}")
|
| 675 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
| 677 |
+
# logits2probs
|
| 678 |
+
probs = torch.softmax(logits, dim=-1)
|
| 679 |
+
logging.debug(f"computed probs with shape: {probs.shape}")
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# sample from logits with multinomial sampling
|
| 683 |
+
b = probs.shape[0]
|
| 684 |
+
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
| 685 |
+
|
| 686 |
+
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
| 687 |
+
|
| 688 |
+
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
| 689 |
+
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
| 690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
| 691 |
|
| 692 |
+
# get the confidences: which tokens did we sample?
|
| 693 |
+
selected_probs = (
|
| 694 |
+
torch.take_along_dim(
|
| 695 |
+
probs, sampled_z.long().unsqueeze(-1),
|
| 696 |
+
dim=-1
|
| 697 |
+
).squeeze(-1)
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
| 701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
| 702 |
# remove conditioning codebooks, we'll add them back at the end
|
|
|
|
| 733 |
|
| 734 |
# get our new mask
|
| 735 |
mask = mask_by_random_topk(
|
| 736 |
+
num_to_mask, selected_probs, temperature * (1-r)
|
| 737 |
)
|
| 738 |
|
| 739 |
# update the mask
|
|
|
|
| 766 |
else:
|
| 767 |
return sampled_z
|
| 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
+
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
"""
|
| 772 |
Args:
|
| 773 |
num_to_mask (int): number of tokens to mask
|
|
|
|
| 780 |
logging.debug(f"temperature: {temperature}")
|
| 781 |
logging.debug("")
|
| 782 |
|
| 783 |
+
confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
|
|
|
|
| 784 |
logging.debug(f"confidence shape: {confidence.shape}")
|
| 785 |
|
| 786 |
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
|
|
|
| 850 |
z_mask_latent = torch.rand(
|
| 851 |
batch_size, model.latent_dim * model.n_codebooks, seq_len
|
| 852 |
).to(device)
|
| 853 |
+
z_hat = model(z_mask_latent, r)
|
| 854 |
|
| 855 |
pred = z_hat.argmax(dim=1)
|
| 856 |
pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
|