Samuel Stevens commited on
Commit ·
6c9f92c
1
Parent(s): 852b07a
add mod preds; todo: add legend
Browse files
app.py
CHANGED
|
@@ -313,7 +313,7 @@ def get_orig_preds(i: int) -> dict[str, object]:
|
|
| 313 |
|
| 314 |
|
| 315 |
@beartype.beartype
|
| 316 |
-
def unscaled(x: float, max_obs: float) -> float:
|
| 317 |
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
|
| 318 |
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
|
| 319 |
|
|
@@ -333,12 +333,17 @@ def map_range(
|
|
| 333 |
|
| 334 |
@beartype.beartype
|
| 335 |
@torch.inference_mode
|
| 336 |
-
def get_mod_preds(i: int, latents: dict[
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
x = sample["image"][None, ...].to(device)
|
| 340 |
-
x_BPD = rest_of_vit.forward_start(x)
|
| 341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
|
| 343 |
|
| 344 |
err_BPD = x_BPD - x_hat_BPD
|
|
@@ -346,18 +351,14 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
|
| 346 |
values = torch.tensor(
|
| 347 |
[
|
| 348 |
unscaled(float(value), top_values[latent].max().item())
|
| 349 |
-
for
|
| 350 |
-
(value1, latent1),
|
| 351 |
-
(value2, latent2),
|
| 352 |
-
(value3, latent3),
|
| 353 |
-
]
|
| 354 |
],
|
| 355 |
-
device=
|
| 356 |
)
|
| 357 |
-
f_x_BPS[..., torch.tensor(
|
| 358 |
|
| 359 |
# Reproduce the SAE forward pass after f_x
|
| 360 |
-
|
| 361 |
einops.einsum(
|
| 362 |
f_x_BPS,
|
| 363 |
sae.W_dec,
|
|
@@ -365,14 +366,19 @@ def get_mod_preds(i: int, latents: dict[int, float]) -> dict[str, object]:
|
|
| 365 |
)
|
| 366 |
+ sae.b_dec
|
| 367 |
)
|
| 368 |
-
|
| 369 |
|
| 370 |
-
|
|
|
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
|
| 375 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
@jaxtyped(typechecker=beartype.beartype)
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
@beartype.beartype
|
| 316 |
+
def unscaled(x: float, max_obs: float | int) -> float:
|
| 317 |
"""Scale from [-10, 10] to [10 * -max_obs, 10 * max_obs]."""
|
| 318 |
return map_range(x, (-10.0, 10.0), (-10.0 * max_obs, 10.0 * max_obs))
|
| 319 |
|
|
|
|
| 333 |
|
| 334 |
@beartype.beartype
|
| 335 |
@torch.inference_mode
|
| 336 |
+
def get_mod_preds(i: int, latents: dict[str, int | float]) -> dict[str, object]:
|
| 337 |
+
latents = {int(k): float(v) for k, v in latents.items()}
|
| 338 |
+
img = data.get_img(i)
|
|
|
|
|
|
|
| 339 |
|
| 340 |
+
split_vit, vit_transform = modeling.load_vit(DEVICE)
|
| 341 |
+
sae = load_sae(DEVICE)
|
| 342 |
+
_, top_values, _ = load_tensors()
|
| 343 |
+
clf = load_clf()
|
| 344 |
+
|
| 345 |
+
x_BCWH = vit_transform(img)[None, ...].to(DEVICE)
|
| 346 |
+
x_BPD = split_vit.forward_start(x_BCWH)
|
| 347 |
x_hat_BPD, f_x_BPS, _ = sae(x_BPD)
|
| 348 |
|
| 349 |
err_BPD = x_BPD - x_hat_BPD
|
|
|
|
| 351 |
values = torch.tensor(
|
| 352 |
[
|
| 353 |
unscaled(float(value), top_values[latent].max().item())
|
| 354 |
+
for latent, value in latents.items()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
],
|
| 356 |
+
device=DEVICE,
|
| 357 |
)
|
| 358 |
+
f_x_BPS[..., torch.tensor(list(latents.keys()), device=DEVICE)] = values
|
| 359 |
|
| 360 |
# Reproduce the SAE forward pass after f_x
|
| 361 |
+
mod_x_hat_BPD = (
|
| 362 |
einops.einsum(
|
| 363 |
f_x_BPS,
|
| 364 |
sae.W_dec,
|
|
|
|
| 366 |
)
|
| 367 |
+ sae.b_dec
|
| 368 |
)
|
| 369 |
+
mod_BPD = err_BPD + mod_x_hat_BPD
|
| 370 |
|
| 371 |
+
mod_BPD = split_vit.forward_end(mod_BPD)
|
| 372 |
+
mod_WHD = einops.rearrange(mod_BPD, "() (w h) dim -> w h dim", w=16, h=16)
|
| 373 |
|
| 374 |
+
logits_WHC = clf(mod_WHD)
|
| 375 |
+
pred_WH = logits_WHC.argmax(axis=-1)
|
| 376 |
+
# pred_WH = einops.rearrange(pred_P, "(w h) -> w h", w=16, h=16)
|
| 377 |
+
return {
|
| 378 |
+
"index": i,
|
| 379 |
+
"orig_url": data.img_to_base64(data.to_sized(img)),
|
| 380 |
+
"seg_url": data.img_to_base64(data.u8_to_img(upsample(pred_WH))),
|
| 381 |
+
}
|
| 382 |
|
| 383 |
|
| 384 |
@jaxtyped(typechecker=beartype.beartype)
|