""" SpectralCell ============ Drop-in layer: (B, N, token_dim) → (B, N, token_dim). Pipeline: tokens → Linear → residual MLP → Linear(hidden, V*D) → reshape(V, D) → capture row magnitudes (encoder confidence) → F.normalize(dim=-1) → full pairwise d² between all V rows on S^{D-1} → C(V,2) pairs → patchwork: round-robin compartment MLPs on d² → structured features → [optional] CM validation → simplex volume (cm_vol2) for soft hand → SVD: analytical closed-form (D=2) or Gram-eigh fp64 (D>2) → cross-attention scales S per mode across all N tokens → recompose M_hat = U · diag(S_modified) · Vt → cat(M_hat, patchwork_features, row_magnitudes) → Linear → residual MLP → Linear(hidden, token_dim) → output SVD dispatch (via geolip.linalg.svd auto-dispatch): D=2: Fused Triton kernel ~0.02ms (fp32, when available) D=3: Fused Triton kernel ~0.02ms (fp32, when available) D≤12: Gram + FL eigh (compilable, 70/72 purity, zero graph breaks) D>12: Gram + torch.linalg.eigh (cuSOLVER) CPU: torch.linalg.svd fallback Falls back to hand-rolled Gram-eigh if geolip.linalg not installed. Cross-attention modifies S multiplicatively: S_out = S * (1 + α * tanh(attention_output)) α per mode, bounded [0, max_alpha], initialized ~0.024. M_hat ≠ M after this step. Sphere normalization enforces ||row||=1 for all V rows. This constrains trace(M^T M) = V (fixed total spectral energy). The SVD decomposes how that fixed energy distributes across D axes. CM validation (configurable): cm_enabled: toggle per cell (disable for speed, enable every Nth cell) cm_points: vertices per simplex (5 = pentachoron) cm_min: minimum volume threshold for valid measurements CM can be disabled for degenerate (D=2) cells where pentachoron geometry is not well-defined on S^1. Factory configurations: spectral_cell_degenerate: D=16 via slice_d=2, no CM, 8× Triton D=2 kernels spectral_cell_primary: D=16, CM enabled, full geometric instrument spectral_cell_conduit: D=16, CM disabled (measured externally) spectral_cell_diamond: D=16, CM enabled, best sweep config spectral_cell_tiny/small/base: research configs Author: AbstractPhil + Claude Opus License: Apache 2.0 """ import math import torch import torch.nn as nn import torch.nn.functional as F from itertools import combinations import warnings # ── Cayley-Menger ─────────────────────────────────────────────── class CMValidator(nn.Module): """Batch-friendly Cayley-Menger determinant. Computes pairwise squared distances and simplex volume for (k+1)-point subsets in arbitrary embedding dimension. For k=4: 5 vertices → 10 pairwise d² + 1 vol². """ def __init__(self, k): super().__init__() self._k = k self._nv = k + 1 pairs = list(combinations(range(self._nv), 2)) self._npairs = len(pairs) self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) sign = (-1.0) ** (k + 1) fact = math.factorial(k) self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) def forward(self, verts): """verts: (..., nv, edim) → d2_pairs: (..., npairs), vol2: (...)""" gram = torch.einsum('...ve,...we->...vw', verts, verts) norms = torch.diagonal(gram, dim1=-2, dim2=-1) d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram d2_mat = F.relu(d2_mat) d2_pairs = d2_mat[..., self._pi, self._pj] shape = d2_mat.shape[:-2] Vn = d2_mat.shape[-1] cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype) cm[..., 0, 1:] = 1.0 cm[..., 1:, 0] = 1.0 cm[..., 1:, 1:] = d2_mat vol2 = self._prefactor * torch.linalg.det(cm.float()) vol2 = vol2.to(d2_pairs.dtype) return d2_pairs, vol2 def cayley_menger_vol2(points: torch.Tensor) -> torch.Tensor: """Squared simplex volume via CM determinant in fp64. points: (B, N, D) → vol2: (B,) """ B, N, D = points.shape pts = points.double() gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=torch.float64) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm) / ((2 ** k) * (fact ** 2)) def cv_of(emb: torch.Tensor, n_samples: int = 200) -> float: """Coefficient of variation of pentachoron volumes. emb: (V, D) — rows of a sphere-normalized matrix. Samples random 5-point subsets, computes CM vol² for each, returns std(vol) / mean(vol). CV ≈ 0.20-0.23 is the empirically observed attractor band. Returns 0.0 if insufficient valid volumes. """ if emb.dim() != 2 or emb.shape[0] < 5: return 0.0 N, D = emb.shape pool = min(N, 512) indices = torch.stack([ torch.randperm(pool, device=emb.device)[:5] for _ in range(n_samples) ]) vol2 = cayley_menger_vol2(emb[:pool][indices]) valid = vol2 > 1e-20 if valid.sum() < 10: return 0.0 vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ── SVD via geolip.linalg auto-dispatch ────────────────────────── # # Dispatch order (from geolip.linalg.svd.batched_svd): # D=2, Triton available: Fused Triton kernel ~0.02ms (fp32) # D=3, Triton available: Fused Triton kernel ~0.02ms (fp32) # D<=12, CUDA: Gram + FL eigh compilable, 70/72 purity # D>12 or CPU: Gram + torch.linalg.eigh # Fallback: torch.linalg.svd # # geolip.linalg is a drop-in superset of torch.linalg. # Import as LA and use LA.svd() for auto-dispatch. try: from geolip_core.linalg import svd as batched_svd from geolip_core.linalg import backend as linalg_backend _HAS_GEOLIP_LINALG = True except ImportError: try: from geolip.linalg import svd as batched_svd from geolip.linalg import backend as linalg_backend _HAS_GEOLIP_LINALG = True except ImportError: _HAS_GEOLIP_LINALG = False linalg_backend = None def batched_svd(A, method='auto', compute_dtype='fp64', **kw): """Fallback SVD when geolip.linalg is not installed.""" B, V, D = A.shape orig = A.dtype with torch.amp.autocast('cuda', enabled=False): Ad = A.double() G = torch.bmm(Ad.transpose(1, 2), Ad) G.diagonal(dim1=-2, dim2=-1).add_(1e-12) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='.*not converge.*') eigenvalues, Vecs = torch.linalg.eigh(G) eigenvalues = eigenvalues.flip(-1) Vecs = Vecs.flip(-1) S = torch.sqrt(eigenvalues.clamp(min=1e-24)) U = torch.bmm(Ad, Vecs) / S.unsqueeze(1).clamp(min=1e-16) Vh = Vecs.transpose(-2, -1).contiguous() return U.to(orig), S.to(orig), Vh.to(orig) # ── Spectral Patchwork ────────────────────────────────────────── class SpectralPatchwork(nn.Module): """Compartmentalized round-robin interpretation of pairwise distances. Takes C(V,2) pairwise squared distances between M rows on S^{D-1} and processes them through K compartment MLPs via stride slicing. comp[k] sees distances[k::K] — every K-th distance starting at k. Each compartment interprets its slice independently. Outputs concatenate into structured geometric features. This is the patchwork from geolip-core adapted for the SpectralCell: triangulation distances → compartment MLPs → structured features. """ def __init__(self, n_pairs, n_comp=8, comp_hidden=32, comp_out=8): super().__init__() self.n_pairs = n_pairs self.n_comp = n_comp self.comp_out = comp_out # Each compartment sees ceil(n_pairs / n_comp) distances self._pairs_per_comp = (n_pairs + n_comp - 1) // n_comp self.compartments = nn.ModuleList([ nn.Sequential( nn.Linear(self._pairs_per_comp, comp_hidden), nn.GELU(), nn.Linear(comp_hidden, comp_out), ) for _ in range(n_comp) ]) @property def output_dim(self): return self.n_comp * self.comp_out def forward(self, d2_pairs): """d2_pairs: (B, n_pairs) → (B, n_comp * comp_out)""" outs = [] for k, comp in enumerate(self.compartments): chunk = d2_pairs[:, k::self.n_comp] # Pad if stride slice is shorter than expected if chunk.shape[-1] < self._pairs_per_comp: pad = torch.zeros( chunk.shape[0], self._pairs_per_comp - chunk.shape[-1], device=chunk.device, dtype=chunk.dtype) chunk = torch.cat([chunk, pad], dim=-1) outs.append(comp(chunk)) return torch.cat(outs, dim=-1) class SpectralCrossAttention(nn.Module): """Multi-head attention on singular values across N tokens. Input S: (B, N, D) — one D-dim spectral profile per token. Attends across N positions (each token sees all others' spectra). Output: S * (1 + α * tanh(out_proj(attended))) α is per-mode, bounded [0, max_alpha] via sigmoid on learnable logits. Initialized at sigmoid(-2.0) * 0.2 ≈ 0.024 per mode. """ def __init__(self, D, n_heads=2, max_alpha=0.2, alpha_init=-2.0): super().__init__() self.n_heads = n_heads self.head_dim = D // n_heads self.max_alpha = max_alpha assert D % n_heads == 0 self.qkv = nn.Linear(D, 3 * D) self.out_proj = nn.Linear(D, D) self.norm = nn.LayerNorm(D) self.scale = self.head_dim ** -0.5 self.alpha_logits = nn.Parameter(torch.full((D,), alpha_init)) @property def alpha(self): return self.max_alpha * torch.sigmoid(self.alpha_logits) def forward(self, S): B, N, D = S.shape Sn = self.norm(S) qkv = self.qkv(Sn).reshape(B, N, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, N, D) gate = torch.tanh(self.out_proj(out)) alpha = self.alpha.unsqueeze(0).unsqueeze(0) return S * (1.0 + alpha * gate) # ── SpectralCell ──────────────────────────────────────────────── class SpectralCell(nn.Module): """Processes N tokens through sphere-normalized SVD with spectral coordination and Cayley-Menger geometric validation. Shapes through the pipeline (for default V=16, D=4, hidden=128, token_dim=64): tokens: (B, N, 64) enc_in: Linear(64, 128) → (B*N, 128) enc_blocks: 2× residual MLP → (B*N, 128) enc_out: Linear(128, 64) → (B*N, 64) → reshape (B*N, 16, 4) normalize: F.normalize(dim=-1) → each row has norm 1 SVD: Gram-eigh in fp64 → U(B*N,16,4), S(B*N,4), Vt(B*N,4,4) cross_attn: S reshaped (B,N,4) → attention across N → S_coord (B,N,4) recompose: U · diag(S_coord) · Vt → M_hat (B*N, 16, 4) → flatten (B*N, 64) out_in: Linear(64, 128) → (B*N, 128) out_blocks: 2× residual MLP → (B*N, 128) out_proj: Linear(128, 64) → (B, N, 64) CM validation: M rows are V unit vectors on S^{D-1}. CMValidator(k=4) samples pentachora from the rows. vol² measures simplex volume. CV measures uniformity. cv_of() returns the coefficient of variation over random subsets. Args: token_dim: input and output dimension per token V: matrix rows (each becomes a unit vector on S^{D-1}) D: matrix columns (spectral modes, eigenvalue count) hidden: residual MLP width depth: residual blocks in input and output projections n_cross: SpectralCrossAttention layers applied to S n_heads: attention heads in cross-attention (must divide D) max_alpha: upper bound on per-mode multiplicative scaling cm_enabled: whether to compute CM validation (disable for speed) cm_points: vertices per simplex (5 = pentachoron) cm_samples: samples for cv_of monitoring cm_min: minimum volume threshold for valid CM degen_threshold: eigenvalue ratio below which SVD is degenerate slice_d: decompose D-dim SVD as D//slice_d independent slice_d-dim SVDs slice_d=2 with D=16 → 8 Triton D=2 kernels for D=16 capacity 0 = no slicing (full SVD) """ def __init__( self, token_dim: int, V: int = 16, D: int = 4, hidden: int = 128, depth: int = 2, n_cross: int = 1, n_heads: int = 2, max_alpha: float = 0.2, cm_enabled: bool = True, cm_points: int = 5, cm_samples: int = 200, cm_min: float = 1e-16, degen_threshold: float = 1e-6, slice_d: int = 0, ): super().__init__() self.token_dim = token_dim self.V = V self.D = D self.mat_dim = V * D self.hidden = hidden self.cm_enabled = cm_enabled self.cm_points = cm_points self.cm_samples = cm_samples self.cm_min = cm_min self.degen_threshold = degen_threshold # Sliced SVD: decompose D-dim as D//slice_d independent slice_d-dim SVDs # slice_d=2, D=16 → 8 independent V×2 SVDs via Triton kernel # Full D capacity, D=2 speed. Degeneracy is expected and handled. if slice_d > 0: assert D % slice_d == 0, f"D={D} must be divisible by slice_d={slice_d}" self.slice_d = slice_d self.n_slices = D // slice_d else: self.slice_d = 0 self.n_slices = 1 # CM validator: k = cm_points - 1 # Only built if cm_enabled and D >= cm_points - 1 self._cm_k = min(cm_points - 1, D - 1) if D >= 2 else 1 if cm_enabled and D >= cm_points - 1: self.cm = CMValidator(self._cm_k) self._cm_npairs = self.cm._npairs else: self.cm = None self._cm_npairs = 0 if cm_enabled and D < cm_points - 1: self.cm_enabled = False # can't do pentachoron on S^1 # Input projection: token_dim → hidden → mat_dim self.enc_in = nn.Linear(token_dim, hidden) self.enc_blocks = nn.ModuleList([ nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden), ) for _ in range(depth) ]) self.enc_out = nn.Linear(hidden, self.mat_dim) nn.init.orthogonal_(self.enc_out.weight) # Cross-attention on singular values across tokens self.cross_attn = nn.ModuleList([ SpectralCrossAttention(D, n_heads=n_heads, max_alpha=max_alpha) for _ in range(n_cross) ]) # Patchwork: compartmentalized interpretation of ALL pairwise distances # Full V×V pairwise d² on S^{D-1} → C(V,2) pairs → round-robin MLPs self._n_pairs = V * (V - 1) // 2 _triu = torch.triu_indices(V, V, offset=1) self.register_buffer('_triu_i', _triu[0]) self.register_buffer('_triu_j', _triu[1]) n_comp = min(8, self._n_pairs) # compartments comp_hidden = max(16, self._n_pairs // n_comp * 2) self.patchwork = SpectralPatchwork( n_pairs=self._n_pairs, n_comp=n_comp, comp_hidden=comp_hidden, comp_out=8, ) # Output projection: M_hat + patchwork features + magnitudes → hidden → token_dim pw_out = self.patchwork.output_dim self.out_in = nn.Linear(self.mat_dim + pw_out + self.V, hidden) self.out_blocks = nn.ModuleList([ nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden), ) for _ in range(depth) ]) self.out_proj = nn.Linear(hidden, token_dim) def format(self, tokens: torch.Tensor) -> dict: """Run full pipeline. Returns output tokens, SVD components, and CM metrics. Args: tokens: (B, N, token_dim) Returns: dict: output: (B, N, token_dim) — processed tokens M: (B, N, V, D) — sphere-normalized matrix (rows on S^{D-1}) U: (B, N, V, D) — left singular vectors from SVD S_orig: (B, N, D) — singular values before cross-attention S: (B, N, D) — singular values after cross-attention Vt: (B, N, D, D) — right singular vectors from SVD M_hat: (B, N, V, D) — U · diag(S_modified) · Vt (≠ M) cm_d2: (B*N, cm_npairs) — CM pairwise d² (5-point sample, for soft hand) cm_vol2: (B*N,) — squared simplex volume from CM row_mag: (B, N, V) — pre-normalization row magnitudes d2_pairs: (B*N, C(V,2)) — full pairwise d² between all M rows pw_features: (B*N, pw_out) — patchwork structured geometric features """ B, N, _ = tokens.shape # Input projection → sphere-normalized V×D matrix flat = tokens.reshape(B * N, -1) h = F.gelu(self.enc_in(flat)) for block in self.enc_blocks: h = h + block(h) M = self.enc_out(h).reshape(B * N, self.V, self.D) row_mag = M.norm(dim=-1) # (B*N, V) — encoder confidence per row M = F.normalize(M, dim=-1) # CM validation (conditional — only when enabled and D supports it) if self.cm_enabled and self.cm is not None: nv = self._cm_k + 1 cm_idx = torch.linspace(0, self.V - 1, nv).long().to(M.device) cm_verts = M[:, cm_idx, :] cm_d2, cm_vol2 = self.cm(cm_verts) else: cm_d2 = None cm_vol2 = None # Full pairwise d² between ALL V rows on S^{D-1} # M rows are unit vectors, so d²(i,j) = 2 - 2·cos(i,j) gram = torch.bmm(M, M.transpose(1, 2)) # (B*N, V, V) d2_full = 2.0 - 2.0 * gram # squared Euclidean d2_pairs = d2_full[:, self._triu_i, self._triu_j] # (B*N, C(V,2)) # Patchwork: compartmentalized interpretation of geometric arrangement pw_features = self.patchwork(d2_pairs) # (B*N, n_comp * comp_out) # SVD decomposition — sliced or full if self.slice_d > 0: # Sliced SVD: D=16 capacity via D=2 Triton kernels # Reshape (BN, V, D) → (BN * n_slices, V, slice_d) BN = B * N M_slices = M.reshape(BN, self.V, self.n_slices, self.slice_d) M_slices = M_slices.permute(0, 2, 1, 3).reshape(BN * self.n_slices, self.V, self.slice_d) # Each slice SVDs via Triton D=2 (fp32 — Triton kernels are fp32-only) # compute_dtype='fp32' ensures Triton dispatch, not eigh fallback U_s, S_s, Vt_s = batched_svd(M_slices, compute_dtype='fp32') # Accumulate S across slices: (BN*ns, sd) → (BN, ns*sd) = (BN, D) S = S_s.reshape(BN, self.n_slices * self.slice_d) # Store U, Vt per-slice for recompose U_s = U_s.reshape(BN, self.n_slices, self.V, self.slice_d) Vt_s = Vt_s.reshape(BN, self.n_slices, self.slice_d, self.slice_d) # Reshape for cross-attention: S is already (BN, D) U = U_s.permute(0, 2, 1, 3).reshape(B, N, self.V, self.D) S = S.reshape(B, N, self.D) M = M.reshape(B, N, self.V, self.D) # Cross-attention on full D-dim S profile across tokens S_orig = S.clone() for layer in self.cross_attn: S = layer(S) # Recompose per-slice with modified S S_sliced = S.reshape(BN, self.n_slices, self.slice_d) M_hat_slices = [] for i in range(self.n_slices): u_i = U_s[:, i] # (BN, V, sd) s_i = S_sliced[:, i] # (BN, sd) vt_i = Vt_s[:, i] # (BN, sd, sd) mh_i = torch.bmm(u_i * s_i.unsqueeze(1), vt_i) # (BN, V, sd) M_hat_slices.append(mh_i) M_hat = torch.cat(M_hat_slices, dim=-1) # (BN, V, D) # Vt for return dict — block diagonal from slices Vt = torch.zeros(BN, self.D, self.D, device=M_hat.device, dtype=M_hat.dtype) for i in range(self.n_slices): s, e = i * self.slice_d, (i + 1) * self.slice_d Vt[:, s:e, s:e] = Vt_s[:, i] Vt = Vt.reshape(B, N, self.D, self.D) else: # Full SVD — auto-dispatches via geolip.linalg # D≤3: fp32 to ensure Triton kernel engages (fp32-only gate) # D=4-12: fp64 for FL eigh precision # D>12: fp64 for torch.linalg.eigh dtype_arg = 'fp32' if self.D <= 3 else 'fp64' U, S, Vt = batched_svd(M, compute_dtype=dtype_arg) # Reshape for cross-attention over N tokens U = U.reshape(B, N, self.V, self.D) S = S.reshape(B, N, self.D) Vt = Vt.reshape(B, N, self.D, self.D) M = M.reshape(B, N, self.V, self.D) # Cross-attention multiplicatively scales S across tokens S_orig = S.clone() for layer in self.cross_attn: S = layer(S) # Recompose with modified S → M_hat ≠ M U_flat = U.reshape(B * N, self.V, self.D) S_flat = S.reshape(B * N, self.D) Vt_flat = Vt.reshape(B * N, self.D, self.D) M_hat = torch.bmm(U_flat * S_flat.unsqueeze(1), Vt_flat) # Output projection: M_hat + patchwork + magnitudes → token_dim out_features = torch.cat([ M_hat.reshape(B * N, -1), # (B*N, V*D) — recomposed spectral structure pw_features, # (B*N, pw_out) — structured geometric arrangement row_mag, # (B*N, V) — encoder confidence ], dim=-1) h = F.gelu(self.out_in(out_features)) for block in self.out_blocks: h = h + block(h) output = self.out_proj(h).reshape(B, N, self.token_dim) return { 'output': output, 'M': M, 'U': U, 'S_orig': S_orig, 'S': S, 'Vt': Vt, 'M_hat': M_hat.reshape(B, N, self.V, self.D), 'cm_d2': cm_d2, 'cm_vol2': cm_vol2, 'row_mag': row_mag.reshape(B, N, self.V), 'd2_pairs': d2_pairs, 'pw_features': pw_features, } def forward(self, tokens: torch.Tensor) -> torch.Tensor: """(B, N, token_dim) → (B, N, token_dim). Drop-in compatible.""" return self.format(tokens)['output'] # ── CM Diagnostics ─────────────────────────────────────────── def cm_cv(self, M: torch.Tensor, n_samples: int = 200) -> float: """Compute CV of pentachoron volumes over random 5-point subsets. M: (B, N, V, D) — sphere-normalized matrices. Returns mean CV across all B*N matrices. """ flat = M.reshape(-1, self.V, self.D) # Sample a few matrices to keep cost reasonable n_mats = min(flat.shape[0], 64) cvs = [] for i in range(n_mats): c = cv_of(flat[i], n_samples=n_samples) cvs.append(c) return sum(cvs) / len(cvs) if cvs else 0.0 def cm_vol2_stats(self, cm_vol2) -> dict: """Statistics on CM vol² from format() output. cm_vol2: (B*N,) or None if CM disabled. """ if cm_vol2 is None: return {'mean': 0.0, 'std': 0.0, 'cv': 0.0, 'frac_valid': 0.0} valid = cm_vol2.abs() > self.cm_min if valid.sum() < 2: return {'mean': 0.0, 'std': 0.0, 'frac_valid': 0.0} vols = cm_vol2[valid].abs().sqrt() return { 'mean': vols.mean().item(), 'std': vols.std().item(), 'cv': (vols.std() / (vols.mean() + 1e-8)).item(), 'frac_valid': valid.float().mean().item(), } # ── SVD Diagnostics ────────────────────────────────────────── @staticmethod def effective_rank(S: torch.Tensor) -> torch.Tensor: """Shannon entropy of normalized singular values, exponentiated. erank = exp(-Σ p_i log p_i) where p_i = σ_i / Σσ. Returns 1.0 for rank-1, D for uniform spectrum. """ p = S / (S.sum(-1, keepdim=True) + 1e-8) p = p.clamp(min=1e-8) return (-(p * p.log()).sum(-1)).exp() @staticmethod def spectral_shift(S_orig, S_coord): """Mean |S_coord - S_orig| across all modes and tokens.""" return (S_coord - S_orig).abs().mean().item() @staticmethod def trace_check(M): """trace(M^T M) should equal V (sum of squared unit row norms).""" flat = M.reshape(-1, M.shape[-2], M.shape[-1]) G = torch.bmm(flat.transpose(1, 2), flat) return torch.diagonal(G, dim1=-2, dim2=-1).sum(-1).mean().item() # ── Full Diagnostic Suite ──────────────────────────────────── # All methods: @torch.no_grad, .detach().float(), returns {str: float} # NEVER returns tensors with grad_fn. Safe to call during training. @staticmethod def _pack_stats(x: torch.Tensor, prefix: str) -> dict: """Quantile summary of a 1-D tensor. All values are Python floats.""" x = x.detach().float().reshape(-1) if x.numel() == 0: return {f"{prefix}_mean": 0.0, f"{prefix}_std": 0.0, f"{prefix}_p50": 0.0, f"{prefix}_p95": 0.0, f"{prefix}_max": 0.0} q = torch.quantile(x, torch.tensor([0.50, 0.95], device=x.device, dtype=x.dtype)) return { f"{prefix}_mean": x.mean().item(), f"{prefix}_std": x.std(unbiased=False).item(), f"{prefix}_p50": q[0].item(), f"{prefix}_p95": q[1].item(), f"{prefix}_max": x.max().item(), } @torch.no_grad() def svd_health(self, result: dict) -> dict: """SVD solver correctness. Is the decomposition numerically trustworthy? All tensors detached — no gradient path.""" M = result['M'].detach().float().reshape(-1, self.V, self.D) U = result['U'].detach().float().reshape(-1, self.V, self.D) S0 = result['S_orig'].detach().float().reshape(-1, self.D) S1 = result['S'].detach().float().reshape(-1, self.D) Vt = result['Vt'].detach().float().reshape(-1, self.D, self.D) # Reconstruction: M ≈ U · diag(S0) · Vt M_recon = torch.bmm(U * S0.unsqueeze(1), Vt) recon_rel = (M - M_recon).norm(dim=(-2, -1)) / M.norm(dim=(-2, -1)).clamp(min=1e-12) # Orthogonality: U^T U ≈ I, Vt Vt^T ≈ I I = torch.eye(self.D, device=M.device).unsqueeze(0) u_orth = (torch.bmm(U.transpose(1, 2), U) - I).norm(dim=(-2, -1)) v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I).norm(dim=(-2, -1)) # Energy: ||M||_F² = sum(S²) for normalized rows energy_M = (M * M).sum(dim=(-2, -1)) energy_S0 = (S0 * S0).sum(dim=-1) energy_resid = (energy_M - energy_S0).abs() / energy_M.clamp(min=1e-12) energy_gain = (S1 * S1).sum(dim=-1) / energy_S0.clamp(min=1e-12) # Condition number and degeneracy cond = S0[:, 0] / S0[:, -1].clamp(min=1e-12) rel_gaps = (S0[:, :-1] - S0[:, 1:]).abs() / S0[:, :1].clamp(min=1e-12) degen_frac = (rel_gaps < self.degen_threshold).float().mean(dim=-1) dead_frac = (S0 < S0[:, :1] * self.degen_threshold).float().mean(dim=-1) out = {} out.update(self._pack_stats(recon_rel, 'svd_recon')) out.update(self._pack_stats(u_orth, 'svd_u_orth')) out.update(self._pack_stats(v_orth, 'svd_v_orth')) out.update(self._pack_stats(energy_resid, 'svd_energy_resid')) out.update(self._pack_stats(energy_gain, 'svd_energy_gain')) out.update(self._pack_stats(cond, 'svd_cond')) out.update(self._pack_stats(degen_frac, 'svd_degen_frac')) out.update(self._pack_stats(dead_frac, 'svd_dead_frac')) return out @torch.no_grad() def geometry_retention(self, result: dict) -> dict: """Did cross-attention preserve sphere geometry? Gram drift, row norm changes, spectral delta.""" M = result['M'].detach().float().reshape(-1, self.V, self.D) M_hat = result['M_hat'].detach().float().reshape(-1, self.V, self.D) S0 = result['S_orig'].detach().float().reshape(-1, self.D) S1 = result['S'].detach().float().reshape(-1, self.D) M_hat_n = F.normalize(M_hat, dim=-1) G0 = torch.bmm(M, M.transpose(1, 2)) G1 = torch.bmm(M_hat_n, M_hat_n.transpose(1, 2)) gram_drift = (G0 - G1).abs().mean(dim=(-2, -1)) row_norm = M_hat.norm(dim=-1) row_norm_mean = row_norm.mean(dim=-1) row_norm_cv = row_norm.std(dim=-1, unbiased=False) / row_norm_mean.clamp(min=1e-12) spectral_delta = (S1 - S0).abs().mean(dim=-1) spectral_ratio = S1.sum(dim=-1) / S0.sum(dim=-1).clamp(min=1e-12) out = {} out.update(self._pack_stats(gram_drift, 'geom_gram_drift')) out.update(self._pack_stats(row_norm_mean, 'geom_row_norm')) out.update(self._pack_stats(row_norm_cv, 'geom_row_norm_cv')) out.update(self._pack_stats(spectral_delta, 'geom_S_delta')) out.update(self._pack_stats(spectral_ratio, 'geom_S_ratio')) return out @torch.no_grad() def thin_svd_health(self, result: dict) -> dict: """D=2 degenerate cell diagnostics. Empty dict if D≠2. Isotropy, frame jumps, handedness flips.""" if self.D != 2 and self.slice_d != 2: return {} S0 = result['S_orig'].detach().float() Vt = result['Vt'].detach().float() # For sliced cells, S0 is (B, N, D=16) and Vt is block-diagonal (B, N, 16, 16) if self.slice_d == 2 and self.D > 2: S0 = S0.reshape(*S0.shape[:-1], self.n_slices, 2) # Extract diagonal 2×2 blocks from the block-diagonal Vt vt_blocks = torch.stack([ Vt[..., i*2:(i+1)*2, i*2:(i+1)*2] for i in range(self.n_slices) ], dim=-3) # (B, N, n_slices, 2, 2) S0 = S0.reshape(-1, 2) Vt = vt_blocks.reshape(-1, 2, 2) is_sliced = True else: S0 = S0.reshape(-1, 2) Vt = Vt.reshape(-1, 2, 2) is_sliced = False # Isotropy: S[0] ≈ S[1] means unstable basis isotropy = (S0[:, 0] - S0[:, 1]).abs() / S0.sum(dim=-1).clamp(min=1e-12) # Handedness flips in the 2×2 right frame det_v = torch.linalg.det(Vt) neg_det = (det_v < 0).float() out = {} out.update(self._pack_stats(isotropy, 'thin_isotropy')) out.update(self._pack_stats(neg_det, 'thin_neg_det')) out['thin_sliced'] = float(is_sliced) out['thin_n_matrices'] = float(S0.shape[0]) return out @torch.no_grad() def diagnostics_bundle(self, result: dict) -> dict: """Complete diagnostic bundle. All floats, no gradients.""" out = {} out.update(self.svd_health(result)) out.update(self.geometry_retention(result)) if self.cm_enabled and result.get('cm_vol2') is not None: out.update(self.cm_vol2_stats(result['cm_vol2'])) out.update(self.thin_svd_health(result)) return out def summary(self): """Print shapes, param count, DOF ratio, CM config.""" n_params = sum(p.numel() for p in self.parameters()) sphere_dof = self.V * (self.D - 1) ratio = sphere_dof / self.token_dim pw = self.patchwork pw_params = sum(p.numel() for p in pw.parameters()) if self.slice_d > 0: svd_path = f"sliced D=2 × {self.n_slices} → D={self.D} ({'geolip Triton' if _HAS_GEOLIP_LINALG else 'fallback'})" elif _HAS_GEOLIP_LINALG: if self.D == 2: svd_path = "geolip.linalg → Triton kernel (D=2)" elif self.D <= 12: svd_path = f"geolip.linalg → FL eigh (D={self.D})" else: svd_path = f"geolip.linalg → torch.linalg.eigh (D={self.D})" else: svd_path = f"fallback Gram-eigh fp64 (D={self.D})" print(f"SpectralCell:") print(f" token_dim={self.token_dim}, V={self.V}, D={self.D}") print(f" mat_dim={self.mat_dim} ({self.V}×{self.D})") print(f" sphere DOF={sphere_dof} (V rows × {self.D-1} free per row)") print(f" SVD: {svd_path}, degen_threshold={self.degen_threshold}") if self.cm_enabled and self.cm is not None: print(f" CM: enabled, k={self._cm_k} ({self._cm_k+1} vertices, {self._cm_npairs} pairs), min={self.cm_min}") else: print(f" CM: disabled") print(f" patchwork: {self._n_pairs} pairs → {pw.n_comp} comps × {pw.comp_out} = {pw.output_dim} features ({pw_params:,} params)") print(f" out_in: {self.mat_dim} (M_hat) + {pw.output_dim} (patchwork) + {self.V} (mag) = {self.mat_dim + pw.output_dim + self.V}") print(f" hidden={self.hidden}, depth={len(self.enc_blocks)}") print(f" cross_attn={len(self.cross_attn)} layers") print(f" params: {n_params:,}") print(f" DOF ratio: {ratio:.2f}× " f"({'expand' if ratio > 1 else 'compress' if ratio < 1 else 'identity'})") # ── Factory functions ──────────────────────────────────────────── def spectral_cell_tiny(token_dim: int) -> SpectralCell: """V=8, D=4, hidden=64, depth=1, 1 cross-attn.""" return SpectralCell(token_dim, V=8, D=4, hidden=64, depth=1, n_cross=1) def spectral_cell_small(token_dim: int) -> SpectralCell: """V=16, D=4, hidden=128, depth=2, 1 cross-attn.""" return SpectralCell(token_dim, V=16, D=4, hidden=128, depth=2, n_cross=1) def spectral_cell_base(token_dim: int) -> SpectralCell: """V=16, D=8, hidden=256, depth=2, 2 cross-attn.""" return SpectralCell(token_dim, V=16, D=8, hidden=256, depth=2, n_cross=2, n_heads=4) def spectral_cell_diamond(token_dim: int) -> SpectralCell: """V=16, D=16, hidden=256, depth=2, 1 cross-attn. Best sweep config.""" return SpectralCell(token_dim, V=16, D=16, hidden=256, depth=2, n_cross=1, n_heads=4) # ── Primary and degenerate configurations ──────────────────────── def spectral_cell_degenerate(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: """D=16 capacity via D=2 sliced SVD. 8 independent Triton kernels. Full D=16 spectral representation accumulated from D=2 substructures. CM disabled — geometry is degenerate by design. Near-degenerate eigenvalues are expected and handled by Triton. Use as fast geometric formatting between primary cells. """ return SpectralCell( token_dim, V=V, D=16, hidden=hidden, depth=1, n_cross=1, n_heads=4, max_alpha=0.2, cm_enabled=False, degen_threshold=1e-6, slice_d=2, ) def spectral_cell_primary(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: """D=16 full path. CM enabled, fp64 eigh SVD, patchwork, magnitude. The full geometric instrument. Use every 3rd cell in a stack. """ return SpectralCell( token_dim, V=V, D=16, hidden=hidden, depth=2, n_cross=2, n_heads=4, max_alpha=0.2, cm_enabled=True, cm_points=5, cm_samples=200, cm_min=1e-16, degen_threshold=1e-6, ) def spectral_cell_conduit(token_dim: int, V: int = 16, hidden: int = 256) -> SpectralCell: """D=16 conduit path. CM disabled (measured externally every 3 cells). Full spectral processing without per-cell CM overhead. """ return SpectralCell( token_dim, V=V, D=16, hidden=hidden, depth=2, n_cross=2, n_heads=4, max_alpha=0.2, cm_enabled=False, degen_threshold=1e-6, ) # ── Self-test ─────────────────────────────────────────────────── if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' configs = [ ('tiny', spectral_cell_tiny), ('small', spectral_cell_small), ('diamond', spectral_cell_diamond), ('degenerate (D=2)', spectral_cell_degenerate), ('primary (D=16)', spectral_cell_primary), ('conduit (D=16 no CM)', spectral_cell_conduit), ] for name, factory in configs: print(f"\n{'='*60}") print(f" {name}") print(f"{'='*60}") cell = factory(token_dim=192).to(device) cell.summary() tokens = torch.randn(2, 16, 192, device=device) result = cell.format(tokens) print(f"\n Input: {tokens.shape}") print(f" Output: {result['output'].shape}") print(f" M: {result['M'].shape}") print(f" S: {result['S'].shape}") if result['cm_d2'] is not None: print(f" cm_d2: {result['cm_d2'].shape}") print(f" cm_vol2: {result['cm_vol2'].shape}") else: print(f" cm_d2: None (CM disabled)") print(f" cm_vol2: None (CM disabled)") print(f" trace: {cell.trace_check(result['M']):.4f} (expect {cell.V})") print(f" erank: {cell.effective_rank(result['S_orig'].reshape(-1, cell.D)).mean():.2f}") print(f" shift: {cell.spectral_shift(result['S_orig'], result['S']):.6f}") # CM stats cm_stats = cell.cm_vol2_stats(result['cm_vol2']) print(f" cm_vol: mean={cm_stats['mean']:.6f} cv={cm_stats.get('cv', 0):.4f} " f"valid={cm_stats['frac_valid']:.1%}") # Gradient check loss = result['output'].sum() loss.backward() grad_ok = all(p.grad is not None and p.grad.abs().sum() > 0 for p in cell.parameters() if p.requires_grad) print(f" grads: {'✓' if grad_ok else '✗'}")