File size: 4,392 Bytes
3d804a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
TruthSayer: Orthogonal Deception Detection via SAE Probes

Usage:
    from inference import TruthSayer
    detector = TruthSayer()
    scores = detector.score("<scratchpad text>")
"""
import json
import numpy as np
import torch
from pathlib import Path
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM


class _EarlyExit(Exception):
    pass


class JumpReLUSAE(torch.nn.Module):
    def __init__(self, params, device):
        super().__init__()
        self.w_enc = params['w_enc'].float().to(device)
        self.b_enc = params['b_enc'].float().to(device)
        self.threshold = params['threshold'].float().to(device)

    def encode(self, x):
        pre = x @ self.w_enc + self.b_enc
        return torch.nn.functional.relu(pre) * (pre > self.threshold).float()


class TruthSayer:
    def __init__(self, probe_dir=None, device='cuda', max_len=2048, layer=40):
        self.device = device
        self.max_len = max_len
        self.layer = layer

        # Load probes
        if probe_dir is None:
            probe_dir = Path(__file__).parent
        probe_dir = Path(probe_dir)

        self.probes = {}
        for name in ['af', 'truth', 'joint']:
            self.probes[name] = {
                'weights': np.load(probe_dir / f'{name}_probe_weights.npy'),
                'bias': np.load(probe_dir / f'{name}_probe_bias.npy'),
                'scaler_mean': np.load(probe_dir / f'{name}_scaler_mean.npy'),
                'scaler_scale': np.load(probe_dir / f'{name}_scaler_scale.npy'),
            }

        # Load model
        print('Loading Gemma 3 27B...')
        self.tokenizer = AutoTokenizer.from_pretrained('google/gemma-3-27b-it')
        self.model = AutoModelForCausalLM.from_pretrained(
            'google/gemma-3-27b-it',
            torch_dtype=torch.bfloat16,
            device_map='auto',
            low_cpu_mem_usage=True,
        )
        self.model.eval()

        # Load SAE
        print('Loading GemmaScope 2 SAE...')
        sae_path = f'resid_post/layer_{layer}_width_16k_l0_medium'
        params_path = hf_hub_download('google/gemma-scope-2-27b-it',
                                       f'{sae_path}/params.safetensors')
        params = load_file(params_path)
        self.sae = JumpReLUSAE(params, device)
        print('Ready.')

    def extract_features(self, text):
        captured = {}

        def hook_fn(module, inp, out):
            h = out[0] if isinstance(out, tuple) else out
            captured['h'] = h[0].detach().float()
            raise _EarlyExit()

        handle = self.model.model.language_model.layers[self.layer].register_forward_hook(hook_fn)
        try:
            inputs = self.tokenizer(text, return_tensors='pt', truncation=True,
                                     max_length=self.max_len, padding=False)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            captured.clear()
            try:
                with torch.no_grad():
                    self.model(**inputs)
            except _EarlyExit:
                pass
            h = captured['h'].to(self.device)
            f_all = self.sae.encode(h)
            return f_all.mean(dim=0).cpu().numpy()
        finally:
            handle.remove()

    def _probe_score(self, features, probe_name):
        p = self.probes[probe_name]
        x = (features - p['scaler_mean']) / (p['scaler_scale'] + 1e-10)
        logit = float(x @ p['weights'].T + p['bias'])
        prob = 1.0 / (1.0 + np.exp(-logit))
        return prob

    def score(self, text):
        features = self.extract_features(text)
        return {
            'af_score': self._probe_score(features, 'af'),
            'sandbagging_score': self._probe_score(features, 'truth'),
            'joint_score': self._probe_score(features, 'joint'),
        }

    def score_batch(self, texts):
        return [self.score(t) for t in texts]


if __name__ == '__main__':
    import sys
    detector = TruthSayer()
    if len(sys.argv) > 1:
        text = ' '.join(sys.argv[1:])
    else:
        text = input('Enter scratchpad text: ')
    scores = detector.score(text)
    print(f'AF score:          {scores["af_score"]:.4f}')
    print(f'Sandbagging score: {scores["sandbagging_score"]:.4f}')
    print(f'Joint score:       {scores["joint_score"]:.4f}')