File size: 7,382 Bytes
b7588a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
"""
SAT Retrofit Test: Can we force an AR model to output 2 tokens at once?

Hypothesis: AR models can't be "snapped" to SAT because their hidden states
only encode next-token prediction, not multi-token prediction.

Test: Take GPT-2, force 2-token prediction, measure degradation.
"""

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F

def load_model():
    print("Loading GPT-2...")
    model = GPT2LMHeadModel.from_pretrained('gpt2').cuda().eval()
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    return model, tokenizer

def ar_generate(model, tokenizer, prompt, n_tokens=20):
    """Standard AR generation - 1 token at a time"""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    
    generated = []
    for _ in range(n_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            next_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_logits, dim=-1)
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
    
    return tokenizer.decode(generated)

def forced_sat_generate(model, tokenizer, prompt, n_tokens=20, block_size=2):
    """
    FORCED SAT: Try to predict 2 tokens at once from AR model
    
    Method: Use hidden state at position N to predict BOTH N+1 and N+2
    This should fail because the model wasn't trained for this.
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    
    generated = []
    for _ in range(n_tokens // block_size):
        with torch.no_grad():
            outputs = model(input_ids, output_hidden_states=True)
            
            # Get final hidden state
            hidden = outputs.hidden_states[-1][:, -1, :]  # [1, 768]
            
            # Method 1: Just use same logits twice (obviously wrong)
            # logits = outputs.logits[:, -1, :]
            # token1 = torch.argmax(logits, dim=-1)
            # token2 = torch.argmax(logits, dim=-1)  # Same!
            
            # Method 2: Get logits, sample first, then... what?
            # The model has NO trained projection for "2nd next token"
            
            # Method 3: Use 2nd-to-last position for token 2?
            # This is using OLDER context which is worse
            logits1 = outputs.logits[:, -1, :]
            logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
            
            token1 = torch.argmax(logits1, dim=-1)
            token2 = torch.argmax(logits2, dim=-1)
            
            generated.extend([token1.item(), token2.item()])
            input_ids = torch.cat([
                input_ids, 
                token1.unsqueeze(0), 
                token2.unsqueeze(0)
            ], dim=1)
    
    return tokenizer.decode(generated)

def forced_sat_v2(model, tokenizer, prompt, n_tokens=20):
    """
    FORCED SAT v2: Add untrained linear projection for 2nd token
    
    This simulates what would happen if you tried to add SAT to AR
    without training it.
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    
    # Create random (untrained) projection for 2nd token
    hidden_size = model.config.n_embd
    vocab_size = model.config.vocab_size
    random_head = torch.randn(hidden_size, vocab_size).cuda() * 0.02
    
    generated = []
    for _ in range(n_tokens // 2):
        with torch.no_grad():
            outputs = model(input_ids, output_hidden_states=True)
            hidden = outputs.hidden_states[-1][:, -1, :]
            
            # Token 1: Use trained head
            logits1 = outputs.logits[:, -1, :]
            token1 = torch.argmax(logits1, dim=-1)
            
            # Token 2: Use untrained random head
            logits2 = hidden @ random_head
            token2 = torch.argmax(logits2, dim=-1)
            
            generated.extend([token1.item(), token2.item()])
            input_ids = torch.cat([
                input_ids,
                token1.unsqueeze(0),
                token2.unsqueeze(0)
            ], dim=1)
    
    return tokenizer.decode(generated)

def measure_perplexity(model, tokenizer, text):
    """Measure perplexity of generated text"""
    input_ids = tokenizer.encode(text, return_tensors='pt').cuda()
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        return torch.exp(outputs.loss).item()

def benchmark_speed(model, tokenizer, prompt, n_tokens=100, n_runs=5):
    """Benchmark tokens per second for AR vs SAT"""
    import time
    
    # Warmup
    ar_generate(model, tokenizer, prompt, n_tokens=10)
    forced_sat_generate(model, tokenizer, prompt, n_tokens=10)
    
    # AR benchmark
    ar_times = []
    for _ in range(n_runs):
        torch.cuda.synchronize()
        start = time.perf_counter()
        ar_generate(model, tokenizer, prompt, n_tokens=n_tokens)
        torch.cuda.synchronize()
        ar_times.append(time.perf_counter() - start)
    
    ar_avg = sum(ar_times) / len(ar_times)
    ar_tps = n_tokens / ar_avg
    
    # SAT benchmark
    sat_times = []
    for _ in range(n_runs):
        torch.cuda.synchronize()
        start = time.perf_counter()
        forced_sat_generate(model, tokenizer, prompt, n_tokens=n_tokens)
        torch.cuda.synchronize()
        sat_times.append(time.perf_counter() - start)
    
    sat_avg = sum(sat_times) / len(sat_times)
    sat_tps = n_tokens / sat_avg
    
    return ar_tps, sat_tps

def main():
    model, tokenizer = load_model()
    
    prompts = [
        "The quick brown fox",
        "In the beginning",
        "Once upon a time",
        "The scientist discovered that",
        "Machine learning is",
    ]
    
    print("\n" + "="*80)
    print("SAT RETROFIT TEST: Can AR models be forced to output 2 tokens?")
    print("="*80)
    
    # Speed benchmark first
    print("\n\nSPEED BENCHMARK (100 tokens, 5 runs):")
    print("-"*60)
    ar_tps, sat_tps = benchmark_speed(model, tokenizer, "The quick brown fox", n_tokens=100, n_runs=5)
    print(f"AR:  {ar_tps:.1f} tokens/sec")
    print(f"SAT: {sat_tps:.1f} tokens/sec")
    print(f"Speedup: {sat_tps/ar_tps:.2f}x")
    
    for prompt in prompts:
        print(f"\n\nPrompt: '{prompt}'")
        print("-"*60)
        
        # Standard AR
        ar_text = ar_generate(model, tokenizer, prompt, n_tokens=20)
        print(f"AR (baseline):     {ar_text}")
        
        # Forced SAT methods
        sat_text = forced_sat_generate(model, tokenizer, prompt, n_tokens=20)
        print(f"Forced SAT v1:     {sat_text}")
        
        sat_v2_text = forced_sat_v2(model, tokenizer, prompt, n_tokens=20)
        print(f"Forced SAT v2:     {sat_v2_text}")
        
        # Measure perplexity
        try:
            ar_ppl = measure_perplexity(model, tokenizer, prompt + ar_text)
            sat_ppl = measure_perplexity(model, tokenizer, prompt + sat_text)
            print(f"\nPerplexity - AR: {ar_ppl:.2f}, SAT: {sat_ppl:.2f}, Ratio: {sat_ppl/ar_ppl:.2f}x worse")
        except:
            pass
    
    print("\n" + "="*80)
    print("CONCLUSION: AR hidden states don't encode multi-token future.")
    print("Joint AR+SAT training required to build compatible representations.")
    print("="*80)

if __name__ == "__main__":
    main()