"""
Plotting utilities for visualizing higher-order feature interactions in a
text-interpretability web app.
Inputs are assumed to come from attribution utilities such as
`shapley_interactions(...)` or `banzhaf_interactions(...)`.
Outputs are Plotly Figure objects that can be rendered directly in Gradio/UI.
"""
import json
import uuid
from collections import defaultdict
from html import escape
import plotly.graph_objects as go
import numpy as np
from typing import List, Tuple, Dict, Optional
try: # optional dependency
from shapiq.interaction_values import InteractionValues # type: ignore
from shapiq.plot import bar_plot # type: ignore
except Exception: # pragma: no cover
InteractionValues = None
bar_plot = None
from .utils import format_feature_label, get_color_scale, create_legend, matplotlib_to_plotly
_TOKEN_VIEW_STYLE = """
"""
_NETWORK_VIEW_STYLE = """
"""
def _interactions_to_shapiq(
interactions: List[Tuple[Tuple[str, ...], float]],
method: str,
order: int,
) -> Tuple[Optional["InteractionValues"], List[str]]:
if InteractionValues is None or not interactions:
return None, []
feature_index: Dict[str, int] = {}
feature_names: List[str] = []
def _idx(name: str) -> int:
if name not in feature_index:
feature_index[name] = len(feature_names)
feature_names.append(name)
return feature_index[name]
lookup: Dict[Tuple[int, ...], int] = {}
values: List[float] = []
for feats, value in interactions:
if not feats:
continue
idx_tuple = tuple(sorted(_idx(f) for f in feats))
lookup[idx_tuple] = len(values)
values.append(float(value))
if not values:
return None, feature_names
if order == 1:
index = "SV" if method == "shapley" else "BV"
else:
index = "SII" if method == "shapley" else "BII"
min_order = 1 if order <= 1 else order
iv = InteractionValues(
values=np.array(values, dtype=float),
index=index,
max_order=order,
n_players=len(feature_names),
min_order=min_order,
interaction_lookup=lookup,
estimated=False,
baseline_value=0.0,
)
return iv, feature_names
def plot_top_interactions(
interactions: List[Tuple[Tuple[str, ...], float]],
top_k: int = 10,
order: int = 2,
method: str = "shapley"
) -> go.Figure:
"""
Visualize the top-k interactions as a bar chart.
Args:
interactions: List of interactions from shapley_interactions/banzhaf_interactions.
Each item is ((feature_name, ...), value).
top_k: Number of top interactions to display.
order: Interaction order (2 or 3).
method: Attribution method label ("shapley" or "banzhaf").
Returns:
Plotly Figure.
Example:
# From attribution
mobius = run_proxyspex(set_function, features, max_order=3)
interactions = shapley_interactions(mobius, order=2, top_k=10)
fig = plot_top_interactions(interactions, top_k=10, order=2, method="shapley")
"""
if not interactions:
return go.Figure().update_layout(
title="No interactions available",
template="plotly_white"
)
ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k]
if bar_plot is not None:
iv, feature_names = _interactions_to_shapiq(ranked, method, order)
if iv is not None:
ax = bar_plot(
[iv],
feature_names=feature_names,
show=False,
abbreviate=False,
max_display=top_k,
global_plot=True,
plot_base_value=True,
)
fig = ax.figure if ax is not None else None
if fig is not None:
return matplotlib_to_plotly(
fig,
title=f"Top {len(ranked)} order-{order} {method.title()} interactions",
)
labels = [
format_feature_label(" · ".join(feats), max_length=50)
for feats, _ in ranked
]
values = [val for _, val in ranked]
colors = ["#E24A33" if v >= 0 else "#348ABD" for v in values]
fig = go.Figure(
data=[
go.Bar(
y=list(reversed(labels)),
x=list(reversed(values)),
orientation="h",
marker=dict(
color=list(reversed(colors)),
line=dict(color="#f5f5f5", width=1.5),
),
text=[f"{v:.3f}" for v in reversed(values)],
textposition="outside",
cliponaxis=False,
)
]
)
fig.add_vline(x=0, line_dash="dash", line_color="#8c8c8c", line_width=1)
fig.update_layout(
title=f"Top {len(labels)} order-{order} {method.title()} interactions",
xaxis_title="Contribution",
yaxis_title="Feature group",
template="plotly_white",
hovermode="y",
legend=create_legend(method, order),
margin=dict(l=20, r=20, t=70, b=20),
height=max(360, 40 * len(labels)),
)
return fig
def plot_interaction_network(
features: List[str],
interactions: List[Tuple[Tuple[str, ...], float]],
threshold: float = 0.1
) -> go.Figure:
"""
Visualize pairwise interactions as a network graph.
Args:
features: The full ordered feature list.
interactions: Pairwise interactions [((feat_i, feat_j), value), ...].
threshold: Only show interactions with absolute value greater than this threshold.
Returns:
Plotly Figure (network-style visualization).
"""
if not interactions:
return go.Figure().update_layout(
title="No pairwise interactions available",
template="plotly_white"
)
filtered = [
item for item in interactions if abs(item[1]) >= threshold
]
if not filtered:
return go.Figure().update_layout(
title=f"No interactions exceed threshold ({threshold})",
template="plotly_white"
)
n = len(features)
angles = np.linspace(0, 2 * np.pi, max(n, 1), endpoint=False)
positions = {
feat: (float(np.cos(theta)), float(np.sin(theta)))
for feat, theta in zip(features, angles)
}
max_abs = max(abs(val) for _, val in filtered) or 1.0
traces = []
for (feat_a, feat_b), value in filtered:
if feat_a not in positions or feat_b not in positions:
continue
(x0, y0), (x1, y1) = positions[feat_a], positions[feat_b]
color = "#d73027" if value >= 0 else "#4575b4"
width = 1 + 4 * abs(value) / max_abs
label = f"{feat_a} <-> {feat_b}: {value:.3f}"
traces.append(
go.Scatter(
x=[x0, x1],
y=[y0, y1],
mode="lines",
line=dict(color=color, width=width),
hoverinfo="text",
text=[label, label],
showlegend=False,
)
)
node_trace = go.Scatter(
x=[positions[f][0] for f in features if f in positions],
y=[positions[f][1] for f in features if f in positions],
mode="markers+text",
marker=dict(size=14, color="#4a4a4a", line=dict(width=2, color="#ffffff")),
text=[format_feature_label(f, 18) for f in features if f in positions],
textposition="bottom center",
hoverinfo="text",
)
fig = go.Figure(data=traces + [node_trace])
fig.update_layout(
title="Pairwise interaction network",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
template="plotly_white",
showlegend=False,
margin=dict(l=20, r=20, t=60, b=20),
)
return fig
def plot_interaction_matrix(
features: List[str],
interactions: List[Tuple[Tuple[int, int], float]]
) -> go.Figure:
"""
Visualize pairwise interactions as a matrix heatmap.
Args:
features: Ordered feature list (labels for axes).
interactions: List of ((i, j), value) where i/j are feature indices.
Returns:
Plotly Figure (heatmap).
"""
n = len(features)
if n == 0:
return go.Figure().update_layout(
title="Interaction matrix (no features)",
template="plotly_white"
)
matrix = np.zeros((n, n), dtype=float)
for (i, j), value in interactions:
if 0 <= i < n and 0 <= j < n:
matrix[i, j] = value
matrix[j, i] = value
heatmap = go.Heatmap(
z=matrix,
x=[format_feature_label(f, 18) for f in features],
y=[format_feature_label(f, 18) for f in features],
colorscale=get_color_scale("shapley"),
colorbar=dict(title="Interaction value"),
hovertemplate=" %{y} vs %{x} value=%{z:.2e}",
)
fig = go.Figure(data=[heatmap])
fig.update_layout(
title="Pairwise interaction matrix",
xaxis=dict(side="top"),
yaxis=dict(autorange="reversed"),
template="plotly_white",
margin=dict(l=80, r=20, t=80, b=80),
)
return fig
def plot_3rd_order_interactions(
interactions: List[Tuple[Tuple[str, ...], float]],
top_k: int = 5
) -> go.Figure:
"""
Visualize third-order (triplet) interactions.
Uses grouped or stacked bars to show the top-k triplets.
Args:
interactions: Triplet interactions [((f1, f2, f3), value), ...].
top_k: Number of top triplets to plot.
Returns:
Plotly Figure.
"""
if not interactions:
return go.Figure().update_layout(
title="No third-order interactions available",
template="plotly_white"
)
ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k]
labels = [
format_feature_label(" · ".join(feats), max_length=40)
for feats, _ in ranked
]
values = [val for _, val in ranked]
colors = ["#d73027" if v >= 0 else "#4575b4" for v in values]
fig = go.Figure(
data=[
go.Bar(
x=labels,
y=values,
marker=dict(color=colors),
text=[f"{v:.3f}" for v in values],
textposition="auto",
)
]
)
fig.update_layout(
title=f"Top {len(labels)} third-order interactions",
xaxis_title="Feature triplet",
yaxis_title="Interaction value",
template="plotly_white",
margin=dict(l=60, r=20, t=60, b=100),
)
return fig
def _token_colors(value: float, max_abs: float) -> Tuple[str, str]:
if max_abs <= 0:
return "rgba(229, 226, 240, 0.7)", "rgba(193, 189, 209, 0.9)"
norm = max(-1.0, min(1.0, value / max_abs))
if norm >= 0:
base = (222, 86, 61)
else:
base = (47, 128, 237)
norm = -norm
neutral = (245, 242, 252)
r = int(round(neutral[0] + (base[0] - neutral[0]) * norm))
g = int(round(neutral[1] + (base[1] - neutral[1]) * norm))
b = int(round(neutral[2] + (base[2] - neutral[2]) * norm))
alpha = 0.35 + 0.4 * norm
return f"rgba({r}, {g}, {b}, {alpha:.3f})", f"rgb({r}, {g}, {b})"
def create_interaction_token_view(
features: List[str],
feature_values: List[float],
pairwise: List[Tuple[Tuple[str, ...], float]],
method: str = "shapley",
max_links: int = 5,
layout: str = "token",
) -> str:
"""
Render token interactions with two layers:
1) Chip list (always visible; no JS required).
2) Interactive network (needs JS; falls back gracefully).
"""
if not features:
return "
No tokens available.
"
values = list(feature_values) if feature_values else []
if len(values) < len(features):
values.extend([0.0] * (len(features) - len(values)))
# partner lookup for the always-visible chip list
adjacency: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
for feats, val in pairwise:
if len(feats) != 2:
continue
a, b = feats
adjacency[a].append((b, float(val)))
adjacency[b].append((a, float(val)))
for key in adjacency:
adjacency[key].sort(key=lambda item: abs(item[1]), reverse=True)
feature_index = {feat: idx for idx, feat in enumerate(features)}
edges: List[Tuple[int, int, float]] = []
for feats, val in pairwise:
if len(feats) != 2:
continue
a_idx = feature_index.get(feats[0])
b_idx = feature_index.get(feats[1])
if a_idx is None or b_idx is None or a_idx == b_idx:
continue
edges.append((a_idx, b_idx, float(val)))
# Fallback: if no edges are provided, synthesize simple neighbor links so the UI isn't empty.
if not edges and len(features) > 1:
for i in range(len(features) - 1):
score = 0.5 * (values[i] + values[i + 1])
edges.append((i, i + 1, float(score)))
max_abs = max((abs(v) for v in values), default=0.0) or 1.0
chips_html = _render_token_chip_view(
features,
values,
adjacency,
method,
max_abs,
max_links,
)
network_html = _render_interaction_network(features, values, edges, method)
return chips_html + network_html
def _render_interaction_network(
features: List[str],
values: List[float],
edges: List[Tuple[int, int, float]],
method: str,
) -> str:
"""
Interactive SVG network view:
- Click a token/sentence to focus it in the center and hide unrelated nodes.
- Edge labels carry interaction scores.
- Click empty space to reset the full view.
"""
max_abs_value = max((abs(v) for v in values), default=0.0) or 1.0
nodes_payload = []
for idx, token in enumerate(features):
fill, stroke = _token_colors(values[idx], max_abs_value)
nodes_payload.append(
{
"id": idx,
"label": token,
"value": float(values[idx]),
"fill": fill,
"stroke": stroke,
}
)
edges_sorted = sorted(edges, key=lambda item: abs(item[2]), reverse=True)
edges_payload = [
{"source": int(a), "target": int(b), "weight": float(val)}
for a, b, val in edges_sorted[:120]
]
graph_id = f"interaction-net-{uuid.uuid4().hex[:8]}"
data_json = json.dumps(
{"nodes": nodes_payload, "edges": edges_payload}, ensure_ascii=False
)
method_label = escape(method.title())
def _shorten(label: str, max_len: int = 48) -> str:
label = label or ""
if len(label) <= max_len:
return label
return label[: max_len - 1] + "…"
def _static_svg() -> str:
"""
Fallback rendering (no JS): radial layout with shortened, multiline labels.
"""
width, height = 820.0, 520.0
n = len(features)
if n == 0:
return "
No interactions available.
"
def _wrap(text: str, max_len: int = 24) -> str:
words = text.split()
lines: List[str] = []
current = ""
for w in words:
if len(current) + len(w) + 1 <= max_len:
current = (current + " " + w).strip()
else:
if current:
lines.append(current)
current = w
if current:
lines.append(current)
if not lines:
lines = [text[:max_len]]
return lines[:2] # at most two lines to limit height
# Limit to top nodes to reduce clutter
top_k = min(10, n)
order = sorted(range(n), key=lambda i: abs(values[i]), reverse=True)[:top_k]
radius = min(width, height) * 0.36
cx, cy = width / 2.0, height / 2.0
positions: Dict[int, Tuple[float, float]] = {}
step = 2 * np.pi / max(len(order), 1)
for i, old_idx in enumerate(order):
angle = i * step - np.pi / 2
positions[old_idx] = (
cx + radius * float(np.cos(angle)),
cy + radius * float(np.sin(angle)),
)
filtered_edges = [
(a, b, val) for a, b, val in edges_sorted[:40]
if a in positions and b in positions
]
max_edge = max((abs(val) for _, _, val in filtered_edges), default=0.0) or 1.0
max_abs_val = max((abs(values[i]) for i in order), default=0.0) or 1.0
line_elems: List[str] = []
for a_idx, b_idx, val in filtered_edges:
x1, y1 = positions[a_idx]
x2, y2 = positions[b_idx]
color = "#d35400" if val >= 0 else "#3867d6"
width_px = 1.2 + 3.5 * (abs(val) / max_edge)
line_elems.append(
f''
)
node_elems: List[str] = []
for old_idx in order:
x, y = positions[old_idx]
fill, stroke = _token_colors(values[old_idx], max_abs_val)
lines = _wrap(features[old_idx])
node_elems.append(
f''
f''
f''
+ "".join(f'{escape(line)}' for idx, line in enumerate(lines))
+ ""
f'{values[old_idx]:+0.2f}'
f''
)
return (
"