SAE_Brain_Semantic_Interface / scripts /add_clip_embeddings.py
Marlin Lee
Sync local changes: CLIP scores, NSD image lookup, multi-trial DynaDiff, phi_c columns, P75 col, label captions, entrypoint pre-warm
fd8ee51
"""
Post-hoc CLIP text-alignment enrichment for explorer_data.pt files.
Loads an existing explorer_data.pt, computes per-feature CLIP text alignment
scores (via MEI images), and saves them back into the same file under:
'clip_text_scores' : Tensor (n_features, n_vocab) float16
'clip_text_vocab' : list[str]
'clip_feature_embeds' : Tensor (n_features, clip_proj_dim) float16
mean CLIP image embedding of each feature's top MEIs
This script does NOT need to re-run DINOv3 or the SAE — it only needs the
existing explorer_data.pt (for image paths and top-MEI indices) and CLIP.
Usage
-----
python add_clip_embeddings.py \
--data ../explorer_data_d32000_k160.pt \
--vocab-file ../vocab/imagenet_labels.txt \
--n-top-images 4 \
--batch-size 32
# Or use the built-in default vocabulary (ImageNet-1K labels + COCO categories):
python add_clip_embeddings.py \
--data ../explorer_data_d32000_k160.pt
The enriched file is saved to --output-path (defaults to overwriting --data
with a backup copy at <data>.bak).
"""
import argparse
import os
import shutil
import torch
import torch.nn.functional as F
from PIL import Image
# Allow running from scripts/ directory or project root
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
from clip_utils import load_clip, compute_text_embeddings, compute_mei_text_alignment
# ---------------------------------------------------------------------------
# Default vocabulary
# ---------------------------------------------------------------------------
DEFAULT_VOCAB = [
# COCO categories
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag",
"tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite",
"baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana",
"apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza",
"donut", "cake", "chair", "couch", "potted plant", "bed", "dining table",
"toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock",
"vase", "scissors", "teddy bear", "hair drier", "toothbrush",
# Texture / scene descriptors
"grass", "sky", "water", "sand", "snow", "wood", "stone", "metal",
"fabric", "fur", "feathers", "leaves", "clouds", "fire", "shadow",
"stripes", "spots", "checkerboard pattern", "geometric pattern",
# Orientation / structure cues (for patch features)
"horizontal lines", "vertical lines", "diagonal lines", "curved lines",
"edges", "corners", "grid", "dots", "concentric circles",
# Color / illumination
"red object", "blue object", "green object", "yellow object",
"black and white", "bright highlight", "dark shadow", "gradient",
# Scene types
"indoor scene", "outdoor scene", "urban street", "nature landscape",
"ocean", "mountain", "forest", "desert", "city buildings", "crowd",
]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Add CLIP text alignment to explorer_data.pt")
parser.add_argument("--data", type=str, required=True,
help="Path to explorer_data.pt")
parser.add_argument("--output-path", type=str, default=None,
help="Output path (default: overwrite --data, keeping .bak)")
parser.add_argument("--vocab-file", type=str, default=None,
help="Plain-text file with one concept per line. "
"Default: built-in COCO+texture vocabulary.")
parser.add_argument("--clip-model", type=str, default="openai/clip-vit-large-patch14",
help="HuggingFace CLIP model ID")
parser.add_argument("--n-top-images", type=int, default=4,
help="Number of MEIs to average per feature for CLIP alignment")
parser.add_argument("--batch-size", type=int, default=32,
help="Batch size for CLIP image encoding")
parser.add_argument("--no-backup", action="store_true",
help="Skip creating a .bak copy before overwriting")
parser.add_argument("--image-dir", type=str, default=None,
help="Primary image directory for resolving bare filenames")
parser.add_argument("--extra-image-dir", type=str, action="append", default=[],
help="Additional image directory (repeatable)")
args = parser.parse_args()
image_bases = [b for b in ([args.image_dir] + args.extra_image_dir) if b]
def resolve_path(p):
if os.path.isabs(p) or not image_bases:
return p
for base in image_bases:
full = os.path.join(base, p)
if os.path.exists(full):
return full
return os.path.join(image_bases[0], p) # fallback
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# --- Load explorer data ---
print(f"Loading explorer data from {args.data}...")
data = torch.load(args.data, map_location='cpu', weights_only=False)
image_paths = [resolve_path(p) for p in data['image_paths']]
d_model = data['d_model']
top_img_idx = data['top_img_idx'] # (n_features, n_top)
n_top_stored = top_img_idx.shape[1]
print(f" d_model={d_model}, n_images={data['n_images']}, "
f"top-{n_top_stored} images stored")
# --- Load vocabulary ---
if args.vocab_file:
with open(args.vocab_file) as f:
vocab = [line.strip() for line in f if line.strip()]
print(f"Loaded {len(vocab)} concepts from {args.vocab_file}")
else:
vocab = DEFAULT_VOCAB
print(f"Using default vocabulary ({len(vocab)} concepts)")
# --- Load CLIP ---
clip_model, clip_processor = load_clip(device, model_name=args.clip_model)
# --- Precompute text embeddings ---
print("Encoding text vocabulary with CLIP...")
text_embeds = compute_text_embeddings(vocab, clip_model, clip_processor, device)
print(f" text_embeds: {text_embeds.shape}")
# --- Collect MEI image paths per feature ---
print("Collecting MEI image paths per feature...")
n_use = min(args.n_top_images, n_top_stored)
feature_mei_paths = []
for feat in range(d_model):
paths = []
for j in range(n_use):
idx = top_img_idx[feat, j].item()
if idx >= 0:
paths.append(image_paths[idx])
feature_mei_paths.append(paths)
# --- Compute per-feature CLIP image embeddings (mean of MEIs) ---
print(f"Computing CLIP image embeddings for {d_model} features "
f"(averaging {n_use} MEIs each)...")
clip_proj_dim = clip_model.config.projection_dim
feature_img_embeds = torch.zeros(d_model, clip_proj_dim, dtype=torch.float32)
dead_count = 0
for feat_start in range(0, d_model, args.batch_size):
feat_end = min(feat_start + args.batch_size, d_model)
for feat in range(feat_start, feat_end):
paths = feature_mei_paths[feat]
if not paths:
dead_count += 1
continue
imgs = []
for p in paths:
try:
imgs.append(Image.open(p).convert("RGB"))
except Exception:
continue
if not imgs:
dead_count += 1
continue
inputs = clip_processor(images=imgs, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
with torch.inference_mode():
# Use vision_model + visual_projection directly to avoid
# version differences in get_image_features() return type.
vision_out = clip_model.vision_model(pixel_values=pixel_values)
embeds = clip_model.visual_projection(vision_out.pooler_output)
embeds = F.normalize(embeds, dim=-1)
mean_embed = embeds.mean(dim=0)
mean_embed = F.normalize(mean_embed, dim=-1)
feature_img_embeds[feat] = mean_embed.cpu().float()
if (feat_start // args.batch_size + 1) % 100 == 0:
print(f" [{feat_end}/{d_model}] features encoded", flush=True)
print(f" Done. Dead/missing features skipped: {dead_count}")
# --- Compute alignment matrix ---
print("Computing text alignment matrix...")
# (n_features, clip_proj_dim) @ (clip_proj_dim, n_vocab) = (n_features, n_vocab)
clip_text_scores = feature_img_embeds @ text_embeds.T # float32
print(f" clip_text_scores: {clip_text_scores.shape}")
# --- Save into explorer_data.pt ---
output_path = args.output_path or args.data
if output_path == args.data and not args.no_backup:
bak_path = args.data + ".bak"
print(f"Creating backup at {bak_path}...")
shutil.copy2(args.data, bak_path)
data['clip_text_scores'] = clip_text_scores.half() # float16 to save space
data['clip_feature_embeds'] = feature_img_embeds.half() # float16
data['clip_text_vocab'] = vocab
print(f"Saving enriched explorer data to {output_path}...")
torch.save(data, output_path)
size_mb = os.path.getsize(output_path) / 1e6
print(f"Saved ({size_mb:.1f} MB)")
print("Done.")
if __name__ == "__main__":
main()