pratikshahp's picture
Update app.py
72a8b0c verified
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import shap
import matplotlib.pyplot as plt
# Load pre-trained model and tokenizer
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
model.eval()
# Define the function to predict next word probabilities
def predict_next_word_probabilities(text):
inputs = tokenizer.encode(text, return_tensors='pt')
with torch.no_grad():
outputs = model(inputs, labels=inputs)
logits = outputs.logits
predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1)
return predicted_probabilities.squeeze().numpy()
# Define a function to wrap the model prediction for SHAP
def predict(texts):
return np.array([predict_next_word_probabilities(text)[tokenizer.encode("mat")[0]] for text in texts])
# Define the initial text
initial_text = "The cat sat on the"
texts = [initial_text, initial_text.replace("the", "", 1), initial_text.replace("mat", "rug")]
# Create an explainer object
explainer = shap.Explainer(predict, texts)
# Generate SHAP values
shap_values = explainer(texts)
# Visualization
shap.initjs()
shap.summary_plot(shap_values, texts)