geolip-svae-implicit-solver-experiments / 006_probe_winners_ft1.py
AbstractPhil's picture
Rename 6_probe_winners_ft1.py to 006_probe_winners_ft1.py
2d31199 verified
"""
cell_p_class_probe.py β€” geometric structure probe for P-Class batteries
Loads P-rank09 (h64_V32_D3_dp0_nx0_adam, MSE 0.028, CV 0.03) and asks
what its 32 row vectors in 3D space actually look like.
Four hypothesis tests:
1. RANK STRUCTURE β€” SVD on the 32Γ—3 row matrix M.
- Polynomial basis: rank ≀ 2 (Vandermonde collapses)
- Trig basis: rank = 2 or 3 with specific singular value ratio
- Cluster: rank 3, all SVs comparable
- Collapsed: rank 1, one dominant SV
2. PARAMETRIC ORDERING β€” Try ordering rows by their first coordinate
(or first principal axis projection). If rows form a smooth curve
when ordered, we're seeing a parametric structure (polynomial,
trig, etc). If they're scattered with no order, it's clusters.
Metric: smoothness of consecutive Ξ” when sorted along PC1.
3. POLYNOMIAL FIT TEST β€” Fit a Vandermonde matrix to the ordered rows.
If RΒ² > 0.95 with cubic, polynomial hypothesis confirmed.
Try [1, x, xΒ²], [1, x, xΒ², xΒ³], [1, sin(x), cos(x)].
4. CLUSTER COUNT β€” k-means with k = 2..8 on the 32 rows. If silhouette
score is high at small k, it's clustered. If silhouette is low for
all k, the rows are spread continuously (consistent with parametric).
Outputs:
- Console verdict for each hypothesis
- /content/phaseQ_reports/p_rank09_probe.png β€” 4-panel diagnostic plot
- /content/phaseQ_reports/p_rank09_probe.json β€” all numerical results
"""
import json
import math
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
CKPT_DIR = Path("/content/phaseQ_reports")
RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe.png"
OUTPUT_JSON = CKPT_DIR / "p_rank09_probe.json"
def load_rank09():
"""Reconstruct P-rank09 model and load its trained weights."""
cfgs = get_phaseQ_configs()
rank09_cfg = next(c for c in cfgs if 'rank09' in c['variant'])
cfg = build_run_config(rank09_cfg)
overrides = rank09_cfg['overrides']
model = PatchSVAE_F_Ablation(
matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
hidden=cfg.hidden, depth=cfg.depth,
n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
max_alpha=overrides.get('max_alpha', cfg.max_alpha),
alpha_init=cfg.alpha_init,
activation=overrides.get('activation', 'gelu'),
row_norm=overrides.get('row_norm', 'sphere'),
svd_mode=overrides.get('svd', 'fp64'),
linear_readout=overrides.get('linear_readout', False),
match_params=overrides.get('match_params', True),
init_scheme=overrides.get('init', 'orthogonal'),
)
ckpt = torch.load(RANK09_CKPT, map_location='cpu', weights_only=False)
# Trainer saves model weights under 'model_state'; the older
# 'model_state_dict' / 'state_dict' fallbacks are kept for compatibility.
state_dict = (
ckpt.get('model_state')
or ckpt.get('model_state_dict')
or ckpt.get('state_dict')
or ckpt
)
model.load_state_dict(state_dict)
model.eval()
return model, cfg
def collect_rows(model, cfg, n_batches=8, batch_size=64):
"""Run gaussian noise through encoder, collect M rows from one canonical
patch position to get a stable [n_samples, V, D] tensor of row matrices."""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
ds = OmegaNoiseDataset(
size=n_batches * batch_size,
img_size=cfg.img_size,
allowed_types=[0]) # gaussian
loader = torch.utils.data.DataLoader(
ds, batch_size=batch_size, shuffle=False)
all_M = [] # collect M from patch 0 of every sample
with torch.no_grad():
for imgs, _ in loader:
imgs = imgs.to(device)
out = model(imgs)
# M shape: [B, N_patches, V, D]
M_patch0 = out['svd']['M'][:, 0] # [B, V, D]
all_M.append(M_patch0.cpu())
return torch.cat(all_M, dim=0) # [n_samples, V, D]
# ════════════════════════════════════════════════════════════════════
# Hypothesis tests
# ════════════════════════════════════════════════════════════════════
def test_rank_structure(M_avg):
"""Test 1: SVD on the canonical row matrix.
M_avg: averaged 32Γ—3 row matrix. SVD gives 3 singular values.
Predictions:
Polynomial Vandermonde: top-1 SV dominates, rankβ‰ˆ1-2
Trig basis: balanced top-2 SVs, small 3rd
Sphere uniform (H2): ~equal SVs, full rank
Cluster: depends on cluster geometry
"""
U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
S_norm = S / S.sum()
erank = math.exp(-(S_norm * np.log(S_norm + 1e-12)).sum())
return {
'singular_values': S.tolist(),
'normalized_SV': S_norm.tolist(),
'effective_rank': erank,
'top1_share': S_norm[0],
'top2_share': S_norm[:2].sum(),
'verdict': (
'rank-1 (collapsed/aligned)' if S_norm[0] > 0.85 else
'rank-2 (planar β€” could be polynomial or trig)' if S_norm[:2].sum() > 0.92 else
'rank-3 (full, balanced)' if S_norm.std() < 0.05 else
'rank-3 (full, imbalanced)'
),
}
def test_parametric_ordering(M_avg):
"""Test 2: Project rows onto first principal axis, sort, check smoothness.
If rows lie on a smooth parametric curve (polynomial, trig), sorting
by PC1 projection should produce a smooth sequence. Smoothness =
1 / variance of consecutive Ξ” in PC2/PC3 coords (after sort).
"""
U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
# Project rows onto principal axes
proj = M_avg @ Vt.T # [V, 3]
# Sort by PC1
sort_idx = np.argsort(proj[:, 0])
sorted_proj = proj[sort_idx]
# Ξ” between consecutive sorted rows in PC2, PC3
deltas_pc2 = np.diff(sorted_proj[:, 1])
deltas_pc3 = np.diff(sorted_proj[:, 2])
# If smooth curve, Ξ” should be small relative to overall PC2/PC3 spread
range_pc2 = sorted_proj[:, 1].max() - sorted_proj[:, 1].min()
range_pc3 = sorted_proj[:, 2].max() - sorted_proj[:, 2].min()
smoothness_pc2 = 1.0 - (np.abs(deltas_pc2).mean() / (range_pc2 + 1e-8))
smoothness_pc3 = 1.0 - (np.abs(deltas_pc3).mean() / (range_pc3 + 1e-8))
return {
'sort_order': sort_idx.tolist(),
'smoothness_pc2': float(smoothness_pc2),
'smoothness_pc3': float(smoothness_pc3),
'pc1_range': float(proj[:, 0].max() - proj[:, 0].min()),
'pc2_range': float(range_pc2),
'pc3_range': float(range_pc3),
'verdict': (
'smooth parametric curve' if min(smoothness_pc2, smoothness_pc3) > 0.85 else
'partial structure' if min(smoothness_pc2, smoothness_pc3) > 0.5 else
'scattered (cluster-like)'
),
}
def test_polynomial_fit(M_avg):
"""Test 3: Try polynomial bases of various orders.
Order rows by PC1 projection. Fit each PC2/PC3 coordinate as a function
of PC1. Polynomial degrees 1, 2, 3, 4. Best-fit RΒ² tells us the order.
Also tries [1, sin(x), cos(x)] for trigonometric basis.
"""
U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
proj = M_avg @ Vt.T
sort_idx = np.argsort(proj[:, 0])
x = proj[sort_idx, 0]
y2 = proj[sort_idx, 1]
y3 = proj[sort_idx, 2]
# Normalize x to [-1, 1] for stable polyfit
x_norm = 2 * (x - x.min()) / (x.max() - x.min() + 1e-8) - 1
def r2(y_true, y_pred):
ss_res = ((y_true - y_pred) ** 2).sum()
ss_tot = ((y_true - y_true.mean()) ** 2).sum()
return 1 - ss_res / (ss_tot + 1e-12)
poly_results = {}
for deg in [1, 2, 3, 4]:
coef2 = np.polyfit(x_norm, y2, deg)
coef3 = np.polyfit(x_norm, y3, deg)
pred2 = np.polyval(coef2, x_norm)
pred3 = np.polyval(coef3, x_norm)
poly_results[f'degree_{deg}'] = {
'r2_pc2': float(r2(y2, pred2)),
'r2_pc3': float(r2(y3, pred3)),
}
# Trigonometric fit: y = a + bΒ·sin(Ο€x) + cΒ·cos(Ο€x) + dΒ·sin(2Ο€x) + eΒ·cos(2Ο€x)
def trig_basis(x):
return np.column_stack([
np.ones_like(x),
np.sin(np.pi * x), np.cos(np.pi * x),
np.sin(2 * np.pi * x), np.cos(2 * np.pi * x),
])
B = trig_basis(x_norm)
coef2_t, _, _, _ = np.linalg.lstsq(B, y2, rcond=None)
coef3_t, _, _, _ = np.linalg.lstsq(B, y3, rcond=None)
trig_r2_pc2 = r2(y2, B @ coef2_t)
trig_r2_pc3 = r2(y3, B @ coef3_t)
# Pick the best fit
best_poly_deg = max([1, 2, 3, 4],
key=lambda d: poly_results[f'degree_{d}']['r2_pc2'])
best_poly_r2 = poly_results[f'degree_{best_poly_deg}']['r2_pc2']
return {
'polynomial': poly_results,
'trigonometric': {
'r2_pc2': float(trig_r2_pc2),
'r2_pc3': float(trig_r2_pc3),
'coefs_pc2': coef2_t.tolist(),
},
'best_poly_degree': best_poly_deg,
'best_poly_r2': float(best_poly_r2),
'verdict': (
f'polynomial degree {best_poly_deg} (RΒ²={best_poly_r2:.3f})'
if best_poly_r2 > 0.95 else
f'trigonometric (RΒ²={trig_r2_pc2:.3f})'
if trig_r2_pc2 > 0.95 else
f'no clean parametric fit (best poly RΒ²={best_poly_r2:.3f}, '
f'trig RΒ²={trig_r2_pc2:.3f})'
),
}
def test_cluster_structure(M_avg):
"""Test 4: k-means + silhouette across k = 2..8.
High silhouette at small k β†’ genuine clusters. Low silhouette across
all k β†’ continuous spread (consistent with parametric structure).
"""
results = {}
best_k = None
best_score = -1
for k in range(2, min(9, M_avg.shape[0])):
km = KMeans(n_clusters=k, n_init=10, random_state=42)
labels = km.fit_predict(M_avg)
if len(set(labels)) < 2:
continue
score = silhouette_score(M_avg, labels)
results[f'k={k}'] = {
'silhouette': float(score),
'inertia': float(km.inertia_),
}
if score > best_score:
best_score = score
best_k = k
return {
'per_k': results,
'best_k': best_k,
'best_silhouette': float(best_score),
'verdict': (
f'strong clusters (k={best_k}, silhouette={best_score:.3f})'
if best_score > 0.5 else
f'weak clusters (k={best_k}, silhouette={best_score:.3f})'
if best_score > 0.25 else
f'no clear clusters (best silhouette={best_score:.3f}) β€” '
f'consistent with continuous structure'
),
}
# ════════════════════════════════════════════════════════════════════
# Plotting
# ════════════════════════════════════════════════════════════════════
def plot_diagnostic(M_avg, all_M, results, output_path):
"""4-panel diagnostic plot."""
fig = plt.figure(figsize=(16, 12))
# Panel 1: 3D scatter of the canonical 32 rows
ax1 = fig.add_subplot(2, 2, 1, projection='3d')
U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
proj = M_avg @ Vt.T
sort_idx = np.argsort(proj[:, 0])
colors = plt.cm.viridis(np.linspace(0, 1, len(sort_idx)))
for i, idx in enumerate(sort_idx):
ax1.scatter(M_avg[idx, 0], M_avg[idx, 1], M_avg[idx, 2],
c=[colors[i]], s=80, edgecolors='black', linewidths=0.5)
ax1.set_xlabel('D1')
ax1.set_ylabel('D2')
ax1.set_zlabel('D3')
ax1.set_title(f'P-rank09 row matrix M (V=32, D=3)\n'
f'colored by PC1 sort order\n'
f'effective rank: {results["rank"]["effective_rank"]:.2f}')
# Panel 2: Singular value spectrum
ax2 = fig.add_subplot(2, 2, 2)
SVs = np.array(results['rank']['singular_values'])
ax2.bar(['SV1', 'SV2', 'SV3'], SVs, color=['red', 'orange', 'yellow'])
ax2.set_ylabel('Singular value')
ax2.set_title(f'Singular values of M\n'
f'top1 share: {results["rank"]["top1_share"]:.2%}\n'
f'verdict: {results["rank"]["verdict"]}')
for i, sv in enumerate(SVs):
ax2.text(i, sv, f'{sv:.3f}', ha='center', va='bottom')
# Panel 3: PC2 and PC3 vs PC1 (parametric curve test)
ax3 = fig.add_subplot(2, 2, 3)
x = proj[sort_idx, 0]
y2 = proj[sort_idx, 1]
y3 = proj[sort_idx, 2]
ax3.plot(x, y2, 'o-', color='blue', label='PC2 vs PC1', markersize=6)
ax3.plot(x, y3, 's-', color='green', label='PC3 vs PC1', markersize=6)
ax3.set_xlabel('PC1 projection')
ax3.set_ylabel('PC2 / PC3 projection')
ax3.set_title(f'Parametric ordering test\n'
f'smoothness PC2: {results["parametric"]["smoothness_pc2"]:.3f}, '
f'PC3: {results["parametric"]["smoothness_pc3"]:.3f}\n'
f'verdict: {results["parametric"]["verdict"]}')
ax3.legend()
ax3.grid(alpha=0.3)
# Panel 4: Cluster silhouette across k
ax4 = fig.add_subplot(2, 2, 4)
ks = []
sils = []
for k_str, r in results['cluster']['per_k'].items():
ks.append(int(k_str.split('=')[1]))
sils.append(r['silhouette'])
ax4.plot(ks, sils, 'o-', color='purple', markersize=8)
ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5,
label='strong cluster threshold')
ax4.axhline(0.25, color='orange', linestyle='--', alpha=0.5,
label='weak cluster threshold')
ax4.set_xlabel('k (number of clusters)')
ax4.set_ylabel('silhouette score')
ax4.set_title(f'Cluster structure test\n'
f'best k={results["cluster"]["best_k"]}, '
f'silhouette={results["cluster"]["best_silhouette"]:.3f}\n'
f'verdict: {results["cluster"]["verdict"]}')
ax4.legend(fontsize=8)
ax4.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=120, bbox_inches='tight')
plt.show()
# ════════════════════════════════════════════════════════════════════
# Main
# ════════════════════════════════════════════════════════════════════
def main():
print("Loading P-rank09 model...")
model, cfg = load_rank09()
print(f" Architecture: V={cfg.matrix_v}, D={cfg.D}, "
f"patch_size={cfg.patch_size}, hidden={cfg.hidden}")
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,}")
print("\nCollecting M rows from gaussian inputs...")
all_M = collect_rows(model, cfg, n_batches=8, batch_size=64)
print(f" Collected {all_M.shape[0]} samples of M [V={all_M.shape[1]}, "
f"D={all_M.shape[2]}]")
# Average M over samples to get the canonical row matrix
M_avg = all_M.mean(dim=0).numpy()
M_std = all_M.std(dim=0).numpy()
print(f" M_avg shape: {M_avg.shape}")
print(f" Per-row variability (mean β€–Οƒβ€–β‚‚ across rows): "
f"{np.linalg.norm(M_std, axis=1).mean():.4f}")
print(f" Per-row mean magnitude (mean β€–ΞΌβ€–β‚‚): "
f"{np.linalg.norm(M_avg, axis=1).mean():.4f}")
# Sphere-norm verification
row_norms = np.linalg.norm(M_avg, axis=1)
print(f" Row norm range: [{row_norms.min():.4f}, {row_norms.max():.4f}]")
print(f" (sphere-normed rows should all have norm ~1.0)")
print("\n" + "═" * 70)
print("HYPOTHESIS TESTS")
print("═" * 70)
print("\n[1/4] Rank structure (SVD)...")
rank_results = test_rank_structure(M_avg)
print(f" Singular values: {[f'{s:.4f}' for s in rank_results['singular_values']]}")
print(f" Effective rank: {rank_results['effective_rank']:.2f}")
print(f" Top-1 share: {rank_results['top1_share']:.2%}")
print(f" VERDICT: {rank_results['verdict']}")
print("\n[2/4] Parametric ordering (PC1 sort + smoothness)...")
param_results = test_parametric_ordering(M_avg)
print(f" Smoothness PC2: {param_results['smoothness_pc2']:.3f}")
print(f" Smoothness PC3: {param_results['smoothness_pc3']:.3f}")
print(f" VERDICT: {param_results['verdict']}")
print("\n[3/4] Polynomial / trigonometric fit...")
fit_results = test_polynomial_fit(M_avg)
print(f" Polynomial fits (RΒ² for PC2):")
for deg in [1, 2, 3, 4]:
r2 = fit_results['polynomial'][f'degree_{deg}']['r2_pc2']
print(f" degree {deg}: RΒ² = {r2:.4f}")
print(f" Trigonometric fit (RΒ² for PC2): "
f"{fit_results['trigonometric']['r2_pc2']:.4f}")
print(f" VERDICT: {fit_results['verdict']}")
print("\n[4/4] Cluster structure (k-means silhouette)...")
cluster_results = test_cluster_structure(M_avg)
print(f" Per-k silhouette:")
for k_str, r in cluster_results['per_k'].items():
print(f" {k_str}: silhouette = {r['silhouette']:.3f}")
print(f" VERDICT: {cluster_results['verdict']}")
all_results = {
'config': {
'variant': 'P_rank09_h64_V32_D3_dp0_nx0_adam',
'V': cfg.matrix_v, 'D': cfg.D, 'params': n_params,
'gaussian_test_mse': 0.02782,
'observed_cv': 0.035,
},
'M_avg_shape': list(M_avg.shape),
'row_norms_mean': float(row_norms.mean()),
'row_norms_std': float(row_norms.std()),
'rank': rank_results,
'parametric': param_results,
'fit': fit_results,
'cluster': cluster_results,
}
print("\n" + "═" * 70)
print("OVERALL INTERPRETATION")
print("═" * 70)
print(f" Rank: {rank_results['verdict']}")
print(f" Parametric: {param_results['verdict']}")
print(f" Fit: {fit_results['verdict']}")
print(f" Clusters: {cluster_results['verdict']}")
# Composite verdict logic
is_polynomial = (
fit_results['best_poly_r2'] > 0.95 and
rank_results['effective_rank'] < 2.5
)
is_trig = (
fit_results['trigonometric']['r2_pc2'] > 0.95 and
not is_polynomial
)
is_clustered = cluster_results['best_silhouette'] > 0.5
is_collapsed = rank_results['top1_share'] > 0.85
print(f"\n Composite read:")
if is_polynomial:
deg = fit_results['best_poly_degree']
print(f" β†’ POLYNOMIAL CONFIRMED (degree {deg}). "
f"P-Class naming validated.")
elif is_trig:
print(f" β†’ TRIGONOMETRIC structure detected. "
f"P-Class might be better named F-Class (Fourier).")
elif is_collapsed:
print(f" β†’ COLLAPSED β€” rows essentially 1-dimensional. "
f"Failed differentiation, not a useful battery.")
elif is_clustered:
k = cluster_results['best_k']
print(f" β†’ CLUSTERED into {k} groups. "
f"P-Class might be better named K-Class "
f"(k-means / quantization).")
else:
print(f" β†’ MIXED structure β€” not cleanly polynomial, trig, or "
f"clustered. Worth probing further with higher-order bases or "
f"deeper geometric analysis.")
with open(OUTPUT_JSON, 'w') as f:
json.dump(all_results, f, indent=2, default=str)
print(f"\n Results saved: {OUTPUT_JSON}")
plot_diagnostic(M_avg, all_M, all_results, OUTPUT_PLOT)
print(f" Plot saved: {OUTPUT_PLOT}")
return all_results
if __name__ == '__main__':
results = main()