AbstractPhil commited on
Commit
6ddd93d
Β·
verified Β·
1 Parent(s): bd4ab66

Create 7_probe_ft2.py

Browse files
Files changed (1) hide show
  1. 7_probe_ft2.py +566 -0
7_probe_ft2.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ cell_p_class_probe_v2.py β€” deeper geometric probe for P-Class
3
+
4
+ Addresses limitations of v1's averaged-M analysis:
5
+ 1. Verify sphere-norm is enforced per-sample (M rows should be unit-length
6
+ per-sample, even if they average to sub-unit across samples)
7
+ 2. Test structure on PER-SAMPLE M, not averaged
8
+ 3. Check if the 5-cluster finding from v1 is consistent or sample-dependent
9
+ 4. Spherical structure analysis: project rows to SΒ², test for angular
10
+ distribution structure (uniform? clustered? band-like?)
11
+ 5. Reconstruct what the H2 sphere-solver looks like for comparison
12
+
13
+ Key question: are the 32 rows really clustered, or does each sample have
14
+ its own spread of 32 rows on SΒ² that AVERAGE to look clustered?
15
+ """
16
+
17
+ import json
18
+ import math
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import matplotlib.pyplot as plt
25
+ from mpl_toolkits.mplot3d import Axes3D # noqa
26
+ from sklearn.cluster import KMeans
27
+ from sklearn.metrics import silhouette_score
28
+
29
+
30
+ CKPT_DIR = Path("/content/phaseQ_reports")
31
+ RANK09_CKPT = CKPT_DIR / "Q_rank09_h64_V32_D3_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
32
+ RANK02_CKPT = CKPT_DIR / "Q_rank02_h64_V32_D4_dp0_nx0_adam" / "epoch_1_checkpoint.pt"
33
+ OUTPUT_PLOT = CKPT_DIR / "p_rank09_probe_v2.png"
34
+ OUTPUT_JSON = CKPT_DIR / "p_rank09_probe_v2.json"
35
+
36
+
37
+ def load_model(variant_str, ckpt_path):
38
+ cfgs = get_phaseQ_configs()
39
+ cfg_dict = next(c for c in cfgs if variant_str in c['variant'])
40
+ cfg = build_run_config(cfg_dict)
41
+ overrides = cfg_dict['overrides']
42
+
43
+ model = PatchSVAE_F_Ablation(
44
+ matrix_v=cfg.matrix_v, D=cfg.D, patch_size=cfg.patch_size,
45
+ hidden=cfg.hidden, depth=cfg.depth,
46
+ n_cross_layers=cfg.n_cross_layers, n_heads=cfg.n_heads,
47
+ max_alpha=overrides.get('max_alpha', cfg.max_alpha),
48
+ alpha_init=cfg.alpha_init,
49
+ activation=overrides.get('activation', 'gelu'),
50
+ row_norm=overrides.get('row_norm', 'sphere'),
51
+ svd_mode=overrides.get('svd', 'fp64'),
52
+ linear_readout=overrides.get('linear_readout', False),
53
+ match_params=overrides.get('match_params', True),
54
+ init_scheme=overrides.get('init', 'orthogonal'),
55
+ )
56
+
57
+ ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
58
+ state_dict = (
59
+ ckpt.get('model_state')
60
+ or ckpt.get('model_state_dict')
61
+ or ckpt.get('state_dict')
62
+ or ckpt
63
+ )
64
+ model.load_state_dict(state_dict)
65
+ model.eval()
66
+ return model, cfg
67
+
68
+
69
+ def collect_per_sample_M(model, cfg, n_batches=8, batch_size=64):
70
+ """Same as v1 but does NOT average β€” returns per-sample M tensors."""
71
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72
+ model = model.to(device)
73
+
74
+ ds = OmegaNoiseDataset(
75
+ size=n_batches * batch_size,
76
+ img_size=cfg.img_size,
77
+ allowed_types=[0])
78
+ loader = torch.utils.data.DataLoader(
79
+ ds, batch_size=batch_size, shuffle=False)
80
+
81
+ all_M = []
82
+ with torch.no_grad():
83
+ for imgs, _ in loader:
84
+ imgs = imgs.to(device)
85
+ out = model(imgs)
86
+ M_patch0 = out['svd']['M'][:, 0]
87
+ all_M.append(M_patch0.cpu())
88
+
89
+ return torch.cat(all_M, dim=0).numpy() # [n_samples, V, D]
90
+
91
+
92
+ # ════════════════════════════════════════════════════════════════════
93
+ # Test 1: Per-sample sphere-norm verification
94
+ # ════════════════════════════════════════════════════════════════════
95
+
96
+ def test_sphere_norm(all_M, label):
97
+ """Verify that per-sample rows are unit-length (sphere-normed)."""
98
+ print(f"\n[{label}] PER-SAMPLE sphere-norm verification:")
99
+
100
+ # all_M shape: [n_samples, V, D]
101
+ row_norms = np.linalg.norm(all_M, axis=2) # [n_samples, V]
102
+
103
+ print(f" Per-sample row norms:")
104
+ print(f" overall min: {row_norms.min():.4f}")
105
+ print(f" overall max: {row_norms.max():.4f}")
106
+ print(f" overall mean: {row_norms.mean():.4f}")
107
+ print(f" overall std: {row_norms.std():.4f}")
108
+
109
+ is_normed = (
110
+ abs(row_norms.mean() - 1.0) < 0.05 and
111
+ row_norms.std() < 0.05
112
+ )
113
+ print(f" Sphere-norm enforced per-sample: {is_normed}")
114
+
115
+ return {
116
+ 'row_norms_min': float(row_norms.min()),
117
+ 'row_norms_max': float(row_norms.max()),
118
+ 'row_norms_mean': float(row_norms.mean()),
119
+ 'row_norms_std': float(row_norms.std()),
120
+ 'sphere_normed_per_sample': bool(is_normed),
121
+ }
122
+
123
+
124
+ # ════════════════════════════════════════════════════════════════════
125
+ # Test 2: Sample-to-sample row stability
126
+ # ══════════════════��═════════════════════════════════════════════════
127
+
128
+ def test_row_stability(all_M, label):
129
+ """For each row index i in [0, V), how much does row i vary across
130
+ samples? If rows are stable (each row index always points the same
131
+ direction), per-sample structure β‰ˆ averaged structure. If unstable,
132
+ averaging blurs structure."""
133
+ print(f"\n[{label}] PER-ROW stability across samples:")
134
+
135
+ # all_M: [n_samples, V, D]
136
+ # For each row index, compute mean direction and variance around it
137
+ n_samples, V, D = all_M.shape
138
+
139
+ # Mean direction per row index (re-normalized to unit)
140
+ mean_dirs = all_M.mean(axis=0) # [V, D]
141
+ mean_dir_norms = np.linalg.norm(mean_dirs, axis=1) # [V]
142
+
143
+ # If sample row directions are tightly clustered around their mean,
144
+ # mean_dir_norm β‰ˆ 1.0. If they're scattered uniformly, mean_dir_norm β‰ˆ 0.
145
+ # This is the "spread index" β€” how concentrated each row index's
146
+ # direction is across samples.
147
+ print(f" Mean direction norms (concentration of row[i] across samples):")
148
+ print(f" min: {mean_dir_norms.min():.4f} (most variable row)")
149
+ print(f" max: {mean_dir_norms.max():.4f} (most stable row)")
150
+ print(f" mean: {mean_dir_norms.mean():.4f}")
151
+
152
+ return {
153
+ 'mean_dir_norms_min': float(mean_dir_norms.min()),
154
+ 'mean_dir_norms_max': float(mean_dir_norms.max()),
155
+ 'mean_dir_norms_mean': float(mean_dir_norms.mean()),
156
+ 'mean_dirs': mean_dirs.tolist(),
157
+ 'mean_dir_norms': mean_dir_norms.tolist(),
158
+ }
159
+
160
+
161
+ # ════════════════════════════════════════════════════════════════════
162
+ # Test 3: Per-sample cluster consistency
163
+ # ════════════════════════════════════════════════════════════════════
164
+
165
+ def test_per_sample_clustering(all_M, k_test=5, n_samples_to_check=20):
166
+ """For each of n_samples_to_check samples, run k-means clustering on its
167
+ own 32 rows. If we consistently get strong clusters at the same k, the
168
+ structure is intrinsic to each sample. If silhouette varies wildly, the
169
+ averaged result was an artifact."""
170
+ print(f"\nPER-SAMPLE k=5 clustering (testing first {n_samples_to_check} samples):")
171
+
172
+ silhouettes = []
173
+ for i in range(min(n_samples_to_check, all_M.shape[0])):
174
+ M = all_M[i] # [V, D]
175
+ try:
176
+ km = KMeans(n_clusters=k_test, n_init=10, random_state=42)
177
+ labels = km.fit_predict(M)
178
+ if len(set(labels)) >= 2:
179
+ sil = silhouette_score(M, labels)
180
+ silhouettes.append(sil)
181
+ except Exception:
182
+ pass
183
+
184
+ silhouettes = np.array(silhouettes)
185
+ print(f" Silhouette across samples (k={k_test}):")
186
+ print(f" mean: {silhouettes.mean():.3f}")
187
+ print(f" std: {silhouettes.std():.3f}")
188
+ print(f" range: [{silhouettes.min():.3f}, {silhouettes.max():.3f}]")
189
+
190
+ return {
191
+ 'k_tested': k_test,
192
+ 'silhouettes_per_sample': silhouettes.tolist(),
193
+ 'mean_silhouette': float(silhouettes.mean()),
194
+ 'std_silhouette': float(silhouettes.std()),
195
+ 'min_silhouette': float(silhouettes.min()) if len(silhouettes) > 0 else None,
196
+ 'max_silhouette': float(silhouettes.max()) if len(silhouettes) > 0 else None,
197
+ }
198
+
199
+
200
+ # ════════════════════════════════════════════════════════════════════
201
+ # Test 4: Angular distribution on the sphere
202
+ # ════════════════════════════════════════════════════════════════════
203
+
204
+ def test_angular_distribution(all_M, label):
205
+ """Project all per-sample row vectors to unit sphere (re-normalize),
206
+ then look at distribution of pairwise angles. Uniform distribution gives
207
+ a specific angular density. Clustered gives bimodal angles. Polar / band
208
+ structures give characteristic patterns."""
209
+ print(f"\n[{label}] ANGULAR DISTRIBUTION:")
210
+
211
+ # Pool all rows from all samples, normalize to unit
212
+ all_rows = all_M.reshape(-1, all_M.shape[-1]) # [n_samples * V, D]
213
+ norms = np.linalg.norm(all_rows, axis=1, keepdims=True)
214
+ unit_rows = all_rows / np.clip(norms, 1e-12, None)
215
+
216
+ # Sample subset for pairwise angle computation
217
+ n_subset = min(500, unit_rows.shape[0])
218
+ idx = np.random.RandomState(42).choice(unit_rows.shape[0], n_subset, replace=False)
219
+ subset = unit_rows[idx]
220
+
221
+ # Pairwise dot products β†’ cosines of pairwise angles
222
+ cosines = subset @ subset.T # [n_subset, n_subset]
223
+ triu_idx = np.triu_indices(n_subset, k=1)
224
+ pairwise_cos = cosines[triu_idx]
225
+ pairwise_angles = np.arccos(np.clip(pairwise_cos, -1, 1)) # radians
226
+
227
+ # For uniform distribution on S^(D-1): angle distribution has known shape
228
+ # For D=3 (S^2): density ∝ sin(ΞΈ), peak at ΞΈ=Ο€/2 (90Β°)
229
+ # For D=4 (S^3): density ∝ sinΒ²(ΞΈ), peak at ΞΈ=Ο€/2
230
+
231
+ mean_angle = float(pairwise_angles.mean())
232
+ median_angle = float(np.median(pairwise_angles))
233
+ expected_uniform_mean = math.pi / 2 # for both D=3 and D=4
234
+
235
+ print(f" Pairwise angle stats (radians):")
236
+ print(f" mean: {mean_angle:.3f} (uniform β‰ˆ Ο€/2 = 1.571)")
237
+ print(f" median: {median_angle:.3f}")
238
+ print(f" deviation from uniform mean: {abs(mean_angle - expected_uniform_mean):.3f}")
239
+
240
+ # Concentrated near small angles β†’ clustered into a few directions
241
+ # Concentrated near Ο€/2 β†’ uniform-like
242
+ # Concentrated near small AND large β†’ bipolar / antipodal pairs
243
+
244
+ near_zero = (pairwise_angles < 0.5).sum() / len(pairwise_angles)
245
+ near_pi = (pairwise_angles > math.pi - 0.5).sum() / len(pairwise_angles)
246
+ near_perp = ((pairwise_angles > math.pi / 2 - 0.3) &
247
+ (pairwise_angles < math.pi / 2 + 0.3)).sum() / len(pairwise_angles)
248
+
249
+ print(f" fraction near 0 (parallel): {near_zero:.3f}")
250
+ print(f" fraction near Ο€ (antiparallel): {near_pi:.3f}")
251
+ print(f" fraction near Ο€/2 (perpendicular): {near_perp:.3f}")
252
+
253
+ return {
254
+ 'mean_angle': mean_angle,
255
+ 'median_angle': median_angle,
256
+ 'expected_uniform_mean': expected_uniform_mean,
257
+ 'fraction_near_zero': float(near_zero),
258
+ 'fraction_near_pi': float(near_pi),
259
+ 'fraction_near_perp': float(near_perp),
260
+ 'pairwise_angles_subset': pairwise_angles[:200].tolist(),
261
+ }
262
+
263
+
264
+ # ════════════════════════════════════════════════════════════════════
265
+ # Test 5: Antipodal structure
266
+ # ════════════════════════════════════════════════════════════════════
267
+
268
+ def test_antipodal(all_M, label):
269
+ """Check if each row has a near-antipodal partner. If 32 rows form
270
+ 16 antipodal pairs, that's a different geometric structure than
271
+ 32 independent points."""
272
+ print(f"\n[{label}] ANTIPODAL STRUCTURE:")
273
+
274
+ mean_dirs = all_M.mean(axis=0) # [V, D]
275
+ norms = np.linalg.norm(mean_dirs, axis=1, keepdims=True)
276
+ unit_dirs = mean_dirs / np.clip(norms, 1e-12, None)
277
+
278
+ # For each row, find nearest negative direction
279
+ cosines = unit_dirs @ unit_dirs.T # [V, V]
280
+ np.fill_diagonal(cosines, 1.0) # exclude self
281
+ most_anti_cos = cosines.min(axis=1) # most negative = closest to antipode
282
+
283
+ # If antipodal structure, each row has a partner with cos β‰ˆ -1
284
+ n_antipodal_pairs = (most_anti_cos < -0.9).sum() // 2
285
+
286
+ print(f" Most-antipodal cos for each row:")
287
+ print(f" min: {most_anti_cos.min():.4f}")
288
+ print(f" mean: {most_anti_cos.mean():.4f}")
289
+ print(f" fraction with antipode (cos < -0.9): "
290
+ f"{(most_anti_cos < -0.9).mean():.3f}")
291
+ print(f" Estimated antipodal pairs: {n_antipodal_pairs} / "
292
+ f"{all_M.shape[1]//2} possible")
293
+
294
+ return {
295
+ 'most_antipodal_cosines_min': float(most_anti_cos.min()),
296
+ 'most_antipodal_cosines_mean': float(most_anti_cos.mean()),
297
+ 'fraction_with_antipode': float((most_anti_cos < -0.9).mean()),
298
+ 'estimated_antipodal_pairs': int(n_antipodal_pairs),
299
+ }
300
+
301
+
302
+ # ════════════════════════════════════════════════════════════════════
303
+ # Test 6: Compare to H2a (Rank 02) on the same metrics
304
+ # ════════════════════════════════════════════════════════════════════
305
+
306
+ def comparison_test(all_M_p, all_M_h2):
307
+ """Side-by-side: P-Class (D=3) vs H2a (D=4). What's the actual
308
+ structural difference?"""
309
+ print("\n" + "═" * 70)
310
+ print("DIRECT COMPARISON: P-Class (D=3) vs H2a (D=4)")
311
+ print("═" * 70)
312
+
313
+ # Effective rank comparison
314
+ M_avg_p = all_M_p.mean(axis=0)
315
+ M_avg_h2 = all_M_h2.mean(axis=0)
316
+
317
+ sv_p = np.linalg.svd(M_avg_p, compute_uv=False)
318
+ sv_h2 = np.linalg.svd(M_avg_h2, compute_uv=False)
319
+
320
+ sv_p_norm = sv_p / sv_p.sum()
321
+ sv_h2_norm = sv_h2 / sv_h2.sum()
322
+
323
+ erank_p = math.exp(-(sv_p_norm * np.log(sv_p_norm + 1e-12)).sum())
324
+ erank_h2 = math.exp(-(sv_h2_norm * np.log(sv_h2_norm + 1e-12)).sum())
325
+
326
+ print(f"\n Effective rank of M_avg:")
327
+ print(f" P-Class (D=3): {erank_p:.2f} of {M_avg_p.shape[1]} possible")
328
+ print(f" H2a (D=4): {erank_h2:.2f} of {M_avg_h2.shape[1]} possible")
329
+ print(f" P uses {erank_p/M_avg_p.shape[1]*100:.0f}% of available dims")
330
+ print(f" H2 uses {erank_h2/M_avg_h2.shape[1]*100:.0f}% of available dims")
331
+
332
+ return {
333
+ 'effective_rank_p': float(erank_p),
334
+ 'effective_rank_h2': float(erank_h2),
335
+ 'p_dim_utilization': float(erank_p / M_avg_p.shape[1]),
336
+ 'h2_dim_utilization': float(erank_h2 / M_avg_h2.shape[1]),
337
+ }
338
+
339
+
340
+ # ════════════════════════════════════════════════════════════════════
341
+ # Plotting
342
+ # ════════════════════════════════════════════════════════════════════
343
+
344
+ def plot_diagnostic(all_M_p, all_M_h2, results, output_path):
345
+ fig = plt.figure(figsize=(18, 12))
346
+
347
+ # Panel 1: Per-sample sphere-norm distribution
348
+ ax1 = fig.add_subplot(2, 3, 1)
349
+ p_norms = np.linalg.norm(all_M_p, axis=2).flatten()
350
+ h2_norms = np.linalg.norm(all_M_h2, axis=2).flatten()
351
+ ax1.hist(p_norms, bins=50, alpha=0.5, label='P-Class', color='red')
352
+ ax1.hist(h2_norms, bins=50, alpha=0.5, label='H2a', color='blue')
353
+ ax1.axvline(1.0, color='black', linestyle='--', alpha=0.7,
354
+ label='unit sphere')
355
+ ax1.set_xlabel('Row norm')
356
+ ax1.set_ylabel('Count')
357
+ ax1.set_title('Per-sample row norms\n'
358
+ '(both should be ~1.0 if sphere-normed)')
359
+ ax1.legend()
360
+
361
+ # Panel 2: P-Class β€” 3D scatter of one sample's rows
362
+ ax2 = fig.add_subplot(2, 3, 2, projection='3d')
363
+ sample_p = all_M_p[0] # one sample, [V=32, D=3]
364
+ ax2.scatter(sample_p[:, 0], sample_p[:, 1], sample_p[:, 2],
365
+ c=np.arange(32), cmap='viridis', s=80,
366
+ edgecolors='black', linewidths=0.5)
367
+ # Wireframe sphere for reference
368
+ u = np.linspace(0, 2 * np.pi, 20)
369
+ v = np.linspace(0, np.pi, 20)
370
+ x_s = np.outer(np.cos(u), np.sin(v))
371
+ y_s = np.outer(np.sin(u), np.sin(v))
372
+ z_s = np.outer(np.ones_like(u), np.cos(v))
373
+ ax2.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
374
+ ax2.set_title(f'P-Class (D=3) β€” single sample\n32 rows in 3D')
375
+
376
+ # Panel 3: H2a β€” 3D scatter (project D=4 to first 3 dims)
377
+ ax3 = fig.add_subplot(2, 3, 3, projection='3d')
378
+ sample_h2 = all_M_h2[0] # [V=32, D=4]
379
+ ax3.scatter(sample_h2[:, 0], sample_h2[:, 1], sample_h2[:, 2],
380
+ c=np.arange(32), cmap='viridis', s=80,
381
+ edgecolors='black', linewidths=0.5)
382
+ ax3.plot_wireframe(x_s, y_s, z_s, alpha=0.1, color='gray')
383
+ ax3.set_title(f'H2a (D=4) β€” single sample\n32 rows projected to first 3 dims')
384
+
385
+ # Panel 4: Per-sample silhouette stability (P-Class)
386
+ ax4 = fig.add_subplot(2, 3, 4)
387
+ sils_p = results['per_sample_clustering_p']['silhouettes_per_sample']
388
+ sils_h2 = results['per_sample_clustering_h2']['silhouettes_per_sample']
389
+ ax4.boxplot([sils_p, sils_h2], labels=['P-Class', 'H2a'])
390
+ ax4.axhline(0.5, color='red', linestyle='--', alpha=0.5,
391
+ label='strong cluster threshold')
392
+ ax4.set_ylabel(f'Silhouette score (k=5 per-sample)')
393
+ ax4.set_title('Per-sample cluster stability\n'
394
+ '(consistent silhouette = real cluster structure)')
395
+ ax4.legend(fontsize=8)
396
+ ax4.grid(alpha=0.3)
397
+
398
+ # Panel 5: Pairwise angle distribution
399
+ ax5 = fig.add_subplot(2, 3, 5)
400
+ angles_p = results['angular_p']['pairwise_angles_subset']
401
+ angles_h2 = results['angular_h2']['pairwise_angles_subset']
402
+ ax5.hist(angles_p, bins=40, alpha=0.5, label='P-Class', color='red',
403
+ density=True)
404
+ ax5.hist(angles_h2, bins=40, alpha=0.5, label='H2a', color='blue',
405
+ density=True)
406
+ ax5.axvline(math.pi / 2, color='black', linestyle='--', alpha=0.7,
407
+ label='Ο€/2 (uniform peak)')
408
+ ax5.set_xlabel('Pairwise angle (radians)')
409
+ ax5.set_ylabel('Density')
410
+ ax5.set_title('Pairwise angle distribution\n'
411
+ '(uniform sphere peaks at Ο€/2)')
412
+ ax5.legend(fontsize=8)
413
+
414
+ # Panel 6: Per-row stability (mean direction concentration)
415
+ ax6 = fig.add_subplot(2, 3, 6)
416
+ stab_p = results['stability_p']['mean_dir_norms']
417
+ stab_h2 = results['stability_h2']['mean_dir_norms']
418
+ ax6.plot(sorted(stab_p, reverse=True), 'o-', label='P-Class',
419
+ color='red', markersize=5)
420
+ ax6.plot(sorted(stab_h2, reverse=True), 's-', label='H2a',
421
+ color='blue', markersize=5)
422
+ ax6.set_xlabel('Row index (sorted by stability)')
423
+ ax6.set_ylabel('Mean direction norm\n(1.0 = perfectly stable)')
424
+ ax6.set_title('Per-row stability across 512 samples\n'
425
+ '(low = row direction depends on input)')
426
+ ax6.legend()
427
+ ax6.grid(alpha=0.3)
428
+
429
+ plt.tight_layout()
430
+ plt.savefig(output_path, dpi=120, bbox_inches='tight')
431
+ plt.show()
432
+
433
+
434
+ # ════════════════════════════════��═══════════════════════════════════
435
+ # Main
436
+ # ════════════════════════════════════════════════════════════════════
437
+
438
+ def main():
439
+ print("Loading P-rank09 (D=3 candidate)...")
440
+ p_model, p_cfg = load_model('rank09', RANK09_CKPT)
441
+ print(f" V={p_cfg.matrix_v}, D={p_cfg.D}, params="
442
+ f"{sum(p.numel() for p in p_model.parameters()):,}")
443
+
444
+ print("\nLoading Q-rank02 H2a (D=4 reference)...")
445
+ h2_model, h2_cfg = load_model('rank02', RANK02_CKPT)
446
+ print(f" V={h2_cfg.matrix_v}, D={h2_cfg.D}, params="
447
+ f"{sum(p.numel() for p in h2_model.parameters()):,}")
448
+
449
+ print("\nCollecting M rows from gaussian inputs (P-Class)...")
450
+ all_M_p = collect_per_sample_M(p_model, p_cfg)
451
+ print(f" shape: {all_M_p.shape}")
452
+
453
+ print("Collecting M rows from gaussian inputs (H2a)...")
454
+ all_M_h2 = collect_per_sample_M(h2_model, h2_cfg)
455
+ print(f" shape: {all_M_h2.shape}")
456
+
457
+ print("\n" + "═" * 70)
458
+ print("SPHERE-NORM VERIFICATION")
459
+ print("═" * 70)
460
+
461
+ norms_p = test_sphere_norm(all_M_p, "P-Class (D=3)")
462
+ norms_h2 = test_sphere_norm(all_M_h2, "H2a (D=4)")
463
+
464
+ print("\n" + "═" * 70)
465
+ print("ROW STABILITY ACROSS SAMPLES")
466
+ print("═" * 70)
467
+
468
+ stab_p = test_row_stability(all_M_p, "P-Class (D=3)")
469
+ stab_h2 = test_row_stability(all_M_h2, "H2a (D=4)")
470
+
471
+ print("\n" + "═" * 70)
472
+ print("PER-SAMPLE CLUSTERING")
473
+ print("═" * 70)
474
+
475
+ cluster_p = test_per_sample_clustering(all_M_p, k_test=5)
476
+ cluster_h2 = test_per_sample_clustering(all_M_h2, k_test=5)
477
+
478
+ print("\n" + "═" * 70)
479
+ print("ANGULAR DISTRIBUTION")
480
+ print("═" * 70)
481
+
482
+ angular_p = test_angular_distribution(all_M_p, "P-Class (D=3)")
483
+ angular_h2 = test_angular_distribution(all_M_h2, "H2a (D=4)")
484
+
485
+ print("\n" + "═" * 70)
486
+ print("ANTIPODAL STRUCTURE")
487
+ print("═" * 70)
488
+
489
+ antipodal_p = test_antipodal(all_M_p, "P-Class (D=3)")
490
+ antipodal_h2 = test_antipodal(all_M_h2, "H2a (D=4)")
491
+
492
+ comparison = comparison_test(all_M_p, all_M_h2)
493
+
494
+ all_results = {
495
+ 'sphere_norm_p': norms_p,
496
+ 'sphere_norm_h2': norms_h2,
497
+ 'stability_p': stab_p,
498
+ 'stability_h2': stab_h2,
499
+ 'per_sample_clustering_p': cluster_p,
500
+ 'per_sample_clustering_h2': cluster_h2,
501
+ 'angular_p': angular_p,
502
+ 'angular_h2': angular_h2,
503
+ 'antipodal_p': antipodal_p,
504
+ 'antipodal_h2': antipodal_h2,
505
+ 'comparison': comparison,
506
+ }
507
+
508
+ # ════════════════════════════════════════════════════════════════
509
+ # Final interpretation
510
+ # ════════════════════════════════════════════════════════════════
511
+
512
+ print("\n" + "═" * 70)
513
+ print("INTERPRETATION")
514
+ print("═" * 70)
515
+
516
+ p_normed = norms_p['sphere_normed_per_sample']
517
+ h2_normed = norms_h2['sphere_normed_per_sample']
518
+
519
+ print(f"\nSphere-norm per-sample:")
520
+ print(f" P-Class: {'YES' if p_normed else 'NO'} "
521
+ f"(mean norm {norms_p['row_norms_mean']:.3f})")
522
+ print(f" H2a: {'YES' if h2_normed else 'NO'} "
523
+ f"(mean norm {norms_h2['row_norms_mean']:.3f})")
524
+
525
+ print(f"\nPer-sample cluster strength (k=5 silhouette):")
526
+ print(f" P-Class: mean {cluster_p['mean_silhouette']:.3f}, "
527
+ f"std {cluster_p['std_silhouette']:.3f}")
528
+ print(f" H2a: mean {cluster_h2['mean_silhouette']:.3f}, "
529
+ f"std {cluster_h2['std_silhouette']:.3f}")
530
+
531
+ print(f"\nRow direction stability (1.0 = perfectly stable):")
532
+ print(f" P-Class: {stab_p['mean_dir_norms_mean']:.3f}")
533
+ print(f" H2a: {stab_h2['mean_dir_norms_mean']:.3f}")
534
+
535
+ print(f"\nAngular distribution mean (uniform = Ο€/2 β‰ˆ 1.571):")
536
+ print(f" P-Class: {angular_p['mean_angle']:.3f}")
537
+ print(f" H2a: {angular_h2['mean_angle']:.3f}")
538
+
539
+ print(f"\nDimension utilization:")
540
+ print(f" P-Class: {comparison['p_dim_utilization']*100:.0f}% of {p_cfg.D}-D")
541
+ print(f" H2a: {comparison['h2_dim_utilization']*100:.0f}% of {h2_cfg.D}-D")
542
+
543
+ print(f"\nKEY QUESTIONS ANSWERED:")
544
+
545
+ if p_normed and cluster_p['mean_silhouette'] > 0.5:
546
+ print(f" βœ“ P-Class IS clustered per-sample (real structure)")
547
+ elif p_normed and cluster_p['mean_silhouette'] < 0.3:
548
+ print(f" βœ— P-Class clusters were AVERAGING ARTIFACT")
549
+ print(f" Per-sample silhouette only {cluster_p['mean_silhouette']:.3f}")
550
+
551
+ if antipodal_p['fraction_with_antipode'] > 0.5:
552
+ print(f" βœ“ P-Class has antipodal structure "
553
+ f"({antipodal_p['estimated_antipodal_pairs']} pairs)")
554
+
555
+ with open(OUTPUT_JSON, 'w') as f:
556
+ json.dump(all_results, f, indent=2, default=str)
557
+ print(f"\nSaved: {OUTPUT_JSON}")
558
+
559
+ plot_diagnostic(all_M_p, all_M_h2, all_results, OUTPUT_PLOT)
560
+ print(f"Saved: {OUTPUT_PLOT}")
561
+
562
+ return all_results
563
+
564
+
565
+ if __name__ == '__main__':
566
+ results = main()