Spaces:
Running
Running
File size: 5,660 Bytes
0e002b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | # src/interpretability.py
# Model Interpretability Module — SHAP Explanations
# SupportMind v1.0 — Asmitha
import torch
import shap
import numpy as np
from transformers import pipeline
import logging
from typing import Dict, List, Any
logger = logging.getLogger(__name__)
class SupportMindExplainer:
"""
Provides SHAP-based explanations for DistilBERT predictions.
Helps support agents understand WHY a ticket was routed to a specific
category by highlighting the most influential words.
"""
def __init__(self, model, tokenizer, device='cpu'):
self.model = model
self.tokenizer = tokenizer
self.device = device
# Create a transformers pipeline for SHAP
self.pipe = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if device == 'cuda' else -1,
top_k=None, # Get all class probabilities
framework="pt" # Force PyTorch
)
# Initialize SHAP explainer
# We use a simple wrap to make it compatible with SHAP's expectations
def predictor(texts):
# Convert numpy array to list if necessary for transformers pipeline
if isinstance(texts, np.ndarray):
texts = texts.tolist()
outputs = self.pipe(texts, batch_size=32)
# SHAP expects a matrix of [num_samples, num_classes]
# Outputs is a list of lists of dicts: [[{'label': 'LABEL_0', 'score': 0.1}, ...], ...]
# We need to ensure the order matches the CATEGORY_MAP
from confidence_router import CATEGORY_MAP, CATEGORY_REVERSE
num_classes = len(CATEGORY_MAP)
results = np.zeros((len(texts), num_classes))
for i, out in enumerate(outputs):
for item in out:
# HF labels are usually 'LABEL_N' or the actual category names
label = item['label']
score = item['score']
# If label is 'LABEL_N'
if label.startswith('LABEL_'):
idx = int(label.split('_')[1])
results[i, idx] = score
# If label is the category name
elif label in CATEGORY_REVERSE:
idx = CATEGORY_REVERSE[label]
results[i, idx] = score
return results
# SHAP Explainer for text
# Using a small masker for performance
self.explainer = shap.Explainer(predictor, self.tokenizer)
def explain(self, text: str, target_class_idx: int = None) -> Dict[str, Any]:
"""
Generate SHAP values for a single ticket.
Args:
text: The ticket text.
target_class_idx: The class index to explain. If None, uses the predicted class.
Returns:
Dictionary with tokens and their corresponding SHAP values.
"""
try:
# Generate SHAP values
# This can be slow for long texts, but for tickets (~128 tokens) it's manageable.
# Capped max_evals to 500 to ensure fast response times during demos.
shap_values = self.explainer([text], max_evals=500)
# If target_class_idx is not provided, use the one with highest mean SHAP value
if target_class_idx is None:
# shap_values.values has shape [samples, tokens, classes]
# We take the class that has the highest average value for this sample
target_class_idx = np.argmax(np.mean(np.abs(shap_values.values[0]), axis=0))
# Extract tokens and values for the target class
# shap_values[sample_idx, :, class_idx]
values = shap_values.values[0, :, target_class_idx]
base_value = float(shap_values.base_values[0, target_class_idx])
# SHAP returns tokens as they are produced by the tokenizer (e.g. '##ing', ' [CLS]')
# We want to map these back to something readable if possible, but raw tokens are okay for highlighting
tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
# Filter out special tokens like [CLS], [SEP], [PAD] for the final output
# but keep the alignment
result_tokens = []
result_values = []
for token, val in zip(tokens, values):
if token in [self.tokenizer.cls_token, self.tokenizer.sep_token, self.tokenizer.pad_token]:
continue
result_tokens.append(token)
result_values.append(float(val))
return {
'tokens': result_tokens,
'values': result_values,
'base_value': base_value,
'target_class': target_class_idx,
'prediction_value': float(base_value + np.sum(values))
}
except Exception as e:
logger.error(f"SHAP explanation failed: {e}")
return {'error': str(e)}
if __name__ == '__main__':
# Test
from confidence_router import ConfidenceGatedRouter
router = ConfidenceGatedRouter()
explainer = SupportMindExplainer(router.model, router.tokenizer)
test_text = "My invoice is wrong, please fix the billing error."
res = explainer.explain(test_text)
print(f"Tokens: {res['tokens']}")
print(f"Values: {res['values']}")
|