Spaces:
Sleeping
Sleeping
Marlin Lee Claude Sonnet 4.6 commited on
Commit ·
fd8ee51
1
Parent(s): 93e35bf
Sync local changes: CLIP scores, NSD image lookup, multi-trial DynaDiff, phi_c columns, P75 col, label captions, entrypoint pre-warm
Browse files- entrypoint.sh +29 -0
- scripts/add_clip_embeddings.py +224 -0
- scripts/dynadiff_loader.py +19 -1
- scripts/explorer_app.py +104 -54
entrypoint.sh
CHANGED
|
@@ -103,6 +103,35 @@ if [ ! -d "$COCO_THUMBS" ]; then
|
|
| 103 |
fi
|
| 104 |
IMAGE_DIR_ARG=(--image-dir "$COCO_THUMBS")
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
# ── Determine websocket origin ────────────────────────────────────────────────
|
| 107 |
SPACE_HOST="${SPACE_HOST:-localhost}"
|
| 108 |
|
|
|
|
| 103 |
fi
|
| 104 |
IMAGE_DIR_ARG=(--image-dir "$COCO_THUMBS")
|
| 105 |
|
| 106 |
+
# ── Pre-warm DynaDiff before Bokeh starts ────────────────────────────────────
|
| 107 |
+
# torch.load holds the GIL for extended periods; doing this synchronously before
|
| 108 |
+
# Bokeh launches ensures Tornado's event loop isn't starved when users connect.
|
| 109 |
+
if [ -f "$DYNADIFF_CKPT" ] && [ -f "$FMRI_H5" ]; then
|
| 110 |
+
echo "Pre-warming DynaDiff (this may take a few minutes on cold start)..."
|
| 111 |
+
python3 - <<PYEOF
|
| 112 |
+
import sys, time, os
|
| 113 |
+
sys.path.insert(0, '/app')
|
| 114 |
+
sys.path.insert(0, '/app/dynadiff')
|
| 115 |
+
sys.path.insert(0, '/app/dynadiff/diffusers/src')
|
| 116 |
+
os.chdir('/app')
|
| 117 |
+
from scripts.dynadiff_loader import get_loader
|
| 118 |
+
loader = get_loader(
|
| 119 |
+
dynadiff_dir='/app/dynadiff',
|
| 120 |
+
checkpoint=os.environ.get('DYNADIFF_CKPT', '$DYNADIFF_CKPT'),
|
| 121 |
+
h5_path='$FMRI_H5',
|
| 122 |
+
)
|
| 123 |
+
while True:
|
| 124 |
+
status, err = loader.status
|
| 125 |
+
if status == 'ok':
|
| 126 |
+
print('DynaDiff pre-warm complete.')
|
| 127 |
+
break
|
| 128 |
+
elif status == 'error':
|
| 129 |
+
print(f'DynaDiff pre-warm failed: {err}')
|
| 130 |
+
break
|
| 131 |
+
time.sleep(5)
|
| 132 |
+
PYEOF
|
| 133 |
+
fi
|
| 134 |
+
|
| 135 |
# ── Determine websocket origin ────────────────────────────────────────────────
|
| 136 |
SPACE_HOST="${SPACE_HOST:-localhost}"
|
| 137 |
|
scripts/add_clip_embeddings.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Post-hoc CLIP text-alignment enrichment for explorer_data.pt files.
|
| 3 |
+
|
| 4 |
+
Loads an existing explorer_data.pt, computes per-feature CLIP text alignment
|
| 5 |
+
scores (via MEI images), and saves them back into the same file under:
|
| 6 |
+
'clip_text_scores' : Tensor (n_features, n_vocab) float16
|
| 7 |
+
'clip_text_vocab' : list[str]
|
| 8 |
+
'clip_feature_embeds' : Tensor (n_features, clip_proj_dim) float16
|
| 9 |
+
mean CLIP image embedding of each feature's top MEIs
|
| 10 |
+
|
| 11 |
+
This script does NOT need to re-run DINOv3 or the SAE — it only needs the
|
| 12 |
+
existing explorer_data.pt (for image paths and top-MEI indices) and CLIP.
|
| 13 |
+
|
| 14 |
+
Usage
|
| 15 |
+
-----
|
| 16 |
+
python add_clip_embeddings.py \
|
| 17 |
+
--data ../explorer_data_d32000_k160.pt \
|
| 18 |
+
--vocab-file ../vocab/imagenet_labels.txt \
|
| 19 |
+
--n-top-images 4 \
|
| 20 |
+
--batch-size 32
|
| 21 |
+
|
| 22 |
+
# Or use the built-in default vocabulary (ImageNet-1K labels + COCO categories):
|
| 23 |
+
python add_clip_embeddings.py \
|
| 24 |
+
--data ../explorer_data_d32000_k160.pt
|
| 25 |
+
|
| 26 |
+
The enriched file is saved to --output-path (defaults to overwriting --data
|
| 27 |
+
with a backup copy at <data>.bak).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import os
|
| 32 |
+
import shutil
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from PIL import Image
|
| 37 |
+
|
| 38 |
+
# Allow running from scripts/ directory or project root
|
| 39 |
+
import sys
|
| 40 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
| 41 |
+
from clip_utils import load_clip, compute_text_embeddings, compute_mei_text_alignment
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Default vocabulary
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
DEFAULT_VOCAB = [
|
| 49 |
+
# COCO categories
|
| 50 |
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
|
| 51 |
+
"truck", "boat", "traffic light", "fire hydrant", "stop sign",
|
| 52 |
+
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
| 53 |
+
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag",
|
| 54 |
+
"tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
|
| 55 |
+
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
|
| 56 |
+
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
|
| 57 |
+
"apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
|
| 58 |
+
"donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
|
| 59 |
+
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
| 60 |
+
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock",
|
| 61 |
+
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
|
| 62 |
+
# Texture / scene descriptors
|
| 63 |
+
"grass", "sky", "water", "sand", "snow", "wood", "stone", "metal",
|
| 64 |
+
"fabric", "fur", "feathers", "leaves", "clouds", "fire", "shadow",
|
| 65 |
+
"stripes", "spots", "checkerboard pattern", "geometric pattern",
|
| 66 |
+
# Orientation / structure cues (for patch features)
|
| 67 |
+
"horizontal lines", "vertical lines", "diagonal lines", "curved lines",
|
| 68 |
+
"edges", "corners", "grid", "dots", "concentric circles",
|
| 69 |
+
# Color / illumination
|
| 70 |
+
"red object", "blue object", "green object", "yellow object",
|
| 71 |
+
"black and white", "bright highlight", "dark shadow", "gradient",
|
| 72 |
+
# Scene types
|
| 73 |
+
"indoor scene", "outdoor scene", "urban street", "nature landscape",
|
| 74 |
+
"ocean", "mountain", "forest", "desert", "city buildings", "crowd",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Main
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def main():
|
| 83 |
+
parser = argparse.ArgumentParser(description="Add CLIP text alignment to explorer_data.pt")
|
| 84 |
+
parser.add_argument("--data", type=str, required=True,
|
| 85 |
+
help="Path to explorer_data.pt")
|
| 86 |
+
parser.add_argument("--output-path", type=str, default=None,
|
| 87 |
+
help="Output path (default: overwrite --data, keeping .bak)")
|
| 88 |
+
parser.add_argument("--vocab-file", type=str, default=None,
|
| 89 |
+
help="Plain-text file with one concept per line. "
|
| 90 |
+
"Default: built-in COCO+texture vocabulary.")
|
| 91 |
+
parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
|
| 92 |
+
help="HuggingFace CLIP model ID")
|
| 93 |
+
parser.add_argument("--n-top-images", type=int, default=4,
|
| 94 |
+
help="Number of MEIs to average per feature for CLIP alignment")
|
| 95 |
+
parser.add_argument("--batch-size", type=int, default=32,
|
| 96 |
+
help="Batch size for CLIP image encoding")
|
| 97 |
+
parser.add_argument("--no-backup", action="store_true",
|
| 98 |
+
help="Skip creating a .bak copy before overwriting")
|
| 99 |
+
parser.add_argument("--image-dir", type=str, default=None,
|
| 100 |
+
help="Primary image directory for resolving bare filenames")
|
| 101 |
+
parser.add_argument("--extra-image-dir", type=str, action="append", default=[],
|
| 102 |
+
help="Additional image directory (repeatable)")
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
|
| 105 |
+
image_bases = [b for b in ([args.image_dir] + args.extra_image_dir) if b]
|
| 106 |
+
|
| 107 |
+
def resolve_path(p):
|
| 108 |
+
if os.path.isabs(p) or not image_bases:
|
| 109 |
+
return p
|
| 110 |
+
for base in image_bases:
|
| 111 |
+
full = os.path.join(base, p)
|
| 112 |
+
if os.path.exists(full):
|
| 113 |
+
return full
|
| 114 |
+
return os.path.join(image_bases[0], p) # fallback
|
| 115 |
+
|
| 116 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 117 |
+
print(f"Device: {device}")
|
| 118 |
+
|
| 119 |
+
# --- Load explorer data ---
|
| 120 |
+
print(f"Loading explorer data from {args.data}...")
|
| 121 |
+
data = torch.load(args.data, map_location='cpu', weights_only=False)
|
| 122 |
+
image_paths = [resolve_path(p) for p in data['image_paths']]
|
| 123 |
+
d_model = data['d_model']
|
| 124 |
+
top_img_idx = data['top_img_idx'] # (n_features, n_top)
|
| 125 |
+
n_top_stored = top_img_idx.shape[1]
|
| 126 |
+
print(f" d_model={d_model}, n_images={data['n_images']}, "
|
| 127 |
+
f"top-{n_top_stored} images stored")
|
| 128 |
+
|
| 129 |
+
# --- Load vocabulary ---
|
| 130 |
+
if args.vocab_file:
|
| 131 |
+
with open(args.vocab_file) as f:
|
| 132 |
+
vocab = [line.strip() for line in f if line.strip()]
|
| 133 |
+
print(f"Loaded {len(vocab)} concepts from {args.vocab_file}")
|
| 134 |
+
else:
|
| 135 |
+
vocab = DEFAULT_VOCAB
|
| 136 |
+
print(f"Using default vocabulary ({len(vocab)} concepts)")
|
| 137 |
+
|
| 138 |
+
# --- Load CLIP ---
|
| 139 |
+
clip_model, clip_processor = load_clip(device, model_name=args.clip_model)
|
| 140 |
+
|
| 141 |
+
# --- Precompute text embeddings ---
|
| 142 |
+
print("Encoding text vocabulary with CLIP...")
|
| 143 |
+
text_embeds = compute_text_embeddings(vocab, clip_model, clip_processor, device)
|
| 144 |
+
print(f" text_embeds: {text_embeds.shape}")
|
| 145 |
+
|
| 146 |
+
# --- Collect MEI image paths per feature ---
|
| 147 |
+
print("Collecting MEI image paths per feature...")
|
| 148 |
+
n_use = min(args.n_top_images, n_top_stored)
|
| 149 |
+
feature_mei_paths = []
|
| 150 |
+
for feat in range(d_model):
|
| 151 |
+
paths = []
|
| 152 |
+
for j in range(n_use):
|
| 153 |
+
idx = top_img_idx[feat, j].item()
|
| 154 |
+
if idx >= 0:
|
| 155 |
+
paths.append(image_paths[idx])
|
| 156 |
+
feature_mei_paths.append(paths)
|
| 157 |
+
|
| 158 |
+
# --- Compute per-feature CLIP image embeddings (mean of MEIs) ---
|
| 159 |
+
print(f"Computing CLIP image embeddings for {d_model} features "
|
| 160 |
+
f"(averaging {n_use} MEIs each)...")
|
| 161 |
+
|
| 162 |
+
clip_proj_dim = clip_model.config.projection_dim
|
| 163 |
+
feature_img_embeds = torch.zeros(d_model, clip_proj_dim, dtype=torch.float32)
|
| 164 |
+
dead_count = 0
|
| 165 |
+
|
| 166 |
+
for feat_start in range(0, d_model, args.batch_size):
|
| 167 |
+
feat_end = min(feat_start + args.batch_size, d_model)
|
| 168 |
+
for feat in range(feat_start, feat_end):
|
| 169 |
+
paths = feature_mei_paths[feat]
|
| 170 |
+
if not paths:
|
| 171 |
+
dead_count += 1
|
| 172 |
+
continue
|
| 173 |
+
imgs = []
|
| 174 |
+
for p in paths:
|
| 175 |
+
try:
|
| 176 |
+
imgs.append(Image.open(p).convert("RGB"))
|
| 177 |
+
except Exception:
|
| 178 |
+
continue
|
| 179 |
+
if not imgs:
|
| 180 |
+
dead_count += 1
|
| 181 |
+
continue
|
| 182 |
+
inputs = clip_processor(images=imgs, return_tensors="pt")
|
| 183 |
+
pixel_values = inputs['pixel_values'].to(device)
|
| 184 |
+
with torch.inference_mode():
|
| 185 |
+
# Use vision_model + visual_projection directly to avoid
|
| 186 |
+
# version differences in get_image_features() return type.
|
| 187 |
+
vision_out = clip_model.vision_model(pixel_values=pixel_values)
|
| 188 |
+
embeds = clip_model.visual_projection(vision_out.pooler_output)
|
| 189 |
+
embeds = F.normalize(embeds, dim=-1)
|
| 190 |
+
mean_embed = embeds.mean(dim=0)
|
| 191 |
+
mean_embed = F.normalize(mean_embed, dim=-1)
|
| 192 |
+
feature_img_embeds[feat] = mean_embed.cpu().float()
|
| 193 |
+
|
| 194 |
+
if (feat_start // args.batch_size + 1) % 100 == 0:
|
| 195 |
+
print(f" [{feat_end}/{d_model}] features encoded", flush=True)
|
| 196 |
+
|
| 197 |
+
print(f" Done. Dead/missing features skipped: {dead_count}")
|
| 198 |
+
|
| 199 |
+
# --- Compute alignment matrix ---
|
| 200 |
+
print("Computing text alignment matrix...")
|
| 201 |
+
# (n_features, clip_proj_dim) @ (clip_proj_dim, n_vocab) = (n_features, n_vocab)
|
| 202 |
+
clip_text_scores = feature_img_embeds @ text_embeds.T # float32
|
| 203 |
+
print(f" clip_text_scores: {clip_text_scores.shape}")
|
| 204 |
+
|
| 205 |
+
# --- Save into explorer_data.pt ---
|
| 206 |
+
output_path = args.output_path or args.data
|
| 207 |
+
if output_path == args.data and not args.no_backup:
|
| 208 |
+
bak_path = args.data + ".bak"
|
| 209 |
+
print(f"Creating backup at {bak_path}...")
|
| 210 |
+
shutil.copy2(args.data, bak_path)
|
| 211 |
+
|
| 212 |
+
data['clip_text_scores'] = clip_text_scores.half() # float16 to save space
|
| 213 |
+
data['clip_feature_embeds'] = feature_img_embeds.half() # float16
|
| 214 |
+
data['clip_text_vocab'] = vocab
|
| 215 |
+
|
| 216 |
+
print(f"Saving enriched explorer data to {output_path}...")
|
| 217 |
+
torch.save(data, output_path)
|
| 218 |
+
size_mb = os.path.getsize(output_path) / 1e6
|
| 219 |
+
print(f"Saved ({size_mb:.1f} MB)")
|
| 220 |
+
print("Done.")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
main()
|
scripts/dynadiff_loader.py
CHANGED
|
@@ -80,6 +80,7 @@ class DynaDiffLoader:
|
|
| 80 |
self._cfg = None
|
| 81 |
self._beta_std = None
|
| 82 |
self._subject_sample_indices = None
|
|
|
|
| 83 |
self._status = 'loading' # 'loading' | 'ok' | 'error'
|
| 84 |
self._error = ''
|
| 85 |
self._lock = threading.Lock()
|
|
@@ -102,6 +103,15 @@ class DynaDiffLoader:
|
|
| 102 |
idx = self._subject_sample_indices
|
| 103 |
return len(idx) if idx is not None else None
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def start(self):
|
| 106 |
"""Start background model loading thread."""
|
| 107 |
t = threading.Thread(target=self._load, daemon=True)
|
|
@@ -216,15 +226,23 @@ class DynaDiffLoader:
|
|
| 216 |
# Subject sample index mapping
|
| 217 |
log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
|
| 218 |
with h5py.File(self.h5_path, 'r') as hf:
|
| 219 |
-
all_subj
|
|
|
|
| 220 |
sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
|
| 221 |
log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
with self._lock:
|
| 224 |
self._model = model
|
| 225 |
self._cfg = cfg
|
| 226 |
self._beta_std = beta_std
|
| 227 |
self._subject_sample_indices = sample_indices
|
|
|
|
| 228 |
self._status = 'ok'
|
| 229 |
log.info('[DynaDiff] Ready.')
|
| 230 |
|
|
|
|
| 80 |
self._cfg = None
|
| 81 |
self._beta_std = None
|
| 82 |
self._subject_sample_indices = None
|
| 83 |
+
self._nsd_to_sample = {}
|
| 84 |
self._status = 'loading' # 'loading' | 'ok' | 'error'
|
| 85 |
self._error = ''
|
| 86 |
self._lock = threading.Lock()
|
|
|
|
| 103 |
idx = self._subject_sample_indices
|
| 104 |
return len(idx) if idx is not None else None
|
| 105 |
|
| 106 |
+
def sample_idxs_for_nsd_img(self, nsd_img_idx):
|
| 107 |
+
"""Return the list of sample_idx values that correspond to a given NSD image index.
|
| 108 |
+
|
| 109 |
+
Returns an empty list if the image has no trials for this subject or the
|
| 110 |
+
mapping is not yet built (model still loading).
|
| 111 |
+
"""
|
| 112 |
+
with self._lock:
|
| 113 |
+
return list(self._nsd_to_sample.get(int(nsd_img_idx), []))
|
| 114 |
+
|
| 115 |
def start(self):
|
| 116 |
"""Start background model loading thread."""
|
| 117 |
t = threading.Thread(target=self._load, daemon=True)
|
|
|
|
| 226 |
# Subject sample index mapping
|
| 227 |
log.info(f'[DynaDiff] Building sample index for subject {self.subject_idx} ...')
|
| 228 |
with h5py.File(self.h5_path, 'r') as hf:
|
| 229 |
+
all_subj = np.array(hf['subject_idx'][:], dtype=np.int64)
|
| 230 |
+
all_imgidx = np.array(hf['image_idx'][:], dtype=np.int64)
|
| 231 |
sample_indices = np.where(all_subj == self.subject_idx)[0].astype(np.int64)
|
| 232 |
log.info(f'[DynaDiff] {len(sample_indices)} samples for subject {self.subject_idx}')
|
| 233 |
|
| 234 |
+
# Build reverse map: NSD image index → list of sample_idx values
|
| 235 |
+
nsd_to_sample: dict[int, list[int]] = {}
|
| 236 |
+
for sample_idx_val, h5_row in enumerate(sample_indices):
|
| 237 |
+
nsd_img = int(all_imgidx[h5_row])
|
| 238 |
+
nsd_to_sample.setdefault(nsd_img, []).append(sample_idx_val)
|
| 239 |
+
|
| 240 |
with self._lock:
|
| 241 |
self._model = model
|
| 242 |
self._cfg = cfg
|
| 243 |
self._beta_std = beta_std
|
| 244 |
self._subject_sample_indices = sample_indices
|
| 245 |
+
self._nsd_to_sample = nsd_to_sample
|
| 246 |
self._status = 'ok'
|
| 247 |
log.info('[DynaDiff] Ready.')
|
| 248 |
|
scripts/explorer_app.py
CHANGED
|
@@ -430,10 +430,10 @@ def _load_brain_dataset_dict(path, label, thumb_dir):
|
|
| 430 |
'feature_p75_val': bd.get('feature_p75_val', torch.zeros(d_model)),
|
| 431 |
'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2,
|
| 432 |
'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2,
|
| 433 |
-
'clip_scores': None,
|
| 434 |
-
'clip_vocab': None,
|
| 435 |
-
'clip_embeds': None,
|
| 436 |
-
'clip_scores_f32': None,
|
| 437 |
'inference_cache': OrderedDict(),
|
| 438 |
'names_file': stem + '_feature_names.json',
|
| 439 |
'auto_interp_file': stem + '_auto_interp.json',
|
|
@@ -633,6 +633,9 @@ def _reconstruct_z_from_heatmaps(img_idx, ds):
|
|
| 633 |
idx = ds.get(idx_key) # (d_sae, n_slots) int tensor
|
| 634 |
if hm is None or idx is None:
|
| 635 |
continue
|
|
|
|
|
|
|
|
|
|
| 636 |
if z is None:
|
| 637 |
d_sae, _, n_patches_sq = hm.shape
|
| 638 |
z = np.zeros((n_patches_sq, d_sae), dtype=np.float32)
|
|
@@ -704,6 +707,20 @@ ALPHA_JET = create_alpha_cmap('jet')
|
|
| 704 |
THUMB = args.thumb_size
|
| 705 |
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
def _resolve_img_path(stored_path):
|
| 708 |
"""Resolve a stored image path, searching image dirs first. Returns None on failure."""
|
| 709 |
if os.path.isabs(stored_path) and os.path.exists(stored_path):
|
|
@@ -952,33 +969,42 @@ def _dynadiff_request(sample_idx, steerings, seed):
|
|
| 952 |
return _dd_loader.reconstruct(sample_idx, steerings, seed)
|
| 953 |
|
| 954 |
|
| 955 |
-
def _make_steering_html(
|
| 956 |
-
"""Build HTML showing GT | Baseline | Steered
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
b64 = resp.get(key)
|
| 962 |
-
if b64 is None:
|
| 963 |
-
img_html = ('<div style="width:200px;height:200px;background:#eee;'
|
| 964 |
-
'display:flex;align-items:center;justify-content:center;'
|
| 965 |
-
'color:#999;font-size:12px">N/A</div>')
|
| 966 |
-
else:
|
| 967 |
-
img_html = (f'<img src="data:image/png;base64,{b64}" '
|
| 968 |
-
'style="width:200px;height:200px;object-fit:contain;'
|
| 969 |
-
'border:1px solid #ddd;border-radius:4px"/>')
|
| 970 |
-
parts.append(
|
| 971 |
-
f'<div style="text-align:center;margin:0 6px">'
|
| 972 |
-
f'{img_html}'
|
| 973 |
-
f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>'
|
| 974 |
-
f'</div>'
|
| 975 |
-
)
|
| 976 |
-
imgs_html = '<div style="display:flex;align-items:flex-end">' + ''.join(parts) + '</div>'
|
| 977 |
-
return (
|
| 978 |
f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;'
|
| 979 |
f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>'
|
| 980 |
-
+ imgs_html
|
| 981 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
|
| 983 |
|
| 984 |
def make_image_grid_html(images_info, title, cols=9):
|
|
@@ -1391,10 +1417,14 @@ def _build_dynadiff_panel():
|
|
| 1391 |
dd_feat_remove_btn.on_click(_on_remove_feat)
|
| 1392 |
dd_feat_clear_btn.on_click(_on_clear_feats)
|
| 1393 |
|
| 1394 |
-
def _reconstruct_thread(
|
| 1395 |
try:
|
| 1396 |
-
|
| 1397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1398 |
def _apply(html=html):
|
| 1399 |
dd_output.text = html
|
| 1400 |
dd_status.text = ''
|
|
@@ -1422,13 +1452,25 @@ def _build_dynadiff_panel():
|
|
| 1422 |
if not steerings:
|
| 1423 |
dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>'
|
| 1424 |
return
|
|
|
|
| 1425 |
try:
|
| 1426 |
-
|
| 1427 |
except ValueError:
|
| 1428 |
dd_status.text = '<span style="color:#c00">Invalid sample index.</span>'
|
| 1429 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1430 |
_n = _dd_loader.n_samples
|
| 1431 |
-
if _n is not None and not (0 <=
|
| 1432 |
dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>'
|
| 1433 |
return
|
| 1434 |
try:
|
|
@@ -1438,11 +1480,13 @@ def _build_dynadiff_panel():
|
|
| 1438 |
names = list(dd_source.data['name'])
|
| 1439 |
feat_name = ' + '.join(names) if names else 'unknown'
|
| 1440 |
dd_btn.disabled = True
|
| 1441 |
-
|
|
|
|
|
|
|
| 1442 |
doc = curdoc()
|
| 1443 |
threading.Thread(
|
| 1444 |
target=_reconstruct_thread,
|
| 1445 |
-
args=(
|
| 1446 |
daemon=True,
|
| 1447 |
).start()
|
| 1448 |
|
|
@@ -1587,14 +1631,15 @@ def update_feature_display(feature_idx):
|
|
| 1587 |
else:
|
| 1588 |
hmap = None
|
| 1589 |
|
|
|
|
| 1590 |
if hmap is None:
|
| 1591 |
plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR)
|
| 1592 |
act_val = float(act_tensor[feat, ranking_idx].item())
|
| 1593 |
-
caption = f"act={act_val:.4f}
|
| 1594 |
return (plain, caption)
|
| 1595 |
max_act, mean_act_val = _patch_stats(hmap.flatten())
|
| 1596 |
img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center)
|
| 1597 |
-
caption =
|
| 1598 |
return (img_out, caption)
|
| 1599 |
except Exception as e:
|
| 1600 |
ph = Image.new("RGB", (THUMB, THUMB), "gray")
|
|
@@ -1817,11 +1862,13 @@ feature_list_source = ColumnDataSource(data=dict(
|
|
| 1817 |
name=[_display_name(int(i)) for i in _init_order],
|
| 1818 |
))
|
| 1819 |
|
| 1820 |
-
|
| 1821 |
-
|
| 1822 |
-
|
| 1823 |
-
|
| 1824 |
-
|
|
|
|
|
|
|
| 1825 |
feature_table = DataTable(
|
| 1826 |
source=feature_list_source,
|
| 1827 |
columns=[
|
|
@@ -1830,9 +1877,7 @@ feature_table = DataTable(
|
|
| 1830 |
formatter=NumberFormatter(format="0,0")),
|
| 1831 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 1832 |
formatter=NumberFormatter(format="0.0000")),
|
| 1833 |
-
|
| 1834 |
-
formatter=NumberFormatter(format="0.0000")),
|
| 1835 |
-
] + _phi_col + [
|
| 1836 |
TableColumn(field="name", title="Name", width=200),
|
| 1837 |
],
|
| 1838 |
width=500, height=500, sortable=True, index_position=None,
|
|
@@ -2170,20 +2215,20 @@ load_patch_btn = Button(label="Load Image", width=90, button_type="primary")
|
|
| 2170 |
clear_patch_btn = Button(label="Clear", width=60)
|
| 2171 |
|
| 2172 |
patch_feat_source = ColumnDataSource(data=dict(
|
| 2173 |
-
feature_idx=[], patch_act=[], frequency=[], mean_act=[],
|
| 2174 |
))
|
| 2175 |
patch_feat_table = DataTable(
|
| 2176 |
source=patch_feat_source,
|
| 2177 |
columns=[
|
| 2178 |
-
TableColumn(field="feature_idx", title="Feature",
|
| 2179 |
TableColumn(field="patch_act", title="Patch Act", width=85,
|
| 2180 |
formatter=NumberFormatter(format="0.0000")),
|
| 2181 |
TableColumn(field="frequency", title="Freq", width=65,
|
| 2182 |
formatter=NumberFormatter(format="0,0")),
|
| 2183 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 2184 |
formatter=NumberFormatter(format="0.0000")),
|
| 2185 |
-
],
|
| 2186 |
-
width=310, height=350, index_position=None, sortable=False, visible=False,
|
| 2187 |
)
|
| 2188 |
patch_info_div = Div(
|
| 2189 |
text="<i>Load an image, then click patches to find top features.</i>",
|
|
@@ -2203,7 +2248,7 @@ def _pil_to_bokeh_rgba(pil_img, size):
|
|
| 2203 |
|
| 2204 |
def _do_load_patch_image():
|
| 2205 |
try:
|
| 2206 |
-
img_idx =
|
| 2207 |
except ValueError:
|
| 2208 |
patch_info_div.text = "<b style='color:red'>Invalid image index</b>"
|
| 2209 |
return
|
|
@@ -2292,7 +2337,7 @@ def _on_patch_select(attr, old, new):
|
|
| 2292 |
if _S.patch_img is None:
|
| 2293 |
return
|
| 2294 |
if not new:
|
| 2295 |
-
patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[])
|
| 2296 |
patch_info_div.text = "<i>Selection cleared.</i>"
|
| 2297 |
return
|
| 2298 |
|
|
@@ -2302,7 +2347,10 @@ def _on_patch_select(attr, old, new):
|
|
| 2302 |
patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)]
|
| 2303 |
|
| 2304 |
feats, acts, freqs, means = _get_top_features_for_patches(patch_indices)
|
| 2305 |
-
patch_feat_source.data = dict(
|
|
|
|
|
|
|
|
|
|
| 2306 |
patch_info_div.text = (
|
| 2307 |
f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. "
|
| 2308 |
f"Click a row below to explore the feature."
|
|
@@ -2347,7 +2395,7 @@ def _build_clip_panel():
|
|
| 2347 |
clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70)
|
| 2348 |
|
| 2349 |
result_source = ColumnDataSource(data=dict(
|
| 2350 |
-
feature_idx=[], clip_score=[], frequency=[], mean_act=[], name=[],
|
| 2351 |
))
|
| 2352 |
clip_result_table = DataTable(
|
| 2353 |
source=result_source,
|
|
@@ -2359,9 +2407,10 @@ def _build_clip_panel():
|
|
| 2359 |
formatter=NumberFormatter(format="0,0")),
|
| 2360 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 2361 |
formatter=NumberFormatter(format="0.0000")),
|
|
|
|
| 2362 |
TableColumn(field="name", title="Name", width=160),
|
| 2363 |
],
|
| 2364 |
-
width=470, height=300, index_position=None, sortable=False,
|
| 2365 |
)
|
| 2366 |
|
| 2367 |
def _do_search():
|
|
@@ -2402,6 +2451,7 @@ def _build_clip_panel():
|
|
| 2402 |
clip_score=[float(scores_vec[i]) for i in top_indices],
|
| 2403 |
frequency=[int(feature_frequency[i].item()) for i in top_indices],
|
| 2404 |
mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
|
|
|
|
| 2405 |
name=[_display_name(int(i)) for i in top_indices],
|
| 2406 |
)
|
| 2407 |
result_div.text = (
|
|
|
|
| 430 |
'feature_p75_val': bd.get('feature_p75_val', torch.zeros(d_model)),
|
| 431 |
'umap_coords': bd['umap_coords'].numpy() if 'umap_coords' in bd else nan2,
|
| 432 |
'dict_umap_coords': bd['dict_umap_coords'].numpy() if 'dict_umap_coords' in bd else nan2,
|
| 433 |
+
'clip_scores': bd.get('clip_text_scores', None),
|
| 434 |
+
'clip_vocab': bd.get('clip_text_vocab', None),
|
| 435 |
+
'clip_embeds': bd.get('clip_feature_embeds', None),
|
| 436 |
+
'clip_scores_f32': bd['clip_text_scores'].float() if 'clip_text_scores' in bd else None,
|
| 437 |
'inference_cache': OrderedDict(),
|
| 438 |
'names_file': stem + '_feature_names.json',
|
| 439 |
'auto_interp_file': stem + '_auto_interp.json',
|
|
|
|
| 633 |
idx = ds.get(idx_key) # (d_sae, n_slots) int tensor
|
| 634 |
if hm is None or idx is None:
|
| 635 |
continue
|
| 636 |
+
# Normalise: flatten 4-D (d_sae, n_slots, H, W) → 3-D (d_sae, n_slots, H*W)
|
| 637 |
+
if hm.ndim == 4:
|
| 638 |
+
hm = hm.reshape(hm.shape[0], hm.shape[1], -1)
|
| 639 |
if z is None:
|
| 640 |
d_sae, _, n_patches_sq = hm.shape
|
| 641 |
z = np.zeros((n_patches_sq, d_sae), dtype=np.float32)
|
|
|
|
| 707 |
THUMB = args.thumb_size
|
| 708 |
|
| 709 |
|
| 710 |
+
def _parse_img_label(value):
|
| 711 |
+
"""Parse an image label into an integer index.
|
| 712 |
+
|
| 713 |
+
Accepts bare integers ('42') or name-prefixed labels ('nsd_00042',
|
| 714 |
+
'COCO_val2014_000000123456') by extracting the trailing integer after
|
| 715 |
+
the last underscore. Raises ValueError on failure.
|
| 716 |
+
"""
|
| 717 |
+
val = value.strip()
|
| 718 |
+
try:
|
| 719 |
+
return int(val)
|
| 720 |
+
except ValueError:
|
| 721 |
+
return int(val.rsplit('_', 1)[-1])
|
| 722 |
+
|
| 723 |
+
|
| 724 |
def _resolve_img_path(stored_path):
|
| 725 |
"""Resolve a stored image path, searching image dirs first. Returns None on failure."""
|
| 726 |
if os.path.isabs(stored_path) and os.path.exists(stored_path):
|
|
|
|
| 969 |
return _dd_loader.reconstruct(sample_idx, steerings, seed)
|
| 970 |
|
| 971 |
|
| 972 |
+
def _make_steering_html(resps, concept_name):
|
| 973 |
+
"""Build HTML showing GT | Baseline | Steered for one or more trials.
|
| 974 |
+
|
| 975 |
+
resps: list of (trial_label, resp_dict) pairs.
|
| 976 |
+
"""
|
| 977 |
+
header = (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 978 |
f'<h3 style="margin:4px 0 6px 0;color:#333;border-bottom:2px solid #e0e0e0;'
|
| 979 |
f'padding-bottom:4px">DynaDiff Steering — {concept_name}</h3>'
|
|
|
|
| 980 |
)
|
| 981 |
+
rows_html = ''
|
| 982 |
+
for trial_label, resp in resps:
|
| 983 |
+
parts = []
|
| 984 |
+
for label, key in [('GT', 'gt_img'),
|
| 985 |
+
('Baseline', 'baseline_img'),
|
| 986 |
+
('Steered', 'steered_img')]:
|
| 987 |
+
b64 = resp.get(key)
|
| 988 |
+
if b64 is None:
|
| 989 |
+
img_html = ('<div style="width:160px;height:160px;background:#eee;'
|
| 990 |
+
'display:flex;align-items:center;justify-content:center;'
|
| 991 |
+
'color:#999;font-size:12px">N/A</div>')
|
| 992 |
+
else:
|
| 993 |
+
img_html = (f'<img src="data:image/png;base64,{b64}" '
|
| 994 |
+
'style="width:160px;height:160px;object-fit:contain;'
|
| 995 |
+
'border:1px solid #ddd;border-radius:4px"/>')
|
| 996 |
+
parts.append(
|
| 997 |
+
f'<div style="text-align:center;margin:0 4px">'
|
| 998 |
+
f'{img_html}'
|
| 999 |
+
f'<div style="font-size:11px;color:#555;margin-top:3px">{label}</div>'
|
| 1000 |
+
f'</div>'
|
| 1001 |
+
)
|
| 1002 |
+
trial_head = (f'<div style="font-size:11px;font-weight:bold;color:#777;'
|
| 1003 |
+
f'margin:6px 0 3px 4px">{trial_label}</div>')
|
| 1004 |
+
rows_html += (trial_head
|
| 1005 |
+
+ '<div style="display:flex;align-items:flex-end;margin-bottom:8px">'
|
| 1006 |
+
+ ''.join(parts) + '</div>')
|
| 1007 |
+
return header + rows_html
|
| 1008 |
|
| 1009 |
|
| 1010 |
def make_image_grid_html(images_info, title, cols=9):
|
|
|
|
| 1417 |
dd_feat_remove_btn.on_click(_on_remove_feat)
|
| 1418 |
dd_feat_clear_btn.on_click(_on_clear_feats)
|
| 1419 |
|
| 1420 |
+
def _reconstruct_thread(sample_idxs, steerings, seed, feat_name, doc):
|
| 1421 |
try:
|
| 1422 |
+
resps = []
|
| 1423 |
+
for i, sidx in enumerate(sample_idxs):
|
| 1424 |
+
trial_label = f'Trial {i+1} (sample {sidx})'
|
| 1425 |
+
resp = _dynadiff_request(sidx, steerings, seed)
|
| 1426 |
+
resps.append((trial_label, resp))
|
| 1427 |
+
html = _make_steering_html(resps, feat_name)
|
| 1428 |
def _apply(html=html):
|
| 1429 |
dd_output.text = html
|
| 1430 |
dd_status.text = ''
|
|
|
|
| 1452 |
if not steerings:
|
| 1453 |
dd_status.text = '<span style="color:#c00">No phi data for selected features.</span>'
|
| 1454 |
return
|
| 1455 |
+
_raw = dd_sample_input.value.strip()
|
| 1456 |
try:
|
| 1457 |
+
_parsed = _parse_img_label(_raw)
|
| 1458 |
except ValueError:
|
| 1459 |
dd_status.text = '<span style="color:#c00">Invalid sample index.</span>'
|
| 1460 |
return
|
| 1461 |
+
# If input looks like an NSD image label (contains '_'), treat _parsed as
|
| 1462 |
+
# an NSD image index and run all trials for that image.
|
| 1463 |
+
if '_' in _raw:
|
| 1464 |
+
sample_idxs = _dd_loader.sample_idxs_for_nsd_img(_parsed)
|
| 1465 |
+
if not sample_idxs:
|
| 1466 |
+
dd_status.text = (
|
| 1467 |
+
f'<span style="color:#c00">NSD image {_parsed} has no trials '
|
| 1468 |
+
f'for this subject.</span>')
|
| 1469 |
+
return
|
| 1470 |
+
else:
|
| 1471 |
+
sample_idxs = [_parsed]
|
| 1472 |
_n = _dd_loader.n_samples
|
| 1473 |
+
if _n is not None and any(not (0 <= s < _n) for s in sample_idxs):
|
| 1474 |
dd_status.text = f'<span style="color:#c00">sample_idx must be 0–{_n-1}.</span>'
|
| 1475 |
return
|
| 1476 |
try:
|
|
|
|
| 1480 |
names = list(dd_source.data['name'])
|
| 1481 |
feat_name = ' + '.join(names) if names else 'unknown'
|
| 1482 |
dd_btn.disabled = True
|
| 1483 |
+
n_trials = len(sample_idxs)
|
| 1484 |
+
dd_status.text = (f'<i style="color:#888">Running DynaDiff reconstruction '
|
| 1485 |
+
f'({n_trials} trial{"s" if n_trials > 1 else ""})…</i>')
|
| 1486 |
doc = curdoc()
|
| 1487 |
threading.Thread(
|
| 1488 |
target=_reconstruct_thread,
|
| 1489 |
+
args=(sample_idxs, steerings, seed, feat_name, doc),
|
| 1490 |
daemon=True,
|
| 1491 |
).start()
|
| 1492 |
|
|
|
|
| 1631 |
else:
|
| 1632 |
hmap = None
|
| 1633 |
|
| 1634 |
+
img_label = os.path.splitext(os.path.basename(image_paths[img_i]))[0]
|
| 1635 |
if hmap is None:
|
| 1636 |
plain = load_image(img_i).resize((THUMB, THUMB), Image.BILINEAR)
|
| 1637 |
act_val = float(act_tensor[feat, ranking_idx].item())
|
| 1638 |
+
caption = f"act={act_val:.4f} {img_label}"
|
| 1639 |
return (plain, caption)
|
| 1640 |
max_act, mean_act_val = _patch_stats(hmap.flatten())
|
| 1641 |
img_out = render_zoomed_overlay(img_i, hmap, size=THUMB, center=center)
|
| 1642 |
+
caption = img_label
|
| 1643 |
return (img_out, caption)
|
| 1644 |
except Exception as e:
|
| 1645 |
ph = Image.new("RGB", (THUMB, THUMB), "gray")
|
|
|
|
| 1862 |
name=[_display_name(int(i)) for i in _init_order],
|
| 1863 |
))
|
| 1864 |
|
| 1865 |
+
def _phi_col():
|
| 1866 |
+
"""Return phi_c column definition list (single element) if phi data is loaded, else []."""
|
| 1867 |
+
if not HAS_PHI:
|
| 1868 |
+
return []
|
| 1869 |
+
return [TableColumn(field="phi_c_val", title="φ_c", width=65,
|
| 1870 |
+
formatter=NumberFormatter(format="0.0000"))]
|
| 1871 |
+
|
| 1872 |
feature_table = DataTable(
|
| 1873 |
source=feature_list_source,
|
| 1874 |
columns=[
|
|
|
|
| 1877 |
formatter=NumberFormatter(format="0,0")),
|
| 1878 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 1879 |
formatter=NumberFormatter(format="0.0000")),
|
| 1880 |
+
] + _phi_col() + [
|
|
|
|
|
|
|
| 1881 |
TableColumn(field="name", title="Name", width=200),
|
| 1882 |
],
|
| 1883 |
width=500, height=500, sortable=True, index_position=None,
|
|
|
|
| 2215 |
clear_patch_btn = Button(label="Clear", width=60)
|
| 2216 |
|
| 2217 |
patch_feat_source = ColumnDataSource(data=dict(
|
| 2218 |
+
feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[],
|
| 2219 |
))
|
| 2220 |
patch_feat_table = DataTable(
|
| 2221 |
source=patch_feat_source,
|
| 2222 |
columns=[
|
| 2223 |
+
TableColumn(field="feature_idx", title="Feature", width=65),
|
| 2224 |
TableColumn(field="patch_act", title="Patch Act", width=85,
|
| 2225 |
formatter=NumberFormatter(format="0.0000")),
|
| 2226 |
TableColumn(field="frequency", title="Freq", width=65,
|
| 2227 |
formatter=NumberFormatter(format="0,0")),
|
| 2228 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 2229 |
formatter=NumberFormatter(format="0.0000")),
|
| 2230 |
+
] + _phi_col(),
|
| 2231 |
+
width=310 + (65 if HAS_PHI else 0), height=350, index_position=None, sortable=False, visible=False,
|
| 2232 |
)
|
| 2233 |
patch_info_div = Div(
|
| 2234 |
text="<i>Load an image, then click patches to find top features.</i>",
|
|
|
|
| 2248 |
|
| 2249 |
def _do_load_patch_image():
|
| 2250 |
try:
|
| 2251 |
+
img_idx = _parse_img_label(patch_img_input.value)
|
| 2252 |
except ValueError:
|
| 2253 |
patch_info_div.text = "<b style='color:red'>Invalid image index</b>"
|
| 2254 |
return
|
|
|
|
| 2337 |
if _S.patch_img is None:
|
| 2338 |
return
|
| 2339 |
if not new:
|
| 2340 |
+
patch_feat_source.data = dict(feature_idx=[], patch_act=[], frequency=[], mean_act=[], phi_c_val=[])
|
| 2341 |
patch_info_div.text = "<i>Selection cleared.</i>"
|
| 2342 |
return
|
| 2343 |
|
|
|
|
| 2347 |
patch_indices = [r * patch_grid + c for r, c in zip(rows, cols)]
|
| 2348 |
|
| 2349 |
feats, acts, freqs, means = _get_top_features_for_patches(patch_indices)
|
| 2350 |
+
patch_feat_source.data = dict(
|
| 2351 |
+
feature_idx=feats, patch_act=acts, frequency=freqs, mean_act=means,
|
| 2352 |
+
phi_c_val=_phi_c_vals(feats),
|
| 2353 |
+
)
|
| 2354 |
patch_info_div.text = (
|
| 2355 |
f"{len(new)} patch(es) selected → {len(feats)} feature(s) found. "
|
| 2356 |
f"Click a row below to explore the feature."
|
|
|
|
| 2395 |
clip_top_k_input = TextInput(title="Top-K results:", value="20", width=70)
|
| 2396 |
|
| 2397 |
result_source = ColumnDataSource(data=dict(
|
| 2398 |
+
feature_idx=[], clip_score=[], frequency=[], mean_act=[], phi_c_val=[], name=[],
|
| 2399 |
))
|
| 2400 |
clip_result_table = DataTable(
|
| 2401 |
source=result_source,
|
|
|
|
| 2407 |
formatter=NumberFormatter(format="0,0")),
|
| 2408 |
TableColumn(field="mean_act", title="Mean Act", width=80,
|
| 2409 |
formatter=NumberFormatter(format="0.0000")),
|
| 2410 |
+
] + _phi_col() + [
|
| 2411 |
TableColumn(field="name", title="Name", width=160),
|
| 2412 |
],
|
| 2413 |
+
width=470 + (65 if HAS_PHI else 0), height=300, index_position=None, sortable=False,
|
| 2414 |
)
|
| 2415 |
|
| 2416 |
def _do_search():
|
|
|
|
| 2451 |
clip_score=[float(scores_vec[i]) for i in top_indices],
|
| 2452 |
frequency=[int(feature_frequency[i].item()) for i in top_indices],
|
| 2453 |
mean_act=[float(feature_mean_act[i].item()) for i in top_indices],
|
| 2454 |
+
phi_c_val=_phi_c_vals(top_indices),
|
| 2455 |
name=[_display_name(int(i)) for i in top_indices],
|
| 2456 |
)
|
| 2457 |
result_div.text = (
|