| """ |
| inference.py — load and run the SEGUE multitask sentiment + emotion model. |
| |
| Requirements |
| ------------ |
| 1. Clone the declare-lab/segue repository and make it importable: |
| git clone https://github.com/declare-lab/segue |
| Then either run your script from inside the segue/ directory, or add it to |
| your path explicitly: |
| import sys; sys.path.append("/path/to/segue") |
| |
| 2. Install dependencies (matching the versions used for training): |
| pip install torch torchaudio |
| pip install transformers==4.35.0 huggingface_hub==0.17.0 |
| pip install numpy>=1.24,<2.0 accelerate>=0.20.1,<0.24.0 |
| |
| Quick start |
| ----------- |
| import torchaudio |
| from inference import load_segue_multitask, segue_predict |
| |
| model, processor = load_segue_multitask("model.pt") |
| |
| waveform, sr = torchaudio.load("speech.wav") |
| audio = waveform.mean(0).numpy() # mono, float32 |
| |
| sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr) |
| # sent_probs: np.ndarray (N, 3) — neutral / positive / negative |
| # emo_probs: np.ndarray (N, 7) — neutral / surprise / fear / sadness / joy / disgust / anger |
| """ |
|
|
| import os |
| from typing import List, Optional, Tuple |
| import numpy as np |
| import torch |
| import torchaudio |
|
|
| |
| try: |
| from segue.modeling_segue import SegueForClassification |
| except ImportError: |
| raise ImportError( |
| "Could not import `segue`. " |
| "Clone https://github.com/declare-lab/segue and either run your script " |
| "from inside that directory or add it to sys.path:\n" |
| " import sys; sys.path.append('/path/to/segue')") |
|
|
| |
| |
| |
|
|
| SENTIMENT_LABELS = {0: "neutral", 1: "positive", 2: "negative"} |
| EMOTION_LABELS = { |
| 0: "neutral", |
| 1: "surprise", |
| 2: "fear", |
| 3: "sadness", |
| 4: "joy", |
| 5: "disgust", |
| 6: "anger"} |
|
|
| TARGET_SAMPLE_RATE = 16_000 |
|
|
|
|
| |
| |
| |
|
|
| class SegueMultiTask(torch.nn.Module): |
| """ |
| Two SegueForClassification heads (sentiment + emotion) that share a single |
| wav2vec2 speech encoder backbone. |
| |
| This is the same architecture used during fine-tuning on MELD. |
| The speech encoder is owned by `sentiment_model`; `emotion_model` holds |
| only its own text encoder and classification head. |
| """ |
|
|
| def __init__( |
| self, |
| sentiment_model: SegueForClassification, |
| emotion_model: SegueForClassification): |
| super().__init__() |
| self.sentiment_model = sentiment_model |
| self.emotion_model = emotion_model |
| |
| self.emotion_model.speech_encoder = self.sentiment_model.speech_encoder |
| self.processor = self.sentiment_model.processor |
|
|
| def forward( |
| self, |
| speech: dict, |
| n_speech_tokens: list, |
| **kwargs) -> dict: |
| """ |
| Args: |
| speech: dict with key "input_values": FloatTensor (B, T) |
| n_speech_tokens: list of ints, length B |
| |
| Returns: |
| dict with keys: |
| "sentiment_predictions": FloatTensor (B, 3) — raw logits |
| "emotion_predictions": FloatTensor (B, 7) — raw logits |
| """ |
| |
| |
| |
| batch_size = speech["input_values"].shape[0] |
| dummy_labels = torch.zeros( |
| batch_size, dtype=torch.long, device=speech["input_values"].device) |
|
|
| sent_out = self.sentiment_model( |
| speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels) |
| emo_out = self.emotion_model( |
| speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels) |
|
|
| return { |
| "sentiment_predictions": sent_out["predictions"], |
| "emotion_predictions": emo_out["predictions"]} |
|
|
|
|
| |
| |
| |
|
|
| def load_segue_multitask( |
| weights_path: str, |
| base_model: str = "declare-lab/segue-w2v2-base", |
| device: Optional[str] = None) -> Tuple[SegueMultiTask, object]: |
| """ |
| Load the fine-tuned SegueMultiTask model from a weights file. |
| |
| Args: |
| weights_path: path to `model.pt` (the fine-tuned state dict) |
| base_model: HuggingFace model ID used as the architecture template |
| device: "cuda", "cpu", or None (auto-detect) |
| |
| Returns: |
| (model, processor) |
| model: SegueMultiTask in eval mode, moved to `device` |
| processor: SegueProcessor for pre-processing audio |
| """ |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| sentiment_model = SegueForClassification.from_pretrained( |
| base_model, n_classes=3, ignore_mismatched_sizes=True) |
| emotion_model = SegueForClassification.from_pretrained( |
| base_model, n_classes=7, ignore_mismatched_sizes=True) |
|
|
| model = SegueMultiTask(sentiment_model, emotion_model) |
|
|
| |
| |
| model.sentiment_model.speech_encoder.config.mask_time_prob = 0.0 |
| model.sentiment_model.speech_encoder.config.mask_feature_prob = 0.0 |
|
|
| |
| state_dict = torch.load(weights_path, map_location="cpu") |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print(f"Warning — missing keys when loading weights: {missing}") |
| if unexpected: |
| print(f"Warning — unexpected keys when loading weights: {unexpected}") |
|
|
| |
| model.emotion_model.speech_encoder = model.sentiment_model.speech_encoder |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| return model, model.processor |
|
|
|
|
| |
| |
| |
|
|
| def segue_predict( |
| model: SegueMultiTask, |
| processor, |
| chunks: List[np.ndarray], |
| sampling_rate: int = TARGET_SAMPLE_RATE) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Run the model on a list of audio chunks and return softmax probabilities. |
| |
| Args: |
| model: SegueMultiTask returned by load_segue_multitask() |
| processor: processor returned by load_segue_multitask() |
| chunks: list of 1-D float32 numpy arrays (mono audio) |
| sampling_rate: sample rate of the audio (model expects 16 000 Hz; |
| pass the actual rate and it will be resampled if needed) |
| |
| Returns: |
| sent_probs: np.ndarray (N, 3) softmax probabilities for sentiment |
| columns: neutral / positive / negative |
| emo_probs: np.ndarray (N, 7) softmax probabilities for emotion |
| columns: neutral / surprise / fear / sadness / joy / disgust / anger |
| """ |
| device = next(model.parameters()).device |
|
|
| |
| if sampling_rate != TARGET_SAMPLE_RATE: |
| resampler = torchaudio.transforms.Resample(sampling_rate, TARGET_SAMPLE_RATE) |
| chunks = [ |
| resampler(torch.from_numpy(c).unsqueeze(0)).squeeze(0).numpy() |
| for c in chunks] |
|
|
| all_sent_logits = [] |
| all_emo_logits = [] |
|
|
| for chunk in chunks: |
| proc = processor(audio=chunk, sampling_rate=TARGET_SAMPLE_RATE) |
| input_values = torch.tensor( |
| proc["speech"]["input_values"], dtype=torch.float32 |
| ).unsqueeze(0).to(device) |
| n_speech_tokens = [int(proc["n_speech_tokens"][0])] |
|
|
| with torch.no_grad(): |
| out = model( |
| speech={"input_values": input_values}, |
| n_speech_tokens=n_speech_tokens) |
|
|
| all_sent_logits.append(out["sentiment_predictions"].cpu()) |
| all_emo_logits.append(out["emotion_predictions"].cpu()) |
|
|
| sent_probs = torch.softmax(torch.cat(all_sent_logits, dim=0), dim=1).numpy() |
| emo_probs = torch.softmax(torch.cat(all_emo_logits, dim=0), dim=1).numpy() |
|
|
| return sent_probs, emo_probs |
|
|
|
|
| |
| |
| |
|
|
| def predict_file( |
| audio_path: str, |
| weights_path: str = "model.pt", |
| device: Optional[str] = None) -> dict: |
| """ |
| Convenience function: load the model and run it on a single audio file. |
| |
| Returns a dict with: |
| sentiment: dict mapping label -> probability |
| emotion: dict mapping label -> probability |
| sentiment_score: float in [-1, 1] (prob_positive - prob_negative) |
| """ |
| model, processor = load_segue_multitask(weights_path, device=device) |
|
|
| waveform, sr = torchaudio.load(audio_path) |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| audio = waveform.squeeze(0).numpy().astype(np.float32) |
|
|
| sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr) |
|
|
| return { |
| "sentiment": { |
| SENTIMENT_LABELS[i]: float(sent_probs[0, i]) |
| for i in range(len(SENTIMENT_LABELS))}, |
| "emotion": { |
| EMOTION_LABELS[i]: float(emo_probs[0, i]) |
| for i in range(len(EMOTION_LABELS))}, |
| "sentiment_score": float(sent_probs[0, 1] - sent_probs[0, 2])} |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import sys |
|
|
| if len(sys.argv) < 2: |
| print("Usage: python inference.py <audio_file> [model.pt]") |
| sys.exit(1) |
|
|
| audio_path = sys.argv[1] |
| weights_path = sys.argv[2] if len(sys.argv) > 2 else "model.pt" |
|
|
| print(f"Audio: {audio_path}") |
| print(f"Weights: {weights_path}") |
| print() |
|
|
| result = predict_file(audio_path, weights_path) |
|
|
| print("Sentiment probabilities:") |
| for label, prob in result["sentiment"].items(): |
| print(f" {label:10s}: {prob:.4f}") |
| print(f" → score (pos - neg): {result['sentiment_score']:+.4f}") |
|
|
| print("\nEmotion probabilities:") |
| for label, prob in result["emotion"].items(): |
| print(f" {label:10s}: {prob:.4f}") |
|
|