File size: 1,971 Bytes
9be21ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# models/emotion_classifier.py

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline


class EmotionClassifier:
    """
    Wrapper around a pre-trained GoEmotions RoBERTa model.
    Uses: SamLowe/roberta-base-go_emotions
    """

    def __init__(self, model_name: str = "SamLowe/roberta-base-go_emotions"):
        print("[EmotionClassifier] Loading model... This may take a moment the first time.")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        # `top_k=None` preserves the old "all scores" behavior without the deprecation warning.
        self.pipeline = TextClassificationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
            top_k=None
        )
        print("[EmotionClassifier] Model loaded successfully.")

    def predict_emotions(self, text: str, top_k: int = 3):
        """
        Predict top_k emotions for a given input text.
        Returns a list of {label, score} dicts.
        """
        if not text or not text.strip():
            return []

        # Get scores for all labels
        outputs = self.pipeline(text)[0]  # pipeline returns a list per input
        # Sort by score descending
        sorted_outputs = sorted(outputs, key=lambda x: x["score"], reverse=True)
        # Take top_k
        return sorted_outputs[:top_k]


# Simple test code so you can run this file directly
if __name__ == "__main__":
    clf = EmotionClassifier()

    test_sentences = [
        "I feel really scared because my period is very late.",
        "I'm so happy that my cycle is finally regular.",
        "I'm embarrassed to talk about my period with anyone."
    ]

    for s in test_sentences:
        print(f"\nText: {s}")
        preds = clf.predict_emotions(s, top_k=5)
        for p in preds:
            print(f"  {p['label']}: {p['score']:.3f}")