cdpearlman commited on
Commit
f33c95a
·
1 Parent(s): 1dd1822

Using final layer output instead of mlp for simpler computation

Browse files
Files changed (4) hide show
  1. app.py +20 -18
  2. components/sidebar.py +5 -5
  3. utils/model_config.py +27 -25
  4. utils/model_patterns.py +77 -35
app.py CHANGED
@@ -67,11 +67,11 @@ app.layout = html.Div([
67
  @app.callback(
68
  [Output('session-patterns-store', 'data'),
69
  Output('attention-modules-dropdown', 'options'),
70
- Output('mlp-modules-dropdown', 'options'),
71
  Output('norm-params-dropdown', 'options'),
72
  Output('logit-lens-dropdown', 'options'),
73
  Output('attention-modules-dropdown', 'value', allow_duplicate=True),
74
- Output('mlp-modules-dropdown', 'value', allow_duplicate=True),
75
  Output('norm-params-dropdown', 'value', allow_duplicate=True),
76
  Output('logit-lens-dropdown', 'value', allow_duplicate=True),
77
  Output('loading-indicator', 'children')],
@@ -122,8 +122,9 @@ def load_model_patterns(selected_model):
122
  attention_options = create_grouped_options(
123
  module_patterns, ['attn', 'attention'], 'modules'
124
  )
125
- mlp_options = create_grouped_options(
126
- module_patterns, ['mlp'], 'modules'
 
127
  )
128
  norm_options = create_grouped_options(
129
  param_patterns, ['norm', 'layernorm', 'layer_norm'], 'params'
@@ -170,11 +171,11 @@ def load_model_patterns(selected_model):
170
  return (
171
  patterns_data,
172
  attention_options,
173
- mlp_options,
174
  norm_options,
175
  logit_lens_options,
176
  auto_selections.get('attention_selection', []),
177
- auto_selections.get('mlp_selection', []),
178
  auto_selections.get('norm_selection', []),
179
  auto_selections.get('logit_lens_selection'),
180
  loading_content
@@ -207,7 +208,7 @@ def show_loading_spinner(selected_model):
207
  # Callback to clear all selections when Clear button is pressed
208
  @app.callback(
209
  [Output('attention-modules-dropdown', 'value'),
210
- Output('mlp-modules-dropdown', 'value'),
211
  Output('norm-params-dropdown', 'value'),
212
  Output('logit-lens-dropdown', 'value'),
213
  Output('session-activation-store', 'data'),
@@ -228,7 +229,7 @@ def clear_all_selections(n_clicks):
228
 
229
  return (
230
  None, # attention-modules-dropdown value
231
- None, # mlp-modules-dropdown value
232
  None, # norm-params-dropdown value
233
  None, # logit-lens-dropdown value
234
  {}, # session-activation-store data
@@ -260,20 +261,20 @@ def show_analysis_loading_spinner(n_clicks):
260
  [State('model-dropdown', 'value'),
261
  State('prompt-input', 'value'),
262
  State('attention-modules-dropdown', 'value'),
263
- State('mlp-modules-dropdown', 'value'),
264
  State('norm-params-dropdown', 'value'),
265
  State('logit-lens-dropdown', 'value'),
266
  State('session-patterns-store', 'data')],
267
  prevent_initial_call=True
268
  )
269
- def run_analysis(n_clicks, model_name, prompt, attn_patterns, mlp_patterns, norm_patterns, logit_pattern, patterns_data):
270
  """Run forward pass and generate cytoscape visualization."""
271
  print(f"\n=== DEBUG: run_analysis START ===")
272
  print(f"DEBUG: n_clicks={n_clicks}, model_name={model_name}, prompt='{prompt}'")
273
- print(f"DEBUG: mlp_patterns={mlp_patterns}")
274
  print(f"DEBUG: logit_pattern={logit_pattern}")
275
 
276
- if not n_clicks or not model_name or not prompt or not mlp_patterns:
277
  print("DEBUG: Missing required inputs, returning empty")
278
  return [], {}, None
279
 
@@ -289,10 +290,11 @@ def run_analysis(n_clicks, model_name, prompt, attn_patterns, mlp_patterns, norm
289
  param_patterns = patterns_data.get('param_patterns', {})
290
  all_patterns = {**module_patterns, **param_patterns}
291
 
 
292
  config = {
293
  'attention_modules': [mod for pattern in (attn_patterns or []) for mod in module_patterns.get(pattern, [])],
294
- 'mlp_modules': [mod for pattern in mlp_patterns for mod in module_patterns.get(pattern, [])],
295
- 'norm_parameters': [param for pattern in (norm_patterns or []) for param in param_patterns.get(pattern, [])],
296
  'logit_lens_parameter': all_patterns.get(logit_pattern, [None])[0] if logit_pattern else None
297
  }
298
 
@@ -341,11 +343,11 @@ def run_analysis(n_clicks, model_name, prompt, attn_patterns, mlp_patterns, norm
341
  Output('run-analysis-btn', 'disabled'),
342
  [Input('model-dropdown', 'value'),
343
  Input('prompt-input', 'value'),
344
- Input('mlp-modules-dropdown', 'value')]
345
  )
346
- def enable_run_button(model, prompt, mlp_modules):
347
- """Enable Run Analysis button when model, prompt, and MLP modules are selected."""
348
- return not (model and prompt and mlp_modules)
349
 
350
  # Node click callback for analysis results
351
  @app.callback(
 
67
  @app.callback(
68
  [Output('session-patterns-store', 'data'),
69
  Output('attention-modules-dropdown', 'options'),
70
+ Output('block-modules-dropdown', 'options'),
71
  Output('norm-params-dropdown', 'options'),
72
  Output('logit-lens-dropdown', 'options'),
73
  Output('attention-modules-dropdown', 'value', allow_duplicate=True),
74
+ Output('block-modules-dropdown', 'value', allow_duplicate=True),
75
  Output('norm-params-dropdown', 'value', allow_duplicate=True),
76
  Output('logit-lens-dropdown', 'value', allow_duplicate=True),
77
  Output('loading-indicator', 'children')],
 
122
  attention_options = create_grouped_options(
123
  module_patterns, ['attn', 'attention'], 'modules'
124
  )
125
+ # Block options - layer/block modules (residual stream outputs)
126
+ block_options = create_grouped_options(
127
+ module_patterns, ['layers', 'h.', 'blocks', 'decoder.layers'], 'modules'
128
  )
129
  norm_options = create_grouped_options(
130
  param_patterns, ['norm', 'layernorm', 'layer_norm'], 'params'
 
171
  return (
172
  patterns_data,
173
  attention_options,
174
+ block_options,
175
  norm_options,
176
  logit_lens_options,
177
  auto_selections.get('attention_selection', []),
178
+ auto_selections.get('block_selection', []),
179
  auto_selections.get('norm_selection', []),
180
  auto_selections.get('logit_lens_selection'),
181
  loading_content
 
208
  # Callback to clear all selections when Clear button is pressed
209
  @app.callback(
210
  [Output('attention-modules-dropdown', 'value'),
211
+ Output('block-modules-dropdown', 'value'),
212
  Output('norm-params-dropdown', 'value'),
213
  Output('logit-lens-dropdown', 'value'),
214
  Output('session-activation-store', 'data'),
 
229
 
230
  return (
231
  None, # attention-modules-dropdown value
232
+ None, # block-modules-dropdown value
233
  None, # norm-params-dropdown value
234
  None, # logit-lens-dropdown value
235
  {}, # session-activation-store data
 
261
  [State('model-dropdown', 'value'),
262
  State('prompt-input', 'value'),
263
  State('attention-modules-dropdown', 'value'),
264
+ State('block-modules-dropdown', 'value'),
265
  State('norm-params-dropdown', 'value'),
266
  State('logit-lens-dropdown', 'value'),
267
  State('session-patterns-store', 'data')],
268
  prevent_initial_call=True
269
  )
270
+ def run_analysis(n_clicks, model_name, prompt, attn_patterns, block_patterns, norm_patterns, logit_pattern, patterns_data):
271
  """Run forward pass and generate cytoscape visualization."""
272
  print(f"\n=== DEBUG: run_analysis START ===")
273
  print(f"DEBUG: n_clicks={n_clicks}, model_name={model_name}, prompt='{prompt}'")
274
+ print(f"DEBUG: block_patterns={block_patterns}")
275
  print(f"DEBUG: logit_pattern={logit_pattern}")
276
 
277
+ if not n_clicks or not model_name or not prompt or not block_patterns:
278
  print("DEBUG: Missing required inputs, returning empty")
279
  return [], {}, None
280
 
 
290
  param_patterns = patterns_data.get('param_patterns', {})
291
  all_patterns = {**module_patterns, **param_patterns}
292
 
293
+ # Use block patterns (full layer outputs / residual stream) for logit lens
294
  config = {
295
  'attention_modules': [mod for pattern in (attn_patterns or []) for mod in module_patterns.get(pattern, [])],
296
+ 'block_modules': [mod for pattern in block_patterns for mod in module_patterns.get(pattern, [])],
297
+ 'norm_parameters': param_patterns.get(norm_patterns, []) if norm_patterns else [],
298
  'logit_lens_parameter': all_patterns.get(logit_pattern, [None])[0] if logit_pattern else None
299
  }
300
 
 
343
  Output('run-analysis-btn', 'disabled'),
344
  [Input('model-dropdown', 'value'),
345
  Input('prompt-input', 'value'),
346
+ Input('block-modules-dropdown', 'value')]
347
  )
348
+ def enable_run_button(model, prompt, block_modules):
349
+ """Enable Run Analysis button when model, prompt, and layer blocks are selected."""
350
+ return not (model and prompt and block_modules)
351
 
352
  # Node click callback for analysis results
353
  @app.callback(
components/sidebar.py CHANGED
@@ -3,7 +3,7 @@ Sidebar component with module and parameter selection dropdowns.
3
 
4
  This component provides the left sidebar interface for selecting:
5
  - Attention modules
6
- - MLP modules
7
  - Normalization parameters
8
  - Logit lens parameters
9
  """
@@ -31,14 +31,14 @@ def create_sidebar():
31
  )
32
  ], className="dropdown-container"),
33
 
34
- # MLP modules dropdown
35
  html.Div([
36
- html.Label("MLP Modules:", className="dropdown-label"),
37
  dcc.Dropdown(
38
- id='mlp-modules-dropdown',
39
  options=[],
40
  value=None,
41
- placeholder="Select MLP modules...",
42
  multi=True,
43
  className="module-dropdown"
44
  )
 
3
 
4
  This component provides the left sidebar interface for selecting:
5
  - Attention modules
6
+ - Layer blocks (residual stream outputs)
7
  - Normalization parameters
8
  - Logit lens parameters
9
  """
 
31
  )
32
  ], className="dropdown-container"),
33
 
34
+ # Layer blocks dropdown (residual stream outputs)
35
  html.Div([
36
+ html.Label("Layer Blocks:", className="dropdown-label"),
37
  dcc.Dropdown(
38
+ id='block-modules-dropdown',
39
  options=[],
40
  value=None,
41
+ placeholder="Select layer blocks...",
42
  multi=True,
43
  className="module-dropdown"
44
  )
utils/model_config.py CHANGED
@@ -17,7 +17,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
17
  "mlp_pattern": "model.layers.{N}.mlp",
18
  "block_pattern": "model.layers.{N}",
19
  },
20
- "norm_patterns": ["model.norm.weight"],
21
  "logit_lens_pattern": "lm_head.weight",
22
  "norm_type": "rmsnorm",
23
  },
@@ -30,7 +30,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
30
  "mlp_pattern": "transformer.h.{N}.mlp",
31
  "block_pattern": "transformer.h.{N}",
32
  },
33
- "norm_patterns": ["transformer.ln_f.weight", "transformer.ln_f.bias"],
34
  "logit_lens_pattern": "lm_head.weight",
35
  "norm_type": "layernorm",
36
  },
@@ -43,7 +43,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
43
  "mlp_pattern": "model.decoder.layers.{N}.fc2",
44
  "block_pattern": "model.decoder.layers.{N}",
45
  },
46
- "norm_patterns": ["model.decoder.final_layer_norm.weight", "model.decoder.final_layer_norm.bias"],
47
  "logit_lens_pattern": "lm_head.weight",
48
  "norm_type": "layernorm",
49
  },
@@ -56,7 +56,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
56
  "mlp_pattern": "gpt_neox.layers.{N}.mlp",
57
  "block_pattern": "gpt_neox.layers.{N}",
58
  },
59
- "norm_patterns": ["gpt_neox.final_layer_norm.weight", "gpt_neox.final_layer_norm.bias"],
60
  "logit_lens_pattern": "embed_out.weight",
61
  "norm_type": "layernorm",
62
  },
@@ -69,7 +69,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
69
  "mlp_pattern": "transformer.h.{N}.mlp",
70
  "block_pattern": "transformer.h.{N}",
71
  },
72
- "norm_patterns": ["transformer.ln_f.weight", "transformer.ln_f.bias"],
73
  "logit_lens_pattern": "lm_head.weight",
74
  "norm_type": "layernorm",
75
  },
@@ -82,7 +82,7 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
82
  "mlp_pattern": "transformer.h.{N}.mlp",
83
  "block_pattern": "transformer.h.{N}",
84
  },
85
- "norm_patterns": ["transformer.ln_f.weight", "transformer.ln_f.bias"],
86
  "logit_lens_pattern": "lm_head.weight",
87
  "norm_type": "layernorm",
88
  },
@@ -95,8 +95,8 @@ MODEL_FAMILIES: Dict[str, Dict[str, Any]] = {
95
  "mlp_pattern": "transformer.blocks.{N}.ffn",
96
  "block_pattern": "transformer.blocks.{N}",
97
  },
98
- "norm_patterns": ["transformer.norm_f.weight"],
99
- "logit_lens_parameter": "lm_head.weight",
100
  "norm_type": "layernorm",
101
  },
102
  }
@@ -214,15 +214,15 @@ def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]],
214
  param_patterns: Available parameter patterns from the model
215
 
216
  Returns:
217
- Dict with keys: attention_selection, mlp_selection, norm_selection, logit_lens_selection
218
  Each value is a list of pattern keys that should be pre-selected
219
  """
220
  family = get_model_family(model_name)
221
  if not family:
222
  return {
223
  'attention_selection': [],
224
- 'mlp_selection': [],
225
- 'norm_selection': [],
226
  'logit_lens_selection': None,
227
  'family_name': None
228
  }
@@ -231,16 +231,16 @@ def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]],
231
  if not config:
232
  return {
233
  'attention_selection': [],
234
- 'mlp_selection': [],
235
- 'norm_selection': [],
236
  'logit_lens_selection': None,
237
  'family_name': None
238
  }
239
 
240
  # Find matching patterns in the available patterns
241
  attention_matches = []
242
- mlp_matches = []
243
- norm_matches = []
244
  logit_lens_match = None
245
 
246
  # Match attention patterns
@@ -249,17 +249,19 @@ def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]],
249
  if _pattern_matches_template(pattern_key, attention_template):
250
  attention_matches.append(pattern_key)
251
 
252
- # Match MLP patterns
253
- mlp_template = config['templates'].get('mlp_pattern', '')
254
  for pattern_key in module_patterns.keys():
255
- if _pattern_matches_template(pattern_key, mlp_template):
256
- mlp_matches.append(pattern_key)
257
 
258
- # Match normalization patterns
259
- for norm_pattern in config.get('norm_patterns', []):
 
260
  for pattern_key in param_patterns.keys():
261
- if _pattern_matches_template(pattern_key, norm_pattern):
262
- norm_matches.append(pattern_key)
 
263
 
264
  # Match logit lens pattern - check both parameters AND modules
265
  logit_pattern = config.get('logit_lens_pattern', '')
@@ -277,8 +279,8 @@ def get_auto_selections(model_name: str, module_patterns: Dict[str, List[str]],
277
 
278
  return {
279
  'attention_selection': attention_matches,
280
- 'mlp_selection': mlp_matches,
281
- 'norm_selection': norm_matches,
282
  'logit_lens_selection': logit_lens_match,
283
  'family_name': family,
284
  'family_description': config.get('description', '')
 
17
  "mlp_pattern": "model.layers.{N}.mlp",
18
  "block_pattern": "model.layers.{N}",
19
  },
20
+ "norm_parameter": "model.norm.weight",
21
  "logit_lens_pattern": "lm_head.weight",
22
  "norm_type": "rmsnorm",
23
  },
 
30
  "mlp_pattern": "transformer.h.{N}.mlp",
31
  "block_pattern": "transformer.h.{N}",
32
  },
33
+ "norm_parameter": "transformer.ln_f.weight",
34
  "logit_lens_pattern": "lm_head.weight",
35
  "norm_type": "layernorm",
36
  },
 
43
  "mlp_pattern": "model.decoder.layers.{N}.fc2",
44
  "block_pattern": "model.decoder.layers.{N}",
45
  },
46
+ "norm_parameter": "model.decoder.final_layer_norm.weight",
47
  "logit_lens_pattern": "lm_head.weight",
48
  "norm_type": "layernorm",
49
  },
 
56
  "mlp_pattern": "gpt_neox.layers.{N}.mlp",
57
  "block_pattern": "gpt_neox.layers.{N}",
58
  },
59
+ "norm_parameter": "gpt_neox.final_layer_norm.weight",
60
  "logit_lens_pattern": "embed_out.weight",
61
  "norm_type": "layernorm",
62
  },
 
69
  "mlp_pattern": "transformer.h.{N}.mlp",
70
  "block_pattern": "transformer.h.{N}",
71
  },
72
+ "norm_parameter": "transformer.ln_f.weight",
73
  "logit_lens_pattern": "lm_head.weight",
74
  "norm_type": "layernorm",
75
  },
 
82
  "mlp_pattern": "transformer.h.{N}.mlp",
83
  "block_pattern": "transformer.h.{N}",
84
  },
85
+ "norm_parameter": "transformer.ln_f.weight",
86
  "logit_lens_pattern": "lm_head.weight",
87
  "norm_type": "layernorm",
88
  },
 
95
  "mlp_pattern": "transformer.blocks.{N}.ffn",
96
  "block_pattern": "transformer.blocks.{N}",
97
  },
98
+ "norm_parameter": "transformer.norm_f.weight",
99
+ "logit_lens_pattern": "lm_head.weight",
100
  "norm_type": "layernorm",
101
  },
102
  }
 
214
  param_patterns: Available parameter patterns from the model
215
 
216
  Returns:
217
+ Dict with keys: attention_selection, block_selection, norm_selection, logit_lens_selection
218
  Each value is a list of pattern keys that should be pre-selected
219
  """
220
  family = get_model_family(model_name)
221
  if not family:
222
  return {
223
  'attention_selection': [],
224
+ 'block_selection': [],
225
+ 'norm_selection': None,
226
  'logit_lens_selection': None,
227
  'family_name': None
228
  }
 
231
  if not config:
232
  return {
233
  'attention_selection': [],
234
+ 'block_selection': [],
235
+ 'norm_selection': None,
236
  'logit_lens_selection': None,
237
  'family_name': None
238
  }
239
 
240
  # Find matching patterns in the available patterns
241
  attention_matches = []
242
+ block_matches = []
243
+ norm_match = None
244
  logit_lens_match = None
245
 
246
  # Match attention patterns
 
249
  if _pattern_matches_template(pattern_key, attention_template):
250
  attention_matches.append(pattern_key)
251
 
252
+ # Match block patterns (full layer outputs - residual stream)
253
+ block_template = config['templates'].get('block_pattern', '')
254
  for pattern_key in module_patterns.keys():
255
+ if _pattern_matches_template(pattern_key, block_template):
256
+ block_matches.append(pattern_key)
257
 
258
+ # Match normalization parameter
259
+ norm_parameter = config.get('norm_parameter', '')
260
+ if norm_parameter:
261
  for pattern_key in param_patterns.keys():
262
+ if _pattern_matches_template(pattern_key, norm_parameter):
263
+ norm_match = pattern_key
264
+ break
265
 
266
  # Match logit lens pattern - check both parameters AND modules
267
  logit_pattern = config.get('logit_lens_pattern', '')
 
279
 
280
  return {
281
  'attention_selection': attention_matches,
282
+ 'block_selection': block_matches,
283
+ 'norm_selection': norm_match,
284
  'logit_lens_selection': logit_lens_match,
285
  'family_name': family,
286
  'family_description': config.get('description', '')
utils/model_patterns.py CHANGED
@@ -92,7 +92,7 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
92
  model: Loaded transformer model
93
  tokenizer: Loaded tokenizer
94
  prompt: Input text prompt
95
- config: Dict with module lists like {"attention_modules": [...], "mlp_modules": [...], ...}
96
 
97
  Returns:
98
  JSON-serializable dict with captured activations and metadata
@@ -101,12 +101,11 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
101
 
102
  # Extract module lists from config
103
  attention_modules = config.get("attention_modules", [])
104
- mlp_modules = config.get("mlp_modules", [])
105
- other_modules = config.get("other_modules", [])
106
  norm_parameters = config.get("norm_parameters", [])
107
  logit_lens_parameter = config.get("logit_lens_parameter")
108
 
109
- all_modules = attention_modules + mlp_modules + other_modules
110
  if not all_modules:
111
  print("No modules specified for capture")
112
  return {"error": "No modules specified"}
@@ -119,12 +118,12 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
119
  if not layer_match:
120
  return {"error": f"Invalid module name format: {mod_name}"}
121
 
122
- # Determine component type
123
- if mod_name in mlp_modules:
124
- component = 'mlp_output'
125
- elif mod_name in attention_modules:
126
  component = 'attention_output'
127
  else:
 
 
128
  component = 'block_output'
129
 
130
  intervenable_representations.append(
@@ -162,10 +161,16 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
162
  for hook in hooks:
163
  hook.remove()
164
 
165
- # Separate outputs by type
166
- attention_outputs = {k: v for k, v in captured.items() if k in attention_modules}
167
- mlp_outputs = {k: v for k, v in captured.items() if k in mlp_modules}
168
- other_outputs = {k: v for k, v in captured.items() if k in other_modules}
 
 
 
 
 
 
169
 
170
  # Capture normalization parameters (deprecated - kept for backward compatibility)
171
  all_params = dict(model.named_parameters())
@@ -184,47 +189,49 @@ def execute_forward_pass(model, tokenizer, prompt: str, config: Dict[str, Any])
184
  "model": getattr(model.config, "name_or_path", "unknown"),
185
  "prompt": prompt,
186
  "input_ids": safe_to_serializable(inputs["input_ids"]),
187
- "attention_modules": attention_modules,
188
  "attention_outputs": attention_outputs,
189
- "mlp_modules": mlp_modules,
190
- "mlp_outputs": mlp_outputs,
191
- "other_modules": other_modules,
192
- "other_outputs": other_outputs,
193
  "norm_parameters": norm_parameters,
194
  "norm_data": norm_data,
195
  "logit_lens_parameter": logit_lens_parameter,
196
- "actual_output": actual_output # Store only token and probability, not full output
197
  }
198
 
199
  print(f"Captured {len(captured)} module outputs using PyVene")
200
  return result
201
 
202
 
203
- def logit_lens_transformation(mlp_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer) -> List[Tuple[str, float]]:
204
  """
205
  Transform layer output to top 3 token probabilities using logit lens.
206
 
 
 
 
207
  Applies final layer normalization before projection (critical for correctness).
208
  Uses model's built-in functions to minimize computational errors.
209
 
210
  Args:
211
- mlp_output: Hidden state from any layer
212
  norm_data: Not used (deprecated - using model's norm layer directly)
213
  model: HuggingFace model
214
  logit_lens_parameter: Not used (deprecated)
215
  tokenizer: Tokenizer for decoding
 
216
 
217
  Returns:
218
  List of (token_string, probability) tuples for top 3 tokens
219
  """
220
  with torch.no_grad():
221
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
222
- hidden = torch.tensor(mlp_output) if not isinstance(mlp_output, torch.Tensor) else mlp_output
223
  if hidden.dim() == 4:
224
  hidden = hidden.squeeze(0)
225
 
226
  # Step 1: Apply final layer normalization (critical for intermediate layers)
227
- final_norm = get_final_norm_layer(model)
228
  if final_norm is not None:
229
  hidden = final_norm(hidden)
230
 
@@ -244,15 +251,31 @@ def logit_lens_transformation(mlp_output: Any, norm_data: List[Any], model, logi
244
  ]
245
 
246
 
247
- def get_final_norm_layer(model):
248
  """
249
- Get the final layer normalization module from the model.
250
- Returns None if not found.
251
 
252
- Supports GPT-2 (transformer.ln_f), LLaMA (model.norm), and similar architectures.
 
 
 
 
 
253
  """
254
- # Try common final norm layer names
255
- for attr_path in ['transformer.ln_f', 'model.norm', 'model.decoder.final_layer_norm',
 
 
 
 
 
 
 
 
 
 
 
 
256
  'gpt_neox.final_layer_norm', 'transformer.norm_f']:
257
  try:
258
  parts = attr_path.split('.')
@@ -274,10 +297,24 @@ def token_to_color(token: str) -> str:
274
 
275
 
276
  def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer) -> Optional[List[Tuple[str, float]]]:
277
- """Helper: Get top 3 tokens for a layer's output."""
 
 
 
 
 
278
  try:
279
- mlp_output = activation_data['mlp_outputs'][module_name]['output']
280
- return logit_lens_transformation(mlp_output, [], model, None, tokenizer)
 
 
 
 
 
 
 
 
 
281
  except Exception as e:
282
  print(f"Warning: Could not compute logit lens for {module_name}: {e}")
283
  return None
@@ -314,15 +351,20 @@ def _create_edge(src_layer: int, tgt_layer: int, token: str, prob: float, rank:
314
 
315
 
316
  def format_data_for_cytoscape(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
317
- """Convert activation data to Cytoscape format with nodes (layers) and edges (top-3 tokens)."""
318
- mlp_modules = activation_data.get('mlp_modules', [])
319
- if not mlp_modules:
 
 
 
 
 
320
  return []
321
 
322
  # Extract and sort layers by layer number
323
  layer_info = sorted(
324
  [(int(re.findall(r'\d+', name)[0]), name)
325
- for name in mlp_modules if re.findall(r'\d+', name)]
326
  )
327
 
328
  elements = []
 
92
  model: Loaded transformer model
93
  tokenizer: Loaded tokenizer
94
  prompt: Input text prompt
95
+ config: Dict with module lists like {"attention_modules": [...], "block_modules": [...], ...}
96
 
97
  Returns:
98
  JSON-serializable dict with captured activations and metadata
 
101
 
102
  # Extract module lists from config
103
  attention_modules = config.get("attention_modules", [])
104
+ block_modules = config.get("block_modules", [])
 
105
  norm_parameters = config.get("norm_parameters", [])
106
  logit_lens_parameter = config.get("logit_lens_parameter")
107
 
108
+ all_modules = attention_modules + block_modules
109
  if not all_modules:
110
  print("No modules specified for capture")
111
  return {"error": "No modules specified"}
 
118
  if not layer_match:
119
  return {"error": f"Invalid module name format: {mod_name}"}
120
 
121
+ # Determine component type based on module name
122
+ if 'attn' in mod_name or 'attention' in mod_name:
 
 
123
  component = 'attention_output'
124
  else:
125
+ # Layer/block modules (e.g., "model.layers.0", "transformer.h.0")
126
+ # These represent the residual stream (full layer output)
127
  component = 'block_output'
128
 
129
  intervenable_representations.append(
 
161
  for hook in hooks:
162
  hook.remove()
163
 
164
+ # Separate outputs by type based on module name pattern
165
+ attention_outputs = {}
166
+ block_outputs = {}
167
+
168
+ for mod_name, output in captured.items():
169
+ if 'attn' in mod_name or 'attention' in mod_name:
170
+ attention_outputs[mod_name] = output
171
+ else:
172
+ # Block/layer outputs (residual stream - full layer output)
173
+ block_outputs[mod_name] = output
174
 
175
  # Capture normalization parameters (deprecated - kept for backward compatibility)
176
  all_params = dict(model.named_parameters())
 
189
  "model": getattr(model.config, "name_or_path", "unknown"),
190
  "prompt": prompt,
191
  "input_ids": safe_to_serializable(inputs["input_ids"]),
192
+ "attention_modules": list(attention_outputs.keys()),
193
  "attention_outputs": attention_outputs,
194
+ "block_modules": list(block_outputs.keys()),
195
+ "block_outputs": block_outputs,
 
 
196
  "norm_parameters": norm_parameters,
197
  "norm_data": norm_data,
198
  "logit_lens_parameter": logit_lens_parameter,
199
+ "actual_output": actual_output
200
  }
201
 
202
  print(f"Captured {len(captured)} module outputs using PyVene")
203
  return result
204
 
205
 
206
+ def logit_lens_transformation(layer_output: Any, norm_data: List[Any], model, logit_lens_parameter: str, tokenizer, norm_parameter: Optional[str] = None) -> List[Tuple[str, float]]:
207
  """
208
  Transform layer output to top 3 token probabilities using logit lens.
209
 
210
+ For standard logit lens, use block/layer outputs (residual stream), not component outputs.
211
+ The residual stream contains the full hidden state with all accumulated information.
212
+
213
  Applies final layer normalization before projection (critical for correctness).
214
  Uses model's built-in functions to minimize computational errors.
215
 
216
  Args:
217
+ layer_output: Hidden state from any layer (preferably block output / residual stream)
218
  norm_data: Not used (deprecated - using model's norm layer directly)
219
  model: HuggingFace model
220
  logit_lens_parameter: Not used (deprecated)
221
  tokenizer: Tokenizer for decoding
222
+ norm_parameter: Parameter path for final norm layer (e.g., "model.norm.weight")
223
 
224
  Returns:
225
  List of (token_string, probability) tuples for top 3 tokens
226
  """
227
  with torch.no_grad():
228
  # Convert to tensor and ensure proper shape [batch, seq_len, hidden_dim]
229
+ hidden = torch.tensor(layer_output) if not isinstance(layer_output, torch.Tensor) else layer_output
230
  if hidden.dim() == 4:
231
  hidden = hidden.squeeze(0)
232
 
233
  # Step 1: Apply final layer normalization (critical for intermediate layers)
234
+ final_norm = get_norm_layer_from_parameter(model, norm_parameter)
235
  if final_norm is not None:
236
  hidden = final_norm(hidden)
237
 
 
251
  ]
252
 
253
 
254
+ def get_norm_layer_from_parameter(model, norm_parameter: Optional[str]) -> Optional[Any]:
255
  """
256
+ Get the final layer normalization module from the model using the norm parameter path.
 
257
 
258
+ Args:
259
+ model: The transformer model
260
+ norm_parameter: Parameter path (e.g., "model.norm.weight") or None
261
+
262
+ Returns:
263
+ The normalization layer module, or None if not found
264
  """
265
+ if norm_parameter:
266
+ # Convert parameter path to module path (remove .weight/.bias suffix)
267
+ module_path = norm_parameter.replace('.weight', '').replace('.bias', '')
268
+ try:
269
+ parts = module_path.split('.')
270
+ obj = model
271
+ for part in parts:
272
+ obj = getattr(obj, part)
273
+ return obj
274
+ except AttributeError:
275
+ print(f"Warning: Could not find norm layer at {module_path}")
276
+
277
+ # Fallback: Try common final norm layer names if no parameter specified
278
+ for attr_path in ['model.norm', 'transformer.ln_f', 'model.decoder.final_layer_norm',
279
  'gpt_neox.final_layer_norm', 'transformer.norm_f']:
280
  try:
281
  parts = attr_path.split('.')
 
297
 
298
 
299
  def _get_top_tokens(activation_data: Dict[str, Any], module_name: str, model, tokenizer) -> Optional[List[Tuple[str, float]]]:
300
+ """
301
+ Helper: Get top 3 tokens for a layer's block output.
302
+
303
+ Uses block outputs (residual stream) which represent the full hidden state
304
+ after all layer computations (attention + feedforward + residuals).
305
+ """
306
  try:
307
+ # Get block output (residual stream)
308
+ if module_name not in activation_data.get('block_outputs', {}):
309
+ return None
310
+
311
+ layer_output = activation_data['block_outputs'][module_name]['output']
312
+
313
+ # Get norm parameter from activation data (should be a single parameter or list with one item)
314
+ norm_params = activation_data.get('norm_parameters', [])
315
+ norm_parameter = norm_params[0] if norm_params else None
316
+
317
+ return logit_lens_transformation(layer_output, [], model, None, tokenizer, norm_parameter)
318
  except Exception as e:
319
  print(f"Warning: Could not compute logit lens for {module_name}: {e}")
320
  return None
 
351
 
352
 
353
  def format_data_for_cytoscape(activation_data: Dict[str, Any], model, tokenizer) -> List[Dict[str, Any]]:
354
+ """
355
+ Convert activation data to Cytoscape format with nodes (layers) and edges (top-3 tokens).
356
+
357
+ Uses block outputs (full layer outputs / residual stream) for logit lens visualization.
358
+ """
359
+ # Get block modules (full layer outputs)
360
+ layer_modules = activation_data.get('block_modules', [])
361
+ if not layer_modules:
362
  return []
363
 
364
  # Extract and sort layers by layer number
365
  layer_info = sorted(
366
  [(int(re.findall(r'\d+', name)[0]), name)
367
+ for name in layer_modules if re.findall(r'\d+', name)]
368
  )
369
 
370
  elements = []