Spaces:
Runtime error
Runtime error
| """ | |
| NeuroScope — Token-Layer Activation Grid | |
| Heatmap with tokens as columns and layers as rows. | |
| Color encodes activation magnitude (L2 norm) per token per layer, | |
| revealing how each token's representation evolves through the network. | |
| All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent). | |
| """ | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from extraction import ExtractionResult | |
| # --------------------------------------------------------------------------- | |
| # Theme constants | |
| # --------------------------------------------------------------------------- | |
| BG_COLOR = "#1a1a2e" | |
| PAPER_COLOR = "#1a1a2e" | |
| TEXT_COLOR = "#e0e0e0" | |
| ACCENT_COLOR = "#e6b800" | |
| GRID_COLOR = "#2a2a4e" | |
| # Custom purple-to-gold heatmap colorscale for activation intensity | |
| TOKEN_LAYER_COLORSCALE = [ | |
| [0.0, "#0d0d1a"], | |
| [0.1, "#1a1040"], | |
| [0.25, "#2d1b69"], | |
| [0.4, "#5e2d8e"], | |
| [0.55, "#8e4585"], | |
| [0.7, "#c46a3a"], | |
| [0.85, "#e6b800"], | |
| [1.0, "#ffd633"], | |
| ] | |
| def create_token_layer_grid( | |
| result: ExtractionResult, | |
| normalize: str = "global", | |
| ) -> go.Figure: | |
| """Create a token × layer activation magnitude heatmap. | |
| Args: | |
| result: Extraction output containing hidden states. | |
| normalize: Normalization strategy: | |
| - "global": Scale to global min/max across all layers and tokens. | |
| - "per_layer": Normalize each row independently (highlights | |
| within-layer variation). | |
| - "per_token": Normalize each column independently (highlights | |
| depth evolution per token). | |
| - "none": Raw L2 norms. | |
| Returns: | |
| Plotly Figure with interactive heatmap. | |
| """ | |
| # hidden_states: (num_layers+1, seq_len, hidden_dim) | |
| hs = result.hidden_states | |
| tokens = result.tokens | |
| num_layers_total = hs.shape[0] # includes embedding layer | |
| seq_len = len(tokens) | |
| # Compute L2 norm per token per layer → (num_layers+1, seq_len) | |
| magnitudes = np.linalg.norm(hs, axis=-1) | |
| # Apply normalization | |
| display = magnitudes.copy() | |
| if normalize == "global": | |
| vmin, vmax = display.min(), display.max() | |
| if vmax > vmin: | |
| display = (display - vmin) / (vmax - vmin) | |
| elif normalize == "per_layer": | |
| for i in range(num_layers_total): | |
| row = display[i] | |
| rmin, rmax = row.min(), row.max() | |
| if rmax > rmin: | |
| display[i] = (row - rmin) / (rmax - rmin) | |
| elif normalize == "per_token": | |
| for j in range(seq_len): | |
| col = display[:, j] | |
| cmin, cmax = col.min(), col.max() | |
| if cmax > cmin: | |
| display[:, j] = (col - cmin) / (cmax - cmin) | |
| # else: "none" — use raw values | |
| # Build axis labels | |
| x_labels = [t[:12] for t in tokens] | |
| y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)] | |
| # Build hover text with raw values | |
| hover = np.empty((num_layers_total, seq_len), dtype=object) | |
| for i in range(num_layers_total): | |
| layer_name = "Embedding" if i == 0 else f"Layer {i - 1}" | |
| for j in range(seq_len): | |
| hover[i, j] = ( | |
| f"Token: {tokens[j]}<br>" | |
| f"{layer_name}<br>" | |
| f"L2 Norm: {magnitudes[i, j]:.2f}<br>" | |
| f"Normalized: {display[i, j]:.3f}" | |
| ) | |
| fig = go.Figure( | |
| data=go.Heatmap( | |
| z=display, | |
| x=x_labels, | |
| y=y_labels, | |
| text=hover, | |
| hoverinfo="text", | |
| colorscale=TOKEN_LAYER_COLORSCALE, | |
| colorbar=dict( | |
| title=dict( | |
| text="Activation" if normalize == "none" else "Norm. Activation", | |
| font=dict(color=TEXT_COLOR), | |
| ), | |
| tickfont=dict(color=TEXT_COLOR), | |
| ), | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"Token × Layer Activation Grid (norm: {normalize})", | |
| font=dict(color=ACCENT_COLOR, size=14), | |
| ), | |
| xaxis=dict( | |
| title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)), | |
| tickfont=dict(color=TEXT_COLOR, size=9), | |
| side="top", | |
| tickangle=45, | |
| ), | |
| yaxis=dict( | |
| title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)), | |
| tickfont=dict(color=TEXT_COLOR, size=8), | |
| autorange="reversed", | |
| ), | |
| paper_bgcolor=PAPER_COLOR, | |
| plot_bgcolor=BG_COLOR, | |
| margin=dict(l=60, r=30, t=80, b=30), | |
| height=520, | |
| ) | |
| return fig | |