Spaces:
Sleeping
Sleeping
File size: 4,567 Bytes
e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 19e9060 e310437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
# Load model once
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2').eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Running on {device}")
def ar_generate(prompt, n_tokens=50):
"""Standard AR generation - 1 token at a time"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = []
for _ in range(n_tokens):
with torch.no_grad():
outputs = model(input_ids)
next_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_logits, dim=-1)
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.decode(generated)
def forced_sat_generate(prompt, n_tokens=50, block_size=2):
"""
FORCED SAT: Predict 2 tokens at once from AR model
Token 1: from position -1 (current)
Token 2: from position -2 (stale context)
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = []
for _ in range(n_tokens // block_size):
with torch.no_grad():
outputs = model(input_ids)
# Token 1: current position
logits1 = outputs.logits[:, -1, :]
# Token 2: previous position (stale)
logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
token1 = torch.argmax(logits1, dim=-1)
token2 = torch.argmax(logits2, dim=-1)
generated.extend([token1.item(), token2.item()])
input_ids = torch.cat([
input_ids,
token1.unsqueeze(0),
token2.unsqueeze(0)
], dim=1)
return tokenizer.decode(generated)
def compare(prompt, n_tokens):
n_tokens = int(n_tokens)
# AR with timing
if device == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
ar_output = ar_generate(prompt, n_tokens)
if device == "cuda":
torch.cuda.synchronize()
ar_time = time.perf_counter() - start
ar_tps = n_tokens / ar_time
# SAT with timing
if device == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
sat_output = forced_sat_generate(prompt, n_tokens)
if device == "cuda":
torch.cuda.synchronize()
sat_time = time.perf_counter() - start
sat_tps = n_tokens / sat_time
speedup = ar_time / sat_time if sat_time > 0 else 0
ar_label = f"AR Output - {ar_tps:.1f} tok/s"
sat_label = f"Forced SAT - {sat_tps:.1f} tok/s"
speedup_text = f"## Speedup: {speedup:.2f}x"
return ar_output, sat_output, ar_label, sat_label, speedup_text
# Gradio interface
with gr.Blocks(title="AR vs Forced SAT") as demo:
gr.Markdown("""
# AR vs Forced SAT Comparison
**Can AR models be forced to output 2 tokens at once?**
Model: GPT-2 (124M params)
""")
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2)
n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate")
btn = gr.Button("Generate", variant="primary")
speedup_display = gr.Markdown("## Speedup: ?x")
with gr.Row():
ar_label = gr.Markdown("### AR Output - ? tok/s")
sat_label = gr.Markdown("### Forced SAT - ? tok/s")
with gr.Row():
ar_output = gr.Textbox(label="", lines=5, show_label=False)
sat_output = gr.Textbox(label="", lines=5, show_label=False)
btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display])
gr.Examples(
examples=[
["The quick brown fox", 40],
["In the beginning", 40],
["Once upon a time", 40],
["Machine learning is", 40],
],
inputs=[prompt, n_tokens],
)
gr.Markdown("""
---
**Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context.
**Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large)
*OpenTransformers Ltd - Scott Bisset*
""")
demo.launch()
|