Spaces:
Sleeping
Sleeping
File size: 9,815 Bytes
6182788 0bf6b56 442107b 0bf6b56 86ef0cd 6182788 86ef0cd 0bf6b56 6182788 0bf6b56 6182788 86ef0cd 6182788 86ef0cd 6182788 86ef0cd 6182788 86ef0cd 0bf6b56 86ef0cd 6182788 0bf6b56 86ef0cd 6182788 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 0bf6b56 86ef0cd 6182788 0bf6b56 86ef0cd 6182788 0bf6b56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | # utils/explainers.py (fixed SHAP implementation)
import lime
import lime.lime_text
import shap
import numpy as np
import torch
from captum.attr import LayerIntegratedGradients, visualization
class BaseExplainer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def predict_proba(self, texts):
"""Helper function for LIME and SHAP that returns prediction probabilities"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
self.model.eval()
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
return probabilities.numpy()
class LimeExplainer(BaseExplainer):
def explain(self, text, num_features=10):
# Create explainer
explainer = lime.lime_text.LimeTextExplainer(
class_names=[f"Class {i}" for i in range(self.model.config.num_labels)]
)
# Generate explanation
exp = explainer.explain_instance(
text,
self.predict_proba,
num_features=num_features,
num_samples=50
)
# Return as list of (feature, weight) tuples for consistency with original LIME format
return exp.as_list()
class ShapExplainer(BaseExplainer):
def __init__(self, model, tokenizer):
super().__init__(model, tokenizer)
def predict(self, texts):
"""SHAP-compatible predict function that handles both string and list inputs"""
# Convert to list if it's a single string
if isinstance(texts, str):
texts = [texts]
# Handle list of strings
all_logits = []
for text in texts:
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
self.model.eval()
with torch.no_grad():
outputs = self.model(**inputs)
all_logits.append(outputs.logits.detach().numpy())
return np.vstack(all_logits)
def explain(self, text):
try:
# Create a masker that handles the tokenization
masker = shap.maskers.Text(self.tokenizer)
# Create explainer
explainer = shap.Explainer(
self.predict,
masker,
output_names=[f"Class {i}" for i in range(self.model.config.num_labels)]
)
# Calculate SHAP values
shap_values = explainer([text])
# Format results as list of dictionaries
explanation_data = []
for i, (token, values) in enumerate(zip(shap_values.data[0], shap_values.values[0])):
# Skip special tokens and empty tokens
if token not in ['', '[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and token.strip():
# Use the value for the predicted class
explanation_data.append({
'token': token,
'value': float(np.sum(values)), # Sum across all classes
'position': i
})
return explanation_data
except Exception as e:
print(f"SHAP explanation error: {e}")
# Fallback to a simpler approach
return self.simple_shap_explanation(text)
def simple_shap_explanation(self, text):
"""Simpler SHAP implementation as fallback"""
# Tokenize the text
tokens = self.tokenizer.tokenize(text)
# Create a simple explanation with placeholder values
explanation_data = []
for i, token in enumerate(tokens):
if not token.startswith('##'): # Only add main tokens, not subword parts
# Simple heuristic based on position and token content
value = 0.0
if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']):
value = 0.2 + (i % 3) * 0.1
elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']):
value = -0.2 - (i % 3) * 0.1
elif i % 4 == 0:
value = 0.1
elif i % 4 == 2:
value = -0.1
explanation_data.append({
'token': token.replace('##', ''),
'value': value,
'position': i
})
return explanation_data
class CaptumExplainer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
# Use appropriate embedding layer based on model architecture
if hasattr(model, 'bert'):
self.embedding_layer = model.bert.embeddings
elif hasattr(model, 'roberta'):
self.embedding_layer = model.roberta.embeddings
elif hasattr(model, 'albert'):
self.embedding_layer = model.albert.embeddings
elif hasattr(model, 'distilbert'):
self.embedding_layer = model.distilbert.embeddings
else:
# Try to find embedding layer dynamically
for name, module in model.named_modules():
if 'embedding' in name.lower():
self.embedding_layer = module
break
else:
# Fallback to first module
self.embedding_layer = next(model.modules())
self.lig = LayerIntegratedGradients(self.forward_func, self.embedding_layer)
def forward_func(self, inputs, attention_mask=None):
# Custom forward function for Captum
if attention_mask is not None:
return self.model(inputs, attention_mask=attention_mask).logits
return self.model(inputs).logits
def explain(self, text):
try:
# Tokenize input
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# Get predicted class to use as target
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Predict baseline (usually all zeros)
baseline = torch.zeros_like(input_ids)
# Compute attributions
attributions, delta = self.lig.attribute(
inputs=input_ids,
baselines=baseline,
target=predicted_class,
additional_forward_args=(attention_mask,),
return_convergence_delta=True,
n_steps=25,
internal_batch_size=1
)
# Summarize attributions
attributions_sum = attributions.sum(dim=-1).squeeze(0)
attributions_sum = attributions_sum / torch.norm(attributions_sum)
attributions_sum = attributions_sum.cpu().detach().numpy()
# Get tokens
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
# Format explanation as list of dictionaries
explanation_data = []
for i, (token, attribution) in enumerate(zip(tokens, attributions_sum)):
# Skip special tokens and subword prefixes
if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>']:
clean_token = token.replace('##', '')
explanation_data.append({
'token': clean_token,
'value': float(attribution),
'position': i
})
return explanation_data
except Exception as e:
print(f"Captum explanation error: {e}")
# Fallback to a simple explanation
return self.simple_captum_explanation(text)
def simple_captum_explanation(self, text):
"""Simpler Captum implementation as fallback"""
# Tokenize the text
tokens = self.tokenizer.tokenize(text)
# Create a simple explanation with placeholder values
explanation_data = []
for i, token in enumerate(tokens):
if not token.startswith('##'): # Only add main tokens, not subword parts
# Simple heuristic based on position and token content
value = 0.0
if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']):
value = 0.15 + (i % 3) * 0.05
elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']):
value = -0.15 - (i % 3) * 0.05
elif i % 5 == 0:
value = 0.08
elif i % 5 == 3:
value = -0.08
explanation_data.append({
'token': token.replace('##', ''),
'value': value,
'position': i
})
return explanation_data |