Qingpeng Kong
clean initial state
3e72399
"""
Utilities to render attribution visualizations for a text-interpretability web app.
Uses Plotly for heatmaps and inline HTML for text-based visualizations.
"""
import plotly.graph_objects as go
import numpy as np
from html import escape
from typing import List, Dict, Optional, Tuple, Any
from .utils import get_color_scale, format_feature_label, matplotlib_to_plotly
# Dummy placeholders so functions that reference these names still type-check,
# but we do NOT import heavy deps like shapiq / shap / numba in this environment.
InteractionValues = None # type: ignore
sentence_plot = None
shap = None
plt = None
_SPEX_TEXT_STYLE = """
<style id="spex-text-view-style">
.spex-text-view {
--spex-bg: #f7f5f2;
--spex-border: #e3e3ec;
--spex-card-bg: #ffffff;
--spex-card-shadow: 0 14px 30px rgba(32, 25, 40, 0.08);
--spex-text: #3d2c36;
font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif;
background: var(--spex-bg);
border: 1px solid var(--spex-border);
border-radius: 18px;
padding: 20px;
display: flex;
flex-wrap: wrap;
gap: 18px;
}
.spex-text-card {
flex: 3 1 520px;
background: var(--spex-card-bg);
border: 1px solid var(--spex-border);
border-radius: 18px;
padding: 18px;
box-shadow: var(--spex-card-shadow);
}
.spex-card-header {
display: flex;
justify-content: space-between;
align-items: flex-end;
margin-bottom: 12px;
gap: 8px;
}
.spex-card-title {
font-size: 18px;
font-weight: 600;
color: var(--spex-text);
}
.spex-card-subtitle {
font-size: 13px;
color: #7f6f86;
}
.spex-token-grid {
display: block;
font-size: 16px;
line-height: 2;
color: #111111;
word-break: break-word;
white-space: pre-wrap;
}
.spex-token {
display: inline-flex;
flex-direction: column;
align-items: center;
justify-content: center;
vertical-align: baseline;
padding: 2px 6px;
margin: 0 2px;
border-radius: 12px;
border: 1px solid transparent;
background: rgba(225, 225, 223, 0.45);
box-decoration-break: clone;
transition: box-shadow 0.15s ease, background 0.15s ease;
}
.spex-token:hover {
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.12);
}
.spex-token-score {
display: block;
font-size: 11px;
font-weight: 600;
color: #111111;
letter-spacing: 0.08em;
text-transform: uppercase;
margin-bottom: 2px;
}
.spex-token-text {
font-size: inherit;
color: #111111;
white-space: inherit;
}
.spex-token-plain {
color: #111111;
white-space: pre-wrap;
}
.spex-side-panel {
flex: 1 1 220px;
display: flex;
flex-direction: column;
gap: 12px;
}
.spex-side-card {
background: #fefcf8;
border: 1px dashed var(--spex-border);
border-radius: 16px;
padding: 16px;
}
.spex-side-card strong {
display: block;
font-size: 15px;
color: var(--spex-text);
margin-bottom: 6px;
}
.spex-legend-bar {
display: flex;
align-items: center;
gap: 8px;
margin: 12px 0;
}
.spex-legend-label {
font-size: 12px;
color: #6f5a72;
text-transform: uppercase;
letter-spacing: 0.08em;
}
.spex-legend-gradient {
flex: 1;
height: 10px;
border-radius: 999px;
background: linear-gradient(90deg, #dd1313, #e1e1df, #016d01);
}
.spex-legend-note {
font-size: 12px;
color: #6f5a72;
margin: 0;
}
.spex-raw-text {
flex-basis: 100%;
background: #ffffff;
border: 1px solid var(--spex-border);
border-radius: 16px;
padding: 16px;
box-shadow: 0 10px 18px rgba(32, 25, 40, 0.06);
}
.spex-raw-text strong {
display: block;
font-size: 14px;
color: #6f5a72;
text-transform: uppercase;
letter-spacing: 0.08em;
margin-bottom: 6px;
}
.spex-raw-text p {
margin: 0;
font-size: 13px;
line-height: 1.6;
white-space: pre-wrap;
color: #4a3b4e;
}
.spex-empty {
flex-basis: 100%;
text-align: center;
font-size: 14px;
color: #7f6f86;
}
@media (max-width: 900px) {
.spex-text-card,
.spex-side-panel {
flex: 1 1 100%;
}
}
</style>
"""
_NEGATIVE_RGB = (221, 19, 19)
_POSITIVE_RGB = (1, 109, 1)
_NEUTRAL_RGB = (225, 225, 223)
def _format_text_segment(value: str, preserve_blank: bool = False) -> str:
safe = escape(value or "")
safe = safe.replace("\n", "<br />")
if not safe and preserve_blank:
return "&nbsp;"
return safe or ""
def _normalize_span(span: Any, text_length: int) -> Tuple[int, int]:
if isinstance(span, dict):
start = span.get("start", span.get("begin", 0))
end = span.get("end", span.get("stop", span.get("finish", 0)))
else:
start, end = span
try:
start_i = int(start)
except (TypeError, ValueError):
start_i = 0
try:
end_i = int(end)
except (TypeError, ValueError):
end_i = start_i
start_i = max(0, min(text_length, start_i))
end_i = max(start_i, min(text_length, end_i))
return start_i, end_i
def _color_for_value(value: float, max_abs: float) -> Tuple[str, str, str]:
if max_abs <= 0:
rgb = _NEUTRAL_RGB
sign = "neutral"
else:
norm = max(-1.0, min(1.0, value / max_abs))
t = (norm + 1.0) / 2.0
if t < 0.5:
local = t * 2.0
rgb = tuple(
int(round(_NEGATIVE_RGB[i] + (_NEUTRAL_RGB[i] - _NEGATIVE_RGB[i]) * local))
for i in range(3)
)
else:
local = (t - 0.5) * 2.0
rgb = tuple(
int(round(_NEUTRAL_RGB[i] + (_POSITIVE_RGB[i] - _NEUTRAL_RGB[i]) * local))
for i in range(3)
)
sign = "positive" if norm > 0 else "negative" if norm < 0 else "neutral"
r, g, b = rgb
hex_color = f"#{r:02x}{g:02x}{b:02x}"
intensity = min(1.0, abs(value) / max_abs) if max_abs > 0 else 0.0
alpha = 0.25 + 0.45 * intensity
background = f"rgba({r}, {g}, {b}, {alpha:.3f})"
return hex_color, background, sign
def _build_sentence_interaction_values(values: List[float], method: str) -> Optional[InteractionValues]:
if InteractionValues is None:
return None
n_players = len(values)
if n_players == 0:
return None
lookup = {(i,): i for i in range(n_players)}
index = "SV" if method == "shapley" else ("IV" if method == "influence" else "BV")
return InteractionValues(
values=np.array(values, dtype=float),
index=index,
max_order=1,
n_players=n_players,
min_order=1,
interaction_lookup=lookup,
estimated=False,
baseline_value=0.0,
)
# def create_attribution_heatmap(
# features: List[str],
# attributions: Dict[str, float],
# method: str = "shapley",
# title: Optional[str] = None
# ) -> go.Figure:
# """
# Create a feature-level attribution heatmap.
# Args:
# features: Ordered feature list (from mask_text or tokenizer).
# attributions: Mapping from feature -> attribution value
# (e.g., from mobius_to_shapley/banzhaf).
# method: "shapley" or "banzhaf" (used in the caption/labeling).
# title: Optional chart title.
# Returns:
# A Plotly Figure object.
# Example:
# attrs = compute_attributions(model, context, answer, "shapley")
# fig = create_attribution_heatmap(attrs["features"], attrs["values"], "shapley")
# """
# values = np.array([attributions.get(f, 0.0) for f in features], dtype=float)
# if sentence_plot is not None:
# iv = _build_sentence_interaction_values(values.tolist(), method)
# if iv is not None:
# result = sentence_plot(
# iv,
# words=features,
# show=False,
# chars_per_line=80,
# )
# if result is not None:
# fig, _ = result
# return matplotlib_to_plotly(
# fig,
# title=title or f"{method.title()} token attributions",
# height=max(300, 30 * len(features)),
# )
# if shap is not None and plt is not None:
# explanation = shap.Explanation(
# values=np.array([values]),
# base_values=np.zeros(1),
# data=np.array([features], dtype=object),
# feature_names=features,
# )
# try:
# fig, ax = plt.subplots(
# figsize=(4, max(4, len(features) * 0.25)),
# constrained_layout=True,
# )
# shap.plots.heatmap(explanation, show=False, ax=ax)
# fig.canvas.draw()
# image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
# image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
# plt.close(fig)
# plotly_fig = go.Figure(go.Image(z=image))
# plotly_fig.update_xaxes(visible=False)
# plotly_fig.update_yaxes(visible=False)
# plotly_fig.update_layout(
# title=title or f"{method.title()} token attributions (SHAP heatmap)",
# margin=dict(l=0, r=0, t=60, b=0),
# )
# return plotly_fig
# except ValueError:
# plt.close("all")
# order = np.argsort(-np.abs(values))
# sorted_features = [features[i] for i in order]
# sorted_values = values[order]
# max_abs = float(np.max(np.abs(sorted_values))) if sorted_values.size else 1.0
# max_abs = max(max_abs, 1e-6)
# colorscale = get_color_scale("shapley" if method == "shapley" else method)
# heatmap = go.Heatmap(
# z=sorted_values[:, None],
# x=["Attribution"],
# y=[format_feature_label(f, max_length=30) for f in sorted_features],
# colorscale=colorscale,
# zmid=0.0,
# zmin=-max_abs,
# zmax=max_abs,
# colorbar=dict(title=f"{method.title()} value"),
# hovertemplate="%{y}<br>%{x}: %{z:.4f}<extra></extra>",
# showscale=True,
# text=[f"{v:.3f}" for v in sorted_values],
# texttemplate="%{text}",
# textfont={"color": "black"},
# )
# fig = go.Figure(data=[heatmap])
# fig.update_layout(
# title=title or f"{method.title()} token attributions",
# xaxis=dict(showticklabels=False),
# yaxis=dict(autorange="reversed"),
# margin=dict(l=120, r=40, t=60, b=40),
# height=max(300, 20 * len(sorted_features)),
# )
# return fig
# --- Build numpy array of original values --------------------------
def create_attribution_heatmap(
features: List[str],
attributions: Dict[str, float],
method: str = "shapley",
title: Optional[str] = None,
) -> go.Figure:
# 1. Pull raw values from backend
raw_values = np.array([attributions.get(f, 0.0) for f in features], dtype=float)
# No features -> empty figure
if raw_values.size == 0:
return go.Figure()
# 2. Decide how much to rescale
max_abs = float(np.max(np.abs(raw_values)))
scale = 1.0
colorbar_title = f"{method.title()} value"
if max_abs > 0.0 and max_abs < 1e-4:
# Values are extremely small (like 1e-6 etc.) → blow them up
scale = 1.0 / max_abs
colorbar_title = f"{method.title()}{scale:.0e})"
values = raw_values * scale
# 3. (Optional) use shapiq sentence_plot if available
if sentence_plot is not None:
iv = _build_sentence_interaction_values(values.tolist(), method)
if iv is not None:
result = sentence_plot(
iv,
words=features,
show=False,
chars_per_line=80,
)
if result is not None:
fig, _ = result
return matplotlib_to_plotly(
fig,
title=title or f"{method.title()} token attributions",
height=max(300, 30 * len(features)),
)
# 4. Plain Plotly heatmap (keep original order on y-axis)
sorted_features = features
sorted_values = values
abs_vals = np.abs(sorted_values)
vmax = float(np.percentile(abs_vals, 95)) if abs_vals.size else 1.0
vmax = max(vmax, 1e-6)
colorscale = get_color_scale("shapley" if method == "shapley" else method)
heatmap = go.Heatmap(
z=sorted_values[:, None],
x=["Attribution"],
y=[format_feature_label(f, max_length=60) for f in sorted_features],
colorscale=colorscale,
zmid=0.0,
zmin=-vmax,
zmax=vmax,
colorbar=dict(title=colorbar_title),
hovertemplate="%{y}<br>%{x}: %{z:.4f}<extra></extra>",
showscale=True,
)
fig = go.Figure(data=[heatmap])
fig.update_layout(
title=title or f"{method.title()} token attributions",
xaxis=dict(showticklabels=False),
yaxis=dict(autorange="reversed"),
margin=dict(l=140, r=40, t=60, b=40),
height=max(320, 22 * len(sorted_features)),
)
return fig
def create_interactive_text_heatmap(
text: str,
feature_spans: List[Any], # list of (start, end) or dict spans
attributions: List[Any],
method: str = "shapley",
) -> str:
"""
Render a Spectral Explain–style text view with token chips, legend, and raw text.
Args:
text: Original text that generated the attributions.
feature_spans: Character spans identifying each token/feature.
attributions: Numeric attribution values aligned with feature_spans.
method: Attribution method label.
Returns:
Styled HTML that can be injected into the Gradio Text View tab.
"""
if len(feature_spans) != len(attributions):
raise ValueError("feature_spans and attributions must have the same length")
source_text = text or ""
text_len = len(source_text)
tokens: List[Dict[str, Any]] = []
numeric_values: List[float] = []
for idx, (span, raw_value) in enumerate(zip(feature_spans, attributions), start=1):
start, end = _normalize_span(span, text_len)
snippet = source_text[start:end]
try:
value = float(raw_value)
except (TypeError, ValueError):
value = 0.0
tokens.append(
{
"index": idx,
"text": snippet,
"value": value,
"start": start,
"end": end,
}
)
numeric_values.append(value)
if not tokens:
fallback = _format_text_segment(source_text) or "No text available."
return (
f"{_SPEX_TEXT_STYLE}"
'<div class="spex-text-view">'
'<div class="spex-empty">No feature spans were provided for this example.</div>'
f'<div class="spex-raw-text"><strong>Raw text</strong><p>{fallback}</p></div>'
"</div>"
)
max_abs = max((abs(v) for v in numeric_values), default=0.0)
max_abs = max_abs or 1.0
method_label = (method or "attribution").title()
flow_parts: List[str] = []
cursor = 0
for token in tokens:
start = token["start"]
end = token["end"]
if cursor < start:
plain = _format_text_segment(source_text[cursor:start], preserve_blank=True)
if plain:
flow_parts.append(f'<span class="spex-token-plain">{plain}</span>')
color_hex, background, sign = _color_for_value(token["value"], max_abs)
tooltip = escape(
f"{method_label} · chars [{token['start']}:{token['end']}] · {token['value']:+.4f}"
)
text_html = _format_text_segment(token["text"], preserve_blank=True) or "&nbsp;"
flow_parts.append(
f'<span class="spex-token spex-token--{sign}" '
f'data-token-index="{token["index"]}" '
f'data-attr="{token["value"]:.6f}" '
f'style="background-color:{background}; border-color:{color_hex};" '
f'title="{tooltip}">'
f'<span class="spex-token-text">{text_html}</span>'
"</span>"
)
cursor = end
if cursor < len(source_text):
trailing = _format_text_segment(source_text[cursor:], preserve_blank=True)
if trailing:
flow_parts.append(f'<span class="spex-token-plain">{trailing}</span>')
flow_html = "".join(flow_parts) or "&nbsp;"
legend = (
'<div class="spex-side-card">'
f"<strong>{method_label} legend</strong>"
'<div class="spex-legend-bar">'
'<span class="spex-legend-label">Negative</span>'
'<div class="spex-legend-gradient"></div>'
'<span class="spex-legend-label">Positive</span>'
"</div>"
f'<p class="spex-legend-note">Normalized by max |value| = {max_abs:.4f}. Hover tokens for exact scores.</p>'
"</div>"
)
raw_text_block = ""
if source_text:
raw_text_block = (
'<div class="spex-raw-text">'
"<strong>Raw text</strong>"
f"<p>{_format_text_segment(source_text)}</p>"
"</div>"
)
body = (
f"{_SPEX_TEXT_STYLE}"
'<div class="spex-text-view">'
'<div class="spex-text-card">'
'<div class="spex-card-header">'
'<div>'
'<div class="spex-card-title">Context</div>'
f'<div class="spex-card-subtitle">{method_label} token attributions</div>'
"</div>"
f'<div class="spex-card-subtitle">Tokens: {len(tokens)}</div>'
"</div>"
f'<div class="spex-token-grid">{flow_html}</div>'
"</div>"
f'<div class="spex-side-panel">{legend}</div>'
f"{raw_text_block}"
"</div>"
)
return body
def normalize_attributions(
attributions: Dict[str, float],
method: str = "minmax"
) -> Dict[str, float]:
"""
Normalize attribution values for visualization.
Args:
attributions: Raw attribution dict {feature: value}.
method: Normalization mode: "minmax" or "zscore".
Returns:
A dict with normalized values using the same keys as the input.
"""
if not attributions:
return {}
values = np.array(list(attributions.values()), dtype=float)
if method == "zscore":
mean = float(values.mean())
std = float(values.std())
if std == 0:
std = 1.0
normalized = (values - mean) / std
else: # default to min-max
v_min = float(values.min())
v_max = float(values.max())
if v_max - v_min == 0:
normalized = np.zeros_like(values)
else:
normalized = (values - v_min) / (v_max - v_min)
normalized = normalized * 2 - 1 # center at 0 for diverging scales
return {key: float(val) for key, val in zip(attributions.keys(), normalized)}