File size: 4,017 Bytes
ac0d5f8
 
 
 
 
 
 
 
22bad87
ac0d5f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
"""
Inference Module - Model Prediction
"""
import os
import torch

# Model path
MODEL_SAVE_PATH = '../best_model'

# Emotion labels
EMOTION_LABELS = [
    "Neutral",
    "Anxiety/Fear",
    "Anger/Frustration",
    "Sadness/Helplessness",
    "Confusion/Doubt",
    "Gratitude/Relief"
]

try:
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    MODEL_LOADED = True
except ImportError:
    MODEL_LOADED = False


class EmotionClassifier:
    """Emotion Classification Inference"""
    
    def __init__(self):
        self.tokenizer = None
        self.model = None
        self.device = None
        self.loaded = False
    
    def load_model(self, model_path=None):
        """Load model"""
        if model_path is None:
            model_path = MODEL_SAVE_PATH
        
        if not MODEL_LOADED:
            return {'error': 'transformers library not installed'}
        
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model.to(self.device)
            self.model.eval()
            self.loaded = True
            
            return {
                'success': True,
                'device': str(self.device),
                'num_labels': len(EMOTION_LABELS),
                'labels': EMOTION_LABELS
            }
        except Exception as e:
            return {'error': f'Failed to load model: {str(e)}'}
    
    def predict(self, text, max_length=512):
        """Predict emotion for single text"""
        if not self.loaded:
            result = self.load_model()
            if 'error' in result:
                return result
        
        try:
            # Tokenize
            inputs = self.tokenizer(
                text, 
                return_tensors="pt", 
                padding=True, 
                truncation=True, 
                max_length=max_length
            )
            
            # Move to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            # Inference
            with torch.no_grad():
                outputs = self.model(**inputs)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=-1)
                predicted_class = torch.argmax(logits, dim=-1).item()
                confidence = probabilities[0][predicted_class].item()
            
            # Build result
            all_probs = probabilities[0].cpu().numpy().tolist()
            label_probs = [
                {'label': EMOTION_LABELS[i], 'probability': round(all_probs[i], 4)}
                for i in range(len(EMOTION_LABELS))
            ]
            
            return {
                'text': text[:100] + '...' if len(text) > 100 else text,
                'predicted_label': EMOTION_LABELS[predicted_class],
                'predicted_id': predicted_class,
                'confidence': round(confidence, 4),
                'all_probabilities': label_probs
            }
            
        except Exception as e:
            return {'error': f'Prediction failed: {str(e)}'}
    
    def predict_batch(self, texts, max_length=512):
        """Batch prediction"""
        if not self.loaded:
            result = self.load_model()
            if 'error' in result:
                return result
        
        results = []
        for text in texts:
            result = self.predict(text, max_length)
            results.append(result)
        
        return results
    
    def is_loaded(self):
        """Check if model is loaded"""
        return self.loaded


# Global classifier instance
_classifier_instance = None

def get_classifier():
    global _classifier_instance
    if _classifier_instance is None:
        _classifier_instance = EmotionClassifier()
    return _classifier_instance