Spaces:
Sleeping
Sleeping
| """ | |
| Dataset loading for the SAE Feature Explorer. | |
| Both regular explorer_data.pt files and brain_meis.pt files are loaded | |
| through a single function that produces the same dict schema. Derived | |
| numpy arrays (freq, log_freq, live_mask, umap_backup, etc.) are | |
| pre-computed at load time so callbacks never recompute them. | |
| """ | |
| import json | |
| import os | |
| from collections import OrderedDict | |
| import numpy as np | |
| import torch | |
| from .args import args | |
| from .state import _all_datasets | |
| # ---------- Helpers ---------- | |
| def _build_basename_index(paths: list) -> dict: | |
| """Map both full basename and stem → image index for fast filename lookup.""" | |
| idx = {} | |
| for i, p in enumerate(paths): | |
| base = os.path.basename(p) | |
| stem = os.path.splitext(base)[0] | |
| idx[base] = i | |
| idx[stem] = i | |
| return idx | |
| # ---------- Core loader ---------- | |
| def _load_dataset(path: str, label: str, *, | |
| sae_url: str | None = None, | |
| is_brain: bool = False, | |
| thumb_dir: str = "") -> dict | None: | |
| """ | |
| Load one dataset file (.pt) and return a fully-populated dict. | |
| Both explorer_data.pt and brain_meis.pt files are handled here. | |
| The returned dict includes: | |
| - all raw tensors from the file | |
| - per-dataset derived arrays (freq, log_freq, live_mask, umap_backup, …) | |
| - feature names and auto-interp labels read from JSON sidecars | |
| - pre-computed heatmaps read from the _heatmaps.pt sidecar if present | |
| """ | |
| print(f"Loading [{label}] from {path} ...") | |
| try: | |
| d = torch.load(path, map_location='cpu', weights_only=False) | |
| except Exception as err: | |
| print(f" WARNING: failed to load {path}: {err}") | |
| return None | |
| # Resolve image paths for brain datasets where paths may be stored as basenames. | |
| raw_paths = d.get('image_paths', []) | |
| if is_brain and raw_paths and thumb_dir and ( | |
| not os.path.isabs(raw_paths[0]) or not os.path.exists(raw_paths[0]) | |
| ): | |
| image_paths = [os.path.join(thumb_dir, os.path.basename(p)) for p in raw_paths] | |
| else: | |
| image_paths = raw_paths | |
| d_model = d['d_model'] | |
| nan2 = np.full((d_model, 2), np.nan, dtype=np.float32) | |
| stem = os.path.splitext(path)[0] | |
| # Feature names (manual labels) | |
| names_file = (args.names_file if (path == args.data and args.names_file) | |
| else stem + '_feature_names.json') | |
| feature_names = {} | |
| if os.path.exists(names_file): | |
| with open(names_file) as f: | |
| feature_names = {int(k): v for k, v in json.load(f).items()} | |
| # Auto-interp labels | |
| auto_interp_file = stem + '_auto_interp.json' | |
| auto_interp_names = {} | |
| if os.path.exists(auto_interp_file): | |
| with open(auto_interp_file) as f: | |
| auto_interp_names = {int(k): v for k, v in json.load(f).items()} | |
| print(f" Loaded {len(auto_interp_names)} auto-interp labels") | |
| # Core tensors | |
| feature_frequency = d['feature_frequency'] | |
| feature_mean_act = d['feature_mean_act'] | |
| umap_coords = d['umap_coords'].numpy() | |
| dict_umap_coords = (d['dict_umap_coords'].numpy() | |
| if 'dict_umap_coords' in d else nan2) | |
| # Derived arrays — computed once, stored in the dict | |
| freq = feature_frequency.numpy() | |
| mean_act = feature_mean_act.numpy() | |
| log_freq = np.log10(freq + 1) | |
| live_mask = ~np.isnan(umap_coords[:, 0]) | |
| live_indices = np.where(live_mask)[0] | |
| dict_live_mask = ~np.isnan(dict_umap_coords[:, 0]) | |
| dict_live_indices = np.where(dict_live_mask)[0] | |
| active_feats = [int(i) for i in range(d_model) if freq[i] > 0] | |
| umap_backup = dict( | |
| act_x = umap_coords[live_mask, 0].tolist(), | |
| act_y = umap_coords[live_mask, 1].tolist(), | |
| act_feat = live_indices.tolist(), | |
| dict_x = dict_umap_coords[dict_live_mask, 0].tolist(), | |
| dict_y = dict_umap_coords[dict_live_mask, 1].tolist(), | |
| dict_feat= dict_live_indices.tolist(), | |
| ) | |
| entry = { | |
| 'label': label, | |
| 'path': path, | |
| 'image_paths': image_paths, | |
| 'basename_index': _build_basename_index(image_paths), | |
| 'd_model': d_model, | |
| 'n_images': d.get('n_images', len(image_paths)), | |
| 'patch_grid': d.get('patch_grid', 16), | |
| 'image_size': d.get('image_size', 224), | |
| 'backbone': d.get('backbone', 'dinov2'), | |
| 'top_img_idx': d['top_img_idx'], | |
| 'top_img_act': d['top_img_act'], | |
| 'mean_img_idx': d.get('mean_img_idx', d['top_img_idx']), | |
| 'mean_img_act': d.get('mean_img_act', d['top_img_act']), | |
| 'nsd_top_img_idx': d.get('nsd_top_img_idx'), | |
| 'nsd_top_img_act': d.get('nsd_top_img_act'), | |
| 'nsd_mean_img_idx': d.get('nsd_mean_img_idx'), | |
| 'nsd_mean_img_act': d.get('nsd_mean_img_act'), | |
| 'feature_frequency': feature_frequency, | |
| 'feature_mean_act': feature_mean_act, | |
| 'umap_coords': umap_coords, | |
| 'dict_umap_coords': dict_umap_coords, | |
| 'clip_embeds': d.get('clip_feature_embeds'), | |
| 'nsd_clip_embeds': d.get('nsd_clip_feature_embeds'), | |
| 'sae_url': sae_url, | |
| 'inference_cache': OrderedDict(), | |
| 'names_file': names_file, | |
| 'auto_interp_file': auto_interp_file, | |
| 'feature_names': feature_names, | |
| 'auto_interp_names': auto_interp_names, | |
| # Pre-computed derived arrays | |
| 'freq': freq, | |
| 'mean_act': mean_act, | |
| 'log_freq': log_freq, | |
| 'live_mask': live_mask, | |
| 'live_indices': live_indices, | |
| 'dict_live_mask': dict_live_mask, | |
| 'dict_live_indices': dict_live_indices, | |
| 'active_feats': active_feats, | |
| 'umap_backup': umap_backup, | |
| } | |
| # Brain MEI sidecar (disabled — re-enable after running precompute_brain_response_meis.py) | |
| # brain_sidecar = stem + '_brain_meis.pt' | |
| # if os.path.exists(brain_sidecar): | |
| # print(f" Loading brain MEI sidecar {os.path.basename(brain_sidecar)} ...") | |
| # bm = torch.load(brain_sidecar, map_location='cpu', weights_only=False) | |
| # entry['brain_top_img_idx'] = bm.get('brain_top_img_idx') | |
| # entry['brain_top_img_act'] = bm.get('brain_top_img_act') | |
| # else: | |
| # entry['brain_top_img_idx'] = None | |
| # entry['brain_top_img_act'] = None | |
| # Heatmaps sidecar | |
| sidecar = stem + '_heatmaps.pt' | |
| if os.path.exists(sidecar): | |
| print(f" Loading heatmaps sidecar {os.path.basename(sidecar)} ...") | |
| hm = torch.load(sidecar, map_location='cpu', weights_only=not is_brain) | |
| entry['top_heatmaps'] = hm.get('top_heatmaps') | |
| entry['mean_heatmaps'] = hm.get('mean_heatmaps') | |
| entry['nsd_top_heatmaps'] = hm.get('nsd_top_heatmaps') | |
| entry['nsd_mean_heatmaps'] = hm.get('nsd_mean_heatmaps') | |
| entry['heatmap_patch_grid'] = hm.get('patch_grid', d.get('patch_grid', 16)) | |
| has_hm = 'yes' | |
| else: | |
| entry['top_heatmaps'] = None | |
| entry['mean_heatmaps'] = None | |
| entry['nsd_top_heatmaps'] = None | |
| entry['nsd_mean_heatmaps'] = None | |
| entry['heatmap_patch_grid'] = d.get('patch_grid', 16) | |
| has_hm = 'no' | |
| # Brain render sidecar (precomputed compact phi map PNGs) | |
| brain_render_sidecar = stem + '_brain_renders.pt' | |
| if os.path.exists(brain_render_sidecar): | |
| print(f" Loading brain render sidecar {os.path.basename(brain_render_sidecar)} ...") | |
| br = torch.load(brain_render_sidecar, map_location='cpu', weights_only=False) | |
| entry['phi_map_cache'] = {int(k): v for k, v in br.items()} | |
| print(f" Cached {len(entry['phi_map_cache'])} phi map renders") | |
| else: | |
| entry['phi_map_cache'] = {} | |
| # Cortical profile sidecar (precomputed 4-view cortical profile PNGs) | |
| cortical_sidecar = stem + '_cortical_profiles.pt' | |
| if os.path.exists(cortical_sidecar): | |
| print(f" Loading cortical profile sidecar {os.path.basename(cortical_sidecar)} ...") | |
| cp = torch.load(cortical_sidecar, map_location='cpu', weights_only=False) | |
| entry['cortical_profile_cache'] = {int(k): v for k, v in cp.items()} | |
| print(f" Cached {len(entry['cortical_profile_cache'])} cortical profiles") | |
| else: | |
| entry['cortical_profile_cache'] = {} | |
| # GT brain render sidecar (precomputed fMRI response PNGs per NSD image) | |
| gt_brain_sidecar = stem + '_gt_brain_renders.pt' | |
| if os.path.exists(gt_brain_sidecar): | |
| print(f" Loading GT brain render sidecar {os.path.basename(gt_brain_sidecar)} ...") | |
| gb = torch.load(gt_brain_sidecar, map_location='cpu', weights_only=False) | |
| entry['gt_brain_cache'] = {int(k): v for k, v in gb.items()} | |
| print(f" Cached {len(entry['gt_brain_cache'])} GT brain renders") | |
| else: | |
| entry['gt_brain_cache'] = {} | |
| has_clip = 'yes' if entry['clip_embeds'] is not None else 'no' | |
| print(f" d={d_model}, n={entry['n_images']}, backbone={entry['backbone']}, " | |
| f"clip={has_clip}, heatmaps={has_hm}") | |
| return entry | |
| # ---------- Public entry point ---------- | |
| def load_all_datasets(): | |
| """Load all datasets specified by CLI args into _all_datasets.""" | |
| # Primary dataset (always required) | |
| _all_datasets.append( | |
| _load_dataset(args.data, args.primary_label, sae_url=args.sae_url) | |
| ) | |
| # Optional NSD brain dataset | |
| if args.brain_data: | |
| if os.path.exists(args.brain_data): | |
| entry = _load_dataset( | |
| args.brain_data, args.brain_label, | |
| is_brain=True, | |
| thumb_dir=args.brain_thumbnails or '', | |
| ) | |
| if entry is not None: | |
| _all_datasets.append(entry) | |
| else: | |
| print(f"[Brain] WARNING: --brain-data not found: {args.brain_data}") | |