AbstractPhil commited on
Commit
bd4ab66
Β·
verified Β·
1 Parent(s): 947e75a

Create 6_probe_winners_ft1.py

Browse files
Files changed (1) hide show
  1. 6_probe_winners_ft1.py +508 -0
6_probe_winners_ft1.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cell_p_class_probe.py β€” geometric structure probe for P-Class batteries
3
+
4
+ Loads P-rank09 (h64_V32_D3_dp0_nx0_adam, MSE 0.028, CV 0.03) and asks
5
+ what its 32 row vectors in 3D space actually look like.
6
+
7
+ Four hypothesis tests:
8
+ 1. RANK STRUCTURE β€” SVD on the 32Γ—3 row matrix M.
9
+ - Polynomial basis: rank ≀ 2 (Vandermonde collapses)
10
+ - Trig basis: rank = 2 or 3 with specific singular value ratio
11
+ - Cluster: rank 3, all SVs comparable
12
+ - Collapsed: rank 1, one dominant SV
13
+
14
+ 2. PARAMETRIC ORDERING β€” Try ordering rows by their first coordinate
15
+ (or first principal axis projection). If rows form a smooth curve
16
+ when ordered, we're seeing a parametric structure (polynomial,
17
+ trig, etc). If they're scattered with no order, it's clusters.
18
+ Metric: smoothness of consecutive Ξ” when sorted along PC1.
19
+
20
+ 3. POLYNOMIAL FIT TEST β€” Fit a Vandermonde matrix to the ordered rows.
21
+ If RΒ² > 0.95 with cubic, polynomial hypothesis confirmed.
22
+ Try [1, x, xΒ²], [1, x, xΒ², xΒ³], [1, sin(x), cos(x)].
23
+
24
+ 4. CLUSTER COUNT β€” k-means with k = 2..8 on the 32 rows. If silhouette
25
+ score is high at small k, it's clustered. If silhouette is low for
26
+ all k, the rows are spread continuously (consistent with parametric).
27
+
28
+ Outputs:
29
+ - Console verdict for each hypothesis
30
+ - /content/phaseQ_reports/p_rank09_probe.png β€” 4-panel diagnostic plot
31
+ - /content/phaseQ_reports/p_rank09_probe.json β€” all numerical results
32
+ """
33
+
34
+ import json
35
+ import math
36
+ from pathlib import Path
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torch.nn.functional as F
41
+ import matplotlib.pyplot as plt
42
+ from mpl_toolkits.mplot3d import Axes3D # noqa
43
+ from sklearn.cluster import KMeans
44
+ from sklearn.metrics import silhouette_score
45
+
46
+
47
+ CKPT_DIR = Path("/content/phaseQ_reports")
48
+ RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
49
+ OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe.png"
50
+ OUTPUT_JSON = CKPT_DIR / "p_rank09_probe.json"
51
+
52
+
53
+ def load_rank09():
54
+ """Reconstruct P-rank09 model and load its trained weights."""
55
+ cfgs = get_phaseQ_configs()
56
+ rank09_cfg = next(c for c in cfgs if 'rank09' in c['variant'])
57
+ cfg = build_run_config(rank09_cfg)
58
+ overrides = rank09_cfg['overrides']
59
+
60
+ model = PatchSVAE_F_Ablation(
61
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
62
+ hidden=cfg.hidden, depth=cfg.depth,
63
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
64
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
65
+ alpha_init=cfg.alpha_init,
66
+ activation=overrides.get('activation', 'gelu'),
67
+ row_norm=overrides.get('row_norm', 'sphere'),
68
+ svd_mode=overrides.get('svd', 'fp64'),
69
+ linear_readout=overrides.get('linear_readout', False),
70
+ match_params=overrides.get('match_params', True),
71
+ init_scheme=overrides.get('init', 'orthogonal'),
72
+ )
73
+
74
+ ckpt = torch.load(RANK09_CKPT, map_location='cpu', weights_only=False)
75
+ # Trainer saves model weights under 'model_state'; the older
76
+ # 'model_state_dict' / 'state_dict' fallbacks are kept for compatibility.
77
+ state_dict = (
78
+ ckpt.get('model_state')
79
+ or ckpt.get('model_state_dict')
80
+ or ckpt.get('state_dict')
81
+ or ckpt
82
+ )
83
+ model.load_state_dict(state_dict)
84
+ model.eval()
85
+ return model, cfg
86
+
87
+
88
+ def collect_rows(model, cfg, n_batches=8, batch_size=64):
89
+ """Run gaussian noise through encoder, collect M rows from one canonical
90
+ patch position to get a stable [n_samples, V, D] tensor of row matrices."""
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ model = model.to(device)
93
+
94
+ ds = OmegaNoiseDataset(
95
+ size=n_batches * batch_size,
96
+ img_size=cfg.img_size,
97
+ allowed_types=[0]) # gaussian
98
+ loader = torch.utils.data.DataLoader(
99
+ ds, batch_size=batch_size, shuffle=False)
100
+
101
+ all_M = [] # collect M from patch 0 of every sample
102
+ with torch.no_grad():
103
+ for imgs, _ in loader:
104
+ imgs = imgs.to(device)
105
+ out = model(imgs)
106
+ # M shape: [B, N_patches, V, D]
107
+ M_patch0 = out['svd']['M'][:, 0] # [B, V, D]
108
+ all_M.append(M_patch0.cpu())
109
+
110
+ return torch.cat(all_M, dim=0) # [n_samples, V, D]
111
+
112
+
113
+ # ════════════════════════════════════════════════════════════════════
114
+ # Hypothesis tests
115
+ # ════════════════════════════════════════════════════════════════════
116
+
117
+ def test_rank_structure(M_avg):
118
+ """Test 1: SVD on the canonical row matrix.
119
+
120
+ M_avg: averaged 32Γ—3 row matrix. SVD gives 3 singular values.
121
+ Predictions:
122
+ Polynomial Vandermonde: top-1 SV dominates, rankβ‰ˆ1-2
123
+ Trig basis: balanced top-2 SVs, small 3rd
124
+ Sphere uniform (H2): ~equal SVs, full rank
125
+ Cluster: depends on cluster geometry
126
+ """
127
+ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
128
+ S_norm = S / S.sum()
129
+ erank = math.exp(-(S_norm * np.log(S_norm + 1e-12)).sum())
130
+
131
+ return {
132
+ 'singular_values': S.tolist(),
133
+ 'normalized_SV': S_norm.tolist(),
134
+ 'effective_rank': erank,
135
+ 'top1_share': S_norm[0],
136
+ 'top2_share': S_norm[:2].sum(),
137
+ 'verdict': (
138
+ 'rank-1 (collapsed/aligned)' if S_norm[0] > 0.85 else
139
+ 'rank-2 (planar β€” could be polynomial or trig)' if S_norm[:2].sum() > 0.92 else
140
+ 'rank-3 (full, balanced)' if S_norm.std() < 0.05 else
141
+ 'rank-3 (full, imbalanced)'
142
+ ),
143
+ }
144
+
145
+
146
+ def test_parametric_ordering(M_avg):
147
+ """Test 2: Project rows onto first principal axis, sort, check smoothness.
148
+
149
+ If rows lie on a smooth parametric curve (polynomial, trig), sorting
150
+ by PC1 projection should produce a smooth sequence. Smoothness =
151
+ 1 / variance of consecutive Ξ” in PC2/PC3 coords (after sort).
152
+ """
153
+ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
154
+ # Project rows onto principal axes
155
+ proj = M_avg @ Vt.T # [V, 3]
156
+
157
+ # Sort by PC1
158
+ sort_idx = np.argsort(proj[:, 0])
159
+ sorted_proj = proj[sort_idx]
160
+
161
+ # Ξ” between consecutive sorted rows in PC2, PC3
162
+ deltas_pc2 = np.diff(sorted_proj[:, 1])
163
+ deltas_pc3 = np.diff(sorted_proj[:, 2])
164
+
165
+ # If smooth curve, Ξ” should be small relative to overall PC2/PC3 spread
166
+ range_pc2 = sorted_proj[:, 1].max() - sorted_proj[:, 1].min()
167
+ range_pc3 = sorted_proj[:, 2].max() - sorted_proj[:, 2].min()
168
+
169
+ smoothness_pc2 = 1.0 - (np.abs(deltas_pc2).mean() / (range_pc2 + 1e-8))
170
+ smoothness_pc3 = 1.0 - (np.abs(deltas_pc3).mean() / (range_pc3 + 1e-8))
171
+
172
+ return {
173
+ 'sort_order': sort_idx.tolist(),
174
+ 'smoothness_pc2': float(smoothness_pc2),
175
+ 'smoothness_pc3': float(smoothness_pc3),
176
+ 'pc1_range': float(proj[:, 0].max() - proj[:, 0].min()),
177
+ 'pc2_range': float(range_pc2),
178
+ 'pc3_range': float(range_pc3),
179
+ 'verdict': (
180
+ 'smooth parametric curve' if min(smoothness_pc2, smoothness_pc3) > 0.85 else
181
+ 'partial structure' if min(smoothness_pc2, smoothness_pc3) > 0.5 else
182
+ 'scattered (cluster-like)'
183
+ ),
184
+ }
185
+
186
+
187
+ def test_polynomial_fit(M_avg):
188
+ """Test 3: Try polynomial bases of various orders.
189
+
190
+ Order rows by PC1 projection. Fit each PC2/PC3 coordinate as a function
191
+ of PC1. Polynomial degrees 1, 2, 3, 4. Best-fit RΒ² tells us the order.
192
+ Also tries [1, sin(x), cos(x)] for trigonometric basis.
193
+ """
194
+ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
195
+ proj = M_avg @ Vt.T
196
+ sort_idx = np.argsort(proj[:, 0])
197
+
198
+ x = proj[sort_idx, 0]
199
+ y2 = proj[sort_idx, 1]
200
+ y3 = proj[sort_idx, 2]
201
+
202
+ # Normalize x to [-1, 1] for stable polyfit
203
+ x_norm = 2 * (x - x.min()) / (x.max() - x.min() + 1e-8) - 1
204
+
205
+ def r2(y_true, y_pred):
206
+ ss_res = ((y_true - y_pred) ** 2).sum()
207
+ ss_tot = ((y_true - y_true.mean()) ** 2).sum()
208
+ return 1 - ss_res / (ss_tot + 1e-12)
209
+
210
+ poly_results = {}
211
+ for deg in [1, 2, 3, 4]:
212
+ coef2 = np.polyfit(x_norm, y2, deg)
213
+ coef3 = np.polyfit(x_norm, y3, deg)
214
+ pred2 = np.polyval(coef2, x_norm)
215
+ pred3 = np.polyval(coef3, x_norm)
216
+ poly_results[f'degree_{deg}'] = {
217
+ 'r2_pc2': float(r2(y2, pred2)),
218
+ 'r2_pc3': float(r2(y3, pred3)),
219
+ }
220
+
221
+ # Trigonometric fit: y = a + bΒ·sin(Ο€x) + cΒ·cos(Ο€x) + dΒ·sin(2Ο€x) + eΒ·cos(2Ο€x)
222
+ def trig_basis(x):
223
+ return np.column_stack([
224
+ np.ones_like(x),
225
+ np.sin(np.pi * x), np.cos(np.pi * x),
226
+ np.sin(2 * np.pi * x), np.cos(2 * np.pi * x),
227
+ ])
228
+
229
+ B = trig_basis(x_norm)
230
+ coef2_t, _, _, _ = np.linalg.lstsq(B, y2, rcond=None)
231
+ coef3_t, _, _, _ = np.linalg.lstsq(B, y3, rcond=None)
232
+ trig_r2_pc2 = r2(y2, B @ coef2_t)
233
+ trig_r2_pc3 = r2(y3, B @ coef3_t)
234
+
235
+ # Pick the best fit
236
+ best_poly_deg = max([1, 2, 3, 4],
237
+ key=lambda d: poly_results[f'degree_{d}']['r2_pc2'])
238
+ best_poly_r2 = poly_results[f'degree_{best_poly_deg}']['r2_pc2']
239
+
240
+ return {
241
+ 'polynomial': poly_results,
242
+ 'trigonometric': {
243
+ 'r2_pc2': float(trig_r2_pc2),
244
+ 'r2_pc3': float(trig_r2_pc3),
245
+ 'coefs_pc2': coef2_t.tolist(),
246
+ },
247
+ 'best_poly_degree': best_poly_deg,
248
+ 'best_poly_r2': float(best_poly_r2),
249
+ 'verdict': (
250
+ f'polynomial degree {best_poly_deg} (RΒ²={best_poly_r2:.3f})'
251
+ if best_poly_r2 > 0.95 else
252
+ f'trigonometric (RΒ²={trig_r2_pc2:.3f})'
253
+ if trig_r2_pc2 > 0.95 else
254
+ f'no clean parametric fit (best poly RΒ²={best_poly_r2:.3f}, '
255
+ f'trig RΒ²={trig_r2_pc2:.3f})'
256
+ ),
257
+ }
258
+
259
+
260
+ def test_cluster_structure(M_avg):
261
+ """Test 4: k-means + silhouette across k = 2..8.
262
+
263
+ High silhouette at small k β†’ genuine clusters. Low silhouette across
264
+ all k β†’ continuous spread (consistent with parametric structure).
265
+ """
266
+ results = {}
267
+ best_k = None
268
+ best_score = -1
269
+ for k in range(2, min(9, M_avg.shape[0])):
270
+ km = KMeans(n_clusters=k, n_init=10, random_state=42)
271
+ labels = km.fit_predict(M_avg)
272
+ if len(set(labels)) < 2:
273
+ continue
274
+ score = silhouette_score(M_avg, labels)
275
+ results[f'k={k}'] = {
276
+ 'silhouette': float(score),
277
+ 'inertia': float(km.inertia_),
278
+ }
279
+ if score > best_score:
280
+ best_score = score
281
+ best_k = k
282
+
283
+ return {
284
+ 'per_k': results,
285
+ 'best_k': best_k,
286
+ 'best_silhouette': float(best_score),
287
+ 'verdict': (
288
+ f'strong clusters (k={best_k}, silhouette={best_score:.3f})'
289
+ if best_score > 0.5 else
290
+ f'weak clusters (k={best_k}, silhouette={best_score:.3f})'
291
+ if best_score > 0.25 else
292
+ f'no clear clusters (best silhouette={best_score:.3f}) β€” '
293
+ f'consistent with continuous structure'
294
+ ),
295
+ }
296
+
297
+
298
+ # ════════════════════════════════════════════════════════════════════
299
+ # Plotting
300
+ # ════════════════════════════════════════════════════════════════════
301
+
302
+ def plot_diagnostic(M_avg, all_M, results, output_path):
303
+ """4-panel diagnostic plot."""
304
+ fig = plt.figure(figsize=(16, 12))
305
+
306
+ # Panel 1: 3D scatter of the canonical 32 rows
307
+ ax1 = fig.add_subplot(2, 2, 1, projection='3d')
308
+ U, S, Vt = np.linalg.svd(M_avg, full_matrices=False)
309
+ proj = M_avg @ Vt.T
310
+ sort_idx = np.argsort(proj[:, 0])
311
+ colors = plt.cm.viridis(np.linspace(0, 1, len(sort_idx)))
312
+ for i, idx in enumerate(sort_idx):
313
+ ax1.scatter(M_avg[idx, 0], M_avg[idx, 1], M_avg[idx, 2],
314
+ c=[colors[i]], s=80, edgecolors='black', linewidths=0.5)
315
+ ax1.set_xlabel('D1')
316
+ ax1.set_ylabel('D2')
317
+ ax1.set_zlabel('D3')
318
+ ax1.set_title(f'P-rank09 row matrix M (V=32, D=3)\n'
319
+ f'colored by PC1 sort order\n'
320
+ f'effective rank: {results["rank"]["effective_rank"]:.2f}')
321
+
322
+ # Panel 2: Singular value spectrum
323
+ ax2 = fig.add_subplot(2, 2, 2)
324
+ SVs = np.array(results['rank']['singular_values'])
325
+ ax2.bar(['SV1', 'SV2', 'SV3'], SVs, color=['red', 'orange', 'yellow'])
326
+ ax2.set_ylabel('Singular value')
327
+ ax2.set_title(f'Singular values of M\n'
328
+ f'top1 share: {results["rank"]["top1_share"]:.2%}\n'
329
+ f'verdict: {results["rank"]["verdict"]}')
330
+ for i, sv in enumerate(SVs):
331
+ ax2.text(i, sv, f'{sv:.3f}', ha='center', va='bottom')
332
+
333
+ # Panel 3: PC2 and PC3 vs PC1 (parametric curve test)
334
+ ax3 = fig.add_subplot(2, 2, 3)
335
+ x = proj[sort_idx, 0]
336
+ y2 = proj[sort_idx, 1]
337
+ y3 = proj[sort_idx, 2]
338
+ ax3.plot(x, y2, 'o-', color='blue', label='PC2 vs PC1', markersize=6)
339
+ ax3.plot(x, y3, 's-', color='green', label='PC3 vs PC1', markersize=6)
340
+ ax3.set_xlabel('PC1 projection')
341
+ ax3.set_ylabel('PC2 / PC3 projection')
342
+ ax3.set_title(f'Parametric ordering test\n'
343
+ f'smoothness PC2: {results["parametric"]["smoothness_pc2"]:.3f}, '
344
+ f'PC3: {results["parametric"]["smoothness_pc3"]:.3f}\n'
345
+ f'verdict: {results["parametric"]["verdict"]}')
346
+ ax3.legend()
347
+ ax3.grid(alpha=0.3)
348
+
349
+ # Panel 4: Cluster silhouette across k
350
+ ax4 = fig.add_subplot(2, 2, 4)
351
+ ks = []
352
+ sils = []
353
+ for k_str, r in results['cluster']['per_k'].items():
354
+ ks.append(int(k_str.split('=')[1]))
355
+ sils.append(r['silhouette'])
356
+ ax4.plot(ks, sils, 'o-', color='purple', markersize=8)
357
+ ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5,
358
+ label='strong cluster threshold')
359
+ ax4.axhline(0.25, color='orange', linestyle='--', alpha=0.5,
360
+ label='weak cluster threshold')
361
+ ax4.set_xlabel('k (number of clusters)')
362
+ ax4.set_ylabel('silhouette score')
363
+ ax4.set_title(f'Cluster structure test\n'
364
+ f'best k={results["cluster"]["best_k"]}, '
365
+ f'silhouette={results["cluster"]["best_silhouette"]:.3f}\n'
366
+ f'verdict: {results["cluster"]["verdict"]}')
367
+ ax4.legend(fontsize=8)
368
+ ax4.grid(alpha=0.3)
369
+
370
+ plt.tight_layout()
371
+ plt.savefig(output_path, dpi=120, bbox_inches='tight')
372
+ plt.show()
373
+
374
+
375
+ # ════════════════════════════════════════════════════════════════════
376
+ # Main
377
+ # ════════════════════════════════════════════════════════════════════
378
+
379
+ def main():
380
+ print("Loading P-rank09 model...")
381
+ model, cfg = load_rank09()
382
+ print(f" Architecture: V={cfg.matrix_v}, D={cfg.D}, "
383
+ f"patch_size={cfg.patch_size}, hidden={cfg.hidden}")
384
+ n_params = sum(p.numel() for p in model.parameters())
385
+ print(f" Parameters: {n_params:,}")
386
+
387
+ print("\nCollecting M rows from gaussian inputs...")
388
+ all_M = collect_rows(model, cfg, n_batches=8, batch_size=64)
389
+ print(f" Collected {all_M.shape[0]} samples of M [V={all_M.shape[1]}, "
390
+ f"D={all_M.shape[2]}]")
391
+
392
+ # Average M over samples to get the canonical row matrix
393
+ M_avg = all_M.mean(dim=0).numpy()
394
+ M_std = all_M.std(dim=0).numpy()
395
+ print(f" M_avg shape: {M_avg.shape}")
396
+ print(f" Per-row variability (mean β€–Οƒβ€–β‚‚ across rows): "
397
+ f"{np.linalg.norm(M_std, axis=1).mean():.4f}")
398
+ print(f" Per-row mean magnitude (mean β€–ΞΌβ€–β‚‚): "
399
+ f"{np.linalg.norm(M_avg, axis=1).mean():.4f}")
400
+
401
+ # Sphere-norm verification
402
+ row_norms = np.linalg.norm(M_avg, axis=1)
403
+ print(f" Row norm range: [{row_norms.min():.4f}, {row_norms.max():.4f}]")
404
+ print(f" (sphere-normed rows should all have norm ~1.0)")
405
+
406
+ print("\n" + "═" * 70)
407
+ print("HYPOTHESIS TESTS")
408
+ print("═" * 70)
409
+
410
+ print("\n[1/4] Rank structure (SVD)...")
411
+ rank_results = test_rank_structure(M_avg)
412
+ print(f" Singular values: {[f'{s:.4f}' for s in rank_results['singular_values']]}")
413
+ print(f" Effective rank: {rank_results['effective_rank']:.2f}")
414
+ print(f" Top-1 share: {rank_results['top1_share']:.2%}")
415
+ print(f" VERDICT: {rank_results['verdict']}")
416
+
417
+ print("\n[2/4] Parametric ordering (PC1 sort + smoothness)...")
418
+ param_results = test_parametric_ordering(M_avg)
419
+ print(f" Smoothness PC2: {param_results['smoothness_pc2']:.3f}")
420
+ print(f" Smoothness PC3: {param_results['smoothness_pc3']:.3f}")
421
+ print(f" VERDICT: {param_results['verdict']}")
422
+
423
+ print("\n[3/4] Polynomial / trigonometric fit...")
424
+ fit_results = test_polynomial_fit(M_avg)
425
+ print(f" Polynomial fits (RΒ² for PC2):")
426
+ for deg in [1, 2, 3, 4]:
427
+ r2 = fit_results['polynomial'][f'degree_{deg}']['r2_pc2']
428
+ print(f" degree {deg}: RΒ² = {r2:.4f}")
429
+ print(f" Trigonometric fit (RΒ² for PC2): "
430
+ f"{fit_results['trigonometric']['r2_pc2']:.4f}")
431
+ print(f" VERDICT: {fit_results['verdict']}")
432
+
433
+ print("\n[4/4] Cluster structure (k-means silhouette)...")
434
+ cluster_results = test_cluster_structure(M_avg)
435
+ print(f" Per-k silhouette:")
436
+ for k_str, r in cluster_results['per_k'].items():
437
+ print(f" {k_str}: silhouette = {r['silhouette']:.3f}")
438
+ print(f" VERDICT: {cluster_results['verdict']}")
439
+
440
+ all_results = {
441
+ 'config': {
442
+ 'variant': 'P_rank09_h64_V32_D3_dp0_nx0_adam',
443
+ 'V': cfg.matrix_v, 'D': cfg.D, 'params': n_params,
444
+ 'gaussian_test_mse': 0.02782,
445
+ 'observed_cv': 0.035,
446
+ },
447
+ 'M_avg_shape': list(M_avg.shape),
448
+ 'row_norms_mean': float(row_norms.mean()),
449
+ 'row_norms_std': float(row_norms.std()),
450
+ 'rank': rank_results,
451
+ 'parametric': param_results,
452
+ 'fit': fit_results,
453
+ 'cluster': cluster_results,
454
+ }
455
+
456
+ print("\n" + "═" * 70)
457
+ print("OVERALL INTERPRETATION")
458
+ print("═" * 70)
459
+ print(f" Rank: {rank_results['verdict']}")
460
+ print(f" Parametric: {param_results['verdict']}")
461
+ print(f" Fit: {fit_results['verdict']}")
462
+ print(f" Clusters: {cluster_results['verdict']}")
463
+
464
+ # Composite verdict logic
465
+ is_polynomial = (
466
+ fit_results['best_poly_r2'] > 0.95 and
467
+ rank_results['effective_rank'] < 2.5
468
+ )
469
+ is_trig = (
470
+ fit_results['trigonometric']['r2_pc2'] > 0.95 and
471
+ not is_polynomial
472
+ )
473
+ is_clustered = cluster_results['best_silhouette'] > 0.5
474
+ is_collapsed = rank_results['top1_share'] > 0.85
475
+
476
+ print(f"\n Composite read:")
477
+ if is_polynomial:
478
+ deg = fit_results['best_poly_degree']
479
+ print(f" β†’ POLYNOMIAL CONFIRMED (degree {deg}). "
480
+ f"P-Class naming validated.")
481
+ elif is_trig:
482
+ print(f" β†’ TRIGONOMETRIC structure detected. "
483
+ f"P-Class might be better named F-Class (Fourier).")
484
+ elif is_collapsed:
485
+ print(f" β†’ COLLAPSED β€” rows essentially 1-dimensional. "
486
+ f"Failed differentiation, not a useful battery.")
487
+ elif is_clustered:
488
+ k = cluster_results['best_k']
489
+ print(f" β†’ CLUSTERED into {k} groups. "
490
+ f"P-Class might be better named K-Class "
491
+ f"(k-means / quantization).")
492
+ else:
493
+ print(f" β†’ MIXED structure β€” not cleanly polynomial, trig, or "
494
+ f"clustered. Worth probing further with higher-order bases or "
495
+ f"deeper geometric analysis.")
496
+
497
+ with open(OUTPUT_JSON, 'w') as f:
498
+ json.dump(all_results, f, indent=2, default=str)
499
+ print(f"\n Results saved: {OUTPUT_JSON}")
500
+
501
+ plot_diagnostic(M_avg, all_M, all_results, OUTPUT_PLOT)
502
+ print(f" Plot saved: {OUTPUT_PLOT}")
503
+
504
+ return all_results
505
+
506
+
507
+ if __name__ == '__main__':
508
+ results = main()