rockerritesh's picture
Visual steerability dashboard (CPU PaliGemma + L10 SAE)
e5e6cca verified
Raw
History Blame Contribute Delete
7.97 kB
"""Visual Steerability — steer PaliGemma's vision model from the inside (ZeroGPU Space).
Pass an image -> see its L10 SAE neuron activations (the model's internal representation) ->
clamp a neuron's latent -> watch PaliGemma's caption change. Companion to the explainer at
https://sumityadav.com.np/study/steerability-spectrum/
"""
import os, json
import numpy as np
import torch, torch.nn as nn
import plotly.graph_objects as go
import gradio as gr
from PIL import Image as PILImage
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
try:
import spaces # ZeroGPU
GPU = spaces.GPU
except Exception: # local / non-ZeroGPU fallback (no-op decorator)
def GPU(*a, **k):
if a and callable(a[0]): # used bare: @GPU
return a[0]
def deco(f): # used with args: @GPU(duration=...)
return f
return deco
MODEL = "google/paligemma-3b-mix-224"
SAE_PT = "sae_l10_coco.pt"
DASH = "dash_data.npz"
HF_TOKEN = os.environ.get("HF_TOKEN")
class TopKSAE(nn.Module):
def __init__(self, d, m, k):
super().__init__(); self.k = k
self.b_pre = nn.Parameter(torch.zeros(d)); self.W_enc = nn.Parameter(torch.zeros(d, m))
self.b_enc = nn.Parameter(torch.zeros(m)); self.W_dec = nn.Parameter(torch.zeros(m, d))
def encode(self, x):
pre = (x - self.b_pre) @ self.W_enc + self.b_enc
v, i = pre.topk(self.k, dim=-1)
a = torch.zeros_like(pre); a.scatter_(-1, i, torch.relu(v)); return a
# ---- load (CPU at import; moved to GPU inside @GPU fns on ZeroGPU) ----
print("loading SAE + map …")
ckpt = torch.load(SAE_PT, map_location="cpu", weights_only=False)
D, M, K, L = ckpt["d"], ckpt["m"], ckpt["k"], ckpt["l"]
SN = float(ckpt["scale_norm"])
sae = TopKSAE(D, M, K); sae.load_state_dict(ckpt["state_dict"]); sae.eval()
mu = torch.tensor(ckpt["mu"], dtype=torch.float32)
_d = np.load(DASH, allow_pickle=True)
COORDS = _d["coords"]; LAT = _d["latents"].astype(np.float32)
print("loading PaliGemma … (first GPU call moves it to cuda)")
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN).eval()
proc = AutoProcessor.from_pretrained(MODEL, token=HF_TOKEN)
tok = proc.tokenizer
vtower = lambda: model.model.vision_tower
block = model.model.vision_tower.encoder.layers[L - 1]
_state = {"on": False, "feat": 0, "strength": 0.0, "moved": False}
def _ensure(dev):
if not _state["moved"] and dev == "cuda":
model.to("cuda"); sae.to("cuda")
globals()["mu"] = mu.to("cuda"); _state["moved"] = True
def _hook(mod, inp, out):
if not _state["on"]:
return out
h = out[0] if isinstance(out, tuple) else out
hf = h.float(); hn = (hf - mu) * SN
b, P, d = hn.shape
acts = sae.encode(hn.reshape(b * P, d))
cur = acts[:, _state["feat"]]
dn = (_state["strength"] - cur).unsqueeze(-1) * sae.W_dec[_state["feat"]]
h2 = (hf + (dn / SN).reshape(b, P, d)).to(h.dtype)
return (h2,) + tuple(out[1:]) if isinstance(out, tuple) else h2
block.register_forward_hook(_hook)
@torch.no_grad()
def _latent(pil, dev):
pv = proc.image_processor([pil], return_tensors="pt")["pixel_values"].to(dev, torch.bfloat16)
hs = vtower()(pixel_values=pv, output_hidden_states=True).hidden_states[L].float()
acts = sae.encode(((hs[0] - mu) * SN))
lat = acts.mean(0).detach().cpu().numpy()
return lat
@torch.no_grad()
def _caption(pil, dev, feat=0, strength=0.0, steer=False):
_state["on"] = steer; _state["feat"] = int(feat); _state["strength"] = float(strength)
inp = proc(text="describe the object in the image", images=pil, return_tensors="pt").to(dev, torch.bfloat16)
n = inp["input_ids"].shape[1]
g = model.generate(**inp, max_new_tokens=28, do_sample=False)
_state["on"] = False
return tok.decode(g[0][n:], skip_special_tokens=True).strip()
def scatter(coord, feat):
feat = int(feat); c = LAT[:, feat]
fig = go.Figure()
fig.add_trace(go.Scatter(x=COORDS[:, 0], y=COORDS[:, 1], mode="markers",
marker=dict(size=7, color=c, colorscale="Viridis", showscale=True, colorbar=dict(title="#%d" % feat), opacity=.55),
hovertemplate="neuron #%d = %%{marker.color:.2f}<extra></extra>" % feat, name="500 ref"))
if coord is not None:
fig.add_trace(go.Scatter(x=[float(coord[0])], y=[float(coord[1])], mode="markers",
marker=dict(size=22, color="red", symbol="star", line=dict(width=1.4, color="black")),
name="your image", hovertemplate="your image<extra></extra>"))
fig.update_layout(title="L10 SAE latent map — your image (★), colored by neuron #%d" % feat,
xaxis_title="PC1", yaxis_title="PC2", height=420, margin=dict(l=8, r=8, t=44, b=8),
legend=dict(orientation="h", y=1.04, x=0))
return fig
def neuron_bar(lat):
idx = np.argsort(-lat)[:10][::-1]
fig = go.Figure(go.Bar(x=lat[idx], y=["#%d" % i for i in idx], orientation="h",
marker_color="#4f46e5", hovertemplate="neuron #%{y} = %{x:.2f}<extra></extra>"))
fig.update_layout(title="this image's strongest L10 SAE neurons", height=300,
xaxis_title="activation", margin=dict(l=8, r=8, t=44, b=8), yaxis=dict(type="category"))
return fig
@GPU(duration=70)
def analyze(pil):
if pil is None:
return "Upload or pick an image.", None, None, gr.update(choices=[], value=None), None
dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
lat = _latent(pil, dev)
base = _caption(pil, dev, steer=False)
top = np.argsort(-lat)[:8]
choices = [("#%d (act %.1f)" % (int(t), lat[t]), int(t)) for t in top]
coord = (lat - _d["pca_mean"]) @ _d["pca_comp"].T
return base, scatter(coord, int(top[0])), neuron_bar(lat), gr.update(choices=choices, value=int(top[0])), coord.tolist()
@GPU(duration=70)
def steer(pil, neuron, strength):
if pil is None or neuron is None:
return "Run an image first."
dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
if float(strength) == 0:
return "Move the slider to clamp this neuron."
return _caption(pil, dev, feat=int(neuron), strength=float(strength), steer=True)
def recolor(coord, neuron):
if coord is None or neuron is None:
return None
return scatter(np.array(coord), int(neuron))
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="Visual Steerability") as demo:
gr.Markdown("# 🎛️ Steer PaliGemma's vision model from the inside\n"
"Pass an image → see its **L10 SAE neuron activations** (the model's internal representation) → "
"**clamp a neuron** and watch the caption change. "
"Read the explainer: [The Steerability Spectrum](https://sumityadav.com.np/study/steerability-spectrum/).")
st_coord = gr.State()
with gr.Row():
with gr.Column(scale=4):
img = gr.Image(type="pil", label="image", height=300)
run = gr.Button("① Caption + read internal representation", variant="primary")
base = gr.Textbox(label="caption", lines=2, interactive=False)
neuron = gr.Radio(choices=[], label="② pick a neuron to steer", interactive=True)
strength = gr.Slider(-40, 40, value=0, step=1, label="③ clamp strength → re-caption on release")
out = gr.Textbox(label="steered caption", lines=2, interactive=False)
with gr.Column(scale=5):
plot = gr.Plot(label="SAE latent map")
bar = gr.Plot(label="internal representation")
gr.Examples(examples=[["examples/" + f] for f in sorted(os.listdir("examples"))], inputs=[img], label="examples")
run.click(analyze, [img], [base, plot, bar, neuron, st_coord])
neuron.change(recolor, [st_coord, neuron], [plot])
strength.release(steer, [img, neuron, strength], [out])
demo.queue(max_size=12).launch()