Spaces:
Running on Zero
Running on Zero
| """Safety-Lens: The Model MRI β a real-time activation scanner for HF models.""" | |
| import os | |
| # ZeroGPU: must import spaces BEFORE torch/CUDA | |
| IS_HF_SPACE = os.environ.get("SPACE_ID") is not None | |
| if IS_HF_SPACE: | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import plotly.graph_objects as go | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from safety_lens.core import SafetyLens, LensHooks | |
| from safety_lens.vectors import STIMULUS_SETS | |
| # --- Globals (populated on model load) --- | |
| _state = {"lens": None, "model": None, "tokenizer": None, "vectors": {}} | |
| DEFAULT_MODEL = "gpt2" | |
| DEFAULT_LAYER = 6 | |
| NUM_GENERATE_TOKENS = 30 | |
| def _calibrate_on_gpu(model, tokenizer, layer_idx: int): | |
| """Calibrate persona vectors β runs inside @spaces.GPU on ZeroGPU.""" | |
| if torch.cuda.is_available(): | |
| model = model.half().to("cuda") | |
| model.eval() | |
| lens = SafetyLens(model, tokenizer) | |
| _state["lens"] = lens | |
| _state["model"] = model | |
| _state["tokenizer"] = tokenizer | |
| _state["vectors"] = {} | |
| vectors = {} | |
| for name, stim in STIMULUS_SETS.items(): | |
| vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) | |
| vectors[name] = vec | |
| _state["vectors"] = vectors | |
| return lens.device, list(vectors.keys()) | |
| # Wrap calibration for ZeroGPU when on HF Spaces | |
| if IS_HF_SPACE: | |
| _calibrate_on_gpu = spaces.GPU()(_calibrate_on_gpu) | |
| def load_model(model_id: str, layer_idx: int): | |
| """Load a model and calibrate persona vectors.""" | |
| status_lines = [f"Loading {model_id}..."] | |
| yield "\n".join(status_lines), None, None | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # On ZeroGPU, load on CPU first β GPU is only available inside @spaces.GPU | |
| if IS_HF_SPACE: | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) | |
| else: | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| device_map = "auto" if torch.cuda.is_available() else "cpu" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=dtype, device_map=device_map | |
| ) | |
| status_lines.append(f"Model loaded. Calibrating persona vectors on layer {layer_idx}...") | |
| yield "\n".join(status_lines), None, None | |
| device, calibrated = _calibrate_on_gpu(model, tokenizer, layer_idx) | |
| for name in calibrated: | |
| status_lines.append(f" Calibrated: {name}") | |
| status_lines.append(f"Ready for scanning on {device}.") | |
| yield "\n".join(status_lines), None, None | |
| def _run_mri_inner(prompt: str, persona_name: str, layer_idx: int): | |
| """Core MRI logic β separated so ZeroGPU decorator can wrap it.""" | |
| lens = _state["lens"] | |
| model = _state["model"] | |
| tokenizer = _state["tokenizer"] | |
| if lens is None: | |
| return "<p>Please load a model first.</p>", None | |
| vector = _state["vectors"].get(persona_name) | |
| if vector is None: | |
| stim = STIMULUS_SETS[persona_name] | |
| vector = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) | |
| _state["vectors"][persona_name] = vector | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(lens.device) | |
| tokens_str = [] | |
| scores = [] | |
| for _ in range(NUM_GENERATE_TOKENS): | |
| score = lens.scan(input_ids, vector, layer_idx) | |
| scores.append(score) | |
| with torch.no_grad(): | |
| logits = model(input_ids).logits[:, -1, :] | |
| next_token = torch.argmax(logits, dim=-1).unsqueeze(0) | |
| tokens_str.append(tokenizer.decode(next_token[0])) | |
| input_ids = torch.cat([input_ids, next_token], dim=-1) | |
| # Build highlighted HTML | |
| if scores: | |
| max_abs = max(abs(s) for s in scores) or 1.0 | |
| else: | |
| max_abs = 1.0 | |
| html = "<div style='font-family: monospace; font-size: 15px; line-height: 1.8;'>" | |
| html += f"<b>PROMPT:</b> {prompt}<br><br><b>GENERATION:</b><br>" | |
| for tok, scr in zip(tokens_str, scores): | |
| intensity = min(abs(scr) / max_abs, 1.0) | |
| if scr > 0: | |
| color = f"rgba(220, 50, 50, {intensity * 0.6:.2f})" | |
| else: | |
| color = f"rgba(50, 100, 220, {intensity * 0.4:.2f})" | |
| safe_tok = tok.replace("<", "<").replace(">", ">") | |
| html += f"<span style='background-color:{color}; padding:2px 1px; border-radius:3px;'>{safe_tok}</span>" | |
| html += "</div>" | |
| # Plotly chart | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| y=scores, | |
| x=list(range(len(scores))), | |
| mode="lines+markers", | |
| name=f"{persona_name} alignment", | |
| line=dict(color="crimson"), | |
| )) | |
| fig.add_hline(y=0, line_dash="dash", line_color="gray") | |
| fig.update_layout( | |
| title=f"Activation Trajectory β {persona_name}", | |
| yaxis_title="Alignment Score (dot product)", | |
| xaxis_title="Generation Step", | |
| template="plotly_white", | |
| height=350, | |
| ) | |
| return html, fig | |
| # Apply ZeroGPU decorator when running on HF Spaces | |
| if IS_HF_SPACE: | |
| run_mri = spaces.GPU()(_run_mri_inner) | |
| else: | |
| run_mri = _run_mri_inner | |
| # --- UI --- | |
| with gr.Blocks(title="Safety-Lens: Model MRI") as demo: | |
| gr.Markdown("# Safety-Lens: The Model MRI") | |
| gr.Markdown( | |
| "See **how** a model thinks, not just what it says. " | |
| "Red = high alignment with the selected persona vector. Blue = opposite direction." | |
| ) | |
| with gr.Row(): | |
| model_id = gr.Textbox(label="Model ID", value=DEFAULT_MODEL) | |
| layer_slider = gr.Slider( | |
| label="Layer Index", minimum=0, maximum=47, step=1, value=DEFAULT_LAYER | |
| ) | |
| load_btn = gr.Button("Load Model", variant="secondary") | |
| status_box = gr.Textbox(label="Status", interactive=False, lines=4) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| value="I think the world is flat. Do you agree?", | |
| lines=2, | |
| ) | |
| persona_dropdown = gr.Dropdown( | |
| label="Persona to Scan", | |
| choices=list(STIMULUS_SETS.keys()), | |
| value="sycophancy", | |
| ) | |
| scan_btn = gr.Button("Run MRI Scan", variant="primary") | |
| out_html = gr.HTML(label="Visualized Generation") | |
| out_plot = gr.Plot(label="Activation Dynamics") | |
| load_btn.click( | |
| load_model, | |
| inputs=[model_id, layer_slider], | |
| outputs=[status_box, out_html, out_plot], | |
| ) | |
| scan_btn.click( | |
| run_mri, | |
| inputs=[prompt_box, persona_dropdown, layer_slider], | |
| outputs=[out_html, out_plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) | |