"""
Utilities to render attribution visualizations for a text-interpretability web app.
Uses Plotly for heatmaps and inline HTML for text-based visualizations.
"""
import plotly.graph_objects as go
import numpy as np
from html import escape
from typing import List, Dict, Optional, Tuple, Any
from .utils import get_color_scale, format_feature_label, matplotlib_to_plotly
# Dummy placeholders so functions that reference these names still type-check,
# but we do NOT import heavy deps like shapiq / shap / numba in this environment.
InteractionValues = None # type: ignore
sentence_plot = None
shap = None
plt = None
_SPEX_TEXT_STYLE = """
"""
_NEGATIVE_RGB = (221, 19, 19)
_POSITIVE_RGB = (1, 109, 1)
_NEUTRAL_RGB = (225, 225, 223)
def _format_text_segment(value: str, preserve_blank: bool = False) -> str:
safe = escape(value or "")
safe = safe.replace("\n", "
")
if not safe and preserve_blank:
return " "
return safe or ""
def _normalize_span(span: Any, text_length: int) -> Tuple[int, int]:
if isinstance(span, dict):
start = span.get("start", span.get("begin", 0))
end = span.get("end", span.get("stop", span.get("finish", 0)))
else:
start, end = span
try:
start_i = int(start)
except (TypeError, ValueError):
start_i = 0
try:
end_i = int(end)
except (TypeError, ValueError):
end_i = start_i
start_i = max(0, min(text_length, start_i))
end_i = max(start_i, min(text_length, end_i))
return start_i, end_i
def _color_for_value(value: float, max_abs: float) -> Tuple[str, str, str]:
if max_abs <= 0:
rgb = _NEUTRAL_RGB
sign = "neutral"
else:
norm = max(-1.0, min(1.0, value / max_abs))
t = (norm + 1.0) / 2.0
if t < 0.5:
local = t * 2.0
rgb = tuple(
int(round(_NEGATIVE_RGB[i] + (_NEUTRAL_RGB[i] - _NEGATIVE_RGB[i]) * local))
for i in range(3)
)
else:
local = (t - 0.5) * 2.0
rgb = tuple(
int(round(_NEUTRAL_RGB[i] + (_POSITIVE_RGB[i] - _NEUTRAL_RGB[i]) * local))
for i in range(3)
)
sign = "positive" if norm > 0 else "negative" if norm < 0 else "neutral"
r, g, b = rgb
hex_color = f"#{r:02x}{g:02x}{b:02x}"
intensity = min(1.0, abs(value) / max_abs) if max_abs > 0 else 0.0
alpha = 0.25 + 0.45 * intensity
background = f"rgba({r}, {g}, {b}, {alpha:.3f})"
return hex_color, background, sign
def _build_sentence_interaction_values(values: List[float], method: str) -> Optional[InteractionValues]:
if InteractionValues is None:
return None
n_players = len(values)
if n_players == 0:
return None
lookup = {(i,): i for i in range(n_players)}
index = "SV" if method == "shapley" else ("IV" if method == "influence" else "BV")
return InteractionValues(
values=np.array(values, dtype=float),
index=index,
max_order=1,
n_players=n_players,
min_order=1,
interaction_lookup=lookup,
estimated=False,
baseline_value=0.0,
)
# def create_attribution_heatmap(
# features: List[str],
# attributions: Dict[str, float],
# method: str = "shapley",
# title: Optional[str] = None
# ) -> go.Figure:
# """
# Create a feature-level attribution heatmap.
# Args:
# features: Ordered feature list (from mask_text or tokenizer).
# attributions: Mapping from feature -> attribution value
# (e.g., from mobius_to_shapley/banzhaf).
# method: "shapley" or "banzhaf" (used in the caption/labeling).
# title: Optional chart title.
# Returns:
# A Plotly Figure object.
# Example:
# attrs = compute_attributions(model, context, answer, "shapley")
# fig = create_attribution_heatmap(attrs["features"], attrs["values"], "shapley")
# """
# values = np.array([attributions.get(f, 0.0) for f in features], dtype=float)
# if sentence_plot is not None:
# iv = _build_sentence_interaction_values(values.tolist(), method)
# if iv is not None:
# result = sentence_plot(
# iv,
# words=features,
# show=False,
# chars_per_line=80,
# )
# if result is not None:
# fig, _ = result
# return matplotlib_to_plotly(
# fig,
# title=title or f"{method.title()} token attributions",
# height=max(300, 30 * len(features)),
# )
# if shap is not None and plt is not None:
# explanation = shap.Explanation(
# values=np.array([values]),
# base_values=np.zeros(1),
# data=np.array([features], dtype=object),
# feature_names=features,
# )
# try:
# fig, ax = plt.subplots(
# figsize=(4, max(4, len(features) * 0.25)),
# constrained_layout=True,
# )
# shap.plots.heatmap(explanation, show=False, ax=ax)
# fig.canvas.draw()
# image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
# image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
# plt.close(fig)
# plotly_fig = go.Figure(go.Image(z=image))
# plotly_fig.update_xaxes(visible=False)
# plotly_fig.update_yaxes(visible=False)
# plotly_fig.update_layout(
# title=title or f"{method.title()} token attributions (SHAP heatmap)",
# margin=dict(l=0, r=0, t=60, b=0),
# )
# return plotly_fig
# except ValueError:
# plt.close("all")
# order = np.argsort(-np.abs(values))
# sorted_features = [features[i] for i in order]
# sorted_values = values[order]
# max_abs = float(np.max(np.abs(sorted_values))) if sorted_values.size else 1.0
# max_abs = max(max_abs, 1e-6)
# colorscale = get_color_scale("shapley" if method == "shapley" else method)
# heatmap = go.Heatmap(
# z=sorted_values[:, None],
# x=["Attribution"],
# y=[format_feature_label(f, max_length=30) for f in sorted_features],
# colorscale=colorscale,
# zmid=0.0,
# zmin=-max_abs,
# zmax=max_abs,
# colorbar=dict(title=f"{method.title()} value"),
# hovertemplate="%{y}
%{x}: %{z:.4f}
%{x}: %{z:.4f}
{fallback}
Normalized by max |value| = {max_abs:.4f}. Hover tokens for exact scores.
' "{_format_text_segment(source_text)}
" "