Update utils.py
Browse files
utils.py
CHANGED
|
@@ -8,27 +8,26 @@ def preprocess_context(context):
|
|
| 8 |
return f"generate question: {context.strip()}"
|
| 9 |
|
| 10 |
|
| 11 |
-
|
| 12 |
-
"""
|
| 13 |
-
Compute SHAP token attributions for a given prompt.
|
| 14 |
-
"""
|
| 15 |
# Tokenize input
|
| 16 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
| 17 |
-
input_ids = inputs[
|
| 18 |
|
| 19 |
# Define wrapper prediction function
|
| 20 |
def f(x):
|
|
|
|
| 21 |
with torch.no_grad():
|
| 22 |
out = model.generate(
|
| 23 |
-
input_ids=
|
| 24 |
max_length=64,
|
| 25 |
do_sample=False,
|
| 26 |
num_beams=2
|
| 27 |
)
|
| 28 |
-
return
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 34 |
return shap_values.values[0], tokens
|
|
|
|
| 8 |
return f"generate question: {context.strip()}"
|
| 9 |
|
| 10 |
|
| 11 |
+
def get_shap_values(tokenizer, model, prompt):
|
|
|
|
|
|
|
|
|
|
| 12 |
# Tokenize input
|
| 13 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
|
| 14 |
+
input_ids = inputs["input_ids"]
|
| 15 |
|
| 16 |
# Define wrapper prediction function
|
| 17 |
def f(x):
|
| 18 |
+
x = torch.tensor(x).long().to(model.device) # 🔧 convert to LongTensor
|
| 19 |
with torch.no_grad():
|
| 20 |
out = model.generate(
|
| 21 |
+
input_ids=x,
|
| 22 |
max_length=64,
|
| 23 |
do_sample=False,
|
| 24 |
num_beams=2
|
| 25 |
)
|
| 26 |
+
return np.ones((x.shape[0], 1)) # dummy prediction
|
| 27 |
|
| 28 |
+
# SHAP explainer
|
| 29 |
+
explainer = shap.Explainer(f, input_ids.numpy())
|
| 30 |
+
shap_values = explainer(input_ids.numpy())
|
| 31 |
|
| 32 |
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
|
| 33 |
return shap_values.values[0], tokens
|