""" PTS Visualizer - Interactive visualization for Pivotal Token Search A Neuronpedia-inspired platform for exploring pivotal tokens, thought anchors, and reasoning circuits in language models. """ import gradio as gr import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import networkx as nx import pandas as pd import numpy as np import json import html as html_lib from typing import List, Dict, Any, Optional, Tuple from datasets import load_dataset from sklearn.manifold import TSNE from sklearn.decomposition import PCA import re from collections import defaultdict # ============================================================================ # Data Loading Functions # ============================================================================ def load_hf_dataset(dataset_id: str, split: str = "train") -> pd.DataFrame: """Load a dataset from HuggingFace Hub.""" try: dataset = load_dataset(dataset_id, split=split) df = pd.DataFrame(dataset) return df, f"Loaded {len(df)} items from {dataset_id}" except Exception as e: return pd.DataFrame(), f"Error loading dataset: {str(e)}" def load_jsonl_file(file_path: str) -> pd.DataFrame: """Load data from a local JSONL file.""" try: data = [] with open(file_path, 'r') as f: for line in f: if line.strip(): data.append(json.loads(line)) return pd.DataFrame(data), f"Loaded {len(data)} items from file" except Exception as e: return pd.DataFrame(), f"Error loading file: {str(e)}" def detect_dataset_type(df: pd.DataFrame) -> str: """Detect the type of PTS dataset.""" columns = set(df.columns) if 'sentence' in columns and 'sentence_id' in columns: return 'thought_anchors' elif 'steering_vector' in columns: return 'steering_vectors' elif 'chosen' in columns and 'rejected' in columns: return 'dpo_pairs' elif 'pivot_token' in columns: return 'pivotal_tokens' else: return 'unknown' # ============================================================================ # Visualization Components # ============================================================================ def create_token_highlight_html(context: str, token: str, prob_delta: float) -> str: """Create HTML with highlighted pivotal token showing full context.""" # Escape HTML characters context_escaped = html_lib.escape(str(context)) token_escaped = html_lib.escape(str(token)) # Determine color based on probability delta if prob_delta > 0: # Positive impact - green gradient intensity = min(abs(prob_delta) * 2, 1.0) color = f"rgba(34, 197, 94, {intensity})" border_color = "#22c55e" impact_text = "Positive Impact" else: # Negative impact - red gradient intensity = min(abs(prob_delta) * 2, 1.0) color = f"rgba(239, 68, 68, {intensity})" border_color = "#ef4444" impact_text = "Negative Impact" # Create highlighted token span token_span = f'{token_escaped}' return f"""
Context Length: {len(context)} characters {impact_text}: {'+' if prob_delta > 0 else ''}{prob_delta:.3f}
{context_escaped}{token_span}
Token: {token_escaped}
""" def create_probability_chart(prob_before: float, prob_after: float) -> go.Figure: """Create a bar chart showing probability change.""" fig = go.Figure() # Ensure values are Python floats prob_before = float(prob_before) if prob_before is not None else 0.0 prob_after = float(prob_after) if prob_after is not None else 0.0 fig.add_trace(go.Bar( x=['Before Token', 'After Token'], y=[prob_before, prob_after], marker_color=['#6366f1', '#22c55e' if prob_after > prob_before else '#ef4444'], text=[f'{prob_before:.3f}', f'{prob_after:.3f}'], textposition='outside' )) fig.update_layout( title="Success Probability Change", yaxis_title="Probability", yaxis_range=[0, 1], template="plotly_dark", height=300 ) return fig def create_pivotal_token_flow(df: pd.DataFrame, selected_query: str = None) -> go.Figure: """Create a visualization for pivotal tokens showing token impact flow.""" if df.empty: fig = go.Figure() fig.add_annotation(text="No data available", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) fig.update_layout(template="plotly_dark") return fig # Filter by query if specified (handle None, empty string, or actual query) if selected_query and isinstance(selected_query, str) and selected_query.strip() and 'query' in df.columns: df = df[df['query'] == selected_query].copy() if df.empty: fig = go.Figure() fig.add_annotation(text="No data for selected query", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) fig.update_layout(template="plotly_dark") return fig # Create scatter plot of tokens by probability delta fig = go.Figure() # Separate positive and negative tokens positive_df = df[df.get('is_positive', df['prob_delta'] > 0) == True] if 'is_positive' in df.columns else df[df['prob_delta'] > 0] negative_df = df[df.get('is_positive', df['prob_delta'] > 0) == False] if 'is_positive' in df.columns else df[df['prob_delta'] <= 0] # Add positive tokens if not positive_df.empty: hover_text = [ f"Token: {row.get('pivot_token', 'N/A')}
" f"Δ Prob: +{row.get('prob_delta', 0):.3f}
" f"Before: {row.get('prob_before', 0):.3f}
" f"After: {row.get('prob_after', 0):.3f}
" f"Query: {str(row.get('query', ''))[:50]}..." for _, row in positive_df.iterrows() ] y_vals = positive_df['prob_delta'].tolist() sizes = [10 + abs(v) * 30 for v in y_vals] fig.add_trace(go.Scatter( x=list(range(len(positive_df))), y=y_vals, mode='markers', name='Positive Impact', marker=dict( size=sizes, color='#22c55e', opacity=0.7 ), hovertext=hover_text, hoverinfo='text' )) # Add negative tokens if not negative_df.empty: hover_text = [ f"Token: {row.get('pivot_token', 'N/A')}
" f"Δ Prob: {row.get('prob_delta', 0):.3f}
" f"Before: {row.get('prob_before', 0):.3f}
" f"After: {row.get('prob_after', 0):.3f}
" f"Query: {str(row.get('query', ''))[:50]}..." for _, row in negative_df.iterrows() ] y_vals = negative_df['prob_delta'].tolist() sizes = [10 + abs(v) * 30 for v in y_vals] fig.add_trace(go.Scatter( x=list(range(len(negative_df))), y=y_vals, mode='markers', name='Negative Impact', marker=dict( size=sizes, color='#ef4444', opacity=0.7 ), hovertext=hover_text, hoverinfo='text' )) fig.add_hline(y=0, line_dash="dash", line_color="gray") fig.update_layout( title="Pivotal Token Impact Distribution", xaxis_title="Token Index", yaxis_title="Probability Delta", template="plotly_dark", height=500, showlegend=True ) return fig def create_thought_anchor_graph(df: pd.DataFrame, selected_query: str = None) -> go.Figure: """Create an interactive graph visualization of thought anchor dependencies.""" dataset_type = detect_dataset_type(df) # For pivotal tokens and steering vectors, create a token impact visualization if dataset_type in ('pivotal_tokens', 'steering_vectors'): return create_pivotal_token_flow(df, selected_query) if df.empty or 'sentence_id' not in df.columns: fig = go.Figure() fig.add_annotation(text="No thought anchor data available. Load a thought anchors dataset to see the reasoning graph.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="#a0a0a0")) fig.update_layout(template="plotly_dark", height=400) return fig # Filter by query if specified (handle None, empty string, or actual query) if selected_query and isinstance(selected_query, str) and selected_query.strip(): df = df[df['query'] == selected_query].copy() if df.empty: fig = go.Figure() fig.add_annotation(text="No data for selected query", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) fig.update_layout(template="plotly_dark") return fig # Create networkx graph G = nx.DiGraph() # Add nodes (sentences) for idx, row in df.iterrows(): sentence_id = row.get('sentence_id', idx) importance = row.get('importance_score', abs(row.get('prob_delta', 0))) is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) sentence = row.get('sentence', '')[:50] + '...' if len(row.get('sentence', '')) > 50 else row.get('sentence', '') G.add_node(sentence_id, importance=importance, is_positive=is_positive, sentence=sentence, category=row.get('sentence_category', 'unknown')) # Add edges from causal dependencies for idx, row in df.iterrows(): sentence_id = row.get('sentence_id', idx) dependencies = row.get('causal_dependencies', []) if isinstance(dependencies, list): for dep in dependencies: if dep in G.nodes(): G.add_edge(dep, sentence_id) # If no explicit dependencies, create sequential edges if G.number_of_edges() == 0: sorted_nodes = sorted(G.nodes()) for i in range(len(sorted_nodes) - 1): G.add_edge(sorted_nodes[i], sorted_nodes[i+1]) # Layout pos = nx.spring_layout(G, k=2, iterations=50) # Create edge traces edge_x = [] edge_y = [] for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] edge_x.extend([float(x0), float(x1), None]) edge_y.extend([float(y0), float(y1), None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=1, color='#888'), hoverinfo='none', mode='lines' ) # Create node traces node_x = [] node_y = [] node_colors = [] node_sizes = [] node_texts = [] for node in G.nodes(): x, y = pos[node] node_x.append(float(x)) node_y.append(float(y)) node_data = G.nodes[node] is_positive = node_data.get('is_positive', True) importance = float(node_data.get('importance', 0.3)) node_colors.append('#22c55e' if is_positive else '#ef4444') node_sizes.append(20 + importance * 50) hover_text = f"Sentence {node}
" hover_text += f"Category: {node_data.get('category', 'unknown')}
" hover_text += f"Importance: {importance:.3f}
" hover_text += f"Text: {node_data.get('sentence', 'N/A')}" node_texts.append(hover_text) node_trace = go.Scatter( x=node_x, y=node_y, mode='markers+text', hoverinfo='text', text=[str(n) for n in G.nodes()], textposition="top center", hovertext=node_texts, marker=dict( color=node_colors, size=node_sizes, line=dict(width=2, color='white') ) ) # Create figure fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title="Thought Anchor Reasoning Graph", showlegend=False, hovermode='closest', template="plotly_dark", xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), height=500 ) return fig def create_probability_space_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure: """Create a probability space visualization for pivotal tokens (prob_before vs prob_after).""" fig = go.Figure() # Color palette for categorical values CATEGORY_COLORS = [ '#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6', '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16' ] # Determine color column use_colorscale = False if color_by in df.columns: color_col = df[color_by] if color_by == 'is_positive': colors = ['#22c55e' if v else '#ef4444' for v in color_col] else: # Convert to list values = color_col.tolist() if hasattr(color_col, 'tolist') else list(color_col) if len(values) > 0: # Check if numeric if isinstance(values[0], (int, float)) and not isinstance(values[0], bool): colors = values use_colorscale = True else: # Categorical - map to colors unique_vals = list(set(values)) color_map = {val: CATEGORY_COLORS[i % len(CATEGORY_COLORS)] for i, val in enumerate(unique_vals)} colors = [color_map[v] for v in values] else: colors = ['#6366f1'] * len(df) else: colors = ['#6366f1'] * len(df) # Create hover text hover_texts = [] for _, row in df.iterrows(): text = f"Token: {row.get('pivot_token', 'N/A')}
" text += f"Before: {row.get('prob_before', 0):.3f}
" text += f"After: {row.get('prob_after', 0):.3f}
" text += f"Delta: {row.get('prob_delta', 0):+.3f}
" text += f"Query: {str(row.get('query', ''))[:40]}..." hover_texts.append(text) fig.add_trace(go.Scatter( x=df['prob_before'].tolist(), y=df['prob_after'].tolist(), mode='markers', marker=dict( size=8, color=colors, opacity=0.6, colorscale='Viridis' if use_colorscale else None, showscale=use_colorscale ), hovertext=hover_texts, hoverinfo='text', name='Pivotal Tokens' )) # Add diagonal line (no change) fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], mode='lines', line=dict(dash='dash', color='gray', width=1), name='No Change Line', showlegend=True )) fig.update_layout( title="Probability Space: Before vs After Pivotal Token", xaxis_title="Probability Before Token", yaxis_title="Probability After Token", xaxis=dict(range=[0, 1]), yaxis=dict(range=[0, 1]), template="plotly_dark", height=500 ) # Add annotations fig.add_annotation( x=0.2, y=0.8, text="Positive Impact ↑", showarrow=False, font=dict(color="#22c55e", size=12) ) fig.add_annotation( x=0.8, y=0.2, text="Negative Impact ↓", showarrow=False, font=dict(color="#ef4444", size=12) ) return fig def create_embedding_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure: """Create UMAP/t-SNE visualization of embeddings or alternative visualization for pivotal tokens.""" if df.empty: fig = go.Figure() fig.add_annotation(text="No data available", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) fig.update_layout(template="plotly_dark") return fig dataset_type = detect_dataset_type(df) # Check for embeddings embedding_col = None for col in ['sentence_embedding', 'steering_vector']: if col in df.columns: embedding_col = col break # For pivotal tokens without embeddings, create a probability space visualization if embedding_col is None: if dataset_type == 'pivotal_tokens' and 'prob_before' in df.columns and 'prob_after' in df.columns: return create_probability_space_visualization(df, color_by) fig = go.Figure() fig.add_annotation( text="No embedding data found. Embeddings are available in thought_anchors and steering_vectors datasets.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="#a0a0a0") ) fig.update_layout(template="plotly_dark", height=400) return fig # Extract embeddings embeddings = [] valid_indices = [] for idx, row in df.iterrows(): emb = row.get(embedding_col, []) # Handle both list and numpy array formats if emb is not None: if isinstance(emb, np.ndarray) and len(emb) > 0: embeddings.append(emb.tolist()) valid_indices.append(idx) elif isinstance(emb, list) and len(emb) > 0: embeddings.append(emb) valid_indices.append(idx) if len(embeddings) < 3: fig = go.Figure() fig.add_annotation(text="Not enough embeddings for visualization (need at least 3)", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) fig.update_layout(template="plotly_dark") return fig embeddings = np.array(embeddings) # Reduce dimensionality n_samples = len(embeddings) perplexity = min(30, max(5, n_samples // 3)) if embeddings.shape[1] > 50: # First reduce with PCA pca = PCA(n_components=min(50, n_samples - 1)) embeddings = pca.fit_transform(embeddings) # Then t-SNE for visualization tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) coords = tsne.fit_transform(embeddings) # Create dataframe for plotting plot_df = df.iloc[valid_indices].copy() plot_df['x'] = coords[:, 0].tolist() plot_df['y'] = coords[:, 1].tolist() # Handle color column if color_by not in plot_df.columns: color_by = 'is_positive' if 'is_positive' in plot_df.columns else None fig = go.Figure() # Determine text field for hover text_field = 'sentence' if 'sentence' in plot_df.columns else 'pivot_token' if color_by and color_by in plot_df.columns: # Group by color column for separate traces if color_by == 'is_positive': # Special handling for boolean is_positive for is_pos in [True, False]: mask = plot_df[color_by] == is_pos subset = plot_df[mask] if len(subset) > 0: hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()] fig.add_trace(go.Scatter( x=subset['x'].tolist(), y=subset['y'].tolist(), mode='markers', name='Positive' if is_pos else 'Negative', marker=dict( size=8, color='#22c55e' if is_pos else '#ef4444', opacity=0.7 ), hovertext=hover_texts, hoverinfo='text' )) else: # Categorical coloring unique_vals = plot_df[color_by].unique() colors = ['#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6', '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16'] for i, val in enumerate(unique_vals): mask = plot_df[color_by] == val subset = plot_df[mask] if len(subset) > 0: hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()] fig.add_trace(go.Scatter( x=subset['x'].tolist(), y=subset['y'].tolist(), mode='markers', name=str(val), marker=dict( size=8, color=colors[i % len(colors)], opacity=0.7 ), hovertext=hover_texts, hoverinfo='text' )) else: # No color grouping hover_texts = [str(row.get(text_field, ''))[:100] for _, row in plot_df.iterrows()] fig.add_trace(go.Scatter( x=plot_df['x'].tolist(), y=plot_df['y'].tolist(), mode='markers', name='Embeddings', marker=dict( size=8, color='#6366f1', opacity=0.7 ), hovertext=hover_texts, hoverinfo='text' )) fig.update_layout( title="Embedding Space Visualization (t-SNE)", xaxis_title="t-SNE 1", yaxis_title="t-SNE 2", template="plotly_dark", height=500, showlegend=True ) return fig def create_pivotal_token_trace(df: pd.DataFrame, selected_query: str) -> Tuple[str, go.Figure]: """Create a trace visualization for pivotal tokens in a query.""" if df.empty: return "No tokens found for this query", go.Figure() # Build HTML for token cards html_parts = [f"""

Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}

Found {len(df)} pivotal tokens for this query

"""] prob_deltas = [] token_indices = [] for idx, (_, row) in enumerate(df.iterrows()): token = row.get('pivot_token', 'N/A') context = row.get('pivot_context', '') is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) prob_delta = row.get('prob_delta', 0) prob_before = row.get('prob_before', 0) prob_after = row.get('prob_after', 0) task_type = row.get('task_type', 'unknown') # Color based on impact bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)" border_color = "#22c55e" if is_positive else "#ef4444" # Show full context in a scrollable container - no truncation # Escape HTML characters in context and token context_escaped = html_lib.escape(str(context)) token_escaped = html_lib.escape(str(token)) # Build token card with full context (scrollable) card_html = f"""
Token #{idx + 1} | {task_type} {'+'if prob_delta > 0 else ''}{prob_delta:.3f}
{context_escaped}{token_escaped}
Before: {prob_before:.3f} After: {prob_after:.3f} Context: {len(context)} chars
""" html_parts.append(card_html) prob_deltas.append(prob_delta) token_indices.append(idx) html_parts.append("
") # Create probability delta chart fig = go.Figure() # Ensure all values are Python native types prob_deltas = [float(d) for d in prob_deltas] colors = ['#22c55e' if d > 0 else '#ef4444' for d in prob_deltas] fig.add_trace(go.Bar( x=token_indices, y=prob_deltas, marker_color=colors, name='Probability Delta', hovertemplate='Token #%{x}
Δ Prob: %{y:.3f}' )) fig.add_hline(y=0, line_dash="dash", line_color="gray") fig.update_layout( title="Probability Impact per Token", xaxis_title="Token Index", yaxis_title="Probability Delta", template="plotly_dark", height=300 ) return "\n".join(html_parts), fig def create_circuit_visualization(df: pd.DataFrame, query_idx: int = 0) -> Tuple[str, go.Figure]: """Create step-by-step circuit visualization for reasoning trace.""" if df.empty: return "No data available", go.Figure() dataset_type = detect_dataset_type(df) # Get unique queries queries = df['query'].unique() if 'query' in df.columns else [] if len(queries) == 0: return "No queries found", go.Figure() query_idx = min(query_idx, len(queries) - 1) selected_query = queries[query_idx] # Filter to this query query_df = df[df['query'] == selected_query].copy() # For pivotal tokens and steering vectors, use the token trace visualization if dataset_type in ('pivotal_tokens', 'steering_vectors'): return create_pivotal_token_trace(query_df, selected_query) # Sort by sentence_id if available, otherwise keep original order if 'sentence_id' in query_df.columns: query_df = query_df.sort_values('sentence_id') else: query_df = query_df.reset_index(drop=True) # Build HTML for step-by-step view html_parts = [f"""

Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}

"""] prob_values = [] sentence_ids = [] for idx, row in query_df.iterrows(): sentence = row.get('sentence', 'N/A') sentence_id = row.get('sentence_id', idx) is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) prob_delta = row.get('prob_delta', 0) category = row.get('sentence_category', 'unknown') importance = row.get('importance_score', abs(prob_delta)) # Verification info verification_score = row.get('verification_score', None) arithmetic_errors = row.get('arithmetic_errors', []) # Color based on impact bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)" border_color = "#22c55e" if is_positive else "#ef4444" # Build step card step_html = f"""
Step {sentence_id} | {category} {'+'if prob_delta > 0 else ''}{prob_delta:.3f}

{sentence}

Importance: {importance:.3f} """ if verification_score is not None: v_color = "#22c55e" if verification_score > 0.5 else "#ef4444" step_html += f""" Verification: {verification_score:.2f} """ if arithmetic_errors: step_html += """ Has Errors """ step_html += """
""" html_parts.append(step_html) prob_values.append(row.get('prob_with_sentence', 0.5)) sentence_ids.append(sentence_id) html_parts.append("
") # Create probability progression chart fig = go.Figure() colors = ['#22c55e' if p > 0.5 else '#ef4444' for p in prob_values] fig.add_trace(go.Scatter( x=[int(s) if isinstance(s, (int, np.integer)) else s for s in sentence_ids], y=[float(p) for p in prob_values], mode='lines+markers', name='Success Probability', line=dict(color='#6366f1', width=2), marker=dict(size=10, color=colors) )) fig.add_hline(y=0.5, line_dash="dash", line_color="gray", annotation_text="50% threshold") fig.update_layout( title="Probability Progression Through Reasoning", xaxis_title="Sentence ID", yaxis_title="Success Probability", yaxis_range=[0, 1], template="plotly_dark", height=300 ) return "\n".join(html_parts), fig def create_statistics_dashboard(df: pd.DataFrame) -> Tuple[str, go.Figure]: """Create statistics dashboard for the dataset.""" if df.empty: return "No data available", go.Figure() dataset_type = detect_dataset_type(df) # Build statistics stats = { "Total Items": len(df), "Dataset Type": dataset_type, } if 'is_positive' in df.columns: positive_count = df['is_positive'].sum() stats["Positive Items"] = int(positive_count) stats["Negative Items"] = int(len(df) - positive_count) if 'prob_delta' in df.columns: stats["Avg Prob Delta"] = f"{df['prob_delta'].mean():.3f}" stats["Max Prob Delta"] = f"{df['prob_delta'].max():.3f}" if 'importance_score' in df.columns: stats["Avg Importance"] = f"{df['importance_score'].mean():.3f}" if 'sentence_category' in df.columns: category_counts = df['sentence_category'].value_counts() stats["Categories"] = len(category_counts) if 'model_id' in df.columns: stats["Models"] = df['model_id'].nunique() # Build HTML html_parts = ['
'] for key, value in stats.items(): html_parts.append(f"""
{value}
{key}
""") html_parts.append('
') # Determine what to show in second chart second_chart_title = "Category Distribution" if 'sentence_category' in df.columns: second_chart_title = "Sentence Category" elif 'reasoning_pattern' in df.columns: second_chart_title = "Reasoning Pattern" elif 'task_type' in df.columns: second_chart_title = "Task Type" elif 'is_positive' in df.columns: second_chart_title = "Positive vs Negative" # Create distribution charts fig = make_subplots(rows=1, cols=2, subplot_titles=("Probability Delta Distribution", second_chart_title)) # First chart: Probability Delta histogram (using numpy for binning) if 'prob_delta' in df.columns and len(df['prob_delta'].dropna()) > 0: prob_data = df['prob_delta'].dropna().values # Create histogram manually using numpy counts, bin_edges = np.histogram(prob_data, bins=30) bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)] fig.add_trace( go.Bar(x=bin_centers, y=counts.tolist(), name="Prob Delta", marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9), row=1, col=1 ) elif 'prob_after' in df.columns and len(df['prob_after'].dropna()) > 0: # Fallback: show prob_after distribution prob_data = df['prob_after'].dropna().values counts, bin_edges = np.histogram(prob_data, bins=30) bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)] fig.add_trace( go.Bar(x=bin_centers, y=counts.tolist(), name="Prob After", marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9), row=1, col=1 ) # Second chart: Categories, patterns, or task types if 'sentence_category' in df.columns: category_counts = df['sentence_category'].value_counts() fig.add_trace( go.Bar(x=category_counts.index.tolist(), y=category_counts.values.tolist(), name="Categories", marker_color='#22c55e'), row=1, col=2 ) elif 'reasoning_pattern' in df.columns: pattern_counts = df['reasoning_pattern'].value_counts() fig.add_trace( go.Bar(x=pattern_counts.index.tolist(), y=pattern_counts.values.tolist(), name="Patterns", marker_color='#22c55e'), row=1, col=2 ) elif 'task_type' in df.columns: task_counts = df['task_type'].value_counts() fig.add_trace( go.Bar(x=task_counts.index.tolist(), y=task_counts.values.tolist(), name="Task Types", marker_color='#22c55e'), row=1, col=2 ) elif 'is_positive' in df.columns: pos_neg_counts = df['is_positive'].value_counts() labels = ['Positive' if v else 'Negative' for v in pos_neg_counts.index.tolist()] fig.add_trace( go.Bar(x=labels, y=pos_neg_counts.values.tolist(), name="Impact", marker_color=['#22c55e' if l == 'Positive' else '#ef4444' for l in labels]), row=1, col=2 ) fig.update_layout( template="plotly_dark", height=350, showlegend=False ) return "\n".join(html_parts), fig # ============================================================================ # Gradio Interface # ============================================================================ # Global state for loaded data current_data = {"df": pd.DataFrame(), "type": "unknown"} def load_dataset_action(source_type: str, dataset_id: str, file_upload): """Handle dataset loading and return all visualization updates.""" global current_data if source_type == "HuggingFace Hub": if not dataset_id: empty_fig = go.Figure() empty_fig.update_layout(template="plotly_dark") return ("Please enter a dataset ID", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, gr.update(maximum=0), gr.update(choices=[], value=None)) df, msg = load_hf_dataset(dataset_id) else: # Local File if file_upload is None: empty_fig = go.Figure() empty_fig.update_layout(template="plotly_dark") return ("Please upload a file", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, gr.update(maximum=0), gr.update(choices=[], value=None)) df, msg = load_jsonl_file(file_upload.name) if df.empty: empty_fig = go.Figure() empty_fig.update_layout(template="plotly_dark") return (msg, "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, gr.update(maximum=0), gr.update(choices=[], value=None)) current_data["df"] = df current_data["type"] = detect_dataset_type(df) columns_info = f"Columns: {', '.join(df.columns[:10])}" if len(df.columns) > 10: columns_info += f" ... and {len(df.columns) - 10} more" # Generate all visualizations stats_html, stats_fig = create_statistics_dashboard(df) graph_fig = create_thought_anchor_graph(df) embed_fig = create_embedding_visualization(df) circuit_html, circuit_fig = create_circuit_visualization(df) # Generate query list query_choices = [] if 'query' in df.columns: queries = df['query'].unique().tolist() for i, q in enumerate(queries): q_str = str(q) if q is not None else "" if len(q_str) > 80: query_choices.append(f"[{i+1}] {q_str[:77]}...") else: query_choices.append(f"[{i+1}] {q_str}") return (msg, f"Dataset type: {current_data['type']}\n{columns_info}", stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig, gr.update(maximum=max(0, len(df) - 1)), gr.update(choices=query_choices, value=None)) def get_token_details(idx: int) -> Tuple[str, go.Figure]: """Get details for a specific pivotal token.""" df = current_data["df"] dataset_type = current_data.get("type", "unknown") if df.empty: return "No data available. Please load a dataset first.", go.Figure() # Handle unsupported dataset types if dataset_type == 'dpo_pairs': html = """

DPO Pairs Dataset

This visualization is not available for DPO pairs datasets.

DPO pairs contain prompt/chosen/rejected structure without token-level context.

Try loading a pivotal_tokens or thought_anchors dataset instead.

""" return html, go.Figure() if idx >= len(df): return "Index out of range", go.Figure() row = df.iloc[idx] context = row.get('pivot_context', row.get('prefix_context', '')) token = row.get('pivot_token', row.get('sentence', '')) prob_delta = row.get('prob_delta', 0) prob_before = row.get('prob_before', row.get('prob_with_sentence', 0.5)) prob_after = row.get('prob_after', row.get('prob_without_sentence', 0.5)) # Handle missing data if not context and not token: html = """

Missing Data

This dataset doesn't have the expected fields for token visualization.

""" return html, go.Figure() html = create_token_highlight_html(context, token, prob_delta) chart = create_probability_chart(prob_before, prob_after) return html, chart def get_original_query_from_label(label: str) -> str: """Extract original query from truncated dropdown label like '[1] query...'""" if not label or not isinstance(label, str): return None df = current_data["df"] if df.empty or 'query' not in df.columns: return None # Extract index from "[N] query..." format match = re.match(r'\[(\d+)\]', label) if match: idx = int(match.group(1)) - 1 # Convert to 0-based index queries = df['query'].unique().tolist() if 0 <= idx < len(queries): return queries[idx] return None def update_graph_visualization(query_dropdown: str = None): """Update the thought anchor graph.""" dataset_type = current_data.get("type", "unknown") if dataset_type == 'dpo_pairs': fig = go.Figure() fig.add_annotation( text="Reasoning Graph is not available for DPO pairs datasets.
Load a pivotal_tokens or thought_anchors dataset.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="#a0a0a0") ) fig.update_layout(template="plotly_dark", height=400) return fig # Convert truncated label back to original query original_query = get_original_query_from_label(query_dropdown) return create_thought_anchor_graph(current_data["df"], original_query) def update_embedding_visualization(color_by: str): """Update the embedding visualization.""" dataset_type = current_data.get("type", "unknown") if dataset_type == 'dpo_pairs': fig = go.Figure() fig.add_annotation( text="Embedding Space is not available for DPO pairs datasets.
Load a pivotal_tokens, thought_anchors, or steering_vectors dataset.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=14, color="#a0a0a0") ) fig.update_layout(template="plotly_dark", height=400) return fig return create_embedding_visualization(current_data["df"], color_by) def update_circuit_view(query_idx: int): """Update the circuit view.""" dataset_type = current_data.get("type", "unknown") if dataset_type == 'dpo_pairs': html = """

DPO Pairs Dataset

Circuit Tracer is not available for DPO pairs datasets.

Load a pivotal_tokens or thought_anchors dataset to explore reasoning circuits.

""" return html, go.Figure() return create_circuit_visualization(current_data["df"], int(query_idx)) def update_statistics(): """Update the statistics dashboard.""" return create_statistics_dashboard(current_data["df"]) def get_query_list(): """Get list of unique queries with truncated display labels.""" df = current_data["df"] if df.empty or 'query' not in df.columns: return gr.update(choices=[], value=None) queries = df['query'].unique().tolist() # Return simple truncated strings for dropdown choices truncated_queries = [] for i, q in enumerate(queries): q_str = str(q) if q is not None else "" if len(q_str) > 80: truncated_queries.append(f"[{i+1}] {q_str[:77]}...") else: truncated_queries.append(f"[{i+1}] {q_str}") return gr.update(choices=truncated_queries, value=None) def refresh_all(): """Refresh all visualizations.""" df = current_data["df"] if df.empty: empty_fig = go.Figure() empty_fig.update_layout(template="plotly_dark") return ( "No data loaded", empty_fig, empty_fig, empty_fig, "No data loaded", empty_fig ) stats_html, stats_fig = create_statistics_dashboard(df) graph_fig = create_thought_anchor_graph(df) embed_fig = create_embedding_visualization(df) circuit_html, circuit_fig = create_circuit_visualization(df) return stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig # ============================================================================ # Build Gradio App # ============================================================================ # Pre-defined HuggingFace datasets HF_DATASETS = [ "codelion/Qwen3-0.6B-pts", "codelion/Qwen3-0.6B-pts-thought-anchors", "codelion/Qwen3-0.6B-pts-steering-vectors", "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts", "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-thought-anchors", "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors", ] # CSS configuration CSS = """ .gradio-container { max-width: 1400px !important; } .main-header { text-align: center; margin-bottom: 20px; } """ with gr.Blocks(title="PTS Visualizer", css=CSS) as demo: # Header gr.Markdown(""" # PTS Visualizer ### Interactive Exploration of Pivotal Tokens, Thought Anchors & Reasoning Circuits A [Neuronpedia](https://neuronpedia.org/)-inspired platform for understanding how language models reason. Load datasets from HuggingFace Hub or upload your own JSONL files. 🔗 [Browse more PTS datasets on HuggingFace](https://huggingface.co/datasets?other=pts) """) # Data Loading Section with gr.Accordion("Load Dataset", open=True): with gr.Row(): source_type = gr.Radio( choices=["HuggingFace Hub", "Local File"], value="HuggingFace Hub", label="Data Source" ) with gr.Row(): with gr.Column(scale=3): dataset_dropdown = gr.Dropdown( choices=HF_DATASETS, value=HF_DATASETS[0], label="Select Dataset", info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID" ) with gr.Column(scale=1): file_upload = gr.File( label="Or Upload JSONL", file_types=[".jsonl", ".json"] ) with gr.Row(): load_btn = gr.Button("Load Dataset", variant="primary") refresh_btn = gr.Button("Refresh Visualizations", variant="secondary") with gr.Row(): load_status = gr.Textbox(label="Status", interactive=False) dataset_info = gr.Textbox(label="Dataset Info", interactive=False) # Main Visualization Tabs with gr.Tabs(): # Overview Tab with gr.TabItem("Overview"): gr.Markdown("### Dataset Statistics") stats_html = gr.HTML() stats_chart = gr.Plot() # Token Explorer Tab with gr.TabItem("Token Explorer"): gr.Markdown("### Explore Pivotal Tokens") with gr.Row(): with gr.Column(scale=1): token_slider = gr.Slider( minimum=0, maximum=100, step=1, value=0, label="Token Index" ) with gr.Column(scale=3): token_html = gr.HTML(label="Token in Context") prob_chart = gr.Plot(label="Probability Change") # Thought Anchor Graph Tab with gr.TabItem("Reasoning Graph"): gr.Markdown("### Thought Anchor Dependency Graph") gr.Markdown(""" *Visualizes causal dependencies between reasoning steps. Green nodes indicate positive impact, red nodes indicate negative impact. Node size reflects importance score.* """) with gr.Row(): query_filter = gr.Dropdown( choices=[], value=None, label="Filter by Query" ) graph_plot = gr.Plot() # Embedding Visualization Tab with gr.TabItem("Embedding Space"): gr.Markdown("### Embedding Space Visualization") gr.Markdown("*t-SNE projection of sentence/token embeddings. Explore clusters and patterns.*") with gr.Row(): color_dropdown = gr.Dropdown( choices=["is_positive", "sentence_category", "reasoning_pattern", "task_type"], value="is_positive", label="Color By" ) embed_plot = gr.Plot() # Circuit Tracer Tab with gr.TabItem("Circuit Tracer"): gr.Markdown("### Step-by-Step Reasoning Circuit") gr.Markdown("*Walk through the reasoning process step by step. See how each step affects the probability of success.*") with gr.Row(): circuit_query_idx = gr.Slider( minimum=0, maximum=100, step=1, value=0, label="Query Index" ) circuit_html = gr.HTML() circuit_chart = gr.Plot() # Event handlers - using api_name=False to prevent schema generation issues load_btn.click( fn=load_dataset_action, inputs=[source_type, dataset_dropdown, file_upload], outputs=[load_status, dataset_info, stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart, token_slider, query_filter], api_name=False ) refresh_btn.click( fn=refresh_all, outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart], api_name=False ) token_slider.change( fn=get_token_details, inputs=[token_slider], outputs=[token_html, prob_chart], api_name=False ) query_filter.change( fn=update_graph_visualization, inputs=[query_filter], outputs=[graph_plot], api_name=False ) color_dropdown.change( fn=update_embedding_visualization, inputs=[color_dropdown], outputs=[embed_plot], api_name=False ) circuit_query_idx.change( fn=update_circuit_view, inputs=[circuit_query_idx], outputs=[circuit_html, circuit_chart], api_name=False ) # ============================================================================ # Main Entry Point # ============================================================================ if __name__ == "__main__": demo.launch()