| """ |
| Explainability for ScorePredictorModel. |
| |
| Given a conversation text, shows *which tokens* drive each predicted score |
| and *how much* they contribute. Two attribution methods are provided: |
| |
| integrated_gradients – gradient-based (most faithful, slower) |
| attention_rollout – attention-based (fast, good overview) |
| |
| Quick start |
| ----------- |
| from explain_score_predictor import ScorePredictorExplainer |
| |
| explainer = ScorePredictorExplainer.from_pretrained("path/to/model") |
| |
| # Get attributions from raw text |
| result = explainer.explain("User: Hello Assistant: Hi there!") |
| |
| # Print a readable summary |
| print(explainer.format(result)) |
| |
| # Save a publication-quality figure |
| explainer.plot(result, save_path="attributions.pdf") |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
| from typing import Dict, List, Literal, Optional, Tuple |
|
|
| import torch |
| import numpy as np |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ExplainabilityOutput: |
| """ |
| Everything ``explain()`` returns. |
| |
| Attributes |
| ---------- |
| text : str |
| Original input text. |
| tokens : List[str] |
| Tokenised input (human-readable sub-words). |
| predictions : Dict[str, float] |
| Predicted score per dimension (e.g. {"clarity": 3.8, …}). |
| attributions : Dict[str, List[float]] |
| Per-token attribution for each score dimension. |
| Length of inner list == len(tokens). |
| method : str |
| Attribution method used. |
| """ |
| text: str = "" |
| tokens: List[str] = field(default_factory=list) |
| predictions: Dict[str, float] = field(default_factory=dict) |
| attributions: Dict[str, List[float]] = field(default_factory=dict) |
| method: str = "" |
|
|
|
|
| |
| |
| |
|
|
| class ScorePredictorExplainer: |
| """ |
| Wraps a ``ScorePredictorModel`` and provides token-level explanations. |
| |
| Parameters |
| ---------- |
| model : ScorePredictorModel |
| A loaded model instance. |
| tokenizer |
| The matching tokenizer. |
| device : str or torch.device, optional |
| Defaults to the model's current device. |
| """ |
|
|
| def __init__(self, model, tokenizer, device: Optional[torch.device] = None): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = device or next(model.parameters()).device |
| self.score_names: List[str] = list(model.config.score_names) |
| self.num_scores: int = model.num_scores |
| self.model.eval() |
|
|
| |
| |
| |
|
|
| @classmethod |
| def from_pretrained(cls, model_path: str, device: str = "auto") -> "ScorePredictorExplainer": |
| """ |
| Load model + tokenizer from a saved checkpoint in one call. |
| |
| Parameters |
| ---------- |
| model_path : str |
| Path (or HF hub id) to the saved model directory. |
| device : str |
| ``"auto"`` picks GPU if available, else CPU. |
| """ |
| from transformers import AutoConfig, AutoModel, AutoTokenizer |
|
|
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| model = AutoModel.from_pretrained( |
| model_path, config=config, trust_remote_code=True |
| ) |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
|
| if device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
| model.eval() |
| return cls(model, tokenizer, torch.device(device)) |
|
|
| |
| |
| |
|
|
| def explain( |
| self, |
| text: str, |
| *, |
| method: Literal["integrated_gradients", "attention_rollout"] = "integrated_gradients", |
| n_steps: int = 30, |
| ) -> ExplainabilityOutput: |
| """ |
| Explain a single text input. |
| |
| Parameters |
| ---------- |
| text : str |
| The conversation / sentence to score and explain. |
| method : str |
| ``"integrated_gradients"`` (default, most accurate) or |
| ``"attention_rollout"`` (faster, attention-based). |
| n_steps : int |
| Riemann-sum steps for integrated gradients (ignored for rollout). |
| |
| Returns |
| ------- |
| ExplainabilityOutput |
| """ |
| |
| enc = self.tokenizer( |
| text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=getattr(self.model.config, "max_position_embeddings", 512), |
| ) |
| input_ids = enc["input_ids"].to(self.device) |
| attention_mask = enc["attention_mask"].to(self.device) |
|
|
| |
| tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) |
|
|
| |
| with torch.no_grad(): |
| base_out = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| preds = base_out.predictions[0].cpu().tolist() |
| predictions = {name: round(v, 4) for name, v in zip(self.score_names, preds)} |
|
|
| |
| if method == "integrated_gradients": |
| raw_attr = self._integrated_gradients(input_ids, attention_mask, n_steps) |
| elif method == "attention_rollout": |
| raw_attr = self._attention_rollout(input_ids, attention_mask) |
| else: |
| raise ValueError( |
| f"Unknown method '{method}'. " |
| "Choose 'integrated_gradients' or 'attention_rollout'." |
| ) |
|
|
| |
| |
| |
| output_start = _find_output_token_idx(tokens) |
| if output_start is not None: |
| for name in raw_attr: |
| raw_attr[name][0, :output_start] = 0.0 |
| |
| total = raw_attr[name][0].sum() |
| if total > 0: |
| raw_attr[name][0] /= total |
|
|
| |
| attributions = { |
| name: [round(float(v), 6) for v in attr[0]] |
| for name, attr in raw_attr.items() |
| } |
|
|
| return ExplainabilityOutput( |
| text=text, |
| tokens=tokens, |
| predictions=predictions, |
| attributions=attributions, |
| method=method, |
| ) |
|
|
| |
| |
| |
|
|
| def _integrated_gradients( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| n_steps: int, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Integral of d(score)/d(embedding) along a straight path from a zero |
| baseline to the actual input embedding. |
| |
| Returns Dict[score_name -> Tensor[1, seq_len]]. |
| """ |
| input_emb = self.model.get_input_embeddings()(input_ids).detach() |
| baseline_emb = torch.zeros_like(input_emb) |
| delta = input_emb - baseline_emb |
| alphas = torch.linspace(0.0, 1.0, n_steps, device=self.device) |
|
|
| accum = {name: torch.zeros_like(input_emb) for name in self.score_names} |
|
|
| for alpha in alphas: |
| interp = (baseline_emb + alpha * delta).requires_grad_(True) |
| preds = self._forward_from_embeddings(interp, attention_mask) |
|
|
| for i, name in enumerate(self.score_names): |
| (grad,) = torch.autograd.grad( |
| preds[:, i].sum(), |
| interp, |
| retain_graph=(i < self.num_scores - 1), |
| ) |
| accum[name] += grad.detach() |
|
|
| attributions: Dict[str, torch.Tensor] = {} |
| for name in self.score_names: |
| ig = (delta * accum[name] / n_steps).norm(dim=-1) |
| ig = ig * attention_mask.float() |
| ig = ig / ig.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
| attributions[name] = ig |
|
|
| return attributions |
|
|
| |
| |
| |
|
|
| def _attention_rollout( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Propagate attention through all layers, accounting for residual |
| connections. Token importance = attention flowing from CLS to each |
| token in the final rolled-out matrix. |
| |
| Returns Dict[score_name -> Tensor[1, seq_len]]. |
| """ |
| attentions = self._get_attentions(input_ids, attention_mask) |
| B, L = attention_mask.shape |
| dummy = torch.zeros(B, L, device=self.device) |
|
|
| if not attentions: |
| return {n: dummy for n in self.score_names} |
|
|
| rollout = torch.eye(L, device=self.device).unsqueeze(0).expand(B, -1, -1).clone() |
| mask_2d = attention_mask.unsqueeze(-1).float() * attention_mask.unsqueeze(-2).float() |
|
|
| for layer_attn in attentions: |
| if layer_attn is None or layer_attn.dim() != 4: |
| continue |
| attn = layer_attn.mean(dim=1) |
| attn = attn + torch.eye(L, device=self.device).unsqueeze(0) |
| attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
| attn = attn * mask_2d |
| rollout = torch.bmm(attn, rollout) |
|
|
| final = rollout[:, 0, :] * attention_mask.float() |
| final = final / final.sum(dim=-1, keepdim=True).clamp_min(1e-9) |
|
|
| return {n: final.clone() for n in self.score_names} |
|
|
| |
| |
| |
|
|
| def _forward_from_embeddings( |
| self, embeddings: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Full forward pass from pre-computed embeddings -> [B, num_scores].""" |
| backbone_out = self.model.backbone( |
| inputs_embeds=embeddings, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| hidden = backbone_out.last_hidden_state |
| pooled = self.model._pool_hidden_states(hidden, attention_mask) |
|
|
| target_dtype = next(self.model.score_heads[0].parameters()).dtype |
| pooled = pooled.to(target_dtype) |
|
|
| if self.model.shared_encoder is not None: |
| features = self.model.shared_encoder(pooled) |
| else: |
| features = pooled |
|
|
| preds = torch.cat( |
| [1.0 + 4.0 * torch.sigmoid(head(features)) for head in self.model.score_heads], |
| dim=-1, |
| ) |
| return preds |
|
|
| def _get_attentions( |
| self, input_ids: torch.Tensor, attention_mask: torch.Tensor |
| ) -> Optional[Tuple[torch.Tensor, ...]]: |
| """Retrieve attention weights from the backbone (no-grad).""" |
| try: |
| with torch.no_grad(): |
| out = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=True, |
| return_dict=True, |
| ) |
| return out.attentions |
| except Exception: |
| return None |
|
|
| |
| |
| |
|
|
| def format( |
| self, |
| result: ExplainabilityOutput, |
| top_k: int = 10, |
| score_name: Optional[str] = None, |
| ) -> str: |
| """ |
| Readable plain-text summary of the explanation. |
| |
| Shows whole words (sub-words merged) with percentage attributions. |
| Special tokens ([CLS], [SEP], …) are excluded. |
| |
| Parameters |
| ---------- |
| result : ExplainabilityOutput |
| top_k : int |
| How many top words to show per score. |
| score_name : str, optional |
| Show only this score (default: all). |
| """ |
| lines: List[str] = [] |
| sep = "-" * 44 |
|
|
| |
| lines.append("Predicted scores:") |
| for name, val in result.predictions.items(): |
| lines.append(f" {name:<20} {val:.4f}") |
| lines.append("") |
|
|
| |
| scores_to_show = [score_name] if score_name else self.score_names |
| for sn in scores_to_show: |
| if sn not in result.attributions: |
| continue |
| words = _merge_subwords(result.tokens, result.attributions[sn]) |
| top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k] |
|
|
| lines.append(f"-- {sn} ({result.method}) --") |
| lines.append(f"{'Word':<28} {'Importance':>12}") |
| lines.append(sep) |
| for word, pct in top: |
| bar = "\u2588" * int(pct / 2) |
| lines.append(f"{word:<28} {pct:>5.1f}% {bar}") |
| lines.append("") |
|
|
| return "\n".join(lines) |
|
|
| |
| |
| |
|
|
| def to_html( |
| self, |
| result: ExplainabilityOutput, |
| score_name: Optional[str] = None, |
| ) -> str: |
| """ |
| HTML span-highlighted attribution view. |
| |
| Tokens are coloured white -> gold proportional to their importance. |
| """ |
| sn = score_name or self.score_names[0] |
| if sn not in result.attributions: |
| return f"<p><em>Score '{sn}' not found.</em></p>" |
|
|
| attrs = result.attributions[sn] |
| a_min, a_max = min(attrs), max(attrs) |
| rng = a_max - a_min if abs(a_max - a_min) > 1e-9 else 1.0 |
|
|
| spans: List[str] = [] |
| for tok, val in zip(result.tokens, attrs): |
| w = max(0.0, min(1.0, (val - a_min) / rng)) |
| r, g, b = 255, int(255 * (1 - 0.16 * w)), int(255 * (1 - w)) |
| tok_disp = _clean_token(tok).replace("<", "<").replace(">", ">") |
| spans.append( |
| f'<span style="background:rgb({r},{g},{b});padding:1px 3px;' |
| f'border-radius:3px" title="{val:.4f}">{tok_disp}</span>' |
| ) |
|
|
| pred_str = "" |
| if sn in result.predictions: |
| pred_str = f"<p><b>{sn}</b>: {result.predictions[sn]:.4f}</p>" |
|
|
| return ( |
| f"<div style='font-family:monospace;line-height:2'>" |
| f"{pred_str}<p>{' '.join(spans)}</p></div>" |
| ) |
|
|
| |
| |
| |
|
|
| def plot( |
| self, |
| result: ExplainabilityOutput, |
| top_k: int = 15, |
| score_name: Optional[str] = None, |
| figsize: Optional[tuple] = None, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Horizontal bar chart of the top-k most important **words** |
| (sub-words merged, specials removed) per score, shown as percentages. |
| |
| Parameters |
| ---------- |
| result : ExplainabilityOutput |
| top_k : int |
| Words to display per score. |
| score_name : str, optional |
| Single score only (default: one subplot per score). |
| save_path : str, optional |
| Save figure to this path. |
| |
| Returns |
| ------- |
| matplotlib.figure.Figure |
| """ |
| import matplotlib.pyplot as plt |
|
|
| scores = [score_name] if score_name else self.score_names |
| n = len(scores) |
| colours = ["#4C72B0", "#DD8452", "#55A868", "#C44E52"] |
|
|
| w = figsize[0] if figsize else 7 |
| h = figsize[1] if figsize else 2.8 * n |
| fig, axes = plt.subplots(n, 1, figsize=(w, h)) |
| if n == 1: |
| axes = [axes] |
|
|
| for ax, sn, colour in zip(axes, scores, colours * 4): |
| if sn not in result.attributions: |
| ax.set_visible(False) |
| continue |
| words = _merge_subwords(result.tokens, result.attributions[sn]) |
| top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k] |
| labels = [w for w, _ in top] |
| pcts = np.array([p for _, p in top]) |
|
|
| bars = ax.barh(range(len(pcts)), pcts, color=colour, |
| edgecolor="white", linewidth=0.4, height=0.72) |
| ax.set_yticks(range(len(labels))) |
| ax.set_yticklabels(labels, fontsize=9) |
| ax.invert_yaxis() |
| ax.set_xlabel("Importance (%)") |
| ax.set_xlim(0, pcts[0] * 1.25 if len(pcts) else 10) |
| pred_val = result.predictions.get(sn, 0) |
| ax.set_title(f"{sn.capitalize()} (predicted: {pred_val:.2f})", |
| fontweight="bold", fontsize=10) |
| |
| for bar, pct in zip(bars, pcts): |
| ax.text(bar.get_width() + pcts[0] * 0.02, |
| bar.get_y() + bar.get_height() / 2, |
| f"{pct:.1f}%", va="center", fontsize=8, color="#333") |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
|
|
| fig.suptitle(f"Word Importance ({result.method.replace('_', ' ').title()})", |
| fontsize=12, fontweight="bold", y=1.01) |
| plt.tight_layout() |
| if save_path: |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") |
| return fig |
|
|
| def plot_heatmap( |
| self, |
| result: ExplainabilityOutput, |
| top_k: int = 25, |
| figsize: Optional[tuple] = None, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Heatmap: scores (rows) x top-k words (columns). |
| |
| Each cell shows the relative importance of a word for a given score |
| dimension, row-normalised so that each score's max = 1. |
| |
| Returns |
| ------- |
| matplotlib.figure.Figure |
| """ |
| import matplotlib.pyplot as plt |
| from matplotlib.colors import LinearSegmentedColormap |
|
|
| cmap = LinearSegmentedColormap.from_list( |
| "attr", ["#FFFFFF", "#FFF7CD", "#FFD700", "#FF6B00", "#8B0000"] |
| ) |
|
|
| |
| merged: Dict[str, Dict[str, float]] = {} |
| for sn in self.score_names: |
| if sn not in result.attributions: |
| continue |
| words = _merge_subwords(result.tokens, result.attributions[sn]) |
| merged[sn] = {w: p for w, p in words} |
|
|
| |
| all_words: Dict[str, float] = {} |
| for word_dict in merged.values(): |
| for w, p in word_dict.items(): |
| all_words[w] = all_words.get(w, 0) + p |
| ranked = sorted(all_words, key=all_words.get, reverse=True)[:top_k] |
|
|
| matrix = np.array([ |
| [merged.get(sn, {}).get(w, 0) for w in ranked] |
| for sn in self.score_names if sn in merged |
| ]) |
| row_max = matrix.max(axis=1, keepdims=True) |
| row_max[row_max == 0] = 1 |
| matrix = matrix / row_max |
|
|
| w = figsize[0] if figsize else max(10, top_k * 0.38) |
| h = figsize[1] if figsize else 2.4 |
| fig, ax = plt.subplots(figsize=(w, h)) |
|
|
| im = ax.imshow(matrix, aspect="auto", cmap=cmap, vmin=0, vmax=1, |
| interpolation="nearest") |
| ax.set_xticks(range(len(ranked))) |
| ax.set_xticklabels(ranked, rotation=45, ha="right", fontsize=8) |
| valid_names = [s for s in self.score_names if s in merged] |
| ax.set_yticks(range(len(valid_names))) |
| ax.set_yticklabels([s.capitalize() for s in valid_names], fontsize=9) |
| ax.set_xlabel("Word (ranked by aggregate importance)") |
| cb = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) |
| cb.set_label("Relative importance", fontsize=8) |
| ax.set_title("Word Importance Across Score Dimensions", |
| fontsize=10, fontweight="bold", pad=8) |
| plt.tight_layout() |
| if save_path: |
| fig.savefig(save_path, dpi=300, bbox_inches="tight") |
| return fig |
|
|
| def plot_summary( |
| self, |
| result: ExplainabilityOutput, |
| top_k: int = 10, |
| output_only: bool = True, |
| figsize: tuple = (16, 14), |
| save_path: Optional[str] = None, |
| ): |
| """ |
| Publication-quality composite figure. |
| |
| Layout:: |
| |
| ┌──────────────────────────────────────────────────┐ |
| │ Title + colour legend bar │ |
| ├──────────────────────────────────────────────────┤ |
| │ Task / Input context box │ |
| ├──────────────────────┬───────────────────────────┤ |
| │ Highlighted output │ Top-k bar chart │ × n_scores |
| └──────────────────────┴───────────────────────────┘ |
| |
| Parameters |
| ---------- |
| result : ExplainabilityOutput |
| top_k : int |
| Words per bar chart. |
| output_only : bool |
| If True (default), only highlight text after the last |
| ``Output:`` / ``Answer:`` marker. |
| figsize : tuple |
| Figure size. |
| save_path : str, optional |
| Save path. |
| |
| Returns |
| ------- |
| matplotlib.figure.Figure |
| """ |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| import textwrap |
|
|
| colours = ["#3B6FA0", "#D07830", "#3D9050", "#BB3B3B"] |
| light_bg = ["#EBF0F7", "#FDF3EB", "#EBF5EE", "#F8EBEB"] |
| n_scores = len(self.score_names) |
|
|
| fig = plt.figure(figsize=figsize, facecolor="white") |
|
|
| outer = gridspec.GridSpec( |
| n_scores + 2, 1, figure=fig, |
| height_ratios=[0.15, 0.25] + [1] * n_scores, |
| hspace=0.28, |
| ) |
|
|
| |
| ax_title = fig.add_subplot(outer[0]) |
| ax_title.axis("off") |
| method_label = result.method.replace("_", " ").title() |
| ax_title.text( |
| 0.5, 0.65, |
| f"OmniScore Explanation \u2014 {method_label}", |
| transform=ax_title.transAxes, fontsize=14, fontweight="bold", |
| ha="center", va="center", color="#222", |
| ) |
| |
| import matplotlib.colors as mcolors |
| grad = np.linspace(0, 1, 256).reshape(1, -1) |
| cmap_legend = mcolors.LinearSegmentedColormap.from_list( |
| "_lg", ["#F0F0F0", "#FDDC6C", "#E8792B", "#9E2320"] |
| ) |
| ax_cbar = fig.add_axes([0.32, 0.945, 0.36, 0.012]) |
| ax_cbar.imshow(grad, aspect="auto", cmap=cmap_legend) |
| ax_cbar.set_xticks([]) |
| ax_cbar.set_yticks([]) |
| for spine in ax_cbar.spines.values(): |
| spine.set_visible(False) |
| fig.text(0.31, 0.950, "Low", fontsize=7.5, ha="right", color="#888") |
| fig.text(0.69, 0.950, "High", fontsize=7.5, ha="left", color="#888") |
|
|
| |
| ax_ctx = fig.add_subplot(outer[1]) |
| ax_ctx.axis("off") |
|
|
| raw = result.text |
| task_str, input_str = "", "" |
| for line in raw.split("\n"): |
| s = line.strip() |
| if s.lower().startswith("task:"): |
| task_str = s[5:].strip() |
| elif s.lower().startswith("input:"): |
| input_str = s[6:].strip() |
|
|
| ctx_parts: List[str] = [] |
| if task_str: |
| ctx_parts.append(f"Task: {task_str}") |
| if input_str: |
| ctx_parts.append(f"Input: {textwrap.fill(input_str, width=105)}") |
| ctx_text = "\n".join(ctx_parts) if ctx_parts else raw[:200] |
|
|
| ax_ctx.text( |
| 0.02, 0.85, ctx_text, |
| transform=ax_ctx.transAxes, fontsize=8.5, va="top", |
| fontfamily="monospace", color="#333", linespacing=1.6, |
| bbox=dict( |
| boxstyle="round,pad=0.6", facecolor="#FAFAFA", |
| edgecolor="#D0D0D0", linewidth=0.7, |
| ), |
| ) |
|
|
| |
| for idx, sn in enumerate(self.score_names): |
| if sn not in result.attributions: |
| continue |
|
|
| colour = colours[idx % len(colours)] |
| bg_colour = light_bg[idx % len(light_bg)] |
| base_rgb = np.array([ |
| int(colour[i:i+2], 16) / 255 for i in (1, 3, 5) |
| ]) |
|
|
| all_words = _merge_subwords(result.tokens, result.attributions[sn]) |
| display_words = _extract_output_words(all_words) if output_only else list(all_words) |
|
|
| word_names = [w for w, _ in display_words] |
| pcts = np.array([p for _, p in display_words]) |
| pmax = pcts.max() if len(pcts) and pcts.max() > 0 else 1.0 |
| norms = pcts / pmax |
|
|
| pred_val = result.predictions.get(sn, 0) |
|
|
| inner = gridspec.GridSpecFromSubplotSpec( |
| 1, 2, subplot_spec=outer[idx + 2], |
| width_ratios=[1.6, 1], wspace=0.22, |
| ) |
|
|
| |
| ax_text = fig.add_subplot(inner[0]) |
| ax_text.axis("off") |
| ax_text.set_xlim(0, 1) |
| ax_text.set_ylim(0, 1) |
|
|
| |
| from matplotlib.patches import FancyBboxPatch |
| ax_text.add_patch(FancyBboxPatch( |
| (0, 0), 1, 1, boxstyle="round,pad=0.02", |
| facecolor=bg_colour, edgecolor="#ddd", linewidth=0.6, |
| transform=ax_text.transAxes, clip_on=False, |
| )) |
|
|
| |
| ax_text.text( |
| 0.02, 0.96, |
| f"{sn.capitalize()} \u2014 predicted {pred_val:.2f} / 5", |
| transform=ax_text.transAxes, fontsize=10, |
| fontweight="bold", va="top", color=colour, |
| ) |
|
|
| |
| renderer = fig.canvas.get_renderer() |
| x, y = 0.02, 0.84 |
| line_h = 0.085 |
| gap = 0.005 |
|
|
| for w, nv in zip(word_names, norms): |
| |
| intensity = nv ** 0.55 |
| bg = tuple(1.0 + (base_rgb[c] - 1.0) * intensity for c in range(3)) |
| |
| txt_col = "#222" if intensity < 0.7 else "#fff" |
| edge = colour if intensity > 0.35 else "none" |
|
|
| t = ax_text.text( |
| x, y, f" {w} ", |
| transform=ax_text.transAxes, fontsize=9, |
| va="top", fontfamily="sans-serif", color=txt_col, |
| bbox=dict( |
| boxstyle="round,pad=0.18", facecolor=bg, |
| edgecolor=edge, linewidth=0.6 if edge != "none" else 0, |
| ), |
| ) |
| bb = t.get_window_extent(renderer=renderer) |
| bb_ax = bb.transformed(ax_text.transAxes.inverted()) |
| word_w = bb_ax.width + gap |
|
|
| x += word_w |
| if x > 0.97: |
| x = 0.02 |
| y -= line_h |
| if y < 0.0: |
| break |
| t.set_position((x, y)) |
| bb = t.get_window_extent(renderer=renderer) |
| bb_ax = bb.transformed(ax_text.transAxes.inverted()) |
| x = 0.02 + bb_ax.width + gap |
|
|
| |
| ax_bar = fig.add_subplot(inner[1]) |
| top_words = sorted(display_words, key=lambda p: p[1], reverse=True)[:top_k] |
| bar_labels = [w for w, _ in top_words] |
| bar_pcts = np.array([p for _, p in top_words]) |
| bar_norms = bar_pcts / pmax if pmax > 0 else bar_pcts |
|
|
| bar_cols = [ |
| tuple(1.0 + (base_rgb[c] - 1.0) * max(n, 0.15) for c in range(3)) |
| for n in bar_norms |
| ] |
|
|
| bars = ax_bar.barh( |
| range(len(bar_pcts)), bar_pcts, color=bar_cols, |
| edgecolor="white", linewidth=0.6, height=0.72, |
| ) |
| ax_bar.set_yticks(range(len(bar_labels))) |
| ax_bar.set_yticklabels(bar_labels, fontsize=8.5, |
| fontfamily="sans-serif") |
| ax_bar.invert_yaxis() |
| ax_bar.set_xlabel("Importance (%)", fontsize=8) |
| ax_bar.set_xlim(0, bar_pcts[0] * 1.32 if len(bar_pcts) else 10) |
| ax_bar.set_title( |
| f"Top-{top_k} words", fontsize=9, color="#555", pad=6, |
| ) |
| for bar, pct in zip(bars, bar_pcts): |
| ax_bar.text( |
| bar.get_width() + bar_pcts[0] * 0.015, |
| bar.get_y() + bar.get_height() / 2, |
| f"{pct:.1f}%", va="center", fontsize=7.5, color="#444", |
| ) |
| ax_bar.spines["top"].set_visible(False) |
| ax_bar.spines["right"].set_visible(False) |
| ax_bar.tick_params(axis="y", length=0) |
|
|
| if save_path: |
| fig.savefig(save_path, dpi=300, bbox_inches="tight", |
| facecolor="white") |
| return fig |
|
|
|
|
| |
| |
| |
|
|
| |
| _SPECIAL_TOKENS = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", |
| "<s>", "</s>", "<pad>", "<unk>", "<mask>"} |
|
|
| |
| _OUTPUT_MARKERS = {"Output", "Answer", "Response", "output", "answer", "response"} |
|
|
|
|
| def _find_output_token_idx(tokens: List[str]) -> Optional[int]: |
| """ |
| Find the token index where the Output/Answer section begins. |
| |
| Scans for the *last* occurrence of a known output marker token |
| (e.g. "Output", "▁Output", "output") and returns the index of the |
| first content token *after* the marker (skipping ":" if present). |
| |
| Returns ``None`` if no marker is found. |
| """ |
| last_marker = -1 |
| for i, tok in enumerate(tokens): |
| clean = tok.replace("\u2581", "").replace("##", "").strip(":").strip() |
| if clean in _OUTPUT_MARKERS: |
| last_marker = i |
|
|
| if last_marker == -1: |
| return None |
|
|
| |
| start = last_marker + 1 |
| if start < len(tokens): |
| next_clean = tokens[start].replace("\u2581", "").replace("##", "").strip() |
| if next_clean == ":": |
| start += 1 |
|
|
| return start |
|
|
|
|
| def _clean_token(tok: str) -> str: |
| """Strip SentencePiece / WordPiece artefacts for display.""" |
| return ( |
| tok.replace("\u2581", " ") |
| .replace("##", "") |
| .strip() |
| or tok |
| ) |
|
|
|
|
| def _extract_output_words( |
| words: List[Tuple[str, float]], |
| ) -> List[Tuple[str, float]]: |
| """ |
| Return only the words that belong to the Output / Answer section. |
| |
| Scans the word list for the *last* occurrence of a known output marker |
| (e.g. "Output", "Answer") and returns everything after it (excluding |
| the marker word itself and any colon that follows). |
| |
| If no marker is found the full list is returned unchanged. |
| """ |
| last_marker = -1 |
| for i, (w, _) in enumerate(words): |
| clean = w.strip(":").strip() |
| if clean in _OUTPUT_MARKERS: |
| last_marker = i |
|
|
| if last_marker == -1: |
| return words |
|
|
| |
| start = last_marker + 1 |
| if start < len(words) and words[start][0].strip() == ":": |
| start += 1 |
|
|
| result = words[start:] |
| |
| total = sum(p for _, p in result) if result else 1.0 |
| return [(w, p / total * 100.0) for w, p in result] |
|
|
|
|
| |
| |
| _PUNCT_GLUE = set('.,;:!?)]\'\"') |
| _PUNCT_OPEN = set('([\"\'') |
|
|
|
|
| def _merge_subwords( |
| tokens: List[str], |
| attributions: List[float], |
| ) -> List[Tuple[str, float]]: |
| """ |
| Merge sub-word tokens back into whole words and sum their attributions. |
| |
| - WordPiece continuations (``##xyz``) are joined to the preceding word. |
| - SentencePiece tokens starting with ``\u2581`` begin a new word. |
| - Standalone punctuation (``.``, ``,``, ``)``, …) is glued to the |
| preceding word so bar-chart labels stay clean. |
| - Opening brackets/quotes are glued to the *following* word. |
| - Special tokens ([CLS], [SEP], …) are dropped. |
| |
| Returns a list of ``(word, importance_percent)`` sorted by position. |
| Percentages sum to ~100 (before any top-k truncation). |
| """ |
| words: List[str] = [] |
| word_scores: List[float] = [] |
|
|
| for tok, attr in zip(tokens, attributions): |
| if tok in _SPECIAL_TOKENS: |
| continue |
|
|
| |
| if tok.startswith("##"): |
| if words: |
| words[-1] += tok[2:] |
| word_scores[-1] += attr |
| continue |
|
|
| |
| clean = tok.replace("\u2581", "") |
| if not clean: |
| continue |
|
|
| |
| if clean in _PUNCT_GLUE and words: |
| words[-1] += clean |
| word_scores[-1] += attr |
| continue |
|
|
| |
| if clean in _PUNCT_OPEN: |
| words.append(clean) |
| word_scores.append(attr) |
| continue |
|
|
| is_new_word = tok.startswith("\u2581") or not words |
|
|
| if is_new_word or not words: |
| |
| if words and words[-1] in _PUNCT_OPEN: |
| words[-1] += clean |
| word_scores[-1] += attr |
| else: |
| words.append(clean) |
| word_scores.append(attr) |
| else: |
| |
| words[-1] += clean |
| word_scores[-1] += attr |
|
|
| |
| total = sum(word_scores) if word_scores else 1.0 |
| return [(w, s / total * 100.0) for w, s in zip(words, word_scores)] |
|
|