|
|
from pathlib import Path |
|
|
import random |
|
|
from typing import List |
|
|
import tempfile |
|
|
import subprocess |
|
|
|
|
|
import argbind |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
|
|
|
from vampnet.interface import Interface |
|
|
from vampnet import mask as pmask |
|
|
import audiotools as at |
|
|
|
|
|
Interface: Interface = argbind.bind(Interface) |
|
|
|
|
|
|
|
|
|
|
|
def calculate_bitrate( |
|
|
interface, num_codebooks, |
|
|
downsample_factor |
|
|
): |
|
|
bit_width = 10 |
|
|
sr = interface.codec.sample_rate |
|
|
hop = interface.codec.hop_size |
|
|
rate = (sr / hop) * ((bit_width * num_codebooks) / downsample_factor) |
|
|
return rate |
|
|
|
|
|
def baseline(sig, interface): |
|
|
return interface.preprocess(sig) |
|
|
|
|
|
def reconstructed(sig, interface): |
|
|
return interface.decode( |
|
|
interface.encode(sig) |
|
|
) |
|
|
|
|
|
def coarse2fine(sig, interface): |
|
|
z = interface.encode(sig) |
|
|
z = z[:, :interface.c2f.n_conditioning_codebooks, :] |
|
|
|
|
|
z = interface.coarse_to_fine(z) |
|
|
return interface.decode(z) |
|
|
|
|
|
class CoarseCond: |
|
|
|
|
|
def __init__(self, num_conditioning_codebooks, downsample_factor): |
|
|
self.num_conditioning_codebooks = num_conditioning_codebooks |
|
|
self.downsample_factor = downsample_factor |
|
|
|
|
|
def __call__(self, sig, interface): |
|
|
z = interface.encode(sig) |
|
|
mask = pmask.full_mask(z) |
|
|
mask = pmask.codebook_unmask(mask, self.num_conditioning_codebooks) |
|
|
mask = pmask.periodic_mask(mask, self.downsample_factor) |
|
|
|
|
|
zv = interface.coarse_vamp(z, mask) |
|
|
zv = interface.coarse_to_fine(zv) |
|
|
return interface.decode(zv) |
|
|
|
|
|
def opus(sig, interface, bitrate=128): |
|
|
sig = interface.preprocess(sig) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav") as f: |
|
|
sig.write(f.name) |
|
|
|
|
|
opus_name = Path(f.name).with_suffix(".opus") |
|
|
|
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-i", f.name, |
|
|
"-c:a", "libopus", |
|
|
"-b:a", f"{bitrate}", |
|
|
opus_name |
|
|
] |
|
|
subprocess.run(cmd, check=True) |
|
|
|
|
|
|
|
|
output_name = Path(f"{f.name}-opus").with_suffix(".wav") |
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-i", opus_name, |
|
|
output_name |
|
|
] |
|
|
|
|
|
subprocess.run(cmd, check=True) |
|
|
|
|
|
sig = at.AudioSignal( |
|
|
output_name, |
|
|
sample_rate=sig.sample_rate |
|
|
) |
|
|
return sig |
|
|
|
|
|
def mask_ratio_1_step(ratio=1.0): |
|
|
def wrapper(sig, interface): |
|
|
z = interface.encode(sig) |
|
|
mask = pmask.linear_random(z, ratio) |
|
|
zv = interface.coarse_vamp( |
|
|
z, |
|
|
mask, |
|
|
sampling_steps=1, |
|
|
) |
|
|
|
|
|
return interface.decode(zv) |
|
|
return wrapper |
|
|
|
|
|
def num_sampling_steps(num_steps=1): |
|
|
def wrapper(sig, interface: Interface): |
|
|
z = interface.encode(sig) |
|
|
mask = pmask.periodic_mask(z, 16) |
|
|
zv = interface.coarse_vamp( |
|
|
z, |
|
|
mask, |
|
|
sampling_steps=num_steps, |
|
|
) |
|
|
|
|
|
zv = interface.coarse_to_fine(zv) |
|
|
return interface.decode(zv) |
|
|
return wrapper |
|
|
|
|
|
def beat_mask(ctx_time): |
|
|
def wrapper(sig, interface): |
|
|
beat_mask = interface.make_beat_mask( |
|
|
sig, |
|
|
before_beat_s=ctx_time/2, |
|
|
after_beat_s=ctx_time/2, |
|
|
invert=True |
|
|
) |
|
|
|
|
|
z = interface.encode(sig) |
|
|
|
|
|
zv = interface.coarse_vamp( |
|
|
z, beat_mask |
|
|
) |
|
|
|
|
|
zv = interface.coarse_to_fine(zv) |
|
|
return interface.decode(zv) |
|
|
return wrapper |
|
|
|
|
|
def inpaint(ctx_time): |
|
|
def wrapper(sig, interface: Interface): |
|
|
z = interface.encode(sig) |
|
|
mask = pmask.inpaint(z, interface.s2t(ctx_time), interface.s2t(ctx_time)) |
|
|
|
|
|
zv = interface.coarse_vamp(z, mask) |
|
|
zv = interface.coarse_to_fine(zv) |
|
|
|
|
|
return interface.decode(zv) |
|
|
return wrapper |
|
|
|
|
|
def token_noise(noise_amt): |
|
|
def wrapper(sig, interface: Interface): |
|
|
z = interface.encode(sig) |
|
|
mask = pmask.random(z, noise_amt) |
|
|
z = torch.where( |
|
|
mask, |
|
|
torch.randint_like(z, 0, interface.coarse.vocab_size), |
|
|
z |
|
|
) |
|
|
return interface.decode(z) |
|
|
return wrapper |
|
|
|
|
|
EXP_REGISTRY = {} |
|
|
|
|
|
EXP_REGISTRY["gen-compression"] = { |
|
|
"baseline": baseline, |
|
|
"reconstructed": reconstructed, |
|
|
"coarse2fine": coarse2fine, |
|
|
**{ |
|
|
f"{n}_codebooks_downsampled_{x}x": CoarseCond(num_conditioning_codebooks=n, downsample_factor=x) |
|
|
for (n, x) in ( |
|
|
(1, 1), |
|
|
(4, 4), |
|
|
(4, 16), |
|
|
(4, 32), |
|
|
) |
|
|
}, |
|
|
**{ |
|
|
f"token_noise_{x}": mask_ratio_1_step(ratio=x) |
|
|
for x in [0.25, 0.5, 0.75] |
|
|
}, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
EXP_REGISTRY["sampling-steps"] = { |
|
|
|
|
|
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 36, 64, 72]}, |
|
|
} |
|
|
|
|
|
|
|
|
EXP_REGISTRY["musical-sampling"] = { |
|
|
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]}, |
|
|
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, |
|
|
} |
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def main( |
|
|
sources=[ |
|
|
"/media/CHONK/hugo/spotdl/val", |
|
|
], |
|
|
output_dir: str = "./samples", |
|
|
max_excerpts: int = 2000, |
|
|
exp_type: str = "gen-compression", |
|
|
seed: int = 0, |
|
|
ext: str = [".mp3"], |
|
|
): |
|
|
at.util.seed(seed) |
|
|
interface = Interface() |
|
|
|
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
from audiotools.data.datasets import AudioLoader, AudioDataset |
|
|
|
|
|
loader = AudioLoader(sources=sources, shuffle_state=seed, ext=ext) |
|
|
dataset = AudioDataset(loader, |
|
|
sample_rate=interface.codec.sample_rate, |
|
|
duration=interface.coarse.chunk_size_s, |
|
|
n_examples=max_excerpts, |
|
|
without_replacement=True, |
|
|
) |
|
|
|
|
|
if exp_type in EXP_REGISTRY: |
|
|
SAMPLE_CONDS = EXP_REGISTRY[exp_type] |
|
|
else: |
|
|
raise ValueError(f"Unknown exp_type {exp_type}") |
|
|
|
|
|
|
|
|
indices = list(range(max_excerpts)) |
|
|
random.shuffle(indices) |
|
|
for i in tqdm(indices): |
|
|
|
|
|
done = [] |
|
|
for name in SAMPLE_CONDS: |
|
|
o_dir = Path(output_dir) / name |
|
|
done.append((o_dir / f"{i}.wav").exists()) |
|
|
if all(done): |
|
|
continue |
|
|
|
|
|
sig = dataset[i]["signal"] |
|
|
results = { |
|
|
name: cond(sig, interface).cpu() |
|
|
for name, cond in SAMPLE_CONDS.items() |
|
|
} |
|
|
|
|
|
for name, sig in results.items(): |
|
|
o_dir = Path(output_dir) / name |
|
|
o_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
sig.write(o_dir / f"{i}.wav") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argbind.parse_args() |
|
|
|
|
|
with argbind.scope(args): |
|
|
main() |
|
|
|