File size: 5,949 Bytes
82c705b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import sys
from transformers import pipeline
from typing import Tuple

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from config import config
from logging_config import get_logger

# Initialize logger
logger = get_logger('model')


class ModelError(Exception):
    """Custom exception for model-related errors."""
    pass


class SentimentAnalyzer:
    """Sentiment analysis model wrapper with error handling and caching."""
    
    def __init__(self):
        self.pipeline = None
        self.model_name = config.MODEL_NAME
        self._load_model()
    
    def _load_model(self):
        """Load the sentiment analysis model with error handling and fallback."""
        try:
            logger.info(f"Loading sentiment analysis model: {self.model_name}")
            
            # Try loading the primary model first
            try:
                self.pipeline = pipeline(
                    "sentiment-analysis",
                    model=self.model_name,
                    top_k=1
                )
                logger.info("Primary model loaded successfully")
                
            except Exception as primary_error:
                logger.warning(f"Primary model failed to load: {primary_error}")
                logger.info("Trying fallback model: distilbert-base-uncased-finetuned-sst-2-english")
                
                # Fallback to a reliable model
                fallback_model = "distilbert-base-uncased-finetuned-sst-2-english"
                self.pipeline = pipeline(
                    "sentiment-analysis",
                    model=fallback_model,
                    top_k=1
                )
                self.model_name = fallback_model  # Update model name for logging
                logger.info("Fallback model loaded successfully")
            
            # Test the model with a simple prediction
            test_result = self.pipeline("This is a test.")
            logger.debug(f"Model test successful: {test_result}")
            
        except Exception as e:
            logger.error(f"Failed to load any sentiment analysis model: {e}")
            raise ModelError(f"Could not load sentiment analysis model: {e}")
    
    def predict(self, text: str) -> Tuple[str, float]:
        """
        Predict sentiment for given text.
        
        Args:
            text: Input text to analyze
            
        Returns:
            Tuple of (sentiment_label, confidence_score)
            
        Raises:
            ModelError: If prediction fails
        """
        try:
            if not self.pipeline:
                raise ModelError("Model not loaded")
            
            logger.debug(f"Running sentiment prediction on text of length {len(text)}")
            
            # Run prediction
            output = self.pipeline(text)
            
            if not output or len(output) == 0:
                raise ModelError("Model returned empty prediction")
            
            # Handle different output formats from different models
            if isinstance(output[0], list):
                # Some models return nested lists
                result = output[0][0] if output[0] else output[0]
            else:
                # Standard format
                result = output[0]
            
            raw_label = result["label"]
            score = result["score"]
            
            # Map model labels to human-readable labels
            sentiment = self._map_sentiment_label(raw_label)
            
            logger.debug(f"Prediction completed: {sentiment} (confidence: {score:.3f})")
            
            return sentiment, float(score)
            
        except Exception as e:
            logger.error(f"Prediction failed: {e}")
            raise ModelError(f"Sentiment prediction failed: {e}")
    
    def _map_sentiment_label(self, label: str) -> str:
        """
        Map model output labels to human-readable sentiment labels.
        
        Args:
            label: Raw label from model
            
        Returns:
            Human-readable sentiment label
        """
        label_mapping = {
            # Original model labels (fitsblb/YelpReviewsAnalyzer)
            "LABEL_0": "Negative",
            "LABEL_1": "Neutral", 
            "LABEL_2": "Positive",
            # Standard model labels (distilbert-base-uncased-finetuned-sst-2-english)
            "NEGATIVE": "Negative",
            "POSITIVE": "Positive",
            # Generic fallbacks
            "NEUTRAL": "Neutral"
        }
        
        mapped_label = label_mapping.get(label, "Unknown")
        
        if mapped_label == "Unknown":
            logger.warning(f"Unknown label received from model: {label}")
            # If it's an unknown label, try to infer from the label string
            label_lower = label.lower()
            if 'neg' in label_lower:
                mapped_label = "Negative"
            elif 'pos' in label_lower:
                mapped_label = "Positive"
            elif 'neu' in label_lower:
                mapped_label = "Neutral"
            else:
                mapped_label = "Neutral"  # Default fallback
        
        return mapped_label


# Global model instance
_sentiment_analyzer = None


def get_model() -> SentimentAnalyzer:
    """Get or create the global sentiment analyzer instance."""
    global _sentiment_analyzer
    
    if _sentiment_analyzer is None:
        _sentiment_analyzer = SentimentAnalyzer()
    
    return _sentiment_analyzer


def predict(text: str) -> Tuple[str, float]:
    """
    Convenience function for sentiment prediction.
    
    Args:
        text: Input text to analyze
        
    Returns:
        Tuple of (sentiment_label, confidence_score)
        
    Raises:
        ModelError: If prediction fails
    """
    model = get_model()
    return model.predict(text)