Spaces:
Running
Running
Marlin Lee commited on
Commit ·
fee4ae4
1
Parent(s): 38c8638
Sync explorer_app.py and clip_utils.py from main repo
Browse files- scripts/explorer_app.py +34 -12
scripts/explorer_app.py
CHANGED
|
@@ -1115,8 +1115,9 @@ name_input.on_change('value', on_name_change)
|
|
| 1115 |
|
| 1116 |
|
| 1117 |
# ---------- Gemini auto-interp button ----------
|
| 1118 |
-
_N_GEMINI_IMAGES
|
| 1119 |
-
_GEMINI_MODEL
|
|
|
|
| 1120 |
|
| 1121 |
def _resolve_img_path(stored_path):
|
| 1122 |
"""Mirror the path resolution from auto_interp_vlm.py."""
|
|
@@ -1132,8 +1133,11 @@ def _resolve_img_path(stored_path):
|
|
| 1132 |
return None
|
| 1133 |
|
| 1134 |
|
| 1135 |
-
def _gemini_label_thread(feat,
|
| 1136 |
-
"""Run in a worker thread: call Gemini and push the result back to the doc.
|
|
|
|
|
|
|
|
|
|
| 1137 |
try:
|
| 1138 |
from google import genai
|
| 1139 |
from google.genai import types
|
|
@@ -1141,10 +1145,14 @@ def _gemini_label_thread(feat, mei_paths, doc):
|
|
| 1141 |
SYSTEM_PROMPT = (
|
| 1142 |
"You are labeling features of a Sparse Autoencoder (SAE) trained on a "
|
| 1143 |
"vision transformer. Each SAE feature is a sparse direction in activation "
|
| 1144 |
-
"space that fires strongly on certain visual patterns."
|
|
|
|
|
|
|
| 1145 |
)
|
| 1146 |
USER_PROMPT = (
|
| 1147 |
-
"The images below
|
|
|
|
|
|
|
| 1148 |
"In 2–5 words, give a precise label for the visual concept this feature detects. "
|
| 1149 |
"Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' "
|
| 1150 |
"over 'texture'. "
|
|
@@ -1153,12 +1161,23 @@ def _gemini_label_thread(feat, mei_paths, doc):
|
|
| 1153 |
|
| 1154 |
client = genai.Client(api_key=_gemini_api_key)
|
| 1155 |
parts = []
|
| 1156 |
-
for
|
| 1157 |
-
resolved = _resolve_img_path(
|
| 1158 |
if resolved is None:
|
| 1159 |
continue
|
| 1160 |
try:
|
| 1161 |
img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1162 |
buf = io.BytesIO()
|
| 1163 |
img.save(buf, format="JPEG", quality=85)
|
| 1164 |
parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg"))
|
|
@@ -1218,13 +1237,16 @@ def _on_gemini_click():
|
|
| 1218 |
return
|
| 1219 |
|
| 1220 |
n_top_stored = top_img_idx.shape[1]
|
| 1221 |
-
|
| 1222 |
for j in range(n_top_stored):
|
| 1223 |
idx = top_img_idx[feat, j].item()
|
| 1224 |
if idx >= 0:
|
| 1225 |
-
|
|
|
|
|
|
|
|
|
|
| 1226 |
|
| 1227 |
-
if not
|
| 1228 |
gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>"
|
| 1229 |
return
|
| 1230 |
|
|
@@ -1234,7 +1256,7 @@ def _on_gemini_click():
|
|
| 1234 |
doc = curdoc()
|
| 1235 |
t = threading.Thread(
|
| 1236 |
target=_gemini_label_thread,
|
| 1237 |
-
args=(feat,
|
| 1238 |
daemon=True,
|
| 1239 |
)
|
| 1240 |
t.start()
|
|
|
|
| 1115 |
|
| 1116 |
|
| 1117 |
# ---------- Gemini auto-interp button ----------
|
| 1118 |
+
_N_GEMINI_IMAGES = 6
|
| 1119 |
+
_GEMINI_MODEL = "gemini-2.5-flash"
|
| 1120 |
+
_GEMINI_HM_ALPHA = 0.25 # heatmap overlay opacity sent to Gemini
|
| 1121 |
|
| 1122 |
def _resolve_img_path(stored_path):
|
| 1123 |
"""Mirror the path resolution from auto_interp_vlm.py."""
|
|
|
|
| 1133 |
return None
|
| 1134 |
|
| 1135 |
|
| 1136 |
+
def _gemini_label_thread(feat, mei_items, doc):
|
| 1137 |
+
"""Run in a worker thread: call Gemini and push the result back to the doc.
|
| 1138 |
+
|
| 1139 |
+
mei_items: list of (path_str, heatmap_np_or_None) where heatmap is (H, W) float32.
|
| 1140 |
+
"""
|
| 1141 |
try:
|
| 1142 |
from google import genai
|
| 1143 |
from google.genai import types
|
|
|
|
| 1145 |
SYSTEM_PROMPT = (
|
| 1146 |
"You are labeling features of a Sparse Autoencoder (SAE) trained on a "
|
| 1147 |
"vision transformer. Each SAE feature is a sparse direction in activation "
|
| 1148 |
+
"space that fires strongly on certain visual patterns. "
|
| 1149 |
+
"Each image has a colour heatmap overlay highlighting the patches where "
|
| 1150 |
+
"the feature activates most strongly."
|
| 1151 |
)
|
| 1152 |
USER_PROMPT = (
|
| 1153 |
+
"The images below show the top maximally-activating images for one SAE feature, "
|
| 1154 |
+
"with a heatmap overlay showing where in each image the feature fires most strongly. "
|
| 1155 |
+
"Focus on the highlighted regions. "
|
| 1156 |
"In 2–5 words, give a precise label for the visual concept this feature detects. "
|
| 1157 |
"Be specific — prefer 'dog snout close-up' over 'dog', or 'brick wall texture' "
|
| 1158 |
"over 'texture'. "
|
|
|
|
| 1161 |
|
| 1162 |
client = genai.Client(api_key=_gemini_api_key)
|
| 1163 |
parts = []
|
| 1164 |
+
for path, heatmap in mei_items[:_N_GEMINI_IMAGES]:
|
| 1165 |
+
resolved = _resolve_img_path(path)
|
| 1166 |
if resolved is None:
|
| 1167 |
continue
|
| 1168 |
try:
|
| 1169 |
img = Image.open(resolved).convert("RGB").resize((224, 224), Image.BILINEAR)
|
| 1170 |
+
if heatmap is not None:
|
| 1171 |
+
img_arr = np.array(img).astype(np.float32) / 255.0
|
| 1172 |
+
hm_up = cv2.resize(heatmap.astype(np.float32), (224, 224),
|
| 1173 |
+
interpolation=cv2.INTER_CUBIC)
|
| 1174 |
+
hmax = hm_up.max()
|
| 1175 |
+
if hmax > 0:
|
| 1176 |
+
hm_up /= hmax
|
| 1177 |
+
overlay = ALPHA_JET(hm_up)
|
| 1178 |
+
ov_alpha = overlay[:, :, 3:4] * _GEMINI_HM_ALPHA
|
| 1179 |
+
blended = img_arr * (1 - ov_alpha) + overlay[:, :, :3] * ov_alpha
|
| 1180 |
+
img = Image.fromarray(np.clip(blended * 255, 0, 255).astype(np.uint8))
|
| 1181 |
buf = io.BytesIO()
|
| 1182 |
img.save(buf, format="JPEG", quality=85)
|
| 1183 |
parts.append(types.Part.from_bytes(data=buf.getvalue(), mime_type="image/jpeg"))
|
|
|
|
| 1237 |
return
|
| 1238 |
|
| 1239 |
n_top_stored = top_img_idx.shape[1]
|
| 1240 |
+
mei_items = []
|
| 1241 |
for j in range(n_top_stored):
|
| 1242 |
idx = top_img_idx[feat, j].item()
|
| 1243 |
if idx >= 0:
|
| 1244 |
+
hm = None
|
| 1245 |
+
if top_heatmaps is not None:
|
| 1246 |
+
hm = top_heatmaps[feat, j].float().numpy().reshape(patch_grid, patch_grid)
|
| 1247 |
+
mei_items.append((image_paths[idx], hm))
|
| 1248 |
|
| 1249 |
+
if not mei_items:
|
| 1250 |
gemini_status_div.text = "<span style='color:#c00'>No MEI paths found.</span>"
|
| 1251 |
return
|
| 1252 |
|
|
|
|
| 1256 |
doc = curdoc()
|
| 1257 |
t = threading.Thread(
|
| 1258 |
target=_gemini_label_thread,
|
| 1259 |
+
args=(feat, mei_items, doc),
|
| 1260 |
daemon=True,
|
| 1261 |
)
|
| 1262 |
t.start()
|