""" DynaDiff in-process loader. Loads the DynaDiff model and exposes a reconstruct() method that returns the same dict format as the HTTP server's /reconstruct endpoint: { "baseline_img": "", "steered_img": "", "gt_img": " | None", "beta_std": float, } Usage (in explorer_app.py): from dynadiff_loader import DynaDiffLoader loader = DynaDiffLoader(dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir) loader.start() # begins background model load loader.n_samples # None until ready loader.is_ready # True when model is loaded result = loader.reconstruct(sample_idx, steerings, seed) """ import base64 import io import logging import os import threading import numpy as np logging.basicConfig( level=logging.INFO, format='[DynaDiff %(levelname)s %(asctime)s] %(message)s', datefmt='%H:%M:%S', ) log = logging.getLogger(__name__) N_VOXELS = 15724 # ── Process-level singleton ─────────────────────────────────────────────────── # Bokeh re-executes the app script per session, so DynaDiffLoader would be # instantiated multiple times. We keep one loader alive for the whole process # so the model is loaded exactly once and all sessions share it. _singleton: "DynaDiffLoader | None" = None _singleton_lock = threading.Lock() def get_loader(dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir=None, subject_idx=0) -> "DynaDiffLoader": """Return the process-level loader, creating and starting it if needed.""" global _singleton with _singleton_lock: if _singleton is None: _singleton = DynaDiffLoader( dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir, subject_idx) _singleton.start() return _singleton def _img_to_b64(img_np): """(H, W, 3) float32 [0,1] → base64 PNG string.""" import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt buf = io.BytesIO() plt.imsave(buf, np.clip(img_np, 0, 1), format='png') return base64.b64encode(buf.getvalue()).decode('utf-8') class DynaDiffLoader: def __init__(self, dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir=None, subject_idx=0): self.dynadiff_dir = os.path.abspath(dynadiff_dir) self.checkpoint = checkpoint self.h5_path = h5_path if os.path.isabs(h5_path) \ else os.path.join(self.dynadiff_dir, h5_path) self.nsd_thumb_dir = nsd_thumb_dir self.subject_idx = subject_idx self._model = None self._cfg = None self._beta_std = None self._subject_sample_indices = None self._nsd_to_sample = {} self._status = 'loading' # 'loading' | 'ok' | 'error' self._error = '' self._lock = threading.Lock() # ── public properties ──────────────────────────────────────────────────── @property def is_ready(self): with self._lock: return self._status == 'ok' @property def status(self): with self._lock: return self._status, self._error @property def n_samples(self): with self._lock: idx = self._subject_sample_indices return len(idx) if idx is not None else None def sample_idxs_for_nsd_img(self, nsd_img_idx): """Return the list of sample_idx values that correspond to a given NSD image index. Returns an empty list if the image has no trials for this subject or the mapping is not yet built (model still loading). """ with self._lock: return list(self._nsd_to_sample.get(int(nsd_img_idx), [])) def start(self): """Start background model loading thread.""" t = threading.Thread(target=self._load, daemon=True) t.start() # ── model loading ──────────────────────────────────────────────────────── def _load(self): try: import sys import torch import h5py # Inject dynadiff paths before any imports from those packages dynadiff_diffusers = os.path.join(self.dynadiff_dir, 'diffusers', 'src') for p in [self.dynadiff_dir, dynadiff_diffusers]: if p not in sys.path: sys.path.insert(0, p) # Pre-import torchvision so it is fully initialised before dynadiff's # diffusers fork pulls it in. Without this, torchvision.transforms can # end up in a partially-initialised state, causing # "cannot import name 'InterpolationMode' from partially initialized # module 'torchvision.transforms'". import torchvision.transforms # noqa: F401 import torchvision.transforms.functional # noqa: F401 # Bokeh's code_runner does os.chdir(original_cwd) in its finally # block after every session's app script, so we cannot rely on cwd # being stable across the slow imports below. Build the config # entirely from absolute paths so no cwd dependency exists. orig_dir = os.getcwd() _vd_cache = os.path.join(self.dynadiff_dir, 'versatile_diffusion') _cache_dir = os.path.join(self.dynadiff_dir, 'cache') _local_infra = {'cluster': None, 'folder': _cache_dir} print('[DynaDiff] importing dynadiff modules...', flush=True) from exca import ConfDict print('[DynaDiff] exca imported', flush=True) _cfg_yaml = os.path.join(self.dynadiff_dir, 'config', 'config.yaml') with open(_cfg_yaml, 'r') as f: cfg = ConfDict.from_yaml(f) cfg['versatilediffusion_config.vd_cache_dir'] = _vd_cache cfg['seed'] = 42 cfg['data.nsd_dataset_config.seed'] = 42 cfg['data.nsd_dataset_config.averaged'] = False cfg['data.nsd_dataset_config.subject_ids'] = [0] cfg['infra'] = _local_infra cfg['data.nsd_dataset_config.infra'] = _local_infra cfg['image_generation_infra'] = _local_infra print('[DynaDiff] config loaded', flush=True) vd_cfg = cfg['versatilediffusion_config'] from model.models import VersatileDiffusion, VersatileDiffusionConfig print('[DynaDiff] model.models imported', flush=True) vd_config = VersatileDiffusionConfig(**vd_cfg) print('[DynaDiff] VersatileDiffusionConfig built', flush=True) # Resolve checkpoint ckpt = self.checkpoint if not os.path.isabs(ckpt): candidate_pth = os.path.join(self.dynadiff_dir, ckpt) candidate_ckpt = os.path.join(self.dynadiff_dir, 'training_checkpoints', ckpt) if os.path.isfile(candidate_pth): ckpt = candidate_pth elif os.path.isdir(candidate_ckpt): ckpt = candidate_ckpt else: raise FileNotFoundError( f'Checkpoint not found: tried {candidate_pth} ' f'and {candidate_ckpt}') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_args = dict(config=vd_config, brain_n_in_channels=N_VOXELS, brain_temp_dim=6) model = VersatileDiffusion(**model_args) if os.path.isfile(ckpt): log.info(f'[DynaDiff] Loading state dict from {ckpt} ...') sd = torch.load(ckpt, map_location=device, weights_only=False) if any(k.startswith('model.') for k in sd): sd = {(k[6:] if k.startswith('model.') else k): v for k, v in sd.items()} drop = ('eval_fid', 'eval_inceptionlastconv', 'eval_eff', 'eval_swav', 'eval_lpips') sd = {k: v for k, v in sd.items() if not any(k.startswith(p) for p in drop)} model.load_state_dict(sd, strict=False) elif os.path.isdir(ckpt): import deepspeed log.info(f'[DynaDiff] Consolidating ZeRO checkpoint from {ckpt} ...') sd = deepspeed.utils.zero_to_fp32 \ .get_fp32_state_dict_from_zero_checkpoint( checkpoint_dir=ckpt, tag='checkpoint', exclude_frozen_parameters=False) sd = {(k[6:] if k.startswith('model.') else k): v for k, v in sd.items()} drop = ('eval_fid', 'eval_inceptionlastconv', 'eval_eff', 'eval_swav', 'eval_lpips') sd = {k: v for k, v in sd.items() if not any(k.startswith(p) for p in drop)} model.load_state_dict(sd, strict=False) else: raise FileNotFoundError(f'Checkpoint not found: {ckpt}') model.sanity_check_blurry = False model.to(device) model.eval() log.info(f'[DynaDiff] Model loaded on {device}') # Beta std log.info(f'[DynaDiff] Computing beta_std from {self.h5_path} ...') with h5py.File(self.h5_path, 'r') as hf: n = min(300, hf['fmri'].shape[0]) beta_std = float(np.array(hf['fmri'][:n]).std(axis=0).mean()) log.info(f'[DynaDiff] beta_std = {beta_std:.5f}') # Subject sample index mapping log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...') with h5py.File(self.h5_path, 'r') as hf: all_subj = np.array(hf['subject_idx'][:], dtype=np.int64) all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64) sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64) log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}') # Build reverse map: NSD image index → list of sample_idx values nsd_to_sample: dict[int, list[int]] = {} for sample_idx_val, h5_row in enumerate(sample_indices): nsd_img = int(all_imgidx[h5_row]) nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val) with self._lock: self._model = model self._cfg = cfg self._beta_std = beta_std self._subject_sample_indices = sample_indices self._nsd_to_sample = nsd_to_sample self._status = 'ok' log.info('[DynaDiff] Ready.') except Exception as exc: log.exception('[DynaDiff] Model loading failed') with self._lock: self._status = 'error' self._error = str(exc) finally: os.chdir(orig_dir) # ── inference ──────────────────────────────────────────────────────────── def reconstruct(self, sample_idx, steerings, seed=42): """ steerings: list of (phi_voxel np.ndarray float32, lam float, threshold float) Returns dict with baseline_img, steered_img, gt_img (base64 PNGs), beta_std. """ import torch with self._lock: model = self._model beta_std = self._beta_std indices = self._subject_sample_indices if model is None: raise RuntimeError('Model not loaded yet') # Map sample_idx → h5 row if indices is not None: if not (0 <= sample_idx < len(indices)): raise IndexError( f'sample_idx {sample_idx} out of range ' f'(subject has {len(indices)} samples)') h5_row = int(indices[sample_idx]) else: h5_row = sample_idx import h5py with h5py.File(self.h5_path, 'r') as hf: fmri = torch.from_numpy( np.array(hf['fmri'][h5_row], dtype=np.float32)).unsqueeze(0) img_idx = int(hf['image_idx'][h5_row]) device = next(model.parameters()).device dtype = next(model.parameters()).dtype # Apply steering perturbations steered_fmri = fmri.clone() for phi_voxel, lam, threshold in steerings: steered_fmri = self._apply_steering( steered_fmri, phi_voxel, lam, beta_std, threshold, device) baseline = self._decode(model, fmri, device, dtype, seed) steered = self._decode(model, steered_fmri, device, dtype, seed) gt_img = self._load_gt_image(img_idx) return { 'baseline_img': _img_to_b64(baseline), 'steered_img': _img_to_b64(steered), 'gt_img': _img_to_b64(gt_img) if gt_img is not None else None, 'beta_std': float(beta_std), } @staticmethod def _apply_steering(fmri_tensor, phi_voxel, lam, beta_std, threshold, device): import torch if lam == 0.0: return fmri_tensor.clone() steered = fmri_tensor.clone().to(device=device) phi_t = torch.from_numpy(phi_voxel).to(dtype=steered.dtype, device=device) phi_max = phi_t.abs().max().item() scale = (beta_std / phi_max) if phi_max > 1e-12 else 1.0 if threshold < 1.0: cutoff = float(np.percentile(np.abs(phi_voxel), 100 * (1 - threshold))) mask = torch.from_numpy(np.abs(phi_voxel) >= cutoff).to(device) else: mask = torch.ones(N_VOXELS, dtype=torch.bool, device=device) perturbation = lam * scale * phi_t perturbation[~mask] = 0.0 if steered.dim() == 3: steered[0, :, :] += perturbation.unsqueeze(-1) else: steered[0, :] += perturbation return steered @staticmethod @__import__('torch').no_grad() def _decode(model, fmri_tensor, device, dtype, seed=42, guidance_scale=3.5, img2img_strength=0.85): encoding = model.get_condition( fmri_tensor.to(device=device, dtype=dtype), __import__('torch').tensor([0], device=device), ) output = model.reconstruction_from_clipbrainimage( encoding, seed=seed, guidance_scale=guidance_scale, img2img_strength=img2img_strength) recon = output.image[0].cpu().float().permute(1, 2, 0).numpy() return np.clip(recon, 0, 1) def _load_gt_image(self, image_idx): """Load GT stimulus: thumbnail first, raw H5 fallback.""" if self.nsd_thumb_dir: thumb = os.path.join(self.nsd_thumb_dir, f'nsd_{image_idx:05d}.jpg') try: from PIL import Image as _PIL return np.array(_PIL.open(thumb).convert('RGB'), dtype=np.float32) / 255.0 except Exception as e: log.warning(f'[DynaDiff] thumb load failed ({thumb}): {e}') # H5 fallback — only works if train_unaveraged.h5 is present try: import h5py train_h5 = os.path.join(self.dynadiff_dir, 'processed_nsd_data', 'train_unaveraged.h5') if not os.path.exists(train_h5): return None with h5py.File(train_h5, 'r') as hf: img = np.array(hf['images'][image_idx], dtype=np.float32) return np.clip(img, 0, 1) except Exception as e: log.warning(f'[DynaDiff] GT image load failed (idx={image_idx}): {e}') return None