Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import numpy as np
|
|
| 2 |
import torch
|
| 3 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 4 |
import shap
|
|
|
|
| 5 |
|
| 6 |
# Load pre-trained model and tokenizer
|
| 7 |
model_name = 'gpt2'
|
|
@@ -18,42 +19,20 @@ def predict_next_word_probabilities(text):
|
|
| 18 |
predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1)
|
| 19 |
return predicted_probabilities.squeeze().numpy()
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Define a function to get SHAP values
|
| 26 |
-
def shap_values(sentence, tokenizer, model):
|
| 27 |
-
# Create a set of perturbed versions of the input sentence
|
| 28 |
-
perturbed_sentences = [
|
| 29 |
-
sentence.replace("the", "", 1), # Remove the first occurrence of "the"
|
| 30 |
-
sentence.replace("mat", "rug"), # Replace "mat" with "rug" (if it was present)
|
| 31 |
-
]
|
| 32 |
-
|
| 33 |
-
# Calculate the predicted probabilities for each perturbed sentence
|
| 34 |
-
base_prob = predict_next_word_probabilities(sentence)
|
| 35 |
-
perturbed_probs = [predict_next_word_probabilities(s) for s in perturbed_sentences]
|
| 36 |
-
|
| 37 |
-
# Compute SHAP values (simplified)
|
| 38 |
-
shap_values = {}
|
| 39 |
-
for idx, perturbed_sentence in enumerate(perturbed_sentences):
|
| 40 |
-
change = base_prob - perturbed_probs[idx]
|
| 41 |
-
shap_values[perturbed_sentence] = change.mean()
|
| 42 |
-
|
| 43 |
-
return shap_values
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
|
| 47 |
-
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
plt.title('SHAP Values for Next-Word Prediction')
|
| 59 |
-
plt.show()
|
|
|
|
| 2 |
import torch
|
| 3 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 4 |
import shap
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
|
| 7 |
# Load pre-trained model and tokenizer
|
| 8 |
model_name = 'gpt2'
|
|
|
|
| 19 |
predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1)
|
| 20 |
return predicted_probabilities.squeeze().numpy()
|
| 21 |
|
| 22 |
+
# Define a function to wrap the model prediction for SHAP
|
| 23 |
+
def predict(texts):
|
| 24 |
+
return np.array([predict_next_word_probabilities(text)[tokenizer.encode("mat")[0]] for text in texts])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
# Define the initial text
|
| 27 |
+
initial_text = "The cat sat on the"
|
| 28 |
+
texts = [initial_text, initial_text.replace("the", "", 1), initial_text.replace("mat", "rug")]
|
| 29 |
|
| 30 |
+
# Create an explainer object
|
| 31 |
+
explainer = shap.Explainer(predict, texts)
|
| 32 |
|
| 33 |
+
# Generate SHAP values
|
| 34 |
+
shap_values = explainer(texts)
|
| 35 |
|
| 36 |
+
# Visualization
|
| 37 |
+
shap.initjs()
|
| 38 |
+
shap.summary_plot(shap_values, texts)
|
|
|
|
|
|