OpenTransformer commited on
Commit
e310437
·
verified ·
1 Parent(s): 032c630

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
+
5
+ # Load model once
6
+ print("Loading GPT-2...")
7
+ model = GPT2LMHeadModel.from_pretrained('gpt2').eval()
8
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
9
+
10
+ # Move to GPU if available
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model = model.to(device)
13
+ print(f"Running on {device}")
14
+
15
+ def ar_generate(prompt, n_tokens=50):
16
+ """Standard AR generation - 1 token at a time"""
17
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
18
+
19
+ generated = []
20
+ for _ in range(n_tokens):
21
+ with torch.no_grad():
22
+ outputs = model(input_ids)
23
+ next_logits = outputs.logits[:, -1, :]
24
+ next_token = torch.argmax(next_logits, dim=-1)
25
+ generated.append(next_token.item())
26
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
27
+
28
+ return tokenizer.decode(generated)
29
+
30
+ def forced_sat_generate(prompt, n_tokens=50, block_size=2):
31
+ """
32
+ FORCED SAT: Predict 2 tokens at once from AR model
33
+ Token 1: from position -1 (current)
34
+ Token 2: from position -2 (stale context)
35
+ """
36
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
37
+
38
+ generated = []
39
+ for _ in range(n_tokens // block_size):
40
+ with torch.no_grad():
41
+ outputs = model(input_ids)
42
+
43
+ # Token 1: current position
44
+ logits1 = outputs.logits[:, -1, :]
45
+ # Token 2: previous position (stale)
46
+ logits2 = outputs.logits[:, -2, :] if input_ids.shape[1] > 1 else logits1
47
+
48
+ token1 = torch.argmax(logits1, dim=-1)
49
+ token2 = torch.argmax(logits2, dim=-1)
50
+
51
+ generated.extend([token1.item(), token2.item()])
52
+ input_ids = torch.cat([
53
+ input_ids,
54
+ token1.unsqueeze(0),
55
+ token2.unsqueeze(0)
56
+ ], dim=1)
57
+
58
+ return tokenizer.decode(generated)
59
+
60
+ def compare(prompt, n_tokens):
61
+ n_tokens = int(n_tokens)
62
+ ar_output = ar_generate(prompt, n_tokens)
63
+ sat_output = forced_sat_generate(prompt, n_tokens)
64
+ return ar_output, sat_output
65
+
66
+ # Gradio interface
67
+ with gr.Blocks(title="AR vs Forced SAT") as demo:
68
+ gr.Markdown("""
69
+ # AR vs Forced SAT Comparison
70
+
71
+ **Can AR models be forced to output 2 tokens at once?**
72
+
73
+ - **AR (Autoregressive):** Standard 1-token-at-a-time generation
74
+ - **Forced SAT:** Outputs 2 tokens per step using stale context for token 2
75
+
76
+ Forced SAT runs ~2x faster but produces degraded output because AR hidden states
77
+ don't encode multi-token futures. Joint AR+SAT training is required for quality SAT inference.
78
+
79
+ Model: GPT-2 (124M params)
80
+ """)
81
+
82
+ with gr.Row():
83
+ prompt = gr.Textbox(label="Prompt", value="The scientist discovered that", lines=2)
84
+ n_tokens = gr.Slider(minimum=10, maximum=100, value=40, step=10, label="Tokens to generate")
85
+
86
+ btn = gr.Button("Generate", variant="primary")
87
+
88
+ with gr.Row():
89
+ ar_output = gr.Textbox(label="AR Output (baseline)", lines=5)
90
+ sat_output = gr.Textbox(label="Forced SAT v1 (2x speed, degraded)", lines=5)
91
+
92
+ btn.click(compare, inputs=[prompt, n_tokens], outputs=[ar_output, sat_output])
93
+
94
+ gr.Examples(
95
+ examples=[
96
+ ["The quick brown fox", 40],
97
+ ["In the beginning", 40],
98
+ ["Once upon a time", 40],
99
+ ["Machine learning is", 40],
100
+ ["The president announced that", 40],
101
+ ],
102
+ inputs=[prompt, n_tokens],
103
+ )
104
+
105
+ gr.Markdown("""
106
+ ---
107
+ **Why Forced SAT fails:** AR hidden states at position N only encode "next token N+1".
108
+ There's no representation for token N+2. Forcing 2-token output uses stale context,
109
+ creating alternating good/bad tokens.
110
+
111
+ **Solution:** Train AR+SAT jointly from scratch so representations encode multiple future tokens.
112
+
113
+ See: [AGILLM-3](https://huggingface.co/OpenTransformer/AGILLM-3-large) | [Experiment Code](https://huggingface.co/OpenTransformer/sat-retrofit-experiment)
114
+
115
+ *OpenTransformers Ltd - Scott Bisset*
116
+ """)
117
+
118
+ demo.launch()