File size: 7,382 Bytes
b7588a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | #!/usr/bin/env python3
"""
SAT Retrofit Test: Can we force an AR model to output 2 tokens at once?
Hypothesis: AR models can't be "snapped" to SAT because their hidden states
only encode next-token prediction, not multi-token prediction.
Test: Take GPT-2, force 2-token prediction, measure degradation.
"""
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F
def load_model():
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2').cuda().eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
return model, tokenizer
def ar_generate(model, tokenizer, prompt, n_tokens=20):
"""Standard AR generation - 1 token at a time"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
generated = []
for _ in range(n_tokens):
with torch.no_grad():
outputs = model(input_ids)
next_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_logits, dim=-1)
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.decode(generated)
def forced_sat_generate(model, tokenizer, prompt, n_tokens=20, block_size=2):
"""
FORCED SAT: Try to predict 2 tokens at once from AR model
Method: Use hidden state at position N to predict BOTH N+1 and N+2
This should fail because the model wasn't trained for this.
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
generated = []
for _ in range(n_tokens // block_size):
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True)
# Get final hidden state
hidden = outputs.hidden_states[-1][:, -1, :] # [1, 768]
# Method 1: Just use same logits twice (obviously wrong)
# logits = outputs.logits[:, -1, :]
# token1 = torch.argmax(logits, dim=-1)
# token2 = torch.argmax(logits, dim=-1) # Same!
# Method 2: Get logits, sample first, then... what?
# The model has NO trained projection for "2nd next token"
# Method 3: Use 2nd-to-last position for token 2?
# This is using OLDER context which is worse
logits1 = outputs.logits[:, -1, :]
logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
token1 = torch.argmax(logits1, dim=-1)
token2 = torch.argmax(logits2, dim=-1)
generated.extend([token1.item(), token2.item()])
input_ids = torch.cat([
input_ids,
token1.unsqueeze(0),
token2.unsqueeze(0)
], dim=1)
return tokenizer.decode(generated)
def forced_sat_v2(model, tokenizer, prompt, n_tokens=20):
"""
FORCED SAT v2: Add untrained linear projection for 2nd token
This simulates what would happen if you tried to add SAT to AR
without training it.
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
# Create random (untrained) projection for 2nd token
hidden_size = model.config.n_embd
vocab_size = model.config.vocab_size
random_head = torch.randn(hidden_size, vocab_size).cuda() * 0.02
generated = []
for _ in range(n_tokens // 2):
with torch.no_grad():
outputs = model(input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[-1][:, -1, :]
# Token 1: Use trained head
logits1 = outputs.logits[:, -1, :]
token1 = torch.argmax(logits1, dim=-1)
# Token 2: Use untrained random head
logits2 = hidden @ random_head
token2 = torch.argmax(logits2, dim=-1)
generated.extend([token1.item(), token2.item()])
input_ids = torch.cat([
input_ids,
token1.unsqueeze(0),
token2.unsqueeze(0)
], dim=1)
return tokenizer.decode(generated)
def measure_perplexity(model, tokenizer, text):
"""Measure perplexity of generated text"""
input_ids = tokenizer.encode(text, return_tensors='pt').cuda()
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
return torch.exp(outputs.loss).item()
def benchmark_speed(model, tokenizer, prompt, n_tokens=100, n_runs=5):
"""Benchmark tokens per second for AR vs SAT"""
import time
# Warmup
ar_generate(model, tokenizer, prompt, n_tokens=10)
forced_sat_generate(model, tokenizer, prompt, n_tokens=10)
# AR benchmark
ar_times = []
for _ in range(n_runs):
torch.cuda.synchronize()
start = time.perf_counter()
ar_generate(model, tokenizer, prompt, n_tokens=n_tokens)
torch.cuda.synchronize()
ar_times.append(time.perf_counter() - start)
ar_avg = sum(ar_times) / len(ar_times)
ar_tps = n_tokens / ar_avg
# SAT benchmark
sat_times = []
for _ in range(n_runs):
torch.cuda.synchronize()
start = time.perf_counter()
forced_sat_generate(model, tokenizer, prompt, n_tokens=n_tokens)
torch.cuda.synchronize()
sat_times.append(time.perf_counter() - start)
sat_avg = sum(sat_times) / len(sat_times)
sat_tps = n_tokens / sat_avg
return ar_tps, sat_tps
def main():
model, tokenizer = load_model()
prompts = [
"The quick brown fox",
"In the beginning",
"Once upon a time",
"The scientist discovered that",
"Machine learning is",
]
print("\n" + "="*80)
print("SAT RETROFIT TEST: Can AR models be forced to output 2 tokens?")
print("="*80)
# Speed benchmark first
print("\n\nSPEED BENCHMARK (100 tokens, 5 runs):")
print("-"*60)
ar_tps, sat_tps = benchmark_speed(model, tokenizer, "The quick brown fox", n_tokens=100, n_runs=5)
print(f"AR: {ar_tps:.1f} tokens/sec")
print(f"SAT: {sat_tps:.1f} tokens/sec")
print(f"Speedup: {sat_tps/ar_tps:.2f}x")
for prompt in prompts:
print(f"\n\nPrompt: '{prompt}'")
print("-"*60)
# Standard AR
ar_text = ar_generate(model, tokenizer, prompt, n_tokens=20)
print(f"AR (baseline): {ar_text}")
# Forced SAT methods
sat_text = forced_sat_generate(model, tokenizer, prompt, n_tokens=20)
print(f"Forced SAT v1: {sat_text}")
sat_v2_text = forced_sat_v2(model, tokenizer, prompt, n_tokens=20)
print(f"Forced SAT v2: {sat_v2_text}")
# Measure perplexity
try:
ar_ppl = measure_perplexity(model, tokenizer, prompt + ar_text)
sat_ppl = measure_perplexity(model, tokenizer, prompt + sat_text)
print(f"\nPerplexity - AR: {ar_ppl:.2f}, SAT: {sat_ppl:.2f}, Ratio: {sat_ppl/ar_ppl:.2f}x worse")
except:
pass
print("\n" + "="*80)
print("CONCLUSION: AR hidden states don't encode multi-token future.")
print("Joint AR+SAT training required to build compatible representations.")
print("="*80)
if __name__ == "__main__":
main()
|