entropy25 commited on
Commit
3644b14
·
verified ·
1 Parent(s): 7f8f92d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -33
app.py CHANGED
@@ -19,6 +19,7 @@ from contextlib import contextmanager
19
  import gc
20
  import pandas as pd
21
  from lime.lime_text import LimeTextExplainer
 
22
 
23
  @dataclass
24
  class Config:
@@ -156,10 +157,11 @@ class HistoryManager:
156
 
157
  # Core Analysis Engine
158
  class SentimentEngine:
159
- """Streamlined sentiment analysis engine"""
160
  def __init__(self):
161
  self.model_manager = ModelManager()
162
  self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
 
163
 
164
  def predict_proba(self, texts):
165
  """Prediction function for LIME"""
@@ -212,6 +214,37 @@ class SentimentEngine:
212
  logger.error(f"LIME extraction failed: {e}")
213
  return []
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
216
  """Create HTML heatmap visualization"""
217
  words = text.split()
@@ -244,20 +277,21 @@ class SentimentEngine:
244
  html_parts.append('</div>')
245
  return ''.join(html_parts)
246
 
247
- @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'key_words': [], 'heatmap_html': ''})
248
  def analyze_single_advanced(self, text: str) -> Dict:
249
- """Advanced single text analysis with LIME explanation"""
250
  if not text.strip():
251
  raise ValueError("Empty text")
252
 
253
  probs = self.predict_proba([text])[0]
254
  sentiment = "Positive" if probs[1] > probs[0] else "Negative"
255
 
256
- # Extract key words using LIME
257
- key_words = self.extract_key_words_lime(text)
 
258
 
259
- # Create heatmap HTML
260
- word_scores_dict = dict(key_words)
261
  heatmap_html = self.create_heatmap_html(text, word_scores_dict)
262
 
263
  return {
@@ -265,7 +299,8 @@ class SentimentEngine:
265
  'confidence': float(probs.max()),
266
  'pos_prob': float(probs[1]),
267
  'neg_prob': float(probs[0]),
268
- 'key_words': key_words,
 
269
  'heatmap_html': heatmap_html
270
  }
271
 
@@ -362,24 +397,54 @@ class PlotFactory:
362
 
363
  @staticmethod
364
  @handle_errors(default_return=None)
365
- def create_keyword_chart(key_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]:
366
- """Create horizontal bar chart for key contributing words"""
367
- if not key_words:
368
  return None
369
 
370
  with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
371
  ax = fig.add_subplot(111)
372
 
373
- words = [word for word, score in key_words]
374
- scores = [score for word, score in key_words]
375
 
376
  color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
377
 
378
  bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
379
  ax.set_yticks(range(len(words)))
380
  ax.set_yticklabels(words)
381
- ax.set_xlabel('Attention Weight')
382
- ax.set_title(f'Top Contributing Words ({sentiment})', fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  for i, (bar, score) in enumerate(zip(bars, scores)):
385
  ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
@@ -580,11 +645,11 @@ class SentimentApp:
580
 
581
  return result_text, prob_plot, gauge_plot, cloud_plot
582
 
583
- @handle_errors(default_return=("Please enter text", None, None, None, None, None))
584
  def analyze_single_advanced(self, text: str, theme: str = 'default'):
585
- """Advanced single text analysis with LIME explanation"""
586
  if not text.strip():
587
- return "Please enter text", None, None, None, None, None
588
 
589
  result = self.engine.analyze_single_advanced(text)
590
 
@@ -595,18 +660,18 @@ class SentimentApp:
595
  })
596
 
597
  theme_ctx = ThemeContext(theme)
598
- probs = np.array([result['neg_prob'], result['pos_prob']])
599
 
600
- prob_plot = PlotFactory.create_sentiment_bars(probs, theme_ctx)
601
- gauge_plot = PlotFactory.create_confidence_gauge(result['confidence'], result['sentiment'], theme_ctx)
602
- cloud_plot = PlotFactory.create_wordcloud(text, result['sentiment'], theme_ctx)
603
- keyword_plot = PlotFactory.create_keyword_chart(result['key_words'], result['sentiment'], theme_ctx)
 
604
 
605
- key_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['key_words'][:5]])
606
  result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
607
- f"Key Words: {key_words_str}")
 
608
 
609
- return result_text, prob_plot, gauge_plot, cloud_plot, keyword_plot, result['heatmap_html']
610
 
611
  @handle_errors(default_return=None)
612
  def analyze_batch(self, reviews: str, progress=None):
@@ -726,16 +791,14 @@ def create_interface():
726
  )
727
 
728
  with gr.Column():
729
- adv_result_output = gr.Textbox(label="Analysis Result", lines=3)
730
- heatmap_output = gr.HTML(label="Word Importance Heatmap")
731
 
732
  with gr.Row():
733
- adv_prob_plot = gr.Plot(label="Probabilities")
734
- adv_gauge_plot = gr.Plot(label="Confidence")
735
 
736
  with gr.Row():
737
- adv_wordcloud_plot = gr.Plot(label="Word Cloud")
738
- keyword_plot = gr.Plot(label="Key Contributing Words")
739
 
740
  with gr.Tab("Batch Analysis"):
741
  with gr.Row():
@@ -778,7 +841,7 @@ def create_interface():
778
  adv_analyze_btn.click(
779
  app.analyze_single_advanced,
780
  inputs=[adv_text_input, adv_theme_selector],
781
- outputs=[adv_result_output, adv_prob_plot, adv_gauge_plot, adv_wordcloud_plot, keyword_plot, heatmap_output]
782
  )
783
 
784
  # Event bindings for Batch Analysis
 
19
  import gc
20
  import pandas as pd
21
  from lime.lime_text import LimeTextExplainer
22
+ import shap
23
 
24
  @dataclass
25
  class Config:
 
157
 
158
  # Core Analysis Engine
159
  class SentimentEngine:
160
+ """Streamlined sentiment analysis engine with LIME and SHAP"""
161
  def __init__(self):
162
  self.model_manager = ModelManager()
163
  self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
164
+ self.shap_explainer = None
165
 
166
  def predict_proba(self, texts):
167
  """Prediction function for LIME"""
 
214
  logger.error(f"LIME extraction failed: {e}")
215
  return []
216
 
217
+ def extract_key_words_shap(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
218
+ """Advanced keyword extraction using SHAP"""
219
+ try:
220
+ # Initialize SHAP explainer if not already done
221
+ if self.shap_explainer is None:
222
+ self.shap_explainer = shap.Explainer(self.predict_proba, self.model_manager.tokenizer)
223
+
224
+ # Get SHAP values
225
+ shap_values = self.shap_explainer([text])
226
+
227
+ # Extract word importance
228
+ words = text.split()
229
+ if len(shap_values.values) > 0 and len(shap_values.values[0]) > 0:
230
+ # Get positive class SHAP values
231
+ pos_shap_values = shap_values.values[0][:, 1] if len(shap_values.values[0].shape) > 1 else shap_values.values[0]
232
+
233
+ word_scores = []
234
+ for i, word in enumerate(words[:len(pos_shap_values)]):
235
+ clean_word = re.sub(r'[^\w]', '', word.lower())
236
+ if len(clean_word) >= config.MIN_WORD_LENGTH:
237
+ word_scores.append((clean_word, abs(float(pos_shap_values[i]))))
238
+
239
+ word_scores.sort(key=lambda x: x[1], reverse=True)
240
+ return word_scores[:top_k]
241
+
242
+ return []
243
+
244
+ except Exception as e:
245
+ logger.error(f"SHAP extraction failed: {e}")
246
+ return []
247
+
248
  def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
249
  """Create HTML heatmap visualization"""
250
  words = text.split()
 
277
  html_parts.append('</div>')
278
  return ''.join(html_parts)
279
 
280
+ @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'lime_words': [], 'shap_words': [], 'heatmap_html': ''})
281
  def analyze_single_advanced(self, text: str) -> Dict:
282
+ """Advanced single text analysis with LIME and SHAP explanation"""
283
  if not text.strip():
284
  raise ValueError("Empty text")
285
 
286
  probs = self.predict_proba([text])[0]
287
  sentiment = "Positive" if probs[1] > probs[0] else "Negative"
288
 
289
+ # Extract key words using both LIME and SHAP
290
+ lime_words = self.extract_key_words_lime(text)
291
+ shap_words = self.extract_key_words_shap(text)
292
 
293
+ # Create heatmap HTML using LIME results
294
+ word_scores_dict = dict(lime_words)
295
  heatmap_html = self.create_heatmap_html(text, word_scores_dict)
296
 
297
  return {
 
299
  'confidence': float(probs.max()),
300
  'pos_prob': float(probs[1]),
301
  'neg_prob': float(probs[0]),
302
+ 'lime_words': lime_words,
303
+ 'shap_words': shap_words,
304
  'heatmap_html': heatmap_html
305
  }
306
 
 
397
 
398
  @staticmethod
399
  @handle_errors(default_return=None)
400
+ def create_lime_keyword_chart(lime_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]:
401
+ """Create horizontal bar chart for LIME key contributing words"""
402
+ if not lime_words:
403
  return None
404
 
405
  with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
406
  ax = fig.add_subplot(111)
407
 
408
+ words = [word for word, score in lime_words]
409
+ scores = [score for word, score in lime_words]
410
 
411
  color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
412
 
413
  bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
414
  ax.set_yticks(range(len(words)))
415
  ax.set_yticklabels(words)
416
+ ax.set_xlabel('LIME Attention Weight')
417
+ ax.set_title(f'LIME: Top Contributing Words ({sentiment})', fontweight='bold')
418
+
419
+ for i, (bar, score) in enumerate(zip(bars, scores)):
420
+ ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
421
+ f'{score:.3f}', ha='left', va='center', fontsize=9)
422
+
423
+ ax.invert_yaxis()
424
+ ax.grid(axis='x', alpha=0.3)
425
+ fig.tight_layout()
426
+ return fig
427
+
428
+ @staticmethod
429
+ @handle_errors(default_return=None)
430
+ def create_shap_keyword_chart(shap_words: List[Tuple[str, float]], sentiment: str, theme: ThemeContext) -> Optional[plt.Figure]:
431
+ """Create horizontal bar chart for SHAP key contributing words"""
432
+ if not shap_words:
433
+ return None
434
+
435
+ with managed_figure(figsize=config.FIGURE_SIZE_SINGLE) as fig:
436
+ ax = fig.add_subplot(111)
437
+
438
+ words = [word for word, score in shap_words]
439
+ scores = [score for word, score in shap_words]
440
+
441
+ color = theme.colors['pos'] if sentiment == 'Positive' else theme.colors['neg']
442
+
443
+ bars = ax.barh(range(len(words)), scores, color=color, alpha=0.7)
444
+ ax.set_yticks(range(len(words)))
445
+ ax.set_yticklabels(words)
446
+ ax.set_xlabel('SHAP Value')
447
+ ax.set_title(f'SHAP: Top Contributing Words ({sentiment})', fontweight='bold')
448
 
449
  for i, (bar, score) in enumerate(zip(bars, scores)):
450
  ax.text(bar.get_width() + 0.001, bar.get_y() + bar.get_height()/2.,
 
645
 
646
  return result_text, prob_plot, gauge_plot, cloud_plot
647
 
648
+ @handle_errors(default_return=("Please enter text", None, None, None))
649
  def analyze_single_advanced(self, text: str, theme: str = 'default'):
650
+ """Advanced single text analysis with LIME and SHAP explanation"""
651
  if not text.strip():
652
+ return "Please enter text", None, None, None
653
 
654
  result = self.engine.analyze_single_advanced(text)
655
 
 
660
  })
661
 
662
  theme_ctx = ThemeContext(theme)
 
663
 
664
+ lime_plot = PlotFactory.create_lime_keyword_chart(result['lime_words'], result['sentiment'], theme_ctx)
665
+ shap_plot = PlotFactory.create_shap_keyword_chart(result['shap_words'], result['sentiment'], theme_ctx)
666
+
667
+ lime_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['lime_words'][:5]])
668
+ shap_words_str = ", ".join([f"{word}({score:.3f})" for word, score in result['shap_words'][:5]])
669
 
 
670
  result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
671
+ f"LIME Key Words: {lime_words_str}\n"
672
+ f"SHAP Key Words: {shap_words_str}")
673
 
674
+ return result_text, lime_plot, shap_plot, result['heatmap_html']
675
 
676
  @handle_errors(default_return=None)
677
  def analyze_batch(self, reviews: str, progress=None):
 
791
  )
792
 
793
  with gr.Column():
794
+ adv_result_output = gr.Textbox(label="Analysis Result", lines=4)
 
795
 
796
  with gr.Row():
797
+ lime_plot = gr.Plot(label="LIME: Key Contributing Words")
798
+ shap_plot = gr.Plot(label="SHAP: Key Contributing Words")
799
 
800
  with gr.Row():
801
+ heatmap_output = gr.HTML(label="Word Importance Heatmap (LIME-based)")
 
802
 
803
  with gr.Tab("Batch Analysis"):
804
  with gr.Row():
 
841
  adv_analyze_btn.click(
842
  app.analyze_single_advanced,
843
  inputs=[adv_text_input, adv_theme_selector],
844
+ outputs=[adv_result_output, lime_plot, shap_plot, heatmap_output]
845
  )
846
 
847
  # Event bindings for Batch Analysis