codelion commited on
Commit
3f9de27
Β·
verified Β·
1 Parent(s): 8afb4ca

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +96 -7
  2. app.py +1222 -0
  3. requirements.txt +8 -0
README.md CHANGED
@@ -1,14 +1,103 @@
1
  ---
2
- title: Pts Visualizer
3
- emoji: πŸ“ˆ
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.1.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: PTS Visualizer
 
 
 
 
 
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: PTS Visualizer
3
+ emoji: πŸ”
4
+ colorFrom: indigo
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ tags:
12
+ - pts
13
+ - pivotal-tokens
14
+ - thought-anchors
15
+ - llm-interpretability
16
+ - reasoning
17
+ - visualization
18
  ---
19
 
20
+ # PTS Visualizer
21
+
22
+ Interactive visualization platform for exploring **Pivotal Tokens**, **Thought Anchors**, and **Reasoning Circuits** in language models.
23
+
24
+ Inspired by [Neuronpedia](https://neuronpedia.org/), this tool helps researchers and practitioners understand how language models reason through complex tasks.
25
+
26
+ ## Features
27
+
28
+ ### πŸ“Š Overview Dashboard
29
+ - Dataset statistics and distributions
30
+ - Quick summary of positive/negative impacts
31
+ - Category and pattern analysis
32
+
33
+ ### πŸ” Token Explorer
34
+ - Highlight pivotal tokens in context
35
+ - Visualize probability changes before/after tokens
36
+ - Explore token-level impacts on success
37
+
38
+ ### πŸ•ΈοΈ Reasoning Graph
39
+ - Interactive dependency graph for thought anchors
40
+ - Visualize causal relationships between reasoning steps
41
+ - Color-coded by impact (green = positive, red = negative)
42
+ - Node size indicates importance
43
+
44
+ ### πŸ—ΊοΈ Embedding Space
45
+ - t-SNE visualization of sentence/token embeddings
46
+ - Color by category, pattern, or impact
47
+ - Explore clusters and patterns in reasoning
48
+
49
+ ### ⚑ Circuit Tracer
50
+ - Step-by-step walkthrough of reasoning traces
51
+ - Probability progression chart
52
+ - Verification scores and error detection
53
+
54
+ ## Supported Datasets
55
+
56
+ Load from HuggingFace Hub:
57
+ - `codelion/Qwen3-0.6B-pts` - Pivotal tokens
58
+ - `codelion/Qwen3-0.6B-pts-thought-anchors` - Thought anchors
59
+ - `codelion/Qwen3-0.6B-pts-steering-vectors` - Steering vectors
60
+ - `codelion/Qwen3-0.6B-pts-dpo-pairs` - DPO training pairs
61
+ - `codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-thought-anchors`
62
+
63
+ Or upload your own JSONL files!
64
+
65
+ ## How to Use
66
+
67
+ 1. **Select a data source**: Choose HuggingFace Hub or upload a local file
68
+ 2. **Load the dataset**: Click "Load Dataset"
69
+ 3. **Explore**: Navigate through the tabs to visualize different aspects
70
+
71
+ ## Local Development
72
+
73
+ ```bash
74
+ # Clone the repository
75
+ git clone https://github.com/codelion/pts
76
+ cd pts/visualizer
77
+
78
+ # Install dependencies
79
+ pip install -r requirements.txt
80
+
81
+ # Run the app
82
+ python app.py
83
+ ```
84
+
85
+ ## Related Resources
86
+
87
+ - [PTS GitHub Repository](https://github.com/codelion/pts)
88
+ - [Pivotal Token Search Collection](https://huggingface.co/collections/codelion/pivotal-token-search)
89
+ - [OptiLLM](https://github.com/codelion/optillm) - Inference optimization library
90
+
91
+ ## Citation
92
+
93
+ If you use this tool in your research, please cite:
94
+
95
+ ```bibtex
96
+ @software{pts,
97
+ title = {PTS: Pivotal Token Search},
98
+ author = {Asankhaya Sharma},
99
+ year = {2025},
100
+ publisher = {GitHub},
101
+ url = {https://github.com/codelion/pts}
102
+ }
103
+ ```
app.py ADDED
@@ -0,0 +1,1222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PTS Visualizer - Interactive visualization for Pivotal Token Search
3
+
4
+ A Neuronpedia-inspired platform for exploring pivotal tokens, thought anchors,
5
+ and reasoning circuits in language models.
6
+ """
7
+
8
+ import gradio as gr
9
+ import plotly.express as px
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+ import networkx as nx
13
+ import pandas as pd
14
+ import numpy as np
15
+ import json
16
+ import html as html_lib
17
+ from typing import List, Dict, Any, Optional, Tuple
18
+ from datasets import load_dataset
19
+ from sklearn.manifold import TSNE
20
+ from sklearn.decomposition import PCA
21
+ import re
22
+ from collections import defaultdict
23
+
24
+ # ============================================================================
25
+ # Data Loading Functions
26
+ # ============================================================================
27
+
28
+ def load_hf_dataset(dataset_id: str, split: str = "train") -> pd.DataFrame:
29
+ """Load a dataset from HuggingFace Hub."""
30
+ try:
31
+ dataset = load_dataset(dataset_id, split=split)
32
+ df = pd.DataFrame(dataset)
33
+ return df, f"Loaded {len(df)} items from {dataset_id}"
34
+ except Exception as e:
35
+ return pd.DataFrame(), f"Error loading dataset: {str(e)}"
36
+
37
+
38
+ def load_jsonl_file(file_path: str) -> pd.DataFrame:
39
+ """Load data from a local JSONL file."""
40
+ try:
41
+ data = []
42
+ with open(file_path, 'r') as f:
43
+ for line in f:
44
+ if line.strip():
45
+ data.append(json.loads(line))
46
+ return pd.DataFrame(data), f"Loaded {len(data)} items from file"
47
+ except Exception as e:
48
+ return pd.DataFrame(), f"Error loading file: {str(e)}"
49
+
50
+
51
+ def detect_dataset_type(df: pd.DataFrame) -> str:
52
+ """Detect the type of PTS dataset."""
53
+ columns = set(df.columns)
54
+
55
+ if 'sentence' in columns and 'sentence_id' in columns:
56
+ return 'thought_anchors'
57
+ elif 'steering_vector' in columns:
58
+ return 'steering_vectors'
59
+ elif 'chosen' in columns and 'rejected' in columns:
60
+ return 'dpo_pairs'
61
+ elif 'pivot_token' in columns:
62
+ return 'pivotal_tokens'
63
+ else:
64
+ return 'unknown'
65
+
66
+
67
+ # ============================================================================
68
+ # Visualization Components
69
+ # ============================================================================
70
+
71
+ def create_token_highlight_html(context: str, token: str, prob_delta: float) -> str:
72
+ """Create HTML with highlighted pivotal token showing full context."""
73
+ # Escape HTML characters
74
+ context_escaped = html_lib.escape(str(context))
75
+ token_escaped = html_lib.escape(str(token))
76
+
77
+ # Determine color based on probability delta
78
+ if prob_delta > 0:
79
+ # Positive impact - green gradient
80
+ intensity = min(abs(prob_delta) * 2, 1.0)
81
+ color = f"rgba(34, 197, 94, {intensity})"
82
+ border_color = "#22c55e"
83
+ impact_text = "Positive Impact"
84
+ else:
85
+ # Negative impact - red gradient
86
+ intensity = min(abs(prob_delta) * 2, 1.0)
87
+ color = f"rgba(239, 68, 68, {intensity})"
88
+ border_color = "#ef4444"
89
+ impact_text = "Negative Impact"
90
+
91
+ # Create highlighted token span
92
+ token_span = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; border: 2px solid {border_color}; font-weight: bold; font-size: 1.1em;">{token_escaped}</span>'
93
+
94
+ return f"""
95
+ <div style="background-color: #1a1a2e; border-radius: 10px; padding: 20px;">
96
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px;">
97
+ <span style="color: #a0a0a0; font-size: 0.9em;">Context Length: {len(context)} characters</span>
98
+ <span style="background-color: {border_color}; color: white; padding: 4px 12px; border-radius: 5px; font-weight: bold;">
99
+ {impact_text}: {'+' if prob_delta > 0 else ''}{prob_delta:.3f}
100
+ </span>
101
+ </div>
102
+ <div style="font-family: monospace; padding: 15px; background-color: #0d1117; border-radius: 8px; color: #e0e0e0; line-height: 1.8; max-height: 500px; overflow-y: auto; white-space: pre-wrap; word-break: break-word; border: 1px solid #30363d;">
103
+ <span style="color: #8b949e;">{context_escaped}</span>{token_span}
104
+ </div>
105
+ <div style="margin-top: 15px; display: flex; gap: 10px; flex-wrap: wrap;">
106
+ <span style="background-color: #238636; color: white; padding: 5px 10px; border-radius: 5px; font-size: 0.9em;">
107
+ Token: <code style="background-color: rgba(0,0,0,0.3); padding: 2px 5px; border-radius: 3px;">{token_escaped}</code>
108
+ </span>
109
+ </div>
110
+ </div>
111
+ """
112
+
113
+
114
+ def create_probability_chart(prob_before: float, prob_after: float) -> go.Figure:
115
+ """Create a bar chart showing probability change."""
116
+ fig = go.Figure()
117
+
118
+ fig.add_trace(go.Bar(
119
+ x=['Before Token', 'After Token'],
120
+ y=[prob_before, prob_after],
121
+ marker_color=['#6366f1', '#22c55e' if prob_after > prob_before else '#ef4444'],
122
+ text=[f'{prob_before:.3f}', f'{prob_after:.3f}'],
123
+ textposition='outside'
124
+ ))
125
+
126
+ fig.update_layout(
127
+ title="Success Probability Change",
128
+ yaxis_title="Probability",
129
+ yaxis_range=[0, 1],
130
+ template="plotly_dark",
131
+ height=300
132
+ )
133
+
134
+ return fig
135
+
136
+
137
+ def create_pivotal_token_flow(df: pd.DataFrame, selected_query: str = None) -> go.Figure:
138
+ """Create a visualization for pivotal tokens showing token impact flow."""
139
+ if df.empty:
140
+ fig = go.Figure()
141
+ fig.add_annotation(text="No data available",
142
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
143
+ fig.update_layout(template="plotly_dark")
144
+ return fig
145
+
146
+ # Filter by query if specified (handle None, empty string, or actual query)
147
+ if selected_query and isinstance(selected_query, str) and selected_query.strip() and 'query' in df.columns:
148
+ df = df[df['query'] == selected_query].copy()
149
+
150
+ if df.empty:
151
+ fig = go.Figure()
152
+ fig.add_annotation(text="No data for selected query",
153
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
154
+ fig.update_layout(template="plotly_dark")
155
+ return fig
156
+
157
+ # Create scatter plot of tokens by probability delta
158
+ fig = go.Figure()
159
+
160
+ # Separate positive and negative tokens
161
+ positive_df = df[df.get('is_positive', df['prob_delta'] > 0) == True] if 'is_positive' in df.columns else df[df['prob_delta'] > 0]
162
+ negative_df = df[df.get('is_positive', df['prob_delta'] > 0) == False] if 'is_positive' in df.columns else df[df['prob_delta'] <= 0]
163
+
164
+ # Add positive tokens
165
+ if not positive_df.empty:
166
+ hover_text = [
167
+ f"Token: {row.get('pivot_token', 'N/A')}<br>"
168
+ f"Ξ” Prob: +{row.get('prob_delta', 0):.3f}<br>"
169
+ f"Before: {row.get('prob_before', 0):.3f}<br>"
170
+ f"After: {row.get('prob_after', 0):.3f}<br>"
171
+ f"Query: {str(row.get('query', ''))[:50]}..."
172
+ for _, row in positive_df.iterrows()
173
+ ]
174
+ fig.add_trace(go.Scatter(
175
+ x=list(range(len(positive_df))),
176
+ y=positive_df['prob_delta'].values,
177
+ mode='markers',
178
+ name='Positive Impact',
179
+ marker=dict(
180
+ size=10 + positive_df['prob_delta'].abs().values * 30,
181
+ color='#22c55e',
182
+ opacity=0.7
183
+ ),
184
+ hovertext=hover_text,
185
+ hoverinfo='text'
186
+ ))
187
+
188
+ # Add negative tokens
189
+ if not negative_df.empty:
190
+ hover_text = [
191
+ f"Token: {row.get('pivot_token', 'N/A')}<br>"
192
+ f"Ξ” Prob: {row.get('prob_delta', 0):.3f}<br>"
193
+ f"Before: {row.get('prob_before', 0):.3f}<br>"
194
+ f"After: {row.get('prob_after', 0):.3f}<br>"
195
+ f"Query: {str(row.get('query', ''))[:50]}..."
196
+ for _, row in negative_df.iterrows()
197
+ ]
198
+ fig.add_trace(go.Scatter(
199
+ x=list(range(len(negative_df))),
200
+ y=negative_df['prob_delta'].values,
201
+ mode='markers',
202
+ name='Negative Impact',
203
+ marker=dict(
204
+ size=10 + negative_df['prob_delta'].abs().values * 30,
205
+ color='#ef4444',
206
+ opacity=0.7
207
+ ),
208
+ hovertext=hover_text,
209
+ hoverinfo='text'
210
+ ))
211
+
212
+ fig.add_hline(y=0, line_dash="dash", line_color="gray")
213
+
214
+ fig.update_layout(
215
+ title="Pivotal Token Impact Distribution",
216
+ xaxis_title="Token Index",
217
+ yaxis_title="Probability Delta",
218
+ template="plotly_dark",
219
+ height=500,
220
+ showlegend=True
221
+ )
222
+
223
+ return fig
224
+
225
+
226
+ def create_thought_anchor_graph(df: pd.DataFrame, selected_query: str = None) -> go.Figure:
227
+ """Create an interactive graph visualization of thought anchor dependencies."""
228
+ dataset_type = detect_dataset_type(df)
229
+
230
+ # For pivotal tokens and steering vectors, create a token impact visualization
231
+ if dataset_type in ('pivotal_tokens', 'steering_vectors'):
232
+ return create_pivotal_token_flow(df, selected_query)
233
+
234
+ if df.empty or 'sentence_id' not in df.columns:
235
+ fig = go.Figure()
236
+ fig.add_annotation(text="No thought anchor data available. Load a thought anchors dataset to see the reasoning graph.",
237
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
238
+ font=dict(size=14, color="#a0a0a0"))
239
+ fig.update_layout(template="plotly_dark", height=400)
240
+ return fig
241
+
242
+ # Filter by query if specified (handle None, empty string, or actual query)
243
+ if selected_query and isinstance(selected_query, str) and selected_query.strip():
244
+ df = df[df['query'] == selected_query].copy()
245
+
246
+ if df.empty:
247
+ fig = go.Figure()
248
+ fig.add_annotation(text="No data for selected query",
249
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
250
+ fig.update_layout(template="plotly_dark")
251
+ return fig
252
+
253
+ # Create networkx graph
254
+ G = nx.DiGraph()
255
+
256
+ # Add nodes (sentences)
257
+ for idx, row in df.iterrows():
258
+ sentence_id = row.get('sentence_id', idx)
259
+ importance = row.get('importance_score', abs(row.get('prob_delta', 0)))
260
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
261
+ sentence = row.get('sentence', '')[:50] + '...' if len(row.get('sentence', '')) > 50 else row.get('sentence', '')
262
+
263
+ G.add_node(sentence_id,
264
+ importance=importance,
265
+ is_positive=is_positive,
266
+ sentence=sentence,
267
+ category=row.get('sentence_category', 'unknown'))
268
+
269
+ # Add edges from causal dependencies
270
+ for idx, row in df.iterrows():
271
+ sentence_id = row.get('sentence_id', idx)
272
+ dependencies = row.get('causal_dependencies', [])
273
+ if isinstance(dependencies, list):
274
+ for dep in dependencies:
275
+ if dep in G.nodes():
276
+ G.add_edge(dep, sentence_id)
277
+
278
+ # If no explicit dependencies, create sequential edges
279
+ if G.number_of_edges() == 0:
280
+ sorted_nodes = sorted(G.nodes())
281
+ for i in range(len(sorted_nodes) - 1):
282
+ G.add_edge(sorted_nodes[i], sorted_nodes[i+1])
283
+
284
+ # Layout
285
+ pos = nx.spring_layout(G, k=2, iterations=50)
286
+
287
+ # Create edge traces
288
+ edge_x = []
289
+ edge_y = []
290
+ for edge in G.edges():
291
+ x0, y0 = pos[edge[0]]
292
+ x1, y1 = pos[edge[1]]
293
+ edge_x.extend([x0, x1, None])
294
+ edge_y.extend([y0, y1, None])
295
+
296
+ edge_trace = go.Scatter(
297
+ x=edge_x, y=edge_y,
298
+ line=dict(width=1, color='#888'),
299
+ hoverinfo='none',
300
+ mode='lines'
301
+ )
302
+
303
+ # Create node traces
304
+ node_x = []
305
+ node_y = []
306
+ node_colors = []
307
+ node_sizes = []
308
+ node_texts = []
309
+
310
+ for node in G.nodes():
311
+ x, y = pos[node]
312
+ node_x.append(x)
313
+ node_y.append(y)
314
+
315
+ node_data = G.nodes[node]
316
+ is_positive = node_data.get('is_positive', True)
317
+ importance = node_data.get('importance', 0.3)
318
+
319
+ node_colors.append('#22c55e' if is_positive else '#ef4444')
320
+ node_sizes.append(20 + importance * 50)
321
+
322
+ hover_text = f"Sentence {node}<br>"
323
+ hover_text += f"Category: {node_data.get('category', 'unknown')}<br>"
324
+ hover_text += f"Importance: {importance:.3f}<br>"
325
+ hover_text += f"Text: {node_data.get('sentence', 'N/A')}"
326
+ node_texts.append(hover_text)
327
+
328
+ node_trace = go.Scatter(
329
+ x=node_x, y=node_y,
330
+ mode='markers+text',
331
+ hoverinfo='text',
332
+ text=[str(n) for n in G.nodes()],
333
+ textposition="top center",
334
+ hovertext=node_texts,
335
+ marker=dict(
336
+ color=node_colors,
337
+ size=node_sizes,
338
+ line=dict(width=2, color='white')
339
+ )
340
+ )
341
+
342
+ # Create figure
343
+ fig = go.Figure(data=[edge_trace, node_trace])
344
+
345
+ fig.update_layout(
346
+ title="Thought Anchor Reasoning Graph",
347
+ showlegend=False,
348
+ hovermode='closest',
349
+ template="plotly_dark",
350
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
351
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
352
+ height=500
353
+ )
354
+
355
+ return fig
356
+
357
+
358
+ def create_probability_space_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure:
359
+ """Create a probability space visualization for pivotal tokens (prob_before vs prob_after)."""
360
+ fig = go.Figure()
361
+
362
+ # Color palette for categorical values
363
+ CATEGORY_COLORS = [
364
+ '#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6',
365
+ '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16'
366
+ ]
367
+
368
+ # Determine color column
369
+ use_colorscale = False
370
+ if color_by in df.columns:
371
+ color_col = df[color_by]
372
+ if color_by == 'is_positive':
373
+ colors = ['#22c55e' if v else '#ef4444' for v in color_col]
374
+ else:
375
+ # Convert to list
376
+ values = color_col.tolist() if hasattr(color_col, 'tolist') else list(color_col)
377
+
378
+ if len(values) > 0:
379
+ # Check if numeric
380
+ if isinstance(values[0], (int, float)) and not isinstance(values[0], bool):
381
+ colors = values
382
+ use_colorscale = True
383
+ else:
384
+ # Categorical - map to colors
385
+ unique_vals = list(set(values))
386
+ color_map = {val: CATEGORY_COLORS[i % len(CATEGORY_COLORS)] for i, val in enumerate(unique_vals)}
387
+ colors = [color_map[v] for v in values]
388
+ else:
389
+ colors = ['#6366f1'] * len(df)
390
+ else:
391
+ colors = ['#6366f1'] * len(df)
392
+
393
+ # Create hover text
394
+ hover_texts = []
395
+ for _, row in df.iterrows():
396
+ text = f"Token: {row.get('pivot_token', 'N/A')}<br>"
397
+ text += f"Before: {row.get('prob_before', 0):.3f}<br>"
398
+ text += f"After: {row.get('prob_after', 0):.3f}<br>"
399
+ text += f"Delta: {row.get('prob_delta', 0):+.3f}<br>"
400
+ text += f"Query: {str(row.get('query', ''))[:40]}..."
401
+ hover_texts.append(text)
402
+
403
+ fig.add_trace(go.Scatter(
404
+ x=df['prob_before'],
405
+ y=df['prob_after'],
406
+ mode='markers',
407
+ marker=dict(
408
+ size=8,
409
+ color=colors,
410
+ opacity=0.6,
411
+ colorscale='Viridis' if use_colorscale else None,
412
+ showscale=use_colorscale
413
+ ),
414
+ hovertext=hover_texts,
415
+ hoverinfo='text',
416
+ name='Pivotal Tokens'
417
+ ))
418
+
419
+ # Add diagonal line (no change)
420
+ fig.add_trace(go.Scatter(
421
+ x=[0, 1],
422
+ y=[0, 1],
423
+ mode='lines',
424
+ line=dict(dash='dash', color='gray', width=1),
425
+ name='No Change Line',
426
+ showlegend=True
427
+ ))
428
+
429
+ fig.update_layout(
430
+ title="Probability Space: Before vs After Pivotal Token",
431
+ xaxis_title="Probability Before Token",
432
+ yaxis_title="Probability After Token",
433
+ xaxis=dict(range=[0, 1]),
434
+ yaxis=dict(range=[0, 1]),
435
+ template="plotly_dark",
436
+ height=500
437
+ )
438
+
439
+ # Add annotations
440
+ fig.add_annotation(
441
+ x=0.2, y=0.8,
442
+ text="Positive Impact ↑",
443
+ showarrow=False,
444
+ font=dict(color="#22c55e", size=12)
445
+ )
446
+ fig.add_annotation(
447
+ x=0.8, y=0.2,
448
+ text="Negative Impact ↓",
449
+ showarrow=False,
450
+ font=dict(color="#ef4444", size=12)
451
+ )
452
+
453
+ return fig
454
+
455
+
456
+ def create_embedding_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure:
457
+ """Create UMAP/t-SNE visualization of embeddings or alternative visualization for pivotal tokens."""
458
+ if df.empty:
459
+ fig = go.Figure()
460
+ fig.add_annotation(text="No data available",
461
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
462
+ fig.update_layout(template="plotly_dark")
463
+ return fig
464
+
465
+ dataset_type = detect_dataset_type(df)
466
+
467
+ # Check for embeddings
468
+ embedding_col = None
469
+ for col in ['sentence_embedding', 'steering_vector']:
470
+ if col in df.columns:
471
+ embedding_col = col
472
+ break
473
+
474
+ # For pivotal tokens without embeddings, create a probability space visualization
475
+ if embedding_col is None:
476
+ if dataset_type == 'pivotal_tokens' and 'prob_before' in df.columns and 'prob_after' in df.columns:
477
+ return create_probability_space_visualization(df, color_by)
478
+
479
+ fig = go.Figure()
480
+ fig.add_annotation(
481
+ text="No embedding data found. Embeddings are available in thought_anchors and steering_vectors datasets.",
482
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
483
+ font=dict(size=12, color="#a0a0a0")
484
+ )
485
+ fig.update_layout(template="plotly_dark", height=400)
486
+ return fig
487
+
488
+ # Extract embeddings
489
+ embeddings = []
490
+ valid_indices = []
491
+
492
+ for idx, row in df.iterrows():
493
+ emb = row.get(embedding_col, [])
494
+ if isinstance(emb, list) and len(emb) > 0:
495
+ embeddings.append(emb)
496
+ valid_indices.append(idx)
497
+
498
+ if len(embeddings) < 3:
499
+ fig = go.Figure()
500
+ fig.add_annotation(text="Not enough embeddings for visualization (need at least 3)",
501
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
502
+ fig.update_layout(template="plotly_dark")
503
+ return fig
504
+
505
+ embeddings = np.array(embeddings)
506
+
507
+ # Reduce dimensionality
508
+ n_samples = len(embeddings)
509
+ perplexity = min(30, max(5, n_samples // 3))
510
+
511
+ if embeddings.shape[1] > 50:
512
+ # First reduce with PCA
513
+ pca = PCA(n_components=min(50, n_samples - 1))
514
+ embeddings = pca.fit_transform(embeddings)
515
+
516
+ # Then t-SNE for visualization
517
+ tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
518
+ coords = tsne.fit_transform(embeddings)
519
+
520
+ # Create dataframe for plotting
521
+ plot_df = df.iloc[valid_indices].copy()
522
+ plot_df['x'] = coords[:, 0]
523
+ plot_df['y'] = coords[:, 1]
524
+
525
+ # Handle color column
526
+ if color_by not in plot_df.columns:
527
+ color_by = 'is_positive' if 'is_positive' in plot_df.columns else None
528
+
529
+ if color_by and color_by in plot_df.columns:
530
+ fig = px.scatter(
531
+ plot_df, x='x', y='y',
532
+ color=color_by,
533
+ hover_data=['sentence' if 'sentence' in plot_df.columns else 'pivot_token'],
534
+ title="Embedding Space Visualization (t-SNE)",
535
+ template="plotly_dark"
536
+ )
537
+ else:
538
+ fig = px.scatter(
539
+ plot_df, x='x', y='y',
540
+ hover_data=['sentence' if 'sentence' in plot_df.columns else 'pivot_token'],
541
+ title="Embedding Space Visualization (t-SNE)",
542
+ template="plotly_dark"
543
+ )
544
+
545
+ fig.update_layout(height=500)
546
+
547
+ return fig
548
+
549
+
550
+ def create_pivotal_token_trace(df: pd.DataFrame, selected_query: str) -> Tuple[str, go.Figure]:
551
+ """Create a trace visualization for pivotal tokens in a query."""
552
+ if df.empty:
553
+ return "No tokens found for this query", go.Figure()
554
+
555
+ # Build HTML for token cards
556
+ html_parts = [f"""
557
+ <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;">
558
+ <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;">
559
+ Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}
560
+ </h3>
561
+ <p style="color: #a0a0a0; margin: 10px 0;">Found {len(df)} pivotal tokens for this query</p>
562
+ <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;">
563
+ """]
564
+
565
+ prob_deltas = []
566
+ token_indices = []
567
+
568
+ for idx, (_, row) in enumerate(df.iterrows()):
569
+ token = row.get('pivot_token', 'N/A')
570
+ context = row.get('pivot_context', '')
571
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
572
+ prob_delta = row.get('prob_delta', 0)
573
+ prob_before = row.get('prob_before', 0)
574
+ prob_after = row.get('prob_after', 0)
575
+ task_type = row.get('task_type', 'unknown')
576
+
577
+ # Color based on impact
578
+ bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)"
579
+ border_color = "#22c55e" if is_positive else "#ef4444"
580
+
581
+ # Show full context in a scrollable container - no truncation
582
+ # Escape HTML characters in context and token
583
+ context_escaped = html_lib.escape(str(context))
584
+ token_escaped = html_lib.escape(str(token))
585
+
586
+ # Build token card with full context (scrollable)
587
+ card_html = f"""
588
+ <div style="background-color: {bg_color}; border-left: 4px solid {border_color};
589
+ padding: 15px; border-radius: 5px; margin-bottom: 5px;">
590
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;">
591
+ <span style="color: #a0a0a0; font-size: 0.9em;">Token #{idx + 1} | {task_type}</span>
592
+ <span style="color: {border_color}; font-weight: bold; font-size: 1.1em;">
593
+ {'+'if prob_delta > 0 else ''}{prob_delta:.3f}
594
+ </span>
595
+ </div>
596
+ <div style="background-color: #1a1a2e; padding: 10px; border-radius: 5px; max-height: 200px; overflow-y: auto; margin: 10px 0;">
597
+ <span style="color: #888; font-family: monospace; font-size: 0.85em; white-space: pre-wrap; word-break: break-word;">{context_escaped}</span><span style="background-color: {border_color}; color: white; padding: 2px 6px; border-radius: 3px; font-weight: bold; font-family: monospace;">{token_escaped}</span>
598
+ </div>
599
+ <div style="display: flex; gap: 15px; flex-wrap: wrap;">
600
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
601
+ Before: {prob_before:.3f}
602
+ </span>
603
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
604
+ After: {prob_after:.3f}
605
+ </span>
606
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #6366f1;">
607
+ Context: {len(context)} chars
608
+ </span>
609
+ </div>
610
+ </div>
611
+ """
612
+ html_parts.append(card_html)
613
+ prob_deltas.append(prob_delta)
614
+ token_indices.append(idx)
615
+
616
+ html_parts.append("</div></div>")
617
+
618
+ # Create probability delta chart
619
+ fig = go.Figure()
620
+
621
+ colors = ['#22c55e' if d > 0 else '#ef4444' for d in prob_deltas]
622
+
623
+ fig.add_trace(go.Bar(
624
+ x=token_indices,
625
+ y=prob_deltas,
626
+ marker_color=colors,
627
+ name='Probability Delta',
628
+ hovertemplate='Token #%{x}<br>Ξ” Prob: %{y:.3f}<extra></extra>'
629
+ ))
630
+
631
+ fig.add_hline(y=0, line_dash="dash", line_color="gray")
632
+
633
+ fig.update_layout(
634
+ title="Probability Impact per Token",
635
+ xaxis_title="Token Index",
636
+ yaxis_title="Probability Delta",
637
+ template="plotly_dark",
638
+ height=300
639
+ )
640
+
641
+ return "\n".join(html_parts), fig
642
+
643
+
644
+ def create_circuit_visualization(df: pd.DataFrame, query_idx: int = 0) -> Tuple[str, go.Figure]:
645
+ """Create step-by-step circuit visualization for reasoning trace."""
646
+ if df.empty:
647
+ return "No data available", go.Figure()
648
+
649
+ dataset_type = detect_dataset_type(df)
650
+
651
+ # Get unique queries
652
+ queries = df['query'].unique() if 'query' in df.columns else []
653
+ if len(queries) == 0:
654
+ return "No queries found", go.Figure()
655
+
656
+ query_idx = min(query_idx, len(queries) - 1)
657
+ selected_query = queries[query_idx]
658
+
659
+ # Filter to this query
660
+ query_df = df[df['query'] == selected_query].copy()
661
+
662
+ # For pivotal tokens and steering vectors, use the token trace visualization
663
+ if dataset_type in ('pivotal_tokens', 'steering_vectors'):
664
+ return create_pivotal_token_trace(query_df, selected_query)
665
+
666
+ # Sort by sentence_id if available, otherwise keep original order
667
+ if 'sentence_id' in query_df.columns:
668
+ query_df = query_df.sort_values('sentence_id')
669
+ else:
670
+ query_df = query_df.reset_index(drop=True)
671
+
672
+ # Build HTML for step-by-step view
673
+ html_parts = [f"""
674
+ <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;">
675
+ <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;">
676
+ Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}
677
+ </h3>
678
+ <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;">
679
+ """]
680
+
681
+ prob_values = []
682
+ sentence_ids = []
683
+
684
+ for idx, row in query_df.iterrows():
685
+ sentence = row.get('sentence', 'N/A')
686
+ sentence_id = row.get('sentence_id', idx)
687
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
688
+ prob_delta = row.get('prob_delta', 0)
689
+ category = row.get('sentence_category', 'unknown')
690
+ importance = row.get('importance_score', abs(prob_delta))
691
+
692
+ # Verification info
693
+ verification_score = row.get('verification_score', None)
694
+ arithmetic_errors = row.get('arithmetic_errors', [])
695
+
696
+ # Color based on impact
697
+ bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)"
698
+ border_color = "#22c55e" if is_positive else "#ef4444"
699
+
700
+ # Build step card
701
+ step_html = f"""
702
+ <div style="background-color: {bg_color}; border-left: 4px solid {border_color};
703
+ padding: 15px; border-radius: 5px;">
704
+ <div style="display: flex; justify-content: space-between; align-items: center;">
705
+ <span style="color: #a0a0a0; font-size: 0.9em;">Step {sentence_id} | {category}</span>
706
+ <span style="color: {border_color}; font-weight: bold;">
707
+ {'+'if prob_delta > 0 else ''}{prob_delta:.3f}
708
+ </span>
709
+ </div>
710
+ <p style="color: #e0e0e0; margin: 10px 0;">{sentence}</p>
711
+ <div style="display: flex; gap: 10px; flex-wrap: wrap;">
712
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
713
+ Importance: {importance:.3f}
714
+ </span>
715
+ """
716
+
717
+ if verification_score is not None:
718
+ v_color = "#22c55e" if verification_score > 0.5 else "#ef4444"
719
+ step_html += f"""
720
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: {v_color};">
721
+ Verification: {verification_score:.2f}
722
+ </span>
723
+ """
724
+
725
+ if arithmetic_errors:
726
+ step_html += """
727
+ <span style="background-color: #7f1d1d; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #fca5a5;">
728
+ Has Errors
729
+ </span>
730
+ """
731
+
732
+ step_html += """
733
+ </div>
734
+ </div>
735
+ """
736
+
737
+ html_parts.append(step_html)
738
+ prob_values.append(row.get('prob_with_sentence', 0.5))
739
+ sentence_ids.append(sentence_id)
740
+
741
+ html_parts.append("</div></div>")
742
+
743
+ # Create probability progression chart
744
+ fig = go.Figure()
745
+
746
+ colors = ['#22c55e' if p > 0.5 else '#ef4444' for p in prob_values]
747
+
748
+ fig.add_trace(go.Scatter(
749
+ x=sentence_ids,
750
+ y=prob_values,
751
+ mode='lines+markers',
752
+ name='Success Probability',
753
+ line=dict(color='#6366f1', width=2),
754
+ marker=dict(size=10, color=colors)
755
+ ))
756
+
757
+ fig.add_hline(y=0.5, line_dash="dash", line_color="gray",
758
+ annotation_text="50% threshold")
759
+
760
+ fig.update_layout(
761
+ title="Probability Progression Through Reasoning",
762
+ xaxis_title="Sentence ID",
763
+ yaxis_title="Success Probability",
764
+ yaxis_range=[0, 1],
765
+ template="plotly_dark",
766
+ height=300
767
+ )
768
+
769
+ return "\n".join(html_parts), fig
770
+
771
+
772
+ def create_statistics_dashboard(df: pd.DataFrame) -> Tuple[str, go.Figure]:
773
+ """Create statistics dashboard for the dataset."""
774
+ if df.empty:
775
+ return "No data available", go.Figure()
776
+
777
+ dataset_type = detect_dataset_type(df)
778
+
779
+ # Build statistics
780
+ stats = {
781
+ "Total Items": len(df),
782
+ "Dataset Type": dataset_type,
783
+ }
784
+
785
+ if 'is_positive' in df.columns:
786
+ positive_count = df['is_positive'].sum()
787
+ stats["Positive Items"] = int(positive_count)
788
+ stats["Negative Items"] = int(len(df) - positive_count)
789
+
790
+ if 'prob_delta' in df.columns:
791
+ stats["Avg Prob Delta"] = f"{df['prob_delta'].mean():.3f}"
792
+ stats["Max Prob Delta"] = f"{df['prob_delta'].max():.3f}"
793
+
794
+ if 'importance_score' in df.columns:
795
+ stats["Avg Importance"] = f"{df['importance_score'].mean():.3f}"
796
+
797
+ if 'sentence_category' in df.columns:
798
+ category_counts = df['sentence_category'].value_counts()
799
+ stats["Categories"] = len(category_counts)
800
+
801
+ if 'model_id' in df.columns:
802
+ stats["Models"] = df['model_id'].nunique()
803
+
804
+ # Build HTML
805
+ html_parts = ['<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">']
806
+
807
+ for key, value in stats.items():
808
+ html_parts.append(f"""
809
+ <div style="background: linear-gradient(135deg, #1e3a5f 0%, #0d1b2a 100%);
810
+ padding: 20px; border-radius: 10px; text-align: center;">
811
+ <div style="color: #6366f1; font-size: 1.5em; font-weight: bold;">{value}</div>
812
+ <div style="color: #a0a0a0; font-size: 0.9em; margin-top: 5px;">{key}</div>
813
+ </div>
814
+ """)
815
+
816
+ html_parts.append('</div>')
817
+
818
+ # Create distribution charts
819
+ fig = make_subplots(rows=1, cols=2,
820
+ subplot_titles=("Probability Delta Distribution", "Category Distribution"))
821
+
822
+ if 'prob_delta' in df.columns:
823
+ fig.add_trace(
824
+ go.Histogram(x=df['prob_delta'], nbinsx=30, name="Prob Delta",
825
+ marker_color='#6366f1'),
826
+ row=1, col=1
827
+ )
828
+
829
+ if 'sentence_category' in df.columns:
830
+ category_counts = df['sentence_category'].value_counts()
831
+ fig.add_trace(
832
+ go.Bar(x=category_counts.index, y=category_counts.values, name="Categories",
833
+ marker_color='#22c55e'),
834
+ row=1, col=2
835
+ )
836
+ elif 'reasoning_pattern' in df.columns:
837
+ pattern_counts = df['reasoning_pattern'].value_counts()
838
+ fig.add_trace(
839
+ go.Bar(x=pattern_counts.index, y=pattern_counts.values, name="Patterns",
840
+ marker_color='#22c55e'),
841
+ row=1, col=2
842
+ )
843
+
844
+ fig.update_layout(
845
+ template="plotly_dark",
846
+ height=350,
847
+ showlegend=False
848
+ )
849
+
850
+ return "\n".join(html_parts), fig
851
+
852
+
853
+ # ============================================================================
854
+ # Gradio Interface
855
+ # ============================================================================
856
+
857
+ # Global state for loaded data
858
+ current_data = {"df": pd.DataFrame(), "type": "unknown"}
859
+
860
+
861
+ def load_dataset_action(source_type: str, dataset_id: str, file_upload) -> Tuple[str, str]:
862
+ """Handle dataset loading."""
863
+ global current_data
864
+
865
+ if source_type == "HuggingFace Hub":
866
+ if not dataset_id:
867
+ return "Please enter a dataset ID", ""
868
+ df, msg = load_hf_dataset(dataset_id)
869
+ else: # Local File
870
+ if file_upload is None:
871
+ return "Please upload a file", ""
872
+ df, msg = load_jsonl_file(file_upload.name)
873
+
874
+ if df.empty:
875
+ return msg, ""
876
+
877
+ current_data["df"] = df
878
+ current_data["type"] = detect_dataset_type(df)
879
+
880
+ columns_info = f"Columns: {', '.join(df.columns[:10])}"
881
+ if len(df.columns) > 10:
882
+ columns_info += f" ... and {len(df.columns) - 10} more"
883
+
884
+ return msg, f"Dataset type: {current_data['type']}\n{columns_info}"
885
+
886
+
887
+ def get_token_details(idx: int) -> Tuple[str, go.Figure]:
888
+ """Get details for a specific pivotal token."""
889
+ df = current_data["df"]
890
+ dataset_type = current_data.get("type", "unknown")
891
+
892
+ if df.empty:
893
+ return "No data available. Please load a dataset first.", go.Figure()
894
+
895
+ # Handle unsupported dataset types
896
+ if dataset_type == 'dpo_pairs':
897
+ html = """
898
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
899
+ <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3>
900
+ <p style="color: #a0a0a0;">This visualization is not available for DPO pairs datasets.</p>
901
+ <p style="color: #a0a0a0;">DPO pairs contain prompt/chosen/rejected structure without token-level context.</p>
902
+ <p style="color: #6366f1; margin-top: 20px;">
903
+ Try loading a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset instead.
904
+ </p>
905
+ </div>
906
+ """
907
+ return html, go.Figure()
908
+
909
+ if idx >= len(df):
910
+ return "Index out of range", go.Figure()
911
+
912
+ row = df.iloc[idx]
913
+
914
+ context = row.get('pivot_context', row.get('prefix_context', ''))
915
+ token = row.get('pivot_token', row.get('sentence', ''))
916
+ prob_delta = row.get('prob_delta', 0)
917
+ prob_before = row.get('prob_before', row.get('prob_with_sentence', 0.5))
918
+ prob_after = row.get('prob_after', row.get('prob_without_sentence', 0.5))
919
+
920
+ # Handle missing data
921
+ if not context and not token:
922
+ html = """
923
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
924
+ <h3 style="color: #ef4444;">Missing Data</h3>
925
+ <p style="color: #a0a0a0;">This dataset doesn't have the expected fields for token visualization.</p>
926
+ </div>
927
+ """
928
+ return html, go.Figure()
929
+
930
+ html = create_token_highlight_html(context, token, prob_delta)
931
+ chart = create_probability_chart(prob_before, prob_after)
932
+
933
+ return html, chart
934
+
935
+
936
+ def update_graph_visualization(query_dropdown: str = None):
937
+ """Update the thought anchor graph."""
938
+ dataset_type = current_data.get("type", "unknown")
939
+ if dataset_type == 'dpo_pairs':
940
+ fig = go.Figure()
941
+ fig.add_annotation(
942
+ text="Reasoning Graph is not available for DPO pairs datasets.<br>Load a pivotal_tokens or thought_anchors dataset.",
943
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
944
+ font=dict(size=14, color="#a0a0a0")
945
+ )
946
+ fig.update_layout(template="plotly_dark", height=400)
947
+ return fig
948
+ return create_thought_anchor_graph(current_data["df"], query_dropdown)
949
+
950
+
951
+ def update_embedding_visualization(color_by: str):
952
+ """Update the embedding visualization."""
953
+ dataset_type = current_data.get("type", "unknown")
954
+ if dataset_type == 'dpo_pairs':
955
+ fig = go.Figure()
956
+ fig.add_annotation(
957
+ text="Embedding Space is not available for DPO pairs datasets.<br>Load a pivotal_tokens, thought_anchors, or steering_vectors dataset.",
958
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
959
+ font=dict(size=14, color="#a0a0a0")
960
+ )
961
+ fig.update_layout(template="plotly_dark", height=400)
962
+ return fig
963
+ return create_embedding_visualization(current_data["df"], color_by)
964
+
965
+
966
+ def update_circuit_view(query_idx: int):
967
+ """Update the circuit view."""
968
+ dataset_type = current_data.get("type", "unknown")
969
+ if dataset_type == 'dpo_pairs':
970
+ html = """
971
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
972
+ <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3>
973
+ <p style="color: #a0a0a0;">Circuit Tracer is not available for DPO pairs datasets.</p>
974
+ <p style="color: #6366f1; margin-top: 20px;">
975
+ Load a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset to explore reasoning circuits.
976
+ </p>
977
+ </div>
978
+ """
979
+ return html, go.Figure()
980
+ return create_circuit_visualization(current_data["df"], int(query_idx))
981
+
982
+
983
+ def update_statistics():
984
+ """Update the statistics dashboard."""
985
+ return create_statistics_dashboard(current_data["df"])
986
+
987
+
988
+ def get_query_list():
989
+ """Get list of unique queries with truncated display labels."""
990
+ df = current_data["df"]
991
+ if df.empty or 'query' not in df.columns:
992
+ return gr.Dropdown(choices=[], value=None)
993
+
994
+ queries = df['query'].unique().tolist()
995
+ # Return tuples of (truncated_label, full_value) for dropdown
996
+ # Gradio will show the label but pass the value
997
+ truncated_queries = []
998
+ for i, q in enumerate(queries):
999
+ q_str = str(q) if q is not None else ""
1000
+ if len(q_str) > 80:
1001
+ label = f"[{i+1}] {q_str[:77]}..."
1002
+ else:
1003
+ label = f"[{i+1}] {q_str}"
1004
+ truncated_queries.append((label, q_str))
1005
+
1006
+ return gr.Dropdown(choices=truncated_queries, value=None)
1007
+
1008
+
1009
+ def refresh_all():
1010
+ """Refresh all visualizations."""
1011
+ df = current_data["df"]
1012
+ if df.empty:
1013
+ empty_fig = go.Figure()
1014
+ empty_fig.update_layout(template="plotly_dark")
1015
+ return (
1016
+ "No data loaded",
1017
+ empty_fig,
1018
+ empty_fig,
1019
+ empty_fig,
1020
+ "No data loaded",
1021
+ empty_fig
1022
+ )
1023
+
1024
+ stats_html, stats_fig = create_statistics_dashboard(df)
1025
+ graph_fig = create_thought_anchor_graph(df)
1026
+ embed_fig = create_embedding_visualization(df)
1027
+ circuit_html, circuit_fig = create_circuit_visualization(df)
1028
+
1029
+ return stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig
1030
+
1031
+
1032
+ # ============================================================================
1033
+ # Build Gradio App
1034
+ # ============================================================================
1035
+
1036
+ # Pre-defined HuggingFace datasets
1037
+ HF_DATASETS = [
1038
+ "codelion/Qwen3-0.6B-pts",
1039
+ "codelion/Qwen3-0.6B-pts-thought-anchors",
1040
+ "codelion/Qwen3-0.6B-pts-steering-vectors",
1041
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts",
1042
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-thought-anchors",
1043
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors",
1044
+ ]
1045
+
1046
+ # Theme and CSS configuration
1047
+ THEME = gr.themes.Soft(
1048
+ primary_hue="indigo",
1049
+ secondary_hue="emerald",
1050
+ neutral_hue="slate"
1051
+ )
1052
+ CSS = """
1053
+ .gradio-container { max-width: 1400px !important; }
1054
+ .main-header { text-align: center; margin-bottom: 20px; }
1055
+ """
1056
+
1057
+ # Use try/except for Gradio version compatibility
1058
+ try:
1059
+ # Gradio 4.x style
1060
+ demo_context = gr.Blocks(title="PTS Visualizer", theme=THEME, css=CSS)
1061
+ except TypeError:
1062
+ # Gradio 6.x style (theme/css moved to launch)
1063
+ demo_context = gr.Blocks(title="PTS Visualizer")
1064
+
1065
+ with demo_context as demo:
1066
+
1067
+ # Header
1068
+ gr.Markdown("""
1069
+ # PTS Visualizer
1070
+ ### Interactive Exploration of Pivotal Tokens, Thought Anchors & Reasoning Circuits
1071
+
1072
+ A [Neuronpedia](https://neuronpedia.org/)-inspired platform for understanding how language models reason.
1073
+ Load datasets from HuggingFace Hub or upload your own JSONL files.
1074
+
1075
+ πŸ”— [Browse more PTS datasets on HuggingFace](https://huggingface.co/datasets?other=pts)
1076
+ """)
1077
+
1078
+ # Data Loading Section
1079
+ with gr.Accordion("Load Dataset", open=True):
1080
+ with gr.Row():
1081
+ source_type = gr.Radio(
1082
+ choices=["HuggingFace Hub", "Local File"],
1083
+ value="HuggingFace Hub",
1084
+ label="Data Source"
1085
+ )
1086
+
1087
+ with gr.Row():
1088
+ with gr.Column(scale=3):
1089
+ dataset_dropdown = gr.Dropdown(
1090
+ choices=HF_DATASETS,
1091
+ label="Select Dataset",
1092
+ allow_custom_value=True,
1093
+ info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID"
1094
+ )
1095
+ with gr.Column(scale=1):
1096
+ file_upload = gr.File(
1097
+ label="Or Upload JSONL",
1098
+ file_types=[".jsonl", ".json"]
1099
+ )
1100
+
1101
+ with gr.Row():
1102
+ load_btn = gr.Button("Load Dataset", variant="primary")
1103
+ refresh_btn = gr.Button("Refresh Visualizations", variant="secondary")
1104
+
1105
+ with gr.Row():
1106
+ load_status = gr.Textbox(label="Status", interactive=False)
1107
+ dataset_info = gr.Textbox(label="Dataset Info", interactive=False)
1108
+
1109
+ # Main Visualization Tabs
1110
+ with gr.Tabs():
1111
+
1112
+ # Overview Tab
1113
+ with gr.TabItem("Overview"):
1114
+ gr.Markdown("### Dataset Statistics")
1115
+ stats_html = gr.HTML()
1116
+ stats_chart = gr.Plot()
1117
+
1118
+ # Token Explorer Tab
1119
+ with gr.TabItem("Token Explorer"):
1120
+ gr.Markdown("### Explore Pivotal Tokens")
1121
+ with gr.Row():
1122
+ with gr.Column(scale=1):
1123
+ token_slider = gr.Slider(
1124
+ minimum=0, maximum=100, step=1, value=0,
1125
+ label="Token Index"
1126
+ )
1127
+ with gr.Column(scale=3):
1128
+ token_html = gr.HTML(label="Token in Context")
1129
+ prob_chart = gr.Plot(label="Probability Change")
1130
+
1131
+ # Thought Anchor Graph Tab
1132
+ with gr.TabItem("Reasoning Graph"):
1133
+ gr.Markdown("### Thought Anchor Dependency Graph")
1134
+ gr.Markdown("""
1135
+ *Visualizes causal dependencies between reasoning steps.
1136
+ Green nodes indicate positive impact, red nodes indicate negative impact.
1137
+ Node size reflects importance score.*
1138
+ """)
1139
+ with gr.Row():
1140
+ query_filter = gr.Dropdown(
1141
+ choices=[],
1142
+ label="Filter by Query",
1143
+ allow_custom_value=True
1144
+ )
1145
+ graph_plot = gr.Plot()
1146
+
1147
+ # Embedding Visualization Tab
1148
+ with gr.TabItem("Embedding Space"):
1149
+ gr.Markdown("### Embedding Space Visualization")
1150
+ gr.Markdown("*t-SNE projection of sentence/token embeddings. Explore clusters and patterns.*")
1151
+ with gr.Row():
1152
+ color_dropdown = gr.Dropdown(
1153
+ choices=["is_positive", "sentence_category", "reasoning_pattern", "task_type"],
1154
+ value="is_positive",
1155
+ label="Color By"
1156
+ )
1157
+ embed_plot = gr.Plot()
1158
+
1159
+ # Circuit Tracer Tab
1160
+ with gr.TabItem("Circuit Tracer"):
1161
+ gr.Markdown("### Step-by-Step Reasoning Circuit")
1162
+ gr.Markdown("*Walk through the reasoning process step by step. See how each step affects the probability of success.*")
1163
+ with gr.Row():
1164
+ circuit_query_idx = gr.Slider(
1165
+ minimum=0, maximum=100, step=1, value=0,
1166
+ label="Query Index"
1167
+ )
1168
+ circuit_html = gr.HTML()
1169
+ circuit_chart = gr.Plot()
1170
+
1171
+ # Event handlers
1172
+ load_btn.click(
1173
+ fn=load_dataset_action,
1174
+ inputs=[source_type, dataset_dropdown, file_upload],
1175
+ outputs=[load_status, dataset_info]
1176
+ ).then(
1177
+ fn=refresh_all,
1178
+ outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart]
1179
+ ).then(
1180
+ fn=lambda: gr.Slider(maximum=max(0, len(current_data["df"]) - 1)),
1181
+ outputs=[token_slider]
1182
+ ).then(
1183
+ fn=get_query_list,
1184
+ outputs=[query_filter]
1185
+ )
1186
+
1187
+ refresh_btn.click(
1188
+ fn=refresh_all,
1189
+ outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart]
1190
+ )
1191
+
1192
+ token_slider.change(
1193
+ fn=get_token_details,
1194
+ inputs=[token_slider],
1195
+ outputs=[token_html, prob_chart]
1196
+ )
1197
+
1198
+ query_filter.change(
1199
+ fn=update_graph_visualization,
1200
+ inputs=[query_filter],
1201
+ outputs=[graph_plot]
1202
+ )
1203
+
1204
+ color_dropdown.change(
1205
+ fn=update_embedding_visualization,
1206
+ inputs=[color_dropdown],
1207
+ outputs=[embed_plot]
1208
+ )
1209
+
1210
+ circuit_query_idx.change(
1211
+ fn=update_circuit_view,
1212
+ inputs=[circuit_query_idx],
1213
+ outputs=[circuit_html, circuit_chart]
1214
+ )
1215
+
1216
+
1217
+ # ============================================================================
1218
+ # Main Entry Point
1219
+ # ============================================================================
1220
+
1221
+ if __name__ == "__main__":
1222
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ plotly>=5.18.0
3
+ networkx>=3.1
4
+ pandas>=2.0.0
5
+ numpy>=1.24.0
6
+ datasets>=2.14.0
7
+ scikit-learn>=1.3.0
8
+ huggingface_hub>=0.19.0