text_classificators / src /model_interpretation.py
theformatisvalid's picture
Upload 7 files
2153792 verified
from typing import List, Dict, Any, Optional, Union, Callable, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import torch
from sklearn.base import BaseEstimator
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings("ignore")
SHAP_AVAILABLE = False
LIME_AVAILABLE = False
CAPTUM_AVAILABLE = False
UMAP_AVAILABLE = False
try:
import shap
SHAP_AVAILABLE = True
except ImportError:
pass
try:
import lime
import lime.lime_text
LIME_AVAILABLE = True
except ImportError:
pass
try:
import captum
import captum.attr
CAPTUM_AVAILABLE = True
except ImportError:
pass
try:
import umap
UMAP_AVAILABLE = True
except ImportError:
pass
def get_linear_feature_importance(
model: BaseEstimator,
feature_names: Optional[List[str]] = None,
class_index: int = -1
) -> pd.DataFrame:
if hasattr(model, "coef_"):
coef = model.coef_
if coef.ndim == 1:
weights = coef
else:
if class_index == -1:
weights = np.mean(coef, axis=0)
else:
weights = coef[class_index]
else:
raise ValueError("Model does not have coef_ attribute")
if feature_names is None:
feature_names = [f"feature_{i}" for i in range(len(weights))]
df = pd.DataFrame({"feature": feature_names, "weight": weights})
df = df.sort_values("weight", key=abs, ascending=False).reset_index(drop=True)
return df
def analyze_tfidf_class_keywords(
tfidf_matrix: np.ndarray,
y: np.ndarray,
feature_names: List[str],
top_k: int = 20
) -> Dict[Any, pd.DataFrame]:
classes = np.unique(y)
results = {}
for cls in classes:
mask = (y == cls)
avg_tfidf = np.mean(tfidf_matrix[mask], axis=0).A1 if hasattr(tfidf_matrix, 'A1') else np.mean(tfidf_matrix[mask], axis=0)
top_indices = np.argsort(avg_tfidf)[::-1][:top_k]
top_words = [feature_names[i] for i in top_indices]
top_scores = [avg_tfidf[i] for i in top_indices]
results[cls] = pd.DataFrame({"word": top_words, "tfidf_score": top_scores})
return results
def explain_with_shap(
model: BaseEstimator,
X_train: np.ndarray,
X_test: np.ndarray,
feature_names: Optional[List[str]] = None,
plot_type: str = "bar",
max_display: int = 20
):
if "tree" in str(type(model)).lower():
explainer = shap.TreeExplainer(model)
else:
explainer = shap.KernelExplainer(model.predict_proba, X_train[:100])
shap_values = explainer.shap_values(X_test[:100])
if feature_names is None:
feature_names = [f"feat_{i}" for i in range(X_test.shape[1])]
plt.figure(figsize=(10, 6))
if isinstance(shap_values, list):
shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
else:
shap.summary_plot(shap_values, X_test[:100], feature_names=feature_names, plot_type=plot_type, max_display=max_display, show=False)
plt.tight_layout()
plt.show()
def explain_text_with_lime(
model: Any,
text: str,
tokenizer: Callable,
class_names: List[str],
num_features: int = 10,
num_samples: int = 5000
):
def predict_fn(texts):
tokenized = [tokenizer(t) for t in texts]
if hasattr(model, "vectorizer"):
X = model.vectorizer.transform(texts)
else:
raise NotImplementedError("Custom predict_fn needed for your pipeline")
return model.predict_proba(X.toarray())
explainer = lime.lime_text.LimeTextExplainer(class_names=class_names)
exp = explainer.explain_instance(text, predict_fn, num_features=num_features, num_samples=num_samples)
exp.show_in_notebook()
def visualize_attention_weights(
tokens: List[str],
attention_weights: np.ndarray,
layer: int = 0,
head: int = 0,
figsize: Tuple[int, int] = (10, 2)
):
if attention_weights.ndim != 4:
raise ValueError("attention_weights must be 4D: (layers, heads, seq, seq)")
weights = attention_weights[layer, head, :len(tokens), :len(tokens)]
plt.figure(figsize=figsize)
sns.heatmap(
weights,
xticklabels=tokens,
yticklabels=tokens,
cmap="viridis",
cbar=True
)
plt.title(f"Attention Layer {layer}, Head {head}")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
def get_transformer_attention(
model: 'torch.nn.Module',
tokenizer: 'transformers.PreTrainedTokenizer',
text: str,
device: str = "cpu"
) -> Tuple[List[str], np.ndarray]:
if not CAPTUM_AVAILABLE:
raise ImportError("Install Captum: pip install captum")
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
input_ids = inputs["input_ids"].to(device)
model = model.to(device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, output_attentions=True)
attentions = outputs.attentions
attn = torch.stack(attentions, dim=0).squeeze(1).cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
return tokens, attn
def analyze_errors(
y_true: np.ndarray,
y_pred: np.ndarray,
texts: List[str],
labels: Optional[List[Any]] = None
) -> pd.DataFrame:
errors = []
for i, (true, pred, text) in enumerate(zip(y_true, y_pred, texts)):
if true != pred:
errors.append({
"index": i,
"text": text,
"true_label": true,
"pred_label": pred
})
return pd.DataFrame(errors)
def compare_model_errors(
models: Dict[str, BaseEstimator],
X_test: np.ndarray,
y_test: np.ndarray,
texts: List[str]
) -> Dict[str, pd.DataFrame]:
results = {}
for name, model in models.items():
y_pred = model.predict(X_test)
errors = analyze_errors(y_test, y_pred, texts)
results[name] = errors
return results
def plot_embeddings(
embeddings: np.ndarray,
labels: np.ndarray,
method: str = "umap",
n_components: int = 2,
figsize: Tuple[int, int] = (12, 8),
title: str = "Embedding Projection"
):
if method == "tsne":
reducer = TSNE(n_components=n_components, random_state=42, n_jobs=-1)
elif method == "umap":
if not UMAP_AVAILABLE:
raise ImportError("Install UMAP: pip install umap-learn")
reducer = umap.UMAP(n_components=n_components, random_state=42, n_jobs=-1)
else:
raise ValueError("method must be 'tsne' or 'umap'")
proj = reducer.fit_transform(embeddings)
plt.figure(figsize=figsize)
scatter = plt.scatter(proj[:, 0], proj[:, 1], c=labels, cmap="tab10", alpha=0.7)
plt.colorbar(scatter)
plt.title(title)
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.tight_layout()
plt.show()
def get_token_importance_captum(
model: 'torch.nn.Module',
tokenizer: 'transformers.PreTrainedTokenizer',
text: str,
device: str = "cpu"
) -> Tuple[List[str], np.ndarray]:
if not CAPTUM_AVAILABLE:
raise ImportError("Install Captum: pip install captum")
from captum.attr import LayerIntegratedGradients
import torch
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
model = model.to(device)
model.eval()
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred_class = torch.argmax(outputs.logits, dim=1).item()
def forward_func(input_ids):
return model(input_ids=input_ids, attention_mask=attention_mask).logits
baseline_ids = torch.zeros_like(input_ids).to(device)
baseline_ids[:, 0] = tokenizer.cls_token_id
baseline_ids[:, -1] = tokenizer.sep_token_id
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)
attributions, delta = lig.attribute(
inputs=input_ids,
baselines=baseline_ids,
target=pred_class,
return_convergence_delta=True
)
attributions = attributions.sum(dim=-1).squeeze(0).cpu().detach().numpy()
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy())
return tokens, attributions
def plot_token_importance(tokens: List[str], importance: np.ndarray, top_k: int = 20):
valid = [(t, imp) for t, imp in zip(tokens, importance) if t not in ["[CLS]", "[SEP]", "[PAD]"]]
if not valid:
return
tokens_clean, imp_clean = zip(*valid)
indices = np.argsort(np.abs(imp_clean))[-top_k:][::-1]
tokens_top = [tokens_clean[i] for i in indices]
imp_top = [imp_clean[i] for i in indices]
plt.figure(figsize=(10, 6))
colors = ["red" if x < 0 else "green" for x in imp_top]
plt.barh(range(len(imp_top)), imp_top, color=colors)
plt.yticks(range(len(imp_top)), tokens_top)
plt.gca().invert_yaxis()
plt.xlabel("Attribution Score")
plt.title("Token Importance (Green: positive, Red: negative)")
plt.tight_layout()
plt.show()