Spaces:
Running
Running
| """ | |
| PTS Visualizer - Interactive visualization for Pivotal Token Search | |
| A Neuronpedia-inspired platform for exploring pivotal tokens, thought anchors, | |
| and reasoning circuits in language models. | |
| """ | |
| import gradio as gr | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import networkx as nx | |
| import pandas as pd | |
| import numpy as np | |
| import json | |
| import html as html_lib | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from datasets import load_dataset | |
| from sklearn.manifold import TSNE | |
| from sklearn.decomposition import PCA | |
| import re | |
| from collections import defaultdict | |
| # ============================================================================ | |
| # Data Loading Functions | |
| # ============================================================================ | |
| def load_hf_dataset(dataset_id: str, split: str = "train") -> pd.DataFrame: | |
| """Load a dataset from HuggingFace Hub.""" | |
| try: | |
| dataset = load_dataset(dataset_id, split=split) | |
| df = pd.DataFrame(dataset) | |
| return df, f"Loaded {len(df)} items from {dataset_id}" | |
| except Exception as e: | |
| return pd.DataFrame(), f"Error loading dataset: {str(e)}" | |
| def load_jsonl_file(file_path: str) -> pd.DataFrame: | |
| """Load data from a local JSONL file.""" | |
| try: | |
| data = [] | |
| with open(file_path, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line)) | |
| return pd.DataFrame(data), f"Loaded {len(data)} items from file" | |
| except Exception as e: | |
| return pd.DataFrame(), f"Error loading file: {str(e)}" | |
| def detect_dataset_type(df: pd.DataFrame) -> str: | |
| """Detect the type of PTS dataset.""" | |
| columns = set(df.columns) | |
| if 'sentence' in columns and 'sentence_id' in columns: | |
| return 'thought_anchors' | |
| elif 'steering_vector' in columns: | |
| return 'steering_vectors' | |
| elif 'chosen' in columns and 'rejected' in columns: | |
| return 'dpo_pairs' | |
| elif 'pivot_token' in columns: | |
| return 'pivotal_tokens' | |
| else: | |
| return 'unknown' | |
| # ============================================================================ | |
| # Visualization Components | |
| # ============================================================================ | |
| def create_token_highlight_html(context: str, token: str, prob_delta: float) -> str: | |
| """Create HTML with highlighted pivotal token showing full context.""" | |
| # Escape HTML characters | |
| context_escaped = html_lib.escape(str(context)) | |
| token_escaped = html_lib.escape(str(token)) | |
| # Determine color based on probability delta | |
| if prob_delta > 0: | |
| # Positive impact - green gradient | |
| intensity = min(abs(prob_delta) * 2, 1.0) | |
| color = f"rgba(34, 197, 94, {intensity})" | |
| border_color = "#22c55e" | |
| impact_text = "Positive Impact" | |
| else: | |
| # Negative impact - red gradient | |
| intensity = min(abs(prob_delta) * 2, 1.0) | |
| color = f"rgba(239, 68, 68, {intensity})" | |
| border_color = "#ef4444" | |
| impact_text = "Negative Impact" | |
| # Create highlighted token span | |
| token_span = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; border: 2px solid {border_color}; font-weight: bold; font-size: 1.1em;">{token_escaped}</span>' | |
| return f""" | |
| <div style="background-color: #1a1a2e; border-radius: 10px; padding: 20px;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px;"> | |
| <span style="color: #a0a0a0; font-size: 0.9em;">Context Length: {len(context)} characters</span> | |
| <span style="background-color: {border_color}; color: white; padding: 4px 12px; border-radius: 5px; font-weight: bold;"> | |
| {impact_text}: {'+' if prob_delta > 0 else ''}{prob_delta:.3f} | |
| </span> | |
| </div> | |
| <div style="font-family: monospace; padding: 15px; background-color: #0d1117; border-radius: 8px; color: #e0e0e0; line-height: 1.8; max-height: 500px; overflow-y: auto; white-space: pre-wrap; word-break: break-word; border: 1px solid #30363d;"> | |
| <span style="color: #8b949e;">{context_escaped}</span>{token_span} | |
| </div> | |
| <div style="margin-top: 15px; display: flex; gap: 10px; flex-wrap: wrap;"> | |
| <span style="background-color: #238636; color: white; padding: 5px 10px; border-radius: 5px; font-size: 0.9em;"> | |
| Token: <code style="background-color: rgba(0,0,0,0.3); padding: 2px 5px; border-radius: 3px;">{token_escaped}</code> | |
| </span> | |
| </div> | |
| </div> | |
| """ | |
| def create_probability_chart(prob_before: float, prob_after: float) -> go.Figure: | |
| """Create a bar chart showing probability change.""" | |
| fig = go.Figure() | |
| # Ensure values are Python floats | |
| prob_before = float(prob_before) if prob_before is not None else 0.0 | |
| prob_after = float(prob_after) if prob_after is not None else 0.0 | |
| fig.add_trace(go.Bar( | |
| x=['Before Token', 'After Token'], | |
| y=[prob_before, prob_after], | |
| marker_color=['#6366f1', '#22c55e' if prob_after > prob_before else '#ef4444'], | |
| text=[f'{prob_before:.3f}', f'{prob_after:.3f}'], | |
| textposition='outside' | |
| )) | |
| fig.update_layout( | |
| title="Success Probability Change", | |
| yaxis_title="Probability", | |
| yaxis_range=[0, 1], | |
| template="plotly_dark", | |
| height=300 | |
| ) | |
| return fig | |
| def create_pivotal_token_flow(df: pd.DataFrame, selected_query: str = None) -> go.Figure: | |
| """Create a visualization for pivotal tokens showing token impact flow.""" | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No data available", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| fig.update_layout(template="plotly_dark") | |
| return fig | |
| # Filter by query if specified (handle None, empty string, or actual query) | |
| if selected_query and isinstance(selected_query, str) and selected_query.strip() and 'query' in df.columns: | |
| df = df[df['query'] == selected_query].copy() | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No data for selected query", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| fig.update_layout(template="plotly_dark") | |
| return fig | |
| # Create scatter plot of tokens by probability delta | |
| fig = go.Figure() | |
| # Separate positive and negative tokens | |
| positive_df = df[df.get('is_positive', df['prob_delta'] > 0) == True] if 'is_positive' in df.columns else df[df['prob_delta'] > 0] | |
| negative_df = df[df.get('is_positive', df['prob_delta'] > 0) == False] if 'is_positive' in df.columns else df[df['prob_delta'] <= 0] | |
| # Add positive tokens | |
| if not positive_df.empty: | |
| hover_text = [ | |
| f"Token: {row.get('pivot_token', 'N/A')}<br>" | |
| f"Δ Prob: +{row.get('prob_delta', 0):.3f}<br>" | |
| f"Before: {row.get('prob_before', 0):.3f}<br>" | |
| f"After: {row.get('prob_after', 0):.3f}<br>" | |
| f"Query: {str(row.get('query', ''))[:50]}..." | |
| for _, row in positive_df.iterrows() | |
| ] | |
| y_vals = positive_df['prob_delta'].tolist() | |
| sizes = [10 + abs(v) * 30 for v in y_vals] | |
| fig.add_trace(go.Scatter( | |
| x=list(range(len(positive_df))), | |
| y=y_vals, | |
| mode='markers', | |
| name='Positive Impact', | |
| marker=dict( | |
| size=sizes, | |
| color='#22c55e', | |
| opacity=0.7 | |
| ), | |
| hovertext=hover_text, | |
| hoverinfo='text' | |
| )) | |
| # Add negative tokens | |
| if not negative_df.empty: | |
| hover_text = [ | |
| f"Token: {row.get('pivot_token', 'N/A')}<br>" | |
| f"Δ Prob: {row.get('prob_delta', 0):.3f}<br>" | |
| f"Before: {row.get('prob_before', 0):.3f}<br>" | |
| f"After: {row.get('prob_after', 0):.3f}<br>" | |
| f"Query: {str(row.get('query', ''))[:50]}..." | |
| for _, row in negative_df.iterrows() | |
| ] | |
| y_vals = negative_df['prob_delta'].tolist() | |
| sizes = [10 + abs(v) * 30 for v in y_vals] | |
| fig.add_trace(go.Scatter( | |
| x=list(range(len(negative_df))), | |
| y=y_vals, | |
| mode='markers', | |
| name='Negative Impact', | |
| marker=dict( | |
| size=sizes, | |
| color='#ef4444', | |
| opacity=0.7 | |
| ), | |
| hovertext=hover_text, | |
| hoverinfo='text' | |
| )) | |
| fig.add_hline(y=0, line_dash="dash", line_color="gray") | |
| fig.update_layout( | |
| title="Pivotal Token Impact Distribution", | |
| xaxis_title="Token Index", | |
| yaxis_title="Probability Delta", | |
| template="plotly_dark", | |
| height=500, | |
| showlegend=True | |
| ) | |
| return fig | |
| def create_thought_anchor_graph(df: pd.DataFrame, selected_query: str = None) -> go.Figure: | |
| """Create an interactive graph visualization of thought anchor dependencies.""" | |
| dataset_type = detect_dataset_type(df) | |
| # For pivotal tokens and steering vectors, create a token impact visualization | |
| if dataset_type in ('pivotal_tokens', 'steering_vectors'): | |
| return create_pivotal_token_flow(df, selected_query) | |
| if df.empty or 'sentence_id' not in df.columns: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No thought anchor data available. Load a thought anchors dataset to see the reasoning graph.", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=14, color="#a0a0a0")) | |
| fig.update_layout(template="plotly_dark", height=400) | |
| return fig | |
| # Filter by query if specified (handle None, empty string, or actual query) | |
| if selected_query and isinstance(selected_query, str) and selected_query.strip(): | |
| df = df[df['query'] == selected_query].copy() | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No data for selected query", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| fig.update_layout(template="plotly_dark") | |
| return fig | |
| # Create networkx graph | |
| G = nx.DiGraph() | |
| # Add nodes (sentences) | |
| for idx, row in df.iterrows(): | |
| sentence_id = row.get('sentence_id', idx) | |
| importance = row.get('importance_score', abs(row.get('prob_delta', 0))) | |
| is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) | |
| sentence = row.get('sentence', '')[:50] + '...' if len(row.get('sentence', '')) > 50 else row.get('sentence', '') | |
| G.add_node(sentence_id, | |
| importance=importance, | |
| is_positive=is_positive, | |
| sentence=sentence, | |
| category=row.get('sentence_category', 'unknown')) | |
| # Add edges from causal dependencies | |
| for idx, row in df.iterrows(): | |
| sentence_id = row.get('sentence_id', idx) | |
| dependencies = row.get('causal_dependencies', []) | |
| if isinstance(dependencies, list): | |
| for dep in dependencies: | |
| if dep in G.nodes(): | |
| G.add_edge(dep, sentence_id) | |
| # If no explicit dependencies, create sequential edges | |
| if G.number_of_edges() == 0: | |
| sorted_nodes = sorted(G.nodes()) | |
| for i in range(len(sorted_nodes) - 1): | |
| G.add_edge(sorted_nodes[i], sorted_nodes[i+1]) | |
| # Layout | |
| pos = nx.spring_layout(G, k=2, iterations=50) | |
| # Create edge traces | |
| edge_x = [] | |
| edge_y = [] | |
| for edge in G.edges(): | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_x.extend([float(x0), float(x1), None]) | |
| edge_y.extend([float(y0), float(y1), None]) | |
| edge_trace = go.Scatter( | |
| x=edge_x, y=edge_y, | |
| line=dict(width=1, color='#888'), | |
| hoverinfo='none', | |
| mode='lines' | |
| ) | |
| # Create node traces | |
| node_x = [] | |
| node_y = [] | |
| node_colors = [] | |
| node_sizes = [] | |
| node_texts = [] | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_x.append(float(x)) | |
| node_y.append(float(y)) | |
| node_data = G.nodes[node] | |
| is_positive = node_data.get('is_positive', True) | |
| importance = float(node_data.get('importance', 0.3)) | |
| node_colors.append('#22c55e' if is_positive else '#ef4444') | |
| node_sizes.append(20 + importance * 50) | |
| hover_text = f"Sentence {node}<br>" | |
| hover_text += f"Category: {node_data.get('category', 'unknown')}<br>" | |
| hover_text += f"Importance: {importance:.3f}<br>" | |
| hover_text += f"Text: {node_data.get('sentence', 'N/A')}" | |
| node_texts.append(hover_text) | |
| node_trace = go.Scatter( | |
| x=node_x, y=node_y, | |
| mode='markers+text', | |
| hoverinfo='text', | |
| text=[str(n) for n in G.nodes()], | |
| textposition="top center", | |
| hovertext=node_texts, | |
| marker=dict( | |
| color=node_colors, | |
| size=node_sizes, | |
| line=dict(width=2, color='white') | |
| ) | |
| ) | |
| # Create figure | |
| fig = go.Figure(data=[edge_trace, node_trace]) | |
| fig.update_layout( | |
| title="Thought Anchor Reasoning Graph", | |
| showlegend=False, | |
| hovermode='closest', | |
| template="plotly_dark", | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| height=500 | |
| ) | |
| return fig | |
| def create_probability_space_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure: | |
| """Create a probability space visualization for pivotal tokens (prob_before vs prob_after).""" | |
| fig = go.Figure() | |
| # Color palette for categorical values | |
| CATEGORY_COLORS = [ | |
| '#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6', | |
| '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16' | |
| ] | |
| # Determine color column | |
| use_colorscale = False | |
| if color_by in df.columns: | |
| color_col = df[color_by] | |
| if color_by == 'is_positive': | |
| colors = ['#22c55e' if v else '#ef4444' for v in color_col] | |
| else: | |
| # Convert to list | |
| values = color_col.tolist() if hasattr(color_col, 'tolist') else list(color_col) | |
| if len(values) > 0: | |
| # Check if numeric | |
| if isinstance(values[0], (int, float)) and not isinstance(values[0], bool): | |
| colors = values | |
| use_colorscale = True | |
| else: | |
| # Categorical - map to colors | |
| unique_vals = list(set(values)) | |
| color_map = {val: CATEGORY_COLORS[i % len(CATEGORY_COLORS)] for i, val in enumerate(unique_vals)} | |
| colors = [color_map[v] for v in values] | |
| else: | |
| colors = ['#6366f1'] * len(df) | |
| else: | |
| colors = ['#6366f1'] * len(df) | |
| # Create hover text | |
| hover_texts = [] | |
| for _, row in df.iterrows(): | |
| text = f"Token: {row.get('pivot_token', 'N/A')}<br>" | |
| text += f"Before: {row.get('prob_before', 0):.3f}<br>" | |
| text += f"After: {row.get('prob_after', 0):.3f}<br>" | |
| text += f"Delta: {row.get('prob_delta', 0):+.3f}<br>" | |
| text += f"Query: {str(row.get('query', ''))[:40]}..." | |
| hover_texts.append(text) | |
| fig.add_trace(go.Scatter( | |
| x=df['prob_before'].tolist(), | |
| y=df['prob_after'].tolist(), | |
| mode='markers', | |
| marker=dict( | |
| size=8, | |
| color=colors, | |
| opacity=0.6, | |
| colorscale='Viridis' if use_colorscale else None, | |
| showscale=use_colorscale | |
| ), | |
| hovertext=hover_texts, | |
| hoverinfo='text', | |
| name='Pivotal Tokens' | |
| )) | |
| # Add diagonal line (no change) | |
| fig.add_trace(go.Scatter( | |
| x=[0, 1], | |
| y=[0, 1], | |
| mode='lines', | |
| line=dict(dash='dash', color='gray', width=1), | |
| name='No Change Line', | |
| showlegend=True | |
| )) | |
| fig.update_layout( | |
| title="Probability Space: Before vs After Pivotal Token", | |
| xaxis_title="Probability Before Token", | |
| yaxis_title="Probability After Token", | |
| xaxis=dict(range=[0, 1]), | |
| yaxis=dict(range=[0, 1]), | |
| template="plotly_dark", | |
| height=500 | |
| ) | |
| # Add annotations | |
| fig.add_annotation( | |
| x=0.2, y=0.8, | |
| text="Positive Impact ↑", | |
| showarrow=False, | |
| font=dict(color="#22c55e", size=12) | |
| ) | |
| fig.add_annotation( | |
| x=0.8, y=0.2, | |
| text="Negative Impact ↓", | |
| showarrow=False, | |
| font=dict(color="#ef4444", size=12) | |
| ) | |
| return fig | |
| def create_embedding_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure: | |
| """Create UMAP/t-SNE visualization of embeddings or alternative visualization for pivotal tokens.""" | |
| if df.empty: | |
| fig = go.Figure() | |
| fig.add_annotation(text="No data available", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| fig.update_layout(template="plotly_dark") | |
| return fig | |
| dataset_type = detect_dataset_type(df) | |
| # Check for embeddings | |
| embedding_col = None | |
| for col in ['sentence_embedding', 'steering_vector']: | |
| if col in df.columns: | |
| embedding_col = col | |
| break | |
| # For pivotal tokens without embeddings, create a probability space visualization | |
| if embedding_col is None: | |
| if dataset_type == 'pivotal_tokens' and 'prob_before' in df.columns and 'prob_after' in df.columns: | |
| return create_probability_space_visualization(df, color_by) | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No embedding data found. Embeddings are available in thought_anchors and steering_vectors datasets.", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=12, color="#a0a0a0") | |
| ) | |
| fig.update_layout(template="plotly_dark", height=400) | |
| return fig | |
| # Extract embeddings | |
| embeddings = [] | |
| valid_indices = [] | |
| for idx, row in df.iterrows(): | |
| emb = row.get(embedding_col, []) | |
| # Handle both list and numpy array formats | |
| if emb is not None: | |
| if isinstance(emb, np.ndarray) and len(emb) > 0: | |
| embeddings.append(emb.tolist()) | |
| valid_indices.append(idx) | |
| elif isinstance(emb, list) and len(emb) > 0: | |
| embeddings.append(emb) | |
| valid_indices.append(idx) | |
| if len(embeddings) < 3: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Not enough embeddings for visualization (need at least 3)", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) | |
| fig.update_layout(template="plotly_dark") | |
| return fig | |
| embeddings = np.array(embeddings) | |
| # Reduce dimensionality | |
| n_samples = len(embeddings) | |
| perplexity = min(30, max(5, n_samples // 3)) | |
| if embeddings.shape[1] > 50: | |
| # First reduce with PCA | |
| pca = PCA(n_components=min(50, n_samples - 1)) | |
| embeddings = pca.fit_transform(embeddings) | |
| # Then t-SNE for visualization | |
| tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42) | |
| coords = tsne.fit_transform(embeddings) | |
| # Create dataframe for plotting | |
| plot_df = df.iloc[valid_indices].copy() | |
| plot_df['x'] = coords[:, 0].tolist() | |
| plot_df['y'] = coords[:, 1].tolist() | |
| # Handle color column | |
| if color_by not in plot_df.columns: | |
| color_by = 'is_positive' if 'is_positive' in plot_df.columns else None | |
| fig = go.Figure() | |
| # Determine text field for hover | |
| text_field = 'sentence' if 'sentence' in plot_df.columns else 'pivot_token' | |
| if color_by and color_by in plot_df.columns: | |
| # Group by color column for separate traces | |
| if color_by == 'is_positive': | |
| # Special handling for boolean is_positive | |
| for is_pos in [True, False]: | |
| mask = plot_df[color_by] == is_pos | |
| subset = plot_df[mask] | |
| if len(subset) > 0: | |
| hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()] | |
| fig.add_trace(go.Scatter( | |
| x=subset['x'].tolist(), | |
| y=subset['y'].tolist(), | |
| mode='markers', | |
| name='Positive' if is_pos else 'Negative', | |
| marker=dict( | |
| size=8, | |
| color='#22c55e' if is_pos else '#ef4444', | |
| opacity=0.7 | |
| ), | |
| hovertext=hover_texts, | |
| hoverinfo='text' | |
| )) | |
| else: | |
| # Categorical coloring | |
| unique_vals = plot_df[color_by].unique() | |
| colors = ['#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6', | |
| '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16'] | |
| for i, val in enumerate(unique_vals): | |
| mask = plot_df[color_by] == val | |
| subset = plot_df[mask] | |
| if len(subset) > 0: | |
| hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()] | |
| fig.add_trace(go.Scatter( | |
| x=subset['x'].tolist(), | |
| y=subset['y'].tolist(), | |
| mode='markers', | |
| name=str(val), | |
| marker=dict( | |
| size=8, | |
| color=colors[i % len(colors)], | |
| opacity=0.7 | |
| ), | |
| hovertext=hover_texts, | |
| hoverinfo='text' | |
| )) | |
| else: | |
| # No color grouping | |
| hover_texts = [str(row.get(text_field, ''))[:100] for _, row in plot_df.iterrows()] | |
| fig.add_trace(go.Scatter( | |
| x=plot_df['x'].tolist(), | |
| y=plot_df['y'].tolist(), | |
| mode='markers', | |
| name='Embeddings', | |
| marker=dict( | |
| size=8, | |
| color='#6366f1', | |
| opacity=0.7 | |
| ), | |
| hovertext=hover_texts, | |
| hoverinfo='text' | |
| )) | |
| fig.update_layout( | |
| title="Embedding Space Visualization (t-SNE)", | |
| xaxis_title="t-SNE 1", | |
| yaxis_title="t-SNE 2", | |
| template="plotly_dark", | |
| height=500, | |
| showlegend=True | |
| ) | |
| return fig | |
| def create_pivotal_token_trace(df: pd.DataFrame, selected_query: str) -> Tuple[str, go.Figure]: | |
| """Create a trace visualization for pivotal tokens in a query.""" | |
| if df.empty: | |
| return "No tokens found for this query", go.Figure() | |
| # Build HTML for token cards | |
| html_parts = [f""" | |
| <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;"> | |
| <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;"> | |
| Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''} | |
| </h3> | |
| <p style="color: #a0a0a0; margin: 10px 0;">Found {len(df)} pivotal tokens for this query</p> | |
| <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;"> | |
| """] | |
| prob_deltas = [] | |
| token_indices = [] | |
| for idx, (_, row) in enumerate(df.iterrows()): | |
| token = row.get('pivot_token', 'N/A') | |
| context = row.get('pivot_context', '') | |
| is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) | |
| prob_delta = row.get('prob_delta', 0) | |
| prob_before = row.get('prob_before', 0) | |
| prob_after = row.get('prob_after', 0) | |
| task_type = row.get('task_type', 'unknown') | |
| # Color based on impact | |
| bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)" | |
| border_color = "#22c55e" if is_positive else "#ef4444" | |
| # Show full context in a scrollable container - no truncation | |
| # Escape HTML characters in context and token | |
| context_escaped = html_lib.escape(str(context)) | |
| token_escaped = html_lib.escape(str(token)) | |
| # Build token card with full context (scrollable) | |
| card_html = f""" | |
| <div style="background-color: {bg_color}; border-left: 4px solid {border_color}; | |
| padding: 15px; border-radius: 5px; margin-bottom: 5px;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;"> | |
| <span style="color: #a0a0a0; font-size: 0.9em;">Token #{idx + 1} | {task_type}</span> | |
| <span style="color: {border_color}; font-weight: bold; font-size: 1.1em;"> | |
| {'+'if prob_delta > 0 else ''}{prob_delta:.3f} | |
| </span> | |
| </div> | |
| <div style="background-color: #1a1a2e; padding: 10px; border-radius: 5px; max-height: 200px; overflow-y: auto; margin: 10px 0;"> | |
| <span style="color: #888; font-family: monospace; font-size: 0.85em; white-space: pre-wrap; word-break: break-word;">{context_escaped}</span><span style="background-color: {border_color}; color: white; padding: 2px 6px; border-radius: 3px; font-weight: bold; font-family: monospace;">{token_escaped}</span> | |
| </div> | |
| <div style="display: flex; gap: 15px; flex-wrap: wrap;"> | |
| <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;"> | |
| Before: {prob_before:.3f} | |
| </span> | |
| <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;"> | |
| After: {prob_after:.3f} | |
| </span> | |
| <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #6366f1;"> | |
| Context: {len(context)} chars | |
| </span> | |
| </div> | |
| </div> | |
| """ | |
| html_parts.append(card_html) | |
| prob_deltas.append(prob_delta) | |
| token_indices.append(idx) | |
| html_parts.append("</div></div>") | |
| # Create probability delta chart | |
| fig = go.Figure() | |
| # Ensure all values are Python native types | |
| prob_deltas = [float(d) for d in prob_deltas] | |
| colors = ['#22c55e' if d > 0 else '#ef4444' for d in prob_deltas] | |
| fig.add_trace(go.Bar( | |
| x=token_indices, | |
| y=prob_deltas, | |
| marker_color=colors, | |
| name='Probability Delta', | |
| hovertemplate='Token #%{x}<br>Δ Prob: %{y:.3f}<extra></extra>' | |
| )) | |
| fig.add_hline(y=0, line_dash="dash", line_color="gray") | |
| fig.update_layout( | |
| title="Probability Impact per Token", | |
| xaxis_title="Token Index", | |
| yaxis_title="Probability Delta", | |
| template="plotly_dark", | |
| height=300 | |
| ) | |
| return "\n".join(html_parts), fig | |
| def create_circuit_visualization(df: pd.DataFrame, query_idx: int = 0) -> Tuple[str, go.Figure]: | |
| """Create step-by-step circuit visualization for reasoning trace.""" | |
| if df.empty: | |
| return "No data available", go.Figure() | |
| dataset_type = detect_dataset_type(df) | |
| # Get unique queries | |
| queries = df['query'].unique() if 'query' in df.columns else [] | |
| if len(queries) == 0: | |
| return "No queries found", go.Figure() | |
| query_idx = min(query_idx, len(queries) - 1) | |
| selected_query = queries[query_idx] | |
| # Filter to this query | |
| query_df = df[df['query'] == selected_query].copy() | |
| # For pivotal tokens and steering vectors, use the token trace visualization | |
| if dataset_type in ('pivotal_tokens', 'steering_vectors'): | |
| return create_pivotal_token_trace(query_df, selected_query) | |
| # Sort by sentence_id if available, otherwise keep original order | |
| if 'sentence_id' in query_df.columns: | |
| query_df = query_df.sort_values('sentence_id') | |
| else: | |
| query_df = query_df.reset_index(drop=True) | |
| # Build HTML for step-by-step view | |
| html_parts = [f""" | |
| <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;"> | |
| <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;"> | |
| Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''} | |
| </h3> | |
| <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;"> | |
| """] | |
| prob_values = [] | |
| sentence_ids = [] | |
| for idx, row in query_df.iterrows(): | |
| sentence = row.get('sentence', 'N/A') | |
| sentence_id = row.get('sentence_id', idx) | |
| is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0) | |
| prob_delta = row.get('prob_delta', 0) | |
| category = row.get('sentence_category', 'unknown') | |
| importance = row.get('importance_score', abs(prob_delta)) | |
| # Verification info | |
| verification_score = row.get('verification_score', None) | |
| arithmetic_errors = row.get('arithmetic_errors', []) | |
| # Color based on impact | |
| bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)" | |
| border_color = "#22c55e" if is_positive else "#ef4444" | |
| # Build step card | |
| step_html = f""" | |
| <div style="background-color: {bg_color}; border-left: 4px solid {border_color}; | |
| padding: 15px; border-radius: 5px;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center;"> | |
| <span style="color: #a0a0a0; font-size: 0.9em;">Step {sentence_id} | {category}</span> | |
| <span style="color: {border_color}; font-weight: bold;"> | |
| {'+'if prob_delta > 0 else ''}{prob_delta:.3f} | |
| </span> | |
| </div> | |
| <p style="color: #e0e0e0; margin: 10px 0;">{sentence}</p> | |
| <div style="display: flex; gap: 10px; flex-wrap: wrap;"> | |
| <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;"> | |
| Importance: {importance:.3f} | |
| </span> | |
| """ | |
| if verification_score is not None: | |
| v_color = "#22c55e" if verification_score > 0.5 else "#ef4444" | |
| step_html += f""" | |
| <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: {v_color};"> | |
| Verification: {verification_score:.2f} | |
| </span> | |
| """ | |
| if arithmetic_errors: | |
| step_html += """ | |
| <span style="background-color: #7f1d1d; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #fca5a5;"> | |
| Has Errors | |
| </span> | |
| """ | |
| step_html += """ | |
| </div> | |
| </div> | |
| """ | |
| html_parts.append(step_html) | |
| prob_values.append(row.get('prob_with_sentence', 0.5)) | |
| sentence_ids.append(sentence_id) | |
| html_parts.append("</div></div>") | |
| # Create probability progression chart | |
| fig = go.Figure() | |
| colors = ['#22c55e' if p > 0.5 else '#ef4444' for p in prob_values] | |
| fig.add_trace(go.Scatter( | |
| x=[int(s) if isinstance(s, (int, np.integer)) else s for s in sentence_ids], | |
| y=[float(p) for p in prob_values], | |
| mode='lines+markers', | |
| name='Success Probability', | |
| line=dict(color='#6366f1', width=2), | |
| marker=dict(size=10, color=colors) | |
| )) | |
| fig.add_hline(y=0.5, line_dash="dash", line_color="gray", | |
| annotation_text="50% threshold") | |
| fig.update_layout( | |
| title="Probability Progression Through Reasoning", | |
| xaxis_title="Sentence ID", | |
| yaxis_title="Success Probability", | |
| yaxis_range=[0, 1], | |
| template="plotly_dark", | |
| height=300 | |
| ) | |
| return "\n".join(html_parts), fig | |
| def create_statistics_dashboard(df: pd.DataFrame) -> Tuple[str, go.Figure]: | |
| """Create statistics dashboard for the dataset.""" | |
| if df.empty: | |
| return "No data available", go.Figure() | |
| dataset_type = detect_dataset_type(df) | |
| # Build statistics | |
| stats = { | |
| "Total Items": len(df), | |
| "Dataset Type": dataset_type, | |
| } | |
| if 'is_positive' in df.columns: | |
| positive_count = df['is_positive'].sum() | |
| stats["Positive Items"] = int(positive_count) | |
| stats["Negative Items"] = int(len(df) - positive_count) | |
| if 'prob_delta' in df.columns: | |
| stats["Avg Prob Delta"] = f"{df['prob_delta'].mean():.3f}" | |
| stats["Max Prob Delta"] = f"{df['prob_delta'].max():.3f}" | |
| if 'importance_score' in df.columns: | |
| stats["Avg Importance"] = f"{df['importance_score'].mean():.3f}" | |
| if 'sentence_category' in df.columns: | |
| category_counts = df['sentence_category'].value_counts() | |
| stats["Categories"] = len(category_counts) | |
| if 'model_id' in df.columns: | |
| stats["Models"] = df['model_id'].nunique() | |
| # Build HTML | |
| html_parts = ['<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">'] | |
| for key, value in stats.items(): | |
| html_parts.append(f""" | |
| <div style="background: linear-gradient(135deg, #1e3a5f 0%, #0d1b2a 100%); | |
| padding: 20px; border-radius: 10px; text-align: center;"> | |
| <div style="color: #6366f1; font-size: 1.5em; font-weight: bold;">{value}</div> | |
| <div style="color: #a0a0a0; font-size: 0.9em; margin-top: 5px;">{key}</div> | |
| </div> | |
| """) | |
| html_parts.append('</div>') | |
| # Determine what to show in second chart | |
| second_chart_title = "Category Distribution" | |
| if 'sentence_category' in df.columns: | |
| second_chart_title = "Sentence Category" | |
| elif 'reasoning_pattern' in df.columns: | |
| second_chart_title = "Reasoning Pattern" | |
| elif 'task_type' in df.columns: | |
| second_chart_title = "Task Type" | |
| elif 'is_positive' in df.columns: | |
| second_chart_title = "Positive vs Negative" | |
| # Create distribution charts | |
| fig = make_subplots(rows=1, cols=2, | |
| subplot_titles=("Probability Delta Distribution", second_chart_title)) | |
| # First chart: Probability Delta histogram (using numpy for binning) | |
| if 'prob_delta' in df.columns and len(df['prob_delta'].dropna()) > 0: | |
| prob_data = df['prob_delta'].dropna().values | |
| # Create histogram manually using numpy | |
| counts, bin_edges = np.histogram(prob_data, bins=30) | |
| bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)] | |
| fig.add_trace( | |
| go.Bar(x=bin_centers, y=counts.tolist(), name="Prob Delta", | |
| marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9), | |
| row=1, col=1 | |
| ) | |
| elif 'prob_after' in df.columns and len(df['prob_after'].dropna()) > 0: | |
| # Fallback: show prob_after distribution | |
| prob_data = df['prob_after'].dropna().values | |
| counts, bin_edges = np.histogram(prob_data, bins=30) | |
| bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)] | |
| fig.add_trace( | |
| go.Bar(x=bin_centers, y=counts.tolist(), name="Prob After", | |
| marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9), | |
| row=1, col=1 | |
| ) | |
| # Second chart: Categories, patterns, or task types | |
| if 'sentence_category' in df.columns: | |
| category_counts = df['sentence_category'].value_counts() | |
| fig.add_trace( | |
| go.Bar(x=category_counts.index.tolist(), y=category_counts.values.tolist(), name="Categories", | |
| marker_color='#22c55e'), | |
| row=1, col=2 | |
| ) | |
| elif 'reasoning_pattern' in df.columns: | |
| pattern_counts = df['reasoning_pattern'].value_counts() | |
| fig.add_trace( | |
| go.Bar(x=pattern_counts.index.tolist(), y=pattern_counts.values.tolist(), name="Patterns", | |
| marker_color='#22c55e'), | |
| row=1, col=2 | |
| ) | |
| elif 'task_type' in df.columns: | |
| task_counts = df['task_type'].value_counts() | |
| fig.add_trace( | |
| go.Bar(x=task_counts.index.tolist(), y=task_counts.values.tolist(), name="Task Types", | |
| marker_color='#22c55e'), | |
| row=1, col=2 | |
| ) | |
| elif 'is_positive' in df.columns: | |
| pos_neg_counts = df['is_positive'].value_counts() | |
| labels = ['Positive' if v else 'Negative' for v in pos_neg_counts.index.tolist()] | |
| fig.add_trace( | |
| go.Bar(x=labels, y=pos_neg_counts.values.tolist(), name="Impact", | |
| marker_color=['#22c55e' if l == 'Positive' else '#ef4444' for l in labels]), | |
| row=1, col=2 | |
| ) | |
| fig.update_layout( | |
| template="plotly_dark", | |
| height=350, | |
| showlegend=False | |
| ) | |
| return "\n".join(html_parts), fig | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| # Global state for loaded data | |
| current_data = {"df": pd.DataFrame(), "type": "unknown"} | |
| def load_dataset_action(source_type: str, dataset_id: str, file_upload): | |
| """Handle dataset loading and return all visualization updates.""" | |
| global current_data | |
| if source_type == "HuggingFace Hub": | |
| if not dataset_id: | |
| empty_fig = go.Figure() | |
| empty_fig.update_layout(template="plotly_dark") | |
| return ("Please enter a dataset ID", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, | |
| gr.update(maximum=0), gr.update(choices=[], value=None)) | |
| df, msg = load_hf_dataset(dataset_id) | |
| else: # Local File | |
| if file_upload is None: | |
| empty_fig = go.Figure() | |
| empty_fig.update_layout(template="plotly_dark") | |
| return ("Please upload a file", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, | |
| gr.update(maximum=0), gr.update(choices=[], value=None)) | |
| df, msg = load_jsonl_file(file_upload.name) | |
| if df.empty: | |
| empty_fig = go.Figure() | |
| empty_fig.update_layout(template="plotly_dark") | |
| return (msg, "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig, | |
| gr.update(maximum=0), gr.update(choices=[], value=None)) | |
| current_data["df"] = df | |
| current_data["type"] = detect_dataset_type(df) | |
| columns_info = f"Columns: {', '.join(df.columns[:10])}" | |
| if len(df.columns) > 10: | |
| columns_info += f" ... and {len(df.columns) - 10} more" | |
| # Generate all visualizations | |
| stats_html, stats_fig = create_statistics_dashboard(df) | |
| graph_fig = create_thought_anchor_graph(df) | |
| embed_fig = create_embedding_visualization(df) | |
| circuit_html, circuit_fig = create_circuit_visualization(df) | |
| # Generate query list | |
| query_choices = [] | |
| if 'query' in df.columns: | |
| queries = df['query'].unique().tolist() | |
| for i, q in enumerate(queries): | |
| q_str = str(q) if q is not None else "" | |
| if len(q_str) > 80: | |
| query_choices.append(f"[{i+1}] {q_str[:77]}...") | |
| else: | |
| query_choices.append(f"[{i+1}] {q_str}") | |
| return (msg, f"Dataset type: {current_data['type']}\n{columns_info}", | |
| stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig, | |
| gr.update(maximum=max(0, len(df) - 1)), | |
| gr.update(choices=query_choices, value=None)) | |
| def get_token_details(idx: int) -> Tuple[str, go.Figure]: | |
| """Get details for a specific pivotal token.""" | |
| df = current_data["df"] | |
| dataset_type = current_data.get("type", "unknown") | |
| if df.empty: | |
| return "No data available. Please load a dataset first.", go.Figure() | |
| # Handle unsupported dataset types | |
| if dataset_type == 'dpo_pairs': | |
| html = """ | |
| <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;"> | |
| <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3> | |
| <p style="color: #a0a0a0;">This visualization is not available for DPO pairs datasets.</p> | |
| <p style="color: #a0a0a0;">DPO pairs contain prompt/chosen/rejected structure without token-level context.</p> | |
| <p style="color: #6366f1; margin-top: 20px;"> | |
| Try loading a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset instead. | |
| </p> | |
| </div> | |
| """ | |
| return html, go.Figure() | |
| if idx >= len(df): | |
| return "Index out of range", go.Figure() | |
| row = df.iloc[idx] | |
| context = row.get('pivot_context', row.get('prefix_context', '')) | |
| token = row.get('pivot_token', row.get('sentence', '')) | |
| prob_delta = row.get('prob_delta', 0) | |
| prob_before = row.get('prob_before', row.get('prob_with_sentence', 0.5)) | |
| prob_after = row.get('prob_after', row.get('prob_without_sentence', 0.5)) | |
| # Handle missing data | |
| if not context and not token: | |
| html = """ | |
| <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;"> | |
| <h3 style="color: #ef4444;">Missing Data</h3> | |
| <p style="color: #a0a0a0;">This dataset doesn't have the expected fields for token visualization.</p> | |
| </div> | |
| """ | |
| return html, go.Figure() | |
| html = create_token_highlight_html(context, token, prob_delta) | |
| chart = create_probability_chart(prob_before, prob_after) | |
| return html, chart | |
| def get_original_query_from_label(label: str) -> str: | |
| """Extract original query from truncated dropdown label like '[1] query...'""" | |
| if not label or not isinstance(label, str): | |
| return None | |
| df = current_data["df"] | |
| if df.empty or 'query' not in df.columns: | |
| return None | |
| # Extract index from "[N] query..." format | |
| match = re.match(r'\[(\d+)\]', label) | |
| if match: | |
| idx = int(match.group(1)) - 1 # Convert to 0-based index | |
| queries = df['query'].unique().tolist() | |
| if 0 <= idx < len(queries): | |
| return queries[idx] | |
| return None | |
| def update_graph_visualization(query_dropdown: str = None): | |
| """Update the thought anchor graph.""" | |
| dataset_type = current_data.get("type", "unknown") | |
| if dataset_type == 'dpo_pairs': | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Reasoning Graph is not available for DPO pairs datasets.<br>Load a pivotal_tokens or thought_anchors dataset.", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=14, color="#a0a0a0") | |
| ) | |
| fig.update_layout(template="plotly_dark", height=400) | |
| return fig | |
| # Convert truncated label back to original query | |
| original_query = get_original_query_from_label(query_dropdown) | |
| return create_thought_anchor_graph(current_data["df"], original_query) | |
| def update_embedding_visualization(color_by: str): | |
| """Update the embedding visualization.""" | |
| dataset_type = current_data.get("type", "unknown") | |
| if dataset_type == 'dpo_pairs': | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Embedding Space is not available for DPO pairs datasets.<br>Load a pivotal_tokens, thought_anchors, or steering_vectors dataset.", | |
| xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, | |
| font=dict(size=14, color="#a0a0a0") | |
| ) | |
| fig.update_layout(template="plotly_dark", height=400) | |
| return fig | |
| return create_embedding_visualization(current_data["df"], color_by) | |
| def update_circuit_view(query_idx: int): | |
| """Update the circuit view.""" | |
| dataset_type = current_data.get("type", "unknown") | |
| if dataset_type == 'dpo_pairs': | |
| html = """ | |
| <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;"> | |
| <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3> | |
| <p style="color: #a0a0a0;">Circuit Tracer is not available for DPO pairs datasets.</p> | |
| <p style="color: #6366f1; margin-top: 20px;"> | |
| Load a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset to explore reasoning circuits. | |
| </p> | |
| </div> | |
| """ | |
| return html, go.Figure() | |
| return create_circuit_visualization(current_data["df"], int(query_idx)) | |
| def update_statistics(): | |
| """Update the statistics dashboard.""" | |
| return create_statistics_dashboard(current_data["df"]) | |
| def get_query_list(): | |
| """Get list of unique queries with truncated display labels.""" | |
| df = current_data["df"] | |
| if df.empty or 'query' not in df.columns: | |
| return gr.update(choices=[], value=None) | |
| queries = df['query'].unique().tolist() | |
| # Return simple truncated strings for dropdown choices | |
| truncated_queries = [] | |
| for i, q in enumerate(queries): | |
| q_str = str(q) if q is not None else "" | |
| if len(q_str) > 80: | |
| truncated_queries.append(f"[{i+1}] {q_str[:77]}...") | |
| else: | |
| truncated_queries.append(f"[{i+1}] {q_str}") | |
| return gr.update(choices=truncated_queries, value=None) | |
| def refresh_all(): | |
| """Refresh all visualizations.""" | |
| df = current_data["df"] | |
| if df.empty: | |
| empty_fig = go.Figure() | |
| empty_fig.update_layout(template="plotly_dark") | |
| return ( | |
| "No data loaded", | |
| empty_fig, | |
| empty_fig, | |
| empty_fig, | |
| "No data loaded", | |
| empty_fig | |
| ) | |
| stats_html, stats_fig = create_statistics_dashboard(df) | |
| graph_fig = create_thought_anchor_graph(df) | |
| embed_fig = create_embedding_visualization(df) | |
| circuit_html, circuit_fig = create_circuit_visualization(df) | |
| return stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig | |
| # ============================================================================ | |
| # Build Gradio App | |
| # ============================================================================ | |
| # Pre-defined HuggingFace datasets | |
| HF_DATASETS = [ | |
| "codelion/Qwen3-0.6B-pts", | |
| "codelion/Qwen3-0.6B-pts-thought-anchors", | |
| "codelion/Qwen3-0.6B-pts-steering-vectors", | |
| "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts", | |
| "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-thought-anchors", | |
| "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors", | |
| ] | |
| # CSS configuration | |
| CSS = """ | |
| .gradio-container { max-width: 1400px !important; } | |
| .main-header { text-align: center; margin-bottom: 20px; } | |
| """ | |
| with gr.Blocks(title="PTS Visualizer", css=CSS) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # PTS Visualizer | |
| ### Interactive Exploration of Pivotal Tokens, Thought Anchors & Reasoning Circuits | |
| A [Neuronpedia](https://neuronpedia.org/)-inspired platform for understanding how language models reason. | |
| Load datasets from HuggingFace Hub or upload your own JSONL files. | |
| 🔗 [Browse more PTS datasets on HuggingFace](https://huggingface.co/datasets?other=pts) | |
| """) | |
| # Data Loading Section | |
| with gr.Accordion("Load Dataset", open=True): | |
| with gr.Row(): | |
| source_type = gr.Radio( | |
| choices=["HuggingFace Hub", "Local File"], | |
| value="HuggingFace Hub", | |
| label="Data Source" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=HF_DATASETS, | |
| value=HF_DATASETS[0], | |
| label="Select Dataset", | |
| info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID" | |
| ) | |
| with gr.Column(scale=1): | |
| file_upload = gr.File( | |
| label="Or Upload JSONL", | |
| file_types=[".jsonl", ".json"] | |
| ) | |
| with gr.Row(): | |
| load_btn = gr.Button("Load Dataset", variant="primary") | |
| refresh_btn = gr.Button("Refresh Visualizations", variant="secondary") | |
| with gr.Row(): | |
| load_status = gr.Textbox(label="Status", interactive=False) | |
| dataset_info = gr.Textbox(label="Dataset Info", interactive=False) | |
| # Main Visualization Tabs | |
| with gr.Tabs(): | |
| # Overview Tab | |
| with gr.TabItem("Overview"): | |
| gr.Markdown("### Dataset Statistics") | |
| stats_html = gr.HTML() | |
| stats_chart = gr.Plot() | |
| # Token Explorer Tab | |
| with gr.TabItem("Token Explorer"): | |
| gr.Markdown("### Explore Pivotal Tokens") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| token_slider = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=0, | |
| label="Token Index" | |
| ) | |
| with gr.Column(scale=3): | |
| token_html = gr.HTML(label="Token in Context") | |
| prob_chart = gr.Plot(label="Probability Change") | |
| # Thought Anchor Graph Tab | |
| with gr.TabItem("Reasoning Graph"): | |
| gr.Markdown("### Thought Anchor Dependency Graph") | |
| gr.Markdown(""" | |
| *Visualizes causal dependencies between reasoning steps. | |
| Green nodes indicate positive impact, red nodes indicate negative impact. | |
| Node size reflects importance score.* | |
| """) | |
| with gr.Row(): | |
| query_filter = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="Filter by Query" | |
| ) | |
| graph_plot = gr.Plot() | |
| # Embedding Visualization Tab | |
| with gr.TabItem("Embedding Space"): | |
| gr.Markdown("### Embedding Space Visualization") | |
| gr.Markdown("*t-SNE projection of sentence/token embeddings. Explore clusters and patterns.*") | |
| with gr.Row(): | |
| color_dropdown = gr.Dropdown( | |
| choices=["is_positive", "sentence_category", "reasoning_pattern", "task_type"], | |
| value="is_positive", | |
| label="Color By" | |
| ) | |
| embed_plot = gr.Plot() | |
| # Circuit Tracer Tab | |
| with gr.TabItem("Circuit Tracer"): | |
| gr.Markdown("### Step-by-Step Reasoning Circuit") | |
| gr.Markdown("*Walk through the reasoning process step by step. See how each step affects the probability of success.*") | |
| with gr.Row(): | |
| circuit_query_idx = gr.Slider( | |
| minimum=0, maximum=100, step=1, value=0, | |
| label="Query Index" | |
| ) | |
| circuit_html = gr.HTML() | |
| circuit_chart = gr.Plot() | |
| # Event handlers - using api_name=False to prevent schema generation issues | |
| load_btn.click( | |
| fn=load_dataset_action, | |
| inputs=[source_type, dataset_dropdown, file_upload], | |
| outputs=[load_status, dataset_info, stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart, token_slider, query_filter], | |
| api_name=False | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_all, | |
| outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart], | |
| api_name=False | |
| ) | |
| token_slider.change( | |
| fn=get_token_details, | |
| inputs=[token_slider], | |
| outputs=[token_html, prob_chart], | |
| api_name=False | |
| ) | |
| query_filter.change( | |
| fn=update_graph_visualization, | |
| inputs=[query_filter], | |
| outputs=[graph_plot], | |
| api_name=False | |
| ) | |
| color_dropdown.change( | |
| fn=update_embedding_visualization, | |
| inputs=[color_dropdown], | |
| outputs=[embed_plot], | |
| api_name=False | |
| ) | |
| circuit_query_idx.change( | |
| fn=update_circuit_view, | |
| inputs=[circuit_query_idx], | |
| outputs=[circuit_html, circuit_chart], | |
| api_name=False | |
| ) | |
| # ============================================================================ | |
| # Main Entry Point | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo.launch() | |