Spaces:
Running
Running
File size: 2,339 Bytes
8953138 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | # explainer.py
import shap
import torch
import torch.nn.functional as F
def generate_shap_explanation(
cleaned_code: str,
model: torch.nn.Module,
tokenizer,
static_features_tensor: torch.Tensor,
device: torch.device,
pred_idx: int,
label_map: dict
):
"""
Generates SHAP token importance scores for the predicted complexity class.
"""
# SHAP Prediction Wrapper
def text_prediction_wrapper(texts):
texts_list = [str(t) for t in texts]
encodings = tokenizer(
texts_list,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(device)
# Expand static features to match SHAP permutation batch size
batch_size = encodings['input_ids'].shape[0]
expanded_static = static_features_tensor.repeat(batch_size, 1)
with torch.no_grad():
batch_logits = model(
input_ids=encodings['input_ids'],
attention_mask=encodings['attention_mask'],
static_features=expanded_static
)
return F.softmax(batch_logits, dim=1).cpu().numpy()
# Configure SHAP Explainer
masker = shap.maskers.Text(tokenizer, mask_token=tokenizer.mask_token)
explainer = shap.Explainer(
text_prediction_wrapper,
masker,
output_names=list(label_map.values())
)
# Calculate Values (max_evals=100 for API speed)
shap_values = explainer([cleaned_code], max_evals=100)
# Extract the specific tokens and their impact scores for the predicted class
tokens = shap_values.data[0]
scores = shap_values.values[0, :, pred_idx]
# Map Character Offsets for Frontend Highlighting
encoding = tokenizer(cleaned_code, return_offsets_mapping=True, truncation=True, max_length=512)
offsets = encoding["offset_mapping"]
token_data = []
for i, (t, s) in enumerate(zip(tokens, scores)):
start_char = offsets[i][0] if i < len(offsets) else 0
end_char = offsets[i][1] if i < len(offsets) else 0
token_data.append({
"token": t,
"score": float(s),
"start_char": start_char,
"end_char": end_char
})
return token_data |