truthsayer / inference.py
vincentoh's picture
Upload inference.py with huggingface_hub
3d804a1 verified
"""
TruthSayer: Orthogonal Deception Detection via SAE Probes
Usage:
from inference import TruthSayer
detector = TruthSayer()
scores = detector.score("<scratchpad text>")
"""
import json
import numpy as np
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
class _EarlyExit(Exception):
pass
class JumpReLUSAE(torch.nn.Module):
def __init__(self, params, device):
super().__init__()
self.w_enc = params['w_enc'].float().to(device)
self.b_enc = params['b_enc'].float().to(device)
self.threshold = params['threshold'].float().to(device)
def encode(self, x):
pre = x @ self.w_enc + self.b_enc
return torch.nn.functional.relu(pre) * (pre > self.threshold).float()
class TruthSayer:
def __init__(self, probe_dir=None, device='cuda', max_len=2048, layer=40):
self.device = device
self.max_len = max_len
self.layer = layer
# Load probes
if probe_dir is None:
probe_dir = Path(__file__).parent
probe_dir = Path(probe_dir)
self.probes = {}
for name in ['af', 'truth', 'joint']:
self.probes[name] = {
'weights': np.load(probe_dir / f'{name}_probe_weights.npy'),
'bias': np.load(probe_dir / f'{name}_probe_bias.npy'),
'scaler_mean': np.load(probe_dir / f'{name}_scaler_mean.npy'),
'scaler_scale': np.load(probe_dir / f'{name}_scaler_scale.npy'),
}
# Load model
print('Loading Gemma 3 27B...')
self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-27b-it')
self.model = AutoModelForCausalLM.from_pretrained(
'google/gemma-3-27b-it',
torch_dtype=torch.bfloat16,
device_map='auto',
low_cpu_mem_usage=True,
)
self.model.eval()
# Load SAE
print('Loading GemmaScope 2 SAE...')
sae_path = f'resid_post/layer_{layer}_width_16k_l0_medium'
params_path = hf_hub_download('google/gemma-scope-2-27b-it',
f'{sae_path}/params.safetensors')
params = load_file(params_path)
self.sae = JumpReLUSAE(params, device)
print('Ready.')
def extract_features(self, text):
captured = {}
def hook_fn(module, inp, out):
h = out[0] if isinstance(out, tuple) else out
captured['h'] = h[0].detach().float()
raise _EarlyExit()
handle = self.model.model.language_model.layers[self.layer].register_forward_hook(hook_fn)
try:
inputs = self.tokenizer(text, return_tensors='pt', truncation=True,
max_length=self.max_len, padding=False)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
captured.clear()
try:
with torch.no_grad():
self.model(**inputs)
except _EarlyExit:
pass
h = captured['h'].to(self.device)
f_all = self.sae.encode(h)
return f_all.mean(dim=0).cpu().numpy()
finally:
handle.remove()
def _probe_score(self, features, probe_name):
p = self.probes[probe_name]
x = (features - p['scaler_mean']) / (p['scaler_scale'] + 1e-10)
logit = float(x @ p['weights'].T + p['bias'])
prob = 1.0 / (1.0 + np.exp(-logit))
return prob
def score(self, text):
features = self.extract_features(text)
return {
'af_score': self._probe_score(features, 'af'),
'sandbagging_score': self._probe_score(features, 'truth'),
'joint_score': self._probe_score(features, 'joint'),
}
def score_batch(self, texts):
return [self.score(t) for t in texts]
if __name__ == '__main__':
import sys
detector = TruthSayer()
if len(sys.argv) > 1:
text = ' '.join(sys.argv[1:])
else:
text = input('Enter scratchpad text: ')
scores = detector.score(text)
print(f'AF score: {scores["af_score"]:.4f}')
print(f'Sandbagging score: {scores["sandbagging_score"]:.4f}')
print(f'Joint score: {scores["joint_score"]:.4f}')