| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Load a nkululeko MLP model trained on audwav2vec2 embeddings and predict on an audio file.""" |
|
|
| import os |
| from collections import OrderedDict |
|
|
| import audeer |
| import audinterface |
| import audonnx |
| import audiofile |
| import click |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
|
|
| W2V2_MODEL_URL = "https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip" |
| CATEGORIES = ["happy", "angry", "sad", "scared", "neutral"] |
|
|
|
|
| def load_mlp_from_checkpoint(path: str, device: str) -> nn.Module: |
| """Build and load an MLP whose architecture is inferred from the checkpoint.""" |
| state = torch.load(path, map_location=device) |
| weight_keys = sorted( |
| [k for k in state if k.endswith(".weight")], |
| key=lambda k: int(k.split(".")[1]), |
| ) |
| od = OrderedDict() |
| for i, key in enumerate(weight_keys): |
| out_size, in_size = state[key].shape |
| idx = key.split(".")[1] |
| od[idx] = nn.Linear(in_size, out_size) |
| if i < len(weight_keys) - 1: |
| od[f"{idx}_r"] = nn.ReLU() |
|
|
| class _MLP(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Sequential(od) |
|
|
| def forward(self, x): |
| return self.linear(x.squeeze(dim=1).float()) |
|
|
| model = _MLP() |
| model.load_state_dict(state) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def load_w2v2(model_root: str, device: str) -> audinterface.Feature: |
| if not os.path.isdir(model_root): |
| click.echo(f"Downloading w2v2 model to {model_root} ...") |
| cache_root = audeer.mkdir("cache") |
| audeer.mkdir(model_root) |
| archive_path = audeer.download_url(W2V2_MODEL_URL, cache_root, verbose=True) |
| audeer.extract_archive(archive_path, model_root) |
| model = audonnx.load(model_root, device=device) |
| return audinterface.Feature( |
| model.labels("hidden_states"), |
| process_func=model, |
| process_func_args={"outputs": "hidden_states"}, |
| sampling_rate=16000, |
| resample=False, |
| verbose=False, |
| ) |
|
|
|
|
| def predict(audio_path: str, model_path: str, w2v2_root: str) -> dict: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| click.echo(f"Device: {device}") |
| click.echo(f"Loading w2v2 feature extractor from {w2v2_root} ...") |
| w2v2 = load_w2v2(w2v2_root, device) |
|
|
| click.echo(f"Loading MLP model from {model_path} ...") |
| mlp = load_mlp_from_checkpoint(model_path, device) |
|
|
| click.echo(f"Reading audio from {audio_path} ...") |
| signal, sr = audiofile.read(audio_path) |
| click.echo(f" samples: {signal.shape}, sample rate: {sr} Hz") |
|
|
| features = np.asarray(w2v2.process_signal(signal, sr).values).flatten() |
| features = np.nan_to_num(features) |
|
|
| with torch.no_grad(): |
| x = torch.from_numpy(features.reshape(1, -1)).float().to(device) |
| logits = mlp(x).cpu().numpy()[0] |
|
|
| scores = {cat: float(logits[i]) for i, cat in enumerate(CATEGORIES)} |
| predicted = max(scores, key=scores.get) |
| return {"scores": scores, "predicted": predicted} |
|
|
|
|
| @click.command() |
| @click.argument("model", type=click.Path(exists=True, dir_okay=False)) |
| @click.argument("audio", type=click.Path(exists=True, dir_okay=False)) |
| @click.option( |
| "--w2v2-root", |
| default="./audmodel/", |
| show_default=True, |
| metavar="DIR", |
| help="Directory where the w2v2 onnx model is cached or will be downloaded to.", |
| ) |
| @click.help_option("-h", "--help") |
| def main(model, audio, w2v2_root): |
| """Predict emotion from an audio file using a nkululeko MLP + audwav2vec2 model. |
| |
| \b |
| MODEL Path to the .model file (torch state dict saved by nkululeko). |
| AUDIO Path to the audio file (must be 16 kHz mono WAV). |
| |
| \b |
| Example: |
| uv run test_model.py my_experiment_0_011.model sample.wav |
| uv run test_model.py my_experiment_0_011.model sample.wav --w2v2-root /data/audmodel/ |
| """ |
| result = predict(audio, model, w2v2_root) |
|
|
| click.echo("\n--- Results ---") |
| for cat, score in sorted(result["scores"].items(), key=lambda x: -x[1]): |
| marker = " <-- predicted" if cat == result["predicted"] else "" |
| click.echo(f" {cat:<10} {score:+.4f}{marker}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|