|
|
|
|
|
""" |
|
|
SAT Retrofit Test: Can we force an AR model to output 2 tokens at once? |
|
|
|
|
|
Hypothesis: AR models can't be "snapped" to SAT because their hidden states |
|
|
only encode next-token prediction, not multi-token prediction. |
|
|
|
|
|
Test: Take GPT-2, force 2-token prediction, measure degradation. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def load_model(): |
|
|
print("Loading GPT-2...") |
|
|
model = GPT2LMHeadModel.from_pretrained('gpt2').cuda().eval() |
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
|
return model, tokenizer |
|
|
|
|
|
def ar_generate(model, tokenizer, prompt, n_tokens=20): |
|
|
"""Standard AR generation - 1 token at a time""" |
|
|
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() |
|
|
|
|
|
generated = [] |
|
|
for _ in range(n_tokens): |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids) |
|
|
next_logits = outputs.logits[:, -1, :] |
|
|
next_token = torch.argmax(next_logits, dim=-1) |
|
|
generated.append(next_token.item()) |
|
|
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) |
|
|
|
|
|
return tokenizer.decode(generated) |
|
|
|
|
|
def forced_sat_generate(model, tokenizer, prompt, n_tokens=20, block_size=2): |
|
|
""" |
|
|
FORCED SAT: Try to predict 2 tokens at once from AR model |
|
|
|
|
|
Method: Use hidden state at position N to predict BOTH N+1 and N+2 |
|
|
This should fail because the model wasn't trained for this. |
|
|
""" |
|
|
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() |
|
|
|
|
|
generated = [] |
|
|
for _ in range(n_tokens // block_size): |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, output_hidden_states=True) |
|
|
|
|
|
|
|
|
hidden = outputs.hidden_states[-1][:, -1, :] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits1 = outputs.logits[:, -1, :] |
|
|
logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1 |
|
|
|
|
|
token1 = torch.argmax(logits1, dim=-1) |
|
|
token2 = torch.argmax(logits2, dim=-1) |
|
|
|
|
|
generated.extend([token1.item(), token2.item()]) |
|
|
input_ids = torch.cat([ |
|
|
input_ids, |
|
|
token1.unsqueeze(0), |
|
|
token2.unsqueeze(0) |
|
|
], dim=1) |
|
|
|
|
|
return tokenizer.decode(generated) |
|
|
|
|
|
def forced_sat_v2(model, tokenizer, prompt, n_tokens=20): |
|
|
""" |
|
|
FORCED SAT v2: Add untrained linear projection for 2nd token |
|
|
|
|
|
This simulates what would happen if you tried to add SAT to AR |
|
|
without training it. |
|
|
""" |
|
|
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() |
|
|
|
|
|
|
|
|
hidden_size = model.config.n_embd |
|
|
vocab_size = model.config.vocab_size |
|
|
random_head = torch.randn(hidden_size, vocab_size).cuda() * 0.02 |
|
|
|
|
|
generated = [] |
|
|
for _ in range(n_tokens // 2): |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, output_hidden_states=True) |
|
|
hidden = outputs.hidden_states[-1][:, -1, :] |
|
|
|
|
|
|
|
|
logits1 = outputs.logits[:, -1, :] |
|
|
token1 = torch.argmax(logits1, dim=-1) |
|
|
|
|
|
|
|
|
logits2 = hidden @ random_head |
|
|
token2 = torch.argmax(logits2, dim=-1) |
|
|
|
|
|
generated.extend([token1.item(), token2.item()]) |
|
|
input_ids = torch.cat([ |
|
|
input_ids, |
|
|
token1.unsqueeze(0), |
|
|
token2.unsqueeze(0) |
|
|
], dim=1) |
|
|
|
|
|
return tokenizer.decode(generated) |
|
|
|
|
|
def measure_perplexity(model, tokenizer, text): |
|
|
"""Measure perplexity of generated text""" |
|
|
input_ids = tokenizer.encode(text, return_tensors='pt').cuda() |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids, labels=input_ids) |
|
|
return torch.exp(outputs.loss).item() |
|
|
|
|
|
def benchmark_speed(model, tokenizer, prompt, n_tokens=100, n_runs=5): |
|
|
"""Benchmark tokens per second for AR vs SAT""" |
|
|
import time |
|
|
|
|
|
|
|
|
ar_generate(model, tokenizer, prompt, n_tokens=10) |
|
|
forced_sat_generate(model, tokenizer, prompt, n_tokens=10) |
|
|
|
|
|
|
|
|
ar_times = [] |
|
|
for _ in range(n_runs): |
|
|
torch.cuda.synchronize() |
|
|
start = time.perf_counter() |
|
|
ar_generate(model, tokenizer, prompt, n_tokens=n_tokens) |
|
|
torch.cuda.synchronize() |
|
|
ar_times.append(time.perf_counter() - start) |
|
|
|
|
|
ar_avg = sum(ar_times) / len(ar_times) |
|
|
ar_tps = n_tokens / ar_avg |
|
|
|
|
|
|
|
|
sat_times = [] |
|
|
for _ in range(n_runs): |
|
|
torch.cuda.synchronize() |
|
|
start = time.perf_counter() |
|
|
forced_sat_generate(model, tokenizer, prompt, n_tokens=n_tokens) |
|
|
torch.cuda.synchronize() |
|
|
sat_times.append(time.perf_counter() - start) |
|
|
|
|
|
sat_avg = sum(sat_times) / len(sat_times) |
|
|
sat_tps = n_tokens / sat_avg |
|
|
|
|
|
return ar_tps, sat_tps |
|
|
|
|
|
def main(): |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
prompts = [ |
|
|
"The quick brown fox", |
|
|
"In the beginning", |
|
|
"Once upon a time", |
|
|
"The scientist discovered that", |
|
|
"Machine learning is", |
|
|
] |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("SAT RETROFIT TEST: Can AR models be forced to output 2 tokens?") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\n\nSPEED BENCHMARK (100 tokens, 5 runs):") |
|
|
print("-"*60) |
|
|
ar_tps, sat_tps = benchmark_speed(model, tokenizer, "The quick brown fox", n_tokens=100, n_runs=5) |
|
|
print(f"AR: {ar_tps:.1f} tokens/sec") |
|
|
print(f"SAT: {sat_tps:.1f} tokens/sec") |
|
|
print(f"Speedup: {sat_tps/ar_tps:.2f}x") |
|
|
|
|
|
for prompt in prompts: |
|
|
print(f"\n\nPrompt: '{prompt}'") |
|
|
print("-"*60) |
|
|
|
|
|
|
|
|
ar_text = ar_generate(model, tokenizer, prompt, n_tokens=20) |
|
|
print(f"AR (baseline): {ar_text}") |
|
|
|
|
|
|
|
|
sat_text = forced_sat_generate(model, tokenizer, prompt, n_tokens=20) |
|
|
print(f"Forced SAT v1: {sat_text}") |
|
|
|
|
|
sat_v2_text = forced_sat_v2(model, tokenizer, prompt, n_tokens=20) |
|
|
print(f"Forced SAT v2: {sat_v2_text}") |
|
|
|
|
|
|
|
|
try: |
|
|
ar_ppl = measure_perplexity(model, tokenizer, prompt + ar_text) |
|
|
sat_ppl = measure_perplexity(model, tokenizer, prompt + sat_text) |
|
|
print(f"\nPerplexity - AR: {ar_ppl:.2f}, SAT: {sat_ppl:.2f}, Ratio: {sat_ppl/ar_ppl:.2f}x worse") |
|
|
except: |
|
|
pass |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("CONCLUSION: AR hidden states don't encode multi-token future.") |
|
|
print("Joint AR+SAT training required to build compatible representations.") |
|
|
print("="*80) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|