| |
| |
| """ |
| Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned). |
| """ |
|
|
| import os |
| import torch |
| import speechbrain as sb |
| from hyperpyyaml import load_hyperpyyaml |
| from speechbrain.dataio.dataio import read_audio |
|
|
| |
| |
| |
| EXP_DIR = ( |
| "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/" |
| "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968" |
| ) |
| HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml") |
| CKPT_DIR = os.path.join(EXP_DIR, "save") |
|
|
| |
| |
| |
| with open(HP_FILE) as f: |
| hparams = load_hyperpyyaml(f) |
|
|
| modules = { |
| "compute_features": hparams["compute_features"], |
| "mean_var_norm" : hparams["mean_var_norm"], |
| "embedding_model" : hparams["embedding_model"], |
| "classifier" : hparams["classifier"], |
| } |
|
|
| |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| checkpointer = sb.utils.checkpoints.Checkpointer( |
| checkpoints_dir=CKPT_DIR, |
| recoverables=modules, |
| allow_partial_load=True, |
| ) |
| checkpointer.recover_if_possible() |
|
|
| |
| |
| |
| class SimpleBatch: |
| def __init__(self, wav, lens): |
| self.sig = (wav, lens) |
|
|
| def to(self, device): |
| wav, lens = self.sig |
| self.sig = (wav.to(device), lens.to(device)) |
| return self |
|
|
| |
| |
| |
| class EmoIdBrain(sb.Brain): |
| def compute_forward(self, batch, stage): |
| wavs, lens = batch.sig |
| feats = self.modules.compute_features(wavs) |
| feats = self.modules.mean_var_norm(feats, lens) |
| emb = self.modules.embedding_model(feats, lens) |
| out = self.modules.classifier(emb) |
| return out |
|
|
| brain = EmoIdBrain( |
| modules=modules, |
| hparams=hparams, |
| run_opts={"device": device}, |
| checkpointer=checkpointer |
| ) |
|
|
| |
| |
| |
| IDX2LAB = [ |
| "anger", "sadness", "neutral", |
| "surprise", "happiness", "fear" |
| ] |
|
|
| |
| |
| |
| def predict(wav_path: str) -> str: |
| wav_raw = read_audio(wav_path) |
| wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0) |
| lens = torch.tensor([1.0]) |
|
|
| batch = SimpleBatch(wav, lens).to(device) |
| brain.modules.eval() |
|
|
| with torch.no_grad(): |
| logits = brain.compute_forward(batch, stage=sb.Stage.TEST) |
|
|
| idx = int(logits.argmax(dim=-1)) |
| return IDX2LAB[idx] |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| WAV_FILE = "shortvoice.wav" |
| print("Predicted emotion:", predict(WAV_FILE)) |