SteerTheShip / visualization.py
benbatman's picture
feature attribution, bug fixes
7b90f10
"""Token-level activation heatmaps, feature dashboards, and visualization utilities.
Generates interactive Plotly visualizations for the dashboard:
- Token-level feature activation heatmaps
- Feature activation distributions
- Steered vs. unsteered comparison displays
- Layer-wise feature activity plots
"""
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from typing import Optional
def create_token_heatmap(
str_tokens: list[str],
activations: list[float],
feature_idx: int,
description: str = "",
colorscale: str = "YlOrRd",
) -> go.Figure:
"""Create a heatmap showing feature activation per token.
Displays tokens along the x-axis with color intensity proportional
to the feature's activation on that token.
"""
# Clean up token strings for display
display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens]
# Reshape activations for heatmap (1 x n_tokens)
z = np.array(activations).reshape(1, -1)
fig = go.Figure(
data=go.Heatmap(
z=z,
x=display_tokens,
y=["Activation"],
colorscale=colorscale,
text=[[f"{v:.3f}" for v in activations]],
texttemplate="%{text}",
textfont={"size": 10},
hovertemplate="Token: %{x}<br>Activation: %{z:.4f}<extra></extra>",
)
)
title = f"Feature #{feature_idx}"
if description:
title += f": {description[:80]}"
fig.update_layout(
title=title,
xaxis_title="Token",
height=150,
margin=dict(l=60, r=20, t=40, b=40),
xaxis=dict(tickangle=45),
)
return fig
def create_multi_feature_heatmap(
str_tokens: list[str],
feature_data: list[dict],
max_features: int = 10,
colorscale: str = "YlOrRd",
) -> go.Figure:
"""Create a heatmap showing multiple features' activations across tokens.
Each row is a feature, each column is a token. Color intensity shows
activation strength.
"""
display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens]
data = feature_data[:max_features]
n_features = len(data)
# Build the z-matrix: [n_features x n_tokens]
z = np.zeros((n_features, len(str_tokens)))
y_labels = []
for i, feat in enumerate(data):
acts = feat["per_token_activations"]
z[i, : len(acts)] = acts
desc = feat["description"][:40]
y_labels.append(f"#{feat['feature_idx']}: {desc}")
fig = go.Figure(
data=go.Heatmap(
z=z,
x=display_tokens,
y=y_labels,
colorscale=colorscale,
hovertemplate="Token: %{x}<br>Feature: %{y}<br>Activation: %{z:.4f}<extra></extra>",
)
)
fig.update_layout(
title="Top Active Features by Token (<bos> token skipped)",
xaxis_title="Token",
yaxis_title="Feature",
height=max(300, 60 * n_features),
margin=dict(l=200, r=20, t=40, b=60),
xaxis=dict(tickangle=45),
)
return fig
def create_activation_histogram(
activations: list[float],
feature_idx: int,
description: str = "",
n_bins: int = 50,
) -> go.Figure:
"""Create a histogram of feature activations across tokens."""
acts = np.array(activations)
nonzero = acts[acts > 0]
fig = make_subplots(rows=1, cols=1)
if len(nonzero) > 0:
fig.add_trace(
go.Histogram(
x=nonzero,
nbinsx=n_bins,
name="Non-zero activations",
marker_color="steelblue",
)
)
title = f"Feature #{feature_idx} Activation Distribution"
if description:
title += f"\n{description[:80]}"
sparsity = 1.0 - (len(nonzero) / len(acts)) if len(acts) > 0 else 1.0
fig.update_layout(
title=title,
xaxis_title="Activation Value",
yaxis_title="Count",
height=300,
margin=dict(l=60, r=20, t=60, b=40),
annotations=[
dict(
text=f"Sparsity: {sparsity:.1%} | Active: {len(nonzero)}/{len(acts)}",
xref="paper",
yref="paper",
x=0.95,
y=0.95,
showarrow=False,
font=dict(size=11),
)
],
)
return fig
def create_steering_comparison(
prompt: str,
unsteered: str,
steered: str,
interventions: list[dict],
) -> str:
"""Create an HTML comparison of steered vs. unsteered text.
Returns formatted HTML string for display in Gradio.
"""
import html
prompt_safe = html.escape(prompt)
unsteered_safe = html.escape(unsteered)
steered_safe = html.escape(steered)
intervention_desc = ", ".join(
f"Feature #{i['feature_idx']} (strength={i['strength']:.1f})"
for i in interventions
)
markup = f"""
<div style="font-family: monospace; padding: 10px;">
<h3>Prompt</h3>
<p style="background: #f0f0f0; padding: 10px; border-radius: 5px;">{prompt_safe}</p>
<div style="display: flex; gap: 20px;">
<div style="flex: 1;">
<h3 style="color: #666;">Unsteered</h3>
<p style="background: #f8f8f8; padding: 10px; border-radius: 5px;
border-left: 3px solid #ccc; white-space: pre-wrap;">{unsteered_safe}</p>
</div>
<div style="flex: 1;">
<h3 style="color: #2196F3;">Steered</h3>
<p style="background: #f0f8ff; padding: 10px; border-radius: 5px;
border-left: 3px solid #2196F3; white-space: pre-wrap;">{steered_safe}</p>
</div>
</div>
<p style="color: #888; font-size: 0.9em;">
Interventions: {intervention_desc}
</p>
</div>
"""
return markup
def create_top_predictions_comparison(
clean_tokens: list[dict],
steered_tokens: list[dict],
kl_divergence: float,
) -> go.Figure:
"""Create a side-by-side bar chart comparing top predicted tokens.
Shows how steering changes the model's next-token distribution.
"""
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=["Unsteered Predictions", "Steered Predictions"],
horizontal_spacing=0.15,
)
# Clean predictions
fig.add_trace(
go.Bar(
x=[t["prob"] for t in clean_tokens],
y=[t["token"] for t in clean_tokens],
orientation="h",
marker_color="lightgray",
name="Unsteered",
),
row=1,
col=1,
)
# Steered predictions
fig.add_trace(
go.Bar(
x=[t["prob"] for t in steered_tokens],
y=[t["token"] for t in steered_tokens],
orientation="h",
marker_color="steelblue",
name="Steered",
),
row=1,
col=2,
)
fig.update_layout(
title=f"Next-Token Predictions (KL Divergence: {kl_divergence:.4f})",
height=400,
showlegend=False,
margin=dict(l=80, r=20, t=60, b=40),
)
fig.update_xaxes(title_text="Probability", row=1, col=1)
fig.update_xaxes(title_text="Probability", row=1, col=2)
return fig
def create_layer_activity_plot(
layer_activations: dict[int, float],
feature_idx: int,
description: str = "",
) -> go.Figure:
"""Plot feature activation strength across layers.
Shows which layers a feature is most active in, giving insight
into where in the model's computation the feature matters.
"""
layers = sorted(layer_activations.keys())
values = [layer_activations[l] for l in layers]
fig = go.Figure(
data=go.Bar(
x=[f"Layer {l}" for l in layers],
y=values,
marker_color="steelblue",
)
)
title = f"Feature #{feature_idx} Activity by Layer"
if description:
title += f"\n{description[:60]}"
fig.update_layout(
title=title,
xaxis_title="Layer",
yaxis_title="Mean Activation",
height=350,
margin=dict(l=60, r=20, t=60, b=60),
)
return fig
def create_logit_attribution_chart(
top_positive: list[dict],
top_negative: list[dict],
bias: float,
error: float,
target_token: str,
total_logit: float,
descriptions: Optional[dict[int, str]] = None,
) -> go.Figure:
"""Create a horizontal bar chart of per-feature logit contributions.
Positive contributions shown in blue (right), negative in red (left).
Includes bias and reconstruction error as separate bars.
"""
labels = []
values = []
colors = []
# Add positive contributors (largest first)
for feat in top_positive:
idx = feat["feature_idx"]
desc = ""
if descriptions and idx in descriptions:
desc = descriptions[idx][:40]
labels.append(f"#{idx}: {desc}")
values.append(feat["contribution"])
colors.append("#2196F3")
# Add negative contributors (most negative first)
for feat in top_negative:
idx = feat["feature_idx"]
desc = ""
if descriptions and idx in descriptions:
desc = descriptions[idx][:40]
labels.append(f"#{idx}: {desc}")
values.append(feat["contribution"])
colors.append("#F44336")
# Add bias and error
labels.append("SAE bias")
values.append(bias)
colors.append("#9E9E9E")
labels.append("Reconstruction error")
values.append(error)
colors.append("#757575")
fig = go.Figure(
data=go.Bar(
y=labels,
x=values,
orientation="h",
marker_color=colors,
hovertemplate="<b>%{y}</b><br>Contribution: %{x:.4f}<extra></extra>",
)
)
fig.update_layout(
title=f'Feature contributions to "{target_token}" (total logit: {total_logit:.2f})',
xaxis_title="Logit Contribution",
height=max(400, 30 * len(labels) + 100),
margin=dict(l=250, r=20, t=60, b=40),
yaxis=dict(autorange="reversed"),
)
return fig
def create_logit_decomposition_summary(
sae_explained: float,
bias: float,
error: float,
total: float,
) -> go.Figure:
"""Create a stacked bar chart showing SAE-explained vs bias vs error portions."""
feature_sum = sae_explained - bias # isolate pure feature contributions
labels = ["Feature contributions", "SAE bias", "Reconstruction error"]
values = [feature_sum, bias, error]
bar_colors = ["#2196F3", "#9E9E9E", "#757575"]
fig = go.Figure(
data=go.Bar(
x=labels,
y=values,
marker_color=bar_colors,
text=[f"{v:.3f}" for v in values],
textposition="auto",
)
)
gap = total - (feature_sum + bias + error)
fig.update_layout(
title=f"Logit Decomposition (total: {total:.3f}, gap: {gap:.4f})",
yaxis_title="Logit Value",
height=350,
margin=dict(l=60, r=20, t=60, b=40),
)
return fig