Spaces:
Sleeping
Sleeping
| import torch | |
| import re | |
| import logging | |
| from typing import List, Dict, Tuple | |
| from functools import lru_cache | |
| from lime.lime_text import LimeTextExplainer | |
| from config import config | |
| from models import ModelManager | |
| from utils import handle_errors | |
| logger = logging.getLogger(__name__) | |
| class TextProcessor: | |
| """Optimized text processing""" | |
| def clean_text(text: str) -> Tuple[str, ...]: | |
| """Single-pass text cleaning""" | |
| words = re.findall(r'\b\w{3,}\b', text.lower()) | |
| return tuple(w for w in words if w not in config.STOP_WORDS) | |
| class SentimentEngine: | |
| """Streamlined sentiment analysis engine with LIME and SHAP""" | |
| def __init__(self): | |
| self.model_manager = ModelManager() | |
| self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive']) | |
| self.shap_explainer = None | |
| def predict_proba(self, texts): | |
| """Prediction function for LIME""" | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| inputs = self.model_manager.tokenizer( | |
| texts, return_tensors="pt", padding=True, | |
| truncation=True, max_length=config.MAX_TEXT_LENGTH | |
| ).to(self.model_manager.device) | |
| with torch.no_grad(): | |
| outputs = self.model_manager.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| return probs | |
| def analyze_single_fast(self, text: str) -> Dict: | |
| """Fast single text analysis without keyword extraction""" | |
| if not text.strip(): | |
| raise ValueError("Empty text") | |
| probs = self.predict_proba([text])[0] | |
| sentiment = "Positive" if probs[1] > probs[0] else "Negative" | |
| return { | |
| 'sentiment': sentiment, | |
| 'confidence': float(probs.max()), | |
| 'pos_prob': float(probs[1]), | |
| 'neg_prob': float(probs[0]) | |
| } | |
| def extract_key_words_lime(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]: | |
| """Advanced keyword extraction using LIME""" | |
| try: | |
| explanation = self.lime_explainer.explain_instance( | |
| text, self.predict_proba, num_features=top_k, num_samples=200 | |
| ) | |
| word_scores = [] | |
| for word, score in explanation.as_list(): | |
| if len(word.strip()) >= config.MIN_WORD_LENGTH: | |
| word_scores.append((word.strip().lower(), abs(score))) | |
| word_scores.sort(key=lambda x: x[1], reverse=True) | |
| return word_scores[:top_k] | |
| except Exception as e: | |
| logger.error(f"LIME extraction failed: {e}") | |
| return [] | |
| def extract_key_words_shap(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]: | |
| """Advanced keyword extraction using SHAP""" | |
| try: | |
| # Simple SHAP implementation using model predictions | |
| words = text.split() | |
| word_scores = [] | |
| # Get baseline prediction | |
| baseline_prob = self.predict_proba([text])[0][1] # Positive probability | |
| # Calculate importance by removing each word | |
| for i, word in enumerate(words): | |
| # Create text without this word | |
| modified_words = words[:i] + words[i+1:] | |
| modified_text = ' '.join(modified_words) | |
| if modified_text.strip(): | |
| modified_prob = self.predict_proba([modified_text])[0][1] | |
| importance = abs(baseline_prob - modified_prob) | |
| clean_word = re.sub(r'[^\w]', '', word.lower()) | |
| if len(clean_word) >= config.MIN_WORD_LENGTH: | |
| word_scores.append((clean_word, importance)) | |
| # Remove duplicates and sort | |
| unique_scores = {} | |
| for word, score in word_scores: | |
| if word in unique_scores: | |
| unique_scores[word] = max(unique_scores[word], score) | |
| else: | |
| unique_scores[word] = score | |
| sorted_scores = sorted(unique_scores.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_scores[:top_k] | |
| except Exception as e: | |
| logger.error(f"SHAP extraction failed: {e}") | |
| return [] | |
| def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str: | |
| """Create HTML heatmap visualization""" | |
| words = text.split() | |
| html_parts = ['<div style="font-family: Arial; font-size: 16px; line-height: 1.6;">'] | |
| if word_scores: | |
| max_score = max(abs(score) for score in word_scores.values()) | |
| min_score = min(word_scores.values()) | |
| else: | |
| max_score = min_score = 0 | |
| for word in words: | |
| clean_word = re.sub(r'[^\w]', '', word.lower()) | |
| score = word_scores.get(clean_word, 0) | |
| if score > 0: | |
| intensity = min(255, int(180 * (score / max_score) if max_score > 0 else 0)) | |
| color = f"rgba(0, {intensity}, 0, 0.3)" | |
| elif score < 0: | |
| intensity = min(255, int(180 * (abs(score) / abs(min_score)) if min_score < 0 else 0)) | |
| color = f"rgba({intensity}, 0, 0, 0.3)" | |
| else: | |
| color = "transparent" | |
| html_parts.append( | |
| f'<span style="background-color: {color}; padding: 2px; margin: 1px; ' | |
| f'border-radius: 3px;" title="Score: {score:.3f}">{word}</span> ' | |
| ) | |
| html_parts.append('</div>') | |
| return ''.join(html_parts) | |
| def analyze_single_advanced(self, text: str) -> Dict: | |
| """Advanced single text analysis with LIME and SHAP explanation""" | |
| if not text.strip(): | |
| raise ValueError("Empty text") | |
| probs = self.predict_proba([text])[0] | |
| sentiment = "Positive" if probs[1] > probs[0] else "Negative" | |
| # Extract key words using both LIME and SHAP | |
| lime_words = self.extract_key_words_lime(text) | |
| shap_words = self.extract_key_words_shap(text) | |
| # Create heatmap HTML using LIME results | |
| word_scores_dict = dict(lime_words) | |
| heatmap_html = self.create_heatmap_html(text, word_scores_dict) | |
| return { | |
| 'sentiment': sentiment, | |
| 'confidence': float(probs.max()), | |
| 'pos_prob': float(probs[1]), | |
| 'neg_prob': float(probs[0]), | |
| 'lime_words': lime_words, | |
| 'shap_words': shap_words, | |
| 'heatmap_html': heatmap_html | |
| } | |
| def analyze_batch(self, texts: List[str], progress_callback=None) -> List[Dict]: | |
| """Optimized batch processing""" | |
| if len(texts) > config.BATCH_SIZE_LIMIT: | |
| texts = texts[:config.BATCH_SIZE_LIMIT] | |
| results = [] | |
| batch_size = config.BATCH_PROCESSING_SIZE | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i+batch_size] | |
| if progress_callback: | |
| progress_callback((i + len(batch)) / len(texts)) | |
| inputs = self.model_manager.tokenizer( | |
| batch, return_tensors="pt", padding=True, | |
| truncation=True, max_length=config.MAX_TEXT_LENGTH | |
| ).to(self.model_manager.device) | |
| with torch.no_grad(): | |
| outputs = self.model_manager.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy() | |
| for text, prob in zip(batch, probs): | |
| sentiment = "Positive" if prob[1] > prob[0] else "Negative" | |
| results.append({ | |
| 'text': text[:50] + '...' if len(text) > 50 else text, | |
| 'full_text': text, | |
| 'sentiment': sentiment, | |
| 'confidence': float(prob.max()), | |
| 'pos_prob': float(prob[1]), | |
| 'neg_prob': float(prob[0]) | |
| }) | |
| return results |