grapheneaffiliates commited on
Commit
849acfb
·
verified ·
1 Parent(s): 06e4588

Upload python/compare_baselines.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. python/compare_baselines.py +330 -0
python/compare_baselines.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Head-to-head comparison: H4 attention vs softmax vs linear attention.
3
+ Same model size, same data, same training budget.
4
+
5
+ Usage:
6
+ python compare_baselines.py # Shakespeare (default)
7
+ python compare_baselines.py --dataset tinystories
8
+ python compare_baselines.py --time-budget 60 # Faster runs
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import math
14
+ import time
15
+ import argparse
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+
21
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
22
+
23
+ from prepare_data import load_and_prepare
24
+ from baselines import BaselineLanguageModel
25
+ from h4_language_model import H4LanguageModel
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Configuration
29
+ # ---------------------------------------------------------------------------
30
+
31
+ # Model architecture (same for all models)
32
+ D_MODEL = 128
33
+ N_HEADS = 8
34
+ N_LAYERS = 4
35
+ D_VALUE = 16
36
+ D_FFN = 512
37
+ MAX_SEQ_LEN = 128
38
+ DROPOUT = 0.0
39
+
40
+ # Training
41
+ BATCH_SIZE = 8
42
+ LR = 5e-3
43
+ WEIGHT_DECAY = 0.01
44
+ WARMUP_STEPS = 50
45
+ GRAD_CLIP = 1.0
46
+ TIME_BUDGET = 120 # seconds per model
47
+
48
+ # Eval
49
+ EVAL_INTERVAL = 25
50
+ EVAL_BATCHES = 5
51
+
52
+ # Models to compare
53
+ CONFIGS = [
54
+ {'name': 'H4 Float', 'attention': 'h4', 'bitlinear': False},
55
+ {'name': 'H4 Ternary', 'attention': 'h4', 'bitlinear': True},
56
+ {'name': 'Softmax', 'attention': 'softmax', 'bitlinear': False},
57
+ {'name': 'Linear', 'attention': 'linear', 'bitlinear': False},
58
+ ]
59
+
60
+
61
+ def get_batch(data, batch_size, seq_len):
62
+ """Sample a random batch of sequences."""
63
+ max_start = len(data) - seq_len - 1
64
+ if max_start <= 0:
65
+ max_start = 1
66
+ ix = torch.randint(0, max_start, (batch_size,))
67
+ x = torch.stack([data[i:i + seq_len] for i in ix])
68
+ y = torch.stack([data[i + 1:i + seq_len + 1] for i in ix])
69
+ return x, y
70
+
71
+
72
+ def create_model(config, vocab_size):
73
+ """Create a model based on config."""
74
+ attn_type = config['attention']
75
+ use_bitlinear = config['bitlinear']
76
+
77
+ if attn_type == 'h4':
78
+ model = H4LanguageModel(
79
+ vocab_size=vocab_size,
80
+ d_model=D_MODEL,
81
+ n_heads=N_HEADS,
82
+ n_layers=N_LAYERS,
83
+ d_value=D_VALUE,
84
+ d_ffn=D_FFN,
85
+ top_k=16,
86
+ max_seq_len=MAX_SEQ_LEN * 2,
87
+ dropout=DROPOUT,
88
+ use_bitlinear=use_bitlinear,
89
+ )
90
+ else:
91
+ model = BaselineLanguageModel(
92
+ vocab_size=vocab_size,
93
+ d_model=D_MODEL,
94
+ n_heads=N_HEADS,
95
+ n_layers=N_LAYERS,
96
+ d_value=D_VALUE,
97
+ d_ffn=D_FFN,
98
+ max_seq_len=MAX_SEQ_LEN * 2,
99
+ dropout=DROPOUT,
100
+ attention_type=attn_type,
101
+ use_bitlinear=use_bitlinear,
102
+ )
103
+ return model
104
+
105
+
106
+ def train_and_evaluate(config, train_data, val_data, vocab_size, itos, time_budget):
107
+ """Train a model and return evaluation metrics."""
108
+ name = config['name']
109
+ print(f"\n{'='*60}")
110
+ print(f"Training: {name}")
111
+ print(f"{'='*60}")
112
+
113
+ torch.manual_seed(42)
114
+ np.random.seed(42)
115
+
116
+ model = create_model(config, vocab_size)
117
+ param_info = model.count_params()
118
+ print(f" Parameters: {param_info['trainable']:,} trainable")
119
+
120
+ optimizer = torch.optim.AdamW(
121
+ model.parameters(),
122
+ lr=LR,
123
+ weight_decay=WEIGHT_DECAY,
124
+ betas=(0.9, 0.95),
125
+ )
126
+
127
+ def lr_schedule(step):
128
+ if step < WARMUP_STEPS:
129
+ return step / max(WARMUP_STEPS, 1)
130
+ progress = (step - WARMUP_STEPS) / max(1, 500 - WARMUP_STEPS)
131
+ return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0)))
132
+
133
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
134
+
135
+ # H4 models use full attention (no tree) for short sequences
136
+ is_h4 = config['attention'] == 'h4'
137
+
138
+ step = 0
139
+ total_training_time = 0.0
140
+ best_val_loss = float('inf')
141
+ model.train()
142
+
143
+ t_start = time.time()
144
+
145
+ while True:
146
+ t0 = time.time()
147
+
148
+ x, y = get_batch(train_data, BATCH_SIZE, MAX_SEQ_LEN)
149
+
150
+ if is_h4:
151
+ logits = model(x, use_tree=False)
152
+ else:
153
+ logits = model(x)
154
+ loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
155
+
156
+ optimizer.zero_grad()
157
+ loss.backward()
158
+ if GRAD_CLIP > 0:
159
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
160
+ optimizer.step()
161
+ scheduler.step()
162
+
163
+ dt = time.time() - t0
164
+ if step > 2:
165
+ total_training_time += dt
166
+
167
+ # Periodic eval
168
+ if step % EVAL_INTERVAL == 0:
169
+ model.eval()
170
+ with torch.no_grad():
171
+ vl = []
172
+ for _ in range(EVAL_BATCHES):
173
+ xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
174
+ if is_h4:
175
+ vlogits = model(xv, use_tree=False)
176
+ else:
177
+ vlogits = model(xv)
178
+ vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
179
+ val_loss = sum(vl) / len(vl)
180
+ if val_loss < best_val_loss:
181
+ best_val_loss = val_loss
182
+
183
+ progress = min(total_training_time / time_budget, 1.0)
184
+ print(f" step {step:5d} | loss {loss.item():.4f} | val_loss {val_loss:.4f} | {progress:.0%}")
185
+ model.train()
186
+
187
+ step += 1
188
+ if step > 2 and total_training_time >= time_budget:
189
+ break
190
+
191
+ # Final evaluation (more batches for stable estimate)
192
+ model.eval()
193
+ with torch.no_grad():
194
+ vl = []
195
+ for _ in range(EVAL_BATCHES * 4):
196
+ xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
197
+ if is_h4:
198
+ vlogits = model(xv, use_tree=False)
199
+ else:
200
+ vlogits = model(xv)
201
+ vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
202
+ final_val_loss = sum(vl) / len(vl)
203
+
204
+ val_bpb = final_val_loss / math.log(2)
205
+ perplexity = math.exp(final_val_loss)
206
+
207
+ # Generate sample
208
+ seed_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long)
209
+ if is_h4:
210
+ gen = model.generate(seed_ids, max_new_tokens=60, temperature=0.8, top_k_sample=10)
211
+ else:
212
+ gen = model.generate(seed_ids, max_new_tokens=60, temperature=0.8, top_k_sample=10)
213
+ gen_text = ''.join([itos.get(i.item(), '?') for i in gen[0]])
214
+
215
+ wall_time = time.time() - t_start
216
+
217
+ results = {
218
+ 'name': name,
219
+ 'attention': config['attention'],
220
+ 'bitlinear': config['bitlinear'],
221
+ 'params': param_info['trainable'],
222
+ 'steps': step,
223
+ 'val_loss': final_val_loss,
224
+ 'best_val_loss': best_val_loss,
225
+ 'val_bpb': val_bpb,
226
+ 'perplexity': perplexity,
227
+ 'wall_time': wall_time,
228
+ 'train_time': total_training_time,
229
+ 'sample': gen_text[:100],
230
+ }
231
+
232
+ print(f" Final: val_loss={final_val_loss:.4f}, bpb={val_bpb:.4f}, "
233
+ f"ppl={perplexity:.1f}, steps={step}, time={wall_time:.0f}s")
234
+
235
+ return results
236
+
237
+
238
+ def print_comparison_table(all_results, dataset_name, time_budget=TIME_BUDGET):
239
+ """Print a formatted comparison table."""
240
+ print(f"\n{'='*80}")
241
+ print(f"COMPARISON RESULTS — Dataset: {dataset_name}")
242
+ print(f"Config: d_model={D_MODEL}, n_layers={N_LAYERS}, n_heads={N_HEADS}, "
243
+ f"seq_len={MAX_SEQ_LEN}, budget={time_budget}s")
244
+ print(f"{'='*80}")
245
+
246
+ # Header
247
+ print(f"{'Model':<16} {'Params':>8} {'Steps':>6} {'Val Loss':>9} "
248
+ f"{'BPB':>7} {'PPL':>8} {'Time':>6}")
249
+ print(f"{'-'*16} {'-'*8} {'-'*6} {'-'*9} {'-'*7} {'-'*8} {'-'*6}")
250
+
251
+ # Sort by val_loss
252
+ sorted_results = sorted(all_results, key=lambda r: r['val_loss'])
253
+
254
+ for r in sorted_results:
255
+ params_str = f"{r['params'] // 1000}K" if r['params'] >= 1000 else str(r['params'])
256
+ print(f"{r['name']:<16} {params_str:>8} {r['steps']:>6} {r['val_loss']:>9.4f} "
257
+ f"{r['val_bpb']:>7.4f} {r['perplexity']:>8.1f} {r['wall_time']:>5.0f}s")
258
+
259
+ # Best model
260
+ best = sorted_results[0]
261
+ print(f"\nBest: {best['name']} (val_loss={best['val_loss']:.4f}, ppl={best['perplexity']:.1f})")
262
+
263
+ # H4 vs Softmax comparison
264
+ h4_float = next((r for r in all_results if r['attention'] == 'h4' and not r['bitlinear']), None)
265
+ softmax = next((r for r in all_results if r['attention'] == 'softmax'), None)
266
+ if h4_float and softmax:
267
+ delta = softmax['val_loss'] - h4_float['val_loss']
268
+ pct = (delta / softmax['val_loss']) * 100
269
+ if delta > 0:
270
+ print(f"H4 Float vs Softmax: H4 wins by {delta:.4f} nats ({pct:.1f}% better)")
271
+ else:
272
+ print(f"H4 Float vs Softmax: Softmax wins by {-delta:.4f} nats ({-pct:.1f}% better)")
273
+
274
+ # Sample text from each model
275
+ print(f"\n{'='*80}")
276
+ print("GENERATED SAMPLES:")
277
+ print(f"{'='*80}")
278
+ for r in sorted_results:
279
+ print(f"\n[{r['name']}]")
280
+ print(f" {r['sample']}")
281
+
282
+
283
+ def main():
284
+ parser = argparse.ArgumentParser(description='Compare H4 vs baseline attention mechanisms')
285
+ parser.add_argument('--dataset', default='shakespeare',
286
+ choices=['synthetic', 'shakespeare', 'tinystories'],
287
+ help='Dataset to use (default: shakespeare)')
288
+ parser.add_argument('--time-budget', type=int, default=TIME_BUDGET,
289
+ help=f'Training time per model in seconds (default: {TIME_BUDGET})')
290
+ parser.add_argument('--models', nargs='+', default=None,
291
+ help='Subset of models to run (e.g., "h4 softmax")')
292
+ args = parser.parse_args()
293
+
294
+ time_budget = args.time_budget
295
+
296
+ print(f"H4 Polytopic Attention — Baseline Comparison")
297
+ print(f"Dataset: {args.dataset}, Time budget: {time_budget}s per model")
298
+ print(f"Expected total time: ~{len(CONFIGS) * time_budget // 60} minutes")
299
+
300
+ # Load data
301
+ train_data, val_data, vocab_size, stoi, itos = load_and_prepare(args.dataset)
302
+ print(f"Vocab: {vocab_size}, Train: {len(train_data):,}, Val: {len(val_data):,}")
303
+
304
+ # Filter configs if requested
305
+ configs = CONFIGS
306
+ if args.models:
307
+ configs = [c for c in CONFIGS if any(m.lower() in c['name'].lower() for m in args.models)]
308
+ if not configs:
309
+ print(f"No matching models for {args.models}. Available: {[c['name'] for c in CONFIGS]}")
310
+ return
311
+
312
+ # Run comparisons
313
+ all_results = []
314
+ for config in configs:
315
+ try:
316
+ results = train_and_evaluate(
317
+ config, train_data, val_data, vocab_size, itos, time_budget
318
+ )
319
+ all_results.append(results)
320
+ except Exception as e:
321
+ print(f"\n ERROR training {config['name']}: {e}")
322
+ import traceback
323
+ traceback.print_exc()
324
+
325
+ if all_results:
326
+ print_comparison_table(all_results, args.dataset, time_budget)
327
+
328
+
329
+ if __name__ == '__main__':
330
+ main()