| """ |
| Utilities to render attribution visualizations for a text-interpretability web app. |
| Uses Plotly for heatmaps and inline HTML for text-based visualizations. |
| """ |
|
|
| import plotly.graph_objects as go |
| import numpy as np |
| from html import escape |
| from typing import List, Dict, Optional, Tuple, Any |
|
|
| from .utils import get_color_scale, format_feature_label, matplotlib_to_plotly |
|
|
| |
| |
| InteractionValues = None |
| sentence_plot = None |
| shap = None |
| plt = None |
|
|
| _SPEX_TEXT_STYLE = """ |
| <style id="spex-text-view-style"> |
| .spex-text-view { |
| --spex-bg: #f7f5f2; |
| --spex-border: #e3e3ec; |
| --spex-card-bg: #ffffff; |
| --spex-card-shadow: 0 14px 30px rgba(32, 25, 40, 0.08); |
| --spex-text: #3d2c36; |
| font-family: "Segoe UI", "Helvetica Neue", Arial, sans-serif; |
| background: var(--spex-bg); |
| border: 1px solid var(--spex-border); |
| border-radius: 18px; |
| padding: 20px; |
| display: flex; |
| flex-wrap: wrap; |
| gap: 18px; |
| } |
| .spex-text-card { |
| flex: 3 1 520px; |
| background: var(--spex-card-bg); |
| border: 1px solid var(--spex-border); |
| border-radius: 18px; |
| padding: 18px; |
| box-shadow: var(--spex-card-shadow); |
| } |
| .spex-card-header { |
| display: flex; |
| justify-content: space-between; |
| align-items: flex-end; |
| margin-bottom: 12px; |
| gap: 8px; |
| } |
| .spex-card-title { |
| font-size: 18px; |
| font-weight: 600; |
| color: var(--spex-text); |
| } |
| .spex-card-subtitle { |
| font-size: 13px; |
| color: #7f6f86; |
| } |
| .spex-token-grid { |
| display: block; |
| font-size: 16px; |
| line-height: 2; |
| color: #111111; |
| word-break: break-word; |
| white-space: pre-wrap; |
| } |
| .spex-token { |
| display: inline-flex; |
| flex-direction: column; |
| align-items: center; |
| justify-content: center; |
| vertical-align: baseline; |
| padding: 2px 6px; |
| margin: 0 2px; |
| border-radius: 12px; |
| border: 1px solid transparent; |
| background: rgba(225, 225, 223, 0.45); |
| box-decoration-break: clone; |
| transition: box-shadow 0.15s ease, background 0.15s ease; |
| } |
| .spex-token:hover { |
| box-shadow: 0 8px 16px rgba(0, 0, 0, 0.12); |
| } |
| .spex-token-score { |
| display: block; |
| font-size: 11px; |
| font-weight: 600; |
| color: #111111; |
| letter-spacing: 0.08em; |
| text-transform: uppercase; |
| margin-bottom: 2px; |
| } |
| .spex-token-text { |
| font-size: inherit; |
| color: #111111; |
| white-space: inherit; |
| } |
| .spex-token-plain { |
| color: #111111; |
| white-space: pre-wrap; |
| } |
| .spex-side-panel { |
| flex: 1 1 220px; |
| display: flex; |
| flex-direction: column; |
| gap: 12px; |
| } |
| .spex-side-card { |
| background: #fefcf8; |
| border: 1px dashed var(--spex-border); |
| border-radius: 16px; |
| padding: 16px; |
| } |
| .spex-side-card strong { |
| display: block; |
| font-size: 15px; |
| color: var(--spex-text); |
| margin-bottom: 6px; |
| } |
| .spex-legend-bar { |
| display: flex; |
| align-items: center; |
| gap: 8px; |
| margin: 12px 0; |
| } |
| .spex-legend-label { |
| font-size: 12px; |
| color: #6f5a72; |
| text-transform: uppercase; |
| letter-spacing: 0.08em; |
| } |
| .spex-legend-gradient { |
| flex: 1; |
| height: 10px; |
| border-radius: 999px; |
| background: linear-gradient(90deg, #dd1313, #e1e1df, #016d01); |
| } |
| .spex-legend-note { |
| font-size: 12px; |
| color: #6f5a72; |
| margin: 0; |
| } |
| .spex-raw-text { |
| flex-basis: 100%; |
| background: #ffffff; |
| border: 1px solid var(--spex-border); |
| border-radius: 16px; |
| padding: 16px; |
| box-shadow: 0 10px 18px rgba(32, 25, 40, 0.06); |
| } |
| .spex-raw-text strong { |
| display: block; |
| font-size: 14px; |
| color: #6f5a72; |
| text-transform: uppercase; |
| letter-spacing: 0.08em; |
| margin-bottom: 6px; |
| } |
| .spex-raw-text p { |
| margin: 0; |
| font-size: 13px; |
| line-height: 1.6; |
| white-space: pre-wrap; |
| color: #4a3b4e; |
| } |
| .spex-empty { |
| flex-basis: 100%; |
| text-align: center; |
| font-size: 14px; |
| color: #7f6f86; |
| } |
| @media (max-width: 900px) { |
| .spex-text-card, |
| .spex-side-panel { |
| flex: 1 1 100%; |
| } |
| } |
| </style> |
| """ |
|
|
| _NEGATIVE_RGB = (221, 19, 19) |
| _POSITIVE_RGB = (1, 109, 1) |
| _NEUTRAL_RGB = (225, 225, 223) |
|
|
|
|
| def _format_text_segment(value: str, preserve_blank: bool = False) -> str: |
| safe = escape(value or "") |
| safe = safe.replace("\n", "<br />") |
| if not safe and preserve_blank: |
| return " " |
| return safe or "" |
|
|
|
|
| def _normalize_span(span: Any, text_length: int) -> Tuple[int, int]: |
| if isinstance(span, dict): |
| start = span.get("start", span.get("begin", 0)) |
| end = span.get("end", span.get("stop", span.get("finish", 0))) |
| else: |
| start, end = span |
|
|
| try: |
| start_i = int(start) |
| except (TypeError, ValueError): |
| start_i = 0 |
| try: |
| end_i = int(end) |
| except (TypeError, ValueError): |
| end_i = start_i |
|
|
| start_i = max(0, min(text_length, start_i)) |
| end_i = max(start_i, min(text_length, end_i)) |
| return start_i, end_i |
|
|
|
|
| def _color_for_value(value: float, max_abs: float) -> Tuple[str, str, str]: |
| if max_abs <= 0: |
| rgb = _NEUTRAL_RGB |
| sign = "neutral" |
| else: |
| norm = max(-1.0, min(1.0, value / max_abs)) |
| t = (norm + 1.0) / 2.0 |
| if t < 0.5: |
| local = t * 2.0 |
| rgb = tuple( |
| int(round(_NEGATIVE_RGB[i] + (_NEUTRAL_RGB[i] - _NEGATIVE_RGB[i]) * local)) |
| for i in range(3) |
| ) |
| else: |
| local = (t - 0.5) * 2.0 |
| rgb = tuple( |
| int(round(_NEUTRAL_RGB[i] + (_POSITIVE_RGB[i] - _NEUTRAL_RGB[i]) * local)) |
| for i in range(3) |
| ) |
| sign = "positive" if norm > 0 else "negative" if norm < 0 else "neutral" |
|
|
| r, g, b = rgb |
| hex_color = f"#{r:02x}{g:02x}{b:02x}" |
| intensity = min(1.0, abs(value) / max_abs) if max_abs > 0 else 0.0 |
| alpha = 0.25 + 0.45 * intensity |
| background = f"rgba({r}, {g}, {b}, {alpha:.3f})" |
| return hex_color, background, sign |
|
|
| def _build_sentence_interaction_values(values: List[float], method: str) -> Optional[InteractionValues]: |
| if InteractionValues is None: |
| return None |
| n_players = len(values) |
| if n_players == 0: |
| return None |
| lookup = {(i,): i for i in range(n_players)} |
| index = "SV" if method == "shapley" else ("IV" if method == "influence" else "BV") |
| return InteractionValues( |
| values=np.array(values, dtype=float), |
| index=index, |
| max_order=1, |
| n_players=n_players, |
| min_order=1, |
| interaction_lookup=lookup, |
| estimated=False, |
| baseline_value=0.0, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def create_attribution_heatmap( |
| features: List[str], |
| attributions: Dict[str, float], |
| method: str = "shapley", |
| title: Optional[str] = None, |
| ) -> go.Figure: |
| |
| raw_values = np.array([attributions.get(f, 0.0) for f in features], dtype=float) |
|
|
| |
| if raw_values.size == 0: |
| return go.Figure() |
|
|
| |
| max_abs = float(np.max(np.abs(raw_values))) |
| scale = 1.0 |
| colorbar_title = f"{method.title()} value" |
|
|
| if max_abs > 0.0 and max_abs < 1e-4: |
| |
| scale = 1.0 / max_abs |
| colorbar_title = f"{method.title()} (×{scale:.0e})" |
|
|
| values = raw_values * scale |
|
|
| |
| if sentence_plot is not None: |
| iv = _build_sentence_interaction_values(values.tolist(), method) |
| if iv is not None: |
| result = sentence_plot( |
| iv, |
| words=features, |
| show=False, |
| chars_per_line=80, |
| ) |
| if result is not None: |
| fig, _ = result |
| return matplotlib_to_plotly( |
| fig, |
| title=title or f"{method.title()} token attributions", |
| height=max(300, 30 * len(features)), |
| ) |
|
|
| |
| sorted_features = features |
| sorted_values = values |
|
|
| abs_vals = np.abs(sorted_values) |
| vmax = float(np.percentile(abs_vals, 95)) if abs_vals.size else 1.0 |
| vmax = max(vmax, 1e-6) |
|
|
| colorscale = get_color_scale("shapley" if method == "shapley" else method) |
|
|
| heatmap = go.Heatmap( |
| z=sorted_values[:, None], |
| x=["Attribution"], |
| y=[format_feature_label(f, max_length=60) for f in sorted_features], |
| colorscale=colorscale, |
| zmid=0.0, |
| zmin=-vmax, |
| zmax=vmax, |
| colorbar=dict(title=colorbar_title), |
| hovertemplate="%{y}<br>%{x}: %{z:.4f}<extra></extra>", |
| showscale=True, |
| ) |
|
|
| fig = go.Figure(data=[heatmap]) |
| fig.update_layout( |
| title=title or f"{method.title()} token attributions", |
| xaxis=dict(showticklabels=False), |
| yaxis=dict(autorange="reversed"), |
| margin=dict(l=140, r=40, t=60, b=40), |
| height=max(320, 22 * len(sorted_features)), |
| ) |
| return fig |
|
|
| def create_interactive_text_heatmap( |
| text: str, |
| feature_spans: List[Any], |
| attributions: List[Any], |
| method: str = "shapley", |
| ) -> str: |
| """ |
| Render a Spectral Explain–style text view with token chips, legend, and raw text. |
| |
| Args: |
| text: Original text that generated the attributions. |
| feature_spans: Character spans identifying each token/feature. |
| attributions: Numeric attribution values aligned with feature_spans. |
| method: Attribution method label. |
| |
| Returns: |
| Styled HTML that can be injected into the Gradio Text View tab. |
| """ |
| if len(feature_spans) != len(attributions): |
| raise ValueError("feature_spans and attributions must have the same length") |
|
|
| source_text = text or "" |
| text_len = len(source_text) |
|
|
| tokens: List[Dict[str, Any]] = [] |
| numeric_values: List[float] = [] |
| for idx, (span, raw_value) in enumerate(zip(feature_spans, attributions), start=1): |
| start, end = _normalize_span(span, text_len) |
| snippet = source_text[start:end] |
| try: |
| value = float(raw_value) |
| except (TypeError, ValueError): |
| value = 0.0 |
|
|
| tokens.append( |
| { |
| "index": idx, |
| "text": snippet, |
| "value": value, |
| "start": start, |
| "end": end, |
| } |
| ) |
| numeric_values.append(value) |
|
|
| if not tokens: |
| fallback = _format_text_segment(source_text) or "No text available." |
| return ( |
| f"{_SPEX_TEXT_STYLE}" |
| '<div class="spex-text-view">' |
| '<div class="spex-empty">No feature spans were provided for this example.</div>' |
| f'<div class="spex-raw-text"><strong>Raw text</strong><p>{fallback}</p></div>' |
| "</div>" |
| ) |
|
|
| max_abs = max((abs(v) for v in numeric_values), default=0.0) |
| max_abs = max_abs or 1.0 |
| method_label = (method or "attribution").title() |
|
|
| flow_parts: List[str] = [] |
| cursor = 0 |
| for token in tokens: |
| start = token["start"] |
| end = token["end"] |
| if cursor < start: |
| plain = _format_text_segment(source_text[cursor:start], preserve_blank=True) |
| if plain: |
| flow_parts.append(f'<span class="spex-token-plain">{plain}</span>') |
|
|
| color_hex, background, sign = _color_for_value(token["value"], max_abs) |
| tooltip = escape( |
| f"{method_label} · chars [{token['start']}:{token['end']}] · {token['value']:+.4f}" |
| ) |
| text_html = _format_text_segment(token["text"], preserve_blank=True) or " " |
| flow_parts.append( |
| f'<span class="spex-token spex-token--{sign}" ' |
| f'data-token-index="{token["index"]}" ' |
| f'data-attr="{token["value"]:.6f}" ' |
| f'style="background-color:{background}; border-color:{color_hex};" ' |
| f'title="{tooltip}">' |
| f'<span class="spex-token-text">{text_html}</span>' |
| "</span>" |
| ) |
| cursor = end |
|
|
| if cursor < len(source_text): |
| trailing = _format_text_segment(source_text[cursor:], preserve_blank=True) |
| if trailing: |
| flow_parts.append(f'<span class="spex-token-plain">{trailing}</span>') |
|
|
| flow_html = "".join(flow_parts) or " " |
|
|
| legend = ( |
| '<div class="spex-side-card">' |
| f"<strong>{method_label} legend</strong>" |
| '<div class="spex-legend-bar">' |
| '<span class="spex-legend-label">Negative</span>' |
| '<div class="spex-legend-gradient"></div>' |
| '<span class="spex-legend-label">Positive</span>' |
| "</div>" |
| f'<p class="spex-legend-note">Normalized by max |value| = {max_abs:.4f}. Hover tokens for exact scores.</p>' |
| "</div>" |
| ) |
|
|
| raw_text_block = "" |
| if source_text: |
| raw_text_block = ( |
| '<div class="spex-raw-text">' |
| "<strong>Raw text</strong>" |
| f"<p>{_format_text_segment(source_text)}</p>" |
| "</div>" |
| ) |
|
|
| body = ( |
| f"{_SPEX_TEXT_STYLE}" |
| '<div class="spex-text-view">' |
| '<div class="spex-text-card">' |
| '<div class="spex-card-header">' |
| '<div>' |
| '<div class="spex-card-title">Context</div>' |
| f'<div class="spex-card-subtitle">{method_label} token attributions</div>' |
| "</div>" |
| f'<div class="spex-card-subtitle">Tokens: {len(tokens)}</div>' |
| "</div>" |
| f'<div class="spex-token-grid">{flow_html}</div>' |
| "</div>" |
| f'<div class="spex-side-panel">{legend}</div>' |
| f"{raw_text_block}" |
| "</div>" |
| ) |
| return body |
|
|
|
|
| def normalize_attributions( |
| attributions: Dict[str, float], |
| method: str = "minmax" |
| ) -> Dict[str, float]: |
| """ |
| Normalize attribution values for visualization. |
| |
| Args: |
| attributions: Raw attribution dict {feature: value}. |
| method: Normalization mode: "minmax" or "zscore". |
| |
| Returns: |
| A dict with normalized values using the same keys as the input. |
| """ |
| if not attributions: |
| return {} |
|
|
| values = np.array(list(attributions.values()), dtype=float) |
|
|
| if method == "zscore": |
| mean = float(values.mean()) |
| std = float(values.std()) |
| if std == 0: |
| std = 1.0 |
| normalized = (values - mean) / std |
| else: |
| v_min = float(values.min()) |
| v_max = float(values.max()) |
| if v_max - v_min == 0: |
| normalized = np.zeros_like(values) |
| else: |
| normalized = (values - v_min) / (v_max - v_min) |
| normalized = normalized * 2 - 1 |
|
|
| return {key: float(val) for key, val in zip(attributions.keys(), normalized)} |
|
|