""" explainability.py — SHAP-based token-level and feature-level explainability. Two explainability layers are provided: 1. Classifier SHAP — KernelExplainer on the clause classifier. Shows which tokens drove each clause type prediction. Output: token-level SHAP bar chart (PNG). 2. Power Imbalance SHAP — KernelExplainer on the power scorer features. Shows which features (sentiment, modal verbs, obligations, assertiveness) drove the bilateral power imbalance score. Output: feature-importance SHAP bar chart (PNG). SHAP KernelExplainer is model-agnostic and works across both models without modification. Token attribution maps directly to legal clause words, making this approach legally interpretable. Usage: from src.explainability import ExplainabilityEngine engine = ExplainabilityEngine() png_path = engine.explain_clause(clause_text, clause_id) """ import sys from pathlib import Path from typing import Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") # headless backend — must be set before pyplot import import matplotlib.pyplot as plt import numpy as np import shap import torch from loguru import logger sys.path.insert(0, str(Path(__file__).parent.parent)) import config from src.clause_classifier import ClauseClassifierInference from src.power_scorer import PowerImbalanceScorer logger.remove() logger.add(config.LOGS_DIR / "explainability.log", rotation="10 MB", level="DEBUG") logger.add(sys.stderr, level="INFO") # 1. CLASSIFIER SHAP EXPLAINER class ClassifierSHAPExplainer: """Generates token-level SHAP explanations for clause type predictions. Uses SHAP KernelExplainer with a bag-of-words input representation. The model wrapper converts word-masked inputs back to text before calling the classifier, enabling true token-level attribution. Note: KernelExplainer is slow (~30s per clause). For production, consider caching explanations after first computation. """ def __init__(self, classifier: Optional[ClauseClassifierInference] = None): self.classifier = classifier or ClauseClassifierInference() def _build_word_mask_fn(self, text: str, target_label_idx: int): """Build a SHAP-compatible prediction function using word masking. SHAP KernelExplainer requires a function f(X) → R^n where X is a binary mask matrix over input features (words here). We mask words out of the original text and run the classifier on the masked version. Args: text: Original clause text. target_label_idx: Index of the target clause type in CUAD_CLAUSE_TYPES. Returns: Tuple of (predict_fn, words) where predict_fn maps masks → probabilities. """ words = text.split() def predict_fn(mask_matrix: np.ndarray) -> np.ndarray: """Classify masked clause texts and return target class probability. Args: mask_matrix: Binary matrix of shape (n_samples, n_words). 1 = keep word, 0 = mask to [MASK]. Returns: 1D array of target class probabilities, shape (n_samples,). """ texts_batch = [] for mask_row in mask_matrix: masked_words = [ w if mask_row[i] == 1 else "[MASK]" for i, w in enumerate(words) ] texts_batch.append(" ".join(masked_words)) preds = self.classifier.predict(texts_batch, threshold=0.0) probs = np.array([ p["probabilities"].get(config.CUAD_CLAUSE_TYPES[target_label_idx], 0.0) for p in preds ]) return probs return predict_fn, words def explain( self, clause_text: str, target_clause_type: str, n_background: int = config.SHAP_BACKGROUND_SAMPLES, max_evals: int = config.SHAP_MAX_EVALS, ) -> Tuple[np.ndarray, List[str]]: """Compute token-level SHAP values for a clause. Args: clause_text: Raw clause text to explain. target_clause_type: Name of the CUAD clause type to explain. n_background: Number of background samples for KernelExplainer. max_evals: Maximum model evaluations (controls accuracy vs speed). Returns: Tuple of (shap_values, words) where shap_values is a 1D array of per-word attribution scores aligned with the words list. Raises: ValueError: If target_clause_type is not in CUAD_CLAUSE_TYPES. """ if target_clause_type not in config.CUAD_CLAUSE_TYPES: raise ValueError( f"Unknown clause type: {target_clause_type}. " f"Must be one of {config.CUAD_CLAUSE_TYPES}" ) target_idx = config.CUAD_CLAUSE_TYPES.index(target_clause_type) predict_fn, words = self._build_word_mask_fn(clause_text, target_idx) n_words = len(words) if n_words == 0: return np.array([]), [] # Background data: random binary masks (represent "average" input) rng = np.random.RandomState(config.RANDOM_SEED) background = rng.randint(0, 2, size=(min(n_background, 50), n_words)).astype(float) # Single instance to explain: all words present instance = np.ones((1, n_words)) explainer = shap.KernelExplainer(predict_fn, background) shap_values = explainer.shap_values( instance, nsamples=min(max_evals, 200), silent=True, ) return shap_values[0], words def plot_and_save( self, shap_values: np.ndarray, words: List[str], clause_id: str, clause_type: str, top_n: int = 20, ) -> Path: """Generate and save a token-level SHAP bar chart as PNG. Args: shap_values: SHAP attribution values (1D array, len = n_words). words: Word list aligned with shap_values. clause_id: Unique clause identifier (used in filename). clause_type: Clause type label for the plot title. top_n: Number of top words to display. Returns: Path to the saved PNG file. """ if len(shap_values) == 0 or len(words) == 0: logger.warning(f"Empty SHAP values for clause {clause_id}. Skipping plot.") return None # Select top N words by absolute SHAP value top_indices = np.argsort(np.abs(shap_values))[-top_n:][::-1] top_words = [words[i] for i in top_indices] top_vals = shap_values[top_indices] colors = ["#C0392B" if v > 0 else "#2980B9" for v in top_vals] fig, ax = plt.subplots(figsize=(10, 6), facecolor="#0D1B2A") ax.set_facecolor("#0D1B2A") bars = ax.barh(range(len(top_words)), top_vals, color=colors, edgecolor="none") ax.set_yticks(range(len(top_words))) ax.set_yticklabels(top_words, fontsize=10, color="#F0E68C") ax.set_xlabel("SHAP Value (Token Attribution)", color="#F0E68C", fontsize=11) ax.set_title( f"Token-Level SHAP: {clause_type}\n(Clause ID: {clause_id[:8]}...)", color="#F0E68C", fontsize=13, pad=12, ) ax.tick_params(colors="#F0E68C") ax.spines["bottom"].set_color("#F0E68C") ax.spines["left"].set_color("#F0E68C") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.axvline(0, color="#F0E68C", linewidth=0.8, alpha=0.5) # Add legend from matplotlib.patches import Patch legend = [ Patch(color="#C0392B", label="Pushes toward Party A"), Patch(color="#2980B9", label="Pushes toward Party B"), ] ax.legend(handles=legend, loc="lower right", facecolor="#0D1B2A", labelcolor="#F0E68C", edgecolor="#F0E68C") plt.tight_layout() output_path = config.SHAP_OUTPUT_DIR / f"shap_classifier_{clause_id}.png" plt.savefig(str(output_path), dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) logger.info(f"SHAP plot saved: {output_path}") return output_path # 2. POWER IMBALANCE SHAP EXPLAINER class PowerImbalanceSHAPExplainer: """Generates feature-level SHAP explanations for power imbalance scores. Features: sentiment, modal_verbs, obligations, assertiveness, length. Shows which features drove the clause's power imbalance toward Party A or Party B. """ FEATURE_NAMES = [ "sentiment_score", "modal_score", "obligation_score", "assertiveness_score", "length_score", ] def __init__(self, power_scorer: Optional[PowerImbalanceScorer] = None): self.power_scorer = power_scorer or PowerImbalanceScorer() def _build_predict_fn(self, base_text: str): """Build a prediction function that maps feature values → imbalance score. Since we want to explain at the feature level (not token level), we perturb individual feature values and observe the imbalance change. Args: base_text: The clause text being explained. Returns: Tuple of (predict_fn, background, base_features). """ # Get base feature values scores = self.power_scorer.score([base_text]) base = scores[0] base_features = np.array([ base["sentiment_score"], base["modal_score"], base["obligation_score"], base["assertiveness_score"], base["length_score"], ]) def predict_fn(feature_matrix: np.ndarray) -> np.ndarray: """Map perturbed feature vectors to predicted imbalance scores. Args: feature_matrix: Shape (n_samples, n_features). Returns: 1D imbalance scores, shape (n_samples,). """ results = [] for feature_row in feature_matrix: s, m, o, a, ln = feature_row party_a_raw = ( config.POWER_WEIGHT_SENTIMENT * (1.0 - s) + config.POWER_WEIGHT_MODAL_VERBS * m + config.POWER_WEIGHT_OBLIGATIONS * o + config.POWER_WEIGHT_ASSERTIVENESS * a ) party_b_raw = ( config.POWER_WEIGHT_SENTIMENT * s + config.POWER_WEIGHT_MODAL_VERBS * (1.0 - m) + config.POWER_WEIGHT_OBLIGATIONS * (1.0 - o) + config.POWER_WEIGHT_ASSERTIVENESS * (1.0 - a) ) amplifier = 0.8 + 0.4 * ln party_a = float(np.clip(party_a_raw * amplifier * 100, 0, 100)) party_b = float(np.clip(party_b_raw * amplifier * 100, 0, 100)) results.append(party_a - party_b) return np.array(results) # Background: all-0.5 (neutral feature values) background = np.full((1, 5), 0.5) return predict_fn, background, base_features def explain(self, clause_text: str) -> Tuple[np.ndarray, np.ndarray]: """Compute feature-level SHAP values for power imbalance. Args: clause_text: Raw clause text. Returns: Tuple of (shap_values, base_features) both as 1D arrays. """ predict_fn, background, base_features = self._build_predict_fn(clause_text) explainer = shap.KernelExplainer(predict_fn, background) shap_values = explainer.shap_values( base_features.reshape(1, -1), nsamples=100, silent=True, ) return shap_values[0], base_features def plot_and_save( self, shap_values: np.ndarray, base_features: np.ndarray, clause_id: str, ) -> Path: """Generate and save a feature-level SHAP plot as PNG. Args: shap_values: SHAP values for each feature (1D, len=5). base_features: Actual feature values (1D, len=5). clause_id: Clause identifier for filename. Returns: Path to the saved PNG file. """ feature_labels = [ f"Sentiment\n({base_features[0]:.2f})", f"Modal Verbs\n({base_features[1]:.2f})", f"Obligations\n({base_features[2]:.2f})", f"Assertiveness\n({base_features[3]:.2f})", f"Length\n({base_features[4]:.2f})", ] colors = ["#C0392B" if v > 0 else "#2980B9" for v in shap_values] fig, ax = plt.subplots(figsize=(9, 5), facecolor="#0D1B2A") ax.set_facecolor("#0D1B2A") ax.barh(feature_labels, shap_values, color=colors, edgecolor="none") ax.set_xlabel("SHAP Value (→ Party A | ← Party B)", color="#F0E68C", fontsize=11) ax.set_title( f"Feature Contributions to Power Imbalance\n(Clause: {clause_id[:8]}...)", color="#F0E68C", fontsize=13, pad=12, ) ax.tick_params(colors="#F0E68C") ax.spines["bottom"].set_color("#F0E68C") ax.spines["left"].set_color("#F0E68C") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.axvline(0, color="#F0E68C", linewidth=0.8, alpha=0.5) ax.set_yticklabels(feature_labels, color="#F0E68C", fontsize=10) from matplotlib.patches import Patch legend = [ Patch(color="#C0392B", label="Favours Party A"), Patch(color="#2980B9", label="Favours Party B"), ] ax.legend(handles=legend, facecolor="#0D1B2A", labelcolor="#F0E68C", edgecolor="#F0E68C", loc="lower right") plt.tight_layout() output_path = config.SHAP_OUTPUT_DIR / f"shap_power_{clause_id}.png" plt.savefig(str(output_path), dpi=150, bbox_inches="tight", facecolor=fig.get_facecolor()) plt.close(fig) logger.info(f"Power SHAP plot saved: {output_path}") return output_path # 3. UNIFIED EXPLAINABILITY ENGINE class ExplainabilityEngine: """Unified interface for generating both classifier and power SHAP explanations. Lazily initialises sub-explainers to avoid loading heavy models unnecessarily. """ def __init__(self): self._classifier_explainer: Optional[ClassifierSHAPExplainer] = None self._power_explainer: Optional[PowerImbalanceSHAPExplainer] = None @property def classifier_explainer(self) -> ClassifierSHAPExplainer: """Lazy-load classifier SHAP explainer.""" if self._classifier_explainer is None: self._classifier_explainer = ClassifierSHAPExplainer() return self._classifier_explainer @property def power_explainer(self) -> PowerImbalanceSHAPExplainer: """Lazy-load power imbalance SHAP explainer.""" if self._power_explainer is None: self._power_explainer = PowerImbalanceSHAPExplainer() return self._power_explainer def explain_clause( self, clause_text: str, clause_id: str, clause_type: Optional[str] = None, ) -> Dict: """Generate both classifier and power SHAP explanations for a clause. Args: clause_text: Raw clause text. clause_id: Unique clause identifier. clause_type: Target clause type for classifier explanation. If None, uses the highest-probability predicted type. Returns: Dict with: 'classifier_shap_path': Path to classifier SHAP PNG (or None) 'power_shap_path': Path to power SHAP PNG 'classifier_shap_values': list of (word, shap_value) pairs 'power_shap_values': dict of feature → shap_value """ result: Dict = { "classifier_shap_path": None, "power_shap_path": None, "classifier_shap_values": [], "power_shap_values": {}, } # --- Classifier SHAP --- try: if clause_type is None: # Predict and use top clause type pred = self.classifier_explainer.classifier.predict_single(clause_text) if pred["clause_types"]: clause_type = pred["clause_types"][0] if clause_type: shap_vals, words = self.classifier_explainer.explain( clause_text, clause_type ) png_path = self.classifier_explainer.plot_and_save( shap_vals, words, clause_id, clause_type ) result["classifier_shap_path"] = png_path.as_posix() if png_path else None result["classifier_shap_values"] = [ {"word": w, "shap_value": float(v)} for w, v in zip(words, shap_vals) ] except Exception as exc: logger.warning(f"Classifier SHAP failed for {clause_id}: {exc}") # --- Power Imbalance SHAP --- try: power_vals, base_feats = self.power_explainer.explain(clause_text) png_path = self.power_explainer.plot_and_save(power_vals, base_feats, clause_id) result["power_shap_path"] = png_path.as_posix() if png_path else None result["power_shap_values"] = { name: float(val) for name, val in zip( PowerImbalanceSHAPExplainer.FEATURE_NAMES, power_vals ) } except Exception as exc: logger.warning(f"Power SHAP failed for {clause_id}: {exc}") return result def explain_contract( self, contract_id: str, max_clauses: int = 10 ) -> List[Dict]: """Generate SHAP explanations for up to max_clauses clauses in a contract. Limited to max_clauses to keep generation time reasonable. Args: contract_id: Contract identifier. max_clauses: Maximum number of clauses to explain. Returns: List of explanation dicts (one per explained clause). """ from api.database import Clause, SessionLocal, create_tables create_tables() with SessionLocal() as session: clauses = ( session.query(Clause) .filter(Clause.contract_id == contract_id) .limit(max_clauses) .all() ) explanations = [] for clause in clauses: logger.info(f"Explaining clause {clause.clause_id}...") exp = self.explain_clause( clause_text=clause.clause_text, clause_id=clause.clause_id, clause_type=clause.clause_type.split("|")[0] if clause.clause_type else None, ) # Persist SHAP plot paths to database with SessionLocal() as session: db_clause = session.get(Clause, clause.clause_id) if db_clause and exp.get("classifier_shap_path"): db_clause.shap_plot_path = exp["classifier_shap_path"] session.commit() exp["clause_id"] = clause.clause_id explanations.append(exp) return explanations