|
|
|
|
|
def clean_output(decoded_list): |
|
|
"""Remove duplicates and trim whitespace""" |
|
|
return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()])) |
|
|
|
|
|
|
|
|
def preprocess_context(context): |
|
|
return f"generate question: {context.strip()}" |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Compute SHAP token attributions for a given prompt. |
|
|
""" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
|
|
input_ids = inputs['input_ids'] |
|
|
|
|
|
|
|
|
def f(x): |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
input_ids=torch.tensor(x).to(model.device), |
|
|
max_length=64, |
|
|
do_sample=False, |
|
|
num_beams=2 |
|
|
) |
|
|
return out[:, 0].detach().cpu().numpy() |
|
|
|
|
|
explainer = shap.Explainer(f, tokenizer) |
|
|
shap_values = explainer([prompt]) |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
return shap_values.values[0], tokens |
|
|
|