""" NeuroScope — Token-Layer Activation Grid Heatmap with tokens as columns and layers as rows. Color encodes activation magnitude (L2 norm) per token per layer, revealing how each token's representation evolves through the network. All charts use Plotly with the project dark theme (#1a1a2e bg, #e6b800 accent). """ import numpy as np import plotly.graph_objects as go from extraction import ExtractionResult # --------------------------------------------------------------------------- # Theme constants # --------------------------------------------------------------------------- BG_COLOR = "#1a1a2e" PAPER_COLOR = "#1a1a2e" TEXT_COLOR = "#e0e0e0" ACCENT_COLOR = "#e6b800" GRID_COLOR = "#2a2a4e" # Custom purple-to-gold heatmap colorscale for activation intensity TOKEN_LAYER_COLORSCALE = [ [0.0, "#0d0d1a"], [0.1, "#1a1040"], [0.25, "#2d1b69"], [0.4, "#5e2d8e"], [0.55, "#8e4585"], [0.7, "#c46a3a"], [0.85, "#e6b800"], [1.0, "#ffd633"], ] def create_token_layer_grid( result: ExtractionResult, normalize: str = "global", ) -> go.Figure: """Create a token × layer activation magnitude heatmap. Args: result: Extraction output containing hidden states. normalize: Normalization strategy: - "global": Scale to global min/max across all layers and tokens. - "per_layer": Normalize each row independently (highlights within-layer variation). - "per_token": Normalize each column independently (highlights depth evolution per token). - "none": Raw L2 norms. Returns: Plotly Figure with interactive heatmap. """ # hidden_states: (num_layers+1, seq_len, hidden_dim) hs = result.hidden_states tokens = result.tokens num_layers_total = hs.shape[0] # includes embedding layer seq_len = len(tokens) # Compute L2 norm per token per layer → (num_layers+1, seq_len) magnitudes = np.linalg.norm(hs, axis=-1) # Apply normalization display = magnitudes.copy() if normalize == "global": vmin, vmax = display.min(), display.max() if vmax > vmin: display = (display - vmin) / (vmax - vmin) elif normalize == "per_layer": for i in range(num_layers_total): row = display[i] rmin, rmax = row.min(), row.max() if rmax > rmin: display[i] = (row - rmin) / (rmax - rmin) elif normalize == "per_token": for j in range(seq_len): col = display[:, j] cmin, cmax = col.min(), col.max() if cmax > cmin: display[:, j] = (col - cmin) / (cmax - cmin) # else: "none" — use raw values # Build axis labels x_labels = [t[:12] for t in tokens] y_labels = ["Embed"] + [f"L{i}" for i in range(result.num_layers)] # Build hover text with raw values hover = np.empty((num_layers_total, seq_len), dtype=object) for i in range(num_layers_total): layer_name = "Embedding" if i == 0 else f"Layer {i - 1}" for j in range(seq_len): hover[i, j] = ( f"Token: {tokens[j]}
" f"{layer_name}
" f"L2 Norm: {magnitudes[i, j]:.2f}
" f"Normalized: {display[i, j]:.3f}" ) fig = go.Figure( data=go.Heatmap( z=display, x=x_labels, y=y_labels, text=hover, hoverinfo="text", colorscale=TOKEN_LAYER_COLORSCALE, colorbar=dict( title=dict( text="Activation" if normalize == "none" else "Norm. Activation", font=dict(color=TEXT_COLOR), ), tickfont=dict(color=TEXT_COLOR), ), ) ) fig.update_layout( title=dict( text=f"Token × Layer Activation Grid (norm: {normalize})", font=dict(color=ACCENT_COLOR, size=14), ), xaxis=dict( title=dict(text="Token", font=dict(color=TEXT_COLOR, size=11)), tickfont=dict(color=TEXT_COLOR, size=9), side="top", tickangle=45, ), yaxis=dict( title=dict(text="Layer", font=dict(color=TEXT_COLOR, size=11)), tickfont=dict(color=TEXT_COLOR, size=8), autorange="reversed", ), paper_bgcolor=PAPER_COLOR, plot_bgcolor=BG_COLOR, margin=dict(l=60, r=30, t=80, b=30), height=520, ) return fig