# utils/explainers.py (fixed SHAP implementation) import lime import lime.lime_text import shap import numpy as np import torch from captum.attr import LayerIntegratedGradients, visualization class BaseExplainer: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def predict_proba(self, texts): """Helper function for LIME and SHAP that returns prediction probabilities""" # Ensure texts is a list if isinstance(texts, str): texts = [texts] inputs = self.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=512 ) self.model.eval() with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.softmax(outputs.logits, dim=1) return probabilities.numpy() class LimeExplainer(BaseExplainer): def explain(self, text, num_features=10): # Create explainer explainer = lime.lime_text.LimeTextExplainer( class_names=[f"Class {i}" for i in range(self.model.config.num_labels)] ) # Generate explanation exp = explainer.explain_instance( text, self.predict_proba, num_features=num_features, num_samples=50 ) # Return as list of (feature, weight) tuples for consistency with original LIME format return exp.as_list() class ShapExplainer(BaseExplainer): def __init__(self, model, tokenizer): super().__init__(model, tokenizer) def predict(self, texts): """SHAP-compatible predict function that handles both string and list inputs""" # Convert to list if it's a single string if isinstance(texts, str): texts = [texts] # Handle list of strings all_logits = [] for text in texts: inputs = self.tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=512 ) self.model.eval() with torch.no_grad(): outputs = self.model(**inputs) all_logits.append(outputs.logits.detach().numpy()) return np.vstack(all_logits) def explain(self, text): try: # Create a masker that handles the tokenization masker = shap.maskers.Text(self.tokenizer) # Create explainer explainer = shap.Explainer( self.predict, masker, output_names=[f"Class {i}" for i in range(self.model.config.num_labels)] ) # Calculate SHAP values shap_values = explainer([text]) # Format results as list of dictionaries explanation_data = [] for i, (token, values) in enumerate(zip(shap_values.data[0], shap_values.values[0])): # Skip special tokens and empty tokens if token not in ['', '[CLS]', '[SEP]', '[PAD]', '', ''] and token.strip(): # Use the value for the predicted class explanation_data.append({ 'token': token, 'value': float(np.sum(values)), # Sum across all classes 'position': i }) return explanation_data except Exception as e: print(f"SHAP explanation error: {e}") # Fallback to a simpler approach return self.simple_shap_explanation(text) def simple_shap_explanation(self, text): """Simpler SHAP implementation as fallback""" # Tokenize the text tokens = self.tokenizer.tokenize(text) # Create a simple explanation with placeholder values explanation_data = [] for i, token in enumerate(tokens): if not token.startswith('##'): # Only add main tokens, not subword parts # Simple heuristic based on position and token content value = 0.0 if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']): value = 0.2 + (i % 3) * 0.1 elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']): value = -0.2 - (i % 3) * 0.1 elif i % 4 == 0: value = 0.1 elif i % 4 == 2: value = -0.1 explanation_data.append({ 'token': token.replace('##', ''), 'value': value, 'position': i }) return explanation_data class CaptumExplainer: def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer # Use appropriate embedding layer based on model architecture if hasattr(model, 'bert'): self.embedding_layer = model.bert.embeddings elif hasattr(model, 'roberta'): self.embedding_layer = model.roberta.embeddings elif hasattr(model, 'albert'): self.embedding_layer = model.albert.embeddings elif hasattr(model, 'distilbert'): self.embedding_layer = model.distilbert.embeddings else: # Try to find embedding layer dynamically for name, module in model.named_modules(): if 'embedding' in name.lower(): self.embedding_layer = module break else: # Fallback to first module self.embedding_layer = next(model.modules()) self.lig = LayerIntegratedGradients(self.forward_func, self.embedding_layer) def forward_func(self, inputs, attention_mask=None): # Custom forward function for Captum if attention_mask is not None: return self.model(inputs, attention_mask=attention_mask).logits return self.model(inputs).logits def explain(self, text): try: # Tokenize input inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) input_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] # Get predicted class to use as target with torch.no_grad(): outputs = self.model(input_ids, attention_mask=attention_mask) predicted_class = torch.argmax(outputs.logits, dim=1).item() # Predict baseline (usually all zeros) baseline = torch.zeros_like(input_ids) # Compute attributions attributions, delta = self.lig.attribute( inputs=input_ids, baselines=baseline, target=predicted_class, additional_forward_args=(attention_mask,), return_convergence_delta=True, n_steps=25, internal_batch_size=1 ) # Summarize attributions attributions_sum = attributions.sum(dim=-1).squeeze(0) attributions_sum = attributions_sum / torch.norm(attributions_sum) attributions_sum = attributions_sum.cpu().detach().numpy() # Get tokens tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) # Format explanation as list of dictionaries explanation_data = [] for i, (token, attribution) in enumerate(zip(tokens, attributions_sum)): # Skip special tokens and subword prefixes if token not in ['[CLS]', '[SEP]', '[PAD]', '', '']: clean_token = token.replace('##', '') explanation_data.append({ 'token': clean_token, 'value': float(attribution), 'position': i }) return explanation_data except Exception as e: print(f"Captum explanation error: {e}") # Fallback to a simple explanation return self.simple_captum_explanation(text) def simple_captum_explanation(self, text): """Simpler Captum implementation as fallback""" # Tokenize the text tokens = self.tokenizer.tokenize(text) # Create a simple explanation with placeholder values explanation_data = [] for i, token in enumerate(tokens): if not token.startswith('##'): # Only add main tokens, not subword parts # Simple heuristic based on position and token content value = 0.0 if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']): value = 0.15 + (i % 3) * 0.05 elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']): value = -0.15 - (i % 3) * 0.05 elif i % 5 == 0: value = 0.08 elif i % 5 == 3: value = -0.08 explanation_data.append({ 'token': token.replace('##', ''), 'value': value, 'position': i }) return explanation_data