AIMirror-Backend / emotion_detector.py
Zayeemk's picture
Upload 21 files
69aa668 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from textblob import TextBlob
import numpy as np
from typing import Dict, List
import logging
logger = logging.getLogger(__name__)
class EmotionDetector:
"""
Emotion detection using pre-trained BERT models
Supports multi-class emotion classification
"""
def __init__(self):
"""Initialize the emotion detection model"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Load pre-trained emotion classification model
model_name = "j-hartmann/emotion-english-distilroberta-base"
try:
logger.info(f"Loading model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model.to(self.device)
self.model.eval()
# Emotion labels for this model
self.emotion_labels = ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
def detect_emotion(self, text: str) -> Dict:
"""
Detect emotions from text
Args:
text: Input text to analyze
Returns:
Dictionary containing emotion probabilities and metadata
"""
if not text or len(text.strip()) == 0:
return self._empty_result()
try:
# Tokenize input
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Convert to probabilities
probs = predictions[0].cpu().numpy()
# Create emotion dictionary
emotions = {
label: float(prob) for label, prob in zip(self.emotion_labels, probs)
}
# Get dominant emotion
dominant_idx = np.argmax(probs)
dominant_emotion = self.emotion_labels[dominant_idx]
dominant_score = float(probs[dominant_idx])
# Get sentiment using TextBlob
sentiment = self._get_sentiment(text)
return {
"emotions": emotions,
"dominant_emotion": dominant_emotion,
"dominant_score": dominant_score,
"sentiment_score": sentiment['polarity'],
"sentiment_label": sentiment['label'],
"text_length": len(text),
"word_count": len(text.split())
}
except Exception as e:
logger.error(f"Error detecting emotion: {str(e)}")
return self._empty_result()
def _get_sentiment(self, text: str) -> Dict:
"""
Get sentiment polarity using TextBlob
Args:
text: Input text
Returns:
Dictionary with polarity score and label
"""
try:
blob = TextBlob(text)
polarity = blob.sentiment.polarity
# Classify sentiment
if polarity > 0.1:
label = "positive"
elif polarity < -0.1:
label = "negative"
else:
label = "neutral"
return {
"polarity": float(polarity),
"label": label
}
except:
return {"polarity": 0.0, "label": "neutral"}
def aggregate_emotions(self, results: List[Dict]) -> Dict:
"""
Aggregate emotions from multiple text analyses
Args:
results: List of emotion detection results
Returns:
Aggregated emotion statistics
"""
if not results:
return self._empty_result()
# Initialize aggregation
emotion_sums = {label: 0.0 for label in self.emotion_labels}
sentiment_sum = 0.0
total_words = 0
# Aggregate
for result in results:
for emotion, score in result['emotions'].items():
emotion_sums[emotion] += score
sentiment_sum += result['sentiment_score']
total_words += result.get('word_count', 0)
# Calculate averages
n = len(results)
emotions_avg = {label: score / n for label, score in emotion_sums.items()}
# Get dominant emotion
dominant_emotion = max(emotions_avg, key=emotions_avg.get)
dominant_score = emotions_avg[dominant_emotion]
# Average sentiment
avg_sentiment = sentiment_sum / n
sentiment_label = "positive" if avg_sentiment > 0.1 else "negative" if avg_sentiment < -0.1 else "neutral"
return {
"emotions": emotions_avg,
"dominant_emotion": dominant_emotion,
"dominant_score": dominant_score,
"sentiment_score": avg_sentiment,
"sentiment_label": sentiment_label,
"total_texts": n,
"total_words": total_words,
"avg_words_per_text": total_words / n if n > 0 else 0
}
def _empty_result(self) -> Dict:
"""Return empty result structure"""
return {
"emotions": {label: 0.0 for label in self.emotion_labels},
"dominant_emotion": "neutral",
"dominant_score": 0.0,
"sentiment_score": 0.0,
"sentiment_label": "neutral",
"text_length": 0,
"word_count": 0
}