|
|
""" |
|
|
TODO: train a linear probe |
|
|
usage: |
|
|
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_gtzan /path/to/gtzan/genres_original --output_dir /path/to/output |
|
|
""" |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
|
|
|
import audiotools as at |
|
|
from audiotools import AudioSignal |
|
|
import argbind |
|
|
import torch |
|
|
import numpy as np |
|
|
import zipfile |
|
|
import json |
|
|
|
|
|
from vampnet.interface import Interface |
|
|
import tqdm |
|
|
|
|
|
|
|
|
Interface = argbind.bind(Interface) |
|
|
|
|
|
DEBUG = False |
|
|
|
|
|
def smart_plotly_export(fig, save_path): |
|
|
img_format = save_path.split('.')[-1] |
|
|
if img_format == 'html': |
|
|
fig.write_html(save_path) |
|
|
elif img_format == 'bytes': |
|
|
return fig.to_image(format='png') |
|
|
|
|
|
elif img_format == 'numpy': |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
def plotly_fig2array(fig): |
|
|
|
|
|
fig_bytes = fig.to_image(format="png", width=1200, height=700) |
|
|
buf = io.BytesIO(fig_bytes) |
|
|
img = Image.open(buf) |
|
|
return np.asarray(img) |
|
|
|
|
|
return plotly_fig2array(fig) |
|
|
elif img_format == 'jpeg' or 'png' or 'webp': |
|
|
fig.write_image(save_path) |
|
|
else: |
|
|
raise ValueError("invalid image format") |
|
|
|
|
|
def dim_reduce(emb, labels, save_path, n_components=3, method='tsne', title=''): |
|
|
""" |
|
|
dimensionality reduction for visualization! |
|
|
saves an html plotly figure to save_path |
|
|
parameters: |
|
|
emb (np.ndarray): the samples to be reduces with shape (samples, features) |
|
|
labels (list): list of labels for embedding |
|
|
save_path (str): path where u wanna save ur figure |
|
|
method (str): umap, tsne, or pca |
|
|
title (str): title for ur figure |
|
|
returns: |
|
|
proj (np.ndarray): projection vector with shape (samples, dimensions) |
|
|
""" |
|
|
import pandas as pd |
|
|
import plotly.express as px |
|
|
if method == 'umap': |
|
|
from umap import UMAP |
|
|
reducer = umap.UMAP(n_components=n_components) |
|
|
elif method == 'tsne': |
|
|
from sklearn.manifold import TSNE |
|
|
reducer = TSNE(n_components=n_components) |
|
|
elif method == 'pca': |
|
|
from sklearn.decomposition import PCA |
|
|
reducer = PCA(n_components=n_components) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
proj = reducer.fit_transform(emb) |
|
|
|
|
|
if n_components == 2: |
|
|
df = pd.DataFrame(dict( |
|
|
x=proj[:, 0], |
|
|
y=proj[:, 1], |
|
|
instrument=labels |
|
|
)) |
|
|
fig = px.scatter(df, x='x', y='y', color='instrument', |
|
|
title=title+f"_{method}") |
|
|
|
|
|
elif n_components == 3: |
|
|
df = pd.DataFrame(dict( |
|
|
x=proj[:, 0], |
|
|
y=proj[:, 1], |
|
|
z=proj[:, 2], |
|
|
instrument=labels |
|
|
)) |
|
|
fig = px.scatter_3d(df, x='x', y='y', z='z', |
|
|
color='instrument', |
|
|
title=title) |
|
|
else: |
|
|
raise ValueError("cant plot more than 3 components") |
|
|
|
|
|
fig.update_traces(marker=dict(size=6, |
|
|
line=dict(width=1, |
|
|
color='DarkSlateGrey')), |
|
|
selector=dict(mode='markers')) |
|
|
|
|
|
return smart_plotly_export(fig, save_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10): |
|
|
with torch.inference_mode(): |
|
|
|
|
|
sig = interface.preprocess(sig) |
|
|
|
|
|
|
|
|
vampnet = interface.coarse |
|
|
|
|
|
|
|
|
z = interface.encode(sig)[:, :vampnet.n_codebooks, :] |
|
|
z_latents = vampnet.embedding.from_codes(z, interface.codec) |
|
|
|
|
|
|
|
|
_z, embeddings = vampnet(z_latents, return_activations=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}" |
|
|
embeddings = embeddings.squeeze(1) |
|
|
|
|
|
num_layers = embeddings.shape[0] |
|
|
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers" |
|
|
|
|
|
|
|
|
embeddings = embeddings.mean(dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
return embeddings |
|
|
|
|
|
from dataclasses import dataclass, fields |
|
|
@dataclass |
|
|
class Embedding: |
|
|
genre: str |
|
|
filename: str |
|
|
embedding: np.ndarray |
|
|
|
|
|
def save(self, path): |
|
|
"""Save the Embedding object to a given path as a zip file.""" |
|
|
with zipfile.ZipFile(path, 'w') as archive: |
|
|
|
|
|
|
|
|
with archive.open('embedding.npy', 'w') as f: |
|
|
np.save(f, self.embedding) |
|
|
|
|
|
|
|
|
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'} |
|
|
with archive.open('data.json', 'w') as f: |
|
|
f.write(json.dumps(non_numpy_data).encode('utf-8')) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path): |
|
|
"""Load the Embedding object from a given zip path.""" |
|
|
with zipfile.ZipFile(path, 'r') as archive: |
|
|
|
|
|
|
|
|
with archive.open('embedding.npy') as f: |
|
|
embedding = np.load(f) |
|
|
|
|
|
|
|
|
with archive.open('data.json') as f: |
|
|
data = json.loads(f.read().decode('utf-8')) |
|
|
|
|
|
return cls(embedding=embedding, **data) |
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def main( |
|
|
path_to_gtzan: str = None, |
|
|
cache_dir: str = "./.gtzan_emb_cache", |
|
|
output_dir: str = "./gtzan_vampnet_embeddings", |
|
|
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19] |
|
|
): |
|
|
path_to_gtzan = Path(path_to_gtzan) |
|
|
assert path_to_gtzan.exists(), f"{path_to_gtzan} does not exist" |
|
|
|
|
|
cache_dir = Path(cache_dir) |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
|
|
interface = Interface() |
|
|
|
|
|
|
|
|
genres = [Path(x).name for x in path_to_gtzan.iterdir() if x.is_dir()] |
|
|
print(f"Found {len(genres)} genres") |
|
|
print(f"genres: {genres}") |
|
|
|
|
|
|
|
|
data = [] |
|
|
for genre in genres: |
|
|
audio_files = list(at.util.find_audio(path_to_gtzan / genre)) |
|
|
print(f"Found {len(audio_files)} audio files for genre {genre}") |
|
|
|
|
|
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding genre {genre}"): |
|
|
|
|
|
cached_path = (cache_dir / f"{genre}_{audio_file.stem}.emb") |
|
|
if cached_path.exists(): |
|
|
|
|
|
if DEBUG: |
|
|
print(f"loading cached embedding for {cached_path.stem}") |
|
|
embedding = Embedding.load(cached_path) |
|
|
else: |
|
|
try: |
|
|
sig = AudioSignal(audio_file) |
|
|
except Exception as e: |
|
|
print(f"failed to load {audio_file.name} with error {e}") |
|
|
print(f"skipping {audio_file.name}") |
|
|
continue |
|
|
|
|
|
|
|
|
emb = vampnet_embed(sig, interface).cpu().numpy() |
|
|
|
|
|
|
|
|
embedding = Embedding( |
|
|
genre=genre, |
|
|
filename=audio_file.name, |
|
|
embedding=emb |
|
|
) |
|
|
|
|
|
|
|
|
cached_path.parent.mkdir(exist_ok=True, parents=True) |
|
|
embedding.save(cached_path) |
|
|
data.append(embedding) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embeddings = [d.embedding for d in data] |
|
|
labels = [d.genre for d in data] |
|
|
|
|
|
|
|
|
embeddings = np.stack(embeddings) |
|
|
|
|
|
|
|
|
for layer in tqdm.tqdm(layers, desc="dim reduction"): |
|
|
dim_reduce( |
|
|
embeddings[:, layer, :], labels, |
|
|
save_path=str(output_dir / f'vampnet-gtzan-layer={layer}.html'), |
|
|
n_components=2, method='tsne', |
|
|
title=f'vampnet-gtzan-layer={layer}' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argbind.parse_args() |
|
|
with argbind.scope(args): |
|
|
main() |