Spaces:
Sleeping
Sleeping
File size: 5,949 Bytes
82c705b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import sys
from transformers import pipeline
from typing import Tuple
# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import config
from logging_config import get_logger
# Initialize logger
logger = get_logger('model')
class ModelError(Exception):
"""Custom exception for model-related errors."""
pass
class SentimentAnalyzer:
"""Sentiment analysis model wrapper with error handling and caching."""
def __init__(self):
self.pipeline = None
self.model_name = config.MODEL_NAME
self._load_model()
def _load_model(self):
"""Load the sentiment analysis model with error handling and fallback."""
try:
logger.info(f"Loading sentiment analysis model: {self.model_name}")
# Try loading the primary model first
try:
self.pipeline = pipeline(
"sentiment-analysis",
model=self.model_name,
top_k=1
)
logger.info("Primary model loaded successfully")
except Exception as primary_error:
logger.warning(f"Primary model failed to load: {primary_error}")
logger.info("Trying fallback model: distilbert-base-uncased-finetuned-sst-2-english")
# Fallback to a reliable model
fallback_model = "distilbert-base-uncased-finetuned-sst-2-english"
self.pipeline = pipeline(
"sentiment-analysis",
model=fallback_model,
top_k=1
)
self.model_name = fallback_model # Update model name for logging
logger.info("Fallback model loaded successfully")
# Test the model with a simple prediction
test_result = self.pipeline("This is a test.")
logger.debug(f"Model test successful: {test_result}")
except Exception as e:
logger.error(f"Failed to load any sentiment analysis model: {e}")
raise ModelError(f"Could not load sentiment analysis model: {e}")
def predict(self, text: str) -> Tuple[str, float]:
"""
Predict sentiment for given text.
Args:
text: Input text to analyze
Returns:
Tuple of (sentiment_label, confidence_score)
Raises:
ModelError: If prediction fails
"""
try:
if not self.pipeline:
raise ModelError("Model not loaded")
logger.debug(f"Running sentiment prediction on text of length {len(text)}")
# Run prediction
output = self.pipeline(text)
if not output or len(output) == 0:
raise ModelError("Model returned empty prediction")
# Handle different output formats from different models
if isinstance(output[0], list):
# Some models return nested lists
result = output[0][0] if output[0] else output[0]
else:
# Standard format
result = output[0]
raw_label = result["label"]
score = result["score"]
# Map model labels to human-readable labels
sentiment = self._map_sentiment_label(raw_label)
logger.debug(f"Prediction completed: {sentiment} (confidence: {score:.3f})")
return sentiment, float(score)
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise ModelError(f"Sentiment prediction failed: {e}")
def _map_sentiment_label(self, label: str) -> str:
"""
Map model output labels to human-readable sentiment labels.
Args:
label: Raw label from model
Returns:
Human-readable sentiment label
"""
label_mapping = {
# Original model labels (fitsblb/YelpReviewsAnalyzer)
"LABEL_0": "Negative",
"LABEL_1": "Neutral",
"LABEL_2": "Positive",
# Standard model labels (distilbert-base-uncased-finetuned-sst-2-english)
"NEGATIVE": "Negative",
"POSITIVE": "Positive",
# Generic fallbacks
"NEUTRAL": "Neutral"
}
mapped_label = label_mapping.get(label, "Unknown")
if mapped_label == "Unknown":
logger.warning(f"Unknown label received from model: {label}")
# If it's an unknown label, try to infer from the label string
label_lower = label.lower()
if 'neg' in label_lower:
mapped_label = "Negative"
elif 'pos' in label_lower:
mapped_label = "Positive"
elif 'neu' in label_lower:
mapped_label = "Neutral"
else:
mapped_label = "Neutral" # Default fallback
return mapped_label
# Global model instance
_sentiment_analyzer = None
def get_model() -> SentimentAnalyzer:
"""Get or create the global sentiment analyzer instance."""
global _sentiment_analyzer
if _sentiment_analyzer is None:
_sentiment_analyzer = SentimentAnalyzer()
return _sentiment_analyzer
def predict(text: str) -> Tuple[str, float]:
"""
Convenience function for sentiment prediction.
Args:
text: Input text to analyze
Returns:
Tuple of (sentiment_label, confidence_score)
Raises:
ModelError: If prediction fails
"""
model = get_model()
return model.predict(text)
|