| from __future__ import annotations |
|
|
| from typing import Dict, List, Tuple, Optional |
|
|
| import plotly.graph_objects as go |
|
|
| from .utils import get_color_scale |
|
|
|
|
| def _normalize_row_offsets(token_rows: List[List[str]], row_offsets: List[int]) -> List[int]: |
| if row_offsets and len(row_offsets) == len(token_rows): |
| return row_offsets |
| offsets: List[int] = [] |
| cursor = 0 |
| for row in token_rows: |
| offsets.append(cursor) |
| cursor += len(row) |
| return offsets |
|
|
|
|
| def _hex_to_rgb(color: str) -> Tuple[int, int, int]: |
| color = color.lstrip("#") |
| if len(color) == 3: |
| color = "".join(ch * 2 for ch in color) |
| return tuple(int(color[i : i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
| def _interpolate_color(left: Tuple[int, int, int], right: Tuple[int, int, int], t: float) -> str: |
| r = int(left[0] + (right[0] - left[0]) * t) |
| g = int(left[1] + (right[1] - left[1]) * t) |
| b = int(left[2] + (right[2] - left[2]) * t) |
| return f"rgb({r}, {g}, {b})" |
|
|
|
|
| def _colorscale_to_color(colorscale: List, t: float) -> str: |
| if not colorscale: |
| return "rgb(200, 200, 200)" |
| t = max(0.0, min(1.0, t)) |
| for idx in range(len(colorscale) - 1): |
| left_pos, left_color = colorscale[idx] |
| right_pos, right_color = colorscale[idx + 1] |
| if t <= right_pos: |
| if isinstance(left_color, str) and left_color.startswith("rgb"): |
| left_rgb = tuple(int(v) for v in left_color[4:-1].split(",")) |
| else: |
| left_rgb = _hex_to_rgb(str(left_color)) |
| if isinstance(right_color, str) and right_color.startswith("rgb"): |
| right_rgb = tuple(int(v) for v in right_color[4:-1].split(",")) |
| else: |
| right_rgb = _hex_to_rgb(str(right_color)) |
| span = right_pos - left_pos or 1.0 |
| local_t = (t - left_pos) / span |
| return _interpolate_color(left_rgb, right_rgb, local_t) |
| tail = colorscale[-1][1] |
| if isinstance(tail, str) and tail.startswith("rgb"): |
| return tail |
| return _interpolate_color(_hex_to_rgb(str(tail)), _hex_to_rgb(str(tail)), 0.0) |
|
|
|
|
| def _value_to_color(value: float, max_abs: float, colorscale: List) -> str: |
| if max_abs <= 0: |
| return "rgb(220, 220, 220)" |
| normalized = (value / max_abs + 1.0) / 2.0 |
| return _colorscale_to_color(colorscale, normalized) |
|
|
|
|
| def _strip_occurrence_suffix(text: str) -> str: |
| text = text or "" |
| if text.endswith(")") and " (" in text: |
| base, _, tail = text.rpartition(" (") |
| if tail[:-1].isdigit(): |
| return base |
| return text |
|
|
|
|
| def plot_text_interactions( |
| token_rows: list[list[str]], |
| marginals_rows: Optional[list[list[float]]], |
| interactions: list[dict], |
| row_offsets: list[int], |
| top_k: int = 30, |
| title: str = "Text interaction view", |
| ) -> go.Figure: |
| if not token_rows: |
| fig = go.Figure() |
| fig.update_layout( |
| title=title, |
| annotations=[{ |
| "text": "No tokens available", |
| "xref": "paper", |
| "yref": "paper", |
| "x": 0.5, |
| "y": 0.5, |
| "showarrow": False, |
| "font": {"size": 16, "color": "#666"}, |
| }], |
| template="plotly_white", |
| height=240, |
| ) |
| return fig |
|
|
| offsets = _normalize_row_offsets(token_rows, row_offsets or []) |
| colorscale = get_color_scale("shapley") |
|
|
| node_x: List[float] = [] |
| node_y: List[float] = [] |
| node_labels: List[str] = [] |
| node_values: List[float] = [] |
| node_hover: List[str] = [] |
| global_to_node: Dict[int, int] = {} |
|
|
| max_cols = 0 |
| for row_idx, row in enumerate(token_rows): |
| max_cols = max(max_cols, len(row)) |
| row_offset = offsets[row_idx] if row_idx < len(offsets) else 0 |
| row_vals = marginals_rows[row_idx] if row_idx < len(marginals_rows or []) else [] |
| for col_idx, token in enumerate(row): |
| global_idx = row_offset + col_idx |
| global_to_node[global_idx] = len(node_x) |
| node_x.append(float(col_idx)) |
| node_y.append(float(-row_idx)) |
| display_token = _strip_occurrence_suffix(str(token)) |
| node_labels.append(display_token) |
| value = float(row_vals[col_idx]) if col_idx < len(row_vals) else 0.0 |
| node_values.append(value) |
| node_hover.append(f"{display_token}<br>Value: {value:+.3f}") |
|
|
| max_abs_value = max((abs(v) for v in node_values), default=0.0) |
| node_colors = [ |
| _value_to_color(value, max_abs_value, colorscale) for value in node_values |
| ] |
|
|
| edges: List[Tuple[int, int, float]] = [] |
| for item in interactions or []: |
| if not isinstance(item, dict): |
| continue |
| indices = item.get("indices") |
| if not indices or len(indices) != 2: |
| continue |
| try: |
| i = int(indices[0]) |
| j = int(indices[1]) |
| except Exception: |
| continue |
| if i not in global_to_node or j not in global_to_node: |
| continue |
| try: |
| value = float(item.get("value", 0.0)) |
| except Exception: |
| value = 0.0 |
| edges.append((i, j, value)) |
|
|
| edges.sort(key=lambda item: abs(item[2]), reverse=True) |
| edges = edges[:top_k] |
| max_abs_edge = max((abs(v) for _, _, v in edges), default=0.0) or 1.0 |
|
|
| edge_traces: List[go.Scatter] = [] |
| for i, j, value in edges: |
| idx_i = global_to_node.get(i) |
| idx_j = global_to_node.get(j) |
| if idx_i is None or idx_j is None: |
| continue |
| x_i, y_i = node_x[idx_i], node_y[idx_i] |
| x_j, y_j = node_x[idx_j], node_y[idx_j] |
| width = 1 + 6 * (abs(value) / max_abs_edge if max_abs_edge > 0 else 0) |
| color = "#d35400" if value >= 0 else "#3867d6" |
| label_i = node_labels[idx_i] |
| label_j = node_labels[idx_j] |
| edge_traces.append( |
| go.Scatter( |
| x=[x_i, x_j], |
| y=[y_i, y_j], |
| mode="lines", |
| line=dict(color=color, width=width), |
| opacity=0.7, |
| hoverinfo="text", |
| hovertext=f"{label_i} x {label_j} : {value:+.3f}", |
| showlegend=False, |
| ) |
| ) |
|
|
| node_trace = go.Scatter( |
| x=node_x, |
| y=node_y, |
| mode="markers+text", |
| text=node_labels, |
| textposition="middle center", |
| marker=dict( |
| size=28, |
| color=node_colors, |
| line=dict(width=1, color="#2f2f2f"), |
| ), |
| hoverinfo="text", |
| hovertext=node_hover, |
| showlegend=False, |
| ) |
|
|
| fig = go.Figure(data=edge_traces + [node_trace]) |
| pad_x = 0.6 |
| pad_y = 0.6 |
| rows = len(token_rows) |
| y_min = -(rows - 1) - pad_y |
| y_max = pad_y |
| fig.update_layout( |
| title=title, |
| showlegend=False, |
| hovermode="closest", |
| margin=dict(l=20, r=20, t=60, b=20), |
| height=max(240, 140 + rows * 80), |
| plot_bgcolor="white", |
| xaxis=dict( |
| showgrid=False, |
| zeroline=False, |
| showticklabels=False, |
| range=[-pad_x, max(0, max_cols - 1) + pad_x], |
| ), |
| yaxis=dict( |
| showgrid=False, |
| zeroline=False, |
| showticklabels=False, |
| range=[y_min, y_max], |
| ), |
| ) |
| return fig |
|
|
|
|
| def demo_text_interactions() -> go.Figure: |
| token_rows = [["Violence", "is", "a", "perfect", "way"]] |
| marginals_rows = [[0.2, -0.1, 0.0, 0.4, -0.2]] |
| interactions = [ |
| {"indices": [0, 3], "value": 2.1}, |
| {"indices": [1, 4], "value": -1.2}, |
| ] |
| return plot_text_interactions( |
| token_rows=token_rows, |
| marginals_rows=marginals_rows, |
| interactions=interactions, |
| row_offsets=[0], |
| top_k=30, |
| title="Text interaction view (demo)", |
| ) |
|
|