import gradio as gr import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import time # Load model once print("Loading GPT-2...") model = GPT2LMHeadModel.from_pretrained('gpt2').eval() tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # Move to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"Running on {device}") def ar_generate(prompt, n_tokens=50): """Standard AR generation - 1 token at a time""" input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) generated = [] for _ in range(n_tokens): with torch.no_grad(): outputs = model(input_ids) next_logits = outputs.logits[:, -1, :] next_token = torch.argmax(next_logits, dim=-1) generated.append(next_token.item()) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) return tokenizer.decode(generated) def forced_sat_generate(prompt, n_tokens=50, block_size=2): """ FORCED SAT: Predict 2 tokens at once from AR model Token 1: from position -1 (current) Token 2: from position -2 (stale context) """ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) generated = [] for _ in range(n_tokens // block_size): with torch.no_grad(): outputs = model(input_ids) # Token 1: current position logits1 = outputs.logits[:, -1, :] # Token 2: previous position (stale) logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1 token1 = torch.argmax(logits1, dim=-1) token2 = torch.argmax(logits2, dim=-1) generated.extend([token1.item(), token2.item()]) input_ids = torch.cat([ input_ids, token1.unsqueeze(0), token2.unsqueeze(0) ], dim=1) return tokenizer.decode(generated) def compare(prompt, n_tokens): n_tokens = int(n_tokens) # AR with timing if device == "cuda": torch.cuda.synchronize() start = time.perf_counter() ar_output = ar_generate(prompt, n_tokens) if device == "cuda": torch.cuda.synchronize() ar_time = time.perf_counter() - start ar_tps = n_tokens / ar_time # SAT with timing if device == "cuda": torch.cuda.synchronize() start = time.perf_counter() sat_output = forced_sat_generate(prompt, n_tokens) if device == "cuda": torch.cuda.synchronize() sat_time = time.perf_counter() - start sat_tps = n_tokens / sat_time speedup = ar_time / sat_time if sat_time > 0 else 0 ar_label = f"AR Output - {ar_tps:.1f} tok/s" sat_label = f"Forced SAT - {sat_tps:.1f} tok/s" speedup_text = f"## Speedup: {speedup:.2f}x" return ar_output, sat_output, ar_label, sat_label, speedup_text # Gradio interface with gr.Blocks(title="AR vs Forced SAT") as demo: gr.Markdown(""" # AR vs Forced SAT Comparison **Can AR models be forced to output 2 tokens at once?** Model: GPT-2 (124M params) """) with gr.Row(): prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2) n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate") btn = gr.Button("Generate", variant="primary") speedup_display = gr.Markdown("## Speedup: ?x") with gr.Row(): ar_label = gr.Markdown("### AR Output - ? tok/s") sat_label = gr.Markdown("### Forced SAT - ? tok/s") with gr.Row(): ar_output = gr.Textbox(label="", lines=5, show_label=False) sat_output = gr.Textbox(label="", lines=5, show_label=False) btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display]) gr.Examples( examples=[ ["The quick brown fox", 40], ["In the beginning", 40], ["Once upon a time", 40], ["Machine learning is", 40], ], inputs=[prompt, n_tokens], ) gr.Markdown(""" --- **Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context. **Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large) *OpenTransformers Ltd - Scott Bisset* """) demo.launch()