| """ |
| TruthSayer: Orthogonal Deception Detection via SAE Probes |
| |
| Usage: |
| from inference import TruthSayer |
| detector = TruthSayer() |
| scores = detector.score("<scratchpad text>") |
| """ |
| import json |
| import numpy as np |
| import torch |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
| class _EarlyExit(Exception): |
| pass |
|
|
|
|
| class JumpReLUSAE(torch.nn.Module): |
| def __init__(self, params, device): |
| super().__init__() |
| self.w_enc = params['w_enc'].float().to(device) |
| self.b_enc = params['b_enc'].float().to(device) |
| self.threshold = params['threshold'].float().to(device) |
|
|
| def encode(self, x): |
| pre = x @ self.w_enc + self.b_enc |
| return torch.nn.functional.relu(pre) * (pre > self.threshold).float() |
|
|
|
|
| class TruthSayer: |
| def __init__(self, probe_dir=None, device='cuda', max_len=2048, layer=40): |
| self.device = device |
| self.max_len = max_len |
| self.layer = layer |
|
|
| |
| if probe_dir is None: |
| probe_dir = Path(__file__).parent |
| probe_dir = Path(probe_dir) |
|
|
| self.probes = {} |
| for name in ['af', 'truth', 'joint']: |
| self.probes[name] = { |
| 'weights': np.load(probe_dir / f'{name}_probe_weights.npy'), |
| 'bias': np.load(probe_dir / f'{name}_probe_bias.npy'), |
| 'scaler_mean': np.load(probe_dir / f'{name}_scaler_mean.npy'), |
| 'scaler_scale': np.load(probe_dir / f'{name}_scaler_scale.npy'), |
| } |
|
|
| |
| print('Loading Gemma 3 27B...') |
| self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-27b-it') |
| self.model = AutoModelForCausalLM.from_pretrained( |
| 'google/gemma-3-27b-it', |
| torch_dtype=torch.bfloat16, |
| device_map='auto', |
| low_cpu_mem_usage=True, |
| ) |
| self.model.eval() |
|
|
| |
| print('Loading GemmaScope 2 SAE...') |
| sae_path = f'resid_post/layer_{layer}_width_16k_l0_medium' |
| params_path = hf_hub_download('google/gemma-scope-2-27b-it', |
| f'{sae_path}/params.safetensors') |
| params = load_file(params_path) |
| self.sae = JumpReLUSAE(params, device) |
| print('Ready.') |
|
|
| def extract_features(self, text): |
| captured = {} |
|
|
| def hook_fn(module, inp, out): |
| h = out[0] if isinstance(out, tuple) else out |
| captured['h'] = h[0].detach().float() |
| raise _EarlyExit() |
|
|
| handle = self.model.model.language_model.layers[self.layer].register_forward_hook(hook_fn) |
| try: |
| inputs = self.tokenizer(text, return_tensors='pt', truncation=True, |
| max_length=self.max_len, padding=False) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| captured.clear() |
| try: |
| with torch.no_grad(): |
| self.model(**inputs) |
| except _EarlyExit: |
| pass |
| h = captured['h'].to(self.device) |
| f_all = self.sae.encode(h) |
| return f_all.mean(dim=0).cpu().numpy() |
| finally: |
| handle.remove() |
|
|
| def _probe_score(self, features, probe_name): |
| p = self.probes[probe_name] |
| x = (features - p['scaler_mean']) / (p['scaler_scale'] + 1e-10) |
| logit = float(x @ p['weights'].T + p['bias']) |
| prob = 1.0 / (1.0 + np.exp(-logit)) |
| return prob |
|
|
| def score(self, text): |
| features = self.extract_features(text) |
| return { |
| 'af_score': self._probe_score(features, 'af'), |
| 'sandbagging_score': self._probe_score(features, 'truth'), |
| 'joint_score': self._probe_score(features, 'joint'), |
| } |
|
|
| def score_batch(self, texts): |
| return [self.score(t) for t in texts] |
|
|
|
|
| if __name__ == '__main__': |
| import sys |
| detector = TruthSayer() |
| if len(sys.argv) > 1: |
| text = ' '.join(sys.argv[1:]) |
| else: |
| text = input('Enter scratchpad text: ') |
| scores = detector.score(text) |
| print(f'AF score: {scores["af_score"]:.4f}') |
| print(f'Sandbagging score: {scores["sandbagging_score"]:.4f}') |
| print(f'Joint score: {scores["joint_score"]:.4f}') |
|
|