"""Safety-Lens: The Model MRI — a real-time activation scanner for HF models.""" import os # ZeroGPU: must import spaces BEFORE torch/CUDA IS_HF_SPACE = os.environ.get("SPACE_ID") is not None if IS_HF_SPACE: import spaces import gradio as gr import torch import plotly.graph_objects as go from transformers import AutoModelForCausalLM, AutoTokenizer from safety_lens.core import SafetyLens, LensHooks from safety_lens.vectors import STIMULUS_SETS # --- Globals (populated on model load) --- _state = {"lens": None, "model": None, "tokenizer": None, "vectors": {}} DEFAULT_MODEL = "gpt2" DEFAULT_LAYER = 6 NUM_GENERATE_TOKENS = 30 def _calibrate_on_gpu(model, tokenizer, layer_idx: int): """Calibrate persona vectors — runs inside @spaces.GPU on ZeroGPU.""" if torch.cuda.is_available(): model = model.half().to("cuda") model.eval() lens = SafetyLens(model, tokenizer) _state["lens"] = lens _state["model"] = model _state["tokenizer"] = tokenizer _state["vectors"] = {} vectors = {} for name, stim in STIMULUS_SETS.items(): vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) vectors[name] = vec _state["vectors"] = vectors return lens.device, list(vectors.keys()) # Wrap calibration for ZeroGPU when on HF Spaces if IS_HF_SPACE: _calibrate_on_gpu = spaces.GPU()(_calibrate_on_gpu) def load_model(model_id: str, layer_idx: int): """Load a model and calibrate persona vectors.""" status_lines = [f"Loading {model_id}..."] yield "\n".join(status_lines), None, None tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # On ZeroGPU, load on CPU first — GPU is only available inside @spaces.GPU if IS_HF_SPACE: model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) else: dtype = torch.float16 if torch.cuda.is_available() else torch.float32 device_map = "auto" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map ) status_lines.append(f"Model loaded. Calibrating persona vectors on layer {layer_idx}...") yield "\n".join(status_lines), None, None device, calibrated = _calibrate_on_gpu(model, tokenizer, layer_idx) for name in calibrated: status_lines.append(f" Calibrated: {name}") status_lines.append(f"Ready for scanning on {device}.") yield "\n".join(status_lines), None, None def _run_mri_inner(prompt: str, persona_name: str, layer_idx: int): """Core MRI logic — separated so ZeroGPU decorator can wrap it.""" lens = _state["lens"] model = _state["model"] tokenizer = _state["tokenizer"] if lens is None: return "

Please load a model first.

", None vector = _state["vectors"].get(persona_name) if vector is None: stim = STIMULUS_SETS[persona_name] vector = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) _state["vectors"][persona_name] = vector input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(lens.device) tokens_str = [] scores = [] for _ in range(NUM_GENERATE_TOKENS): score = lens.scan(input_ids, vector, layer_idx) scores.append(score) with torch.no_grad(): logits = model(input_ids).logits[:, -1, :] next_token = torch.argmax(logits, dim=-1).unsqueeze(0) tokens_str.append(tokenizer.decode(next_token[0])) input_ids = torch.cat([input_ids, next_token], dim=-1) # Build highlighted HTML if scores: max_abs = max(abs(s) for s in scores) or 1.0 else: max_abs = 1.0 html = "
" html += f"PROMPT: {prompt}

GENERATION:
" for tok, scr in zip(tokens_str, scores): intensity = min(abs(scr) / max_abs, 1.0) if scr > 0: color = f"rgba(220, 50, 50, {intensity * 0.6:.2f})" else: color = f"rgba(50, 100, 220, {intensity * 0.4:.2f})" safe_tok = tok.replace("<", "<").replace(">", ">") html += f"{safe_tok}" html += "
" # Plotly chart fig = go.Figure() fig.add_trace(go.Scatter( y=scores, x=list(range(len(scores))), mode="lines+markers", name=f"{persona_name} alignment", line=dict(color="crimson"), )) fig.add_hline(y=0, line_dash="dash", line_color="gray") fig.update_layout( title=f"Activation Trajectory — {persona_name}", yaxis_title="Alignment Score (dot product)", xaxis_title="Generation Step", template="plotly_white", height=350, ) return html, fig # Apply ZeroGPU decorator when running on HF Spaces if IS_HF_SPACE: run_mri = spaces.GPU()(_run_mri_inner) else: run_mri = _run_mri_inner # --- UI --- with gr.Blocks(title="Safety-Lens: Model MRI") as demo: gr.Markdown("# Safety-Lens: The Model MRI") gr.Markdown( "See **how** a model thinks, not just what it says. " "Red = high alignment with the selected persona vector. Blue = opposite direction." ) with gr.Row(): model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL) layer_slider = gr.Slider( label="Layer Index", minimum=0, maximum=47, step=1, value=DEFAULT_LAYER ) load_btn = gr.Button("Load Model", variant="secondary") status_box = gr.Textbox(label="Status", interactive=False, lines=4) gr.Markdown("---") with gr.Row(): prompt_box = gr.Textbox( label="Prompt", value="I think the world is flat. Do you agree?", lines=2, ) persona_dropdown = gr.Dropdown( label="Persona to Scan", choices=list(STIMULUS_SETS.keys()), value="sycophancy", ) scan_btn = gr.Button("Run MRI Scan", variant="primary") out_html = gr.HTML(label="Visualized Generation") out_plot = gr.Plot(label="Activation Dynamics") load_btn.click( load_model, inputs=[model_id, layer_slider], outputs=[status_box, out_html, out_plot], ) scan_btn.click( run_mri, inputs=[prompt_box, persona_dropdown, layer_slider], outputs=[out_html, out_plot], ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())