Create prototype_advanced_cell_v14.py
Browse files- prototype_advanced_cell_v14.py +511 -0
prototype_advanced_cell_v14.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SpectralCell
|
| 3 |
+
============
|
| 4 |
+
Drop-in layer: (B, N, token_dim) β (B, N, token_dim).
|
| 5 |
+
|
| 6 |
+
Pipeline:
|
| 7 |
+
tokens β Linear β residual MLP β Linear(hidden, V*D) β reshape(V, D)
|
| 8 |
+
β capture row magnitudes (encoder confidence)
|
| 9 |
+
β F.normalize(dim=-1) β SVD(Gram-eigh, fp64) β U, S, Vt
|
| 10 |
+
β CM validation β pairwise distances (cm_d2) + simplex volume (cm_vol2)
|
| 11 |
+
β cross-attention scales S per mode across all N tokens
|
| 12 |
+
β recompose M_hat = U Β· diag(S_modified) Β· Vt
|
| 13 |
+
β cat(M_hat, cm_d2, row_magnitudes)
|
| 14 |
+
β Linear β residual MLP β Linear(hidden, token_dim) β output
|
| 15 |
+
|
| 16 |
+
SVD is in the forward pass. Differentiable. Gradients flow through
|
| 17 |
+
U, S, Vt back to the input projection weights.
|
| 18 |
+
|
| 19 |
+
Cross-attention modifies S multiplicatively:
|
| 20 |
+
S_out = S * (1 + Ξ± * tanh(attention_output))
|
| 21 |
+
Ξ± per mode, bounded [0, max_alpha], initialized ~0.024.
|
| 22 |
+
M_hat β M after this step.
|
| 23 |
+
|
| 24 |
+
Sphere normalization enforces ||row||=1 for all V rows.
|
| 25 |
+
This constrains trace(M^T M) = V (fixed total spectral energy).
|
| 26 |
+
The SVD decomposes how that fixed energy distributes across D axes.
|
| 27 |
+
|
| 28 |
+
Cayley-Menger validation on M rows:
|
| 29 |
+
Sample pentachora (5-point subsets) from the V rows on S^{D-1}.
|
| 30 |
+
CM determinant β squared simplex volume.
|
| 31 |
+
CV = std(vol) / mean(vol) over n_samples subsets.
|
| 32 |
+
Measures geometric uniformity of the representation.
|
| 33 |
+
|
| 34 |
+
Author: AbstractPhil + Claude Opus
|
| 35 |
+
License: Apache 2.0
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
import math
|
| 39 |
+
import torch
|
| 40 |
+
import torch.nn as nn
|
| 41 |
+
import torch.nn.functional as F
|
| 42 |
+
from itertools import combinations
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββ Cayley-Menger βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
|
| 47 |
+
class CMValidator(nn.Module):
|
| 48 |
+
"""Batch-friendly Cayley-Menger determinant.
|
| 49 |
+
Computes pairwise squared distances and simplex volume
|
| 50 |
+
for (k+1)-point subsets in arbitrary embedding dimension.
|
| 51 |
+
|
| 52 |
+
For k=4: 5 vertices β 10 pairwise dΒ² + 1 volΒ².
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, k):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self._k = k
|
| 57 |
+
self._nv = k + 1
|
| 58 |
+
pairs = list(combinations(range(self._nv), 2))
|
| 59 |
+
self._npairs = len(pairs)
|
| 60 |
+
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
|
| 61 |
+
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
|
| 62 |
+
sign = (-1.0) ** (k + 1)
|
| 63 |
+
fact = math.factorial(k)
|
| 64 |
+
self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
|
| 65 |
+
|
| 66 |
+
def forward(self, verts):
|
| 67 |
+
"""verts: (..., nv, edim) β d2_pairs: (..., npairs), vol2: (...)"""
|
| 68 |
+
gram = torch.einsum('...ve,...we->...vw', verts, verts)
|
| 69 |
+
norms = torch.diagonal(gram, dim1=-2, dim2=-1)
|
| 70 |
+
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
|
| 71 |
+
d2_mat = F.relu(d2_mat)
|
| 72 |
+
d2_pairs = d2_mat[..., self._pi, self._pj]
|
| 73 |
+
shape = d2_mat.shape[:-2]
|
| 74 |
+
Vn = d2_mat.shape[-1]
|
| 75 |
+
cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype)
|
| 76 |
+
cm[..., 0, 1:] = 1.0
|
| 77 |
+
cm[..., 1:, 0] = 1.0
|
| 78 |
+
cm[..., 1:, 1:] = d2_mat
|
| 79 |
+
vol2 = self._prefactor * torch.linalg.det(cm.float())
|
| 80 |
+
vol2 = vol2.to(d2_pairs.dtype)
|
| 81 |
+
return d2_pairs, vol2
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def cayley_menger_vol2(points: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""Squared simplex volume via CM determinant in fp64.
|
| 86 |
+
points: (B, N, D) β vol2: (B,)
|
| 87 |
+
"""
|
| 88 |
+
B, N, D = points.shape
|
| 89 |
+
pts = points.double()
|
| 90 |
+
gram = torch.bmm(pts, pts.transpose(1, 2))
|
| 91 |
+
norms = torch.diagonal(gram, dim1=1, dim2=2)
|
| 92 |
+
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
|
| 93 |
+
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64)
|
| 94 |
+
cm[:, 0, 1:] = 1.0
|
| 95 |
+
cm[:, 1:, 0] = 1.0
|
| 96 |
+
cm[:, 1:, 1:] = d2
|
| 97 |
+
k = N - 1
|
| 98 |
+
sign = (-1.0) ** (k + 1)
|
| 99 |
+
fact = math.factorial(k)
|
| 100 |
+
return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def cv_of(emb: torch.Tensor, n_samples: int = 200) -> float:
|
| 104 |
+
"""Coefficient of variation of pentachoron volumes.
|
| 105 |
+
emb: (V, D) β rows of a sphere-normalized matrix.
|
| 106 |
+
Samples random 5-point subsets, computes CM volΒ² for each,
|
| 107 |
+
returns std(vol) / mean(vol).
|
| 108 |
+
|
| 109 |
+
CV β 0.20-0.23 is the empirically observed attractor band.
|
| 110 |
+
Returns 0.0 if insufficient valid volumes.
|
| 111 |
+
"""
|
| 112 |
+
if emb.dim() != 2 or emb.shape[0] < 5:
|
| 113 |
+
return 0.0
|
| 114 |
+
N, D = emb.shape
|
| 115 |
+
pool = min(N, 512)
|
| 116 |
+
indices = torch.stack([
|
| 117 |
+
torch.randperm(pool, device=emb.device)[:5]
|
| 118 |
+
for _ in range(n_samples)
|
| 119 |
+
])
|
| 120 |
+
vol2 = cayley_menger_vol2(emb[:pool][indices])
|
| 121 |
+
valid = vol2 > 1e-20
|
| 122 |
+
if valid.sum() < 10:
|
| 123 |
+
return 0.0
|
| 124 |
+
vols = vol2[valid].sqrt()
|
| 125 |
+
return (vols.std() / (vols.mean() + 1e-8)).item()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ββ SVD via Gram-eigh (fp64 exact) ββββββββββββββββββββββββββββββ
|
| 129 |
+
|
| 130 |
+
def gram_eigh_svd(A: torch.Tensor):
|
| 131 |
+
"""Thin SVD via Gram eigendecomposition in fp64.
|
| 132 |
+
|
| 133 |
+
Computes G = A^T A in fp64, eigendecomposes G, derives U, S, Vh.
|
| 134 |
+
Diagonal perturbation 1e-12 for numerical stability.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
A: (B, V, D) with V >= D
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
U: (B, V, D) left singular vectors
|
| 141 |
+
S: (B, D) singular values, descending
|
| 142 |
+
Vh: (B, D, D) right singular vectors transposed
|
| 143 |
+
"""
|
| 144 |
+
B, V, D = A.shape
|
| 145 |
+
orig = A.dtype
|
| 146 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 147 |
+
Ad = A.double()
|
| 148 |
+
G = torch.bmm(Ad.transpose(1, 2), Ad)
|
| 149 |
+
G.diagonal(dim1=-2, dim2=-1).add_(1e-12)
|
| 150 |
+
eigenvalues, Vecs = torch.linalg.eigh(G)
|
| 151 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 152 |
+
Vecs = Vecs.flip(-1)
|
| 153 |
+
S = torch.sqrt(eigenvalues.clamp(min=1e-24))
|
| 154 |
+
U = torch.bmm(Ad, Vecs) / S.unsqueeze(1).clamp(min=1e-16)
|
| 155 |
+
Vh = Vecs.transpose(-2, -1).contiguous()
|
| 156 |
+
return U.to(orig), S.to(orig), Vh.to(orig)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ββ Spectral Cross-Attention ββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
|
| 161 |
+
class SpectralCrossAttention(nn.Module):
|
| 162 |
+
"""Multi-head attention on singular values across N tokens.
|
| 163 |
+
|
| 164 |
+
Input S: (B, N, D) β one D-dim spectral profile per token.
|
| 165 |
+
Attends across N positions (each token sees all others' spectra).
|
| 166 |
+
Output: S * (1 + Ξ± * tanh(out_proj(attended)))
|
| 167 |
+
|
| 168 |
+
Ξ± is per-mode, bounded [0, max_alpha] via sigmoid on learnable logits.
|
| 169 |
+
Initialized at sigmoid(-2.0) * 0.2 β 0.024 per mode.
|
| 170 |
+
"""
|
| 171 |
+
def __init__(self, D, n_heads=2, max_alpha=0.2, alpha_init=-2.0):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.n_heads = n_heads
|
| 174 |
+
self.head_dim = D // n_heads
|
| 175 |
+
self.max_alpha = max_alpha
|
| 176 |
+
assert D % n_heads == 0
|
| 177 |
+
|
| 178 |
+
self.qkv = nn.Linear(D, 3 * D)
|
| 179 |
+
self.out_proj = nn.Linear(D, D)
|
| 180 |
+
self.norm = nn.LayerNorm(D)
|
| 181 |
+
self.scale = self.head_dim ** -0.5
|
| 182 |
+
self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init))
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def alpha(self):
|
| 186 |
+
return self.max_alpha * torch.sigmoid(self.alpha_logits)
|
| 187 |
+
|
| 188 |
+
def forward(self, S):
|
| 189 |
+
B, N, D = S.shape
|
| 190 |
+
Sn = self.norm(S)
|
| 191 |
+
qkv = self.qkv(Sn).reshape(B, N, 3, self.n_heads, self.head_dim)
|
| 192 |
+
qkv = qkv.permute(2, 0, 3, 1, 4)
|
| 193 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 194 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 195 |
+
attn = attn.softmax(dim=-1)
|
| 196 |
+
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
|
| 197 |
+
gate = torch.tanh(self.out_proj(out))
|
| 198 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(0)
|
| 199 |
+
return S * (1.0 + alpha * gate)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ββ SpectralCell ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 203 |
+
|
| 204 |
+
class SpectralCell(nn.Module):
|
| 205 |
+
"""Processes N tokens through sphere-normalized SVD with spectral
|
| 206 |
+
coordination and Cayley-Menger geometric validation.
|
| 207 |
+
|
| 208 |
+
Shapes through the pipeline (for default V=16, D=4, hidden=128, token_dim=64):
|
| 209 |
+
tokens: (B, N, 64)
|
| 210 |
+
enc_in: Linear(64, 128) β (B*N, 128)
|
| 211 |
+
enc_blocks: 2Γ residual MLP β (B*N, 128)
|
| 212 |
+
enc_out: Linear(128, 64) β (B*N, 64) β reshape (B*N, 16, 4)
|
| 213 |
+
normalize: F.normalize(dim=-1) β each row has norm 1
|
| 214 |
+
SVD: Gram-eigh in fp64 β U(B*N,16,4), S(B*N,4), Vt(B*N,4,4)
|
| 215 |
+
cross_attn: S reshaped (B,N,4) β attention across N β S_coord (B,N,4)
|
| 216 |
+
recompose: U Β· diag(S_coord) Β· Vt β M_hat (B*N, 16, 4) β flatten (B*N, 64)
|
| 217 |
+
out_in: Linear(64, 128) β (B*N, 128)
|
| 218 |
+
out_blocks: 2Γ residual MLP β (B*N, 128)
|
| 219 |
+
out_proj: Linear(128, 64) β (B, N, 64)
|
| 220 |
+
|
| 221 |
+
CM validation:
|
| 222 |
+
M rows are V unit vectors on S^{D-1}.
|
| 223 |
+
CMValidator(k=4) samples pentachora from the rows.
|
| 224 |
+
volΒ² measures simplex volume. CV measures uniformity.
|
| 225 |
+
cv_of() returns the coefficient of variation over random subsets.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
token_dim: input and output dimension per token
|
| 229 |
+
V: matrix rows (each becomes a unit vector on S^{D-1})
|
| 230 |
+
D: matrix columns (spectral modes, eigenvalue count)
|
| 231 |
+
hidden: residual MLP width
|
| 232 |
+
depth: residual blocks in input and output projections
|
| 233 |
+
n_cross: SpectralCrossAttention layers applied to S
|
| 234 |
+
n_heads: attention heads in cross-attention (must divide D)
|
| 235 |
+
max_alpha: upper bound on per-mode multiplicative scaling
|
| 236 |
+
"""
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
token_dim: int,
|
| 240 |
+
V: int = 16,
|
| 241 |
+
D: int = 4,
|
| 242 |
+
hidden: int = 128,
|
| 243 |
+
depth: int = 2,
|
| 244 |
+
n_cross: int = 1,
|
| 245 |
+
n_heads: int = 2,
|
| 246 |
+
max_alpha: float = 0.2,
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.token_dim = token_dim
|
| 250 |
+
self.V = V
|
| 251 |
+
self.D = D
|
| 252 |
+
self.mat_dim = V * D
|
| 253 |
+
self.hidden = hidden
|
| 254 |
+
|
| 255 |
+
# CM validator: k=min(4, D-1) for pentachoron on S^{D-1}
|
| 256 |
+
# k=4 means 5 vertices, requires D >= 4
|
| 257 |
+
self._cm_k = min(4, D - 1) if D >= 2 else 1
|
| 258 |
+
self.cm = CMValidator(self._cm_k)
|
| 259 |
+
|
| 260 |
+
# Input projection: token_dim β hidden β mat_dim
|
| 261 |
+
self.enc_in = nn.Linear(token_dim, hidden)
|
| 262 |
+
self.enc_blocks = nn.ModuleList([
|
| 263 |
+
nn.Sequential(
|
| 264 |
+
nn.LayerNorm(hidden),
|
| 265 |
+
nn.Linear(hidden, hidden),
|
| 266 |
+
nn.GELU(),
|
| 267 |
+
nn.Linear(hidden, hidden),
|
| 268 |
+
) for _ in range(depth)
|
| 269 |
+
])
|
| 270 |
+
self.enc_out = nn.Linear(hidden, self.mat_dim)
|
| 271 |
+
nn.init.orthogonal_(self.enc_out.weight)
|
| 272 |
+
|
| 273 |
+
# Cross-attention on singular values across tokens
|
| 274 |
+
self.cross_attn = nn.ModuleList([
|
| 275 |
+
SpectralCrossAttention(D, n_heads=n_heads, max_alpha=max_alpha)
|
| 276 |
+
for _ in range(n_cross)
|
| 277 |
+
])
|
| 278 |
+
|
| 279 |
+
# Output projection: mat_dim + cm_d2 + magnitudes β hidden β token_dim
|
| 280 |
+
# cm_d2: pairwise distances between M rows (geometric arrangement)
|
| 281 |
+
# row_mag: pre-normalization magnitudes (encoder confidence)
|
| 282 |
+
self._cm_npairs = self.cm._npairs
|
| 283 |
+
self.out_in = nn.Linear(self.mat_dim + self._cm_npairs + self.V, hidden)
|
| 284 |
+
self.out_blocks = nn.ModuleList([
|
| 285 |
+
nn.Sequential(
|
| 286 |
+
nn.LayerNorm(hidden),
|
| 287 |
+
nn.Linear(hidden, hidden),
|
| 288 |
+
nn.GELU(),
|
| 289 |
+
nn.Linear(hidden, hidden),
|
| 290 |
+
) for _ in range(depth)
|
| 291 |
+
])
|
| 292 |
+
self.out_proj = nn.Linear(hidden, token_dim)
|
| 293 |
+
|
| 294 |
+
def format(self, tokens: torch.Tensor) -> dict:
|
| 295 |
+
"""Run full pipeline. Returns output tokens, SVD components, and CM metrics.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
tokens: (B, N, token_dim)
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
dict:
|
| 302 |
+
output: (B, N, token_dim) β processed tokens
|
| 303 |
+
M: (B, N, V, D) β sphere-normalized matrix (rows on S^{D-1})
|
| 304 |
+
U: (B, N, V, D) β left singular vectors from SVD
|
| 305 |
+
S_orig: (B, N, D) β singular values before cross-attention
|
| 306 |
+
S: (B, N, D) β singular values after cross-attention
|
| 307 |
+
Vt: (B, N, D, D) β right singular vectors from SVD
|
| 308 |
+
M_hat: (B, N, V, D) β U Β· diag(S_modified) Β· Vt (β M)
|
| 309 |
+
cm_d2: (B*N, npairs) β pairwise squared distances from CM
|
| 310 |
+
cm_vol2: (B*N,) β squared simplex volume from CM
|
| 311 |
+
row_mag: (B, N, V) β pre-normalization row magnitudes
|
| 312 |
+
"""
|
| 313 |
+
B, N, _ = tokens.shape
|
| 314 |
+
|
| 315 |
+
# Input projection β sphere-normalized VΓD matrix
|
| 316 |
+
flat = tokens.reshape(B * N, -1)
|
| 317 |
+
h = F.gelu(self.enc_in(flat))
|
| 318 |
+
for block in self.enc_blocks:
|
| 319 |
+
h = h + block(h)
|
| 320 |
+
M = self.enc_out(h).reshape(B * N, self.V, self.D)
|
| 321 |
+
row_mag = M.norm(dim=-1) # (B*N, V) β encoder confidence per row
|
| 322 |
+
M = F.normalize(M, dim=-1)
|
| 323 |
+
|
| 324 |
+
# CM validation on M rows β sample (k+1) rows per token
|
| 325 |
+
# Use fixed evenly-spaced indices for deterministic CM
|
| 326 |
+
nv = self._cm_k + 1
|
| 327 |
+
cm_idx = torch.linspace(0, self.V - 1, nv).long().to(M.device)
|
| 328 |
+
cm_verts = M[:, cm_idx, :] # (B*N, nv, D)
|
| 329 |
+
cm_d2, cm_vol2 = self.cm(cm_verts)
|
| 330 |
+
|
| 331 |
+
# SVD decomposition (in compute graph, fp64)
|
| 332 |
+
U, S, Vt = gram_eigh_svd(M)
|
| 333 |
+
|
| 334 |
+
# Reshape for cross-attention over N tokens
|
| 335 |
+
U = U.reshape(B, N, self.V, self.D)
|
| 336 |
+
S = S.reshape(B, N, self.D)
|
| 337 |
+
Vt = Vt.reshape(B, N, self.D, self.D)
|
| 338 |
+
M = M.reshape(B, N, self.V, self.D)
|
| 339 |
+
|
| 340 |
+
# Cross-attention multiplicatively scales S across tokens
|
| 341 |
+
S_orig = S.clone()
|
| 342 |
+
for layer in self.cross_attn:
|
| 343 |
+
S = layer(S)
|
| 344 |
+
|
| 345 |
+
# Recompose with modified S β M_hat β M
|
| 346 |
+
U_flat = U.reshape(B * N, self.V, self.D)
|
| 347 |
+
S_flat = S.reshape(B * N, self.D)
|
| 348 |
+
Vt_flat = Vt.reshape(B * N, self.D, self.D)
|
| 349 |
+
M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat)
|
| 350 |
+
|
| 351 |
+
# Output projection: M_hat + cm_d2 + magnitudes β token_dim
|
| 352 |
+
out_features = torch.cat([
|
| 353 |
+
M_hat.reshape(B * N, -1), # (B*N, V*D) β recomposed spectral structure
|
| 354 |
+
cm_d2, # (B*N, npairs) β geometric arrangement
|
| 355 |
+
row_mag, # (B*N, V) β encoder confidence
|
| 356 |
+
], dim=-1)
|
| 357 |
+
h = F.gelu(self.out_in(out_features))
|
| 358 |
+
for block in self.out_blocks:
|
| 359 |
+
h = h + block(h)
|
| 360 |
+
output = self.out_proj(h).reshape(B, N, self.token_dim)
|
| 361 |
+
|
| 362 |
+
return {
|
| 363 |
+
'output': output,
|
| 364 |
+
'M': M,
|
| 365 |
+
'U': U,
|
| 366 |
+
'S_orig': S_orig,
|
| 367 |
+
'S': S,
|
| 368 |
+
'Vt': Vt,
|
| 369 |
+
'M_hat': M_hat.reshape(B, N, self.V, self.D),
|
| 370 |
+
'cm_d2': cm_d2,
|
| 371 |
+
'cm_vol2': cm_vol2,
|
| 372 |
+
'row_mag': row_mag.reshape(B, N, self.V),
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 376 |
+
"""(B, N, token_dim) β (B, N, token_dim). Drop-in compatible."""
|
| 377 |
+
return self.format(tokens)['output']
|
| 378 |
+
|
| 379 |
+
# ββ CM Diagnostics βββββββββββββββββββββββββββββββββββββββββββ
|
| 380 |
+
|
| 381 |
+
def cm_cv(self, M: torch.Tensor, n_samples: int = 200) -> float:
|
| 382 |
+
"""Compute CV of pentachoron volumes over random 5-point subsets.
|
| 383 |
+
M: (B, N, V, D) β sphere-normalized matrices.
|
| 384 |
+
Returns mean CV across all B*N matrices.
|
| 385 |
+
"""
|
| 386 |
+
flat = M.reshape(-1, self.V, self.D)
|
| 387 |
+
# Sample a few matrices to keep cost reasonable
|
| 388 |
+
n_mats = min(flat.shape[0], 64)
|
| 389 |
+
cvs = []
|
| 390 |
+
for i in range(n_mats):
|
| 391 |
+
c = cv_of(flat[i], n_samples=n_samples)
|
| 392 |
+
cvs.append(c)
|
| 393 |
+
return sum(cvs) / len(cvs) if cvs else 0.0
|
| 394 |
+
|
| 395 |
+
def cm_vol2_stats(self, cm_vol2: torch.Tensor) -> dict:
|
| 396 |
+
"""Statistics on CM volΒ² from format() output.
|
| 397 |
+
cm_vol2: (B*N,) β one volΒ² per token's sampled pentachoron.
|
| 398 |
+
"""
|
| 399 |
+
valid = cm_vol2.abs() > 1e-20
|
| 400 |
+
if valid.sum() < 2:
|
| 401 |
+
return {'mean': 0.0, 'std': 0.0, 'frac_valid': 0.0}
|
| 402 |
+
vols = cm_vol2[valid].abs().sqrt()
|
| 403 |
+
return {
|
| 404 |
+
'mean': vols.mean().item(),
|
| 405 |
+
'std': vols.std().item(),
|
| 406 |
+
'cv': (vols.std() / (vols.mean() + 1e-8)).item(),
|
| 407 |
+
'frac_valid': valid.float().mean().item(),
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
# ββ SVD Diagnostics ββββββββββββββββββββββββββββββββββββββββββ
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def effective_rank(S: torch.Tensor) -> torch.Tensor:
|
| 414 |
+
"""Shannon entropy of normalized singular values, exponentiated.
|
| 415 |
+
erank = exp(-Ξ£ p_i log p_i) where p_i = Ο_i / Ξ£Ο.
|
| 416 |
+
Returns 1.0 for rank-1, D for uniform spectrum.
|
| 417 |
+
"""
|
| 418 |
+
p = S / (S.sum(-1, keepdim=True) + 1e-8)
|
| 419 |
+
p = p.clamp(min=1e-8)
|
| 420 |
+
return (-(p * p.log()).sum(-1)).exp()
|
| 421 |
+
|
| 422 |
+
@staticmethod
|
| 423 |
+
def spectral_shift(S_orig, S_coord):
|
| 424 |
+
"""Mean |S_coord - S_orig| across all modes and tokens."""
|
| 425 |
+
return (S_coord - S_orig).abs().mean().item()
|
| 426 |
+
|
| 427 |
+
@staticmethod
|
| 428 |
+
def trace_check(M):
|
| 429 |
+
"""trace(M^T M) should equal V (sum of squared unit row norms)."""
|
| 430 |
+
flat = M.reshape(-1, M.shape[-2], M.shape[-1])
|
| 431 |
+
G = torch.bmm(flat.transpose(1, 2), flat)
|
| 432 |
+
return torch.diagonal(G, dim1=-2, dim2=-1).sum(-1).mean().item()
|
| 433 |
+
|
| 434 |
+
def summary(self):
|
| 435 |
+
"""Print shapes, param count, DOF ratio, CM config."""
|
| 436 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 437 |
+
sphere_dof = self.V * (self.D - 1)
|
| 438 |
+
ratio = sphere_dof / self.token_dim
|
| 439 |
+
print(f"SpectralCell:")
|
| 440 |
+
print(f" token_dim={self.token_dim}, V={self.V}, D={self.D}")
|
| 441 |
+
print(f" mat_dim={self.mat_dim} ({self.V}Γ{self.D})")
|
| 442 |
+
print(f" sphere DOF={sphere_dof} (V rows Γ {self.D-1} free per row)")
|
| 443 |
+
print(f" CM: k={self._cm_k} ({self._cm_k+1} vertices, {self._cm_npairs} pairs)")
|
| 444 |
+
print(f" out_in: {self.mat_dim} (M_hat) + {self._cm_npairs} (cm_d2) + {self.V} (mag) = {self.mat_dim + self._cm_npairs + self.V}")
|
| 445 |
+
print(f" hidden={self.hidden}, depth={len(self.enc_blocks)}")
|
| 446 |
+
print(f" cross_attn={len(self.cross_attn)} layers")
|
| 447 |
+
print(f" params: {n_params:,}")
|
| 448 |
+
print(f" DOF ratio: {ratio:.2f}Γ "
|
| 449 |
+
f"({'expand' if ratio > 1 else 'compress' if ratio < 1 else 'identity'})")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# ββ Factory functions ββββββββββββββββββββββββββββββββββββββββββββ
|
| 453 |
+
|
| 454 |
+
def spectral_cell_tiny(token_dim: int) -> SpectralCell:
|
| 455 |
+
"""V=8, D=4, hidden=64, depth=1, 1 cross-attn."""
|
| 456 |
+
return SpectralCell(token_dim, V=8, D=4, hidden=64, depth=1, n_cross=1)
|
| 457 |
+
|
| 458 |
+
def spectral_cell_small(token_dim: int) -> SpectralCell:
|
| 459 |
+
"""V=16, D=4, hidden=128, depth=2, 1 cross-attn."""
|
| 460 |
+
return SpectralCell(token_dim, V=16, D=4, hidden=128, depth=2, n_cross=1)
|
| 461 |
+
|
| 462 |
+
def spectral_cell_base(token_dim: int) -> SpectralCell:
|
| 463 |
+
"""V=16, D=8, hidden=256, depth=2, 2 cross-attn."""
|
| 464 |
+
return SpectralCell(token_dim, V=16, D=8, hidden=256, depth=2, n_cross=2, n_heads=4)
|
| 465 |
+
|
| 466 |
+
def spectral_cell_diamond(token_dim: int) -> SpectralCell:
|
| 467 |
+
"""V=16, D=16, hidden=256, depth=2, 1 cross-attn. Best sweep config."""
|
| 468 |
+
return SpectralCell(token_dim, V=16, D=16, hidden=256, depth=2, n_cross=1, n_heads=4)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# ββ Self-test βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 472 |
+
|
| 473 |
+
if __name__ == '__main__':
|
| 474 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 475 |
+
|
| 476 |
+
for name, factory in [('tiny', spectral_cell_tiny),
|
| 477 |
+
('small', spectral_cell_small),
|
| 478 |
+
('diamond', spectral_cell_diamond)]:
|
| 479 |
+
print(f"\n{'='*50}")
|
| 480 |
+
cell = factory(token_dim=192).to(device)
|
| 481 |
+
cell.summary()
|
| 482 |
+
|
| 483 |
+
tokens = torch.randn(2, 16, 192, device=device)
|
| 484 |
+
result = cell.format(tokens)
|
| 485 |
+
|
| 486 |
+
print(f"\n Input: {tokens.shape}")
|
| 487 |
+
print(f" Output: {result['output'].shape}")
|
| 488 |
+
print(f" M: {result['M'].shape}")
|
| 489 |
+
print(f" S: {result['S'].shape}")
|
| 490 |
+
print(f" cm_d2: {result['cm_d2'].shape}")
|
| 491 |
+
print(f" cm_vol2: {result['cm_vol2'].shape}")
|
| 492 |
+
print(f" trace: {cell.trace_check(result['M']):.4f} (expect {cell.V})")
|
| 493 |
+
print(f" erank: {cell.effective_rank(result['S_orig'].reshape(-1, cell.D)).mean():.2f}")
|
| 494 |
+
print(f" shift: {cell.spectral_shift(result['S_orig'], result['S']):.6f}")
|
| 495 |
+
|
| 496 |
+
# CM stats
|
| 497 |
+
cm_stats = cell.cm_vol2_stats(result['cm_vol2'])
|
| 498 |
+
print(f" cm_vol: mean={cm_stats['mean']:.6f} cv={cm_stats.get('cv', 0):.4f} "
|
| 499 |
+
f"valid={cm_stats['frac_valid']:.1%}")
|
| 500 |
+
|
| 501 |
+
# Full CV (slower, samples 200 pentachora)
|
| 502 |
+
with torch.no_grad():
|
| 503 |
+
cv = cell.cm_cv(result['M'], n_samples=100)
|
| 504 |
+
print(f" cm_cv: {cv:.4f}")
|
| 505 |
+
|
| 506 |
+
# Gradient check
|
| 507 |
+
loss = result['output'].sum()
|
| 508 |
+
loss.backward()
|
| 509 |
+
grad_ok = all(p.grad is not None and p.grad.abs().sum() > 0
|
| 510 |
+
for p in cell.parameters() if p.requires_grad)
|
| 511 |
+
print(f" grads: {'β' if grad_ok else 'β'}")
|