""" Plotting utilities for visualizing higher-order feature interactions in a text-interpretability web app. Inputs are assumed to come from attribution utilities such as `shapley_interactions(...)` or `banzhaf_interactions(...)`. Outputs are Plotly Figure objects that can be rendered directly in Gradio/UI. """ import json import math import uuid from collections import defaultdict from html import escape import plotly.graph_objects as go import numpy as np from typing import List, Tuple, Dict, Optional try: # optional dependency from shapiq.interaction_values import InteractionValues # type: ignore from shapiq.plot import bar_plot # type: ignore except Exception: # pragma: no cover InteractionValues = None bar_plot = None from .utils import format_feature_label, get_color_scale, create_legend, matplotlib_to_plotly _TOKEN_VIEW_STYLE = """ """ def _strip_occurrence_suffix(text: str) -> str: text = text or "" if text.endswith(")") and " (" in text: base, _, tail = text.rpartition(" (") if tail[:-1].isdigit(): return base return text _NETWORK_VIEW_STYLE = """ """ def _interactions_to_shapiq( interactions: List[Tuple[Tuple[str, ...], float]], method: str, order: int, ) -> Tuple[Optional["InteractionValues"], List[str]]: if InteractionValues is None or not interactions: return None, [] feature_index: Dict[str, int] = {} feature_names: List[str] = [] def _idx(name: str) -> int: if name not in feature_index: feature_index[name] = len(feature_names) feature_names.append(name) return feature_index[name] lookup: Dict[Tuple[int, ...], int] = {} values: List[float] = [] for feats, value in interactions: if not feats: continue idx_tuple = tuple(sorted(_idx(f) for f in feats)) lookup[idx_tuple] = len(values) values.append(float(value)) if not values: return None, feature_names if order == 1: index = "SV" if method == "shapley" else ("IV" if method == "influence" else "BV") else: index = "SII" if method == "shapley" else ("III" if method == "influence" else "BII") min_order = 1 if order <= 1 else order iv = InteractionValues( values=np.array(values, dtype=float), index=index, max_order=order, n_players=len(feature_names), min_order=min_order, interaction_lookup=lookup, estimated=False, baseline_value=0.0, ) return iv, feature_names def plot_top_interactions( interactions: List[Tuple[Tuple[str, ...], float]], top_k: int = 10, order: int = 2, method: str = "shapley" ) -> go.Figure: """ Visualize the top-k interactions as a bar chart. Args: interactions: List of interactions from shapley_interactions/banzhaf_interactions. Each item is ((feature_name, ...), value). top_k: Number of top interactions to display. order: Interaction order (2 or 3). method: Attribution method label ("shapley" or "banzhaf"). Returns: Plotly Figure. Example: # From attribution mobius = run_proxyspex(set_function, features, max_order=3) interactions = shapley_interactions(mobius, order=2, top_k=10) fig = plot_top_interactions(interactions, top_k=10, order=2, method="shapley") """ if not interactions: return go.Figure().update_layout( title="No interactions available", template="plotly_white" ) ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k] if bar_plot is not None: iv, feature_names = _interactions_to_shapiq(ranked, method, order) if iv is not None: ax = bar_plot( [iv], feature_names=feature_names, show=False, abbreviate=False, max_display=top_k, global_plot=True, plot_base_value=True, ) fig = ax.figure if ax is not None else None if fig is not None: return matplotlib_to_plotly( fig, title=f"Top {len(ranked)} order-{order} {method.title()} interactions", ) labels = [ format_feature_label(" · ".join(feats), max_length=50) for feats, _ in ranked ] values = [val for _, val in ranked] # Influence scores are always non-negative (squared Fourier coefficients) is_influence = method.lower() == "influence" if is_influence: values = [abs(v) for v in values] # Create color scale based on value magnitude (importance) max_abs_val = max(abs(v) for v in values) if values else 1.0 def get_color(val: float) -> str: """Map value to color: purple = positive, red = negative (matches legend).""" norm = abs(val) / max_abs_val if max_abs_val > 0 else 0.5 if val >= 0: # Positive: Lavender -> deep violet r = int(76 + (214 - 76) * (1 - norm)) g = int(29 + (190 - 29) * (1 - norm)) b = int(149 + (255 - 149) * (1 - norm)) else: # Negative: Light rose -> deep red r = int(139 + (255 - 139) * (1 - norm)) g = int(0 + (160 - 0) * (1 - norm)) b = int(0 + (122 - 0) * (1 - norm)) return f"rgb({r}, {g}, {b})" colors = [get_color(v) for v in values] fig = go.Figure( data=[ go.Bar( y=list(reversed(labels)), x=list(reversed(values)), orientation="h", marker=dict( color=list(reversed(colors)), line=dict(color="#f5f5f5", width=2), ), text=[f"{v:.3f}" for v in reversed(values)], textposition="outside", textfont=dict(size=14, weight='bold'), cliponaxis=False, hovertemplate="%{y}
Value: %{x:.4f}", ) ] ) if not is_influence: fig.add_vline(x=0, line_dash="dash", line_color="#8c8c8c", line_width=2) x_axis_label = "Influence Magnitude" if is_influence else "Contribution" annotation_text = ( "Influence scores are always non-negative (squared Fourier coefficients). Color intensity shows magnitude." if is_influence else "Color intensity shows interaction strength: red = negative, purple = positive." ) fig.update_layout( title=dict( text=f"Top {len(labels)} order-{order} {method.title()} interactions", font=dict(size=18, weight='bold') ), xaxis_title=dict(text=x_axis_label, font=dict(size=14)), yaxis_title=None, xaxis=dict( tickfont=dict(size=12), gridcolor="rgba(148, 163, 184, 0.18)", zerolinecolor="rgba(148, 163, 184, 0.28)", rangemode="tozero" if is_influence else "normal", ), yaxis=dict(tickfont=dict(size=13), automargin=True), template="none", paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", font=dict(color="#2d1f4a"), hovermode="y", legend=create_legend(method, order), margin=dict(l=140, r=20, t=70, b=120), height=max(800, 70 * len(labels)), # Much larger: 800px minimum, 70px per bar annotations=[ dict( text=annotation_text, xref="paper", yref="paper", x=0.5, y=-0.08, showarrow=False, font=dict(size=11, color="#666"), xanchor='center', yanchor='top', ) ], ) return fig def plot_interaction_network( features: List[str], interactions: List[Tuple[Tuple[str, ...], float]], threshold: float = 0.1 ) -> go.Figure: """ Visualize pairwise interactions as a network graph. Args: features: The full ordered feature list. interactions: Pairwise interactions [((feat_i, feat_j), value), ...]. threshold: Only show interactions with absolute value greater than this threshold. Returns: Plotly Figure (network-style visualization). """ if not interactions: return go.Figure().update_layout( title="No pairwise interactions available", template="plotly_white" ) filtered = [ item for item in interactions if abs(item[1]) >= threshold ] if not filtered: return go.Figure().update_layout( title=f"No interactions exceed threshold ({threshold})", template="plotly_white" ) n = len(features) angles = np.linspace(0, 2 * np.pi, max(n, 1), endpoint=False) positions = { feat: (float(np.cos(theta)), float(np.sin(theta))) for feat, theta in zip(features, angles) } max_abs = max(abs(val) for _, val in filtered) or 1.0 traces = [] for (feat_a, feat_b), value in filtered: if feat_a not in positions or feat_b not in positions: continue (x0, y0), (x1, y1) = positions[feat_a], positions[feat_b] color = "#d73027" if value >= 0 else "#4575b4" width = 1 + 4 * abs(value) / max_abs label = f"{feat_a} <-> {feat_b}: {value:.3f}" traces.append( go.Scatter( x=[x0, x1], y=[y0, y1], mode="lines", line=dict(color=color, width=width), hoverinfo="text", text=[label, label], showlegend=False, ) ) node_trace = go.Scatter( x=[positions[f][0] for f in features if f in positions], y=[positions[f][1] for f in features if f in positions], mode="markers+text", marker=dict(size=14, color="#4a4a4a", line=dict(width=2, color="#ffffff")), text=[format_feature_label(f, 18) for f in features if f in positions], textposition="bottom center", hoverinfo="text", ) fig = go.Figure(data=traces + [node_trace]) fig.update_layout( title="Pairwise interaction network", xaxis=dict(visible=False), yaxis=dict(visible=False), template="plotly_white", showlegend=False, margin=dict(l=20, r=20, t=60, b=20), ) return fig def plot_interaction_matrix( features: List[str], interactions: List[Tuple[Tuple[int, int], float]] ) -> go.Figure: """ Visualize pairwise interactions as a matrix heatmap. Args: features: Ordered feature list (labels for axes). interactions: List of ((i, j), value) where i/j are feature indices. Returns: Plotly Figure (heatmap). """ n = len(features) if n == 0: return go.Figure().update_layout( title="Interaction matrix (no features)", template="plotly_white" ) matrix = np.zeros((n, n), dtype=float) for (i, j), value in interactions: if 0 <= i < n and 0 <= j < n: matrix[i, j] = value matrix[j, i] = value heatmap = go.Heatmap( z=matrix, x=[format_feature_label(f, 18) for f in features], y=[format_feature_label(f, 18) for f in features], colorscale=get_color_scale("shapley"), colorbar=dict(title="Interaction value"), hovertemplate=" %{y} vs %{x}
value=%{z:.2e}", ) fig = go.Figure(data=[heatmap]) fig.update_layout( title="Pairwise interaction matrix", xaxis=dict(side="top"), yaxis=dict(autorange="reversed"), template="plotly_white", margin=dict(l=80, r=20, t=80, b=80), ) return fig def plot_3rd_order_interactions( interactions: List[Tuple[Tuple[str, ...], float]], top_k: int = 5 ) -> go.Figure: """ Visualize third-order (triplet) interactions. Uses grouped or stacked bars to show the top-k triplets. Args: interactions: Triplet interactions [((f1, f2, f3), value), ...]. top_k: Number of top triplets to plot. Returns: Plotly Figure. """ if not interactions: return go.Figure().update_layout( title="No third-order interactions available", template="plotly_white" ) ranked = sorted(interactions, key=lambda item: abs(item[1]), reverse=True)[:top_k] labels = [ format_feature_label(" · ".join(feats), max_length=40) for feats, _ in ranked ] values = [val for _, val in ranked] colors = ["#d73027" if v >= 0 else "#4575b4" for v in values] fig = go.Figure( data=[ go.Bar( x=labels, y=values, marker=dict(color=colors), text=[f"{v:.3f}" for v in values], textposition="auto", ) ] ) fig.update_layout( title=f"Top {len(labels)} third-order interactions", xaxis_title="Feature triplet", yaxis_title="Interaction value", template="plotly_white", margin=dict(l=60, r=20, t=60, b=100), ) return fig def _token_colors(value: float, max_abs: float) -> Tuple[str, str]: if max_abs <= 0: return "rgba(229, 226, 240, 0.7)", "rgba(193, 189, 209, 0.9)" norm = max(-1.0, min(1.0, value / max_abs)) if norm >= 0: base = (222, 86, 61) else: base = (47, 128, 237) norm = -norm neutral = (245, 242, 252) r = int(round(neutral[0] + (base[0] - neutral[0]) * norm)) g = int(round(neutral[1] + (base[1] - neutral[1]) * norm)) b = int(round(neutral[2] + (base[2] - neutral[2]) * norm)) alpha = 0.35 + 0.4 * norm return f"rgba({r}, {g}, {b}, {alpha:.3f})", f"rgb({r}, {g}, {b})" def _wrap(text: str, max_len: int = 20) -> List[str]: """Wrap text into lines of max_len characters""" if not text or len(text) <= max_len: return [text] words = text.split() lines = [] current_line = [] current_length = 0 for word in words: word_len = len(word) if current_length + word_len + len(current_line) > max_len: if current_line: lines.append(' '.join(current_line)) current_line = [word] current_length = word_len else: # Single word longer than max_len lines.append(word[:max_len]) current_line = [] current_length = 0 else: current_line.append(word) current_length += word_len if current_line: lines.append(' '.join(current_line)) return lines if lines else [text[:max_len]] def create_interaction_token_view( features: List[str], feature_values: List[float], pairwise: List[Tuple[Tuple[str, ...], float]], method: str = "shapley", max_links: int = 5, layout: str = "token", ) -> str: """ Render token interactions as a lightweight chip list. """ if not features: return "
No tokens available.
" values = list(feature_values) if feature_values else [] if len(values) < len(features): values.extend([0.0] * (len(features) - len(values))) # partner lookup for the always-visible chip list adjacency: Dict[str, List[Tuple[str, float]]] = defaultdict(list) for feats, val in pairwise: if len(feats) != 2: continue a, b = feats adjacency[a].append((b, float(val))) adjacency[b].append((a, float(val))) for key in adjacency: adjacency[key].sort(key=lambda item: abs(item[1]), reverse=True) feature_index = {feat: idx for idx, feat in enumerate(features)} edges: List[Tuple[int, int, float]] = [] for feats, val in pairwise: if len(feats) != 2: continue a_idx = feature_index.get(feats[0]) b_idx = feature_index.get(feats[1]) if a_idx is None or b_idx is None or a_idx == b_idx: continue edges.append((a_idx, b_idx, float(val))) # Fallback: if no edges are provided, synthesize simple neighbor links so the UI isn't empty. if not edges and len(features) > 1: for i in range(len(features) - 1): score = 0.5 * (values[i] + values[i + 1]) edges.append((i, i + 1, float(score))) max_abs = max((abs(v) for v in values), default=0.0) or 1.0 chips_html = _render_token_chip_view( features, values, adjacency, method, max_abs, max_links, ) return chips_html def _render_interaction_network( features: List[str], values: List[float], edges: List[Tuple[int, int, float]], method: str, ) -> go.Figure: """ Interactive Plotly network view showing pairwise feature interactions. Args: features: List of token/feature labels values: Attribution values for each feature edges: List of (source_idx, target_idx, interaction_weight) tuples method: Attribution method name (for display) Returns: Plotly Figure with interactive network graph """ if not features or not edges: # Return empty figure with message fig = go.Figure() fig.update_layout( title=f"{method.title()} pairwise interactions", annotations=[{ "text": "No interaction data available", "xref": "paper", "yref": "paper", "x": 0.5, "y": 0.5, "showarrow": False, "font": {"size": 16, "color": "#666"} }], template="plotly_white", height=600, ) return fig # Prepare node data max_abs_value = max((abs(v) for v in values), default=0.0) or 1.0 # Sort edges by absolute weight and limit to top 60 edges_sorted = sorted(edges, key=lambda item: abs(item[2]), reverse=True)[:60] # Calculate circular layout positions n = len(features) angle_step = 2 * math.pi / n radius = 100 node_positions = {} for idx in range(n): angle = idx * angle_step - math.pi / 2 # Start from top node_positions[idx] = { 'x': radius * math.cos(angle), 'y': radius * math.sin(angle) } # Create edge traces edge_traces = [] max_edge_weight = max((abs(w) for _, _, w in edges_sorted), default=0.0) or 1.0 for source_idx, target_idx, weight in edges_sorted: x0, y0 = node_positions[source_idx]['x'], node_positions[source_idx]['y'] x1, y1 = node_positions[target_idx]['x'], node_positions[target_idx]['y'] # Color based on sign color = '#d35400' if weight >= 0 else '#3867d6' # Width based on magnitude width = 0.5 + 4.5 * (abs(weight) / max_edge_weight) edge_trace = go.Scatter( x=[x0, x1, None], y=[y0, y1, None], mode='lines', line=dict(color=color, width=width), opacity=0.7, hoverinfo='text', hovertext=( f"{_strip_occurrence_suffix(features[source_idx])} ↔ " f"{_strip_occurrence_suffix(features[target_idx])}
" f"Interaction: {weight:+.3f}" ), showlegend=False, ) edge_traces.append(edge_trace) # Create node trace node_x = [] node_y = [] node_text = [] node_colors = [] node_sizes = [] for idx in range(n): pos = node_positions[idx] node_x.append(pos['x']) node_y.append(pos['y']) # Shorten label if too long label = _strip_occurrence_suffix(features[idx]) if len(label) > 30: label = label[:27] + "..." value = values[idx] node_text.append(f"{label}
Value: {value:+.3f}") # Color based on attribution value fill_color, _ = _token_colors(value, max_abs_value) node_colors.append(fill_color) # Size based on absolute value size = 20 + 25 * (abs(value) / max_abs_value if max_abs_value > 0 else 0) node_sizes.append(size) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', marker=dict( size=node_sizes, color=node_colors, line=dict(color='#5c4c78', width=2), ), text=[f"{_strip_occurrence_suffix(features[i])[:20]}" for i in range(n)], # Short labels on nodes textposition="top center", textfont=dict(size=10, color='#1f1533'), hoverinfo='text', hovertext=node_text, showlegend=False, ) # Create figure fig = go.Figure(data=edge_traces + [node_trace]) # Update layout fig.update_layout( title=dict( text=f"{method.title()} pairwise interactions network", font=dict(size=16, color='#1f1533') ), showlegend=False, hovermode='closest', margin=dict(b=20, l=20, r=20, t=60), xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), plot_bgcolor='rgba(250, 248, 255, 0.5)', paper_bgcolor='white', height=600, template="plotly_white", ) # Add annotation with instructions fig.add_annotation( text="💡 Hover over nodes and edges to see details | Zoom and pan to explore", xref="paper", yref="paper", x=0.5, y=-0.05, showarrow=False, font=dict(size=11, color='#666'), xanchor='center', ) return fig def _render_token_chip_view( features: List[str], values: List[float], adjacency: Dict[str, List[Tuple[str, float]]], method: str, max_abs: float, max_links: int, ) -> str: def _shorten(text: str, max_len: int = 80) -> str: text = text or "" return text if len(text) <= max_len else text[: max_len - 1] + "…" is_influence = method.lower() == "influence" tokens_html: List[str] = [] for idx, token in enumerate(features): value = abs(values[idx]) if is_influence else values[idx] bg, border = _token_colors(value, max_abs) display_token = _strip_occurrence_suffix(token) label_full = escape(display_token, quote=True) label_short = escape(_shorten(display_token), quote=True) partners = adjacency.get(token, [])[:max_links] partner_total = len(adjacency.get(token, [])) partner_label = "interaction" if partner_total == 1 else "interactions" body: List[str] = [] if partners: body.append('") else: body.append("
No interactions recorded.
") score_display = f"{value:.2f}" if is_influence else f"{value:+.2f}" open_attr = " open" if idx == 0 else "" tokens_html.append( f'
' "" f'{score_display}' f'{label_short}' f'{partner_total} {partner_label}' "" + "".join(body) + "
" ) return "".join([ _TOKEN_VIEW_STYLE, '
', '
' f'{escape(method.title())} pairwise interactions · click a token to inspect its strongest partners.' "
", '
', "".join(tokens_html) or "
No tokens available.
", "
", "
", ]) def _render_sentence_link_view( features: List[str], values: List[float], adjacency: Dict[str, List[Tuple[str, float]]], pairwise: List[Tuple[Tuple[str, ...], float]], method: str, row_gap: float = 70.0, max_edges: int = 14, ) -> str: feature_index = {feat: idx for idx, feat in enumerate(features)} edges: List[Tuple[int, int, float]] = [] for feats, val in pairwise: if len(feats) != 2: continue a_idx = feature_index.get(feats[0]) b_idx = feature_index.get(feats[1]) if a_idx is None or b_idx is None or a_idx == b_idx: continue edges.append((a_idx, b_idx, float(val))) edges.sort(key=lambda item: abs(item[2]), reverse=True) trimmed: List[Tuple[int, int, float]] = [] seen = set() for a_idx, b_idx, val in edges: key = tuple(sorted((a_idx, b_idx))) if key in seen: continue seen.add(key) trimmed.append((a_idx, b_idx, val)) if len(trimmed) >= max_edges: break max_edge = max((abs(val) for _, _, val in trimmed), default=0.0) or 1.0 canvas_height = row_gap * len(features) + 20 rows_html: List[str] = [] value_max = max((abs(v) for v in values), default=0.0) or 1.0 for idx, token in enumerate(features): value = values[idx] bg, border = _token_colors(value, value_max) label_text = escape(token) label_attr = escape(token, quote=True) partner_badges: List[str] = [] for partner, val in adjacency.get(token, [])[:2]: partner_badges.append( f"{escape(str(partner))} {val:+.2f}" ) links_html = "" if partner_badges: links_html = "" rows_html.append( f'
' f'{value:+.2f}' f'
{label_text}
' f"{links_html}" "
" ) path_elements: List[str] = [] for a_idx, b_idx, val in trimmed: y1 = row_gap * a_idx + row_gap / 2 y2 = row_gap * b_idx + row_gap / 2 color = "#d35400" if val >= 0 else "#3867d6" width = 1.5 + 3.0 * (abs(val) / max_edge) control = 120 + abs(a_idx - b_idx) * 12 path_elements.append( f'' ) mid_y = (y1 + y2) / 2 path_elements.append( f'{val:+.1f}' ) node_elements = [ f'' for idx in range(len(features)) ] if path_elements: links_block = ( f'' + "".join(path_elements + node_elements) + "" ) else: links_block = "
No pairwise arcs available.
" return "".join([ _TOKEN_VIEW_STYLE, '
', '
' f'{escape(method.title())} pairwise interactions · one sentence per row.' "
", '
', '
', "".join(rows_html), "
", f'", "
", "
", ])