cdpearlman commited on
Commit
1221d69
·
1 Parent(s): 6fac99b

feat: Add top-5 predictions with deltas, certainty meter, and bar charts in accordion panels

Browse files
app.py CHANGED
@@ -506,12 +506,14 @@ def update_check_token_graph(check_token_data):
506
  [State('model-dropdown', 'value')]
507
  )
508
  def create_layer_accordions(activation_data, model_name):
509
- """Create accordion panels for each layer."""
510
  if not activation_data or not model_name:
511
  return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
512
 
513
  try:
514
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
515
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
516
  tokenizer = AutoTokenizer.from_pretrained(model_name)
517
 
@@ -521,38 +523,98 @@ def create_layer_accordions(activation_data, model_name):
521
  if not layer_data:
522
  return html.P("No layer data available.", className="placeholder-text")
523
 
524
- # Create accordion panels
525
  accordions = []
526
- for i, layer in enumerate(layer_data):
527
  layer_num = layer['layer_num']
528
  top_token = layer.get('top_token', 'N/A')
529
  top_prob = layer.get('top_prob', 0.0)
530
- top_3 = layer.get('top_3_tokens', [])
 
 
531
 
532
- # Create summary header
533
  if top_token:
534
- summary_text = f"Layer L{layer_num}: '{top_token}' (p={top_prob:.3f})"
535
  else:
536
  summary_text = f"Layer L{layer_num}: (no prediction)"
537
 
538
- # Create accordion panel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  panel = html.Details([
540
  html.Summary(summary_text, className="layer-summary"),
541
- html.Div([
542
- html.P(f"Layer {layer_num} details (placeholder for future content)")
543
- ], className="layer-content")
544
  ], className="layer-accordion")
545
 
546
  accordions.append(panel)
547
-
548
- # Add token chips between adjacent layers (not after last layer)
549
- if i < len(layer_data) - 1 and top_3:
550
- chips = html.Div([
551
- html.Span("→", className="token-arrow"),
552
- *[html.Span(f"{tok} ({prob:.2f})", className="token-chip")
553
- for tok, prob in top_3]
554
- ], className="token-chips-row")
555
- accordions.append(chips)
556
 
557
  return html.Div(accordions)
558
 
 
506
  [State('model-dropdown', 'value')]
507
  )
508
  def create_layer_accordions(activation_data, model_name):
509
+ """Create accordion panels for each layer with top-5 bar charts, deltas, and certainty."""
510
  if not activation_data or not model_name:
511
  return html.P("Run analysis to see layer-by-layer predictions.", className="placeholder-text")
512
 
513
  try:
514
  from transformers import AutoModelForCausalLM, AutoTokenizer
515
+ import plotly.graph_objs as go
516
+
517
  model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='eager')
518
  tokenizer = AutoTokenizer.from_pretrained(model_name)
519
 
 
523
  if not layer_data:
524
  return html.P("No layer data available.", className="placeholder-text")
525
 
526
+ # Create accordion panels (reversed to show final layer first)
527
  accordions = []
528
+ for i, layer in enumerate(reversed(layer_data)):
529
  layer_num = layer['layer_num']
530
  top_token = layer.get('top_token', 'N/A')
531
  top_prob = layer.get('top_prob', 0.0)
532
+ top_5 = layer.get('top_5_tokens', [])
533
+ deltas = layer.get('deltas', {})
534
+ certainty = layer.get('certainty', 0.0)
535
 
536
+ # Create summary header with certainty
537
  if top_token:
538
+ summary_text = f"Layer L{layer_num}: '{top_token}' (p={top_prob:.3f}, certainty={certainty:.2f})"
539
  else:
540
  summary_text = f"Layer L{layer_num}: (no prediction)"
541
 
542
+ # Create accordion panel content
543
+ content_items = []
544
+
545
+ if top_5:
546
+ # Create horizontal bar chart for top-5 tokens
547
+ tokens = [tok for tok, _ in top_5]
548
+ probs = [prob for _, prob in top_5]
549
+
550
+ # Create delta annotations (▲/▼ with color)
551
+ annotations = []
552
+ for idx, (token, prob) in enumerate(top_5):
553
+ delta = deltas.get(token, 0.0)
554
+ if abs(delta) > 0.001: # Only show meaningful deltas
555
+ symbol = '▲' if delta > 0 else '▼'
556
+ color = '#28a745' if delta > 0 else '#dc3545'
557
+ annotations.append({
558
+ 'x': prob,
559
+ 'y': idx,
560
+ 'text': f'{symbol} {abs(delta):.3f}',
561
+ 'showarrow': False,
562
+ 'xanchor': 'left',
563
+ 'xshift': 10,
564
+ 'font': {'size': 10, 'color': color}
565
+ })
566
+
567
+ # Create Plotly figure
568
+ fig = go.Figure(data=[
569
+ go.Bar(
570
+ x=probs,
571
+ y=tokens,
572
+ orientation='h',
573
+ marker={'color': '#667eea'},
574
+ text=[f'{p:.3f}' for p in probs],
575
+ textposition='auto',
576
+ hovertemplate='%{y}: %{x:.4f}<extra></extra>'
577
+ )
578
+ ])
579
+
580
+ fig.update_layout(
581
+ title={
582
+ 'text': f'Top 5 Predictions (Certainty: {certainty:.2f})',
583
+ 'font': {'size': 14}
584
+ },
585
+ xaxis={'title': 'Probability', 'range': [0, max(probs) * 1.15]},
586
+ yaxis={'title': '', 'autorange': 'reversed'},
587
+ height=250,
588
+ margin={'l': 100, 'r': 80, 't': 50, 'b': 40},
589
+ annotations=annotations,
590
+ hovermode='closest'
591
+ )
592
+
593
+ content_items.append(
594
+ dcc.Graph(
595
+ figure=fig,
596
+ config={'displayModeBar': False},
597
+ style={'marginBottom': '10px'}
598
+ )
599
+ )
600
+
601
+ # Add certainty tooltip explanation
602
+ content_items.append(html.Div([
603
+ html.Small([
604
+ html.I(className="fas fa-info-circle", style={'marginRight': '5px', 'color': '#667eea'}),
605
+ f"Certainty = 1 − H(p_top5)/log(5), where H is Shannon entropy. ",
606
+ "Higher values indicate more confident predictions."
607
+ ], style={'color': '#6c757d', 'fontStyle': 'italic'})
608
+ ], style={'marginTop': '5px'}))
609
+ else:
610
+ content_items.append(html.P("No predictions available"))
611
+
612
  panel = html.Details([
613
  html.Summary(summary_text, className="layer-summary"),
614
+ html.Div(content_items, className="layer-content")
 
 
615
  ], className="layer-accordion")
616
 
617
  accordions.append(panel)
 
 
 
 
 
 
 
 
 
618
 
619
  return html.Div(accordions)
620
 
components/main_panel.py CHANGED
@@ -45,10 +45,19 @@ def create_main_panel():
45
  ], id='check-token-graph-container', style={'flex': '1', 'minWidth': '300px', 'display': 'none'})
46
  ], className="input-container", style={"marginBottom": "1.5rem", "display": "flex", "gap": "1.5rem", "alignItems": "flex-start"}),
47
 
48
- # Layer-based visualization section
49
  html.Div([
50
  html.H3("Layer-by-Layer Predictions", className="section-title"),
51
- html.Div(id='layer-accordions-container', className="layer-accordions")
 
 
 
 
 
 
 
 
 
52
  ], className="visualization-section"),
53
 
54
  # Two-Prompt Comparison section (shown when comparing)
 
45
  ], id='check-token-graph-container', style={'flex': '1', 'minWidth': '300px', 'display': 'none'})
46
  ], className="input-container", style={"marginBottom": "1.5rem", "display": "flex", "gap": "1.5rem", "alignItems": "flex-start"}),
47
 
48
+ # Layer-based visualization section with loading spinner
49
  html.Div([
50
  html.H3("Layer-by-Layer Predictions", className="section-title"),
51
+ dcc.Loading(
52
+ id="layer-accordions-loading",
53
+ type="default",
54
+ children=html.Div(id='layer-accordions-container', className="layer-accordions"),
55
+ overlay_style={"visibility":"visible", "opacity": .7, "backgroundColor": "white"},
56
+ custom_spinner=html.Div([
57
+ html.I(className="fas fa-spinner fa-spin", style={'fontSize': '24px', 'color': '#667eea', 'marginRight': '10px'}),
58
+ html.Span("Loading visuals...", style={'fontSize': '16px', 'color': '#495057'})
59
+ ], style={'display': 'flex', 'alignItems': 'center', 'justifyContent': 'center', 'padding': '2rem'})
60
+ )
61
  ], className="visualization-section"),
62
 
63
  # Two-Prompt Comparison section (shown when comparing)
todo.md CHANGED
@@ -15,12 +15,13 @@ Note: Minimal-change approach. Reuse existing files (`app.py`, `components/main_
15
  - [ ] Add CSS utility classes for compact header + tokens chips row
16
 
17
  ## Feature: Per-layer predictions (top-5), deltas, certainty meter
18
- - [ ] Extend forward pass outputs to include per-layer top-5 tokens + probs (reusing logit lens) in `utils/model_patterns.py`
19
- - [ ] Compute delta vs previous layer for overlapping tokens (prob change, signed)
20
- - [ ] Compute certainty meter using normalized entropy over top-5 probs (0–1)
21
- - [ ] Render a `dcc.Graph` horizontal bar chart (top-5) inside each panel body
22
- - [ ] Show per-token delta as small ▲/▼ with color next to bars
23
- - [ ] Add tooltip explaining certainty: "certainty = 1 − H(p_top5)/log(5)"
 
24
 
25
  ## Feature: Simplified attention view + open full interactive view
26
  - [ ] From `activation_data['attention_outputs']`, compute top-3 attended input tokens for current position (per layer)
 
15
  - [ ] Add CSS utility classes for compact header + tokens chips row
16
 
17
  ## Feature: Per-layer predictions (top-5), deltas, certainty meter
18
+ - [x] Extend forward pass outputs to include per-layer top-5 tokens + probs (reusing logit lens) in `utils/model_patterns.py`
19
+ - [x] Compute delta vs previous layer for overlapping tokens (prob change, signed)
20
+ - [x] Compute certainty meter using normalized entropy over top-5 probs (0–1)
21
+ - [x] Render a `dcc.Graph` horizontal bar chart (top-5) inside each panel body
22
+ - [x] Show per-token delta as small ▲/▼ with color next to bars
23
+ - [x] Add tooltip explaining certainty: "certainty = 1 − H(p_top5)/log(5)"
24
+ - [x] Add a spinning "Loading visuals..." after loading data until all the visualizations are loaded
25
 
26
  ## Feature: Simplified attention view + open full interactive view
27
  - [ ] From `activation_data['attention_outputs']`, compute top-3 attended input tokens for current position (per layer)
utils/__pycache__/model_patterns.cpython-311.pyc CHANGED
Binary files a/utils/__pycache__/model_patterns.cpython-311.pyc and b/utils/__pycache__/model_patterns.cpython-311.pyc differ
 
utils/model_patterns.py CHANGED
@@ -203,9 +203,9 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
203
  return result
204
 
205
 
206
- def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer, norm_parameter: Optional[str] = None) -> List[Tuple[str, float]]:
207
  """
208
- Transform layer output to top 3 token probabilities using logit lens.
209
 
210
  For standard logit lens, use block/layer outputs (residual stream), not component outputs.
211
  The residual stream contains the full hidden state with all accumulated information.
@@ -220,9 +220,10 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, lo
220
  logit_lens_parameter: Not used (deprecated)
221
  tokenizer: Tokenizer for decoding
222
  norm_parameter: Parameter path for final norm layer (e.g., "model.norm.weight")
 
223
 
224
  Returns:
225
- List of (token_string, probability) tuples for top 3 tokens
226
  """
227
  with torch.no_grad():
228
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
@@ -242,8 +243,8 @@ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, lo
242
  # Step 3: Get probabilities via softmax
243
  probs = F.softmax(logits[0, -1, :], dim=-1)
244
 
245
- # Step 4: Extract top 3 tokens
246
- top_probs, top_indices = torch.topk(probs, k=3)
247
 
248
  return [
249
  (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
@@ -288,9 +289,9 @@ def get_norm_layer_from_parameter(model, norm_parameter: Optional[str]) -> Optio
288
  return None
289
 
290
 
291
- def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer) -> Optional[List[Tuple[str, float]]]:
292
  """
293
- Helper: Get top 3 tokens for a layer's block output.
294
 
295
  Uses block outputs (residual stream) which represent the full hidden state
296
  after all layer computations (attention + feedforward + residuals).
@@ -306,7 +307,7 @@ def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, to
306
  norm_params = activation_data.get('norm_parameters', [])
307
  norm_parameter = norm_params[0] if norm_params else None
308
 
309
- return logit_lens_transformation(layer_output, [], model, None, tokenizer, norm_parameter)
310
  except Exception as e:
311
  print(f"Warning: Could not compute logit lens for {module_name}: {e}")
312
  return None
@@ -388,12 +389,43 @@ def get_check_token_probabilities(activation_data: Dict[str, Any], model, tokeni
388
  return None
389
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
392
  """
393
- Extract layer-by-layer data for accordion display.
394
 
395
  Returns:
396
- List of dicts with: layer_num, top_token, top_prob, top_3_tokens (list of (token, prob))
397
  """
398
  layer_modules = activation_data.get('block_modules', [])
399
  if not layer_modules:
@@ -407,26 +439,47 @@ def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> Lis
407
 
408
  logit_lens_enabled = activation_data.get('logit_lens_parameter') is not None
409
  layer_data = []
 
410
 
411
  for layer_num, module_name in layer_info:
412
- top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer) if logit_lens_enabled else None
413
 
414
  if top_tokens:
415
  top_token, top_prob = top_tokens[0]
 
 
 
 
 
 
 
 
 
 
 
416
  layer_data.append({
417
  'layer_num': layer_num,
418
  'module_name': module_name,
419
  'top_token': top_token,
420
  'top_prob': top_prob,
421
- 'top_3_tokens': top_tokens[:3] # Get top 3 for chips
 
 
 
422
  })
 
 
 
423
  else:
424
  layer_data.append({
425
  'layer_num': layer_num,
426
  'module_name': module_name,
427
  'top_token': None,
428
  'top_prob': None,
429
- 'top_3_tokens': []
 
 
 
430
  })
431
 
432
  return layer_data
 
203
  return result
204
 
205
 
206
+ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer, norm_parameter: Optional[str] = None, top_k: int = 5) -> List[Tuple[str, float]]:
207
  """
208
+ Transform layer output to top K token probabilities using logit lens.
209
 
210
  For standard logit lens, use block/layer outputs (residual stream), not component outputs.
211
  The residual stream contains the full hidden state with all accumulated information.
 
220
  logit_lens_parameter: Not used (deprecated)
221
  tokenizer: Tokenizer for decoding
222
  norm_parameter: Parameter path for final norm layer (e.g., "model.norm.weight")
223
+ top_k: Number of top tokens to return (default: 5)
224
 
225
  Returns:
226
+ List of (token_string, probability) tuples for top K tokens
227
  """
228
  with torch.no_grad():
229
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
 
243
  # Step 3: Get probabilities via softmax
244
  probs = F.softmax(logits[0, -1, :], dim=-1)
245
 
246
+ # Step 4: Extract top K tokens
247
+ top_probs, top_indices = torch.topk(probs, k=top_k)
248
 
249
  return [
250
  (tokenizer.decode([idx.item()], skip_special_tokens=False), prob.item())
 
289
  return None
290
 
291
 
292
+ def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer, top_k: int = 5) -> Optional[List[Tuple[str, float]]]:
293
  """
294
+ Helper: Get top K tokens for a layer's block output.
295
 
296
  Uses block outputs (residual stream) which represent the full hidden state
297
  after all layer computations (attention + feedforward + residuals).
 
307
  norm_params = activation_data.get('norm_parameters', [])
308
  norm_parameter = norm_params[0] if norm_params else None
309
 
310
+ return logit_lens_transformation(layer_output, [], model, None, tokenizer, norm_parameter, top_k=top_k)
311
  except Exception as e:
312
  print(f"Warning: Could not compute logit lens for {module_name}: {e}")
313
  return None
 
389
  return None
390
 
391
 
392
+ def _compute_certainty(probs: List[float]) -> float:
393
+ """
394
+ Compute normalized certainty from probability distribution.
395
+ Formula: certainty = 1 - H(p)/log(K) where H is Shannon entropy.
396
+
397
+ Args:
398
+ probs: List of probabilities (top-K)
399
+
400
+ Returns:
401
+ Certainty score in [0, 1] where 1 = completely certain
402
+ """
403
+ import math
404
+ if not probs or len(probs) == 0:
405
+ return 0.0
406
+
407
+ # Compute Shannon entropy: H = -Σ(p_i * log(p_i))
408
+ entropy = 0.0
409
+ for p in probs:
410
+ if p > 0:
411
+ entropy -= p * math.log(p)
412
+
413
+ # Normalize by max entropy (log(K))
414
+ max_entropy = math.log(len(probs))
415
+ if max_entropy == 0:
416
+ return 1.0
417
+
418
+ # Certainty = 1 - normalized_entropy
419
+ certainty = 1.0 - (entropy / max_entropy)
420
+ return max(0.0, min(1.0, certainty)) # Clamp to [0, 1]
421
+
422
+
423
  def extract_layer_data(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
424
  """
425
+ Extract layer-by-layer data for accordion display with top-5, deltas, and certainty.
426
 
427
  Returns:
428
+ List of dicts with: layer_num, top_token, top_prob, top_5_tokens, deltas, certainty
429
  """
430
  layer_modules = activation_data.get('block_modules', [])
431
  if not layer_modules:
 
439
 
440
  logit_lens_enabled = activation_data.get('logit_lens_parameter') is not None
441
  layer_data = []
442
+ prev_token_probs = {} # Track previous layer's token probabilities
443
 
444
  for layer_num, module_name in layer_info:
445
+ top_tokens = _get_top_tokens(activation_data, module_name, model, tokenizer, top_k=5) if logit_lens_enabled else None
446
 
447
  if top_tokens:
448
  top_token, top_prob = top_tokens[0]
449
+
450
+ # Compute deltas vs previous layer
451
+ deltas = {}
452
+ for token, prob in top_tokens:
453
+ prev_prob = prev_token_probs.get(token, 0.0)
454
+ deltas[token] = prob - prev_prob
455
+
456
+ # Compute certainty from top-5 probabilities
457
+ probs = [prob for _, prob in top_tokens]
458
+ certainty = _compute_certainty(probs)
459
+
460
  layer_data.append({
461
  'layer_num': layer_num,
462
  'module_name': module_name,
463
  'top_token': top_token,
464
  'top_prob': top_prob,
465
+ 'top_3_tokens': top_tokens[:3], # Keep for backward compatibility
466
+ 'top_5_tokens': top_tokens[:5], # New: top-5 for bar chart
467
+ 'deltas': deltas,
468
+ 'certainty': certainty
469
  })
470
+
471
+ # Update previous layer probabilities
472
+ prev_token_probs = {token: prob for token, prob in top_tokens}
473
  else:
474
  layer_data.append({
475
  'layer_num': layer_num,
476
  'module_name': module_name,
477
  'top_token': None,
478
  'top_prob': None,
479
+ 'top_3_tokens': [],
480
+ 'top_5_tokens': [],
481
+ 'deltas': {},
482
+ 'certainty': 0.0
483
  })
484
 
485
  return layer_data