PhunvVi commited on
Commit
b738126
·
verified ·
1 Parent(s): 0c561ea

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +8 -9
utils.py CHANGED
@@ -8,27 +8,26 @@ def preprocess_context(context):
8
  return f"generate question: {context.strip()}"
9
 
10
 
11
-
12
- """
13
- Compute SHAP token attributions for a given prompt.
14
- """
15
  # Tokenize input
16
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
17
- input_ids = inputs['input_ids']
18
 
19
  # Define wrapper prediction function
20
  def f(x):
 
21
  with torch.no_grad():
22
  out = model.generate(
23
- input_ids=torch.tensor(x).to(model.device),
24
  max_length=64,
25
  do_sample=False,
26
  num_beams=2
27
  )
28
- return out[:, 0].detach().cpu().numpy()
29
 
30
- explainer = shap.Explainer(f, tokenizer)
31
- shap_values = explainer([prompt])
 
32
 
33
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
34
  return shap_values.values[0], tokens
 
8
  return f"generate question: {context.strip()}"
9
 
10
 
11
+ def get_shap_values(tokenizer, model, prompt):
 
 
 
12
  # Tokenize input
13
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
14
+ input_ids = inputs["input_ids"]
15
 
16
  # Define wrapper prediction function
17
  def f(x):
18
+ x = torch.tensor(x).long().to(model.device) # 🔧 convert to LongTensor
19
  with torch.no_grad():
20
  out = model.generate(
21
+ input_ids=x,
22
  max_length=64,
23
  do_sample=False,
24
  num_beams=2
25
  )
26
+ return np.ones((x.shape[0], 1)) # dummy prediction
27
 
28
+ # SHAP explainer
29
+ explainer = shap.Explainer(f, input_ids.numpy())
30
+ shap_values = explainer(input_ids.numpy())
31
 
32
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
33
  return shap_values.values[0], tokens