OpenTransformer commited on
Commit
b7588a3
·
verified ·
1 Parent(s): f53ead3

Upload sat_test.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sat_test.py +208 -0
sat_test.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAT Retrofit Test: Can we force an AR model to output 2 tokens at once?
4
+
5
+ Hypothesis: AR models can't be "snapped" to SAT because their hidden states
6
+ only encode next-token prediction, not multi-token prediction.
7
+
8
+ Test: Take GPT-2, force 2-token prediction, measure degradation.
9
+ """
10
+
11
+ import torch
12
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
13
+ import torch.nn.functional as F
14
+
15
+ def load_model():
16
+ print("Loading GPT-2...")
17
+ model = GPT2LMHeadModel.from_pretrained('gpt2').cuda().eval()
18
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
19
+ return model, tokenizer
20
+
21
+ def ar_generate(model, tokenizer, prompt, n_tokens=20):
22
+ """Standard AR generation - 1 token at a time"""
23
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
24
+
25
+ generated = []
26
+ for _ in range(n_tokens):
27
+ with torch.no_grad():
28
+ outputs = model(input_ids)
29
+ next_logits = outputs.logits[:, -1, :]
30
+ next_token = torch.argmax(next_logits, dim=-1)
31
+ generated.append(next_token.item())
32
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
33
+
34
+ return tokenizer.decode(generated)
35
+
36
+ def forced_sat_generate(model, tokenizer, prompt, n_tokens=20, block_size=2):
37
+ """
38
+ FORCED SAT: Try to predict 2 tokens at once from AR model
39
+
40
+ Method: Use hidden state at position N to predict BOTH N+1 and N+2
41
+ This should fail because the model wasn't trained for this.
42
+ """
43
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
44
+
45
+ generated = []
46
+ for _ in range(n_tokens // block_size):
47
+ with torch.no_grad():
48
+ outputs = model(input_ids, output_hidden_states=True)
49
+
50
+ # Get final hidden state
51
+ hidden = outputs.hidden_states[-1][:, -1, :] # [1, 768]
52
+
53
+ # Method 1: Just use same logits twice (obviously wrong)
54
+ # logits = outputs.logits[:, -1, :]
55
+ # token1 = torch.argmax(logits, dim=-1)
56
+ # token2 = torch.argmax(logits, dim=-1) # Same!
57
+
58
+ # Method 2: Get logits, sample first, then... what?
59
+ # The model has NO trained projection for "2nd next token"
60
+
61
+ # Method 3: Use 2nd-to-last position for token 2?
62
+ # This is using OLDER context which is worse
63
+ logits1 = outputs.logits[:, -1, :]
64
+ logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
65
+
66
+ token1 = torch.argmax(logits1, dim=-1)
67
+ token2 = torch.argmax(logits2, dim=-1)
68
+
69
+ generated.extend([token1.item(), token2.item()])
70
+ input_ids = torch.cat([
71
+ input_ids,
72
+ token1.unsqueeze(0),
73
+ token2.unsqueeze(0)
74
+ ], dim=1)
75
+
76
+ return tokenizer.decode(generated)
77
+
78
+ def forced_sat_v2(model, tokenizer, prompt, n_tokens=20):
79
+ """
80
+ FORCED SAT v2: Add untrained linear projection for 2nd token
81
+
82
+ This simulates what would happen if you tried to add SAT to AR
83
+ without training it.
84
+ """
85
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
86
+
87
+ # Create random (untrained) projection for 2nd token
88
+ hidden_size = model.config.n_embd
89
+ vocab_size = model.config.vocab_size
90
+ random_head = torch.randn(hidden_size, vocab_size).cuda() * 0.02
91
+
92
+ generated = []
93
+ for _ in range(n_tokens // 2):
94
+ with torch.no_grad():
95
+ outputs = model(input_ids, output_hidden_states=True)
96
+ hidden = outputs.hidden_states[-1][:, -1, :]
97
+
98
+ # Token 1: Use trained head
99
+ logits1 = outputs.logits[:, -1, :]
100
+ token1 = torch.argmax(logits1, dim=-1)
101
+
102
+ # Token 2: Use untrained random head
103
+ logits2 = hidden @ random_head
104
+ token2 = torch.argmax(logits2, dim=-1)
105
+
106
+ generated.extend([token1.item(), token2.item()])
107
+ input_ids = torch.cat([
108
+ input_ids,
109
+ token1.unsqueeze(0),
110
+ token2.unsqueeze(0)
111
+ ], dim=1)
112
+
113
+ return tokenizer.decode(generated)
114
+
115
+ def measure_perplexity(model, tokenizer, text):
116
+ """Measure perplexity of generated text"""
117
+ input_ids = tokenizer.encode(text, return_tensors='pt').cuda()
118
+ with torch.no_grad():
119
+ outputs = model(input_ids, labels=input_ids)
120
+ return torch.exp(outputs.loss).item()
121
+
122
+ def benchmark_speed(model, tokenizer, prompt, n_tokens=100, n_runs=5):
123
+ """Benchmark tokens per second for AR vs SAT"""
124
+ import time
125
+
126
+ # Warmup
127
+ ar_generate(model, tokenizer, prompt, n_tokens=10)
128
+ forced_sat_generate(model, tokenizer, prompt, n_tokens=10)
129
+
130
+ # AR benchmark
131
+ ar_times = []
132
+ for _ in range(n_runs):
133
+ torch.cuda.synchronize()
134
+ start = time.perf_counter()
135
+ ar_generate(model, tokenizer, prompt, n_tokens=n_tokens)
136
+ torch.cuda.synchronize()
137
+ ar_times.append(time.perf_counter() - start)
138
+
139
+ ar_avg = sum(ar_times) / len(ar_times)
140
+ ar_tps = n_tokens / ar_avg
141
+
142
+ # SAT benchmark
143
+ sat_times = []
144
+ for _ in range(n_runs):
145
+ torch.cuda.synchronize()
146
+ start = time.perf_counter()
147
+ forced_sat_generate(model, tokenizer, prompt, n_tokens=n_tokens)
148
+ torch.cuda.synchronize()
149
+ sat_times.append(time.perf_counter() - start)
150
+
151
+ sat_avg = sum(sat_times) / len(sat_times)
152
+ sat_tps = n_tokens / sat_avg
153
+
154
+ return ar_tps, sat_tps
155
+
156
+ def main():
157
+ model, tokenizer = load_model()
158
+
159
+ prompts = [
160
+ "The quick brown fox",
161
+ "In the beginning",
162
+ "Once upon a time",
163
+ "The scientist discovered that",
164
+ "Machine learning is",
165
+ ]
166
+
167
+ print("\n" + "="*80)
168
+ print("SAT RETROFIT TEST: Can AR models be forced to output 2 tokens?")
169
+ print("="*80)
170
+
171
+ # Speed benchmark first
172
+ print("\n\nSPEED BENCHMARK (100 tokens, 5 runs):")
173
+ print("-"*60)
174
+ ar_tps, sat_tps = benchmark_speed(model, tokenizer, "The quick brown fox", n_tokens=100, n_runs=5)
175
+ print(f"AR: {ar_tps:.1f} tokens/sec")
176
+ print(f"SAT: {sat_tps:.1f} tokens/sec")
177
+ print(f"Speedup: {sat_tps/ar_tps:.2f}x")
178
+
179
+ for prompt in prompts:
180
+ print(f"\n\nPrompt: '{prompt}'")
181
+ print("-"*60)
182
+
183
+ # Standard AR
184
+ ar_text = ar_generate(model, tokenizer, prompt, n_tokens=20)
185
+ print(f"AR (baseline): {ar_text}")
186
+
187
+ # Forced SAT methods
188
+ sat_text = forced_sat_generate(model, tokenizer, prompt, n_tokens=20)
189
+ print(f"Forced SAT v1: {sat_text}")
190
+
191
+ sat_v2_text = forced_sat_v2(model, tokenizer, prompt, n_tokens=20)
192
+ print(f"Forced SAT v2: {sat_v2_text}")
193
+
194
+ # Measure perplexity
195
+ try:
196
+ ar_ppl = measure_perplexity(model, tokenizer, prompt + ar_text)
197
+ sat_ppl = measure_perplexity(model, tokenizer, prompt + sat_text)
198
+ print(f"\nPerplexity - AR: {ar_ppl:.2f}, SAT: {sat_ppl:.2f}, Ratio: {sat_ppl/ar_ppl:.2f}x worse")
199
+ except:
200
+ pass
201
+
202
+ print("\n" + "="*80)
203
+ print("CONCLUSION: AR hidden states don't encode multi-token future.")
204
+ print("Joint AR+SAT training required to build compatible representations.")
205
+ print("="*80)
206
+
207
+ if __name__ == "__main__":
208
+ main()