"""
Explainability for ScorePredictorModel.
Given a conversation text, shows *which tokens* drive each predicted score
and *how much* they contribute. Two attribution methods are provided:
integrated_gradients – gradient-based (most faithful, slower)
attention_rollout – attention-based (fast, good overview)
Quick start
-----------
from explain_score_predictor import ScorePredictorExplainer
explainer = ScorePredictorExplainer.from_pretrained("path/to/model")
# Get attributions from raw text
result = explainer.explain("User: Hello Assistant: Hi there!")
# Print a readable summary
print(explainer.format(result))
# Save a publication-quality figure
explainer.plot(result, save_path="attributions.pdf")
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple
import torch
import numpy as np
# ---------------------------------------------------------------------------
# Output container
# ---------------------------------------------------------------------------
@dataclass
class ExplainabilityOutput:
"""
Everything ``explain()`` returns.
Attributes
----------
text : str
Original input text.
tokens : List[str]
Tokenised input (human-readable sub-words).
predictions : Dict[str, float]
Predicted score per dimension (e.g. {"clarity": 3.8, …}).
attributions : Dict[str, List[float]]
Per-token attribution for each score dimension.
Length of inner list == len(tokens).
method : str
Attribution method used.
"""
text: str = ""
tokens: List[str] = field(default_factory=list)
predictions: Dict[str, float] = field(default_factory=dict)
attributions: Dict[str, List[float]] = field(default_factory=dict)
method: str = ""
# ---------------------------------------------------------------------------
# Main explainer
# ---------------------------------------------------------------------------
class ScorePredictorExplainer:
"""
Wraps a ``ScorePredictorModel`` and provides token-level explanations.
Parameters
----------
model : ScorePredictorModel
A loaded model instance.
tokenizer
The matching tokenizer.
device : str or torch.device, optional
Defaults to the model's current device.
"""
def __init__(self, model, tokenizer, device: Optional[torch.device] = None):
self.model = model
self.tokenizer = tokenizer
self.device = device or next(model.parameters()).device
self.score_names: List[str] = list(model.config.score_names)
self.num_scores: int = model.num_scores
self.model.eval()
# ------------------------------------------------------------------
# Convenience constructor
# ------------------------------------------------------------------
@classmethod
def from_pretrained(cls, model_path: str, device: str = "auto") -> "ScorePredictorExplainer":
"""
Load model + tokenizer from a saved checkpoint in one call.
Parameters
----------
model_path : str
Path (or HF hub id) to the saved model directory.
device : str
``"auto"`` picks GPU if available, else CPU.
"""
from transformers import AutoConfig, AutoModel, AutoTokenizer
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_path, config=config, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
return cls(model, tokenizer, torch.device(device))
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def explain(
self,
text: str,
*,
method: Literal["integrated_gradients", "attention_rollout"] = "integrated_gradients",
n_steps: int = 30,
) -> ExplainabilityOutput:
"""
Explain a single text input.
Parameters
----------
text : str
The conversation / sentence to score and explain.
method : str
``"integrated_gradients"`` (default, most accurate) or
``"attention_rollout"`` (faster, attention-based).
n_steps : int
Riemann-sum steps for integrated gradients (ignored for rollout).
Returns
-------
ExplainabilityOutput
"""
# Tokenise
enc = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=getattr(self.model.config, "max_position_embeddings", 512),
)
input_ids = enc["input_ids"].to(self.device)
attention_mask = enc["attention_mask"].to(self.device)
# Decode token strings
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
# Base prediction
with torch.no_grad():
base_out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
preds = base_out.predictions[0].cpu().tolist()
predictions = {name: round(v, 4) for name, v in zip(self.score_names, preds)}
# Attributions
if method == "integrated_gradients":
raw_attr = self._integrated_gradients(input_ids, attention_mask, n_steps)
elif method == "attention_rollout":
raw_attr = self._attention_rollout(input_ids, attention_mask)
else:
raise ValueError(
f"Unknown method '{method}'. "
"Choose 'integrated_gradients' or 'attention_rollout'."
)
# Zero out attributions for Task/Input tokens — keep only the
# Output section so that task names and input questions don't
# dominate the explanation.
output_start = _find_output_token_idx(tokens)
if output_start is not None:
for name in raw_attr:
raw_attr[name][0, :output_start] = 0.0
# Re-normalise so surviving tokens sum to 1
total = raw_attr[name][0].sum()
if total > 0:
raw_attr[name][0] /= total
# Convert tensors → plain lists
attributions = {
name: [round(float(v), 6) for v in attr[0]]
for name, attr in raw_attr.items()
}
return ExplainabilityOutput(
text=text,
tokens=tokens,
predictions=predictions,
attributions=attributions,
method=method,
)
# ------------------------------------------------------------------
# Attribution: Integrated Gradients (Sundararajan et al., 2017)
# ------------------------------------------------------------------
def _integrated_gradients(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
n_steps: int,
) -> Dict[str, torch.Tensor]:
"""
Integral of d(score)/d(embedding) along a straight path from a zero
baseline to the actual input embedding.
Returns Dict[score_name -> Tensor[1, seq_len]].
"""
input_emb = self.model.get_input_embeddings()(input_ids).detach()
baseline_emb = torch.zeros_like(input_emb)
delta = input_emb - baseline_emb
alphas = torch.linspace(0.0, 1.0, n_steps, device=self.device)
accum = {name: torch.zeros_like(input_emb) for name in self.score_names}
for alpha in alphas:
interp = (baseline_emb + alpha * delta).requires_grad_(True)
preds = self._forward_from_embeddings(interp, attention_mask)
for i, name in enumerate(self.score_names):
(grad,) = torch.autograd.grad(
preds[:, i].sum(),
interp,
retain_graph=(i < self.num_scores - 1),
)
accum[name] += grad.detach()
attributions: Dict[str, torch.Tensor] = {}
for name in self.score_names:
ig = (delta * accum[name] / n_steps).norm(dim=-1) # [1, L]
ig = ig * attention_mask.float()
ig = ig / ig.sum(dim=-1, keepdim=True).clamp_min(1e-9)
attributions[name] = ig
return attributions
# ------------------------------------------------------------------
# Attribution: Attention Rollout (Abnar & Zuidema, 2020)
# ------------------------------------------------------------------
def _attention_rollout(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Propagate attention through all layers, accounting for residual
connections. Token importance = attention flowing from CLS to each
token in the final rolled-out matrix.
Returns Dict[score_name -> Tensor[1, seq_len]].
"""
attentions = self._get_attentions(input_ids, attention_mask)
B, L = attention_mask.shape
dummy = torch.zeros(B, L, device=self.device)
if not attentions:
return {n: dummy for n in self.score_names}
rollout = torch.eye(L, device=self.device).unsqueeze(0).expand(B, -1, -1).clone()
mask_2d = attention_mask.unsqueeze(-1).float() * attention_mask.unsqueeze(-2).float()
for layer_attn in attentions:
if layer_attn is None or layer_attn.dim() != 4:
continue
attn = layer_attn.mean(dim=1) # mean over heads -> [B, L, L]
attn = attn + torch.eye(L, device=self.device).unsqueeze(0) # residual
attn = attn / attn.sum(dim=-1, keepdim=True).clamp_min(1e-9)
attn = attn * mask_2d
rollout = torch.bmm(attn, rollout)
final = rollout[:, 0, :] * attention_mask.float()
final = final / final.sum(dim=-1, keepdim=True).clamp_min(1e-9)
return {n: final.clone() for n in self.score_names}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _forward_from_embeddings(
self, embeddings: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Full forward pass from pre-computed embeddings -> [B, num_scores]."""
backbone_out = self.model.backbone(
inputs_embeds=embeddings,
attention_mask=attention_mask,
return_dict=True,
)
hidden = backbone_out.last_hidden_state
pooled = self.model._pool_hidden_states(hidden, attention_mask)
target_dtype = next(self.model.score_heads[0].parameters()).dtype
pooled = pooled.to(target_dtype)
if self.model.shared_encoder is not None:
features = self.model.shared_encoder(pooled)
else:
features = pooled
preds = torch.cat(
[1.0 + 4.0 * torch.sigmoid(head(features)) for head in self.model.score_heads],
dim=-1,
)
return preds
def _get_attentions(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor
) -> Optional[Tuple[torch.Tensor, ...]]:
"""Retrieve attention weights from the backbone (no-grad)."""
try:
with torch.no_grad():
out = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True,
return_dict=True,
)
return out.attentions
except Exception:
return None
# ------------------------------------------------------------------
# Text formatting
# ------------------------------------------------------------------
def format(
self,
result: ExplainabilityOutput,
top_k: int = 10,
score_name: Optional[str] = None,
) -> str:
"""
Readable plain-text summary of the explanation.
Shows whole words (sub-words merged) with percentage attributions.
Special tokens ([CLS], [SEP], …) are excluded.
Parameters
----------
result : ExplainabilityOutput
top_k : int
How many top words to show per score.
score_name : str, optional
Show only this score (default: all).
"""
lines: List[str] = []
sep = "-" * 44
# Predictions
lines.append("Predicted scores:")
for name, val in result.predictions.items():
lines.append(f" {name:<20} {val:.4f}")
lines.append("")
# Attributions (merged into words, shown as %)
scores_to_show = [score_name] if score_name else self.score_names
for sn in scores_to_show:
if sn not in result.attributions:
continue
words = _merge_subwords(result.tokens, result.attributions[sn])
top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k]
lines.append(f"-- {sn} ({result.method}) --")
lines.append(f"{'Word':<28} {'Importance':>12}")
lines.append(sep)
for word, pct in top:
bar = "\u2588" * int(pct / 2) # simple ascii bar
lines.append(f"{word:<28} {pct:>5.1f}% {bar}")
lines.append("")
return "\n".join(lines)
# ------------------------------------------------------------------
# HTML
# ------------------------------------------------------------------
def to_html(
self,
result: ExplainabilityOutput,
score_name: Optional[str] = None,
) -> str:
"""
HTML span-highlighted attribution view.
Tokens are coloured white -> gold proportional to their importance.
"""
sn = score_name or self.score_names[0]
if sn not in result.attributions:
return f"
Score '{sn}' not found.
"
attrs = result.attributions[sn]
a_min, a_max = min(attrs), max(attrs)
rng = a_max - a_min if abs(a_max - a_min) > 1e-9 else 1.0
spans: List[str] = []
for tok, val in zip(result.tokens, attrs):
w = max(0.0, min(1.0, (val - a_min) / rng))
r, g, b = 255, int(255 * (1 - 0.16 * w)), int(255 * (1 - w))
tok_disp = _clean_token(tok).replace("<", "<").replace(">", ">")
spans.append(
f'{tok_disp}'
)
pred_str = ""
if sn in result.predictions:
pred_str = f"{sn}: {result.predictions[sn]:.4f}
"
return (
f""
f"{pred_str}
{' '.join(spans)}
"
)
# ------------------------------------------------------------------
# Visualisation
# ------------------------------------------------------------------
def plot(
self,
result: ExplainabilityOutput,
top_k: int = 15,
score_name: Optional[str] = None,
figsize: Optional[tuple] = None,
save_path: Optional[str] = None,
):
"""
Horizontal bar chart of the top-k most important **words**
(sub-words merged, specials removed) per score, shown as percentages.
Parameters
----------
result : ExplainabilityOutput
top_k : int
Words to display per score.
score_name : str, optional
Single score only (default: one subplot per score).
save_path : str, optional
Save figure to this path.
Returns
-------
matplotlib.figure.Figure
"""
import matplotlib.pyplot as plt
scores = [score_name] if score_name else self.score_names
n = len(scores)
colours = ["#4C72B0", "#DD8452", "#55A868", "#C44E52"]
w = figsize[0] if figsize else 7
h = figsize[1] if figsize else 2.8 * n
fig, axes = plt.subplots(n, 1, figsize=(w, h))
if n == 1:
axes = [axes]
for ax, sn, colour in zip(axes, scores, colours * 4):
if sn not in result.attributions:
ax.set_visible(False)
continue
words = _merge_subwords(result.tokens, result.attributions[sn])
top = sorted(words, key=lambda p: p[1], reverse=True)[:top_k]
labels = [w for w, _ in top]
pcts = np.array([p for _, p in top])
bars = ax.barh(range(len(pcts)), pcts, color=colour,
edgecolor="white", linewidth=0.4, height=0.72)
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels, fontsize=9)
ax.invert_yaxis()
ax.set_xlabel("Importance (%)")
ax.set_xlim(0, pcts[0] * 1.25 if len(pcts) else 10)
pred_val = result.predictions.get(sn, 0)
ax.set_title(f"{sn.capitalize()} (predicted: {pred_val:.2f})",
fontweight="bold", fontsize=10)
# Annotate bars
for bar, pct in zip(bars, pcts):
ax.text(bar.get_width() + pcts[0] * 0.02,
bar.get_y() + bar.get_height() / 2,
f"{pct:.1f}%", va="center", fontsize=8, color="#333")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.suptitle(f"Word Importance ({result.method.replace('_', ' ').title()})",
fontsize=12, fontweight="bold", y=1.01)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig
def plot_heatmap(
self,
result: ExplainabilityOutput,
top_k: int = 25,
figsize: Optional[tuple] = None,
save_path: Optional[str] = None,
):
"""
Heatmap: scores (rows) x top-k words (columns).
Each cell shows the relative importance of a word for a given score
dimension, row-normalised so that each score's max = 1.
Returns
-------
matplotlib.figure.Figure
"""
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list(
"attr", ["#FFFFFF", "#FFF7CD", "#FFD700", "#FF6B00", "#8B0000"]
)
# Merge subwords per score, collect union of top words
merged: Dict[str, Dict[str, float]] = {}
for sn in self.score_names:
if sn not in result.attributions:
continue
words = _merge_subwords(result.tokens, result.attributions[sn])
merged[sn] = {w: p for w, p in words}
# Rank words by average importance across scores
all_words: Dict[str, float] = {}
for word_dict in merged.values():
for w, p in word_dict.items():
all_words[w] = all_words.get(w, 0) + p
ranked = sorted(all_words, key=all_words.get, reverse=True)[:top_k]
matrix = np.array([
[merged.get(sn, {}).get(w, 0) for w in ranked]
for sn in self.score_names if sn in merged
])
row_max = matrix.max(axis=1, keepdims=True)
row_max[row_max == 0] = 1
matrix = matrix / row_max
w = figsize[0] if figsize else max(10, top_k * 0.38)
h = figsize[1] if figsize else 2.4
fig, ax = plt.subplots(figsize=(w, h))
im = ax.imshow(matrix, aspect="auto", cmap=cmap, vmin=0, vmax=1,
interpolation="nearest")
ax.set_xticks(range(len(ranked)))
ax.set_xticklabels(ranked, rotation=45, ha="right", fontsize=8)
valid_names = [s for s in self.score_names if s in merged]
ax.set_yticks(range(len(valid_names)))
ax.set_yticklabels([s.capitalize() for s in valid_names], fontsize=9)
ax.set_xlabel("Word (ranked by aggregate importance)")
cb = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
cb.set_label("Relative importance", fontsize=8)
ax.set_title("Word Importance Across Score Dimensions",
fontsize=10, fontweight="bold", pad=8)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight")
return fig
def plot_summary(
self,
result: ExplainabilityOutput,
top_k: int = 10,
output_only: bool = True,
figsize: tuple = (16, 14),
save_path: Optional[str] = None,
):
"""
Publication-quality composite figure.
Layout::
┌──────────────────────────────────────────────────┐
│ Title + colour legend bar │
├──────────────────────────────────────────────────┤
│ Task / Input context box │
├──────────────────────┬───────────────────────────┤
│ Highlighted output │ Top-k bar chart │ × n_scores
└──────────────────────┴───────────────────────────┘
Parameters
----------
result : ExplainabilityOutput
top_k : int
Words per bar chart.
output_only : bool
If True (default), only highlight text after the last
``Output:`` / ``Answer:`` marker.
figsize : tuple
Figure size.
save_path : str, optional
Save path.
Returns
-------
matplotlib.figure.Figure
"""
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import textwrap
colours = ["#3B6FA0", "#D07830", "#3D9050", "#BB3B3B"]
light_bg = ["#EBF0F7", "#FDF3EB", "#EBF5EE", "#F8EBEB"]
n_scores = len(self.score_names)
fig = plt.figure(figsize=figsize, facecolor="white")
outer = gridspec.GridSpec(
n_scores + 2, 1, figure=fig,
height_ratios=[0.15, 0.25] + [1] * n_scores,
hspace=0.28,
)
# ── Row 0: Title + gradient legend ────────────────────────────
ax_title = fig.add_subplot(outer[0])
ax_title.axis("off")
method_label = result.method.replace("_", " ").title()
ax_title.text(
0.5, 0.65,
f"OmniScore Explanation \u2014 {method_label}",
transform=ax_title.transAxes, fontsize=14, fontweight="bold",
ha="center", va="center", color="#222",
)
# Smooth gradient bar
import matplotlib.colors as mcolors
grad = np.linspace(0, 1, 256).reshape(1, -1)
cmap_legend = mcolors.LinearSegmentedColormap.from_list(
"_lg", ["#F0F0F0", "#FDDC6C", "#E8792B", "#9E2320"]
)
ax_cbar = fig.add_axes([0.32, 0.945, 0.36, 0.012]) # [left, bottom, w, h]
ax_cbar.imshow(grad, aspect="auto", cmap=cmap_legend)
ax_cbar.set_xticks([])
ax_cbar.set_yticks([])
for spine in ax_cbar.spines.values():
spine.set_visible(False)
fig.text(0.31, 0.950, "Low", fontsize=7.5, ha="right", color="#888")
fig.text(0.69, 0.950, "High", fontsize=7.5, ha="left", color="#888")
# ── Row 1: Task / Input context ───────────────────────────────
ax_ctx = fig.add_subplot(outer[1])
ax_ctx.axis("off")
raw = result.text
task_str, input_str = "", ""
for line in raw.split("\n"):
s = line.strip()
if s.lower().startswith("task:"):
task_str = s[5:].strip()
elif s.lower().startswith("input:"):
input_str = s[6:].strip()
ctx_parts: List[str] = []
if task_str:
ctx_parts.append(f"Task: {task_str}")
if input_str:
ctx_parts.append(f"Input: {textwrap.fill(input_str, width=105)}")
ctx_text = "\n".join(ctx_parts) if ctx_parts else raw[:200]
ax_ctx.text(
0.02, 0.85, ctx_text,
transform=ax_ctx.transAxes, fontsize=8.5, va="top",
fontfamily="monospace", color="#333", linespacing=1.6,
bbox=dict(
boxstyle="round,pad=0.6", facecolor="#FAFAFA",
edgecolor="#D0D0D0", linewidth=0.7,
),
)
# ── Per-score rows (highlighted text | bar chart) ─────────────
for idx, sn in enumerate(self.score_names):
if sn not in result.attributions:
continue
colour = colours[idx % len(colours)]
bg_colour = light_bg[idx % len(light_bg)]
base_rgb = np.array([
int(colour[i:i+2], 16) / 255 for i in (1, 3, 5)
])
all_words = _merge_subwords(result.tokens, result.attributions[sn])
display_words = _extract_output_words(all_words) if output_only else list(all_words)
word_names = [w for w, _ in display_words]
pcts = np.array([p for _, p in display_words])
pmax = pcts.max() if len(pcts) and pcts.max() > 0 else 1.0
norms = pcts / pmax
pred_val = result.predictions.get(sn, 0)
inner = gridspec.GridSpecFromSubplotSpec(
1, 2, subplot_spec=outer[idx + 2],
width_ratios=[1.6, 1], wspace=0.22,
)
# ────────── LEFT: highlighted output text ──────────────────
ax_text = fig.add_subplot(inner[0])
ax_text.axis("off")
ax_text.set_xlim(0, 1)
ax_text.set_ylim(0, 1)
# Light background panel
from matplotlib.patches import FancyBboxPatch
ax_text.add_patch(FancyBboxPatch(
(0, 0), 1, 1, boxstyle="round,pad=0.02",
facecolor=bg_colour, edgecolor="#ddd", linewidth=0.6,
transform=ax_text.transAxes, clip_on=False,
))
# Score label
ax_text.text(
0.02, 0.96,
f"{sn.capitalize()} \u2014 predicted {pred_val:.2f} / 5",
transform=ax_text.transAxes, fontsize=10,
fontweight="bold", va="top", color=colour,
)
# Word highlighting with proper wrapping
renderer = fig.canvas.get_renderer()
x, y = 0.02, 0.84
line_h = 0.085
gap = 0.005
for w, nv in zip(word_names, norms):
# Apply a power curve so mid-range values are more visible
intensity = nv ** 0.55
bg = tuple(1.0 + (base_rgb[c] - 1.0) * intensity for c in range(3))
# Text colour: dark on light bg, white on dark bg
txt_col = "#222" if intensity < 0.7 else "#fff"
edge = colour if intensity > 0.35 else "none"
t = ax_text.text(
x, y, f" {w} ",
transform=ax_text.transAxes, fontsize=9,
va="top", fontfamily="sans-serif", color=txt_col,
bbox=dict(
boxstyle="round,pad=0.18", facecolor=bg,
edgecolor=edge, linewidth=0.6 if edge != "none" else 0,
),
)
bb = t.get_window_extent(renderer=renderer)
bb_ax = bb.transformed(ax_text.transAxes.inverted())
word_w = bb_ax.width + gap
x += word_w
if x > 0.97:
x = 0.02
y -= line_h
if y < 0.0:
break
t.set_position((x, y))
bb = t.get_window_extent(renderer=renderer)
bb_ax = bb.transformed(ax_text.transAxes.inverted())
x = 0.02 + bb_ax.width + gap
# ────────── RIGHT: bar chart ───────────────────────────────
ax_bar = fig.add_subplot(inner[1])
top_words = sorted(display_words, key=lambda p: p[1], reverse=True)[:top_k]
bar_labels = [w for w, _ in top_words]
bar_pcts = np.array([p for _, p in top_words])
bar_norms = bar_pcts / pmax if pmax > 0 else bar_pcts
bar_cols = [
tuple(1.0 + (base_rgb[c] - 1.0) * max(n, 0.15) for c in range(3))
for n in bar_norms
]
bars = ax_bar.barh(
range(len(bar_pcts)), bar_pcts, color=bar_cols,
edgecolor="white", linewidth=0.6, height=0.72,
)
ax_bar.set_yticks(range(len(bar_labels)))
ax_bar.set_yticklabels(bar_labels, fontsize=8.5,
fontfamily="sans-serif")
ax_bar.invert_yaxis()
ax_bar.set_xlabel("Importance (%)", fontsize=8)
ax_bar.set_xlim(0, bar_pcts[0] * 1.32 if len(bar_pcts) else 10)
ax_bar.set_title(
f"Top-{top_k} words", fontsize=9, color="#555", pad=6,
)
for bar, pct in zip(bars, bar_pcts):
ax_bar.text(
bar.get_width() + bar_pcts[0] * 0.015,
bar.get_y() + bar.get_height() / 2,
f"{pct:.1f}%", va="center", fontsize=7.5, color="#444",
)
ax_bar.spines["top"].set_visible(False)
ax_bar.spines["right"].set_visible(False)
ax_bar.tick_params(axis="y", length=0)
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches="tight",
facecolor="white")
return fig
# ---------------------------------------------------------------------------
# Utility
# ---------------------------------------------------------------------------
# Tokens to exclude from explanations (model artefacts, not content).
_SPECIAL_TOKENS = {"[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]",
"", "", "", "", ""}
# Markers that signal the start of the model-generated output section.
_OUTPUT_MARKERS = {"Output", "Answer", "Response", "output", "answer", "response"}
def _find_output_token_idx(tokens: List[str]) -> Optional[int]:
"""
Find the token index where the Output/Answer section begins.
Scans for the *last* occurrence of a known output marker token
(e.g. "Output", "▁Output", "output") and returns the index of the
first content token *after* the marker (skipping ":" if present).
Returns ``None`` if no marker is found.
"""
last_marker = -1
for i, tok in enumerate(tokens):
clean = tok.replace("\u2581", "").replace("##", "").strip(":").strip()
if clean in _OUTPUT_MARKERS:
last_marker = i
if last_marker == -1:
return None
# Skip the marker itself, and an optional ":" token right after it
start = last_marker + 1
if start < len(tokens):
next_clean = tokens[start].replace("\u2581", "").replace("##", "").strip()
if next_clean == ":":
start += 1
return start
def _clean_token(tok: str) -> str:
"""Strip SentencePiece / WordPiece artefacts for display."""
return (
tok.replace("\u2581", " ")
.replace("##", "")
.strip()
or tok
)
def _extract_output_words(
words: List[Tuple[str, float]],
) -> List[Tuple[str, float]]:
"""
Return only the words that belong to the Output / Answer section.
Scans the word list for the *last* occurrence of a known output marker
(e.g. "Output", "Answer") and returns everything after it (excluding
the marker word itself and any colon that follows).
If no marker is found the full list is returned unchanged.
"""
last_marker = -1
for i, (w, _) in enumerate(words):
clean = w.strip(":").strip()
if clean in _OUTPUT_MARKERS:
last_marker = i
if last_marker == -1:
return words
# Skip the marker and an optional colon-word after it
start = last_marker + 1
if start < len(words) and words[start][0].strip() == ":":
start += 1
result = words[start:]
# Re-normalise percentages so they sum to ~100
total = sum(p for _, p in result) if result else 1.0
return [(w, p / total * 100.0) for w, p in result]
# Characters that are pure punctuation and should be glued to the
# preceding word rather than stand alone.
_PUNCT_GLUE = set('.,;:!?)]\'\"')
_PUNCT_OPEN = set('([\"\'')
def _merge_subwords(
tokens: List[str],
attributions: List[float],
) -> List[Tuple[str, float]]:
"""
Merge sub-word tokens back into whole words and sum their attributions.
- WordPiece continuations (``##xyz``) are joined to the preceding word.
- SentencePiece tokens starting with ``\u2581`` begin a new word.
- Standalone punctuation (``.``, ``,``, ``)``, …) is glued to the
preceding word so bar-chart labels stay clean.
- Opening brackets/quotes are glued to the *following* word.
- Special tokens ([CLS], [SEP], …) are dropped.
Returns a list of ``(word, importance_percent)`` sorted by position.
Percentages sum to ~100 (before any top-k truncation).
"""
words: List[str] = []
word_scores: List[float] = []
for tok, attr in zip(tokens, attributions):
if tok in _SPECIAL_TOKENS:
continue
# WordPiece continuation
if tok.startswith("##"):
if words:
words[-1] += tok[2:]
word_scores[-1] += attr
continue
# SentencePiece: strip the leading ▁
clean = tok.replace("\u2581", "")
if not clean:
continue
# Pure trailing punctuation → glue to previous word
if clean in _PUNCT_GLUE and words:
words[-1] += clean
word_scores[-1] += attr
continue
# Opening punctuation → start a new word (will be glued to next)
if clean in _PUNCT_OPEN:
words.append(clean)
word_scores.append(attr)
continue
is_new_word = tok.startswith("\u2581") or not words
if is_new_word or not words:
# If previous word is an opening bracket, glue this onto it
if words and words[-1] in _PUNCT_OPEN:
words[-1] += clean
word_scores[-1] += attr
else:
words.append(clean)
word_scores.append(attr)
else:
# sub-word continuation (no ## prefix, no \u2581 prefix)
words[-1] += clean
word_scores[-1] += attr
# Convert raw attribution sums → percentages of total
total = sum(word_scores) if word_scores else 1.0
return [(w, s / total * 100.0) for w, s in zip(words, word_scores)]