AbstractPhil's picture
Rename 10_run_finetune.py to 010_run_finetune.py
927d119 verified
"""
cell_r_runner.py β€” Phase R: sphere-packing prediction test
Trains 3 configs whose (V, D) match natural sphere polytopes:
D=4, V=16: 16-cell vertices on SΒ³
D=4, V=8: 8-cell / 16-cell vertex subset on SΒ³
D=3, V=20: dodecahedron vertices on SΒ²
Hypothesis: each will produce H2-LIKE rows (high stability, low antipodal
pairs, full rank utilization) because V points uniformly fit S^(D-1) for
these counts. The G-Class behavior at (V=32, D=3) was geometric frustration
β€” natural V's should reproduce H2 sphere-solver character.
After training, immediately runs the v3 probe metrics on each model:
- per-sample sphere-norm
- row stability across 512 gaussian inputs
- antipodal pair fraction
- per-sample silhouette
- effective rank
- pairwise angle distribution
Outputs:
/content/phaseR_reports/results_phaseR.json β€” training results + probes
/content/phaseR_reports/phaseR_summary.png β€” H2-LIKE / G-LIKE verdicts
"""
import json
import math
import time
import traceback
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
OUTPUT_ROOT = Path("/content/phaseR_reports")
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
AGGREGATE_PATH = OUTPUT_ROOT / "results_phaseR.json"
SUMMARY_PLOT = OUTPUT_ROOT / "phaseR_summary.png"
# ════════════════════════════════════════════════════════════════════
# Geometric probe (compact version of v3)
# ════════════════════════════════════════════════════════════════════
def collect_M(model, cfg, n_batches=8, batch_size=64):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ds = OmegaNoiseDataset(
size=n_batches * batch_size, img_size=cfg.img_size,
allowed_types=[0])
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False)
all_M = []
with torch.no_grad():
for imgs, _ in loader:
imgs = imgs.to(device)
out = model(imgs)
M_patch0 = out['svd']['M'][:, 0]
all_M.append(M_patch0.cpu())
return torch.cat(all_M, dim=0).numpy()
def probe_geometry(all_M):
"""Return all v3 probe metrics in one dict."""
# sphere-norm
row_norms = np.linalg.norm(all_M, axis=2)
sphere_normed = abs(row_norms.mean() - 1.0) < 0.05 and row_norms.std() < 0.05
# row stability
mean_dirs = all_M.mean(axis=0)
mean_dir_norms = np.linalg.norm(mean_dirs, axis=1)
# per-sample silhouette (k=5 if Vβ‰₯10 else k=V//2)
V = all_M.shape[1]
k_test = min(5, max(2, V // 2))
sils = []
for i in range(min(20, all_M.shape[0])):
try:
km = KMeans(n_clusters=k_test, n_init=10, random_state=42)
labels = km.fit_predict(all_M[i])
if len(set(labels)) >= 2:
sils.append(silhouette_score(all_M[i], labels))
except Exception:
pass
sils = np.array(sils)
# angular
all_rows = all_M.reshape(-1, all_M.shape[-1])
norms = np.linalg.norm(all_rows, axis=1, keepdims=True)
unit_rows = all_rows / np.clip(norms, 1e-12, None)
n_subset = min(500, unit_rows.shape[0])
idx = np.random.RandomState(42).choice(unit_rows.shape[0], n_subset, replace=False)
cosines = unit_rows[idx] @ unit_rows[idx].T
pairwise_angles = np.arccos(
np.clip(cosines[np.triu_indices(n_subset, k=1)], -1, 1))
# antipodal
unit_dirs = mean_dirs / np.clip(
np.linalg.norm(mean_dirs, axis=1, keepdims=True), 1e-12, None)
cos_mat = unit_dirs @ unit_dirs.T
np.fill_diagonal(cos_mat, 1.0)
most_anti = cos_mat.min(axis=1)
# effective rank
M_avg = all_M.mean(axis=0)
sv = np.linalg.svd(M_avg, compute_uv=False)
sv_norm = sv / sv.sum()
erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum())
return {
'sphere_normed': bool(sphere_normed),
'row_norm_mean': float(row_norms.mean()),
'stability_mean': float(mean_dir_norms.mean()),
'stability_min': float(mean_dir_norms.min()),
'stability_max': float(mean_dir_norms.max()),
'silhouette_mean': float(sils.mean()) if len(sils) else None,
'silhouette_std': float(sils.std()) if len(sils) else None,
'angular_mean': float(pairwise_angles.mean()),
'angular_near_pi': float((pairwise_angles > math.pi - 0.5).mean()),
'angular_near_perp': float(
((pairwise_angles > math.pi/2 - 0.3) &
(pairwise_angles < math.pi/2 + 0.3)).mean()),
'antipodal_frac': float((most_anti < -0.9).mean()),
'antipodal_pairs': int((most_anti < -0.9).sum() // 2),
'antipodal_max_pairs': int(all_M.shape[1] // 2),
'effective_rank': float(erank),
'D': int(all_M.shape[2]),
'utilization': float(erank / all_M.shape[2]),
}
def classify_character(probe):
"""H2-LIKE / G-LIKE / DIFFUSE / HYBRID β€” same logic as v3."""
stab = probe['stability_mean']
anti = probe['antipodal_frac']
util = probe['utilization']
if stab > 0.85 and anti < 0.55 and util > 0.95:
return 'H2-LIKE'
if stab < 0.65 and anti > 0.80:
return 'G-LIKE'
if stab < 0.65 and anti < 0.55:
return 'DIFFUSE'
return 'HYBRID'
# ════════════════════════════════════════════════════════════════════
# Build trained model from a Q-style report
# ════════════════════════════════════════════════════════════════════
def build_model_from_config(ablation_config):
"""Build the model architecture (without loaded weights). After
training, load from checkpoint."""
cfg = build_run_config(ablation_config)
overrides = ablation_config['overrides']
model = PatchSVAE_F_Ablation(
matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
hidden=cfg.hidden, depth=cfg.depth,
n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
max_alpha=overrides.get('max_alpha', cfg.max_alpha),
alpha_init=cfg.alpha_init,
activation=overrides.get('activation', 'gelu'),
row_norm=overrides.get('row_norm', 'sphere'),
svd_mode=overrides.get('svd', 'fp64'),
linear_readout=overrides.get('linear_readout', False),
match_params=overrides.get('match_params', True),
init_scheme=overrides.get('init', 'orthogonal'),
)
return model, cfg
def load_trained(ablation_config, output_dir):
"""Load the trained model's weights from its epoch checkpoint."""
model, cfg = build_model_from_config(ablation_config)
ckpt_path = Path(output_dir) / "epoch_1_checkpoint.pt"
ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
state_dict = (
ckpt.get('model_state')
or ckpt.get('model_state_dict')
or ckpt.get('state_dict')
or ckpt
)
model.load_state_dict(state_dict)
model.eval()
return model, cfg
# ════════════════════════════════════════════════════════════════════
# Main
# ════════════════════════════════════════════════════════════════════
def run_sweep_with_probes():
configs = get_phaseR_configs()
print(f"Phase R: {len(configs)} packed-polytope test configs")
print(f"Output: {OUTPUT_ROOT}\n")
print("Predicted: each config produces H2-LIKE static rows because")
print("(V, D) matches a natural sphere polytope vertex count.\n")
print("Config lineup:")
for cfg in configs:
ov = cfg['overrides']
print(f" {cfg['variant']:<45} V={ov['V']} D={ov['D']}")
print()
results = []
sweep_t0 = time.time()
for i, cfg in enumerate(configs):
print(f"[{i+1}/{len(configs)}] {cfg['variant']}")
config_output_dir = OUTPUT_ROOT / cfg['variant']
config_output_dir.mkdir(exist_ok=True)
# ── Train ──
t0 = time.time()
try:
report = run_ablation_config(
ablation_config=cfg,
output_dir=str(config_output_dir),
batch_limit=phase2_batch_limit(cfg),
num_epochs=cfg.get('num_epochs', 1),
)
report['_sweep_status'] = 'ok'
train_time = time.time() - t0
g_mse = report.get('test_mse_per_noise', {}).get(0,
report.get('test_mse_per_noise', {}).get('0'))
cv = report.get('observed_sphere_cv', 0.0)
print(f" train: {train_time:.0f}s, "
f"G-MSE={g_mse:.5f}, CV={cv:.3f}")
# ── Probe geometry ──
print(f" probe: collecting M rows + running v3 metrics...", end=' ', flush=True)
t1 = time.time()
try:
model, run_cfg = load_trained(cfg, config_output_dir)
all_M = collect_M(model, run_cfg)
probe = probe_geometry(all_M)
probe['M_shape'] = list(all_M.shape)
probe['character'] = classify_character(probe)
report['probe'] = probe
print(f"{time.time()-t1:.0f}s β†’ {probe['character']}")
print(f" stability={probe['stability_mean']:.3f}, "
f"antipodal={probe['antipodal_pairs']}/"
f"{probe['antipodal_max_pairs']}, "
f"utilization={probe['utilization']*100:.0f}%")
except Exception as e:
report['probe'] = {'error': f'{type(e).__name__}: {str(e)[:300]}'}
print(f"FAILED: {type(e).__name__}: {str(e)[:80]}")
except Exception as e:
report = {
'_sweep_status': f'error: {type(e).__name__}: {str(e)[:300]}',
'_traceback': traceback.format_exc()[:2000],
'config': cfg,
'variant': cfg['variant'],
}
print(f" ERROR: {type(e).__name__}: {str(e)[:80]}")
report['variant'] = cfg['variant']
report['wallclock_outer_s'] = time.time() - t0
results.append(report)
with open(AGGREGATE_PATH, 'w') as f:
json.dump(results, f, indent=2, default=str)
print()
# ════════════════════════════════════════════════════════════════
# Verdict summary
# ════════════════════════════════════════════════════════════════
print("=" * 70)
print("PHASE R RESULTS β€” sphere-packing hypothesis test")
print("=" * 70)
print(f"\n{'Variant':<45} {'G-MSE':>9} {'Char':>10} {'Stab':>6} {'Anti':>10}")
print("-" * 85)
n_h2like = 0
n_glike = 0
for r in results:
v = r.get('variant', '?')
probe = r.get('probe', {})
if 'error' in probe:
print(f"{v[:45]:<45} {'N/A':>9} {'PROBE_ERR':>10}")
continue
g_mse = r.get('test_mse_per_noise', {}).get(0,
r.get('test_mse_per_noise', {}).get('0', float('nan')))
char = probe.get('character', '?')
stab = probe.get('stability_mean', 0)
ap_pairs = probe.get('antipodal_pairs', 0)
ap_max = probe.get('antipodal_max_pairs', 0)
print(f"{v[:45]:<45} {g_mse:>9.5f} {char:>10} {stab:>6.3f} "
f"{f'{ap_pairs}/{ap_max}':>10}")
if char == 'H2-LIKE':
n_h2like += 1
elif char == 'G-LIKE':
n_glike += 1
print(f"\n H2-LIKE: {n_h2like}/{len(results)}")
print(f" G-LIKE: {n_glike}/{len(results)}")
print("\n" + "=" * 70)
print("INTERPRETATION")
print("=" * 70)
if n_h2like == len(results):
print(" All 3 packed-polytope configs produced H2-LIKE batteries.")
print(" β†’ Sphere-packing hypothesis CONFIRMED.")
print(" β†’ G-Class is a SYMPTOM of (V, D) geometric frustration,")
print(" not a battery family in its own right.")
print(" β†’ Useful (V, D) pairs follow polytope vertex counts:")
print(" D=3: 4, 6, 8, 12, 20 (Platonic)")
print(" D=4: 5, 8, 16, 24, 120, 600 (4D regular polytopes)")
print(" Dβ‰₯5: most V's work (high-D sphere-packing flexible)")
elif n_h2like > 0:
print(f" Mixed: {n_h2like}/{len(results)} produced H2-LIKE.")
print(" β†’ Hypothesis partially supported but more nuanced.")
print(" β†’ Some packed-polytope V's work, others don't.")
else:
print(" No H2-LIKE batteries produced.")
print(" β†’ Sphere-packing hypothesis FALSIFIED.")
print(" β†’ G-Class behavior has a different cause.")
total = time.time() - sweep_t0
print(f"\nTotal time: {total/60:.1f} min")
print(f"Aggregate: {AGGREGATE_PATH}")
return results
if __name__ == '__main__':
results = run_sweep_with_probes()