File size: 4,567 Bytes
e310437
 
 
19e9060
e310437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9060
 
 
 
 
e310437
19e9060
 
 
 
 
 
 
 
 
e310437
19e9060
 
 
 
 
 
 
 
 
 
 
 
e310437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19e9060
 
e310437
19e9060
 
e310437
19e9060
 
 
 
 
e310437
 
 
 
 
 
 
 
 
 
 
 
 
19e9060
e310437
19e9060
e310437
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import time

# Load model once
print("Loading GPT-2...")
model = GPT2LMHeadModel.from_pretrained('gpt2').eval()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Running on {device}")

def ar_generate(prompt, n_tokens=50):
    """Standard AR generation - 1 token at a time"""
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    generated = []
    for _ in range(n_tokens):
        with torch.no_grad():
            outputs = model(input_ids)
            next_logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(next_logits, dim=-1)
            generated.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
    
    return tokenizer.decode(generated)

def forced_sat_generate(prompt, n_tokens=50, block_size=2):
    """
    FORCED SAT: Predict 2 tokens at once from AR model
    Token 1: from position -1 (current)
    Token 2: from position -2 (stale context)
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    generated = []
    for _ in range(n_tokens // block_size):
        with torch.no_grad():
            outputs = model(input_ids)
            
            # Token 1: current position
            logits1 = outputs.logits[:, -1, :]
            # Token 2: previous position (stale)
            logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
            
            token1 = torch.argmax(logits1, dim=-1)
            token2 = torch.argmax(logits2, dim=-1)
            
            generated.extend([token1.item(), token2.item()])
            input_ids = torch.cat([
                input_ids, 
                token1.unsqueeze(0), 
                token2.unsqueeze(0)
            ], dim=1)
    
    return tokenizer.decode(generated)

def compare(prompt, n_tokens):
    n_tokens = int(n_tokens)
    
    # AR with timing
    if device == "cuda":
        torch.cuda.synchronize()
    start = time.perf_counter()
    ar_output = ar_generate(prompt, n_tokens)
    if device == "cuda":
        torch.cuda.synchronize()
    ar_time = time.perf_counter() - start
    ar_tps = n_tokens / ar_time
    
    # SAT with timing
    if device == "cuda":
        torch.cuda.synchronize()
    start = time.perf_counter()
    sat_output = forced_sat_generate(prompt, n_tokens)
    if device == "cuda":
        torch.cuda.synchronize()
    sat_time = time.perf_counter() - start
    sat_tps = n_tokens / sat_time
    
    speedup = ar_time / sat_time if sat_time > 0 else 0
    
    ar_label = f"AR Output - {ar_tps:.1f} tok/s"
    sat_label = f"Forced SAT - {sat_tps:.1f} tok/s"
    speedup_text = f"## Speedup: {speedup:.2f}x"
    
    return ar_output, sat_output, ar_label, sat_label, speedup_text

# Gradio interface
with gr.Blocks(title="AR vs Forced SAT") as demo:
    gr.Markdown("""
    # AR vs Forced SAT Comparison
    
    **Can AR models be forced to output 2 tokens at once?**
    
    Model: GPT-2 (124M params)
    """)
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2)
        n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate")
    
    btn = gr.Button("Generate", variant="primary")
    
    speedup_display = gr.Markdown("## Speedup: ?x")
    
    with gr.Row():
        ar_label = gr.Markdown("### AR Output - ? tok/s")
        sat_label = gr.Markdown("### Forced SAT - ? tok/s")
    
    with gr.Row():
        ar_output = gr.Textbox(label="", lines=5, show_label=False)
        sat_output = gr.Textbox(label="", lines=5, show_label=False)
    
    btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output, ar_label, sat_label, speedup_display])
    
    gr.Examples(
        examples=[
            ["The quick brown fox", 40],
            ["In the beginning", 40],
            ["Once upon a time", 40],
            ["Machine learning is", 40],
        ],
        inputs=[prompt, n_tokens],
    )
    
    gr.Markdown("""
    ---
    **Why Forced SAT fails quality:** AR hidden states only encode "next token". Forcing 2-token output uses stale context.
    
    **Solution:** Joint AR+SAT training from scratch. See [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large)
    
    *OpenTransformers Ltd - Scott Bisset*
    """)

demo.launch()