AbstractPhil commited on
Commit
9f3395f
Β·
verified Β·
1 Parent(s): f4ce5fa

Create 12_test_claim_2_deeper.py

Browse files
Files changed (1) hide show
  1. 12_test_claim_2_deeper.py +579 -0
12_test_claim_2_deeper.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ implicit_solver/A1_projective_reprobe_h2a.py
3
+ ============================================
4
+
5
+ Same projective probe as A0, applied to H2a (Q-rank02, V=32, D=4).
6
+
7
+ Tests whether the projective interpretation generalizes:
8
+ - A0 found G-Cand (D=3) has uniform distribution on ℝPΒ² when collapsed.
9
+ - A1 tests whether H2a (D=4) shows the same on ℝPΒ³.
10
+
11
+ Predicted outcomes
12
+ ------------------
13
+ A. UNIFORM ℝPΒ³ ALSO: H2a's rows collapse to N axes uniformly distributed
14
+ on ℝPΒ³ (deviation from baseline < 0.05). Projective reading is
15
+ GENERAL β€” works at any D. Polygonal omega derivation via sphere
16
+ training is validated as a method, not a D=3 quirk.
17
+
18
+ B. STILL SPHERICAL: H2a shows few antipodal pairs (< 4), and what few
19
+ axes get collapsed don't show uniform ℝPΒ³ distribution. Projective
20
+ reading is D=3-SPECIFIC β€” sphere-starvation symptom. D=4 genuinely
21
+ lives on SΒ³ as designed.
22
+
23
+ C. INTERMEDIATE: Some collapse but not full uniform. Mixed regime.
24
+
25
+ Cost: ~10 seconds (same checkpoint we already have).
26
+
27
+ Output
28
+ ------
29
+ /content/implicit_solver_reports/A1_projective_reprobe_h2a.json
30
+ /content/implicit_solver_reports/A1_projective_reprobe_h2a.png
31
+ """
32
+
33
+ import json
34
+ import math
35
+ from pathlib import Path
36
+
37
+ import numpy as np
38
+ import torch
39
+ import matplotlib.pyplot as plt
40
+ from mpl_toolkits.mplot3d import Axes3D # noqa
41
+ from sklearn.cluster import KMeans
42
+ from sklearn.metrics import silhouette_score
43
+
44
+ CKPT_DIR = Path("/content/phaseQ_reports")
45
+ RANK02_CKPT = CKPT_DIR / "Q_rank02_h64_V32_D4_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
46
+
47
+ OUTPUT_DIR = Path("/content/implicit_solver_reports")
48
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
49
+ OUTPUT_PLOT = OUTPUT_DIR / "A1_projective_reprobe_h2a.png"
50
+ OUTPUT_JSON = OUTPUT_DIR / "A1_projective_reprobe_h2a.json"
51
+
52
+
53
+ # ════════════════════════════════════════════════════════════════════
54
+ # Loading
55
+ # ════════════════════════════════════════════════════════════════════
56
+
57
+ def load_h2a():
58
+ cfgs = get_phaseQ_configs()
59
+ cfg_dict = next(c for c in cfgs if 'rank02' in c['variant'])
60
+ cfg = build_run_config(cfg_dict)
61
+ overrides = cfg_dict['overrides']
62
+
63
+ model = PatchSVAE_F_Ablation(
64
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
65
+ hidden=cfg.hidden, depth=cfg.depth,
66
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
67
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
68
+ alpha_init=cfg.alpha_init,
69
+ activation=overrides.get('activation', 'gelu'),
70
+ row_norm=overrides.get('row_norm', 'sphere'),
71
+ svd_mode=overrides.get('svd', 'fp64'),
72
+ linear_readout=overrides.get('linear_readout', False),
73
+ match_params=overrides.get('match_params', True),
74
+ init_scheme=overrides.get('init', 'orthogonal'),
75
+ )
76
+
77
+ ckpt = torch.load(RANK02_CKPT, map_location='cpu', weights_only=False)
78
+ state_dict = (
79
+ ckpt.get('model_state')
80
+ or ckpt.get('model_state_dict')
81
+ or ckpt.get('state_dict')
82
+ or ckpt
83
+ )
84
+ model.load_state_dict(state_dict)
85
+ model.eval()
86
+ return model, cfg
87
+
88
+
89
+ def collect_per_sample_M(model, cfg, n_batches=8, batch_size=64):
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+ model = model.to(device)
92
+ ds = OmegaNoiseDataset(
93
+ size=n_batches * batch_size, img_size=cfg.img_size,
94
+ allowed_types=[0])
95
+ loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False)
96
+
97
+ all_M = []
98
+ with torch.no_grad():
99
+ for imgs, _ in loader:
100
+ imgs = imgs.to(device)
101
+ out = model(imgs)
102
+ M_patch0 = out['svd']['M'][:, 0]
103
+ all_M.append(M_patch0.cpu())
104
+ return torch.cat(all_M, dim=0).numpy()
105
+
106
+
107
+ # ════════════════════════════════════════════════════════════════════
108
+ # Antipodal pair identification + projective collapse (carry from A0)
109
+ # ════════════════════════════════════════════════════════════════════
110
+
111
+ def identify_antipodal_pairs(M_avg, threshold=-0.9):
112
+ """Greedy mutual-strongest matching."""
113
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
114
+ unit = M_avg / np.clip(norms, 1e-12, None)
115
+ cosines = unit @ unit.T
116
+ np.fill_diagonal(cosines, 1.0)
117
+
118
+ V = M_avg.shape[0]
119
+ claimed = [False] * V
120
+ pairs = []
121
+
122
+ candidates = []
123
+ for i in range(V):
124
+ best_j = int(cosines[i].argmin())
125
+ best_cos = float(cosines[i, best_j])
126
+ if best_cos < threshold:
127
+ candidates.append((best_cos, i, best_j))
128
+ candidates.sort()
129
+
130
+ for cos_val, i, j in candidates:
131
+ if claimed[i] or claimed[j]:
132
+ continue
133
+ if cosines[j].argmin() == i or cosines[j, i] < threshold:
134
+ pairs.append((min(i, j), max(i, j)))
135
+ claimed[i] = True
136
+ claimed[j] = True
137
+
138
+ unpaired = [i for i in range(V) if not claimed[i]]
139
+ return pairs, unpaired
140
+
141
+
142
+ def collapse_to_axes(M_avg, pairs, unpaired):
143
+ """For each pair, take (row_i - row_j)/2 normalized β€” symmetric merge.
144
+ For unpaired, take the row as-is. Canonicalize sign so first nonzero
145
+ coordinate is positive."""
146
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
147
+ unit = M_avg / np.clip(norms, 1e-12, None)
148
+
149
+ representatives = []
150
+ for i, j in pairs:
151
+ merged = unit[i] - unit[j]
152
+ merged = merged / max(np.linalg.norm(merged), 1e-12)
153
+ for k in range(merged.shape[0]):
154
+ if abs(merged[k]) > 1e-6:
155
+ if merged[k] < 0:
156
+ merged = -merged
157
+ break
158
+ representatives.append(merged)
159
+
160
+ for i in unpaired:
161
+ r = unit[i].copy()
162
+ for k in range(r.shape[0]):
163
+ if abs(r[k]) > 1e-6:
164
+ if r[k] < 0:
165
+ r = -r
166
+ break
167
+ representatives.append(r)
168
+
169
+ return np.array(representatives)
170
+
171
+
172
+ # ════════════════════════════════════════════════════════════════════
173
+ # Projective metrics
174
+ # ════════════════════════════════════════════════════════════════════
175
+
176
+ def projective_pairwise_angles(axes):
177
+ """Angles on ℝP^(D-1): wrap [0, Ο€] β†’ [0, Ο€/2] via min(ΞΈ, Ο€-ΞΈ)."""
178
+ n = axes.shape[0]
179
+ cosines = axes @ axes.T
180
+ cosines = np.clip(cosines, -1, 1)
181
+ raw_angles = np.arccos(cosines)
182
+ proj_angles = np.minimum(raw_angles, np.pi - raw_angles)
183
+ triu = np.triu_indices(n, k=1)
184
+ return proj_angles[triu]
185
+
186
+
187
+ def uniform_rp_pairwise_angle_baseline(D, n_axes, n_trials=10):
188
+ """Empirical baseline: sample n_axes uniformly on ℝP^(D-1),
189
+ compute mean projective pairwise angle. Higher D β†’ higher baseline."""
190
+ rng = np.random.RandomState(0)
191
+ means = []
192
+ for _ in range(n_trials):
193
+ x = rng.randn(n_axes, D)
194
+ x = x / np.linalg.norm(x, axis=1, keepdims=True)
195
+ # Canonicalize to upper hemisphere
196
+ for k in range(D):
197
+ mask = (x[:, k] != 0) & (np.all(x[:, :k] == 0, axis=1) if k > 0 else np.ones(n_axes, dtype=bool))
198
+ x[mask] = x[mask] * np.sign(x[mask, k:k+1])
199
+ if not np.any(mask):
200
+ break
201
+ angles = projective_pairwise_angles(x)
202
+ means.append(angles.mean())
203
+ return float(np.mean(means))
204
+
205
+
206
+ def test_axis_distribution(axes, label):
207
+ D = axes.shape[1]
208
+ n = axes.shape[0]
209
+
210
+ print(f"\n[{label}]")
211
+ print(f" Axes shape: {axes.shape}")
212
+
213
+ proj_angles = projective_pairwise_angles(axes)
214
+
215
+ print(f" Projective pairwise angles (radians, max Ο€/2={math.pi/2:.3f}):")
216
+ print(f" mean: {proj_angles.mean():.3f}")
217
+ print(f" median: {np.median(proj_angles):.3f}")
218
+ print(f" min: {proj_angles.min():.3f}")
219
+ print(f" max: {proj_angles.max():.3f}")
220
+
221
+ uniform_baseline = uniform_rp_pairwise_angle_baseline(D, n)
222
+ deviation = proj_angles.mean() - uniform_baseline
223
+ print(f" Uniform ℝP^{D-1} baseline (n={n}): {uniform_baseline:.3f}")
224
+ print(f" Deviation: {deviation:+.3f} "
225
+ f"({'CLOSE TO UNIFORM' if abs(deviation) < 0.05 else 'NON-UNIFORM'})")
226
+
227
+ fraction_clustered = (proj_angles < 0.3).mean()
228
+ print(f" Fraction near-zero (axes parallel): {fraction_clustered:.3f}")
229
+
230
+ sils = []
231
+ for k in range(2, min(8, n)):
232
+ try:
233
+ km = KMeans(n_clusters=k, n_init=10, random_state=42)
234
+ labels = km.fit_predict(axes)
235
+ if len(set(labels)) >= 2:
236
+ sils.append((k, silhouette_score(axes, labels)))
237
+ except Exception:
238
+ pass
239
+
240
+ if sils:
241
+ best_k, best_sil = max(sils, key=lambda x: x[1])
242
+ print(f" Best cluster k={best_k}, silhouette={best_sil:.3f}")
243
+ cluster_verdict = (
244
+ 'STRONG (real clusters)' if best_sil > 0.5 else
245
+ 'WEAK (some structure)' if best_sil > 0.3 else
246
+ 'NONE (continuous distribution)'
247
+ )
248
+ print(f" Cluster verdict: {cluster_verdict}")
249
+ else:
250
+ best_k, best_sil = None, None
251
+ cluster_verdict = 'N/A'
252
+
253
+ sv = np.linalg.svd(axes, compute_uv=False)
254
+ sv_norm = sv / sv.sum()
255
+ erank = math.exp(-(sv_norm * np.log(sv_norm + 1e-12)).sum())
256
+ print(f" Effective rank: {erank:.2f} of {D} possible "
257
+ f"({erank/D*100:.0f}% utilization)")
258
+
259
+ cos_axes = axes @ axes.T
260
+ np.fill_diagonal(cos_axes, 1.0)
261
+ most_anti = cos_axes.min(axis=1)
262
+ secondary_anti = (most_anti < -0.9).sum() // 2
263
+ print(f" Secondary antipodal pairs: {secondary_anti}/{n//2}")
264
+
265
+ return {
266
+ 'n_axes': int(n),
267
+ 'D': int(D),
268
+ 'proj_angle_mean': float(proj_angles.mean()),
269
+ 'proj_angle_median': float(np.median(proj_angles)),
270
+ 'proj_angle_min': float(proj_angles.min()),
271
+ 'proj_angle_max': float(proj_angles.max()),
272
+ 'uniform_baseline': uniform_baseline,
273
+ 'deviation_from_uniform': float(deviation),
274
+ 'fraction_clustered': float(fraction_clustered),
275
+ 'best_cluster_k': best_k,
276
+ 'best_silhouette': best_sil,
277
+ 'cluster_verdict': cluster_verdict,
278
+ 'effective_rank': float(erank),
279
+ 'utilization': float(erank / D),
280
+ 'secondary_antipodal_pairs': int(secondary_anti),
281
+ 'proj_angles_subset': proj_angles[:200].tolist(),
282
+ }
283
+
284
+
285
+ # ════════════════════════════════════════════════════════════════════
286
+ # Plotting
287
+ # ════════════════════════════════════════════════════════════════════
288
+
289
+ def plot_projective(M_avg, axes, pairs, unpaired, results, output_path,
290
+ g_cand_results=None):
291
+ """Same 6-panel layout as A0, but for D=4 we project to first 3 dims
292
+ for the 3D scatter panels. Adds optional comparison lines from A0."""
293
+ fig = plt.figure(figsize=(18, 12))
294
+
295
+ # Panel 1: Original M_avg projected to first 3 dims
296
+ ax1 = fig.add_subplot(2, 3, 1, projection='3d')
297
+ norms = np.linalg.norm(M_avg, axis=1, keepdims=True)
298
+ unit = M_avg / np.clip(norms, 1e-12, None)
299
+
300
+ u = np.linspace(0, 2*np.pi, 20)
301
+ v = np.linspace(0, np.pi, 20)
302
+ x_s = np.outer(np.cos(u), np.sin(v))
303
+ y_s = np.outer(np.sin(u), np.sin(v))
304
+ z_s = np.outer(np.ones_like(u), np.cos(v))
305
+ ax1.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
306
+
307
+ pair_colors = plt.cm.tab20(np.linspace(0, 1, max(len(pairs), 1)))
308
+ for k, (i, j) in enumerate(pairs):
309
+ color = pair_colors[k]
310
+ ax1.scatter(unit[i, 0], unit[i, 1], unit[i, 2],
311
+ c=[color], s=80, edgecolors='black', linewidths=0.5)
312
+ ax1.scatter(unit[j, 0], unit[j, 1], unit[j, 2],
313
+ c=[color], s=80, edgecolors='black', linewidths=0.5)
314
+ ax1.plot([unit[i, 0], unit[j, 0]],
315
+ [unit[i, 1], unit[j, 1]],
316
+ [unit[i, 2], unit[j, 2]],
317
+ color=color, alpha=0.3, linewidth=0.8)
318
+ for i in unpaired:
319
+ ax1.scatter(unit[i, 0], unit[i, 1], unit[i, 2],
320
+ c='blue', marker='o', s=80,
321
+ edgecolors='black', linewidths=0.5, alpha=0.7)
322
+ ax1.set_title(f'H2a M_avg projected to first 3 dims\n'
323
+ f'{len(pairs)} antipodal pairs (colored), '
324
+ f'{len(unpaired)} unpaired (blue)')
325
+
326
+ # Panel 2: Collapsed axes (first 3 dims)
327
+ ax2 = fig.add_subplot(2, 3, 2, projection='3d')
328
+ ax2.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
329
+ for k, ax in enumerate(axes):
330
+ ax2.scatter(ax[0], ax[1], ax[2], c=[plt.cm.tab20(k % 20)],
331
+ s=120, edgecolors='black', linewidths=0.5)
332
+ ax2.plot([-ax[0], ax[0]], [-ax[1], ax[1]], [-ax[2], ax[2]],
333
+ color=plt.cm.tab20(k % 20), alpha=0.4, linewidth=1.0)
334
+ ax2.set_title(f'Collapsed axes (n={axes.shape[0]})\n'
335
+ f'D={axes.shape[1]} β†’ projected to first 3 dims')
336
+
337
+ # Panel 3: Projective angle distribution + uniform baseline + G-Cand overlay
338
+ ax3 = fig.add_subplot(2, 3, 3)
339
+ proj_angles = results['proj_angles_subset']
340
+ ax3.hist(proj_angles, bins=30, density=True, alpha=0.7,
341
+ color='steelblue', label=f'H2a projective (D={results["D"]})')
342
+ if g_cand_results is not None:
343
+ ax3.hist(g_cand_results['proj_angles_subset'], bins=30, density=True,
344
+ alpha=0.4, color='red', label='G-Cand projective (D=3)')
345
+ ax3.axvline(results['uniform_baseline'], color='blue', linestyle='--',
346
+ label=f"H2a uniform ℝPΒ³ ({results['uniform_baseline']:.3f})")
347
+ if g_cand_results is not None:
348
+ ax3.axvline(g_cand_results['uniform_baseline'], color='red',
349
+ linestyle=':', alpha=0.5,
350
+ label=f"G-Cand uniform ℝPΒ² ({g_cand_results['uniform_baseline']:.3f})")
351
+ ax3.set_xlabel('Projective pairwise angle (radians)')
352
+ ax3.set_ylabel('Density')
353
+ ax3.set_title(f'Projective angle distribution\n'
354
+ f"H2a deviation: {results['deviation_from_uniform']:+.3f}")
355
+ ax3.legend(fontsize=8)
356
+
357
+ # Panel 4: Cluster silhouette across k
358
+ ax4 = fig.add_subplot(2, 3, 4)
359
+ if results['best_cluster_k'] is not None:
360
+ ks_sils = []
361
+ for k in range(2, min(8, axes.shape[0])):
362
+ try:
363
+ km = KMeans(n_clusters=k, n_init=10, random_state=42)
364
+ labels = km.fit_predict(axes)
365
+ if len(set(labels)) >= 2:
366
+ ks_sils.append((k, silhouette_score(axes, labels)))
367
+ except Exception:
368
+ pass
369
+ if ks_sils:
370
+ ks, sils = zip(*ks_sils)
371
+ ax4.plot(ks, sils, 'o-', color='purple', markersize=8)
372
+ ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5,
373
+ label='strong cluster')
374
+ ax4.axhline(0.3, color='orange', linestyle='--', alpha=0.5,
375
+ label='weak cluster')
376
+ ax4.set_xlabel('k (number of clusters)')
377
+ ax4.set_ylabel('silhouette score')
378
+ ax4.set_title(f"Axis clustering\n"
379
+ f"verdict: {results['cluster_verdict']}")
380
+ ax4.legend(fontsize=8)
381
+ ax4.grid(alpha=0.3)
382
+
383
+ # Panel 5: Singular values
384
+ ax5 = fig.add_subplot(2, 3, 5)
385
+ sv = np.linalg.svd(axes, compute_uv=False)
386
+ ax5.bar([f'Οƒ{i+1}' for i in range(len(sv))], sv,
387
+ color=plt.cm.viridis(np.linspace(0.2, 0.8, len(sv))))
388
+ ax5.set_ylabel('Singular value')
389
+ ax5.set_title(f"Singular values of axis matrix\n"
390
+ f"effective rank: {results['effective_rank']:.2f} "
391
+ f"of {results['D']}")
392
+
393
+ # Panel 6: Comparison verdict
394
+ ax6 = fig.add_subplot(2, 3, 6)
395
+ ax6.axis('off')
396
+
397
+ is_uniform = abs(results['deviation_from_uniform']) < 0.05
398
+ is_clustered = (results['best_silhouette'] or 0) > 0.5
399
+ has_secondary = results['secondary_antipodal_pairs'] >= 3
400
+ full_rank = results['utilization'] > 0.95
401
+
402
+ if is_uniform and not is_clustered and not has_secondary and full_rank:
403
+ verdict = "βœ“ ALSO ℝPΒ³ UNIFORM"
404
+ explanation = (
405
+ "H2a's collapsed axes are uniformly distributed on ℝPΒ³.\n"
406
+ "Projective interpretation GENERALIZES beyond D=3.\n\n"
407
+ "Sphere-solvers in general are projective at the level of\n"
408
+ "their geometric output. Polygonal omega derivation via\n"
409
+ "sphere-trained anchors is validated as a method."
410
+ )
411
+ color = 'lightgreen'
412
+ elif results['n_axes'] >= results['D'] * 6 and full_rank:
413
+ # Many axes, full rank β†’ still strongly spherical
414
+ verdict = "βœ— STILL ESSENTIALLY SPHERICAL"
415
+ explanation = (
416
+ f"H2a has {results['n_axes']} axes (vs G-Cand's smaller count),\n"
417
+ f"few antipodal pairs were identified, full rank utilization.\n\n"
418
+ f"Projective collapse barely changes the picture at D=4.\n"
419
+ f"D=3 was a special case β€” sphere-starvation symptom.\n"
420
+ f"D=4 lives on SΒ³ as designed."
421
+ )
422
+ color = 'lightyellow'
423
+ elif is_uniform:
424
+ verdict = "βœ“ MOSTLY ℝPΒ³, full rank"
425
+ explanation = (
426
+ "H2a collapses to axes that are roughly uniform on ℝPΒ³.\n"
427
+ "Projective reading IS valid at D=4 too, with caveats."
428
+ )
429
+ color = 'palegreen'
430
+ else:
431
+ verdict = "? MIXED RESULT"
432
+ explanation = (
433
+ "H2a doesn't cleanly fit either ℝPΒ³ uniform or pure spherical.\n"
434
+ "Geometry is more complex than the simple projective hypothesis."
435
+ )
436
+ color = 'lightgray'
437
+
438
+ ax6.text(0.5, 0.85, verdict, ha='center', va='top',
439
+ fontsize=18, fontweight='bold',
440
+ bbox=dict(boxstyle='round', facecolor=color, alpha=0.8))
441
+ ax6.text(0.05, 0.55, explanation, ha='left', va='top', fontsize=10,
442
+ wrap=True, family='monospace')
443
+
444
+ metrics_summary = (
445
+ f"\n\nKey metrics (H2a):\n"
446
+ f" axes: {results['n_axes']}\n"
447
+ f" proj angle mean: {results['proj_angle_mean']:.3f}\n"
448
+ f" uniform baseline: {results['uniform_baseline']:.3f}\n"
449
+ f" deviation: {results['deviation_from_uniform']:+.3f}\n"
450
+ f" best cluster silhouette: {results['best_silhouette'] or 0:.3f}\n"
451
+ f" effective rank: {results['effective_rank']:.2f}/{results['D']}\n"
452
+ f" secondary antipodal: {results['secondary_antipodal_pairs']}\n"
453
+ )
454
+ if g_cand_results is not None:
455
+ metrics_summary += (
456
+ f"\nG-Cand comparison:\n"
457
+ f" axes: {g_cand_results['n_axes']}\n"
458
+ f" deviation: {g_cand_results['deviation_from_uniform']:+.3f}\n"
459
+ f" best silhouette: {g_cand_results['best_silhouette']:.3f}\n"
460
+ )
461
+ ax6.text(0.05, 0.30, metrics_summary, ha='left', va='top',
462
+ fontsize=9, family='monospace')
463
+
464
+ plt.tight_layout()
465
+ plt.savefig(output_path, dpi=120, bbox_inches='tight')
466
+ plt.show()
467
+
468
+
469
+ # ════════════════════════════════════════════════════════════════════
470
+ # Main
471
+ # ════════════════════════════════════════════════════════════════════
472
+
473
+ def main():
474
+ print("=" * 70)
475
+ print("Projective re-probe of H2a (Q-rank02, V=32, D=4)")
476
+ print("Tests whether projective interpretation generalizes from D=3 β†’ D=4")
477
+ print("=" * 70)
478
+
479
+ print("\nLoading H2a checkpoint...")
480
+ model, cfg = load_h2a()
481
+ print(f" V={cfg.matrix_v}, D={cfg.D}, "
482
+ f"params={sum(p.numel() for p in model.parameters()):,}")
483
+
484
+ print("\nCollecting M tensor (512 gaussian samples)...")
485
+ all_M = collect_per_sample_M(model, cfg)
486
+ M_avg = all_M.mean(axis=0)
487
+ print(f" M_avg shape: {M_avg.shape}")
488
+
489
+ print("\nIdentifying antipodal pairs (cos < -0.9, mutual-strongest)...")
490
+ pairs, unpaired = identify_antipodal_pairs(M_avg, threshold=-0.9)
491
+ print(f" Found {len(pairs)} antipodal pairs")
492
+ print(f" Unpaired rows: {len(unpaired)}")
493
+ print(f" Total accounted: {2*len(pairs) + len(unpaired)} of {M_avg.shape[0]}")
494
+
495
+ print("\nCollapsing to projective axes...")
496
+ axes = collapse_to_axes(M_avg, pairs, unpaired)
497
+ print(f" Axes: {axes.shape[0]} representatives in {axes.shape[1]}-D")
498
+
499
+ results = test_axis_distribution(axes, "H2a projective axes")
500
+
501
+ # Try to load A0 (G-Cand) results for side-by-side comparison
502
+ g_cand_results = None
503
+ g_cand_json = OUTPUT_DIR / "A0_projective_reprobe.json"
504
+ if g_cand_json.exists():
505
+ with open(g_cand_json) as f:
506
+ g_cand_data = json.load(f)
507
+ g_cand_results = g_cand_data['projective_metrics']
508
+ print(f"\n (Loaded A0 G-Cand results for comparison)")
509
+
510
+ output_data = {
511
+ 'config': {
512
+ 'variant': 'Q_rank02_h64_V32_D4_dp0_nx0_adam',
513
+ 'V': cfg.matrix_v,
514
+ 'D': cfg.D,
515
+ },
516
+ 'antipodal_pairs_found': len(pairs),
517
+ 'unpaired_rows': len(unpaired),
518
+ 'total_axes': axes.shape[0],
519
+ 'projective_metrics': results,
520
+ 'pairs': [list(p) for p in pairs],
521
+ 'unpaired': unpaired,
522
+ }
523
+
524
+ with open(OUTPUT_JSON, 'w') as f:
525
+ json.dump(output_data, f, indent=2, default=str)
526
+ print(f"\nSaved: {OUTPUT_JSON}")
527
+
528
+ plot_projective(M_avg, axes, pairs, unpaired, results, OUTPUT_PLOT,
529
+ g_cand_results=g_cand_results)
530
+ print(f"Saved: {OUTPUT_PLOT}")
531
+
532
+ # Headline conclusion
533
+ print("\n" + "=" * 70)
534
+ print("CONCLUSION β€” generalization test")
535
+ print("=" * 70)
536
+
537
+ is_uniform = abs(results['deviation_from_uniform']) < 0.05
538
+ is_clustered = (results['best_silhouette'] or 0) > 0.5
539
+ has_secondary = results['secondary_antipodal_pairs'] >= 3
540
+ full_rank = results['utilization'] > 0.95
541
+
542
+ print(f"\n H2a (D=4, V=32):")
543
+ print(f" {len(pairs)} antipodal pairs, {axes.shape[0]} total axes")
544
+ print(f" Projective angle mean: {results['proj_angle_mean']:.3f}")
545
+ print(f" ℝPΒ³ uniform baseline: {results['uniform_baseline']:.3f}")
546
+ print(f" Deviation: {results['deviation_from_uniform']:+.3f}")
547
+
548
+ if g_cand_results is not None:
549
+ print(f"\n G-Cand (D=3, V=32) for comparison:")
550
+ print(f" {g_cand_data.get('antipodal_pairs_found', '?')} antipodal pairs, "
551
+ f"{g_cand_data.get('total_axes', '?')} total axes")
552
+ print(f" Projective angle mean: {g_cand_results['proj_angle_mean']:.3f}")
553
+ print(f" ℝPΒ² uniform baseline: {g_cand_results['uniform_baseline']:.3f}")
554
+ print(f" Deviation: {g_cand_results['deviation_from_uniform']:+.3f}")
555
+
556
+ print("\n" + "-" * 70)
557
+ if is_uniform and not is_clustered and not has_secondary and full_rank:
558
+ print(" βœ“ PROJECTIVE READING GENERALIZES")
559
+ print(" H2a also collapses to uniform projective distribution.")
560
+ print(" The polytope-implicit-in-sphere hypothesis is supported")
561
+ print(" at D=4 too. Inference-projection framing is general.")
562
+ elif len(pairs) <= 4 and full_rank:
563
+ print(" βœ— PROJECTIVE READING IS D=3-SPECIFIC")
564
+ print(" H2a has very few antipodal pairs β€” most rows didn't")
565
+ print(" collapse. The projective reading is a sphere-starvation")
566
+ print(" symptom, not a general property of trained sphere-solvers.")
567
+ print(" D=4 lives on SΒ³ as designed.")
568
+ else:
569
+ print(" ? INTERMEDIATE RESULT")
570
+ print(" H2a shows partial collapse with unclear interpretation.")
571
+ print(" Need to think about whether the metric thresholds")
572
+ print(" (uniform deviation, cluster silhouette) are appropriate")
573
+ print(" at higher D where the unfilled space is much larger.")
574
+
575
+ return output_data
576
+
577
+
578
+ if __name__ == '__main__':
579
+ results = main()