orrp commited on
Commit
9583919
·
1 Parent(s): e823eac

Refactoring, linting, switching to pyproject.toml and Docker

Browse files
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .venv
2
+ .git
3
+ __pycache__
4
+ *.pyc
5
+ .ruff_cache
6
+ wham/vampnet/models/*
vampnet/.pre-commit-config.yaml → .pre-commit-config.yaml RENAMED
File without changes
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
3
+
4
+ # Install ffmpeg for audio processing
5
+ RUN apt-get update && apt-get install -y ffmpeg git build-essential && rm -rf /var/lib/apt/lists/*
6
+
7
+ WORKDIR /app
8
+
9
+ # Install dependencies using uv
10
+ COPY pyproject.toml .
11
+ RUN uv pip install --system .
12
+
13
+ # Copy your code and run the app
14
+ COPY . .
15
+ EXPOSE 7860
16
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
17
+ CMD ["python", "vampnet/app.py"]
README.md CHANGED
@@ -3,11 +3,9 @@ title: WhAM
3
  emoji: 🐋
4
  colorFrom: blue
5
  colorTo: indigo
6
- sdk: gradio
7
- app_file: vampnet/app.py
8
  pinned: false
9
  hardware: a10g-small
10
- python_version: "3.10"
11
  ---
12
 
13
  # WhAM: a Whale Acoustics Model
 
3
  emoji: 🐋
4
  colorFrom: blue
5
  colorTo: indigo
6
+ sdk: docker
 
7
  pinned: false
8
  hardware: a10g-small
 
9
  ---
10
 
11
  # WhAM: a Whale Acoustics Model
pyproject.toml CHANGED
@@ -11,7 +11,7 @@ authors = [
11
  { name = "Project CETI" }
12
  ]
13
  license = { text = "MIT" }
14
- requires-python = ">=3.9"
15
  dependencies = [
16
  "torch",
17
  "gradio",
@@ -33,10 +33,10 @@ dependencies = [
33
  "gdown",
34
  "transformers",
35
  "fadtk",
36
- "urllib3==2.0",
37
  "plotly",
38
  "pyharp",
39
- # Git-based dependencies
40
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git",
41
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
42
  "descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git"
@@ -47,15 +47,13 @@ where = ["."]
47
  include = ["wham*", "vampnet*"]
48
 
49
  [tool.ruff]
50
- # Target Python 3.9+
51
- target-version = "py39"
52
  line-length = 88
53
 
54
  [tool.ruff.lint]
55
- # Enable Pyflakes (F), pycodestyle (E, W), and isort (I)
56
  select = ["E", "F", "W", "I"]
57
- ignore = []
58
 
59
- [tool.ruff.format]
60
- quote-style = "double"
61
- indent-style = "space"
 
11
  { name = "Project CETI" }
12
  ]
13
  license = { text = "MIT" }
14
+ requires-python = ">=3.10,<3.11"
15
  dependencies = [
16
  "torch",
17
  "gradio",
 
33
  "gdown",
34
  "transformers",
35
  "fadtk",
36
+ "urllib3>=2.0.2",
37
  "plotly",
38
  "pyharp",
39
+ "ruff",
40
  "wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git",
41
  "lac @ git+https://github.com/hugofloresgarcia/lac.git",
42
  "descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git"
 
47
  include = ["wham*", "vampnet*"]
48
 
49
  [tool.ruff]
50
+ target-version = "py310"
 
51
  line-length = 88
52
 
53
  [tool.ruff.lint]
 
54
  select = ["E", "F", "W", "I"]
55
+ fixable = ["ALL"]
56
 
57
+ [tool.ruff.lint.isort]
58
+ known-first-party = ["wham", "vampnet"]
59
+ section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
vampnet/app.py CHANGED
@@ -1,43 +1,40 @@
1
  import os
2
  import sys
 
 
3
 
 
 
 
 
 
 
 
 
 
4
 
5
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
6
  os.chdir(SCRIPT_DIR)
7
 
8
- import torch
9
- device = "cuda" if torch.cuda.is_available()
10
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
11
 
12
- from pathlib import Path
13
- from typing import Tuple
14
- import yaml
15
- import tempfile
16
- import uuid
17
- from dataclasses import dataclass, asdict
18
-
19
- import numpy as np
20
- import audiotools as at
21
- import argbind
22
-
23
- import gradio as gr
24
- from vampnet.interface import Interface
25
- from vampnet import mask as pmask
26
 
27
  Interface = argbind.bind(Interface)
28
 
29
  conf = argbind.parse_args()
30
 
31
 
32
- from torch_pitch_shift import pitch_shift, get_fast_shifts
 
 
33
  def shift_pitch(signal, interval: int):
34
  signal.samples = pitch_shift(
35
- signal.samples,
36
- shift=interval,
37
- sample_rate=signal.sample_rate
38
  )
39
  return signal
40
 
 
41
  def load_interface():
42
  with argbind.scope(conf):
43
  interface = Interface()
@@ -46,8 +43,6 @@ def load_interface():
46
  return interface
47
 
48
 
49
-
50
-
51
  interface = load_interface()
52
 
53
 
@@ -59,8 +54,7 @@ def load_audio(file):
59
  print(file)
60
  filepath = file.name
61
  sig = at.AudioSignal.salient_excerpt(
62
- filepath,
63
- duration=interface.coarse.chunk_size_s
64
  )
65
  sig = interface.preprocess(sig)
66
 
@@ -121,19 +115,10 @@ def _vamp(
121
  # build the mask
122
  mask = pmask.linear_random(z, _rand_mask_intensity)
123
  mask = pmask.mask_and(
124
- mask, pmask.inpaint(
125
- z,
126
- interface.s2t(_prefix_s),
127
- interface.s2t(_suffix_s)
128
- )
129
  )
130
  mask = pmask.mask_and(
131
- mask, pmask.periodic_mask(
132
- z,
133
- _periodic_p,
134
- _periodic_w,
135
- random_roll=True
136
- )
137
  )
138
  if _onset_mask_width > 0:
139
  mask = pmask.mask_or(
@@ -142,7 +127,7 @@ def _vamp(
142
  if _beat_mask_width > 0:
143
  beat_mask = interface.make_beat_mask(
144
  sig,
145
- after_beat_s=(_beat_mask_width/1000),
146
  mask_upbeats=not _beat_mask_downbeats,
147
  )
148
  mask = pmask.mask_and(mask, beat_mask)
@@ -174,14 +159,14 @@ def _vamp(
174
 
175
  _top_p_val = _top_p if _top_p > 0 else None
176
  # save the mask as a txt file
177
- np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
178
 
179
  _seed_val = _seed if _seed > 0 else None
180
  zv, mask_z = interface.coarse_vamp(
181
  z,
182
  mask=mask,
183
  sampling_steps=_num_steps,
184
- mask_temperature=_masktemp*10,
185
  sampling_temperature=_sampletemp,
186
  return_mask=True,
187
  typical_filtering=_typical_filtering,
@@ -196,7 +181,7 @@ def _vamp(
196
  if _use_coarse2fine:
197
  zv = interface.coarse_to_fine(
198
  zv,
199
- mask_temperature=_masktemp*10,
200
  sampling_temperature=_sampletemp,
201
  mask=mask,
202
  sampling_steps=_num_steps,
@@ -220,6 +205,7 @@ def _vamp(
220
  else:
221
  return sig.path_to_file
222
 
 
223
  def _extract_and_call_vamp(data, return_mask):
224
  """Extract plain values from Gradio data dict so only picklable args cross the ZeroGPU boundary."""
225
  return _vamp(
@@ -250,12 +236,15 @@ def _extract_and_call_vamp(data, return_mask):
250
  return_mask=return_mask,
251
  )
252
 
 
253
  def vamp(data):
254
  return _extract_and_call_vamp(data, return_mask=True)
255
 
 
256
  def api_vamp(data):
257
  return _extract_and_call_vamp(data, return_mask=False)
258
 
 
259
  def save_vamp(data):
260
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
261
  out_dir.mkdir(parents=True, exist_ok=True)
@@ -289,6 +278,7 @@ def save_vamp(data):
289
  yaml.dump(_data, f)
290
 
291
  import zipfile
 
292
  zip_path = str(out_dir.with_suffix(".zip"))
293
  with zipfile.ZipFile(zip_path, "w") as zf:
294
  for file in out_dir.iterdir():
@@ -312,7 +302,7 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
312
  if _beat_mask_width > 0:
313
  beat_mask = interface.make_beat_mask(
314
  sig,
315
- after_beat_s=(_beat_mask_width/1000),
316
  )
317
  mask = pmask.mask_and(mask, beat_mask)
318
 
@@ -325,7 +315,6 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
325
  gen_fn=interface.coarse.generate,
326
  )
327
 
328
-
329
  zv = interface.coarse_to_fine(
330
  zv,
331
  sampling_temperature=_sampletemp,
@@ -339,8 +328,8 @@ def harp_vamp(_input_audio, _beat_mask_width, _sampletemp):
339
 
340
  return sig.path_to_file
341
 
342
- with gr.Blocks() as demo:
343
 
 
344
  with gr.Row():
345
  with gr.Column():
346
  gr.Markdown("# VampNet Audio Vamping")
@@ -360,11 +349,9 @@ with gr.Blocks() as demo:
360
  """)
361
  with gr.Row():
362
  with gr.Column():
363
-
364
-
365
  manual_audio_upload = gr.File(
366
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
367
- file_types=["audio"]
368
  )
369
  load_example_audio_button = gr.Button("or load example audio")
370
 
@@ -382,71 +369,65 @@ with gr.Blocks() as demo:
382
 
383
  # connect widgets
384
  load_example_audio_button.click(
385
- fn=load_example_audio,
386
- inputs=[],
387
- outputs=[ input_audio]
388
  )
389
 
390
  manual_audio_upload.change(
391
- fn=load_audio,
392
- inputs=[manual_audio_upload],
393
- outputs=[ input_audio]
394
  )
395
 
396
  # mask settings
397
  with gr.Column():
398
-
399
-
400
  presets = {
401
- "unconditional": {
402
- "periodic_p": 0,
403
- "onset_mask_width": 0,
404
- "beat_mask_width": 0,
405
- "beat_mask_downbeats": False,
406
- },
407
- "slight periodic variation": {
408
- "periodic_p": 5,
409
- "onset_mask_width": 5,
410
- "beat_mask_width": 0,
411
- "beat_mask_downbeats": False,
412
- },
413
- "moderate periodic variation": {
414
- "periodic_p": 13,
415
- "onset_mask_width": 5,
416
- "beat_mask_width": 0,
417
- "beat_mask_downbeats": False,
418
- },
419
- "strong periodic variation": {
420
- "periodic_p": 17,
421
- "onset_mask_width": 5,
422
- "beat_mask_width": 0,
423
- "beat_mask_downbeats": False,
424
- },
425
- "very strong periodic variation": {
426
- "periodic_p": 21,
427
- "onset_mask_width": 5,
428
- "beat_mask_width": 0,
429
- "beat_mask_downbeats": False,
430
- },
431
- "beat-driven variation": {
432
- "periodic_p": 0,
433
- "onset_mask_width": 0,
434
- "beat_mask_width": 50,
435
- "beat_mask_downbeats": False,
436
- },
437
- "beat-driven variation (downbeats only)": {
438
- "periodic_p": 0,
439
- "onset_mask_width": 0,
440
- "beat_mask_width": 50,
441
- "beat_mask_downbeats": True,
442
- },
443
- "beat-driven variation (downbeats only, strong)": {
444
- "periodic_p": 0,
445
- "onset_mask_width": 0,
446
- "beat_mask_width": 20,
447
- "beat_mask_downbeats": True,
448
- },
449
- }
450
 
451
  preset = gr.Dropdown(
452
  label="preset",
@@ -464,7 +445,6 @@ with gr.Blocks() as demo:
464
  value=3,
465
  )
466
 
467
-
468
  onset_mask_width = gr.Slider(
469
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
470
  minimum=0,
@@ -480,8 +460,7 @@ with gr.Blocks() as demo:
480
  value=0,
481
  )
482
  beat_mask_downbeats = gr.Checkbox(
483
- label="beat mask downbeats only?",
484
- value=False
485
  )
486
 
487
  n_mask_codebooks = gr.Number(
@@ -489,7 +468,6 @@ with gr.Blocks() as demo:
489
  value=9,
490
  )
491
 
492
-
493
  with gr.Accordion("extras ", open=False):
494
  pitch_shift_amt = gr.Slider(
495
  label="pitch shift amount (semitones)",
@@ -503,7 +481,7 @@ with gr.Blocks() as demo:
503
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
504
  minimum=0.0,
505
  maximum=1.0,
506
- value=1.0
507
  )
508
 
509
  periodic_w = gr.Slider(
@@ -538,78 +516,62 @@ with gr.Blocks() as demo:
538
  return tuple(presets[_preset].values())
539
 
540
  load_preset_button.click(
541
- fn=load_preset,
542
- inputs=[preset],
543
- outputs=preset_outputs
544
  )
545
 
546
-
547
  with gr.Accordion("prefix/suffix prompts", open=False):
548
  prefix_s = gr.Slider(
549
  label="prefix hint length (seconds)",
550
  minimum=0.0,
551
  maximum=10.0,
552
- value=0.0
553
  )
554
  suffix_s = gr.Slider(
555
  label="suffix hint length (seconds)",
556
  minimum=0.0,
557
  maximum=10.0,
558
- value=0.0
559
  )
560
 
561
  masktemp = gr.Slider(
562
- label="mask temperature",
563
- minimum=0.0,
564
- maximum=100.0,
565
- value=1.5
566
  )
567
  sampletemp = gr.Slider(
568
  label="sample temperature",
569
  minimum=0.1,
570
  maximum=10.0,
571
  value=1.0,
572
- step=0.001
573
  )
574
 
575
-
576
-
577
  with gr.Accordion("sampling settings", open=False):
578
  top_p = gr.Slider(
579
- label="top p (0.0 = off)",
580
- minimum=0.0,
581
- maximum=1.0,
582
- value=0.0
583
- )
584
- typical_filtering = gr.Checkbox(
585
- label="typical filtering ",
586
- value=False
587
  )
 
588
  typical_mass = gr.Slider(
589
  label="typical mass (should probably stay between 0.1 and 0.5)",
590
  minimum=0.01,
591
  maximum=0.99,
592
- value=0.15
593
  )
594
  typical_min_tokens = gr.Slider(
595
  label="typical min tokens (should probably stay between 1 and 256)",
596
  minimum=1,
597
  maximum=256,
598
  step=1,
599
- value=64
600
  )
601
  sample_cutoff = gr.Slider(
602
  label="sample cutoff",
603
  minimum=0.0,
604
  maximum=1.0,
605
  value=0.5,
606
- step=0.01
607
  )
608
 
609
  use_coarse2fine = gr.Checkbox(
610
- label="use coarse2fine",
611
- value=True,
612
- visible=False
613
  )
614
 
615
  num_steps = gr.Slider(
@@ -617,29 +579,21 @@ with gr.Blocks() as demo:
617
  minimum=1,
618
  maximum=128,
619
  step=1,
620
- value=36
621
  )
622
 
623
  dropout = gr.Slider(
624
- label="mask dropout",
625
- minimum=0.0,
626
- maximum=1.0,
627
- step=0.01,
628
- value=0.0
629
  )
630
 
631
-
632
  seed = gr.Number(
633
  label="seed (0 for random)",
634
  value=0,
635
  precision=0,
636
  )
637
 
638
-
639
-
640
  # mask settings
641
  with gr.Column():
642
-
643
  # lora_choice = gr.Dropdown(
644
  # label="lora choice",
645
  # choices=list(loras.keys()),
@@ -649,51 +603,49 @@ with gr.Blocks() as demo:
649
 
650
  vamp_button = gr.Button("generate (vamp)!!!")
651
  output_audio = gr.Audio(
652
- label="output audio",
653
- interactive=False,
654
- type="filepath"
655
  )
656
 
657
  notes_text = gr.Textbox(
658
  label="type any notes about the generated audio here",
659
  value="",
660
- interactive=True
661
  )
662
  save_button = gr.Button("save vamp")
663
  download_file = gr.File(
664
- label="vamp to download will appear here",
665
- interactive=False
666
  )
667
  use_as_input_button = gr.Button("use output as input")
668
 
669
  thank_you = gr.Markdown("")
670
 
671
-
672
  _inputs = {
673
- input_audio,
674
- num_steps,
675
- masktemp,
676
- sampletemp,
677
- top_p,
678
- prefix_s, suffix_s,
679
- rand_mask_intensity,
680
- periodic_p, periodic_w,
681
- n_conditioning_codebooks,
682
- dropout,
683
- use_coarse2fine,
684
- stretch_factor,
685
- onset_mask_width,
686
- typical_filtering,
687
- typical_mass,
688
- typical_min_tokens,
689
- beat_mask_width,
690
- beat_mask_downbeats,
691
- seed,
692
- # lora_choice,
693
- n_mask_codebooks,
694
- pitch_shift_amt,
695
- sample_cutoff
696
- }
 
 
697
 
698
  # connect widgets
699
  vamp_button.click(
@@ -704,22 +656,17 @@ with gr.Blocks() as demo:
704
 
705
  api_vamp_button = gr.Button("api vamp", visible=False)
706
  api_vamp_button.click(
707
- fn=api_vamp,
708
- inputs=_inputs,
709
- outputs=[output_audio],
710
- api_name="vamp"
711
  )
712
 
713
  use_as_input_button.click(
714
- fn=lambda x: x,
715
- inputs=[output_audio],
716
- outputs=[input_audio]
717
  )
718
 
719
  save_button.click(
720
  fn=save_vamp,
721
  inputs=_inputs | {notes_text, output_audio},
722
- outputs=[thank_you, download_file]
723
  )
724
 
725
 
 
1
  import os
2
  import sys
3
+ import uuid
4
+ from pathlib import Path
5
 
6
+ import argbind
7
+ import audiotools as at
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ import yaml
12
+
13
+ from vampnet import mask as pmask
14
+ from vampnet.interface import Interface
15
 
16
  SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
17
  os.chdir(SCRIPT_DIR)
18
 
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
  sys.argv = ["app.py", "--args.load", "conf/interface.yml", "--Interface.device", device]
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  Interface = argbind.bind(Interface)
24
 
25
  conf = argbind.parse_args()
26
 
27
 
28
+ from torch_pitch_shift import pitch_shift
29
+
30
+
31
  def shift_pitch(signal, interval: int):
32
  signal.samples = pitch_shift(
33
+ signal.samples, shift=interval, sample_rate=signal.sample_rate
 
 
34
  )
35
  return signal
36
 
37
+
38
  def load_interface():
39
  with argbind.scope(conf):
40
  interface = Interface()
 
43
  return interface
44
 
45
 
 
 
46
  interface = load_interface()
47
 
48
 
 
54
  print(file)
55
  filepath = file.name
56
  sig = at.AudioSignal.salient_excerpt(
57
+ filepath, duration=interface.coarse.chunk_size_s
 
58
  )
59
  sig = interface.preprocess(sig)
60
 
 
115
  # build the mask
116
  mask = pmask.linear_random(z, _rand_mask_intensity)
117
  mask = pmask.mask_and(
118
+ mask, pmask.inpaint(z, interface.s2t(_prefix_s), interface.s2t(_suffix_s))
 
 
 
 
119
  )
120
  mask = pmask.mask_and(
121
+ mask, pmask.periodic_mask(z, _periodic_p, _periodic_w, random_roll=True)
 
 
 
 
 
122
  )
123
  if _onset_mask_width > 0:
124
  mask = pmask.mask_or(
 
127
  if _beat_mask_width > 0:
128
  beat_mask = interface.make_beat_mask(
129
  sig,
130
+ after_beat_s=(_beat_mask_width / 1000),
131
  mask_upbeats=not _beat_mask_downbeats,
132
  )
133
  mask = pmask.mask_and(mask, beat_mask)
 
159
 
160
  _top_p_val = _top_p if _top_p > 0 else None
161
  # save the mask as a txt file
162
+ np.savetxt(out_dir / "mask.txt", mask[:, 0, :].long().cpu().numpy())
163
 
164
  _seed_val = _seed if _seed > 0 else None
165
  zv, mask_z = interface.coarse_vamp(
166
  z,
167
  mask=mask,
168
  sampling_steps=_num_steps,
169
+ mask_temperature=_masktemp * 10,
170
  sampling_temperature=_sampletemp,
171
  return_mask=True,
172
  typical_filtering=_typical_filtering,
 
181
  if _use_coarse2fine:
182
  zv = interface.coarse_to_fine(
183
  zv,
184
+ mask_temperature=_masktemp * 10,
185
  sampling_temperature=_sampletemp,
186
  mask=mask,
187
  sampling_steps=_num_steps,
 
205
  else:
206
  return sig.path_to_file
207
 
208
+
209
  def _extract_and_call_vamp(data, return_mask):
210
  """Extract plain values from Gradio data dict so only picklable args cross the ZeroGPU boundary."""
211
  return _vamp(
 
236
  return_mask=return_mask,
237
  )
238
 
239
+
240
  def vamp(data):
241
  return _extract_and_call_vamp(data, return_mask=True)
242
 
243
+
244
  def api_vamp(data):
245
  return _extract_and_call_vamp(data, return_mask=False)
246
 
247
+
248
  def save_vamp(data):
249
  out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
250
  out_dir.mkdir(parents=True, exist_ok=True)
 
278
  yaml.dump(_data, f)
279
 
280
  import zipfile
281
+
282
  zip_path = str(out_dir.with_suffix(".zip"))
283
  with zipfile.ZipFile(zip_path, "w") as zf:
284
  for file in out_dir.iterdir():
 
302
  if _beat_mask_width > 0:
303
  beat_mask = interface.make_beat_mask(
304
  sig,
305
+ after_beat_s=(_beat_mask_width / 1000),
306
  )
307
  mask = pmask.mask_and(mask, beat_mask)
308
 
 
315
  gen_fn=interface.coarse.generate,
316
  )
317
 
 
318
  zv = interface.coarse_to_fine(
319
  zv,
320
  sampling_temperature=_sampletemp,
 
328
 
329
  return sig.path_to_file
330
 
 
331
 
332
+ with gr.Blocks() as demo:
333
  with gr.Row():
334
  with gr.Column():
335
  gr.Markdown("# VampNet Audio Vamping")
 
349
  """)
350
  with gr.Row():
351
  with gr.Column():
 
 
352
  manual_audio_upload = gr.File(
353
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
354
+ file_types=["audio"],
355
  )
356
  load_example_audio_button = gr.Button("or load example audio")
357
 
 
369
 
370
  # connect widgets
371
  load_example_audio_button.click(
372
+ fn=load_example_audio, inputs=[], outputs=[input_audio]
 
 
373
  )
374
 
375
  manual_audio_upload.change(
376
+ fn=load_audio, inputs=[manual_audio_upload], outputs=[input_audio]
 
 
377
  )
378
 
379
  # mask settings
380
  with gr.Column():
 
 
381
  presets = {
382
+ "unconditional": {
383
+ "periodic_p": 0,
384
+ "onset_mask_width": 0,
385
+ "beat_mask_width": 0,
386
+ "beat_mask_downbeats": False,
387
+ },
388
+ "slight periodic variation": {
389
+ "periodic_p": 5,
390
+ "onset_mask_width": 5,
391
+ "beat_mask_width": 0,
392
+ "beat_mask_downbeats": False,
393
+ },
394
+ "moderate periodic variation": {
395
+ "periodic_p": 13,
396
+ "onset_mask_width": 5,
397
+ "beat_mask_width": 0,
398
+ "beat_mask_downbeats": False,
399
+ },
400
+ "strong periodic variation": {
401
+ "periodic_p": 17,
402
+ "onset_mask_width": 5,
403
+ "beat_mask_width": 0,
404
+ "beat_mask_downbeats": False,
405
+ },
406
+ "very strong periodic variation": {
407
+ "periodic_p": 21,
408
+ "onset_mask_width": 5,
409
+ "beat_mask_width": 0,
410
+ "beat_mask_downbeats": False,
411
+ },
412
+ "beat-driven variation": {
413
+ "periodic_p": 0,
414
+ "onset_mask_width": 0,
415
+ "beat_mask_width": 50,
416
+ "beat_mask_downbeats": False,
417
+ },
418
+ "beat-driven variation (downbeats only)": {
419
+ "periodic_p": 0,
420
+ "onset_mask_width": 0,
421
+ "beat_mask_width": 50,
422
+ "beat_mask_downbeats": True,
423
+ },
424
+ "beat-driven variation (downbeats only, strong)": {
425
+ "periodic_p": 0,
426
+ "onset_mask_width": 0,
427
+ "beat_mask_width": 20,
428
+ "beat_mask_downbeats": True,
429
+ },
430
+ }
431
 
432
  preset = gr.Dropdown(
433
  label="preset",
 
445
  value=3,
446
  )
447
 
 
448
  onset_mask_width = gr.Slider(
449
  label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
450
  minimum=0,
 
460
  value=0,
461
  )
462
  beat_mask_downbeats = gr.Checkbox(
463
+ label="beat mask downbeats only?", value=False
 
464
  )
465
 
466
  n_mask_codebooks = gr.Number(
 
468
  value=9,
469
  )
470
 
 
471
  with gr.Accordion("extras ", open=False):
472
  pitch_shift_amt = gr.Slider(
473
  label="pitch shift amount (semitones)",
 
481
  label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
482
  minimum=0.0,
483
  maximum=1.0,
484
+ value=1.0,
485
  )
486
 
487
  periodic_w = gr.Slider(
 
516
  return tuple(presets[_preset].values())
517
 
518
  load_preset_button.click(
519
+ fn=load_preset, inputs=[preset], outputs=preset_outputs
 
 
520
  )
521
 
 
522
  with gr.Accordion("prefix/suffix prompts", open=False):
523
  prefix_s = gr.Slider(
524
  label="prefix hint length (seconds)",
525
  minimum=0.0,
526
  maximum=10.0,
527
+ value=0.0,
528
  )
529
  suffix_s = gr.Slider(
530
  label="suffix hint length (seconds)",
531
  minimum=0.0,
532
  maximum=10.0,
533
+ value=0.0,
534
  )
535
 
536
  masktemp = gr.Slider(
537
+ label="mask temperature", minimum=0.0, maximum=100.0, value=1.5
 
 
 
538
  )
539
  sampletemp = gr.Slider(
540
  label="sample temperature",
541
  minimum=0.1,
542
  maximum=10.0,
543
  value=1.0,
544
+ step=0.001,
545
  )
546
 
 
 
547
  with gr.Accordion("sampling settings", open=False):
548
  top_p = gr.Slider(
549
+ label="top p (0.0 = off)", minimum=0.0, maximum=1.0, value=0.0
 
 
 
 
 
 
 
550
  )
551
+ typical_filtering = gr.Checkbox(label="typical filtering ", value=False)
552
  typical_mass = gr.Slider(
553
  label="typical mass (should probably stay between 0.1 and 0.5)",
554
  minimum=0.01,
555
  maximum=0.99,
556
+ value=0.15,
557
  )
558
  typical_min_tokens = gr.Slider(
559
  label="typical min tokens (should probably stay between 1 and 256)",
560
  minimum=1,
561
  maximum=256,
562
  step=1,
563
+ value=64,
564
  )
565
  sample_cutoff = gr.Slider(
566
  label="sample cutoff",
567
  minimum=0.0,
568
  maximum=1.0,
569
  value=0.5,
570
+ step=0.01,
571
  )
572
 
573
  use_coarse2fine = gr.Checkbox(
574
+ label="use coarse2fine", value=True, visible=False
 
 
575
  )
576
 
577
  num_steps = gr.Slider(
 
579
  minimum=1,
580
  maximum=128,
581
  step=1,
582
+ value=36,
583
  )
584
 
585
  dropout = gr.Slider(
586
+ label="mask dropout", minimum=0.0, maximum=1.0, step=0.01, value=0.0
 
 
 
 
587
  )
588
 
 
589
  seed = gr.Number(
590
  label="seed (0 for random)",
591
  value=0,
592
  precision=0,
593
  )
594
 
 
 
595
  # mask settings
596
  with gr.Column():
 
597
  # lora_choice = gr.Dropdown(
598
  # label="lora choice",
599
  # choices=list(loras.keys()),
 
603
 
604
  vamp_button = gr.Button("generate (vamp)!!!")
605
  output_audio = gr.Audio(
606
+ label="output audio", interactive=False, type="filepath"
 
 
607
  )
608
 
609
  notes_text = gr.Textbox(
610
  label="type any notes about the generated audio here",
611
  value="",
612
+ interactive=True,
613
  )
614
  save_button = gr.Button("save vamp")
615
  download_file = gr.File(
616
+ label="vamp to download will appear here", interactive=False
 
617
  )
618
  use_as_input_button = gr.Button("use output as input")
619
 
620
  thank_you = gr.Markdown("")
621
 
 
622
  _inputs = {
623
+ input_audio,
624
+ num_steps,
625
+ masktemp,
626
+ sampletemp,
627
+ top_p,
628
+ prefix_s,
629
+ suffix_s,
630
+ rand_mask_intensity,
631
+ periodic_p,
632
+ periodic_w,
633
+ n_conditioning_codebooks,
634
+ dropout,
635
+ use_coarse2fine,
636
+ stretch_factor,
637
+ onset_mask_width,
638
+ typical_filtering,
639
+ typical_mass,
640
+ typical_min_tokens,
641
+ beat_mask_width,
642
+ beat_mask_downbeats,
643
+ seed,
644
+ # lora_choice,
645
+ n_mask_codebooks,
646
+ pitch_shift_amt,
647
+ sample_cutoff,
648
+ }
649
 
650
  # connect widgets
651
  vamp_button.click(
 
656
 
657
  api_vamp_button = gr.Button("api vamp", visible=False)
658
  api_vamp_button.click(
659
+ fn=api_vamp, inputs=_inputs, outputs=[output_audio], api_name="vamp"
 
 
 
660
  )
661
 
662
  use_as_input_button.click(
663
+ fn=lambda x: x, inputs=[output_audio], outputs=[input_audio]
 
 
664
  )
665
 
666
  save_button.click(
667
  fn=save_vamp,
668
  inputs=_inputs | {notes_text, output_audio},
669
+ outputs=[thank_you, download_file],
670
  )
671
 
672
 
vampnet/scripts/exp/eval.py CHANGED
@@ -1,20 +1,18 @@
1
  from pathlib import Path
2
- import os
3
- from functools import partial
4
 
5
- from frechet_audio_distance import FrechetAudioDistance
6
- import pandas
7
  import argbind
 
 
8
  import torch
 
 
9
  from tqdm import tqdm
10
 
11
- import audiotools
12
- from audiotools import AudioSignal
13
 
14
  @argbind.bind(without_prefix=True)
15
  def eval(
16
  exp_dir: str = None,
17
- baseline_key: str = "baseline",
18
  audio_ext: str = ".wav",
19
  ):
20
  assert exp_dir is not None
@@ -26,9 +24,9 @@ def eval(
26
  # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
27
  mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
28
  frechet = FrechetAudioDistance(
29
- use_pca=False,
30
  use_activation=False,
31
- verbose=True,
32
  audio_load_worker=4,
33
  )
34
  frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
@@ -36,19 +34,25 @@ def eval(
36
  # figure out what conditions we have
37
  conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
38
 
39
- assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}"
 
 
40
  conditions.remove(baseline_key)
41
 
42
  print(f"Found {len(conditions)} conditions in {exp_dir}")
43
  print(f"conditions: {conditions}")
44
 
45
- baseline_dir = exp_dir / baseline_key
46
- baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
 
 
47
 
48
  metrics = []
49
  for condition in tqdm(conditions):
50
  cond_dir = exp_dir / condition
51
- cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
 
 
52
 
53
  print(f"computing fad for {baseline_dir} and {cond_dir}")
54
  frechet_score = frechet.score(baseline_dir, cond_dir)
@@ -57,11 +61,15 @@ def eval(
57
  num_files = min(len(baseline_files), len(cond_files))
58
  baseline_files = baseline_files[:num_files]
59
  cond_files = cond_files[:num_files]
60
- assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
 
 
61
 
62
  def process(baseline_file, cond_file):
63
  # make sure the files match (same name)
64
- assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match"
 
 
65
 
66
  # load the files
67
  baseline_sig = AudioSignal(str(baseline_file))
@@ -74,7 +82,9 @@ def eval(
74
  if "inpaint" in condition:
75
  ctx_amt = float(condition.split("_")[-1])
76
  ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
77
- print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
 
 
78
  cond_sig.trim(ctx_samples, ctx_samples)
79
  baseline_sig.trim(ctx_samples, ctx_samples)
80
 
@@ -88,15 +98,18 @@ def eval(
88
  "file": baseline_file.stem,
89
  }
90
 
91
- print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}")
92
- metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)))
 
 
 
 
93
 
94
  metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
95
 
96
-
97
  for mk in metric_keys:
98
  stat = pandas.DataFrame(metrics)
99
- stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std'])
100
  stat.to_csv(exp_dir / f"stats-{mk}.csv")
101
 
102
  df = pandas.DataFrame(metrics)
@@ -107,4 +120,4 @@ if __name__ == "__main__":
107
  args = argbind.parse_args()
108
 
109
  with argbind.scope(args):
110
- eval()
 
1
  from pathlib import Path
 
 
2
 
 
 
3
  import argbind
4
+ import audiotools
5
+ import pandas
6
  import torch
7
+ from audiotools import AudioSignal
8
+ from frechet_audio_distance import FrechetAudioDistance
9
  from tqdm import tqdm
10
 
 
 
11
 
12
  @argbind.bind(without_prefix=True)
13
  def eval(
14
  exp_dir: str = None,
15
+ baseline_key: str = "baseline",
16
  audio_ext: str = ".wav",
17
  ):
18
  assert exp_dir is not None
 
24
  # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
25
  mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
26
  frechet = FrechetAudioDistance(
27
+ use_pca=False,
28
  use_activation=False,
29
+ verbose=True,
30
  audio_load_worker=4,
31
  )
32
  frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
 
34
  # figure out what conditions we have
35
  conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
36
 
37
+ assert baseline_key in conditions, (
38
+ f"baseline_key {baseline_key} not found in {exp_dir}"
39
+ )
40
  conditions.remove(baseline_key)
41
 
42
  print(f"Found {len(conditions)} conditions in {exp_dir}")
43
  print(f"conditions: {conditions}")
44
 
45
+ baseline_dir = exp_dir / baseline_key
46
+ baseline_files = sorted(
47
+ list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)
48
+ )
49
 
50
  metrics = []
51
  for condition in tqdm(conditions):
52
  cond_dir = exp_dir / condition
53
+ cond_files = sorted(
54
+ list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)
55
+ )
56
 
57
  print(f"computing fad for {baseline_dir} and {cond_dir}")
58
  frechet_score = frechet.score(baseline_dir, cond_dir)
 
61
  num_files = min(len(baseline_files), len(cond_files))
62
  baseline_files = baseline_files[:num_files]
63
  cond_files = cond_files[:num_files]
64
+ assert len(list(baseline_files)) == len(list(cond_files)), (
65
+ f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}"
66
+ )
67
 
68
  def process(baseline_file, cond_file):
69
  # make sure the files match (same name)
70
+ assert baseline_file.stem == cond_file.stem, (
71
+ f"baseline file {baseline_file} and cond file {cond_file} do not match"
72
+ )
73
 
74
  # load the files
75
  baseline_sig = AudioSignal(str(baseline_file))
 
82
  if "inpaint" in condition:
83
  ctx_amt = float(condition.split("_")[-1])
84
  ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
85
+ print(
86
+ f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}"
87
+ )
88
  cond_sig.trim(ctx_samples, ctx_samples)
89
  baseline_sig.trim(ctx_samples, ctx_samples)
90
 
 
98
  "file": baseline_file.stem,
99
  }
100
 
101
+ print(
102
+ f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}"
103
+ )
104
+ metrics.extend(
105
+ tqdm(map(process, baseline_files, cond_files), total=len(baseline_files))
106
+ )
107
 
108
  metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")]
109
 
 
110
  for mk in metric_keys:
111
  stat = pandas.DataFrame(metrics)
112
+ stat = stat.groupby(["condition"])[mk].agg(["mean", "count", "std"])
113
  stat.to_csv(exp_dir / f"stats-{mk}.csv")
114
 
115
  df = pandas.DataFrame(metrics)
 
120
  args = argbind.parse_args()
121
 
122
  with argbind.scope(args):
123
+ eval()
vampnet/scripts/exp/experiment.py CHANGED
@@ -1,48 +1,44 @@
1
- from pathlib import Path
2
  import random
3
- from typing import List
4
- import tempfile
5
  import subprocess
 
 
6
 
7
  import argbind
8
- from tqdm import tqdm
9
  import torch
 
10
 
11
- from vampnet.interface import Interface
12
  from vampnet import mask as pmask
13
- import audiotools as at
14
 
15
  Interface: Interface = argbind.bind(Interface)
16
 
17
 
18
-
19
- def calculate_bitrate(
20
- interface, num_codebooks,
21
- downsample_factor
22
- ):
23
  bit_width = 10
24
  sr = interface.codec.sample_rate
25
  hop = interface.codec.hop_size
26
  rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
27
  return rate
28
 
 
29
  def baseline(sig, interface):
30
  return interface.preprocess(sig)
31
 
 
32
  def reconstructed(sig, interface):
33
- return interface.to_signal(
34
- interface.encode(sig)
35
- )
36
 
37
  def coarse2fine(sig, interface):
38
  z = interface.encode(sig)
39
- z = z[:, :interface.c2f.n_conditioning_codebooks, :]
40
 
41
  z = interface.coarse_to_fine(z)
42
  return interface.to_signal(z)
43
 
44
- class CoarseCond:
45
 
 
46
  def __init__(self, num_conditioning_codebooks, downsample_factor):
47
  self.num_conditioning_codebooks = num_conditioning_codebooks
48
  self.downsample_factor = downsample_factor
@@ -57,83 +53,85 @@ class CoarseCond:
57
  zv = interface.coarse_to_fine(zv)
58
  return interface.to_signal(zv)
59
 
 
60
  def opus(sig, interface, bitrate=128):
61
  sig = interface.preprocess(sig)
62
-
63
  with tempfile.NamedTemporaryFile(suffix=".wav") as f:
64
  sig.write(f.name)
65
 
66
  opus_name = Path(f.name).with_suffix(".opus")
67
  # convert to opus
68
  cmd = [
69
- "ffmpeg", "-y", "-i", f.name,
70
- "-c:a", "libopus",
71
- "-b:a", f"{bitrate}",
72
- opus_name
 
 
 
 
 
73
  ]
74
  subprocess.run(cmd, check=True)
75
 
76
  # convert back to wav
77
  output_name = Path(f"{f.name}-opus").with_suffix(".wav")
78
- cmd = [
79
- "ffmpeg", "-y", "-i", opus_name,
80
- output_name
81
- ]
82
 
83
  subprocess.run(cmd, check=True)
84
 
85
- sig = at.AudioSignal(
86
- output_name,
87
- sample_rate=sig.sample_rate
88
- )
89
  return sig
90
 
 
91
  def mask_ratio_1_step(ratio=1.0):
92
  def wrapper(sig, interface):
93
  z = interface.encode(sig)
94
  mask = pmask.linear_random(z, ratio)
95
  zv = interface.coarse_vamp(
96
- z,
97
  mask,
98
- sampling_steps=1,
99
  )
100
 
101
  return interface.to_signal(zv)
 
102
  return wrapper
103
 
 
104
  def num_sampling_steps(num_steps=1):
105
  def wrapper(sig, interface: Interface):
106
  z = interface.encode(sig)
107
  mask = pmask.periodic_mask(z, 16)
108
  zv = interface.coarse_vamp(
109
- z,
110
  mask,
111
- sampling_steps=num_steps,
112
  )
113
 
114
  zv = interface.coarse_to_fine(zv)
115
  return interface.to_signal(zv)
 
116
  return wrapper
117
 
 
118
  def beat_mask(ctx_time):
119
  def wrapper(sig, interface):
120
  beat_mask = interface.make_beat_mask(
121
- sig,
122
- before_beat_s=ctx_time/2,
123
- after_beat_s=ctx_time/2,
124
- invert=True
125
  )
126
 
127
  z = interface.encode(sig)
128
 
129
- zv = interface.coarse_vamp(
130
- z, beat_mask
131
- )
132
 
133
  zv = interface.coarse_to_fine(zv)
134
  return interface.to_signal(zv)
 
135
  return wrapper
136
 
 
137
  def inpaint(ctx_time):
138
  def wrapper(sig, interface: Interface):
139
  z = interface.encode(sig)
@@ -141,22 +139,22 @@ def inpaint(ctx_time):
141
 
142
  zv = interface.coarse_vamp(z, mask)
143
  zv = interface.coarse_to_fine(zv)
144
-
145
  return interface.to_signal(zv)
 
146
  return wrapper
147
 
 
148
  def token_noise(noise_amt):
149
  def wrapper(sig, interface: Interface):
150
  z = interface.encode(sig)
151
  mask = pmask.random(z, noise_amt)
152
- z = torch.where(
153
- mask,
154
- torch.randint_like(z, 0, interface.coarse.vocab_size),
155
- z
156
- )
157
  return interface.to_signal(z)
 
158
  return wrapper
159
 
 
160
  EXP_REGISTRY = {}
161
 
162
  EXP_REGISTRY["gen-compression"] = {
@@ -164,57 +162,63 @@ EXP_REGISTRY["gen-compression"] = {
164
  "reconstructed": reconstructed,
165
  "coarse2fine": coarse2fine,
166
  **{
167
- f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x)
168
- for (n, x) in (
169
- (1, 1), # 1 codebook, no downsampling
170
- (4, 4), # 4 codebooks, downsampled 4x
171
- (4, 16), # 4 codebooks, downsampled 16x
172
- (4, 32), # 4 codebooks, downsampled 16x
173
- )
174
- },
175
- **{
176
- f"token_noise_{x}": mask_ratio_1_step(ratio=x)
177
- for x in [0.25, 0.5, 0.75]
178
  },
179
-
180
  }
181
 
182
 
183
  EXP_REGISTRY["sampling-steps"] = {
184
  # "codec": reconstructed,
185
- **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
186
  }
187
 
188
 
189
  EXP_REGISTRY["musical-sampling"] = {
190
- **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
191
- **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
 
 
 
 
 
 
192
  }
193
 
 
194
  @argbind.bind(without_prefix=True)
195
  def main(
196
- sources=[
197
- "/media/CHONK/hugo/spotdl/val",
198
- ],
199
- output_dir: str = "./samples",
200
- max_excerpts: int = 2000,
201
- exp_type: str = "gen-compression",
202
- seed: int = 0,
203
- ext: str = [".mp3"],
204
- ):
205
  at.util.seed(seed)
206
  interface = Interface()
207
 
208
- output_dir = Path(output_dir)
209
  output_dir.mkdir(exist_ok=True, parents=True)
210
 
211
- from audiotools.data.datasets import AudioLoader, AudioDataset
212
 
213
  loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
214
- dataset = AudioDataset(loader,
215
- sample_rate=interface.codec.sample_rate,
216
- duration=interface.coarse.chunk_size_s,
217
- n_examples=max_excerpts,
 
218
  without_replacement=True,
219
  )
220
 
@@ -223,7 +227,6 @@ def main(
223
  else:
224
  raise ValueError(f"Unknown exp_type {exp_type}")
225
 
226
-
227
  indices = list(range(max_excerpts))
228
  random.shuffle(indices)
229
  for i in tqdm(indices):
@@ -237,8 +240,7 @@ def main(
237
 
238
  sig = dataset[i]["signal"]
239
  results = {
240
- name: cond(sig, interface).cpu()
241
- for name, cond in SAMPLE_CONDS.items()
242
  }
243
 
244
  for name, sig in results.items():
@@ -247,6 +249,7 @@ def main(
247
 
248
  sig.write(o_dir / f"{i}.wav")
249
 
 
250
  if __name__ == "__main__":
251
  args = argbind.parse_args()
252
 
 
 
1
  import random
 
 
2
  import subprocess
3
+ import tempfile
4
+ from pathlib import Path
5
 
6
  import argbind
7
+ import audiotools as at
8
  import torch
9
+ from tqdm import tqdm
10
 
 
11
  from vampnet import mask as pmask
12
+ from vampnet.interface import Interface
13
 
14
  Interface: Interface = argbind.bind(Interface)
15
 
16
 
17
+ def calculate_bitrate(interface, num_codebooks, downsample_factor):
 
 
 
 
18
  bit_width = 10
19
  sr = interface.codec.sample_rate
20
  hop = interface.codec.hop_size
21
  rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor)
22
  return rate
23
 
24
+
25
  def baseline(sig, interface):
26
  return interface.preprocess(sig)
27
 
28
+
29
  def reconstructed(sig, interface):
30
+ return interface.to_signal(interface.encode(sig))
31
+
 
32
 
33
  def coarse2fine(sig, interface):
34
  z = interface.encode(sig)
35
+ z = z[:, : interface.c2f.n_conditioning_codebooks, :]
36
 
37
  z = interface.coarse_to_fine(z)
38
  return interface.to_signal(z)
39
 
 
40
 
41
+ class CoarseCond:
42
  def __init__(self, num_conditioning_codebooks, downsample_factor):
43
  self.num_conditioning_codebooks = num_conditioning_codebooks
44
  self.downsample_factor = downsample_factor
 
53
  zv = interface.coarse_to_fine(zv)
54
  return interface.to_signal(zv)
55
 
56
+
57
  def opus(sig, interface, bitrate=128):
58
  sig = interface.preprocess(sig)
59
+
60
  with tempfile.NamedTemporaryFile(suffix=".wav") as f:
61
  sig.write(f.name)
62
 
63
  opus_name = Path(f.name).with_suffix(".opus")
64
  # convert to opus
65
  cmd = [
66
+ "ffmpeg",
67
+ "-y",
68
+ "-i",
69
+ f.name,
70
+ "-c:a",
71
+ "libopus",
72
+ "-b:a",
73
+ f"{bitrate}",
74
+ opus_name,
75
  ]
76
  subprocess.run(cmd, check=True)
77
 
78
  # convert back to wav
79
  output_name = Path(f"{f.name}-opus").with_suffix(".wav")
80
+ cmd = ["ffmpeg", "-y", "-i", opus_name, output_name]
 
 
 
81
 
82
  subprocess.run(cmd, check=True)
83
 
84
+ sig = at.AudioSignal(output_name, sample_rate=sig.sample_rate)
 
 
 
85
  return sig
86
 
87
+
88
  def mask_ratio_1_step(ratio=1.0):
89
  def wrapper(sig, interface):
90
  z = interface.encode(sig)
91
  mask = pmask.linear_random(z, ratio)
92
  zv = interface.coarse_vamp(
93
+ z,
94
  mask,
95
+ sampling_steps=1,
96
  )
97
 
98
  return interface.to_signal(zv)
99
+
100
  return wrapper
101
 
102
+
103
  def num_sampling_steps(num_steps=1):
104
  def wrapper(sig, interface: Interface):
105
  z = interface.encode(sig)
106
  mask = pmask.periodic_mask(z, 16)
107
  zv = interface.coarse_vamp(
108
+ z,
109
  mask,
110
+ sampling_steps=num_steps,
111
  )
112
 
113
  zv = interface.coarse_to_fine(zv)
114
  return interface.to_signal(zv)
115
+
116
  return wrapper
117
 
118
+
119
  def beat_mask(ctx_time):
120
  def wrapper(sig, interface):
121
  beat_mask = interface.make_beat_mask(
122
+ sig, before_beat_s=ctx_time / 2, after_beat_s=ctx_time / 2, invert=True
 
 
 
123
  )
124
 
125
  z = interface.encode(sig)
126
 
127
+ zv = interface.coarse_vamp(z, beat_mask)
 
 
128
 
129
  zv = interface.coarse_to_fine(zv)
130
  return interface.to_signal(zv)
131
+
132
  return wrapper
133
 
134
+
135
  def inpaint(ctx_time):
136
  def wrapper(sig, interface: Interface):
137
  z = interface.encode(sig)
 
139
 
140
  zv = interface.coarse_vamp(z, mask)
141
  zv = interface.coarse_to_fine(zv)
142
+
143
  return interface.to_signal(zv)
144
+
145
  return wrapper
146
 
147
+
148
  def token_noise(noise_amt):
149
  def wrapper(sig, interface: Interface):
150
  z = interface.encode(sig)
151
  mask = pmask.random(z, noise_amt)
152
+ z = torch.where(mask, torch.randint_like(z, 0, interface.coarse.vocab_size), z)
 
 
 
 
153
  return interface.to_signal(z)
154
+
155
  return wrapper
156
 
157
+
158
  EXP_REGISTRY = {}
159
 
160
  EXP_REGISTRY["gen-compression"] = {
 
162
  "reconstructed": reconstructed,
163
  "coarse2fine": coarse2fine,
164
  **{
165
+ f"{n}_codebooks_downsampled_{x}x": CoarseCond(
166
+ num_conditioning_codebooks=n, downsample_factor=x
167
+ )
168
+ for (n, x) in (
169
+ (1, 1), # 1 codebook, no downsampling
170
+ (4, 4), # 4 codebooks, downsampled 4x
171
+ (4, 16), # 4 codebooks, downsampled 16x
172
+ (4, 32), # 4 codebooks, downsampled 16x
173
+ )
 
 
174
  },
175
+ **{f"token_noise_{x}": mask_ratio_1_step(ratio=x) for x in [0.25, 0.5, 0.75]},
176
  }
177
 
178
 
179
  EXP_REGISTRY["sampling-steps"] = {
180
  # "codec": reconstructed,
181
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]},
182
  }
183
 
184
 
185
  EXP_REGISTRY["musical-sampling"] = {
186
+ **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
187
+ **{
188
+ f"inpaint_{t}": inpaint(t)
189
+ for t in [
190
+ 0.5,
191
+ 1.0,
192
+ ]
193
+ }, # multiply these by 2 (they go left and right)
194
  }
195
 
196
+
197
  @argbind.bind(without_prefix=True)
198
  def main(
199
+ sources=[
200
+ "/media/CHONK/hugo/spotdl/val",
201
+ ],
202
+ output_dir: str = "./samples",
203
+ max_excerpts: int = 2000,
204
+ exp_type: str = "gen-compression",
205
+ seed: int = 0,
206
+ ext: str = [".mp3"],
207
+ ):
208
  at.util.seed(seed)
209
  interface = Interface()
210
 
211
+ output_dir = Path(output_dir)
212
  output_dir.mkdir(exist_ok=True, parents=True)
213
 
214
+ from audiotools.data.datasets import AudioDataset, AudioLoader
215
 
216
  loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext)
217
+ dataset = AudioDataset(
218
+ loader,
219
+ sample_rate=interface.codec.sample_rate,
220
+ duration=interface.coarse.chunk_size_s,
221
+ n_examples=max_excerpts,
222
  without_replacement=True,
223
  )
224
 
 
227
  else:
228
  raise ValueError(f"Unknown exp_type {exp_type}")
229
 
 
230
  indices = list(range(max_excerpts))
231
  random.shuffle(indices)
232
  for i in tqdm(indices):
 
240
 
241
  sig = dataset[i]["signal"]
242
  results = {
243
+ name: cond(sig, interface).cpu() for name, cond in SAMPLE_CONDS.items()
 
244
  }
245
 
246
  for name, sig in results.items():
 
249
 
250
  sig.write(o_dir / f"{i}.wav")
251
 
252
+
253
  if __name__ == "__main__":
254
  args = argbind.parse_args()
255
 
vampnet/scripts/exp/fine_tune.py CHANGED
@@ -1,20 +1,21 @@
1
- import argbind
2
  from pathlib import Path
3
- import yaml
4
  from typing import List
5
 
6
-
7
-
8
 
9
  """example output: (yaml)
10
 
11
  """
12
 
 
13
  @argbind.bind(without_prefix=True, positional=True)
14
  def fine_tune(audio_files_or_folders: List[str], name: str):
15
 
16
  conf_dir = Path("conf")
17
- assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
 
 
18
 
19
  conf_dir = conf_dir / "generated"
20
  conf_dir.mkdir(exist_ok=True)
@@ -35,7 +36,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
- "fine_tune_checkpoint": "./models/vampnet/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
@@ -44,15 +45,13 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
44
  "train/AudioLoader.sources": audio_files_or_folders,
45
  "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
- "fine_tune_checkpoint": "./models/vampnet/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
  "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
52
-
53
  "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
55
-
56
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
57
  "AudioLoader.sources": [audio_files_or_folders],
58
  }
@@ -63,19 +62,17 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
63
 
64
  with open(finetune_dir / "coarse.yml", "w") as f:
65
  yaml.dump(finetune_coarse_conf, f)
66
-
67
- with open(finetune_dir / "interface.yml", "w") as f:
68
  yaml.dump(interface_conf, f)
69
 
 
 
 
70
 
71
- print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
72
 
73
  if __name__ == "__main__":
74
  args = argbind.parse_args()
75
 
76
  with argbind.scope(args):
77
  fine_tune()
78
-
79
-
80
-
81
-
 
 
1
  from pathlib import Path
 
2
  from typing import List
3
 
4
+ import argbind
5
+ import yaml
6
 
7
  """example output: (yaml)
8
 
9
  """
10
 
11
+
12
  @argbind.bind(without_prefix=True, positional=True)
13
  def fine_tune(audio_files_or_folders: List[str], name: str):
14
 
15
  conf_dir = Path("conf")
16
+ assert conf_dir.exists(), (
17
+ "conf directory not found. are you in the vampnet directory?"
18
+ )
19
 
20
  conf_dir = conf_dir / "generated"
21
  conf_dir.mkdir(exist_ok=True)
 
36
  "AudioDataset.duration": 3.0,
37
  "AudioDataset.loudness_cutoff": -40.0,
38
  "save_path": f"./runs/{name}/c2f",
39
+ "fine_tune_checkpoint": "./models/vampnet/c2f.pth",
40
  }
41
 
42
  finetune_coarse_conf = {
 
45
  "train/AudioLoader.sources": audio_files_or_folders,
46
  "val/AudioLoader.sources": audio_files_or_folders,
47
  "save_path": f"./runs/{name}/coarse",
48
+ "fine_tune_checkpoint": "./models/vampnet/coarse.pth",
49
  }
50
 
51
  interface_conf = {
52
  "Interface.coarse_ckpt": f"./runs/{name}/coarse/latest/vampnet/weights.pth",
 
53
  "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/latest/vampnet/weights.pth",
54
  "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
 
55
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
56
  "AudioLoader.sources": [audio_files_or_folders],
57
  }
 
62
 
63
  with open(finetune_dir / "coarse.yml", "w") as f:
64
  yaml.dump(finetune_coarse_conf, f)
65
+
66
+ with open(finetune_dir / "interface.yml", "w") as f:
67
  yaml.dump(interface_conf, f)
68
 
69
+ print(
70
+ f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` "
71
+ )
72
 
 
73
 
74
  if __name__ == "__main__":
75
  args = argbind.parse_args()
76
 
77
  with argbind.scope(args):
78
  fine_tune()
 
 
 
 
vampnet/scripts/exp/train.py CHANGED
@@ -1,36 +1,33 @@
1
  import os
2
  import sys
3
  import warnings
 
4
  from pathlib import Path
5
  from typing import Optional
6
- from dataclasses import dataclass
7
 
8
  import argbind
9
  import audiotools as at
 
10
  import torch
 
11
  import torch.nn as nn
12
  from audiotools import AudioSignal
13
  from audiotools.data import transforms
 
14
  from einops import rearrange
 
 
 
15
  from rich import pretty
16
  from rich.traceback import install
17
  from torch.utils.tensorboard import SummaryWriter
18
 
19
  import vampnet
20
- from vampnet.modules.transformer import VampNet
21
- from vampnet.util import codebook_unflatten, codebook_flatten
22
  from vampnet import mask as pmask
23
- # from dac.model.dac import DAC
24
- from lac.model.lac import LAC as DAC
25
-
26
- from audiotools.ml.decorators import (
27
- timer, Tracker, when
28
- )
29
-
30
- import loralib as lora
31
 
32
- import torch._dynamo
33
- torch._dynamo.config.verbose=True
34
 
35
 
36
  # Enable cudnn autotuner to speed up training
@@ -50,11 +47,15 @@ AdamW = argbind.bind(torch.optim.AdamW)
50
  NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
51
 
52
  # transforms
53
- filter_fn = lambda fn: hasattr(fn, "transform") and fn.__qualname__ not in [
54
- "BaseTransform",
55
- "Compose",
56
- "Choose",
57
- ]
 
 
 
 
58
  tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
59
 
60
  # model
@@ -106,13 +107,14 @@ def flip_coin(shape, p, rng):
106
 
107
 
108
  def num_params_hook(o, p):
109
- return o + f" {p/1e6:<.3f}M params."
110
 
111
 
112
  def add_num_params_repr_hook(model):
113
- import numpy as np
114
  from functools import partial
115
 
 
 
116
  for n, m in model.named_modules():
117
  o = m.extra_repr()
118
  p = sum([np.prod(p.size()) for p in m.parameters()])
@@ -149,6 +151,7 @@ def accuracy(
149
 
150
  return accuracy
151
 
 
152
  def _metrics(z_hat, r, target, flat_mask, output):
153
  for r_range in [(0, 0.5), (0.5, 1.0)]:
154
  unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
@@ -219,7 +222,7 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
219
  mask = pmask.random(z, r)
220
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
221
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
222
-
223
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
224
 
225
  dtype = torch.bfloat16 if accel.amp else None
@@ -246,13 +249,11 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
246
  output=output,
247
  )
248
 
249
-
250
  accel.backward(output["loss"])
251
 
252
  output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
253
  output["other/batch_size"] = z.shape[0]
254
 
255
-
256
  accel.scaler.unscale_(state.optimizer)
257
  output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
258
  state.model.parameters(), state.grad_clip_val
@@ -264,7 +265,6 @@ def train_loop(state: State, batch: dict, accel: Accelerator):
264
  state.scheduler.step()
265
  accel.update()
266
 
267
-
268
  return {k: v for k, v in sorted(output.items())}
269
 
270
 
@@ -295,9 +295,7 @@ def val_loop(state: State, batch: dict, accel: Accelerator):
295
  z[:, vn.n_conditioning_codebooks :, :],
296
  )
297
 
298
- flat_mask = codebook_flatten(
299
- mask[:, vn.n_conditioning_codebooks :, :]
300
- )
301
 
302
  output = {}
303
  # replace target with ignore index for masked tokens
@@ -338,16 +336,16 @@ def checkpoint(state, save_iters, save_path, fine_tune):
338
  tags.append(f"{state.tracker.step // 1000}k")
339
 
340
  if state.tracker.is_best("val", "loss"):
341
- state.tracker.print(f"Best model so far")
342
  tags.append("best")
343
 
344
  if fine_tune:
345
- for tag in tags:
346
- # save the lora model
347
  (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
348
  torch.save(
349
- lora.lora_state_dict(accel.unwrap(state.model)),
350
- f"{save_path}/{tag}/lora.pth"
351
  )
352
 
353
  for tag in tags:
@@ -383,7 +381,7 @@ def save_sampled(state, z, writer):
383
 
384
  def save_imputation(state, z, val_idx, writer):
385
  n_prefix = int(z.shape[-1] * 0.25)
386
- n_suffix = int(z.shape[-1] * 0.25)
387
 
388
  vn = accel.unwrap(state.model)
389
 
@@ -402,8 +400,8 @@ def save_imputation(state, z, val_idx, writer):
402
  time_steps=z.shape[-1],
403
  start_tokens=z[i][None, ...],
404
  mask=mask[i][None, ...],
405
- )
406
- )
407
  imputed = AudioSignal.batch(imputed)
408
 
409
  for i in range(len(val_idx)):
@@ -443,7 +441,6 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
443
 
444
  r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
445
 
446
-
447
  mask = pmask.random(z, r)
448
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
449
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
@@ -479,7 +476,6 @@ def save_samples(state: State, val_idx: int, writer: SummaryWriter):
479
  save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
480
 
481
 
482
-
483
  @argbind.bind(without_prefix=True)
484
  def load(
485
  args,
@@ -499,11 +495,12 @@ def load(
499
  if args["fine_tune"]:
500
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
501
  model = torch.compile(
502
- VampNet.load(location=Path(fine_tune_checkpoint),
503
- map_location="cpu",
 
504
  )
505
  )
506
-
507
  if resume:
508
  kwargs = {
509
  "folder": f"{save_path}/{tag}",
@@ -518,16 +515,11 @@ def load(
518
  f"Could not find a VampNet checkpoint in {kwargs['folder']}"
519
  )
520
 
521
-
522
-
523
-
524
  model = torch.compile(VampNet()) if model is None else model
525
  model = accel.prepare_model(model)
526
 
527
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
528
- assert (
529
- accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
530
- )
531
 
532
  optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
533
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
@@ -538,13 +530,13 @@ def load(
538
  scheduler.load_state_dict(v_extra["scheduler.pth"])
539
  if "tracker.pth" in v_extra:
540
  tracker.load_state_dict(v_extra["tracker.pth"])
541
-
542
  criterion = CrossEntropyLoss()
543
 
544
  sample_rate = codec.sample_rate
545
 
546
  # a better rng for sampling from our schedule
547
- rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
548
 
549
  # log a model summary w/ num params
550
  if accel.local_rank == 0:
@@ -577,13 +569,19 @@ def train(
577
  codec_ckpt: str = None,
578
  save_path: str = "ckpt",
579
  num_iters: int = int(1000e6),
580
- save_iters: list = [10000, 50000, 100000, 300000, 500000,],
581
- sample_freq: int = 10000,
 
 
 
 
 
 
582
  val_freq: int = 1000,
583
  batch_size: int = 12,
584
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
585
  num_workers: int = 10,
586
- fine_tune: bool = False,
587
  ):
588
  assert codec_ckpt is not None, "codec_ckpt is required"
589
 
@@ -600,11 +598,7 @@ def train(
600
  )
601
 
602
  # load the codec model
603
- state: State = load(
604
- args=args,
605
- accel=accel,
606
- tracker=tracker,
607
- save_path=save_path)
608
  print("initialized state.")
609
 
610
  train_dataloader = accel.prepare_dataloader(
@@ -624,8 +618,6 @@ def train(
624
  )
625
  print("initialized dataloader.")
626
 
627
-
628
-
629
  if fine_tune:
630
  lora.mark_only_lora_as_trainable(state.model)
631
  print("marked only lora as trainable.")
@@ -658,10 +650,11 @@ def train(
658
  if tracker.step % val_freq == 0 or last_iter:
659
  validate(state, val_dataloader, accel)
660
  checkpoint(
661
- state=state,
662
- save_iters=save_iters,
663
- save_path=save_path,
664
- fine_tune=fine_tune)
 
665
 
666
  # Reset validation progress bar, print summary since last validation.
667
  tracker.done("val", f"Iteration {tracker.step}")
 
1
  import os
2
  import sys
3
  import warnings
4
+ from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Optional
 
7
 
8
  import argbind
9
  import audiotools as at
10
+ import loralib as lora
11
  import torch
12
+ import torch._dynamo
13
  import torch.nn as nn
14
  from audiotools import AudioSignal
15
  from audiotools.data import transforms
16
+ from audiotools.ml.decorators import Tracker, timer, when
17
  from einops import rearrange
18
+
19
+ # from dac.model.dac import DAC
20
+ from lac.model.lac import LAC as DAC
21
  from rich import pretty
22
  from rich.traceback import install
23
  from torch.utils.tensorboard import SummaryWriter
24
 
25
  import vampnet
 
 
26
  from vampnet import mask as pmask
27
+ from vampnet.modules.transformer import VampNet
28
+ from vampnet.util import codebook_flatten, codebook_unflatten
 
 
 
 
 
 
29
 
30
+ torch._dynamo.config.verbose = True
 
31
 
32
 
33
  # Enable cudnn autotuner to speed up training
 
47
  NoamScheduler = argbind.bind(vampnet.scheduler.NoamScheduler)
48
 
49
  # transforms
50
+ filter_fn = lambda fn: (
51
+ hasattr(fn, "transform")
52
+ and fn.__qualname__
53
+ not in [
54
+ "BaseTransform",
55
+ "Compose",
56
+ "Choose",
57
+ ]
58
+ )
59
  tfm = argbind.bind_module(transforms, "train", "val", filter_fn=filter_fn)
60
 
61
  # model
 
107
 
108
 
109
  def num_params_hook(o, p):
110
+ return o + f" {p / 1e6:<.3f}M params."
111
 
112
 
113
  def add_num_params_repr_hook(model):
 
114
  from functools import partial
115
 
116
+ import numpy as np
117
+
118
  for n, m in model.named_modules():
119
  o = m.extra_repr()
120
  p = sum([np.prod(p.size()) for p in m.parameters()])
 
151
 
152
  return accuracy
153
 
154
+
155
  def _metrics(z_hat, r, target, flat_mask, output):
156
  for r_range in [(0, 0.5), (0.5, 1.0)]:
157
  unmasked_target = target.masked_fill(flat_mask.bool(), IGNORE_INDEX)
 
222
  mask = pmask.random(z, r)
223
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
224
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
225
+
226
  z_mask_latent = vn.embedding.from_codes(z_mask, state.codec)
227
 
228
  dtype = torch.bfloat16 if accel.amp else None
 
249
  output=output,
250
  )
251
 
 
252
  accel.backward(output["loss"])
253
 
254
  output["other/learning_rate"] = state.optimizer.param_groups[0]["lr"]
255
  output["other/batch_size"] = z.shape[0]
256
 
 
257
  accel.scaler.unscale_(state.optimizer)
258
  output["other/grad_norm"] = torch.nn.utils.clip_grad_norm_(
259
  state.model.parameters(), state.grad_clip_val
 
265
  state.scheduler.step()
266
  accel.update()
267
 
 
268
  return {k: v for k, v in sorted(output.items())}
269
 
270
 
 
295
  z[:, vn.n_conditioning_codebooks :, :],
296
  )
297
 
298
+ flat_mask = codebook_flatten(mask[:, vn.n_conditioning_codebooks :, :])
 
 
299
 
300
  output = {}
301
  # replace target with ignore index for masked tokens
 
336
  tags.append(f"{state.tracker.step // 1000}k")
337
 
338
  if state.tracker.is_best("val", "loss"):
339
+ state.tracker.print("Best model so far")
340
  tags.append("best")
341
 
342
  if fine_tune:
343
+ for tag in tags:
344
+ # save the lora model
345
  (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
346
  torch.save(
347
+ lora.lora_state_dict(accel.unwrap(state.model)),
348
+ f"{save_path}/{tag}/lora.pth",
349
  )
350
 
351
  for tag in tags:
 
381
 
382
  def save_imputation(state, z, val_idx, writer):
383
  n_prefix = int(z.shape[-1] * 0.25)
384
+ n_suffix = int(z.shape[-1] * 0.25)
385
 
386
  vn = accel.unwrap(state.model)
387
 
 
400
  time_steps=z.shape[-1],
401
  start_tokens=z[i][None, ...],
402
  mask=mask[i][None, ...],
403
+ )
404
+ )
405
  imputed = AudioSignal.batch(imputed)
406
 
407
  for i in range(len(val_idx)):
 
441
 
442
  r = torch.linspace(0.1, 0.95, len(val_idx)).to(accel.device)
443
 
 
444
  mask = pmask.random(z, r)
445
  mask = pmask.codebook_unmask(mask, vn.n_conditioning_codebooks)
446
  z_mask, mask = pmask.apply_mask(z, mask, vn.mask_token)
 
476
  save_imputation(state=state, z=z, val_idx=val_idx, writer=writer)
477
 
478
 
 
479
  @argbind.bind(without_prefix=True)
480
  def load(
481
  args,
 
495
  if args["fine_tune"]:
496
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
497
  model = torch.compile(
498
+ VampNet.load(
499
+ location=Path(fine_tune_checkpoint),
500
+ map_location="cpu",
501
  )
502
  )
503
+
504
  if resume:
505
  kwargs = {
506
  "folder": f"{save_path}/{tag}",
 
515
  f"Could not find a VampNet checkpoint in {kwargs['folder']}"
516
  )
517
 
 
 
 
518
  model = torch.compile(VampNet()) if model is None else model
519
  model = accel.prepare_model(model)
520
 
521
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
522
+ assert accel.unwrap(model).vocab_size == codec.quantizer.quantizers[0].codebook_size
 
 
523
 
524
  optimizer = AdamW(model.parameters(), use_zero=accel.use_ddp)
525
  scheduler = NoamScheduler(optimizer, d_model=accel.unwrap(model).embedding_dim)
 
530
  scheduler.load_state_dict(v_extra["scheduler.pth"])
531
  if "tracker.pth" in v_extra:
532
  tracker.load_state_dict(v_extra["tracker.pth"])
533
+
534
  criterion = CrossEntropyLoss()
535
 
536
  sample_rate = codec.sample_rate
537
 
538
  # a better rng for sampling from our schedule
539
+ rng = torch.quasirandom.SobolEngine(1, scramble=True, seed=args["seed"])
540
 
541
  # log a model summary w/ num params
542
  if accel.local_rank == 0:
 
569
  codec_ckpt: str = None,
570
  save_path: str = "ckpt",
571
  num_iters: int = int(1000e6),
572
+ save_iters: list = [
573
+ 10000,
574
+ 50000,
575
+ 100000,
576
+ 300000,
577
+ 500000,
578
+ ],
579
+ sample_freq: int = 10000,
580
  val_freq: int = 1000,
581
  batch_size: int = 12,
582
  val_idx: list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
583
  num_workers: int = 10,
584
+ fine_tune: bool = False,
585
  ):
586
  assert codec_ckpt is not None, "codec_ckpt is required"
587
 
 
598
  )
599
 
600
  # load the codec model
601
+ state: State = load(args=args, accel=accel, tracker=tracker, save_path=save_path)
 
 
 
 
602
  print("initialized state.")
603
 
604
  train_dataloader = accel.prepare_dataloader(
 
618
  )
619
  print("initialized dataloader.")
620
 
 
 
621
  if fine_tune:
622
  lora.mark_only_lora_as_trainable(state.model)
623
  print("marked only lora as trainable.")
 
650
  if tracker.step % val_freq == 0 or last_iter:
651
  validate(state, val_dataloader, accel)
652
  checkpoint(
653
+ state=state,
654
+ save_iters=save_iters,
655
+ save_path=save_path,
656
+ fine_tune=fine_tune,
657
+ )
658
 
659
  # Reset validation progress bar, print summary since last validation.
660
  tracker.done("val", f"Iteration {tracker.step}")
vampnet/scripts/utils/data/augment.py CHANGED
@@ -1,17 +1,12 @@
1
  from pathlib import Path
2
 
3
- import audiotools as at
4
- from audiotools import AudioSignal
5
-
6
  import argbind
7
- import tqdm
8
  import torch
9
-
10
-
11
- from torch_pitch_shift import pitch_shift, get_fast_shifts
12
- from torch_time_stretch import time_stretch, get_fast_stretches
13
-
14
- from audiotools.core.util import sample_from_dist
15
 
16
 
17
  @argbind.bind(without_prefix=True)
@@ -20,11 +15,11 @@ def augment(
20
  dest_folder: Path = None,
21
  n_augmentations: int = 10,
22
  ):
23
- """
24
- Augment a folder of audio files by applying audiotools and pedalboard transforms.
25
 
26
- The dest foler will contain a folder for each of the clean dataset's files.
27
- Under each of these folders, there will be a clean file and many augmented files.
28
  """
29
  assert audio_folder is not None
30
  assert dest_folder is not None
@@ -37,24 +32,31 @@ def augment(
37
 
38
  src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
-
41
  for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
42
  # apply pedalboard transforms
43
  for j in range(n_augmentations):
44
  # pitch shift between -7 and 7 semitones
45
  import random
 
46
  dst = chunk.clone()
47
  dst.samples = pitch_shift(
48
- dst.samples,
49
- shift=random.choice(get_fast_shifts(src.sample_rate,
50
- condition=lambda x: x >= 0.25 and x <= 1.0)),
51
- sample_rate=src.sample_rate
 
 
 
52
  )
53
  dst.samples = time_stretch(
54
  dst.samples,
55
- stretch=random.choice(get_fast_stretches(src.sample_rate,
56
- condition=lambda x: x >= 0.667 and x <= 1.5, )),
57
- sample_rate=src.sample_rate,
 
 
 
 
58
  )
59
 
60
  dst.cpu().write(subdir / f"{i}-{j}.wav")
@@ -64,4 +66,4 @@ if __name__ == "__main__":
64
  args = argbind.parse_args()
65
 
66
  with argbind.scope(args):
67
- augment()
 
1
  from pathlib import Path
2
 
 
 
 
3
  import argbind
4
+ import audiotools as at
5
  import torch
6
+ import tqdm
7
+ from audiotools import AudioSignal
8
+ from torch_pitch_shift import get_fast_shifts, pitch_shift
9
+ from torch_time_stretch import get_fast_stretches, time_stretch
 
 
10
 
11
 
12
  @argbind.bind(without_prefix=True)
 
15
  dest_folder: Path = None,
16
  n_augmentations: int = 10,
17
  ):
18
+ """
19
+ Augment a folder of audio files by applying audiotools and pedalboard transforms.
20
 
21
+ The dest foler will contain a folder for each of the clean dataset's files.
22
+ Under each of these folders, there will be a clean file and many augmented files.
23
  """
24
  assert audio_folder is not None
25
  assert dest_folder is not None
 
32
 
33
  src = AudioSignal(audio_file).to("cuda" if torch.cuda.is_available() else "cpu")
34
 
 
35
  for i, chunk in tqdm.tqdm(enumerate(src.windows(10, 10))):
36
  # apply pedalboard transforms
37
  for j in range(n_augmentations):
38
  # pitch shift between -7 and 7 semitones
39
  import random
40
+
41
  dst = chunk.clone()
42
  dst.samples = pitch_shift(
43
+ dst.samples,
44
+ shift=random.choice(
45
+ get_fast_shifts(
46
+ src.sample_rate, condition=lambda x: x >= 0.25 and x <= 1.0
47
+ )
48
+ ),
49
+ sample_rate=src.sample_rate,
50
  )
51
  dst.samples = time_stretch(
52
  dst.samples,
53
+ stretch=random.choice(
54
+ get_fast_stretches(
55
+ src.sample_rate,
56
+ condition=lambda x: x >= 0.667 and x <= 1.5,
57
+ )
58
+ ),
59
+ sample_rate=src.sample_rate,
60
  )
61
 
62
  dst.cpu().write(subdir / f"{i}-{j}.wav")
 
66
  args = argbind.parse_args()
67
 
68
  with argbind.scope(args):
69
+ augment()
vampnet/scripts/utils/data/maestro-reorg.py CHANGED
@@ -1,6 +1,6 @@
1
- from pathlib import Path
2
  import json
3
  import os
 
4
 
5
  maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
6
  output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
@@ -14,7 +14,7 @@ train = []
14
  validation = []
15
  test = []
16
  for key, split in maestro["split"].items():
17
- audio_filename = maestro['audio_filename'][key]
18
  if split == "train":
19
  train.append(audio_filename)
20
  elif split == "test":
@@ -36,4 +36,4 @@ for audio_filename in validation:
36
  for audio_filename in test:
37
  p = output_path / "test" / audio_filename
38
  p.parent.mkdir(parents=True, exist_ok=True)
39
- os.symlink(maestro_path / audio_filename, p)
 
 
1
  import json
2
  import os
3
+ from pathlib import Path
4
 
5
  maestro_path = Path("/media/CHONK/hugo/maestro-v3.0.0")
6
  output_path = Path("/media/CHONK/hugo/maestro-v3.0.0-split")
 
14
  validation = []
15
  test = []
16
  for key, split in maestro["split"].items():
17
+ audio_filename = maestro["audio_filename"][key]
18
  if split == "train":
19
  train.append(audio_filename)
20
  elif split == "test":
 
36
  for audio_filename in test:
37
  p = output_path / "test" / audio_filename
38
  p.parent.mkdir(parents=True, exist_ok=True)
39
+ os.symlink(maestro_path / audio_filename, p)
vampnet/scripts/utils/plots.py CHANGED
@@ -2,16 +2,19 @@ import matplotlib.pyplot as plt
2
  import seaborn as sns
3
  from pandas.api.types import CategoricalDtype
4
 
 
5
  def plot_metrics(metrics, condition_to_latex, title, color_palette):
6
  # Add a new column to your dataframe with the latex representation
7
- metrics['condition_latex'] = metrics['condition'].map(condition_to_latex)
8
 
9
  # Order condition_latex as per the condition_to_latex dictionary
10
  cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
11
- metrics['condition_latex'] = metrics['condition_latex'].astype(cat_type)
12
 
13
  # Compute mean and std for each condition for each metric
14
- grouped = metrics.groupby('condition_latex')[['mel', 'frechet']].agg(['mean', 'std'])
 
 
15
 
16
  fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
17
 
@@ -22,16 +25,28 @@ def plot_metrics(metrics, condition_to_latex, title, color_palette):
22
  bar_colors = [color_palette[condition] for condition in grouped.index]
23
 
24
  # Plot mel
25
- sns.boxplot(x='condition_latex', y='mel', data=metrics, ax=axs[0], palette=color_palette, showfliers=False)
26
- axs[0].set_ylabel('Mel Spectrogram Loss \u2190')
27
- axs[0].set_xlabel('') # Remove x-axis label
28
- axs[0].set_xticklabels(grouped.index, rotation=0, ha='center')
 
 
 
 
 
 
 
29
 
30
  # Plot frechet
31
- axs[1].bar(grouped.index, grouped['frechet']['mean'], yerr=grouped['frechet']['std'], color=bar_colors)
32
- axs[1].set_ylabel('FAD \u2190')
33
- axs[1].set_xlabel('') # Remove x-axis label
34
- axs[1].set_xticklabels(grouped.index, rotation=0, ha='center')
 
 
 
 
 
35
 
36
  # Adjust the space between plots
37
  plt.subplots_adjust(hspace=0.1)
@@ -40,4 +55,4 @@ def plot_metrics(metrics, condition_to_latex, title, color_palette):
40
  plt.tight_layout(rect=[0, 0, 1, 0.96])
41
 
42
  # Reduce the space between suptitle and the plot
43
- plt.subplots_adjust(top=0.92)
 
2
  import seaborn as sns
3
  from pandas.api.types import CategoricalDtype
4
 
5
+
6
  def plot_metrics(metrics, condition_to_latex, title, color_palette):
7
  # Add a new column to your dataframe with the latex representation
8
+ metrics["condition_latex"] = metrics["condition"].map(condition_to_latex)
9
 
10
  # Order condition_latex as per the condition_to_latex dictionary
11
  cat_type = CategoricalDtype(categories=condition_to_latex.values(), ordered=True)
12
+ metrics["condition_latex"] = metrics["condition_latex"].astype(cat_type)
13
 
14
  # Compute mean and std for each condition for each metric
15
+ grouped = metrics.groupby("condition_latex")[["mel", "frechet"]].agg(
16
+ ["mean", "std"]
17
+ )
18
 
19
  fig, axs = plt.subplots(2, 1, figsize=(7, 5.25))
20
 
 
25
  bar_colors = [color_palette[condition] for condition in grouped.index]
26
 
27
  # Plot mel
28
+ sns.boxplot(
29
+ x="condition_latex",
30
+ y="mel",
31
+ data=metrics,
32
+ ax=axs[0],
33
+ palette=color_palette,
34
+ showfliers=False,
35
+ )
36
+ axs[0].set_ylabel("Mel Spectrogram Loss \u2190")
37
+ axs[0].set_xlabel("") # Remove x-axis label
38
+ axs[0].set_xticklabels(grouped.index, rotation=0, ha="center")
39
 
40
  # Plot frechet
41
+ axs[1].bar(
42
+ grouped.index,
43
+ grouped["frechet"]["mean"],
44
+ yerr=grouped["frechet"]["std"],
45
+ color=bar_colors,
46
+ )
47
+ axs[1].set_ylabel("FAD \u2190")
48
+ axs[1].set_xlabel("") # Remove x-axis label
49
+ axs[1].set_xticklabels(grouped.index, rotation=0, ha="center")
50
 
51
  # Adjust the space between plots
52
  plt.subplots_adjust(hspace=0.1)
 
55
  plt.tight_layout(rect=[0, 0, 1, 0.96])
56
 
57
  # Reduce the space between suptitle and the plot
58
+ plt.subplots_adjust(top=0.92)
vampnet/scripts/utils/remove_quiet_files.py CHANGED
@@ -1,9 +1,11 @@
1
  # removes files with loudness below 24db
2
 
3
- from pathlib import Path
4
  import shutil
5
- import audiotools as at
 
6
  import argbind
 
 
7
 
8
  @argbind.bind(without_prefix=True)
9
  def remove_quiet_files(
@@ -14,7 +16,7 @@ def remove_quiet_files(
14
  # copy src to dest
15
  dest_dir.mkdir(parents=True, exist_ok=True)
16
  shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
17
-
18
  audio_files = at.util.find_audio(dest_dir)
19
  for audio_file in audio_files:
20
  sig = at.AudioSignal(audio_file)
@@ -22,8 +24,9 @@ def remove_quiet_files(
22
  audio_file.unlink()
23
  print(f"removed {audio_file}")
24
 
 
25
  if __name__ == "__main__":
26
  args = argbind.parse_args()
27
 
28
  with argbind.scope(args):
29
- remove_quiet_files()
 
1
  # removes files with loudness below 24db
2
 
 
3
  import shutil
4
+ from pathlib import Path
5
+
6
  import argbind
7
+ import audiotools as at
8
+
9
 
10
  @argbind.bind(without_prefix=True)
11
  def remove_quiet_files(
 
16
  # copy src to dest
17
  dest_dir.mkdir(parents=True, exist_ok=True)
18
  shutil.copytree(src_dir, dest_dir, dirs_exist_ok=True)
19
+
20
  audio_files = at.util.find_audio(dest_dir)
21
  for audio_file in audio_files:
22
  sig = at.AudioSignal(audio_file)
 
24
  audio_file.unlink()
25
  print(f"removed {audio_file}")
26
 
27
+
28
  if __name__ == "__main__":
29
  args = argbind.parse_args()
30
 
31
  with argbind.scope(args):
32
+ remove_quiet_files()
vampnet/scripts/utils/split.py CHANGED
@@ -1,29 +1,25 @@
1
- from pathlib import Path
2
- import random
3
- import shutil
4
  import os
5
- import json
 
6
 
7
  import argbind
8
  from tqdm import tqdm
9
- from tqdm.contrib.concurrent import thread_map
10
-
11
- from audiotools.core import util
12
 
13
 
14
  @argbind.bind(without_prefix=True)
15
  def train_test_split(
16
- audio_folder: str = ".",
17
  test_size: float = 0.2,
18
  seed: int = 42,
19
  pattern: str = "**/*.mp3",
20
  ):
21
- print(f"finding audio")
22
 
23
  audio_folder = Path(audio_folder)
24
  audio_files = list(tqdm(audio_folder.glob(pattern)))
25
  print(f"found {len(audio_files)} audio files")
26
-
27
  # split according to test_size
28
  n_test = int(len(audio_files) * test_size)
29
  n_train = len(audio_files) - n_test
@@ -35,30 +31,28 @@ def train_test_split(
35
  train_files = audio_files[:n_train]
36
  test_files = audio_files[n_train:]
37
 
38
-
39
  print(f"Train files: {len(train_files)}")
40
  print(f"Test files: {len(test_files)}")
41
  continue_ = input("Continue [yn]? ") or "n"
42
 
43
  if continue_ != "y":
44
  return
45
-
46
- for split, files in (
47
- ("train", train_files), ("test", test_files)
48
- ):
49
  for file in tqdm(files):
50
- out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
 
 
51
  out_file.parent.mkdir(exist_ok=True, parents=True)
52
  os.symlink(file, out_file)
53
 
54
  # save split as json
55
  with open(Path(audio_folder) / f"{split}.json", "w") as f:
56
  json.dump([str(f) for f in files], f)
57
-
58
 
59
-
60
  if __name__ == "__main__":
61
- args = argbind.parse_args()
62
 
63
  with argbind.scope(args):
64
- train_test_split()
 
1
+ import json
 
 
2
  import os
3
+ import random
4
+ from pathlib import Path
5
 
6
  import argbind
7
  from tqdm import tqdm
 
 
 
8
 
9
 
10
  @argbind.bind(without_prefix=True)
11
  def train_test_split(
12
+ audio_folder: str = ".",
13
  test_size: float = 0.2,
14
  seed: int = 42,
15
  pattern: str = "**/*.mp3",
16
  ):
17
+ print("finding audio")
18
 
19
  audio_folder = Path(audio_folder)
20
  audio_files = list(tqdm(audio_folder.glob(pattern)))
21
  print(f"found {len(audio_files)} audio files")
22
+
23
  # split according to test_size
24
  n_test = int(len(audio_files) * test_size)
25
  n_train = len(audio_files) - n_test
 
31
  train_files = audio_files[:n_train]
32
  test_files = audio_files[n_train:]
33
 
 
34
  print(f"Train files: {len(train_files)}")
35
  print(f"Test files: {len(test_files)}")
36
  continue_ = input("Continue [yn]? ") or "n"
37
 
38
  if continue_ != "y":
39
  return
40
+
41
+ for split, files in (("train", train_files), ("test", test_files)):
 
 
42
  for file in tqdm(files):
43
+ out_file = (
44
+ audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
45
+ )
46
  out_file.parent.mkdir(exist_ok=True, parents=True)
47
  os.symlink(file, out_file)
48
 
49
  # save split as json
50
  with open(Path(audio_folder) / f"{split}.json", "w") as f:
51
  json.dump([str(f) for f in files], f)
 
52
 
53
+
54
  if __name__ == "__main__":
55
+ args = argbind.parse_args()
56
 
57
  with argbind.scope(args):
58
+ train_test_split()
vampnet/scripts/utils/split_long_audio_file.py CHANGED
@@ -1,34 +1,37 @@
1
  from pathlib import Path
2
- import argbind
3
 
 
4
  import audiotools as at
5
  import tqdm
6
 
7
 
8
  @argbind.bind(without_prefix=True)
9
- def split_long_audio_file(
10
- file: str = None,
11
- max_chunk_size_s: int = 60*10
12
- ):
13
  file = Path(file)
14
  output_dir = file.parent / file.stem
15
  output_dir.mkdir()
16
-
17
  sig = at.AudioSignal(file)
18
 
19
  # split into chunks
20
- for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
- window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
- preprocess=True))
 
 
 
 
 
23
  ):
24
  sig.write(output_dir / f"{i}.wav")
25
 
26
  print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
-
28
  return output_dir
29
 
 
30
  if __name__ == "__main__":
31
  args = argbind.parse_args()
32
 
33
  with argbind.scope(args):
34
- split_long_audio_file()
 
1
  from pathlib import Path
 
2
 
3
+ import argbind
4
  import audiotools as at
5
  import tqdm
6
 
7
 
8
  @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(file: str = None, max_chunk_size_s: int = 60 * 10):
 
 
 
10
  file = Path(file)
11
  output_dir = file.parent / file.stem
12
  output_dir.mkdir()
13
+
14
  sig = at.AudioSignal(file)
15
 
16
  # split into chunks
17
+ for i, sig in tqdm.tqdm(
18
+ enumerate(
19
+ sig.windows(
20
+ window_duration=max_chunk_size_s,
21
+ hop_duration=max_chunk_size_s / 2,
22
+ preprocess=True,
23
+ )
24
+ )
25
  ):
26
  sig.write(output_dir / f"{i}.wav")
27
 
28
  print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
29
+
30
  return output_dir
31
 
32
+
33
  if __name__ == "__main__":
34
  args = argbind.parse_args()
35
 
36
  with argbind.scope(args):
37
+ split_long_audio_file()
vampnet/scripts/utils/stage.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import subprocess
3
  from pathlib import Path
4
 
5
  import argbind
 
1
  import os
 
2
  from pathlib import Path
3
 
4
  import argbind
vampnet/scripts/utils/visualize_embeddings.py CHANGED
@@ -3,19 +3,20 @@ TODO: train a linear probe
3
  usage:
4
  python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
5
  """
 
 
 
6
  from pathlib import Path
7
  from typing import List
8
 
9
- import audiotools as at
10
- from audiotools import AudioSignal
11
  import argbind
12
- import torch
13
  import numpy as np
14
- import zipfile
15
- import json
 
16
 
17
  from vampnet.interface import Interface
18
- import tqdm
19
 
20
  # bind the Interface to argbind
21
  Interface = argbind.bind(Interface)
@@ -34,6 +35,7 @@ def smart_plotly_export(fig, save_path: Path):
34
  # TODO: come back and make this prettier
35
  elif img_format == "numpy":
36
  import io
 
37
  from PIL import Image
38
 
39
  def plotly_fig2array(fig):
@@ -72,6 +74,7 @@ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="
72
 
73
  if method == "umap":
74
  from umap import UMAP
 
75
  reducer = UMAP(n_components=n_components)
76
  elif method == "tsne":
77
  from sklearn.manifold import TSNE
@@ -100,11 +103,16 @@ def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="
100
  )
101
  if n_components == 2:
102
  fig = px.scatter(
103
- df, x="x", y="y", color="label", hover_name="name", title=fig_title,
 
 
 
 
 
104
  )
105
 
106
  elif n_components == 3:
107
- df['z'] = projs[:, 2]
108
  fig = px.scatter_3d(
109
  df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
110
  )
@@ -139,15 +147,15 @@ def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10):
139
  # [20, 1, 600ish, 768]
140
 
141
  # squeeze batch dim (1 bc layer should be dim 0)
142
- assert (
143
- embeddings.shape[1] == 1
144
- ), f"expected batch dim to be 1, got {embeddings.shape[0]}"
145
  embeddings = embeddings.squeeze(1)
146
 
147
  num_layers = embeddings.shape[0]
148
- assert (
149
- layer < num_layers
150
- ), f"layer {layer} is out of bounds for model with {num_layers} layers"
151
 
152
  # do meanpooling over the time dimension
153
  embeddings = embeddings.mean(dim=-2)
@@ -169,7 +177,6 @@ class AnnotatedEmbedding:
169
  def save(self, path):
170
  """Save the Embedding object to a given path as a zip file."""
171
  with zipfile.ZipFile(path, "w") as archive:
172
-
173
  # Save numpy array
174
  with archive.open("embedding.npy", "w") as f:
175
  np.save(f, self.embedding)
@@ -187,7 +194,6 @@ class AnnotatedEmbedding:
187
  def load(cls, path):
188
  """Load the Embedding object from a given zip path."""
189
  with zipfile.ZipFile(path, "r") as archive:
190
-
191
  # Load numpy array
192
  with archive.open("embedding.npy") as f:
193
  embedding = np.load(f)
 
3
  usage:
4
  python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output
5
  """
6
+
7
+ import json
8
+ import zipfile
9
  from pathlib import Path
10
  from typing import List
11
 
 
 
12
  import argbind
13
+ import audiotools as at
14
  import numpy as np
15
+ import torch
16
+ import tqdm
17
+ from audiotools import AudioSignal
18
 
19
  from vampnet.interface import Interface
 
20
 
21
  # bind the Interface to argbind
22
  Interface = argbind.bind(Interface)
 
35
  # TODO: come back and make this prettier
36
  elif img_format == "numpy":
37
  import io
38
+
39
  from PIL import Image
40
 
41
  def plotly_fig2array(fig):
 
74
 
75
  if method == "umap":
76
  from umap import UMAP
77
+
78
  reducer = UMAP(n_components=n_components)
79
  elif method == "tsne":
80
  from sklearn.manifold import TSNE
 
103
  )
104
  if n_components == 2:
105
  fig = px.scatter(
106
+ df,
107
+ x="x",
108
+ y="y",
109
+ color="label",
110
+ hover_name="name",
111
+ title=fig_title,
112
  )
113
 
114
  elif n_components == 3:
115
+ df["z"] = projs[:, 2]
116
  fig = px.scatter_3d(
117
  df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title
118
  )
 
147
  # [20, 1, 600ish, 768]
148
 
149
  # squeeze batch dim (1 bc layer should be dim 0)
150
+ assert embeddings.shape[1] == 1, (
151
+ f"expected batch dim to be 1, got {embeddings.shape[0]}"
152
+ )
153
  embeddings = embeddings.squeeze(1)
154
 
155
  num_layers = embeddings.shape[0]
156
+ assert layer < num_layers, (
157
+ f"layer {layer} is out of bounds for model with {num_layers} layers"
158
+ )
159
 
160
  # do meanpooling over the time dimension
161
  embeddings = embeddings.mean(dim=-2)
 
177
  def save(self, path):
178
  """Save the Embedding object to a given path as a zip file."""
179
  with zipfile.ZipFile(path, "w") as archive:
 
180
  # Save numpy array
181
  with archive.open("embedding.npy", "w") as f:
182
  np.save(f, self.embedding)
 
194
  def load(cls, path):
195
  """Load the Embedding object from a given zip path."""
196
  with zipfile.ZipFile(path, "r") as archive:
 
197
  # Load numpy array
198
  with archive.open("embedding.npy") as f:
199
  embedding = np.load(f)
vampnet/scripts/utils/xeno-canto-dl.py CHANGED
@@ -1,6 +1,5 @@
1
  from xenopy import Query
2
 
3
-
4
  SPECIES = [
5
  "American Robin",
6
  "Northern Cardinal",
@@ -208,27 +207,36 @@ SPECIES = [
208
  "American Woodcock",
209
  "Wilson's Phalarope",
210
  "Red-necked Phalarope",
211
- "Red Phalarope"
212
  ]
213
 
214
  from pathlib import Path
215
 
 
216
  def remove_spaces(s):
217
  return s.replace(" ", "")
218
 
219
- for species in SPECIES:
 
220
  if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
221
  continue
222
  try:
223
  q = Query(
224
- name=species, q="A", length="10-30",
225
- )
 
 
226
 
227
  # retrieve metadata
228
  metafiles = q.retrieve_meta(verbose=True)
229
  # retrieve recordings
230
- q.retrieve_recordings(multiprocess=True, nproc=10, attempts=10, outdir="/media/CHONK/hugo/xeno-canto-full/")
 
 
 
 
 
231
 
232
  except:
233
  print("Failed to download " + species)
234
- continue
 
1
  from xenopy import Query
2
 
 
3
  SPECIES = [
4
  "American Robin",
5
  "Northern Cardinal",
 
207
  "American Woodcock",
208
  "Wilson's Phalarope",
209
  "Red-necked Phalarope",
210
+ "Red Phalarope",
211
  ]
212
 
213
  from pathlib import Path
214
 
215
+
216
  def remove_spaces(s):
217
  return s.replace(" ", "")
218
 
219
+
220
+ for species in SPECIES:
221
  if Path("/media/CHONK/hugo/xeno-canto-full/" + remove_spaces(species)).exists():
222
  continue
223
  try:
224
  q = Query(
225
+ name=species,
226
+ q="A",
227
+ length="10-30",
228
+ )
229
 
230
  # retrieve metadata
231
  metafiles = q.retrieve_meta(verbose=True)
232
  # retrieve recordings
233
+ q.retrieve_recordings(
234
+ multiprocess=True,
235
+ nproc=10,
236
+ attempts=10,
237
+ outdir="/media/CHONK/hugo/xeno-canto-full/",
238
+ )
239
 
240
  except:
241
  print("Failed to download " + species)
242
+ continue
vampnet/setup.py CHANGED
@@ -1,5 +1,4 @@
1
- from setuptools import find_packages
2
- from setuptools import setup
3
 
4
  with open("README.md") as f:
5
  long_description = f.read()
@@ -29,7 +28,7 @@ setup(
29
  "Cython",
30
  ],
31
  install_requires=[
32
- "Cython", # Added by WAM because it seems to be needed by this repo?
33
  "torch",
34
  "pydantic==2.10.6",
35
  "argbind>=0.3.2",
@@ -40,8 +39,6 @@ setup(
40
  "gradio",
41
  "loralib",
42
  "torch_pitch_shift",
43
- "plotly", # Added by WAM for clustering (see https://github.com/hugofloresgarcia/vampnet/issues/20)
44
  "pyharp",
45
-
46
  ],
47
  )
 
1
+ from setuptools import find_packages, setup
 
2
 
3
  with open("README.md") as f:
4
  long_description = f.read()
 
28
  "Cython",
29
  ],
30
  install_requires=[
31
+ "Cython", # Added by WhAM because it seems to be needed by this repo?
32
  "torch",
33
  "pydantic==2.10.6",
34
  "argbind>=0.3.2",
 
39
  "gradio",
40
  "loralib",
41
  "torch_pitch_shift",
 
42
  "pyharp",
 
43
  ],
44
  )
vampnet/vampnet/__init__.py CHANGED
@@ -1,6 +1,4 @@
1
-
2
- from . import modules
3
- from . import scheduler
4
  from .interface import Interface
5
 
6
  __version__ = "0.0.1"
 
1
+ from . import modules, scheduler
 
 
2
  from .interface import Interface
3
 
4
  __version__ = "0.0.1"
vampnet/vampnet/beats.py CHANGED
@@ -1,19 +1,14 @@
1
  import json
2
  import logging
3
- import warnings
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
- from typing import Any
7
- from typing import List
8
- from typing import Tuple
9
- from typing import Union
10
 
11
  import librosa
12
- import torch
13
  import numpy as np
 
14
  from audiotools import AudioSignal
15
 
16
-
17
  logging.basicConfig(level=logging.INFO)
18
 
19
  ###################
@@ -60,7 +55,6 @@ def mkdir(path: Union[Path, str]) -> Path:
60
  return p
61
 
62
 
63
-
64
  ###################
65
  # beat data #
66
  ###################
@@ -204,7 +198,9 @@ class WaveBeat(BeatTracker):
204
  def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
205
  from wavebeat.dstcn import dsTCNModel
206
 
207
- model = dsTCNModel.load_from_checkpoint(ckpt_path, map_location=torch.device(device), weights_only=False)
 
 
208
  model.eval()
209
 
210
  self.device = device
@@ -247,4 +243,4 @@ def load_beat_tracker(beat_tracker: str, **kwargs) -> BeatTracker:
247
  f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
248
  )
249
 
250
- return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
 
1
  import json
2
  import logging
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
+ from typing import List, Tuple, Union
 
 
 
6
 
7
  import librosa
 
8
  import numpy as np
9
+ import torch
10
  from audiotools import AudioSignal
11
 
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
  ###################
 
55
  return p
56
 
57
 
 
58
  ###################
59
  # beat data #
60
  ###################
 
198
  def __init__(self, ckpt_path: str = "checkpoints/wavebeat", device: str = "cpu"):
199
  from wavebeat.dstcn import dsTCNModel
200
 
201
+ model = dsTCNModel.load_from_checkpoint(
202
+ ckpt_path, map_location=torch.device(device), weights_only=False
203
+ )
204
  model.eval()
205
 
206
  self.device = device
 
243
  f"Unknown beat tracker {beat_tracker}. Available: {list_beat_trackers()}"
244
  )
245
 
246
+ return BEAT_TRACKER_REGISTRY[beat_tracker](**kwargs)
vampnet/vampnet/interface.py CHANGED
@@ -1,19 +1,17 @@
1
- import os
2
- from pathlib import Path
3
  import math
 
4
 
5
- import torch
6
  import numpy as np
 
7
  from audiotools import AudioSignal
8
- import tqdm
9
-
10
- from .modules.transformer import VampNet
11
- from .beats import WaveBeat
12
- from .mask import *
13
 
14
  # from dac.model.dac import DAC
15
  from lac.model.lac import LAC as DAC
16
 
 
 
 
 
17
 
18
  def signal_concat(
19
  audio_signals: list,
@@ -24,7 +22,7 @@ def signal_concat(
24
 
25
 
26
  def _load_model(
27
- ckpt: str,
28
  lora_ckpt: str = None,
29
  device: str = "cpu",
30
  chunk_size_s: int = 10,
@@ -41,7 +39,9 @@ def _load_model(
41
  if should_cont != "y":
42
  raise Exception("aborting")
43
  else:
44
- model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
 
 
45
 
46
  model.to(device)
47
  model.eval()
@@ -49,7 +49,6 @@ def _load_model(
49
  return model
50
 
51
 
52
-
53
  class Interface(torch.nn.Module):
54
  def __init__(
55
  self,
@@ -60,8 +59,8 @@ class Interface(torch.nn.Module):
60
  codec_ckpt: str = None,
61
  wavebeat_ckpt: str = None,
62
  device: str = "cpu",
63
- coarse_chunk_size_s: int = 10,
64
- coarse2fine_chunk_size_s: int = 3,
65
  ):
66
  super().__init__()
67
  assert codec_ckpt is not None, "must provide a codec checkpoint"
@@ -98,7 +97,7 @@ class Interface(torch.nn.Module):
98
  self.device = device
99
 
100
  def lora_load(
101
- self,
102
  coarse_ckpt: str = None,
103
  c2f_ckpt: str = None,
104
  full_ckpts: bool = False,
@@ -106,7 +105,7 @@ class Interface(torch.nn.Module):
106
  if full_ckpts:
107
  if coarse_ckpt is not None:
108
  self.coarse = _load_model(
109
- ckpt=coarse_ckpt,
110
  device=self.device,
111
  chunk_size_s=self.coarse.chunk_size_s,
112
  )
@@ -129,7 +128,7 @@ class Interface(torch.nn.Module):
129
  print(f"loading c2f from {c2f_ckpt}")
130
  self.c2f.load_state_dict(state_dict, strict=False)
131
  self.c2f.to(self.device)
132
-
133
  def s2t(self, seconds: float):
134
  """seconds to tokens"""
135
  if isinstance(seconds, np.ndarray):
@@ -140,7 +139,7 @@ class Interface(torch.nn.Module):
140
  def s2t2s(self, seconds: float):
141
  """seconds to tokens to seconds"""
142
  return self.t2s(self.s2t(seconds))
143
-
144
  def t2s(self, tokens: int):
145
  """tokens to seconds"""
146
  return tokens * self.codec.hop_length / self.codec.sample_rate
@@ -159,7 +158,7 @@ class Interface(torch.nn.Module):
159
 
160
  def to_signal(self, z: torch.Tensor):
161
  return self.coarse.to_signal(z, self.codec)
162
-
163
  def preprocess(self, signal: AudioSignal):
164
  signal = (
165
  signal.clone()
@@ -169,41 +168,39 @@ class Interface(torch.nn.Module):
169
  .ensure_max_of_audio(1.0)
170
  )
171
  return signal
172
-
173
  @torch.inference_mode()
174
  def encode(self, signal: AudioSignal):
175
  signal = self.preprocess(signal).to(self.device)
176
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
177
  return z
178
 
179
- def snap_to_beats(
180
- self,
181
- signal: AudioSignal
182
- ):
183
  assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
184
  beats, downbeats = self.beat_tracker.extract_beats(signal)
185
-
186
  # trim the signa around the first beat time
187
- samples_begin = int(beats[0] * signal.sample_rate )
188
  samples_end = int(beats[-1] * signal.sample_rate)
189
  print(beats[0])
190
  signal = signal.clone().trim(samples_begin, signal.length - samples_end)
191
 
192
  return signal
193
 
194
- def make_beat_mask(self,
195
- signal: AudioSignal,
196
- before_beat_s: float = 0.0,
197
- after_beat_s: float = 0.02,
198
- mask_downbeats: bool = True,
199
- mask_upbeats: bool = True,
200
- downbeat_downsample_factor: int = None,
201
- beat_downsample_factor: int = None,
202
- dropout: float = 0.0,
203
- invert: bool = True,
 
204
  ):
205
- """make a beat synced mask. that is, make a mask that
206
- places 1s at and around the beat, and 0s everywhere else.
207
  """
208
  assert self.beat_tracker is not None, "No beat tracker loaded"
209
 
@@ -214,14 +211,16 @@ class Interface(torch.nn.Module):
214
  beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
215
 
216
  # remove downbeats from beats
217
- beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))]
 
 
218
  beats_z = beats_z.tolist()
219
  downbeats_z = downbeats_z.tolist()
220
 
221
- # make the mask
222
  seq_len = self.s2t(signal.duration)
223
  mask = torch.zeros(seq_len, device=self.device)
224
-
225
  mask_b4 = self.s2t(before_beat_s)
226
  mask_after = self.s2t(after_beat_s)
227
 
@@ -241,44 +240,39 @@ class Interface(torch.nn.Module):
241
  downbeats_z = downbeats_z[::downbeat_downsample_factor]
242
  print(f"beats_z: {len(beats_z)}")
243
  print(f"downbeats_z: {len(downbeats_z)}")
244
-
245
  if mask_upbeats:
246
  for beat_idx in beats_z:
247
  _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
248
- num_steps = mask[_slice[0]:_slice[1]].shape[0]
249
  _m = torch.ones(num_steps, device=self.device)
250
  _m_mask = torch.bernoulli(_m * (1 - dropout))
251
  _m = _m * _m_mask.long()
252
-
253
- mask[_slice[0]:_slice[1]] = _m
254
 
255
  if mask_downbeats:
256
  for downbeat_idx in downbeats_z:
257
  _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
258
- num_steps = mask[_slice[0]:_slice[1]].shape[0]
259
  _m = torch.ones(num_steps, device=self.device)
260
  _m_mask = torch.bernoulli(_m * (1 - dropout))
261
  _m = _m * _m_mask.long()
262
-
263
- mask[_slice[0]:_slice[1]] = _m
264
-
265
  mask = mask.clamp(0, 1)
266
  if invert:
267
  mask = 1 - mask
268
-
269
  mask = mask[None, None, :].bool().long()
270
  if self.c2f is not None:
271
  mask = mask.repeat(1, self.c2f.n_codebooks, 1)
272
  else:
273
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
274
  return mask
275
-
276
- def coarse_to_fine(
277
- self,
278
- z: torch.Tensor,
279
- mask: torch.Tensor = None,
280
- **kwargs
281
- ):
282
  assert self.c2f is not None, "No coarse2fine model loaded"
283
  length = z.shape[-1]
284
  chunk_len = self.s2t(self.c2f.chunk_size_s)
@@ -288,49 +282,57 @@ class Interface(torch.nn.Module):
288
  if length % chunk_len != 0:
289
  pad_len = chunk_len - (length % chunk_len)
290
  z = torch.nn.functional.pad(z, (0, pad_len))
291
- mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
 
 
 
 
292
 
293
  n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
294
  if n_codebooks_to_append > 0:
295
- z = torch.cat([
296
- z,
297
- torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
298
- ], dim=1)
 
 
 
 
 
299
 
300
  # set the mask to 0 for all conditioning codebooks
301
  if mask is not None:
302
  mask = mask.clone()
303
- mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
304
 
305
  fine_z = []
306
  for i in range(n_chunks):
307
  chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
308
- mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
309
-
 
 
 
 
310
  chunk = self.c2f.generate(
311
  codec=self.codec,
312
  time_steps=chunk_len,
313
  start_tokens=chunk,
314
  return_signal=False,
315
  mask=mask_chunk,
316
- **kwargs
317
  )
318
  fine_z.append(chunk)
319
 
320
  fine_z = torch.cat(fine_z, dim=-1)
321
  return fine_z[:, :, :length].clone()
322
-
323
- def coarse_vamp(
324
- self,
325
- z,
326
- mask,
327
- return_mask=False,
328
- gen_fn=None,
329
- **kwargs
330
- ):
331
  # coarse z
332
  cz = z[:, : self.coarse.n_codebooks, :].clone()
333
- assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
 
 
334
 
335
  mask = mask[:, : self.coarse.n_codebooks, :]
336
 
@@ -342,41 +344,39 @@ class Interface(torch.nn.Module):
342
  codec=self.codec,
343
  time_steps=cz.shape[-1],
344
  start_tokens=cz,
345
- mask=mask,
346
  return_signal=False,
347
- **kwargs
348
  )
349
 
350
  # add the fine codes back in
351
- c_vamp = torch.cat(
352
- [c_vamp, z[:, self.coarse.n_codebooks :, :]],
353
- dim=1
354
- )
355
 
356
  if return_mask:
357
  return c_vamp, cz_masked
358
-
359
  return c_vamp
360
 
361
 
362
  if __name__ == "__main__":
363
- import audiotools as at
364
  import logging
 
 
 
365
  logger = logging.getLogger()
366
  logger.setLevel(logging.INFO)
367
  torch.set_printoptions(threshold=10000)
368
  at.util.seed(42)
369
 
370
  interface = Interface(
371
- coarse_ckpt="./models/vampnet/coarse.pth",
372
- coarse2fine_ckpt="./models/vampnet/c2f.pth",
373
  codec_ckpt="./models/vampnet/codec.pth",
374
- device="cuda",
375
- wavebeat_ckpt="./models/wavebeat.pth"
376
  )
377
 
378
-
379
- sig = at.AudioSignal('assets/example.wav')
380
 
381
  z = interface.encode(sig)
382
  breakpoint()
@@ -398,19 +398,18 @@ if __name__ == "__main__":
398
  # mask = codebook_unmask(mask, 0)
399
 
400
  mask = inpaint(z, n_prefix=100, n_suffix=100)
401
-
402
  zv, mask_z = interface.coarse_vamp(
403
- z,
404
  mask=mask,
405
  sampling_steps=36,
406
  temperature=8.0,
407
- return_mask=True,
408
- gen_fn=interface.coarse.generate
409
  )
410
-
411
 
412
  use_coarse2fine = True
413
- if use_coarse2fine:
414
  zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
415
  breakpoint()
416
 
@@ -418,5 +417,3 @@ if __name__ == "__main__":
418
 
419
  sig = interface.to_signal(zv).cpu()
420
  print("done")
421
-
422
-
 
 
 
1
  import math
2
+ from pathlib import Path
3
 
 
4
  import numpy as np
5
+ import torch
6
  from audiotools import AudioSignal
 
 
 
 
 
7
 
8
  # from dac.model.dac import DAC
9
  from lac.model.lac import LAC as DAC
10
 
11
+ from .beats import WaveBeat
12
+ from .mask import *
13
+ from .modules.transformer import VampNet
14
+
15
 
16
  def signal_concat(
17
  audio_signals: list,
 
22
 
23
 
24
  def _load_model(
25
+ ckpt: str,
26
  lora_ckpt: str = None,
27
  device: str = "cpu",
28
  chunk_size_s: int = 10,
 
39
  if should_cont != "y":
40
  raise Exception("aborting")
41
  else:
42
+ model.load_state_dict(
43
+ torch.load(lora_ckpt, map_location="cpu"), strict=False
44
+ )
45
 
46
  model.to(device)
47
  model.eval()
 
49
  return model
50
 
51
 
 
52
  class Interface(torch.nn.Module):
53
  def __init__(
54
  self,
 
59
  codec_ckpt: str = None,
60
  wavebeat_ckpt: str = None,
61
  device: str = "cpu",
62
+ coarse_chunk_size_s: int = 10,
63
+ coarse2fine_chunk_size_s: int = 3,
64
  ):
65
  super().__init__()
66
  assert codec_ckpt is not None, "must provide a codec checkpoint"
 
97
  self.device = device
98
 
99
  def lora_load(
100
+ self,
101
  coarse_ckpt: str = None,
102
  c2f_ckpt: str = None,
103
  full_ckpts: bool = False,
 
105
  if full_ckpts:
106
  if coarse_ckpt is not None:
107
  self.coarse = _load_model(
108
+ ckpt=coarse_ckpt,
109
  device=self.device,
110
  chunk_size_s=self.coarse.chunk_size_s,
111
  )
 
128
  print(f"loading c2f from {c2f_ckpt}")
129
  self.c2f.load_state_dict(state_dict, strict=False)
130
  self.c2f.to(self.device)
131
+
132
  def s2t(self, seconds: float):
133
  """seconds to tokens"""
134
  if isinstance(seconds, np.ndarray):
 
139
  def s2t2s(self, seconds: float):
140
  """seconds to tokens to seconds"""
141
  return self.t2s(self.s2t(seconds))
142
+
143
  def t2s(self, tokens: int):
144
  """tokens to seconds"""
145
  return tokens * self.codec.hop_length / self.codec.sample_rate
 
158
 
159
  def to_signal(self, z: torch.Tensor):
160
  return self.coarse.to_signal(z, self.codec)
161
+
162
  def preprocess(self, signal: AudioSignal):
163
  signal = (
164
  signal.clone()
 
168
  .ensure_max_of_audio(1.0)
169
  )
170
  return signal
171
+
172
  @torch.inference_mode()
173
  def encode(self, signal: AudioSignal):
174
  signal = self.preprocess(signal).to(self.device)
175
  z = self.codec.encode(signal.samples, signal.sample_rate)["codes"]
176
  return z
177
 
178
+ def snap_to_beats(self, signal: AudioSignal):
 
 
 
179
  assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
180
  beats, downbeats = self.beat_tracker.extract_beats(signal)
181
+
182
  # trim the signa around the first beat time
183
+ samples_begin = int(beats[0] * signal.sample_rate)
184
  samples_end = int(beats[-1] * signal.sample_rate)
185
  print(beats[0])
186
  signal = signal.clone().trim(samples_begin, signal.length - samples_end)
187
 
188
  return signal
189
 
190
+ def make_beat_mask(
191
+ self,
192
+ signal: AudioSignal,
193
+ before_beat_s: float = 0.0,
194
+ after_beat_s: float = 0.02,
195
+ mask_downbeats: bool = True,
196
+ mask_upbeats: bool = True,
197
+ downbeat_downsample_factor: int = None,
198
+ beat_downsample_factor: int = None,
199
+ dropout: float = 0.0,
200
+ invert: bool = True,
201
  ):
202
+ """make a beat synced mask. that is, make a mask that
203
+ places 1s at and around the beat, and 0s everywhere else.
204
  """
205
  assert self.beat_tracker is not None, "No beat tracker loaded"
206
 
 
211
  beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats)
212
 
213
  # remove downbeats from beats
214
+ beats_z = torch.tensor(beats_z)[
215
+ ~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))
216
+ ]
217
  beats_z = beats_z.tolist()
218
  downbeats_z = downbeats_z.tolist()
219
 
220
+ # make the mask
221
  seq_len = self.s2t(signal.duration)
222
  mask = torch.zeros(seq_len, device=self.device)
223
+
224
  mask_b4 = self.s2t(before_beat_s)
225
  mask_after = self.s2t(after_beat_s)
226
 
 
240
  downbeats_z = downbeats_z[::downbeat_downsample_factor]
241
  print(f"beats_z: {len(beats_z)}")
242
  print(f"downbeats_z: {len(downbeats_z)}")
243
+
244
  if mask_upbeats:
245
  for beat_idx in beats_z:
246
  _slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
247
+ num_steps = mask[_slice[0] : _slice[1]].shape[0]
248
  _m = torch.ones(num_steps, device=self.device)
249
  _m_mask = torch.bernoulli(_m * (1 - dropout))
250
  _m = _m * _m_mask.long()
251
+
252
+ mask[_slice[0] : _slice[1]] = _m
253
 
254
  if mask_downbeats:
255
  for downbeat_idx in downbeats_z:
256
  _slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
257
+ num_steps = mask[_slice[0] : _slice[1]].shape[0]
258
  _m = torch.ones(num_steps, device=self.device)
259
  _m_mask = torch.bernoulli(_m * (1 - dropout))
260
  _m = _m * _m_mask.long()
261
+
262
+ mask[_slice[0] : _slice[1]] = _m
263
+
264
  mask = mask.clamp(0, 1)
265
  if invert:
266
  mask = 1 - mask
267
+
268
  mask = mask[None, None, :].bool().long()
269
  if self.c2f is not None:
270
  mask = mask.repeat(1, self.c2f.n_codebooks, 1)
271
  else:
272
  mask = mask.repeat(1, self.coarse.n_codebooks, 1)
273
  return mask
274
+
275
+ def coarse_to_fine(self, z: torch.Tensor, mask: torch.Tensor = None, **kwargs):
 
 
 
 
 
276
  assert self.c2f is not None, "No coarse2fine model loaded"
277
  length = z.shape[-1]
278
  chunk_len = self.s2t(self.c2f.chunk_size_s)
 
282
  if length % chunk_len != 0:
283
  pad_len = chunk_len - (length % chunk_len)
284
  z = torch.nn.functional.pad(z, (0, pad_len))
285
+ mask = (
286
+ torch.nn.functional.pad(mask, (0, pad_len))
287
+ if mask is not None
288
+ else None
289
+ )
290
 
291
  n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
292
  if n_codebooks_to_append > 0:
293
+ z = torch.cat(
294
+ [
295
+ z,
296
+ torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1])
297
+ .long()
298
+ .to(self.device),
299
+ ],
300
+ dim=1,
301
+ )
302
 
303
  # set the mask to 0 for all conditioning codebooks
304
  if mask is not None:
305
  mask = mask.clone()
306
+ mask[:, : self.c2f.n_conditioning_codebooks, :] = 0
307
 
308
  fine_z = []
309
  for i in range(n_chunks):
310
  chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
311
+ mask_chunk = (
312
+ mask[:, :, i * chunk_len : (i + 1) * chunk_len]
313
+ if mask is not None
314
+ else None
315
+ )
316
+
317
  chunk = self.c2f.generate(
318
  codec=self.codec,
319
  time_steps=chunk_len,
320
  start_tokens=chunk,
321
  return_signal=False,
322
  mask=mask_chunk,
323
+ **kwargs,
324
  )
325
  fine_z.append(chunk)
326
 
327
  fine_z = torch.cat(fine_z, dim=-1)
328
  return fine_z[:, :, :length].clone()
329
+
330
+ def coarse_vamp(self, z, mask, return_mask=False, gen_fn=None, **kwargs):
 
 
 
 
 
 
 
331
  # coarse z
332
  cz = z[:, : self.coarse.n_codebooks, :].clone()
333
+ assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), (
334
+ f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}"
335
+ )
336
 
337
  mask = mask[:, : self.coarse.n_codebooks, :]
338
 
 
344
  codec=self.codec,
345
  time_steps=cz.shape[-1],
346
  start_tokens=cz,
347
+ mask=mask,
348
  return_signal=False,
349
+ **kwargs,
350
  )
351
 
352
  # add the fine codes back in
353
+ c_vamp = torch.cat([c_vamp, z[:, self.coarse.n_codebooks :, :]], dim=1)
 
 
 
354
 
355
  if return_mask:
356
  return c_vamp, cz_masked
357
+
358
  return c_vamp
359
 
360
 
361
  if __name__ == "__main__":
 
362
  import logging
363
+
364
+ import audiotools as at
365
+
366
  logger = logging.getLogger()
367
  logger.setLevel(logging.INFO)
368
  torch.set_printoptions(threshold=10000)
369
  at.util.seed(42)
370
 
371
  interface = Interface(
372
+ coarse_ckpt="./models/vampnet/coarse.pth",
373
+ coarse2fine_ckpt="./models/vampnet/c2f.pth",
374
  codec_ckpt="./models/vampnet/codec.pth",
375
+ device="cuda",
376
+ wavebeat_ckpt="./models/wavebeat.pth",
377
  )
378
 
379
+ sig = at.AudioSignal("assets/example.wav")
 
380
 
381
  z = interface.encode(sig)
382
  breakpoint()
 
398
  # mask = codebook_unmask(mask, 0)
399
 
400
  mask = inpaint(z, n_prefix=100, n_suffix=100)
401
+
402
  zv, mask_z = interface.coarse_vamp(
403
+ z,
404
  mask=mask,
405
  sampling_steps=36,
406
  temperature=8.0,
407
+ return_mask=True,
408
+ gen_fn=interface.coarse.generate,
409
  )
 
410
 
411
  use_coarse2fine = True
412
+ if use_coarse2fine:
413
  zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
414
  breakpoint()
415
 
 
417
 
418
  sig = interface.to_signal(zv).cpu()
419
  print("done")
 
 
vampnet/vampnet/mask.py CHANGED
@@ -1,33 +1,34 @@
1
- from typing import Optional
2
-
3
  import torch
4
  from audiotools import AudioSignal
5
 
6
  from .util import scalar_to_batch_tensor
7
 
 
8
  def _gamma(r):
9
  return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
10
 
 
11
  def _invgamma(y):
12
  if not torch.is_tensor(y):
13
  y = torch.tensor(y)[None]
14
  return 2 * y.acos() / torch.pi
15
 
 
16
  def full_mask(x: torch.Tensor):
17
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
18
  return torch.ones_like(x).long()
19
 
 
20
  def empty_mask(x: torch.Tensor):
21
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
22
  return torch.zeros_like(x).long()
23
 
24
- def apply_mask(
25
- x: torch.Tensor,
26
- mask: torch.Tensor,
27
- mask_token: int
28
- ):
29
  assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
30
- assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
 
 
31
  assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
32
  assert ~torch.any(mask > 1), "mask must be binary"
33
  assert ~torch.any(mask < 0), "mask must be binary"
@@ -37,10 +38,8 @@ def apply_mask(
37
 
38
  return x, mask
39
 
40
- def random(
41
- x: torch.Tensor,
42
- r: torch.Tensor
43
- ):
44
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
45
  if not isinstance(r, torch.Tensor):
46
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
@@ -53,6 +52,7 @@ def random(
53
 
54
  return mask
55
 
 
56
  def linear_random(
57
  x: torch.Tensor,
58
  r: torch.Tensor,
@@ -71,19 +71,21 @@ def linear_random(
71
 
72
  return mask
73
 
74
- def inpaint(x: torch.Tensor,
 
 
75
  n_prefix,
76
  n_suffix,
77
  ):
78
  assert n_prefix is not None
79
  assert n_suffix is not None
80
-
81
  mask = full_mask(x)
82
 
83
  # if we have a prefix or suffix, set their mask prob to 0
84
  if n_prefix > 0:
85
  if not isinstance(n_prefix, torch.Tensor):
86
- n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
87
  for i, n in enumerate(n_prefix):
88
  if n > 0:
89
  mask[i, :, :n] = 0.0
@@ -94,13 +96,15 @@ def inpaint(x: torch.Tensor,
94
  if n > 0:
95
  mask[i, :, -n:] = 0.0
96
 
97
-
98
  return mask
99
 
100
- def periodic_mask(x: torch.Tensor,
101
- period: int, width: int = 1,
102
- random_roll=False,
103
- ):
 
 
 
104
  mask = full_mask(x)
105
  if period == 0:
106
  return mask
@@ -113,8 +117,8 @@ def periodic_mask(x: torch.Tensor,
113
  for j in range(mask.shape[-1]):
114
  if j % factor == 0:
115
  # figure out how wide the mask should be
116
- j_start = max(0, j - width // 2 )
117
- j_end = min(mask.shape[-1] - 1, j + width // 2 ) + 1
118
  # flip a coin for each position in the mask
119
  j_mask = torch.bernoulli(torch.ones(j_end - j_start))
120
  assert torch.all(j_mask == 1)
@@ -129,10 +133,8 @@ def periodic_mask(x: torch.Tensor,
129
 
130
  return mask
131
 
132
- def codebook_unmask(
133
- mask: torch.Tensor,
134
- n_conditioning_codebooks: int
135
- ):
136
  if n_conditioning_codebooks == None:
137
  return mask
138
  # if we have any conditioning codebooks, set their mask to 0
@@ -140,18 +142,18 @@ def codebook_unmask(
140
  mask[:, :n_conditioning_codebooks, :] = 0
141
  return mask
142
 
 
143
  def codebook_mask(mask: torch.Tensor, start: int):
144
  mask = mask.clone()
145
  mask[:, start:, :] = 1
146
  return mask
147
 
148
- def mask_and(
149
- mask1: torch.Tensor,
150
- mask2: torch.Tensor
151
- ):
152
  assert mask1.shape == mask2.shape, "masks must be same shape"
153
  return torch.min(mask1, mask2)
154
 
 
155
  def dropout(
156
  mask: torch.Tensor,
157
  p: float,
@@ -164,19 +166,20 @@ def dropout(
164
  mask = ~mask.round().bool()
165
  return mask.long()
166
 
167
- def mask_or(
168
- mask1: torch.Tensor,
169
- mask2: torch.Tensor
170
- ):
171
- assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
172
  assert mask1.max() <= 1, "mask1 must be binary"
173
  assert mask2.max() <= 1, "mask2 must be binary"
174
  assert mask1.min() >= 0, "mask1 must be binary"
175
  assert mask2.min() >= 0, "mask2 must be binary"
176
  return (mask1 + mask2).clamp(0, 1)
177
 
 
178
  def time_stretch_mask(
179
- x: torch.Tensor,
180
  stretch_factor: int,
181
  ):
182
  assert stretch_factor >= 1, "stretch factor must be >= 1"
@@ -189,18 +192,19 @@ def time_stretch_mask(
189
  mask = periodic_mask(x, stretch_factor, width=1)
190
  return mask
191
 
 
192
  def _onset_times_madmom(wav_path, sample_rate, hop_length):
193
- from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
 
194
  proc = RNNOnsetProcessor(online=False)
195
- onsetproc = OnsetPeakPickingProcessor(
196
- threshold=0.3, fps=sample_rate / hop_length
197
- )
198
  act = proc(wav_path)
199
  return onsetproc(act)
200
 
201
 
202
  def _onset_times_librosa(wav_path, sample_rate, hop_length):
203
  import librosa
 
204
  y, sr = librosa.load(wav_path, sr=sample_rate)
205
  onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
206
  onset_frames = librosa.onset.onset_detect(
@@ -209,18 +213,14 @@ def _onset_times_librosa(wav_path, sample_rate, hop_length):
209
  return librosa.frames_to_time(onset_frames, sr=sr, hop_length=hop_length)
210
 
211
 
212
- def onset_mask(
213
- sig: AudioSignal,
214
- z: torch.Tensor,
215
- interface,
216
- width: int = 1
217
- ):
218
- import librosa
219
  import tempfile
220
- import numpy as np
 
221
 
222
  try:
223
  import madmom # noqa: F401
 
224
  _get_onset_times = _onset_times_madmom
225
  except ImportError:
226
  print("madmom not installed, falling back to librosa for onset detection")
@@ -228,7 +228,7 @@ def onset_mask(
228
 
229
  hop_length = interface.codec.hop_length
230
 
231
- with tempfile.NamedTemporaryFile(suffix='.wav') as f:
232
  sig = sig.clone()
233
  sig.write(f.name)
234
 
@@ -238,9 +238,9 @@ def onset_mask(
238
  )
239
 
240
  if onset_indices.shape[0] == 0:
241
- mask = empty_mask(z)
242
- print(f"no onsets found, returning empty mask")
243
- else:
244
  torch.set_printoptions(threshold=1000)
245
  print("onset indices: ", onset_indices)
246
  print("onset times: ", onset_times)
@@ -251,12 +251,11 @@ def onset_mask(
251
  for onset_index in onset_indices:
252
  onset_index = min(onset_index, n_timesteps - 1)
253
  onset_index = max(onset_index, 0)
254
- mask[:, :, onset_index - width:onset_index + width] = 0.0
255
 
256
  print(mask)
257
-
258
- return mask
259
 
 
260
 
261
 
262
  if __name__ == "__main__":
 
 
 
1
  import torch
2
  from audiotools import AudioSignal
3
 
4
  from .util import scalar_to_batch_tensor
5
 
6
+
7
  def _gamma(r):
8
  return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
9
 
10
+
11
  def _invgamma(y):
12
  if not torch.is_tensor(y):
13
  y = torch.tensor(y)[None]
14
  return 2 * y.acos() / torch.pi
15
 
16
+
17
  def full_mask(x: torch.Tensor):
18
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
19
  return torch.ones_like(x).long()
20
 
21
+
22
  def empty_mask(x: torch.Tensor):
23
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
24
  return torch.zeros_like(x).long()
25
 
26
+
27
+ def apply_mask(x: torch.Tensor, mask: torch.Tensor, mask_token: int):
 
 
 
28
  assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
29
+ assert mask.shape == x.shape, (
30
+ f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
31
+ )
32
  assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
33
  assert ~torch.any(mask > 1), "mask must be binary"
34
  assert ~torch.any(mask < 0), "mask must be binary"
 
38
 
39
  return x, mask
40
 
41
+
42
+ def random(x: torch.Tensor, r: torch.Tensor):
 
 
43
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
44
  if not isinstance(r, torch.Tensor):
45
  r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device)
 
52
 
53
  return mask
54
 
55
+
56
  def linear_random(
57
  x: torch.Tensor,
58
  r: torch.Tensor,
 
71
 
72
  return mask
73
 
74
+
75
+ def inpaint(
76
+ x: torch.Tensor,
77
  n_prefix,
78
  n_suffix,
79
  ):
80
  assert n_prefix is not None
81
  assert n_suffix is not None
82
+
83
  mask = full_mask(x)
84
 
85
  # if we have a prefix or suffix, set their mask prob to 0
86
  if n_prefix > 0:
87
  if not isinstance(n_prefix, torch.Tensor):
88
+ n_prefix = scalar_to_batch_tensor(n_prefix, x.shape[0]).to(x.device)
89
  for i, n in enumerate(n_prefix):
90
  if n > 0:
91
  mask[i, :, :n] = 0.0
 
96
  if n > 0:
97
  mask[i, :, -n:] = 0.0
98
 
 
99
  return mask
100
 
101
+
102
+ def periodic_mask(
103
+ x: torch.Tensor,
104
+ period: int,
105
+ width: int = 1,
106
+ random_roll=False,
107
+ ):
108
  mask = full_mask(x)
109
  if period == 0:
110
  return mask
 
117
  for j in range(mask.shape[-1]):
118
  if j % factor == 0:
119
  # figure out how wide the mask should be
120
+ j_start = max(0, j - width // 2)
121
+ j_end = min(mask.shape[-1] - 1, j + width // 2) + 1
122
  # flip a coin for each position in the mask
123
  j_mask = torch.bernoulli(torch.ones(j_end - j_start))
124
  assert torch.all(j_mask == 1)
 
133
 
134
  return mask
135
 
136
+
137
+ def codebook_unmask(mask: torch.Tensor, n_conditioning_codebooks: int):
 
 
138
  if n_conditioning_codebooks == None:
139
  return mask
140
  # if we have any conditioning codebooks, set their mask to 0
 
142
  mask[:, :n_conditioning_codebooks, :] = 0
143
  return mask
144
 
145
+
146
  def codebook_mask(mask: torch.Tensor, start: int):
147
  mask = mask.clone()
148
  mask[:, start:, :] = 1
149
  return mask
150
 
151
+
152
+ def mask_and(mask1: torch.Tensor, mask2: torch.Tensor):
 
 
153
  assert mask1.shape == mask2.shape, "masks must be same shape"
154
  return torch.min(mask1, mask2)
155
 
156
+
157
  def dropout(
158
  mask: torch.Tensor,
159
  p: float,
 
166
  mask = ~mask.round().bool()
167
  return mask.long()
168
 
169
+
170
+ def mask_or(mask1: torch.Tensor, mask2: torch.Tensor):
171
+ assert mask1.shape == mask2.shape, (
172
+ f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
173
+ )
174
  assert mask1.max() <= 1, "mask1 must be binary"
175
  assert mask2.max() <= 1, "mask2 must be binary"
176
  assert mask1.min() >= 0, "mask1 must be binary"
177
  assert mask2.min() >= 0, "mask2 must be binary"
178
  return (mask1 + mask2).clamp(0, 1)
179
 
180
+
181
  def time_stretch_mask(
182
+ x: torch.Tensor,
183
  stretch_factor: int,
184
  ):
185
  assert stretch_factor >= 1, "stretch factor must be >= 1"
 
192
  mask = periodic_mask(x, stretch_factor, width=1)
193
  return mask
194
 
195
+
196
  def _onset_times_madmom(wav_path, sample_rate, hop_length):
197
+ from madmom.features.onsets import OnsetPeakPickingProcessor, RNNOnsetProcessor
198
+
199
  proc = RNNOnsetProcessor(online=False)
200
+ onsetproc = OnsetPeakPickingProcessor(threshold=0.3, fps=sample_rate / hop_length)
 
 
201
  act = proc(wav_path)
202
  return onsetproc(act)
203
 
204
 
205
  def _onset_times_librosa(wav_path, sample_rate, hop_length):
206
  import librosa
207
+
208
  y, sr = librosa.load(wav_path, sr=sample_rate)
209
  onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
210
  onset_frames = librosa.onset.onset_detect(
 
213
  return librosa.frames_to_time(onset_frames, sr=sr, hop_length=hop_length)
214
 
215
 
216
+ def onset_mask(sig: AudioSignal, z: torch.Tensor, interface, width: int = 1):
 
 
 
 
 
 
217
  import tempfile
218
+
219
+ import librosa
220
 
221
  try:
222
  import madmom # noqa: F401
223
+
224
  _get_onset_times = _onset_times_madmom
225
  except ImportError:
226
  print("madmom not installed, falling back to librosa for onset detection")
 
228
 
229
  hop_length = interface.codec.hop_length
230
 
231
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
232
  sig = sig.clone()
233
  sig.write(f.name)
234
 
 
238
  )
239
 
240
  if onset_indices.shape[0] == 0:
241
+ mask = empty_mask(z)
242
+ print("no onsets found, returning empty mask")
243
+ else:
244
  torch.set_printoptions(threshold=1000)
245
  print("onset indices: ", onset_indices)
246
  print("onset times: ", onset_times)
 
251
  for onset_index in onset_indices:
252
  onset_index = min(onset_index, n_timesteps - 1)
253
  onset_index = max(onset_index, 0)
254
+ mask[:, :, onset_index - width : onset_index + width] = 0.0
255
 
256
  print(mask)
 
 
257
 
258
+ return mask
259
 
260
 
261
  if __name__ == "__main__":
vampnet/vampnet/modules/__init__.py CHANGED
@@ -3,4 +3,4 @@ import audiotools
3
  audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
  audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
 
6
- from .transformer import VampNet
 
3
  audiotools.ml.BaseModel.INTERN += ["vampnet.modules.**"]
4
  audiotools.ml.BaseModel.EXTERN += ["einops", "flash_attn.flash_attention", "loralib"]
5
 
6
+ from .transformer import VampNet
vampnet/vampnet/modules/activations.py CHANGED
@@ -1,9 +1,7 @@
1
  import math
2
- import numpy as np
3
  import torch
4
  import torch.nn as nn
5
- import torch.nn.functional as F
6
- from einops import rearrange
7
 
8
 
9
  class NewGELU(nn.Module):
@@ -25,6 +23,7 @@ class NewGELU(nn.Module):
25
  )
26
  )
27
 
 
28
  class GatedGELU(nn.Module):
29
  def __init__(self):
30
  super().__init__()
@@ -34,6 +33,7 @@ class GatedGELU(nn.Module):
34
  p1, p2 = x.chunk(2, dim=dim)
35
  return p1 * self.gelu(p2)
36
 
 
37
  class Snake1d(nn.Module):
38
  def __init__(self, channels):
39
  super().__init__()
@@ -42,6 +42,7 @@ class Snake1d(nn.Module):
42
  def forward(self, x):
43
  return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
44
 
 
45
  def get_activation(name: str = "relu"):
46
  if name == "relu":
47
  return nn.ReLU
@@ -52,4 +53,4 @@ def get_activation(name: str = "relu"):
52
  elif name == "snake":
53
  return Snake1d
54
  else:
55
- raise ValueError(f"Unrecognized activation {name}")
 
1
  import math
2
+
3
  import torch
4
  import torch.nn as nn
 
 
5
 
6
 
7
  class NewGELU(nn.Module):
 
23
  )
24
  )
25
 
26
+
27
  class GatedGELU(nn.Module):
28
  def __init__(self):
29
  super().__init__()
 
33
  p1, p2 = x.chunk(2, dim=dim)
34
  return p1 * self.gelu(p2)
35
 
36
+
37
  class Snake1d(nn.Module):
38
  def __init__(self, channels):
39
  super().__init__()
 
42
  def forward(self, x):
43
  return x + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * x).pow(2)
44
 
45
+
46
  def get_activation(name: str = "relu"):
47
  if name == "relu":
48
  return nn.ReLU
 
53
  elif name == "snake":
54
  return Snake1d
55
  else:
56
+ raise ValueError(f"Unrecognized activation {name}")
vampnet/vampnet/modules/layers.py CHANGED
@@ -1,13 +1,11 @@
1
- import time
2
- from typing import Optional
3
- from typing import Tuple
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- from einops import rearrange
9
  from torch.nn.utils import weight_norm
10
 
 
11
  # Scripting this brings model speed up 1.4x
12
  @torch.jit.script
13
  def snake(x, alpha):
@@ -132,10 +130,10 @@ class CodebookEmbedding(nn.Module):
132
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
 
134
  def from_codes(self, codes: torch.Tensor, codec):
135
- """
136
- get a sequence of continuous embeddings from a sequence of discrete codes.
137
  unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
138
- necessary for the language model, like <MASK>.
139
  """
140
  n_codebooks = codes.shape[1]
141
  latent = []
@@ -161,4 +159,3 @@ class CodebookEmbedding(nn.Module):
161
  """
162
  x = self.out_proj(latents)
163
  return x
164
-
 
1
+ from typing import Optional, Tuple
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
 
6
  from torch.nn.utils import weight_norm
7
 
8
+
9
  # Scripting this brings model speed up 1.4x
10
  @torch.jit.script
11
  def snake(x, alpha):
 
130
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
131
 
132
  def from_codes(self, codes: torch.Tensor, codec):
133
+ """
134
+ get a sequence of continuous embeddings from a sequence of discrete codes.
135
  unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
136
+ necessary for the language model, like <MASK>.
137
  """
138
  n_codebooks = codes.shape[1]
139
  latent = []
 
159
  """
160
  x = self.out_proj(latents)
161
  return x
 
vampnet/vampnet/modules/transformer.py CHANGED
@@ -1,22 +1,19 @@
1
- import math
2
  import logging
3
- from typing import Optional, Tuple, Union
 
4
 
 
 
5
  import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
  from einops import rearrange
10
- import loralib as lora
11
- import audiotools as at
12
 
13
- from .activations import get_activation
14
- from .layers import CodebookEmbedding
15
- from .layers import FiLM
16
- from .layers import SequentialWithFiLM
17
- from .layers import WNConv1d
18
- from ..util import scalar_to_batch_tensor, codebook_flatten, codebook_unflatten
19
  from ..mask import _gamma
 
 
 
20
 
21
  LORA_R = 8
22
 
@@ -279,6 +276,7 @@ class TransformerLayer(nn.Module):
279
 
280
  if flash_attn:
281
  from flash_attn.flash_attention import FlashMHA
 
282
  self.self_attn = FlashMHA(
283
  embed_dim=d_model,
284
  num_heads=n_heads,
@@ -410,9 +408,15 @@ class TransformerStack(nn.Module):
410
  def subsequent_mask(self, size):
411
  return torch.ones(1, size, size).tril().bool()
412
 
413
- def forward(self, x, x_mask, cond=None, src=None, src_mask=None,
414
- return_activations: bool = False
415
- ):
 
 
 
 
 
 
416
  """Computes a full transformer stack
417
  Parameters
418
  ----------
@@ -454,7 +458,6 @@ class TransformerStack(nn.Module):
454
  if return_activations:
455
  activations.append(x.detach())
456
 
457
-
458
  out = self.norm(x) if self.norm is not None else x
459
  if return_activations:
460
  return out, torch.stack(activations)
@@ -475,10 +478,12 @@ class VampNet(at.ml.BaseModel):
475
  vocab_size: int = 1024,
476
  flash_attn: bool = True,
477
  noise_mode: str = "mask",
478
- dropout: float = 0.1
479
  ):
480
  super().__init__()
481
- assert r_cond_dim == 0, f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
 
 
482
  self.n_heads = n_heads
483
  self.n_layers = n_layers
484
  self.r_cond_dim = r_cond_dim
@@ -530,13 +535,15 @@ class VampNet(at.ml.BaseModel):
530
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
531
 
532
  x = rearrange(x, "b d n -> b n d")
533
- out = self.transformer(x=x, x_mask=x_mask, return_activations=return_activations)
 
 
534
  if return_activations:
535
  out, activations = out
536
 
537
  out = rearrange(out, "b n d -> b d n")
538
 
539
- out = self.classifier(out, None) # no cond here!
540
 
541
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
542
 
@@ -544,7 +551,7 @@ class VampNet(at.ml.BaseModel):
544
  return out, activations
545
  else:
546
  return out
547
-
548
  def r_embed(self, r, max_positions=10000):
549
  if self.r_cond_dim > 0:
550
  dtype = r.dtype
@@ -564,11 +571,11 @@ class VampNet(at.ml.BaseModel):
564
  return emb.to(dtype)
565
  else:
566
  return r
567
-
568
  @torch.no_grad()
569
  def to_signal(self, z, codec):
570
  """
571
- convert a sequence of latents to a signal.
572
  """
573
  assert z.ndim == 3
574
 
@@ -588,7 +595,6 @@ class VampNet(at.ml.BaseModel):
588
 
589
  return signal
590
 
591
-
592
  @torch.no_grad()
593
  def generate(
594
  self,
@@ -604,16 +610,14 @@ class VampNet(at.ml.BaseModel):
604
  typical_min_tokens=1,
605
  top_p=None,
606
  return_signal=True,
607
- seed: int = None,
608
  sample_cutoff: float = 1.0,
609
  ):
610
  if seed is not None:
611
  at.util.seed(seed)
612
  logging.debug(f"beginning generation with {sampling_steps} steps")
613
 
614
-
615
-
616
- #####################
617
  # resolve initial z #
618
  #####################
619
  z = start_tokens
@@ -625,7 +629,6 @@ class VampNet(at.ml.BaseModel):
625
 
626
  logging.debug(f"created z with shape {z.shape}")
627
 
628
-
629
  #################
630
  # resolve mask #
631
  #################
@@ -636,9 +639,8 @@ class VampNet(at.ml.BaseModel):
636
  if mask.ndim == 2:
637
  mask = mask[:, None, :].repeat(1, z.shape[1], 1)
638
  # init_mask = mask.clone()
639
-
640
- logging.debug(f"created mask with shape {mask.shape}")
641
 
 
642
 
643
  ###########
644
  # set up #
@@ -663,33 +665,33 @@ class VampNet(at.ml.BaseModel):
663
  logging.debug(f"step {i} of {sampling_steps}")
664
 
665
  # our current schedule step
666
- r = scalar_to_batch_tensor(
667
- (i + 1) / sampling_steps,
668
- z.shape[0]
669
- ).to(z.device)
670
  logging.debug(f"r: {r}")
671
 
672
  # get latents
673
  latents = self.embedding.from_codes(z_masked, codec)
674
  logging.debug(f"computed latents with shape: {latents.shape}")
675
 
676
-
677
  # infer from latents
678
  # NOTE: this collapses the codebook dimension into the sequence dimension
679
- logits = self.forward(latents) # b, prob, seq
680
  logits = logits.permute(0, 2, 1) # b, seq, prob
681
  b = logits.shape[0]
682
 
683
  logging.debug(f"permuted logits with shape: {logits.shape}")
684
 
685
  sampled_z, selected_probs = sample_from_logits(
686
- logits, sample=(
687
- (i / sampling_steps) <= sample_cutoff
688
- ),
689
  temperature=sampling_temperature,
690
- typical_filtering=typical_filtering, typical_mass=typical_mass,
 
691
  typical_min_tokens=typical_min_tokens,
692
- top_k=None, top_p=top_p, return_probs=True,
 
 
693
  )
694
 
695
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
@@ -697,46 +699,38 @@ class VampNet(at.ml.BaseModel):
697
  # flatten z_masked and mask, so we can deal with the sampling logic
698
  # we'll unflatten them at the end of the loop for the next forward pass
699
  # remove conditioning codebooks, we'll add them back at the end
700
- z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
701
 
702
  mask = (z_masked == self.mask_token).int()
703
-
704
  # update the mask, remove conditioning codebooks from the mask
705
  logging.debug(f"updated mask with shape: {mask.shape}")
706
  # add z back into sampled z where the mask was false
707
- sampled_z = torch.where(
708
- mask.bool(), sampled_z, z_masked
709
- )
710
  logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
711
 
712
  # ignore any tokens that weren't masked
713
- selected_probs = torch.where(
714
- mask.bool(), selected_probs, torch.inf
715
- )
716
 
717
  # get the num tokens to mask, according to the schedule
718
- num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
 
 
719
  logging.debug(f"num to mask: {num_to_mask}")
720
 
721
  if i != (sampling_steps - 1):
722
  num_to_mask = torch.maximum(
723
  torch.tensor(1),
724
- torch.minimum(
725
- mask.sum(dim=-1, keepdim=True) - 1,
726
- num_to_mask
727
- )
728
  )
729
 
730
-
731
  # get our new mask
732
  mask = mask_by_random_topk(
733
- num_to_mask, selected_probs, mask_temperature * (1-r)
734
- )
735
 
736
  # update the mask
737
- z_masked = torch.where(
738
- mask.bool(), self.mask_token, sampled_z
739
- )
740
  logging.debug(f"updated z_masked with shape: {z_masked.shape}")
741
 
742
  z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
@@ -745,35 +739,37 @@ class VampNet(at.ml.BaseModel):
745
 
746
  # add conditioning codebooks back to z_masked
747
  z_masked = torch.cat(
748
- (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
 
 
 
749
  )
750
- logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
751
-
752
 
753
  # add conditioning codebooks back to sampled_z
754
  sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
755
  sampled_z = torch.cat(
756
- (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
757
  )
758
 
759
- logging.debug(f"finished sampling")
760
 
761
  if return_signal:
762
  return self.to_signal(sampled_z, codec)
763
  else:
764
  return sampled_z
765
 
 
766
  def sample_from_logits(
767
- logits,
768
- sample: bool = True,
769
- temperature: float = 1.0,
770
- top_k: int = None,
771
- top_p: float = None,
772
- typical_filtering: bool = False,
773
- typical_mass: float = 0.2,
774
- typical_min_tokens: int = 1,
775
- return_probs: bool = False
776
- ):
777
  """Convenience function to sample from a categorial distribution with input as
778
  unnormalized logits.
779
 
@@ -801,9 +797,8 @@ def sample_from_logits(
801
  shp = logits.shape[:-1]
802
 
803
  if typical_filtering:
804
- typical_filter(logits,
805
- typical_mass=typical_mass,
806
- typical_min_tokens=typical_min_tokens
807
  )
808
 
809
  # Apply top_k sampling
@@ -846,21 +841,20 @@ def sample_from_logits(
846
  return token, token_probs
847
  else:
848
  return token
849
-
850
 
851
 
852
  def mask_by_random_topk(
853
- num_to_mask: int,
854
- probs: torch.Tensor,
855
- temperature: float = 1.0,
856
- ):
857
  """
858
  Args:
859
  num_to_mask (int): number of tokens to mask
860
  probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
861
  temperature (float, optional): temperature. Defaults to 1.0.
862
  """
863
- logging.debug(f"masking by random topk")
864
  logging.debug(f"num to mask: {num_to_mask}")
865
  logging.debug(f"probs shape: {probs.shape}")
866
  logging.debug(f"temperature: {temperature}")
@@ -875,9 +869,7 @@ def mask_by_random_topk(
875
  logging.debug(f"sorted idx shape: {sorted_idx.shape}")
876
 
877
  # get the cut off threshold, given the mask length
878
- cut_off = torch.take_along_dim(
879
- sorted_confidence, num_to_mask, axis=-1
880
- )
881
  logging.debug(f"cut off shape: {cut_off.shape}")
882
 
883
  # mask out the tokens
@@ -886,10 +878,12 @@ def mask_by_random_topk(
886
 
887
  return mask
888
 
 
889
  def typical_filter(
890
- logits,
891
- typical_mass: float = 0.95,
892
- typical_min_tokens: int = 1,):
 
893
  nb, nt, _ = logits.shape
894
  x_flat = rearrange(logits, "b t l -> (b t ) l")
895
  x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
@@ -898,9 +892,7 @@ def typical_filter(
898
 
899
  c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
900
  c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
901
- x_flat_cumsum = (
902
- x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
903
- )
904
 
905
  last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
906
  sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
@@ -933,7 +925,7 @@ if __name__ == "__main__":
933
  ).to(device)
934
 
935
  r = torch.zeros(batch_size).to(device)
936
-
937
  z_mask_latent = torch.rand(
938
  batch_size, model.latent_dim * model.n_codebooks, seq_len
939
  ).to(device)
@@ -942,12 +934,10 @@ if __name__ == "__main__":
942
  pred = z_hat.argmax(dim=1)
943
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
944
 
945
- print(f"model has {num_params(model)/1e6:<.3f}M parameters")
946
  print(f"prediction has shape {pred.shape}")
947
  breakpoint()
948
 
949
  args = argbind.parse_args()
950
  with argbind.scope(args):
951
  try_model()
952
-
953
-
 
 
1
  import logging
2
+ import math
3
+ from typing import Optional
4
 
5
+ import audiotools as at
6
+ import loralib as lora
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
  from einops import rearrange
 
 
12
 
 
 
 
 
 
 
13
  from ..mask import _gamma
14
+ from ..util import codebook_flatten, codebook_unflatten, scalar_to_batch_tensor
15
+ from .activations import get_activation
16
+ from .layers import CodebookEmbedding, FiLM, SequentialWithFiLM, WNConv1d
17
 
18
  LORA_R = 8
19
 
 
276
 
277
  if flash_attn:
278
  from flash_attn.flash_attention import FlashMHA
279
+
280
  self.self_attn = FlashMHA(
281
  embed_dim=d_model,
282
  num_heads=n_heads,
 
408
  def subsequent_mask(self, size):
409
  return torch.ones(1, size, size).tril().bool()
410
 
411
+ def forward(
412
+ self,
413
+ x,
414
+ x_mask,
415
+ cond=None,
416
+ src=None,
417
+ src_mask=None,
418
+ return_activations: bool = False,
419
+ ):
420
  """Computes a full transformer stack
421
  Parameters
422
  ----------
 
458
  if return_activations:
459
  activations.append(x.detach())
460
 
 
461
  out = self.norm(x) if self.norm is not None else x
462
  if return_activations:
463
  return out, torch.stack(activations)
 
478
  vocab_size: int = 1024,
479
  flash_attn: bool = True,
480
  noise_mode: str = "mask",
481
+ dropout: float = 0.1,
482
  ):
483
  super().__init__()
484
+ assert r_cond_dim == 0, (
485
+ f"r_cond_dim must be 0 (not supported), but got {r_cond_dim}"
486
+ )
487
  self.n_heads = n_heads
488
  self.n_layers = n_layers
489
  self.r_cond_dim = r_cond_dim
 
535
  x_mask = torch.ones_like(x, dtype=torch.bool)[:, :1, :].squeeze(1)
536
 
537
  x = rearrange(x, "b d n -> b n d")
538
+ out = self.transformer(
539
+ x=x, x_mask=x_mask, return_activations=return_activations
540
+ )
541
  if return_activations:
542
  out, activations = out
543
 
544
  out = rearrange(out, "b n d -> b d n")
545
 
546
+ out = self.classifier(out, None) # no cond here!
547
 
548
  out = rearrange(out, "b (p c) t -> b p (t c)", c=self.n_predict_codebooks)
549
 
 
551
  return out, activations
552
  else:
553
  return out
554
+
555
  def r_embed(self, r, max_positions=10000):
556
  if self.r_cond_dim > 0:
557
  dtype = r.dtype
 
571
  return emb.to(dtype)
572
  else:
573
  return r
574
+
575
  @torch.no_grad()
576
  def to_signal(self, z, codec):
577
  """
578
+ convert a sequence of latents to a signal.
579
  """
580
  assert z.ndim == 3
581
 
 
595
 
596
  return signal
597
 
 
598
  @torch.no_grad()
599
  def generate(
600
  self,
 
610
  typical_min_tokens=1,
611
  top_p=None,
612
  return_signal=True,
613
+ seed: int = None,
614
  sample_cutoff: float = 1.0,
615
  ):
616
  if seed is not None:
617
  at.util.seed(seed)
618
  logging.debug(f"beginning generation with {sampling_steps} steps")
619
 
620
+ #####################
 
 
621
  # resolve initial z #
622
  #####################
623
  z = start_tokens
 
629
 
630
  logging.debug(f"created z with shape {z.shape}")
631
 
 
632
  #################
633
  # resolve mask #
634
  #################
 
639
  if mask.ndim == 2:
640
  mask = mask[:, None, :].repeat(1, z.shape[1], 1)
641
  # init_mask = mask.clone()
 
 
642
 
643
+ logging.debug(f"created mask with shape {mask.shape}")
644
 
645
  ###########
646
  # set up #
 
665
  logging.debug(f"step {i} of {sampling_steps}")
666
 
667
  # our current schedule step
668
+ r = scalar_to_batch_tensor((i + 1) / sampling_steps, z.shape[0]).to(
669
+ z.device
670
+ )
 
671
  logging.debug(f"r: {r}")
672
 
673
  # get latents
674
  latents = self.embedding.from_codes(z_masked, codec)
675
  logging.debug(f"computed latents with shape: {latents.shape}")
676
 
 
677
  # infer from latents
678
  # NOTE: this collapses the codebook dimension into the sequence dimension
679
+ logits = self.forward(latents) # b, prob, seq
680
  logits = logits.permute(0, 2, 1) # b, seq, prob
681
  b = logits.shape[0]
682
 
683
  logging.debug(f"permuted logits with shape: {logits.shape}")
684
 
685
  sampled_z, selected_probs = sample_from_logits(
686
+ logits,
687
+ sample=((i / sampling_steps) <= sample_cutoff),
 
688
  temperature=sampling_temperature,
689
+ typical_filtering=typical_filtering,
690
+ typical_mass=typical_mass,
691
  typical_min_tokens=typical_min_tokens,
692
+ top_k=None,
693
+ top_p=top_p,
694
+ return_probs=True,
695
  )
696
 
697
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
 
699
  # flatten z_masked and mask, so we can deal with the sampling logic
700
  # we'll unflatten them at the end of the loop for the next forward pass
701
  # remove conditioning codebooks, we'll add them back at the end
702
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks :, :])
703
 
704
  mask = (z_masked == self.mask_token).int()
705
+
706
  # update the mask, remove conditioning codebooks from the mask
707
  logging.debug(f"updated mask with shape: {mask.shape}")
708
  # add z back into sampled z where the mask was false
709
+ sampled_z = torch.where(mask.bool(), sampled_z, z_masked)
 
 
710
  logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
711
 
712
  # ignore any tokens that weren't masked
713
+ selected_probs = torch.where(mask.bool(), selected_probs, torch.inf)
 
 
714
 
715
  # get the num tokens to mask, according to the schedule
716
+ num_to_mask = (
717
+ torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
718
+ )
719
  logging.debug(f"num to mask: {num_to_mask}")
720
 
721
  if i != (sampling_steps - 1):
722
  num_to_mask = torch.maximum(
723
  torch.tensor(1),
724
+ torch.minimum(mask.sum(dim=-1, keepdim=True) - 1, num_to_mask),
 
 
 
725
  )
726
 
 
727
  # get our new mask
728
  mask = mask_by_random_topk(
729
+ num_to_mask, selected_probs, mask_temperature * (1 - r)
730
+ )
731
 
732
  # update the mask
733
+ z_masked = torch.where(mask.bool(), self.mask_token, sampled_z)
 
 
734
  logging.debug(f"updated z_masked with shape: {z_masked.shape}")
735
 
736
  z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
 
739
 
740
  # add conditioning codebooks back to z_masked
741
  z_masked = torch.cat(
742
+ (z[:, : self.n_conditioning_codebooks, :], z_masked), dim=1
743
+ )
744
+ logging.debug(
745
+ f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}"
746
  )
 
 
747
 
748
  # add conditioning codebooks back to sampled_z
749
  sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
750
  sampled_z = torch.cat(
751
+ (z[:, : self.n_conditioning_codebooks, :], sampled_z), dim=1
752
  )
753
 
754
+ logging.debug("finished sampling")
755
 
756
  if return_signal:
757
  return self.to_signal(sampled_z, codec)
758
  else:
759
  return sampled_z
760
 
761
+
762
  def sample_from_logits(
763
+ logits,
764
+ sample: bool = True,
765
+ temperature: float = 1.0,
766
+ top_k: int = None,
767
+ top_p: float = None,
768
+ typical_filtering: bool = False,
769
+ typical_mass: float = 0.2,
770
+ typical_min_tokens: int = 1,
771
+ return_probs: bool = False,
772
+ ):
773
  """Convenience function to sample from a categorial distribution with input as
774
  unnormalized logits.
775
 
 
797
  shp = logits.shape[:-1]
798
 
799
  if typical_filtering:
800
+ typical_filter(
801
+ logits, typical_mass=typical_mass, typical_min_tokens=typical_min_tokens
 
802
  )
803
 
804
  # Apply top_k sampling
 
841
  return token, token_probs
842
  else:
843
  return token
 
844
 
845
 
846
  def mask_by_random_topk(
847
+ num_to_mask: int,
848
+ probs: torch.Tensor,
849
+ temperature: float = 1.0,
850
+ ):
851
  """
852
  Args:
853
  num_to_mask (int): number of tokens to mask
854
  probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
855
  temperature (float, optional): temperature. Defaults to 1.0.
856
  """
857
+ logging.debug("masking by random topk")
858
  logging.debug(f"num to mask: {num_to_mask}")
859
  logging.debug(f"probs shape: {probs.shape}")
860
  logging.debug(f"temperature: {temperature}")
 
869
  logging.debug(f"sorted idx shape: {sorted_idx.shape}")
870
 
871
  # get the cut off threshold, given the mask length
872
+ cut_off = torch.take_along_dim(sorted_confidence, num_to_mask, axis=-1)
 
 
873
  logging.debug(f"cut off shape: {cut_off.shape}")
874
 
875
  # mask out the tokens
 
878
 
879
  return mask
880
 
881
+
882
  def typical_filter(
883
+ logits,
884
+ typical_mass: float = 0.95,
885
+ typical_min_tokens: int = 1,
886
+ ):
887
  nb, nt, _ = logits.shape
888
  x_flat = rearrange(logits, "b t l -> (b t ) l")
889
  x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
 
892
 
893
  c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
894
  c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
895
+ x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
 
 
896
 
897
  last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
898
  sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
 
925
  ).to(device)
926
 
927
  r = torch.zeros(batch_size).to(device)
928
+
929
  z_mask_latent = torch.rand(
930
  batch_size, model.latent_dim * model.n_codebooks, seq_len
931
  ).to(device)
 
934
  pred = z_hat.argmax(dim=1)
935
  pred = model.embedding.unflatten(pred, n_codebooks=model.n_predict_codebooks)
936
 
937
+ print(f"model has {num_params(model) / 1e6:<.3f}M parameters")
938
  print(f"prediction has shape {pred.shape}")
939
  breakpoint()
940
 
941
  args = argbind.parse_args()
942
  with argbind.scope(args):
943
  try_model()
 
 
vampnet/vampnet/scheduler.py CHANGED
@@ -1,8 +1,6 @@
1
- import copy
2
- from typing import List
3
-
4
  import torch
5
 
 
6
  class NoamScheduler:
7
  """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
8
  Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
@@ -44,4 +42,3 @@ class NoamScheduler:
44
 
45
  for p in self.optimizer.param_groups:
46
  p["lr"] = self.lr
47
-
 
 
 
 
1
  import torch
2
 
3
+
4
  class NoamScheduler:
5
  """OG scheduler from transformer paper: https://arxiv.org/pdf/1706.03762.pdf
6
  Implementation from Annotated Transformer: https://nlp.seas.harvard.edu/2018/04/03/attention.html
 
42
 
43
  for p in self.optimizer.param_groups:
44
  p["lr"] = self.lr
 
vampnet/vampnet/util.py CHANGED
@@ -1,43 +1,36 @@
1
- import tqdm
2
-
3
  import torch
 
4
  from einops import rearrange
5
 
 
6
  def scalar_to_batch_tensor(x, batch_size):
7
  return torch.tensor(x).repeat(batch_size)
8
 
9
 
10
- def parallelize(
11
- fn,
12
- *iterables,
13
- parallel: str = "thread_map",
14
- **kwargs
15
- ):
16
  if parallel == "thread_map":
17
  from tqdm.contrib.concurrent import thread_map
18
- return thread_map(
19
- fn,
20
- *iterables,
21
- **kwargs
22
- )
23
  elif parallel == "process_map":
24
  from tqdm.contrib.concurrent import process_map
25
- return process_map(
26
- fn,
27
- *iterables,
28
- **kwargs
29
- )
30
  elif parallel == "single":
31
  return [fn(x) for x in tqdm.tqdm(*iterables)]
32
  else:
33
- raise ValueError(f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}")
34
-
 
 
 
35
  def codebook_flatten(tokens: torch.Tensor):
36
- """
37
  flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
38
  """
39
  return rearrange(tokens, "b c t -> b (t c)")
40
 
 
41
  def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
42
  """
43
  unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
 
 
 
1
  import torch
2
+ import tqdm
3
  from einops import rearrange
4
 
5
+
6
  def scalar_to_batch_tensor(x, batch_size):
7
  return torch.tensor(x).repeat(batch_size)
8
 
9
 
10
+ def parallelize(fn, *iterables, parallel: str = "thread_map", **kwargs):
 
 
 
 
 
11
  if parallel == "thread_map":
12
  from tqdm.contrib.concurrent import thread_map
13
+
14
+ return thread_map(fn, *iterables, **kwargs)
 
 
 
15
  elif parallel == "process_map":
16
  from tqdm.contrib.concurrent import process_map
17
+
18
+ return process_map(fn, *iterables, **kwargs)
 
 
 
19
  elif parallel == "single":
20
  return [fn(x) for x in tqdm.tqdm(*iterables)]
21
  else:
22
+ raise ValueError(
23
+ f"parallel must be one of 'thread_map', 'process_map', 'single', but got {parallel}"
24
+ )
25
+
26
+
27
  def codebook_flatten(tokens: torch.Tensor):
28
+ """
29
  flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
30
  """
31
  return rearrange(tokens, "b c t -> b (t c)")
32
 
33
+
34
  def codebook_unflatten(flat_tokens: torch.Tensor, n_c: int = None):
35
  """
36
  unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
wham.egg-info/PKG-INFO ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: wham
3
+ Version: 0.0.1
4
+ Summary: Towards A Translative Model of Sperm Whale Vocalization
5
+ Author: Project CETI
6
+ License: MIT
7
+ Requires-Python: <3.11,>=3.10
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: torch
11
+ Requires-Dist: gradio
12
+ Requires-Dist: argbind>=0.3.2
13
+ Requires-Dist: numpy<1.24
14
+ Requires-Dist: pydantic<3,>=2.0
15
+ Requires-Dist: huggingface_hub
16
+ Requires-Dist: loralib
17
+ Requires-Dist: torch_pitch_shift
18
+ Requires-Dist: soundfile
19
+ Requires-Dist: pydub
20
+ Requires-Dist: tqdm
21
+ Requires-Dist: Cython
22
+ Requires-Dist: pandas
23
+ Requires-Dist: pathlib
24
+ Requires-Dist: ffmpeg-python
25
+ Requires-Dist: scikit-learn
26
+ Requires-Dist: wandb
27
+ Requires-Dist: gdown
28
+ Requires-Dist: transformers
29
+ Requires-Dist: fadtk
30
+ Requires-Dist: urllib3>=2.0.2
31
+ Requires-Dist: plotly
32
+ Requires-Dist: pyharp
33
+ Requires-Dist: ruff
34
+ Requires-Dist: wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
35
+ Requires-Dist: lac @ git+https://github.com/hugofloresgarcia/lac.git
36
+ Requires-Dist: descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
37
+ Dynamic: license-file
38
+
39
+ ---
40
+ title: WhAM
41
+ emoji: 🐋
42
+ colorFrom: blue
43
+ colorTo: indigo
44
+ sdk: docker
45
+ pinned: false
46
+ hardware: a10g-small
47
+ ---
48
+
49
+ # WhAM: a Whale Acoustics Model
50
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.02206-b31b1b.svg)](https://arxiv.org/abs/2512.02206)
51
+ [![Model Weights](https://img.shields.io/badge/Zenodo-Model%20Weights-blue.svg)](https://doi.org/10.5281/zenodo.17633708)
52
+ [![Hugging Face Dataset](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DSWP%20Dataset-yellow)](https://huggingface.co/datasets/orrp/DSWP)
53
+ ![WhAM](assets/inference.png "WhAM")
54
+ WhAM is a transformer-based audio-to-audio model designed to synthesize and analyze sperm whale codas. Based on [VampNet](https://github.com/hugofloresgarcia/vampnet), WhAM uses masked acoustic token modeling to capture temporal and spectral features of whale communication. WhAM generates codas from a given audio context, enabling three core capabilities:
55
+
56
+ - Acoustic Translation: The ability to style-transfer arbitrary audio prompts (e.g., human speech, noise) into the acoustic texture of sperm whale codas.
57
+
58
+ - Synthesizing novel "pseudocodas".
59
+
60
+ - Providing audio embeddings for downstream tasks such as social unit and spectral feature ("vowel") classification.
61
+
62
+ See our [NeurIPS 2025](https://openreview.net/pdf?id=IL1wvzOgqD) publication for more details.
63
+
64
+ ## Installation
65
+
66
+ 1. **Clone the repository:**
67
+ ```bash
68
+ git clone https://github.com/Project-CETI/wham.git
69
+ cd wham
70
+ ```
71
+
72
+ 2. **Set up the environment:**
73
+ ```bash
74
+ conda create -n wham python=3.9
75
+ conda activate wham
76
+ ```
77
+
78
+ 3. **Install dependencies:**
79
+ ```bash
80
+ # Install the wham package
81
+ pip install -e .
82
+
83
+ # Install VampNet
84
+ pip install -e ./vampnet
85
+
86
+ # Install madmom
87
+ pip install --no-build-isolation madmom
88
+
89
+ # Install ffmpeg
90
+ conda install -c conda-forge ffmpeg
91
+ ```
92
+
93
+ 4. **Download model weights:**
94
+ Download the [weights](https://zenodo.org/records/17633708) and extract to `vampnet/models/`.
95
+
96
+ ## Generation
97
+
98
+ To run WhAM locally and prompt it in your browser:
99
+
100
+ ```bash
101
+ python vampnet/app.py --args.load conf/interface.yml --Interface.device cuda
102
+ ```
103
+
104
+ This will provide you with a Gradio link to test WhAM on inputs of your choice.
105
+
106
+ ## Training Data
107
+
108
+ ![Training](assets/training.png "Training")
109
+
110
+ You only need to follow these to fine-tune your own version of WhAM. First, obtain the original VampNet weights by following the instructions in the ![original repo](https://github.com/hugofloresgarcia/vampnet/tree/ismir-2023). Download
111
+ c2f.pth and codec.pth and replace the weights you previously downloaded in `vampnet/models`.
112
+
113
+ Second, obtain data:
114
+
115
+ 1. **Domain adaptation data:**
116
+
117
+ - Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `vampnet/training_data/domain_adaptation`.
118
+
119
+ - Download audio samples from the [BirdSet Dataset](https://huggingface.co/datasets/DBD-research-group/BirdSet). Save these under the same directory
120
+
121
+ - Finally, download all samples from the [AudioSet Dataset](https://research.google.com/audioset/ontology/index.html) with the label `Animal` and once again save these into the directory
122
+
123
+ 3. **Species-specific finetuning:** Finetuning can be performed on the openly available **[Dominica Sperm Whale Project (DSWP)](https://huggingface.co/datasets/orrp/DSWP)** dataset, available on Hugging Face.
124
+
125
+
126
+ With data in hand, navigate into `vampnet` and perform Domain Adaptation:
127
+ ```bash
128
+ python vampnet/scripts/exp/fine_tune.py "training_data/domain_adaptation" domain_adapted && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/domain_adapted/c2f.yml
129
+ ```
130
+
131
+ Then fine-tune the domain-adapted model. Create the config file with the command:
132
+
133
+ ```bash
134
+ python vampnet/scripts/exp/fine_tune.py "training_data/species_specific_finetuning" fine-tuned
135
+ ```
136
+
137
+ To select which weights you want to use as a checkpoint, change `fine_tune_checkpoint` in `conf/generated/fine-tuned/[c2f/coarse].yml` to `./runs/domain_adaptation/[coarse/c2f]/[checkpoint]/vampnets/weights.pth`. `[checkpoint]` can be `latest` in order to use the last saved checkpoint from the previous run, though it is recommended to manually verify the quality of generations over various checkpoints as overtraining can often cause degradation in audio quality, especially with smaller datasets. After making that change, run the command:
138
+
139
+ ```bash
140
+ python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/coarse.yml && python vampnet/scripts/exp/train.py --args.load conf/generated/fine-tuned/c2f.yml
141
+ ```
142
+
143
+ After following these steps, you should be able to generate audio via the browser by running:
144
+ ```bash
145
+ python app.py --args.load vampnet/conf/generated/fine-tuned/interface.yml
146
+ ```
147
+
148
+ **Note**: The coarse and fine weights can be trained separately if compute allows. In this case, you would call the two scripts:
149
+
150
+ ```bash
151
+ python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/coarse.yml
152
+ ```
153
+
154
+ ```bash
155
+ python vampnet/scripts/exp/train.py --args.load conf/generated/[fine-tuned/domain_adaptated]/c2f.yml
156
+ ```
157
+
158
+ After both are finished running, ensure that both resulting weights are copied into the same copy of WhAM.
159
+
160
+
161
+
162
+ ## Testing Data
163
+
164
+ 1. **Marine Mammel Data:**
165
+ Download audio samples from the [WMMS 'Best Of' Cut](https://whoicf2.whoi.edu/science/B/whalesounds/index.cfm). Save them under `data/testing_data/marine_mammals/data/[SPECIES_NAME]`.
166
+ * `[SPECIES_NAME]` must match the species names found in `wham/generation/prompt_configs.py`.
167
+
168
+ 2. **Sperm Whale Codas:**
169
+ To evaluate on sperm whale codas, you can use the openly available [DSWP](https://huggingface.co/datasets/orrp/DSWP) dataset.
170
+
171
+ 3. Generate artifical beeps for experiments. `data/generate_beeps.sh`
172
+
173
+
174
+ ## Reproducing Paper Results
175
+ Note: Access to the DSWP+CETI annotated is required to reproduce all results; as of time of publication, only part of this data is publicly available. Still, we include the following code as it may be useful for researchers who may benefit from our evaluation pipeline.
176
+
177
+ ### 1. Downstream Classification Tasks
178
+ To reproduce **Table 1** (Classification Accuracies) and **Figure 7** (Ablation Study):
179
+
180
+ **Table 1 Results:**
181
+ ```bash
182
+ cd wham/embedding
183
+ ./downstream_tasks.sh
184
+ ```
185
+ * Runs all downstream classification tasks.
186
+ * **Baselines:** Run once.
187
+ * **Models (AVES, VampNet):** Run over 3 random seeds; reports mean and standard deviation.
188
+
189
+ **Figure 7 Results (Ablation):**
190
+ ```bash
191
+ cd wham/embedding
192
+ ./downstream_ablation.sh
193
+ ```
194
+ * Outputs accuracy scores for ablation variants (averaged across 3 seeds with error bars).
195
+
196
+ ### 2. Generative Metrics
197
+
198
+ **Figure 12: Frechet Audio Distance (FAD) Scores**
199
+ Calculate the distance between WhAM's generated results and real codas:
200
+ ```bash
201
+ # Calculate for all species
202
+ bash wham/generation/eval/calculate_FAD.sh
203
+
204
+ # Calculate for a single species
205
+ bash wham/generation/eval/calculate_FAD.sh [species_name]
206
+ ```
207
+ * *Runtime:* ~3 hours on an NVIDIA A10 GPU.
208
+
209
+ **Figure 3: FAD with Custom/BirdNET Embeddings**
210
+ To compare against other embeddings:
211
+ 1. Convert your `.wav` files to `.npy` embeddings.
212
+ 2. Place raw coda embeddings in: `data/testing_data/coda_embeddings`
213
+ 3. Place comparison embeddings in subfolders within: `data/testing_data/comparison_embeddings`
214
+ 4. Run:
215
+ ```bash
216
+ python wham/generation/eval/calculate_custom_fad.py
217
+ ```
218
+ *For BirdNET embeddings, refer to the [official repo](https://github.com/BirdNET-Team/BirdNET-Analyzer).*
219
+
220
+ **Table 2: Embedding Type Ablation**
221
+ Calculate distances between raw codas, denoised versions, and noise profiles:
222
+ ```bash
223
+ bash wham/generation/eval/FAD_ablation.sh
224
+ ```
225
+ * *Prerequisites:* Ensure `data/testing_data/ablation/noise` and `data/testing_data/ablation/denoised` are populated.
226
+ * *Runtime:* ~1.5 hours on an NVIDIA A10 GPU.
227
+
228
+ **Figure 13: Tokenizer Reconstruction**
229
+ Test the mean squared reconstruction error:
230
+ ```bash
231
+ bash wham/generation/eval/evaluate_tokenizer.sh
232
+ ```
233
+
234
+ ---
235
+
236
+ ## Citation
237
+
238
+ Please use the following citation if you use this code, model or data.
239
+
240
+ ```bibtex
241
+ @inproceedings{wham2025,
242
+ title={Towards A Translative Model of Sperm Whale Vocalization},
243
+ author={Orr Paradise, Pranav Muralikrishnan, Liangyuan Chen, Hugo Flores Garcia, Bryan Pardo, Roee Diamant, David F. Gruber, Shane Gero, Shafi Goldwasser},
244
+ booktitle={Advances in Neural Information Processing Systems 39: Annual Conference
245
+ on Neural Information Processing Systems 2025, NeurIPS 2025, San Diego, CA, USA},
246
+ year={2025}
247
+ }
248
+ ```
wham.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ vampnet/app.py
5
+ vampnet/setup.py
6
+ vampnet/scripts/exp/eval.py
7
+ vampnet/scripts/exp/experiment.py
8
+ vampnet/scripts/exp/fine_tune.py
9
+ vampnet/scripts/exp/train.py
10
+ vampnet/scripts/utils/plots.py
11
+ vampnet/scripts/utils/remove_quiet_files.py
12
+ vampnet/scripts/utils/split.py
13
+ vampnet/scripts/utils/split_long_audio_file.py
14
+ vampnet/scripts/utils/stage.py
15
+ vampnet/scripts/utils/visualize_embeddings.py
16
+ vampnet/scripts/utils/xeno-canto-dl.py
17
+ vampnet/scripts/utils/data/augment.py
18
+ vampnet/scripts/utils/data/maestro-reorg.py
19
+ vampnet/vampnet/__init__.py
20
+ vampnet/vampnet/beats.py
21
+ vampnet/vampnet/interface.py
22
+ vampnet/vampnet/mask.py
23
+ vampnet/vampnet/scheduler.py
24
+ vampnet/vampnet/util.py
25
+ vampnet/vampnet/modules/__init__.py
26
+ vampnet/vampnet/modules/activations.py
27
+ vampnet/vampnet/modules/layers.py
28
+ vampnet/vampnet/modules/transformer.py
29
+ wham.egg-info/PKG-INFO
30
+ wham.egg-info/SOURCES.txt
31
+ wham.egg-info/dependency_links.txt
32
+ wham.egg-info/requires.txt
33
+ wham.egg-info/top_level.txt
wham.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
wham.egg-info/requires.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ argbind>=0.3.2
4
+ numpy<1.24
5
+ pydantic<3,>=2.0
6
+ huggingface_hub
7
+ loralib
8
+ torch_pitch_shift
9
+ soundfile
10
+ pydub
11
+ tqdm
12
+ Cython
13
+ pandas
14
+ pathlib
15
+ ffmpeg-python
16
+ scikit-learn
17
+ wandb
18
+ gdown
19
+ transformers
20
+ fadtk
21
+ urllib3>=2.0.2
22
+ plotly
23
+ pyharp
24
+ ruff
25
+ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat.git
26
+ lac @ git+https://github.com/hugofloresgarcia/lac.git
27
+ descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
wham.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ vampnet