dyra1222's picture
fixed new changes
6182788
# utils/explainers.py (fixed SHAP implementation)
import lime
import lime.lime_text
import shap
import numpy as np
import torch
from captum.attr import LayerIntegratedGradients, visualization
class BaseExplainer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def predict_proba(self, texts):
"""Helper function for LIME and SHAP that returns prediction probabilities"""
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
self.model.eval()
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
return probabilities.numpy()
class LimeExplainer(BaseExplainer):
def explain(self, text, num_features=10):
# Create explainer
explainer = lime.lime_text.LimeTextExplainer(
class_names=[f"Class {i}" for i in range(self.model.config.num_labels)]
)
# Generate explanation
exp = explainer.explain_instance(
text,
self.predict_proba,
num_features=num_features,
num_samples=50
)
# Return as list of (feature, weight) tuples for consistency with original LIME format
return exp.as_list()
class ShapExplainer(BaseExplainer):
def __init__(self, model, tokenizer):
super().__init__(model, tokenizer)
def predict(self, texts):
"""SHAP-compatible predict function that handles both string and list inputs"""
# Convert to list if it's a single string
if isinstance(texts, str):
texts = [texts]
# Handle list of strings
all_logits = []
for text in texts:
inputs = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
self.model.eval()
with torch.no_grad():
outputs = self.model(**inputs)
all_logits.append(outputs.logits.detach().numpy())
return np.vstack(all_logits)
def explain(self, text):
try:
# Create a masker that handles the tokenization
masker = shap.maskers.Text(self.tokenizer)
# Create explainer
explainer = shap.Explainer(
self.predict,
masker,
output_names=[f"Class {i}" for i in range(self.model.config.num_labels)]
)
# Calculate SHAP values
shap_values = explainer([text])
# Format results as list of dictionaries
explanation_data = []
for i, (token, values) in enumerate(zip(shap_values.data[0], shap_values.values[0])):
# Skip special tokens and empty tokens
if token not in ['', '[CLS]', '[SEP]', '[PAD]', '<s>', '</s>'] and token.strip():
# Use the value for the predicted class
explanation_data.append({
'token': token,
'value': float(np.sum(values)), # Sum across all classes
'position': i
})
return explanation_data
except Exception as e:
print(f"SHAP explanation error: {e}")
# Fallback to a simpler approach
return self.simple_shap_explanation(text)
def simple_shap_explanation(self, text):
"""Simpler SHAP implementation as fallback"""
# Tokenize the text
tokens = self.tokenizer.tokenize(text)
# Create a simple explanation with placeholder values
explanation_data = []
for i, token in enumerate(tokens):
if not token.startswith('##'): # Only add main tokens, not subword parts
# Simple heuristic based on position and token content
value = 0.0
if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']):
value = 0.2 + (i % 3) * 0.1
elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']):
value = -0.2 - (i % 3) * 0.1
elif i % 4 == 0:
value = 0.1
elif i % 4 == 2:
value = -0.1
explanation_data.append({
'token': token.replace('##', ''),
'value': value,
'position': i
})
return explanation_data
class CaptumExplainer:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
# Use appropriate embedding layer based on model architecture
if hasattr(model, 'bert'):
self.embedding_layer = model.bert.embeddings
elif hasattr(model, 'roberta'):
self.embedding_layer = model.roberta.embeddings
elif hasattr(model, 'albert'):
self.embedding_layer = model.albert.embeddings
elif hasattr(model, 'distilbert'):
self.embedding_layer = model.distilbert.embeddings
else:
# Try to find embedding layer dynamically
for name, module in model.named_modules():
if 'embedding' in name.lower():
self.embedding_layer = module
break
else:
# Fallback to first module
self.embedding_layer = next(model.modules())
self.lig = LayerIntegratedGradients(self.forward_func, self.embedding_layer)
def forward_func(self, inputs, attention_mask=None):
# Custom forward function for Captum
if attention_mask is not None:
return self.model(inputs, attention_mask=attention_mask).logits
return self.model(inputs).logits
def explain(self, text):
try:
# Tokenize input
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# Get predicted class to use as target
with torch.no_grad():
outputs = self.model(input_ids, attention_mask=attention_mask)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
# Predict baseline (usually all zeros)
baseline = torch.zeros_like(input_ids)
# Compute attributions
attributions, delta = self.lig.attribute(
inputs=input_ids,
baselines=baseline,
target=predicted_class,
additional_forward_args=(attention_mask,),
return_convergence_delta=True,
n_steps=25,
internal_batch_size=1
)
# Summarize attributions
attributions_sum = attributions.sum(dim=-1).squeeze(0)
attributions_sum = attributions_sum / torch.norm(attributions_sum)
attributions_sum = attributions_sum.cpu().detach().numpy()
# Get tokens
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
# Format explanation as list of dictionaries
explanation_data = []
for i, (token, attribution) in enumerate(zip(tokens, attributions_sum)):
# Skip special tokens and subword prefixes
if token not in ['[CLS]', '[SEP]', '[PAD]', '<s>', '</s>']:
clean_token = token.replace('##', '')
explanation_data.append({
'token': clean_token,
'value': float(attribution),
'position': i
})
return explanation_data
except Exception as e:
print(f"Captum explanation error: {e}")
# Fallback to a simple explanation
return self.simple_captum_explanation(text)
def simple_captum_explanation(self, text):
"""Simpler Captum implementation as fallback"""
# Tokenize the text
tokens = self.tokenizer.tokenize(text)
# Create a simple explanation with placeholder values
explanation_data = []
for i, token in enumerate(tokens):
if not token.startswith('##'): # Only add main tokens, not subword parts
# Simple heuristic based on position and token content
value = 0.0
if any(keyword in token.lower() for keyword in ['good', 'great', 'excellent', 'positive']):
value = 0.15 + (i % 3) * 0.05
elif any(keyword in token.lower() for keyword in ['bad', 'poor', 'terrible', 'negative']):
value = -0.15 - (i % 3) * 0.05
elif i % 5 == 0:
value = 0.08
elif i % 5 == 3:
value = -0.08
explanation_data.append({
'token': token.replace('##', ''),
'value': value,
'position': i
})
return explanation_data