Marlin Lee Claude Sonnet 4.6 commited on
Commit
fd8ee51
·
1 Parent(s): 93e35bf

Sync local changes: CLIP scores, NSD image lookup, multi-trial DynaDiff, phi_c columns, P75 col, label captions, entrypoint pre-warm

Browse files
entrypoint.sh CHANGED
@@ -103,6 +103,35 @@ if [ ! -d "$COCO_THUMBS" ]; then
103
  fi
104
  IMAGE_DIR_ARG=(--image-dir "$COCO_THUMBS")
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  # ── Determine websocket origin ────────────────────────────────────────────────
107
  SPACE_HOST="${SPACE_HOST:-localhost}"
108
 
 
103
  fi
104
  IMAGE_DIR_ARG=(--image-dir "$COCO_THUMBS")
105
 
106
+ # ── Pre-warm DynaDiff before Bokeh starts ────────────────────────────────────
107
+ # torch.load holds the GIL for extended periods; doing this synchronously before
108
+ # Bokeh launches ensures Tornado's event loop isn't starved when users connect.
109
+ if [ -f "$DYNADIFF_CKPT" ] && [ -f "$FMRI_H5" ]; then
110
+ echo "Pre-warming DynaDiff (this may take a few minutes on cold start)..."
111
+ python3 - <<PYEOF
112
+ import sys, time, os
113
+ sys.path.insert(0, '/app')
114
+ sys.path.insert(0, '/app/dynadiff')
115
+ sys.path.insert(0, '/app/dynadiff/diffusers/src')
116
+ os.chdir('/app')
117
+ from scripts.dynadiff_loader import get_loader
118
+ loader = get_loader(
119
+ dynadiff_dir='/app/dynadiff',
120
+ checkpoint=os.environ.get('DYNADIFF_CKPT', '$DYNADIFF_CKPT'),
121
+ h5_path='$FMRI_H5',
122
+ )
123
+ while True:
124
+ status, err = loader.status
125
+ if status == 'ok':
126
+ print('DynaDiff pre-warm complete.')
127
+ break
128
+ elif status == 'error':
129
+ print(f'DynaDiff pre-warm failed: {err}')
130
+ break
131
+ time.sleep(5)
132
+ PYEOF
133
+ fi
134
+
135
  # ── Determine websocket origin ────────────────────────────────────────────────
136
  SPACE_HOST="${SPACE_HOST:-localhost}"
137
 
scripts/add_clip_embeddings.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Post-hoc CLIP text-alignment enrichment for explorer_data.pt files.
3
+
4
+ Loads an existing explorer_data.pt, computes per-feature CLIP text alignment
5
+ scores (via MEI images), and saves them back into the same file under:
6
+ 'clip_text_scores' : Tensor (n_features, n_vocab) float16
7
+ 'clip_text_vocab' : list[str]
8
+ 'clip_feature_embeds' : Tensor (n_features, clip_proj_dim) float16
9
+ mean CLIP image embedding of each feature's top MEIs
10
+
11
+ This script does NOT need to re-run DINOv3 or the SAE — it only needs the
12
+ existing explorer_data.pt (for image paths and top-MEI indices) and CLIP.
13
+
14
+ Usage
15
+ -----
16
+ python add_clip_embeddings.py \
17
+ --data ../explorer_data_d32000_k160.pt \
18
+ --vocab-file ../vocab/imagenet_labels.txt \
19
+ --n-top-images 4 \
20
+ --batch-size 32
21
+
22
+ # Or use the built-in default vocabulary (ImageNet-1K labels + COCO categories):
23
+ python add_clip_embeddings.py \
24
+ --data ../explorer_data_d32000_k160.pt
25
+
26
+ The enriched file is saved to --output-path (defaults to overwriting --data
27
+ with a backup copy at <data>.bak).
28
+ """
29
+
30
+ import argparse
31
+ import os
32
+ import shutil
33
+
34
+ import torch
35
+ import torch.nn.functional as F
36
+ from PIL import Image
37
+
38
+ # Allow running from scripts/ directory or project root
39
+ import sys
40
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
41
+ from clip_utils import load_clip, compute_text_embeddings, compute_mei_text_alignment
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Default vocabulary
46
+ # ---------------------------------------------------------------------------
47
+
48
+ DEFAULT_VOCAB = [
49
+ # COCO categories
50
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
51
+ "truck", "boat", "traffic light", "fire hydrant", "stop sign",
52
+ "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
53
+ "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag",
54
+ "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
55
+ "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
56
+ "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
57
+ "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
58
+ "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
59
+ "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
60
+ "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock",
61
+ "vase", "scissors", "teddy bear", "hair drier", "toothbrush",
62
+ # Texture / scene descriptors
63
+ "grass", "sky", "water", "sand", "snow", "wood", "stone", "metal",
64
+ "fabric", "fur", "feathers", "leaves", "clouds", "fire", "shadow",
65
+ "stripes", "spots", "checkerboard pattern", "geometric pattern",
66
+ # Orientation / structure cues (for patch features)
67
+ "horizontal lines", "vertical lines", "diagonal lines", "curved lines",
68
+ "edges", "corners", "grid", "dots", "concentric circles",
69
+ # Color / illumination
70
+ "red object", "blue object", "green object", "yellow object",
71
+ "black and white", "bright highlight", "dark shadow", "gradient",
72
+ # Scene types
73
+ "indoor scene", "outdoor scene", "urban street", "nature landscape",
74
+ "ocean", "mountain", "forest", "desert", "city buildings", "crowd",
75
+ ]
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Main
80
+ # ---------------------------------------------------------------------------
81
+
82
+ def main():
83
+ parser = argparse.ArgumentParser(description="Add CLIP text alignment to explorer_data.pt")
84
+ parser.add_argument("--data", type=str, required=True,
85
+ help="Path to explorer_data.pt")
86
+ parser.add_argument("--output-path", type=str, default=None,
87
+ help="Output path (default: overwrite --data, keeping .bak)")
88
+ parser.add_argument("--vocab-file", type=str, default=None,
89
+ help="Plain-text file with one concept per line. "
90
+ "Default: built-in COCO+texture vocabulary.")
91
+ parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
92
+ help="HuggingFace CLIP model ID")
93
+ parser.add_argument("--n-top-images", type=int, default=4,
94
+ help="Number of MEIs to average per feature for CLIP alignment")
95
+ parser.add_argument("--batch-size", type=int, default=32,
96
+ help="Batch size for CLIP image encoding")
97
+ parser.add_argument("--no-backup", action="store_true",
98
+ help="Skip creating a .bak copy before overwriting")
99
+ parser.add_argument("--image-dir", type=str, default=None,
100
+ help="Primary image directory for resolving bare filenames")
101
+ parser.add_argument("--extra-image-dir", type=str, action="append", default=[],
102
+ help="Additional image directory (repeatable)")
103
+ args = parser.parse_args()
104
+
105
+ image_bases = [b for b in ([args.image_dir] + args.extra_image_dir) if b]
106
+
107
+ def resolve_path(p):
108
+ if os.path.isabs(p) or not image_bases:
109
+ return p
110
+ for base in image_bases:
111
+ full = os.path.join(base, p)
112
+ if os.path.exists(full):
113
+ return full
114
+ return os.path.join(image_bases[0], p) # fallback
115
+
116
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
117
+ print(f"Device: {device}")
118
+
119
+ # --- Load explorer data ---
120
+ print(f"Loading explorer data from {args.data}...")
121
+ data = torch.load(args.data, map_location='cpu', weights_only=False)
122
+ image_paths = [resolve_path(p) for p in data['image_paths']]
123
+ d_model = data['d_model']
124
+ top_img_idx = data['top_img_idx'] # (n_features, n_top)
125
+ n_top_stored = top_img_idx.shape[1]
126
+ print(f" d_model={d_model}, n_images={data['n_images']}, "
127
+ f"top-{n_top_stored} images stored")
128
+
129
+ # --- Load vocabulary ---
130
+ if args.vocab_file:
131
+ with open(args.vocab_file) as f:
132
+ vocab = [line.strip() for line in f if line.strip()]
133
+ print(f"Loaded {len(vocab)} concepts from {args.vocab_file}")
134
+ else:
135
+ vocab = DEFAULT_VOCAB
136
+ print(f"Using default vocabulary ({len(vocab)} concepts)")
137
+
138
+ # --- Load CLIP ---
139
+ clip_model, clip_processor = load_clip(device, model_name=args.clip_model)
140
+
141
+ # --- Precompute text embeddings ---
142
+ print("Encoding text vocabulary with CLIP...")
143
+ text_embeds = compute_text_embeddings(vocab, clip_model, clip_processor, device)
144
+ print(f" text_embeds: {text_embeds.shape}")
145
+
146
+ # --- Collect MEI image paths per feature ---
147
+ print("Collecting MEI image paths per feature...")
148
+ n_use = min(args.n_top_images, n_top_stored)
149
+ feature_mei_paths = []
150
+ for feat in range(d_model):
151
+ paths = []
152
+ for j in range(n_use):
153
+ idx = top_img_idx[feat, j].item()
154
+ if idx >= 0:
155
+ paths.append(image_paths[idx])
156
+ feature_mei_paths.append(paths)
157
+
158
+ # --- Compute per-feature CLIP image embeddings (mean of MEIs) ---
159
+ print(f"Computing CLIP image embeddings for {d_model} features "
160
+ f"(averaging {n_use} MEIs each)...")
161
+
162
+ clip_proj_dim = clip_model.config.projection_dim
163
+ feature_img_embeds = torch.zeros(d_model, clip_proj_dim, dtype=torch.float32)
164
+ dead_count = 0
165
+
166
+ for feat_start in range(0, d_model, args.batch_size):
167
+ feat_end = min(feat_start + args.batch_size, d_model)
168
+ for feat in range(feat_start, feat_end):
169
+ paths = feature_mei_paths[feat]
170
+ if not paths:
171
+ dead_count += 1
172
+ continue
173
+ imgs = []
174
+ for p in paths:
175
+ try:
176
+ imgs.append(Image.open(p).convert("RGB"))
177
+ except Exception:
178
+ continue
179
+ if not imgs:
180
+ dead_count += 1
181
+ continue
182
+ inputs = clip_processor(images=imgs, return_tensors="pt")
183
+ pixel_values = inputs['pixel_values'].to(device)
184
+ with torch.inference_mode():
185
+ # Use vision_model + visual_projection directly to avoid
186
+ # version differences in get_image_features() return type.
187
+ vision_out = clip_model.vision_model(pixel_values=pixel_values)
188
+ embeds = clip_model.visual_projection(vision_out.pooler_output)
189
+ embeds = F.normalize(embeds, dim=-1)
190
+ mean_embed = embeds.mean(dim=0)
191
+ mean_embed = F.normalize(mean_embed, dim=-1)
192
+ feature_img_embeds[feat] = mean_embed.cpu().float()
193
+
194
+ if (feat_start // args.batch_size + 1) % 100 == 0:
195
+ print(f" [{feat_end}/{d_model}] features encoded", flush=True)
196
+
197
+ print(f" Done. Dead/missing features skipped: {dead_count}")
198
+
199
+ # --- Compute alignment matrix ---
200
+ print("Computing text alignment matrix...")
201
+ # (n_features, clip_proj_dim) @ (clip_proj_dim, n_vocab) = (n_features, n_vocab)
202
+ clip_text_scores = feature_img_embeds @ text_embeds.T # float32
203
+ print(f" clip_text_scores: {clip_text_scores.shape}")
204
+
205
+ # --- Save into explorer_data.pt ---
206
+ output_path = args.output_path or args.data
207
+ if output_path == args.data and not args.no_backup:
208
+ bak_path = args.data + ".bak"
209
+ print(f"Creating backup at {bak_path}...")
210
+ shutil.copy2(args.data, bak_path)
211
+
212
+ data['clip_text_scores'] = clip_text_scores.half() # float16 to save space
213
+ data['clip_feature_embeds'] = feature_img_embeds.half() # float16
214
+ data['clip_text_vocab'] = vocab
215
+
216
+ print(f"Saving enriched explorer data to {output_path}...")
217
+ torch.save(data, output_path)
218
+ size_mb = os.path.getsize(output_path) / 1e6
219
+ print(f"Saved ({size_mb:.1f} MB)")
220
+ print("Done.")
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
scripts/dynadiff_loader.py CHANGED
@@ -80,6 +80,7 @@ class DynaDiffLoader:
80
  self._cfg = None
81
  self._beta_std = None
82
  self._subject_sample_indices = None
 
83
  self._status = 'loading' # 'loading' | 'ok' | 'error'
84
  self._error = ''
85
  self._lock = threading.Lock()
@@ -102,6 +103,15 @@ class DynaDiffLoader:
102
  idx = self._subject_sample_indices
103
  return len(idx) if idx is not None else None
104
 
 
 
 
 
 
 
 
 
 
105
  def start(self):
106
  """Start background model loading thread."""
107
  t = threading.Thread(target=self._load, daemon=True)
@@ -216,15 +226,23 @@ class DynaDiffLoader:
216
  # Subject sample index mapping
217
  log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
218
  with h5py.File(self.h5_path, 'r') as hf:
219
- all_subj = np.array(hf['subject_idx'][:], dtype=np.int64)
 
220
  sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
221
  log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
222
 
 
 
 
 
 
 
223
  with self._lock:
224
  self._model = model
225
  self._cfg = cfg
226
  self._beta_std = beta_std
227
  self._subject_sample_indices = sample_indices
 
228
  self._status = 'ok'
229
  log.info('[DynaDiff] Ready.')
230
 
 
80
  self._cfg = None
81
  self._beta_std = None
82
  self._subject_sample_indices = None
83
+ self._nsd_to_sample = {}
84
  self._status = 'loading' # 'loading' | 'ok' | 'error'
85
  self._error = ''
86
  self._lock = threading.Lock()
 
103
  idx = self._subject_sample_indices
104
  return len(idx) if idx is not None else None
105
 
106
+ def sample_idxs_for_nsd_img(self, nsd_img_idx):
107
+ """Return the list of sample_idx values that correspond to a given NSD image index.
108
+
109
+ Returns an empty list if the image has no trials for this subject or the
110
+ mapping is not yet built (model still loading).
111
+ """
112
+ with self._lock:
113
+ return list(self._nsd_to_sample.get(int(nsd_img_idx), []))
114
+
115
  def start(self):
116
  """Start background model loading thread."""
117
  t = threading.Thread(target=self._load, daemon=True)
 
226
  # Subject sample index mapping
227
  log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
228
  with h5py.File(self.h5_path, 'r') as hf:
229
+ all_subj = np.array(hf['subject_idx'][:], dtype=np.int64)
230
+ all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64)
231
  sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
232
  log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
233
 
234
+ # Build reverse map: NSD image index → list of sample_idx values
235
+ nsd_to_sample: dict[int, list[int]] = {}
236
+ for sample_idx_val, h5_row in enumerate(sample_indices):
237
+ nsd_img = int(all_imgidx[h5_row])
238
+ nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val)
239
+
240
  with self._lock:
241
  self._model = model
242
  self._cfg = cfg
243
  self._beta_std = beta_std
244
  self._subject_sample_indices = sample_indices
245
+ self._nsd_to_sample = nsd_to_sample
246
  self._status = 'ok'
247
  log.info('[DynaDiff] Ready.')
248
 
scripts/explorer_app.py CHANGED
@@ -430,10 +430,10 @@ def _load_brain_dataset_dict(path, label, thumb_dir):
430
  'feature_p75_val': bd.get('feature_p75_val', torch.zeros(d_model)),
431
  'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2,
432
  'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2,
433
- 'clip_scores': None,
434
- 'clip_vocab': None,
435
- 'clip_embeds': None,
436
- 'clip_scores_f32': None,
437
  'inference_cache': OrderedDict(),
438
  'names_file': stem + '_feature_names.json',
439
  'auto_interp_file': stem + '_auto_interp.json',
@@ -633,6 +633,9 @@ def _reconstruct_z_from_heatmaps(img_idx, ds):
633
  idx = ds.get(idx_key) # (d_sae, n_slots) int tensor
634
  if hm is None or idx is None:
635
  continue
 
 
 
636
  if z is None:
637
  d_sae, _, n_patches_sq = hm.shape
638
  z = np.zeros((n_patches_sq, d_sae), dtype=np.float32)
@@ -704,6 +707,20 @@ ALPHA_JET = create_alpha_cmap('jet')
704
  THUMB = args.thumb_size
705
 
706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  def _resolve_img_path(stored_path):
708
  """Resolve a stored image path, searching image dirs first. Returns None on failure."""
709
  if os.path.isabs(stored_path) and os.path.exists(stored_path):
@@ -952,33 +969,42 @@ def _dynadiff_request(sample_idx, steerings, seed):
952
  return _dd_loader.reconstruct(sample_idx, steerings, seed)
953
 
954
 
955
- def _make_steering_html(resp, concept_name):
956
- """Build HTML showing GT | Baseline | Steered side by side."""
957
- parts = []
958
- for label, key in [('Ground Truth', 'gt_img'),
959
- ('Baseline (λ=0)', 'baseline_img'),
960
- (f'Steered', 'steered_img')]:
961
- b64 = resp.get(key)
962
- if b64 is None:
963
- img_html = ('<div style="width:200px;height:200px;background:#eee;'
964
- 'display:flex;align-items:center;justify-content:center;'
965
- 'color:#999;font-size:12px">N/A</div>')
966
- else:
967
- img_html = (f'<img src="data:image/png;base64,{b64}" '
968
- 'style="width:200px;height:200px;object-fit:contain;'
969
- 'border:1px solid #ddd;border-radius:4px"/>')
970
- parts.append(
971
- f'<div style="text-align:center;margin:0 6px">'
972
- f'{img_html}'
973
- f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>'
974
- f'</div>'
975
- )
976
- imgs_html = '<div style="display:flex;align-items:flex-end">' + ''.join(parts) + '</div>'
977
- return (
978
  f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;'
979
  f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>'
980
- + imgs_html
981
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
 
983
 
984
  def make_image_grid_html(images_info, title, cols=9):
@@ -1391,10 +1417,14 @@ def _build_dynadiff_panel():
1391
  dd_feat_remove_btn.on_click(_on_remove_feat)
1392
  dd_feat_clear_btn.on_click(_on_clear_feats)
1393
 
1394
- def _reconstruct_thread(sample_idx, steerings, seed, feat_name, doc):
1395
  try:
1396
- resp = _dynadiff_request(sample_idx, steerings, seed)
1397
- html = _make_steering_html(resp, feat_name)
 
 
 
 
1398
  def _apply(html=html):
1399
  dd_output.text = html
1400
  dd_status.text = ''
@@ -1422,13 +1452,25 @@ def _build_dynadiff_panel():
1422
  if not steerings:
1423
  dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>'
1424
  return
 
1425
  try:
1426
- sample_idx = int(dd_sample_input.value)
1427
  except ValueError:
1428
  dd_status.text = '<span style="color:#c00">Invalid sample index.</span>'
1429
  return
 
 
 
 
 
 
 
 
 
 
 
1430
  _n = _dd_loader.n_samples
1431
- if _n is not None and not (0 <= sample_idx < _n):
1432
  dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>'
1433
  return
1434
  try:
@@ -1438,11 +1480,13 @@ def _build_dynadiff_panel():
1438
  names = list(dd_source.data['name'])
1439
  feat_name = ' + '.join(names) if names else 'unknown'
1440
  dd_btn.disabled = True
1441
- dd_status.text = '<i style="color:#888">Running DynaDiff reconstruction…</i>'
 
 
1442
  doc = curdoc()
1443
  threading.Thread(
1444
  target=_reconstruct_thread,
1445
- args=(sample_idx, steerings, seed, feat_name, doc),
1446
  daemon=True,
1447
  ).start()
1448
 
@@ -1587,14 +1631,15 @@ def update_feature_display(feature_idx):
1587
  else:
1588
  hmap = None
1589
 
 
1590
  if hmap is None:
1591
  plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR)
1592
  act_val = float(act_tensor[feat, ranking_idx].item())
1593
- caption = f"act={act_val:.4f} img {img_i}"
1594
  return (plain, caption)
1595
  max_act, mean_act_val = _patch_stats(hmap.flatten())
1596
  img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center)
1597
- caption = f"img {img_i}"
1598
  return (img_out, caption)
1599
  except Exception as e:
1600
  ph = Image.new("RGB", (THUMB, THUMB), "gray")
@@ -1817,11 +1862,13 @@ feature_list_source = ColumnDataSource(data=dict(
1817
  name=[_display_name(int(i)) for i in _init_order],
1818
  ))
1819
 
1820
- _phi_col = (
1821
- [TableColumn(field="phi_c_val", title="φ_c", width=65,
1822
- formatter=NumberFormatter(format="0.0000"))]
1823
- if HAS_PHI else []
1824
- )
 
 
1825
  feature_table = DataTable(
1826
  source=feature_list_source,
1827
  columns=[
@@ -1830,9 +1877,7 @@ feature_table = DataTable(
1830
  formatter=NumberFormatter(format="0,0")),
1831
  TableColumn(field="mean_act", title="Mean Act", width=80,
1832
  formatter=NumberFormatter(format="0.0000")),
1833
- TableColumn(field="p75_val", title="P75", width=70,
1834
- formatter=NumberFormatter(format="0.0000")),
1835
- ] + _phi_col + [
1836
  TableColumn(field="name", title="Name", width=200),
1837
  ],
1838
  width=500, height=500, sortable=True, index_position=None,
@@ -2170,20 +2215,20 @@ load_patch_btn = Button(label="Load Image", width=90, button_type="primary")
2170
  clear_patch_btn = Button(label="Clear", width=60)
2171
 
2172
  patch_feat_source = ColumnDataSource(data=dict(
2173
- feature_idx=[], patch_act=[], frequency=[], mean_act=[],
2174
  ))
2175
  patch_feat_table = DataTable(
2176
  source=patch_feat_source,
2177
  columns=[
2178
- TableColumn(field="feature_idx", title="Feature", width=65),
2179
  TableColumn(field="patch_act", title="Patch Act", width=85,
2180
  formatter=NumberFormatter(format="0.0000")),
2181
  TableColumn(field="frequency", title="Freq", width=65,
2182
  formatter=NumberFormatter(format="0,0")),
2183
  TableColumn(field="mean_act", title="Mean Act", width=80,
2184
  formatter=NumberFormatter(format="0.0000")),
2185
- ],
2186
- width=310, height=350, index_position=None, sortable=False, visible=False,
2187
  )
2188
  patch_info_div = Div(
2189
  text="<i>Load an image, then click patches to find top features.</i>",
@@ -2203,7 +2248,7 @@ def _pil_to_bokeh_rgba(pil_img, size):
2203
 
2204
  def _do_load_patch_image():
2205
  try:
2206
- img_idx = int(patch_img_input.value)
2207
  except ValueError:
2208
  patch_info_div.text = "<b style='color:red'>Invalid image index</b>"
2209
  return
@@ -2292,7 +2337,7 @@ def _on_patch_select(attr, old, new):
2292
  if _S.patch_img is None:
2293
  return
2294
  if not new:
2295
- patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[])
2296
  patch_info_div.text = "<i>Selection cleared.</i>"
2297
  return
2298
 
@@ -2302,7 +2347,10 @@ def _on_patch_select(attr, old, new):
2302
  patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)]
2303
 
2304
  feats, acts, freqs, means = _get_top_features_for_patches(patch_indices)
2305
- patch_feat_source.data = dict(feature_idx=feats, patch_act=acts, frequency=freqs, mean_act=means)
 
 
 
2306
  patch_info_div.text = (
2307
  f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. "
2308
  f"Click a row below to explore the feature."
@@ -2347,7 +2395,7 @@ def _build_clip_panel():
2347
  clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70)
2348
 
2349
  result_source = ColumnDataSource(data=dict(
2350
- feature_idx=[], clip_score=[], frequency=[], mean_act=[], name=[],
2351
  ))
2352
  clip_result_table = DataTable(
2353
  source=result_source,
@@ -2359,9 +2407,10 @@ def _build_clip_panel():
2359
  formatter=NumberFormatter(format="0,0")),
2360
  TableColumn(field="mean_act", title="Mean Act", width=80,
2361
  formatter=NumberFormatter(format="0.0000")),
 
2362
  TableColumn(field="name", title="Name", width=160),
2363
  ],
2364
- width=470, height=300, index_position=None, sortable=False,
2365
  )
2366
 
2367
  def _do_search():
@@ -2402,6 +2451,7 @@ def _build_clip_panel():
2402
  clip_score=[float(scores_vec[i]) for i in top_indices],
2403
  frequency=[int(feature_frequency[i].item()) for i in top_indices],
2404
  mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
 
2405
  name=[_display_name(int(i)) for i in top_indices],
2406
  )
2407
  result_div.text = (
 
430
  'feature_p75_val': bd.get('feature_p75_val', torch.zeros(d_model)),
431
  'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2,
432
  'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2,
433
+ 'clip_scores': bd.get('clip_text_scores', None),
434
+ 'clip_vocab': bd.get('clip_text_vocab', None),
435
+ 'clip_embeds': bd.get('clip_feature_embeds', None),
436
+ 'clip_scores_f32': bd['clip_text_scores'].float() if 'clip_text_scores' in bd else None,
437
  'inference_cache': OrderedDict(),
438
  'names_file': stem + '_feature_names.json',
439
  'auto_interp_file': stem + '_auto_interp.json',
 
633
  idx = ds.get(idx_key) # (d_sae, n_slots) int tensor
634
  if hm is None or idx is None:
635
  continue
636
+ # Normalise: flatten 4-D (d_sae, n_slots, H, W) → 3-D (d_sae, n_slots, H*W)
637
+ if hm.ndim == 4:
638
+ hm = hm.reshape(hm.shape[0], hm.shape[1], -1)
639
  if z is None:
640
  d_sae, _, n_patches_sq = hm.shape
641
  z = np.zeros((n_patches_sq, d_sae), dtype=np.float32)
 
707
  THUMB = args.thumb_size
708
 
709
 
710
+ def _parse_img_label(value):
711
+ """Parse an image label into an integer index.
712
+
713
+ Accepts bare integers ('42') or name-prefixed labels ('nsd_00042',
714
+ 'COCO_val2014_000000123456') by extracting the trailing integer after
715
+ the last underscore. Raises ValueError on failure.
716
+ """
717
+ val = value.strip()
718
+ try:
719
+ return int(val)
720
+ except ValueError:
721
+ return int(val.rsplit('_', 1)[-1])
722
+
723
+
724
  def _resolve_img_path(stored_path):
725
  """Resolve a stored image path, searching image dirs first. Returns None on failure."""
726
  if os.path.isabs(stored_path) and os.path.exists(stored_path):
 
969
  return _dd_loader.reconstruct(sample_idx, steerings, seed)
970
 
971
 
972
+ def _make_steering_html(resps, concept_name):
973
+ """Build HTML showing GT | Baseline | Steered for one or more trials.
974
+
975
+ resps: list of (trial_label, resp_dict) pairs.
976
+ """
977
+ header = (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
  f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;'
979
  f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>'
 
980
  )
981
+ rows_html = ''
982
+ for trial_label, resp in resps:
983
+ parts = []
984
+ for label, key in [('GT', 'gt_img'),
985
+ ('Baseline', 'baseline_img'),
986
+ ('Steered', 'steered_img')]:
987
+ b64 = resp.get(key)
988
+ if b64 is None:
989
+ img_html = ('<div style="width:160px;height:160px;background:#eee;'
990
+ 'display:flex;align-items:center;justify-content:center;'
991
+ 'color:#999;font-size:12px">N/A</div>')
992
+ else:
993
+ img_html = (f'<img src="data:image/png;base64,{b64}" '
994
+ 'style="width:160px;height:160px;object-fit:contain;'
995
+ 'border:1px solid #ddd;border-radius:4px"/>')
996
+ parts.append(
997
+ f'<div style="text-align:center;margin:0 4px">'
998
+ f'{img_html}'
999
+ f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>'
1000
+ f'</div>'
1001
+ )
1002
+ trial_head = (f'<div style="font-size:11px;font-weight:bold;color:#777;'
1003
+ f'margin:6px 0 3px 4px">{trial_label}</div>')
1004
+ rows_html += (trial_head
1005
+ + '<div style="display:flex;align-items:flex-end;margin-bottom:8px">'
1006
+ + ''.join(parts) + '</div>')
1007
+ return header + rows_html
1008
 
1009
 
1010
  def make_image_grid_html(images_info, title, cols=9):
 
1417
  dd_feat_remove_btn.on_click(_on_remove_feat)
1418
  dd_feat_clear_btn.on_click(_on_clear_feats)
1419
 
1420
+ def _reconstruct_thread(sample_idxs, steerings, seed, feat_name, doc):
1421
  try:
1422
+ resps = []
1423
+ for i, sidx in enumerate(sample_idxs):
1424
+ trial_label = f'Trial {i+1} (sample {sidx})'
1425
+ resp = _dynadiff_request(sidx, steerings, seed)
1426
+ resps.append((trial_label, resp))
1427
+ html = _make_steering_html(resps, feat_name)
1428
  def _apply(html=html):
1429
  dd_output.text = html
1430
  dd_status.text = ''
 
1452
  if not steerings:
1453
  dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>'
1454
  return
1455
+ _raw = dd_sample_input.value.strip()
1456
  try:
1457
+ _parsed = _parse_img_label(_raw)
1458
  except ValueError:
1459
  dd_status.text = '<span style="color:#c00">Invalid sample index.</span>'
1460
  return
1461
+ # If input looks like an NSD image label (contains '_'), treat _parsed as
1462
+ # an NSD image index and run all trials for that image.
1463
+ if '_' in _raw:
1464
+ sample_idxs = _dd_loader.sample_idxs_for_nsd_img(_parsed)
1465
+ if not sample_idxs:
1466
+ dd_status.text = (
1467
+ f'<span style="color:#c00">NSD image {_parsed} has no trials '
1468
+ f'for this subject.</span>')
1469
+ return
1470
+ else:
1471
+ sample_idxs = [_parsed]
1472
  _n = _dd_loader.n_samples
1473
+ if _n is not None and any(not (0 <= s < _n) for s in sample_idxs):
1474
  dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>'
1475
  return
1476
  try:
 
1480
  names = list(dd_source.data['name'])
1481
  feat_name = ' + '.join(names) if names else 'unknown'
1482
  dd_btn.disabled = True
1483
+ n_trials = len(sample_idxs)
1484
+ dd_status.text = (f'<i style="color:#888">Running DynaDiff reconstruction '
1485
+ f'({n_trials} trial{"s" if n_trials > 1 else ""})…</i>')
1486
  doc = curdoc()
1487
  threading.Thread(
1488
  target=_reconstruct_thread,
1489
+ args=(sample_idxs, steerings, seed, feat_name, doc),
1490
  daemon=True,
1491
  ).start()
1492
 
 
1631
  else:
1632
  hmap = None
1633
 
1634
+ img_label = os.path.splitext(os.path.basename(image_paths[img_i]))[0]
1635
  if hmap is None:
1636
  plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR)
1637
  act_val = float(act_tensor[feat, ranking_idx].item())
1638
+ caption = f"act={act_val:.4f} {img_label}"
1639
  return (plain, caption)
1640
  max_act, mean_act_val = _patch_stats(hmap.flatten())
1641
  img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center)
1642
+ caption = img_label
1643
  return (img_out, caption)
1644
  except Exception as e:
1645
  ph = Image.new("RGB", (THUMB, THUMB), "gray")
 
1862
  name=[_display_name(int(i)) for i in _init_order],
1863
  ))
1864
 
1865
+ def _phi_col():
1866
+ """Return phi_c column definition list (single element) if phi data is loaded, else []."""
1867
+ if not HAS_PHI:
1868
+ return []
1869
+ return [TableColumn(field="phi_c_val", title="φ_c", width=65,
1870
+ formatter=NumberFormatter(format="0.0000"))]
1871
+
1872
  feature_table = DataTable(
1873
  source=feature_list_source,
1874
  columns=[
 
1877
  formatter=NumberFormatter(format="0,0")),
1878
  TableColumn(field="mean_act", title="Mean Act", width=80,
1879
  formatter=NumberFormatter(format="0.0000")),
1880
+ ] + _phi_col() + [
 
 
1881
  TableColumn(field="name", title="Name", width=200),
1882
  ],
1883
  width=500, height=500, sortable=True, index_position=None,
 
2215
  clear_patch_btn = Button(label="Clear", width=60)
2216
 
2217
  patch_feat_source = ColumnDataSource(data=dict(
2218
+ feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[],
2219
  ))
2220
  patch_feat_table = DataTable(
2221
  source=patch_feat_source,
2222
  columns=[
2223
+ TableColumn(field="feature_idx", title="Feature", width=65),
2224
  TableColumn(field="patch_act", title="Patch Act", width=85,
2225
  formatter=NumberFormatter(format="0.0000")),
2226
  TableColumn(field="frequency", title="Freq", width=65,
2227
  formatter=NumberFormatter(format="0,0")),
2228
  TableColumn(field="mean_act", title="Mean Act", width=80,
2229
  formatter=NumberFormatter(format="0.0000")),
2230
+ ] + _phi_col(),
2231
+ width=310 + (65 if HAS_PHI else 0), height=350, index_position=None, sortable=False, visible=False,
2232
  )
2233
  patch_info_div = Div(
2234
  text="<i>Load an image, then click patches to find top features.</i>",
 
2248
 
2249
  def _do_load_patch_image():
2250
  try:
2251
+ img_idx = _parse_img_label(patch_img_input.value)
2252
  except ValueError:
2253
  patch_info_div.text = "<b style='color:red'>Invalid image index</b>"
2254
  return
 
2337
  if _S.patch_img is None:
2338
  return
2339
  if not new:
2340
+ patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[])
2341
  patch_info_div.text = "<i>Selection cleared.</i>"
2342
  return
2343
 
 
2347
  patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)]
2348
 
2349
  feats, acts, freqs, means = _get_top_features_for_patches(patch_indices)
2350
+ patch_feat_source.data = dict(
2351
+ feature_idx=feats, patch_act=acts, frequency=freqs, mean_act=means,
2352
+ phi_c_val=_phi_c_vals(feats),
2353
+ )
2354
  patch_info_div.text = (
2355
  f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. "
2356
  f"Click a row below to explore the feature."
 
2395
  clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70)
2396
 
2397
  result_source = ColumnDataSource(data=dict(
2398
+ feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[],
2399
  ))
2400
  clip_result_table = DataTable(
2401
  source=result_source,
 
2407
  formatter=NumberFormatter(format="0,0")),
2408
  TableColumn(field="mean_act", title="Mean Act", width=80,
2409
  formatter=NumberFormatter(format="0.0000")),
2410
+ ] + _phi_col() + [
2411
  TableColumn(field="name", title="Name", width=160),
2412
  ],
2413
+ width=470 + (65 if HAS_PHI else 0), height=300, index_position=None, sortable=False,
2414
  )
2415
 
2416
  def _do_search():
 
2451
  clip_score=[float(scores_vec[i]) for i in top_indices],
2452
  frequency=[int(feature_frequency[i].item()) for i in top_indices],
2453
  mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
2454
+ phi_c_val=_phi_c_vals(top_indices),
2455
  name=[_display_name(int(i)) for i in top_indices],
2456
  )
2457
  result_div.text = (