| """Visual Steerability — steer PaliGemma's vision model from the inside (ZeroGPU Space). |
| |
| Pass an image -> see its L10 SAE neuron activations (the model's internal representation) -> |
| clamp a neuron's latent -> watch PaliGemma's caption change. Companion to the explainer at |
| https://sumityadav.com.np/study/steerability-spectrum/ |
| """ |
| import os, json |
| import numpy as np |
| import torch, torch.nn as nn |
| import plotly.graph_objects as go |
| import gradio as gr |
| from PIL import Image as PILImage |
| from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
|
|
| try: |
| import spaces |
| GPU = spaces.GPU |
| except Exception: |
| def GPU(*a, **k): |
| if a and callable(a[0]): |
| return a[0] |
| def deco(f): |
| return f |
| return deco |
|
|
| MODEL = "google/paligemma-3b-mix-224" |
| SAE_PT = "sae_l10_coco.pt" |
| DASH = "dash_data.npz" |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
|
|
| class TopKSAE(nn.Module): |
| def __init__(self, d, m, k): |
| super().__init__(); self.k = k |
| self.b_pre = nn.Parameter(torch.zeros(d)); self.W_enc = nn.Parameter(torch.zeros(d, m)) |
| self.b_enc = nn.Parameter(torch.zeros(m)); self.W_dec = nn.Parameter(torch.zeros(m, d)) |
|
|
| def encode(self, x): |
| pre = (x - self.b_pre) @ self.W_enc + self.b_enc |
| v, i = pre.topk(self.k, dim=-1) |
| a = torch.zeros_like(pre); a.scatter_(-1, i, torch.relu(v)); return a |
|
|
|
|
| |
| print("loading SAE + map …") |
| ckpt = torch.load(SAE_PT, map_location="cpu", weights_only=False) |
| D, M, K, L = ckpt["d"], ckpt["m"], ckpt["k"], ckpt["l"] |
| SN = float(ckpt["scale_norm"]) |
| sae = TopKSAE(D, M, K); sae.load_state_dict(ckpt["state_dict"]); sae.eval() |
| mu = torch.tensor(ckpt["mu"], dtype=torch.float32) |
| _d = np.load(DASH, allow_pickle=True) |
| COORDS = _d["coords"]; LAT = _d["latents"].astype(np.float32) |
|
|
| print("loading PaliGemma … (first GPU call moves it to cuda)") |
| model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN).eval() |
| proc = AutoProcessor.from_pretrained(MODEL, token=HF_TOKEN) |
| tok = proc.tokenizer |
| vtower = lambda: model.model.vision_tower |
| block = model.model.vision_tower.encoder.layers[L - 1] |
| _state = {"on": False, "feat": 0, "strength": 0.0, "moved": False} |
|
|
|
|
| def _ensure(dev): |
| if not _state["moved"] and dev == "cuda": |
| model.to("cuda"); sae.to("cuda") |
| globals()["mu"] = mu.to("cuda"); _state["moved"] = True |
|
|
|
|
| def _hook(mod, inp, out): |
| if not _state["on"]: |
| return out |
| h = out[0] if isinstance(out, tuple) else out |
| hf = h.float(); hn = (hf - mu) * SN |
| b, P, d = hn.shape |
| acts = sae.encode(hn.reshape(b * P, d)) |
| cur = acts[:, _state["feat"]] |
| dn = (_state["strength"] - cur).unsqueeze(-1) * sae.W_dec[_state["feat"]] |
| h2 = (hf + (dn / SN).reshape(b, P, d)).to(h.dtype) |
| return (h2,) + tuple(out[1:]) if isinstance(out, tuple) else h2 |
| block.register_forward_hook(_hook) |
|
|
|
|
| @torch.no_grad() |
| def _latent(pil, dev): |
| pv = proc.image_processor([pil], return_tensors="pt")["pixel_values"].to(dev, torch.bfloat16) |
| hs = vtower()(pixel_values=pv, output_hidden_states=True).hidden_states[L].float() |
| acts = sae.encode(((hs[0] - mu) * SN)) |
| lat = acts.mean(0).detach().cpu().numpy() |
| return lat |
|
|
|
|
| @torch.no_grad() |
| def _caption(pil, dev, feat=0, strength=0.0, steer=False): |
| _state["on"] = steer; _state["feat"] = int(feat); _state["strength"] = float(strength) |
| inp = proc(text="describe the object in the image", images=pil, return_tensors="pt").to(dev, torch.bfloat16) |
| n = inp["input_ids"].shape[1] |
| g = model.generate(**inp, max_new_tokens=28, do_sample=False) |
| _state["on"] = False |
| return tok.decode(g[0][n:], skip_special_tokens=True).strip() |
|
|
|
|
| def scatter(coord, feat): |
| feat = int(feat); c = LAT[:, feat] |
| fig = go.Figure() |
| fig.add_trace(go.Scatter(x=COORDS[:, 0], y=COORDS[:, 1], mode="markers", |
| marker=dict(size=7, color=c, colorscale="Viridis", showscale=True, colorbar=dict(title="#%d" % feat), opacity=.55), |
| hovertemplate="neuron #%d = %%{marker.color:.2f}<extra></extra>" % feat, name="500 ref")) |
| if coord is not None: |
| fig.add_trace(go.Scatter(x=[float(coord[0])], y=[float(coord[1])], mode="markers", |
| marker=dict(size=22, color="red", symbol="star", line=dict(width=1.4, color="black")), |
| name="your image", hovertemplate="your image<extra></extra>")) |
| fig.update_layout(title="L10 SAE latent map — your image (★), colored by neuron #%d" % feat, |
| xaxis_title="PC1", yaxis_title="PC2", height=420, margin=dict(l=8, r=8, t=44, b=8), |
| legend=dict(orientation="h", y=1.04, x=0)) |
| return fig |
|
|
|
|
| def neuron_bar(lat): |
| idx = np.argsort(-lat)[:10][::-1] |
| fig = go.Figure(go.Bar(x=lat[idx], y=["#%d" % i for i in idx], orientation="h", |
| marker_color="#4f46e5", hovertemplate="neuron #%{y} = %{x:.2f}<extra></extra>")) |
| fig.update_layout(title="this image's strongest L10 SAE neurons", height=300, |
| xaxis_title="activation", margin=dict(l=8, r=8, t=44, b=8), yaxis=dict(type="category")) |
| return fig |
|
|
|
|
| @GPU(duration=70) |
| def analyze(pil): |
| if pil is None: |
| return "Upload or pick an image.", None, None, gr.update(choices=[], value=None), None |
| dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev) |
| lat = _latent(pil, dev) |
| base = _caption(pil, dev, steer=False) |
| top = np.argsort(-lat)[:8] |
| choices = [("#%d (act %.1f)" % (int(t), lat[t]), int(t)) for t in top] |
| coord = (lat - _d["pca_mean"]) @ _d["pca_comp"].T |
| return base, scatter(coord, int(top[0])), neuron_bar(lat), gr.update(choices=choices, value=int(top[0])), coord.tolist() |
|
|
|
|
| @GPU(duration=70) |
| def steer(pil, neuron, strength): |
| if pil is None or neuron is None: |
| return "Run an image first." |
| dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev) |
| if float(strength) == 0: |
| return "Move the slider to clamp this neuron." |
| return _caption(pil, dev, feat=int(neuron), strength=float(strength), steer=True) |
|
|
|
|
| def recolor(coord, neuron): |
| if coord is None or neuron is None: |
| return None |
| return scatter(np.array(coord), int(neuron)) |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="Visual Steerability") as demo: |
| gr.Markdown("# 🎛️ Steer PaliGemma's vision model from the inside\n" |
| "Pass an image → see its **L10 SAE neuron activations** (the model's internal representation) → " |
| "**clamp a neuron** and watch the caption change. " |
| "Read the explainer: [The Steerability Spectrum](https://sumityadav.com.np/study/steerability-spectrum/).") |
| st_coord = gr.State() |
| with gr.Row(): |
| with gr.Column(scale=4): |
| img = gr.Image(type="pil", label="image", height=300) |
| run = gr.Button("① Caption + read internal representation", variant="primary") |
| base = gr.Textbox(label="caption", lines=2, interactive=False) |
| neuron = gr.Radio(choices=[], label="② pick a neuron to steer", interactive=True) |
| strength = gr.Slider(-40, 40, value=0, step=1, label="③ clamp strength → re-caption on release") |
| out = gr.Textbox(label="steered caption", lines=2, interactive=False) |
| with gr.Column(scale=5): |
| plot = gr.Plot(label="SAE latent map") |
| bar = gr.Plot(label="internal representation") |
| gr.Examples(examples=[["examples/" + f] for f in sorted(os.listdir("examples"))], inputs=[img], label="examples") |
| run.click(analyze, [img], [base, plot, bar, neuron, st_coord]) |
| neuron.change(recolor, [st_coord, neuron], [plot]) |
| strength.release(steer, [img, neuron, strength], [out]) |
|
|
| demo.queue(max_size=12).launch() |
|
|