""" cell_p_class_probe.py — geometric structure probe for P-Class batteries Loads P-rank09 (h64_V32_D3_dp0_nx0_adam, MSE 0.028, CV 0.03) and asks what its 32 row vectors in 3D space actually look like. Four hypothesis tests: 1. RANK STRUCTURE — SVD on the 32×3 row matrix M. - Polynomial basis: rank ≤ 2 (Vandermonde collapses) - Trig basis: rank = 2 or 3 with specific singular value ratio - Cluster: rank 3, all SVs comparable - Collapsed: rank 1, one dominant SV 2. PARAMETRIC ORDERING — Try ordering rows by their first coordinate (or first principal axis projection). If rows form a smooth curve when ordered, we're seeing a parametric structure (polynomial, trig, etc). If they're scattered with no order, it's clusters. Metric: smoothness of consecutive Δ when sorted along PC1. 3. POLYNOMIAL FIT TEST — Fit a Vandermonde matrix to the ordered rows. If R² > 0.95 with cubic, polynomial hypothesis confirmed. Try [1, x, x²], [1, x, x², x³], [1, sin(x), cos(x)]. 4. CLUSTER COUNT — k-means with k = 2..8 on the 32 rows. If silhouette score is high at small k, it's clustered. If silhouette is low for all k, the rows are spread continuously (consistent with parametric). Outputs: - Console verdict for each hypothesis - /content/phaseQ_reports/p_rank09_probe.png — 4-panel diagnostic plot - /content/phaseQ_reports/p_rank09_probe.json — all numerical results """ import json import math from pathlib import Path import numpy as np import torch import torch.nn.functional as F import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score CKPT_DIR = Path("/content/phaseQ_reports") RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt" OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe.png" OUTPUT_JSON = CKPT_DIR / "p_rank09_probe.json" def load_rank09(): """Reconstruct P-rank09 model and load its trained weights.""" cfgs = get_phaseQ_configs() rank09_cfg = next(c for c in cfgs if 'rank09' in c['variant']) cfg = build_run_config(rank09_cfg) overrides = rank09_cfg['overrides'] model = PatchSVAE_F_Ablation( matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size, hidden=cfg.hidden, depth=cfg.depth, n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads, max_alpha=overrides.get('max_alpha', cfg.max_alpha), alpha_init=cfg.alpha_init, activation=overrides.get('activation', 'gelu'), row_norm=overrides.get('row_norm', 'sphere'), svd_mode=overrides.get('svd', 'fp64'), linear_readout=overrides.get('linear_readout', False), match_params=overrides.get('match_params', True), init_scheme=overrides.get('init', 'orthogonal'), ) ckpt = torch.load(RANK09_CKPT, map_location='cpu', weights_only=False) # Trainer saves model weights under 'model_state'; the older # 'model_state_dict' / 'state_dict' fallbacks are kept for compatibility. state_dict = ( ckpt.get('model_state') or ckpt.get('model_state_dict') or ckpt.get('state_dict') or ckpt ) model.load_state_dict(state_dict) model.eval() return model, cfg def collect_rows(model, cfg, n_batches=8, batch_size=64): """Run gaussian noise through encoder, collect M rows from one canonical patch position to get a stable [n_samples, V, D] tensor of row matrices.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) ds = OmegaNoiseDataset( size=n_batches * batch_size, img_size=cfg.img_size, allowed_types=[0]) # gaussian loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=False) all_M = [] # collect M from patch 0 of every sample with torch.no_grad(): for imgs, _ in loader: imgs = imgs.to(device) out = model(imgs) # M shape: [B, N_patches, V, D] M_patch0 = out['svd']['M'][:, 0] # [B, V, D] all_M.append(M_patch0.cpu()) return torch.cat(all_M, dim=0) # [n_samples, V, D] # ════════════════════════════════════════════════════════════════════ # Hypothesis tests # ════════════════════════════════════════════════════════════════════ def test_rank_structure(M_avg): """Test 1: SVD on the canonical row matrix. M_avg: averaged 32×3 row matrix. SVD gives 3 singular values. Predictions: Polynomial Vandermonde: top-1 SV dominates, rank≈1-2 Trig basis: balanced top-2 SVs, small 3rd Sphere uniform (H2): ~equal SVs, full rank Cluster: depends on cluster geometry """ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) S_norm = S / S.sum() erank = math.exp(-(S_norm * np.log(S_norm + 1e-12)).sum()) return { 'singular_values': S.tolist(), 'normalized_SV': S_norm.tolist(), 'effective_rank': erank, 'top1_share': S_norm[0], 'top2_share': S_norm[:2].sum(), 'verdict': ( 'rank-1 (collapsed/aligned)' if S_norm[0] > 0.85 else 'rank-2 (planar — could be polynomial or trig)' if S_norm[:2].sum() > 0.92 else 'rank-3 (full, balanced)' if S_norm.std() < 0.05 else 'rank-3 (full, imbalanced)' ), } def test_parametric_ordering(M_avg): """Test 2: Project rows onto first principal axis, sort, check smoothness. If rows lie on a smooth parametric curve (polynomial, trig), sorting by PC1 projection should produce a smooth sequence. Smoothness = 1 / variance of consecutive Δ in PC2/PC3 coords (after sort). """ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) # Project rows onto principal axes proj = M_avg @ Vt.T # [V, 3] # Sort by PC1 sort_idx = np.argsort(proj[:, 0]) sorted_proj = proj[sort_idx] # Δ between consecutive sorted rows in PC2, PC3 deltas_pc2 = np.diff(sorted_proj[:, 1]) deltas_pc3 = np.diff(sorted_proj[:, 2]) # If smooth curve, Δ should be small relative to overall PC2/PC3 spread range_pc2 = sorted_proj[:, 1].max() - sorted_proj[:, 1].min() range_pc3 = sorted_proj[:, 2].max() - sorted_proj[:, 2].min() smoothness_pc2 = 1.0 - (np.abs(deltas_pc2).mean() / (range_pc2 + 1e-8)) smoothness_pc3 = 1.0 - (np.abs(deltas_pc3).mean() / (range_pc3 + 1e-8)) return { 'sort_order': sort_idx.tolist(), 'smoothness_pc2': float(smoothness_pc2), 'smoothness_pc3': float(smoothness_pc3), 'pc1_range': float(proj[:, 0].max() - proj[:, 0].min()), 'pc2_range': float(range_pc2), 'pc3_range': float(range_pc3), 'verdict': ( 'smooth parametric curve' if min(smoothness_pc2, smoothness_pc3) > 0.85 else 'partial structure' if min(smoothness_pc2, smoothness_pc3) > 0.5 else 'scattered (cluster-like)' ), } def test_polynomial_fit(M_avg): """Test 3: Try polynomial bases of various orders. Order rows by PC1 projection. Fit each PC2/PC3 coordinate as a function of PC1. Polynomial degrees 1, 2, 3, 4. Best-fit R² tells us the order. Also tries [1, sin(x), cos(x)] for trigonometric basis. """ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) proj = M_avg @ Vt.T sort_idx = np.argsort(proj[:, 0]) x = proj[sort_idx, 0] y2 = proj[sort_idx, 1] y3 = proj[sort_idx, 2] # Normalize x to [-1, 1] for stable polyfit x_norm = 2 * (x - x.min()) / (x.max() - x.min() + 1e-8) - 1 def r2(y_true, y_pred): ss_res = ((y_true - y_pred) ** 2).sum() ss_tot = ((y_true - y_true.mean()) ** 2).sum() return 1 - ss_res / (ss_tot + 1e-12) poly_results = {} for deg in [1, 2, 3, 4]: coef2 = np.polyfit(x_norm, y2, deg) coef3 = np.polyfit(x_norm, y3, deg) pred2 = np.polyval(coef2, x_norm) pred3 = np.polyval(coef3, x_norm) poly_results[f'degree_{deg}'] = { 'r2_pc2': float(r2(y2, pred2)), 'r2_pc3': float(r2(y3, pred3)), } # Trigonometric fit: y = a + b·sin(πx) + c·cos(πx) + d·sin(2πx) + e·cos(2πx) def trig_basis(x): return np.column_stack([ np.ones_like(x), np.sin(np.pi * x), np.cos(np.pi * x), np.sin(2 * np.pi * x), np.cos(2 * np.pi * x), ]) B = trig_basis(x_norm) coef2_t, _, _, _ = np.linalg.lstsq(B, y2, rcond=None) coef3_t, _, _, _ = np.linalg.lstsq(B, y3, rcond=None) trig_r2_pc2 = r2(y2, B @ coef2_t) trig_r2_pc3 = r2(y3, B @ coef3_t) # Pick the best fit best_poly_deg = max([1, 2, 3, 4], key=lambda d: poly_results[f'degree_{d}']['r2_pc2']) best_poly_r2 = poly_results[f'degree_{best_poly_deg}']['r2_pc2'] return { 'polynomial': poly_results, 'trigonometric': { 'r2_pc2': float(trig_r2_pc2), 'r2_pc3': float(trig_r2_pc3), 'coefs_pc2': coef2_t.tolist(), }, 'best_poly_degree': best_poly_deg, 'best_poly_r2': float(best_poly_r2), 'verdict': ( f'polynomial degree {best_poly_deg} (R²={best_poly_r2:.3f})' if best_poly_r2 > 0.95 else f'trigonometric (R²={trig_r2_pc2:.3f})' if trig_r2_pc2 > 0.95 else f'no clean parametric fit (best poly R²={best_poly_r2:.3f}, ' f'trig R²={trig_r2_pc2:.3f})' ), } def test_cluster_structure(M_avg): """Test 4: k-means + silhouette across k = 2..8. High silhouette at small k → genuine clusters. Low silhouette across all k → continuous spread (consistent with parametric structure). """ results = {} best_k = None best_score = -1 for k in range(2, min(9, M_avg.shape[0])): km = KMeans(n_clusters=k, n_init=10, random_state=42) labels = km.fit_predict(M_avg) if len(set(labels)) < 2: continue score = silhouette_score(M_avg, labels) results[f'k={k}'] = { 'silhouette': float(score), 'inertia': float(km.inertia_), } if score > best_score: best_score = score best_k = k return { 'per_k': results, 'best_k': best_k, 'best_silhouette': float(best_score), 'verdict': ( f'strong clusters (k={best_k}, silhouette={best_score:.3f})' if best_score > 0.5 else f'weak clusters (k={best_k}, silhouette={best_score:.3f})' if best_score > 0.25 else f'no clear clusters (best silhouette={best_score:.3f}) — ' f'consistent with continuous structure' ), } # ════════════════════════════════════════════════════════════════════ # Plotting # ════════════════════════════════════════════════════════════════════ def plot_diagnostic(M_avg, all_M, results, output_path): """4-panel diagnostic plot.""" fig = plt.figure(figsize=(16, 12)) # Panel 1: 3D scatter of the canonical 32 rows ax1 = fig.add_subplot(2, 2, 1, projection='3d') U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) proj = M_avg @ Vt.T sort_idx = np.argsort(proj[:, 0]) colors = plt.cm.viridis(np.linspace(0, 1, len(sort_idx))) for i, idx in enumerate(sort_idx): ax1.scatter(M_avg[idx, 0], M_avg[idx, 1], M_avg[idx, 2], c=[colors[i]], s=80, edgecolors='black', linewidths=0.5) ax1.set_xlabel('D1') ax1.set_ylabel('D2') ax1.set_zlabel('D3') ax1.set_title(f'P-rank09 row matrix M (V=32, D=3)\n' f'colored by PC1 sort order\n' f'effective rank: {results["rank"]["effective_rank"]:.2f}') # Panel 2: Singular value spectrum ax2 = fig.add_subplot(2, 2, 2) SVs = np.array(results['rank']['singular_values']) ax2.bar(['SV1', 'SV2', 'SV3'], SVs, color=['red', 'orange', 'yellow']) ax2.set_ylabel('Singular value') ax2.set_title(f'Singular values of M\n' f'top1 share: {results["rank"]["top1_share"]:.2%}\n' f'verdict: {results["rank"]["verdict"]}') for i, sv in enumerate(SVs): ax2.text(i, sv, f'{sv:.3f}', ha='center', va='bottom') # Panel 3: PC2 and PC3 vs PC1 (parametric curve test) ax3 = fig.add_subplot(2, 2, 3) x = proj[sort_idx, 0] y2 = proj[sort_idx, 1] y3 = proj[sort_idx, 2] ax3.plot(x, y2, 'o-', color='blue', label='PC2 vs PC1', markersize=6) ax3.plot(x, y3, 's-', color='green', label='PC3 vs PC1', markersize=6) ax3.set_xlabel('PC1 projection') ax3.set_ylabel('PC2 / PC3 projection') ax3.set_title(f'Parametric ordering test\n' f'smoothness PC2: {results["parametric"]["smoothness_pc2"]:.3f}, ' f'PC3: {results["parametric"]["smoothness_pc3"]:.3f}\n' f'verdict: {results["parametric"]["verdict"]}') ax3.legend() ax3.grid(alpha=0.3) # Panel 4: Cluster silhouette across k ax4 = fig.add_subplot(2, 2, 4) ks = [] sils = [] for k_str, r in results['cluster']['per_k'].items(): ks.append(int(k_str.split('=')[1])) sils.append(r['silhouette']) ax4.plot(ks, sils, 'o-', color='purple', markersize=8) ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5, label='strong cluster threshold') ax4.axhline(0.25, color='orange', linestyle='--', alpha=0.5, label='weak cluster threshold') ax4.set_xlabel('k (number of clusters)') ax4.set_ylabel('silhouette score') ax4.set_title(f'Cluster structure test\n' f'best k={results["cluster"]["best_k"]}, ' f'silhouette={results["cluster"]["best_silhouette"]:.3f}\n' f'verdict: {results["cluster"]["verdict"]}') ax4.legend(fontsize=8) ax4.grid(alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=120, bbox_inches='tight') plt.show() # ════════════════════════════════════════════════════════════════════ # Main # ════════════════════════════════════════════════════════════════════ def main(): print("Loading P-rank09 model...") model, cfg = load_rank09() print(f" Architecture: V={cfg.matrix_v}, D={cfg.D}, " f"patch_size={cfg.patch_size}, hidden={cfg.hidden}") n_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {n_params:,}") print("\nCollecting M rows from gaussian inputs...") all_M = collect_rows(model, cfg, n_batches=8, batch_size=64) print(f" Collected {all_M.shape[0]} samples of M [V={all_M.shape[1]}, " f"D={all_M.shape[2]}]") # Average M over samples to get the canonical row matrix M_avg = all_M.mean(dim=0).numpy() M_std = all_M.std(dim=0).numpy() print(f" M_avg shape: {M_avg.shape}") print(f" Per-row variability (mean ‖σ‖₂ across rows): " f"{np.linalg.norm(M_std, axis=1).mean():.4f}") print(f" Per-row mean magnitude (mean ‖μ‖₂): " f"{np.linalg.norm(M_avg, axis=1).mean():.4f}") # Sphere-norm verification row_norms = np.linalg.norm(M_avg, axis=1) print(f" Row norm range: [{row_norms.min():.4f}, {row_norms.max():.4f}]") print(f" (sphere-normed rows should all have norm ~1.0)") print("\n" + "═" * 70) print("HYPOTHESIS TESTS") print("═" * 70) print("\n[1/4] Rank structure (SVD)...") rank_results = test_rank_structure(M_avg) print(f" Singular values: {[f'{s:.4f}' for s in rank_results['singular_values']]}") print(f" Effective rank: {rank_results['effective_rank']:.2f}") print(f" Top-1 share: {rank_results['top1_share']:.2%}") print(f" VERDICT: {rank_results['verdict']}") print("\n[2/4] Parametric ordering (PC1 sort + smoothness)...") param_results = test_parametric_ordering(M_avg) print(f" Smoothness PC2: {param_results['smoothness_pc2']:.3f}") print(f" Smoothness PC3: {param_results['smoothness_pc3']:.3f}") print(f" VERDICT: {param_results['verdict']}") print("\n[3/4] Polynomial / trigonometric fit...") fit_results = test_polynomial_fit(M_avg) print(f" Polynomial fits (R² for PC2):") for deg in [1, 2, 3, 4]: r2 = fit_results['polynomial'][f'degree_{deg}']['r2_pc2'] print(f" degree {deg}: R² = {r2:.4f}") print(f" Trigonometric fit (R² for PC2): " f"{fit_results['trigonometric']['r2_pc2']:.4f}") print(f" VERDICT: {fit_results['verdict']}") print("\n[4/4] Cluster structure (k-means silhouette)...") cluster_results = test_cluster_structure(M_avg) print(f" Per-k silhouette:") for k_str, r in cluster_results['per_k'].items(): print(f" {k_str}: silhouette = {r['silhouette']:.3f}") print(f" VERDICT: {cluster_results['verdict']}") all_results = { 'config': { 'variant': 'P_rank09_h64_V32_D3_dp0_nx0_adam', 'V': cfg.matrix_v, 'D': cfg.D, 'params': n_params, 'gaussian_test_mse': 0.02782, 'observed_cv': 0.035, }, 'M_avg_shape': list(M_avg.shape), 'row_norms_mean': float(row_norms.mean()), 'row_norms_std': float(row_norms.std()), 'rank': rank_results, 'parametric': param_results, 'fit': fit_results, 'cluster': cluster_results, } print("\n" + "═" * 70) print("OVERALL INTERPRETATION") print("═" * 70) print(f" Rank: {rank_results['verdict']}") print(f" Parametric: {param_results['verdict']}") print(f" Fit: {fit_results['verdict']}") print(f" Clusters: {cluster_results['verdict']}") # Composite verdict logic is_polynomial = ( fit_results['best_poly_r2'] > 0.95 and rank_results['effective_rank'] < 2.5 ) is_trig = ( fit_results['trigonometric']['r2_pc2'] > 0.95 and not is_polynomial ) is_clustered = cluster_results['best_silhouette'] > 0.5 is_collapsed = rank_results['top1_share'] > 0.85 print(f"\n Composite read:") if is_polynomial: deg = fit_results['best_poly_degree'] print(f" → POLYNOMIAL CONFIRMED (degree {deg}). " f"P-Class naming validated.") elif is_trig: print(f" → TRIGONOMETRIC structure detected. " f"P-Class might be better named F-Class (Fourier).") elif is_collapsed: print(f" → COLLAPSED — rows essentially 1-dimensional. " f"Failed differentiation, not a useful battery.") elif is_clustered: k = cluster_results['best_k'] print(f" → CLUSTERED into {k} groups. " f"P-Class might be better named K-Class " f"(k-means / quantization).") else: print(f" → MIXED structure — not cleanly polynomial, trig, or " f"clustered. Worth probing further with higher-order bases or " f"deeper geometric analysis.") with open(OUTPUT_JSON, 'w') as f: json.dump(all_results, f, indent=2, default=str) print(f"\n Results saved: {OUTPUT_JSON}") plot_diagnostic(M_avg, all_M, all_results, OUTPUT_PLOT) print(f" Plot saved: {OUTPUT_PLOT}") return all_results if __name__ == '__main__': results = main()