github-actions[bot]
Deploy from GitHub Actions (commit: 8b247ffacd77c0672965b8378f1d52a7dcd187ae)
9366995
| """ | |
| Emotion Evaluator | |
| Analyzes user emotions using j-hartmann/emotion-english-distilroberta-base model. | |
| Calculates negative emotion sum, joy/neutral shift, and tracks emotion change trends. | |
| """ | |
| from typing import List, Dict, Optional | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| import torch.nn.functional as F | |
| from evaluators.base import Evaluator | |
| from evaluators.registry import register_evaluator | |
| from custom_types import Utterance, EvaluationResult | |
| from utils.evaluation_helpers import create_numerical_score, create_categorical_score, create_utterance_result | |
| class EmotionEvaluator(Evaluator): | |
| """Evaluator for emotion analysis using j-hartmann/emotion-english-distilroberta-base.""" | |
| METRIC_NAME = "emotion_analysis" | |
| MODEL_NAME = "j-hartmann/emotion-english-distilroberta-base" | |
| # Emotion labels in the order the model outputs them | |
| EMOTION_LABELS = ["anger", "disgust", "fear", "joy", "neutral", "sadness", "surprise"] | |
| # Negative emotions to sum | |
| NEGATIVE_EMOTIONS = ["anger", "disgust", "fear", "sadness"] | |
| # User role identifiers | |
| USER_ROLES = {"patient", "seeker", "client", "user"} | |
| def __init__( | |
| self, | |
| api_keys: Optional[Dict[str, str]] = None, | |
| api_key: Optional[str] = None, | |
| **kwargs | |
| ): | |
| """ | |
| Initialize Emotion Evaluator. | |
| Args: | |
| api_keys: Dict of API keys (not used for local model, kept for interface consistency) | |
| api_key: Single API key (not used for local model, kept for interface consistency) | |
| **kwargs: Additional arguments (ignored) | |
| """ | |
| super().__init__() | |
| self.tokenizer = None | |
| self.model = None | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model and tokenizer.""" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(self.MODEL_NAME) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load {self.MODEL_NAME}: {e}") | |
| def _predict_emotions(self, text: str) -> Dict[str, float]: | |
| """ | |
| Predict emotion scores for a single text. | |
| Args: | |
| text: The text to analyze | |
| Returns: | |
| Dict mapping emotion labels to their probability scores | |
| """ | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Predict | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| scores = F.softmax(outputs.logits, dim=1) | |
| scores = scores.cpu().numpy()[0] | |
| # Map to emotion labels | |
| emotion_scores = dict(zip(self.EMOTION_LABELS, scores)) | |
| return emotion_scores | |
| def _calculate_metrics(self, emotion_scores: Dict[str, float]) -> Dict[str, float]: | |
| """ | |
| Calculate negative emotion sum and joy/neutral shift. | |
| Args: | |
| emotion_scores: Dict of emotion label -> probability | |
| Returns: | |
| Dict with 'sum_negative' and 'joy_neutral_shift' | |
| """ | |
| # Sum negative emotions | |
| sum_negative = sum(emotion_scores[emotion] for emotion in self.NEGATIVE_EMOTIONS) | |
| # Joy/neutral shift | |
| joy_neutral_shift = emotion_scores["joy"] - emotion_scores["neutral"] | |
| return { | |
| "sum_negative": sum_negative, | |
| "joy_neutral_shift": joy_neutral_shift, | |
| "emotion_scores": emotion_scores | |
| } | |
| def _analyze_trend(self, all_metrics: List[Dict[str, float]]) -> Dict[str, float]: | |
| """ | |
| Analyze emotion change trend across all user utterances. | |
| Args: | |
| all_metrics: List of metric dicts from all user utterances | |
| Returns: | |
| Dict with trend information | |
| """ | |
| if not all_metrics: | |
| return { | |
| "avg_sum_negative": 0.0, | |
| "avg_joy_neutral_shift": 0.0, | |
| "trend_direction": "neutral" | |
| } | |
| # Calculate averages | |
| avg_sum_negative = sum(m["sum_negative"] for m in all_metrics) / len(all_metrics) | |
| avg_joy_neutral_shift = sum(m["joy_neutral_shift"] for m in all_metrics) / len(all_metrics) | |
| # Determine trend direction | |
| # Compare first half vs second half of conversation | |
| mid_point = len(all_metrics) // 2 | |
| if mid_point > 0: | |
| first_half_negative = sum(m["sum_negative"] for m in all_metrics[:mid_point]) / mid_point | |
| second_half_negative = sum(m["sum_negative"] for m in all_metrics[mid_point:]) / (len(all_metrics) - mid_point) | |
| first_half_shift = sum(m["joy_neutral_shift"] for m in all_metrics[:mid_point]) / mid_point | |
| second_half_shift = sum(m["joy_neutral_shift"] for m in all_metrics[mid_point:]) / (len(all_metrics) - mid_point) | |
| # Determine trend | |
| negative_change = second_half_negative - first_half_negative | |
| shift_change = second_half_shift - first_half_shift | |
| if negative_change < -0.1 and shift_change > 0.1: | |
| trend_direction = "improving" | |
| elif negative_change > 0.1 and shift_change < -0.1: | |
| trend_direction = "declining" | |
| else: | |
| trend_direction = "stable" | |
| else: | |
| trend_direction = "neutral" | |
| return { | |
| "avg_sum_negative": avg_sum_negative, | |
| "avg_joy_neutral_shift": avg_joy_neutral_shift, | |
| "trend_direction": trend_direction | |
| } | |
| def execute(self, conversation: List[Utterance], **kwargs) -> EvaluationResult: | |
| """ | |
| Evaluate emotions for each user utterance in the conversation. | |
| Args: | |
| conversation: List of utterances with 'speaker' and 'text' | |
| Returns: | |
| EvaluationResult with per-utterance scores and overall trend | |
| """ | |
| scores_per_utterance = [] | |
| user_metrics = [] # Track metrics for trend analysis | |
| for i, utt in enumerate(conversation): | |
| # Only evaluate user utterances | |
| if utt["speaker"].lower() in self.USER_ROLES: | |
| # Predict emotions | |
| emotion_scores = self._predict_emotions(utt["text"]) | |
| # Calculate metrics | |
| metrics = self._calculate_metrics(emotion_scores) | |
| user_metrics.append(metrics) | |
| # Create scores for this utterance | |
| # Store both metrics per utterance | |
| scores_per_utterance.append({ | |
| "emotion_sum_negative": create_numerical_score( | |
| value=metrics["sum_negative"], | |
| max_value=1.0, | |
| label=self._get_label_for_negative(metrics["sum_negative"]) | |
| ), | |
| "emotion_joy_neutral_shift": create_numerical_score( | |
| value=metrics["joy_neutral_shift"], | |
| max_value=1.0, | |
| label=self._get_label_for_shift(metrics["joy_neutral_shift"]) | |
| ) | |
| }) | |
| else: | |
| # Not a user utterance, skip | |
| scores_per_utterance.append({}) | |
| # Analyze overall trend | |
| trend = self._analyze_trend(user_metrics) | |
| # Create result with both per-utterance scores and overall trend | |
| result = create_utterance_result(conversation, scores_per_utterance) | |
| # Add overall trend information | |
| if user_metrics: | |
| result["overall"] = { | |
| "emotion_avg_sum_negative": create_numerical_score( | |
| value=trend["avg_sum_negative"], | |
| max_value=1.0, | |
| label=trend["trend_direction"] | |
| ), | |
| "emotion_avg_joy_neutral_shift": create_numerical_score( | |
| value=trend["avg_joy_neutral_shift"], | |
| max_value=1.0, | |
| label=trend["trend_direction"] | |
| ), | |
| "emotion_trend_direction": create_categorical_score( | |
| label=trend["trend_direction"], | |
| confidence=None | |
| ) | |
| } | |
| return result | |
| def _get_label_for_negative(self, value: float) -> str: | |
| """Get label for negative emotion sum.""" | |
| if value < 0.2: | |
| return "Low" | |
| elif value < 0.5: | |
| return "Medium" | |
| else: | |
| return "High" | |
| def _get_label_for_shift(self, value: float) -> str: | |
| """Get label for joy/neutral shift.""" | |
| if value > 0.2: | |
| return "Positive" | |
| elif value < -0.2: | |
| return "Negative" | |
| else: | |
| return "Neutral" | |