| """ |
| SpectralCell |
| ============ |
| Drop-in layer: (B, N, token_dim) β (B, N, token_dim). |
| |
| Pipeline: |
| tokens β Linear β residual MLP β Linear(hidden, V*D) β reshape(V, D) |
| β capture row magnitudes (encoder confidence) |
| β F.normalize(dim=-1) |
| β full pairwise dΒ² between all V rows on S^{D-1} β C(V,2) pairs |
| β patchwork: round-robin compartment MLPs on dΒ² β structured features |
| β [optional] CM validation β simplex volume (cm_vol2) for soft hand |
| β SVD: analytical closed-form (D=2) or Gram-eigh fp64 (D>2) |
| β cross-attention scales S per mode across all N tokens |
| β recompose M_hat = U Β· diag(S_modified) Β· Vt |
| β cat(M_hat, patchwork_features, row_magnitudes) |
| β Linear β residual MLP β Linear(hidden, token_dim) β output |
| |
| SVD dispatch (via geolip.linalg.svd auto-dispatch): |
| D=2: Fused Triton kernel ~0.02ms (fp32, when available) |
| D=3: Fused Triton kernel ~0.02ms (fp32, when available) |
| Dβ€12: Gram + FL eigh (compilable, 70/72 purity, zero graph breaks) |
| D>12: Gram + torch.linalg.eigh (cuSOLVER) |
| CPU: torch.linalg.svd fallback |
| Falls back to hand-rolled Gram-eigh if geolip.linalg not installed. |
| |
| Cross-attention modifies S multiplicatively: |
| S_out = S * (1 + Ξ± * tanh(attention_output)) |
| Ξ± per mode, bounded [0, max_alpha], initialized ~0.024. |
| M_hat β M after this step. |
| |
| Sphere normalization enforces ||row||=1 for all V rows. |
| This constrains trace(M^T M) = V (fixed total spectral energy). |
| The SVD decomposes how that fixed energy distributes across D axes. |
| |
| CM validation (configurable): |
| cm_enabled: toggle per cell (disable for speed, enable every Nth cell) |
| cm_points: vertices per simplex (5 = pentachoron) |
| cm_min: minimum volume threshold for valid measurements |
| CM can be disabled for degenerate (D=2) cells where pentachoron |
| geometry is not well-defined on S^1. |
| |
| Factory configurations: |
| spectral_cell_degenerate: D=16 via slice_d=2, no CM, 8Γ Triton D=2 kernels |
| spectral_cell_primary: D=16, CM enabled, full geometric instrument |
| spectral_cell_conduit: D=16, CM disabled (measured externally) |
| spectral_cell_diamond: D=16, CM enabled, best sweep config |
| spectral_cell_tiny/small/base: research configs |
| |
| Author: AbstractPhil + Claude Opus |
| License: Apache 2.0 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from itertools import combinations |
| import warnings |
|
|
|
|
| |
|
|
| class CMValidator(nn.Module): |
| """Batch-friendly Cayley-Menger determinant. |
| Computes pairwise squared distances and simplex volume |
| for (k+1)-point subsets in arbitrary embedding dimension. |
| |
| For k=4: 5 vertices β 10 pairwise dΒ² + 1 volΒ². |
| """ |
| def __init__(self, k): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| pairs = list(combinations(range(self._nv), 2)) |
| self._npairs = len(pairs) |
| self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
|
|
| def forward(self, verts): |
| """verts: (..., nv, edim) β d2_pairs: (..., npairs), vol2: (...)""" |
| gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| d2_mat = F.relu(d2_mat) |
| d2_pairs = d2_mat[..., self._pi, self._pj] |
| shape = d2_mat.shape[:-2] |
| Vn = d2_mat.shape[-1] |
| cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype) |
| cm[..., 0, 1:] = 1.0 |
| cm[..., 1:, 0] = 1.0 |
| cm[..., 1:, 1:] = d2_mat |
| vol2 = self._prefactor * torch.linalg.det(cm.float()) |
| vol2 = vol2.to(d2_pairs.dtype) |
| return d2_pairs, vol2 |
|
|
|
|
| def cayley_menger_vol2(points: torch.Tensor) -> torch.Tensor: |
| """Squared simplex volume via CM determinant in fp64. |
| points: (B, N, D) β vol2: (B,) |
| """ |
| B, N, D = points.shape |
| pts = points.double() |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_of(emb: torch.Tensor, n_samples: int = 200) -> float: |
| """Coefficient of variation of pentachoron volumes. |
| emb: (V, D) β rows of a sphere-normalized matrix. |
| Samples random 5-point subsets, computes CM volΒ² for each, |
| returns std(vol) / mean(vol). |
| |
| CV β 0.20-0.23 is the empirically observed attractor band. |
| Returns 0.0 if insufficient valid volumes. |
| """ |
| if emb.dim() != 2 or emb.shape[0] < 5: |
| return 0.0 |
| N, D = emb.shape |
| pool = min(N, 512) |
| indices = torch.stack([ |
| torch.randperm(pool, device=emb.device)[:5] |
| for _ in range(n_samples) |
| ]) |
| vol2 = cayley_menger_vol2(emb[:pool][indices]) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return 0.0 |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| try: |
| from geolip_core.linalg import svd as batched_svd |
| from geolip_core.linalg import backend as linalg_backend |
| _HAS_GEOLIP_LINALG = True |
| except ImportError: |
| try: |
| from geolip.linalg import svd as batched_svd |
| from geolip.linalg import backend as linalg_backend |
| _HAS_GEOLIP_LINALG = True |
| except ImportError: |
| _HAS_GEOLIP_LINALG = False |
| linalg_backend = None |
|
|
| def batched_svd(A, method='auto', compute_dtype='fp64', **kw): |
| """Fallback SVD when geolip.linalg is not installed.""" |
| B, V, D = A.shape |
| orig = A.dtype |
| with torch.amp.autocast('cuda', enabled=False): |
| Ad = A.double() |
| G = torch.bmm(Ad.transpose(1, 2), Ad) |
| G.diagonal(dim1=-2, dim2=-1).add_(1e-12) |
| with warnings.catch_warnings(): |
| warnings.filterwarnings('ignore', message='.*not converge.*') |
| eigenvalues, Vecs = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| Vecs = Vecs.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-24)) |
| U = torch.bmm(Ad, Vecs) / S.unsqueeze(1).clamp(min=1e-16) |
| Vh = Vecs.transpose(-2, -1).contiguous() |
| return U.to(orig), S.to(orig), Vh.to(orig) |
|
|
|
|
| |
|
|
| class SpectralPatchwork(nn.Module): |
| """Compartmentalized round-robin interpretation of pairwise distances. |
| |
| Takes C(V,2) pairwise squared distances between M rows on S^{D-1} |
| and processes them through K compartment MLPs via stride slicing. |
| |
| comp[k] sees distances[k::K] β every K-th distance starting at k. |
| Each compartment interprets its slice independently. |
| Outputs concatenate into structured geometric features. |
| |
| This is the patchwork from geolip-core adapted for the SpectralCell: |
| triangulation distances β compartment MLPs β structured features. |
| """ |
| def __init__(self, n_pairs, n_comp=8, comp_hidden=32, comp_out=8): |
| super().__init__() |
| self.n_pairs = n_pairs |
| self.n_comp = n_comp |
| self.comp_out = comp_out |
|
|
| |
| self._pairs_per_comp = (n_pairs + n_comp - 1) // n_comp |
|
|
| self.compartments = nn.ModuleList([ |
| nn.Sequential( |
| nn.Linear(self._pairs_per_comp, comp_hidden), |
| nn.GELU(), |
| nn.Linear(comp_hidden, comp_out), |
| ) for _ in range(n_comp) |
| ]) |
|
|
| @property |
| def output_dim(self): |
| return self.n_comp * self.comp_out |
|
|
| def forward(self, d2_pairs): |
| """d2_pairs: (B, n_pairs) β (B, n_comp * comp_out)""" |
| outs = [] |
| for k, comp in enumerate(self.compartments): |
| chunk = d2_pairs[:, k::self.n_comp] |
| |
| if chunk.shape[-1] < self._pairs_per_comp: |
| pad = torch.zeros( |
| chunk.shape[0], self._pairs_per_comp - chunk.shape[-1], |
| device=chunk.device, dtype=chunk.dtype) |
| chunk = torch.cat([chunk, pad], dim=-1) |
| outs.append(comp(chunk)) |
| return torch.cat(outs, dim=-1) |
|
|
|
|
|
|
|
|
| class SpectralCrossAttention(nn.Module): |
| """Multi-head attention on singular values across N tokens. |
| |
| Input S: (B, N, D) β one D-dim spectral profile per token. |
| Attends across N positions (each token sees all others' spectra). |
| Output: S * (1 + Ξ± * tanh(out_proj(attended))) |
| |
| Ξ± is per-mode, bounded [0, max_alpha] via sigmoid on learnable logits. |
| Initialized at sigmoid(-2.0) * 0.2 β 0.024 per mode. |
| """ |
| def __init__(self, D, n_heads=2, max_alpha=0.2, alpha_init=-2.0): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = D // n_heads |
| self.max_alpha = max_alpha |
| assert D % n_heads == 0 |
|
|
| self.qkv = nn.Linear(D, 3 * D) |
| self.out_proj = nn.Linear(D, D) |
| self.norm = nn.LayerNorm(D) |
| self.scale = self.head_dim ** -0.5 |
| self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) |
|
|
| @property |
| def alpha(self): |
| return self.max_alpha * torch.sigmoid(self.alpha_logits) |
|
|
| def forward(self, S): |
| B, N, D = S.shape |
| Sn = self.norm(S) |
| qkv = self.qkv(Sn).reshape(B, N, 3, self.n_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, N, D) |
| gate = torch.tanh(self.out_proj(out)) |
| alpha = self.alpha.unsqueeze(0).unsqueeze(0) |
| return S * (1.0 + alpha * gate) |
|
|
|
|
| |
|
|
| class SpectralCell(nn.Module): |
| """Processes N tokens through sphere-normalized SVD with spectral |
| coordination and Cayley-Menger geometric validation. |
| |
| Shapes through the pipeline (for default V=16, D=4, hidden=128, token_dim=64): |
| tokens: (B, N, 64) |
| enc_in: Linear(64, 128) β (B*N, 128) |
| enc_blocks: 2Γ residual MLP β (B*N, 128) |
| enc_out: Linear(128, 64) β (B*N, 64) β reshape (B*N, 16, 4) |
| normalize: F.normalize(dim=-1) β each row has norm 1 |
| SVD: Gram-eigh in fp64 β U(B*N,16,4), S(B*N,4), Vt(B*N,4,4) |
| cross_attn: S reshaped (B,N,4) β attention across N β S_coord (B,N,4) |
| recompose: U Β· diag(S_coord) Β· Vt β M_hat (B*N, 16, 4) β flatten (B*N, 64) |
| out_in: Linear(64, 128) β (B*N, 128) |
| out_blocks: 2Γ residual MLP β (B*N, 128) |
| out_proj: Linear(128, 64) β (B, N, 64) |
| |
| CM validation: |
| M rows are V unit vectors on S^{D-1}. |
| CMValidator(k=4) samples pentachora from the rows. |
| volΒ² measures simplex volume. CV measures uniformity. |
| cv_of() returns the coefficient of variation over random subsets. |
| |
| Args: |
| token_dim: input and output dimension per token |
| V: matrix rows (each becomes a unit vector on S^{D-1}) |
| D: matrix columns (spectral modes, eigenvalue count) |
| hidden: residual MLP width |
| depth: residual blocks in input and output projections |
| n_cross: SpectralCrossAttention layers applied to S |
| n_heads: attention heads in cross-attention (must divide D) |
| max_alpha: upper bound on per-mode multiplicative scaling |
| cm_enabled: whether to compute CM validation (disable for speed) |
| cm_points: vertices per simplex (5 = pentachoron) |
| cm_samples: samples for cv_of monitoring |
| cm_min: minimum volume threshold for valid CM |
| degen_threshold: eigenvalue ratio below which SVD is degenerate |
| slice_d: decompose D-dim SVD as D//slice_d independent slice_d-dim SVDs |
| slice_d=2 with D=16 β 8 Triton D=2 kernels for D=16 capacity |
| 0 = no slicing (full SVD) |
| """ |
| def __init__( |
| self, |
| token_dim: int, |
| V: int = 16, |
| D: int = 4, |
| hidden: int = 128, |
| depth: int = 2, |
| n_cross: int = 1, |
| n_heads: int = 2, |
| max_alpha: float = 0.2, |
| cm_enabled: bool = True, |
| cm_points: int = 5, |
| cm_samples: int = 200, |
| cm_min: float = 1e-16, |
| degen_threshold: float = 1e-6, |
| slice_d: int = 0, |
| ): |
| super().__init__() |
| self.token_dim = token_dim |
| self.V = V |
| self.D = D |
| self.mat_dim = V * D |
| self.hidden = hidden |
| self.cm_enabled = cm_enabled |
| self.cm_points = cm_points |
| self.cm_samples = cm_samples |
| self.cm_min = cm_min |
| self.degen_threshold = degen_threshold |
|
|
| |
| |
| |
| if slice_d > 0: |
| assert D % slice_d == 0, f"D={D} must be divisible by slice_d={slice_d}" |
| self.slice_d = slice_d |
| self.n_slices = D // slice_d |
| else: |
| self.slice_d = 0 |
| self.n_slices = 1 |
|
|
| |
| |
| self._cm_k = min(cm_points - 1, D - 1) if D >= 2 else 1 |
| if cm_enabled and D >= cm_points - 1: |
| self.cm = CMValidator(self._cm_k) |
| self._cm_npairs = self.cm._npairs |
| else: |
| self.cm = None |
| self._cm_npairs = 0 |
| if cm_enabled and D < cm_points - 1: |
| self.cm_enabled = False |
|
|
| |
| self.enc_in = nn.Linear(token_dim, hidden) |
| self.enc_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| ) for _ in range(depth) |
| ]) |
| self.enc_out = nn.Linear(hidden, self.mat_dim) |
| nn.init.orthogonal_(self.enc_out.weight) |
|
|
| |
| self.cross_attn = nn.ModuleList([ |
| SpectralCrossAttention(D, n_heads=n_heads, max_alpha=max_alpha) |
| for _ in range(n_cross) |
| ]) |
|
|
| |
| |
| self._n_pairs = V * (V - 1) // 2 |
| _triu = torch.triu_indices(V, V, offset=1) |
| self.register_buffer('_triu_i', _triu[0]) |
| self.register_buffer('_triu_j', _triu[1]) |
| n_comp = min(8, self._n_pairs) |
| comp_hidden = max(16, self._n_pairs // n_comp * 2) |
| self.patchwork = SpectralPatchwork( |
| n_pairs=self._n_pairs, n_comp=n_comp, |
| comp_hidden=comp_hidden, comp_out=8, |
| ) |
|
|
| |
| pw_out = self.patchwork.output_dim |
| self.out_in = nn.Linear(self.mat_dim + pw_out + self.V, hidden) |
| self.out_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| ) for _ in range(depth) |
| ]) |
| self.out_proj = nn.Linear(hidden, token_dim) |
|
|
| def format(self, tokens: torch.Tensor) -> dict: |
| """Run full pipeline. Returns output tokens, SVD components, and CM metrics. |
| |
| Args: |
| tokens: (B, N, token_dim) |
| |
| Returns: |
| dict: |
| output: (B, N, token_dim) β processed tokens |
| M: (B, N, V, D) β sphere-normalized matrix (rows on S^{D-1}) |
| U: (B, N, V, D) β left singular vectors from SVD |
| S_orig: (B, N, D) β singular values before cross-attention |
| S: (B, N, D) β singular values after cross-attention |
| Vt: (B, N, D, D) β right singular vectors from SVD |
| M_hat: (B, N, V, D) β U Β· diag(S_modified) Β· Vt (β M) |
| cm_d2: (B*N, cm_npairs) β CM pairwise dΒ² (5-point sample, for soft hand) |
| cm_vol2: (B*N,) β squared simplex volume from CM |
| row_mag: (B, N, V) β pre-normalization row magnitudes |
| d2_pairs: (B*N, C(V,2)) β full pairwise dΒ² between all M rows |
| pw_features: (B*N, pw_out) β patchwork structured geometric features |
| """ |
| B, N, _ = tokens.shape |
|
|
| |
| flat = tokens.reshape(B * N, -1) |
| h = F.gelu(self.enc_in(flat)) |
| for block in self.enc_blocks: |
| h = h + block(h) |
| M = self.enc_out(h).reshape(B * N, self.V, self.D) |
| row_mag = M.norm(dim=-1) |
| M = F.normalize(M, dim=-1) |
|
|
| |
| if self.cm_enabled and self.cm is not None: |
| nv = self._cm_k + 1 |
| cm_idx = torch.linspace(0, self.V - 1, nv).long().to(M.device) |
| cm_verts = M[:, cm_idx, :] |
| cm_d2, cm_vol2 = self.cm(cm_verts) |
| else: |
| cm_d2 = None |
| cm_vol2 = None |
|
|
| |
| |
| gram = torch.bmm(M, M.transpose(1, 2)) |
| d2_full = 2.0 - 2.0 * gram |
| d2_pairs = d2_full[:, self._triu_i, self._triu_j] |
|
|
| |
| pw_features = self.patchwork(d2_pairs) |
|
|
| |
| if self.slice_d > 0: |
| |
| |
| BN = B * N |
| M_slices = M.reshape(BN, self.V, self.n_slices, self.slice_d) |
| M_slices = M_slices.permute(0, 2, 1, 3).reshape(BN * self.n_slices, self.V, self.slice_d) |
|
|
| |
| |
| U_s, S_s, Vt_s = batched_svd(M_slices, compute_dtype='fp32') |
|
|
| |
| S = S_s.reshape(BN, self.n_slices * self.slice_d) |
|
|
| |
| U_s = U_s.reshape(BN, self.n_slices, self.V, self.slice_d) |
| Vt_s = Vt_s.reshape(BN, self.n_slices, self.slice_d, self.slice_d) |
|
|
| |
| U = U_s.permute(0, 2, 1, 3).reshape(B, N, self.V, self.D) |
| S = S.reshape(B, N, self.D) |
| M = M.reshape(B, N, self.V, self.D) |
|
|
| |
| S_orig = S.clone() |
| for layer in self.cross_attn: |
| S = layer(S) |
|
|
| |
| S_sliced = S.reshape(BN, self.n_slices, self.slice_d) |
| M_hat_slices = [] |
| for i in range(self.n_slices): |
| u_i = U_s[:, i] |
| s_i = S_sliced[:, i] |
| vt_i = Vt_s[:, i] |
| mh_i = torch.bmm(u_i * s_i.unsqueeze(1), vt_i) |
| M_hat_slices.append(mh_i) |
| M_hat = torch.cat(M_hat_slices, dim=-1) |
|
|
| |
| Vt = torch.zeros(BN, self.D, self.D, device=M_hat.device, dtype=M_hat.dtype) |
| for i in range(self.n_slices): |
| s, e = i * self.slice_d, (i + 1) * self.slice_d |
| Vt[:, s:e, s:e] = Vt_s[:, i] |
| Vt = Vt.reshape(B, N, self.D, self.D) |
|
|
| else: |
| |
| |
| |
| |
| dtype_arg = 'fp32' if self.D <= 3 else 'fp64' |
| U, S, Vt = batched_svd(M, compute_dtype=dtype_arg) |
|
|
| |
| U = U.reshape(B, N, self.V, self.D) |
| S = S.reshape(B, N, self.D) |
| Vt = Vt.reshape(B, N, self.D, self.D) |
| M = M.reshape(B, N, self.V, self.D) |
|
|
| |
| S_orig = S.clone() |
| for layer in self.cross_attn: |
| S = layer(S) |
|
|
| |
| U_flat = U.reshape(B * N, self.V, self.D) |
| S_flat = S.reshape(B * N, self.D) |
| Vt_flat = Vt.reshape(B * N, self.D, self.D) |
| M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) |
|
|
| |
| out_features = torch.cat([ |
| M_hat.reshape(B * N, -1), |
| pw_features, |
| row_mag, |
| ], dim=-1) |
| h = F.gelu(self.out_in(out_features)) |
| for block in self.out_blocks: |
| h = h + block(h) |
| output = self.out_proj(h).reshape(B, N, self.token_dim) |
|
|
| return { |
| 'output': output, |
| 'M': M, |
| 'U': U, |
| 'S_orig': S_orig, |
| 'S': S, |
| 'Vt': Vt, |
| 'M_hat': M_hat.reshape(B, N, self.V, self.D), |
| 'cm_d2': cm_d2, |
| 'cm_vol2': cm_vol2, |
| 'row_mag': row_mag.reshape(B, N, self.V), |
| 'd2_pairs': d2_pairs, |
| 'pw_features': pw_features, |
| } |
|
|
| def forward(self, tokens: torch.Tensor) -> torch.Tensor: |
| """(B, N, token_dim) β (B, N, token_dim). Drop-in compatible.""" |
| return self.format(tokens)['output'] |
|
|
| |
|
|
| def cm_cv(self, M: torch.Tensor, n_samples: int = 200) -> float: |
| """Compute CV of pentachoron volumes over random 5-point subsets. |
| M: (B, N, V, D) β sphere-normalized matrices. |
| Returns mean CV across all B*N matrices. |
| """ |
| flat = M.reshape(-1, self.V, self.D) |
| |
| n_mats = min(flat.shape[0], 64) |
| cvs = [] |
| for i in range(n_mats): |
| c = cv_of(flat[i], n_samples=n_samples) |
| cvs.append(c) |
| return sum(cvs) / len(cvs) if cvs else 0.0 |
|
|
| def cm_vol2_stats(self, cm_vol2) -> dict: |
| """Statistics on CM volΒ² from format() output. |
| cm_vol2: (B*N,) or None if CM disabled. |
| """ |
| if cm_vol2 is None: |
| return {'mean': 0.0, 'std': 0.0, 'cv': 0.0, 'frac_valid': 0.0} |
| valid = cm_vol2.abs() > self.cm_min |
| if valid.sum() < 2: |
| return {'mean': 0.0, 'std': 0.0, 'frac_valid': 0.0} |
| vols = cm_vol2[valid].abs().sqrt() |
| return { |
| 'mean': vols.mean().item(), |
| 'std': vols.std().item(), |
| 'cv': (vols.std() / (vols.mean() + 1e-8)).item(), |
| 'frac_valid': valid.float().mean().item(), |
| } |
|
|
| |
|
|
| @staticmethod |
| def effective_rank(S: torch.Tensor) -> torch.Tensor: |
| """Shannon entropy of normalized singular values, exponentiated. |
| erank = exp(-Ξ£ p_i log p_i) where p_i = Ο_i / Ξ£Ο. |
| Returns 1.0 for rank-1, D for uniform spectrum. |
| """ |
| p = S / (S.sum(-1, keepdim=True) + 1e-8) |
| p = p.clamp(min=1e-8) |
| return (-(p * p.log()).sum(-1)).exp() |
|
|
| @staticmethod |
| def spectral_shift(S_orig, S_coord): |
| """Mean |S_coord - S_orig| across all modes and tokens.""" |
| return (S_coord - S_orig).abs().mean().item() |
|
|
| @staticmethod |
| def trace_check(M): |
| """trace(M^T M) should equal V (sum of squared unit row norms).""" |
| flat = M.reshape(-1, M.shape[-2], M.shape[-1]) |
| G = torch.bmm(flat.transpose(1, 2), flat) |
| return torch.diagonal(G, dim1=-2, dim2=-1).sum(-1).mean().item() |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _pack_stats(x: torch.Tensor, prefix: str) -> dict: |
| """Quantile summary of a 1-D tensor. All values are Python floats.""" |
| x = x.detach().float().reshape(-1) |
| if x.numel() == 0: |
| return {f"{prefix}_mean": 0.0, f"{prefix}_std": 0.0, |
| f"{prefix}_p50": 0.0, f"{prefix}_p95": 0.0, f"{prefix}_max": 0.0} |
| q = torch.quantile(x, torch.tensor([0.50, 0.95], device=x.device, dtype=x.dtype)) |
| return { |
| f"{prefix}_mean": x.mean().item(), |
| f"{prefix}_std": x.std(unbiased=False).item(), |
| f"{prefix}_p50": q[0].item(), |
| f"{prefix}_p95": q[1].item(), |
| f"{prefix}_max": x.max().item(), |
| } |
|
|
| @torch.no_grad() |
| def svd_health(self, result: dict) -> dict: |
| """SVD solver correctness. Is the decomposition numerically trustworthy? |
| All tensors detached β no gradient path.""" |
| M = result['M'].detach().float().reshape(-1, self.V, self.D) |
| U = result['U'].detach().float().reshape(-1, self.V, self.D) |
| S0 = result['S_orig'].detach().float().reshape(-1, self.D) |
| S1 = result['S'].detach().float().reshape(-1, self.D) |
| Vt = result['Vt'].detach().float().reshape(-1, self.D, self.D) |
|
|
| |
| M_recon = torch.bmm(U * S0.unsqueeze(1), Vt) |
| recon_rel = (M - M_recon).norm(dim=(-2, -1)) / M.norm(dim=(-2, -1)).clamp(min=1e-12) |
|
|
| |
| I = torch.eye(self.D, device=M.device).unsqueeze(0) |
| u_orth = (torch.bmm(U.transpose(1, 2), U) - I).norm(dim=(-2, -1)) |
| v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I).norm(dim=(-2, -1)) |
|
|
| |
| energy_M = (M * M).sum(dim=(-2, -1)) |
| energy_S0 = (S0 * S0).sum(dim=-1) |
| energy_resid = (energy_M - energy_S0).abs() / energy_M.clamp(min=1e-12) |
| energy_gain = (S1 * S1).sum(dim=-1) / energy_S0.clamp(min=1e-12) |
|
|
| |
| cond = S0[:, 0] / S0[:, -1].clamp(min=1e-12) |
| rel_gaps = (S0[:, :-1] - S0[:, 1:]).abs() / S0[:, :1].clamp(min=1e-12) |
| degen_frac = (rel_gaps < self.degen_threshold).float().mean(dim=-1) |
| dead_frac = (S0 < S0[:, :1] * self.degen_threshold).float().mean(dim=-1) |
|
|
| out = {} |
| out.update(self._pack_stats(recon_rel, 'svd_recon')) |
| out.update(self._pack_stats(u_orth, 'svd_u_orth')) |
| out.update(self._pack_stats(v_orth, 'svd_v_orth')) |
| out.update(self._pack_stats(energy_resid, 'svd_energy_resid')) |
| out.update(self._pack_stats(energy_gain, 'svd_energy_gain')) |
| out.update(self._pack_stats(cond, 'svd_cond')) |
| out.update(self._pack_stats(degen_frac, 'svd_degen_frac')) |
| out.update(self._pack_stats(dead_frac, 'svd_dead_frac')) |
| return out |
|
|
| @torch.no_grad() |
| def geometry_retention(self, result: dict) -> dict: |
| """Did cross-attention preserve sphere geometry? |
| Gram drift, row norm changes, spectral delta.""" |
| M = result['M'].detach().float().reshape(-1, self.V, self.D) |
| M_hat = result['M_hat'].detach().float().reshape(-1, self.V, self.D) |
| S0 = result['S_orig'].detach().float().reshape(-1, self.D) |
| S1 = result['S'].detach().float().reshape(-1, self.D) |
|
|
| M_hat_n = F.normalize(M_hat, dim=-1) |
| G0 = torch.bmm(M, M.transpose(1, 2)) |
| G1 = torch.bmm(M_hat_n, M_hat_n.transpose(1, 2)) |
| gram_drift = (G0 - G1).abs().mean(dim=(-2, -1)) |
|
|
| row_norm = M_hat.norm(dim=-1) |
| row_norm_mean = row_norm.mean(dim=-1) |
| row_norm_cv = row_norm.std(dim=-1, unbiased=False) / row_norm_mean.clamp(min=1e-12) |
|
|
| spectral_delta = (S1 - S0).abs().mean(dim=-1) |
| spectral_ratio = S1.sum(dim=-1) / S0.sum(dim=-1).clamp(min=1e-12) |
|
|
| out = {} |
| out.update(self._pack_stats(gram_drift, 'geom_gram_drift')) |
| out.update(self._pack_stats(row_norm_mean, 'geom_row_norm')) |
| out.update(self._pack_stats(row_norm_cv, 'geom_row_norm_cv')) |
| out.update(self._pack_stats(spectral_delta, 'geom_S_delta')) |
| out.update(self._pack_stats(spectral_ratio, 'geom_S_ratio')) |
| return out |
|
|
| @torch.no_grad() |
| def thin_svd_health(self, result: dict) -> dict: |
| """D=2 degenerate cell diagnostics. Empty dict if Dβ 2. |
| Isotropy, frame jumps, handedness flips.""" |
| if self.D != 2 and self.slice_d != 2: |
| return {} |
|
|
| S0 = result['S_orig'].detach().float() |
| Vt = result['Vt'].detach().float() |
|
|
| |
| if self.slice_d == 2 and self.D > 2: |
| S0 = S0.reshape(*S0.shape[:-1], self.n_slices, 2) |
| |
| vt_blocks = torch.stack([ |
| Vt[..., i*2:(i+1)*2, i*2:(i+1)*2] |
| for i in range(self.n_slices) |
| ], dim=-3) |
| S0 = S0.reshape(-1, 2) |
| Vt = vt_blocks.reshape(-1, 2, 2) |
| is_sliced = True |
| else: |
| S0 = S0.reshape(-1, 2) |
| Vt = Vt.reshape(-1, 2, 2) |
| is_sliced = False |
|
|
| |
| isotropy = (S0[:, 0] - S0[:, 1]).abs() / S0.sum(dim=-1).clamp(min=1e-12) |
|
|
| |
| det_v = torch.linalg.det(Vt) |
| neg_det = (det_v < 0).float() |
|
|
| out = {} |
| out.update(self._pack_stats(isotropy, 'thin_isotropy')) |
| out.update(self._pack_stats(neg_det, 'thin_neg_det')) |
| out['thin_sliced'] = float(is_sliced) |
| out['thin_n_matrices'] = float(S0.shape[0]) |
| return out |
|
|
| @torch.no_grad() |
| def diagnostics_bundle(self, result: dict) -> dict: |
| """Complete diagnostic bundle. All floats, no gradients.""" |
| out = {} |
| out.update(self.svd_health(result)) |
| out.update(self.geometry_retention(result)) |
| if self.cm_enabled and result.get('cm_vol2') is not None: |
| out.update(self.cm_vol2_stats(result['cm_vol2'])) |
| out.update(self.thin_svd_health(result)) |
| return out |
|
|
| def summary(self): |
| """Print shapes, param count, DOF ratio, CM config.""" |
| n_params = sum(p.numel() for p in self.parameters()) |
| sphere_dof = self.V * (self.D - 1) |
| ratio = sphere_dof / self.token_dim |
| pw = self.patchwork |
| pw_params = sum(p.numel() for p in pw.parameters()) |
| if self.slice_d > 0: |
| svd_path = f"sliced D=2 Γ {self.n_slices} β D={self.D} ({'geolip Triton' if _HAS_GEOLIP_LINALG else 'fallback'})" |
| elif _HAS_GEOLIP_LINALG: |
| if self.D == 2: |
| svd_path = "geolip.linalg β Triton kernel (D=2)" |
| elif self.D <= 12: |
| svd_path = f"geolip.linalg β FL eigh (D={self.D})" |
| else: |
| svd_path = f"geolip.linalg β torch.linalg.eigh (D={self.D})" |
| else: |
| svd_path = f"fallback Gram-eigh fp64 (D={self.D})" |
| print(f"SpectralCell:") |
| print(f" token_dim={self.token_dim}, V={self.V}, D={self.D}") |
| print(f" mat_dim={self.mat_dim} ({self.V}Γ{self.D})") |
| print(f" sphere DOF={sphere_dof} (V rows Γ {self.D-1} free per row)") |
| print(f" SVD: {svd_path}, degen_threshold={self.degen_threshold}") |
| if self.cm_enabled and self.cm is not None: |
| print(f" CM: enabled, k={self._cm_k} ({self._cm_k+1} vertices, {self._cm_npairs} pairs), min={self.cm_min}") |
| else: |
| print(f" CM: disabled") |
| print(f" patchwork: {self._n_pairs} pairs β {pw.n_comp} comps Γ {pw.comp_out} = {pw.output_dim} features ({pw_params:,} params)") |
| print(f" out_in: {self.mat_dim} (M_hat) + {pw.output_dim} (patchwork) + {self.V} (mag) = {self.mat_dim + pw.output_dim + self.V}") |
| print(f" hidden={self.hidden}, depth={len(self.enc_blocks)}") |
| print(f" cross_attn={len(self.cross_attn)} layers") |
| print(f" params: {n_params:,}") |
| print(f" DOF ratio: {ratio:.2f}Γ " |
| f"({'expand' if ratio > 1 else 'compress' if ratio < 1 else 'identity'})") |
|
|
|
|
| |
|
|
| def spectral_cell_tiny(token_dim: int) -> SpectralCell: |
| """V=8, D=4, hidden=64, depth=1, 1 cross-attn.""" |
| return SpectralCell(token_dim, V=8, D=4, hidden=64, depth=1, n_cross=1) |
|
|
| def spectral_cell_small(token_dim: int) -> SpectralCell: |
| """V=16, D=4, hidden=128, depth=2, 1 cross-attn.""" |
| return SpectralCell(token_dim, V=16, D=4, hidden=128, depth=2, n_cross=1) |
|
|
| def spectral_cell_base(token_dim: int) -> SpectralCell: |
| """V=16, D=8, hidden=256, depth=2, 2 cross-attn.""" |
| return SpectralCell(token_dim, V=16, D=8, hidden=256, depth=2, n_cross=2, n_heads=4) |
|
|
| def spectral_cell_diamond(token_dim: int) -> SpectralCell: |
| """V=16, D=16, hidden=256, depth=2, 1 cross-attn. Best sweep config.""" |
| return SpectralCell(token_dim, V=16, D=16, hidden=256, depth=2, n_cross=1, n_heads=4) |
|
|
|
|
| |
|
|
| def spectral_cell_degenerate(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: |
| """D=16 capacity via D=2 sliced SVD. 8 independent Triton kernels. |
| Full D=16 spectral representation accumulated from D=2 substructures. |
| CM disabled β geometry is degenerate by design. |
| Near-degenerate eigenvalues are expected and handled by Triton. |
| Use as fast geometric formatting between primary cells. |
| """ |
| return SpectralCell( |
| token_dim, V=V, D=16, hidden=hidden, depth=1, n_cross=1, n_heads=4, |
| max_alpha=0.2, cm_enabled=False, degen_threshold=1e-6, |
| slice_d=2, |
| ) |
|
|
| def spectral_cell_primary(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: |
| """D=16 full path. CM enabled, fp64 eigh SVD, patchwork, magnitude. |
| The full geometric instrument. Use every 3rd cell in a stack. |
| """ |
| return SpectralCell( |
| token_dim, V=V, D=16, hidden=hidden, depth=2, n_cross=2, n_heads=4, |
| max_alpha=0.2, cm_enabled=True, cm_points=5, cm_samples=200, |
| cm_min=1e-16, degen_threshold=1e-6, |
| ) |
|
|
| def spectral_cell_conduit(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: |
| """D=16 conduit path. CM disabled (measured externally every 3 cells). |
| Full spectral processing without per-cell CM overhead. |
| """ |
| return SpectralCell( |
| token_dim, V=V, D=16, hidden=hidden, depth=2, n_cross=2, n_heads=4, |
| max_alpha=0.2, cm_enabled=False, degen_threshold=1e-6, |
| ) |
|
|
|
|
| |
|
|
| if __name__ == '__main__': |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| configs = [ |
| ('tiny', spectral_cell_tiny), |
| ('small', spectral_cell_small), |
| ('diamond', spectral_cell_diamond), |
| ('degenerate (D=2)', spectral_cell_degenerate), |
| ('primary (D=16)', spectral_cell_primary), |
| ('conduit (D=16 no CM)', spectral_cell_conduit), |
| ] |
|
|
| for name, factory in configs: |
| print(f"\n{'='*60}") |
| print(f" {name}") |
| print(f"{'='*60}") |
| cell = factory(token_dim=192).to(device) |
| cell.summary() |
|
|
| tokens = torch.randn(2, 16, 192, device=device) |
| result = cell.format(tokens) |
|
|
| print(f"\n Input: {tokens.shape}") |
| print(f" Output: {result['output'].shape}") |
| print(f" M: {result['M'].shape}") |
| print(f" S: {result['S'].shape}") |
| if result['cm_d2'] is not None: |
| print(f" cm_d2: {result['cm_d2'].shape}") |
| print(f" cm_vol2: {result['cm_vol2'].shape}") |
| else: |
| print(f" cm_d2: None (CM disabled)") |
| print(f" cm_vol2: None (CM disabled)") |
| print(f" trace: {cell.trace_check(result['M']):.4f} (expect {cell.V})") |
| print(f" erank: {cell.effective_rank(result['S_orig'].reshape(-1, cell.D)).mean():.2f}") |
| print(f" shift: {cell.spectral_shift(result['S_orig'], result['S']):.6f}") |
|
|
| |
| cm_stats = cell.cm_vol2_stats(result['cm_vol2']) |
| print(f" cm_vol: mean={cm_stats['mean']:.6f} cv={cm_stats.get('cv', 0):.4f} " |
| f"valid={cm_stats['frac_valid']:.1%}") |
|
|
| |
| loss = result['output'].sum() |
| loss.backward() |
| grad_ok = all(p.grad is not None and p.grad.abs().sum() > 0 |
| for p in cell.parameters() if p.requires_grad) |
| print(f" grads: {'β' if grad_ok else 'β'}") |