from __future__ import annotations from typing import Dict, List, Tuple, Optional import plotly.graph_objects as go from .utils import get_color_scale def _normalize_row_offsets(token_rows: List[List[str]], row_offsets: List[int]) -> List[int]: if row_offsets and len(row_offsets) == len(token_rows): return row_offsets offsets: List[int] = [] cursor = 0 for row in token_rows: offsets.append(cursor) cursor += len(row) return offsets def _hex_to_rgb(color: str) -> Tuple[int, int, int]: color = color.lstrip("#") if len(color) == 3: color = "".join(ch * 2 for ch in color) return tuple(int(color[i : i + 2], 16) for i in (0, 2, 4)) def _interpolate_color(left: Tuple[int, int, int], right: Tuple[int, int, int], t: float) -> str: r = int(left[0] + (right[0] - left[0]) * t) g = int(left[1] + (right[1] - left[1]) * t) b = int(left[2] + (right[2] - left[2]) * t) return f"rgb({r}, {g}, {b})" def _colorscale_to_color(colorscale: List, t: float) -> str: if not colorscale: return "rgb(200, 200, 200)" t = max(0.0, min(1.0, t)) for idx in range(len(colorscale) - 1): left_pos, left_color = colorscale[idx] right_pos, right_color = colorscale[idx + 1] if t <= right_pos: if isinstance(left_color, str) and left_color.startswith("rgb"): left_rgb = tuple(int(v) for v in left_color[4:-1].split(",")) else: left_rgb = _hex_to_rgb(str(left_color)) if isinstance(right_color, str) and right_color.startswith("rgb"): right_rgb = tuple(int(v) for v in right_color[4:-1].split(",")) else: right_rgb = _hex_to_rgb(str(right_color)) span = right_pos - left_pos or 1.0 local_t = (t - left_pos) / span return _interpolate_color(left_rgb, right_rgb, local_t) tail = colorscale[-1][1] if isinstance(tail, str) and tail.startswith("rgb"): return tail return _interpolate_color(_hex_to_rgb(str(tail)), _hex_to_rgb(str(tail)), 0.0) def _value_to_color(value: float, max_abs: float, colorscale: List) -> str: if max_abs <= 0: return "rgb(220, 220, 220)" normalized = (value / max_abs + 1.0) / 2.0 return _colorscale_to_color(colorscale, normalized) def _strip_occurrence_suffix(text: str) -> str: text = text or "" if text.endswith(")") and " (" in text: base, _, tail = text.rpartition(" (") if tail[:-1].isdigit(): return base return text def plot_text_interactions( token_rows: list[list[str]], marginals_rows: Optional[list[list[float]]], interactions: list[dict], row_offsets: list[int], top_k: int = 30, title: str = "Text interaction view", ) -> go.Figure: if not token_rows: fig = go.Figure() fig.update_layout( title=title, annotations=[{ "text": "No tokens available", "xref": "paper", "yref": "paper", "x": 0.5, "y": 0.5, "showarrow": False, "font": {"size": 16, "color": "#666"}, }], template="plotly_white", height=240, ) return fig offsets = _normalize_row_offsets(token_rows, row_offsets or []) colorscale = get_color_scale("shapley") node_x: List[float] = [] node_y: List[float] = [] node_labels: List[str] = [] node_values: List[float] = [] node_hover: List[str] = [] global_to_node: Dict[int, int] = {} max_cols = 0 for row_idx, row in enumerate(token_rows): max_cols = max(max_cols, len(row)) row_offset = offsets[row_idx] if row_idx < len(offsets) else 0 row_vals = marginals_rows[row_idx] if row_idx < len(marginals_rows or []) else [] for col_idx, token in enumerate(row): global_idx = row_offset + col_idx global_to_node[global_idx] = len(node_x) node_x.append(float(col_idx)) node_y.append(float(-row_idx)) display_token = _strip_occurrence_suffix(str(token)) node_labels.append(display_token) value = float(row_vals[col_idx]) if col_idx < len(row_vals) else 0.0 node_values.append(value) node_hover.append(f"{display_token}
Value: {value:+.3f}") max_abs_value = max((abs(v) for v in node_values), default=0.0) node_colors = [ _value_to_color(value, max_abs_value, colorscale) for value in node_values ] edges: List[Tuple[int, int, float]] = [] for item in interactions or []: if not isinstance(item, dict): continue indices = item.get("indices") if not indices or len(indices) != 2: continue try: i = int(indices[0]) j = int(indices[1]) except Exception: continue if i not in global_to_node or j not in global_to_node: continue try: value = float(item.get("value", 0.0)) except Exception: value = 0.0 edges.append((i, j, value)) edges.sort(key=lambda item: abs(item[2]), reverse=True) edges = edges[:top_k] max_abs_edge = max((abs(v) for _, _, v in edges), default=0.0) or 1.0 edge_traces: List[go.Scatter] = [] for i, j, value in edges: idx_i = global_to_node.get(i) idx_j = global_to_node.get(j) if idx_i is None or idx_j is None: continue x_i, y_i = node_x[idx_i], node_y[idx_i] x_j, y_j = node_x[idx_j], node_y[idx_j] width = 1 + 6 * (abs(value) / max_abs_edge if max_abs_edge > 0 else 0) color = "#d35400" if value >= 0 else "#3867d6" label_i = node_labels[idx_i] label_j = node_labels[idx_j] edge_traces.append( go.Scatter( x=[x_i, x_j], y=[y_i, y_j], mode="lines", line=dict(color=color, width=width), opacity=0.7, hoverinfo="text", hovertext=f"{label_i} x {label_j} : {value:+.3f}", showlegend=False, ) ) node_trace = go.Scatter( x=node_x, y=node_y, mode="markers+text", text=node_labels, textposition="middle center", marker=dict( size=28, color=node_colors, line=dict(width=1, color="#2f2f2f"), ), hoverinfo="text", hovertext=node_hover, showlegend=False, ) fig = go.Figure(data=edge_traces + [node_trace]) pad_x = 0.6 pad_y = 0.6 rows = len(token_rows) y_min = -(rows - 1) - pad_y y_max = pad_y fig.update_layout( title=title, showlegend=False, hovermode="closest", margin=dict(l=20, r=20, t=60, b=20), height=max(240, 140 + rows * 80), plot_bgcolor="white", xaxis=dict( showgrid=False, zeroline=False, showticklabels=False, range=[-pad_x, max(0, max_cols - 1) + pad_x], ), yaxis=dict( showgrid=False, zeroline=False, showticklabels=False, range=[y_min, y_max], ), ) return fig def demo_text_interactions() -> go.Figure: token_rows = [["Violence", "is", "a", "perfect", "way"]] marginals_rows = [[0.2, -0.1, 0.0, 0.4, -0.2]] interactions = [ {"indices": [0, 3], "value": 2.1}, {"indices": [1, 4], "value": -1.2}, ] return plot_text_interactions( token_rows=token_rows, marginals_rows=marginals_rows, interactions=interactions, row_offsets=[0], top_k=30, title="Text interaction view (demo)", )