File size: 8,215 Bytes
884a21e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
CLIP text-alignment utilities for SAE feature interpretation.

Key functions:
- compute_text_embeddings: encode text strings into L2-normalised CLIP embeddings.
- compute_mei_text_alignment: align SAE features to text via their top MEI images.
- compute_text_alignment: dot-product similarity between precomputed feature/text embeds.
- search_features_by_text: find top-k features for a free-text query.

The precomputed scores can be stored in explorer_data.pt under:
    'clip_text_scores'   : Tensor (n_features, n_vocab)  float16
    'clip_text_vocab'    : list[str]
    'clip_feature_embeds': Tensor (n_features, clip_proj_dim)  float32
"""

import torch
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor


# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------

def load_clip(device: str | torch.device = "cpu", model_name: str = "openai/clip-vit-large-patch14"):
    """
    Load a CLIP model and processor.

    Parameters
    ----------
    device : str or torch.device
    model_name : str
        HuggingFace model ID.  Default matches the ViT-L/14 variant used by
        many vision papers and is a reasonable match for DINOv3-ViT-L/16.

    Returns
    -------
    model : CLIPModel (eval mode, on device)
    processor : CLIPProcessor
    """
    print(f"Loading CLIP ({model_name})...")
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name, torch_dtype=torch.float32)
    model = model.to(device).eval()
    print(f"  CLIP loaded (d_text={model.config.projection_dim})")
    return model, processor


# ---------------------------------------------------------------------------
# Core alignment computation
# ---------------------------------------------------------------------------

def compute_text_embeddings(
    texts: list[str],
    model: CLIPModel,
    processor: CLIPProcessor,
    device: str | torch.device,
    batch_size: int = 256,
) -> torch.Tensor:
    """
    Encode a list of text strings into L2-normalised CLIP text embeddings.

    Returns
    -------
    Tensor of shape (len(texts), clip_proj_dim), float32, on CPU.
    """
    all_embeds = []
    for start in range(0, len(texts), batch_size):
        batch = texts[start : start + batch_size]
        inputs = processor(text=batch, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.inference_mode():
            # Go through text_model + text_projection directly to avoid
            # version differences in get_text_features() return type.
            text_out = model.text_model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs.get('attention_mask'),
            )
            embeds = model.text_projection(text_out.pooler_output)
            embeds = F.normalize(embeds, dim=-1)
        all_embeds.append(embeds.cpu().float())
    return torch.cat(all_embeds, dim=0)  # (n_texts, clip_proj_dim)


def compute_text_alignment(
    feature_vision_embeds: torch.Tensor,
    text_embeds: torch.Tensor,
) -> torch.Tensor:
    """
    Compute pairwise cosine similarity between feature embeddings and text
    embeddings.  Both inputs must already be L2-normalised.

    Parameters
    ----------
    feature_vision_embeds : Tensor (n_features, d)
    text_embeds : Tensor (n_texts, d)

    Returns
    -------
    Tensor (n_features, n_texts) of cosine similarities in [-1, 1].
    """
    return feature_vision_embeds @ text_embeds.T   # (n_features, n_texts)


# ---------------------------------------------------------------------------
# MEI-based text alignment (more accurate, more expensive)
# ---------------------------------------------------------------------------

def compute_mei_text_alignment(
    top_img_paths: list[list[str]],
    texts: list[str],
    model: CLIPModel,
    processor: CLIPProcessor,
    device: str | torch.device,
    n_top_images: int = 4,
    batch_size: int = 32,
) -> torch.Tensor:
    """
    For each feature, compute the mean CLIP image embedding of its top-N MEIs,
    then return cosine similarity against each text embedding.

    This is the most principled approach: CLIP operates on actual images, so
    the alignment reflects the true visual concept captured by the feature.

    Parameters
    ----------
    top_img_paths : list of lists
        top_img_paths[i] = list of image file paths for feature i's MEIs.
    texts : list[str]
        Text queries / vocabulary concepts.
    n_top_images : int
        How many MEIs to average per feature.
    batch_size : int

    Returns
    -------
    Tensor (n_features, n_texts) float32, on CPU.
    """
    from PIL import Image

    n_features = len(top_img_paths)
    text_embeds = compute_text_embeddings(texts, model, processor, device)
    # text_embeds: (n_texts, d)

    feature_img_embeds = []
    for feat_paths in top_img_paths:
        paths = [p for p in feat_paths[:n_top_images] if p]
        if not paths:
            feature_img_embeds.append(torch.zeros(model.config.projection_dim))
            continue

        imgs = [Image.open(p).convert("RGB") for p in paths]
        inputs = processor(images=imgs, return_tensors="pt")
        pixel_values = inputs['pixel_values'].to(device)
        with torch.inference_mode():
            vision_out = model.vision_model(pixel_values=pixel_values)
            img_embeds = model.visual_projection(vision_out.pooler_output)  # (n_imgs, d)
            img_embeds = F.normalize(img_embeds, dim=-1)
            mean_embed = img_embeds.mean(dim=0)
            mean_embed = F.normalize(mean_embed, dim=-1)
        feature_img_embeds.append(mean_embed.cpu().float())

    feature_img_embeds = torch.stack(feature_img_embeds, dim=0)  # (n_feat, d)
    return feature_img_embeds @ text_embeds.T                     # (n_feat, n_texts)


# ---------------------------------------------------------------------------
# Feature search by free-text query
# ---------------------------------------------------------------------------

def search_features_by_text(
    query: str,
    clip_scores: torch.Tensor,
    vocab: list[str],
    model: CLIPModel,
    processor: CLIPProcessor,
    device: str | torch.device,
    top_k: int = 20,
    feature_embeds: torch.Tensor | None = None,
) -> list[tuple[int, float]]:
    """
    Find the top-k SAE features most aligned with a free-text query.

    If the query is already in `vocab`, use the precomputed scores directly.
    Otherwise encode the query on-the-fly and compute dot products against
    `feature_embeds` (the per-feature MEI image embeddings stored as
    'clip_feature_embeds' in explorer_data.pt).

    Parameters
    ----------
    query : str
    clip_scores : Tensor (n_features, n_vocab)
        Precomputed alignment matrix (L2-normalised features × L2-normalised
        text embeddings).
    vocab : list[str]
    model, processor, device : CLIP model components (used for on-the-fly encoding)
    top_k : int
    feature_embeds : Tensor (n_features, clip_proj_dim) or None
        L2-normalised per-feature MEI image embeddings.  Required for
        free-text queries that are not in `vocab`.

    Returns
    -------
    list of (feature_idx, score) sorted by score descending.
    """
    if query in vocab:
        col = vocab.index(query)
        scores_vec = clip_scores[:, col].float()                  # (n_features,)
    else:
        if feature_embeds is None:
            raise ValueError(
                "Free-text query requires 'feature_embeds' (clip_feature_embeds "
                "from explorer_data.pt).  Pass feature_embeds=data['clip_feature_embeds'] "
                "or restrict queries to vocab terms."
            )
        q_embed = compute_text_embeddings([query], model, processor, device)  # (1, d)
        scores_vec = (feature_embeds.float() @ q_embed.T).squeeze(-1)        # (n_features,)

    top_indices = torch.topk(scores_vec, k=min(top_k, len(scores_vec))).indices
    return [(int(i), float(scores_vec[i])) for i in top_indices]