OpenTransformer commited on
Commit
19e9060
·
verified ·
1 Parent(s): 08d6098

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
4
 
5
  # Load model once
6
  print("Loading GPT-2...")
@@ -59,9 +60,34 @@ def forced_sat_generate(prompt, n_tokens=50, block_size=2):
59
 
60
  def compare(prompt, n_tokens):
61
  n_tokens = int(n_tokens)
 
 
 
 
 
62
  ar_output = ar_generate(prompt, n_tokens)
 
 
 
 
 
 
 
 
 
63
  sat_output = forced_sat_generate(prompt, n_tokens)
64
- return ar_output, sat_output
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Gradio interface
67
  with gr.Blocks(title="AR vs Forced SAT") as demo:
@@ -70,12 +96,6 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
70
 
71
  **Can AR models be forced to output 2 tokens at once?**
72
 
73
- - **AR (Autoregressive):** Standard 1-token-at-a-time generation
74
- - **Forced SAT:** Outputs 2 tokens per step using stale context for token 2
75
-
76
- Forced SAT runs ~2x faster but produces degraded output because AR hidden states
77
- don't encode multi-token futures. Joint AR+SAT training is required for quality SAT inference.
78
-
79
  Model: GPT-2 (124M params)
80
  """)
81
 
@@ -85,11 +105,17 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
85
 
86
  btn = gr.Button("Generate", variant="primary")
87
 
 
 
88
  with gr.Row():
89
- ar_output = gr.Textbox(label="AR Output (baseline)", lines=5)
90
- sat_output = gr.Textbox(label="Forced SAT v1 (2x speed, degraded)", lines=5)
91
 
92
- btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output])
 
 
 
 
93
 
94
  gr.Examples(
95
  examples=[
@@ -97,20 +123,15 @@ with gr.Blocks(title="AR vs Forced SAT") as demo:
97
  ["In the beginning", 40],
98
  ["Once upon a time", 40],
99
  ["Machine learning is", 40],
100
- ["The president announced that", 40],
101
  ],
102
  inputs=[prompt, n_tokens],
103
  )
104
 
105
  gr.Markdown("""
106
  ---
107
- **Why Forced SAT fails:** AR hidden states at position N only encode "next token N+1".
108
- There's no representation for token N+2. Forcing 2-token output uses stale context,
109
- creating alternating good/bad tokens.
110
-
111
- **Solution:** Train AR+SAT jointly from scratch so representations encode multiple future tokens.
112
 
113
- See: [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large) | [Experiment Code](https://huggingface.co/OpenTransformer/sat-retrofit-experiment)
114
 
115
  *OpenTransformers Ltd - Scott Bisset*
116
  """)
 
1
  import gradio as gr
2
  import torch
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+ import time
5
 
6
  # Load model once
7
  print("Loading GPT-2...")
 
60
 
61
  def compare(prompt, n_tokens):
62
  n_tokens = int(n_tokens)
63
+
64
+ # AR with timing
65
+ if device == "cuda":
66
+ torch.cuda.synchronize()
67
+ start = time.perf_counter()
68
  ar_output = ar_generate(prompt, n_tokens)
69
+ if device == "cuda":
70
+ torch.cuda.synchronize()
71
+ ar_time = time.perf_counter() - start
72
+ ar_tps = n_tokens / ar_time
73
+
74
+ # SAT with timing
75
+ if device == "cuda":
76
+ torch.cuda.synchronize()
77
+ start = time.perf_counter()
78
  sat_output = forced_sat_generate(prompt, n_tokens)
79
+ if device == "cuda":
80
+ torch.cuda.synchronize()
81
+ sat_time = time.perf_counter() - start
82
+ sat_tps = n_tokens / sat_time
83
+
84
+ speedup = ar_time / sat_time if sat_time > 0 else 0
85
+
86
+ ar_label = f"AR Output - {ar_tps:.1f} tok/s"
87
+ sat_label = f"Forced SAT - {sat_tps:.1f} tok/s"
88
+ speedup_text = f"## Speedup: {speedup:.2f}x"
89
+
90
+ return ar_output, sat_output, ar_label, sat_label, speedup_text
91
 
92
  # Gradio interface
93
  with gr.Blocks(title="AR vs Forced SAT") as demo:
 
96
 
97
  **Can AR models be forced to output 2 tokens at once?**
98
 
 
 
 
 
 
 
99
  Model: GPT-2 (124M params)
100
  """)
101
 
 
105
 
106
  btn = gr.Button("Generate", variant="primary")
107
 
108
+ speedup_display = gr.Markdown("## Speedup: ?x")
109
+
110
  with gr.Row():
111
+ ar_label = gr.Markdown("### AR Output - ? tok/s")
112
+ sat_label = gr.Markdown("### Forced SAT - ? tok/s")
113
 
114
+ with gr.Row():
115
+ ar_output = gr.Textbox(label="", lines=5, show_label=False)
116
+ sat_output = gr.Textbox(label="", lines=5, show_label=False)
117
+
118
+ btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display])
119
 
120
  gr.Examples(
121
  examples=[
 
123
  ["In the beginning", 40],
124
  ["Once upon a time", 40],
125
  ["Machine learning is", 40],
 
126
  ],
127
  inputs=[prompt, n_tokens],
128
  )
129
 
130
  gr.Markdown("""
131
  ---
132
+ **Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context.
 
 
 
 
133
 
134
+ **Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large)
135
 
136
  *OpenTransformers Ltd - Scott Bisset*
137
  """)