# src/interpretability.py # Model Interpretability Module — SHAP Explanations # SupportMind v1.0 — Asmitha import torch import shap import numpy as np from transformers import pipeline import logging from typing import Dict, List, Any logger = logging.getLogger(__name__) class SupportMindExplainer: """ Provides SHAP-based explanations for DistilBERT predictions. Helps support agents understand WHY a ticket was routed to a specific category by highlighting the most influential words. """ def __init__(self, model, tokenizer, device='cpu'): self.model = model self.tokenizer = tokenizer self.device = device # Create a transformers pipeline for SHAP self.pipe = pipeline( "text-classification", model=self.model, tokenizer=self.tokenizer, device=0 if device == 'cuda' else -1, top_k=None, # Get all class probabilities framework="pt" # Force PyTorch ) # Initialize SHAP explainer # We use a simple wrap to make it compatible with SHAP's expectations def predictor(texts): # Convert numpy array to list if necessary for transformers pipeline if isinstance(texts, np.ndarray): texts = texts.tolist() outputs = self.pipe(texts, batch_size=32) # SHAP expects a matrix of [num_samples, num_classes] # Outputs is a list of lists of dicts: [[{'label': 'LABEL_0', 'score': 0.1}, ...], ...] # We need to ensure the order matches the CATEGORY_MAP from confidence_router import CATEGORY_MAP, CATEGORY_REVERSE num_classes = len(CATEGORY_MAP) results = np.zeros((len(texts), num_classes)) for i, out in enumerate(outputs): for item in out: # HF labels are usually 'LABEL_N' or the actual category names label = item['label'] score = item['score'] # If label is 'LABEL_N' if label.startswith('LABEL_'): idx = int(label.split('_')[1]) results[i, idx] = score # If label is the category name elif label in CATEGORY_REVERSE: idx = CATEGORY_REVERSE[label] results[i, idx] = score return results # SHAP Explainer for text # Using a small masker for performance self.explainer = shap.Explainer(predictor, self.tokenizer) def explain(self, text: str, target_class_idx: int = None) -> Dict[str, Any]: """ Generate SHAP values for a single ticket. Args: text: The ticket text. target_class_idx: The class index to explain. If None, uses the predicted class. Returns: Dictionary with tokens and their corresponding SHAP values. """ try: # Generate SHAP values # This can be slow for long texts, but for tickets (~128 tokens) it's manageable. # Capped max_evals to 500 to ensure fast response times during demos. shap_values = self.explainer([text], max_evals=500) # If target_class_idx is not provided, use the one with highest mean SHAP value if target_class_idx is None: # shap_values.values has shape [samples, tokens, classes] # We take the class that has the highest average value for this sample target_class_idx = np.argmax(np.mean(np.abs(shap_values.values[0]), axis=0)) # Extract tokens and values for the target class # shap_values[sample_idx, :, class_idx] values = shap_values.values[0, :, target_class_idx] base_value = float(shap_values.base_values[0, target_class_idx]) # SHAP returns tokens as they are produced by the tokenizer (e.g. '##ing', ' [CLS]') # We want to map these back to something readable if possible, but raw tokens are okay for highlighting tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) # Filter out special tokens like [CLS], [SEP], [PAD] for the final output # but keep the alignment result_tokens = [] result_values = [] for token, val in zip(tokens, values): if token in [self.tokenizer.cls_token, self.tokenizer.sep_token, self.tokenizer.pad_token]: continue result_tokens.append(token) result_values.append(float(val)) return { 'tokens': result_tokens, 'values': result_values, 'base_value': base_value, 'target_class': target_class_idx, 'prediction_value': float(base_value + np.sum(values)) } except Exception as e: logger.error(f"SHAP explanation failed: {e}") return {'error': str(e)} if __name__ == '__main__': # Test from confidence_router import ConfidenceGatedRouter router = ConfidenceGatedRouter() explainer = SupportMindExplainer(router.model, router.tokenizer) test_text = "My invoice is wrong, please fix the billing error." res = explainer.explain(test_text) print(f"Tokens: {res['tokens']}") print(f"Values: {res['values']}")