File size: 7,973 Bytes
e5e6cca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Visual Steerability β€” steer PaliGemma's vision model from the inside (ZeroGPU Space).

Pass an image -> see its L10 SAE neuron activations (the model's internal representation) ->
clamp a neuron's latent -> watch PaliGemma's caption change. Companion to the explainer at
https://sumityadav.com.np/study/steerability-spectrum/
"""
import os, json
import numpy as np
import torch, torch.nn as nn
import plotly.graph_objects as go
import gradio as gr
from PIL import Image as PILImage
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

try:
    import spaces  # ZeroGPU
    GPU = spaces.GPU
except Exception:                                    # local / non-ZeroGPU fallback (no-op decorator)
    def GPU(*a, **k):
        if a and callable(a[0]):                     # used bare: @GPU
            return a[0]
        def deco(f):                                 # used with args: @GPU(duration=...)
            return f
        return deco

MODEL = "google/paligemma-3b-mix-224"
SAE_PT = "sae_l10_coco.pt"
DASH = "dash_data.npz"
HF_TOKEN = os.environ.get("HF_TOKEN")


class TopKSAE(nn.Module):
    def __init__(self, d, m, k):
        super().__init__(); self.k = k
        self.b_pre = nn.Parameter(torch.zeros(d)); self.W_enc = nn.Parameter(torch.zeros(d, m))
        self.b_enc = nn.Parameter(torch.zeros(m)); self.W_dec = nn.Parameter(torch.zeros(m, d))

    def encode(self, x):
        pre = (x - self.b_pre) @ self.W_enc + self.b_enc
        v, i = pre.topk(self.k, dim=-1)
        a = torch.zeros_like(pre); a.scatter_(-1, i, torch.relu(v)); return a


# ---- load (CPU at import; moved to GPU inside @GPU fns on ZeroGPU) ----
print("loading SAE + map …")
ckpt = torch.load(SAE_PT, map_location="cpu", weights_only=False)
D, M, K, L = ckpt["d"], ckpt["m"], ckpt["k"], ckpt["l"]
SN = float(ckpt["scale_norm"])
sae = TopKSAE(D, M, K); sae.load_state_dict(ckpt["state_dict"]); sae.eval()
mu = torch.tensor(ckpt["mu"], dtype=torch.float32)
_d = np.load(DASH, allow_pickle=True)
COORDS = _d["coords"]; LAT = _d["latents"].astype(np.float32)

print("loading PaliGemma … (first GPU call moves it to cuda)")
model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL, torch_dtype=torch.bfloat16, token=HF_TOKEN).eval()
proc = AutoProcessor.from_pretrained(MODEL, token=HF_TOKEN)
tok = proc.tokenizer
vtower = lambda: model.model.vision_tower
block = model.model.vision_tower.encoder.layers[L - 1]
_state = {"on": False, "feat": 0, "strength": 0.0, "moved": False}


def _ensure(dev):
    if not _state["moved"] and dev == "cuda":
        model.to("cuda"); sae.to("cuda")
        globals()["mu"] = mu.to("cuda"); _state["moved"] = True


def _hook(mod, inp, out):
    if not _state["on"]:
        return out
    h = out[0] if isinstance(out, tuple) else out
    hf = h.float(); hn = (hf - mu) * SN
    b, P, d = hn.shape
    acts = sae.encode(hn.reshape(b * P, d))
    cur = acts[:, _state["feat"]]
    dn = (_state["strength"] - cur).unsqueeze(-1) * sae.W_dec[_state["feat"]]
    h2 = (hf + (dn / SN).reshape(b, P, d)).to(h.dtype)
    return (h2,) + tuple(out[1:]) if isinstance(out, tuple) else h2
block.register_forward_hook(_hook)


@torch.no_grad()
def _latent(pil, dev):
    pv = proc.image_processor([pil], return_tensors="pt")["pixel_values"].to(dev, torch.bfloat16)
    hs = vtower()(pixel_values=pv, output_hidden_states=True).hidden_states[L].float()
    acts = sae.encode(((hs[0] - mu) * SN))
    lat = acts.mean(0).detach().cpu().numpy()
    return lat


@torch.no_grad()
def _caption(pil, dev, feat=0, strength=0.0, steer=False):
    _state["on"] = steer; _state["feat"] = int(feat); _state["strength"] = float(strength)
    inp = proc(text="describe the object in the image", images=pil, return_tensors="pt").to(dev, torch.bfloat16)
    n = inp["input_ids"].shape[1]
    g = model.generate(**inp, max_new_tokens=28, do_sample=False)
    _state["on"] = False
    return tok.decode(g[0][n:], skip_special_tokens=True).strip()


def scatter(coord, feat):
    feat = int(feat); c = LAT[:, feat]
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=COORDS[:, 0], y=COORDS[:, 1], mode="markers",
        marker=dict(size=7, color=c, colorscale="Viridis", showscale=True, colorbar=dict(title="#%d" % feat), opacity=.55),
        hovertemplate="neuron #%d = %%{marker.color:.2f}<extra></extra>" % feat, name="500 ref"))
    if coord is not None:
        fig.add_trace(go.Scatter(x=[float(coord[0])], y=[float(coord[1])], mode="markers",
            marker=dict(size=22, color="red", symbol="star", line=dict(width=1.4, color="black")),
            name="your image", hovertemplate="your image<extra></extra>"))
    fig.update_layout(title="L10 SAE latent map β€” your image (β˜…), colored by neuron #%d" % feat,
        xaxis_title="PC1", yaxis_title="PC2", height=420, margin=dict(l=8, r=8, t=44, b=8),
        legend=dict(orientation="h", y=1.04, x=0))
    return fig


def neuron_bar(lat):
    idx = np.argsort(-lat)[:10][::-1]
    fig = go.Figure(go.Bar(x=lat[idx], y=["#%d" % i for i in idx], orientation="h",
        marker_color="#4f46e5", hovertemplate="neuron #%{y} = %{x:.2f}<extra></extra>"))
    fig.update_layout(title="this image's strongest L10 SAE neurons", height=300,
        xaxis_title="activation", margin=dict(l=8, r=8, t=44, b=8), yaxis=dict(type="category"))
    return fig


@GPU(duration=70)
def analyze(pil):
    if pil is None:
        return "Upload or pick an image.", None, None, gr.update(choices=[], value=None), None
    dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
    lat = _latent(pil, dev)
    base = _caption(pil, dev, steer=False)
    top = np.argsort(-lat)[:8]
    choices = [("#%d  (act %.1f)" % (int(t), lat[t]), int(t)) for t in top]
    coord = (lat - _d["pca_mean"]) @ _d["pca_comp"].T
    return base, scatter(coord, int(top[0])), neuron_bar(lat), gr.update(choices=choices, value=int(top[0])), coord.tolist()


@GPU(duration=70)
def steer(pil, neuron, strength):
    if pil is None or neuron is None:
        return "Run an image first."
    dev = "cuda" if torch.cuda.is_available() else "cpu"; _ensure(dev)
    if float(strength) == 0:
        return "Move the slider to clamp this neuron."
    return _caption(pil, dev, feat=int(neuron), strength=float(strength), steer=True)


def recolor(coord, neuron):
    if coord is None or neuron is None:
        return None
    return scatter(np.array(coord), int(neuron))


with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), title="Visual Steerability") as demo:
    gr.Markdown("# πŸŽ›οΈ Steer PaliGemma's vision model from the inside\n"
        "Pass an image β†’ see its **L10 SAE neuron activations** (the model's internal representation) β†’ "
        "**clamp a neuron** and watch the caption change. "
        "Read the explainer: [The Steerability Spectrum](https://sumityadav.com.np/study/steerability-spectrum/).")
    st_coord = gr.State()
    with gr.Row():
        with gr.Column(scale=4):
            img = gr.Image(type="pil", label="image", height=300)
            run = gr.Button("β‘  Caption + read internal representation", variant="primary")
            base = gr.Textbox(label="caption", lines=2, interactive=False)
            neuron = gr.Radio(choices=[], label="β‘‘ pick a neuron to steer", interactive=True)
            strength = gr.Slider(-40, 40, value=0, step=1, label="β‘’ clamp strength β†’ re-caption on release")
            out = gr.Textbox(label="steered caption", lines=2, interactive=False)
        with gr.Column(scale=5):
            plot = gr.Plot(label="SAE latent map")
            bar = gr.Plot(label="internal representation")
    gr.Examples(examples=[["examples/" + f] for f in sorted(os.listdir("examples"))], inputs=[img], label="examples")
    run.click(analyze, [img], [base, plot, bar, neuron, st_coord])
    neuron.change(recolor, [st_coord, neuron], [plot])
    strength.release(steer, [img, neuron, strength], [out])

demo.queue(max_size=12).launch()