| """ |
| Pure steering and patch-exploration logic β no Bokeh dependencies. |
| |
| Provides computation functions called by the panels/steering.py UI layer. |
| Functions here depend on dataset state and brain data, but never on Bokeh |
| widgets, callbacks, or the document event loop. |
| """ |
|
|
| import base64 |
| import io |
| import os |
|
|
| import numpy as np |
|
|
| from .args import args |
| from .state import active_ds |
| from .inference import run_gpu_inference |
| from .rendering import load_image |
| from .brain import ( |
| _dd_loader, |
| phi_voxel_row, phi_cv_shape, feat_display_name, |
| apply_steering_fmri, dynadiff_request, get_dd_fmri, |
| ) |
|
|
| _N_VOXELS_DD = 15724 |
|
|
|
|
| |
|
|
| def compute_patch_activations(img_idx: int) -> np.ndarray | None: |
| """LRU-cached GPU inference for a single image. |
| |
| Returns (n_patches, d_sae) float32 or None if GPU unavailable. |
| """ |
| ds = active_ds() |
| cache = ds['inference_cache'] |
| if img_idx in cache: |
| cache.move_to_end(img_idx) |
| return cache[img_idx] |
| pil = load_image(img_idx) |
| z_np = run_gpu_inference(pil) |
| if z_np is not None: |
| cache[img_idx] = z_np |
| if len(cache) > args.inference_cache_size: |
| cache.popitem(last=False) |
| return z_np |
|
|
|
|
| def get_top_features_for_patches(z: np.ndarray | None, |
| patch_indices: list, |
| top_n: int = 20): |
| """Return (feats, act_sums, freqs, means) for top features across patches.""" |
| if z is None: |
| return [], [], [], [] |
| z_sel = z[patch_indices] |
| feat_sums = z_sel.sum(axis=0) |
| top_feats = np.argsort(-feat_sums)[:top_n] |
| top_feats = top_feats[feat_sums[top_feats] > 0] |
| feats = top_feats.tolist() |
| acts = feat_sums[top_feats].tolist() |
| ds = active_ds() |
| freqs = [int(ds['feature_frequency'][f].item()) for f in feats] |
| means = [float(ds['feature_mean_act'][f].item()) for f in feats] |
| print(f"[patch] {len(patch_indices)} patches β {len(feats)} features, " |
| f"max_sum={feat_sums.max():.4f}") |
| return feats, acts, freqs, means |
|
|
|
|
| |
|
|
| def resolve_nsd_basename(img_idx: int) -> str | None: |
| """Return 'nsd_XXXXX' basename if the image is NSD, else None.""" |
| ds = active_ds() |
| basename = os.path.splitext(os.path.basename(ds['image_paths'][img_idx]))[0] |
| return basename if basename.startswith('nsd_') else None |
|
|
|
|
| def parse_nsd_img_idx(nsd_basename: str) -> int | None: |
| """Extract integer NSD image index from 'nsd_XXXXX' string.""" |
| if not nsd_basename or not nsd_basename.startswith('nsd_'): |
| return None |
| try: |
| return int(nsd_basename.rsplit('_', 1)[-1]) |
| except ValueError: |
| return None |
|
|
|
|
| def load_gt_thumbnail_b64(nsd_img_idx: int) -> str | None: |
| """Load GT brain thumbnail from local brain_thumbnails dir as base64 PNG.""" |
| thumb_dir = getattr(args, 'brain_thumbnails', None) |
| if not thumb_dir: |
| return None |
| path = os.path.join(thumb_dir, f'nsd_{nsd_img_idx:05d}.jpg') |
| if not os.path.isfile(path): |
| return None |
| try: |
| from PIL import Image |
| img = Image.open(path).convert('RGB').resize((160, 160)) |
| buf = io.BytesIO() |
| img.save(buf, format='PNG') |
| return base64.b64encode(buf.getvalue()).decode() |
| except Exception: |
| return None |
|
|
|
|
| def load_gt_fmri(nsd_basename: str) -> tuple: |
| """Load GT fMRI for an NSD image. |
| |
| Returns (sample_idx, fmri_array) or (None, None). |
| """ |
| nsd_img_idx = parse_nsd_img_idx(nsd_basename) |
| if nsd_img_idx is None or _dd_loader is None: |
| return None, None |
| sample_idxs = _dd_loader.sample_idxs_for_nsd_img(nsd_img_idx) |
| if not sample_idxs: |
| return None, None |
| fmri = get_dd_fmri(sample_idxs[0]) |
| return sample_idxs[0], fmri |
|
|
|
|
| |
|
|
| def compute_steering_direction(feats, lams, thresholds): |
| """Combine phi vectors into a single steering direction (N_VOXELS,) float32.""" |
| combined = np.zeros(_N_VOXELS_DD, dtype=np.float32) |
| for f, lam, thr in zip(feats, lams, thresholds): |
| phi = phi_voxel_row(f) |
| if phi is None: |
| continue |
| phi_max = float(np.abs(phi).max()) |
| if phi_max < 1e-12: |
| continue |
| norm_phi = phi / phi_max |
| if thr < 1.0: |
| cutoff = float(np.percentile(np.abs(phi), 100.0 * (1.0 - thr))) |
| norm_phi = norm_phi * (np.abs(phi) >= cutoff) |
| combined += lam * norm_phi |
| return combined |
|
|
|
|
| def build_steerings(feats, lams, thresholds): |
| """Build [(phi_voxel, lam, threshold), ...] tuples for dynadiff_request.""" |
| return [(phi_voxel_row(f), float(lam), float(thr)) |
| for f, lam, thr in zip(feats, lams, thresholds) |
| if phi_voxel_row(f) is not None] |
|
|
|
|
| def compute_steered_fmri(gt_fmri, feats, lams, thresholds): |
| """Apply steering perturbation to ground-truth fMRI.""" |
| steerings = build_steerings(feats, lams, thresholds) |
| return apply_steering_fmri(gt_fmri, steerings) |
|
|
|
|
| def validate_feature(feat: int) -> str | None: |
| """Return error message if feature can't be steered, or None if OK.""" |
| shape = phi_cv_shape() |
| if shape is None or feat < 0 or feat >= shape[0]: |
| return f'No phi data for feature {feat}.' |
| return None |
|
|
|
|
| def make_steering_entry(feat: int, lam: float = 3.0, |
| threshold: float = 0.10) -> dict: |
| """Create a single steering entry dict.""" |
| return dict(feat=feat, name=feat_display_name(feat), |
| lam=lam, threshold=threshold) |
|
|
|
|
| |
|
|
| def validate_reconstruction(nsd_basename, feats, lams, thresholds): |
| """Validate inputs before running DynaDiff. |
| |
| Returns (sample_idxs, steerings, error_msg). |
| If error_msg is not None, the other values are None. |
| """ |
| if not feats: |
| return None, None, 'Add at least one feature first.' |
|
|
| steerings = build_steerings(feats, lams, thresholds) |
| if not steerings: |
| return None, None, 'No phi data for selected features.' |
|
|
| if not nsd_basename or not nsd_basename.startswith('nsd_'): |
| return None, None, 'Load an NSD image in the patch explorer first.' |
|
|
| nsd_img_idx = parse_nsd_img_idx(nsd_basename) |
| if nsd_img_idx is None: |
| return None, None, 'Could not parse NSD image index.' |
|
|
| sample_idxs = _dd_loader.sample_idxs_for_nsd_img(nsd_img_idx) |
| if not sample_idxs: |
| return None, None, (f'NSD image {nsd_img_idx} has no trials ' |
| f'for this subject.') |
|
|
| n = _dd_loader.n_samples |
| if n is not None and any(not (0 <= s < n) for s in sample_idxs): |
| return None, None, f'sample_idx must be 0β{n - 1}.' |
|
|
| status, err = _dd_loader.status |
| if status == 'loading': |
| return None, None, 'DynaDiff model still loading β try again shortly.' |
| if status == 'error': |
| return None, None, f'DynaDiff model load failed: {err}' |
|
|
| return sample_idxs, steerings, None |
|
|
|
|
| def run_reconstruction(sample_idxs, steerings, seed=42, |
| nsd_img_idx=None): |
| """Run DynaDiff reconstruction. Returns response dict. May raise.""" |
| resp = dynadiff_request(sample_idxs[0], steerings, seed) |
| if resp.get('gt_img') is None and nsd_img_idx is not None: |
| resp = dict(resp) |
| resp['gt_img'] = load_gt_thumbnail_b64(nsd_img_idx) |
| return resp |
|
|