vangru commited on
Commit
67abd79
·
verified ·
1 Parent(s): 326a4f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -1,19 +1,22 @@
1
  import gradio as gr
2
  import shap
 
 
 
3
  from transformers import pipeline
4
 
5
- # Use lighter model (important for HF)
6
  classifier = pipeline(
7
  "sentiment-analysis",
8
  model="distilbert-base-uncased-finetuned-sst-2-english"
9
  )
10
 
11
- # Create SHAP explainer once
12
  explainer = shap.Explainer(classifier)
13
 
14
  def analyze(text):
15
  if not text.strip():
16
- return "Please enter text", ""
17
 
18
  # Prediction
19
  result = classifier(text)[0]
@@ -23,20 +26,25 @@ def analyze(text):
23
  # SHAP values
24
  shap_values = explainer([text])
25
 
26
- # Convert SHAP to HTML
27
- shap_html = shap.plots.text(shap_values[0], display=False)
28
 
29
- return f"Prediction: {label} (Confidence: {score:.2f})", shap_html
 
 
 
 
 
 
30
 
31
  with gr.Blocks() as demo:
32
  gr.Markdown("# Sentiment Analysis with SHAP")
33
 
34
  inp = gr.Textbox(lines=4, placeholder="Enter text here...")
35
  prediction = gr.Textbox(label="Prediction")
36
- shap_output = gr.HTML(label="SHAP Explanation")
37
 
38
  btn = gr.Button("Analyze")
39
-
40
- btn.click(analyze, inp, [prediction, shap_output])
41
 
42
  demo.launch()
 
1
  import gradio as gr
2
  import shap
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
  from transformers import pipeline
7
 
8
+ # Load lightweight model
9
  classifier = pipeline(
10
  "sentiment-analysis",
11
  model="distilbert-base-uncased-finetuned-sst-2-english"
12
  )
13
 
14
+ # Create explainer
15
  explainer = shap.Explainer(classifier)
16
 
17
  def analyze(text):
18
  if not text.strip():
19
+ return "Please enter text", None
20
 
21
  # Prediction
22
  result = classifier(text)[0]
 
26
  # SHAP values
27
  shap_values = explainer([text])
28
 
29
+ tokens = shap_values[0].data
30
+ values = shap_values[0].values
31
 
32
+ # Create bar plot
33
+ plt.figure()
34
+ plt.barh(tokens, values)
35
+ plt.xlabel("SHAP Value")
36
+ plt.title("Word Contribution to Sentiment")
37
+
38
+ return f"Prediction: {label} (Confidence: {score:.2f})", plt.gcf()
39
 
40
  with gr.Blocks() as demo:
41
  gr.Markdown("# Sentiment Analysis with SHAP")
42
 
43
  inp = gr.Textbox(lines=4, placeholder="Enter text here...")
44
  prediction = gr.Textbox(label="Prediction")
45
+ shap_plot = gr.Plot(label="SHAP Explanation")
46
 
47
  btn = gr.Button("Analyze")
48
+ btn.click(analyze, inp, [prediction, shap_plot])
 
49
 
50
  demo.launch()