Marlin Lee commited on
Commit
fee4ae4
·
1 Parent(s): 38c8638

Sync explorer_app.py and clip_utils.py from main repo

Browse files
Files changed (1) hide show
  1. scripts/explorer_app.py +34 -12
scripts/explorer_app.py CHANGED
@@ -1115,8 +1115,9 @@ name_input.on_change('value', on_name_change)
1115
 
1116
 
1117
  # ---------- Gemini auto-interp button ----------
1118
- _N_GEMINI_IMAGES = 6
1119
- _GEMINI_MODEL = "gemini-2.5-flash"
 
1120
 
1121
  def _resolve_img_path(stored_path):
1122
  """Mirror the path resolution from auto_interp_vlm.py."""
@@ -1132,8 +1133,11 @@ def _resolve_img_path(stored_path):
1132
  return None
1133
 
1134
 
1135
- def _gemini_label_thread(feat, mei_paths, doc):
1136
- """Run in a worker thread: call Gemini and push the result back to the doc."""
 
 
 
1137
  try:
1138
  from google import genai
1139
  from google.genai import types
@@ -1141,10 +1145,14 @@ def _gemini_label_thread(feat, mei_paths, doc):
1141
  SYSTEM_PROMPT = (
1142
  "You are labeling features of a Sparse Autoencoder (SAE) trained on a "
1143
  "vision transformer. Each SAE feature is a sparse direction in activation "
1144
- "space that fires strongly on certain visual patterns."
 
 
1145
  )
1146
  USER_PROMPT = (
1147
- "The images below are the top maximally-activating images for one SAE feature. "
 
 
1148
  "In 2–5 words, give a precise label for the visual concept this feature detects. "
1149
  "Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' "
1150
  "over 'texture'. "
@@ -1153,12 +1161,23 @@ def _gemini_label_thread(feat, mei_paths, doc):
1153
 
1154
  client = genai.Client(api_key=_gemini_api_key)
1155
  parts = []
1156
- for p in mei_paths[:_N_GEMINI_IMAGES]:
1157
- resolved = _resolve_img_path(p)
1158
  if resolved is None:
1159
  continue
1160
  try:
1161
  img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR)
 
 
 
 
 
 
 
 
 
 
 
1162
  buf = io.BytesIO()
1163
  img.save(buf, format="JPEG", quality=85)
1164
  parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg"))
@@ -1218,13 +1237,16 @@ def _on_gemini_click():
1218
  return
1219
 
1220
  n_top_stored = top_img_idx.shape[1]
1221
- mei_paths = []
1222
  for j in range(n_top_stored):
1223
  idx = top_img_idx[feat, j].item()
1224
  if idx >= 0:
1225
- mei_paths.append(image_paths[idx])
 
 
 
1226
 
1227
- if not mei_paths:
1228
  gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>"
1229
  return
1230
 
@@ -1234,7 +1256,7 @@ def _on_gemini_click():
1234
  doc = curdoc()
1235
  t = threading.Thread(
1236
  target=_gemini_label_thread,
1237
- args=(feat, mei_paths, doc),
1238
  daemon=True,
1239
  )
1240
  t.start()
 
1115
 
1116
 
1117
  # ---------- Gemini auto-interp button ----------
1118
+ _N_GEMINI_IMAGES = 6
1119
+ _GEMINI_MODEL = "gemini-2.5-flash"
1120
+ _GEMINI_HM_ALPHA = 0.25 # heatmap overlay opacity sent to Gemini
1121
 
1122
  def _resolve_img_path(stored_path):
1123
  """Mirror the path resolution from auto_interp_vlm.py."""
 
1133
  return None
1134
 
1135
 
1136
+ def _gemini_label_thread(feat, mei_items, doc):
1137
+ """Run in a worker thread: call Gemini and push the result back to the doc.
1138
+
1139
+ mei_items: list of (path_str, heatmap_np_or_None) where heatmap is (H, W) float32.
1140
+ """
1141
  try:
1142
  from google import genai
1143
  from google.genai import types
 
1145
  SYSTEM_PROMPT = (
1146
  "You are labeling features of a Sparse Autoencoder (SAE) trained on a "
1147
  "vision transformer. Each SAE feature is a sparse direction in activation "
1148
+ "space that fires strongly on certain visual patterns. "
1149
+ "Each image has a colour heatmap overlay highlighting the patches where "
1150
+ "the feature activates most strongly."
1151
  )
1152
  USER_PROMPT = (
1153
+ "The images below show the top maximally-activating images for one SAE feature, "
1154
+ "with a heatmap overlay showing where in each image the feature fires most strongly. "
1155
+ "Focus on the highlighted regions. "
1156
  "In 2–5 words, give a precise label for the visual concept this feature detects. "
1157
  "Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' "
1158
  "over 'texture'. "
 
1161
 
1162
  client = genai.Client(api_key=_gemini_api_key)
1163
  parts = []
1164
+ for path, heatmap in mei_items[:_N_GEMINI_IMAGES]:
1165
+ resolved = _resolve_img_path(path)
1166
  if resolved is None:
1167
  continue
1168
  try:
1169
  img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR)
1170
+ if heatmap is not None:
1171
+ img_arr = np.array(img).astype(np.float32) / 255.0
1172
+ hm_up = cv2.resize(heatmap.astype(np.float32), (224, 224),
1173
+ interpolation=cv2.INTER_CUBIC)
1174
+ hmax = hm_up.max()
1175
+ if hmax > 0:
1176
+ hm_up /= hmax
1177
+ overlay = ALPHA_JET(hm_up)
1178
+ ov_alpha = overlay[:, :, 3:4] * _GEMINI_HM_ALPHA
1179
+ blended = img_arr * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha
1180
+ img = Image.fromarray(np.clip(blended * 255, 0, 255).astype(np.uint8))
1181
  buf = io.BytesIO()
1182
  img.save(buf, format="JPEG", quality=85)
1183
  parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg"))
 
1237
  return
1238
 
1239
  n_top_stored = top_img_idx.shape[1]
1240
+ mei_items = []
1241
  for j in range(n_top_stored):
1242
  idx = top_img_idx[feat, j].item()
1243
  if idx >= 0:
1244
+ hm = None
1245
+ if top_heatmaps is not None:
1246
+ hm = top_heatmaps[feat, j].float().numpy().reshape(patch_grid, patch_grid)
1247
+ mei_items.append((image_paths[idx], hm))
1248
 
1249
+ if not mei_items:
1250
  gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>"
1251
  return
1252
 
 
1256
  doc = curdoc()
1257
  t = threading.Thread(
1258
  target=_gemini_label_thread,
1259
+ args=(feat, mei_items, doc),
1260
  daemon=True,
1261
  )
1262
  t.start()