AbstractPhil commited on
Commit
bbcb9fa
Β·
verified Β·
1 Parent(s): a7dc87a

Create prototype_advanced_cell_v14.py

Browse files
Files changed (1) hide show
  1. prototype_advanced_cell_v14.py +511 -0
prototype_advanced_cell_v14.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SpectralCell
3
+ ============
4
+ Drop-in layer: (B, N, token_dim) β†’ (B, N, token_dim).
5
+
6
+ Pipeline:
7
+ tokens β†’ Linear β†’ residual MLP β†’ Linear(hidden, V*D) β†’ reshape(V, D)
8
+ β†’ capture row magnitudes (encoder confidence)
9
+ β†’ F.normalize(dim=-1) β†’ SVD(Gram-eigh, fp64) β†’ U, S, Vt
10
+ β†’ CM validation β†’ pairwise distances (cm_d2) + simplex volume (cm_vol2)
11
+ β†’ cross-attention scales S per mode across all N tokens
12
+ β†’ recompose M_hat = U Β· diag(S_modified) Β· Vt
13
+ β†’ cat(M_hat, cm_d2, row_magnitudes)
14
+ β†’ Linear β†’ residual MLP β†’ Linear(hidden, token_dim) β†’ output
15
+
16
+ SVD is in the forward pass. Differentiable. Gradients flow through
17
+ U, S, Vt back to the input projection weights.
18
+
19
+ Cross-attention modifies S multiplicatively:
20
+ S_out = S * (1 + Ξ± * tanh(attention_output))
21
+ Ξ± per mode, bounded [0, max_alpha], initialized ~0.024.
22
+ M_hat β‰  M after this step.
23
+
24
+ Sphere normalization enforces ||row||=1 for all V rows.
25
+ This constrains trace(M^T M) = V (fixed total spectral energy).
26
+ The SVD decomposes how that fixed energy distributes across D axes.
27
+
28
+ Cayley-Menger validation on M rows:
29
+ Sample pentachora (5-point subsets) from the V rows on S^{D-1}.
30
+ CM determinant β†’ squared simplex volume.
31
+ CV = std(vol) / mean(vol) over n_samples subsets.
32
+ Measures geometric uniformity of the representation.
33
+
34
+ Author: AbstractPhil + Claude Opus
35
+ License: Apache 2.0
36
+ """
37
+
38
+ import math
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+ from itertools import combinations
43
+
44
+
45
+ # ── Cayley-Menger ───────────────────────────────────────────────
46
+
47
+ class CMValidator(nn.Module):
48
+ """Batch-friendly Cayley-Menger determinant.
49
+ Computes pairwise squared distances and simplex volume
50
+ for (k+1)-point subsets in arbitrary embedding dimension.
51
+
52
+ For k=4: 5 vertices β†’ 10 pairwise dΒ² + 1 volΒ².
53
+ """
54
+ def __init__(self, k):
55
+ super().__init__()
56
+ self._k = k
57
+ self._nv = k + 1
58
+ pairs = list(combinations(range(self._nv), 2))
59
+ self._npairs = len(pairs)
60
+ self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
61
+ self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
62
+ sign = (-1.0) ** (k + 1)
63
+ fact = math.factorial(k)
64
+ self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
65
+
66
+ def forward(self, verts):
67
+ """verts: (..., nv, edim) β†’ d2_pairs: (..., npairs), vol2: (...)"""
68
+ gram = torch.einsum('...ve,...we->...vw', verts, verts)
69
+ norms = torch.diagonal(gram, dim1=-2, dim2=-1)
70
+ d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
71
+ d2_mat = F.relu(d2_mat)
72
+ d2_pairs = d2_mat[..., self._pi, self._pj]
73
+ shape = d2_mat.shape[:-2]
74
+ Vn = d2_mat.shape[-1]
75
+ cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype)
76
+ cm[..., 0, 1:] = 1.0
77
+ cm[..., 1:, 0] = 1.0
78
+ cm[..., 1:, 1:] = d2_mat
79
+ vol2 = self._prefactor * torch.linalg.det(cm.float())
80
+ vol2 = vol2.to(d2_pairs.dtype)
81
+ return d2_pairs, vol2
82
+
83
+
84
+ def cayley_menger_vol2(points: torch.Tensor) -> torch.Tensor:
85
+ """Squared simplex volume via CM determinant in fp64.
86
+ points: (B, N, D) β†’ vol2: (B,)
87
+ """
88
+ B, N, D = points.shape
89
+ pts = points.double()
90
+ gram = torch.bmm(pts, pts.transpose(1, 2))
91
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
92
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
93
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
94
+ cm[:, 0, 1:] = 1.0
95
+ cm[:, 1:, 0] = 1.0
96
+ cm[:, 1:, 1:] = d2
97
+ k = N - 1
98
+ sign = (-1.0) ** (k + 1)
99
+ fact = math.factorial(k)
100
+ return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
101
+
102
+
103
+ def cv_of(emb: torch.Tensor, n_samples: int = 200) -> float:
104
+ """Coefficient of variation of pentachoron volumes.
105
+ emb: (V, D) β€” rows of a sphere-normalized matrix.
106
+ Samples random 5-point subsets, computes CM volΒ² for each,
107
+ returns std(vol) / mean(vol).
108
+
109
+ CV β‰ˆ 0.20-0.23 is the empirically observed attractor band.
110
+ Returns 0.0 if insufficient valid volumes.
111
+ """
112
+ if emb.dim() != 2 or emb.shape[0] < 5:
113
+ return 0.0
114
+ N, D = emb.shape
115
+ pool = min(N, 512)
116
+ indices = torch.stack([
117
+ torch.randperm(pool, device=emb.device)[:5]
118
+ for _ in range(n_samples)
119
+ ])
120
+ vol2 = cayley_menger_vol2(emb[:pool][indices])
121
+ valid = vol2 > 1e-20
122
+ if valid.sum() < 10:
123
+ return 0.0
124
+ vols = vol2[valid].sqrt()
125
+ return (vols.std() / (vols.mean() + 1e-8)).item()
126
+
127
+
128
+ # ── SVD via Gram-eigh (fp64 exact) ──────────────────────────────
129
+
130
+ def gram_eigh_svd(A: torch.Tensor):
131
+ """Thin SVD via Gram eigendecomposition in fp64.
132
+
133
+ Computes G = A^T A in fp64, eigendecomposes G, derives U, S, Vh.
134
+ Diagonal perturbation 1e-12 for numerical stability.
135
+
136
+ Args:
137
+ A: (B, V, D) with V >= D
138
+
139
+ Returns:
140
+ U: (B, V, D) left singular vectors
141
+ S: (B, D) singular values, descending
142
+ Vh: (B, D, D) right singular vectors transposed
143
+ """
144
+ B, V, D = A.shape
145
+ orig = A.dtype
146
+ with torch.amp.autocast('cuda', enabled=False):
147
+ Ad = A.double()
148
+ G = torch.bmm(Ad.transpose(1, 2), Ad)
149
+ G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
150
+ eigenvalues, Vecs = torch.linalg.eigh(G)
151
+ eigenvalues = eigenvalues.flip(-1)
152
+ Vecs = Vecs.flip(-1)
153
+ S = torch.sqrt(eigenvalues.clamp(min=1e-24))
154
+ U = torch.bmm(Ad, Vecs) / S.unsqueeze(1).clamp(min=1e-16)
155
+ Vh = Vecs.transpose(-2, -1).contiguous()
156
+ return U.to(orig), S.to(orig), Vh.to(orig)
157
+
158
+
159
+ # ── Spectral Cross-Attention ────────────────────────────────────
160
+
161
+ class SpectralCrossAttention(nn.Module):
162
+ """Multi-head attention on singular values across N tokens.
163
+
164
+ Input S: (B, N, D) β€” one D-dim spectral profile per token.
165
+ Attends across N positions (each token sees all others' spectra).
166
+ Output: S * (1 + Ξ± * tanh(out_proj(attended)))
167
+
168
+ Ξ± is per-mode, bounded [0, max_alpha] via sigmoid on learnable logits.
169
+ Initialized at sigmoid(-2.0) * 0.2 β‰ˆ 0.024 per mode.
170
+ """
171
+ def __init__(self, D, n_heads=2, max_alpha=0.2, alpha_init=-2.0):
172
+ super().__init__()
173
+ self.n_heads = n_heads
174
+ self.head_dim = D // n_heads
175
+ self.max_alpha = max_alpha
176
+ assert D % n_heads == 0
177
+
178
+ self.qkv = nn.Linear(D, 3 * D)
179
+ self.out_proj = nn.Linear(D, D)
180
+ self.norm = nn.LayerNorm(D)
181
+ self.scale = self.head_dim ** -0.5
182
+ self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
183
+
184
+ @property
185
+ def alpha(self):
186
+ return self.max_alpha * torch.sigmoid(self.alpha_logits)
187
+
188
+ def forward(self, S):
189
+ B, N, D = S.shape
190
+ Sn = self.norm(S)
191
+ qkv = self.qkv(Sn).reshape(B, N, 3, self.n_heads, self.head_dim)
192
+ qkv = qkv.permute(2, 0, 3, 1, 4)
193
+ q, k, v = qkv[0], qkv[1], qkv[2]
194
+ attn = (q @ k.transpose(-2, -1)) * self.scale
195
+ attn = attn.softmax(dim=-1)
196
+ out = (attn @ v).transpose(1, 2).reshape(B, N, D)
197
+ gate = torch.tanh(self.out_proj(out))
198
+ alpha = self.alpha.unsqueeze(0).unsqueeze(0)
199
+ return S * (1.0 + alpha * gate)
200
+
201
+
202
+ # ── SpectralCell ────────────────────────────────────────────────
203
+
204
+ class SpectralCell(nn.Module):
205
+ """Processes N tokens through sphere-normalized SVD with spectral
206
+ coordination and Cayley-Menger geometric validation.
207
+
208
+ Shapes through the pipeline (for default V=16, D=4, hidden=128, token_dim=64):
209
+ tokens: (B, N, 64)
210
+ enc_in: Linear(64, 128) β†’ (B*N, 128)
211
+ enc_blocks: 2Γ— residual MLP β†’ (B*N, 128)
212
+ enc_out: Linear(128, 64) β†’ (B*N, 64) β†’ reshape (B*N, 16, 4)
213
+ normalize: F.normalize(dim=-1) β†’ each row has norm 1
214
+ SVD: Gram-eigh in fp64 β†’ U(B*N,16,4), S(B*N,4), Vt(B*N,4,4)
215
+ cross_attn: S reshaped (B,N,4) β†’ attention across N β†’ S_coord (B,N,4)
216
+ recompose: U Β· diag(S_coord) Β· Vt β†’ M_hat (B*N, 16, 4) β†’ flatten (B*N, 64)
217
+ out_in: Linear(64, 128) β†’ (B*N, 128)
218
+ out_blocks: 2Γ— residual MLP β†’ (B*N, 128)
219
+ out_proj: Linear(128, 64) β†’ (B, N, 64)
220
+
221
+ CM validation:
222
+ M rows are V unit vectors on S^{D-1}.
223
+ CMValidator(k=4) samples pentachora from the rows.
224
+ volΒ² measures simplex volume. CV measures uniformity.
225
+ cv_of() returns the coefficient of variation over random subsets.
226
+
227
+ Args:
228
+ token_dim: input and output dimension per token
229
+ V: matrix rows (each becomes a unit vector on S^{D-1})
230
+ D: matrix columns (spectral modes, eigenvalue count)
231
+ hidden: residual MLP width
232
+ depth: residual blocks in input and output projections
233
+ n_cross: SpectralCrossAttention layers applied to S
234
+ n_heads: attention heads in cross-attention (must divide D)
235
+ max_alpha: upper bound on per-mode multiplicative scaling
236
+ """
237
+ def __init__(
238
+ self,
239
+ token_dim: int,
240
+ V: int = 16,
241
+ D: int = 4,
242
+ hidden: int = 128,
243
+ depth: int = 2,
244
+ n_cross: int = 1,
245
+ n_heads: int = 2,
246
+ max_alpha: float = 0.2,
247
+ ):
248
+ super().__init__()
249
+ self.token_dim = token_dim
250
+ self.V = V
251
+ self.D = D
252
+ self.mat_dim = V * D
253
+ self.hidden = hidden
254
+
255
+ # CM validator: k=min(4, D-1) for pentachoron on S^{D-1}
256
+ # k=4 means 5 vertices, requires D >= 4
257
+ self._cm_k = min(4, D - 1) if D >= 2 else 1
258
+ self.cm = CMValidator(self._cm_k)
259
+
260
+ # Input projection: token_dim β†’ hidden β†’ mat_dim
261
+ self.enc_in = nn.Linear(token_dim, hidden)
262
+ self.enc_blocks = nn.ModuleList([
263
+ nn.Sequential(
264
+ nn.LayerNorm(hidden),
265
+ nn.Linear(hidden, hidden),
266
+ nn.GELU(),
267
+ nn.Linear(hidden, hidden),
268
+ ) for _ in range(depth)
269
+ ])
270
+ self.enc_out = nn.Linear(hidden, self.mat_dim)
271
+ nn.init.orthogonal_(self.enc_out.weight)
272
+
273
+ # Cross-attention on singular values across tokens
274
+ self.cross_attn = nn.ModuleList([
275
+ SpectralCrossAttention(D, n_heads=n_heads, max_alpha=max_alpha)
276
+ for _ in range(n_cross)
277
+ ])
278
+
279
+ # Output projection: mat_dim + cm_d2 + magnitudes β†’ hidden β†’ token_dim
280
+ # cm_d2: pairwise distances between M rows (geometric arrangement)
281
+ # row_mag: pre-normalization magnitudes (encoder confidence)
282
+ self._cm_npairs = self.cm._npairs
283
+ self.out_in = nn.Linear(self.mat_dim + self._cm_npairs + self.V, hidden)
284
+ self.out_blocks = nn.ModuleList([
285
+ nn.Sequential(
286
+ nn.LayerNorm(hidden),
287
+ nn.Linear(hidden, hidden),
288
+ nn.GELU(),
289
+ nn.Linear(hidden, hidden),
290
+ ) for _ in range(depth)
291
+ ])
292
+ self.out_proj = nn.Linear(hidden, token_dim)
293
+
294
+ def format(self, tokens: torch.Tensor) -> dict:
295
+ """Run full pipeline. Returns output tokens, SVD components, and CM metrics.
296
+
297
+ Args:
298
+ tokens: (B, N, token_dim)
299
+
300
+ Returns:
301
+ dict:
302
+ output: (B, N, token_dim) β€” processed tokens
303
+ M: (B, N, V, D) β€” sphere-normalized matrix (rows on S^{D-1})
304
+ U: (B, N, V, D) β€” left singular vectors from SVD
305
+ S_orig: (B, N, D) β€” singular values before cross-attention
306
+ S: (B, N, D) β€” singular values after cross-attention
307
+ Vt: (B, N, D, D) β€” right singular vectors from SVD
308
+ M_hat: (B, N, V, D) β€” U Β· diag(S_modified) Β· Vt (β‰  M)
309
+ cm_d2: (B*N, npairs) β€” pairwise squared distances from CM
310
+ cm_vol2: (B*N,) β€” squared simplex volume from CM
311
+ row_mag: (B, N, V) β€” pre-normalization row magnitudes
312
+ """
313
+ B, N, _ = tokens.shape
314
+
315
+ # Input projection β†’ sphere-normalized VΓ—D matrix
316
+ flat = tokens.reshape(B * N, -1)
317
+ h = F.gelu(self.enc_in(flat))
318
+ for block in self.enc_blocks:
319
+ h = h + block(h)
320
+ M = self.enc_out(h).reshape(B * N, self.V, self.D)
321
+ row_mag = M.norm(dim=-1) # (B*N, V) β€” encoder confidence per row
322
+ M = F.normalize(M, dim=-1)
323
+
324
+ # CM validation on M rows β€” sample (k+1) rows per token
325
+ # Use fixed evenly-spaced indices for deterministic CM
326
+ nv = self._cm_k + 1
327
+ cm_idx = torch.linspace(0, self.V - 1, nv).long().to(M.device)
328
+ cm_verts = M[:, cm_idx, :] # (B*N, nv, D)
329
+ cm_d2, cm_vol2 = self.cm(cm_verts)
330
+
331
+ # SVD decomposition (in compute graph, fp64)
332
+ U, S, Vt = gram_eigh_svd(M)
333
+
334
+ # Reshape for cross-attention over N tokens
335
+ U = U.reshape(B, N, self.V, self.D)
336
+ S = S.reshape(B, N, self.D)
337
+ Vt = Vt.reshape(B, N, self.D, self.D)
338
+ M = M.reshape(B, N, self.V, self.D)
339
+
340
+ # Cross-attention multiplicatively scales S across tokens
341
+ S_orig = S.clone()
342
+ for layer in self.cross_attn:
343
+ S = layer(S)
344
+
345
+ # Recompose with modified S β†’ M_hat β‰  M
346
+ U_flat = U.reshape(B * N, self.V, self.D)
347
+ S_flat = S.reshape(B * N, self.D)
348
+ Vt_flat = Vt.reshape(B * N, self.D, self.D)
349
+ M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
350
+
351
+ # Output projection: M_hat + cm_d2 + magnitudes β†’ token_dim
352
+ out_features = torch.cat([
353
+ M_hat.reshape(B * N, -1), # (B*N, V*D) β€” recomposed spectral structure
354
+ cm_d2, # (B*N, npairs) β€” geometric arrangement
355
+ row_mag, # (B*N, V) β€” encoder confidence
356
+ ], dim=-1)
357
+ h = F.gelu(self.out_in(out_features))
358
+ for block in self.out_blocks:
359
+ h = h + block(h)
360
+ output = self.out_proj(h).reshape(B, N, self.token_dim)
361
+
362
+ return {
363
+ 'output': output,
364
+ 'M': M,
365
+ 'U': U,
366
+ 'S_orig': S_orig,
367
+ 'S': S,
368
+ 'Vt': Vt,
369
+ 'M_hat': M_hat.reshape(B, N, self.V, self.D),
370
+ 'cm_d2': cm_d2,
371
+ 'cm_vol2': cm_vol2,
372
+ 'row_mag': row_mag.reshape(B, N, self.V),
373
+ }
374
+
375
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
376
+ """(B, N, token_dim) β†’ (B, N, token_dim). Drop-in compatible."""
377
+ return self.format(tokens)['output']
378
+
379
+ # ── CM Diagnostics ───────────────────────────────────────────
380
+
381
+ def cm_cv(self, M: torch.Tensor, n_samples: int = 200) -> float:
382
+ """Compute CV of pentachoron volumes over random 5-point subsets.
383
+ M: (B, N, V, D) β€” sphere-normalized matrices.
384
+ Returns mean CV across all B*N matrices.
385
+ """
386
+ flat = M.reshape(-1, self.V, self.D)
387
+ # Sample a few matrices to keep cost reasonable
388
+ n_mats = min(flat.shape[0], 64)
389
+ cvs = []
390
+ for i in range(n_mats):
391
+ c = cv_of(flat[i], n_samples=n_samples)
392
+ cvs.append(c)
393
+ return sum(cvs) / len(cvs) if cvs else 0.0
394
+
395
+ def cm_vol2_stats(self, cm_vol2: torch.Tensor) -> dict:
396
+ """Statistics on CM volΒ² from format() output.
397
+ cm_vol2: (B*N,) β€” one volΒ² per token's sampled pentachoron.
398
+ """
399
+ valid = cm_vol2.abs() > 1e-20
400
+ if valid.sum() < 2:
401
+ return {'mean': 0.0, 'std': 0.0, 'frac_valid': 0.0}
402
+ vols = cm_vol2[valid].abs().sqrt()
403
+ return {
404
+ 'mean': vols.mean().item(),
405
+ 'std': vols.std().item(),
406
+ 'cv': (vols.std() / (vols.mean() + 1e-8)).item(),
407
+ 'frac_valid': valid.float().mean().item(),
408
+ }
409
+
410
+ # ── SVD Diagnostics ──────────────────────────────────────────
411
+
412
+ @staticmethod
413
+ def effective_rank(S: torch.Tensor) -> torch.Tensor:
414
+ """Shannon entropy of normalized singular values, exponentiated.
415
+ erank = exp(-Ξ£ p_i log p_i) where p_i = Οƒ_i / Σσ.
416
+ Returns 1.0 for rank-1, D for uniform spectrum.
417
+ """
418
+ p = S / (S.sum(-1, keepdim=True) + 1e-8)
419
+ p = p.clamp(min=1e-8)
420
+ return (-(p * p.log()).sum(-1)).exp()
421
+
422
+ @staticmethod
423
+ def spectral_shift(S_orig, S_coord):
424
+ """Mean |S_coord - S_orig| across all modes and tokens."""
425
+ return (S_coord - S_orig).abs().mean().item()
426
+
427
+ @staticmethod
428
+ def trace_check(M):
429
+ """trace(M^T M) should equal V (sum of squared unit row norms)."""
430
+ flat = M.reshape(-1, M.shape[-2], M.shape[-1])
431
+ G = torch.bmm(flat.transpose(1, 2), flat)
432
+ return torch.diagonal(G, dim1=-2, dim2=-1).sum(-1).mean().item()
433
+
434
+ def summary(self):
435
+ """Print shapes, param count, DOF ratio, CM config."""
436
+ n_params = sum(p.numel() for p in self.parameters())
437
+ sphere_dof = self.V * (self.D - 1)
438
+ ratio = sphere_dof / self.token_dim
439
+ print(f"SpectralCell:")
440
+ print(f" token_dim={self.token_dim}, V={self.V}, D={self.D}")
441
+ print(f" mat_dim={self.mat_dim} ({self.V}Γ—{self.D})")
442
+ print(f" sphere DOF={sphere_dof} (V rows Γ— {self.D-1} free per row)")
443
+ print(f" CM: k={self._cm_k} ({self._cm_k+1} vertices, {self._cm_npairs} pairs)")
444
+ print(f" out_in: {self.mat_dim} (M_hat) + {self._cm_npairs} (cm_d2) + {self.V} (mag) = {self.mat_dim + self._cm_npairs + self.V}")
445
+ print(f" hidden={self.hidden}, depth={len(self.enc_blocks)}")
446
+ print(f" cross_attn={len(self.cross_attn)} layers")
447
+ print(f" params: {n_params:,}")
448
+ print(f" DOF ratio: {ratio:.2f}Γ— "
449
+ f"({'expand' if ratio > 1 else 'compress' if ratio < 1 else 'identity'})")
450
+
451
+
452
+ # ── Factory functions ────────────────────────────────────────────
453
+
454
+ def spectral_cell_tiny(token_dim: int) -> SpectralCell:
455
+ """V=8, D=4, hidden=64, depth=1, 1 cross-attn."""
456
+ return SpectralCell(token_dim, V=8, D=4, hidden=64, depth=1, n_cross=1)
457
+
458
+ def spectral_cell_small(token_dim: int) -> SpectralCell:
459
+ """V=16, D=4, hidden=128, depth=2, 1 cross-attn."""
460
+ return SpectralCell(token_dim, V=16, D=4, hidden=128, depth=2, n_cross=1)
461
+
462
+ def spectral_cell_base(token_dim: int) -> SpectralCell:
463
+ """V=16, D=8, hidden=256, depth=2, 2 cross-attn."""
464
+ return SpectralCell(token_dim, V=16, D=8, hidden=256, depth=2, n_cross=2, n_heads=4)
465
+
466
+ def spectral_cell_diamond(token_dim: int) -> SpectralCell:
467
+ """V=16, D=16, hidden=256, depth=2, 1 cross-attn. Best sweep config."""
468
+ return SpectralCell(token_dim, V=16, D=16, hidden=256, depth=2, n_cross=1, n_heads=4)
469
+
470
+
471
+ # ── Self-test ───────────────────────────────────────────────────
472
+
473
+ if __name__ == '__main__':
474
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
475
+
476
+ for name, factory in [('tiny', spectral_cell_tiny),
477
+ ('small', spectral_cell_small),
478
+ ('diamond', spectral_cell_diamond)]:
479
+ print(f"\n{'='*50}")
480
+ cell = factory(token_dim=192).to(device)
481
+ cell.summary()
482
+
483
+ tokens = torch.randn(2, 16, 192, device=device)
484
+ result = cell.format(tokens)
485
+
486
+ print(f"\n Input: {tokens.shape}")
487
+ print(f" Output: {result['output'].shape}")
488
+ print(f" M: {result['M'].shape}")
489
+ print(f" S: {result['S'].shape}")
490
+ print(f" cm_d2: {result['cm_d2'].shape}")
491
+ print(f" cm_vol2: {result['cm_vol2'].shape}")
492
+ print(f" trace: {cell.trace_check(result['M']):.4f} (expect {cell.V})")
493
+ print(f" erank: {cell.effective_rank(result['S_orig'].reshape(-1, cell.D)).mean():.2f}")
494
+ print(f" shift: {cell.spectral_shift(result['S_orig'], result['S']):.6f}")
495
+
496
+ # CM stats
497
+ cm_stats = cell.cm_vol2_stats(result['cm_vol2'])
498
+ print(f" cm_vol: mean={cm_stats['mean']:.6f} cv={cm_stats.get('cv', 0):.4f} "
499
+ f"valid={cm_stats['frac_valid']:.1%}")
500
+
501
+ # Full CV (slower, samples 200 pentachora)
502
+ with torch.no_grad():
503
+ cv = cell.cm_cv(result['M'], n_samples=100)
504
+ print(f" cm_cv: {cv:.4f}")
505
+
506
+ # Gradient check
507
+ loss = result['output'].sum()
508
+ loss.backward()
509
+ grad_ok = all(p.grad is not None and p.grad.abs().sum() > 0
510
+ for p in cell.parameters() if p.requires_grad)
511
+ print(f" grads: {'βœ“' if grad_ok else 'βœ—'}")