Prompt
{prompt_safe}
Unsteered
{unsteered_safe}
Interventions: {intervention_desc}
"""
return markup
def create_top_predictions_comparison(
clean_tokens: list[dict],
steered_tokens: list[dict],
kl_divergence: float,
) -> go.Figure:
"""Create a side-by-side bar chart comparing top predicted tokens.
Shows how steering changes the model's next-token distribution.
"""
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=["Unsteered Predictions", "Steered Predictions"],
horizontal_spacing=0.15,
)
# Clean predictions
fig.add_trace(
go.Bar(
x=[t["prob"] for t in clean_tokens],
y=[t["token"] for t in clean_tokens],
orientation="h",
marker_color="lightgray",
name="Unsteered",
),
row=1,
col=1,
)
# Steered predictions
fig.add_trace(
go.Bar(
x=[t["prob"] for t in steered_tokens],
y=[t["token"] for t in steered_tokens],
orientation="h",
marker_color="steelblue",
name="Steered",
),
row=1,
col=2,
)
fig.update_layout(
title=f"Next-Token Predictions (KL Divergence: {kl_divergence:.4f})",
height=400,
showlegend=False,
margin=dict(l=80, r=20, t=60, b=40),
)
fig.update_xaxes(title_text="Probability", row=1, col=1)
fig.update_xaxes(title_text="Probability", row=1, col=2)
return fig
def create_layer_activity_plot(
layer_activations: dict[int, float],
feature_idx: int,
description: str = "",
) -> go.Figure:
"""Plot feature activation strength across layers.
Shows which layers a feature is most active in, giving insight
into where in the model's computation the feature matters.
"""
layers = sorted(layer_activations.keys())
values = [layer_activations[l] for l in layers]
fig = go.Figure(
data=go.Bar(
x=[f"Layer {l}" for l in layers],
y=values,
marker_color="steelblue",
)
)
title = f"Feature #{feature_idx} Activity by Layer"
if description:
title += f"\n{description[:60]}"
fig.update_layout(
title=title,
xaxis_title="Layer",
yaxis_title="Mean Activation",
height=350,
margin=dict(l=60, r=20, t=60, b=60),
)
return fig
def create_logit_attribution_chart(
top_positive: list[dict],
top_negative: list[dict],
bias: float,
error: float,
target_token: str,
total_logit: float,
descriptions: Optional[dict[int, str]] = None,
) -> go.Figure:
"""Create a horizontal bar chart of per-feature logit contributions.
Positive contributions shown in blue (right), negative in red (left).
Includes bias and reconstruction error as separate bars.
"""
labels = []
values = []
colors = []
# Add positive contributors (largest first)
for feat in top_positive:
idx = feat["feature_idx"]
desc = ""
if descriptions and idx in descriptions:
desc = descriptions[idx][:40]
labels.append(f"#{idx}: {desc}")
values.append(feat["contribution"])
colors.append("#2196F3")
# Add negative contributors (most negative first)
for feat in top_negative:
idx = feat["feature_idx"]
desc = ""
if descriptions and idx in descriptions:
desc = descriptions[idx][:40]
labels.append(f"#{idx}: {desc}")
values.append(feat["contribution"])
colors.append("#F44336")
# Add bias and error
labels.append("SAE bias")
values.append(bias)
colors.append("#9E9E9E")
labels.append("Reconstruction error")
values.append(error)
colors.append("#757575")
fig = go.Figure(
data=go.Bar(
y=labels,
x=values,
orientation="h",
marker_color=colors,
hovertemplate="