Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| import shap | |
| import matplotlib.pyplot as plt | |
| # Load pre-trained model and tokenizer | |
| model_name = 'gpt2' | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| model.eval() | |
| # Define the function to predict next word probabilities | |
| def predict_next_word_probabilities(text): | |
| inputs = tokenizer.encode(text, return_tensors='pt') | |
| with torch.no_grad(): | |
| outputs = model(inputs, labels=inputs) | |
| logits = outputs.logits | |
| predicted_probabilities = torch.nn.functional.softmax(logits[:, -1], dim=-1) | |
| return predicted_probabilities.squeeze().numpy() | |
| # Define a function to wrap the model prediction for SHAP | |
| def predict(texts): | |
| return np.array([predict_next_word_probabilities(text)[tokenizer.encode("mat")[0]] for text in texts]) | |
| # Define the initial text | |
| initial_text = "The cat sat on the" | |
| texts = [initial_text, initial_text.replace("the", "", 1), initial_text.replace("mat", "rug")] | |
| # Create an explainer object | |
| explainer = shap.Explainer(predict, texts) | |
| # Generate SHAP values | |
| shap_values = explainer(texts) | |
| # Visualization | |
| shap.initjs() | |
| shap.summary_plot(shap_values, texts) |