File size: 966 Bytes
32342e9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
def clean_output(decoded_list):
"""Remove duplicates and trim whitespace"""
return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()]))
def preprocess_context(context):
return f"generate question: {context.strip()}"
"""
Compute SHAP token attributions for a given prompt.
"""
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs['input_ids']
# Define wrapper prediction function
def f(x):
with torch.no_grad():
out = model.generate(
input_ids=torch.tensor(x).to(model.device),
max_length=64,
do_sample=False,
num_beams=2
)
return out[:, 0].detach().cpu().numpy()
explainer = shap.Explainer(f, tokenizer)
shap_values = explainer([prompt])
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
return shap_values.values[0], tokens
|