Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,7 @@ from contextlib import contextmanager
|
|
| 19 |
import gc
|
| 20 |
import pandas as pd
|
| 21 |
from lime.lime_text import LimeTextExplainer
|
|
|
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class Config:
|
|
@@ -156,10 +157,11 @@ class HistoryManager:
|
|
| 156 |
|
| 157 |
# Core Analysis Engine
|
| 158 |
class SentimentEngine:
|
| 159 |
-
"""Streamlined sentiment analysis engine"""
|
| 160 |
def __init__(self):
|
| 161 |
self.model_manager = ModelManager()
|
| 162 |
self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
|
|
|
|
| 163 |
|
| 164 |
def predict_proba(self, texts):
|
| 165 |
"""Prediction function for LIME"""
|
|
@@ -212,6 +214,37 @@ class SentimentEngine:
|
|
| 212 |
logger.error(f"LIME extraction failed: {e}")
|
| 213 |
return []
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
|
| 216 |
"""Create HTML heatmap visualization"""
|
| 217 |
words = text.split()
|
|
@@ -244,20 +277,21 @@ class SentimentEngine:
|
|
| 244 |
html_parts.append('</div>')
|
| 245 |
return ''.join(html_parts)
|
| 246 |
|
| 247 |
-
@handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, '
|
| 248 |
def analyze_single_advanced(self, text: str) -> Dict:
|
| 249 |
-
"""Advanced single text analysis with LIME explanation"""
|
| 250 |
if not text.strip():
|
| 251 |
raise ValueError("Empty text")
|
| 252 |
|
| 253 |
probs = self.predict_proba([text])[0]
|
| 254 |
sentiment = "Positive" if probs[1] > probs[0] else "Negative"
|
| 255 |
|
| 256 |
-
# Extract key words using LIME
|
| 257 |
-
|
|
|
|
| 258 |
|
| 259 |
-
# Create heatmap HTML
|
| 260 |
-
word_scores_dict = dict(
|
| 261 |
heatmap_html = self.create_heatmap_html(text, word_scores_dict)
|
| 262 |
|
| 263 |
return {
|
|
@@ -265,7 +299,8 @@ class SentimentEngine:
|
|
| 265 |
'confidence': float(probs.max()),
|
| 266 |
'pos_prob': float(probs[1]),
|
| 267 |
'neg_prob': float(probs[0]),
|
| 268 |
-
'
|
|
|
|
| 269 |
'heatmap_html': heatmap_html
|
| 270 |
}
|
| 271 |
|
|
@@ -362,24 +397,54 @@ class PlotFactory:
|
|
| 362 |
|
| 363 |
@staticmethod
|
| 364 |
@handle_errors(default_return=None)
|
| 365 |
-
def
|
| 366 |
-
"""Create horizontal bar chart for key contributing words"""
|
| 367 |
-
if not
|
| 368 |
return None
|
| 369 |
|
| 370 |
with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
|
| 371 |
ax = fig.add_subplot(111)
|
| 372 |
|
| 373 |
-
words = [word for word, score in
|
| 374 |
-
scores = [score for word, score in
|
| 375 |
|
| 376 |
color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
|
| 377 |
|
| 378 |
bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
|
| 379 |
ax.set_yticks(range(len(words)))
|
| 380 |
ax.set_yticklabels(words)
|
| 381 |
-
ax.set_xlabel('Attention Weight')
|
| 382 |
-
ax.set_title(f'Top Contributing Words ({sentiment})', fontweight='bold')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
for i, (bar, score) in enumerate(zip(bars, scores)):
|
| 385 |
ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
|
|
@@ -580,11 +645,11 @@ class SentimentApp:
|
|
| 580 |
|
| 581 |
return result_text, prob_plot, gauge_plot, cloud_plot
|
| 582 |
|
| 583 |
-
@handle_errors(default_return=("Please enter text", None, None, None
|
| 584 |
def analyze_single_advanced(self, text: str, theme: str = 'default'):
|
| 585 |
-
"""Advanced single text analysis with LIME explanation"""
|
| 586 |
if not text.strip():
|
| 587 |
-
return "Please enter text", None, None, None
|
| 588 |
|
| 589 |
result = self.engine.analyze_single_advanced(text)
|
| 590 |
|
|
@@ -595,18 +660,18 @@ class SentimentApp:
|
|
| 595 |
})
|
| 596 |
|
| 597 |
theme_ctx = ThemeContext(theme)
|
| 598 |
-
probs = np.array([result['neg_prob'], result['pos_prob']])
|
| 599 |
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
|
|
|
| 604 |
|
| 605 |
-
key_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['key_words'][:5]])
|
| 606 |
result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
|
| 607 |
-
f"Key Words: {
|
|
|
|
| 608 |
|
| 609 |
-
return result_text,
|
| 610 |
|
| 611 |
@handle_errors(default_return=None)
|
| 612 |
def analyze_batch(self, reviews: str, progress=None):
|
|
@@ -726,16 +791,14 @@ def create_interface():
|
|
| 726 |
)
|
| 727 |
|
| 728 |
with gr.Column():
|
| 729 |
-
adv_result_output = gr.Textbox(label="Analysis Result", lines=
|
| 730 |
-
heatmap_output = gr.HTML(label="Word Importance Heatmap")
|
| 731 |
|
| 732 |
with gr.Row():
|
| 733 |
-
|
| 734 |
-
|
| 735 |
|
| 736 |
with gr.Row():
|
| 737 |
-
|
| 738 |
-
keyword_plot = gr.Plot(label="Key Contributing Words")
|
| 739 |
|
| 740 |
with gr.Tab("Batch Analysis"):
|
| 741 |
with gr.Row():
|
|
@@ -778,7 +841,7 @@ def create_interface():
|
|
| 778 |
adv_analyze_btn.click(
|
| 779 |
app.analyze_single_advanced,
|
| 780 |
inputs=[adv_text_input, adv_theme_selector],
|
| 781 |
-
outputs=[adv_result_output,
|
| 782 |
)
|
| 783 |
|
| 784 |
# Event bindings for Batch Analysis
|
|
|
|
| 19 |
import gc
|
| 20 |
import pandas as pd
|
| 21 |
from lime.lime_text import LimeTextExplainer
|
| 22 |
+
import shap
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
class Config:
|
|
|
|
| 157 |
|
| 158 |
# Core Analysis Engine
|
| 159 |
class SentimentEngine:
|
| 160 |
+
"""Streamlined sentiment analysis engine with LIME and SHAP"""
|
| 161 |
def __init__(self):
|
| 162 |
self.model_manager = ModelManager()
|
| 163 |
self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
|
| 164 |
+
self.shap_explainer = None
|
| 165 |
|
| 166 |
def predict_proba(self, texts):
|
| 167 |
"""Prediction function for LIME"""
|
|
|
|
| 214 |
logger.error(f"LIME extraction failed: {e}")
|
| 215 |
return []
|
| 216 |
|
| 217 |
+
def extract_key_words_shap(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
|
| 218 |
+
"""Advanced keyword extraction using SHAP"""
|
| 219 |
+
try:
|
| 220 |
+
# Initialize SHAP explainer if not already done
|
| 221 |
+
if self.shap_explainer is None:
|
| 222 |
+
self.shap_explainer = shap.Explainer(self.predict_proba, self.model_manager.tokenizer)
|
| 223 |
+
|
| 224 |
+
# Get SHAP values
|
| 225 |
+
shap_values = self.shap_explainer([text])
|
| 226 |
+
|
| 227 |
+
# Extract word importance
|
| 228 |
+
words = text.split()
|
| 229 |
+
if len(shap_values.values) > 0 and len(shap_values.values[0]) > 0:
|
| 230 |
+
# Get positive class SHAP values
|
| 231 |
+
pos_shap_values = shap_values.values[0][:, 1] if len(shap_values.values[0].shape) > 1 else shap_values.values[0]
|
| 232 |
+
|
| 233 |
+
word_scores = []
|
| 234 |
+
for i, word in enumerate(words[:len(pos_shap_values)]):
|
| 235 |
+
clean_word = re.sub(r'[^\w]', '', word.lower())
|
| 236 |
+
if len(clean_word) >= config.MIN_WORD_LENGTH:
|
| 237 |
+
word_scores.append((clean_word, abs(float(pos_shap_values[i]))))
|
| 238 |
+
|
| 239 |
+
word_scores.sort(key=lambda x: x[1], reverse=True)
|
| 240 |
+
return word_scores[:top_k]
|
| 241 |
+
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.error(f"SHAP extraction failed: {e}")
|
| 246 |
+
return []
|
| 247 |
+
|
| 248 |
def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
|
| 249 |
"""Create HTML heatmap visualization"""
|
| 250 |
words = text.split()
|
|
|
|
| 277 |
html_parts.append('</div>')
|
| 278 |
return ''.join(html_parts)
|
| 279 |
|
| 280 |
+
@handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'lime_words': [], 'shap_words': [], 'heatmap_html': ''})
|
| 281 |
def analyze_single_advanced(self, text: str) -> Dict:
|
| 282 |
+
"""Advanced single text analysis with LIME and SHAP explanation"""
|
| 283 |
if not text.strip():
|
| 284 |
raise ValueError("Empty text")
|
| 285 |
|
| 286 |
probs = self.predict_proba([text])[0]
|
| 287 |
sentiment = "Positive" if probs[1] > probs[0] else "Negative"
|
| 288 |
|
| 289 |
+
# Extract key words using both LIME and SHAP
|
| 290 |
+
lime_words = self.extract_key_words_lime(text)
|
| 291 |
+
shap_words = self.extract_key_words_shap(text)
|
| 292 |
|
| 293 |
+
# Create heatmap HTML using LIME results
|
| 294 |
+
word_scores_dict = dict(lime_words)
|
| 295 |
heatmap_html = self.create_heatmap_html(text, word_scores_dict)
|
| 296 |
|
| 297 |
return {
|
|
|
|
| 299 |
'confidence': float(probs.max()),
|
| 300 |
'pos_prob': float(probs[1]),
|
| 301 |
'neg_prob': float(probs[0]),
|
| 302 |
+
'lime_words': lime_words,
|
| 303 |
+
'shap_words': shap_words,
|
| 304 |
'heatmap_html': heatmap_html
|
| 305 |
}
|
| 306 |
|
|
|
|
| 397 |
|
| 398 |
@staticmethod
|
| 399 |
@handle_errors(default_return=None)
|
| 400 |
+
def create_lime_keyword_chart(lime_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]:
|
| 401 |
+
"""Create horizontal bar chart for LIME key contributing words"""
|
| 402 |
+
if not lime_words:
|
| 403 |
return None
|
| 404 |
|
| 405 |
with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
|
| 406 |
ax = fig.add_subplot(111)
|
| 407 |
|
| 408 |
+
words = [word for word, score in lime_words]
|
| 409 |
+
scores = [score for word, score in lime_words]
|
| 410 |
|
| 411 |
color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
|
| 412 |
|
| 413 |
bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
|
| 414 |
ax.set_yticks(range(len(words)))
|
| 415 |
ax.set_yticklabels(words)
|
| 416 |
+
ax.set_xlabel('LIME Attention Weight')
|
| 417 |
+
ax.set_title(f'LIME: Top Contributing Words ({sentiment})', fontweight='bold')
|
| 418 |
+
|
| 419 |
+
for i, (bar, score) in enumerate(zip(bars, scores)):
|
| 420 |
+
ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
|
| 421 |
+
f'{score:.3f}', ha='left', va='center', fontsize=9)
|
| 422 |
+
|
| 423 |
+
ax.invert_yaxis()
|
| 424 |
+
ax.grid(axis='x', alpha=0.3)
|
| 425 |
+
fig.tight_layout()
|
| 426 |
+
return fig
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
@handle_errors(default_return=None)
|
| 430 |
+
def create_shap_keyword_chart(shap_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]:
|
| 431 |
+
"""Create horizontal bar chart for SHAP key contributing words"""
|
| 432 |
+
if not shap_words:
|
| 433 |
+
return None
|
| 434 |
+
|
| 435 |
+
with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
|
| 436 |
+
ax = fig.add_subplot(111)
|
| 437 |
+
|
| 438 |
+
words = [word for word, score in shap_words]
|
| 439 |
+
scores = [score for word, score in shap_words]
|
| 440 |
+
|
| 441 |
+
color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
|
| 442 |
+
|
| 443 |
+
bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
|
| 444 |
+
ax.set_yticks(range(len(words)))
|
| 445 |
+
ax.set_yticklabels(words)
|
| 446 |
+
ax.set_xlabel('SHAP Value')
|
| 447 |
+
ax.set_title(f'SHAP: Top Contributing Words ({sentiment})', fontweight='bold')
|
| 448 |
|
| 449 |
for i, (bar, score) in enumerate(zip(bars, scores)):
|
| 450 |
ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
|
|
|
|
| 645 |
|
| 646 |
return result_text, prob_plot, gauge_plot, cloud_plot
|
| 647 |
|
| 648 |
+
@handle_errors(default_return=("Please enter text", None, None, None))
|
| 649 |
def analyze_single_advanced(self, text: str, theme: str = 'default'):
|
| 650 |
+
"""Advanced single text analysis with LIME and SHAP explanation"""
|
| 651 |
if not text.strip():
|
| 652 |
+
return "Please enter text", None, None, None
|
| 653 |
|
| 654 |
result = self.engine.analyze_single_advanced(text)
|
| 655 |
|
|
|
|
| 660 |
})
|
| 661 |
|
| 662 |
theme_ctx = ThemeContext(theme)
|
|
|
|
| 663 |
|
| 664 |
+
lime_plot = PlotFactory.create_lime_keyword_chart(result['lime_words'], result['sentiment'], theme_ctx)
|
| 665 |
+
shap_plot = PlotFactory.create_shap_keyword_chart(result['shap_words'], result['sentiment'], theme_ctx)
|
| 666 |
+
|
| 667 |
+
lime_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['lime_words'][:5]])
|
| 668 |
+
shap_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['shap_words'][:5]])
|
| 669 |
|
|
|
|
| 670 |
result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
|
| 671 |
+
f"LIME Key Words: {lime_words_str}\n"
|
| 672 |
+
f"SHAP Key Words: {shap_words_str}")
|
| 673 |
|
| 674 |
+
return result_text, lime_plot, shap_plot, result['heatmap_html']
|
| 675 |
|
| 676 |
@handle_errors(default_return=None)
|
| 677 |
def analyze_batch(self, reviews: str, progress=None):
|
|
|
|
| 791 |
)
|
| 792 |
|
| 793 |
with gr.Column():
|
| 794 |
+
adv_result_output = gr.Textbox(label="Analysis Result", lines=4)
|
|
|
|
| 795 |
|
| 796 |
with gr.Row():
|
| 797 |
+
lime_plot = gr.Plot(label="LIME: Key Contributing Words")
|
| 798 |
+
shap_plot = gr.Plot(label="SHAP: Key Contributing Words")
|
| 799 |
|
| 800 |
with gr.Row():
|
| 801 |
+
heatmap_output = gr.HTML(label="Word Importance Heatmap (LIME-based)")
|
|
|
|
| 802 |
|
| 803 |
with gr.Tab("Batch Analysis"):
|
| 804 |
with gr.Row():
|
|
|
|
| 841 |
adv_analyze_btn.click(
|
| 842 |
app.analyze_single_advanced,
|
| 843 |
inputs=[adv_text_input, adv_theme_selector],
|
| 844 |
+
outputs=[adv_result_output, lime_plot, shap_plot, heatmap_output]
|
| 845 |
)
|
| 846 |
|
| 847 |
# Event bindings for Batch Analysis
|