# /// script # requires-python = ">=3.10" # dependencies = [ # "audeer>=1.0.0", # "audinterface>=1.0.0", # "audonnx>=0.7.0", # "audiofile>=1.0.0", # "click>=8.0.0", # "numpy>=1.20.0", # "torch>=1.0.0", # ] # /// """Load a nkululeko MLP model trained on audwav2vec2 embeddings and predict on an audio file.""" import os from collections import OrderedDict import audeer import audinterface import audonnx import audiofile import click import numpy as np import torch import torch.nn as nn W2V2_MODEL_URL = "https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip" CATEGORIES = ["happy", "angry", "sad", "scared", "neutral"] def load_mlp_from_checkpoint(path: str, device: str) -> nn.Module: """Build and load an MLP whose architecture is inferred from the checkpoint.""" state = torch.load(path, map_location=device) weight_keys = sorted( [k for k in state if k.endswith(".weight")], key=lambda k: int(k.split(".")[1]), ) od = OrderedDict() for i, key in enumerate(weight_keys): out_size, in_size = state[key].shape idx = key.split(".")[1] # "0", "1", "3", ... od[idx] = nn.Linear(in_size, out_size) if i < len(weight_keys) - 1: od[f"{idx}_r"] = nn.ReLU() class _MLP(nn.Module): def __init__(self): super().__init__() self.linear = nn.Sequential(od) def forward(self, x): return self.linear(x.squeeze(dim=1).float()) model = _MLP() model.load_state_dict(state) model.to(device) model.eval() return model def load_w2v2(model_root: str, device: str) -> audinterface.Feature: if not os.path.isdir(model_root): click.echo(f"Downloading w2v2 model to {model_root} ...") cache_root = audeer.mkdir("cache") audeer.mkdir(model_root) archive_path = audeer.download_url(W2V2_MODEL_URL, cache_root, verbose=True) audeer.extract_archive(archive_path, model_root) model = audonnx.load(model_root, device=device) return audinterface.Feature( model.labels("hidden_states"), process_func=model, process_func_args={"outputs": "hidden_states"}, sampling_rate=16000, resample=False, # audio is already 16 kHz verbose=False, ) def predict(audio_path: str, model_path: str, w2v2_root: str) -> dict: device = "cuda" if torch.cuda.is_available() else "cpu" click.echo(f"Device: {device}") click.echo(f"Loading w2v2 feature extractor from {w2v2_root} ...") w2v2 = load_w2v2(w2v2_root, device) click.echo(f"Loading MLP model from {model_path} ...") mlp = load_mlp_from_checkpoint(model_path, device) click.echo(f"Reading audio from {audio_path} ...") signal, sr = audiofile.read(audio_path) click.echo(f" samples: {signal.shape}, sample rate: {sr} Hz") features = np.asarray(w2v2.process_signal(signal, sr).values).flatten() features = np.nan_to_num(features) with torch.no_grad(): x = torch.from_numpy(features.reshape(1, -1)).float().to(device) logits = mlp(x).cpu().numpy()[0] scores = {cat: float(logits[i]) for i, cat in enumerate(CATEGORIES)} predicted = max(scores, key=scores.get) return {"scores": scores, "predicted": predicted} @click.command() @click.argument("model", type=click.Path(exists=True, dir_okay=False)) @click.argument("audio", type=click.Path(exists=True, dir_okay=False)) @click.option( "--w2v2-root", default="./audmodel/", show_default=True, metavar="DIR", help="Directory where the w2v2 onnx model is cached or will be downloaded to.", ) @click.help_option("-h", "--help") def main(model, audio, w2v2_root): """Predict emotion from an audio file using a nkululeko MLP + audwav2vec2 model. \b MODEL Path to the .model file (torch state dict saved by nkululeko). AUDIO Path to the audio file (must be 16 kHz mono WAV). \b Example: uv run test_model.py my_experiment_0_011.model sample.wav uv run test_model.py my_experiment_0_011.model sample.wav --w2v2-root /data/audmodel/ """ result = predict(audio, model, w2v2_root) click.echo("\n--- Results ---") for cat, score in sorted(result["scores"].items(), key=lambda x: -x[1]): marker = " <-- predicted" if cat == result["predicted"] else "" click.echo(f" {cat:<10} {score:+.4f}{marker}") if __name__ == "__main__": main()