AttrLLM / visualization /plotting /interactions.py.backup
Qingpeng Kong
clean initial state
3e72399
"""
Plotting utilities for visualizing higher-order feature interactions in a
text-interpretability web app.
Inputs are assumed to come from attribution utilities such as
`shapley_interactions(...)` or `banzhaf_interactions(...)`.
Outputs are Plotly Figure objects that can be rendered directly in Gradio/UI.
"""
import json
import uuid
from collections import defaultdict
from html import escape
import plotly.graph_objects as go
import numpy as np
from typing import List, Tuple, Dict, Optional
try: # optional dependency
from shapiq.interaction_values import InteractionValues # type: ignore
from shapiq.plot import bar_plot # type: ignore
except Exception: # pragma: no cover
InteractionValues = None
bar_plot = None
from .utils import format_feature_label, get_color_scale, create_legend, matplotlib_to_plotly
_TOKEN_VIEW_STYLE = """
<style>
.token-interaction-view {
--token-border: #e0d9f0;
--token-active: #7048e8;
--token-bg: #faf8ff;
--panel-bg: linear-gradient(135deg, #f6f3ff 0%, #f0f4ff 100%);
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif;
background: var(--token-bg);
border: 2px solid var(--token-border);
border-radius: 20px;
padding: 24px;
display: flex;
flex-direction: column;
gap: 20px;
box-shadow: 0 4px 20px rgba(88, 60, 140, 0.08);
margin: 16px 0;
}
.token-interaction-panel {
background: var(--panel-bg);
border-radius: 16px;
padding: 16px 20px;
border: 1.5px solid rgba(112, 72, 232, 0.15);
font-size: 14px;
font-weight: 500;
color: #2d1f4a;
line-height: 1.6;
box-shadow: 0 2px 12px rgba(112, 72, 232, 0.06);
}
.token-interaction-grid {
display: flex;
flex-wrap: wrap;
gap: 16px;
}
.interaction-token {
border-radius: 16px;
border: 2px solid #e7e2f5;
padding: 14px 16px;
width: min(260px, 100%);
background: linear-gradient(135deg, #ffffff 0%, #faf8ff 100%);
transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
cursor: pointer;
box-shadow: 0 2px 8px rgba(88, 60, 140, 0.06);
}
.interaction-token:hover {
transform: translateY(-2px);
box-shadow: 0 8px 24px rgba(112, 72, 232, 0.15);
border-color: rgba(112, 72, 232, 0.3);
}
.interaction-token[open] {
border-color: var(--token-active);
box-shadow: 0 12px 32px rgba(112, 72, 232, 0.22);
transform: translateY(-3px);
background: linear-gradient(135deg, #ffffff 0%, #f6f3ff 100%);
}
.interaction-token summary {
list-style: none;
cursor: pointer;
display: flex;
flex-direction: column;
gap: 6px;
}
.interaction-token summary::-webkit-details-marker {
display: none;
}
.interaction-token__score {
font-size: 13px;
font-weight: 700;
letter-spacing: 0.03em;
color: #7048e8;
background: linear-gradient(135deg, #7048e8 0%, #9b6dff 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.interaction-token__text {
font-size: 14px;
font-weight: 500;
color: #2d1f4a;
white-space: normal;
overflow: hidden;
text-overflow: ellipsis;
line-height: 1.5;
}
.interaction-token__hint {
font-size: 12px;
color: #7a6b99;
font-weight: 500;
margin-top: 2px;
}
.token-link-list {
list-style: none;
padding: 12px 0 0 0;
margin: 12px 0 0 0;
border-top: 2px solid rgba(112, 72, 232, 0.1);
}
.token-link-list li {
display: flex;
justify-content: space-between;
align-items: center;
padding: 8px 0;
font-size: 13px;
}
.token-link-list li + li {
border-top: 1.5px dashed rgba(112, 72, 232, 0.12);
}
.token-link-name {
color: #3a2f50;
font-weight: 500;
margin-right: 12px;
flex: 1;
}
.token-link-value {
font-weight: 700;
font-size: 13px;
color: #7048e8;
background: rgba(112, 72, 232, 0.08);
padding: 4px 10px;
border-radius: 999px;
}
.token-interaction-empty {
font-size: 13px;
font-weight: 500;
color: #9a8bb5;
margin-top: 8px;
font-style: italic;
}
.sentence-interaction-view {
--token-border: #d9d5e0;
--token-bg: #fbf8ff;
display: flex;
flex-direction: column;
gap: 16px;
padding: 18px;
border: 1px solid var(--token-border);
border-radius: 18px;
background: var(--token-bg);
font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif;
}
.sentence-interaction-header {
font-size: 14px;
font-weight: 600;
color: #4a3c71;
}
.sentence-interaction-body {
display: grid;
grid-template-columns: minmax(0, 1fr) 240px;
gap: 24px;
}
.sentence-list {
display: flex;
flex-direction: column;
gap: 10px;
}
.sentence-row {
display: grid;
grid-template-columns: 90px minmax(0, 1fr);
gap: 16px;
align-items: center;
padding: 10px 14px;
border-radius: 14px;
border: 1px solid var(--token-border);
background: #fff;
box-shadow: 0 8px 18px rgba(64, 58, 95, 0.08);
height: 64px;
}
.sentence-row__score {
font-size: 13px;
font-weight: 600;
color: #2a1f44;
}
.sentence-row__text {
font-size: 13px;
color: #2c233b;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.sentence-row__links {
grid-column: 1 / -1;
font-size: 11px;
color: #6d6483;
display: flex;
gap: 6px;
flex-wrap: wrap;
}
.sentence-row__badge {
background: rgba(112, 72, 232, 0.08);
border-radius: 999px;
padding: 2px 8px;
font-weight: 600;
color: #5f43c2;
}
.sentence-links {
position: relative;
min-height: 160px;
}
.sentence-links svg {
width: 100%;
height: var(--canvas-height, 200px);
overflow: visible;
}
.sentence-link-path {
fill: none;
opacity: 0.9;
}
.sentence-link-label {
font-size: 11px;
fill: #4b3f66;
}
.sentence-node {
fill: #fff;
stroke: #c3bed7;
stroke-width: 1.5;
}
</style>
"""
_NETWORK_VIEW_STYLE = """
<style>
.interaction-network {
position: relative;
background: linear-gradient(135deg, #faf8ff 0%, #f0f4ff 100%);
border: 2px solid #e0d9f0;
border-radius: 20px;
padding: 20px 24px 16px;
box-shadow: 0 8px 32px rgba(88, 60, 140, 0.12), 0 2px 8px rgba(88, 60, 140, 0.08);
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif;
margin: 16px 0;
}
.network-toolbar {
display: flex;
justify-content: space-between;
align-items: center;
gap: 16px;
margin-bottom: 12px;
color: #2d1f4a;
}
.network-title {
font-weight: 700;
font-size: 18px;
letter-spacing: -0.01em;
color: #2d1f4a;
background: linear-gradient(135deg, #7048e8 0%, #9b6dff 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
.network-hint {
font-size: 13px;
font-weight: 500;
color: #5b4a7a;
background: linear-gradient(135deg, rgba(112, 72, 232, 0.12) 0%, rgba(155, 109, 255, 0.12) 100%);
padding: 8px 16px;
border-radius: 999px;
border: 1.5px solid rgba(112, 72, 232, 0.18);
box-shadow: 0 2px 8px rgba(112, 72, 232, 0.08);
}
.network-svg {
width: 100%;
height: 580px;
border-radius: 16px;
background: #ffffff;
border: 2px solid #ebe7f5;
cursor: pointer;
box-shadow: inset 0 2px 8px rgba(112, 72, 232, 0.04);
}
@keyframes fadeIn {
from { opacity: 0; transform: scale(0.96) translateY(4px); }
to { opacity: 1; transform: scale(1) translateY(0); }
}
@keyframes pulse {
0%, 100% { transform: scale(1); }
50% { transform: scale(1.05); }
}
.network-node {
animation: fadeIn 0.4s cubic-bezier(0.34, 1.56, 0.64, 1);
cursor: pointer;
transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
}
.network-node:hover {
animation: pulse 0.6s ease-in-out infinite;
}
.network-edge {
animation: fadeIn 0.5s ease-out;
stroke-linecap: round;
opacity: 0.88;
transition: all 0.3s ease;
}
.edge-label {
font-size: 12px;
font-weight: 700;
fill: #3a2858;
paint-order: stroke;
stroke: #ffffff;
stroke-width: 4px;
text-anchor: middle;
opacity: 0.95;
}
.network-node text {
font-size: 13px;
fill: #1a0f2e;
pointer-events: none;
font-weight: 600;
}
.network-node circle {
stroke-width: 2.5px;
filter: drop-shadow(0 4px 16px rgba(112, 72, 232, 0.2));
transition: all 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
}
.network-node:hover circle {
filter: drop-shadow(0 8px 24px rgba(112, 72, 232, 0.35));
stroke-width: 3px;
}
.network-status {
font-size: 13px;
font-weight: 500;
color: #5b4a7a;
margin-top: 12px;
padding: 10px 16px;
background: linear-gradient(135deg, rgba(112, 72, 232, 0.06) 0%, rgba(155, 109, 255, 0.06) 100%);
border-radius: 12px;
border: 1.5px solid rgba(112, 72, 232, 0.12);
text-align: center;
box-shadow: 0 2px 8px rgba(112, 72, 232, 0.06);
}
.network-empty {
padding: 20px 16px;
border-radius: 14px;
background: linear-gradient(135deg, #fff5f5 0%, #ffe8e8 100%);
border: 2px solid #ffd0d0;
color: #9a2a42;
font-size: 14px;
font-weight: 500;
text-align: center;
}
.interaction-fallback {
position: relative;
}
</style>
"""
def _interactions_to_shapiq(
interactions: List[Tuple[Tuple[str, ...], float]],
method: str,
order: int,
) -> Tuple[Optional["InteractionValues"], List[str]]:
if InteractionValues is None or not interactions:
return None, []
feature_index: Dict[str, int] = {}
feature_names: List[str] = []
def _idx(name: str) -> int:
if name not in feature_index:
feature_index[name] = len(feature_names)
feature_names.append(name)
return feature_index[name]
lookup: Dict[Tuple[int, ...], int] = {}
values: List[float] = []
for feats, value in interactions:
if not feats:
continue
idx_tuple = tuple(sorted(_idx(f) for f in feats))
lookup[idx_tuple] = len(values)
values.append(float(value))
if not values:
return None, feature_names
if order == 1:
index = "SV" if method == "shapley" else "BV"
else:
index = "SII" if method == "shapley" else "BII"
min_order = 1 if order <= 1 else order
iv = InteractionValues(
values=np.array(values, dtype=float),
index=index,
max_order=order,
n_players=len(feature_names),
min_order=min_order,
interaction_lookup=lookup,
estimated=False,
baseline_value=0.0,
)
return iv, feature_names
def plot_top_interactions(
interactions: List[Tuple[Tuple[str, ...], float]],
top_k: int = 10,
order: int = 2,
method: str = "shapley"
) -> go.Figure:
"""
Visualize the top-k interactions as a bar chart.
Args:
interactions: List of interactions from shapley_interactions/banzhaf_interactions.
Each item is ((feature_name, ...), value).
top_k: Number of top interactions to display.
order: Interaction order (2 or 3).
method: Attribution method label ("shapley" or "banzhaf").
Returns:
Plotly Figure.
Example:
# From attribution
mobius = run_proxyspex(set_function, features, max_order=3)
interactions = shapley_interactions(mobius, order=2, top_k=10)
fig = plot_top_interactions(interactions, top_k=10, order=2, method="shapley")
"""
if not interactions:
return go.Figure().update_layout(
title="No interactions available",
template="plotly_white"
)
ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k]
if bar_plot is not None:
iv, feature_names = _interactions_to_shapiq(ranked, method, order)
if iv is not None:
ax = bar_plot(
[iv],
feature_names=feature_names,
show=False,
abbreviate=False,
max_display=top_k,
global_plot=True,
plot_base_value=True,
)
fig = ax.figure if ax is not None else None
if fig is not None:
return matplotlib_to_plotly(
fig,
title=f"Top {len(ranked)} order-{order} {method.title()} interactions",
)
labels = [
format_feature_label(" · ".join(feats), max_length=50)
for feats, _ in ranked
]
values = [val for _, val in ranked]
colors = ["#E24A33" if v >= 0 else "#348ABD" for v in values]
fig = go.Figure(
data=[
go.Bar(
y=list(reversed(labels)),
x=list(reversed(values)),
orientation="h",
marker=dict(
color=list(reversed(colors)),
line=dict(color="#f5f5f5", width=1.5),
),
text=[f"{v:.3f}" for v in reversed(values)],
textposition="outside",
cliponaxis=False,
)
]
)
fig.add_vline(x=0, line_dash="dash", line_color="#8c8c8c", line_width=1)
fig.update_layout(
title=f"Top {len(labels)} order-{order} {method.title()} interactions",
xaxis_title="Contribution",
yaxis_title="Feature group",
template="plotly_white",
hovermode="y",
legend=create_legend(method, order),
margin=dict(l=20, r=20, t=70, b=20),
height=max(360, 40 * len(labels)),
)
return fig
def plot_interaction_network(
features: List[str],
interactions: List[Tuple[Tuple[str, ...], float]],
threshold: float = 0.1
) -> go.Figure:
"""
Visualize pairwise interactions as a network graph.
Args:
features: The full ordered feature list.
interactions: Pairwise interactions [((feat_i, feat_j), value), ...].
threshold: Only show interactions with absolute value greater than this threshold.
Returns:
Plotly Figure (network-style visualization).
"""
if not interactions:
return go.Figure().update_layout(
title="No pairwise interactions available",
template="plotly_white"
)
filtered = [
item for item in interactions if abs(item[1]) >= threshold
]
if not filtered:
return go.Figure().update_layout(
title=f"No interactions exceed threshold ({threshold})",
template="plotly_white"
)
n = len(features)
angles = np.linspace(0, 2 * np.pi, max(n, 1), endpoint=False)
positions = {
feat: (float(np.cos(theta)), float(np.sin(theta)))
for feat, theta in zip(features, angles)
}
max_abs = max(abs(val) for _, val in filtered) or 1.0
traces = []
for (feat_a, feat_b), value in filtered:
if feat_a not in positions or feat_b not in positions:
continue
(x0, y0), (x1, y1) = positions[feat_a], positions[feat_b]
color = "#d73027" if value >= 0 else "#4575b4"
width = 1 + 4 * abs(value) / max_abs
label = f"{feat_a} <-> {feat_b}: {value:.3f}"
traces.append(
go.Scatter(
x=[x0, x1],
y=[y0, y1],
mode="lines",
line=dict(color=color, width=width),
hoverinfo="text",
text=[label, label],
showlegend=False,
)
)
node_trace = go.Scatter(
x=[positions[f][0] for f in features if f in positions],
y=[positions[f][1] for f in features if f in positions],
mode="markers+text",
marker=dict(size=14, color="#4a4a4a", line=dict(width=2, color="#ffffff")),
text=[format_feature_label(f, 18) for f in features if f in positions],
textposition="bottom center",
hoverinfo="text",
)
fig = go.Figure(data=traces + [node_trace])
fig.update_layout(
title="Pairwise interaction network",
xaxis=dict(visible=False),
yaxis=dict(visible=False),
template="plotly_white",
showlegend=False,
margin=dict(l=20, r=20, t=60, b=20),
)
return fig
def plot_interaction_matrix(
features: List[str],
interactions: List[Tuple[Tuple[int, int], float]]
) -> go.Figure:
"""
Visualize pairwise interactions as a matrix heatmap.
Args:
features: Ordered feature list (labels for axes).
interactions: List of ((i, j), value) where i/j are feature indices.
Returns:
Plotly Figure (heatmap).
"""
n = len(features)
if n == 0:
return go.Figure().update_layout(
title="Interaction matrix (no features)",
template="plotly_white"
)
matrix = np.zeros((n, n), dtype=float)
for (i, j), value in interactions:
if 0 <= i < n and 0 <= j < n:
matrix[i, j] = value
matrix[j, i] = value
heatmap = go.Heatmap(
z=matrix,
x=[format_feature_label(f, 18) for f in features],
y=[format_feature_label(f, 18) for f in features],
colorscale=get_color_scale("shapley"),
colorbar=dict(title="Interaction value"),
hovertemplate=" %{y} vs %{x}<br>value=%{z:.2e}<extra></extra>",
)
fig = go.Figure(data=[heatmap])
fig.update_layout(
title="Pairwise interaction matrix",
xaxis=dict(side="top"),
yaxis=dict(autorange="reversed"),
template="plotly_white",
margin=dict(l=80, r=20, t=80, b=80),
)
return fig
def plot_3rd_order_interactions(
interactions: List[Tuple[Tuple[str, ...], float]],
top_k: int = 5
) -> go.Figure:
"""
Visualize third-order (triplet) interactions.
Uses grouped or stacked bars to show the top-k triplets.
Args:
interactions: Triplet interactions [((f1, f2, f3), value), ...].
top_k: Number of top triplets to plot.
Returns:
Plotly Figure.
"""
if not interactions:
return go.Figure().update_layout(
title="No third-order interactions available",
template="plotly_white"
)
ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k]
labels = [
format_feature_label(" · ".join(feats), max_length=40)
for feats, _ in ranked
]
values = [val for _, val in ranked]
colors = ["#d73027" if v >= 0 else "#4575b4" for v in values]
fig = go.Figure(
data=[
go.Bar(
x=labels,
y=values,
marker=dict(color=colors),
text=[f"{v:.3f}" for v in values],
textposition="auto",
)
]
)
fig.update_layout(
title=f"Top {len(labels)} third-order interactions",
xaxis_title="Feature triplet",
yaxis_title="Interaction value",
template="plotly_white",
margin=dict(l=60, r=20, t=60, b=100),
)
return fig
def _token_colors(value: float, max_abs: float) -> Tuple[str, str]:
if max_abs <= 0:
return "rgba(229, 226, 240, 0.7)", "rgba(193, 189, 209, 0.9)"
norm = max(-1.0, min(1.0, value / max_abs))
if norm >= 0:
base = (222, 86, 61)
else:
base = (47, 128, 237)
norm = -norm
neutral = (245, 242, 252)
r = int(round(neutral[0] + (base[0] - neutral[0]) * norm))
g = int(round(neutral[1] + (base[1] - neutral[1]) * norm))
b = int(round(neutral[2] + (base[2] - neutral[2]) * norm))
alpha = 0.35 + 0.4 * norm
return f"rgba({r}, {g}, {b}, {alpha:.3f})", f"rgb({r}, {g}, {b})"
def create_interaction_token_view(
features: List[str],
feature_values: List[float],
pairwise: List[Tuple[Tuple[str, ...], float]],
method: str = "shapley",
max_links: int = 5,
layout: str = "token",
) -> str:
"""
Render token interactions with two layers:
1) Chip list (always visible; no JS required).
2) Interactive network (needs JS; falls back gracefully).
"""
if not features:
return "<div class='token-interaction-empty'>No tokens available.</div>"
values = list(feature_values) if feature_values else []
if len(values) < len(features):
values.extend([0.0] * (len(features) - len(values)))
# partner lookup for the always-visible chip list
adjacency: Dict[str, List[Tuple[str, float]]] = defaultdict(list)
for feats, val in pairwise:
if len(feats) != 2:
continue
a, b = feats
adjacency[a].append((b, float(val)))
adjacency[b].append((a, float(val)))
for key in adjacency:
adjacency[key].sort(key=lambda item: abs(item[1]), reverse=True)
feature_index = {feat: idx for idx, feat in enumerate(features)}
edges: List[Tuple[int, int, float]] = []
for feats, val in pairwise:
if len(feats) != 2:
continue
a_idx = feature_index.get(feats[0])
b_idx = feature_index.get(feats[1])
if a_idx is None or b_idx is None or a_idx == b_idx:
continue
edges.append((a_idx, b_idx, float(val)))
# Fallback: if no edges are provided, synthesize simple neighbor links so the UI isn't empty.
if not edges and len(features) > 1:
for i in range(len(features) - 1):
score = 0.5 * (values[i] + values[i + 1])
edges.append((i, i + 1, float(score)))
max_abs = max((abs(v) for v in values), default=0.0) or 1.0
chips_html = _render_token_chip_view(
features,
values,
adjacency,
method,
max_abs,
max_links,
)
network_html = _render_interaction_network(features, values, edges, method)
return chips_html + network_html
def _render_interaction_network(
features: List[str],
values: List[float],
edges: List[Tuple[int, int, float]],
method: str,
) -> str:
"""
Interactive SVG network view:
- Click a token/sentence to focus it in the center and hide unrelated nodes.
- Edge labels carry interaction scores.
- Click empty space to reset the full view.
"""
max_abs_value = max((abs(v) for v in values), default=0.0) or 1.0
nodes_payload = []
for idx, token in enumerate(features):
fill, stroke = _token_colors(values[idx], max_abs_value)
nodes_payload.append(
{
"id": idx,
"label": token,
"value": float(values[idx]),
"fill": fill,
"stroke": stroke,
}
)
edges_sorted = sorted(edges, key=lambda item: abs(item[2]), reverse=True)
edges_payload = [
{"source": int(a), "target": int(b), "weight": float(val)}
for a, b, val in edges_sorted[:120]
]
graph_id = f"interaction-net-{uuid.uuid4().hex[:8]}"
data_json = json.dumps(
{"nodes": nodes_payload, "edges": edges_payload}, ensure_ascii=False
)
method_label = escape(method.title())
def _shorten(label: str, max_len: int = 48) -> str:
label = label or ""
if len(label) <= max_len:
return label
return label[: max_len - 1] + "…"
def _static_svg() -> str:
"""
Fallback rendering (no JS): radial layout with shortened, multiline labels.
"""
width, height = 820.0, 520.0
n = len(features)
if n == 0:
return "<div class='network-status'>No interactions available.</div>"
def _wrap(text: str, max_len: int = 24) -> str:
words = text.split()
lines: List[str] = []
current = ""
for w in words:
if len(current) + len(w) + 1 <= max_len:
current = (current + " " + w).strip()
else:
if current:
lines.append(current)
current = w
if current:
lines.append(current)
if not lines:
lines = [text[:max_len]]
return lines[:2] # at most two lines to limit height
# Limit to top nodes to reduce clutter
top_k = min(10, n)
order = sorted(range(n), key=lambda i: abs(values[i]), reverse=True)[:top_k]
radius = min(width, height) * 0.36
cx, cy = width / 2.0, height / 2.0
positions: Dict[int, Tuple[float, float]] = {}
step = 2 * np.pi / max(len(order), 1)
for i, old_idx in enumerate(order):
angle = i * step - np.pi / 2
positions[old_idx] = (
cx + radius * float(np.cos(angle)),
cy + radius * float(np.sin(angle)),
)
filtered_edges = [
(a, b, val) for a, b, val in edges_sorted[:40]
if a in positions and b in positions
]
max_edge = max((abs(val) for _, _, val in filtered_edges), default=0.0) or 1.0
max_abs_val = max((abs(values[i]) for i in order), default=0.0) or 1.0
line_elems: List[str] = []
for a_idx, b_idx, val in filtered_edges:
x1, y1 = positions[a_idx]
x2, y2 = positions[b_idx]
color = "#d35400" if val >= 0 else "#3867d6"
width_px = 1.2 + 3.5 * (abs(val) / max_edge)
line_elems.append(
f'<line x1="{x1:.1f}" y1="{y1:.1f}" x2="{x2:.1f}" y2="{y2:.1f}" '
f'stroke="{color}" stroke-width="{width_px:.2f}" opacity="0.85" />'
)
node_elems: List[str] = []
for old_idx in order:
x, y = positions[old_idx]
fill, stroke = _token_colors(values[old_idx], max_abs_val)
lines = _wrap(features[old_idx])
node_elems.append(
f'<g class="network-node">'
f'<circle cx="{x:.1f}" cy="{y:.1f}" r="18" fill="{fill}" stroke="{stroke}" stroke-width="2" />'
f'<text x="{x:.1f}" y="{y - 4:.1f}" font-size="12" font-weight="600" fill="#1f1533" text-anchor="middle">'
+ "".join(f'<tspan x="{x:.1f}" dy="{12 if idx else 0}">{escape(line)}</tspan>' for idx, line in enumerate(lines))
+ "</text>"
f'<text x="{x:.1f}" y="{y + 20:.1f}" font-size="11" fill="#5c4c78" text-anchor="middle">{values[old_idx]:+0.2f}</text>'
f'</g>'
)
return (
"<div id='__ID__-fallback' class='interaction-fallback'>"
"<div class='network-status'>Static view (scripts blocked)</div>"
f'<svg class="network-svg" viewBox="0 0 {width:.0f} {height:.0f}" '
'preserveAspectRatio="xMidYMid meet">'
+ "".join(line_elems + node_elems) +
"</svg></div>"
)
# Render interactive SVG with JS + a static fallback that hides when scripts run.
# Use unique IDs and properly escaped JavaScript
js_code = f"""
(function() {{
const graphId = "{graph_id}";
const data = {data_json};
function initGraph() {{
const svg = document.getElementById(graphId);
const status = document.getElementById(graphId + "-status");
const fallback = document.getElementById(graphId + "-fallback");
if (!svg) {{
console.error("SVG element not found:", graphId);
if (fallback) fallback.style.display = "block";
if (status) status.textContent = "Failed to load interactive view";
return;
}}
// Hide fallback since interactive is loading
if (fallback) fallback.style.display = "none";
// shorten labels for display
data.nodes = data.nodes.map((n) => ({{
...n,
labelShort: n.label.length > 36 ? n.label.slice(0, 35) + "…" : n.label,
}}));
let selected = null;
const MAX_EDGES = 60;
const setStatus = (msg) => {{ if (status) status.textContent = msg; }};
const size = () => {{
const rect = svg.getBoundingClientRect();
return {{
w: Math.max(rect.width || 960, 720),
h: Math.max(rect.height || 540, 420),
}};
}};
const layoutPositions = (nodes, edges, selectedId, w, h) => {{
const cx = w / 2;
const cy = h / 2;
const positions = {{}};
if (selectedId !== null) {{
// Radial focus layout
positions[selectedId] = {{ x: cx, y: cy }};
const neighbors = nodes.filter((n) => n.id !== selectedId);
const angleStep = (2 * Math.PI) / Math.max(neighbors.length, 1);
const radius = Math.min(w, h) * 0.28;
neighbors.forEach((n, i) => {{
const angle = i * angleStep - Math.PI / 2;
positions[n.id] = {{
x: cx + radius * Math.cos(angle),
y: cy + radius * Math.sin(angle),
}};
}});
}} else {{
// Simple force-directed layout approximation
nodes.forEach((n, i) => {{
const angle = ((i / nodes.length) * 2 * Math.PI) - Math.PI / 2;
const r = Math.min(w, h) * 0.35;
positions[n.id] = {{
x: cx + r * Math.cos(angle),
y: cy + r * Math.sin(angle),
}};
}});
// Simple force simulation (3 iterations)
for (let iter = 0; iter < 3; iter++) {{
const forces = {{}};
nodes.forEach((n) => {{ forces[n.id] = {{ x: 0, y: 0 }}; }});
// Repulsion between all nodes
for (let i = 0; i < nodes.length; i++) {{
for (let j = i + 1; j < nodes.length; j++) {{
const ni = nodes[i].id;
const nj = nodes[j].id;
const dx = positions[nj].x - positions[ni].x;
const dy = positions[nj].y - positions[ni].y;
const dist = Math.sqrt(dx * dx + dy * dy) || 1;
const force = 800 / (dist * dist);
forces[ni].x -= (dx / dist) * force;
forces[ni].y -= (dy / dist) * force;
forces[nj].x += (dx / dist) * force;
forces[nj].y += (dy / dist) * force;
}}
}}
// Apply forces
nodes.forEach((n) => {{
positions[n.id].x += forces[n.id].x * 0.1;
positions[n.id].y += forces[n.id].y * 0.1;
// Keep within bounds
positions[n.id].x = Math.max(50, Math.min(w - 50, positions[n.id].x));
positions[n.id].y = Math.max(50, Math.min(h - 50, positions[n.id].y));
}});
}}
}}
return positions;
}};
const render = () => {{
const {{ w, h }} = size();
while (svg.firstChild) svg.removeChild(svg.firstChild);
const graphId = "{graph_id}";
const data = {data_json};
function initGraph() {{
const svg = document.getElementById(graphId);
const status = document.getElementById(graphId + "-status");
const fallback = document.getElementById(graphId + "-fallback");
if (!svg) {{
console.error("SVG element not found:", graphId);
if (fallback) fallback.style.display = "block";
if (status) status.textContent = "Failed to load interactive view";
return;
}}
// Hide fallback since interactive is loading
if (fallback) fallback.style.display = "none";
// shorten labels for display
data.nodes = data.nodes.map((n) => ({{
...n,
labelShort: n.label.length > 36 ? n.label.slice(0, 35) + "…" : n.label,
}}));
let selected = null;
const MAX_EDGES = 60;
const setStatus = (msg) => {{ if (status) status.textContent = msg; }};
const size = () => {{
const rect = svg.getBoundingClientRect();
return {{
w: Math.max(rect.width || 960, 720),
h: Math.max(rect.height || 540, 420),
}};
}};
const layoutPositions = (nodes, edges, selectedId, w, h) => {{
const cx = w / 2;
const cy = h / 2;
const positions = {{}};
if (selectedId !== null) {{
// FOCUS MODE: Place selected node at center, arrange neighbors in a circle
const center = nodes.find((n) => n.id === selectedId);
if (center) positions[center.id] = {{ x: cx, y: cy }};
const neighbors = edges
.filter((e) => e.source === selectedId || e.target === selectedId)
.map((e) => (e.source === selectedId ? e.target : e.source))
.filter((v, idx, arr) => arr.indexOf(v) === idx);
if (neighbors.length > 0) {{
const step = 2 * Math.PI / neighbors.length;
const radius = Math.min(w, h) * 0.35;
neighbors.forEach((nid, idx) => {{
const angle = idx * step - Math.PI / 2; // Start from top
positions[nid] = {{
x: cx + radius * Math.cos(angle),
y: cy + radius * Math.sin(angle),
}};
}});
}}
return positions;
}}
// GLOBAL VIEW: Radial layout with top nodes by degree/weight
const deg = new Map();
edges.forEach((e) => {{
deg.set(e.source, (deg.get(e.source) || 0) + Math.abs(e.weight));
deg.set(e.target, (deg.get(e.target) || 0) + Math.abs(e.weight));
}});
const ordered = [...nodes].sort((a, b) => (deg.get(b.id) || 0) - (deg.get(a.id) || 0));
const keep = ordered.slice(0, Math.min(18, ordered.length));
const step = 2 * Math.PI / Math.max(keep.length, 1);
const radius = Math.min(w, h) * 0.38;
keep.forEach((node, idx) => {{
const angle = idx * step - Math.PI / 2; // Start from top
positions[node.id] = {{
x: cx + radius * Math.cos(angle),
y: cy + radius * Math.sin(angle),
}};
}});
return positions;
}};
const render = () => {{
const {{ w, h }} = size();
svg.setAttribute("viewBox", "0 0 " + w + " " + h);
while (svg.firstChild) svg.removeChild(svg.firstChild);
if (fallback) fallback.style.display = "none";
let edges = [...data.edges].sort((a, b) => Math.abs(b.weight) - Math.abs(a.weight));
edges = edges.slice(0, MAX_EDGES);
let nodes = data.nodes.slice();
if (selected !== null) {{
const neighborIds = new Set();
edges.forEach((e) => {{
if (e.source === selected) neighborIds.add(e.target);
if (e.target === selected) neighborIds.add(e.source);
}});
nodes = nodes.filter((n) => n.id === selected || neighborIds.has(n.id));
edges = edges.filter((e) => e.source === selected || e.target === selected);
}}
if (!nodes.length) {{
setStatus("No interactions available.");
return;
}}
const positions = layoutPositions(nodes, edges, selected, w, h);
const maxEdge = edges.reduce((acc, e) => Math.max(acc, Math.abs(e.weight)), 0) || 1;
edges.forEach((edge) => {{
const src = positions[edge.source];
const tgt = positions[edge.target];
if (!src || !tgt) return;
const line = document.createElementNS("http://www.w3.org/2000/svg", "line");
line.setAttribute("x1", src.x);
line.setAttribute("y1", src.y);
line.setAttribute("x2", tgt.x);
line.setAttribute("y2", tgt.y);
const color = edge.weight >= 0 ? "#d35400" : "#3867d6";
// In focus mode, make edges more prominent
const baseWidth = selected !== null ? 2.0 : 1.2;
const width = baseWidth + 4 * (Math.abs(edge.weight) / maxEdge);
line.setAttribute("stroke", color);
line.setAttribute("stroke-width", width.toFixed(2));
line.setAttribute("class", "network-edge");
line.setAttribute("opacity", selected !== null ? "0.85" : "0.75");
svg.appendChild(line);
// Show edge labels more prominently in focus mode
const label = document.createElementNS("http://www.w3.org/2000/svg", "text");
label.setAttribute("class", "edge-label");
label.setAttribute("x", (src.x + tgt.x) / 2);
label.setAttribute("y", (src.y + tgt.y) / 2 - 6);
label.setAttribute("font-size", selected !== null ? "12px" : "11px");
label.setAttribute("font-weight", selected !== null ? "700" : "600");
label.textContent = (edge.weight >= 0 ? "+" : "") + edge.weight.toFixed(2);
svg.appendChild(label);
}});
nodes.forEach((node) => {{
const pos = positions[node.id];
if (!pos) return;
const g = document.createElementNS("http://www.w3.org/2000/svg", "g");
g.setAttribute("class", "network-node");
g.dataset.nodeId = String(node.id);
g.style.cursor = "pointer";
// Make focused node larger and more prominent
const isFocused = selected === node.id;
const radius = isFocused ? 28 : 18;
const strokeWidth = isFocused ? 3 : 2;
const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle");
circle.setAttribute("cx", pos.x);
circle.setAttribute("cy", pos.y);
circle.setAttribute("r", radius);
circle.setAttribute("fill", node.fill);
circle.setAttribute("stroke", isFocused ? "#7048e8" : node.stroke);
circle.setAttribute("stroke-width", strokeWidth);
if (isFocused) {{
circle.style.filter = "drop-shadow(0 6px 16px rgba(112, 72, 232, 0.35))";
}}
g.appendChild(circle);
const text = document.createElementNS("http://www.w3.org/2000/svg", "text");
text.setAttribute("x", pos.x);
text.setAttribute("y", pos.y + 4);
text.setAttribute("text-anchor", "middle");
text.setAttribute("font-weight", isFocused ? "700" : "600");
text.setAttribute("font-size", isFocused ? "13px" : "12px");
text.textContent = node.labelShort;
text.setAttribute("title", node.label);
g.appendChild(text);
const score = document.createElementNS("http://www.w3.org/2000/svg", "text");
score.setAttribute("x", pos.x);
score.setAttribute("y", pos.y + (isFocused ? 22 : 18));
score.setAttribute("text-anchor", "middle");
score.setAttribute("fill", isFocused ? "#7048e8" : "#5c4c78");
score.setAttribute("font-size", isFocused ? "12px" : "11px");
score.setAttribute("font-weight", isFocused ? "700" : "400");
score.textContent = (node.value >= 0 ? "+" : "") + node.value.toFixed(2);
g.appendChild(score);
svg.appendChild(g);
}});
if (selected === null) {{
setStatus("Global view: " + edges.length + " edges shown. Click any node to see only its interactions.");
}} else {{
const node = data.nodes.find((n) => n.id === selected);
const neighborCount = nodes.length - 1; // all nodes except the focused one
setStatus(node ? "🎯 Focused on \"" + node.label + "\" → showing " + neighborCount + " neighbor(s) with " + edges.length + " interaction(s). Click background to restore global view." : "Click background to restore global view.");
}}
}};
svg.addEventListener("click", (event) => {{
const target = event.target.closest("[data-node-id]");
if (target) {{
selected = Number(target.dataset.nodeId);
render();
}} else {{
selected = null;
render();
}}
}});
window.addEventListener("resize", () => {{ render(); }});
// Initial render
render();
}}
// Try to initialize immediately
if (document.readyState === 'loading') {{
document.addEventListener('DOMContentLoaded', initGraph);
}} else {{
// DOM already loaded
initGraph();
}}
// Also try after a short delay to ensure Gradio has rendered
setTimeout(initGraph, 100);
}})();
</script>
"""
)
return template
def _render_token_chip_view(
features: List[str],
values: List[float],
adjacency: Dict[str, List[Tuple[str, float]]],
method: str,
max_abs: float,
max_links: int,
) -> str:
def _shorten(text: str, max_len: int = 80) -> str:
text = text or ""
return text if len(text) <= max_len else text[: max_len - 1] + "…"
tokens_html: List[str] = []
for idx, token in enumerate(features):
value = values[idx]
bg, border = _token_colors(value, max_abs)
label_full = escape(token, quote=True)
label_short = escape(_shorten(token), quote=True)
partners = adjacency.get(token, [])[:max_links]
body: List[str] = []
if partners:
body.append('<ul class="token-link-list">')
for partner, val in partners:
body.append(
"<li>"
f"<span class='token-link-name'>{escape(str(partner))}</span>"
f"<span class='token-link-value'>{val:+.3f}</span>"
"</li>"
)
body.append("</ul>")
else:
body.append("<div class='token-interaction-empty'>No interactions recorded.</div>")
open_attr = " open" if idx == 0 else ""
tokens_html.append(
f'<details class="interaction-token"{open_attr} '
f'style="background-color:{bg}; border-color:{border};">'
"<summary>"
f'<span class="interaction-token__score">{value:+.2f}</span>'
f'<span class="interaction-token__text" title="{label_full}">{label_short}</span>'
'<span class="interaction-token__hint">Click to reveal linked tokens</span>'
"</summary>"
+ "".join(body) +
"</details>"
)
return "".join([
_TOKEN_VIEW_STYLE,
'<div class="token-interaction-view">',
'<div class="token-interaction-panel">'
f'{escape(method.title())} pairwise interactions · click a token to inspect its strongest partners.'
"</div>",
'<div class="token-interaction-grid">',
"".join(tokens_html) or "<div class='token-interaction-empty'>No tokens available.</div>",
"</div>",
"</div>",
])
def _render_sentence_link_view(
features: List[str],
values: List[float],
adjacency: Dict[str, List[Tuple[str, float]]],
pairwise: List[Tuple[Tuple[str, ...], float]],
method: str,
row_gap: float = 70.0,
max_edges: int = 14,
) -> str:
feature_index = {feat: idx for idx, feat in enumerate(features)}
edges: List[Tuple[int, int, float]] = []
for feats, val in pairwise:
if len(feats) != 2:
continue
a_idx = feature_index.get(feats[0])
b_idx = feature_index.get(feats[1])
if a_idx is None or b_idx is None or a_idx == b_idx:
continue
edges.append((a_idx, b_idx, float(val)))
edges.sort(key=lambda item: abs(item[2]), reverse=True)
trimmed: List[Tuple[int, int, float]] = []
seen = set()
for a_idx, b_idx, val in edges:
key = tuple(sorted((a_idx, b_idx)))
if key in seen:
continue
seen.add(key)
trimmed.append((a_idx, b_idx, val))
if len(trimmed) >= max_edges:
break
max_edge = max((abs(val) for _, _, val in trimmed), default=0.0) or 1.0
canvas_height = row_gap * len(features) + 20
rows_html: List[str] = []
value_max = max((abs(v) for v in values), default=0.0) or 1.0
for idx, token in enumerate(features):
value = values[idx]
bg, border = _token_colors(value, value_max)
label_text = escape(token)
label_attr = escape(token, quote=True)
partner_badges: List[str] = []
for partner, val in adjacency.get(token, [])[:2]:
partner_badges.append(
f"<span class='sentence-row__badge'>{escape(str(partner))} {val:+.2f}</span>"
)
links_html = ""
if partner_badges:
links_html = "<div class='sentence-row__links'>Top links: " + " ".join(partner_badges) + "</div>"
rows_html.append(
f'<div class="sentence-row" style="border-color:{border};" title="{label_attr}">'
f'<span class="sentence-row__score">{value:+.2f}</span>'
f'<div class="sentence-row__text">{label_text}</div>'
f"{links_html}"
"</div>"
)
path_elements: List[str] = []
for a_idx, b_idx, val in trimmed:
y1 = row_gap * a_idx + row_gap / 2
y2 = row_gap * b_idx + row_gap / 2
color = "#d35400" if val >= 0 else "#3867d6"
width = 1.5 + 3.0 * (abs(val) / max_edge)
control = 120 + abs(a_idx - b_idx) * 12
path_elements.append(
f'<path class="sentence-link-path" d="M10 {y1:.1f} C {control:.1f} {y1:.1f}, {control + 60:.1f} {y2:.1f}, 220 {y2:.1f}" '
f'stroke="{color}" stroke-width="{width:.2f}" data-label="{escape(features[a_idx])} ↔ {escape(features[b_idx])}" />'
)
mid_y = (y1 + y2) / 2
path_elements.append(
f'<text class="sentence-link-label" x="160" y="{mid_y:.1f}">{val:+.1f}</text>'
)
node_elements = [
f'<circle class="sentence-node" cx="230" cy="{(row_gap * idx + row_gap / 2):.1f}" r="5" />'
for idx in range(len(features))
]
if path_elements:
links_block = (
f'<svg viewBox="0 0 240 {canvas_height:.0f}" preserveAspectRatio="none">'
+ "".join(path_elements + node_elements) +
"</svg>"
)
else:
links_block = "<div class='token-interaction-empty'>No pairwise arcs available.</div>"
return "".join([
_TOKEN_VIEW_STYLE,
'<div class="sentence-interaction-view">',
'<div class="sentence-interaction-header">'
f'{escape(method.title())} pairwise interactions · one sentence per row.'
"</div>",
'<div class="sentence-interaction-body">',
'<div class="sentence-list">',
"".join(rows_html),
"</div>",
f'<div class="sentence-links" style="--canvas-height:{canvas_height:.0f}px;">'
f"{links_block}</div>",
"</div>",
"</div>",
])