Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from textblob import TextBlob | |
| import numpy as np | |
| from typing import Dict, List | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class EmotionDetector: | |
| """ | |
| Emotion detection using pre-trained BERT models | |
| Supports multi-class emotion classification | |
| """ | |
| def __init__(self): | |
| """Initialize the emotion detection model""" | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Load pre-trained emotion classification model | |
| model_name = "j-hartmann/emotion-english-distilroberta-base" | |
| try: | |
| logger.info(f"Loading model: {model_name}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Emotion labels for this model | |
| self.emotion_labels = ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'] | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise | |
| def detect_emotion(self, text: str) -> Dict: | |
| """ | |
| Detect emotions from text | |
| Args: | |
| text: Input text to analyze | |
| Returns: | |
| Dictionary containing emotion probabilities and metadata | |
| """ | |
| if not text or len(text.strip()) == 0: | |
| return self._empty_result() | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ).to(self.device) | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Convert to probabilities | |
| probs = predictions[0].cpu().numpy() | |
| # Create emotion dictionary | |
| emotions = { | |
| label: float(prob) for label, prob in zip(self.emotion_labels, probs) | |
| } | |
| # Get dominant emotion | |
| dominant_idx = np.argmax(probs) | |
| dominant_emotion = self.emotion_labels[dominant_idx] | |
| dominant_score = float(probs[dominant_idx]) | |
| # Get sentiment using TextBlob | |
| sentiment = self._get_sentiment(text) | |
| return { | |
| "emotions": emotions, | |
| "dominant_emotion": dominant_emotion, | |
| "dominant_score": dominant_score, | |
| "sentiment_score": sentiment['polarity'], | |
| "sentiment_label": sentiment['label'], | |
| "text_length": len(text), | |
| "word_count": len(text.split()) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error detecting emotion: {str(e)}") | |
| return self._empty_result() | |
| def _get_sentiment(self, text: str) -> Dict: | |
| """ | |
| Get sentiment polarity using TextBlob | |
| Args: | |
| text: Input text | |
| Returns: | |
| Dictionary with polarity score and label | |
| """ | |
| try: | |
| blob = TextBlob(text) | |
| polarity = blob.sentiment.polarity | |
| # Classify sentiment | |
| if polarity > 0.1: | |
| label = "positive" | |
| elif polarity < -0.1: | |
| label = "negative" | |
| else: | |
| label = "neutral" | |
| return { | |
| "polarity": float(polarity), | |
| "label": label | |
| } | |
| except: | |
| return {"polarity": 0.0, "label": "neutral"} | |
| def aggregate_emotions(self, results: List[Dict]) -> Dict: | |
| """ | |
| Aggregate emotions from multiple text analyses | |
| Args: | |
| results: List of emotion detection results | |
| Returns: | |
| Aggregated emotion statistics | |
| """ | |
| if not results: | |
| return self._empty_result() | |
| # Initialize aggregation | |
| emotion_sums = {label: 0.0 for label in self.emotion_labels} | |
| sentiment_sum = 0.0 | |
| total_words = 0 | |
| # Aggregate | |
| for result in results: | |
| for emotion, score in result['emotions'].items(): | |
| emotion_sums[emotion] += score | |
| sentiment_sum += result['sentiment_score'] | |
| total_words += result.get('word_count', 0) | |
| # Calculate averages | |
| n = len(results) | |
| emotions_avg = {label: score / n for label, score in emotion_sums.items()} | |
| # Get dominant emotion | |
| dominant_emotion = max(emotions_avg, key=emotions_avg.get) | |
| dominant_score = emotions_avg[dominant_emotion] | |
| # Average sentiment | |
| avg_sentiment = sentiment_sum / n | |
| sentiment_label = "positive" if avg_sentiment > 0.1 else "negative" if avg_sentiment < -0.1 else "neutral" | |
| return { | |
| "emotions": emotions_avg, | |
| "dominant_emotion": dominant_emotion, | |
| "dominant_score": dominant_score, | |
| "sentiment_score": avg_sentiment, | |
| "sentiment_label": sentiment_label, | |
| "total_texts": n, | |
| "total_words": total_words, | |
| "avg_words_per_text": total_words / n if n > 0 else 0 | |
| } | |
| def _empty_result(self) -> Dict: | |
| """Return empty result structure""" | |
| return { | |
| "emotions": {label: 0.0 for label in self.emotion_labels}, | |
| "dominant_emotion": "neutral", | |
| "dominant_score": 0.0, | |
| "sentiment_score": 0.0, | |
| "sentiment_label": "neutral", | |
| "text_length": 0, | |
| "word_count": 0 | |
| } | |