Spaces:
Building on T4
Building on T4
File size: 16,236 Bytes
884a21e fd8ee51 884a21e fd8ee51 884a21e 0a17b8d 884a21e fd8ee51 884a21e fd8ee51 884a21e fd8ee51 884a21e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 | """
DynaDiff in-process loader.
Loads the DynaDiff model and exposes a reconstruct() method that returns the
same dict format as the HTTP server's /reconstruct endpoint:
{
"baseline_img": "<base64 PNG>",
"steered_img": "<base64 PNG>",
"gt_img": "<base64 PNG> | None",
"beta_std": float,
}
Usage (in explorer_app.py):
from dynadiff_loader import DynaDiffLoader
loader = DynaDiffLoader(dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir)
loader.start() # begins background model load
loader.n_samples # None until ready
loader.is_ready # True when model is loaded
result = loader.reconstruct(sample_idx, steerings, seed)
"""
import base64
import io
import logging
import os
import threading
import numpy as np
logging.basicConfig(
level=logging.INFO,
format='[DynaDiff %(levelname)s %(asctime)s] %(message)s',
datefmt='%H:%M:%S',
)
log = logging.getLogger(__name__)
N_VOXELS = 15724
# ββ Process-level singleton βββββββββββββββββββββββββββββββββββββββββββββββββββ
# Bokeh re-executes the app script per session, so DynaDiffLoader would be
# instantiated multiple times. We keep one loader alive for the whole process
# so the model is loaded exactly once and all sessions share it.
_singleton: "DynaDiffLoader | None" = None
_singleton_lock = threading.Lock()
def get_loader(dynadiff_dir, checkpoint, h5_path,
nsd_thumb_dir=None, subject_idx=0) -> "DynaDiffLoader":
"""Return the process-level loader, creating and starting it if needed."""
global _singleton
with _singleton_lock:
if _singleton is None:
_singleton = DynaDiffLoader(
dynadiff_dir, checkpoint, h5_path, nsd_thumb_dir, subject_idx)
_singleton.start()
return _singleton
def _img_to_b64(img_np):
"""(H, W, 3) float32 [0,1] β base64 PNG string."""
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
buf = io.BytesIO()
plt.imsave(buf, np.clip(img_np, 0, 1), format='png')
return base64.b64encode(buf.getvalue()).decode('utf-8')
class DynaDiffLoader:
def __init__(self, dynadiff_dir, checkpoint, h5_path,
nsd_thumb_dir=None, subject_idx=0):
self.dynadiff_dir = os.path.abspath(dynadiff_dir)
self.checkpoint = checkpoint
self.h5_path = h5_path if os.path.isabs(h5_path) \
else os.path.join(self.dynadiff_dir, h5_path)
self.nsd_thumb_dir = nsd_thumb_dir
self.subject_idx = subject_idx
self._model = None
self._cfg = None
self._beta_std = None
self._subject_sample_indices = None
self._nsd_to_sample = {}
self._status = 'loading' # 'loading' | 'ok' | 'error'
self._error = ''
self._lock = threading.Lock()
# ββ public properties ββββββββββββββββββββββββββββββββββββββββββββββββββββ
@property
def is_ready(self):
with self._lock:
return self._status == 'ok'
@property
def status(self):
with self._lock:
return self._status, self._error
@property
def n_samples(self):
with self._lock:
idx = self._subject_sample_indices
return len(idx) if idx is not None else None
def sample_idxs_for_nsd_img(self, nsd_img_idx):
"""Return the list of sample_idx values that correspond to a given NSD image index.
Returns an empty list if the image has no trials for this subject or the
mapping is not yet built (model still loading).
"""
with self._lock:
return list(self._nsd_to_sample.get(int(nsd_img_idx), []))
def start(self):
"""Start background model loading thread."""
t = threading.Thread(target=self._load, daemon=True)
t.start()
# ββ model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _load(self):
try:
import sys
import torch
import h5py
# Inject dynadiff paths before any imports from those packages
dynadiff_diffusers = os.path.join(self.dynadiff_dir, 'diffusers', 'src')
for p in [self.dynadiff_dir, dynadiff_diffusers]:
if p not in sys.path:
sys.path.insert(0, p)
# Pre-import torchvision so it is fully initialised before dynadiff's
# diffusers fork pulls it in. Without this, torchvision.transforms can
# end up in a partially-initialised state, causing
# "cannot import name 'InterpolationMode' from partially initialized
# module 'torchvision.transforms'".
import torchvision.transforms # noqa: F401
import torchvision.transforms.functional # noqa: F401
# Bokeh's code_runner does os.chdir(original_cwd) in its finally
# block after every session's app script, so we cannot rely on cwd
# being stable across the slow imports below. Build the config
# entirely from absolute paths so no cwd dependency exists.
orig_dir = os.getcwd()
_vd_cache = os.path.join(self.dynadiff_dir, 'versatile_diffusion')
_cache_dir = os.path.join(self.dynadiff_dir, 'cache')
_local_infra = {'cluster': None, 'folder': _cache_dir}
print('[DynaDiff] importing dynadiff modules...', flush=True)
from exca import ConfDict
print('[DynaDiff] exca imported', flush=True)
_cfg_yaml = os.path.join(self.dynadiff_dir, 'config', 'config.yaml')
with open(_cfg_yaml, 'r') as f:
cfg = ConfDict.from_yaml(f)
cfg['versatilediffusion_config.vd_cache_dir'] = _vd_cache
cfg['seed'] = 42
cfg['data.nsd_dataset_config.seed'] = 42
cfg['data.nsd_dataset_config.averaged'] = False
cfg['data.nsd_dataset_config.subject_ids'] = [0]
cfg['infra'] = _local_infra
cfg['data.nsd_dataset_config.infra'] = _local_infra
cfg['image_generation_infra'] = _local_infra
print('[DynaDiff] config loaded', flush=True)
vd_cfg = cfg['versatilediffusion_config']
from model.models import VersatileDiffusion, VersatileDiffusionConfig
print('[DynaDiff] model.models imported', flush=True)
vd_config = VersatileDiffusionConfig(**vd_cfg)
print('[DynaDiff] VersatileDiffusionConfig built', flush=True)
# Resolve checkpoint
ckpt = self.checkpoint
if not os.path.isabs(ckpt):
candidate_pth = os.path.join(self.dynadiff_dir, ckpt)
candidate_ckpt = os.path.join(self.dynadiff_dir,
'training_checkpoints', ckpt)
if os.path.isfile(candidate_pth):
ckpt = candidate_pth
elif os.path.isdir(candidate_ckpt):
ckpt = candidate_ckpt
else:
raise FileNotFoundError(
f'Checkpoint not found: tried {candidate_pth} '
f'and {candidate_ckpt}')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_args = dict(config=vd_config,
brain_n_in_channels=N_VOXELS, brain_temp_dim=6)
model = VersatileDiffusion(**model_args)
if os.path.isfile(ckpt):
log.info(f'[DynaDiff] Loading state dict from {ckpt} ...')
sd = torch.load(ckpt, map_location=device, weights_only=False)
if any(k.startswith('model.') for k in sd):
sd = {(k[6:] if k.startswith('model.') else k): v
for k, v in sd.items()}
drop = ('eval_fid', 'eval_inceptionlastconv',
'eval_eff', 'eval_swav', 'eval_lpips')
sd = {k: v for k, v in sd.items()
if not any(k.startswith(p) for p in drop)}
model.load_state_dict(sd, strict=False)
elif os.path.isdir(ckpt):
import deepspeed
log.info(f'[DynaDiff] Consolidating ZeRO checkpoint from {ckpt} ...')
sd = deepspeed.utils.zero_to_fp32 \
.get_fp32_state_dict_from_zero_checkpoint(
checkpoint_dir=ckpt, tag='checkpoint',
exclude_frozen_parameters=False)
sd = {(k[6:] if k.startswith('model.') else k): v
for k, v in sd.items()}
drop = ('eval_fid', 'eval_inceptionlastconv',
'eval_eff', 'eval_swav', 'eval_lpips')
sd = {k: v for k, v in sd.items()
if not any(k.startswith(p) for p in drop)}
model.load_state_dict(sd, strict=False)
else:
raise FileNotFoundError(f'Checkpoint not found: {ckpt}')
model.sanity_check_blurry = False
model.to(device)
model.eval()
log.info(f'[DynaDiff] Model loaded on {device}')
# Beta std
log.info(f'[DynaDiff] Computing beta_std from {self.h5_path} ...')
with h5py.File(self.h5_path, 'r') as hf:
n = min(300, hf['fmri'].shape[0])
beta_std = float(np.array(hf['fmri'][:n]).std(axis=0).mean())
log.info(f'[DynaDiff] beta_std = {beta_std:.5f}')
# Subject sample index mapping
log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
with h5py.File(self.h5_path, 'r') as hf:
all_subj = np.array(hf['subject_idx'][:], dtype=np.int64)
all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64)
sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
# Build reverse map: NSD image index β list of sample_idx values
nsd_to_sample: dict[int, list[int]] = {}
for sample_idx_val, h5_row in enumerate(sample_indices):
nsd_img = int(all_imgidx[h5_row])
nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val)
with self._lock:
self._model = model
self._cfg = cfg
self._beta_std = beta_std
self._subject_sample_indices = sample_indices
self._nsd_to_sample = nsd_to_sample
self._status = 'ok'
log.info('[DynaDiff] Ready.')
except Exception as exc:
log.exception('[DynaDiff] Model loading failed')
with self._lock:
self._status = 'error'
self._error = str(exc)
finally:
os.chdir(orig_dir)
# ββ inference ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reconstruct(self, sample_idx, steerings, seed=42):
"""
steerings: list of (phi_voxel np.ndarray float32, lam float, threshold float)
Returns dict with baseline_img, steered_img, gt_img (base64 PNGs), beta_std.
"""
import torch
with self._lock:
model = self._model
beta_std = self._beta_std
indices = self._subject_sample_indices
if model is None:
raise RuntimeError('Model not loaded yet')
# Map sample_idx β h5 row
if indices is not None:
if not (0 <= sample_idx < len(indices)):
raise IndexError(
f'sample_idx {sample_idx} out of range '
f'(subject has {len(indices)} samples)')
h5_row = int(indices[sample_idx])
else:
h5_row = sample_idx
import h5py
with h5py.File(self.h5_path, 'r') as hf:
fmri = torch.from_numpy(
np.array(hf['fmri'][h5_row], dtype=np.float32)).unsqueeze(0)
img_idx = int(hf['image_idx'][h5_row])
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
# Apply steering perturbations
steered_fmri = fmri.clone()
for phi_voxel, lam, threshold in steerings:
steered_fmri = self._apply_steering(
steered_fmri, phi_voxel, lam, beta_std, threshold, device)
baseline = self._decode(model, fmri, device, dtype, seed)
steered = self._decode(model, steered_fmri, device, dtype, seed)
gt_img = self._load_gt_image(img_idx)
return {
'baseline_img': _img_to_b64(baseline),
'steered_img': _img_to_b64(steered),
'gt_img': _img_to_b64(gt_img) if gt_img is not None else None,
'beta_std': float(beta_std),
}
@staticmethod
def _apply_steering(fmri_tensor, phi_voxel, lam, beta_std, threshold, device):
import torch
if lam == 0.0:
return fmri_tensor.clone()
steered = fmri_tensor.clone().to(device=device)
phi_t = torch.from_numpy(phi_voxel).to(dtype=steered.dtype, device=device)
phi_max = phi_t.abs().max().item()
scale = (beta_std / phi_max) if phi_max > 1e-12 else 1.0
if threshold < 1.0:
cutoff = float(np.percentile(np.abs(phi_voxel), 100 * (1 - threshold)))
mask = torch.from_numpy(np.abs(phi_voxel) >= cutoff).to(device)
else:
mask = torch.ones(N_VOXELS, dtype=torch.bool, device=device)
perturbation = lam * scale * phi_t
perturbation[~mask] = 0.0
if steered.dim() == 3:
steered[0, :, :] += perturbation.unsqueeze(-1)
else:
steered[0, :] += perturbation
return steered
@staticmethod
@__import__('torch').no_grad()
def _decode(model, fmri_tensor, device, dtype, seed=42,
guidance_scale=3.5, img2img_strength=0.85):
encoding = model.get_condition(
fmri_tensor.to(device=device, dtype=dtype),
__import__('torch').tensor([0], device=device),
)
output = model.reconstruction_from_clipbrainimage(
encoding, seed=seed, guidance_scale=guidance_scale,
img2img_strength=img2img_strength)
recon = output.image[0].cpu().float().permute(1, 2, 0).numpy()
return np.clip(recon, 0, 1)
def _load_gt_image(self, image_idx):
"""Load GT stimulus: thumbnail first, raw H5 fallback."""
if self.nsd_thumb_dir:
thumb = os.path.join(self.nsd_thumb_dir, f'nsd_{image_idx:05d}.jpg')
try:
from PIL import Image as _PIL
return np.array(_PIL.open(thumb).convert('RGB'),
dtype=np.float32) / 255.0
except Exception as e:
log.warning(f'[DynaDiff] thumb load failed ({thumb}): {e}')
# H5 fallback β only works if train_unaveraged.h5 is present
try:
import h5py
train_h5 = os.path.join(self.dynadiff_dir,
'processed_nsd_data', 'train_unaveraged.h5')
if not os.path.exists(train_h5):
return None
with h5py.File(train_h5, 'r') as hf:
img = np.array(hf['images'][image_idx], dtype=np.float32)
return np.clip(img, 0, 1)
except Exception as e:
log.warning(f'[DynaDiff] GT image load failed (idx={image_idx}): {e}')
return None
|