SupportMind / src /interpretability.py
Asmitha-28's picture
Upload src/interpretability.py with huggingface_hub
0e002b5 verified
# src/interpretability.py
# Model Interpretability Module — SHAP Explanations
# SupportMind v1.0 — Asmitha
import torch
import shap
import numpy as np
from transformers import pipeline
import logging
from typing import Dict, List, Any
logger = logging.getLogger(__name__)
class SupportMindExplainer:
"""
Provides SHAP-based explanations for DistilBERT predictions.
Helps support agents understand WHY a ticket was routed to a specific
category by highlighting the most influential words.
"""
def __init__(self, model, tokenizer, device='cpu'):
self.model = model
self.tokenizer = tokenizer
self.device = device
# Create a transformers pipeline for SHAP
self.pipe = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if device == 'cuda' else -1,
top_k=None, # Get all class probabilities
framework="pt" # Force PyTorch
)
# Initialize SHAP explainer
# We use a simple wrap to make it compatible with SHAP's expectations
def predictor(texts):
# Convert numpy array to list if necessary for transformers pipeline
if isinstance(texts, np.ndarray):
texts = texts.tolist()
outputs = self.pipe(texts, batch_size=32)
# SHAP expects a matrix of [num_samples, num_classes]
# Outputs is a list of lists of dicts: [[{'label': 'LABEL_0', 'score': 0.1}, ...], ...]
# We need to ensure the order matches the CATEGORY_MAP
from confidence_router import CATEGORY_MAP, CATEGORY_REVERSE
num_classes = len(CATEGORY_MAP)
results = np.zeros((len(texts), num_classes))
for i, out in enumerate(outputs):
for item in out:
# HF labels are usually 'LABEL_N' or the actual category names
label = item['label']
score = item['score']
# If label is 'LABEL_N'
if label.startswith('LABEL_'):
idx = int(label.split('_')[1])
results[i, idx] = score
# If label is the category name
elif label in CATEGORY_REVERSE:
idx = CATEGORY_REVERSE[label]
results[i, idx] = score
return results
# SHAP Explainer for text
# Using a small masker for performance
self.explainer = shap.Explainer(predictor, self.tokenizer)
def explain(self, text: str, target_class_idx: int = None) -> Dict[str, Any]:
"""
Generate SHAP values for a single ticket.
Args:
text: The ticket text.
target_class_idx: The class index to explain. If None, uses the predicted class.
Returns:
Dictionary with tokens and their corresponding SHAP values.
"""
try:
# Generate SHAP values
# This can be slow for long texts, but for tickets (~128 tokens) it's manageable.
# Capped max_evals to 500 to ensure fast response times during demos.
shap_values = self.explainer([text], max_evals=500)
# If target_class_idx is not provided, use the one with highest mean SHAP value
if target_class_idx is None:
# shap_values.values has shape [samples, tokens, classes]
# We take the class that has the highest average value for this sample
target_class_idx = np.argmax(np.mean(np.abs(shap_values.values[0]), axis=0))
# Extract tokens and values for the target class
# shap_values[sample_idx, :, class_idx]
values = shap_values.values[0, :, target_class_idx]
base_value = float(shap_values.base_values[0, target_class_idx])
# SHAP returns tokens as they are produced by the tokenizer (e.g. '##ing', ' [CLS]')
# We want to map these back to something readable if possible, but raw tokens are okay for highlighting
tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
# Filter out special tokens like [CLS], [SEP], [PAD] for the final output
# but keep the alignment
result_tokens = []
result_values = []
for token, val in zip(tokens, values):
if token in [self.tokenizer.cls_token, self.tokenizer.sep_token, self.tokenizer.pad_token]:
continue
result_tokens.append(token)
result_values.append(float(val))
return {
'tokens': result_tokens,
'values': result_values,
'base_value': base_value,
'target_class': target_class_idx,
'prediction_value': float(base_value + np.sum(values))
}
except Exception as e:
logger.error(f"SHAP explanation failed: {e}")
return {'error': str(e)}
if __name__ == '__main__':
# Test
from confidence_router import ConfidenceGatedRouter
router = ConfidenceGatedRouter()
explainer = SupportMindExplainer(router.model, router.tokenizer)
test_text = "My invoice is wrong, please fix the billing error."
res = explainer.explain(test_text)
print(f"Tokens: {res['tokens']}")
print(f"Values: {res['values']}")