import torch import re import logging from typing import List, Dict, Tuple from functools import lru_cache from lime.lime_text import LimeTextExplainer from config import config from models import ModelManager from utils import handle_errors logger = logging.getLogger(__name__) class TextProcessor: """Optimized text processing""" @staticmethod @lru_cache(maxsize=config.CACHE_SIZE) def clean_text(text: str) -> Tuple[str, ...]: """Single-pass text cleaning""" words = re.findall(r'\b\w{3,}\b', text.lower()) return tuple(w for w in words if w not in config.STOP_WORDS) class SentimentEngine: """Streamlined sentiment analysis engine with LIME and SHAP""" def __init__(self): self.model_manager = ModelManager() self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive']) self.shap_explainer = None def predict_proba(self, texts): """Prediction function for LIME""" if isinstance(texts, str): texts = [texts] inputs = self.model_manager.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=config.MAX_TEXT_LENGTH ).to(self.model_manager.device) with torch.no_grad(): outputs = self.model_manager.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() return probs @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0}) def analyze_single_fast(self, text: str) -> Dict: """Fast single text analysis without keyword extraction""" if not text.strip(): raise ValueError("Empty text") probs = self.predict_proba([text])[0] sentiment = "Positive" if probs[1] > probs[0] else "Negative" return { 'sentiment': sentiment, 'confidence': float(probs.max()), 'pos_prob': float(probs[1]), 'neg_prob': float(probs[0]) } def extract_key_words_lime(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]: """Advanced keyword extraction using LIME""" try: explanation = self.lime_explainer.explain_instance( text, self.predict_proba, num_features=top_k, num_samples=200 ) word_scores = [] for word, score in explanation.as_list(): if len(word.strip()) >= config.MIN_WORD_LENGTH: word_scores.append((word.strip().lower(), abs(score))) word_scores.sort(key=lambda x: x[1], reverse=True) return word_scores[:top_k] except Exception as e: logger.error(f"LIME extraction failed: {e}") return [] def extract_key_words_shap(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]: """Advanced keyword extraction using SHAP""" try: # Simple SHAP implementation using model predictions words = text.split() word_scores = [] # Get baseline prediction baseline_prob = self.predict_proba([text])[0][1] # Positive probability # Calculate importance by removing each word for i, word in enumerate(words): # Create text without this word modified_words = words[:i] + words[i+1:] modified_text = ' '.join(modified_words) if modified_text.strip(): modified_prob = self.predict_proba([modified_text])[0][1] importance = abs(baseline_prob - modified_prob) clean_word = re.sub(r'[^\w]', '', word.lower()) if len(clean_word) >= config.MIN_WORD_LENGTH: word_scores.append((clean_word, importance)) # Remove duplicates and sort unique_scores = {} for word, score in word_scores: if word in unique_scores: unique_scores[word] = max(unique_scores[word], score) else: unique_scores[word] = score sorted_scores = sorted(unique_scores.items(), key=lambda x: x[1], reverse=True) return sorted_scores[:top_k] except Exception as e: logger.error(f"SHAP extraction failed: {e}") return [] def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str: """Create HTML heatmap visualization""" words = text.split() html_parts = ['
'] if word_scores: max_score = max(abs(score) for score in word_scores.values()) min_score = min(word_scores.values()) else: max_score = min_score = 0 for word in words: clean_word = re.sub(r'[^\w]', '', word.lower()) score = word_scores.get(clean_word, 0) if score > 0: intensity = min(255, int(180 * (score / max_score) if max_score > 0 else 0)) color = f"rgba(0, {intensity}, 0, 0.3)" elif score < 0: intensity = min(255, int(180 * (abs(score) / abs(min_score)) if min_score < 0 else 0)) color = f"rgba({intensity}, 0, 0, 0.3)" else: color = "transparent" html_parts.append( f'{word} ' ) html_parts.append('
') return ''.join(html_parts) @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'lime_words': [], 'shap_words': [], 'heatmap_html': ''}) def analyze_single_advanced(self, text: str) -> Dict: """Advanced single text analysis with LIME and SHAP explanation""" if not text.strip(): raise ValueError("Empty text") probs = self.predict_proba([text])[0] sentiment = "Positive" if probs[1] > probs[0] else "Negative" # Extract key words using both LIME and SHAP lime_words = self.extract_key_words_lime(text) shap_words = self.extract_key_words_shap(text) # Create heatmap HTML using LIME results word_scores_dict = dict(lime_words) heatmap_html = self.create_heatmap_html(text, word_scores_dict) return { 'sentiment': sentiment, 'confidence': float(probs.max()), 'pos_prob': float(probs[1]), 'neg_prob': float(probs[0]), 'lime_words': lime_words, 'shap_words': shap_words, 'heatmap_html': heatmap_html } @handle_errors(default_return=[]) def analyze_batch(self, texts: List[str], progress_callback=None) -> List[Dict]: """Optimized batch processing""" if len(texts) > config.BATCH_SIZE_LIMIT: texts = texts[:config.BATCH_SIZE_LIMIT] results = [] batch_size = config.BATCH_PROCESSING_SIZE for i in range(0, len(texts), batch_size): batch = texts[i:i+batch_size] if progress_callback: progress_callback((i + len(batch)) / len(texts)) inputs = self.model_manager.tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=config.MAX_TEXT_LENGTH ).to(self.model_manager.device) with torch.no_grad(): outputs = self.model_manager.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() for text, prob in zip(batch, probs): sentiment = "Positive" if prob[1] > prob[0] else "Negative" results.append({ 'text': text[:50] + '...' if len(text) > 50 else text, 'full_text': text, 'sentiment': sentiment, 'confidence': float(prob.max()), 'pos_prob': float(prob[1]), 'neg_prob': float(prob[0]) }) return results