pratikshahp commited on
Commit
72a8b0c
·
verified ·
1 Parent(s): 427d981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -35
app.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np
2
  import torch
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
  import shap
 
5
 
6
  # Load pre-trained model and tokenizer
7
  model_name = 'gpt2'
@@ -18,42 +19,20 @@ def predict_next_word_probabilities(text):
18
  predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1)
19
  return predicted_probabilities.squeeze().numpy()
20
 
21
- # Example sentence
22
- sentence = "The cat sat on the"
23
- predicted_prob = predict_next_word_probabilities(sentence)
24
-
25
- # Define a function to get SHAP values
26
- def shap_values(sentence, tokenizer, model):
27
- # Create a set of perturbed versions of the input sentence
28
- perturbed_sentences = [
29
- sentence.replace("the", "", 1), # Remove the first occurrence of "the"
30
- sentence.replace("mat", "rug"), # Replace "mat" with "rug" (if it was present)
31
- ]
32
-
33
- # Calculate the predicted probabilities for each perturbed sentence
34
- base_prob = predict_next_word_probabilities(sentence)
35
- perturbed_probs = [predict_next_word_probabilities(s) for s in perturbed_sentences]
36
-
37
- # Compute SHAP values (simplified)
38
- shap_values = {}
39
- for idx, perturbed_sentence in enumerate(perturbed_sentences):
40
- change = base_prob - perturbed_probs[idx]
41
- shap_values[perturbed_sentence] = change.mean()
42
-
43
- return shap_values
44
 
45
- # Compute SHAP values
46
- shap_values = shap_values(sentence, tokenizer, model)
47
- print("SHAP values:", shap_values)
48
 
49
- # Visualization using matplotlib (optional)
50
- import matplotlib.pyplot as plt
51
 
52
- words = list(shap_values.keys())
53
- values = list(shap_values.values())
54
 
55
- plt.bar(words, values)
56
- plt.xlabel('Perturbed Sentences')
57
- plt.ylabel('SHAP Value')
58
- plt.title('SHAP Values for Next-Word Prediction')
59
- plt.show()
 
2
  import torch
3
  from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
  import shap
5
+ import matplotlib.pyplot as plt
6
 
7
  # Load pre-trained model and tokenizer
8
  model_name = 'gpt2'
 
19
  predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1)
20
  return predicted_probabilities.squeeze().numpy()
21
 
22
+ # Define a function to wrap the model prediction for SHAP
23
+ def predict(texts):
24
+ return np.array([predict_next_word_probabilities(text)[tokenizer.encode("mat")[0]] for text in texts])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Define the initial text
27
+ initial_text = "The cat sat on the"
28
+ texts = [initial_text, initial_text.replace("the", "", 1), initial_text.replace("mat", "rug")]
29
 
30
+ # Create an explainer object
31
+ explainer = shap.Explainer(predict, texts)
32
 
33
+ # Generate SHAP values
34
+ shap_values = explainer(texts)
35
 
36
+ # Visualization
37
+ shap.initjs()
38
+ shap.summary_plot(shap_values, texts)