lmellory commited on
Commit
75c7b6a
·
verified ·
1 Parent(s): 8606572
Files changed (1) hide show
  1. app.py +1357 -35
app.py CHANGED
@@ -1,49 +1,1371 @@
1
- import gradio as gr
2
- from diffusers import StableDiffusionPipeline
3
- import torch
4
 
5
- # Модель, которая валит даже на CPU (runwayml — публичная, без gated)
6
- pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None)
 
7
 
8
- # Оптимизации, чтобы не падало на CPU
9
- pipe.enable_model_cpu_offload() # Разгружает на CPU, если GPU нет
10
- pipe.enable_vae_slicing()
11
- pipe.enable_attention_slicing()
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Если GPU нет — принудительно CPU (чтобы не орало про NVIDIA)
14
- if not torch.cuda.is_available():
15
- pipe = pipe.to("cpu")
16
 
17
- def generate(prompt, negative_prompt="", steps=30, seed=42):
 
18
  try:
19
- negative = negative_prompt or "blurry, low quality, ugly, deformed"
20
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
21
 
22
- image = pipe(prompt, negative_prompt=negative, num_inference_steps=steps, generator=generator).images[0]
23
- return image
 
 
 
 
 
 
 
24
  except Exception as e:
25
- return f"Ошибка, брат: {str(e)}. На бесплатном CPU медленно (1–2 мин), подожди или попробуй проще prompt."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- with gr.Blocks() as demo:
28
- gr.Markdown("# Epic AI Art Generator 🔥")
29
- gr.Markdown("Сделано для моей девушки ❤️ Валим арты нахуй! Попробуй на русском или английском.")
 
 
 
30
 
31
- with gr.Row():
32
- prompt = gr.Textbox(label="Prompt (описание)", placeholder="рыжий кот в стиле Гарри Поттера, волшебная шляпа, замок Хогвартс")
33
- negative = gr.Textbox(label="Negative prompt", placeholder="blurry, low quality")
34
- steps = gr.Slider(20, 50, value=30, step=5, label="Шаги (steps) — меньше для скорости")
35
- seed = gr.Slider(0, 999999, value=42, label="Seed")
 
36
 
37
- btn = gr.Button("Generate 🔥", variant="primary")
 
 
 
 
 
38
 
39
- output = gr.Image(label="Арт, братан")
 
 
 
 
 
40
 
41
- btn.click(generate, inputs=[prompt, negative, steps, seed], outputs=output)
42
 
43
- gr.Markdown("### Примеры, чтобы ахуеть:")
44
- gr.Markdown("- 'рыжий кот космонавт в космосе'")
45
- gr.Markdown("- 'романтическая пара на закате пляжа'")
46
- gr.Markdown("- 'кот в мантии волшебника, стиль Гарри Поттера'")
47
- gr.Markdown("- 'мы с девушкой как супергерои'")
48
 
49
- demo.launch()
 
 
1
+ """
2
+ PTS Visualizer - Interactive visualization for Pivotal Token Search
 
3
 
4
+ A Neuronpedia-inspired platform for exploring pivotal tokens, thought anchors,
5
+ and reasoning circuits in language models.
6
+ """
7
 
8
+ import gradio as gr
9
+ import plotly.express as px
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+ import networkx as nx
13
+ import pandas as pd
14
+ import numpy as np
15
+ import json
16
+ import html as html_lib
17
+ from typing import List, Dict, Any, Optional, Tuple
18
+ from datasets import load_dataset
19
+ from sklearn.manifold import TSNE
20
+ from sklearn.decomposition import PCA
21
+ import re
22
+ from collections import defaultdict
23
 
24
+ # ============================================================================
25
+ # Data Loading Functions
26
+ # ============================================================================
27
 
28
+ def load_hf_dataset(dataset_id: str, split: str = "train") -> pd.DataFrame:
29
+ """Load a dataset from HuggingFace Hub."""
30
  try:
31
+ dataset = load_dataset(dataset_id, split=split)
32
+ df = pd.DataFrame(dataset)
33
+ return df, f"Loaded {len(df)} items from {dataset_id}"
34
+ except Exception as e:
35
+ return pd.DataFrame(), f"Error loading dataset: {str(e)}"
36
+
37
 
38
+ def load_jsonl_file(file_path: str) -> pd.DataFrame:
39
+ """Load data from a local JSONL file."""
40
+ try:
41
+ data = []
42
+ with open(file_path, 'r') as f:
43
+ for line in f:
44
+ if line.strip():
45
+ data.append(json.loads(line))
46
+ return pd.DataFrame(data), f"Loaded {len(data)} items from file"
47
  except Exception as e:
48
+ return pd.DataFrame(), f"Error loading file: {str(e)}"
49
+
50
+
51
+ def detect_dataset_type(df: pd.DataFrame) -> str:
52
+ """Detect the type of PTS dataset."""
53
+ columns = set(df.columns)
54
+
55
+ if 'sentence' in columns and 'sentence_id' in columns:
56
+ return 'thought_anchors'
57
+ elif 'steering_vector' in columns:
58
+ return 'steering_vectors'
59
+ elif 'chosen' in columns and 'rejected' in columns:
60
+ return 'dpo_pairs'
61
+ elif 'pivot_token' in columns:
62
+ return 'pivotal_tokens'
63
+ else:
64
+ return 'unknown'
65
+
66
+
67
+ # ============================================================================
68
+ # Visualization Components
69
+ # ============================================================================
70
+
71
+ def create_token_highlight_html(context: str, token: str, prob_delta: float) -> str:
72
+ """Create HTML with highlighted pivotal token showing full context."""
73
+ # Escape HTML characters
74
+ context_escaped = html_lib.escape(str(context))
75
+ token_escaped = html_lib.escape(str(token))
76
+
77
+ # Determine color based on probability delta
78
+ if prob_delta > 0:
79
+ # Positive impact - green gradient
80
+ intensity = min(abs(prob_delta) * 2, 1.0)
81
+ color = f"rgba(34, 197, 94, {intensity})"
82
+ border_color = "#22c55e"
83
+ impact_text = "Positive Impact"
84
+ else:
85
+ # Negative impact - red gradient
86
+ intensity = min(abs(prob_delta) * 2, 1.0)
87
+ color = f"rgba(239, 68, 68, {intensity})"
88
+ border_color = "#ef4444"
89
+ impact_text = "Negative Impact"
90
+
91
+ # Create highlighted token span
92
+ token_span = f'<span style="background-color: {color}; padding: 2px 6px; border-radius: 3px; border: 2px solid {border_color}; font-weight: bold; font-size: 1.1em;">{token_escaped}</span>'
93
+
94
+ return f"""
95
+ <div style="background-color: #1a1a2e; border-radius: 10px; padding: 20px;">
96
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 15px;">
97
+ <span style="color: #a0a0a0; font-size: 0.9em;">Context Length: {len(context)} characters</span>
98
+ <span style="background-color: {border_color}; color: white; padding: 4px 12px; border-radius: 5px; font-weight: bold;">
99
+ {impact_text}: {'+' if prob_delta > 0 else ''}{prob_delta:.3f}
100
+ </span>
101
+ </div>
102
+ <div style="font-family: monospace; padding: 15px; background-color: #0d1117; border-radius: 8px; color: #e0e0e0; line-height: 1.8; max-height: 500px; overflow-y: auto; white-space: pre-wrap; word-break: break-word; border: 1px solid #30363d;">
103
+ <span style="color: #8b949e;">{context_escaped}</span>{token_span}
104
+ </div>
105
+ <div style="margin-top: 15px; display: flex; gap: 10px; flex-wrap: wrap;">
106
+ <span style="background-color: #238636; color: white; padding: 5px 10px; border-radius: 5px; font-size: 0.9em;">
107
+ Token: <code style="background-color: rgba(0,0,0,0.3); padding: 2px 5px; border-radius: 3px;">{token_escaped}</code>
108
+ </span>
109
+ </div>
110
+ </div>
111
+ """
112
+
113
+
114
+ def create_probability_chart(prob_before: float, prob_after: float) -> go.Figure:
115
+ """Create a bar chart showing probability change."""
116
+ fig = go.Figure()
117
+
118
+ # Ensure values are Python floats
119
+ prob_before = float(prob_before) if prob_before is not None else 0.0
120
+ prob_after = float(prob_after) if prob_after is not None else 0.0
121
+
122
+ fig.add_trace(go.Bar(
123
+ x=['Before Token', 'After Token'],
124
+ y=[prob_before, prob_after],
125
+ marker_color=['#6366f1', '#22c55e' if prob_after > prob_before else '#ef4444'],
126
+ text=[f'{prob_before:.3f}', f'{prob_after:.3f}'],
127
+ textposition='outside'
128
+ ))
129
+
130
+ fig.update_layout(
131
+ title="Success Probability Change",
132
+ yaxis_title="Probability",
133
+ yaxis_range=[0, 1],
134
+ template="plotly_dark",
135
+ height=300
136
+ )
137
+
138
+ return fig
139
+
140
+
141
+ def create_pivotal_token_flow(df: pd.DataFrame, selected_query: str = None) -> go.Figure:
142
+ """Create a visualization for pivotal tokens showing token impact flow."""
143
+ if df.empty:
144
+ fig = go.Figure()
145
+ fig.add_annotation(text="No data available",
146
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
147
+ fig.update_layout(template="plotly_dark")
148
+ return fig
149
+
150
+ # Filter by query if specified (handle None, empty string, or actual query)
151
+ if selected_query and isinstance(selected_query, str) and selected_query.strip() and 'query' in df.columns:
152
+ df = df[df['query'] == selected_query].copy()
153
+
154
+ if df.empty:
155
+ fig = go.Figure()
156
+ fig.add_annotation(text="No data for selected query",
157
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
158
+ fig.update_layout(template="plotly_dark")
159
+ return fig
160
+
161
+ # Create scatter plot of tokens by probability delta
162
+ fig = go.Figure()
163
+
164
+ # Separate positive and negative tokens
165
+ positive_df = df[df.get('is_positive', df['prob_delta'] > 0) == True] if 'is_positive' in df.columns else df[df['prob_delta'] > 0]
166
+ negative_df = df[df.get('is_positive', df['prob_delta'] > 0) == False] if 'is_positive' in df.columns else df[df['prob_delta'] <= 0]
167
+
168
+ # Add positive tokens
169
+ if not positive_df.empty:
170
+ hover_text = [
171
+ f"Token: {row.get('pivot_token', 'N/A')}<br>"
172
+ f"Δ Prob: +{row.get('prob_delta', 0):.3f}<br>"
173
+ f"Before: {row.get('prob_before', 0):.3f}<br>"
174
+ f"After: {row.get('prob_after', 0):.3f}<br>"
175
+ f"Query: {str(row.get('query', ''))[:50]}..."
176
+ for _, row in positive_df.iterrows()
177
+ ]
178
+ y_vals = positive_df['prob_delta'].tolist()
179
+ sizes = [10 + abs(v) * 30 for v in y_vals]
180
+ fig.add_trace(go.Scatter(
181
+ x=list(range(len(positive_df))),
182
+ y=y_vals,
183
+ mode='markers',
184
+ name='Positive Impact',
185
+ marker=dict(
186
+ size=sizes,
187
+ color='#22c55e',
188
+ opacity=0.7
189
+ ),
190
+ hovertext=hover_text,
191
+ hoverinfo='text'
192
+ ))
193
+
194
+ # Add negative tokens
195
+ if not negative_df.empty:
196
+ hover_text = [
197
+ f"Token: {row.get('pivot_token', 'N/A')}<br>"
198
+ f"Δ Prob: {row.get('prob_delta', 0):.3f}<br>"
199
+ f"Before: {row.get('prob_before', 0):.3f}<br>"
200
+ f"After: {row.get('prob_after', 0):.3f}<br>"
201
+ f"Query: {str(row.get('query', ''))[:50]}..."
202
+ for _, row in negative_df.iterrows()
203
+ ]
204
+ y_vals = negative_df['prob_delta'].tolist()
205
+ sizes = [10 + abs(v) * 30 for v in y_vals]
206
+ fig.add_trace(go.Scatter(
207
+ x=list(range(len(negative_df))),
208
+ y=y_vals,
209
+ mode='markers',
210
+ name='Negative Impact',
211
+ marker=dict(
212
+ size=sizes,
213
+ color='#ef4444',
214
+ opacity=0.7
215
+ ),
216
+ hovertext=hover_text,
217
+ hoverinfo='text'
218
+ ))
219
+
220
+ fig.add_hline(y=0, line_dash="dash", line_color="gray")
221
+
222
+ fig.update_layout(
223
+ title="Pivotal Token Impact Distribution",
224
+ xaxis_title="Token Index",
225
+ yaxis_title="Probability Delta",
226
+ template="plotly_dark",
227
+ height=500,
228
+ showlegend=True
229
+ )
230
+
231
+ return fig
232
+
233
+
234
+ def create_thought_anchor_graph(df: pd.DataFrame, selected_query: str = None) -> go.Figure:
235
+ """Create an interactive graph visualization of thought anchor dependencies."""
236
+ dataset_type = detect_dataset_type(df)
237
+
238
+ # For pivotal tokens and steering vectors, create a token impact visualization
239
+ if dataset_type in ('pivotal_tokens', 'steering_vectors'):
240
+ return create_pivotal_token_flow(df, selected_query)
241
+
242
+ if df.empty or 'sentence_id' not in df.columns:
243
+ fig = go.Figure()
244
+ fig.add_annotation(text="No thought anchor data available. Load a thought anchors dataset to see the reasoning graph.",
245
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
246
+ font=dict(size=14, color="#a0a0a0"))
247
+ fig.update_layout(template="plotly_dark", height=400)
248
+ return fig
249
+
250
+ # Filter by query if specified (handle None, empty string, or actual query)
251
+ if selected_query and isinstance(selected_query, str) and selected_query.strip():
252
+ df = df[df['query'] == selected_query].copy()
253
+
254
+ if df.empty:
255
+ fig = go.Figure()
256
+ fig.add_annotation(text="No data for selected query",
257
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
258
+ fig.update_layout(template="plotly_dark")
259
+ return fig
260
+
261
+ # Create networkx graph
262
+ G = nx.DiGraph()
263
+
264
+ # Add nodes (sentences)
265
+ for idx, row in df.iterrows():
266
+ sentence_id = row.get('sentence_id', idx)
267
+ importance = row.get('importance_score', abs(row.get('prob_delta', 0)))
268
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
269
+ sentence = row.get('sentence', '')[:50] + '...' if len(row.get('sentence', '')) > 50 else row.get('sentence', '')
270
+
271
+ G.add_node(sentence_id,
272
+ importance=importance,
273
+ is_positive=is_positive,
274
+ sentence=sentence,
275
+ category=row.get('sentence_category', 'unknown'))
276
+
277
+ # Add edges from causal dependencies
278
+ for idx, row in df.iterrows():
279
+ sentence_id = row.get('sentence_id', idx)
280
+ dependencies = row.get('causal_dependencies', [])
281
+ if isinstance(dependencies, list):
282
+ for dep in dependencies:
283
+ if dep in G.nodes():
284
+ G.add_edge(dep, sentence_id)
285
+
286
+ # If no explicit dependencies, create sequential edges
287
+ if G.number_of_edges() == 0:
288
+ sorted_nodes = sorted(G.nodes())
289
+ for i in range(len(sorted_nodes) - 1):
290
+ G.add_edge(sorted_nodes[i], sorted_nodes[i+1])
291
+
292
+ # Layout
293
+ pos = nx.spring_layout(G, k=2, iterations=50)
294
+
295
+ # Create edge traces
296
+ edge_x = []
297
+ edge_y = []
298
+ for edge in G.edges():
299
+ x0, y0 = pos[edge[0]]
300
+ x1, y1 = pos[edge[1]]
301
+ edge_x.extend([float(x0), float(x1), None])
302
+ edge_y.extend([float(y0), float(y1), None])
303
+
304
+ edge_trace = go.Scatter(
305
+ x=edge_x, y=edge_y,
306
+ line=dict(width=1, color='#888'),
307
+ hoverinfo='none',
308
+ mode='lines'
309
+ )
310
+
311
+ # Create node traces
312
+ node_x = []
313
+ node_y = []
314
+ node_colors = []
315
+ node_sizes = []
316
+ node_texts = []
317
+
318
+ for node in G.nodes():
319
+ x, y = pos[node]
320
+ node_x.append(float(x))
321
+ node_y.append(float(y))
322
+
323
+ node_data = G.nodes[node]
324
+ is_positive = node_data.get('is_positive', True)
325
+ importance = float(node_data.get('importance', 0.3))
326
+
327
+ node_colors.append('#22c55e' if is_positive else '#ef4444')
328
+ node_sizes.append(20 + importance * 50)
329
+
330
+ hover_text = f"Sentence {node}<br>"
331
+ hover_text += f"Category: {node_data.get('category', 'unknown')}<br>"
332
+ hover_text += f"Importance: {importance:.3f}<br>"
333
+ hover_text += f"Text: {node_data.get('sentence', 'N/A')}"
334
+ node_texts.append(hover_text)
335
+
336
+ node_trace = go.Scatter(
337
+ x=node_x, y=node_y,
338
+ mode='markers+text',
339
+ hoverinfo='text',
340
+ text=[str(n) for n in G.nodes()],
341
+ textposition="top center",
342
+ hovertext=node_texts,
343
+ marker=dict(
344
+ color=node_colors,
345
+ size=node_sizes,
346
+ line=dict(width=2, color='white')
347
+ )
348
+ )
349
+
350
+ # Create figure
351
+ fig = go.Figure(data=[edge_trace, node_trace])
352
+
353
+ fig.update_layout(
354
+ title="Thought Anchor Reasoning Graph",
355
+ showlegend=False,
356
+ hovermode='closest',
357
+ template="plotly_dark",
358
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
359
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
360
+ height=500
361
+ )
362
+
363
+ return fig
364
+
365
+
366
+ def create_probability_space_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure:
367
+ """Create a probability space visualization for pivotal tokens (prob_before vs prob_after)."""
368
+ fig = go.Figure()
369
+
370
+ # Color palette for categorical values
371
+ CATEGORY_COLORS = [
372
+ '#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6',
373
+ '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16'
374
+ ]
375
+
376
+ # Determine color column
377
+ use_colorscale = False
378
+ if color_by in df.columns:
379
+ color_col = df[color_by]
380
+ if color_by == 'is_positive':
381
+ colors = ['#22c55e' if v else '#ef4444' for v in color_col]
382
+ else:
383
+ # Convert to list
384
+ values = color_col.tolist() if hasattr(color_col, 'tolist') else list(color_col)
385
+
386
+ if len(values) > 0:
387
+ # Check if numeric
388
+ if isinstance(values[0], (int, float)) and not isinstance(values[0], bool):
389
+ colors = values
390
+ use_colorscale = True
391
+ else:
392
+ # Categorical - map to colors
393
+ unique_vals = list(set(values))
394
+ color_map = {val: CATEGORY_COLORS[i % len(CATEGORY_COLORS)] for i, val in enumerate(unique_vals)}
395
+ colors = [color_map[v] for v in values]
396
+ else:
397
+ colors = ['#6366f1'] * len(df)
398
+ else:
399
+ colors = ['#6366f1'] * len(df)
400
+
401
+ # Create hover text
402
+ hover_texts = []
403
+ for _, row in df.iterrows():
404
+ text = f"Token: {row.get('pivot_token', 'N/A')}<br>"
405
+ text += f"Before: {row.get('prob_before', 0):.3f}<br>"
406
+ text += f"After: {row.get('prob_after', 0):.3f}<br>"
407
+ text += f"Delta: {row.get('prob_delta', 0):+.3f}<br>"
408
+ text += f"Query: {str(row.get('query', ''))[:40]}..."
409
+ hover_texts.append(text)
410
+
411
+ fig.add_trace(go.Scatter(
412
+ x=df['prob_before'].tolist(),
413
+ y=df['prob_after'].tolist(),
414
+ mode='markers',
415
+ marker=dict(
416
+ size=8,
417
+ color=colors,
418
+ opacity=0.6,
419
+ colorscale='Viridis' if use_colorscale else None,
420
+ showscale=use_colorscale
421
+ ),
422
+ hovertext=hover_texts,
423
+ hoverinfo='text',
424
+ name='Pivotal Tokens'
425
+ ))
426
+
427
+ # Add diagonal line (no change)
428
+ fig.add_trace(go.Scatter(
429
+ x=[0, 1],
430
+ y=[0, 1],
431
+ mode='lines',
432
+ line=dict(dash='dash', color='gray', width=1),
433
+ name='No Change Line',
434
+ showlegend=True
435
+ ))
436
+
437
+ fig.update_layout(
438
+ title="Probability Space: Before vs After Pivotal Token",
439
+ xaxis_title="Probability Before Token",
440
+ yaxis_title="Probability After Token",
441
+ xaxis=dict(range=[0, 1]),
442
+ yaxis=dict(range=[0, 1]),
443
+ template="plotly_dark",
444
+ height=500
445
+ )
446
+
447
+ # Add annotations
448
+ fig.add_annotation(
449
+ x=0.2, y=0.8,
450
+ text="Positive Impact ↑",
451
+ showarrow=False,
452
+ font=dict(color="#22c55e", size=12)
453
+ )
454
+ fig.add_annotation(
455
+ x=0.8, y=0.2,
456
+ text="Negative Impact ↓",
457
+ showarrow=False,
458
+ font=dict(color="#ef4444", size=12)
459
+ )
460
+
461
+ return fig
462
+
463
+
464
+ def create_embedding_visualization(df: pd.DataFrame, color_by: str = 'is_positive') -> go.Figure:
465
+ """Create UMAP/t-SNE visualization of embeddings or alternative visualization for pivotal tokens."""
466
+ if df.empty:
467
+ fig = go.Figure()
468
+ fig.add_annotation(text="No data available",
469
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
470
+ fig.update_layout(template="plotly_dark")
471
+ return fig
472
+
473
+ dataset_type = detect_dataset_type(df)
474
+
475
+ # Check for embeddings
476
+ embedding_col = None
477
+ for col in ['sentence_embedding', 'steering_vector']:
478
+ if col in df.columns:
479
+ embedding_col = col
480
+ break
481
+
482
+ # For pivotal tokens without embeddings, create a probability space visualization
483
+ if embedding_col is None:
484
+ if dataset_type == 'pivotal_tokens' and 'prob_before' in df.columns and 'prob_after' in df.columns:
485
+ return create_probability_space_visualization(df, color_by)
486
+
487
+ fig = go.Figure()
488
+ fig.add_annotation(
489
+ text="No embedding data found. Embeddings are available in thought_anchors and steering_vectors datasets.",
490
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
491
+ font=dict(size=12, color="#a0a0a0")
492
+ )
493
+ fig.update_layout(template="plotly_dark", height=400)
494
+ return fig
495
+
496
+ # Extract embeddings
497
+ embeddings = []
498
+ valid_indices = []
499
+
500
+ for idx, row in df.iterrows():
501
+ emb = row.get(embedding_col, [])
502
+ # Handle both list and numpy array formats
503
+ if emb is not None:
504
+ if isinstance(emb, np.ndarray) and len(emb) > 0:
505
+ embeddings.append(emb.tolist())
506
+ valid_indices.append(idx)
507
+ elif isinstance(emb, list) and len(emb) > 0:
508
+ embeddings.append(emb)
509
+ valid_indices.append(idx)
510
+
511
+ if len(embeddings) < 3:
512
+ fig = go.Figure()
513
+ fig.add_annotation(text="Not enough embeddings for visualization (need at least 3)",
514
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False)
515
+ fig.update_layout(template="plotly_dark")
516
+ return fig
517
+
518
+ embeddings = np.array(embeddings)
519
+
520
+ # Reduce dimensionality
521
+ n_samples = len(embeddings)
522
+ perplexity = min(30, max(5, n_samples // 3))
523
+
524
+ if embeddings.shape[1] > 50:
525
+ # First reduce with PCA
526
+ pca = PCA(n_components=min(50, n_samples - 1))
527
+ embeddings = pca.fit_transform(embeddings)
528
+
529
+ # Then t-SNE for visualization
530
+ tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
531
+ coords = tsne.fit_transform(embeddings)
532
+
533
+ # Create dataframe for plotting
534
+ plot_df = df.iloc[valid_indices].copy()
535
+ plot_df['x'] = coords[:, 0].tolist()
536
+ plot_df['y'] = coords[:, 1].tolist()
537
+
538
+ # Handle color column
539
+ if color_by not in plot_df.columns:
540
+ color_by = 'is_positive' if 'is_positive' in plot_df.columns else None
541
+
542
+ fig = go.Figure()
543
+
544
+ # Determine text field for hover
545
+ text_field = 'sentence' if 'sentence' in plot_df.columns else 'pivot_token'
546
+
547
+ if color_by and color_by in plot_df.columns:
548
+ # Group by color column for separate traces
549
+ if color_by == 'is_positive':
550
+ # Special handling for boolean is_positive
551
+ for is_pos in [True, False]:
552
+ mask = plot_df[color_by] == is_pos
553
+ subset = plot_df[mask]
554
+ if len(subset) > 0:
555
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()]
556
+ fig.add_trace(go.Scatter(
557
+ x=subset['x'].tolist(),
558
+ y=subset['y'].tolist(),
559
+ mode='markers',
560
+ name='Positive' if is_pos else 'Negative',
561
+ marker=dict(
562
+ size=8,
563
+ color='#22c55e' if is_pos else '#ef4444',
564
+ opacity=0.7
565
+ ),
566
+ hovertext=hover_texts,
567
+ hoverinfo='text'
568
+ ))
569
+ else:
570
+ # Categorical coloring
571
+ unique_vals = plot_df[color_by].unique()
572
+ colors = ['#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6',
573
+ '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16']
574
+ for i, val in enumerate(unique_vals):
575
+ mask = plot_df[color_by] == val
576
+ subset = plot_df[mask]
577
+ if len(subset) > 0:
578
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()]
579
+ fig.add_trace(go.Scatter(
580
+ x=subset['x'].tolist(),
581
+ y=subset['y'].tolist(),
582
+ mode='markers',
583
+ name=str(val),
584
+ marker=dict(
585
+ size=8,
586
+ color=colors[i % len(colors)],
587
+ opacity=0.7
588
+ ),
589
+ hovertext=hover_texts,
590
+ hoverinfo='text'
591
+ ))
592
+ else:
593
+ # No color grouping
594
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in plot_df.iterrows()]
595
+ fig.add_trace(go.Scatter(
596
+ x=plot_df['x'].tolist(),
597
+ y=plot_df['y'].tolist(),
598
+ mode='markers',
599
+ name='Embeddings',
600
+ marker=dict(
601
+ size=8,
602
+ color='#6366f1',
603
+ opacity=0.7
604
+ ),
605
+ hovertext=hover_texts,
606
+ hoverinfo='text'
607
+ ))
608
+
609
+ fig.update_layout(
610
+ title="Embedding Space Visualization (t-SNE)",
611
+ xaxis_title="t-SNE 1",
612
+ yaxis_title="t-SNE 2",
613
+ template="plotly_dark",
614
+ height=500,
615
+ showlegend=True
616
+ )
617
+
618
+ return fig
619
+
620
+
621
+ def create_pivotal_token_trace(df: pd.DataFrame, selected_query: str) -> Tuple[str, go.Figure]:
622
+ """Create a trace visualization for pivotal tokens in a query."""
623
+ if df.empty:
624
+ return "No tokens found for this query", go.Figure()
625
+
626
+ # Build HTML for token cards
627
+ html_parts = [f"""
628
+ <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;">
629
+ <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;">
630
+ Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}
631
+ </h3>
632
+ <p style="color: #a0a0a0; margin: 10px 0;">Found {len(df)} pivotal tokens for this query</p>
633
+ <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;">
634
+ """]
635
+
636
+ prob_deltas = []
637
+ token_indices = []
638
+
639
+ for idx, (_, row) in enumerate(df.iterrows()):
640
+ token = row.get('pivot_token', 'N/A')
641
+ context = row.get('pivot_context', '')
642
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
643
+ prob_delta = row.get('prob_delta', 0)
644
+ prob_before = row.get('prob_before', 0)
645
+ prob_after = row.get('prob_after', 0)
646
+ task_type = row.get('task_type', 'unknown')
647
+
648
+ # Color based on impact
649
+ bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)"
650
+ border_color = "#22c55e" if is_positive else "#ef4444"
651
+
652
+ # Show full context in a scrollable container - no truncation
653
+ # Escape HTML characters in context and token
654
+ context_escaped = html_lib.escape(str(context))
655
+ token_escaped = html_lib.escape(str(token))
656
+
657
+ # Build token card with full context (scrollable)
658
+ card_html = f"""
659
+ <div style="background-color: {bg_color}; border-left: 4px solid {border_color};
660
+ padding: 15px; border-radius: 5px; margin-bottom: 5px;">
661
+ <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;">
662
+ <span style="color: #a0a0a0; font-size: 0.9em;">Token #{idx + 1} | {task_type}</span>
663
+ <span style="color: {border_color}; font-weight: bold; font-size: 1.1em;">
664
+ {'+'if prob_delta > 0 else ''}{prob_delta:.3f}
665
+ </span>
666
+ </div>
667
+ <div style="background-color: #1a1a2e; padding: 10px; border-radius: 5px; max-height: 200px; overflow-y: auto; margin: 10px 0;">
668
+ <span style="color: #888; font-family: monospace; font-size: 0.85em; white-space: pre-wrap; word-break: break-word;">{context_escaped}</span><span style="background-color: {border_color}; color: white; padding: 2px 6px; border-radius: 3px; font-weight: bold; font-family: monospace;">{token_escaped}</span>
669
+ </div>
670
+ <div style="display: flex; gap: 15px; flex-wrap: wrap;">
671
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
672
+ Before: {prob_before:.3f}
673
+ </span>
674
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
675
+ After: {prob_after:.3f}
676
+ </span>
677
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #6366f1;">
678
+ Context: {len(context)} chars
679
+ </span>
680
+ </div>
681
+ </div>
682
+ """
683
+ html_parts.append(card_html)
684
+ prob_deltas.append(prob_delta)
685
+ token_indices.append(idx)
686
+
687
+ html_parts.append("</div></div>")
688
+
689
+ # Create probability delta chart
690
+ fig = go.Figure()
691
+
692
+ # Ensure all values are Python native types
693
+ prob_deltas = [float(d) for d in prob_deltas]
694
+ colors = ['#22c55e' if d > 0 else '#ef4444' for d in prob_deltas]
695
+
696
+ fig.add_trace(go.Bar(
697
+ x=token_indices,
698
+ y=prob_deltas,
699
+ marker_color=colors,
700
+ name='Probability Delta',
701
+ hovertemplate='Token #%{x}<br>Δ Prob: %{y:.3f}<extra></extra>'
702
+ ))
703
+
704
+ fig.add_hline(y=0, line_dash="dash", line_color="gray")
705
+
706
+ fig.update_layout(
707
+ title="Probability Impact per Token",
708
+ xaxis_title="Token Index",
709
+ yaxis_title="Probability Delta",
710
+ template="plotly_dark",
711
+ height=300
712
+ )
713
+
714
+ return "\n".join(html_parts), fig
715
+
716
+
717
+ def create_circuit_visualization(df: pd.DataFrame, query_idx: int = 0) -> Tuple[str, go.Figure]:
718
+ """Create step-by-step circuit visualization for reasoning trace."""
719
+ if df.empty:
720
+ return "No data available", go.Figure()
721
+
722
+ dataset_type = detect_dataset_type(df)
723
+
724
+ # Get unique queries
725
+ queries = df['query'].unique() if 'query' in df.columns else []
726
+ if len(queries) == 0:
727
+ return "No queries found", go.Figure()
728
+
729
+ query_idx = min(query_idx, len(queries) - 1)
730
+ selected_query = queries[query_idx]
731
+
732
+ # Filter to this query
733
+ query_df = df[df['query'] == selected_query].copy()
734
+
735
+ # For pivotal tokens and steering vectors, use the token trace visualization
736
+ if dataset_type in ('pivotal_tokens', 'steering_vectors'):
737
+ return create_pivotal_token_trace(query_df, selected_query)
738
+
739
+ # Sort by sentence_id if available, otherwise keep original order
740
+ if 'sentence_id' in query_df.columns:
741
+ query_df = query_df.sort_values('sentence_id')
742
+ else:
743
+ query_df = query_df.reset_index(drop=True)
744
+
745
+ # Build HTML for step-by-step view
746
+ html_parts = [f"""
747
+ <div style="font-family: sans-serif; padding: 20px; background-color: #1a1a2e; border-radius: 10px;">
748
+ <h3 style="color: #e0e0e0; border-bottom: 2px solid #6366f1; padding-bottom: 10px;">
749
+ Query: {selected_query[:100]}{'...' if len(selected_query) > 100 else ''}
750
+ </h3>
751
+ <div style="display: flex; flex-direction: column; gap: 15px; margin-top: 20px;">
752
+ """]
753
+
754
+ prob_values = []
755
+ sentence_ids = []
756
+
757
+ for idx, row in query_df.iterrows():
758
+ sentence = row.get('sentence', 'N/A')
759
+ sentence_id = row.get('sentence_id', idx)
760
+ is_positive = row.get('is_positive', row.get('prob_delta', 0) > 0)
761
+ prob_delta = row.get('prob_delta', 0)
762
+ category = row.get('sentence_category', 'unknown')
763
+ importance = row.get('importance_score', abs(prob_delta))
764
+
765
+ # Verification info
766
+ verification_score = row.get('verification_score', None)
767
+ arithmetic_errors = row.get('arithmetic_errors', [])
768
+
769
+ # Color based on impact
770
+ bg_color = "rgba(34, 197, 94, 0.2)" if is_positive else "rgba(239, 68, 68, 0.2)"
771
+ border_color = "#22c55e" if is_positive else "#ef4444"
772
+
773
+ # Build step card
774
+ step_html = f"""
775
+ <div style="background-color: {bg_color}; border-left: 4px solid {border_color};
776
+ padding: 15px; border-radius: 5px;">
777
+ <div style="display: flex; justify-content: space-between; align-items: center;">
778
+ <span style="color: #a0a0a0; font-size: 0.9em;">Step {sentence_id} | {category}</span>
779
+ <span style="color: {border_color}; font-weight: bold;">
780
+ {'+'if prob_delta > 0 else ''}{prob_delta:.3f}
781
+ </span>
782
+ </div>
783
+ <p style="color: #e0e0e0; margin: 10px 0;">{sentence}</p>
784
+ <div style="display: flex; gap: 10px; flex-wrap: wrap;">
785
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #a0a0a0;">
786
+ Importance: {importance:.3f}
787
+ </span>
788
+ """
789
+
790
+ if verification_score is not None:
791
+ v_color = "#22c55e" if verification_score > 0.5 else "#ef4444"
792
+ step_html += f"""
793
+ <span style="background-color: #333; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: {v_color};">
794
+ Verification: {verification_score:.2f}
795
+ </span>
796
+ """
797
+
798
+ if arithmetic_errors:
799
+ step_html += """
800
+ <span style="background-color: #7f1d1d; padding: 3px 8px; border-radius: 3px; font-size: 0.8em; color: #fca5a5;">
801
+ Has Errors
802
+ </span>
803
+ """
804
+
805
+ step_html += """
806
+ </div>
807
+ </div>
808
+ """
809
+
810
+ html_parts.append(step_html)
811
+ prob_values.append(row.get('prob_with_sentence', 0.5))
812
+ sentence_ids.append(sentence_id)
813
+
814
+ html_parts.append("</div></div>")
815
+
816
+ # Create probability progression chart
817
+ fig = go.Figure()
818
+
819
+ colors = ['#22c55e' if p > 0.5 else '#ef4444' for p in prob_values]
820
+
821
+ fig.add_trace(go.Scatter(
822
+ x=[int(s) if isinstance(s, (int, np.integer)) else s for s in sentence_ids],
823
+ y=[float(p) for p in prob_values],
824
+ mode='lines+markers',
825
+ name='Success Probability',
826
+ line=dict(color='#6366f1', width=2),
827
+ marker=dict(size=10, color=colors)
828
+ ))
829
+
830
+ fig.add_hline(y=0.5, line_dash="dash", line_color="gray",
831
+ annotation_text="50% threshold")
832
+
833
+ fig.update_layout(
834
+ title="Probability Progression Through Reasoning",
835
+ xaxis_title="Sentence ID",
836
+ yaxis_title="Success Probability",
837
+ yaxis_range=[0, 1],
838
+ template="plotly_dark",
839
+ height=300
840
+ )
841
+
842
+ return "\n".join(html_parts), fig
843
+
844
+
845
+ def create_statistics_dashboard(df: pd.DataFrame) -> Tuple[str, go.Figure]:
846
+ """Create statistics dashboard for the dataset."""
847
+ if df.empty:
848
+ return "No data available", go.Figure()
849
+
850
+ dataset_type = detect_dataset_type(df)
851
+
852
+ # Build statistics
853
+ stats = {
854
+ "Total Items": len(df),
855
+ "Dataset Type": dataset_type,
856
+ }
857
+
858
+ if 'is_positive' in df.columns:
859
+ positive_count = df['is_positive'].sum()
860
+ stats["Positive Items"] = int(positive_count)
861
+ stats["Negative Items"] = int(len(df) - positive_count)
862
+
863
+ if 'prob_delta' in df.columns:
864
+ stats["Avg Prob Delta"] = f"{df['prob_delta'].mean():.3f}"
865
+ stats["Max Prob Delta"] = f"{df['prob_delta'].max():.3f}"
866
+
867
+ if 'importance_score' in df.columns:
868
+ stats["Avg Importance"] = f"{df['importance_score'].mean():.3f}"
869
+
870
+ if 'sentence_category' in df.columns:
871
+ category_counts = df['sentence_category'].value_counts()
872
+ stats["Categories"] = len(category_counts)
873
+
874
+ if 'model_id' in df.columns:
875
+ stats["Models"] = df['model_id'].nunique()
876
+
877
+ # Build HTML
878
+ html_parts = ['<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">']
879
+
880
+ for key, value in stats.items():
881
+ html_parts.append(f"""
882
+ <div style="background: linear-gradient(135deg, #1e3a5f 0%, #0d1b2a 100%);
883
+ padding: 20px; border-radius: 10px; text-align: center;">
884
+ <div style="color: #6366f1; font-size: 1.5em; font-weight: bold;">{value}</div>
885
+ <div style="color: #a0a0a0; font-size: 0.9em; margin-top: 5px;">{key}</div>
886
+ </div>
887
+ """)
888
+
889
+ html_parts.append('</div>')
890
+
891
+ # Determine what to show in second chart
892
+ second_chart_title = "Category Distribution"
893
+ if 'sentence_category' in df.columns:
894
+ second_chart_title = "Sentence Category"
895
+ elif 'reasoning_pattern' in df.columns:
896
+ second_chart_title = "Reasoning Pattern"
897
+ elif 'task_type' in df.columns:
898
+ second_chart_title = "Task Type"
899
+ elif 'is_positive' in df.columns:
900
+ second_chart_title = "Positive vs Negative"
901
+
902
+ # Create distribution charts
903
+ fig = make_subplots(rows=1, cols=2,
904
+ subplot_titles=("Probability Delta Distribution", second_chart_title))
905
+
906
+ # First chart: Probability Delta histogram (using numpy for binning)
907
+ if 'prob_delta' in df.columns and len(df['prob_delta'].dropna()) > 0:
908
+ prob_data = df['prob_delta'].dropna().values
909
+ # Create histogram manually using numpy
910
+ counts, bin_edges = np.histogram(prob_data, bins=30)
911
+ bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)]
912
+ fig.add_trace(
913
+ go.Bar(x=bin_centers, y=counts.tolist(), name="Prob Delta",
914
+ marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9),
915
+ row=1, col=1
916
+ )
917
+ elif 'prob_after' in df.columns and len(df['prob_after'].dropna()) > 0:
918
+ # Fallback: show prob_after distribution
919
+ prob_data = df['prob_after'].dropna().values
920
+ counts, bin_edges = np.histogram(prob_data, bins=30)
921
+ bin_centers = [(bin_edges[i] + bin_edges[i+1]) / 2 for i in range(len(bin_edges)-1)]
922
+ fig.add_trace(
923
+ go.Bar(x=bin_centers, y=counts.tolist(), name="Prob After",
924
+ marker_color='#6366f1', width=(bin_edges[1]-bin_edges[0])*0.9),
925
+ row=1, col=1
926
+ )
927
+
928
+ # Second chart: Categories, patterns, or task types
929
+ if 'sentence_category' in df.columns:
930
+ category_counts = df['sentence_category'].value_counts()
931
+ fig.add_trace(
932
+ go.Bar(x=category_counts.index.tolist(), y=category_counts.values.tolist(), name="Categories",
933
+ marker_color='#22c55e'),
934
+ row=1, col=2
935
+ )
936
+ elif 'reasoning_pattern' in df.columns:
937
+ pattern_counts = df['reasoning_pattern'].value_counts()
938
+ fig.add_trace(
939
+ go.Bar(x=pattern_counts.index.tolist(), y=pattern_counts.values.tolist(), name="Patterns",
940
+ marker_color='#22c55e'),
941
+ row=1, col=2
942
+ )
943
+ elif 'task_type' in df.columns:
944
+ task_counts = df['task_type'].value_counts()
945
+ fig.add_trace(
946
+ go.Bar(x=task_counts.index.tolist(), y=task_counts.values.tolist(), name="Task Types",
947
+ marker_color='#22c55e'),
948
+ row=1, col=2
949
+ )
950
+ elif 'is_positive' in df.columns:
951
+ pos_neg_counts = df['is_positive'].value_counts()
952
+ labels = ['Positive' if v else 'Negative' for v in pos_neg_counts.index.tolist()]
953
+ fig.add_trace(
954
+ go.Bar(x=labels, y=pos_neg_counts.values.tolist(), name="Impact",
955
+ marker_color=['#22c55e' if l == 'Positive' else '#ef4444' for l in labels]),
956
+ row=1, col=2
957
+ )
958
+
959
+ fig.update_layout(
960
+ template="plotly_dark",
961
+ height=350,
962
+ showlegend=False
963
+ )
964
+
965
+ return "\n".join(html_parts), fig
966
+
967
+
968
+ # ============================================================================
969
+ # Gradio Interface
970
+ # ============================================================================
971
+
972
+ # Global state for loaded data
973
+ current_data = {"df": pd.DataFrame(), "type": "unknown"}
974
+
975
+
976
+ def load_dataset_action(source_type: str, dataset_id: str, file_upload):
977
+ """Handle dataset loading and return all visualization updates."""
978
+ global current_data
979
+
980
+ if source_type == "HuggingFace Hub":
981
+ if not dataset_id:
982
+ empty_fig = go.Figure()
983
+ empty_fig.update_layout(template="plotly_dark")
984
+ return ("Please enter a dataset ID", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
985
+ gr.update(maximum=0), gr.update(choices=[], value=None))
986
+ df, msg = load_hf_dataset(dataset_id)
987
+ else: # Local File
988
+ if file_upload is None:
989
+ empty_fig = go.Figure()
990
+ empty_fig.update_layout(template="plotly_dark")
991
+ return ("Please upload a file", "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
992
+ gr.update(maximum=0), gr.update(choices=[], value=None))
993
+ df, msg = load_jsonl_file(file_upload.name)
994
+
995
+ if df.empty:
996
+ empty_fig = go.Figure()
997
+ empty_fig.update_layout(template="plotly_dark")
998
+ return (msg, "", "No data", empty_fig, empty_fig, empty_fig, "No data", empty_fig,
999
+ gr.update(maximum=0), gr.update(choices=[], value=None))
1000
+
1001
+ current_data["df"] = df
1002
+ current_data["type"] = detect_dataset_type(df)
1003
+
1004
+ columns_info = f"Columns: {', '.join(df.columns[:10])}"
1005
+ if len(df.columns) > 10:
1006
+ columns_info += f" ... and {len(df.columns) - 10} more"
1007
+
1008
+ # Generate all visualizations
1009
+ stats_html, stats_fig = create_statistics_dashboard(df)
1010
+ graph_fig = create_thought_anchor_graph(df)
1011
+ embed_fig = create_embedding_visualization(df)
1012
+ circuit_html, circuit_fig = create_circuit_visualization(df)
1013
+
1014
+ # Generate query list
1015
+ query_choices = []
1016
+ if 'query' in df.columns:
1017
+ queries = df['query'].unique().tolist()
1018
+ for i, q in enumerate(queries):
1019
+ q_str = str(q) if q is not None else ""
1020
+ if len(q_str) > 80:
1021
+ query_choices.append(f"[{i+1}] {q_str[:77]}...")
1022
+ else:
1023
+ query_choices.append(f"[{i+1}] {q_str}")
1024
+
1025
+ return (msg, f"Dataset type: {current_data['type']}\n{columns_info}",
1026
+ stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig,
1027
+ gr.update(maximum=max(0, len(df) - 1)),
1028
+ gr.update(choices=query_choices, value=None))
1029
+
1030
+
1031
+ def get_token_details(idx: int) -> Tuple[str, go.Figure]:
1032
+ """Get details for a specific pivotal token."""
1033
+ df = current_data["df"]
1034
+ dataset_type = current_data.get("type", "unknown")
1035
+
1036
+ if df.empty:
1037
+ return "No data available. Please load a dataset first.", go.Figure()
1038
+
1039
+ # Handle unsupported dataset types
1040
+ if dataset_type == 'dpo_pairs':
1041
+ html = """
1042
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
1043
+ <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3>
1044
+ <p style="color: #a0a0a0;">This visualization is not available for DPO pairs datasets.</p>
1045
+ <p style="color: #a0a0a0;">DPO pairs contain prompt/chosen/rejected structure without token-level context.</p>
1046
+ <p style="color: #6366f1; margin-top: 20px;">
1047
+ Try loading a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset instead.
1048
+ </p>
1049
+ </div>
1050
+ """
1051
+ return html, go.Figure()
1052
+
1053
+ if idx >= len(df):
1054
+ return "Index out of range", go.Figure()
1055
+
1056
+ row = df.iloc[idx]
1057
+
1058
+ context = row.get('pivot_context', row.get('prefix_context', ''))
1059
+ token = row.get('pivot_token', row.get('sentence', ''))
1060
+ prob_delta = row.get('prob_delta', 0)
1061
+ prob_before = row.get('prob_before', row.get('prob_with_sentence', 0.5))
1062
+ prob_after = row.get('prob_after', row.get('prob_without_sentence', 0.5))
1063
+
1064
+ # Handle missing data
1065
+ if not context and not token:
1066
+ html = """
1067
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
1068
+ <h3 style="color: #ef4444;">Missing Data</h3>
1069
+ <p style="color: #a0a0a0;">This dataset doesn't have the expected fields for token visualization.</p>
1070
+ </div>
1071
+ """
1072
+ return html, go.Figure()
1073
+
1074
+ html = create_token_highlight_html(context, token, prob_delta)
1075
+ chart = create_probability_chart(prob_before, prob_after)
1076
+
1077
+ return html, chart
1078
+
1079
+
1080
+ def get_original_query_from_label(label: str) -> str:
1081
+ """Extract original query from truncated dropdown label like '[1] query...'"""
1082
+ if not label or not isinstance(label, str):
1083
+ return None
1084
+
1085
+ df = current_data["df"]
1086
+ if df.empty or 'query' not in df.columns:
1087
+ return None
1088
+
1089
+ # Extract index from "[N] query..." format
1090
+ match = re.match(r'\[(\d+)\]', label)
1091
+ if match:
1092
+ idx = int(match.group(1)) - 1 # Convert to 0-based index
1093
+ queries = df['query'].unique().tolist()
1094
+ if 0 <= idx < len(queries):
1095
+ return queries[idx]
1096
+
1097
+ return None
1098
+
1099
+
1100
+ def update_graph_visualization(query_dropdown: str = None):
1101
+ """Update the thought anchor graph."""
1102
+ dataset_type = current_data.get("type", "unknown")
1103
+ if dataset_type == 'dpo_pairs':
1104
+ fig = go.Figure()
1105
+ fig.add_annotation(
1106
+ text="Reasoning Graph is not available for DPO pairs datasets.<br>Load a pivotal_tokens or thought_anchors dataset.",
1107
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
1108
+ font=dict(size=14, color="#a0a0a0")
1109
+ )
1110
+ fig.update_layout(template="plotly_dark", height=400)
1111
+ return fig
1112
+
1113
+ # Convert truncated label back to original query
1114
+ original_query = get_original_query_from_label(query_dropdown)
1115
+ return create_thought_anchor_graph(current_data["df"], original_query)
1116
+
1117
+
1118
+ def update_embedding_visualization(color_by: str):
1119
+ """Update the embedding visualization."""
1120
+ dataset_type = current_data.get("type", "unknown")
1121
+ if dataset_type == 'dpo_pairs':
1122
+ fig = go.Figure()
1123
+ fig.add_annotation(
1124
+ text="Embedding Space is not available for DPO pairs datasets.<br>Load a pivotal_tokens, thought_anchors, or steering_vectors dataset.",
1125
+ xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False,
1126
+ font=dict(size=14, color="#a0a0a0")
1127
+ )
1128
+ fig.update_layout(template="plotly_dark", height=400)
1129
+ return fig
1130
+ return create_embedding_visualization(current_data["df"], color_by)
1131
+
1132
+
1133
+ def update_circuit_view(query_idx: int):
1134
+ """Update the circuit view."""
1135
+ dataset_type = current_data.get("type", "unknown")
1136
+ if dataset_type == 'dpo_pairs':
1137
+ html = """
1138
+ <div style="padding: 40px; text-align: center; background-color: #1a1a2e; border-radius: 10px;">
1139
+ <h3 style="color: #f59e0b;">DPO Pairs Dataset</h3>
1140
+ <p style="color: #a0a0a0;">Circuit Tracer is not available for DPO pairs datasets.</p>
1141
+ <p style="color: #6366f1; margin-top: 20px;">
1142
+ Load a <strong>pivotal_tokens</strong> or <strong>thought_anchors</strong> dataset to explore reasoning circuits.
1143
+ </p>
1144
+ </div>
1145
+ """
1146
+ return html, go.Figure()
1147
+ return create_circuit_visualization(current_data["df"], int(query_idx))
1148
+
1149
+
1150
+ def update_statistics():
1151
+ """Update the statistics dashboard."""
1152
+ return create_statistics_dashboard(current_data["df"])
1153
+
1154
+
1155
+ def get_query_list():
1156
+ """Get list of unique queries with truncated display labels."""
1157
+ df = current_data["df"]
1158
+ if df.empty or 'query' not in df.columns:
1159
+ return gr.update(choices=[], value=None)
1160
+
1161
+ queries = df['query'].unique().tolist()
1162
+ # Return simple truncated strings for dropdown choices
1163
+ truncated_queries = []
1164
+ for i, q in enumerate(queries):
1165
+ q_str = str(q) if q is not None else ""
1166
+ if len(q_str) > 80:
1167
+ truncated_queries.append(f"[{i+1}] {q_str[:77]}...")
1168
+ else:
1169
+ truncated_queries.append(f"[{i+1}] {q_str}")
1170
+
1171
+ return gr.update(choices=truncated_queries, value=None)
1172
+
1173
+
1174
+ def refresh_all():
1175
+ """Refresh all visualizations."""
1176
+ df = current_data["df"]
1177
+ if df.empty:
1178
+ empty_fig = go.Figure()
1179
+ empty_fig.update_layout(template="plotly_dark")
1180
+ return (
1181
+ "No data loaded",
1182
+ empty_fig,
1183
+ empty_fig,
1184
+ empty_fig,
1185
+ "No data loaded",
1186
+ empty_fig
1187
+ )
1188
+
1189
+ stats_html, stats_fig = create_statistics_dashboard(df)
1190
+ graph_fig = create_thought_anchor_graph(df)
1191
+ embed_fig = create_embedding_visualization(df)
1192
+ circuit_html, circuit_fig = create_circuit_visualization(df)
1193
+
1194
+ return stats_html, stats_fig, graph_fig, embed_fig, circuit_html, circuit_fig
1195
+
1196
+
1197
+ # ============================================================================
1198
+ # Build Gradio App
1199
+ # ============================================================================
1200
+
1201
+ # Pre-defined HuggingFace datasets
1202
+ HF_DATASETS = [
1203
+ "codelion/Qwen3-0.6B-pts",
1204
+ "codelion/Qwen3-0.6B-pts-thought-anchors",
1205
+ "codelion/Qwen3-0.6B-pts-steering-vectors",
1206
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts",
1207
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-thought-anchors",
1208
+ "codelion/DeepSeek-R1-Distill-Qwen-1.5B-pts-steering-vectors",
1209
+ ]
1210
+
1211
+ # CSS configuration
1212
+ CSS = """
1213
+ .gradio-container { max-width: 1400px !important; }
1214
+ .main-header { text-align: center; margin-bottom: 20px; }
1215
+ """
1216
+
1217
+ with gr.Blocks(title="PTS Visualizer", css=CSS) as demo:
1218
+
1219
+ # Header
1220
+ gr.Markdown("""
1221
+ # PTS Visualizer
1222
+ ### Interactive Exploration of Pivotal Tokens, Thought Anchors & Reasoning Circuits
1223
+
1224
+ A [Neuronpedia](https://neuronpedia.org/)-inspired platform for understanding how language models reason.
1225
+ Load datasets from HuggingFace Hub or upload your own JSONL files.
1226
+
1227
+ 🔗 [Browse more PTS datasets on HuggingFace](https://huggingface.co/datasets?other=pts)
1228
+ """)
1229
+
1230
+ # Data Loading Section
1231
+ with gr.Accordion("Load Dataset", open=True):
1232
+ with gr.Row():
1233
+ source_type = gr.Radio(
1234
+ choices=["HuggingFace Hub", "Local File"],
1235
+ value="HuggingFace Hub",
1236
+ label="Data Source"
1237
+ )
1238
+
1239
+ with gr.Row():
1240
+ with gr.Column(scale=3):
1241
+ dataset_dropdown = gr.Dropdown(
1242
+ choices=HF_DATASETS,
1243
+ value=HF_DATASETS[0],
1244
+ label="Select Dataset",
1245
+ info="Choose a pre-defined dataset or enter your own HuggingFace dataset ID"
1246
+ )
1247
+ with gr.Column(scale=1):
1248
+ file_upload = gr.File(
1249
+ label="Or Upload JSONL",
1250
+ file_types=[".jsonl", ".json"]
1251
+ )
1252
+
1253
+ with gr.Row():
1254
+ load_btn = gr.Button("Load Dataset", variant="primary")
1255
+ refresh_btn = gr.Button("Refresh Visualizations", variant="secondary")
1256
+
1257
+ with gr.Row():
1258
+ load_status = gr.Textbox(label="Status", interactive=False)
1259
+ dataset_info = gr.Textbox(label="Dataset Info", interactive=False)
1260
+
1261
+ # Main Visualization Tabs
1262
+ with gr.Tabs():
1263
+
1264
+ # Overview Tab
1265
+ with gr.TabItem("Overview"):
1266
+ gr.Markdown("### Dataset Statistics")
1267
+ stats_html = gr.HTML()
1268
+ stats_chart = gr.Plot()
1269
+
1270
+ # Token Explorer Tab
1271
+ with gr.TabItem("Token Explorer"):
1272
+ gr.Markdown("### Explore Pivotal Tokens")
1273
+ with gr.Row():
1274
+ with gr.Column(scale=1):
1275
+ token_slider = gr.Slider(
1276
+ minimum=0, maximum=100, step=1, value=0,
1277
+ label="Token Index"
1278
+ )
1279
+ with gr.Column(scale=3):
1280
+ token_html = gr.HTML(label="Token in Context")
1281
+ prob_chart = gr.Plot(label="Probability Change")
1282
+
1283
+ # Thought Anchor Graph Tab
1284
+ with gr.TabItem("Reasoning Graph"):
1285
+ gr.Markdown("### Thought Anchor Dependency Graph")
1286
+ gr.Markdown("""
1287
+ *Visualizes causal dependencies between reasoning steps.
1288
+ Green nodes indicate positive impact, red nodes indicate negative impact.
1289
+ Node size reflects importance score.*
1290
+ """)
1291
+ with gr.Row():
1292
+ query_filter = gr.Dropdown(
1293
+ choices=[],
1294
+ value=None,
1295
+ label="Filter by Query"
1296
+ )
1297
+ graph_plot = gr.Plot()
1298
+
1299
+ # Embedding Visualization Tab
1300
+ with gr.TabItem("Embedding Space"):
1301
+ gr.Markdown("### Embedding Space Visualization")
1302
+ gr.Markdown("*t-SNE projection of sentence/token embeddings. Explore clusters and patterns.*")
1303
+ with gr.Row():
1304
+ color_dropdown = gr.Dropdown(
1305
+ choices=["is_positive", "sentence_category", "reasoning_pattern", "task_type"],
1306
+ value="is_positive",
1307
+ label="Color By"
1308
+ )
1309
+ embed_plot = gr.Plot()
1310
+
1311
+ # Circuit Tracer Tab
1312
+ with gr.TabItem("Circuit Tracer"):
1313
+ gr.Markdown("### Step-by-Step Reasoning Circuit")
1314
+ gr.Markdown("*Walk through the reasoning process step by step. See how each step affects the probability of success.*")
1315
+ with gr.Row():
1316
+ circuit_query_idx = gr.Slider(
1317
+ minimum=0, maximum=100, step=1, value=0,
1318
+ label="Query Index"
1319
+ )
1320
+ circuit_html = gr.HTML()
1321
+ circuit_chart = gr.Plot()
1322
+
1323
+ # Event handlers - using api_name=False to prevent schema generation issues
1324
+ load_btn.click(
1325
+ fn=load_dataset_action,
1326
+ inputs=[source_type, dataset_dropdown, file_upload],
1327
+ outputs=[load_status, dataset_info, stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart, token_slider, query_filter],
1328
+ api_name=False
1329
+ )
1330
+
1331
+ refresh_btn.click(
1332
+ fn=refresh_all,
1333
+ outputs=[stats_html, stats_chart, graph_plot, embed_plot, circuit_html, circuit_chart],
1334
+ api_name=False
1335
+ )
1336
 
1337
+ token_slider.change(
1338
+ fn=get_token_details,
1339
+ inputs=[token_slider],
1340
+ outputs=[token_html, prob_chart],
1341
+ api_name=False
1342
+ )
1343
 
1344
+ query_filter.change(
1345
+ fn=update_graph_visualization,
1346
+ inputs=[query_filter],
1347
+ outputs=[graph_plot],
1348
+ api_name=False
1349
+ )
1350
 
1351
+ color_dropdown.change(
1352
+ fn=update_embedding_visualization,
1353
+ inputs=[color_dropdown],
1354
+ outputs=[embed_plot],
1355
+ api_name=False
1356
+ )
1357
 
1358
+ circuit_query_idx.change(
1359
+ fn=update_circuit_view,
1360
+ inputs=[circuit_query_idx],
1361
+ outputs=[circuit_html, circuit_chart],
1362
+ api_name=False
1363
+ )
1364
 
 
1365
 
1366
+ # ============================================================================
1367
+ # Main Entry Point
1368
+ # ============================================================================
 
 
1369
 
1370
+ if __name__ == "__main__":
1371
+ demo.launch()