"""Safety-Lens: The Model MRI — a real-time activation scanner for HF models.""" import os # ZeroGPU: must import spaces BEFORE torch/CUDA IS_HF_SPACE = os.environ.get("SPACE_ID") is not None if IS_HF_SPACE: import spaces import gradio as gr import torch import plotly.graph_objects as go from transformers import AutoModelForCausalLM, AutoTokenizer from safety_lens.core import SafetyLens, LensHooks from safety_lens.vectors import STIMULUS_SETS # --- Globals (populated on model load) --- _state = {"lens": None, "model": None, "tokenizer": None, "vectors": {}} DEFAULT_MODEL = "gpt2" DEFAULT_LAYER = 6 NUM_GENERATE_TOKENS = 30 def _calibrate_on_gpu(model, tokenizer, layer_idx: int): """Calibrate persona vectors — runs inside @spaces.GPU on ZeroGPU.""" if torch.cuda.is_available(): model = model.half().to("cuda") model.eval() lens = SafetyLens(model, tokenizer) _state["lens"] = lens _state["model"] = model _state["tokenizer"] = tokenizer _state["vectors"] = {} vectors = {} for name, stim in STIMULUS_SETS.items(): vec = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) vectors[name] = vec _state["vectors"] = vectors return lens.device, list(vectors.keys()) # Wrap calibration for ZeroGPU when on HF Spaces if IS_HF_SPACE: _calibrate_on_gpu = spaces.GPU()(_calibrate_on_gpu) def load_model(model_id: str, layer_idx: int): """Load a model and calibrate persona vectors.""" status_lines = [f"Loading {model_id}..."] yield "\n".join(status_lines), None, None tokenizer = AutoTokenizer.from_pretrained(model_id) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # On ZeroGPU, load on CPU first — GPU is only available inside @spaces.GPU if IS_HF_SPACE: model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) else: dtype = torch.float16 if torch.cuda.is_available() else torch.float32 device_map = "auto" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map ) status_lines.append(f"Model loaded. Calibrating persona vectors on layer {layer_idx}...") yield "\n".join(status_lines), None, None device, calibrated = _calibrate_on_gpu(model, tokenizer, layer_idx) for name in calibrated: status_lines.append(f" Calibrated: {name}") status_lines.append(f"Ready for scanning on {device}.") yield "\n".join(status_lines), None, None def _run_mri_inner(prompt: str, persona_name: str, layer_idx: int): """Core MRI logic — separated so ZeroGPU decorator can wrap it.""" lens = _state["lens"] model = _state["model"] tokenizer = _state["tokenizer"] if lens is None: return "
Please load a model first.
", None vector = _state["vectors"].get(persona_name) if vector is None: stim = STIMULUS_SETS[persona_name] vector = lens.extract_persona_vector(stim["pos"], stim["neg"], layer_idx) _state["vectors"][persona_name] = vector input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(lens.device) tokens_str = [] scores = [] for _ in range(NUM_GENERATE_TOKENS): score = lens.scan(input_ids, vector, layer_idx) scores.append(score) with torch.no_grad(): logits = model(input_ids).logits[:, -1, :] next_token = torch.argmax(logits, dim=-1).unsqueeze(0) tokens_str.append(tokenizer.decode(next_token[0])) input_ids = torch.cat([input_ids, next_token], dim=-1) # Build highlighted HTML if scores: max_abs = max(abs(s) for s in scores) or 1.0 else: max_abs = 1.0 html = "