entropy25 commited on
Commit
326a9a1
·
verified ·
1 Parent(s): 43f768a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -77
app.py CHANGED
@@ -17,7 +17,8 @@ from dataclasses import dataclass
17
  from typing import List, Dict, Optional, Tuple, Any, Callable
18
  from contextlib import contextmanager
19
  import gc
20
- import pandas as pd # Added missing import
 
21
 
22
  @dataclass
23
  class Config:
@@ -155,99 +156,109 @@ class HistoryManager:
155
 
156
  # Core Analysis Engine
157
  class SentimentEngine:
158
- """Streamlined sentiment analysis with attention-based keyword extraction"""
159
  def __init__(self):
160
  self.model_manager = ModelManager()
 
161
 
162
- def extract_key_words(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
163
- """Extract contributing words using BERT attention weights"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  try:
165
- inputs = self.model_manager.tokenizer(
166
- text, return_tensors="pt", padding=True,
167
- truncation=True, max_length=config.MAX_TEXT_LENGTH
168
- ).to(self.model_manager.device)
169
-
170
- # Get model outputs with attention weights
171
- with torch.no_grad():
172
- outputs = self.model_manager.model(**inputs, output_attentions=True)
173
- attention = outputs.attentions # Tuple of attention tensors for each layer
174
-
175
- # Use the last layer's attention, average over all heads
176
- last_attention = attention[-1] # Shape: [batch_size, num_heads, seq_len, seq_len]
177
- avg_attention = last_attention.mean(dim=1) # Average over heads: [batch_size, seq_len, seq_len]
178
-
179
- # Focus on attention to [CLS] token (index 0) as it represents the whole sequence
180
- cls_attention = avg_attention[0, 0, :] # Attention from CLS to all tokens
181
-
182
- # Get tokens and their attention scores
183
- tokens = self.model_manager.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
184
- attention_scores = cls_attention.cpu().numpy()
185
-
186
- # Filter out special tokens and combine subword tokens
187
- word_scores = {}
188
- current_word = ""
189
- current_score = 0.0
190
-
191
- for i, (token, score) in enumerate(zip(tokens, attention_scores)):
192
- if token in ['[CLS]', '[SEP]', '[PAD]']:
193
- continue
194
-
195
- if token.startswith('##'):
196
- # Subword token, add to current word
197
- current_word += token[2:]
198
- current_score = max(current_score, score) # Take max attention
199
- else:
200
- # New word, save previous if exists
201
- if current_word and len(current_word) >= config.MIN_WORD_LENGTH:
202
- word_scores[current_word.lower()] = current_score
203
-
204
- current_word = token
205
- current_score = score
206
-
207
- # Don't forget the last word
208
- if current_word and len(current_word) >= config.MIN_WORD_LENGTH:
209
- word_scores[current_word.lower()] = current_score
210
 
211
- # Filter out stop words and sort by attention score
212
- filtered_words = {
213
- word: score for word, score in word_scores.items()
214
- if word not in config.STOP_WORDS and len(word) >= config.MIN_WORD_LENGTH
215
- }
216
 
217
- # Sort by attention score and return top_k
218
- sorted_words = sorted(filtered_words.items(), key=lambda x: x[1], reverse=True)
219
- return sorted_words[:top_k]
220
 
221
  except Exception as e:
222
- logger.error(f"Key word extraction failed: {e}")
223
  return []
224
 
225
- @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'key_words': []})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  def analyze_single(self, text: str) -> Dict:
227
- """Analyze single text with key word extraction"""
228
  if not text.strip():
229
  raise ValueError("Empty text")
230
 
231
- inputs = self.model_manager.tokenizer(
232
- text, return_tensors="pt", padding=True,
233
- truncation=True, max_length=config.MAX_TEXT_LENGTH
234
- ).to(self.model_manager.device)
235
-
236
- with torch.no_grad():
237
- outputs = self.model_manager.model(**inputs)
238
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
239
-
240
  sentiment = "Positive" if probs[1] > probs[0] else "Negative"
241
 
242
- # Extract key contributing words
243
- key_words = self.extract_key_words(text)
 
 
 
 
244
 
245
  return {
246
  'sentiment': sentiment,
247
  'confidence': float(probs.max()),
248
  'pos_prob': float(probs[1]),
249
  'neg_prob': float(probs[0]),
250
- 'key_words': key_words
 
251
  }
252
 
253
  @handle_errors(default_return=[])
@@ -585,11 +596,11 @@ class SentimentApp:
585
  ]
586
 
587
 
588
- @handle_errors(default_return=("Please enter text", None, None, None, None))
589
  def analyze_single(self, text: str, theme: str = 'default'):
590
- """Single text analysis with key words"""
591
  if not text.strip():
592
- return "Please enter text", None, None, None, None
593
 
594
  result = self.engine.analyze_single(text)
595
 
@@ -614,7 +625,8 @@ class SentimentApp:
614
  result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
615
  f"Key Words: {key_words_str}")
616
 
617
- return result_text, prob_plot, gauge_plot, cloud_plot, keyword_plot
 
618
 
619
  @handle_errors(default_return=None)
620
  def analyze_batch(self, reviews: str, progress=None):
@@ -706,6 +718,7 @@ def create_interface():
706
 
707
  with gr.Column():
708
  result_output = gr.Textbox(label="Result", lines=3)
 
709
 
710
  with gr.Row():
711
  prob_plot = gr.Plot(label="Probabilities")
@@ -749,7 +762,7 @@ def create_interface():
749
  analyze_btn.click(
750
  app.analyze_single,
751
  inputs=[text_input, theme_selector],
752
- outputs=[result_output, prob_plot, gauge_plot, wordcloud_plot, keyword_plot]
753
  )
754
 
755
  load_btn.click(app.data_handler.process_file, inputs=file_upload, outputs=batch_input)
 
17
  from typing import List, Dict, Optional, Tuple, Any, Callable
18
  from contextlib import contextmanager
19
  import gc
20
+ import pandas as pd
21
+ from lime.lime_text import LimeTextExplainer # Added LIME import
22
 
23
  @dataclass
24
  class Config:
 
156
 
157
  # Core Analysis Engine
158
  class SentimentEngine:
159
+ """Streamlined sentiment analysis with LIME-based keyword extraction"""
160
  def __init__(self):
161
  self.model_manager = ModelManager()
162
+ self.lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
163
 
164
+ def predict_proba(self, texts):
165
+ """Prediction function for LIME"""
166
+ if isinstance(texts, str):
167
+ texts = [texts]
168
+
169
+ inputs = self.model_manager.tokenizer(
170
+ texts, return_tensors="pt", padding=True,
171
+ truncation=True, max_length=config.MAX_TEXT_LENGTH
172
+ ).to(self.model_manager.device)
173
+
174
+ with torch.no_grad():
175
+ outputs = self.model_manager.model(**inputs)
176
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1).cpu().numpy()
177
+
178
+ return probs
179
+
180
+ def extract_key_words_lime(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
181
+ """Fast keyword extraction using LIME"""
182
  try:
183
+ # Get LIME explanation
184
+ explanation = self.lime_explainer.explain_instance(
185
+ text, self.predict_proba, num_features=top_k, num_samples=100
186
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ # Extract word importance scores
189
+ word_scores = []
190
+ for word, score in explanation.as_list():
191
+ if len(word.strip()) >= config.MIN_WORD_LENGTH:
192
+ word_scores.append((word.strip().lower(), abs(score)))
193
 
194
+ # Sort by importance and return top_k
195
+ word_scores.sort(key=lambda x: x[1], reverse=True)
196
+ return word_scores[:top_k]
197
 
198
  except Exception as e:
199
+ logger.error(f"LIME extraction failed: {e}")
200
  return []
201
 
202
+ def create_heatmap_html(self, text: str, word_scores: Dict[str, float]) -> str:
203
+ """Create HTML heatmap visualization"""
204
+ words = text.split()
205
+ html_parts = ['<div style="font-family: Arial; font-size: 16px; line-height: 1.6;">']
206
+
207
+ # Normalize scores for color intensity
208
+ if word_scores:
209
+ max_score = max(abs(score) for score in word_scores.values())
210
+ min_score = min(word_scores.values())
211
+ else:
212
+ max_score = min_score = 0
213
+
214
+ for word in words:
215
+ clean_word = re.sub(r'[^\w]', '', word.lower())
216
+ score = word_scores.get(clean_word, 0)
217
+
218
+ if score > 0:
219
+ # Positive contribution - green
220
+ intensity = min(255, int(180 * (score / max_score) if max_score > 0 else 0))
221
+ color = f"rgba(0, {intensity}, 0, 0.3)"
222
+ elif score < 0:
223
+ # Negative contribution - red
224
+ intensity = min(255, int(180 * (abs(score) / abs(min_score)) if min_score < 0 else 0))
225
+ color = f"rgba({intensity}, 0, 0, 0.3)"
226
+ else:
227
+ # Neutral - no highlighting
228
+ color = "transparent"
229
+
230
+ html_parts.append(
231
+ f'<span style="background-color: {color}; padding: 2px; margin: 1px; '
232
+ f'border-radius: 3px;" title="Score: {score:.3f}">{word}</span> '
233
+ )
234
+
235
+ html_parts.append('</div>')
236
+ return ''.join(html_parts)
237
+
238
+ @handle_errors(default_return={'sentiment': 'Unknown', 'confidence': 0.0, 'key_words': [], 'heatmap_html': ''})
239
  def analyze_single(self, text: str) -> Dict:
240
+ """Analyze single text with LIME explanation"""
241
  if not text.strip():
242
  raise ValueError("Empty text")
243
 
244
+ # Get sentiment prediction
245
+ probs = self.predict_proba([text])[0]
 
 
 
 
 
 
 
246
  sentiment = "Positive" if probs[1] > probs[0] else "Negative"
247
 
248
+ # Extract key words using LIME
249
+ key_words = self.extract_key_words_lime(text)
250
+
251
+ # Create heatmap HTML
252
+ word_scores_dict = dict(key_words)
253
+ heatmap_html = self.create_heatmap_html(text, word_scores_dict)
254
 
255
  return {
256
  'sentiment': sentiment,
257
  'confidence': float(probs.max()),
258
  'pos_prob': float(probs[1]),
259
  'neg_prob': float(probs[0]),
260
+ 'key_words': key_words,
261
+ 'heatmap_html': heatmap_html
262
  }
263
 
264
  @handle_errors(default_return=[])
 
596
  ]
597
 
598
 
599
+ @handle_errors(default_return=("Please enter text", None, None, None, None, None))
600
  def analyze_single(self, text: str, theme: str = 'default'):
601
+ """Single text analysis with LIME explanation and heatmap"""
602
  if not text.strip():
603
+ return "Please enter text", None, None, None, None, None
604
 
605
  result = self.engine.analyze_single(text)
606
 
 
625
  result_text = (f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']:.3f})\n"
626
  f"Key Words: {key_words_str}")
627
 
628
+ # Return heatmap HTML as additional output
629
+ return result_text, prob_plot, gauge_plot, cloud_plot, keyword_plot, result['heatmap_html']
630
 
631
  @handle_errors(default_return=None)
632
  def analyze_batch(self, reviews: str, progress=None):
 
718
 
719
  with gr.Column():
720
  result_output = gr.Textbox(label="Result", lines=3)
721
+ heatmap_output = gr.HTML(label="Word Importance Heatmap")
722
 
723
  with gr.Row():
724
  prob_plot = gr.Plot(label="Probabilities")
 
762
  analyze_btn.click(
763
  app.analyze_single,
764
  inputs=[text_input, theme_selector],
765
+ outputs=[result_output, prob_plot, gauge_plot, wordcloud_plot, keyword_plot, heatmap_output]
766
  )
767
 
768
  load_btn.click(app.data_handler.process_file, inputs=file_upload, outputs=batch_input)