"""Token-level activation heatmaps, feature dashboards, and visualization utilities. Generates interactive Plotly visualizations for the dashboard: - Token-level feature activation heatmaps - Feature activation distributions - Steered vs. unsteered comparison displays - Layer-wise feature activity plots """ import numpy as np import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots from typing import Optional def create_token_heatmap( str_tokens: list[str], activations: list[float], feature_idx: int, description: str = "", colorscale: str = "YlOrRd", ) -> go.Figure: """Create a heatmap showing feature activation per token. Displays tokens along the x-axis with color intensity proportional to the feature's activation on that token. """ # Clean up token strings for display display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens] # Reshape activations for heatmap (1 x n_tokens) z = np.array(activations).reshape(1, -1) fig = go.Figure( data=go.Heatmap( z=z, x=display_tokens, y=["Activation"], colorscale=colorscale, text=[[f"{v:.3f}" for v in activations]], texttemplate="%{text}", textfont={"size": 10}, hovertemplate="Token: %{x}
Activation: %{z:.4f}", ) ) title = f"Feature #{feature_idx}" if description: title += f": {description[:80]}" fig.update_layout( title=title, xaxis_title="Token", height=150, margin=dict(l=60, r=20, t=40, b=40), xaxis=dict(tickangle=45), ) return fig def create_multi_feature_heatmap( str_tokens: list[str], feature_data: list[dict], max_features: int = 10, colorscale: str = "YlOrRd", ) -> go.Figure: """Create a heatmap showing multiple features' activations across tokens. Each row is a feature, each column is a token. Color intensity shows activation strength. """ display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens] data = feature_data[:max_features] n_features = len(data) # Build the z-matrix: [n_features x n_tokens] z = np.zeros((n_features, len(str_tokens))) y_labels = [] for i, feat in enumerate(data): acts = feat["per_token_activations"] z[i, : len(acts)] = acts desc = feat["description"][:40] y_labels.append(f"#{feat['feature_idx']}: {desc}") fig = go.Figure( data=go.Heatmap( z=z, x=display_tokens, y=y_labels, colorscale=colorscale, hovertemplate="Token: %{x}
Feature: %{y}
Activation: %{z:.4f}", ) ) fig.update_layout( title="Top Active Features by Token ( token skipped)", xaxis_title="Token", yaxis_title="Feature", height=max(300, 60 * n_features), margin=dict(l=200, r=20, t=40, b=60), xaxis=dict(tickangle=45), ) return fig def create_activation_histogram( activations: list[float], feature_idx: int, description: str = "", n_bins: int = 50, ) -> go.Figure: """Create a histogram of feature activations across tokens.""" acts = np.array(activations) nonzero = acts[acts > 0] fig = make_subplots(rows=1, cols=1) if len(nonzero) > 0: fig.add_trace( go.Histogram( x=nonzero, nbinsx=n_bins, name="Non-zero activations", marker_color="steelblue", ) ) title = f"Feature #{feature_idx} Activation Distribution" if description: title += f"\n{description[:80]}" sparsity = 1.0 - (len(nonzero) / len(acts)) if len(acts) > 0 else 1.0 fig.update_layout( title=title, xaxis_title="Activation Value", yaxis_title="Count", height=300, margin=dict(l=60, r=20, t=60, b=40), annotations=[ dict( text=f"Sparsity: {sparsity:.1%} | Active: {len(nonzero)}/{len(acts)}", xref="paper", yref="paper", x=0.95, y=0.95, showarrow=False, font=dict(size=11), ) ], ) return fig def create_steering_comparison( prompt: str, unsteered: str, steered: str, interventions: list[dict], ) -> str: """Create an HTML comparison of steered vs. unsteered text. Returns formatted HTML string for display in Gradio. """ import html prompt_safe = html.escape(prompt) unsteered_safe = html.escape(unsteered) steered_safe = html.escape(steered) intervention_desc = ", ".join( f"Feature #{i['feature_idx']} (strength={i['strength']:.1f})" for i in interventions ) markup = f"""

Prompt

{prompt_safe}

Unsteered

{unsteered_safe}

Steered

{steered_safe}

Interventions: {intervention_desc}

""" return markup def create_top_predictions_comparison( clean_tokens: list[dict], steered_tokens: list[dict], kl_divergence: float, ) -> go.Figure: """Create a side-by-side bar chart comparing top predicted tokens. Shows how steering changes the model's next-token distribution. """ fig = make_subplots( rows=1, cols=2, subplot_titles=["Unsteered Predictions", "Steered Predictions"], horizontal_spacing=0.15, ) # Clean predictions fig.add_trace( go.Bar( x=[t["prob"] for t in clean_tokens], y=[t["token"] for t in clean_tokens], orientation="h", marker_color="lightgray", name="Unsteered", ), row=1, col=1, ) # Steered predictions fig.add_trace( go.Bar( x=[t["prob"] for t in steered_tokens], y=[t["token"] for t in steered_tokens], orientation="h", marker_color="steelblue", name="Steered", ), row=1, col=2, ) fig.update_layout( title=f"Next-Token Predictions (KL Divergence: {kl_divergence:.4f})", height=400, showlegend=False, margin=dict(l=80, r=20, t=60, b=40), ) fig.update_xaxes(title_text="Probability", row=1, col=1) fig.update_xaxes(title_text="Probability", row=1, col=2) return fig def create_layer_activity_plot( layer_activations: dict[int, float], feature_idx: int, description: str = "", ) -> go.Figure: """Plot feature activation strength across layers. Shows which layers a feature is most active in, giving insight into where in the model's computation the feature matters. """ layers = sorted(layer_activations.keys()) values = [layer_activations[l] for l in layers] fig = go.Figure( data=go.Bar( x=[f"Layer {l}" for l in layers], y=values, marker_color="steelblue", ) ) title = f"Feature #{feature_idx} Activity by Layer" if description: title += f"\n{description[:60]}" fig.update_layout( title=title, xaxis_title="Layer", yaxis_title="Mean Activation", height=350, margin=dict(l=60, r=20, t=60, b=60), ) return fig def create_logit_attribution_chart( top_positive: list[dict], top_negative: list[dict], bias: float, error: float, target_token: str, total_logit: float, descriptions: Optional[dict[int, str]] = None, ) -> go.Figure: """Create a horizontal bar chart of per-feature logit contributions. Positive contributions shown in blue (right), negative in red (left). Includes bias and reconstruction error as separate bars. """ labels = [] values = [] colors = [] # Add positive contributors (largest first) for feat in top_positive: idx = feat["feature_idx"] desc = "" if descriptions and idx in descriptions: desc = descriptions[idx][:40] labels.append(f"#{idx}: {desc}") values.append(feat["contribution"]) colors.append("#2196F3") # Add negative contributors (most negative first) for feat in top_negative: idx = feat["feature_idx"] desc = "" if descriptions and idx in descriptions: desc = descriptions[idx][:40] labels.append(f"#{idx}: {desc}") values.append(feat["contribution"]) colors.append("#F44336") # Add bias and error labels.append("SAE bias") values.append(bias) colors.append("#9E9E9E") labels.append("Reconstruction error") values.append(error) colors.append("#757575") fig = go.Figure( data=go.Bar( y=labels, x=values, orientation="h", marker_color=colors, hovertemplate="%{y}
Contribution: %{x:.4f}", ) ) fig.update_layout( title=f'Feature contributions to "{target_token}" (total logit: {total_logit:.2f})', xaxis_title="Logit Contribution", height=max(400, 30 * len(labels) + 100), margin=dict(l=250, r=20, t=60, b=40), yaxis=dict(autorange="reversed"), ) return fig def create_logit_decomposition_summary( sae_explained: float, bias: float, error: float, total: float, ) -> go.Figure: """Create a stacked bar chart showing SAE-explained vs bias vs error portions.""" feature_sum = sae_explained - bias # isolate pure feature contributions labels = ["Feature contributions", "SAE bias", "Reconstruction error"] values = [feature_sum, bias, error] bar_colors = ["#2196F3", "#9E9E9E", "#757575"] fig = go.Figure( data=go.Bar( x=labels, y=values, marker_color=bar_colors, text=[f"{v:.3f}" for v in values], textposition="auto", ) ) gap = total - (feature_sum + bias + error) fig.update_layout( title=f"Logit Decomposition (total: {total:.3f}, gap: {gap:.4f})", yaxis_title="Logit Value", height=350, margin=dict(l=60, r=20, t=60, b=40), ) return fig