PhunvVi's picture
Upload 3 files
32342e9 verified
raw
history blame
966 Bytes
def clean_output(decoded_list):
"""Remove duplicates and trim whitespace"""
return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()]))
def preprocess_context(context):
return f"generate question: {context.strip()}"
"""
Compute SHAP token attributions for a given prompt.
"""
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
input_ids = inputs['input_ids']
# Define wrapper prediction function
def f(x):
with torch.no_grad():
out = model.generate(
input_ids=torch.tensor(x).to(model.device),
max_length=64,
do_sample=False,
num_beams=2
)
return out[:, 0].detach().cpu().numpy()
explainer = shap.Explainer(f, tokenizer)
shap_values = explainer([prompt])
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
return shap_values.values[0], tokens