AttrLLM / visualization /plotting /text_interaction_graph.py
Qingpeng Kong
clean initial state
3e72399
from __future__ import annotations
from typing import Dict, List, Tuple, Optional
import plotly.graph_objects as go
from .utils import get_color_scale
def _normalize_row_offsets(token_rows: List[List[str]], row_offsets: List[int]) -> List[int]:
if row_offsets and len(row_offsets) == len(token_rows):
return row_offsets
offsets: List[int] = []
cursor = 0
for row in token_rows:
offsets.append(cursor)
cursor += len(row)
return offsets
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
color = color.lstrip("#")
if len(color) == 3:
color = "".join(ch * 2 for ch in color)
return tuple(int(color[i : i + 2], 16) for i in (0, 2, 4))
def _interpolate_color(left: Tuple[int, int, int], right: Tuple[int, int, int], t: float) -> str:
r = int(left[0] + (right[0] - left[0]) * t)
g = int(left[1] + (right[1] - left[1]) * t)
b = int(left[2] + (right[2] - left[2]) * t)
return f"rgb({r}, {g}, {b})"
def _colorscale_to_color(colorscale: List, t: float) -> str:
if not colorscale:
return "rgb(200, 200, 200)"
t = max(0.0, min(1.0, t))
for idx in range(len(colorscale) - 1):
left_pos, left_color = colorscale[idx]
right_pos, right_color = colorscale[idx + 1]
if t <= right_pos:
if isinstance(left_color, str) and left_color.startswith("rgb"):
left_rgb = tuple(int(v) for v in left_color[4:-1].split(","))
else:
left_rgb = _hex_to_rgb(str(left_color))
if isinstance(right_color, str) and right_color.startswith("rgb"):
right_rgb = tuple(int(v) for v in right_color[4:-1].split(","))
else:
right_rgb = _hex_to_rgb(str(right_color))
span = right_pos - left_pos or 1.0
local_t = (t - left_pos) / span
return _interpolate_color(left_rgb, right_rgb, local_t)
tail = colorscale[-1][1]
if isinstance(tail, str) and tail.startswith("rgb"):
return tail
return _interpolate_color(_hex_to_rgb(str(tail)), _hex_to_rgb(str(tail)), 0.0)
def _value_to_color(value: float, max_abs: float, colorscale: List) -> str:
if max_abs <= 0:
return "rgb(220, 220, 220)"
normalized = (value / max_abs + 1.0) / 2.0
return _colorscale_to_color(colorscale, normalized)
def _strip_occurrence_suffix(text: str) -> str:
text = text or ""
if text.endswith(")") and " (" in text:
base, _, tail = text.rpartition(" (")
if tail[:-1].isdigit():
return base
return text
def plot_text_interactions(
token_rows: list[list[str]],
marginals_rows: Optional[list[list[float]]],
interactions: list[dict],
row_offsets: list[int],
top_k: int = 30,
title: str = "Text interaction view",
) -> go.Figure:
if not token_rows:
fig = go.Figure()
fig.update_layout(
title=title,
annotations=[{
"text": "No tokens available",
"xref": "paper",
"yref": "paper",
"x": 0.5,
"y": 0.5,
"showarrow": False,
"font": {"size": 16, "color": "#666"},
}],
template="plotly_white",
height=240,
)
return fig
offsets = _normalize_row_offsets(token_rows, row_offsets or [])
colorscale = get_color_scale("shapley")
node_x: List[float] = []
node_y: List[float] = []
node_labels: List[str] = []
node_values: List[float] = []
node_hover: List[str] = []
global_to_node: Dict[int, int] = {}
max_cols = 0
for row_idx, row in enumerate(token_rows):
max_cols = max(max_cols, len(row))
row_offset = offsets[row_idx] if row_idx < len(offsets) else 0
row_vals = marginals_rows[row_idx] if row_idx < len(marginals_rows or []) else []
for col_idx, token in enumerate(row):
global_idx = row_offset + col_idx
global_to_node[global_idx] = len(node_x)
node_x.append(float(col_idx))
node_y.append(float(-row_idx))
display_token = _strip_occurrence_suffix(str(token))
node_labels.append(display_token)
value = float(row_vals[col_idx]) if col_idx < len(row_vals) else 0.0
node_values.append(value)
node_hover.append(f"{display_token}<br>Value: {value:+.3f}")
max_abs_value = max((abs(v) for v in node_values), default=0.0)
node_colors = [
_value_to_color(value, max_abs_value, colorscale) for value in node_values
]
edges: List[Tuple[int, int, float]] = []
for item in interactions or []:
if not isinstance(item, dict):
continue
indices = item.get("indices")
if not indices or len(indices) != 2:
continue
try:
i = int(indices[0])
j = int(indices[1])
except Exception:
continue
if i not in global_to_node or j not in global_to_node:
continue
try:
value = float(item.get("value", 0.0))
except Exception:
value = 0.0
edges.append((i, j, value))
edges.sort(key=lambda item: abs(item[2]), reverse=True)
edges = edges[:top_k]
max_abs_edge = max((abs(v) for _, _, v in edges), default=0.0) or 1.0
edge_traces: List[go.Scatter] = []
for i, j, value in edges:
idx_i = global_to_node.get(i)
idx_j = global_to_node.get(j)
if idx_i is None or idx_j is None:
continue
x_i, y_i = node_x[idx_i], node_y[idx_i]
x_j, y_j = node_x[idx_j], node_y[idx_j]
width = 1 + 6 * (abs(value) / max_abs_edge if max_abs_edge > 0 else 0)
color = "#d35400" if value >= 0 else "#3867d6"
label_i = node_labels[idx_i]
label_j = node_labels[idx_j]
edge_traces.append(
go.Scatter(
x=[x_i, x_j],
y=[y_i, y_j],
mode="lines",
line=dict(color=color, width=width),
opacity=0.7,
hoverinfo="text",
hovertext=f"{label_i} x {label_j} : {value:+.3f}",
showlegend=False,
)
)
node_trace = go.Scatter(
x=node_x,
y=node_y,
mode="markers+text",
text=node_labels,
textposition="middle center",
marker=dict(
size=28,
color=node_colors,
line=dict(width=1, color="#2f2f2f"),
),
hoverinfo="text",
hovertext=node_hover,
showlegend=False,
)
fig = go.Figure(data=edge_traces + [node_trace])
pad_x = 0.6
pad_y = 0.6
rows = len(token_rows)
y_min = -(rows - 1) - pad_y
y_max = pad_y
fig.update_layout(
title=title,
showlegend=False,
hovermode="closest",
margin=dict(l=20, r=20, t=60, b=20),
height=max(240, 140 + rows * 80),
plot_bgcolor="white",
xaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False,
range=[-pad_x, max(0, max_cols - 1) + pad_x],
),
yaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False,
range=[y_min, y_max],
),
)
return fig
def demo_text_interactions() -> go.Figure:
token_rows = [["Violence", "is", "a", "perfect", "way"]]
marginals_rows = [[0.2, -0.1, 0.0, 0.4, -0.2]]
interactions = [
{"indices": [0, 3], "value": 2.1},
{"indices": [1, 4], "value": -1.2},
]
return plot_text_interactions(
token_rows=token_rows,
marginals_rows=marginals_rows,
interactions=interactions,
row_offsets=[0],
top_k=30,
title="Text interaction view (demo)",
)