AbstractPhil commited on
Commit
947e75a
·
verified ·
1 Parent(s): 7f5d157

Create 5_big_finder_sweep_600_configs.py

Browse files
Files changed (1) hide show
  1. 5_big_finder_sweep_600_configs.py +119 -0
5_big_finder_sweep_600_configs.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cell_q_runner.py — Phase Q H2-candidate extended sweep
3
+
4
+ Takes the top 10 configs flagged by the P-sweep analyzer's
5
+ continued-training-potential metric and re-runs each with
6
+ batch_limit=1000 (50× the P sweep's 20-batch budget).
7
+
8
+ Purpose: produce the data needed to assign H2 class ranks —
9
+ actual convergence floors, trajectory shapes at full budget,
10
+ Adam-vs-LBFGS parity question, sharpened discrimination ratios.
11
+
12
+ Output: /content/phaseQ_reports/results_phaseQ.json
13
+ """
14
+
15
+ import json
16
+ import time
17
+ import traceback
18
+ from pathlib import Path
19
+
20
+
21
+ OUTPUT_ROOT = Path("/content/phaseQ_reports")
22
+ OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
23
+ AGGREGATE_PATH = OUTPUT_ROOT / "results_phaseQ.json"
24
+
25
+
26
+ def run_sweep():
27
+ configs = get_phaseQ_configs()
28
+ print(f"Phase Q: {len(configs)} configs at 1000 batches each")
29
+ print(f"Output: {OUTPUT_ROOT}\n")
30
+
31
+ # Print the config lineup so we know what's running
32
+ print("Config lineup:")
33
+ for cfg in configs:
34
+ ov = cfg['overrides']
35
+ print(f" {cfg['variant']:<45} "
36
+ f"h={ov['hidden']} V={ov['V']} D={ov['D']} "
37
+ f"dp={ov['depth']} nx={ov['n_cross']} opt={ov['optimizer']}")
38
+ print()
39
+
40
+ # Resume support
41
+ results = []
42
+ done_variants = set()
43
+ if AGGREGATE_PATH.exists():
44
+ with open(AGGREGATE_PATH) as f:
45
+ results = json.load(f)
46
+ done_variants = {r.get('variant') for r in results}
47
+ if done_variants:
48
+ print(f"Resuming: {len(done_variants)} configs already complete")
49
+
50
+ sweep_t0 = time.time()
51
+ for i, cfg in enumerate(configs):
52
+ if cfg['variant'] in done_variants:
53
+ print(f" [{i+1}/{len(configs)}] {cfg['variant']} (skipped — already done)")
54
+ continue
55
+
56
+ config_output_dir = OUTPUT_ROOT / cfg['variant']
57
+ config_output_dir.mkdir(exist_ok=True)
58
+
59
+ batch_limit = phase2_batch_limit(cfg)
60
+ t0 = time.time()
61
+ print(f" [{i+1}/{len(configs)}] {cfg['variant']} "
62
+ f"(batch_limit={batch_limit}) ...", end=' ', flush=True)
63
+ try:
64
+ report = run_ablation_config(
65
+ ablation_config=cfg,
66
+ output_dir=str(config_output_dir),
67
+ batch_limit=batch_limit,
68
+ num_epochs=cfg.get('num_epochs', 1),
69
+ )
70
+ report['_sweep_status'] = 'ok'
71
+ elapsed = time.time() - t0
72
+ final_mse = report.get('test_mse_per_noise', {}).get(0,
73
+ report.get('test_mse_per_noise', {}).get('0', 'N/A'))
74
+ cv = report.get('observed_sphere_cv', 0.0)
75
+ finite = report.get('params_finite', False)
76
+ status_ind = "OK " if finite else "NaN"
77
+ print(f"{status_ind} ({elapsed:.0f}s, "
78
+ f"G-MSE={final_mse if isinstance(final_mse, str) else f'{final_mse:.5f}'}, "
79
+ f"CV={cv:.3f})")
80
+ except Exception as e:
81
+ report = {
82
+ '_sweep_status': f'error: {type(e).__name__}: {str(e)[:300]}',
83
+ '_traceback': traceback.format_exc()[:2000],
84
+ 'config': cfg,
85
+ 'variant': cfg['variant'],
86
+ }
87
+ print(f"ERROR: {type(e).__name__}: {str(e)[:80]}")
88
+
89
+ report['variant'] = cfg['variant']
90
+ report['wallclock_outer_s'] = time.time() - t0
91
+ results.append(report)
92
+
93
+ # Checkpoint after every run (only 10 configs, cheap)
94
+ with open(AGGREGATE_PATH, 'w') as f:
95
+ json.dump(results, f, indent=2, default=str)
96
+
97
+ total = time.time() - sweep_t0
98
+ print(f"\nPhase Q complete: {len(results)} reports in {total/60:.1f} min")
99
+ print(f"Aggregate: {AGGREGATE_PATH}")
100
+
101
+ # Quick summary
102
+ print(f"\nQuick summary (by rank):")
103
+ print(f" {'Rank':<6} {'Variant':<45} {'G-MSE':>9} {'CV':>6} {'Finite':>7}")
104
+ print(f" {'-'*75}")
105
+ for r in results:
106
+ v = r.get('variant', '?')
107
+ g_mse = r.get('test_mse_per_noise', {}).get(0)
108
+ if g_mse is None:
109
+ g_mse = r.get('test_mse_per_noise', {}).get('0', float('nan'))
110
+ cv = r.get('observed_sphere_cv', 0.0)
111
+ finite = r.get('params_finite', False)
112
+ print(f" {v[:5]:<6} {v[:45]:<45} "
113
+ f"{g_mse:>9.5f} {cv:>6.3f} {'YES' if finite else 'NO':>7}")
114
+
115
+ return results
116
+
117
+
118
+ if __name__ == '__main__':
119
+ results = run_sweep()