""" Explainability for ScorePredictorModel. Given a conversation text, shows *which tokens* drive each predicted score and *how much* they contribute. Two attribution methods are provided: integrated_gradients – gradient-based (most faithful, slower) attention_rollout – attention-based (fast, good overview) Quick start ----------- from explain_score_predictor import ScorePredictorExplainer explainer = ScorePredictorExplainer.from_pretrained("path/to/model") # Get attributions from raw text result = explainer.explain("User: Hello Assistant: Hi there!") # Print a readable summary print(explainer.format(result)) # Save a publication-quality figure explainer.plot(result, save_path="attributions.pdf") """ from __future__ import annotations from dataclasses import dataclass, field from typing import Dict, List, Literal, Optional, Tuple import torch import numpy as np # --------------------------------------------------------------------------- # Output container # --------------------------------------------------------------------------- @dataclass class ExplainabilityOutput: """ Everything ``explain()`` returns. Attributes ---------- text : str Original input text. tokens : List[str] Tokenised input (human-readable sub-words). predictions : Dict[str, float] Predicted score per dimension (e.g. {"clarity": 3.8, …}). attributions : Dict[str, List[float]] Per-token attribution for each score dimension. Length of inner list == len(tokens). method : str Attribution method used. """ text: str = "" tokens: List[str] = field(default_factory=list) predictions: Dict[str, float] = field(default_factory=dict) attributions: Dict[str, List[float]] = field(default_factory=dict) method: str = "" # --------------------------------------------------------------------------- # Main explainer # --------------------------------------------------------------------------- class ScorePredictorExplainer: """ Wraps a ``ScorePredictorModel`` and provides token-level explanations. Parameters ---------- model : ScorePredictorModel A loaded model instance. tokenizer The matching tokenizer. device : str or torch.device, optional Defaults to the model's current device. """ def __init__(self, model, tokenizer, device: Optional[torch.device] = None): self.model = model self.tokenizer = tokenizer self.device = device or next(model.parameters()).device self.score_names: List[str] = list(model.config.score_names) self.num_scores: int = model.num_scores self.model.eval() # ------------------------------------------------------------------ # Convenience constructor # ------------------------------------------------------------------ @classmethod def from_pretrained(cls, model_path: str, device: str = "auto") -> "ScorePredictorExplainer": """ Load model + tokenizer from a saved checkpoint in one call. Parameters ---------- model_path : str Path (or HF hub id) to the saved model directory. device : str ``"auto"`` picks GPU if available, else CPU. """ from transformers import AutoConfig, AutoModel, AutoTokenizer config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained( model_path, config=config, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) model.eval() return cls(model, tokenizer, torch.device(device)) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def explain( self, text: str, *, method: Literal["integrated_gradients", "attention_rollout"] = "integrated_gradients", n_steps: int = 30, ) -> ExplainabilityOutput: """ Explain a single text input. Parameters ---------- text : str The conversation / sentence to score and explain. method : str ``"integrated_gradients"`` (default, most accurate) or ``"attention_rollout"`` (faster, attention-based). n_steps : int Riemann-sum steps for integrated gradients (ignored for rollout). Returns ------- ExplainabilityOutput """ # Tokenise enc = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=getattr(self.model.config, "max_position_embeddings", 512), ) input_ids = enc["input_ids"].to(self.device) attention_mask = enc["attention_mask"].to(self.device) # Decode token strings tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0].tolist()) # Base prediction with torch.no_grad(): base_out = self.model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) preds = base_out.predictions[0].cpu().tolist() predictions = {name: round(v, 4) for name, v in zip(self.score_names, preds)} # Attributions if method == "integrated_gradients": raw_attr = self._integrated_gradients(input_ids, attention_mask, n_steps) elif method == "attention_rollout": raw_attr = self._attention_rollout(input_ids, attention_mask) else: raise ValueError( f"Unknown method '{method}'. " "Choose 'integrated_gradients' or 'attention_rollout'." ) # Zero out attributions for Task/Input tokens — keep only the # Output section so that task names and input questions don't # dominate the explanation. output_start = _find_output_token_idx(tokens) if output_start is not None: for name in raw_attr: raw_attr[name][0, :output_start] = 0.0 # Re-normalise so surviving tokens sum to 1 total = raw_attr[name][0].sum() if total > 0: raw_attr[name][0] /= total # Convert tensors → plain lists attributions = { name: [round(float(v), 6) for v in attr[0]] for name, attr in raw_attr.items() } return ExplainabilityOutput( text=text, tokens=tokens, predictions=predictions, attributions=attributions, method=method, ) # ------------------------------------------------------------------ # Attribution: Integrated Gradients (Sundararajan et al., 2017) # ------------------------------------------------------------------ def _integrated_gradients( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, n_steps: int, ) -> Dict[str, torch.Tensor]: """ Integral of d(score)/d(embedding) along a straight path from a zero baseline to the actual input embedding. Returns Dict[score_name -> Tensor[1, seq_len]]. """ input_emb = self.model.get_input_embeddings()(input_ids).detach() baseline_emb = torch.zeros_like(input_emb) delta = input_emb - baseline_emb alphas = torch.linspace(0.0, 1.0, n_steps, device=self.device) accum = {name: torch.zeros_like(input_emb) for name in self.score_names} for alpha in alphas: interp = (baseline_emb + alpha * delta).requires_grad_(True) preds = self._forward_from_embeddings(interp, attention_mask) for i, name in enumerate(self.score_names): (grad,) = torch.autograd.grad( preds[:, i].sum(), interp, retain_graph=(i < self.num_scores - 1), ) accum[name] += grad.detach() attributions: Dict[str, torch.Tensor] = {} for name in self.score_names: ig = (delta * accum[name] / n_steps).norm(dim=-1) # [1, L] ig = ig * attention_mask.float() ig = ig / ig.sum(dim=-1, keepdim=True).clamp_min(1e-9) attributions[name] = ig return attributions # ------------------------------------------------------------------ # Attribution: Attention Rollout (Abnar & Zuidema, 2020) # ------------------------------------------------------------------ def _attention_rollout( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ Propagate attention through all layers, accounting for residual connections. Token importance = attention flowing from CLS to each token in the final rolled-out matrix. Returns Dict[score_name -> Tensor[1, seq_len]]. """ attentions = self._get_attentions(input_ids, attention_mask) B, L = attention_mask.shape dummy = torch.zeros(B, L, device=self.device) if not attentions: return {n: dummy for n in self.score_names} rollout = torch.eye(L, device=self.device).unsqueeze(0).expand(B, -1, -1).clone() mask_2d = attention_mask.unsqueeze(-1).float() * attention_mask.unsqueeze(-2).float() for layer_attn in attentions: if layer_attn is None or layer_attn.dim() != 4: continue attn = layer_attn.mean(dim=1) # mean over heads -> [B, L, L] attn = attn + torch.eye(L, device=self.device).unsqueeze(0) # residual attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9) attn = attn * mask_2d rollout = torch.bmm(attn, rollout) final = rollout[:, 0, :] * attention_mask.float() final = final / final.sum(dim=-1, keepdim=True).clamp_min(1e-9) return {n: final.clone() for n in self.score_names} # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _forward_from_embeddings( self, embeddings: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: """Full forward pass from pre-computed embeddings -> [B, num_scores].""" backbone_out = self.model.backbone( inputs_embeds=embeddings, attention_mask=attention_mask, return_dict=True, ) hidden = backbone_out.last_hidden_state pooled = self.model._pool_hidden_states(hidden, attention_mask) target_dtype = next(self.model.score_heads[0].parameters()).dtype pooled = pooled.to(target_dtype) if self.model.shared_encoder is not None: features = self.model.shared_encoder(pooled) else: features = pooled preds = torch.cat( [1.0 + 4.0 * torch.sigmoid(head(features)) for head in self.model.score_heads], dim=-1, ) return preds def _get_attentions( self, input_ids: torch.Tensor, attention_mask: torch.Tensor ) -> Optional[Tuple[torch.Tensor, ...]]: """Retrieve attention weights from the backbone (no-grad).""" try: with torch.no_grad(): out = self.model( input_ids=input_ids, attention_mask=attention_mask, output_attentions=True, return_dict=True, ) return out.attentions except Exception: return None # ------------------------------------------------------------------ # Text formatting # ------------------------------------------------------------------ def format( self, result: ExplainabilityOutput, top_k: int = 10, score_name: Optional[str] = None, ) -> str: """ Readable plain-text summary of the explanation. Shows whole words (sub-words merged) with percentage attributions. Special tokens ([CLS], [SEP], …) are excluded. Parameters ---------- result : ExplainabilityOutput top_k : int How many top words to show per score. score_name : str, optional Show only this score (default: all). """ lines: List[str] = [] sep = "-" * 44 # Predictions lines.append("Predicted scores:") for name, val in result.predictions.items(): lines.append(f" {name:<20} {val:.4f}") lines.append("") # Attributions (merged into words, shown as %) scores_to_show = [score_name] if score_name else self.score_names for sn in scores_to_show: if sn not in result.attributions: continue words = _merge_subwords(result.tokens, result.attributions[sn]) top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k] lines.append(f"-- {sn} ({result.method}) --") lines.append(f"{'Word':<28} {'Importance':>12}") lines.append(sep) for word, pct in top: bar = "\u2588" * int(pct / 2) # simple ascii bar lines.append(f"{word:<28} {pct:>5.1f}% {bar}") lines.append("") return "\n".join(lines) # ------------------------------------------------------------------ # HTML # ------------------------------------------------------------------ def to_html( self, result: ExplainabilityOutput, score_name: Optional[str] = None, ) -> str: """ HTML span-highlighted attribution view. Tokens are coloured white -> gold proportional to their importance. """ sn = score_name or self.score_names[0] if sn not in result.attributions: return f"

Score '{sn}' not found.

" attrs = result.attributions[sn] a_min, a_max = min(attrs), max(attrs) rng = a_max - a_min if abs(a_max - a_min) > 1e-9 else 1.0 spans: List[str] = [] for tok, val in zip(result.tokens, attrs): w = max(0.0, min(1.0, (val - a_min) / rng)) r, g, b = 255, int(255 * (1 - 0.16 * w)), int(255 * (1 - w)) tok_disp = _clean_token(tok).replace("<", "<").replace(">", ">") spans.append( f'{tok_disp}' ) pred_str = "" if sn in result.predictions: pred_str = f"

{sn}: {result.predictions[sn]:.4f}

" return ( f"
" f"{pred_str}

{' '.join(spans)}

" ) # ------------------------------------------------------------------ # Visualisation # ------------------------------------------------------------------ def plot( self, result: ExplainabilityOutput, top_k: int = 15, score_name: Optional[str] = None, figsize: Optional[tuple] = None, save_path: Optional[str] = None, ): """ Horizontal bar chart of the top-k most important **words** (sub-words merged, specials removed) per score, shown as percentages. Parameters ---------- result : ExplainabilityOutput top_k : int Words to display per score. score_name : str, optional Single score only (default: one subplot per score). save_path : str, optional Save figure to this path. Returns ------- matplotlib.figure.Figure """ import matplotlib.pyplot as plt scores = [score_name] if score_name else self.score_names n = len(scores) colours = ["#4C72B0", "#DD8452", "#55A868", "#C44E52"] w = figsize[0] if figsize else 7 h = figsize[1] if figsize else 2.8 * n fig, axes = plt.subplots(n, 1, figsize=(w, h)) if n == 1: axes = [axes] for ax, sn, colour in zip(axes, scores, colours * 4): if sn not in result.attributions: ax.set_visible(False) continue words = _merge_subwords(result.tokens, result.attributions[sn]) top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k] labels = [w for w, _ in top] pcts = np.array([p for _, p in top]) bars = ax.barh(range(len(pcts)), pcts, color=colour, edgecolor="white", linewidth=0.4, height=0.72) ax.set_yticks(range(len(labels))) ax.set_yticklabels(labels, fontsize=9) ax.invert_yaxis() ax.set_xlabel("Importance (%)") ax.set_xlim(0, pcts[0] * 1.25 if len(pcts) else 10) pred_val = result.predictions.get(sn, 0) ax.set_title(f"{sn.capitalize()} (predicted: {pred_val:.2f})", fontweight="bold", fontsize=10) # Annotate bars for bar, pct in zip(bars, pcts): ax.text(bar.get_width() + pcts[0] * 0.02, bar.get_y() + bar.get_height() / 2, f"{pct:.1f}%", va="center", fontsize=8, color="#333") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) fig.suptitle(f"Word Importance ({result.method.replace('_', ' ').title()})", fontsize=12, fontweight="bold", y=1.01) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig def plot_heatmap( self, result: ExplainabilityOutput, top_k: int = 25, figsize: Optional[tuple] = None, save_path: Optional[str] = None, ): """ Heatmap: scores (rows) x top-k words (columns). Each cell shows the relative importance of a word for a given score dimension, row-normalised so that each score's max = 1. Returns ------- matplotlib.figure.Figure """ import matplotlib.pyplot as plt from matplotlib.colors import LinearSegmentedColormap cmap = LinearSegmentedColormap.from_list( "attr", ["#FFFFFF", "#FFF7CD", "#FFD700", "#FF6B00", "#8B0000"] ) # Merge subwords per score, collect union of top words merged: Dict[str, Dict[str, float]] = {} for sn in self.score_names: if sn not in result.attributions: continue words = _merge_subwords(result.tokens, result.attributions[sn]) merged[sn] = {w: p for w, p in words} # Rank words by average importance across scores all_words: Dict[str, float] = {} for word_dict in merged.values(): for w, p in word_dict.items(): all_words[w] = all_words.get(w, 0) + p ranked = sorted(all_words, key=all_words.get, reverse=True)[:top_k] matrix = np.array([ [merged.get(sn, {}).get(w, 0) for w in ranked] for sn in self.score_names if sn in merged ]) row_max = matrix.max(axis=1, keepdims=True) row_max[row_max == 0] = 1 matrix = matrix / row_max w = figsize[0] if figsize else max(10, top_k * 0.38) h = figsize[1] if figsize else 2.4 fig, ax = plt.subplots(figsize=(w, h)) im = ax.imshow(matrix, aspect="auto", cmap=cmap, vmin=0, vmax=1, interpolation="nearest") ax.set_xticks(range(len(ranked))) ax.set_xticklabels(ranked, rotation=45, ha="right", fontsize=8) valid_names = [s for s in self.score_names if s in merged] ax.set_yticks(range(len(valid_names))) ax.set_yticklabels([s.capitalize() for s in valid_names], fontsize=9) ax.set_xlabel("Word (ranked by aggregate importance)") cb = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02) cb.set_label("Relative importance", fontsize=8) ax.set_title("Word Importance Across Score Dimensions", fontsize=10, fontweight="bold", pad=8) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight") return fig def plot_summary( self, result: ExplainabilityOutput, top_k: int = 10, output_only: bool = True, figsize: tuple = (16, 14), save_path: Optional[str] = None, ): """ Publication-quality composite figure. Layout:: ┌──────────────────────────────────────────────────┐ │ Title + colour legend bar │ ├──────────────────────────────────────────────────┤ │ Task / Input context box │ ├──────────────────────┬───────────────────────────┤ │ Highlighted output │ Top-k bar chart │ × n_scores └──────────────────────┴───────────────────────────┘ Parameters ---------- result : ExplainabilityOutput top_k : int Words per bar chart. output_only : bool If True (default), only highlight text after the last ``Output:`` / ``Answer:`` marker. figsize : tuple Figure size. save_path : str, optional Save path. Returns ------- matplotlib.figure.Figure """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import textwrap colours = ["#3B6FA0", "#D07830", "#3D9050", "#BB3B3B"] light_bg = ["#EBF0F7", "#FDF3EB", "#EBF5EE", "#F8EBEB"] n_scores = len(self.score_names) fig = plt.figure(figsize=figsize, facecolor="white") outer = gridspec.GridSpec( n_scores + 2, 1, figure=fig, height_ratios=[0.15, 0.25] + [1] * n_scores, hspace=0.28, ) # ── Row 0: Title + gradient legend ──────────────────────────── ax_title = fig.add_subplot(outer[0]) ax_title.axis("off") method_label = result.method.replace("_", " ").title() ax_title.text( 0.5, 0.65, f"OmniScore Explanation \u2014 {method_label}", transform=ax_title.transAxes, fontsize=14, fontweight="bold", ha="center", va="center", color="#222", ) # Smooth gradient bar import matplotlib.colors as mcolors grad = np.linspace(0, 1, 256).reshape(1, -1) cmap_legend = mcolors.LinearSegmentedColormap.from_list( "_lg", ["#F0F0F0", "#FDDC6C", "#E8792B", "#9E2320"] ) ax_cbar = fig.add_axes([0.32, 0.945, 0.36, 0.012]) # [left, bottom, w, h] ax_cbar.imshow(grad, aspect="auto", cmap=cmap_legend) ax_cbar.set_xticks([]) ax_cbar.set_yticks([]) for spine in ax_cbar.spines.values(): spine.set_visible(False) fig.text(0.31, 0.950, "Low", fontsize=7.5, ha="right", color="#888") fig.text(0.69, 0.950, "High", fontsize=7.5, ha="left", color="#888") # ── Row 1: Task / Input context ─────────────────────────────── ax_ctx = fig.add_subplot(outer[1]) ax_ctx.axis("off") raw = result.text task_str, input_str = "", "" for line in raw.split("\n"): s = line.strip() if s.lower().startswith("task:"): task_str = s[5:].strip() elif s.lower().startswith("input:"): input_str = s[6:].strip() ctx_parts: List[str] = [] if task_str: ctx_parts.append(f"Task: {task_str}") if input_str: ctx_parts.append(f"Input: {textwrap.fill(input_str, width=105)}") ctx_text = "\n".join(ctx_parts) if ctx_parts else raw[:200] ax_ctx.text( 0.02, 0.85, ctx_text, transform=ax_ctx.transAxes, fontsize=8.5, va="top", fontfamily="monospace", color="#333", linespacing=1.6, bbox=dict( boxstyle="round,pad=0.6", facecolor="#FAFAFA", edgecolor="#D0D0D0", linewidth=0.7, ), ) # ── Per-score rows (highlighted text | bar chart) ───────────── for idx, sn in enumerate(self.score_names): if sn not in result.attributions: continue colour = colours[idx % len(colours)] bg_colour = light_bg[idx % len(light_bg)] base_rgb = np.array([ int(colour[i:i+2], 16) / 255 for i in (1, 3, 5) ]) all_words = _merge_subwords(result.tokens, result.attributions[sn]) display_words = _extract_output_words(all_words) if output_only else list(all_words) word_names = [w for w, _ in display_words] pcts = np.array([p for _, p in display_words]) pmax = pcts.max() if len(pcts) and pcts.max() > 0 else 1.0 norms = pcts / pmax pred_val = result.predictions.get(sn, 0) inner = gridspec.GridSpecFromSubplotSpec( 1, 2, subplot_spec=outer[idx + 2], width_ratios=[1.6, 1], wspace=0.22, ) # ────────── LEFT: highlighted output text ────────────────── ax_text = fig.add_subplot(inner[0]) ax_text.axis("off") ax_text.set_xlim(0, 1) ax_text.set_ylim(0, 1) # Light background panel from matplotlib.patches import FancyBboxPatch ax_text.add_patch(FancyBboxPatch( (0, 0), 1, 1, boxstyle="round,pad=0.02", facecolor=bg_colour, edgecolor="#ddd", linewidth=0.6, transform=ax_text.transAxes, clip_on=False, )) # Score label ax_text.text( 0.02, 0.96, f"{sn.capitalize()} \u2014 predicted {pred_val:.2f} / 5", transform=ax_text.transAxes, fontsize=10, fontweight="bold", va="top", color=colour, ) # Word highlighting with proper wrapping renderer = fig.canvas.get_renderer() x, y = 0.02, 0.84 line_h = 0.085 gap = 0.005 for w, nv in zip(word_names, norms): # Apply a power curve so mid-range values are more visible intensity = nv ** 0.55 bg = tuple(1.0 + (base_rgb[c] - 1.0) * intensity for c in range(3)) # Text colour: dark on light bg, white on dark bg txt_col = "#222" if intensity < 0.7 else "#fff" edge = colour if intensity > 0.35 else "none" t = ax_text.text( x, y, f" {w} ", transform=ax_text.transAxes, fontsize=9, va="top", fontfamily="sans-serif", color=txt_col, bbox=dict( boxstyle="round,pad=0.18", facecolor=bg, edgecolor=edge, linewidth=0.6 if edge != "none" else 0, ), ) bb = t.get_window_extent(renderer=renderer) bb_ax = bb.transformed(ax_text.transAxes.inverted()) word_w = bb_ax.width + gap x += word_w if x > 0.97: x = 0.02 y -= line_h if y < 0.0: break t.set_position((x, y)) bb = t.get_window_extent(renderer=renderer) bb_ax = bb.transformed(ax_text.transAxes.inverted()) x = 0.02 + bb_ax.width + gap # ────────── RIGHT: bar chart ─────────────────────────────── ax_bar = fig.add_subplot(inner[1]) top_words = sorted(display_words, key=lambda p: p[1], reverse=True)[:top_k] bar_labels = [w for w, _ in top_words] bar_pcts = np.array([p for _, p in top_words]) bar_norms = bar_pcts / pmax if pmax > 0 else bar_pcts bar_cols = [ tuple(1.0 + (base_rgb[c] - 1.0) * max(n, 0.15) for c in range(3)) for n in bar_norms ] bars = ax_bar.barh( range(len(bar_pcts)), bar_pcts, color=bar_cols, edgecolor="white", linewidth=0.6, height=0.72, ) ax_bar.set_yticks(range(len(bar_labels))) ax_bar.set_yticklabels(bar_labels, fontsize=8.5, fontfamily="sans-serif") ax_bar.invert_yaxis() ax_bar.set_xlabel("Importance (%)", fontsize=8) ax_bar.set_xlim(0, bar_pcts[0] * 1.32 if len(bar_pcts) else 10) ax_bar.set_title( f"Top-{top_k} words", fontsize=9, color="#555", pad=6, ) for bar, pct in zip(bars, bar_pcts): ax_bar.text( bar.get_width() + bar_pcts[0] * 0.015, bar.get_y() + bar.get_height() / 2, f"{pct:.1f}%", va="center", fontsize=7.5, color="#444", ) ax_bar.spines["top"].set_visible(False) ax_bar.spines["right"].set_visible(False) ax_bar.tick_params(axis="y", length=0) if save_path: fig.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") return fig # --------------------------------------------------------------------------- # Utility # --------------------------------------------------------------------------- # Tokens to exclude from explanations (model artefacts, not content). _SPECIAL_TOKENS = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "", "", "", "", ""} # Markers that signal the start of the model-generated output section. _OUTPUT_MARKERS = {"Output", "Answer", "Response", "output", "answer", "response"} def _find_output_token_idx(tokens: List[str]) -> Optional[int]: """ Find the token index where the Output/Answer section begins. Scans for the *last* occurrence of a known output marker token (e.g. "Output", "▁Output", "output") and returns the index of the first content token *after* the marker (skipping ":" if present). Returns ``None`` if no marker is found. """ last_marker = -1 for i, tok in enumerate(tokens): clean = tok.replace("\u2581", "").replace("##", "").strip(":").strip() if clean in _OUTPUT_MARKERS: last_marker = i if last_marker == -1: return None # Skip the marker itself, and an optional ":" token right after it start = last_marker + 1 if start < len(tokens): next_clean = tokens[start].replace("\u2581", "").replace("##", "").strip() if next_clean == ":": start += 1 return start def _clean_token(tok: str) -> str: """Strip SentencePiece / WordPiece artefacts for display.""" return ( tok.replace("\u2581", " ") .replace("##", "") .strip() or tok ) def _extract_output_words( words: List[Tuple[str, float]], ) -> List[Tuple[str, float]]: """ Return only the words that belong to the Output / Answer section. Scans the word list for the *last* occurrence of a known output marker (e.g. "Output", "Answer") and returns everything after it (excluding the marker word itself and any colon that follows). If no marker is found the full list is returned unchanged. """ last_marker = -1 for i, (w, _) in enumerate(words): clean = w.strip(":").strip() if clean in _OUTPUT_MARKERS: last_marker = i if last_marker == -1: return words # Skip the marker and an optional colon-word after it start = last_marker + 1 if start < len(words) and words[start][0].strip() == ":": start += 1 result = words[start:] # Re-normalise percentages so they sum to ~100 total = sum(p for _, p in result) if result else 1.0 return [(w, p / total * 100.0) for w, p in result] # Characters that are pure punctuation and should be glued to the # preceding word rather than stand alone. _PUNCT_GLUE = set('.,;:!?)]\'\"') _PUNCT_OPEN = set('([\"\'') def _merge_subwords( tokens: List[str], attributions: List[float], ) -> List[Tuple[str, float]]: """ Merge sub-word tokens back into whole words and sum their attributions. - WordPiece continuations (``##xyz``) are joined to the preceding word. - SentencePiece tokens starting with ``\u2581`` begin a new word. - Standalone punctuation (``.``, ``,``, ``)``, …) is glued to the preceding word so bar-chart labels stay clean. - Opening brackets/quotes are glued to the *following* word. - Special tokens ([CLS], [SEP], …) are dropped. Returns a list of ``(word, importance_percent)`` sorted by position. Percentages sum to ~100 (before any top-k truncation). """ words: List[str] = [] word_scores: List[float] = [] for tok, attr in zip(tokens, attributions): if tok in _SPECIAL_TOKENS: continue # WordPiece continuation if tok.startswith("##"): if words: words[-1] += tok[2:] word_scores[-1] += attr continue # SentencePiece: strip the leading ▁ clean = tok.replace("\u2581", "") if not clean: continue # Pure trailing punctuation → glue to previous word if clean in _PUNCT_GLUE and words: words[-1] += clean word_scores[-1] += attr continue # Opening punctuation → start a new word (will be glued to next) if clean in _PUNCT_OPEN: words.append(clean) word_scores.append(attr) continue is_new_word = tok.startswith("\u2581") or not words if is_new_word or not words: # If previous word is an opening bracket, glue this onto it if words and words[-1] in _PUNCT_OPEN: words[-1] += clean word_scores[-1] += attr else: words.append(clean) word_scores.append(attr) else: # sub-word continuation (no ## prefix, no \u2581 prefix) words[-1] += clean word_scores[-1] += attr # Convert raw attribution sums → percentages of total total = sum(word_scores) if word_scores else 1.0 return [(w, s / total * 100.0) for w, s in zip(words, word_scores)]