File size: 5,660 Bytes
0e002b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# src/interpretability.py
# Model Interpretability Module — SHAP Explanations
# SupportMind v1.0 — Asmitha

import torch
import shap
import numpy as np
from transformers import pipeline
import logging
from typing import Dict, List, Any

logger = logging.getLogger(__name__)

class SupportMindExplainer:
    """
    Provides SHAP-based explanations for DistilBERT predictions.
    
    Helps support agents understand WHY a ticket was routed to a specific
    category by highlighting the most influential words.
    """
    
    def __init__(self, model, tokenizer, device='cpu'):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device
        
        # Create a transformers pipeline for SHAP
        self.pipe = pipeline(
            "text-classification",
            model=self.model,
            tokenizer=self.tokenizer,
            device=0 if device == 'cuda' else -1,
            top_k=None, # Get all class probabilities
            framework="pt" # Force PyTorch
        )

        
        # Initialize SHAP explainer
        # We use a simple wrap to make it compatible with SHAP's expectations
        def predictor(texts):
            # Convert numpy array to list if necessary for transformers pipeline
            if isinstance(texts, np.ndarray):
                texts = texts.tolist()
            outputs = self.pipe(texts, batch_size=32)
            # SHAP expects a matrix of [num_samples, num_classes]
            # Outputs is a list of lists of dicts: [[{'label': 'LABEL_0', 'score': 0.1}, ...], ...]
            # We need to ensure the order matches the CATEGORY_MAP
            from confidence_router import CATEGORY_MAP, CATEGORY_REVERSE
            
            num_classes = len(CATEGORY_MAP)
            results = np.zeros((len(texts), num_classes))
            
            for i, out in enumerate(outputs):
                for item in out:
                    # HF labels are usually 'LABEL_N' or the actual category names
                    label = item['label']
                    score = item['score']
                    
                    # If label is 'LABEL_N'
                    if label.startswith('LABEL_'):
                        idx = int(label.split('_')[1])
                        results[i, idx] = score
                    # If label is the category name
                    elif label in CATEGORY_REVERSE:
                        idx = CATEGORY_REVERSE[label]
                        results[i, idx] = score
                        
            return results

        # SHAP Explainer for text
        # Using a small masker for performance
        self.explainer = shap.Explainer(predictor, self.tokenizer)

    def explain(self, text: str, target_class_idx: int = None) -> Dict[str, Any]:
        """
        Generate SHAP values for a single ticket.
        
        Args:
            text: The ticket text.
            target_class_idx: The class index to explain. If None, uses the predicted class.
            
        Returns:
            Dictionary with tokens and their corresponding SHAP values.
        """
        try:
            # Generate SHAP values
            # This can be slow for long texts, but for tickets (~128 tokens) it's manageable.
            # Capped max_evals to 500 to ensure fast response times during demos.
            shap_values = self.explainer([text], max_evals=500)
            
            # If target_class_idx is not provided, use the one with highest mean SHAP value
            if target_class_idx is None:
                # shap_values.values has shape [samples, tokens, classes]
                # We take the class that has the highest average value for this sample
                target_class_idx = np.argmax(np.mean(np.abs(shap_values.values[0]), axis=0))

            # Extract tokens and values for the target class
            # shap_values[sample_idx, :, class_idx]
            values = shap_values.values[0, :, target_class_idx]
            base_value = float(shap_values.base_values[0, target_class_idx])
            
            # SHAP returns tokens as they are produced by the tokenizer (e.g. '##ing', ' [CLS]')
            # We want to map these back to something readable if possible, but raw tokens are okay for highlighting
            tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
            
            # Filter out special tokens like [CLS], [SEP], [PAD] for the final output
            # but keep the alignment
            result_tokens = []
            result_values = []
            
            for token, val in zip(tokens, values):
                if token in [self.tokenizer.cls_token, self.tokenizer.sep_token, self.tokenizer.pad_token]:
                    continue
                result_tokens.append(token)
                result_values.append(float(val))

            return {
                'tokens': result_tokens,
                'values': result_values,
                'base_value': base_value,
                'target_class': target_class_idx,
                'prediction_value': float(base_value + np.sum(values))
            }
        except Exception as e:
            logger.error(f"SHAP explanation failed: {e}")
            return {'error': str(e)}

if __name__ == '__main__':
    # Test
    from confidence_router import ConfidenceGatedRouter
    router = ConfidenceGatedRouter()
    explainer = SupportMindExplainer(router.model, router.tokenizer)
    
    test_text = "My invoice is wrong, please fix the billing error."
    res = explainer.explain(test_text)
    print(f"Tokens: {res['tokens']}")
    print(f"Values: {res['values']}")