"""Visual Steerability — steer PaliGemma's vision model from the inside (ZeroGPU Space). Pass an image -> see its L10 SAE neuron activations (the model's internal representation) -> clamp a neuron's latent -> watch PaliGemma's caption change. Companion to the explainer at https://sumityadav.com.np/study/steerability-spectrum/ """ import os, json import numpy as np import torch, torch.nn as nn import plotly.graph_objects as go import gradio as gr from PIL import Image as PILImage from transformers import AutoProcessor, PaliGemmaForConditionalGeneration try: import spaces # ZeroGPU GPU = spaces.GPU except Exception: # local / non-ZeroGPU fallback (no-op decorator) def GPU(*a, **k): if a and callable(a[0]): # used bare: @GPU return a[0] def deco(f): # used with args: @GPU(duration=...) return f return deco MODEL = "google/paligemma-3b-mix-224" SAE_PT = "sae_l10_coco.pt" DASH = "dash_data.npz" HF_TOKEN = os.environ.get("HF_TOKEN") class TopKSAE(nn.Module): def __init__(self, d, m, k): super().__init__(); self.k = k self.b_pre = nn.Parameter(torch.zeros(d)); self.W_enc = nn.Parameter(torch.zeros(d, m)) self.b_enc = nn.Parameter(torch.zeros(m)); self.W_dec = nn.Parameter(torch.zeros(m, d)) def encode(self, x): pre = (x - self.b_pre) @ self.W_enc + self.b_enc v, i = pre.topk(self.k, dim=-1) a = torch.zeros_like(pre); a.scatter_(-1, i, torch.relu(v)); return a # ---- load (CPU at import; moved to GPU inside @GPU fns on ZeroGPU) ---- print("loading SAE + map …") ckpt = torch.load(SAE_PT, map_location="cpu", weights_only=False) D, M, K, L = ckpt["d"], ckpt["m"], ckpt["k"], ckpt["l"] SN = float(ckpt["scale_norm"]) sae = TopKSAE(D, M, K); sae.load_state_dict(ckpt["state_dict"]); sae.eval() mu = torch.tensor(ckpt["mu"], dtype=torch.float32) _d = np.load(DASH, allow_pickle=True) COORDS = _d["coords"]; LAT = _d["latents"].astype(np.float32) print("loading PaliGemma … (first GPU call moves it to cuda)") model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN).eval() proc = AutoProcessor.from_pretrained(MODEL, token=HF_TOKEN) tok = proc.tokenizer vtower = lambda: model.model.vision_tower block = model.model.vision_tower.encoder.layers[L - 1] _state = {"on": False, "feat": 0, "strength": 0.0, "moved": False} def _ensure(dev): if not _state["moved"] and dev == "cuda": model.to("cuda"); sae.to("cuda") globals()["mu"] = mu.to("cuda"); _state["moved"] = True def _hook(mod, inp, out): if not _state["on"]: return out h = out[0] if isinstance(out, tuple) else out hf = h.float(); hn = (hf - mu) * SN b, P, d = hn.shape acts = sae.encode(hn.reshape(b * P, d)) cur = acts[:, _state["feat"]] dn = (_state["strength"] - cur).unsqueeze(-1) * sae.W_dec[_state["feat"]] h2 = (hf + (dn / SN).reshape(b, P, d)).to(h.dtype) return (h2,) + tuple(out[1:]) if isinstance(out, tuple) else h2 block.register_forward_hook(_hook) @torch.no_grad() def _latent(pil, dev): pv = proc.image_processor([pil], return_tensors="pt")["pixel_values"].to(dev, torch.bfloat16) hs = vtower()(pixel_values=pv, output_hidden_states=True).hidden_states[L].float() acts = sae.encode(((hs[0] - mu) * SN)) lat = acts.mean(0).detach().cpu().numpy() return lat @torch.no_grad() def _caption(pil, dev, feat=0, strength=0.0, steer=False): _state["on"] = steer; _state["feat"] = int(feat); _state["strength"] = float(strength) inp = proc(text="describe the object in the image", images=pil, return_tensors="pt").to(dev, torch.bfloat16) n = inp["input_ids"].shape[1] g = model.generate(**inp, max_new_tokens=28, do_sample=False) _state["on"] = False return tok.decode(g[0][n:], skip_special_tokens=True).strip() def scatter(coord, feat): feat = int(feat); c = LAT[:, feat] fig = go.Figure() fig.add_trace(go.Scatter(x=COORDS[:, 0], y=COORDS[:, 1], mode="markers", marker=dict(size=7, color=c, colorscale="Viridis", showscale=True, colorbar=dict(title="#%d" % feat), opacity=.55), hovertemplate="neuron #%d = %%{marker.color:.2f}" % feat, name="500 ref")) if coord is not None: fig.add_trace(go.Scatter(x=[float(coord[0])], y=[float(coord[1])], mode="markers", marker=dict(size=22, color="red", symbol="star", line=dict(width=1.4, color="black")), name="your image", hovertemplate="your image")) fig.update_layout(title="L10 SAE latent map — your image (★), colored by neuron #%d" % feat, xaxis_title="PC1", yaxis_title="PC2", height=420, margin=dict(l=8, r=8, t=44, b=8), legend=dict(orientation="h", y=1.04, x=0)) return fig def neuron_bar(lat): idx = np.argsort(-lat)[:10][::-1] fig = go.Figure(go.Bar(x=lat[idx], y=["#%d" % i for i in idx], orientation="h", marker_color="#4f46e5", hovertemplate="neuron #%{y} = %{x:.2f}")) fig.update_layout(title="this image's strongest L10 SAE neurons", height=300, xaxis_title="activation", margin=dict(l=8, r=8, t=44, b=8), yaxis=dict(type="category")) return fig @GPU(duration=70) def analyze(pil): if pil is None: return "Upload or pick an image.", None, None, gr.update(choices=[], value=None), None dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev) lat = _latent(pil, dev) base = _caption(pil, dev, steer=False) top = np.argsort(-lat)[:8] choices = [("#%d (act %.1f)" % (int(t), lat[t]), int(t)) for t in top] coord = (lat - _d["pca_mean"]) @ _d["pca_comp"].T return base, scatter(coord, int(top[0])), neuron_bar(lat), gr.update(choices=choices, value=int(top[0])), coord.tolist() @GPU(duration=70) def steer(pil, neuron, strength): if pil is None or neuron is None: return "Run an image first." dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev) if float(strength) == 0: return "Move the slider to clamp this neuron." return _caption(pil, dev, feat=int(neuron), strength=float(strength), steer=True) def recolor(coord, neuron): if coord is None or neuron is None: return None return scatter(np.array(coord), int(neuron)) with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="Visual Steerability") as demo: gr.Markdown("# 🎛️ Steer PaliGemma's vision model from the inside\n" "Pass an image → see its **L10 SAE neuron activations** (the model's internal representation) → " "**clamp a neuron** and watch the caption change. " "Read the explainer: [The Steerability Spectrum](https://sumityadav.com.np/study/steerability-spectrum/).") st_coord = gr.State() with gr.Row(): with gr.Column(scale=4): img = gr.Image(type="pil", label="image", height=300) run = gr.Button("① Caption + read internal representation", variant="primary") base = gr.Textbox(label="caption", lines=2, interactive=False) neuron = gr.Radio(choices=[], label="② pick a neuron to steer", interactive=True) strength = gr.Slider(-40, 40, value=0, step=1, label="③ clamp strength → re-caption on release") out = gr.Textbox(label="steered caption", lines=2, interactive=False) with gr.Column(scale=5): plot = gr.Plot(label="SAE latent map") bar = gr.Plot(label="internal representation") gr.Examples(examples=[["examples/" + f] for f in sorted(os.listdir("examples"))], inputs=[img], label="examples") run.click(analyze, [img], [base, plot, bar, neuron, st_coord]) neuron.change(recolor, [st_coord, neuron], [plot]) strength.release(steer, [img, neuron, strength], [out]) demo.queue(max_size=12).launch()