""" inference.py — load and run the SEGUE multitask sentiment + emotion model. Requirements ------------ 1. Clone the declare-lab/segue repository and make it importable: git clone https://github.com/declare-lab/segue Then either run your script from inside the segue/ directory, or add it to your path explicitly: import sys; sys.path.append("/path/to/segue") 2. Install dependencies (matching the versions used for training): pip install torch torchaudio pip install transformers==4.35.0 huggingface_hub==0.17.0 pip install numpy>=1.24,<2.0 accelerate>=0.20.1,<0.24.0 Quick start ----------- import torchaudio from inference import load_segue_multitask, segue_predict model, processor = load_segue_multitask("model.pt") waveform, sr = torchaudio.load("speech.wav") audio = waveform.mean(0).numpy() # mono, float32 sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr) # sent_probs: np.ndarray (N, 3) — neutral / positive / negative # emo_probs: np.ndarray (N, 7) — neutral / surprise / fear / sadness / joy / disgust / anger """ import os from typing import List, Optional, Tuple import numpy as np import torch import torchaudio # SegueForClassification lives in the segue repo — must be on sys.path. try: from segue.modeling_segue import SegueForClassification except ImportError: raise ImportError( "Could not import `segue`. " "Clone https://github.com/declare-lab/segue and either run your script " "from inside that directory or add it to sys.path:\n" " import sys; sys.path.append('/path/to/segue')") # --------------------------------------------------------------------------- # Label definitions # --------------------------------------------------------------------------- SENTIMENT_LABELS = {0: "neutral", 1: "positive", 2: "negative"} EMOTION_LABELS = { 0: "neutral", 1: "surprise", 2: "fear", 3: "sadness", 4: "joy", 5: "disgust", 6: "anger"} TARGET_SAMPLE_RATE = 16_000 # --------------------------------------------------------------------------- # Model wrapper # --------------------------------------------------------------------------- class SegueMultiTask(torch.nn.Module): """ Two SegueForClassification heads (sentiment + emotion) that share a single wav2vec2 speech encoder backbone. This is the same architecture used during fine-tuning on MELD. The speech encoder is owned by `sentiment_model`; `emotion_model` holds only its own text encoder and classification head. """ def __init__( self, sentiment_model: SegueForClassification, emotion_model: SegueForClassification): super().__init__() self.sentiment_model = sentiment_model self.emotion_model = emotion_model # Tie the encoders so the backbone is shared self.emotion_model.speech_encoder = self.sentiment_model.speech_encoder self.processor = self.sentiment_model.processor def forward( self, speech: dict, n_speech_tokens: list, **kwargs) -> dict: """ Args: speech: dict with key "input_values": FloatTensor (B, T) n_speech_tokens: list of ints, length B Returns: dict with keys: "sentiment_predictions": FloatTensor (B, 3) — raw logits "emotion_predictions": FloatTensor (B, 7) — raw logits """ # SegueForClassification.forward() unconditionally calls # labels.unsqueeze(-1), so we must always supply labels. # We pass dummy zeros and ignore the returned loss. batch_size = speech["input_values"].shape[0] dummy_labels = torch.zeros( batch_size, dtype=torch.long, device=speech["input_values"].device) sent_out = self.sentiment_model( speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels) emo_out = self.emotion_model( speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels) return { "sentiment_predictions": sent_out["predictions"], "emotion_predictions": emo_out["predictions"]} # --------------------------------------------------------------------------- # Loading # --------------------------------------------------------------------------- def load_segue_multitask( weights_path: str, base_model: str = "declare-lab/segue-w2v2-base", device: Optional[str] = None) -> Tuple[SegueMultiTask, object]: """ Load the fine-tuned SegueMultiTask model from a weights file. Args: weights_path: path to `model.pt` (the fine-tuned state dict) base_model: HuggingFace model ID used as the architecture template device: "cuda", "cpu", or None (auto-detect) Returns: (model, processor) model: SegueMultiTask in eval mode, moved to `device` processor: SegueProcessor for pre-processing audio """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" # Load architecture from the pre-trained base (weights will be overwritten) sentiment_model = SegueForClassification.from_pretrained( base_model, n_classes=3, ignore_mismatched_sizes=True) emotion_model = SegueForClassification.from_pretrained( base_model, n_classes=7, ignore_mismatched_sizes=True) model = SegueMultiTask(sentiment_model, emotion_model) # Disable wav2vec2 feature masking — it's a pre-training trick that causes # errors on short sequences and is not needed for inference. model.sentiment_model.speech_encoder.config.mask_time_prob = 0.0 model.sentiment_model.speech_encoder.config.mask_feature_prob = 0.0 # Load fine-tuned weights state_dict = torch.load(weights_path, map_location="cpu") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f"Warning — missing keys when loading weights: {missing}") if unexpected: print(f"Warning — unexpected keys when loading weights: {unexpected}") # Re-tie the shared speech encoder (load_state_dict breaks the reference) model.emotion_model.speech_encoder = model.sentiment_model.speech_encoder model = model.to(device) model.eval() return model, model.processor # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def segue_predict( model: SegueMultiTask, processor, chunks: List[np.ndarray], sampling_rate: int = TARGET_SAMPLE_RATE) -> Tuple[np.ndarray, np.ndarray]: """ Run the model on a list of audio chunks and return softmax probabilities. Args: model: SegueMultiTask returned by load_segue_multitask() processor: processor returned by load_segue_multitask() chunks: list of 1-D float32 numpy arrays (mono audio) sampling_rate: sample rate of the audio (model expects 16 000 Hz; pass the actual rate and it will be resampled if needed) Returns: sent_probs: np.ndarray (N, 3) softmax probabilities for sentiment columns: neutral / positive / negative emo_probs: np.ndarray (N, 7) softmax probabilities for emotion columns: neutral / surprise / fear / sadness / joy / disgust / anger """ device = next(model.parameters()).device # Resample if the audio doesn't match the model's expected rate if sampling_rate != TARGET_SAMPLE_RATE: resampler = torchaudio.transforms.Resample(sampling_rate, TARGET_SAMPLE_RATE) chunks = [ resampler(torch.from_numpy(c).unsqueeze(0)).squeeze(0).numpy() for c in chunks] all_sent_logits = [] all_emo_logits = [] for chunk in chunks: proc = processor(audio=chunk, sampling_rate=TARGET_SAMPLE_RATE) input_values = torch.tensor( proc["speech"]["input_values"], dtype=torch.float32 ).unsqueeze(0).to(device) n_speech_tokens = [int(proc["n_speech_tokens"][0])] with torch.no_grad(): out = model( speech={"input_values": input_values}, n_speech_tokens=n_speech_tokens) all_sent_logits.append(out["sentiment_predictions"].cpu()) all_emo_logits.append(out["emotion_predictions"].cpu()) sent_probs = torch.softmax(torch.cat(all_sent_logits, dim=0), dim=1).numpy() emo_probs = torch.softmax(torch.cat(all_emo_logits, dim=0), dim=1).numpy() return sent_probs, emo_probs # --------------------------------------------------------------------------- # Convenience: predict a single audio file # --------------------------------------------------------------------------- def predict_file( audio_path: str, weights_path: str = "model.pt", device: Optional[str] = None) -> dict: """ Convenience function: load the model and run it on a single audio file. Returns a dict with: sentiment: dict mapping label -> probability emotion: dict mapping label -> probability sentiment_score: float in [-1, 1] (prob_positive - prob_negative) """ model, processor = load_segue_multitask(weights_path, device=device) waveform, sr = torchaudio.load(audio_path) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) audio = waveform.squeeze(0).numpy().astype(np.float32) sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr) return { "sentiment": { SENTIMENT_LABELS[i]: float(sent_probs[0, i]) for i in range(len(SENTIMENT_LABELS))}, "emotion": { EMOTION_LABELS[i]: float(emo_probs[0, i]) for i in range(len(EMOTION_LABELS))}, "sentiment_score": float(sent_probs[0, 1] - sent_probs[0, 2])} # --------------------------------------------------------------------------- # Quick test when run directly # --------------------------------------------------------------------------- if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python inference.py [model.pt]") sys.exit(1) audio_path = sys.argv[1] weights_path = sys.argv[2] if len(sys.argv) > 2 else "model.pt" print(f"Audio: {audio_path}") print(f"Weights: {weights_path}") print() result = predict_file(audio_path, weights_path) print("Sentiment probabilities:") for label, prob in result["sentiment"].items(): print(f" {label:10s}: {prob:.4f}") print(f" → score (pos - neg): {result['sentiment_score']:+.4f}") print("\nEmotion probabilities:") for label, prob in result["emotion"].items(): print(f" {label:10s}: {prob:.4f}")