Spaces:
Running
Running
| # explainer.py | |
| import shap | |
| import torch | |
| import torch.nn.functional as F | |
| def generate_shap_explanation( | |
| cleaned_code: str, | |
| model: torch.nn.Module, | |
| tokenizer, | |
| static_features_tensor: torch.Tensor, | |
| device: torch.device, | |
| pred_idx: int, | |
| label_map: dict | |
| ): | |
| """ | |
| Generates SHAP token importance scores for the predicted complexity class. | |
| """ | |
| # SHAP Prediction Wrapper | |
| def text_prediction_wrapper(texts): | |
| texts_list = [str(t) for t in texts] | |
| encodings = tokenizer( | |
| texts_list, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| # Expand static features to match SHAP permutation batch size | |
| batch_size = encodings['input_ids'].shape[0] | |
| expanded_static = static_features_tensor.repeat(batch_size, 1) | |
| with torch.no_grad(): | |
| batch_logits = model( | |
| input_ids=encodings['input_ids'], | |
| attention_mask=encodings['attention_mask'], | |
| static_features=expanded_static | |
| ) | |
| return F.softmax(batch_logits, dim=1).cpu().numpy() | |
| # Configure SHAP Explainer | |
| masker = shap.maskers.Text(tokenizer, mask_token=tokenizer.mask_token) | |
| explainer = shap.Explainer( | |
| text_prediction_wrapper, | |
| masker, | |
| output_names=list(label_map.values()) | |
| ) | |
| # Calculate Values (max_evals=100 for API speed) | |
| shap_values = explainer([cleaned_code], max_evals=100) | |
| # Extract the specific tokens and their impact scores for the predicted class | |
| tokens = shap_values.data[0] | |
| scores = shap_values.values[0, :, pred_idx] | |
| # Map Character Offsets for Frontend Highlighting | |
| encoding = tokenizer(cleaned_code, return_offsets_mapping=True, truncation=True, max_length=512) | |
| offsets = encoding["offset_mapping"] | |
| token_data = [] | |
| for i, (t, s) in enumerate(zip(tokens, scores)): | |
| start_char = offsets[i][0] if i < len(offsets) else 0 | |
| end_char = offsets[i][1] if i < len(offsets) else 0 | |
| token_data.append({ | |
| "token": t, | |
| "score": float(s), | |
| "start_char": start_char, | |
| "end_char": end_char | |
| }) | |
| return token_data |