Spaces:
Sleeping
Sleeping
| # utils/explainers.py (fixed SHAP implementation) | |
| import lime | |
| import lime.lime_text | |
| import shap | |
| import numpy as np | |
| import torch | |
| from captum.attr import LayerIntegratedGradients, visualization | |
| class BaseExplainer: | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| def predict_proba(self, texts): | |
| """Helper function for LIME and SHAP that returns prediction probabilities""" | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| inputs = self.tokenizer( | |
| texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probabilities = torch.softmax(outputs.logits, dim=1) | |
| return probabilities.numpy() | |
| class LimeExplainer(BaseExplainer): | |
| def explain(self, text, num_features=10): | |
| # Create explainer | |
| explainer = lime.lime_text.LimeTextExplainer( | |
| class_names=[f"Class {i}" for i in range(self.model.config.num_labels)] | |
| ) | |
| # Generate explanation | |
| exp = explainer.explain_instance( | |
| text, | |
| self.predict_proba, | |
| num_features=num_features, | |
| num_samples=50 | |
| ) | |
| # Return as list of (feature, weight) tuples for consistency with original LIME format | |
| return exp.as_list() | |
| class ShapExplainer(BaseExplainer): | |
| def __init__(self, model, tokenizer): | |
| super().__init__(model, tokenizer) | |
| def predict(self, texts): | |
| """SHAP-compatible predict function that handles both string and list inputs""" | |
| # Convert to list if it's a single string | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| # Handle list of strings | |
| all_logits = [] | |
| for text in texts: | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| self.model.eval() | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| all_logits.append(outputs.logits.detach().numpy()) | |
| return np.vstack(all_logits) | |
| def explain(self, text): | |
| try: | |
| # Create a masker that handles the tokenization | |
| masker = shap.maskers.Text(self.tokenizer) | |
| # Create explainer | |
| explainer = shap.Explainer( | |
| self.predict, | |
| masker, | |
| output_names=[f"Class {i}" for i in range(self.model.config.num_labels)] | |
| ) | |
| # Calculate SHAP values | |
| shap_values = explainer([text]) | |
| # Format results as list of dictionaries | |
| explanation_data = [] | |
| for i, (token, values) in enumerate(zip(shap_values.data[0], shap_values.values[0])): | |
| # Skip special tokens and empty tokens | |
| if token not in ['', '[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and token.strip(): | |
| # Use the value for the predicted class | |
| explanation_data.append({ | |
| 'token': token, | |
| 'value': float(np.sum(values)), # Sum across all classes | |
| 'position': i | |
| }) | |
| return explanation_data | |
| except Exception as e: | |
| print(f"SHAP explanation error: {e}") | |
| # Fallback to a simpler approach | |
| return self.simple_shap_explanation(text) | |
| def simple_shap_explanation(self, text): | |
| """Simpler SHAP implementation as fallback""" | |
| # Tokenize the text | |
| tokens = self.tokenizer.tokenize(text) | |
| # Create a simple explanation with placeholder values | |
| explanation_data = [] | |
| for i, token in enumerate(tokens): | |
| if not token.startswith('##'): # Only add main tokens, not subword parts | |
| # Simple heuristic based on position and token content | |
| value = 0.0 | |
| if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']): | |
| value = 0.2 + (i % 3) * 0.1 | |
| elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']): | |
| value = -0.2 - (i % 3) * 0.1 | |
| elif i % 4 == 0: | |
| value = 0.1 | |
| elif i % 4 == 2: | |
| value = -0.1 | |
| explanation_data.append({ | |
| 'token': token.replace('##', ''), | |
| 'value': value, | |
| 'position': i | |
| }) | |
| return explanation_data | |
| class CaptumExplainer: | |
| def __init__(self, model, tokenizer): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| # Use appropriate embedding layer based on model architecture | |
| if hasattr(model, 'bert'): | |
| self.embedding_layer = model.bert.embeddings | |
| elif hasattr(model, 'roberta'): | |
| self.embedding_layer = model.roberta.embeddings | |
| elif hasattr(model, 'albert'): | |
| self.embedding_layer = model.albert.embeddings | |
| elif hasattr(model, 'distilbert'): | |
| self.embedding_layer = model.distilbert.embeddings | |
| else: | |
| # Try to find embedding layer dynamically | |
| for name, module in model.named_modules(): | |
| if 'embedding' in name.lower(): | |
| self.embedding_layer = module | |
| break | |
| else: | |
| # Fallback to first module | |
| self.embedding_layer = next(model.modules()) | |
| self.lig = LayerIntegratedGradients(self.forward_func, self.embedding_layer) | |
| def forward_func(self, inputs, attention_mask=None): | |
| # Custom forward function for Captum | |
| if attention_mask is not None: | |
| return self.model(inputs, attention_mask=attention_mask).logits | |
| return self.model(inputs).logits | |
| def explain(self, text): | |
| try: | |
| # Tokenize input | |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| # Get predicted class to use as target | |
| with torch.no_grad(): | |
| outputs = self.model(input_ids, attention_mask=attention_mask) | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| # Predict baseline (usually all zeros) | |
| baseline = torch.zeros_like(input_ids) | |
| # Compute attributions | |
| attributions, delta = self.lig.attribute( | |
| inputs=input_ids, | |
| baselines=baseline, | |
| target=predicted_class, | |
| additional_forward_args=(attention_mask,), | |
| return_convergence_delta=True, | |
| n_steps=25, | |
| internal_batch_size=1 | |
| ) | |
| # Summarize attributions | |
| attributions_sum = attributions.sum(dim=-1).squeeze(0) | |
| attributions_sum = attributions_sum / torch.norm(attributions_sum) | |
| attributions_sum = attributions_sum.cpu().detach().numpy() | |
| # Get tokens | |
| tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| # Format explanation as list of dictionaries | |
| explanation_data = [] | |
| for i, (token, attribution) in enumerate(zip(tokens, attributions_sum)): | |
| # Skip special tokens and subword prefixes | |
| if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>']: | |
| clean_token = token.replace('##', '') | |
| explanation_data.append({ | |
| 'token': clean_token, | |
| 'value': float(attribution), | |
| 'position': i | |
| }) | |
| return explanation_data | |
| except Exception as e: | |
| print(f"Captum explanation error: {e}") | |
| # Fallback to a simple explanation | |
| return self.simple_captum_explanation(text) | |
| def simple_captum_explanation(self, text): | |
| """Simpler Captum implementation as fallback""" | |
| # Tokenize the text | |
| tokens = self.tokenizer.tokenize(text) | |
| # Create a simple explanation with placeholder values | |
| explanation_data = [] | |
| for i, token in enumerate(tokens): | |
| if not token.startswith('##'): # Only add main tokens, not subword parts | |
| # Simple heuristic based on position and token content | |
| value = 0.0 | |
| if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']): | |
| value = 0.15 + (i % 3) * 0.05 | |
| elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']): | |
| value = -0.15 - (i % 3) * 0.05 | |
| elif i % 5 == 0: | |
| value = 0.08 | |
| elif i % 5 == 3: | |
| value = -0.08 | |
| explanation_data.append({ | |
| 'token': token.replace('##', ''), | |
| 'value': value, | |
| 'position': i | |
| }) | |
| return explanation_data |