File size: 1,041 Bytes
32342e9
 
 
 
 
 
 
 
 
 
b738126
32342e9
 
b738126
32342e9
 
 
b738126
32342e9
 
b738126
32342e9
 
 
 
b738126
32342e9
b738126
 
 
32342e9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

def clean_output(decoded_list):
    """Remove duplicates and trim whitespace"""
    return list(dict.fromkeys([q.strip() for q in decoded_list if q.strip()]))


def preprocess_context(context):
    return f"generate question: {context.strip()}"


def get_shap_values(tokenizer, model, prompt):
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    input_ids = inputs["input_ids"]

    # Define wrapper prediction function
    def f(x):
        x = torch.tensor(x).long().to(model.device)  # 🔧 convert to LongTensor
        with torch.no_grad():
            out = model.generate(
                input_ids=x,
                max_length=64,
                do_sample=False,
                num_beams=2
            )
        return np.ones((x.shape[0], 1))  # dummy prediction

    # SHAP explainer
    explainer = shap.Explainer(f, input_ids.numpy())
    shap_values = explainer(input_ids.numpy())

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    return shap_values.values[0], tokens