| """ |
| cell_p_class_probe.py β geometric structure probe for P-Class batteries |
| |
| Loads P-rank09 (h64_V32_D3_dp0_nx0_adam, MSE 0.028, CV 0.03) and asks |
| what its 32 row vectors in 3D space actually look like. |
| |
| Four hypothesis tests: |
| 1. RANK STRUCTURE β SVD on the 32Γ3 row matrix M. |
| - Polynomial basis: rank β€ 2 (Vandermonde collapses) |
| - Trig basis: rank = 2 or 3 with specific singular value ratio |
| - Cluster: rank 3, all SVs comparable |
| - Collapsed: rank 1, one dominant SV |
| |
| 2. PARAMETRIC ORDERING β Try ordering rows by their first coordinate |
| (or first principal axis projection). If rows form a smooth curve |
| when ordered, we're seeing a parametric structure (polynomial, |
| trig, etc). If they're scattered with no order, it's clusters. |
| Metric: smoothness of consecutive Ξ when sorted along PC1. |
| |
| 3. POLYNOMIAL FIT TEST β Fit a Vandermonde matrix to the ordered rows. |
| If RΒ² > 0.95 with cubic, polynomial hypothesis confirmed. |
| Try [1, x, xΒ²], [1, x, xΒ², xΒ³], [1, sin(x), cos(x)]. |
| |
| 4. CLUSTER COUNT β k-means with k = 2..8 on the 32 rows. If silhouette |
| score is high at small k, it's clustered. If silhouette is low for |
| all k, the rows are spread continuously (consistent with parametric). |
| |
| Outputs: |
| - Console verdict for each hypothesis |
| - /content/phaseQ_reports/p_rank09_probe.png β 4-panel diagnostic plot |
| - /content/phaseQ_reports/p_rank09_probe.json β all numerical results |
| """ |
|
|
| import json |
| import math |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import matplotlib.pyplot as plt |
| from mpl_toolkits.mplot3d import Axes3D |
| from sklearn.cluster import KMeans |
| from sklearn.metrics import silhouette_score |
|
|
|
|
| CKPT_DIR = Path("/content/phaseQ_reports") |
| RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt" |
| OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe.png" |
| OUTPUT_JSON = CKPT_DIR / "p_rank09_probe.json" |
|
|
|
|
| def load_rank09(): |
| """Reconstruct P-rank09 model and load its trained weights.""" |
| cfgs = get_phaseQ_configs() |
| rank09_cfg = next(c for c in cfgs if 'rank09' in c['variant']) |
| cfg = build_run_config(rank09_cfg) |
| overrides = rank09_cfg['overrides'] |
|
|
| model = PatchSVAE_F_Ablation( |
| matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size, |
| hidden=cfg.hidden, depth=cfg.depth, |
| n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads, |
| max_alpha=overrides.get('max_alpha', cfg.max_alpha), |
| alpha_init=cfg.alpha_init, |
| activation=overrides.get('activation', 'gelu'), |
| row_norm=overrides.get('row_norm', 'sphere'), |
| svd_mode=overrides.get('svd', 'fp64'), |
| linear_readout=overrides.get('linear_readout', False), |
| match_params=overrides.get('match_params', True), |
| init_scheme=overrides.get('init', 'orthogonal'), |
| ) |
|
|
| ckpt = torch.load(RANK09_CKPT, map_location='cpu', weights_only=False) |
| |
| |
| state_dict = ( |
| ckpt.get('model_state') |
| or ckpt.get('model_state_dict') |
| or ckpt.get('state_dict') |
| or ckpt |
| ) |
| model.load_state_dict(state_dict) |
| model.eval() |
| return model, cfg |
|
|
|
|
| def collect_rows(model, cfg, n_batches=8, batch_size=64): |
| """Run gaussian noise through encoder, collect M rows from one canonical |
| patch position to get a stable [n_samples, V, D] tensor of row matrices.""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = model.to(device) |
|
|
| ds = OmegaNoiseDataset( |
| size=n_batches * batch_size, |
| img_size=cfg.img_size, |
| allowed_types=[0]) |
| loader = torch.utils.data.DataLoader( |
| ds, batch_size=batch_size, shuffle=False) |
|
|
| all_M = [] |
| with torch.no_grad(): |
| for imgs, _ in loader: |
| imgs = imgs.to(device) |
| out = model(imgs) |
| |
| M_patch0 = out['svd']['M'][:, 0] |
| all_M.append(M_patch0.cpu()) |
|
|
| return torch.cat(all_M, dim=0) |
|
|
|
|
| |
| |
| |
|
|
| def test_rank_structure(M_avg): |
| """Test 1: SVD on the canonical row matrix. |
| |
| M_avg: averaged 32Γ3 row matrix. SVD gives 3 singular values. |
| Predictions: |
| Polynomial Vandermonde: top-1 SV dominates, rankβ1-2 |
| Trig basis: balanced top-2 SVs, small 3rd |
| Sphere uniform (H2): ~equal SVs, full rank |
| Cluster: depends on cluster geometry |
| """ |
| U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) |
| S_norm = S / S.sum() |
| erank = math.exp(-(S_norm * np.log(S_norm + 1e-12)).sum()) |
|
|
| return { |
| 'singular_values': S.tolist(), |
| 'normalized_SV': S_norm.tolist(), |
| 'effective_rank': erank, |
| 'top1_share': S_norm[0], |
| 'top2_share': S_norm[:2].sum(), |
| 'verdict': ( |
| 'rank-1 (collapsed/aligned)' if S_norm[0] > 0.85 else |
| 'rank-2 (planar β could be polynomial or trig)' if S_norm[:2].sum() > 0.92 else |
| 'rank-3 (full, balanced)' if S_norm.std() < 0.05 else |
| 'rank-3 (full, imbalanced)' |
| ), |
| } |
|
|
|
|
| def test_parametric_ordering(M_avg): |
| """Test 2: Project rows onto first principal axis, sort, check smoothness. |
| |
| If rows lie on a smooth parametric curve (polynomial, trig), sorting |
| by PC1 projection should produce a smooth sequence. Smoothness = |
| 1 / variance of consecutive Ξ in PC2/PC3 coords (after sort). |
| """ |
| U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) |
| |
| proj = M_avg @ Vt.T |
|
|
| |
| sort_idx = np.argsort(proj[:, 0]) |
| sorted_proj = proj[sort_idx] |
|
|
| |
| deltas_pc2 = np.diff(sorted_proj[:, 1]) |
| deltas_pc3 = np.diff(sorted_proj[:, 2]) |
|
|
| |
| range_pc2 = sorted_proj[:, 1].max() - sorted_proj[:, 1].min() |
| range_pc3 = sorted_proj[:, 2].max() - sorted_proj[:, 2].min() |
|
|
| smoothness_pc2 = 1.0 - (np.abs(deltas_pc2).mean() / (range_pc2 + 1e-8)) |
| smoothness_pc3 = 1.0 - (np.abs(deltas_pc3).mean() / (range_pc3 + 1e-8)) |
|
|
| return { |
| 'sort_order': sort_idx.tolist(), |
| 'smoothness_pc2': float(smoothness_pc2), |
| 'smoothness_pc3': float(smoothness_pc3), |
| 'pc1_range': float(proj[:, 0].max() - proj[:, 0].min()), |
| 'pc2_range': float(range_pc2), |
| 'pc3_range': float(range_pc3), |
| 'verdict': ( |
| 'smooth parametric curve' if min(smoothness_pc2, smoothness_pc3) > 0.85 else |
| 'partial structure' if min(smoothness_pc2, smoothness_pc3) > 0.5 else |
| 'scattered (cluster-like)' |
| ), |
| } |
|
|
|
|
| def test_polynomial_fit(M_avg): |
| """Test 3: Try polynomial bases of various orders. |
| |
| Order rows by PC1 projection. Fit each PC2/PC3 coordinate as a function |
| of PC1. Polynomial degrees 1, 2, 3, 4. Best-fit RΒ² tells us the order. |
| Also tries [1, sin(x), cos(x)] for trigonometric basis. |
| """ |
| U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) |
| proj = M_avg @ Vt.T |
| sort_idx = np.argsort(proj[:, 0]) |
|
|
| x = proj[sort_idx, 0] |
| y2 = proj[sort_idx, 1] |
| y3 = proj[sort_idx, 2] |
|
|
| |
| x_norm = 2 * (x - x.min()) / (x.max() - x.min() + 1e-8) - 1 |
|
|
| def r2(y_true, y_pred): |
| ss_res = ((y_true - y_pred) ** 2).sum() |
| ss_tot = ((y_true - y_true.mean()) ** 2).sum() |
| return 1 - ss_res / (ss_tot + 1e-12) |
|
|
| poly_results = {} |
| for deg in [1, 2, 3, 4]: |
| coef2 = np.polyfit(x_norm, y2, deg) |
| coef3 = np.polyfit(x_norm, y3, deg) |
| pred2 = np.polyval(coef2, x_norm) |
| pred3 = np.polyval(coef3, x_norm) |
| poly_results[f'degree_{deg}'] = { |
| 'r2_pc2': float(r2(y2, pred2)), |
| 'r2_pc3': float(r2(y3, pred3)), |
| } |
|
|
| |
| def trig_basis(x): |
| return np.column_stack([ |
| np.ones_like(x), |
| np.sin(np.pi * x), np.cos(np.pi * x), |
| np.sin(2 * np.pi * x), np.cos(2 * np.pi * x), |
| ]) |
|
|
| B = trig_basis(x_norm) |
| coef2_t, _, _, _ = np.linalg.lstsq(B, y2, rcond=None) |
| coef3_t, _, _, _ = np.linalg.lstsq(B, y3, rcond=None) |
| trig_r2_pc2 = r2(y2, B @ coef2_t) |
| trig_r2_pc3 = r2(y3, B @ coef3_t) |
|
|
| |
| best_poly_deg = max([1, 2, 3, 4], |
| key=lambda d: poly_results[f'degree_{d}']['r2_pc2']) |
| best_poly_r2 = poly_results[f'degree_{best_poly_deg}']['r2_pc2'] |
|
|
| return { |
| 'polynomial': poly_results, |
| 'trigonometric': { |
| 'r2_pc2': float(trig_r2_pc2), |
| 'r2_pc3': float(trig_r2_pc3), |
| 'coefs_pc2': coef2_t.tolist(), |
| }, |
| 'best_poly_degree': best_poly_deg, |
| 'best_poly_r2': float(best_poly_r2), |
| 'verdict': ( |
| f'polynomial degree {best_poly_deg} (RΒ²={best_poly_r2:.3f})' |
| if best_poly_r2 > 0.95 else |
| f'trigonometric (RΒ²={trig_r2_pc2:.3f})' |
| if trig_r2_pc2 > 0.95 else |
| f'no clean parametric fit (best poly RΒ²={best_poly_r2:.3f}, ' |
| f'trig RΒ²={trig_r2_pc2:.3f})' |
| ), |
| } |
|
|
|
|
| def test_cluster_structure(M_avg): |
| """Test 4: k-means + silhouette across k = 2..8. |
| |
| High silhouette at small k β genuine clusters. Low silhouette across |
| all k β continuous spread (consistent with parametric structure). |
| """ |
| results = {} |
| best_k = None |
| best_score = -1 |
| for k in range(2, min(9, M_avg.shape[0])): |
| km = KMeans(n_clusters=k, n_init=10, random_state=42) |
| labels = km.fit_predict(M_avg) |
| if len(set(labels)) < 2: |
| continue |
| score = silhouette_score(M_avg, labels) |
| results[f'k={k}'] = { |
| 'silhouette': float(score), |
| 'inertia': float(km.inertia_), |
| } |
| if score > best_score: |
| best_score = score |
| best_k = k |
|
|
| return { |
| 'per_k': results, |
| 'best_k': best_k, |
| 'best_silhouette': float(best_score), |
| 'verdict': ( |
| f'strong clusters (k={best_k}, silhouette={best_score:.3f})' |
| if best_score > 0.5 else |
| f'weak clusters (k={best_k}, silhouette={best_score:.3f})' |
| if best_score > 0.25 else |
| f'no clear clusters (best silhouette={best_score:.3f}) β ' |
| f'consistent with continuous structure' |
| ), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def plot_diagnostic(M_avg, all_M, results, output_path): |
| """4-panel diagnostic plot.""" |
| fig = plt.figure(figsize=(16, 12)) |
|
|
| |
| ax1 = fig.add_subplot(2, 2, 1, projection='3d') |
| U, S, Vt = np.linalg.svd(M_avg, full_matrices=False) |
| proj = M_avg @ Vt.T |
| sort_idx = np.argsort(proj[:, 0]) |
| colors = plt.cm.viridis(np.linspace(0, 1, len(sort_idx))) |
| for i, idx in enumerate(sort_idx): |
| ax1.scatter(M_avg[idx, 0], M_avg[idx, 1], M_avg[idx, 2], |
| c=[colors[i]], s=80, edgecolors='black', linewidths=0.5) |
| ax1.set_xlabel('D1') |
| ax1.set_ylabel('D2') |
| ax1.set_zlabel('D3') |
| ax1.set_title(f'P-rank09 row matrix M (V=32, D=3)\n' |
| f'colored by PC1 sort order\n' |
| f'effective rank: {results["rank"]["effective_rank"]:.2f}') |
|
|
| |
| ax2 = fig.add_subplot(2, 2, 2) |
| SVs = np.array(results['rank']['singular_values']) |
| ax2.bar(['SV1', 'SV2', 'SV3'], SVs, color=['red', 'orange', 'yellow']) |
| ax2.set_ylabel('Singular value') |
| ax2.set_title(f'Singular values of M\n' |
| f'top1 share: {results["rank"]["top1_share"]:.2%}\n' |
| f'verdict: {results["rank"]["verdict"]}') |
| for i, sv in enumerate(SVs): |
| ax2.text(i, sv, f'{sv:.3f}', ha='center', va='bottom') |
|
|
| |
| ax3 = fig.add_subplot(2, 2, 3) |
| x = proj[sort_idx, 0] |
| y2 = proj[sort_idx, 1] |
| y3 = proj[sort_idx, 2] |
| ax3.plot(x, y2, 'o-', color='blue', label='PC2 vs PC1', markersize=6) |
| ax3.plot(x, y3, 's-', color='green', label='PC3 vs PC1', markersize=6) |
| ax3.set_xlabel('PC1 projection') |
| ax3.set_ylabel('PC2 / PC3 projection') |
| ax3.set_title(f'Parametric ordering test\n' |
| f'smoothness PC2: {results["parametric"]["smoothness_pc2"]:.3f}, ' |
| f'PC3: {results["parametric"]["smoothness_pc3"]:.3f}\n' |
| f'verdict: {results["parametric"]["verdict"]}') |
| ax3.legend() |
| ax3.grid(alpha=0.3) |
|
|
| |
| ax4 = fig.add_subplot(2, 2, 4) |
| ks = [] |
| sils = [] |
| for k_str, r in results['cluster']['per_k'].items(): |
| ks.append(int(k_str.split('=')[1])) |
| sils.append(r['silhouette']) |
| ax4.plot(ks, sils, 'o-', color='purple', markersize=8) |
| ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5, |
| label='strong cluster threshold') |
| ax4.axhline(0.25, color='orange', linestyle='--', alpha=0.5, |
| label='weak cluster threshold') |
| ax4.set_xlabel('k (number of clusters)') |
| ax4.set_ylabel('silhouette score') |
| ax4.set_title(f'Cluster structure test\n' |
| f'best k={results["cluster"]["best_k"]}, ' |
| f'silhouette={results["cluster"]["best_silhouette"]:.3f}\n' |
| f'verdict: {results["cluster"]["verdict"]}') |
| ax4.legend(fontsize=8) |
| ax4.grid(alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=120, bbox_inches='tight') |
| plt.show() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("Loading P-rank09 model...") |
| model, cfg = load_rank09() |
| print(f" Architecture: V={cfg.matrix_v}, D={cfg.D}, " |
| f"patch_size={cfg.patch_size}, hidden={cfg.hidden}") |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {n_params:,}") |
|
|
| print("\nCollecting M rows from gaussian inputs...") |
| all_M = collect_rows(model, cfg, n_batches=8, batch_size=64) |
| print(f" Collected {all_M.shape[0]} samples of M [V={all_M.shape[1]}, " |
| f"D={all_M.shape[2]}]") |
|
|
| |
| M_avg = all_M.mean(dim=0).numpy() |
| M_std = all_M.std(dim=0).numpy() |
| print(f" M_avg shape: {M_avg.shape}") |
| print(f" Per-row variability (mean βΟββ across rows): " |
| f"{np.linalg.norm(M_std, axis=1).mean():.4f}") |
| print(f" Per-row mean magnitude (mean βΞΌββ): " |
| f"{np.linalg.norm(M_avg, axis=1).mean():.4f}") |
|
|
| |
| row_norms = np.linalg.norm(M_avg, axis=1) |
| print(f" Row norm range: [{row_norms.min():.4f}, {row_norms.max():.4f}]") |
| print(f" (sphere-normed rows should all have norm ~1.0)") |
|
|
| print("\n" + "β" * 70) |
| print("HYPOTHESIS TESTS") |
| print("β" * 70) |
|
|
| print("\n[1/4] Rank structure (SVD)...") |
| rank_results = test_rank_structure(M_avg) |
| print(f" Singular values: {[f'{s:.4f}' for s in rank_results['singular_values']]}") |
| print(f" Effective rank: {rank_results['effective_rank']:.2f}") |
| print(f" Top-1 share: {rank_results['top1_share']:.2%}") |
| print(f" VERDICT: {rank_results['verdict']}") |
|
|
| print("\n[2/4] Parametric ordering (PC1 sort + smoothness)...") |
| param_results = test_parametric_ordering(M_avg) |
| print(f" Smoothness PC2: {param_results['smoothness_pc2']:.3f}") |
| print(f" Smoothness PC3: {param_results['smoothness_pc3']:.3f}") |
| print(f" VERDICT: {param_results['verdict']}") |
|
|
| print("\n[3/4] Polynomial / trigonometric fit...") |
| fit_results = test_polynomial_fit(M_avg) |
| print(f" Polynomial fits (RΒ² for PC2):") |
| for deg in [1, 2, 3, 4]: |
| r2 = fit_results['polynomial'][f'degree_{deg}']['r2_pc2'] |
| print(f" degree {deg}: RΒ² = {r2:.4f}") |
| print(f" Trigonometric fit (RΒ² for PC2): " |
| f"{fit_results['trigonometric']['r2_pc2']:.4f}") |
| print(f" VERDICT: {fit_results['verdict']}") |
|
|
| print("\n[4/4] Cluster structure (k-means silhouette)...") |
| cluster_results = test_cluster_structure(M_avg) |
| print(f" Per-k silhouette:") |
| for k_str, r in cluster_results['per_k'].items(): |
| print(f" {k_str}: silhouette = {r['silhouette']:.3f}") |
| print(f" VERDICT: {cluster_results['verdict']}") |
|
|
| all_results = { |
| 'config': { |
| 'variant': 'P_rank09_h64_V32_D3_dp0_nx0_adam', |
| 'V': cfg.matrix_v, 'D': cfg.D, 'params': n_params, |
| 'gaussian_test_mse': 0.02782, |
| 'observed_cv': 0.035, |
| }, |
| 'M_avg_shape': list(M_avg.shape), |
| 'row_norms_mean': float(row_norms.mean()), |
| 'row_norms_std': float(row_norms.std()), |
| 'rank': rank_results, |
| 'parametric': param_results, |
| 'fit': fit_results, |
| 'cluster': cluster_results, |
| } |
|
|
| print("\n" + "β" * 70) |
| print("OVERALL INTERPRETATION") |
| print("β" * 70) |
| print(f" Rank: {rank_results['verdict']}") |
| print(f" Parametric: {param_results['verdict']}") |
| print(f" Fit: {fit_results['verdict']}") |
| print(f" Clusters: {cluster_results['verdict']}") |
|
|
| |
| is_polynomial = ( |
| fit_results['best_poly_r2'] > 0.95 and |
| rank_results['effective_rank'] < 2.5 |
| ) |
| is_trig = ( |
| fit_results['trigonometric']['r2_pc2'] > 0.95 and |
| not is_polynomial |
| ) |
| is_clustered = cluster_results['best_silhouette'] > 0.5 |
| is_collapsed = rank_results['top1_share'] > 0.85 |
|
|
| print(f"\n Composite read:") |
| if is_polynomial: |
| deg = fit_results['best_poly_degree'] |
| print(f" β POLYNOMIAL CONFIRMED (degree {deg}). " |
| f"P-Class naming validated.") |
| elif is_trig: |
| print(f" β TRIGONOMETRIC structure detected. " |
| f"P-Class might be better named F-Class (Fourier).") |
| elif is_collapsed: |
| print(f" β COLLAPSED β rows essentially 1-dimensional. " |
| f"Failed differentiation, not a useful battery.") |
| elif is_clustered: |
| k = cluster_results['best_k'] |
| print(f" β CLUSTERED into {k} groups. " |
| f"P-Class might be better named K-Class " |
| f"(k-means / quantization).") |
| else: |
| print(f" β MIXED structure β not cleanly polynomial, trig, or " |
| f"clustered. Worth probing further with higher-order bases or " |
| f"deeper geometric analysis.") |
|
|
| with open(OUTPUT_JSON, 'w') as f: |
| json.dump(all_results, f, indent=2, default=str) |
| print(f"\n Results saved: {OUTPUT_JSON}") |
|
|
| plot_diagnostic(M_avg, all_M, all_results, OUTPUT_PLOT) |
| print(f" Plot saved: {OUTPUT_PLOT}") |
|
|
| return all_results |
|
|
|
|
| if __name__ == '__main__': |
| results = main() |