Spaces:
Build error
Build error
| # FinBERT sentiment analysis module for financial news | |
| """ | |
| This module handles loading the ProsusAI/finbert model and extracting | |
| sentiment predictions with confidence scores from financial news text. | |
| """ | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import streamlit as st | |
| from typing import Dict, Tuple | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class FinBERTAnalyzer: | |
| """ | |
| A wrapper class for the ProsusAI/finbert model to analyze financial sentiment. | |
| """ | |
| def __init__(self, model_name: str = "ProsusAI/finbert"): | |
| """ | |
| Initialize the FinBERT analyzer. | |
| Args: | |
| model_name: The Hugging Face model identifier | |
| """ | |
| self.model_name = model_name | |
| self.tokenizer = None | |
| self.model = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(_self): | |
| """ | |
| Load the FinBERT model and tokenizer with caching. | |
| Using _self to avoid hashing issues with streamlit cache. | |
| """ | |
| try: | |
| logger.info(f"Loading FinBERT model: {_self.model_name}") | |
| # Try to load tokenizer first | |
| _self.tokenizer = AutoTokenizer.from_pretrained( | |
| _self.model_name, | |
| cache_dir=None, # Use default cache | |
| local_files_only=False # Allow downloading if needed | |
| ) | |
| logger.info("Tokenizer loaded successfully") | |
| # Load model | |
| _self.model = AutoModelForSequenceClassification.from_pretrained( | |
| _self.model_name, | |
| cache_dir=None, # Use default cache | |
| local_files_only=False # Allow downloading if needed | |
| ) | |
| _self.model.to(_self.device) | |
| _self.model.eval() | |
| logger.info("FinBERT model loaded successfully") | |
| return True | |
| except Exception as e: | |
| error_msg = f"Error loading FinBERT model: {str(e)}" | |
| logger.error(error_msg) | |
| # Provide helpful error messages | |
| if "Connection" in str(e) or "timeout" in str(e).lower(): | |
| logger.error("Network connection issue. Check internet connectivity.") | |
| elif "disk" in str(e).lower() or "space" in str(e).lower(): | |
| logger.error("Insufficient disk space for model download.") | |
| elif "permission" in str(e).lower(): | |
| logger.error("Permission denied. Check file/directory permissions.") | |
| return False | |
| def analyze_sentiment(self, text: str) -> Dict[str, float]: | |
| """ | |
| Analyze sentiment of financial news text. | |
| Args: | |
| text: The financial news text to analyze | |
| Returns: | |
| Dictionary containing sentiment label, confidence, and raw scores | |
| """ | |
| if not self.model or not self.tokenizer: | |
| if not self.load_model(): | |
| raise RuntimeError("Failed to load FinBERT model") | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Get predictions | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Extract results | |
| scores = predictions.cpu().numpy()[0] | |
| labels = ["negative", "neutral", "positive"] | |
| # Find the predicted sentiment and confidence | |
| predicted_idx = scores.argmax() | |
| predicted_sentiment = labels[predicted_idx] | |
| confidence = float(scores[predicted_idx]) | |
| return { | |
| "sentiment": predicted_sentiment, | |
| "confidence": confidence, | |
| "scores": { | |
| "negative": float(scores[0]), | |
| "neutral": float(scores[1]), | |
| "positive": float(scores[2]) | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error analyzing sentiment: {str(e)}") | |
| raise RuntimeError(f"Sentiment analysis failed: {str(e)}") | |
| def get_sentiment_direction(self, sentiment: str) -> int: | |
| """ | |
| Convert sentiment label to numerical direction for evaluation. | |
| Args: | |
| sentiment: The sentiment label ("positive", "negative", "neutral") | |
| Returns: | |
| 1 for positive, -1 for negative, 0 for neutral | |
| """ | |
| sentiment_map = { | |
| "positive": 1, | |
| "negative": -1, | |
| "neutral": 0 | |
| } | |
| return sentiment_map.get(sentiment.lower(), 0) | |