""" TruthSayer: Orthogonal Deception Detection via SAE Probes Usage: from inference import TruthSayer detector = TruthSayer() scores = detector.score("") """ import json import numpy as np import torch from pathlib import Path from huggingface_hub import hf_hub_download from safetensors.torch import load_file from transformers import AutoTokenizer, AutoModelForCausalLM class _EarlyExit(Exception): pass class JumpReLUSAE(torch.nn.Module): def __init__(self, params, device): super().__init__() self.w_enc = params['w_enc'].float().to(device) self.b_enc = params['b_enc'].float().to(device) self.threshold = params['threshold'].float().to(device) def encode(self, x): pre = x @ self.w_enc + self.b_enc return torch.nn.functional.relu(pre) * (pre > self.threshold).float() class TruthSayer: def __init__(self, probe_dir=None, device='cuda', max_len=2048, layer=40): self.device = device self.max_len = max_len self.layer = layer # Load probes if probe_dir is None: probe_dir = Path(__file__).parent probe_dir = Path(probe_dir) self.probes = {} for name in ['af', 'truth', 'joint']: self.probes[name] = { 'weights': np.load(probe_dir / f'{name}_probe_weights.npy'), 'bias': np.load(probe_dir / f'{name}_probe_bias.npy'), 'scaler_mean': np.load(probe_dir / f'{name}_scaler_mean.npy'), 'scaler_scale': np.load(probe_dir / f'{name}_scaler_scale.npy'), } # Load model print('Loading Gemma 3 27B...') self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-27b-it') self.model = AutoModelForCausalLM.from_pretrained( 'google/gemma-3-27b-it', torch_dtype=torch.bfloat16, device_map='auto', low_cpu_mem_usage=True, ) self.model.eval() # Load SAE print('Loading GemmaScope 2 SAE...') sae_path = f'resid_post/layer_{layer}_width_16k_l0_medium' params_path = hf_hub_download('google/gemma-scope-2-27b-it', f'{sae_path}/params.safetensors') params = load_file(params_path) self.sae = JumpReLUSAE(params, device) print('Ready.') def extract_features(self, text): captured = {} def hook_fn(module, inp, out): h = out[0] if isinstance(out, tuple) else out captured['h'] = h[0].detach().float() raise _EarlyExit() handle = self.model.model.language_model.layers[self.layer].register_forward_hook(hook_fn) try: inputs = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=self.max_len, padding=False) inputs = {k: v.to(self.device) for k, v in inputs.items()} captured.clear() try: with torch.no_grad(): self.model(**inputs) except _EarlyExit: pass h = captured['h'].to(self.device) f_all = self.sae.encode(h) return f_all.mean(dim=0).cpu().numpy() finally: handle.remove() def _probe_score(self, features, probe_name): p = self.probes[probe_name] x = (features - p['scaler_mean']) / (p['scaler_scale'] + 1e-10) logit = float(x @ p['weights'].T + p['bias']) prob = 1.0 / (1.0 + np.exp(-logit)) return prob def score(self, text): features = self.extract_features(text) return { 'af_score': self._probe_score(features, 'af'), 'sandbagging_score': self._probe_score(features, 'truth'), 'joint_score': self._probe_score(features, 'joint'), } def score_batch(self, texts): return [self.score(t) for t in texts] if __name__ == '__main__': import sys detector = TruthSayer() if len(sys.argv) > 1: text = ' '.join(sys.argv[1:]) else: text = input('Enter scratchpad text: ') scores = detector.score(text) print(f'AF score: {scores["af_score"]:.4f}') print(f'Sandbagging score: {scores["sandbagging_score"]:.4f}') print(f'Joint score: {scores["joint_score"]:.4f}')