""" Utilities to render attribution visualizations for a text-interpretability web app. Uses Plotly for heatmaps and inline HTML for text-based visualizations. """ import plotly.graph_objects as go import numpy as np from html import escape from typing import List, Dict, Optional, Tuple, Any from .utils import get_color_scale, format_feature_label, matplotlib_to_plotly # Dummy placeholders so functions that reference these names still type-check, # but we do NOT import heavy deps like shapiq / shap / numba in this environment. InteractionValues = None # type: ignore sentence_plot = None shap = None plt = None _SPEX_TEXT_STYLE = """ """ _NEGATIVE_RGB = (221, 19, 19) _POSITIVE_RGB = (1, 109, 1) _NEUTRAL_RGB = (225, 225, 223) def _format_text_segment(value: str, preserve_blank: bool = False) -> str: safe = escape(value or "") safe = safe.replace("\n", "
") if not safe and preserve_blank: return " " return safe or "" def _normalize_span(span: Any, text_length: int) -> Tuple[int, int]: if isinstance(span, dict): start = span.get("start", span.get("begin", 0)) end = span.get("end", span.get("stop", span.get("finish", 0))) else: start, end = span try: start_i = int(start) except (TypeError, ValueError): start_i = 0 try: end_i = int(end) except (TypeError, ValueError): end_i = start_i start_i = max(0, min(text_length, start_i)) end_i = max(start_i, min(text_length, end_i)) return start_i, end_i def _color_for_value(value: float, max_abs: float) -> Tuple[str, str, str]: if max_abs <= 0: rgb = _NEUTRAL_RGB sign = "neutral" else: norm = max(-1.0, min(1.0, value / max_abs)) t = (norm + 1.0) / 2.0 if t < 0.5: local = t * 2.0 rgb = tuple( int(round(_NEGATIVE_RGB[i] + (_NEUTRAL_RGB[i] - _NEGATIVE_RGB[i]) * local)) for i in range(3) ) else: local = (t - 0.5) * 2.0 rgb = tuple( int(round(_NEUTRAL_RGB[i] + (_POSITIVE_RGB[i] - _NEUTRAL_RGB[i]) * local)) for i in range(3) ) sign = "positive" if norm > 0 else "negative" if norm < 0 else "neutral" r, g, b = rgb hex_color = f"#{r:02x}{g:02x}{b:02x}" intensity = min(1.0, abs(value) / max_abs) if max_abs > 0 else 0.0 alpha = 0.25 + 0.45 * intensity background = f"rgba({r}, {g}, {b}, {alpha:.3f})" return hex_color, background, sign def _build_sentence_interaction_values(values: List[float], method: str) -> Optional[InteractionValues]: if InteractionValues is None: return None n_players = len(values) if n_players == 0: return None lookup = {(i,): i for i in range(n_players)} index = "SV" if method == "shapley" else ("IV" if method == "influence" else "BV") return InteractionValues( values=np.array(values, dtype=float), index=index, max_order=1, n_players=n_players, min_order=1, interaction_lookup=lookup, estimated=False, baseline_value=0.0, ) # def create_attribution_heatmap( # features: List[str], # attributions: Dict[str, float], # method: str = "shapley", # title: Optional[str] = None # ) -> go.Figure: # """ # Create a feature-level attribution heatmap. # Args: # features: Ordered feature list (from mask_text or tokenizer). # attributions: Mapping from feature -> attribution value # (e.g., from mobius_to_shapley/banzhaf). # method: "shapley" or "banzhaf" (used in the caption/labeling). # title: Optional chart title. # Returns: # A Plotly Figure object. # Example: # attrs = compute_attributions(model, context, answer, "shapley") # fig = create_attribution_heatmap(attrs["features"], attrs["values"], "shapley") # """ # values = np.array([attributions.get(f, 0.0) for f in features], dtype=float) # if sentence_plot is not None: # iv = _build_sentence_interaction_values(values.tolist(), method) # if iv is not None: # result = sentence_plot( # iv, # words=features, # show=False, # chars_per_line=80, # ) # if result is not None: # fig, _ = result # return matplotlib_to_plotly( # fig, # title=title or f"{method.title()} token attributions", # height=max(300, 30 * len(features)), # ) # if shap is not None and plt is not None: # explanation = shap.Explanation( # values=np.array([values]), # base_values=np.zeros(1), # data=np.array([features], dtype=object), # feature_names=features, # ) # try: # fig, ax = plt.subplots( # figsize=(4, max(4, len(features) * 0.25)), # constrained_layout=True, # ) # shap.plots.heatmap(explanation, show=False, ax=ax) # fig.canvas.draw() # image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) # image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) # plt.close(fig) # plotly_fig = go.Figure(go.Image(z=image)) # plotly_fig.update_xaxes(visible=False) # plotly_fig.update_yaxes(visible=False) # plotly_fig.update_layout( # title=title or f"{method.title()} token attributions (SHAP heatmap)", # margin=dict(l=0, r=0, t=60, b=0), # ) # return plotly_fig # except ValueError: # plt.close("all") # order = np.argsort(-np.abs(values)) # sorted_features = [features[i] for i in order] # sorted_values = values[order] # max_abs = float(np.max(np.abs(sorted_values))) if sorted_values.size else 1.0 # max_abs = max(max_abs, 1e-6) # colorscale = get_color_scale("shapley" if method == "shapley" else method) # heatmap = go.Heatmap( # z=sorted_values[:, None], # x=["Attribution"], # y=[format_feature_label(f, max_length=30) for f in sorted_features], # colorscale=colorscale, # zmid=0.0, # zmin=-max_abs, # zmax=max_abs, # colorbar=dict(title=f"{method.title()} value"), # hovertemplate="%{y}
%{x}: %{z:.4f}", # showscale=True, # text=[f"{v:.3f}" for v in sorted_values], # texttemplate="%{text}", # textfont={"color": "black"}, # ) # fig = go.Figure(data=[heatmap]) # fig.update_layout( # title=title or f"{method.title()} token attributions", # xaxis=dict(showticklabels=False), # yaxis=dict(autorange="reversed"), # margin=dict(l=120, r=40, t=60, b=40), # height=max(300, 20 * len(sorted_features)), # ) # return fig # --- Build numpy array of original values -------------------------- def create_attribution_heatmap( features: List[str], attributions: Dict[str, float], method: str = "shapley", title: Optional[str] = None, ) -> go.Figure: # 1. Pull raw values from backend raw_values = np.array([attributions.get(f, 0.0) for f in features], dtype=float) # No features -> empty figure if raw_values.size == 0: return go.Figure() # 2. Decide how much to rescale max_abs = float(np.max(np.abs(raw_values))) scale = 1.0 colorbar_title = f"{method.title()} value" if max_abs > 0.0 and max_abs < 1e-4: # Values are extremely small (like 1e-6 etc.) → blow them up scale = 1.0 / max_abs colorbar_title = f"{method.title()} (×{scale:.0e})" values = raw_values * scale # 3. (Optional) use shapiq sentence_plot if available if sentence_plot is not None: iv = _build_sentence_interaction_values(values.tolist(), method) if iv is not None: result = sentence_plot( iv, words=features, show=False, chars_per_line=80, ) if result is not None: fig, _ = result return matplotlib_to_plotly( fig, title=title or f"{method.title()} token attributions", height=max(300, 30 * len(features)), ) # 4. Plain Plotly heatmap (keep original order on y-axis) sorted_features = features sorted_values = values abs_vals = np.abs(sorted_values) vmax = float(np.percentile(abs_vals, 95)) if abs_vals.size else 1.0 vmax = max(vmax, 1e-6) colorscale = get_color_scale("shapley" if method == "shapley" else method) heatmap = go.Heatmap( z=sorted_values[:, None], x=["Attribution"], y=[format_feature_label(f, max_length=60) for f in sorted_features], colorscale=colorscale, zmid=0.0, zmin=-vmax, zmax=vmax, colorbar=dict(title=colorbar_title), hovertemplate="%{y}
%{x}: %{z:.4f}", showscale=True, ) fig = go.Figure(data=[heatmap]) fig.update_layout( title=title or f"{method.title()} token attributions", xaxis=dict(showticklabels=False), yaxis=dict(autorange="reversed"), margin=dict(l=140, r=40, t=60, b=40), height=max(320, 22 * len(sorted_features)), ) return fig def create_interactive_text_heatmap( text: str, feature_spans: List[Any], # list of (start, end) or dict spans attributions: List[Any], method: str = "shapley", ) -> str: """ Render a Spectral Explain–style text view with token chips, legend, and raw text. Args: text: Original text that generated the attributions. feature_spans: Character spans identifying each token/feature. attributions: Numeric attribution values aligned with feature_spans. method: Attribution method label. Returns: Styled HTML that can be injected into the Gradio Text View tab. """ if len(feature_spans) != len(attributions): raise ValueError("feature_spans and attributions must have the same length") source_text = text or "" text_len = len(source_text) tokens: List[Dict[str, Any]] = [] numeric_values: List[float] = [] for idx, (span, raw_value) in enumerate(zip(feature_spans, attributions), start=1): start, end = _normalize_span(span, text_len) snippet = source_text[start:end] try: value = float(raw_value) except (TypeError, ValueError): value = 0.0 tokens.append( { "index": idx, "text": snippet, "value": value, "start": start, "end": end, } ) numeric_values.append(value) if not tokens: fallback = _format_text_segment(source_text) or "No text available." return ( f"{_SPEX_TEXT_STYLE}" '
' '
No feature spans were provided for this example.
' f'
Raw text

{fallback}

' "
" ) max_abs = max((abs(v) for v in numeric_values), default=0.0) max_abs = max_abs or 1.0 method_label = (method or "attribution").title() flow_parts: List[str] = [] cursor = 0 for token in tokens: start = token["start"] end = token["end"] if cursor < start: plain = _format_text_segment(source_text[cursor:start], preserve_blank=True) if plain: flow_parts.append(f'{plain}') color_hex, background, sign = _color_for_value(token["value"], max_abs) tooltip = escape( f"{method_label} · chars [{token['start']}:{token['end']}] · {token['value']:+.4f}" ) text_html = _format_text_segment(token["text"], preserve_blank=True) or " " flow_parts.append( f'' f'{text_html}' "" ) cursor = end if cursor < len(source_text): trailing = _format_text_segment(source_text[cursor:], preserve_blank=True) if trailing: flow_parts.append(f'{trailing}') flow_html = "".join(flow_parts) or " " legend = ( '
' f"{method_label} legend" '
' 'Negative' '
' 'Positive' "
" f'

Normalized by max |value| = {max_abs:.4f}. Hover tokens for exact scores.

' "
" ) raw_text_block = "" if source_text: raw_text_block = ( '
' "Raw text" f"

{_format_text_segment(source_text)}

" "
" ) body = ( f"{_SPEX_TEXT_STYLE}" '
' '
' '
' '
' '
Context
' f'
{method_label} token attributions
' "
" f'
Tokens: {len(tokens)}
' "
" f'
{flow_html}
' "
" f'
{legend}
' f"{raw_text_block}" "
" ) return body def normalize_attributions( attributions: Dict[str, float], method: str = "minmax" ) -> Dict[str, float]: """ Normalize attribution values for visualization. Args: attributions: Raw attribution dict {feature: value}. method: Normalization mode: "minmax" or "zscore". Returns: A dict with normalized values using the same keys as the input. """ if not attributions: return {} values = np.array(list(attributions.values()), dtype=float) if method == "zscore": mean = float(values.mean()) std = float(values.std()) if std == 0: std = 1.0 normalized = (values - mean) / std else: # default to min-max v_min = float(values.min()) v_max = float(values.max()) if v_max - v_min == 0: normalized = np.zeros_like(values) else: normalized = (values - v_min) / (v_max - v_min) normalized = normalized * 2 - 1 # center at 0 for diverging scales return {key: float(val) for key, val in zip(attributions.keys(), normalized)}