File size: 6,670 Bytes
821536e
 
 
30b5d0d
 
 
 
 
 
821536e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4328168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821536e
 
 
 
 
 
 
 
 
4328168
 
 
 
 
 
 
 
 
821536e
4328168
821536e
 
4328168
821536e
4328168
 
 
821536e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""Safety-Lens: The Model MRI — a real-time activation scanner for HF models."""

import os

# ZeroGPU: must import spaces BEFORE torch/CUDA
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
if IS_HF_SPACE:
    import spaces

import gradio as gr
import torch
import plotly.graph_objects as go
from transformers import AutoModelForCausalLM, AutoTokenizer
from safety_lens.core import SafetyLens, LensHooks
from safety_lens.vectors import STIMULUS_SETS

# --- Globals (populated on model load) ---
_state = {"lens": None, "model": None, "tokenizer": None, "vectors": {}}

DEFAULT_MODEL = "gpt2"
DEFAULT_LAYER = 6
NUM_GENERATE_TOKENS = 30


def _calibrate_on_gpu(model, tokenizer, layer_idx: int):
    """Calibrate persona vectors — runs inside @spaces.GPU on ZeroGPU."""
    if torch.cuda.is_available():
        model = model.half().to("cuda")

    model.eval()
    lens = SafetyLens(model, tokenizer)

    _state["lens"] = lens
    _state["model"] = model
    _state["tokenizer"] = tokenizer
    _state["vectors"] = {}

    vectors = {}
    for name, stim in STIMULUS_SETS.items():
        vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx)
        vectors[name] = vec

    _state["vectors"] = vectors
    return lens.device, list(vectors.keys())


# Wrap calibration for ZeroGPU when on HF Spaces
if IS_HF_SPACE:
    _calibrate_on_gpu = spaces.GPU()(_calibrate_on_gpu)


def load_model(model_id: str, layer_idx: int):
    """Load a model and calibrate persona vectors."""
    status_lines = [f"Loading {model_id}..."]
    yield "\n".join(status_lines), None, None

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # On ZeroGPU, load on CPU first — GPU is only available inside @spaces.GPU
    if IS_HF_SPACE:
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
    else:
        dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        device_map = "auto" if torch.cuda.is_available() else "cpu"
        model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=dtype, device_map=device_map
        )

    status_lines.append(f"Model loaded. Calibrating persona vectors on layer {layer_idx}...")
    yield "\n".join(status_lines), None, None

    device, calibrated = _calibrate_on_gpu(model, tokenizer, layer_idx)

    for name in calibrated:
        status_lines.append(f"  Calibrated: {name}")
    status_lines.append(f"Ready for scanning on {device}.")
    yield "\n".join(status_lines), None, None


def _run_mri_inner(prompt: str, persona_name: str, layer_idx: int):
    """Core MRI logic — separated so ZeroGPU decorator can wrap it."""
    lens = _state["lens"]
    model = _state["model"]
    tokenizer = _state["tokenizer"]

    if lens is None:
        return "<p>Please load a model first.</p>", None

    vector = _state["vectors"].get(persona_name)
    if vector is None:
        stim = STIMULUS_SETS[persona_name]
        vector = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx)
        _state["vectors"][persona_name] = vector

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(lens.device)

    tokens_str = []
    scores = []

    for _ in range(NUM_GENERATE_TOKENS):
        score = lens.scan(input_ids, vector, layer_idx)
        scores.append(score)

        with torch.no_grad():
            logits = model(input_ids).logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1).unsqueeze(0)

        tokens_str.append(tokenizer.decode(next_token[0]))
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    # Build highlighted HTML
    if scores:
        max_abs = max(abs(s) for s in scores) or 1.0
    else:
        max_abs = 1.0

    html = "<div style='font-family: monospace; font-size: 15px; line-height: 1.8;'>"
    html += f"<b>PROMPT:</b> {prompt}<br><br><b>GENERATION:</b><br>"
    for tok, scr in zip(tokens_str, scores):
        intensity = min(abs(scr) / max_abs, 1.0)
        if scr > 0:
            color = f"rgba(220, 50, 50, {intensity * 0.6:.2f})"
        else:
            color = f"rgba(50, 100, 220, {intensity * 0.4:.2f})"
        safe_tok = tok.replace("<", "&lt;").replace(">", "&gt;")
        html += f"<span style='background-color:{color}; padding:2px 1px; border-radius:3px;'>{safe_tok}</span>"
    html += "</div>"

    # Plotly chart
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        y=scores,
        x=list(range(len(scores))),
        mode="lines+markers",
        name=f"{persona_name} alignment",
        line=dict(color="crimson"),
    ))
    fig.add_hline(y=0, line_dash="dash", line_color="gray")
    fig.update_layout(
        title=f"Activation Trajectory — {persona_name}",
        yaxis_title="Alignment Score (dot product)",
        xaxis_title="Generation Step",
        template="plotly_white",
        height=350,
    )

    return html, fig


# Apply ZeroGPU decorator when running on HF Spaces
if IS_HF_SPACE:
    run_mri = spaces.GPU()(_run_mri_inner)
else:
    run_mri = _run_mri_inner


# --- UI ---
with gr.Blocks(title="Safety-Lens: Model MRI") as demo:
    gr.Markdown("# Safety-Lens: The Model MRI")
    gr.Markdown(
        "See **how** a model thinks, not just what it says. "
        "Red = high alignment with the selected persona vector. Blue = opposite direction."
    )

    with gr.Row():
        model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL)
        layer_slider = gr.Slider(
            label="Layer Index", minimum=0, maximum=47, step=1, value=DEFAULT_LAYER
        )
        load_btn = gr.Button("Load Model", variant="secondary")

    status_box = gr.Textbox(label="Status", interactive=False, lines=4)

    gr.Markdown("---")

    with gr.Row():
        prompt_box = gr.Textbox(
            label="Prompt",
            value="I think the world is flat. Do you agree?",
            lines=2,
        )
        persona_dropdown = gr.Dropdown(
            label="Persona to Scan",
            choices=list(STIMULUS_SETS.keys()),
            value="sycophancy",
        )

    scan_btn = gr.Button("Run MRI Scan", variant="primary")

    out_html = gr.HTML(label="Visualized Generation")
    out_plot = gr.Plot(label="Activation Dynamics")

    load_btn.click(
        load_model,
        inputs=[model_id, layer_slider],
        outputs=[status_box, out_html, out_plot],
    )
    scan_btn.click(
        run_mri,
        inputs=[prompt_box, persona_dropdown, layer_slider],
        outputs=[out_html, out_plot],
    )

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())