Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,8 @@ from dataclasses import dataclass
|
|
| 17 |
from typing import List, Dict, Optional, Tuple, Any, Callable
|
| 18 |
from contextlib import contextmanager
|
| 19 |
import gc
|
| 20 |
-
import pandas as pd
|
|
|
|
| 21 |
|
| 22 |
@dataclass
|
| 23 |
class Config:
|
|
@@ -155,99 +156,109 @@ class HistoryManager:
|
|
| 155 |
|
| 156 |
# Core Analysis Engine
|
| 157 |
class SentimentEngine:
|
| 158 |
-
"""Streamlined sentiment analysis with
|
| 159 |
def __init__(self):
|
| 160 |
self.model_manager = ModelManager()
|
|
|
|
| 161 |
|
| 162 |
-
def
|
| 163 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
try:
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
# Get model outputs with attention weights
|
| 171 |
-
with torch.no_grad():
|
| 172 |
-
outputs = self.model_manager.model(**inputs, output_attentions=True)
|
| 173 |
-
attention = outputs.attentions # Tuple of attention tensors for each layer
|
| 174 |
-
|
| 175 |
-
# Use the last layer's attention, average over all heads
|
| 176 |
-
last_attention = attention[-1] # Shape: [batch_size, num_heads, seq_len, seq_len]
|
| 177 |
-
avg_attention = last_attention.mean(dim=1) # Average over heads: [batch_size, seq_len, seq_len]
|
| 178 |
-
|
| 179 |
-
# Focus on attention to [CLS] token (index 0) as it represents the whole sequence
|
| 180 |
-
cls_attention = avg_attention[0, 0, :] # Attention from CLS to all tokens
|
| 181 |
-
|
| 182 |
-
# Get tokens and their attention scores
|
| 183 |
-
tokens = self.model_manager.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
| 184 |
-
attention_scores = cls_attention.cpu().numpy()
|
| 185 |
-
|
| 186 |
-
# Filter out special tokens and combine subword tokens
|
| 187 |
-
word_scores = {}
|
| 188 |
-
current_word = ""
|
| 189 |
-
current_score = 0.0
|
| 190 |
-
|
| 191 |
-
for i, (token, score) in enumerate(zip(tokens, attention_scores)):
|
| 192 |
-
if token in ['[CLS]', '[SEP]', '[PAD]']:
|
| 193 |
-
continue
|
| 194 |
-
|
| 195 |
-
if token.startswith('##'):
|
| 196 |
-
# Subword token, add to current word
|
| 197 |
-
current_word += token[2:]
|
| 198 |
-
current_score = max(current_score, score) # Take max attention
|
| 199 |
-
else:
|
| 200 |
-
# New word, save previous if exists
|
| 201 |
-
if current_word and len(current_word) >= config.MIN_WORD_LENGTH:
|
| 202 |
-
word_scores[current_word.lower()] = current_score
|
| 203 |
-
|
| 204 |
-
current_word = token
|
| 205 |
-
current_score = score
|
| 206 |
-
|
| 207 |
-
# Don't forget the last word
|
| 208 |
-
if current_word and len(current_word) >= config.MIN_WORD_LENGTH:
|
| 209 |
-
word_scores[current_word.lower()] = current_score
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
if
|
| 215 |
-
|
| 216 |
|
| 217 |
-
# Sort by
|
| 218 |
-
|
| 219 |
-
return
|
| 220 |
|
| 221 |
except Exception as e:
|
| 222 |
-
logger.error(f"
|
| 223 |
return []
|
| 224 |
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
def analyze_single(self, text: str) -> Dict:
|
| 227 |
-
"""Analyze single text with
|
| 228 |
if not text.strip():
|
| 229 |
raise ValueError("Empty text")
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
truncation=True, max_length=config.MAX_TEXT_LENGTH
|
| 234 |
-
).to(self.model_manager.device)
|
| 235 |
-
|
| 236 |
-
with torch.no_grad():
|
| 237 |
-
outputs = self.model_manager.model(**inputs)
|
| 238 |
-
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
|
| 239 |
-
|
| 240 |
sentiment = "Positive" if probs[1] > probs[0] else "Negative"
|
| 241 |
|
| 242 |
-
# Extract key
|
| 243 |
-
key_words = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
|
| 245 |
return {
|
| 246 |
'sentiment': sentiment,
|
| 247 |
'confidence': float(probs.max()),
|
| 248 |
'pos_prob': float(probs[1]),
|
| 249 |
'neg_prob': float(probs[0]),
|
| 250 |
-
'key_words': key_words
|
|
|
|
| 251 |
}
|
| 252 |
|
| 253 |
@handle_errors(default_return=[])
|
|
@@ -585,11 +596,11 @@ class SentimentApp:
|
|
| 585 |
]
|
| 586 |
|
| 587 |
|
| 588 |
-
@handle_errors(default_return=("Please enter text", None, None, None, None))
|
| 589 |
def analyze_single(self, text: str, theme: str = 'default'):
|
| 590 |
-
"""Single text analysis with
|
| 591 |
if not text.strip():
|
| 592 |
-
return "Please enter text", None, None, None, None
|
| 593 |
|
| 594 |
result = self.engine.analyze_single(text)
|
| 595 |
|
|
@@ -614,7 +625,8 @@ class SentimentApp:
|
|
| 614 |
result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
|
| 615 |
f"Key Words: {key_words_str}")
|
| 616 |
|
| 617 |
-
|
|
|
|
| 618 |
|
| 619 |
@handle_errors(default_return=None)
|
| 620 |
def analyze_batch(self, reviews: str, progress=None):
|
|
@@ -706,6 +718,7 @@ def create_interface():
|
|
| 706 |
|
| 707 |
with gr.Column():
|
| 708 |
result_output = gr.Textbox(label="Result", lines=3)
|
|
|
|
| 709 |
|
| 710 |
with gr.Row():
|
| 711 |
prob_plot = gr.Plot(label="Probabilities")
|
|
@@ -749,7 +762,7 @@ def create_interface():
|
|
| 749 |
analyze_btn.click(
|
| 750 |
app.analyze_single,
|
| 751 |
inputs=[text_input, theme_selector],
|
| 752 |
-
outputs=[result_output, prob_plot, gauge_plot, wordcloud_plot, keyword_plot]
|
| 753 |
)
|
| 754 |
|
| 755 |
load_btn.click(app.data_handler.process_file, inputs=file_upload, outputs=batch_input)
|
|
|
|
| 17 |
from typing import List, Dict, Optional, Tuple, Any, Callable
|
| 18 |
from contextlib import contextmanager
|
| 19 |
import gc
|
| 20 |
+
import pandas as pd
|
| 21 |
+
from lime.lime_text import LimeTextExplainer # Added LIME import
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class Config:
|
|
|
|
| 156 |
|
| 157 |
# Core Analysis Engine
|
| 158 |
class SentimentEngine:
|
| 159 |
+
"""Streamlined sentiment analysis with LIME-based keyword extraction"""
|
| 160 |
def __init__(self):
|
| 161 |
self.model_manager = ModelManager()
|
| 162 |
+
self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
|
| 163 |
|
| 164 |
+
def predict_proba(self, texts):
|
| 165 |
+
"""Prediction function for LIME"""
|
| 166 |
+
if isinstance(texts, str):
|
| 167 |
+
texts = [texts]
|
| 168 |
+
|
| 169 |
+
inputs = self.model_manager.tokenizer(
|
| 170 |
+
texts, return_tensors="pt", padding=True,
|
| 171 |
+
truncation=True, max_length=config.MAX_TEXT_LENGTH
|
| 172 |
+
).to(self.model_manager.device)
|
| 173 |
+
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
outputs = self.model_manager.model(**inputs)
|
| 176 |
+
probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
|
| 177 |
+
|
| 178 |
+
return probs
|
| 179 |
+
|
| 180 |
+
def extract_key_words_lime(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
| 181 |
+
"""Fast keyword extraction using LIME"""
|
| 182 |
try:
|
| 183 |
+
# Get LIME explanation
|
| 184 |
+
explanation = self.lime_explainer.explain_instance(
|
| 185 |
+
text, self.predict_proba, num_features=top_k, num_samples=100
|
| 186 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
# Extract word importance scores
|
| 189 |
+
word_scores = []
|
| 190 |
+
for word, score in explanation.as_list():
|
| 191 |
+
if len(word.strip()) >= config.MIN_WORD_LENGTH:
|
| 192 |
+
word_scores.append((word.strip().lower(), abs(score)))
|
| 193 |
|
| 194 |
+
# Sort by importance and return top_k
|
| 195 |
+
word_scores.sort(key=lambda x: x[1], reverse=True)
|
| 196 |
+
return word_scores[:top_k]
|
| 197 |
|
| 198 |
except Exception as e:
|
| 199 |
+
logger.error(f"LIME extraction failed: {e}")
|
| 200 |
return []
|
| 201 |
|
| 202 |
+
def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
|
| 203 |
+
"""Create HTML heatmap visualization"""
|
| 204 |
+
words = text.split()
|
| 205 |
+
html_parts = ['<div style="font-family: Arial; font-size: 16px; line-height: 1.6;">']
|
| 206 |
+
|
| 207 |
+
# Normalize scores for color intensity
|
| 208 |
+
if word_scores:
|
| 209 |
+
max_score = max(abs(score) for score in word_scores.values())
|
| 210 |
+
min_score = min(word_scores.values())
|
| 211 |
+
else:
|
| 212 |
+
max_score = min_score = 0
|
| 213 |
+
|
| 214 |
+
for word in words:
|
| 215 |
+
clean_word = re.sub(r'[^\w]', '', word.lower())
|
| 216 |
+
score = word_scores.get(clean_word, 0)
|
| 217 |
+
|
| 218 |
+
if score > 0:
|
| 219 |
+
# Positive contribution - green
|
| 220 |
+
intensity = min(255, int(180 * (score / max_score) if max_score > 0 else 0))
|
| 221 |
+
color = f"rgba(0, {intensity}, 0, 0.3)"
|
| 222 |
+
elif score < 0:
|
| 223 |
+
# Negative contribution - red
|
| 224 |
+
intensity = min(255, int(180 * (abs(score) / abs(min_score)) if min_score < 0 else 0))
|
| 225 |
+
color = f"rgba({intensity}, 0, 0, 0.3)"
|
| 226 |
+
else:
|
| 227 |
+
# Neutral - no highlighting
|
| 228 |
+
color = "transparent"
|
| 229 |
+
|
| 230 |
+
html_parts.append(
|
| 231 |
+
f'<span style="background-color: {color}; padding: 2px; margin: 1px; '
|
| 232 |
+
f'border-radius: 3px;" title="Score: {score:.3f}">{word}</span> '
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
html_parts.append('</div>')
|
| 236 |
+
return ''.join(html_parts)
|
| 237 |
+
|
| 238 |
+
@handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'key_words': [], 'heatmap_html': ''})
|
| 239 |
def analyze_single(self, text: str) -> Dict:
|
| 240 |
+
"""Analyze single text with LIME explanation"""
|
| 241 |
if not text.strip():
|
| 242 |
raise ValueError("Empty text")
|
| 243 |
|
| 244 |
+
# Get sentiment prediction
|
| 245 |
+
probs = self.predict_proba([text])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
sentiment = "Positive" if probs[1] > probs[0] else "Negative"
|
| 247 |
|
| 248 |
+
# Extract key words using LIME
|
| 249 |
+
key_words = self.extract_key_words_lime(text)
|
| 250 |
+
|
| 251 |
+
# Create heatmap HTML
|
| 252 |
+
word_scores_dict = dict(key_words)
|
| 253 |
+
heatmap_html = self.create_heatmap_html(text, word_scores_dict)
|
| 254 |
|
| 255 |
return {
|
| 256 |
'sentiment': sentiment,
|
| 257 |
'confidence': float(probs.max()),
|
| 258 |
'pos_prob': float(probs[1]),
|
| 259 |
'neg_prob': float(probs[0]),
|
| 260 |
+
'key_words': key_words,
|
| 261 |
+
'heatmap_html': heatmap_html
|
| 262 |
}
|
| 263 |
|
| 264 |
@handle_errors(default_return=[])
|
|
|
|
| 596 |
]
|
| 597 |
|
| 598 |
|
| 599 |
+
@handle_errors(default_return=("Please enter text", None, None, None, None, None))
|
| 600 |
def analyze_single(self, text: str, theme: str = 'default'):
|
| 601 |
+
"""Single text analysis with LIME explanation and heatmap"""
|
| 602 |
if not text.strip():
|
| 603 |
+
return "Please enter text", None, None, None, None, None
|
| 604 |
|
| 605 |
result = self.engine.analyze_single(text)
|
| 606 |
|
|
|
|
| 625 |
result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
|
| 626 |
f"Key Words: {key_words_str}")
|
| 627 |
|
| 628 |
+
# Return heatmap HTML as additional output
|
| 629 |
+
return result_text, prob_plot, gauge_plot, cloud_plot, keyword_plot, result['heatmap_html']
|
| 630 |
|
| 631 |
@handle_errors(default_return=None)
|
| 632 |
def analyze_batch(self, reviews: str, progress=None):
|
|
|
|
| 718 |
|
| 719 |
with gr.Column():
|
| 720 |
result_output = gr.Textbox(label="Result", lines=3)
|
| 721 |
+
heatmap_output = gr.HTML(label="Word Importance Heatmap")
|
| 722 |
|
| 723 |
with gr.Row():
|
| 724 |
prob_plot = gr.Plot(label="Probabilities")
|
|
|
|
| 762 |
analyze_btn.click(
|
| 763 |
app.analyze_single,
|
| 764 |
inputs=[text_input, theme_selector],
|
| 765 |
+
outputs=[result_output, prob_plot, gauge_plot, wordcloud_plot, keyword_plot, heatmap_output]
|
| 766 |
)
|
| 767 |
|
| 768 |
load_btn.click(app.data_handler.process_file, inputs=file_upload, outputs=batch_input)
|