tekkmaven commited on
Commit
52eb540
·
verified ·
1 Parent(s): 3ec445e

Upload experiment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. experiment.py +525 -0
experiment.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Representation Learning Dynamics Experiment
3
+ ============================================
4
+ How does a model's internal representation respond to continued training
5
+ on the same task vs. learning a new one?
6
+
7
+ Experiment design:
8
+ Phase 1: Train model on Task A (modular addition) until convergence
9
+ Phase 2: Fork into two branches:
10
+ Branch A→A: Continue training on Task A (same task, more data)
11
+ Branch A→B: Switch to Task B (modular subtraction)
12
+ Track: CKA, subspace angles, gradient alignment, attention entropy,
13
+ representation variance explained, probing accuracy — all per layer,
14
+ at every checkpoint.
15
+
16
+ The key contrast reveals what "learning" looks like at the representation
17
+ level vs. what "forgetting" looks like — and the precise moment they diverge.
18
+ """
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.optim as optim
23
+ import numpy as np
24
+ import json
25
+ import os
26
+ import copy
27
+ import time
28
+ from pathlib import Path
29
+ from collections import defaultdict
30
+ from typing import Dict, List, Optional
31
+
32
+ from model import SmallTransformer, TransformerConfig
33
+ from tasks import (
34
+ ModularArithmeticDataset, get_probe_data, get_dataloaders,
35
+ DEFAULT_P, NUM_SPECIAL
36
+ )
37
+ from representation_tracker import (
38
+ linear_CKA, svcca, subspace_angles, mean_subspace_angle_degrees,
39
+ gradient_alignment, attention_entropy, task_variance_explained,
40
+ parameter_delta_cosine, weight_change_magnitude_per_layer,
41
+ cka_heatmap
42
+ )
43
+
44
+
45
+ def evaluate(model, dataloader, device) -> Dict[str, float]:
46
+ """Evaluate accuracy and loss on a dataset."""
47
+ model.eval()
48
+ total_loss = 0
49
+ total_correct = 0
50
+ total_count = 0
51
+
52
+ with torch.no_grad():
53
+ for batch in dataloader:
54
+ input_ids = batch['input_ids'].to(device)
55
+ labels = batch['labels'].to(device)
56
+ out = model(input_ids, labels=labels)
57
+
58
+ total_loss += out['loss'].item() * input_ids.shape[0]
59
+ # Accuracy: check last position prediction
60
+ preds = out['logits'][:, -1, :].argmax(dim=-1)
61
+ targets = labels[:, -1]
62
+ total_correct += (preds == targets).sum().item()
63
+ total_count += input_ids.shape[0]
64
+
65
+ return {
66
+ 'loss': total_loss / total_count,
67
+ 'accuracy': total_correct / total_count,
68
+ }
69
+
70
+
71
+ def collect_representations(model, probe_input_ids, device,
72
+ token_position: int = -1) -> Dict:
73
+ """
74
+ Collect all representation data from a single forward pass on probe data.
75
+ Returns hidden states, attention patterns, MLP activations.
76
+ """
77
+ model.eval()
78
+ with torch.no_grad():
79
+ out = model(probe_input_ids.to(device), return_internals=True)
80
+
81
+ # Extract at the answer position (last token)
82
+ hidden_states = [hs[:, token_position, :].cpu()
83
+ for hs in out['hidden_states']]
84
+ attn_weights = [aw.cpu() for aw in out['attn_weights']]
85
+ mlp_hidden = [mh[:, token_position, :].cpu()
86
+ for mh in out['mlp_hidden']]
87
+
88
+ return {
89
+ 'hidden_states': hidden_states,
90
+ 'attn_weights': attn_weights,
91
+ 'mlp_hidden': mlp_hidden,
92
+ }
93
+
94
+
95
+ def compute_all_metrics(
96
+ model, model_init_state, model_phase1_state,
97
+ reps_current, reps_at_init, reps_at_phase1_end,
98
+ probe_input_ids_a, probe_labels_a,
99
+ probe_input_ids_b, probe_labels_b,
100
+ device, config
101
+ ) -> Dict:
102
+ """
103
+ Compute the full suite of representation metrics at a single checkpoint.
104
+ """
105
+ metrics = {}
106
+ n_layers = config.n_layers + 1 # +1 for embedding layer
107
+
108
+ # === Per-layer metrics ===
109
+ for layer_idx in range(n_layers):
110
+ prefix = f'layer_{layer_idx}'
111
+ curr = reps_current['hidden_states'][layer_idx]
112
+ init = reps_at_init['hidden_states'][layer_idx]
113
+ p1 = reps_at_phase1_end['hidden_states'][layer_idx]
114
+
115
+ # CKA vs initialization (how far has the representation drifted?)
116
+ metrics[f'{prefix}/cka_vs_init'] = linear_CKA(curr, init)
117
+
118
+ # CKA vs end of Phase 1 (how much has Phase 2 changed things?)
119
+ metrics[f'{prefix}/cka_vs_phase1'] = linear_CKA(curr, p1)
120
+
121
+ # SVCCA for secondary comparison
122
+ metrics[f'{prefix}/svcca_vs_phase1'] = svcca(curr, p1)
123
+
124
+ # Subspace angle vs Phase 1 end
125
+ k = min(10, curr.shape[0] // 2, curr.shape[1])
126
+ if k > 0:
127
+ metrics[f'{prefix}/subspace_angle_vs_phase1'] = \
128
+ mean_subspace_angle_degrees(curr, p1, k=k)
129
+ else:
130
+ metrics[f'{prefix}/subspace_angle_vs_phase1'] = 0.0
131
+
132
+ # === Attention entropy per layer ===
133
+ for layer_idx, aw in enumerate(reps_current['attn_weights']):
134
+ ent = attention_entropy(aw)
135
+ metrics[f'layer_{layer_idx+1}/attn_entropy_mean'] = ent['mean_entropy']
136
+ for h, he in enumerate(ent['per_head_entropy']):
137
+ metrics[f'layer_{layer_idx+1}/head_{h}_entropy'] = he
138
+
139
+ # === Task A representation variance explained ===
140
+ for layer_idx in range(n_layers):
141
+ curr = reps_current['hidden_states'][layer_idx]
142
+ if len(set(probe_labels_a.tolist())) > 1:
143
+ tve = task_variance_explained(
144
+ curr, torch.tensor(probe_labels_a, dtype=torch.float), n_components=10
145
+ )
146
+ metrics[f'layer_{layer_idx}/task_a_var_explained'] = tve['weighted_r2']
147
+
148
+ # === Parameter-space: weight change magnitude ===
149
+ current_state = {k: v.cpu() for k, v in model.state_dict().items()}
150
+ wc_from_init = weight_change_magnitude_per_layer(model_init_state, current_state)
151
+ wc_from_p1 = weight_change_magnitude_per_layer(model_phase1_state, current_state)
152
+
153
+ # Aggregate per block
154
+ for block_idx in range(config.n_layers):
155
+ init_total = sum(v for k, v in wc_from_init.items()
156
+ if f'blocks.{block_idx}' in k)
157
+ p1_total = sum(v for k, v in wc_from_p1.items()
158
+ if f'blocks.{block_idx}' in k)
159
+ metrics[f'block_{block_idx}/weight_change_from_init'] = init_total
160
+ metrics[f'block_{block_idx}/weight_change_from_phase1'] = p1_total
161
+
162
+ return metrics
163
+
164
+
165
+ def train_phase(
166
+ model, optimizer, dataloader, n_epochs: int,
167
+ device, phase_name: str,
168
+ # For metric collection
169
+ model_init_state, model_phase1_state,
170
+ reps_at_init, reps_at_phase1_end,
171
+ probe_input_ids_a, probe_labels_a,
172
+ probe_input_ids_b, probe_labels_b,
173
+ eval_loaders: Dict,
174
+ config: TransformerConfig,
175
+ checkpoint_every: int = 50, # steps between metric collection
176
+ output_dir: str = 'results',
177
+ ) -> List[Dict]:
178
+ """
179
+ Train for n_epochs, collecting representation metrics periodically.
180
+ """
181
+ history = []
182
+ global_step = 0
183
+ os.makedirs(output_dir, exist_ok=True)
184
+
185
+ for epoch in range(n_epochs):
186
+ model.train()
187
+ epoch_loss = 0
188
+ n_batches = 0
189
+
190
+ for batch in dataloader:
191
+ input_ids = batch['input_ids'].to(device)
192
+ labels = batch['labels'].to(device)
193
+
194
+ out = model(input_ids, labels=labels)
195
+ loss = out['loss']
196
+
197
+ optimizer.zero_grad()
198
+ loss.backward()
199
+ optimizer.step()
200
+
201
+ epoch_loss += loss.item()
202
+ n_batches += 1
203
+ global_step += 1
204
+
205
+ # Collect metrics at checkpoint intervals
206
+ if global_step % checkpoint_every == 0:
207
+ model.eval()
208
+
209
+ # Get current representations on probe data
210
+ reps_current = collect_representations(
211
+ model, probe_input_ids_a, device
212
+ )
213
+
214
+ # Compute all metrics
215
+ step_metrics = compute_all_metrics(
216
+ model, model_init_state, model_phase1_state,
217
+ reps_current, reps_at_init, reps_at_phase1_end,
218
+ probe_input_ids_a, probe_labels_a,
219
+ probe_input_ids_b, probe_labels_b,
220
+ device, config
221
+ )
222
+
223
+ # Evaluate on all datasets
224
+ for name, loader in eval_loaders.items():
225
+ eval_res = evaluate(model, loader, device)
226
+ step_metrics[f'eval/{name}_loss'] = eval_res['loss']
227
+ step_metrics[f'eval/{name}_acc'] = eval_res['accuracy']
228
+
229
+ # Gradient alignment between tasks
230
+ # Get one batch from each task
231
+ batch_a = next(iter(eval_loaders['add_test']))
232
+ batch_b = next(iter(eval_loaders['subtract_test']))
233
+
234
+ def loss_fn(m, b):
235
+ return m(b['input_ids'].to(device),
236
+ labels=b['labels'].to(device))['loss']
237
+
238
+ try:
239
+ ga = gradient_alignment(model, batch_a, batch_b, loss_fn)
240
+ step_metrics['gradient_alignment_a_vs_b'] = ga
241
+ except Exception:
242
+ step_metrics['gradient_alignment_a_vs_b'] = 0.0
243
+
244
+ step_metrics['phase'] = phase_name
245
+ step_metrics['epoch'] = epoch
246
+ step_metrics['step'] = global_step
247
+ step_metrics['train_loss'] = epoch_loss / n_batches
248
+
249
+ history.append(step_metrics)
250
+
251
+ print(f"[{phase_name}] Step {global_step} | "
252
+ f"Loss: {epoch_loss/n_batches:.4f} | "
253
+ f"Add acc: {step_metrics.get('eval/add_test_acc', 0):.3f} | "
254
+ f"Sub acc: {step_metrics.get('eval/subtract_test_acc', 0):.3f} | "
255
+ f"CKA(L1 vs P1): {step_metrics.get('layer_1/cka_vs_phase1', 0):.3f} | "
256
+ f"Grad align: {step_metrics.get('gradient_alignment_a_vs_b', 0):.3f}")
257
+
258
+ model.train()
259
+
260
+ # End of epoch eval
261
+ print(f"[{phase_name}] Epoch {epoch+1}/{n_epochs} complete, "
262
+ f"avg loss: {epoch_loss/n_batches:.4f}")
263
+
264
+ return history
265
+
266
+
267
+ def run_experiment(
268
+ p: int = DEFAULT_P,
269
+ n_layers: int = 2,
270
+ d_model: int = 128,
271
+ n_heads: int = 4,
272
+ d_mlp: int = 512,
273
+ phase1_epochs: int = 100,
274
+ phase2_epochs: int = 100,
275
+ lr: float = 1e-3,
276
+ weight_decay: float = 1.0,
277
+ batch_size: int = 512,
278
+ train_frac: float = 0.5,
279
+ checkpoint_every: int = 50,
280
+ output_dir: str = 'results',
281
+ seed: int = 42,
282
+ ):
283
+ """
284
+ Run the full two-phase experiment.
285
+ """
286
+ torch.manual_seed(seed)
287
+ np.random.seed(seed)
288
+
289
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
290
+ print(f"Using device: {device}")
291
+
292
+ # === Setup ===
293
+ config = TransformerConfig(
294
+ vocab_size=p + NUM_SPECIAL,
295
+ n_layers=n_layers,
296
+ d_model=d_model,
297
+ n_heads=n_heads,
298
+ d_mlp=d_mlp,
299
+ max_seq_len=5,
300
+ )
301
+
302
+ model = SmallTransformer(config).to(device)
303
+ print(f"Model parameters: {model.count_parameters():,}")
304
+
305
+ # Save initial state
306
+ model_init_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
307
+
308
+ # Dataloaders
309
+ loaders = get_dataloaders(p=p, batch_size=batch_size,
310
+ train_frac=train_frac, seed=seed)
311
+
312
+ # Probe datasets (fixed subsets for consistent metric computation)
313
+ ds_a = ModularArithmeticDataset('add', p=p, split='test', train_frac=train_frac, seed=seed)
314
+ ds_b = ModularArithmeticDataset('subtract', p=p, split='test', train_frac=train_frac, seed=seed)
315
+ probe_ids_a, probe_labels_a = get_probe_data(ds_a, n_samples=min(500, len(ds_a)))
316
+ probe_ids_b, probe_labels_b = get_probe_data(ds_b, n_samples=min(500, len(ds_b)))
317
+
318
+ # Initial representations
319
+ reps_at_init = collect_representations(model, probe_ids_a, device)
320
+
321
+ # ===========================
322
+ # PHASE 1: Train on Task A
323
+ # ===========================
324
+ print("\n" + "=" * 60)
325
+ print("PHASE 1: Training on Task A (Modular Addition)")
326
+ print("=" * 60)
327
+
328
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
329
+
330
+ # Dummy phase1 state for Phase 1 tracking (use init)
331
+ phase1_history = train_phase(
332
+ model=model,
333
+ optimizer=optimizer,
334
+ dataloader=loaders['add_train'],
335
+ n_epochs=phase1_epochs,
336
+ device=device,
337
+ phase_name='phase1_add',
338
+ model_init_state=model_init_state,
339
+ model_phase1_state=model_init_state, # placeholder
340
+ reps_at_init=reps_at_init,
341
+ reps_at_phase1_end=reps_at_init, # placeholder
342
+ probe_input_ids_a=probe_ids_a,
343
+ probe_labels_a=probe_labels_a,
344
+ probe_input_ids_b=probe_ids_b,
345
+ probe_labels_b=probe_labels_b,
346
+ eval_loaders=loaders,
347
+ config=config,
348
+ checkpoint_every=checkpoint_every,
349
+ output_dir=output_dir,
350
+ )
351
+
352
+ # Save Phase 1 endpoint
353
+ model_phase1_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
354
+ reps_at_phase1_end = collect_representations(model, probe_ids_a, device)
355
+ phase1_final_eval = evaluate(model, loaders['add_test'], device)
356
+ print(f"\nPhase 1 final — Add accuracy: {phase1_final_eval['accuracy']:.3f}")
357
+
358
+ # Save Phase 1 checkpoint
359
+ torch.save(model.state_dict(), os.path.join(output_dir, 'phase1_checkpoint.pt'))
360
+
361
+ # ===========================
362
+ # PHASE 2: Fork into A→A and A→B
363
+ # ===========================
364
+
365
+ # Branch A→A: Continue on same task
366
+ print("\n" + "=" * 60)
367
+ print("PHASE 2a: Branch A→A (Continue training on Addition)")
368
+ print("=" * 60)
369
+
370
+ model_aa = SmallTransformer(config).to(device)
371
+ model_aa.load_state_dict(torch.load(os.path.join(output_dir, 'phase1_checkpoint.pt'),
372
+ weights_only=True))
373
+ optimizer_aa = optim.AdamW(model_aa.parameters(), lr=lr, weight_decay=weight_decay)
374
+
375
+ history_aa = train_phase(
376
+ model=model_aa,
377
+ optimizer=optimizer_aa,
378
+ dataloader=loaders['add_train'],
379
+ n_epochs=phase2_epochs,
380
+ device=device,
381
+ phase_name='phase2_aa',
382
+ model_init_state=model_init_state,
383
+ model_phase1_state=model_phase1_state,
384
+ reps_at_init=reps_at_init,
385
+ reps_at_phase1_end=reps_at_phase1_end,
386
+ probe_input_ids_a=probe_ids_a,
387
+ probe_labels_a=probe_labels_a,
388
+ probe_input_ids_b=probe_ids_b,
389
+ probe_labels_b=probe_labels_b,
390
+ eval_loaders=loaders,
391
+ config=config,
392
+ checkpoint_every=checkpoint_every,
393
+ output_dir=output_dir,
394
+ )
395
+
396
+ # Branch A→B: Switch to new task
397
+ print("\n" + "=" * 60)
398
+ print("PHASE 2b: Branch A→B (Switch to Subtraction)")
399
+ print("=" * 60)
400
+
401
+ model_ab = SmallTransformer(config).to(device)
402
+ model_ab.load_state_dict(torch.load(os.path.join(output_dir, 'phase1_checkpoint.pt'),
403
+ weights_only=True))
404
+ optimizer_ab = optim.AdamW(model_ab.parameters(), lr=lr, weight_decay=weight_decay)
405
+
406
+ history_ab = train_phase(
407
+ model=model_ab,
408
+ optimizer=optimizer_ab,
409
+ dataloader=loaders['subtract_train'],
410
+ n_epochs=phase2_epochs,
411
+ device=device,
412
+ phase_name='phase2_ab',
413
+ model_init_state=model_init_state,
414
+ model_phase1_state=model_phase1_state,
415
+ reps_at_init=reps_at_init,
416
+ reps_at_phase1_end=reps_at_phase1_end,
417
+ probe_input_ids_a=probe_ids_a,
418
+ probe_labels_a=probe_labels_a,
419
+ probe_input_ids_b=probe_ids_b,
420
+ probe_labels_b=probe_labels_b,
421
+ eval_loaders=loaders,
422
+ config=config,
423
+ checkpoint_every=checkpoint_every,
424
+ output_dir=output_dir,
425
+ )
426
+
427
+ # ===========================
428
+ # PHASE 3: Cross-model comparison
429
+ # ===========================
430
+ print("\n" + "=" * 60)
431
+ print("PHASE 3: Cross-model representation comparison")
432
+ print("=" * 60)
433
+
434
+ reps_aa = collect_representations(model_aa, probe_ids_a, device)
435
+ reps_ab = collect_representations(model_ab, probe_ids_a, device)
436
+
437
+ cross_metrics = {}
438
+ for layer_idx in range(config.n_layers + 1):
439
+ ha = reps_aa['hidden_states'][layer_idx]
440
+ hb = reps_ab['hidden_states'][layer_idx]
441
+ hp1 = reps_at_phase1_end['hidden_states'][layer_idx]
442
+
443
+ cross_metrics[f'layer_{layer_idx}/cka_aa_vs_ab'] = linear_CKA(ha, hb)
444
+ cross_metrics[f'layer_{layer_idx}/cka_aa_vs_p1'] = linear_CKA(ha, hp1)
445
+ cross_metrics[f'layer_{layer_idx}/cka_ab_vs_p1'] = linear_CKA(hb, hp1)
446
+ cross_metrics[f'layer_{layer_idx}/subspace_angle_aa_vs_ab'] = \
447
+ mean_subspace_angle_degrees(ha, hb, k=min(10, ha.shape[0] // 2, ha.shape[1]))
448
+
449
+ # CKA heatmaps
450
+ heatmap_aa_vs_ab = cka_heatmap(reps_aa['hidden_states'], reps_ab['hidden_states'])
451
+ heatmap_aa_vs_p1 = cka_heatmap(reps_aa['hidden_states'],
452
+ reps_at_phase1_end['hidden_states'])
453
+ heatmap_ab_vs_p1 = cka_heatmap(reps_ab['hidden_states'],
454
+ reps_at_phase1_end['hidden_states'])
455
+
456
+ # Parameter delta cosine
457
+ params_init = [v for v in model_init_state.values()]
458
+ params_aa = [v.cpu() for v in model_aa.state_dict().values()]
459
+ params_ab = [v.cpu() for v in model_ab.state_dict().values()]
460
+ params_p1 = [v for v in model_phase1_state.values()]
461
+
462
+ cross_metrics['param_delta_cosine_aa_vs_ab'] = \
463
+ parameter_delta_cosine(params_p1, params_aa, params_ab)
464
+ cross_metrics['param_delta_cosine_aa_vs_p1_from_init'] = \
465
+ parameter_delta_cosine(params_init, params_p1, params_aa)
466
+
467
+ print("\n=== Cross-model metrics ===")
468
+ for k, v in sorted(cross_metrics.items()):
469
+ print(f" {k}: {v:.4f}")
470
+
471
+ # ===========================
472
+ # Save all results
473
+ # ===========================
474
+ results = {
475
+ 'config': {
476
+ 'p': p, 'n_layers': n_layers, 'd_model': d_model,
477
+ 'n_heads': n_heads, 'd_mlp': d_mlp,
478
+ 'phase1_epochs': phase1_epochs, 'phase2_epochs': phase2_epochs,
479
+ 'lr': lr, 'weight_decay': weight_decay, 'batch_size': batch_size,
480
+ 'train_frac': train_frac, 'seed': seed,
481
+ 'n_parameters': model.count_parameters(),
482
+ },
483
+ 'phase1_history': phase1_history,
484
+ 'phase2_aa_history': history_aa,
485
+ 'phase2_ab_history': history_ab,
486
+ 'cross_metrics': cross_metrics,
487
+ 'cka_heatmaps': {
488
+ 'aa_vs_ab': heatmap_aa_vs_ab.tolist(),
489
+ 'aa_vs_p1': heatmap_aa_vs_p1.tolist(),
490
+ 'ab_vs_p1': heatmap_ab_vs_p1.tolist(),
491
+ },
492
+ }
493
+
494
+ results_path = os.path.join(output_dir, 'experiment_results.json')
495
+ with open(results_path, 'w') as f:
496
+ json.dump(results, f, indent=2, default=str)
497
+ print(f"\nResults saved to {results_path}")
498
+
499
+ # Save final models
500
+ torch.save(model_aa.state_dict(), os.path.join(output_dir, 'model_aa_final.pt'))
501
+ torch.save(model_ab.state_dict(), os.path.join(output_dir, 'model_ab_final.pt'))
502
+
503
+ return results
504
+
505
+
506
+ if __name__ == '__main__':
507
+ import argparse
508
+ parser = argparse.ArgumentParser()
509
+ parser.add_argument('--p', type=int, default=DEFAULT_P)
510
+ parser.add_argument('--n-layers', type=int, default=2)
511
+ parser.add_argument('--d-model', type=int, default=128)
512
+ parser.add_argument('--n-heads', type=int, default=4)
513
+ parser.add_argument('--d-mlp', type=int, default=512)
514
+ parser.add_argument('--phase1-epochs', type=int, default=100)
515
+ parser.add_argument('--phase2-epochs', type=int, default=100)
516
+ parser.add_argument('--lr', type=float, default=1e-3)
517
+ parser.add_argument('--weight-decay', type=float, default=1.0)
518
+ parser.add_argument('--batch-size', type=int, default=512)
519
+ parser.add_argument('--train-frac', type=float, default=0.5)
520
+ parser.add_argument('--checkpoint-every', type=int, default=50)
521
+ parser.add_argument('--output-dir', type=str, default='results')
522
+ parser.add_argument('--seed', type=int, default=42)
523
+ args = parser.parse_args()
524
+
525
+ run_experiment(**vars(args))