#!/usr/bin/env python3 """ SAT Retrofit Test: Can we force an AR model to output 2 tokens at once? Hypothesis: AR models can't be "snapped" to SAT because their hidden states only encode next-token prediction, not multi-token prediction. Test: Take GPT-2, force 2-token prediction, measure degradation. """ import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch.nn.functional as F def load_model(): print("Loading GPT-2...") model = GPT2LMHeadModel.from_pretrained('gpt2').cuda().eval() tokenizer = GPT2Tokenizer.from_pretrained('gpt2') return model, tokenizer def ar_generate(model, tokenizer, prompt, n_tokens=20): """Standard AR generation - 1 token at a time""" input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() generated = [] for _ in range(n_tokens): with torch.no_grad(): outputs = model(input_ids) next_logits = outputs.logits[:, -1, :] next_token = torch.argmax(next_logits, dim=-1) generated.append(next_token.item()) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) return tokenizer.decode(generated) def forced_sat_generate(model, tokenizer, prompt, n_tokens=20, block_size=2): """ FORCED SAT: Try to predict 2 tokens at once from AR model Method: Use hidden state at position N to predict BOTH N+1 and N+2 This should fail because the model wasn't trained for this. """ input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() generated = [] for _ in range(n_tokens // block_size): with torch.no_grad(): outputs = model(input_ids, output_hidden_states=True) # Get final hidden state hidden = outputs.hidden_states[-1][:, -1, :] # [1, 768] # Method 1: Just use same logits twice (obviously wrong) # logits = outputs.logits[:, -1, :] # token1 = torch.argmax(logits, dim=-1) # token2 = torch.argmax(logits, dim=-1) # Same! # Method 2: Get logits, sample first, then... what? # The model has NO trained projection for "2nd next token" # Method 3: Use 2nd-to-last position for token 2? # This is using OLDER context which is worse logits1 = outputs.logits[:, -1, :] logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1 token1 = torch.argmax(logits1, dim=-1) token2 = torch.argmax(logits2, dim=-1) generated.extend([token1.item(), token2.item()]) input_ids = torch.cat([ input_ids, token1.unsqueeze(0), token2.unsqueeze(0) ], dim=1) return tokenizer.decode(generated) def forced_sat_v2(model, tokenizer, prompt, n_tokens=20): """ FORCED SAT v2: Add untrained linear projection for 2nd token This simulates what would happen if you tried to add SAT to AR without training it. """ input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() # Create random (untrained) projection for 2nd token hidden_size = model.config.n_embd vocab_size = model.config.vocab_size random_head = torch.randn(hidden_size, vocab_size).cuda() * 0.02 generated = [] for _ in range(n_tokens // 2): with torch.no_grad(): outputs = model(input_ids, output_hidden_states=True) hidden = outputs.hidden_states[-1][:, -1, :] # Token 1: Use trained head logits1 = outputs.logits[:, -1, :] token1 = torch.argmax(logits1, dim=-1) # Token 2: Use untrained random head logits2 = hidden @ random_head token2 = torch.argmax(logits2, dim=-1) generated.extend([token1.item(), token2.item()]) input_ids = torch.cat([ input_ids, token1.unsqueeze(0), token2.unsqueeze(0) ], dim=1) return tokenizer.decode(generated) def measure_perplexity(model, tokenizer, text): """Measure perplexity of generated text""" input_ids = tokenizer.encode(text, return_tensors='pt').cuda() with torch.no_grad(): outputs = model(input_ids, labels=input_ids) return torch.exp(outputs.loss).item() def benchmark_speed(model, tokenizer, prompt, n_tokens=100, n_runs=5): """Benchmark tokens per second for AR vs SAT""" import time # Warmup ar_generate(model, tokenizer, prompt, n_tokens=10) forced_sat_generate(model, tokenizer, prompt, n_tokens=10) # AR benchmark ar_times = [] for _ in range(n_runs): torch.cuda.synchronize() start = time.perf_counter() ar_generate(model, tokenizer, prompt, n_tokens=n_tokens) torch.cuda.synchronize() ar_times.append(time.perf_counter() - start) ar_avg = sum(ar_times) / len(ar_times) ar_tps = n_tokens / ar_avg # SAT benchmark sat_times = [] for _ in range(n_runs): torch.cuda.synchronize() start = time.perf_counter() forced_sat_generate(model, tokenizer, prompt, n_tokens=n_tokens) torch.cuda.synchronize() sat_times.append(time.perf_counter() - start) sat_avg = sum(sat_times) / len(sat_times) sat_tps = n_tokens / sat_avg return ar_tps, sat_tps def main(): model, tokenizer = load_model() prompts = [ "The quick brown fox", "In the beginning", "Once upon a time", "The scientist discovered that", "Machine learning is", ] print("\n" + "="*80) print("SAT RETROFIT TEST: Can AR models be forced to output 2 tokens?") print("="*80) # Speed benchmark first print("\n\nSPEED BENCHMARK (100 tokens, 5 runs):") print("-"*60) ar_tps, sat_tps = benchmark_speed(model, tokenizer, "The quick brown fox", n_tokens=100, n_runs=5) print(f"AR: {ar_tps:.1f} tokens/sec") print(f"SAT: {sat_tps:.1f} tokens/sec") print(f"Speedup: {sat_tps/ar_tps:.2f}x") for prompt in prompts: print(f"\n\nPrompt: '{prompt}'") print("-"*60) # Standard AR ar_text = ar_generate(model, tokenizer, prompt, n_tokens=20) print(f"AR (baseline): {ar_text}") # Forced SAT methods sat_text = forced_sat_generate(model, tokenizer, prompt, n_tokens=20) print(f"Forced SAT v1: {sat_text}") sat_v2_text = forced_sat_v2(model, tokenizer, prompt, n_tokens=20) print(f"Forced SAT v2: {sat_v2_text}") # Measure perplexity try: ar_ppl = measure_perplexity(model, tokenizer, prompt + ar_text) sat_ppl = measure_perplexity(model, tokenizer, prompt + sat_text) print(f"\nPerplexity - AR: {ar_ppl:.2f}, SAT: {sat_ppl:.2f}, Ratio: {sat_ppl/ar_ppl:.2f}x worse") except: pass print("\n" + "="*80) print("CONCLUSION: AR hidden states don't encode multi-token future.") print("Joint AR+SAT training required to build compatible representations.") print("="*80) if __name__ == "__main__": main()