grapheneaffiliates commited on
Commit
4166b93
·
verified ·
1 Parent(s): d035351

Upload python/train_cpu.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. python/train_cpu.py +354 -0
python/train_cpu.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ H4 Polytopic Attention — CPU autoresearch training script.
3
+ This is the ONLY file the agent modifies during autonomous research.
4
+
5
+ Follows the autoresearch pattern: modify → run (2 min budget) → measure → keep/discard.
6
+
7
+ The frozen H4 geometry is off-limits. Only the trainable adapters, hyperparameters,
8
+ training loop details, and architecture of trainable layers may be changed.
9
+ """
10
+
11
+ import os
12
+ import math
13
+ import time
14
+ import json
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+
20
+ import sys
21
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
22
+
23
+ from h4_polytopic_attention import generate_600_cell_vertices, build_coxeter_chambers
24
+ from h4_language_model import H4LanguageModel
25
+ from bitlinear import BitLinear
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Constants (DO NOT MODIFY the frozen geometry section)
29
+ # ---------------------------------------------------------------------------
30
+
31
+ PHI = (1 + math.sqrt(5)) / 2
32
+
33
+ # Frozen geometric constants — loaded from existing code
34
+ VERTICES = torch.tensor(generate_600_cell_vertices(), dtype=torch.float32)
35
+ CHAMBERS = build_coxeter_chambers(VERTICES.numpy())
36
+ SIMPLE_ROOTS = torch.tensor(CHAMBERS['simple_roots'], dtype=torch.float32)
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # Hyperparameters (AGENT MAY MODIFY THESE)
40
+ # ---------------------------------------------------------------------------
41
+
42
+ # Time budget: 2 minutes on CPU
43
+ TIME_BUDGET = 120 # seconds
44
+
45
+ # Dataset: 'synthetic', 'shakespeare', or 'tinystories'
46
+ DATASET = 'synthetic'
47
+
48
+ # Data
49
+ MAX_SEQ_LEN = 128
50
+ BATCH_SIZE = 8
51
+
52
+ # Model
53
+ D_MODEL = 256
54
+ N_HEADS = 8
55
+ N_LAYERS = 4
56
+ D_VALUE = 16
57
+ D_FFN = 512
58
+ TOP_K = 16
59
+ DROPOUT = 0.0
60
+ USE_BITLINEAR = True # Set True for ternary {-1,0,+1} weights
61
+
62
+ # Optimizer
63
+ LR = 5e-3
64
+ WEIGHT_DECAY = 0.01
65
+ WARMUP_STEPS = 50
66
+ GRAD_CLIP = 1.0
67
+
68
+ # Eval
69
+ EVAL_INTERVAL = 25
70
+ EVAL_BATCHES = 5
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Data: Character-level Shakespeare (or synthetic if not available)
74
+ # ---------------------------------------------------------------------------
75
+
76
+ def load_text_data():
77
+ """Load training text. Falls back to synthetic data if no file available."""
78
+ # Try to load Shakespeare or other text
79
+ data_paths = [
80
+ os.path.join(os.path.dirname(__file__), '..', 'data', 'shakespeare.txt'),
81
+ os.path.join(os.path.dirname(__file__), '..', 'data', 'input.txt'),
82
+ os.path.join(os.path.dirname(__file__), 'data', 'input.txt'),
83
+ ]
84
+
85
+ text = None
86
+ for path in data_paths:
87
+ if os.path.exists(path):
88
+ with open(path, 'r', encoding='utf-8') as f:
89
+ text = f.read()
90
+ print(f"Loaded data from {path} ({len(text)} chars)")
91
+ break
92
+
93
+ if text is None:
94
+ # Generate synthetic text with mathematical structure
95
+ # Fibonacci-structured repetitions to test geometric inductive bias
96
+ print("No data file found, generating synthetic text...")
97
+ base_phrases = [
98
+ "the golden ratio appears in nature ",
99
+ "fibonacci numbers grow exponentially ",
100
+ "symmetry underlies all of physics ",
101
+ "the icosahedron has twenty faces ",
102
+ "phi equals one plus one over phi ",
103
+ "geometry is the language of space ",
104
+ "five fold symmetry cannot tile a plane ",
105
+ "the dodecahedron has twelve faces ",
106
+ ]
107
+ # Build text with Fibonacci-structured repetitions
108
+ text = ""
109
+ a, b = 1, 1
110
+ for _ in range(200):
111
+ phrase = base_phrases[a % len(base_phrases)]
112
+ text += phrase * (b % 3 + 1)
113
+ a, b = b, a + b
114
+
115
+ return text
116
+
117
+
118
+ def prepare_char_dataset(text: str):
119
+ """Prepare character-level dataset from text."""
120
+ chars = sorted(list(set(text)))
121
+ vocab_size = len(chars)
122
+ stoi = {ch: i for i, ch in enumerate(chars)}
123
+ itos = {i: ch for ch, i in stoi.items()}
124
+
125
+ data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
126
+
127
+ # Split 90/10
128
+ n = int(0.9 * len(data))
129
+ train_data = data[:n]
130
+ val_data = data[n:]
131
+
132
+ return train_data, val_data, vocab_size, stoi, itos
133
+
134
+
135
+ def get_batch(data: torch.Tensor, batch_size: int, seq_len: int):
136
+ """Sample a random batch of sequences."""
137
+ max_start = len(data) - seq_len - 1
138
+ if max_start <= 0:
139
+ max_start = 1
140
+ ix = torch.randint(0, max_start, (batch_size,))
141
+ x = torch.stack([data[i:i + seq_len] for i in ix])
142
+ y = torch.stack([data[i + 1:i + seq_len + 1] for i in ix])
143
+ return x, y
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # Training loop (follows autoresearch pattern)
148
+ # ---------------------------------------------------------------------------
149
+
150
+ def main():
151
+ t_start = time.time()
152
+ torch.manual_seed(42)
153
+ np.random.seed(42)
154
+
155
+ # Load data
156
+ if DATASET != 'synthetic':
157
+ from prepare_data import load_and_prepare
158
+ train_data, val_data, vocab_size, stoi, itos = load_and_prepare(DATASET)
159
+ else:
160
+ text = load_text_data()
161
+ train_data, val_data, vocab_size, stoi, itos = prepare_char_dataset(text)
162
+ print(f"Vocab size: {vocab_size}, Train: {len(train_data)}, Val: {len(val_data)}")
163
+
164
+ # Create model
165
+ model = H4LanguageModel(
166
+ vocab_size=vocab_size,
167
+ d_model=D_MODEL,
168
+ n_heads=N_HEADS,
169
+ n_layers=N_LAYERS,
170
+ d_value=D_VALUE,
171
+ d_ffn=D_FFN,
172
+ top_k=TOP_K,
173
+ max_seq_len=MAX_SEQ_LEN * 2,
174
+ dropout=DROPOUT,
175
+ use_bitlinear=USE_BITLINEAR,
176
+ )
177
+
178
+ param_info = model.count_params()
179
+ print(f"Model params: {param_info['trainable']:,} trainable, {param_info['buffers']:,} buffer elements")
180
+
181
+ # Optimizer: AdamW with cosine schedule
182
+ optimizer = torch.optim.AdamW(
183
+ model.parameters(),
184
+ lr=LR,
185
+ weight_decay=WEIGHT_DECAY,
186
+ betas=(0.9, 0.95),
187
+ )
188
+
189
+ # Cosine LR schedule with warmup
190
+ def lr_schedule(step):
191
+ if step < WARMUP_STEPS:
192
+ return step / max(WARMUP_STEPS, 1)
193
+ # Cosine decay to 10% of peak
194
+ progress = (step - WARMUP_STEPS) / max(1, 500 - WARMUP_STEPS)
195
+ return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * min(progress, 1.0)))
196
+
197
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
198
+
199
+ # Training state
200
+ step = 0
201
+ total_training_time = 0.0
202
+ best_val_loss = float('inf')
203
+ train_losses = []
204
+ val_losses = []
205
+
206
+ # Use full attention (no tree) for short sequences during training
207
+ # Tree is beneficial for long sequences; for seq_len=128, full attention is faster
208
+ use_tree = MAX_SEQ_LEN > 256
209
+
210
+ print(f"\nTraining for {TIME_BUDGET}s budget, seq_len={MAX_SEQ_LEN}, use_tree={use_tree}")
211
+ print(f"{'step':>6} {'loss':>8} {'val_loss':>8} {'lr':>10} {'dt':>6} {'progress':>8}")
212
+ print("-" * 56)
213
+
214
+ model.train()
215
+
216
+ while True:
217
+ t0 = time.time()
218
+
219
+ # Get batch
220
+ x, y = get_batch(train_data, BATCH_SIZE, MAX_SEQ_LEN)
221
+
222
+ # Forward
223
+ logits = model(x, use_tree=use_tree)
224
+ loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))
225
+
226
+ # Backward
227
+ optimizer.zero_grad()
228
+ loss.backward()
229
+
230
+ # Gradient clipping
231
+ if GRAD_CLIP > 0:
232
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
233
+
234
+ optimizer.step()
235
+ scheduler.step()
236
+
237
+ dt = time.time() - t0
238
+ if step > 2: # skip warmup steps for timing
239
+ total_training_time += dt
240
+
241
+ train_losses.append(loss.item())
242
+
243
+ # Eval
244
+ val_loss = None
245
+ if step % EVAL_INTERVAL == 0:
246
+ model.eval()
247
+ with torch.no_grad():
248
+ vl = []
249
+ for _ in range(EVAL_BATCHES):
250
+ xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
251
+ vlogits = model(xv, use_tree=False)
252
+ vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
253
+ val_loss = sum(vl) / len(vl)
254
+ val_losses.append(val_loss)
255
+
256
+ if val_loss < best_val_loss:
257
+ best_val_loss = val_loss
258
+
259
+ current_lr = scheduler.get_last_lr()[0]
260
+ progress = min(total_training_time / TIME_BUDGET, 1.0)
261
+ print(f"{step:6d} {loss.item():8.4f} {val_loss:8.4f} {current_lr:10.6f} {dt:6.3f} {progress:7.1%}")
262
+ model.train()
263
+
264
+ step += 1
265
+ if step > 2 and total_training_time >= TIME_BUDGET:
266
+ break
267
+
268
+ # ---------------------------------------------------------------------------
269
+ # Final evaluation
270
+ # ---------------------------------------------------------------------------
271
+
272
+ model.eval()
273
+ with torch.no_grad():
274
+ # Final val loss
275
+ vl = []
276
+ for _ in range(EVAL_BATCHES * 4):
277
+ xv, yv = get_batch(val_data, BATCH_SIZE, MAX_SEQ_LEN)
278
+ vlogits = model(xv, use_tree=False)
279
+ vl.append(F.cross_entropy(vlogits.view(-1, vocab_size), yv.view(-1)).item())
280
+ final_val_loss = sum(vl) / len(vl)
281
+
282
+ # Bits per byte (for character-level: loss_nats / ln(2))
283
+ val_bpb = final_val_loss / math.log(2)
284
+
285
+ # Geometric diagnostics on a sample batch
286
+ xd, _ = get_batch(val_data, 1, MAX_SEQ_LEN)
287
+ _, diag_list = model(xd, use_tree=False, return_diagnostics=True)
288
+
289
+ # Aggregate diagnostics across layers
290
+ avg_chamber_entropy = np.mean([d['chamber_entropy'] for d in diag_list])
291
+ nudge_ranks = []
292
+ geo_aligns = []
293
+ for d in diag_list:
294
+ nudge_ranks.extend(d['nudge_rank'])
295
+ geo_aligns.extend(d['geo_alignment'])
296
+ avg_nudge_rank = np.mean([r for r in nudge_ranks if r != float('inf')] or [0])
297
+ avg_geo_alignment = np.mean(geo_aligns)
298
+
299
+ # Generate sample text
300
+ seed_text = list(stoi.keys())[:4] # first 4 chars
301
+ seed_ids = torch.tensor([[stoi[c] for c in seed_text]], dtype=torch.long)
302
+ generated = model.generate(seed_ids, max_new_tokens=80, temperature=0.8, top_k_sample=10)
303
+ gen_text = ''.join([itos.get(i.item(), '?') for i in generated[0]])
304
+
305
+ # ---------------------------------------------------------------------------
306
+ # Summary (autoresearch-parseable format)
307
+ # ---------------------------------------------------------------------------
308
+
309
+ # Ternary diagnostics (if using BitLinear)
310
+ has_bitlinear = any(isinstance(m, BitLinear) for m in model.modules())
311
+ ternary_info = {}
312
+ if has_bitlinear:
313
+ from ternary_diagnostics import chamber_preservation, bitlinear_layer_stats, size_comparison
314
+ cp = chamber_preservation(model)
315
+ mean_cp = sum(cp.values()) / len(cp) if cp else 0.0
316
+ bl_stats = bitlinear_layer_stats(model)
317
+ mean_zero_pct = np.mean([s['zero'] for s in bl_stats.values()]) if bl_stats else 0.0
318
+ sz = size_comparison(model)
319
+ ternary_info = {
320
+ 'chamber_preserve': mean_cp,
321
+ 'mean_zero_pct': mean_zero_pct,
322
+ 'compression': sz['compression'],
323
+ 'mixed_kb': sz['mixed_kb'],
324
+ }
325
+
326
+ print("\n" + "=" * 60)
327
+ print("GENERATED SAMPLE:")
328
+ print(gen_text[:200])
329
+ print("=" * 60)
330
+
331
+ print("\n---")
332
+ print(f"val_bpb: {val_bpb:.6f}")
333
+ print(f"val_loss: {final_val_loss:.6f}")
334
+ print(f"best_val_loss: {best_val_loss:.6f}")
335
+ print(f"chamber_entropy: {avg_chamber_entropy:.4f}")
336
+ print(f"avg_nudge_rank: {avg_nudge_rank:.4f}")
337
+ print(f"avg_geo_alignment: {avg_geo_alignment:.4f}")
338
+ print(f"training_seconds: {total_training_time:.1f}")
339
+ print(f"total_seconds: {time.time() - t_start:.1f}")
340
+ print(f"peak_memory_mb: {0:.1f}")
341
+ print(f"num_steps: {step}")
342
+ print(f"num_params: {param_info['trainable']}")
343
+ print(f"vocab_size: {vocab_size}")
344
+ print(f"seq_len: {MAX_SEQ_LEN}")
345
+ print(f"ternary: {'yes' if USE_BITLINEAR else 'no'}")
346
+ if ternary_info:
347
+ print(f"chamber_preserve: {ternary_info['chamber_preserve']:.4f}")
348
+ print(f"mean_zero_pct: {ternary_info['mean_zero_pct']:.4f}")
349
+ print(f"compression: {ternary_info['compression']:.1f}x")
350
+ print(f"model_size_kb: {ternary_info['mixed_kb']:.1f}")
351
+
352
+
353
+ if __name__ == '__main__':
354
+ main()