""" CLIP text-alignment utilities for SAE feature interpretation. Key functions: - compute_text_embeddings: encode text strings into L2-normalised CLIP embeddings. - compute_mei_text_alignment: align SAE features to text via their top MEI images. - compute_text_alignment: dot-product similarity between precomputed feature/text embeds. - search_features_by_text: find top-k features for a free-text query. The precomputed scores can be stored in explorer_data.pt under: 'clip_text_scores' : Tensor (n_features, n_vocab) float16 'clip_text_vocab' : list[str] 'clip_feature_embeds': Tensor (n_features, clip_proj_dim) float32 """ import torch import torch.nn.functional as F from transformers import CLIPModel, CLIPProcessor # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_clip(device: str | torch.device = "cpu", model_name: str = "openai/clip-vit-large-patch14"): """ Load a CLIP model and processor. Parameters ---------- device : str or torch.device model_name : str HuggingFace model ID. Default matches the ViT-L/14 variant used by many vision papers and is a reasonable match for DINOv3-ViT-L/16. Returns ------- model : CLIPModel (eval mode, on device) processor : CLIPProcessor """ print(f"Loading CLIP ({model_name})...") processor = CLIPProcessor.from_pretrained(model_name) model = CLIPModel.from_pretrained(model_name, torch_dtype=torch.float32) model = model.to(device).eval() print(f" CLIP loaded (d_text={model.config.projection_dim})") return model, processor # --------------------------------------------------------------------------- # Core alignment computation # --------------------------------------------------------------------------- def compute_text_embeddings( texts: list[str], model: CLIPModel, processor: CLIPProcessor, device: str | torch.device, batch_size: int = 256, ) -> torch.Tensor: """ Encode a list of text strings into L2-normalised CLIP text embeddings. Returns ------- Tensor of shape (len(texts), clip_proj_dim), float32, on CPU. """ all_embeds = [] for start in range(0, len(texts), batch_size): batch = texts[start : start + batch_size] inputs = processor(text=batch, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.inference_mode(): # Go through text_model + text_projection directly to avoid # version differences in get_text_features() return type. text_out = model.text_model( input_ids=inputs['input_ids'], attention_mask=inputs.get('attention_mask'), ) embeds = model.text_projection(text_out.pooler_output) embeds = F.normalize(embeds, dim=-1) all_embeds.append(embeds.cpu().float()) return torch.cat(all_embeds, dim=0) # (n_texts, clip_proj_dim) def compute_text_alignment( feature_vision_embeds: torch.Tensor, text_embeds: torch.Tensor, ) -> torch.Tensor: """ Compute pairwise cosine similarity between feature embeddings and text embeddings. Both inputs must already be L2-normalised. Parameters ---------- feature_vision_embeds : Tensor (n_features, d) text_embeds : Tensor (n_texts, d) Returns ------- Tensor (n_features, n_texts) of cosine similarities in [-1, 1]. """ return feature_vision_embeds @ text_embeds.T # (n_features, n_texts) # --------------------------------------------------------------------------- # MEI-based text alignment (more accurate, more expensive) # --------------------------------------------------------------------------- def compute_mei_text_alignment( top_img_paths: list[list[str]], texts: list[str], model: CLIPModel, processor: CLIPProcessor, device: str | torch.device, n_top_images: int = 4, batch_size: int = 32, ) -> torch.Tensor: """ For each feature, compute the mean CLIP image embedding of its top-N MEIs, then return cosine similarity against each text embedding. This is the most principled approach: CLIP operates on actual images, so the alignment reflects the true visual concept captured by the feature. Parameters ---------- top_img_paths : list of lists top_img_paths[i] = list of image file paths for feature i's MEIs. texts : list[str] Text queries / vocabulary concepts. n_top_images : int How many MEIs to average per feature. batch_size : int Returns ------- Tensor (n_features, n_texts) float32, on CPU. """ from PIL import Image n_features = len(top_img_paths) text_embeds = compute_text_embeddings(texts, model, processor, device) # text_embeds: (n_texts, d) feature_img_embeds = [] for feat_paths in top_img_paths: paths = [p for p in feat_paths[:n_top_images] if p] if not paths: feature_img_embeds.append(torch.zeros(model.config.projection_dim)) continue imgs = [Image.open(p).convert("RGB") for p in paths] inputs = processor(images=imgs, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) with torch.inference_mode(): vision_out = model.vision_model(pixel_values=pixel_values) img_embeds = model.visual_projection(vision_out.pooler_output) # (n_imgs, d) img_embeds = F.normalize(img_embeds, dim=-1) mean_embed = img_embeds.mean(dim=0) mean_embed = F.normalize(mean_embed, dim=-1) feature_img_embeds.append(mean_embed.cpu().float()) feature_img_embeds = torch.stack(feature_img_embeds, dim=0) # (n_feat, d) return feature_img_embeds @ text_embeds.T # (n_feat, n_texts) # --------------------------------------------------------------------------- # Feature search by free-text query # --------------------------------------------------------------------------- def search_features_by_text( query: str, clip_scores: torch.Tensor, vocab: list[str], model: CLIPModel, processor: CLIPProcessor, device: str | torch.device, top_k: int = 20, feature_embeds: torch.Tensor | None = None, ) -> list[tuple[int, float]]: """ Find the top-k SAE features most aligned with a free-text query. If the query is already in `vocab`, use the precomputed scores directly. Otherwise encode the query on-the-fly and compute dot products against `feature_embeds` (the per-feature MEI image embeddings stored as 'clip_feature_embeds' in explorer_data.pt). Parameters ---------- query : str clip_scores : Tensor (n_features, n_vocab) Precomputed alignment matrix (L2-normalised features × L2-normalised text embeddings). vocab : list[str] model, processor, device : CLIP model components (used for on-the-fly encoding) top_k : int feature_embeds : Tensor (n_features, clip_proj_dim) or None L2-normalised per-feature MEI image embeddings. Required for free-text queries that are not in `vocab`. Returns ------- list of (feature_idx, score) sorted by score descending. """ if query in vocab: col = vocab.index(query) scores_vec = clip_scores[:, col].float() # (n_features,) else: if feature_embeds is None: raise ValueError( "Free-text query requires 'feature_embeds' (clip_feature_embeds " "from explorer_data.pt). Pass feature_embeds=data['clip_feature_embeds'] " "or restrict queries to vocab terms." ) q_embed = compute_text_embeddings([query], model, processor, device) # (1, d) scores_vec = (feature_embeds.float() @ q_embed.T).squeeze(-1) # (n_features,) top_indices = torch.topk(scores_vec, k=min(top_k, len(scores_vec))).indices return [(int(i), float(scores_vec[i])) for i in top_indices]