Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from transformers import pipeline | |
| from typing import Tuple | |
| # Add parent directory to path for imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import config | |
| from logging_config import get_logger | |
| # Initialize logger | |
| logger = get_logger('model') | |
| class ModelError(Exception): | |
| """Custom exception for model-related errors.""" | |
| pass | |
| class SentimentAnalyzer: | |
| """Sentiment analysis model wrapper with error handling and caching.""" | |
| def __init__(self): | |
| self.pipeline = None | |
| self.model_name = config.MODEL_NAME | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the sentiment analysis model with error handling and fallback.""" | |
| try: | |
| logger.info(f"Loading sentiment analysis model: {self.model_name}") | |
| # Try loading the primary model first | |
| try: | |
| self.pipeline = pipeline( | |
| "sentiment-analysis", | |
| model=self.model_name, | |
| top_k=1 | |
| ) | |
| logger.info("Primary model loaded successfully") | |
| except Exception as primary_error: | |
| logger.warning(f"Primary model failed to load: {primary_error}") | |
| logger.info("Trying fallback model: distilbert-base-uncased-finetuned-sst-2-english") | |
| # Fallback to a reliable model | |
| fallback_model = "distilbert-base-uncased-finetuned-sst-2-english" | |
| self.pipeline = pipeline( | |
| "sentiment-analysis", | |
| model=fallback_model, | |
| top_k=1 | |
| ) | |
| self.model_name = fallback_model # Update model name for logging | |
| logger.info("Fallback model loaded successfully") | |
| # Test the model with a simple prediction | |
| test_result = self.pipeline("This is a test.") | |
| logger.debug(f"Model test successful: {test_result}") | |
| except Exception as e: | |
| logger.error(f"Failed to load any sentiment analysis model: {e}") | |
| raise ModelError(f"Could not load sentiment analysis model: {e}") | |
| def predict(self, text: str) -> Tuple[str, float]: | |
| """ | |
| Predict sentiment for given text. | |
| Args: | |
| text: Input text to analyze | |
| Returns: | |
| Tuple of (sentiment_label, confidence_score) | |
| Raises: | |
| ModelError: If prediction fails | |
| """ | |
| try: | |
| if not self.pipeline: | |
| raise ModelError("Model not loaded") | |
| logger.debug(f"Running sentiment prediction on text of length {len(text)}") | |
| # Run prediction | |
| output = self.pipeline(text) | |
| if not output or len(output) == 0: | |
| raise ModelError("Model returned empty prediction") | |
| # Handle different output formats from different models | |
| if isinstance(output[0], list): | |
| # Some models return nested lists | |
| result = output[0][0] if output[0] else output[0] | |
| else: | |
| # Standard format | |
| result = output[0] | |
| raw_label = result["label"] | |
| score = result["score"] | |
| # Map model labels to human-readable labels | |
| sentiment = self._map_sentiment_label(raw_label) | |
| logger.debug(f"Prediction completed: {sentiment} (confidence: {score:.3f})") | |
| return sentiment, float(score) | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| raise ModelError(f"Sentiment prediction failed: {e}") | |
| def _map_sentiment_label(self, label: str) -> str: | |
| """ | |
| Map model output labels to human-readable sentiment labels. | |
| Args: | |
| label: Raw label from model | |
| Returns: | |
| Human-readable sentiment label | |
| """ | |
| label_mapping = { | |
| # Original model labels (fitsblb/YelpReviewsAnalyzer) | |
| "LABEL_0": "Negative", | |
| "LABEL_1": "Neutral", | |
| "LABEL_2": "Positive", | |
| # Standard model labels (distilbert-base-uncased-finetuned-sst-2-english) | |
| "NEGATIVE": "Negative", | |
| "POSITIVE": "Positive", | |
| # Generic fallbacks | |
| "NEUTRAL": "Neutral" | |
| } | |
| mapped_label = label_mapping.get(label, "Unknown") | |
| if mapped_label == "Unknown": | |
| logger.warning(f"Unknown label received from model: {label}") | |
| # If it's an unknown label, try to infer from the label string | |
| label_lower = label.lower() | |
| if 'neg' in label_lower: | |
| mapped_label = "Negative" | |
| elif 'pos' in label_lower: | |
| mapped_label = "Positive" | |
| elif 'neu' in label_lower: | |
| mapped_label = "Neutral" | |
| else: | |
| mapped_label = "Neutral" # Default fallback | |
| return mapped_label | |
| # Global model instance | |
| _sentiment_analyzer = None | |
| def get_model() -> SentimentAnalyzer: | |
| """Get or create the global sentiment analyzer instance.""" | |
| global _sentiment_analyzer | |
| if _sentiment_analyzer is None: | |
| _sentiment_analyzer = SentimentAnalyzer() | |
| return _sentiment_analyzer | |
| def predict(text: str) -> Tuple[str, float]: | |
| """ | |
| Convenience function for sentiment prediction. | |
| Args: | |
| text: Input text to analyze | |
| Returns: | |
| Tuple of (sentiment_label, confidence_score) | |
| Raises: | |
| ModelError: If prediction fails | |
| """ | |
| model = get_model() | |
| return model.predict(text) | |