Create 5_big_finder_sweep_600_configs.py
Browse files
5_big_finder_sweep_600_configs.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
cell_q_runner.py — Phase Q H2-candidate extended sweep
|
| 3 |
+
|
| 4 |
+
Takes the top 10 configs flagged by the P-sweep analyzer's
|
| 5 |
+
continued-training-potential metric and re-runs each with
|
| 6 |
+
batch_limit=1000 (50× the P sweep's 20-batch budget).
|
| 7 |
+
|
| 8 |
+
Purpose: produce the data needed to assign H2 class ranks —
|
| 9 |
+
actual convergence floors, trajectory shapes at full budget,
|
| 10 |
+
Adam-vs-LBFGS parity question, sharpened discrimination ratios.
|
| 11 |
+
|
| 12 |
+
Output: /content/phaseQ_reports/results_phaseQ.json
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import time
|
| 17 |
+
import traceback
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
OUTPUT_ROOT = Path("/content/phaseQ_reports")
|
| 22 |
+
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
|
| 23 |
+
AGGREGATE_PATH = OUTPUT_ROOT / "results_phaseQ.json"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run_sweep():
|
| 27 |
+
configs = get_phaseQ_configs()
|
| 28 |
+
print(f"Phase Q: {len(configs)} configs at 1000 batches each")
|
| 29 |
+
print(f"Output: {OUTPUT_ROOT}\n")
|
| 30 |
+
|
| 31 |
+
# Print the config lineup so we know what's running
|
| 32 |
+
print("Config lineup:")
|
| 33 |
+
for cfg in configs:
|
| 34 |
+
ov = cfg['overrides']
|
| 35 |
+
print(f" {cfg['variant']:<45} "
|
| 36 |
+
f"h={ov['hidden']} V={ov['V']} D={ov['D']} "
|
| 37 |
+
f"dp={ov['depth']} nx={ov['n_cross']} opt={ov['optimizer']}")
|
| 38 |
+
print()
|
| 39 |
+
|
| 40 |
+
# Resume support
|
| 41 |
+
results = []
|
| 42 |
+
done_variants = set()
|
| 43 |
+
if AGGREGATE_PATH.exists():
|
| 44 |
+
with open(AGGREGATE_PATH) as f:
|
| 45 |
+
results = json.load(f)
|
| 46 |
+
done_variants = {r.get('variant') for r in results}
|
| 47 |
+
if done_variants:
|
| 48 |
+
print(f"Resuming: {len(done_variants)} configs already complete")
|
| 49 |
+
|
| 50 |
+
sweep_t0 = time.time()
|
| 51 |
+
for i, cfg in enumerate(configs):
|
| 52 |
+
if cfg['variant'] in done_variants:
|
| 53 |
+
print(f" [{i+1}/{len(configs)}] {cfg['variant']} (skipped — already done)")
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
config_output_dir = OUTPUT_ROOT / cfg['variant']
|
| 57 |
+
config_output_dir.mkdir(exist_ok=True)
|
| 58 |
+
|
| 59 |
+
batch_limit = phase2_batch_limit(cfg)
|
| 60 |
+
t0 = time.time()
|
| 61 |
+
print(f" [{i+1}/{len(configs)}] {cfg['variant']} "
|
| 62 |
+
f"(batch_limit={batch_limit}) ...", end=' ', flush=True)
|
| 63 |
+
try:
|
| 64 |
+
report = run_ablation_config(
|
| 65 |
+
ablation_config=cfg,
|
| 66 |
+
output_dir=str(config_output_dir),
|
| 67 |
+
batch_limit=batch_limit,
|
| 68 |
+
num_epochs=cfg.get('num_epochs', 1),
|
| 69 |
+
)
|
| 70 |
+
report['_sweep_status'] = 'ok'
|
| 71 |
+
elapsed = time.time() - t0
|
| 72 |
+
final_mse = report.get('test_mse_per_noise', {}).get(0,
|
| 73 |
+
report.get('test_mse_per_noise', {}).get('0', 'N/A'))
|
| 74 |
+
cv = report.get('observed_sphere_cv', 0.0)
|
| 75 |
+
finite = report.get('params_finite', False)
|
| 76 |
+
status_ind = "OK " if finite else "NaN"
|
| 77 |
+
print(f"{status_ind} ({elapsed:.0f}s, "
|
| 78 |
+
f"G-MSE={final_mse if isinstance(final_mse, str) else f'{final_mse:.5f}'}, "
|
| 79 |
+
f"CV={cv:.3f})")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
report = {
|
| 82 |
+
'_sweep_status': f'error: {type(e).__name__}: {str(e)[:300]}',
|
| 83 |
+
'_traceback': traceback.format_exc()[:2000],
|
| 84 |
+
'config': cfg,
|
| 85 |
+
'variant': cfg['variant'],
|
| 86 |
+
}
|
| 87 |
+
print(f"ERROR: {type(e).__name__}: {str(e)[:80]}")
|
| 88 |
+
|
| 89 |
+
report['variant'] = cfg['variant']
|
| 90 |
+
report['wallclock_outer_s'] = time.time() - t0
|
| 91 |
+
results.append(report)
|
| 92 |
+
|
| 93 |
+
# Checkpoint after every run (only 10 configs, cheap)
|
| 94 |
+
with open(AGGREGATE_PATH, 'w') as f:
|
| 95 |
+
json.dump(results, f, indent=2, default=str)
|
| 96 |
+
|
| 97 |
+
total = time.time() - sweep_t0
|
| 98 |
+
print(f"\nPhase Q complete: {len(results)} reports in {total/60:.1f} min")
|
| 99 |
+
print(f"Aggregate: {AGGREGATE_PATH}")
|
| 100 |
+
|
| 101 |
+
# Quick summary
|
| 102 |
+
print(f"\nQuick summary (by rank):")
|
| 103 |
+
print(f" {'Rank':<6} {'Variant':<45} {'G-MSE':>9} {'CV':>6} {'Finite':>7}")
|
| 104 |
+
print(f" {'-'*75}")
|
| 105 |
+
for r in results:
|
| 106 |
+
v = r.get('variant', '?')
|
| 107 |
+
g_mse = r.get('test_mse_per_noise', {}).get(0)
|
| 108 |
+
if g_mse is None:
|
| 109 |
+
g_mse = r.get('test_mse_per_noise', {}).get('0', float('nan'))
|
| 110 |
+
cv = r.get('observed_sphere_cv', 0.0)
|
| 111 |
+
finite = r.get('params_finite', False)
|
| 112 |
+
print(f" {v[:5]:<6} {v[:45]:<45} "
|
| 113 |
+
f"{g_mse:>9.5f} {cv:>6.3f} {'YES' if finite else 'NO':>7}")
|
| 114 |
+
|
| 115 |
+
return results
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == '__main__':
|
| 119 |
+
results = run_sweep()
|