def clean_output(decoded_list): """Remove duplicates and trim whitespace""" return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()])) def preprocess_context(context): return f"generate question: {context.strip()}" """ Compute SHAP token attributions for a given prompt. """ # Tokenize input inputs = tokenizer(prompt, return_tensors="pt", truncation=True) input_ids = inputs['input_ids'] # Define wrapper prediction function def f(x): with torch.no_grad(): out = model.generate( input_ids=torch.tensor(x).to(model.device), max_length=64, do_sample=False, num_beams=2 ) return out[:, 0].detach().cpu().numpy() explainer = shap.Explainer(f, tokenizer) shap_values = explainer([prompt]) tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) return shap_values.values[0], tokens