File size: 1,041 Bytes
32342e9 b738126 32342e9 b738126 32342e9 b738126 32342e9 b738126 32342e9 b738126 32342e9 b738126 32342e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
def clean_output(decoded_list):
"""Remove duplicates and trim whitespace"""
return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()]))
def preprocess_context(context):
return f"generate question: {context.strip()}"
def get_shap_values(tokenizer, model, prompt):
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs["input_ids"]
# Define wrapper prediction function
def f(x):
x = torch.tensor(x).long().to(model.device) # 🔧 convert to LongTensor
with torch.no_grad():
out = model.generate(
input_ids=x,
max_length=64,
do_sample=False,
num_beams=2
)
return np.ones((x.shape[0], 1)) # dummy prediction
# SHAP explainer
explainer = shap.Explainer(f, input_ids.numpy())
shap_values = explainer(input_ids.numpy())
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
return shap_values.values[0], tokens
|