| | """Model interpretability and visualization tools.""" |
| |
|
| | import numpy as np |
| | import matplotlib |
| | matplotlib.use('Agg') |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from typing import List, Dict, Any, Optional, Tuple |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import warnings |
| |
|
| | |
| | try: |
| | import shap |
| | SHAP_AVAILABLE = True |
| | except ImportError: |
| | SHAP_AVAILABLE = False |
| | warnings.warn("SHAP not installed. Install with: pip install shap") |
| |
|
| |
|
| | class AttentionVisualizer: |
| | """Visualize attention weights from transformer models.""" |
| | |
| | def __init__(self, model, tokenizer): |
| | """ |
| | Initialize attention visualizer. |
| | |
| | Args: |
| | model: Transformer model |
| | tokenizer: Corresponding tokenizer |
| | """ |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.device = next(model.parameters()).device |
| | |
| | def get_attention_weights(self, text: str) -> Dict[str, Any]: |
| | """Get attention weights for a given text.""" |
| | |
| | inputs = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=512 |
| | ) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs, output_attentions=True) |
| | attentions = outputs.attentions |
| | |
| | |
| | attention_weights = [att.cpu().numpy() for att in attentions] |
| | tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
| | |
| | return { |
| | "tokens": tokens, |
| | "attention_weights": attention_weights, |
| | "input_ids": inputs["input_ids"].cpu().numpy(), |
| | "predictions": torch.softmax(outputs.logits, dim=-1).cpu().numpy() |
| | } |
| | |
| | def plot_attention_heatmap( |
| | self, |
| | text: str, |
| | layer: int = -1, |
| | head: int = 0, |
| | save_path: Optional[str] = None |
| | ): |
| | """ |
| | Plot attention heatmap for a specific layer and head. |
| | |
| | Args: |
| | text: Input text |
| | layer: Layer index (-1 for last layer) |
| | head: Attention head index |
| | save_path: Path to save the plot |
| | """ |
| | attention_data = self.get_attention_weights(text) |
| | tokens = attention_data["tokens"] |
| | attention_weights = attention_data["attention_weights"] |
| | |
| | |
| | layer_attention = attention_weights[layer][0, head] |
| | |
| | |
| | plt.figure(figsize=(12, 10)) |
| | |
| | |
| | token_labels = [] |
| | for token in tokens: |
| | if token.startswith('##'): |
| | token_labels.append(token[2:]) |
| | elif token in ['[CLS]', '[SEP]', '[PAD]']: |
| | token_labels.append(token) |
| | else: |
| | token_labels.append(token) |
| | |
| | |
| | max_tokens = 50 |
| | if len(token_labels) > max_tokens: |
| | layer_attention = layer_attention[:max_tokens, :max_tokens] |
| | token_labels = token_labels[:max_tokens] |
| | |
| | sns.heatmap( |
| | layer_attention, |
| | xticklabels=token_labels, |
| | yticklabels=token_labels, |
| | cmap='Blues', |
| | cbar=True, |
| | square=True |
| | ) |
| | |
| | plt.title(f'Attention Weights - Layer {layer}, Head {head}') |
| | plt.xlabel('Key Tokens') |
| | plt.ylabel('Query Tokens') |
| | plt.xticks(rotation=45, ha='right') |
| | plt.yticks(rotation=0) |
| | plt.tight_layout() |
| | |
| | if save_path: |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | |
| | plt.show() |
| | |
| | def plot_attention_summary( |
| | self, |
| | text: str, |
| | save_path: Optional[str] = None |
| | ): |
| | """ |
| | Plot attention summary across all layers and heads. |
| | |
| | Args: |
| | text: Input text |
| | save_path: Path to save the plot |
| | """ |
| | attention_data = self.get_attention_weights(text) |
| | attention_weights = attention_data["attention_weights"] |
| | tokens = attention_data["tokens"] |
| | |
| | num_layers = len(attention_weights) |
| | num_heads = attention_weights[0].shape[1] |
| | |
| | |
| | layer_avg_attention = [] |
| | for layer_att in attention_weights: |
| | |
| | avg_att = np.mean(layer_att[0]) |
| | layer_avg_attention.append(avg_att) |
| | |
| | |
| | head_attention_variance = [] |
| | for head in range(num_heads): |
| | head_variances = [] |
| | for layer_att in attention_weights: |
| | head_att = layer_att[0, head] |
| | variance = np.var(head_att) |
| | head_variances.append(variance) |
| | head_attention_variance.append(head_variances) |
| | |
| | |
| | fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) |
| | |
| | |
| | ax1.plot(range(num_layers), layer_avg_attention, marker='o') |
| | ax1.set_title('Average Attention Weight per Layer') |
| | ax1.set_xlabel('Layer') |
| | ax1.set_ylabel('Average Attention') |
| | ax1.grid(True) |
| | |
| | |
| | for head in range(min(num_heads, 8)): |
| | ax2.plot(range(num_layers), head_attention_variance[head], |
| | marker='o', label=f'Head {head}') |
| | ax2.set_title('Attention Variance per Head Across Layers') |
| | ax2.set_xlabel('Layer') |
| | ax2.set_ylabel('Attention Variance') |
| | ax2.legend() |
| | ax2.grid(True) |
| | |
| | |
| | last_layer_att = attention_weights[-1][0, 0] |
| | max_tokens = 20 |
| | if len(tokens) > max_tokens: |
| | last_layer_att = last_layer_att[:max_tokens, :max_tokens] |
| | display_tokens = tokens[:max_tokens] |
| | else: |
| | display_tokens = tokens |
| | |
| | im = ax3.imshow(last_layer_att, cmap='Blues') |
| | ax3.set_title('Last Layer Attention (Head 0)') |
| | ax3.set_xticks(range(len(display_tokens))) |
| | ax3.set_yticks(range(len(display_tokens))) |
| | ax3.set_xticklabels(display_tokens, rotation=45, ha='right') |
| | ax3.set_yticklabels(display_tokens) |
| | |
| | |
| | token_attention_sum = np.sum(last_layer_att, axis=0) |
| | ax4.bar(range(len(display_tokens)), token_attention_sum) |
| | ax4.set_title('Total Attention Received per Token') |
| | ax4.set_xlabel('Token') |
| | ax4.set_ylabel('Total Attention') |
| | ax4.set_xticks(range(len(display_tokens))) |
| | ax4.set_xticklabels(display_tokens, rotation=45, ha='right') |
| | |
| | plt.tight_layout() |
| | |
| | if save_path: |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| | |
| | plt.show() |
| |
|
| |
|
| | class SHAPExplainer: |
| | """SHAP-based explainability for transformer models.""" |
| | |
| | def __init__(self, model, tokenizer): |
| | """ |
| | Initialize SHAP explainer. |
| | |
| | Args: |
| | model: Transformer model |
| | tokenizer: Corresponding tokenizer |
| | """ |
| | if not SHAP_AVAILABLE: |
| | raise ImportError("SHAP is required for this functionality. Install with: pip install shap") |
| | |
| | self.model = model |
| | self.tokenizer = tokenizer |
| | self.device = next(model.parameters()).device |
| | |
| | |
| | self.explainer = shap.Explainer(self._predict_fn, self.tokenizer) |
| | |
| | def _predict_fn(self, texts): |
| | """Prediction function for SHAP.""" |
| | predictions = [] |
| | |
| | for text in texts: |
| | inputs = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=512 |
| | ) |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | probs = torch.softmax(outputs.logits, dim=-1) |
| | predictions.append(probs.cpu().numpy()[0]) |
| | |
| | return np.array(predictions) |
| | |
| | def explain_text(self, text: str, max_evals: int = 100): |
| | """ |
| | Generate SHAP explanations for a text. |
| | |
| | Args: |
| | text: Input text to explain |
| | max_evals: Maximum number of evaluations for SHAP |
| | |
| | Returns: |
| | SHAP explanation object |
| | """ |
| | shap_values = self.explainer([text], max_evals=max_evals) |
| | return shap_values |
| | |
| | def plot_shap_explanation( |
| | self, |
| | text: str, |
| | class_index: int = 1, |
| | max_evals: int = 100, |
| | save_path: Optional[str] = None |
| | ): |
| | """ |
| | Plot SHAP explanation for a specific class. |
| | |
| | Args: |
| | text: Input text |
| | class_index: Class index to explain |
| | max_evals: Maximum evaluations for SHAP |
| | save_path: Path to save the plot |
| | """ |
| | shap_values = self.explain_text(text, max_evals=max_evals) |
| | |
| | |
| | shap.plots.text(shap_values[0, :, class_index]) |
| | |
| | if save_path: |
| | plt.savefig(save_path, dpi=300, bbox_inches='tight') |
| |
|
| |
|
| | class InterpretabilityPipeline: |
| | """Complete interpretability pipeline combining multiple methods.""" |
| | |
| | def __init__(self, model_path: str): |
| | """ |
| | Initialize interpretability pipeline. |
| | |
| | Args: |
| | model_path: Path to trained model |
| | """ |
| | self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | self.model.eval() |
| | |
| | |
| | self.attention_viz = AttentionVisualizer(self.model, self.tokenizer) |
| | |
| | if SHAP_AVAILABLE: |
| | self.shap_explainer = SHAPExplainer(self.model, self.tokenizer) |
| | else: |
| | self.shap_explainer = None |
| | print("Warning: SHAP not available. Install with: pip install shap") |
| | |
| | def full_analysis( |
| | self, |
| | text: str, |
| | output_dir: str = "./interpretability_output" |
| | ): |
| | """ |
| | Perform full interpretability analysis. |
| | |
| | Args: |
| | text: Text to analyze |
| | output_dir: Directory to save outputs |
| | """ |
| | import os |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | print(f"π Analyzing text: {text[:100]}...") |
| | |
| | |
| | inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | predictions = torch.softmax(outputs.logits, dim=-1) |
| | predicted_class = torch.argmax(predictions, dim=-1).item() |
| | confidence = predictions[0, predicted_class].item() |
| | |
| | print(f"π Prediction: Class {predicted_class}, Confidence: {confidence:.3f}") |
| | |
| | |
| | print("π― Generating attention visualizations...") |
| | self.attention_viz.plot_attention_summary( |
| | text, |
| | save_path=os.path.join(output_dir, "attention_summary.png") |
| | ) |
| | |
| | self.attention_viz.plot_attention_heatmap( |
| | text, |
| | layer=-1, |
| | head=0, |
| | save_path=os.path.join(output_dir, "attention_heatmap.png") |
| | ) |
| | |
| | |
| | if self.shap_explainer: |
| | print("π¬ Generating SHAP explanations...") |
| | try: |
| | self.shap_explainer.plot_shap_explanation( |
| | text, |
| | class_index=predicted_class, |
| | save_path=os.path.join(output_dir, "shap_explanation.png") |
| | ) |
| | except Exception as e: |
| | print(f"SHAP explanation failed: {e}") |
| | |
| | |
| | report = { |
| | "text": text, |
| | "predicted_class": int(predicted_class), |
| | "confidence": float(confidence), |
| | "model_path": self.model.config._name_or_path, |
| | "analysis_files": { |
| | "attention_summary": "attention_summary.png", |
| | "attention_heatmap": "attention_heatmap.png", |
| | "shap_explanation": "shap_explanation.png" if self.shap_explainer else None |
| | } |
| | } |
| | |
| | report_path = os.path.join(output_dir, "analysis_report.json") |
| | with open(report_path, "w") as f: |
| | import json |
| | json.dump(report, f, indent=2) |
| | |
| | print(f"β
Analysis complete! Results saved to: {output_dir}") |
| | return report |
| |
|
| |
|
| | def main(): |
| | """CLI for interpretability analysis.""" |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description="Model interpretability analysis") |
| | parser.add_argument("--model", type=str, required=True, help="Path to model") |
| | parser.add_argument("--text", type=str, required=True, help="Text to analyze") |
| | parser.add_argument("--output", type=str, default="./interpretability_output", help="Output directory") |
| | parser.add_argument("--attention-only", action="store_true", help="Only run attention analysis") |
| | |
| | args = parser.parse_args() |
| | |
| | |
| | pipeline = InterpretabilityPipeline(args.model) |
| | |
| | if args.attention_only: |
| | pipeline.attention_viz.plot_attention_summary(args.text) |
| | else: |
| | pipeline.full_analysis(args.text, args.output) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |