Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import time | |
| # Load model once | |
| print("Loading GPT-2...") | |
| model = GPT2LMHeadModel.from_pretrained('gpt2').eval() | |
| tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| # Move to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| print(f"Running on {device}") | |
| def ar_generate(prompt, n_tokens=50): | |
| """Standard AR generation - 1 token at a time""" | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated = [] | |
| for _ in range(n_tokens): | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| next_logits = outputs.logits[:, -1, :] | |
| next_token = torch.argmax(next_logits, dim=-1) | |
| generated.append(next_token.item()) | |
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) | |
| return tokenizer.decode(generated) | |
| def forced_sat_generate(prompt, n_tokens=50, block_size=2): | |
| """ | |
| FORCED SAT: Predict 2 tokens at once from AR model | |
| Token 1: from position -1 (current) | |
| Token 2: from position -2 (stale context) | |
| """ | |
| input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device) | |
| generated = [] | |
| for _ in range(n_tokens // block_size): | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| # Token 1: current position | |
| logits1 = outputs.logits[:, -1, :] | |
| # Token 2: previous position (stale) | |
| logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1 | |
| token1 = torch.argmax(logits1, dim=-1) | |
| token2 = torch.argmax(logits2, dim=-1) | |
| generated.extend([token1.item(), token2.item()]) | |
| input_ids = torch.cat([ | |
| input_ids, | |
| token1.unsqueeze(0), | |
| token2.unsqueeze(0) | |
| ], dim=1) | |
| return tokenizer.decode(generated) | |
| def compare(prompt, n_tokens): | |
| n_tokens = int(n_tokens) | |
| # AR with timing | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| start = time.perf_counter() | |
| ar_output = ar_generate(prompt, n_tokens) | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| ar_time = time.perf_counter() - start | |
| ar_tps = n_tokens / ar_time | |
| # SAT with timing | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| start = time.perf_counter() | |
| sat_output = forced_sat_generate(prompt, n_tokens) | |
| if device == "cuda": | |
| torch.cuda.synchronize() | |
| sat_time = time.perf_counter() - start | |
| sat_tps = n_tokens / sat_time | |
| speedup = ar_time / sat_time if sat_time > 0 else 0 | |
| ar_label = f"AR Output - {ar_tps:.1f} tok/s" | |
| sat_label = f"Forced SAT - {sat_tps:.1f} tok/s" | |
| speedup_text = f"## Speedup: {speedup:.2f}x" | |
| return ar_output, sat_output, ar_label, sat_label, speedup_text | |
| # Gradio interface | |
| with gr.Blocks(title="AR vs Forced SAT") as demo: | |
| gr.Markdown(""" | |
| # AR vs Forced SAT Comparison | |
| **Can AR models be forced to output 2 tokens at once?** | |
| Model: GPT-2 (124M params) | |
| """) | |
| with gr.Row(): | |
| prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2) | |
| n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate") | |
| btn = gr.Button("Generate", variant="primary") | |
| speedup_display = gr.Markdown("## Speedup: ?x") | |
| with gr.Row(): | |
| ar_label = gr.Markdown("### AR Output - ? tok/s") | |
| sat_label = gr.Markdown("### Forced SAT - ? tok/s") | |
| with gr.Row(): | |
| ar_output = gr.Textbox(label="", lines=5, show_label=False) | |
| sat_output = gr.Textbox(label="", lines=5, show_label=False) | |
| btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display]) | |
| gr.Examples( | |
| examples=[ | |
| ["The quick brown fox", 40], | |
| ["In the beginning", 40], | |
| ["Once upon a time", 40], | |
| ["Machine learning is", 40], | |
| ], | |
| inputs=[prompt, n_tokens], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context. | |
| **Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large) | |
| *OpenTransformers Ltd - Scott Bisset* | |
| """) | |
| demo.launch() | |