Spaces:
Sleeping
Sleeping
| """Token-level activation heatmaps, feature dashboards, and visualization utilities. | |
| Generates interactive Plotly visualizations for the dashboard: | |
| - Token-level feature activation heatmaps | |
| - Feature activation distributions | |
| - Steered vs. unsteered comparison displays | |
| - Layer-wise feature activity plots | |
| """ | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| from typing import Optional | |
| def create_token_heatmap( | |
| str_tokens: list[str], | |
| activations: list[float], | |
| feature_idx: int, | |
| description: str = "", | |
| colorscale: str = "YlOrRd", | |
| ) -> go.Figure: | |
| """Create a heatmap showing feature activation per token. | |
| Displays tokens along the x-axis with color intensity proportional | |
| to the feature's activation on that token. | |
| """ | |
| # Clean up token strings for display | |
| display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens] | |
| # Reshape activations for heatmap (1 x n_tokens) | |
| z = np.array(activations).reshape(1, -1) | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=z, | |
| x=display_tokens, | |
| y=["Activation"], | |
| colorscale=colorscale, | |
| text=[[f"{v:.3f}" for v in activations]], | |
| texttemplate="%{text}", | |
| textfont={"size": 10}, | |
| hovertemplate="Token: %{x}<br>Activation: %{z:.4f}<extra></extra>", | |
| ) | |
| ) | |
| title = f"Feature #{feature_idx}" | |
| if description: | |
| title += f": {description[:80]}" | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Token", | |
| height=150, | |
| margin=dict(l=60, r=20, t=40, b=40), | |
| xaxis=dict(tickangle=45), | |
| ) | |
| return fig | |
| def create_multi_feature_heatmap( | |
| str_tokens: list[str], | |
| feature_data: list[dict], | |
| max_features: int = 10, | |
| colorscale: str = "YlOrRd", | |
| ) -> go.Figure: | |
| """Create a heatmap showing multiple features' activations across tokens. | |
| Each row is a feature, each column is a token. Color intensity shows | |
| activation strength. | |
| """ | |
| display_tokens = [t.replace("▁", " ").replace("Ġ", " ") for t in str_tokens] | |
| data = feature_data[:max_features] | |
| n_features = len(data) | |
| # Build the z-matrix: [n_features x n_tokens] | |
| z = np.zeros((n_features, len(str_tokens))) | |
| y_labels = [] | |
| for i, feat in enumerate(data): | |
| acts = feat["per_token_activations"] | |
| z[i, : len(acts)] = acts | |
| desc = feat["description"][:40] | |
| y_labels.append(f"#{feat['feature_idx']}: {desc}") | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=z, | |
| x=display_tokens, | |
| y=y_labels, | |
| colorscale=colorscale, | |
| hovertemplate="Token: %{x}<br>Feature: %{y}<br>Activation: %{z:.4f}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title="Top Active Features by Token (<bos> token skipped)", | |
| xaxis_title="Token", | |
| yaxis_title="Feature", | |
| height=max(300, 60 * n_features), | |
| margin=dict(l=200, r=20, t=40, b=60), | |
| xaxis=dict(tickangle=45), | |
| ) | |
| return fig | |
| def create_activation_histogram( | |
| activations: list[float], | |
| feature_idx: int, | |
| description: str = "", | |
| n_bins: int = 50, | |
| ) -> go.Figure: | |
| """Create a histogram of feature activations across tokens.""" | |
| acts = np.array(activations) | |
| nonzero = acts[acts > 0] | |
| fig = make_subplots(rows=1, cols=1) | |
| if len(nonzero) > 0: | |
| fig.add_trace( | |
| go.Histogram( | |
| x=nonzero, | |
| nbinsx=n_bins, | |
| name="Non-zero activations", | |
| marker_color="steelblue", | |
| ) | |
| ) | |
| title = f"Feature #{feature_idx} Activation Distribution" | |
| if description: | |
| title += f"\n{description[:80]}" | |
| sparsity = 1.0 - (len(nonzero) / len(acts)) if len(acts) > 0 else 1.0 | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Activation Value", | |
| yaxis_title="Count", | |
| height=300, | |
| margin=dict(l=60, r=20, t=60, b=40), | |
| annotations=[ | |
| dict( | |
| text=f"Sparsity: {sparsity:.1%} | Active: {len(nonzero)}/{len(acts)}", | |
| xref="paper", | |
| yref="paper", | |
| x=0.95, | |
| y=0.95, | |
| showarrow=False, | |
| font=dict(size=11), | |
| ) | |
| ], | |
| ) | |
| return fig | |
| def create_steering_comparison( | |
| prompt: str, | |
| unsteered: str, | |
| steered: str, | |
| interventions: list[dict], | |
| ) -> str: | |
| """Create an HTML comparison of steered vs. unsteered text. | |
| Returns formatted HTML string for display in Gradio. | |
| """ | |
| import html | |
| prompt_safe = html.escape(prompt) | |
| unsteered_safe = html.escape(unsteered) | |
| steered_safe = html.escape(steered) | |
| intervention_desc = ", ".join( | |
| f"Feature #{i['feature_idx']} (strength={i['strength']:.1f})" | |
| for i in interventions | |
| ) | |
| markup = f""" | |
| <div style="font-family: monospace; padding: 10px;"> | |
| <h3>Prompt</h3> | |
| <p style="background: #f0f0f0; padding: 10px; border-radius: 5px;">{prompt_safe}</p> | |
| <div style="display: flex; gap: 20px;"> | |
| <div style="flex: 1;"> | |
| <h3 style="color: #666;">Unsteered</h3> | |
| <p style="background: #f8f8f8; padding: 10px; border-radius: 5px; | |
| border-left: 3px solid #ccc; white-space: pre-wrap;">{unsteered_safe}</p> | |
| </div> | |
| <div style="flex: 1;"> | |
| <h3 style="color: #2196F3;">Steered</h3> | |
| <p style="background: #f0f8ff; padding: 10px; border-radius: 5px; | |
| border-left: 3px solid #2196F3; white-space: pre-wrap;">{steered_safe}</p> | |
| </div> | |
| </div> | |
| <p style="color: #888; font-size: 0.9em;"> | |
| Interventions: {intervention_desc} | |
| </p> | |
| </div> | |
| """ | |
| return markup | |
| def create_top_predictions_comparison( | |
| clean_tokens: list[dict], | |
| steered_tokens: list[dict], | |
| kl_divergence: float, | |
| ) -> go.Figure: | |
| """Create a side-by-side bar chart comparing top predicted tokens. | |
| Shows how steering changes the model's next-token distribution. | |
| """ | |
| fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| subplot_titles=["Unsteered Predictions", "Steered Predictions"], | |
| horizontal_spacing=0.15, | |
| ) | |
| # Clean predictions | |
| fig.add_trace( | |
| go.Bar( | |
| x=[t["prob"] for t in clean_tokens], | |
| y=[t["token"] for t in clean_tokens], | |
| orientation="h", | |
| marker_color="lightgray", | |
| name="Unsteered", | |
| ), | |
| row=1, | |
| col=1, | |
| ) | |
| # Steered predictions | |
| fig.add_trace( | |
| go.Bar( | |
| x=[t["prob"] for t in steered_tokens], | |
| y=[t["token"] for t in steered_tokens], | |
| orientation="h", | |
| marker_color="steelblue", | |
| name="Steered", | |
| ), | |
| row=1, | |
| col=2, | |
| ) | |
| fig.update_layout( | |
| title=f"Next-Token Predictions (KL Divergence: {kl_divergence:.4f})", | |
| height=400, | |
| showlegend=False, | |
| margin=dict(l=80, r=20, t=60, b=40), | |
| ) | |
| fig.update_xaxes(title_text="Probability", row=1, col=1) | |
| fig.update_xaxes(title_text="Probability", row=1, col=2) | |
| return fig | |
| def create_layer_activity_plot( | |
| layer_activations: dict[int, float], | |
| feature_idx: int, | |
| description: str = "", | |
| ) -> go.Figure: | |
| """Plot feature activation strength across layers. | |
| Shows which layers a feature is most active in, giving insight | |
| into where in the model's computation the feature matters. | |
| """ | |
| layers = sorted(layer_activations.keys()) | |
| values = [layer_activations[l] for l in layers] | |
| fig = go.Figure( | |
| data=go.Bar( | |
| x=[f"Layer {l}" for l in layers], | |
| y=values, | |
| marker_color="steelblue", | |
| ) | |
| ) | |
| title = f"Feature #{feature_idx} Activity by Layer" | |
| if description: | |
| title += f"\n{description[:60]}" | |
| fig.update_layout( | |
| title=title, | |
| xaxis_title="Layer", | |
| yaxis_title="Mean Activation", | |
| height=350, | |
| margin=dict(l=60, r=20, t=60, b=60), | |
| ) | |
| return fig | |
| def create_logit_attribution_chart( | |
| top_positive: list[dict], | |
| top_negative: list[dict], | |
| bias: float, | |
| error: float, | |
| target_token: str, | |
| total_logit: float, | |
| descriptions: Optional[dict[int, str]] = None, | |
| ) -> go.Figure: | |
| """Create a horizontal bar chart of per-feature logit contributions. | |
| Positive contributions shown in blue (right), negative in red (left). | |
| Includes bias and reconstruction error as separate bars. | |
| """ | |
| labels = [] | |
| values = [] | |
| colors = [] | |
| # Add positive contributors (largest first) | |
| for feat in top_positive: | |
| idx = feat["feature_idx"] | |
| desc = "" | |
| if descriptions and idx in descriptions: | |
| desc = descriptions[idx][:40] | |
| labels.append(f"#{idx}: {desc}") | |
| values.append(feat["contribution"]) | |
| colors.append("#2196F3") | |
| # Add negative contributors (most negative first) | |
| for feat in top_negative: | |
| idx = feat["feature_idx"] | |
| desc = "" | |
| if descriptions and idx in descriptions: | |
| desc = descriptions[idx][:40] | |
| labels.append(f"#{idx}: {desc}") | |
| values.append(feat["contribution"]) | |
| colors.append("#F44336") | |
| # Add bias and error | |
| labels.append("SAE bias") | |
| values.append(bias) | |
| colors.append("#9E9E9E") | |
| labels.append("Reconstruction error") | |
| values.append(error) | |
| colors.append("#757575") | |
| fig = go.Figure( | |
| data=go.Bar( | |
| y=labels, | |
| x=values, | |
| orientation="h", | |
| marker_color=colors, | |
| hovertemplate="<b>%{y}</b><br>Contribution: %{x:.4f}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=f'Feature contributions to "{target_token}" (total logit: {total_logit:.2f})', | |
| xaxis_title="Logit Contribution", | |
| height=max(400, 30 * len(labels) + 100), | |
| margin=dict(l=250, r=20, t=60, b=40), | |
| yaxis=dict(autorange="reversed"), | |
| ) | |
| return fig | |
| def create_logit_decomposition_summary( | |
| sae_explained: float, | |
| bias: float, | |
| error: float, | |
| total: float, | |
| ) -> go.Figure: | |
| """Create a stacked bar chart showing SAE-explained vs bias vs error portions.""" | |
| feature_sum = sae_explained - bias # isolate pure feature contributions | |
| labels = ["Feature contributions", "SAE bias", "Reconstruction error"] | |
| values = [feature_sum, bias, error] | |
| bar_colors = ["#2196F3", "#9E9E9E", "#757575"] | |
| fig = go.Figure( | |
| data=go.Bar( | |
| x=labels, | |
| y=values, | |
| marker_color=bar_colors, | |
| text=[f"{v:.3f}" for v in values], | |
| textposition="auto", | |
| ) | |
| ) | |
| gap = total - (feature_sum + bias + error) | |
| fig.update_layout( | |
| title=f"Logit Decomposition (total: {total:.3f}, gap: {gap:.4f})", | |
| yaxis_title="Logit Value", | |
| height=350, | |
| margin=dict(l=60, r=20, t=60, b=40), | |
| ) | |
| return fig | |