File size: 7,973 Bytes
e5e6cca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """Visual Steerability β steer PaliGemma's vision model from the inside (ZeroGPU Space).
Pass an image -> see its L10 SAE neuron activations (the model's internal representation) ->
clamp a neuron's latent -> watch PaliGemma's caption change. Companion to the explainer at
https://sumityadav.com.np/study/steerability-spectrum/
"""
import os, json
import numpy as np
import torch, torch.nn as nn
import plotly.graph_objects as go
import gradio as gr
from PIL import Image as PILImage
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
try:
import spaces # ZeroGPU
GPU = spaces.GPU
except Exception: # local / non-ZeroGPU fallback (no-op decorator)
def GPU(*a, **k):
if a and callable(a[0]): # used bare: @GPU
return a[0]
def deco(f): # used with args: @GPU(duration=...)
return f
return deco
MODEL = "google/paligemma-3b-mix-224"
SAE_PT = "sae_l10_coco.pt"
DASH = "dash_data.npz"
HF_TOKEN = os.environ.get("HF_TOKEN")
class TopKSAE(nn.Module):
def __init__(self, d, m, k):
super().__init__(); self.k = k
self.b_pre = nn.Parameter(torch.zeros(d)); self.W_enc = nn.Parameter(torch.zeros(d, m))
self.b_enc = nn.Parameter(torch.zeros(m)); self.W_dec = nn.Parameter(torch.zeros(m, d))
def encode(self, x):
pre = (x - self.b_pre) @ self.W_enc + self.b_enc
v, i = pre.topk(self.k, dim=-1)
a = torch.zeros_like(pre); a.scatter_(-1, i, torch.relu(v)); return a
# ---- load (CPU at import; moved to GPU inside @GPU fns on ZeroGPU) ----
print("loading SAE + map β¦")
ckpt = torch.load(SAE_PT, map_location="cpu", weights_only=False)
D, M, K, L = ckpt["d"], ckpt["m"], ckpt["k"], ckpt["l"]
SN = float(ckpt["scale_norm"])
sae = TopKSAE(D, M, K); sae.load_state_dict(ckpt["state_dict"]); sae.eval()
mu = torch.tensor(ckpt["mu"], dtype=torch.float32)
_d = np.load(DASH, allow_pickle=True)
COORDS = _d["coords"]; LAT = _d["latents"].astype(np.float32)
print("loading PaliGemma β¦ (first GPU call moves it to cuda)")
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN).eval()
proc = AutoProcessor.from_pretrained(MODEL, token=HF_TOKEN)
tok = proc.tokenizer
vtower = lambda: model.model.vision_tower
block = model.model.vision_tower.encoder.layers[L - 1]
_state = {"on": False, "feat": 0, "strength": 0.0, "moved": False}
def _ensure(dev):
if not _state["moved"] and dev == "cuda":
model.to("cuda"); sae.to("cuda")
globals()["mu"] = mu.to("cuda"); _state["moved"] = True
def _hook(mod, inp, out):
if not _state["on"]:
return out
h = out[0] if isinstance(out, tuple) else out
hf = h.float(); hn = (hf - mu) * SN
b, P, d = hn.shape
acts = sae.encode(hn.reshape(b * P, d))
cur = acts[:, _state["feat"]]
dn = (_state["strength"] - cur).unsqueeze(-1) * sae.W_dec[_state["feat"]]
h2 = (hf + (dn / SN).reshape(b, P, d)).to(h.dtype)
return (h2,) + tuple(out[1:]) if isinstance(out, tuple) else h2
block.register_forward_hook(_hook)
@torch.no_grad()
def _latent(pil, dev):
pv = proc.image_processor([pil], return_tensors="pt")["pixel_values"].to(dev, torch.bfloat16)
hs = vtower()(pixel_values=pv, output_hidden_states=True).hidden_states[L].float()
acts = sae.encode(((hs[0] - mu) * SN))
lat = acts.mean(0).detach().cpu().numpy()
return lat
@torch.no_grad()
def _caption(pil, dev, feat=0, strength=0.0, steer=False):
_state["on"] = steer; _state["feat"] = int(feat); _state["strength"] = float(strength)
inp = proc(text="describe the object in the image", images=pil, return_tensors="pt").to(dev, torch.bfloat16)
n = inp["input_ids"].shape[1]
g = model.generate(**inp, max_new_tokens=28, do_sample=False)
_state["on"] = False
return tok.decode(g[0][n:], skip_special_tokens=True).strip()
def scatter(coord, feat):
feat = int(feat); c = LAT[:, feat]
fig = go.Figure()
fig.add_trace(go.Scatter(x=COORDS[:, 0], y=COORDS[:, 1], mode="markers",
marker=dict(size=7, color=c, colorscale="Viridis", showscale=True, colorbar=dict(title="#%d" % feat), opacity=.55),
hovertemplate="neuron #%d = %%{marker.color:.2f}<extra></extra>" % feat, name="500 ref"))
if coord is not None:
fig.add_trace(go.Scatter(x=[float(coord[0])], y=[float(coord[1])], mode="markers",
marker=dict(size=22, color="red", symbol="star", line=dict(width=1.4, color="black")),
name="your image", hovertemplate="your image<extra></extra>"))
fig.update_layout(title="L10 SAE latent map β your image (β
), colored by neuron #%d" % feat,
xaxis_title="PC1", yaxis_title="PC2", height=420, margin=dict(l=8, r=8, t=44, b=8),
legend=dict(orientation="h", y=1.04, x=0))
return fig
def neuron_bar(lat):
idx = np.argsort(-lat)[:10][::-1]
fig = go.Figure(go.Bar(x=lat[idx], y=["#%d" % i for i in idx], orientation="h",
marker_color="#4f46e5", hovertemplate="neuron #%{y} = %{x:.2f}<extra></extra>"))
fig.update_layout(title="this image's strongest L10 SAE neurons", height=300,
xaxis_title="activation", margin=dict(l=8, r=8, t=44, b=8), yaxis=dict(type="category"))
return fig
@GPU(duration=70)
def analyze(pil):
if pil is None:
return "Upload or pick an image.", None, None, gr.update(choices=[], value=None), None
dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
lat = _latent(pil, dev)
base = _caption(pil, dev, steer=False)
top = np.argsort(-lat)[:8]
choices = [("#%d (act %.1f)" % (int(t), lat[t]), int(t)) for t in top]
coord = (lat - _d["pca_mean"]) @ _d["pca_comp"].T
return base, scatter(coord, int(top[0])), neuron_bar(lat), gr.update(choices=choices, value=int(top[0])), coord.tolist()
@GPU(duration=70)
def steer(pil, neuron, strength):
if pil is None or neuron is None:
return "Run an image first."
dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
if float(strength) == 0:
return "Move the slider to clamp this neuron."
return _caption(pil, dev, feat=int(neuron), strength=float(strength), steer=True)
def recolor(coord, neuron):
if coord is None or neuron is None:
return None
return scatter(np.array(coord), int(neuron))
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="Visual Steerability") as demo:
gr.Markdown("# ποΈ Steer PaliGemma's vision model from the inside\n"
"Pass an image β see its **L10 SAE neuron activations** (the model's internal representation) β "
"**clamp a neuron** and watch the caption change. "
"Read the explainer: [The Steerability Spectrum](https://sumityadav.com.np/study/steerability-spectrum/).")
st_coord = gr.State()
with gr.Row():
with gr.Column(scale=4):
img = gr.Image(type="pil", label="image", height=300)
run = gr.Button("β Caption + read internal representation", variant="primary")
base = gr.Textbox(label="caption", lines=2, interactive=False)
neuron = gr.Radio(choices=[], label="β‘ pick a neuron to steer", interactive=True)
strength = gr.Slider(-40, 40, value=0, step=1, label="β’ clamp strength β re-caption on release")
out = gr.Textbox(label="steered caption", lines=2, interactive=False)
with gr.Column(scale=5):
plot = gr.Plot(label="SAE latent map")
bar = gr.Plot(label="internal representation")
gr.Examples(examples=[["examples/" + f] for f in sorted(os.listdir("examples"))], inputs=[img], label="examples")
run.click(analyze, [img], [base, plot, bar, neuron, st_coord])
neuron.change(recolor, [st_coord, neuron], [plot])
strength.release(steer, [img, neuron, strength], [out])
demo.queue(max_size=12).launch()
|