""" cell_g_class_probe_v3.py — three-way geometric probe Tests the same geometric battery of metrics on three batteries: H2a: Q-rank02 (D=4, V=32, 40K params, 1000 batches Adam) G-Cand: Q-rank09 (D=3, V=32, 29K params, 1000 batches Adam) h2-64: single-noise gaussian battery (D=8, V=64, 57K params, 10 epochs) Key question: is the antipodal+rotational structure found in G-Cand a property of D=3 specifically, or a property of LOW-band attractors at ANY D? h2-64 has D=8 which sits in LOW band naturally (CV ~0.21). Predicted outcomes: - h2-64 looks like H2 (uniform sphere, stable rows): G-class is D=3-specific - h2-64 looks like G (antipodal pairs, rotating frame): G-class is the universal LOW-band character; H2a is the OUTLIER for being so static - h2-64 looks like neither (some third pattern): D=8 has its own geometric character we haven't seen yet Loading h2-64 from `loaded` if defined in session, else fetches from HF. """ import json import math import sys from pathlib import Path import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score CKPT_DIR = Path("/content/phaseQ_reports") RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt" RANK02_CKPT = CKPT_DIR / "Q_rank02_h64_V32_D4_dp0_nx0_adam" / "epoch_1_checkpoint.pt" OUTPUT_PLOT = CKPT_DIR / "g_class_probe_v3.png" OUTPUT_JSON = CKPT_DIR / "g_class_probe_v3.json" # ════════════════════════════════════════════════════════════════════ # Loading # ════════════════════════════════════════════════════════════════════ def load_qsweep_model(variant_str, ckpt_path): """Load Q-sweep model (rank02 or rank09).""" cfgs = get_phaseQ_configs() cfg_dict = next(c for c in cfgs if variant_str in c['variant']) cfg = build_run_config(cfg_dict) overrides = cfg_dict['overrides'] model = PatchSVAE_F_Ablation( matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size, hidden=cfg.hidden, depth=cfg.depth, n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads, max_alpha=overrides.get('max_alpha', cfg.max_alpha), alpha_init=cfg.alpha_init, activation=overrides.get('activation', 'gelu'), row_norm=overrides.get('row_norm', 'sphere'), svd_mode=overrides.get('svd', 'fp64'), linear_readout=overrides.get('linear_readout', False), match_params=overrides.get('match_params', True), init_scheme=overrides.get('init', 'orthogonal'), ) ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False) state_dict = ( ckpt.get('model_state') or ckpt.get('model_state_dict') or ckpt.get('state_dict') or ckpt ) model.load_state_dict(state_dict) model.eval() return model, cfg def load_h2_64_battery(battery_idx=0, phase='final'): """Get one battery from the h2-64 array. Tries `loaded` from globals first (already in Colab session), falls back to AutoModel.from_pretrained. Returns (bank_module, V, D, patch_size, img_size). """ array_model = globals().get('loaded') if array_model is None: print(f" `loaded` not found, fetching from HF...") # Importing geolip_svae.arrays auto-registers BatteryArrayConfig # with HF Auto* — without this, model_type='battery_array' is unknown. import geolip_svae.arrays # noqa: F401 from transformers import AutoModel array_model = AutoModel.from_pretrained( "AbstractPhil/geolip-svae-h2-64") print(f" Loaded h2-64 from HF") else: print(f" Using `loaded` from global session") # Get the specific battery bank bank = array_model.bank(battery_idx, phase) bank.eval() # Get architecture from config cfg_dict = array_model.config.batteries[battery_idx] print(f" Battery {battery_idx} ({phase}): " f"subgroup={cfg_dict.get('subgroup')}, " f"variant={cfg_dict.get('variant')}, " f"noise_types={cfg_dict.get('noise_types')}") # Architecture is uniform across h2-64 batteries V = 64 D = 8 patch_size = 2 img_size = 64 return bank, V, D, patch_size, img_size # ════════════════════════════════════════════════════════════════════ # Collect M rows # ════════════════════════════════════════════════════════════════════ def collect_per_sample_M(model, V, D, patch_size, img_size, n_batches=8, batch_size=64, is_h2_64_bank=False): """Collect [n_samples, V, D] M tensors from gaussian inputs.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) ds = OmegaNoiseDataset( size=n_batches * batch_size, img_size=img_size, allowed_types=[0]) # gaussian loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=False) all_M = [] with torch.no_grad(): for imgs, _ in loader: imgs = imgs.to(device) out = model(imgs) # Both PatchSVAE and h2-64 banks return dict with 'svd' or # similar — the M tensor is at out['svd']['M'][:, 0] if 'svd' in out and 'M' in out['svd']: M_patch0 = out['svd']['M'][:, 0] # [B, V, D] elif 'M' in out: M_patch0 = out['M'][:, 0] else: # Fall back: try to access via internal encode_patches from johanna_F_trainer import extract_patches patches = extract_patches(imgs, patch_size) enc = model.encode_patches(patches) M_patch0 = enc['M'][:, 0] all_M.append(M_patch0.cpu()) return torch.cat(all_M, dim=0).numpy() # [n_samples, V, D] # ════════════════════════════════════════════════════════════════════ # Tests (carry over from v2) # ════════════════════════════════════════════════════════════════════ def test_sphere_norm(all_M): row_norms = np.linalg.norm(all_M, axis=2) return { 'min': float(row_norms.min()), 'max': float(row_norms.max()), 'mean': float(row_norms.mean()), 'std': float(row_norms.std()), 'sphere_normed': bool( abs(row_norms.mean() - 1.0) < 0.05 and row_norms.std() < 0.05), } def test_row_stability(all_M): mean_dirs = all_M.mean(axis=0) mean_dir_norms = np.linalg.norm(mean_dirs, axis=1) return { 'mean': float(mean_dir_norms.mean()), 'min': float(mean_dir_norms.min()), 'max': float(mean_dir_norms.max()), 'std': float(mean_dir_norms.std()), 'mean_dir_norms': mean_dir_norms.tolist(), } def test_per_sample_clustering(all_M, k_test=5, n_samples=20): silhouettes = [] for i in range(min(n_samples, all_M.shape[0])): M = all_M[i] try: km = KMeans(n_clusters=k_test, n_init=10, random_state=42) labels = km.fit_predict(M) if len(set(labels)) >= 2: sil = silhouette_score(M, labels) silhouettes.append(sil) except Exception: pass silhouettes = np.array(silhouettes) return { 'k_tested': k_test, 'mean': float(silhouettes.mean()) if len(silhouettes) else None, 'std': float(silhouettes.std()) if len(silhouettes) else None, 'silhouettes_per_sample': silhouettes.tolist(), } def test_angular_distribution(all_M): all_rows = all_M.reshape(-1, all_M.shape[-1]) norms = np.linalg.norm(all_rows, axis=1, keepdims=True) unit_rows = all_rows / np.clip(norms, 1e-12, None) n_subset = min(500, unit_rows.shape[0]) idx = np.random.RandomState(42).choice( unit_rows.shape[0], n_subset, replace=False) subset = unit_rows[idx] cosines = subset @ subset.T triu_idx = np.triu_indices(n_subset, k=1) pairwise_cos = cosines[triu_idx] pairwise_angles = np.arccos(np.clip(pairwise_cos, -1, 1)) return { 'mean_angle': float(pairwise_angles.mean()), 'median_angle': float(np.median(pairwise_angles)), 'fraction_near_zero': float((pairwise_angles < 0.5).mean()), 'fraction_near_pi': float((pairwise_angles > math.pi - 0.5).mean()), 'fraction_near_perp': float( ((pairwise_angles > math.pi/2 - 0.3) & (pairwise_angles < math.pi/2 + 0.3)).mean()), 'pairwise_angles_subset': pairwise_angles[:200].tolist(), } def test_antipodal(all_M): mean_dirs = all_M.mean(axis=0) norms = np.linalg.norm(mean_dirs, axis=1, keepdims=True) unit_dirs = mean_dirs / np.clip(norms, 1e-12, None) cosines = unit_dirs @ unit_dirs.T np.fill_diagonal(cosines, 1.0) most_anti_cos = cosines.min(axis=1) n_pairs = (most_anti_cos < -0.9).sum() // 2 return { 'min_cos': float(most_anti_cos.min()), 'mean_cos': float(most_anti_cos.mean()), 'fraction_with_antipode': float((most_anti_cos < -0.9).mean()), 'estimated_pairs': int(n_pairs), 'max_possible_pairs': all_M.shape[1] // 2, } def test_effective_rank(all_M): M_avg = all_M.mean(axis=0) sv = np.linalg.svd(M_avg, compute_uv=False) sv_norm = sv / sv.sum() erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum()) return { 'singular_values': sv.tolist(), 'normalized_SV': sv_norm.tolist(), 'effective_rank': float(erank), 'D': int(all_M.shape[2]), 'utilization': float(erank / all_M.shape[2]), 'top1_share': float(sv_norm[0]), } def run_all_tests(all_M, label): print(f"\n[{label}]") print(f" Shape: {all_M.shape}") sphere = test_sphere_norm(all_M) print(f" Sphere-norm: mean={sphere['mean']:.4f}, " f"std={sphere['std']:.4f} → {'YES' if sphere['sphere_normed'] else 'NO'}") stability = test_row_stability(all_M) print(f" Row stability: mean={stability['mean']:.3f}, " f"range=[{stability['min']:.3f}, {stability['max']:.3f}]") cluster = test_per_sample_clustering(all_M) if cluster['mean'] is not None: print(f" Cluster (k=5): silhouette mean={cluster['mean']:.3f}, " f"std={cluster['std']:.3f}") angular = test_angular_distribution(all_M) print(f" Angular: mean={angular['mean_angle']:.3f} " f"(uniform=π/2={math.pi/2:.3f})") print(f" near-perp: {angular['fraction_near_perp']:.3f}, " f"near-π: {angular['fraction_near_pi']:.3f}") antipodal = test_antipodal(all_M) print(f" Antipodal: {antipodal['estimated_pairs']}/" f"{antipodal['max_possible_pairs']} pairs, " f"frac with antipode={antipodal['fraction_with_antipode']:.3f}") erank = test_effective_rank(all_M) print(f" Effective rank: {erank['effective_rank']:.2f} of {erank['D']} " f"({erank['utilization']*100:.0f}% utilization)") return { 'sphere_norm': sphere, 'stability': stability, 'clustering': cluster, 'angular': angular, 'antipodal': antipodal, 'rank': erank, } # ════════════════════════════════════════════════════════════════════ # Composite character classification # ════════════════════════════════════════════════════════════════════ def classify_battery_character(results): """Determine if battery is H2-like (sphere-solver) or G-like (rotating-antipodal) or something else.""" stab = results['stability']['mean'] antipodal_frac = results['antipodal']['fraction_with_antipode'] cluster_sil = results['clustering']['mean'] rank_util = results['rank']['utilization'] # H2-like: high stability, low antipodal fraction, full rank is_h2_like = ( stab > 0.85 and antipodal_frac < 0.55 and rank_util > 0.95 ) # G-like: low stability, high antipodal fraction is_g_like = ( stab < 0.65 and antipodal_frac > 0.80 ) # Hybrid: somewhere in between if is_h2_like: return f"H2-LIKE (static sphere-solver)" elif is_g_like: return f"G-LIKE (rotating antipodal frame)" elif stab < 0.65 and antipodal_frac < 0.55: return f"DIFFUSE (low stability, no antipodal structure)" else: return (f"HYBRID (stab={stab:.2f}, antipodal_frac=" f"{antipodal_frac:.2f})") # ════════════════════════════════════════════════════════════════════ # Main # ════════════════════════════════════════════════════════════════════ def main(): print("=" * 70) print("Loading three batteries for comparative analysis") print("=" * 70) print("\n[1/3] H2a (Q-rank02, D=4, 1000-batch Adam)") h2_model, h2_cfg = load_qsweep_model('rank02', RANK02_CKPT) print(f" V={h2_cfg.matrix_v}, D={h2_cfg.D}, " f"params={sum(p.numel() for p in h2_model.parameters()):,}") print("\n[2/3] G-Class candidate (Q-rank09, D=3, 1000-batch Adam)") g_model, g_cfg = load_qsweep_model('rank09', RANK09_CKPT) print(f" V={g_cfg.matrix_v}, D={g_cfg.D}, " f"params={sum(p.numel() for p in g_model.parameters()):,}") print("\n[3/3] h2-64 single-noise gaussian battery (D=8, 10 epochs converged)") h264_bank, h264_V, h264_D, h264_ps, h264_img = load_h2_64_battery( battery_idx=0, phase='final') print(f" V={h264_V}, D={h264_D}, patch_size={h264_ps}, img_size={h264_img}") # ════════════════════════════════════════════════════════════════ # Collect M rows # ════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print("Collecting M rows (gaussian inputs, 512 samples each)") print("=" * 70) print("\n H2a...") all_M_h2 = collect_per_sample_M( h2_model, h2_cfg.matrix_v, h2_cfg.D, h2_cfg.patch_size, h2_cfg.img_size) print(" G-Cand...") all_M_g = collect_per_sample_M( g_model, g_cfg.matrix_v, g_cfg.D, g_cfg.patch_size, g_cfg.img_size) print(" h2-64 gaussian...") all_M_h264 = collect_per_sample_M( h264_bank, h264_V, h264_D, h264_ps, h264_img, is_h2_64_bank=True) # ════════════════════════════════════════════════════════════════ # Run tests on each # ════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print("GEOMETRIC ANALYSIS") print("=" * 70) results_h2 = run_all_tests(all_M_h2, "H2a (D=4, 1000-batch Adam)") results_g = run_all_tests(all_M_g, "G-Cand (D=3, 1000-batch Adam)") results_h264 = run_all_tests( all_M_h264, "h2-64 gaussian (D=8, 10 epochs)") # ════════════════════════════════════════════════════════════════ # Side-by-side comparison # ════════════════════════════════════════════════════════════════ print("\n" + "=" * 70) print("THREE-WAY COMPARISON") print("=" * 70) headers = f"{'Metric':<32} {'H2a (D=4)':>12} {'G-Cand (D=3)':>14} {'h2-64 (D=8)':>14}" print(f"\n {headers}") print(" " + "-" * len(headers)) rows = [ ('Effective rank', results_h2['rank']['effective_rank'], results_g['rank']['effective_rank'], results_h264['rank']['effective_rank'], '.2f'), ('Dim utilization (%)', results_h2['rank']['utilization'] * 100, results_g['rank']['utilization'] * 100, results_h264['rank']['utilization'] * 100, '.0f'), ('Row stability', results_h2['stability']['mean'], results_g['stability']['mean'], results_h264['stability']['mean'], '.3f'), ('Per-sample silhouette (k=5)', results_h2['clustering']['mean'] or 0, results_g['clustering']['mean'] or 0, results_h264['clustering']['mean'] or 0, '.3f'), ('Mean pairwise angle (rad)', results_h2['angular']['mean_angle'], results_g['angular']['mean_angle'], results_h264['angular']['mean_angle'], '.3f'), ('Antipodal pair fraction', results_h2['antipodal']['fraction_with_antipode'], results_g['antipodal']['fraction_with_antipode'], results_h264['antipodal']['fraction_with_antipode'], '.3f'), ('Estimated antipodal pairs', results_h2['antipodal']['estimated_pairs'], results_g['antipodal']['estimated_pairs'], results_h264['antipodal']['estimated_pairs'], 'd'), ] for row in rows: name, h2v, gv, h264v, fmt = row if fmt == 'd': print(f" {name:<32} {h2v:>12d} {gv:>14d} {h264v:>14d}") else: print(f" {name:<32} {h2v:>12{fmt}} {gv:>14{fmt}} {h264v:>14{fmt}}") print() char_h2 = classify_battery_character(results_h2) char_g = classify_battery_character(results_g) char_h264 = classify_battery_character(results_h264) print(f" Character verdict:") print(f" H2a: {char_h2}") print(f" G-Cand: {char_g}") print(f" h2-64: {char_h264}") # Headline conclusion print("\n" + "=" * 70) print("CONCLUSION") print("=" * 70) if "G-LIKE" in char_h264: print(" h2-64 (D=8, fully converged) shows G-CLASS character.") print(" → The antipodal+rotational structure is NOT D=3-specific.") print(" → It's the LOW-band attractor's natural geometry.") print(" → H2a (D=4 at HIGH band) is the OUTLIER — its sphere-solver") print(" rigidity is HIGH-band-specific, not the universal pattern.") elif "H2-LIKE" in char_h264: print(" h2-64 (D=8, fully converged) shows H2 sphere-solver character.") print(" → G-Class at D=3 is genuinely different from sphere-solvers.") print(" → D=3 specifically can't form a stable static 32-row arrangement,") print(" so it falls into the rotating-antipodal regime.") print(" → Higher D recovers static sphere-solver behavior even in LOW band.") elif "HYBRID" in char_h264 or "DIFFUSE" in char_h264: print(" h2-64 (D=8) shows mixed character — partial G-like features.") print(" → Possible spectrum: HIGH-band → static sphere (H2),") print(" LOW-band → progressively more antipodal as D decreases.") print(" → D=8 sits in transition; D=3 is fully G-class; D=4 HIGH is fully H2.") all_results = { 'h2a': results_h2, 'g_class_candidate': results_g, 'h2_64_gaussian': results_h264, 'characters': { 'h2a': char_h2, 'g_class': char_g, 'h2_64': char_h264, }, } with open(OUTPUT_JSON, 'w') as f: json.dump(all_results, f, indent=2, default=str) print(f"\n Saved: {OUTPUT_JSON}") # Plot plot_three_way(all_M_h2, all_M_g, all_M_h264, results_h2, results_g, results_h264, OUTPUT_PLOT) print(f" Saved: {OUTPUT_PLOT}") return all_results def plot_three_way(M_h2, M_g, M_h264, r_h2, r_g, r_h264, output_path): """6-panel comparison figure: 3 batteries × 2 metrics each.""" fig = plt.figure(figsize=(18, 14)) # Row 1: Single-sample row scatters (project to first 3 dims) ax1 = fig.add_subplot(3, 3, 1, projection='3d') s = M_h2[0] ax1.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)), cmap='viridis', s=80, edgecolors='black', linewidths=0.5) ax1.set_title(f'H2a (D=4) — single sample\nrows projected to first 3 dims') ax2 = fig.add_subplot(3, 3, 2, projection='3d') s = M_g[0] ax2.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)), cmap='viridis', s=80, edgecolors='black', linewidths=0.5) ax2.set_title(f'G-Cand (D=3) — single sample\nfull native dims') ax3 = fig.add_subplot(3, 3, 3, projection='3d') s = M_h264[0] ax3.scatter(s[:, 0], s[:, 1], s[:, 2], c=np.arange(len(s)), cmap='viridis', s=80, edgecolors='black', linewidths=0.5) ax3.set_title(f'h2-64 gaussian (D=8) — single sample\nrows projected to first 3 dims') # Row 2: Per-row stability sorted (descending) ax4 = fig.add_subplot(3, 3, 4) ax4.plot(sorted(r_h2['stability']['mean_dir_norms'], reverse=True), 'o-', color='blue', markersize=4) ax4.set_title(f"H2a row stability\nmean={r_h2['stability']['mean']:.3f}") ax4.set_xlabel('Row index (sorted)') ax4.set_ylabel('Mean direction norm') ax4.set_ylim([0, 1.05]) ax4.grid(alpha=0.3) ax5 = fig.add_subplot(3, 3, 5) ax5.plot(sorted(r_g['stability']['mean_dir_norms'], reverse=True), 'o-', color='red', markersize=4) ax5.set_title(f"G-Cand row stability\nmean={r_g['stability']['mean']:.3f}") ax5.set_xlabel('Row index (sorted)') ax5.set_ylabel('Mean direction norm') ax5.set_ylim([0, 1.05]) ax5.grid(alpha=0.3) ax6 = fig.add_subplot(3, 3, 6) ax6.plot(sorted(r_h264['stability']['mean_dir_norms'], reverse=True), 'o-', color='green', markersize=4) ax6.set_title(f"h2-64 row stability\nmean={r_h264['stability']['mean']:.3f}") ax6.set_xlabel('Row index (sorted)') ax6.set_ylabel('Mean direction norm') ax6.set_ylim([0, 1.05]) ax6.grid(alpha=0.3) # Row 3: Pairwise angle distributions ax7 = fig.add_subplot(3, 3, 7) ax7.hist(r_h2['angular']['pairwise_angles_subset'], bins=30, color='blue', alpha=0.7, density=True) ax7.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5) ax7.set_title(f"H2a pairwise angles\nmean={r_h2['angular']['mean_angle']:.3f}") ax7.set_xlabel('Angle (radians)') ax8 = fig.add_subplot(3, 3, 8) ax8.hist(r_g['angular']['pairwise_angles_subset'], bins=30, color='red', alpha=0.7, density=True) ax8.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5) ax8.set_title(f"G-Cand pairwise angles\nmean={r_g['angular']['mean_angle']:.3f}") ax8.set_xlabel('Angle (radians)') ax9 = fig.add_subplot(3, 3, 9) ax9.hist(r_h264['angular']['pairwise_angles_subset'], bins=30, color='green', alpha=0.7, density=True) ax9.axvline(math.pi/2, color='black', linestyle='--', alpha=0.5) ax9.set_title(f"h2-64 pairwise angles\nmean={r_h264['angular']['mean_angle']:.3f}") ax9.set_xlabel('Angle (radians)') plt.tight_layout() plt.savefig(output_path, dpi=120, bbox_inches='tight') plt.show() if __name__ == '__main__': results = main()