ar-vs-sat / app.py
OpenTransformer's picture
Upload app.py with huggingface_hub
19e9060 verified
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time
# Load model once
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2').eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Running on {device}")
def ar_generate(prompt, n_tokens=50):
"""Standard AR generation - 1 token at a time"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = []
for _ in range(n_tokens):
with torch.no_grad():
outputs = model(input_ids)
next_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_logits, dim=-1)
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
return tokenizer.decode(generated)
def forced_sat_generate(prompt, n_tokens=50, block_size=2):
"""
FORCED SAT: Predict 2 tokens at once from AR model
Token 1: from position -1 (current)
Token 2: from position -2 (stale context)
"""
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
generated = []
for _ in range(n_tokens // block_size):
with torch.no_grad():
outputs = model(input_ids)
# Token 1: current position
logits1 = outputs.logits[:, -1, :]
# Token 2: previous position (stale)
logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
token1 = torch.argmax(logits1, dim=-1)
token2 = torch.argmax(logits2, dim=-1)
generated.extend([token1.item(), token2.item()])
input_ids = torch.cat([
input_ids,
token1.unsqueeze(0),
token2.unsqueeze(0)
], dim=1)
return tokenizer.decode(generated)
def compare(prompt, n_tokens):
n_tokens = int(n_tokens)
# AR with timing
if device == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
ar_output = ar_generate(prompt, n_tokens)
if device == "cuda":
torch.cuda.synchronize()
ar_time = time.perf_counter() - start
ar_tps = n_tokens / ar_time
# SAT with timing
if device == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
sat_output = forced_sat_generate(prompt, n_tokens)
if device == "cuda":
torch.cuda.synchronize()
sat_time = time.perf_counter() - start
sat_tps = n_tokens / sat_time
speedup = ar_time / sat_time if sat_time > 0 else 0
ar_label = f"AR Output - {ar_tps:.1f} tok/s"
sat_label = f"Forced SAT - {sat_tps:.1f} tok/s"
speedup_text = f"## Speedup: {speedup:.2f}x"
return ar_output, sat_output, ar_label, sat_label, speedup_text
# Gradio interface
with gr.Blocks(title="AR vs Forced SAT") as demo:
gr.Markdown("""
# AR vs Forced SAT Comparison
**Can AR models be forced to output 2 tokens at once?**
Model: GPT-2 (124M params)
""")
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2)
n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate")
btn = gr.Button("Generate", variant="primary")
speedup_display = gr.Markdown("## Speedup: ?x")
with gr.Row():
ar_label = gr.Markdown("### AR Output - ? tok/s")
sat_label = gr.Markdown("### Forced SAT - ? tok/s")
with gr.Row():
ar_output = gr.Textbox(label="", lines=5, show_label=False)
sat_output = gr.Textbox(label="", lines=5, show_label=False)
btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display])
gr.Examples(
examples=[
["The quick brown fox", 40],
["In the beginning", 40],
["Once upon a time", 40],
["Machine learning is", 40],
],
inputs=[prompt, n_tokens],
)
gr.Markdown("""
---
**Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context.
**Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large)
*OpenTransformers Ltd - Scott Bisset*
""")
demo.launch()