segue-meld-multitask / inference.py
andreasvc's picture
Upload folder using huggingface_hub
2cf4246 verified
Raw
History Blame Contribute Delete
11 kB
"""
inference.py — load and run the SEGUE multitask sentiment + emotion model.
Requirements
------------
1. Clone the declare-lab/segue repository and make it importable:
git clone https://github.com/declare-lab/segue
Then either run your script from inside the segue/ directory, or add it to
your path explicitly:
import sys; sys.path.append("/path/to/segue")
2. Install dependencies (matching the versions used for training):
pip install torch torchaudio
pip install transformers==4.35.0 huggingface_hub==0.17.0
pip install numpy>=1.24,<2.0 accelerate>=0.20.1,<0.24.0
Quick start
-----------
import torchaudio
from inference import load_segue_multitask, segue_predict
model, processor = load_segue_multitask("model.pt")
waveform, sr = torchaudio.load("speech.wav")
audio = waveform.mean(0).numpy() # mono, float32
sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr)
# sent_probs: np.ndarray (N, 3) — neutral / positive / negative
# emo_probs: np.ndarray (N, 7) — neutral / surprise / fear / sadness / joy / disgust / anger
"""
import os
from typing import List, Optional, Tuple
import numpy as np
import torch
import torchaudio
# SegueForClassification lives in the segue repo — must be on sys.path.
try:
from segue.modeling_segue import SegueForClassification
except ImportError:
raise ImportError(
"Could not import `segue`. "
"Clone https://github.com/declare-lab/segue and either run your script "
"from inside that directory or add it to sys.path:\n"
" import sys; sys.path.append('/path/to/segue')")
# ---------------------------------------------------------------------------
# Label definitions
# ---------------------------------------------------------------------------
SENTIMENT_LABELS = {0: "neutral", 1: "positive", 2: "negative"}
EMOTION_LABELS = {
0: "neutral",
1: "surprise",
2: "fear",
3: "sadness",
4: "joy",
5: "disgust",
6: "anger"}
TARGET_SAMPLE_RATE = 16_000
# ---------------------------------------------------------------------------
# Model wrapper
# ---------------------------------------------------------------------------
class SegueMultiTask(torch.nn.Module):
"""
Two SegueForClassification heads (sentiment + emotion) that share a single
wav2vec2 speech encoder backbone.
This is the same architecture used during fine-tuning on MELD.
The speech encoder is owned by `sentiment_model`; `emotion_model` holds
only its own text encoder and classification head.
"""
def __init__(
self,
sentiment_model: SegueForClassification,
emotion_model: SegueForClassification):
super().__init__()
self.sentiment_model = sentiment_model
self.emotion_model = emotion_model
# Tie the encoders so the backbone is shared
self.emotion_model.speech_encoder = self.sentiment_model.speech_encoder
self.processor = self.sentiment_model.processor
def forward(
self,
speech: dict,
n_speech_tokens: list,
**kwargs) -> dict:
"""
Args:
speech: dict with key "input_values": FloatTensor (B, T)
n_speech_tokens: list of ints, length B
Returns:
dict with keys:
"sentiment_predictions": FloatTensor (B, 3) — raw logits
"emotion_predictions": FloatTensor (B, 7) — raw logits
"""
# SegueForClassification.forward() unconditionally calls
# labels.unsqueeze(-1), so we must always supply labels.
# We pass dummy zeros and ignore the returned loss.
batch_size = speech["input_values"].shape[0]
dummy_labels = torch.zeros(
batch_size, dtype=torch.long, device=speech["input_values"].device)
sent_out = self.sentiment_model(
speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels)
emo_out = self.emotion_model(
speech=speech, n_speech_tokens=n_speech_tokens, labels=dummy_labels)
return {
"sentiment_predictions": sent_out["predictions"],
"emotion_predictions": emo_out["predictions"]}
# ---------------------------------------------------------------------------
# Loading
# ---------------------------------------------------------------------------
def load_segue_multitask(
weights_path: str,
base_model: str = "declare-lab/segue-w2v2-base",
device: Optional[str] = None) -> Tuple[SegueMultiTask, object]:
"""
Load the fine-tuned SegueMultiTask model from a weights file.
Args:
weights_path: path to `model.pt` (the fine-tuned state dict)
base_model: HuggingFace model ID used as the architecture template
device: "cuda", "cpu", or None (auto-detect)
Returns:
(model, processor)
model: SegueMultiTask in eval mode, moved to `device`
processor: SegueProcessor for pre-processing audio
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load architecture from the pre-trained base (weights will be overwritten)
sentiment_model = SegueForClassification.from_pretrained(
base_model, n_classes=3, ignore_mismatched_sizes=True)
emotion_model = SegueForClassification.from_pretrained(
base_model, n_classes=7, ignore_mismatched_sizes=True)
model = SegueMultiTask(sentiment_model, emotion_model)
# Disable wav2vec2 feature masking — it's a pre-training trick that causes
# errors on short sequences and is not needed for inference.
model.sentiment_model.speech_encoder.config.mask_time_prob = 0.0
model.sentiment_model.speech_encoder.config.mask_feature_prob = 0.0
# Load fine-tuned weights
state_dict = torch.load(weights_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning — missing keys when loading weights: {missing}")
if unexpected:
print(f"Warning — unexpected keys when loading weights: {unexpected}")
# Re-tie the shared speech encoder (load_state_dict breaks the reference)
model.emotion_model.speech_encoder = model.sentiment_model.speech_encoder
model = model.to(device)
model.eval()
return model, model.processor
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def segue_predict(
model: SegueMultiTask,
processor,
chunks: List[np.ndarray],
sampling_rate: int = TARGET_SAMPLE_RATE) -> Tuple[np.ndarray, np.ndarray]:
"""
Run the model on a list of audio chunks and return softmax probabilities.
Args:
model: SegueMultiTask returned by load_segue_multitask()
processor: processor returned by load_segue_multitask()
chunks: list of 1-D float32 numpy arrays (mono audio)
sampling_rate: sample rate of the audio (model expects 16 000 Hz;
pass the actual rate and it will be resampled if needed)
Returns:
sent_probs: np.ndarray (N, 3) softmax probabilities for sentiment
columns: neutral / positive / negative
emo_probs: np.ndarray (N, 7) softmax probabilities for emotion
columns: neutral / surprise / fear / sadness / joy / disgust / anger
"""
device = next(model.parameters()).device
# Resample if the audio doesn't match the model's expected rate
if sampling_rate != TARGET_SAMPLE_RATE:
resampler = torchaudio.transforms.Resample(sampling_rate, TARGET_SAMPLE_RATE)
chunks = [
resampler(torch.from_numpy(c).unsqueeze(0)).squeeze(0).numpy()
for c in chunks]
all_sent_logits = []
all_emo_logits = []
for chunk in chunks:
proc = processor(audio=chunk, sampling_rate=TARGET_SAMPLE_RATE)
input_values = torch.tensor(
proc["speech"]["input_values"], dtype=torch.float32
).unsqueeze(0).to(device)
n_speech_tokens = [int(proc["n_speech_tokens"][0])]
with torch.no_grad():
out = model(
speech={"input_values": input_values},
n_speech_tokens=n_speech_tokens)
all_sent_logits.append(out["sentiment_predictions"].cpu())
all_emo_logits.append(out["emotion_predictions"].cpu())
sent_probs = torch.softmax(torch.cat(all_sent_logits, dim=0), dim=1).numpy()
emo_probs = torch.softmax(torch.cat(all_emo_logits, dim=0), dim=1).numpy()
return sent_probs, emo_probs
# ---------------------------------------------------------------------------
# Convenience: predict a single audio file
# ---------------------------------------------------------------------------
def predict_file(
audio_path: str,
weights_path: str = "model.pt",
device: Optional[str] = None) -> dict:
"""
Convenience function: load the model and run it on a single audio file.
Returns a dict with:
sentiment: dict mapping label -> probability
emotion: dict mapping label -> probability
sentiment_score: float in [-1, 1] (prob_positive - prob_negative)
"""
model, processor = load_segue_multitask(weights_path, device=device)
waveform, sr = torchaudio.load(audio_path)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
audio = waveform.squeeze(0).numpy().astype(np.float32)
sent_probs, emo_probs = segue_predict(model, processor, [audio], sampling_rate=sr)
return {
"sentiment": {
SENTIMENT_LABELS[i]: float(sent_probs[0, i])
for i in range(len(SENTIMENT_LABELS))},
"emotion": {
EMOTION_LABELS[i]: float(emo_probs[0, i])
for i in range(len(EMOTION_LABELS))},
"sentiment_score": float(sent_probs[0, 1] - sent_probs[0, 2])}
# ---------------------------------------------------------------------------
# Quick test when run directly
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python inference.py <audio_file> [model.pt]")
sys.exit(1)
audio_path = sys.argv[1]
weights_path = sys.argv[2] if len(sys.argv) > 2 else "model.pt"
print(f"Audio: {audio_path}")
print(f"Weights: {weights_path}")
print()
result = predict_file(audio_path, weights_path)
print("Sentiment probabilities:")
for label, prob in result["sentiment"].items():
print(f" {label:10s}: {prob:.4f}")
print(f" → score (pos - neg): {result['sentiment_score']:+.4f}")
print("\nEmotion probabilities:")
for label, prob in result["emotion"].items():
print(f" {label:10s}: {prob:.4f}")