|
|
""" |
|
|
TODO: train a linear probe |
|
|
usage: |
|
|
python gtzan_embeddings.py --args.load conf/interface.yml --Interface.device cuda --path_to_audio /path/to/audio/labels --output_dir /path/to/output |
|
|
""" |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
|
|
|
import audiotools as at |
|
|
from audiotools import AudioSignal |
|
|
import argbind |
|
|
import torch |
|
|
import numpy as np |
|
|
import zipfile |
|
|
import json |
|
|
|
|
|
from vampnet.interface import Interface |
|
|
import tqdm |
|
|
|
|
|
|
|
|
Interface = argbind.bind(Interface) |
|
|
|
|
|
DEBUG = False |
|
|
|
|
|
|
|
|
def smart_plotly_export(fig, save_path: Path): |
|
|
img_format = save_path.suffix[1:] |
|
|
if img_format == "html": |
|
|
fig.write_html(save_path) |
|
|
elif img_format == 'bytes': |
|
|
return fig.to_image(format='png') |
|
|
|
|
|
elif img_format == 'numpy': |
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
def plotly_fig2array(fig): |
|
|
|
|
|
fig_bytes = fig.to_image(format="png", width=1200, height=700) |
|
|
buf = io.BytesIO(fig_bytes) |
|
|
img = Image.open(buf) |
|
|
return np.asarray(img) |
|
|
|
|
|
return plotly_fig2array(fig) |
|
|
elif img_format == 'jpeg' or 'png' or 'webp': |
|
|
fig.write_image(save_path) |
|
|
else: |
|
|
raise ValueError("invalid image format") |
|
|
|
|
|
|
|
|
def dim_reduce(annotated_embeddings, layer, output_dir, n_components=3, method="tsne"): |
|
|
""" |
|
|
dimensionality reduction for visualization! |
|
|
saves an html plotly figure to save_path |
|
|
parameters: |
|
|
annotated_embeddings (list): the annotated enmbeddings to be reduced; embeddings have shape (samples, features) |
|
|
labels (list): list of labels for embedding |
|
|
save_path (str): path where u wanna save ur figure |
|
|
method (str): umap, tsne, or pca |
|
|
title (str): title for ur figure |
|
|
returns: |
|
|
proj (np.ndarray): projection vector with shape (samples, dimensions) |
|
|
""" |
|
|
import pandas as pd |
|
|
import plotly.express as px |
|
|
|
|
|
fig_name = f"vampnet-embeddings-layer={layer}" |
|
|
fig_title = f"{fig_name}_{method}" |
|
|
save_path = (output_dir / fig_name).with_suffix(".html") |
|
|
|
|
|
if method == "umap": |
|
|
from umap import UMAP |
|
|
reducer = umap.UMAP(n_components=n_components) |
|
|
elif method == "tsne": |
|
|
from sklearn.manifold import TSNE |
|
|
|
|
|
reducer = TSNE(n_components=n_components) |
|
|
elif method == "pca": |
|
|
from sklearn.decomposition import PCA |
|
|
|
|
|
reducer = PCA(n_components=n_components) |
|
|
else: |
|
|
raise ValueError(f"invalid method: {method}") |
|
|
|
|
|
labels = [emb.label for emb in annotated_embeddings] |
|
|
names = [emb.filename for emb in annotated_embeddings] |
|
|
embs = [emb.embedding for emb in annotated_embeddings] |
|
|
embs_at_layer = np.stack(embs)[:, layer, :] |
|
|
projs = reducer.fit_transform(embs_at_layer) |
|
|
|
|
|
df = pd.DataFrame( |
|
|
{ |
|
|
"label": labels, |
|
|
"name": names, |
|
|
"x": projs[:, 0], |
|
|
"y": projs[:, 1], |
|
|
} |
|
|
) |
|
|
if n_components == 2: |
|
|
fig = px.scatter( |
|
|
df, x="x", y="y", color="label", hover_name="name", title=fig_title, |
|
|
) |
|
|
|
|
|
elif n_components == 3: |
|
|
df['z'] = projs[:, 2] |
|
|
fig = px.scatter_3d( |
|
|
df, x="x", y="y", z="z", color="label", hover_name="name", title=fig_title |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"can't plot {n_components} components") |
|
|
|
|
|
fig.update_traces( |
|
|
marker=dict(size=6, line=dict(width=1, color="DarkSlateGrey")), |
|
|
selector=dict(mode="markers"), |
|
|
) |
|
|
|
|
|
return smart_plotly_export(fig, save_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vampnet_embed(sig: AudioSignal, interface: Interface, layer=10): |
|
|
with torch.inference_mode(): |
|
|
|
|
|
sig = interface.preprocess(sig) |
|
|
|
|
|
|
|
|
vampnet = interface.coarse |
|
|
|
|
|
|
|
|
z = interface.encode(sig)[:, :vampnet.n_codebooks, :] |
|
|
z_latents = vampnet.embedding.from_codes(z, interface.codec) |
|
|
|
|
|
|
|
|
_z, embeddings = vampnet(z_latents, return_activations=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert embeddings.shape[1] == 1, f"expected batch dim to be 1, got {embeddings.shape[0]}" |
|
|
embeddings = embeddings.squeeze(1) |
|
|
|
|
|
num_layers = embeddings.shape[0] |
|
|
assert layer < num_layers, f"layer {layer} is out of bounds for model with {num_layers} layers" |
|
|
|
|
|
|
|
|
embeddings = embeddings.mean(dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
return embeddings |
|
|
|
|
|
from dataclasses import dataclass, fields |
|
|
@dataclass |
|
|
class AnnotatedEmbedding: |
|
|
label: str |
|
|
filename: str |
|
|
embedding: np.ndarray |
|
|
|
|
|
def save(self, path): |
|
|
"""Save the Embedding object to a given path as a zip file.""" |
|
|
with zipfile.ZipFile(path, 'w') as archive: |
|
|
|
|
|
|
|
|
with archive.open('embedding.npy', 'w') as f: |
|
|
np.save(f, self.embedding) |
|
|
|
|
|
|
|
|
non_numpy_data = {f.name: getattr(self, f.name) for f in fields(self) if f.name != 'embedding'} |
|
|
with archive.open('data.json', 'w') as f: |
|
|
f.write(json.dumps(non_numpy_data).encode('utf-8')) |
|
|
|
|
|
@classmethod |
|
|
def load(cls, path): |
|
|
"""Load the Embedding object from a given zip path.""" |
|
|
with zipfile.ZipFile(path, 'r') as archive: |
|
|
|
|
|
|
|
|
with archive.open('embedding.npy') as f: |
|
|
embedding = np.load(f) |
|
|
|
|
|
|
|
|
with archive.open('data.json') as f: |
|
|
data = json.loads(f.read().decode('utf-8')) |
|
|
|
|
|
return cls(embedding=embedding, **data) |
|
|
|
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def main( |
|
|
path_to_audio: str = None, |
|
|
cache_dir: str = "./.emb_cache", |
|
|
output_dir: str = "./vampnet_embeddings", |
|
|
layers: List[int] = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19], |
|
|
method: str = "tsne", |
|
|
n_components: int = 2, |
|
|
): |
|
|
path_to_audio = Path(path_to_audio) |
|
|
assert path_to_audio.exists(), f"{path_to_audio} does not exist" |
|
|
|
|
|
cache_dir = Path(cache_dir) |
|
|
output_dir = Path(output_dir) |
|
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
|
|
interface = Interface() |
|
|
|
|
|
|
|
|
labels = [Path(x).name for x in path_to_audio.iterdir() if x.is_dir()] |
|
|
print(f"Found {len(labels)} labels") |
|
|
print(f"labels: {labels}") |
|
|
|
|
|
|
|
|
annotated_embeddings = [] |
|
|
for label in labels: |
|
|
audio_files = list(at.util.find_audio(path_to_audio / label)) |
|
|
print(f"Found {len(audio_files)} audio files for label {label}") |
|
|
|
|
|
for audio_file in tqdm.tqdm(audio_files, desc=f"embedding label {label}"): |
|
|
|
|
|
cached_path = cache_dir / f"{label}_{audio_file.stem}.emb" |
|
|
if cached_path.exists(): |
|
|
|
|
|
if DEBUG: |
|
|
print(f"loading cached embedding for {cached_path.stem}") |
|
|
embedding = AnnotatedEmbedding.load(cached_path) |
|
|
else: |
|
|
try: |
|
|
sig = AudioSignal(audio_file) |
|
|
except Exception as e: |
|
|
print(f"failed to load {audio_file.name} with error {e}") |
|
|
print(f"skipping {audio_file.name}") |
|
|
continue |
|
|
|
|
|
|
|
|
emb = vampnet_embed(sig, interface).cpu().numpy() |
|
|
|
|
|
|
|
|
embedding = AnnotatedEmbedding( |
|
|
label=label, filename=audio_file.name, embedding=emb |
|
|
) |
|
|
|
|
|
|
|
|
cached_path.parent.mkdir(exist_ok=True, parents=True) |
|
|
embedding.save(cached_path) |
|
|
annotated_embeddings.append(embedding) |
|
|
|
|
|
|
|
|
for layer in tqdm.tqdm(layers, desc="dim reduction"): |
|
|
dim_reduce( |
|
|
annotated_embeddings, |
|
|
layer, |
|
|
output_dir=output_dir, |
|
|
n_components=n_components, |
|
|
method=method, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argbind.parse_args() |
|
|
with argbind.scope(args): |
|
|
main() |
|
|
|