Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
|
| 4 |
|
| 5 |
# Load model once
|
| 6 |
print("Loading GPT-2...")
|
|
@@ -59,9 +60,34 @@ def forced_sat_generate(prompt, n_tokens=50, block_size=2):
|
|
| 59 |
|
| 60 |
def compare(prompt, n_tokens):
|
| 61 |
n_tokens = int(n_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
ar_output = ar_generate(prompt, n_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
sat_output = forced_sat_generate(prompt, n_tokens)
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
# Gradio interface
|
| 67 |
with gr.Blocks(title="AR vs Forced SAT") as demo:
|
|
@@ -70,12 +96,6 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
|
|
| 70 |
|
| 71 |
**Can AR models be forced to output 2 tokens at once?**
|
| 72 |
|
| 73 |
-
- **AR (Autoregressive):** Standard 1-token-at-a-time generation
|
| 74 |
-
- **Forced SAT:** Outputs 2 tokens per step using stale context for token 2
|
| 75 |
-
|
| 76 |
-
Forced SAT runs ~2x faster but produces degraded output because AR hidden states
|
| 77 |
-
don't encode multi-token futures. Joint AR+SAT training is required for quality SAT inference.
|
| 78 |
-
|
| 79 |
Model: GPT-2 (124M params)
|
| 80 |
""")
|
| 81 |
|
|
@@ -85,11 +105,17 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
|
|
| 85 |
|
| 86 |
btn = gr.Button("Generate", variant="primary")
|
| 87 |
|
|
|
|
|
|
|
| 88 |
with gr.Row():
|
| 89 |
-
|
| 90 |
-
|
| 91 |
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
gr.Examples(
|
| 95 |
examples=[
|
|
@@ -97,20 +123,15 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
|
|
| 97 |
["In the beginning", 40],
|
| 98 |
["Once upon a time", 40],
|
| 99 |
["Machine learning is", 40],
|
| 100 |
-
["The president announced that", 40],
|
| 101 |
],
|
| 102 |
inputs=[prompt, n_tokens],
|
| 103 |
)
|
| 104 |
|
| 105 |
gr.Markdown("""
|
| 106 |
---
|
| 107 |
-
**Why Forced SAT fails:** AR hidden states
|
| 108 |
-
There's no representation for token N+2. Forcing 2-token output uses stale context,
|
| 109 |
-
creating alternating good/bad tokens.
|
| 110 |
-
|
| 111 |
-
**Solution:** Train AR+SAT jointly from scratch so representations encode multiple future tokens.
|
| 112 |
|
| 113 |
-
See
|
| 114 |
|
| 115 |
*OpenTransformers Ltd - Scott Bisset*
|
| 116 |
""")
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 4 |
+
import time
|
| 5 |
|
| 6 |
# Load model once
|
| 7 |
print("Loading GPT-2...")
|
|
|
|
| 60 |
|
| 61 |
def compare(prompt, n_tokens):
|
| 62 |
n_tokens = int(n_tokens)
|
| 63 |
+
|
| 64 |
+
# AR with timing
|
| 65 |
+
if device == "cuda":
|
| 66 |
+
torch.cuda.synchronize()
|
| 67 |
+
start = time.perf_counter()
|
| 68 |
ar_output = ar_generate(prompt, n_tokens)
|
| 69 |
+
if device == "cuda":
|
| 70 |
+
torch.cuda.synchronize()
|
| 71 |
+
ar_time = time.perf_counter() - start
|
| 72 |
+
ar_tps = n_tokens / ar_time
|
| 73 |
+
|
| 74 |
+
# SAT with timing
|
| 75 |
+
if device == "cuda":
|
| 76 |
+
torch.cuda.synchronize()
|
| 77 |
+
start = time.perf_counter()
|
| 78 |
sat_output = forced_sat_generate(prompt, n_tokens)
|
| 79 |
+
if device == "cuda":
|
| 80 |
+
torch.cuda.synchronize()
|
| 81 |
+
sat_time = time.perf_counter() - start
|
| 82 |
+
sat_tps = n_tokens / sat_time
|
| 83 |
+
|
| 84 |
+
speedup = ar_time / sat_time if sat_time > 0 else 0
|
| 85 |
+
|
| 86 |
+
ar_label = f"AR Output - {ar_tps:.1f} tok/s"
|
| 87 |
+
sat_label = f"Forced SAT - {sat_tps:.1f} tok/s"
|
| 88 |
+
speedup_text = f"## Speedup: {speedup:.2f}x"
|
| 89 |
+
|
| 90 |
+
return ar_output, sat_output, ar_label, sat_label, speedup_text
|
| 91 |
|
| 92 |
# Gradio interface
|
| 93 |
with gr.Blocks(title="AR vs Forced SAT") as demo:
|
|
|
|
| 96 |
|
| 97 |
**Can AR models be forced to output 2 tokens at once?**
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
Model: GPT-2 (124M params)
|
| 100 |
""")
|
| 101 |
|
|
|
|
| 105 |
|
| 106 |
btn = gr.Button("Generate", variant="primary")
|
| 107 |
|
| 108 |
+
speedup_display = gr.Markdown("## Speedup: ?x")
|
| 109 |
+
|
| 110 |
with gr.Row():
|
| 111 |
+
ar_label = gr.Markdown("### AR Output - ? tok/s")
|
| 112 |
+
sat_label = gr.Markdown("### Forced SAT - ? tok/s")
|
| 113 |
|
| 114 |
+
with gr.Row():
|
| 115 |
+
ar_output = gr.Textbox(label="", lines=5, show_label=False)
|
| 116 |
+
sat_output = gr.Textbox(label="", lines=5, show_label=False)
|
| 117 |
+
|
| 118 |
+
btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display])
|
| 119 |
|
| 120 |
gr.Examples(
|
| 121 |
examples=[
|
|
|
|
| 123 |
["In the beginning", 40],
|
| 124 |
["Once upon a time", 40],
|
| 125 |
["Machine learning is", 40],
|
|
|
|
| 126 |
],
|
| 127 |
inputs=[prompt, n_tokens],
|
| 128 |
)
|
| 129 |
|
| 130 |
gr.Markdown("""
|
| 131 |
---
|
| 132 |
+
**Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
**Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large)
|
| 135 |
|
| 136 |
*OpenTransformers Ltd - Scott Bisset*
|
| 137 |
""")
|