""" Plotting utilities for visualizing higher-order feature interactions in a text-interpretability web app. Inputs are assumed to come from attribution utilities such as `shapley_interactions(...)` or `banzhaf_interactions(...)`. Outputs are Plotly Figure objects that can be rendered directly in Gradio/UI. """ import json import uuid from collections import defaultdict from html import escape import plotly.graph_objects as go import numpy as np from typing import List, Tuple, Dict, Optional try: # optional dependency from shapiq.interaction_values import InteractionValues # type: ignore from shapiq.plot import bar_plot # type: ignore except Exception: # pragma: no cover InteractionValues = None bar_plot = None from .utils import format_feature_label, get_color_scale, create_legend, matplotlib_to_plotly _TOKEN_VIEW_STYLE = """ """ _NETWORK_VIEW_STYLE = """ """ def _interactions_to_shapiq( interactions: List[Tuple[Tuple[str, ...], float]], method: str, order: int, ) -> Tuple[Optional["InteractionValues"], List[str]]: if InteractionValues is None or not interactions: return None, [] feature_index: Dict[str, int] = {} feature_names: List[str] = [] def _idx(name: str) -> int: if name not in feature_index: feature_index[name] = len(feature_names) feature_names.append(name) return feature_index[name] lookup: Dict[Tuple[int, ...], int] = {} values: List[float] = [] for feats, value in interactions: if not feats: continue idx_tuple = tuple(sorted(_idx(f) for f in feats)) lookup[idx_tuple] = len(values) values.append(float(value)) if not values: return None, feature_names if order == 1: index = "SV" if method == "shapley" else "BV" else: index = "SII" if method == "shapley" else "BII" min_order = 1 if order <= 1 else order iv = InteractionValues( values=np.array(values, dtype=float), index=index, max_order=order, n_players=len(feature_names), min_order=min_order, interaction_lookup=lookup, estimated=False, baseline_value=0.0, ) return iv, feature_names def plot_top_interactions( interactions: List[Tuple[Tuple[str, ...], float]], top_k: int = 10, order: int = 2, method: str = "shapley" ) -> go.Figure: """ Visualize the top-k interactions as a bar chart. Args: interactions: List of interactions from shapley_interactions/banzhaf_interactions. Each item is ((feature_name, ...), value). top_k: Number of top interactions to display. order: Interaction order (2 or 3). method: Attribution method label ("shapley" or "banzhaf"). Returns: Plotly Figure. Example: # From attribution mobius = run_proxyspex(set_function, features, max_order=3) interactions = shapley_interactions(mobius, order=2, top_k=10) fig = plot_top_interactions(interactions, top_k=10, order=2, method="shapley") """ if not interactions: return go.Figure().update_layout( title="No interactions available", template="plotly_white" ) ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k] if bar_plot is not None: iv, feature_names = _interactions_to_shapiq(ranked, method, order) if iv is not None: ax = bar_plot( [iv], feature_names=feature_names, show=False, abbreviate=False, max_display=top_k, global_plot=True, plot_base_value=True, ) fig = ax.figure if ax is not None else None if fig is not None: return matplotlib_to_plotly( fig, title=f"Top {len(ranked)} order-{order} {method.title()} interactions", ) labels = [ format_feature_label(" · ".join(feats), max_length=50) for feats, _ in ranked ] values = [val for _, val in ranked] colors = ["#E24A33" if v >= 0 else "#348ABD" for v in values] fig = go.Figure( data=[ go.Bar( y=list(reversed(labels)), x=list(reversed(values)), orientation="h", marker=dict( color=list(reversed(colors)), line=dict(color="#f5f5f5", width=1.5), ), text=[f"{v:.3f}" for v in reversed(values)], textposition="outside", cliponaxis=False, ) ] ) fig.add_vline(x=0, line_dash="dash", line_color="#8c8c8c", line_width=1) fig.update_layout( title=f"Top {len(labels)} order-{order} {method.title()} interactions", xaxis_title="Contribution", yaxis_title="Feature group", template="plotly_white", hovermode="y", legend=create_legend(method, order), margin=dict(l=20, r=20, t=70, b=20), height=max(360, 40 * len(labels)), ) return fig def plot_interaction_network( features: List[str], interactions: List[Tuple[Tuple[str, ...], float]], threshold: float = 0.1 ) -> go.Figure: """ Visualize pairwise interactions as a network graph. Args: features: The full ordered feature list. interactions: Pairwise interactions [((feat_i, feat_j), value), ...]. threshold: Only show interactions with absolute value greater than this threshold. Returns: Plotly Figure (network-style visualization). """ if not interactions: return go.Figure().update_layout( title="No pairwise interactions available", template="plotly_white" ) filtered = [ item for item in interactions if abs(item[1]) >= threshold ] if not filtered: return go.Figure().update_layout( title=f"No interactions exceed threshold ({threshold})", template="plotly_white" ) n = len(features) angles = np.linspace(0, 2 * np.pi, max(n, 1), endpoint=False) positions = { feat: (float(np.cos(theta)), float(np.sin(theta))) for feat, theta in zip(features, angles) } max_abs = max(abs(val) for _, val in filtered) or 1.0 traces = [] for (feat_a, feat_b), value in filtered: if feat_a not in positions or feat_b not in positions: continue (x0, y0), (x1, y1) = positions[feat_a], positions[feat_b] color = "#d73027" if value >= 0 else "#4575b4" width = 1 + 4 * abs(value) / max_abs label = f"{feat_a} <-> {feat_b}: {value:.3f}" traces.append( go.Scatter( x=[x0, x1], y=[y0, y1], mode="lines", line=dict(color=color, width=width), hoverinfo="text", text=[label, label], showlegend=False, ) ) node_trace = go.Scatter( x=[positions[f][0] for f in features if f in positions], y=[positions[f][1] for f in features if f in positions], mode="markers+text", marker=dict(size=14, color="#4a4a4a", line=dict(width=2, color="#ffffff")), text=[format_feature_label(f, 18) for f in features if f in positions], textposition="bottom center", hoverinfo="text", ) fig = go.Figure(data=traces + [node_trace]) fig.update_layout( title="Pairwise interaction network", xaxis=dict(visible=False), yaxis=dict(visible=False), template="plotly_white", showlegend=False, margin=dict(l=20, r=20, t=60, b=20), ) return fig def plot_interaction_matrix( features: List[str], interactions: List[Tuple[Tuple[int, int], float]] ) -> go.Figure: """ Visualize pairwise interactions as a matrix heatmap. Args: features: Ordered feature list (labels for axes). interactions: List of ((i, j), value) where i/j are feature indices. Returns: Plotly Figure (heatmap). """ n = len(features) if n == 0: return go.Figure().update_layout( title="Interaction matrix (no features)", template="plotly_white" ) matrix = np.zeros((n, n), dtype=float) for (i, j), value in interactions: if 0 <= i < n and 0 <= j < n: matrix[i, j] = value matrix[j, i] = value heatmap = go.Heatmap( z=matrix, x=[format_feature_label(f, 18) for f in features], y=[format_feature_label(f, 18) for f in features], colorscale=get_color_scale("shapley"), colorbar=dict(title="Interaction value"), hovertemplate=" %{y} vs %{x}
value=%{z:.2e}", ) fig = go.Figure(data=[heatmap]) fig.update_layout( title="Pairwise interaction matrix", xaxis=dict(side="top"), yaxis=dict(autorange="reversed"), template="plotly_white", margin=dict(l=80, r=20, t=80, b=80), ) return fig def plot_3rd_order_interactions( interactions: List[Tuple[Tuple[str, ...], float]], top_k: int = 5 ) -> go.Figure: """ Visualize third-order (triplet) interactions. Uses grouped or stacked bars to show the top-k triplets. Args: interactions: Triplet interactions [((f1, f2, f3), value), ...]. top_k: Number of top triplets to plot. Returns: Plotly Figure. """ if not interactions: return go.Figure().update_layout( title="No third-order interactions available", template="plotly_white" ) ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k] labels = [ format_feature_label(" · ".join(feats), max_length=40) for feats, _ in ranked ] values = [val for _, val in ranked] colors = ["#d73027" if v >= 0 else "#4575b4" for v in values] fig = go.Figure( data=[ go.Bar( x=labels, y=values, marker=dict(color=colors), text=[f"{v:.3f}" for v in values], textposition="auto", ) ] ) fig.update_layout( title=f"Top {len(labels)} third-order interactions", xaxis_title="Feature triplet", yaxis_title="Interaction value", template="plotly_white", margin=dict(l=60, r=20, t=60, b=100), ) return fig def _token_colors(value: float, max_abs: float) -> Tuple[str, str]: if max_abs <= 0: return "rgba(229, 226, 240, 0.7)", "rgba(193, 189, 209, 0.9)" norm = max(-1.0, min(1.0, value / max_abs)) if norm >= 0: base = (222, 86, 61) else: base = (47, 128, 237) norm = -norm neutral = (245, 242, 252) r = int(round(neutral[0] + (base[0] - neutral[0]) * norm)) g = int(round(neutral[1] + (base[1] - neutral[1]) * norm)) b = int(round(neutral[2] + (base[2] - neutral[2]) * norm)) alpha = 0.35 + 0.4 * norm return f"rgba({r}, {g}, {b}, {alpha:.3f})", f"rgb({r}, {g}, {b})" def create_interaction_token_view( features: List[str], feature_values: List[float], pairwise: List[Tuple[Tuple[str, ...], float]], method: str = "shapley", max_links: int = 5, layout: str = "token", ) -> str: """ Render token interactions with two layers: 1) Chip list (always visible; no JS required). 2) Interactive network (needs JS; falls back gracefully). """ if not features: return "
No tokens available.
" values = list(feature_values) if feature_values else [] if len(values) < len(features): values.extend([0.0] * (len(features) - len(values))) # partner lookup for the always-visible chip list adjacency: Dict[str, List[Tuple[str, float]]] = defaultdict(list) for feats, val in pairwise: if len(feats) != 2: continue a, b = feats adjacency[a].append((b, float(val))) adjacency[b].append((a, float(val))) for key in adjacency: adjacency[key].sort(key=lambda item: abs(item[1]), reverse=True) feature_index = {feat: idx for idx, feat in enumerate(features)} edges: List[Tuple[int, int, float]] = [] for feats, val in pairwise: if len(feats) != 2: continue a_idx = feature_index.get(feats[0]) b_idx = feature_index.get(feats[1]) if a_idx is None or b_idx is None or a_idx == b_idx: continue edges.append((a_idx, b_idx, float(val))) # Fallback: if no edges are provided, synthesize simple neighbor links so the UI isn't empty. if not edges and len(features) > 1: for i in range(len(features) - 1): score = 0.5 * (values[i] + values[i + 1]) edges.append((i, i + 1, float(score))) max_abs = max((abs(v) for v in values), default=0.0) or 1.0 chips_html = _render_token_chip_view( features, values, adjacency, method, max_abs, max_links, ) network_html = _render_interaction_network(features, values, edges, method) return chips_html + network_html def _render_interaction_network( features: List[str], values: List[float], edges: List[Tuple[int, int, float]], method: str, ) -> str: """ Interactive SVG network view: - Click a token/sentence to focus it in the center and hide unrelated nodes. - Edge labels carry interaction scores. - Click empty space to reset the full view. """ max_abs_value = max((abs(v) for v in values), default=0.0) or 1.0 nodes_payload = [] for idx, token in enumerate(features): fill, stroke = _token_colors(values[idx], max_abs_value) nodes_payload.append( { "id": idx, "label": token, "value": float(values[idx]), "fill": fill, "stroke": stroke, } ) edges_sorted = sorted(edges, key=lambda item: abs(item[2]), reverse=True) edges_payload = [ {"source": int(a), "target": int(b), "weight": float(val)} for a, b, val in edges_sorted[:120] ] graph_id = f"interaction-net-{uuid.uuid4().hex[:8]}" data_json = json.dumps( {"nodes": nodes_payload, "edges": edges_payload}, ensure_ascii=False ) method_label = escape(method.title()) def _shorten(label: str, max_len: int = 48) -> str: label = label or "" if len(label) <= max_len: return label return label[: max_len - 1] + "…" def _static_svg() -> str: """ Fallback rendering (no JS): radial layout with shortened, multiline labels. """ width, height = 820.0, 520.0 n = len(features) if n == 0: return "
No interactions available.
" def _wrap(text: str, max_len: int = 24) -> str: words = text.split() lines: List[str] = [] current = "" for w in words: if len(current) + len(w) + 1 <= max_len: current = (current + " " + w).strip() else: if current: lines.append(current) current = w if current: lines.append(current) if not lines: lines = [text[:max_len]] return lines[:2] # at most two lines to limit height # Limit to top nodes to reduce clutter top_k = min(10, n) order = sorted(range(n), key=lambda i: abs(values[i]), reverse=True)[:top_k] radius = min(width, height) * 0.36 cx, cy = width / 2.0, height / 2.0 positions: Dict[int, Tuple[float, float]] = {} step = 2 * np.pi / max(len(order), 1) for i, old_idx in enumerate(order): angle = i * step - np.pi / 2 positions[old_idx] = ( cx + radius * float(np.cos(angle)), cy + radius * float(np.sin(angle)), ) filtered_edges = [ (a, b, val) for a, b, val in edges_sorted[:40] if a in positions and b in positions ] max_edge = max((abs(val) for _, _, val in filtered_edges), default=0.0) or 1.0 max_abs_val = max((abs(values[i]) for i in order), default=0.0) or 1.0 line_elems: List[str] = [] for a_idx, b_idx, val in filtered_edges: x1, y1 = positions[a_idx] x2, y2 = positions[b_idx] color = "#d35400" if val >= 0 else "#3867d6" width_px = 1.2 + 3.5 * (abs(val) / max_edge) line_elems.append( f'' ) node_elems: List[str] = [] for old_idx in order: x, y = positions[old_idx] fill, stroke = _token_colors(values[old_idx], max_abs_val) lines = _wrap(features[old_idx]) node_elems.append( f'' f'' f'' + "".join(f'{escape(line)}' for idx, line in enumerate(lines)) + "" f'{values[old_idx]:+0.2f}' f'' ) return ( "
" "
Static view (scripts blocked)
" f'' + "".join(line_elems + node_elems) + "
" ) # Render interactive SVG with JS + a static fallback that hides when scripts run. # Use unique IDs and properly escaped JavaScript js_code = f""" (function() {{ const graphId = "{graph_id}"; const data = {data_json}; function initGraph() {{ const svg = document.getElementById(graphId); const status = document.getElementById(graphId + "-status"); const fallback = document.getElementById(graphId + "-fallback"); if (!svg) {{ console.error("SVG element not found:", graphId); if (fallback) fallback.style.display = "block"; if (status) status.textContent = "Failed to load interactive view"; return; }} // Hide fallback since interactive is loading if (fallback) fallback.style.display = "none"; // shorten labels for display data.nodes = data.nodes.map((n) => ({{ ...n, labelShort: n.label.length > 36 ? n.label.slice(0, 35) + "…" : n.label, }})); let selected = null; const MAX_EDGES = 60; const setStatus = (msg) => {{ if (status) status.textContent = msg; }}; const size = () => {{ const rect = svg.getBoundingClientRect(); return {{ w: Math.max(rect.width || 960, 720), h: Math.max(rect.height || 540, 420), }}; }}; const layoutPositions = (nodes, edges, selectedId, w, h) => {{ const cx = w / 2; const cy = h / 2; const positions = {{}}; if (selectedId !== null) {{ // Radial focus layout positions[selectedId] = {{ x: cx, y: cy }}; const neighbors = nodes.filter((n) => n.id !== selectedId); const angleStep = (2 * Math.PI) / Math.max(neighbors.length, 1); const radius = Math.min(w, h) * 0.28; neighbors.forEach((n, i) => {{ const angle = i * angleStep - Math.PI / 2; positions[n.id] = {{ x: cx + radius * Math.cos(angle), y: cy + radius * Math.sin(angle), }}; }}); }} else {{ // Simple force-directed layout approximation nodes.forEach((n, i) => {{ const angle = ((i / nodes.length) * 2 * Math.PI) - Math.PI / 2; const r = Math.min(w, h) * 0.35; positions[n.id] = {{ x: cx + r * Math.cos(angle), y: cy + r * Math.sin(angle), }}; }}); // Simple force simulation (3 iterations) for (let iter = 0; iter < 3; iter++) {{ const forces = {{}}; nodes.forEach((n) => {{ forces[n.id] = {{ x: 0, y: 0 }}; }}); // Repulsion between all nodes for (let i = 0; i < nodes.length; i++) {{ for (let j = i + 1; j < nodes.length; j++) {{ const ni = nodes[i].id; const nj = nodes[j].id; const dx = positions[nj].x - positions[ni].x; const dy = positions[nj].y - positions[ni].y; const dist = Math.sqrt(dx * dx + dy * dy) || 1; const force = 800 / (dist * dist); forces[ni].x -= (dx / dist) * force; forces[ni].y -= (dy / dist) * force; forces[nj].x += (dx / dist) * force; forces[nj].y += (dy / dist) * force; }} }} // Apply forces nodes.forEach((n) => {{ positions[n.id].x += forces[n.id].x * 0.1; positions[n.id].y += forces[n.id].y * 0.1; // Keep within bounds positions[n.id].x = Math.max(50, Math.min(w - 50, positions[n.id].x)); positions[n.id].y = Math.max(50, Math.min(h - 50, positions[n.id].y)); }}); }} }} return positions; }}; const render = () => {{ const {{ w, h }} = size(); while (svg.firstChild) svg.removeChild(svg.firstChild); const graphId = "{graph_id}"; const data = {data_json}; function initGraph() {{ const svg = document.getElementById(graphId); const status = document.getElementById(graphId + "-status"); const fallback = document.getElementById(graphId + "-fallback"); if (!svg) {{ console.error("SVG element not found:", graphId); if (fallback) fallback.style.display = "block"; if (status) status.textContent = "Failed to load interactive view"; return; }} // Hide fallback since interactive is loading if (fallback) fallback.style.display = "none"; // shorten labels for display data.nodes = data.nodes.map((n) => ({{ ...n, labelShort: n.label.length > 36 ? n.label.slice(0, 35) + "…" : n.label, }})); let selected = null; const MAX_EDGES = 60; const setStatus = (msg) => {{ if (status) status.textContent = msg; }}; const size = () => {{ const rect = svg.getBoundingClientRect(); return {{ w: Math.max(rect.width || 960, 720), h: Math.max(rect.height || 540, 420), }}; }}; const layoutPositions = (nodes, edges, selectedId, w, h) => {{ const cx = w / 2; const cy = h / 2; const positions = {{}}; if (selectedId !== null) {{ // FOCUS MODE: Place selected node at center, arrange neighbors in a circle const center = nodes.find((n) => n.id === selectedId); if (center) positions[center.id] = {{ x: cx, y: cy }}; const neighbors = edges .filter((e) => e.source === selectedId || e.target === selectedId) .map((e) => (e.source === selectedId ? e.target : e.source)) .filter((v, idx, arr) => arr.indexOf(v) === idx); if (neighbors.length > 0) {{ const step = 2 * Math.PI / neighbors.length; const radius = Math.min(w, h) * 0.35; neighbors.forEach((nid, idx) => {{ const angle = idx * step - Math.PI / 2; // Start from top positions[nid] = {{ x: cx + radius * Math.cos(angle), y: cy + radius * Math.sin(angle), }}; }}); }} return positions; }} // GLOBAL VIEW: Radial layout with top nodes by degree/weight const deg = new Map(); edges.forEach((e) => {{ deg.set(e.source, (deg.get(e.source) || 0) + Math.abs(e.weight)); deg.set(e.target, (deg.get(e.target) || 0) + Math.abs(e.weight)); }}); const ordered = [...nodes].sort((a, b) => (deg.get(b.id) || 0) - (deg.get(a.id) || 0)); const keep = ordered.slice(0, Math.min(18, ordered.length)); const step = 2 * Math.PI / Math.max(keep.length, 1); const radius = Math.min(w, h) * 0.38; keep.forEach((node, idx) => {{ const angle = idx * step - Math.PI / 2; // Start from top positions[node.id] = {{ x: cx + radius * Math.cos(angle), y: cy + radius * Math.sin(angle), }}; }}); return positions; }}; const render = () => {{ const {{ w, h }} = size(); svg.setAttribute("viewBox", "0 0 " + w + " " + h); while (svg.firstChild) svg.removeChild(svg.firstChild); if (fallback) fallback.style.display = "none"; let edges = [...data.edges].sort((a, b) => Math.abs(b.weight) - Math.abs(a.weight)); edges = edges.slice(0, MAX_EDGES); let nodes = data.nodes.slice(); if (selected !== null) {{ const neighborIds = new Set(); edges.forEach((e) => {{ if (e.source === selected) neighborIds.add(e.target); if (e.target === selected) neighborIds.add(e.source); }}); nodes = nodes.filter((n) => n.id === selected || neighborIds.has(n.id)); edges = edges.filter((e) => e.source === selected || e.target === selected); }} if (!nodes.length) {{ setStatus("No interactions available."); return; }} const positions = layoutPositions(nodes, edges, selected, w, h); const maxEdge = edges.reduce((acc, e) => Math.max(acc, Math.abs(e.weight)), 0) || 1; edges.forEach((edge) => {{ const src = positions[edge.source]; const tgt = positions[edge.target]; if (!src || !tgt) return; const line = document.createElementNS("http://www.w3.org/2000/svg", "line"); line.setAttribute("x1", src.x); line.setAttribute("y1", src.y); line.setAttribute("x2", tgt.x); line.setAttribute("y2", tgt.y); const color = edge.weight >= 0 ? "#d35400" : "#3867d6"; // In focus mode, make edges more prominent const baseWidth = selected !== null ? 2.0 : 1.2; const width = baseWidth + 4 * (Math.abs(edge.weight) / maxEdge); line.setAttribute("stroke", color); line.setAttribute("stroke-width", width.toFixed(2)); line.setAttribute("class", "network-edge"); line.setAttribute("opacity", selected !== null ? "0.85" : "0.75"); svg.appendChild(line); // Show edge labels more prominently in focus mode const label = document.createElementNS("http://www.w3.org/2000/svg", "text"); label.setAttribute("class", "edge-label"); label.setAttribute("x", (src.x + tgt.x) / 2); label.setAttribute("y", (src.y + tgt.y) / 2 - 6); label.setAttribute("font-size", selected !== null ? "12px" : "11px"); label.setAttribute("font-weight", selected !== null ? "700" : "600"); label.textContent = (edge.weight >= 0 ? "+" : "") + edge.weight.toFixed(2); svg.appendChild(label); }}); nodes.forEach((node) => {{ const pos = positions[node.id]; if (!pos) return; const g = document.createElementNS("http://www.w3.org/2000/svg", "g"); g.setAttribute("class", "network-node"); g.dataset.nodeId = String(node.id); g.style.cursor = "pointer"; // Make focused node larger and more prominent const isFocused = selected === node.id; const radius = isFocused ? 28 : 18; const strokeWidth = isFocused ? 3 : 2; const circle = document.createElementNS("http://www.w3.org/2000/svg", "circle"); circle.setAttribute("cx", pos.x); circle.setAttribute("cy", pos.y); circle.setAttribute("r", radius); circle.setAttribute("fill", node.fill); circle.setAttribute("stroke", isFocused ? "#7048e8" : node.stroke); circle.setAttribute("stroke-width", strokeWidth); if (isFocused) {{ circle.style.filter = "drop-shadow(0 6px 16px rgba(112, 72, 232, 0.35))"; }} g.appendChild(circle); const text = document.createElementNS("http://www.w3.org/2000/svg", "text"); text.setAttribute("x", pos.x); text.setAttribute("y", pos.y + 4); text.setAttribute("text-anchor", "middle"); text.setAttribute("font-weight", isFocused ? "700" : "600"); text.setAttribute("font-size", isFocused ? "13px" : "12px"); text.textContent = node.labelShort; text.setAttribute("title", node.label); g.appendChild(text); const score = document.createElementNS("http://www.w3.org/2000/svg", "text"); score.setAttribute("x", pos.x); score.setAttribute("y", pos.y + (isFocused ? 22 : 18)); score.setAttribute("text-anchor", "middle"); score.setAttribute("fill", isFocused ? "#7048e8" : "#5c4c78"); score.setAttribute("font-size", isFocused ? "12px" : "11px"); score.setAttribute("font-weight", isFocused ? "700" : "400"); score.textContent = (node.value >= 0 ? "+" : "") + node.value.toFixed(2); g.appendChild(score); svg.appendChild(g); }}); if (selected === null) {{ setStatus("Global view: " + edges.length + " edges shown. Click any node to see only its interactions."); }} else {{ const node = data.nodes.find((n) => n.id === selected); const neighborCount = nodes.length - 1; // all nodes except the focused one setStatus(node ? "🎯 Focused on \"" + node.label + "\" → showing " + neighborCount + " neighbor(s) with " + edges.length + " interaction(s). Click background to restore global view." : "Click background to restore global view."); }} }}; svg.addEventListener("click", (event) => {{ const target = event.target.closest("[data-node-id]"); if (target) {{ selected = Number(target.dataset.nodeId); render(); }} else {{ selected = null; render(); }} }}); window.addEventListener("resize", () => {{ render(); }}); // Initial render render(); }} // Try to initialize immediately if (document.readyState === 'loading') {{ document.addEventListener('DOMContentLoaded', initGraph); }} else {{ // DOM already loaded initGraph(); }} // Also try after a short delay to ensure Gradio has rendered setTimeout(initGraph, 100); }})(); """ ) return template def _render_token_chip_view( features: List[str], values: List[float], adjacency: Dict[str, List[Tuple[str, float]]], method: str, max_abs: float, max_links: int, ) -> str: def _shorten(text: str, max_len: int = 80) -> str: text = text or "" return text if len(text) <= max_len else text[: max_len - 1] + "…" tokens_html: List[str] = [] for idx, token in enumerate(features): value = values[idx] bg, border = _token_colors(value, max_abs) label_full = escape(token, quote=True) label_short = escape(_shorten(token), quote=True) partners = adjacency.get(token, [])[:max_links] body: List[str] = [] if partners: body.append('") else: body.append("
No interactions recorded.
") open_attr = " open" if idx == 0 else "" tokens_html.append( f'
' "" f'{value:+.2f}' f'{label_short}' 'Click to reveal linked tokens' "" + "".join(body) + "
" ) return "".join([ _TOKEN_VIEW_STYLE, '
', '
' f'{escape(method.title())} pairwise interactions · click a token to inspect its strongest partners.' "
", '
', "".join(tokens_html) or "
No tokens available.
", "
", "
", ]) def _render_sentence_link_view( features: List[str], values: List[float], adjacency: Dict[str, List[Tuple[str, float]]], pairwise: List[Tuple[Tuple[str, ...], float]], method: str, row_gap: float = 70.0, max_edges: int = 14, ) -> str: feature_index = {feat: idx for idx, feat in enumerate(features)} edges: List[Tuple[int, int, float]] = [] for feats, val in pairwise: if len(feats) != 2: continue a_idx = feature_index.get(feats[0]) b_idx = feature_index.get(feats[1]) if a_idx is None or b_idx is None or a_idx == b_idx: continue edges.append((a_idx, b_idx, float(val))) edges.sort(key=lambda item: abs(item[2]), reverse=True) trimmed: List[Tuple[int, int, float]] = [] seen = set() for a_idx, b_idx, val in edges: key = tuple(sorted((a_idx, b_idx))) if key in seen: continue seen.add(key) trimmed.append((a_idx, b_idx, val)) if len(trimmed) >= max_edges: break max_edge = max((abs(val) for _, _, val in trimmed), default=0.0) or 1.0 canvas_height = row_gap * len(features) + 20 rows_html: List[str] = [] value_max = max((abs(v) for v in values), default=0.0) or 1.0 for idx, token in enumerate(features): value = values[idx] bg, border = _token_colors(value, value_max) label_text = escape(token) label_attr = escape(token, quote=True) partner_badges: List[str] = [] for partner, val in adjacency.get(token, [])[:2]: partner_badges.append( f"{escape(str(partner))} {val:+.2f}" ) links_html = "" if partner_badges: links_html = "" rows_html.append( f'
' f'{value:+.2f}' f'
{label_text}
' f"{links_html}" "
" ) path_elements: List[str] = [] for a_idx, b_idx, val in trimmed: y1 = row_gap * a_idx + row_gap / 2 y2 = row_gap * b_idx + row_gap / 2 color = "#d35400" if val >= 0 else "#3867d6" width = 1.5 + 3.0 * (abs(val) / max_edge) control = 120 + abs(a_idx - b_idx) * 12 path_elements.append( f'' ) mid_y = (y1 + y2) / 2 path_elements.append( f'{val:+.1f}' ) node_elements = [ f'' for idx in range(len(features)) ] if path_elements: links_block = ( f'' + "".join(path_elements + node_elements) + "" ) else: links_block = "
No pairwise arcs available.
" return "".join([ _TOKEN_VIEW_STYLE, '
', '
' f'{escape(method.title())} pairwise interactions · one sentence per row.' "
", '
', '
', "".join(rows_html), "
", f'", "
", "
", ])