| """ |
| SVD Transformer Prototype |
| ================================================================= |
| Standalone prototype matching user-provided API spec. Combines: |
| |
| Claude is a new context window with opus 4.7 so we'll see if the results show. |
| I essentially have to reteach claude how to use the session-scratchpad skill which is annoying but it's fine. |
| |
| - SpectralProbe lineage's three-head SVD readout (S, U, Vt β embed) |
| - Correct geolip imports: geolip_core registers 'geolip' alias, then |
| geolip.linalg as LA, then FLEigh from geolip_core.linalg.eigh |
| - NO row centering (verified bug β gram-based SVD goes degenerate) |
| - Configurable encoder (mlp/transformer/conv/film/ffn/rotary/lstm/gru) |
| - Configurable geometric activation (star=ReLUΒ² default) |
| - Configurable attention layers between SVD passes |
| - Configurable depth (stacked SVD cells) |
| - Three head selection via `target` ({SVD, VD, SV, S, V}) |
| - Three output formats via `token_out` ({all, QKV, SUVt}) |
| - Solver dispatch: svd_solver={auto, torch, triton}, eigh_solver={auto, torch, fl} |
| |
| API parameter interpretations (clarify if wrong): |
| svd=[S, V, D] β S = sequence/slot count, V/D = SVD matrix dims |
| target="SVD" β all three heads active (S, U, Vt) |
| target="VD" β U + Vt only (drop singular values) |
| target="SV" β S + U only (drop right basis) |
| target="S"/"V" β single head |
| token_out="all" β return (B, S, embed_dim) sequence |
| token_out="QKV" β return (Q, K, V) tuple after QKV projection |
| token_out="SUVt" β return (U, S_vals, Vt) of the LAST cell's SVD |
| depth=N β stack N independent SVD cells |
| |
| Lineage: AbstractPhil / SpectralProbe β CIFAR-10 (53.7% with 13.6k params) |
| """ |
|
|
| import math |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| try: |
| import geolip_core |
| import geolip |
| import geolip.linalg as LA |
| from geolip_core.linalg.eigh import FLEigh |
| _HAS_GEOLIP = True |
| print(f"β geolip {geolip.__version__} β using LA.svd + FLEigh") |
| LA.backend.status() |
| except ImportError as e: |
| _HAS_GEOLIP = False |
| LA = None |
| FLEigh = None |
| print(f"β geolip_core not installed ({e}) β torch.linalg fallback") |
|
|
|
|
| |
| |
| |
|
|
| class StarActivation(nn.Module): |
| """ReLUΒ² β squared positive activation. All-positive output.""" |
| def forward(self, x): |
| return F.relu(x).pow(2) |
|
|
|
|
| _GEO_ACTS = { |
| 'star': lambda: StarActivation(), |
| 'relu': lambda: nn.ReLU(), |
| 'gelu': lambda: nn.GELU(), |
| 'silu': lambda: nn.SiLU(), |
| 'swilu': lambda: nn.SiLU(), |
| 'tanh': lambda: nn.Tanh(), |
| 'sigmoid': lambda: nn.Sigmoid(), |
| 'leaky_relu': lambda: nn.LeakyReLU(0.01), |
| } |
|
|
| _REG_ACTS = { |
| 'gelu': lambda: nn.GELU(), |
| 'relu': lambda: nn.ReLU(), |
| 'silu': lambda: nn.SiLU(), |
| 'tanh': lambda: nn.Tanh(), |
| 'leaky_relu': lambda: nn.LeakyReLU(0.01), |
| } |
|
|
|
|
| def make_geo_activation(name: str) -> nn.Module: |
| name = (name or 'star').lower() |
| if name not in _GEO_ACTS: |
| raise ValueError(f"Unknown geo_activation: {name!r}; options: {list(_GEO_ACTS)}") |
| return _GEO_ACTS[name]() |
|
|
|
|
| def make_activation(name: str) -> nn.Module: |
| name = (name or 'gelu').lower() |
| if name not in _REG_ACTS: |
| raise ValueError(f"Unknown activation: {name!r}; options: {list(_REG_ACTS)}") |
| return _REG_ACTS[name]() |
|
|
|
|
| def _act_name_for_pytorch(name: str) -> str: |
| """nn.TransformerEncoderLayer accepts 'gelu'/'relu' strings; map our names.""" |
| name = (name or 'gelu').lower() |
| return name if name in ('gelu', 'relu') else 'gelu' |
|
|
|
|
| |
| |
| |
|
|
| def _fit_heads(d: int, target: int) -> int: |
| """Pick a head count that divides d evenly (target preferred).""" |
| for h in [target, target // 2, target // 4, 8, 4, 2, 1]: |
| if h > 0 and d % h == 0: |
| return h |
| return 1 |
|
|
|
|
| class MLPEncoder(nn.Module): |
| """encode='mlp' (default) β two-layer MLP per token.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| |
| |
| h = max(hidden_size, max(in_dim, out_dim) // 2, 8) |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, h), |
| make_activation(activation), |
| nn.Linear(h, out_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class FFNEncoder(nn.Module): |
| """encode='ffn' β transformer-style 4Γ expansion FFN.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| h = max(hidden_size, 4 * out_dim) |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, h), |
| make_activation(activation), |
| nn.Linear(h, out_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class FiLMEncoder(nn.Module): |
| """encode='film' β feature-wise affine modulation.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| self.skip = nn.Linear(in_dim, out_dim) |
| self.gamma = nn.Linear(in_dim, out_dim) |
| self.beta = nn.Linear(in_dim, out_dim) |
| self.act = make_activation(activation) |
|
|
| def forward(self, x): |
| skip = self.skip(x) |
| return self.act(skip * (1.0 + self.gamma(x)) + self.beta(x)) |
|
|
|
|
| class ConvEncoder(nn.Module): |
| """encode='conv' β 1D conv across the token sequence.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| self.proj = nn.Linear(in_dim, out_dim) |
| self.conv = nn.Conv1d(out_dim, out_dim, kernel_size=3, padding=1) |
| self.act = make_activation(activation) |
|
|
| def forward(self, x): |
| x = self.proj(x) |
| x = x.transpose(1, 2) |
| x = self.act(self.conv(x)) |
| return x.transpose(1, 2) |
|
|
|
|
| class TransformerEncoder(nn.Module): |
| """encode='transformer' β single transformer encoder layer pre-SVD.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, n_heads=4, **_): |
| super().__init__() |
| self.proj = nn.Linear(in_dim, out_dim) |
| h = _fit_heads(out_dim, n_heads) |
| self.layer = nn.TransformerEncoderLayer( |
| d_model=out_dim, nhead=h, |
| dim_feedforward=max(hidden_size, 4 * out_dim), |
| activation=_act_name_for_pytorch(activation), |
| batch_first=True, norm_first=True, |
| ) |
|
|
| def forward(self, x): |
| return self.layer(self.proj(x)) |
|
|
|
|
| class LSTMEncoder(nn.Module): |
| """encode='lstm' β sequential LSTM.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| self.lstm = nn.LSTM(in_dim, out_dim, batch_first=True) |
| self.act = make_activation(activation) |
|
|
| def forward(self, x): |
| out, _ = self.lstm(x) |
| return self.act(out) |
|
|
|
|
| class GRUEncoder(nn.Module): |
| """encode='gru' β sequential GRU.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| self.gru = nn.GRU(in_dim, out_dim, batch_first=True) |
| self.act = make_activation(activation) |
|
|
| def forward(self, x): |
| out, _ = self.gru(x) |
| return self.act(out) |
|
|
|
|
| class RotaryEncoder(nn.Module): |
| """encode='rotary' β projection then rotary positional embedding.""" |
| def __init__(self, in_dim, out_dim, hidden_size, activation, **_): |
| super().__init__() |
| self.proj = nn.Linear(in_dim, out_dim) |
| self.dim = out_dim |
| self.act = make_activation(activation) |
|
|
| def forward(self, x): |
| x = self.proj(x) |
| B, S, D = x.shape |
| d_half = D // 2 |
| if d_half == 0: |
| return self.act(x) |
| positions = torch.arange(S, device=x.device, dtype=x.dtype).unsqueeze(0) |
| freqs = torch.exp(torch.arange(d_half, device=x.device, dtype=x.dtype) |
| * (-math.log(10000.0) / d_half)) |
| angles = positions.unsqueeze(-1) * freqs.unsqueeze(0) |
| cos, sin = angles.cos(), angles.sin() |
| x1 = x[..., :d_half] |
| x2 = x[..., d_half:2 * d_half] |
| rotated_1 = x1 * cos - x2 * sin |
| rotated_2 = x1 * sin + x2 * cos |
| if D % 2 == 1: |
| tail = x[..., 2 * d_half:] |
| x = torch.cat([rotated_1, rotated_2, tail], dim=-1) |
| else: |
| x = torch.cat([rotated_1, rotated_2], dim=-1) |
| return self.act(x) |
|
|
|
|
| _ENCODERS = { |
| 'mlp': MLPEncoder, |
| 'ffn': FFNEncoder, |
| 'film': FiLMEncoder, |
| 'conv': ConvEncoder, |
| 'transformer': TransformerEncoder, |
| 'lstm': LSTMEncoder, |
| 'gru': GRUEncoder, |
| 'rotary': RotaryEncoder, |
| } |
|
|
|
|
| def build_encoder(encode, in_dim, out_dim, hidden_size, activation): |
| enc = (encode or 'mlp').lower() |
| if enc not in _ENCODERS: |
| raise ValueError(f"Unknown encode={encode!r}; options: {list(_ENCODERS)}") |
| return _ENCODERS[enc](in_dim, out_dim, hidden_size, activation) |
|
|
|
|
| |
| |
| |
|
|
| def _svd_dispatch(M: torch.Tensor, |
| svd_solver: str = 'auto', |
| eigh_solver: str = 'auto' |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| M: (BS, V, D) β batch of matrices to decompose. |
| Returns: U (BS, V, D), S_vals (BS, D), Vt (BS, D, D) |
| |
| Dispatch logic (ALL paths produce thin SVD with descending singular values): |
| |
| no geolip β torch.linalg.svd in fp64 (rank-deficient-safe) |
| svd_solver='torch' β torch.linalg.svd in fp64 |
| svd_solver='triton' β LA.svd(method='triton', compute_dtype='fp32') |
| eigh_solver='fl' β custom gram + FLEigh (compiles up to D=12) |
| auto/auto β LA.svd with compute_dtype selected by D: |
| D β€ 3 β fp32 (Triton fused kernel, fast and precise) |
| D β₯ 4 β fp64 (FL eigh, compilable, ~1e-7 orthogonality) |
| |
| CRITICAL: without compute_dtype='fp64' for Dβ₯4, the Blackwell Triton SVD |
| kernels (4Γ4-6Γ6 fp32) engage by default and orthogonality drops from |
| ~1e-7 to ~1e-3. fp64 is required for training stability at D=4. |
| |
| NEVER row-center M before this β the gram path produces garbage U for |
| rank-deficient inputs. Verified bug across both this implementation and |
| geolip's gram_eigh path. The production SVDObserver in geolip_core also |
| avoids centering for this reason. |
| """ |
| |
| if not _HAS_GEOLIP: |
| with torch.amp.autocast('cuda', enabled=False): |
| U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False) |
| return U.float(), Sv.float(), Vt.float() |
|
|
| |
| if svd_solver == 'torch': |
| with torch.amp.autocast('cuda', enabled=False): |
| U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False) |
| return U.float(), Sv.float(), Vt.float() |
|
|
| |
| if eigh_solver == 'fl': |
| return _gram_fl_eigh_svd(M) |
|
|
| |
| |
| |
| |
| |
| |
| D = M.shape[-1] |
| dtype_arg = 'fp32' if D <= 3 else 'fp64' |
|
|
| |
| if svd_solver == 'triton': |
| try: |
| return LA.svd(M, method='triton', compute_dtype='fp32') |
| except Exception as exc: |
| print(f" β triton SVD failed ({exc}); falling back to LA.svd default") |
| return LA.svd(M, compute_dtype=dtype_arg) |
|
|
| |
| |
| |
| return LA.svd(M, compute_dtype=dtype_arg) |
|
|
|
|
| def _gram_fl_eigh_svd(M: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Custom gram + FL eigh SVD. Uses geolip_core.linalg.eigh.FLEigh β the |
| Faddeev-LeVerrier polynomial + Laguerre roots + Newton-Schulz pipeline. |
| More accurate than torch.linalg.eigh on ill-conditioned grams. |
| |
| Compiles up to D=12 on CUDA. For larger D, use _svd_dispatch with default |
| auto routing (which will pick torch.linalg.svd). |
| """ |
| if FLEigh is None: |
| raise RuntimeError("FLEigh unavailable β geolip_core not installed") |
| orig_dtype = M.dtype |
| A = M.float() |
| G = torch.bmm(A.transpose(1, 2), A) |
| eigenvalues, V = FLEigh()(G) |
| |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
| Sv = torch.sqrt(eigenvalues.clamp(min=1e-12)) |
| U = torch.bmm(A, V) / Sv.unsqueeze(1).clamp(min=1e-8) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U.to(orig_dtype), Sv.to(orig_dtype), Vh.to(orig_dtype) |
|
|
|
|
| |
| |
| |
|
|
| class TensorPatcher(nn.Module): |
| """(B, C, H, W) β (B, N, CΒ·phΒ·pw). Pure reshape, no learned params.""" |
| def __init__(self, input_shape, patch_size): |
| super().__init__() |
| C, H, W = input_shape |
| ph = pw = patch_size |
| assert H % ph == 0 and W % pw == 0 |
| self.C, self.H, self.W = C, H, W |
| self.ph, self.pw = ph, pw |
| self.n_patches = (H // ph) * (W // pw) |
| self.patch_dim = C * ph * pw |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| ph, pw = self.ph, self.pw |
| gh, gw = H // ph, W // pw |
| p = x.reshape(B, C, gh, ph, gw, pw) |
| p = p.permute(0, 2, 4, 1, 3, 5).contiguous() |
| return p.reshape(B, gh * gw, -1) |
|
|
|
|
| |
| |
| |
|
|
| from itertools import combinations |
|
|
|
|
| class CMValidator(nn.Module): |
| """Batch-friendly Cayley-Menger determinant. |
| |
| Computes pairwise squared distances and simplex volume for (k+1)-point |
| subsets in arbitrary embedding dimension. Used as a degeneracy detector: |
| when sphere-normalized rows collapse toward coincidence, the simplex |
| formed by any (k+1)-point subset flattens and volΒ² β 0. |
| |
| For k=4: 5 vertices β 10 pairwise dΒ² + 1 volΒ². Pentachoron config. |
| Adapted from SpectralCell (geolip-core). |
| """ |
| def __init__(self, k): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| pairs = list(combinations(range(self._nv), 2)) |
| self._npairs = len(pairs) |
| self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
|
|
| def forward(self, verts): |
| """verts: (..., nv, edim) β (d2_pairs: (..., npairs), vol2: (...))""" |
| gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| d2_mat = F.relu(d2_mat) |
| d2_pairs = d2_mat[..., self._pi, self._pj] |
| shape = d2_mat.shape[:-2] |
| Vn = d2_mat.shape[-1] |
| cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype) |
| cm[..., 0, 1:] = 1.0 |
| cm[..., 1:, 0] = 1.0 |
| cm[..., 1:, 1:] = d2_mat |
| vol2 = self._prefactor * torch.linalg.det(cm.float()) |
| vol2 = vol2.to(d2_pairs.dtype) |
| return d2_pairs, vol2 |
|
|
|
|
| |
| |
| |
|
|
| class SVDCell(nn.Module): |
| """ |
| One cycle of the architecture: |
| |
| tokens (B, S, in_dim) |
| β encode (mlp/conv/transformer/...) [out: (B, S, VΒ·D)] |
| β reshape [out: (BΒ·S, V, D)] |
| β capture row_mag = ||M_row|| [out: (BΒ·S, V)] β NEW |
| β sphere-normalize rows: F.normalize [unit S^(D-1) rows] β NEW |
| β (optional) CM validator on 5-point subset [volΒ² for degen] β NEW |
| β SVD via geolip (on clean unit matrix) [out: U, S_vals, Vt] |
| β four-head readout: s + u + vt + mag [out: (BΒ·S, embed_dim)] |
| β geo_activation [out: (BΒ·S, embed_dim)] |
| β reshape [out: (B, S, embed_dim)] |
| β attention_layers Γ TransformerEncoderLayer |
| β LayerNorm |
| β tokens (B, S, embed_dim) |
| |
| Sphere normalization rationale: |
| Without it, M rows have unbounded magnitudes β when any row explodes, |
| the SVD singular values explode, attention softmax sees extreme values, |
| and the gradient destabilizes. SpectralCell's fix: normalize rows to |
| S^(D-1), preserving magnitude as a separate signal (row_mag) that |
| flows through mag_head. This keeps the high-end AND low-end magnitude |
| distribution intact while giving the SVD clean inputs to decompose. |
| |
| CM validation rationale: |
| Sphere-normalized rows can still collapse to near-coincidence, which |
| causes SVD to have near-degenerate singular values (numerical nightmare |
| in gradients). CM computes pentachoron volume from 5 sampled rows; volΒ² |
| near zero signals degeneracy. Cached on self._last_cm_vol2 for |
| optional loss regularization or diagnostics. |
| |
| The SVD components and CM signal are cached on self._last_svd and |
| self._last_cm_vol2 for token_out="SUVt" extraction and diagnostics. |
| """ |
| _TARGET_TO_MASK = { |
| 'SVD': (True, True, True), |
| 'VD': (False, True, True), |
| 'SV': (True, True, False), |
| 'S': (True, False, False), |
| 'V': (False, True, False), |
| } |
|
|
| def __init__(self, *, in_dim, S, V, D, embed_dim, hidden_size, |
| encode, activation, geo_activation, target, |
| attention_layers, heads, svd_solver, eigh_solver, |
| sphere_norm=True, mag_head=True, |
| cm_enabled=True, cm_points=5): |
| super().__init__() |
| self.S, self.V, self.D = S, V, D |
| self.embed_dim = embed_dim |
| self.target = (target or 'SVD').upper() |
| self.svd_solver = svd_solver |
| self.eigh_solver = eigh_solver |
| self.sphere_norm = sphere_norm |
| self.use_mag_head = mag_head and sphere_norm |
| if self.target not in self._TARGET_TO_MASK: |
| raise ValueError(f"Unknown target={target!r}; options: {list(self._TARGET_TO_MASK)}") |
|
|
| mat_dim = V * D |
|
|
| |
| self.encoder = build_encoder(encode, in_dim, mat_dim, hidden_size, activation) |
|
|
| |
| self.s_head = nn.Linear(D, embed_dim) |
| self.u_head = nn.Linear(V * D, embed_dim) |
| self.vt_head = nn.Linear(D * D, embed_dim) |
|
|
| |
| if self.use_mag_head: |
| self.mag_head = nn.Linear(V, embed_dim) |
| else: |
| self.mag_head = None |
|
|
| |
| self._cm_nv = cm_points |
| self._cm_k = cm_points - 1 |
| if cm_enabled and sphere_norm and D >= self._cm_k: |
| self.cm = CMValidator(self._cm_k) |
| self._cm_enabled = True |
| else: |
| self.cm = None |
| self._cm_enabled = False |
|
|
| self.geo_act = make_geo_activation(geo_activation) |
|
|
| |
| if attention_layers > 0: |
| n_h = _fit_heads(embed_dim, heads) |
| layer = nn.TransformerEncoderLayer( |
| d_model=embed_dim, nhead=n_h, |
| dim_feedforward=4 * embed_dim, |
| activation=_act_name_for_pytorch(activation), |
| batch_first=True, norm_first=True, |
| ) |
| self.attention = nn.TransformerEncoder(layer, num_layers=attention_layers) |
| else: |
| self.attention = nn.Identity() |
|
|
| self.norm = nn.LayerNorm(embed_dim) |
| self._last_svd = None |
| self._last_row_mag = None |
| self._last_cm_vol2 = None |
|
|
| def forward(self, tokens): |
| """tokens: (B, S, in_dim) β (B, S, embed_dim)""" |
| B, S, _ = tokens.shape |
| assert S == self.S, f"Expected S={self.S} tokens, got {S}" |
|
|
| |
| encoded = self.encoder(tokens) |
| M = encoded.reshape(B * S, self.V, self.D) |
|
|
| |
| |
| |
| |
| |
| if self.sphere_norm: |
| row_mag = M.norm(dim=-1) |
| M = F.normalize(M, dim=-1) |
| self._last_row_mag = row_mag |
| else: |
| row_mag = None |
| self._last_row_mag = None |
|
|
| |
| |
| |
| if self._cm_enabled: |
| |
| cm_idx = torch.linspace(0, self.V - 1, self._cm_nv, |
| device=M.device).long() |
| cm_verts = M[:, cm_idx, :] |
| _, cm_vol2 = self.cm(cm_verts) |
| self._last_cm_vol2 = cm_vol2 |
| else: |
| self._last_cm_vol2 = None |
|
|
| |
| U, Sv, Vt = _svd_dispatch(M, self.svd_solver, self.eigh_solver) |
| self._last_svd = (U, Sv, Vt) |
|
|
| |
| use_s, use_u, use_vt = self._TARGET_TO_MASK[self.target] |
| token_feat = torch.zeros(B * S, self.embed_dim, |
| device=tokens.device, dtype=tokens.dtype) |
| if use_s: |
| token_feat = token_feat + self.s_head(Sv) |
| if use_u: |
| token_feat = token_feat + self.u_head(U.reshape(B * S, -1)) |
| if use_vt: |
| token_feat = token_feat + self.vt_head(Vt.reshape(B * S, -1)) |
| if self.mag_head is not None and row_mag is not None: |
| token_feat = token_feat + self.mag_head(row_mag) |
|
|
| token_feat = self.geo_act(token_feat) |
| token_feat = token_feat.reshape(B, S, self.embed_dim) |
|
|
| |
| token_feat = self.attention(token_feat) |
| return self.norm(token_feat) |
|
|
|
|
| |
| |
| |
|
|
| class SVDTransformer(nn.Module): |
| """ |
| Stacked SVD cells with configurable encoder, attention, and head selection. |
| |
| First cell takes in_dim; subsequent cells take embed_dim. Each cell has its |
| own encoder + SVD + attention substack; `depth` cells in sequence. |
| """ |
| def __init__(self, *, |
| in_dim: int, |
| svd: Tuple[int, int, int] = (16, 8, 4), |
| bypass_crash: bool = True, |
| heads: int = 64, |
| hidden_size: int = 4, |
| depth: int = 4, |
| encode: str = 'mlp', |
| attention_layers: int = 2, |
| activation: str = 'gelu', |
| geo_activation: str = 'star', |
| token_out: str = 'all', |
| target: str = 'SVD', |
| svd_solver: str = 'auto', |
| eigh_solver: str = 'auto', |
| embed_dim: Optional[int] = None, |
| sphere_norm: bool = True, |
| mag_head: bool = True, |
| cm_enabled: bool = True, |
| cm_points: int = 5): |
| super().__init__() |
| S, V, D = svd |
| self.S, self.V, self.D = S, V, D |
|
|
| if D > 128: |
| msg = f"D={D} > 128 β gram-based SVD will be very slow / OOM-prone." |
| if not bypass_crash: |
| raise RuntimeError(msg + " Pass bypass_crash=True to override.") |
| print(f"β {msg}") |
|
|
| if embed_dim is None: |
| embed_dim = V * D |
| self.embed_dim = embed_dim |
| self.token_out = (token_out or 'all').lower() |
|
|
| cells = [] |
| for i in range(depth): |
| cell_in = in_dim if i == 0 else embed_dim |
| cells.append(SVDCell( |
| in_dim=cell_in, S=S, V=V, D=D, embed_dim=embed_dim, |
| hidden_size=hidden_size, encode=encode, |
| activation=activation, geo_activation=geo_activation, |
| target=target, attention_layers=attention_layers, |
| heads=heads, svd_solver=svd_solver, eigh_solver=eigh_solver, |
| sphere_norm=sphere_norm, mag_head=mag_head, |
| cm_enabled=cm_enabled, cm_points=cm_points, |
| )) |
| self.cells = nn.ModuleList(cells) |
|
|
| |
| if self.token_out == 'qkv': |
| self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) |
|
|
| def forward(self, x: torch.Tensor, |
| y: Optional[torch.Tensor] = None, |
| z: Optional[Union[torch.Tensor, dict, list]] = None): |
| """ |
| x: (B, S, in_dim) β input token sequence |
| y: optional mask tensor (reserved; not yet wired into QKV/SUVt logic) |
| z: experimentation hooks (passed through; not yet consumed) |
| |
| Returns one of: |
| token_out="all" (default) β (B, S, embed_dim) |
| token_out="QKV" β (Q, K, V) tuple, each (B, S, embed_dim) |
| token_out="SUVt"/"SUV" β (U, S_vals, Vt) raw geometric tokens |
| from the last cell's SVD |
| """ |
| for cell in self.cells: |
| x = cell(x) |
|
|
| if self.token_out == 'qkv': |
| qkv = self.qkv_proj(x) |
| q, k, v = qkv.chunk(3, dim=-1) |
| return q, k, v |
|
|
| if self.token_out in ('suvt', 'suv'): |
| |
| |
| U, Sv, Vt = self.cells[-1]._last_svd |
| B, S = x.shape[:2] |
| U = U.reshape(B, S, self.V, self.D) |
| Sv = Sv.reshape(B, S, self.D) |
| Vt = Vt.reshape(B, S, self.D, self.D) |
| return U, Sv, Vt |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
| def svd_transformer(x: torch.Tensor, |
| y: Optional[torch.Tensor] = None, |
| z: Optional[Union[torch.Tensor, dict, list]] = None, |
| *, |
| svd: Optional[Tuple[int, int, int]] = None, |
| bypass_crash: bool = True, |
| heads: int = 64, |
| hidden_size: int = 4, |
| depth: int = 4, |
| encode: str = 'mlp', |
| attention_layers: int = 2, |
| activation: str = 'gelu', |
| geo_activation: str = 'star', |
| token_out: str = 'all', |
| target: str = 'SVD', |
| svd_solver: str = 'auto', |
| eigh_solver: str = 'auto', |
| embed_dim: Optional[int] = None, |
| sphere_norm: bool = True, |
| mag_head: bool = True, |
| cm_enabled: bool = True, |
| cm_points: int = 5) -> SVDTransformer: |
| """ |
| Functional API matching user-provided spec. Returns an SVDTransformer |
| initialized from x's shape; caller invokes it via former(x). |
| |
| Shape inference for `svd=None`: |
| x.shape = (B, S, F) β svd = (S, V, D) using sqrt(F) if F is a perfect |
| square, else (S, 8, 4) fallback |
| x.shape = (B, C, H, W) β raises (caller must patchify or pass svd=) |
| |
| Extra params (beyond original API spec, for magnitude handling + degeneracy): |
| sphere_norm : F.normalize(M, dim=-1) before SVD (default True). Critical |
| for training stability β otherwise row magnitudes are |
| unbounded and SVD produces extreme singular values on |
| occasional batches, driving attention softmax collapse. |
| mag_head : add a fourth readout head on captured row magnitudes |
| (default True, requires sphere_norm). |
| cm_enabled : Cayley-Menger validator on 5-point row subset. Caches |
| pentachoron volΒ² on each cell for degeneracy monitoring |
| / optional loss regularization. Default True. |
| cm_points : vertices per simplex (default 5 = pentachoron). |
| |
| Returns the SVDTransformer module. Apply it with `former(x)`. |
| """ |
| if svd is None: |
| if x.ndim == 3: |
| B, S, F = x.shape |
| sq = int(F ** 0.5) |
| if sq * sq == F: |
| V, D = sq, sq |
| else: |
| V, D = 8, 4 |
| svd_param = (S, V, D) |
| elif x.ndim == 4: |
| raise ValueError( |
| "svd_transformer with svd=None requires pre-tokenized input " |
| "(B, S, F). For images, patchify first or pass svd=(S, V, D)." |
| ) |
| else: |
| raise ValueError(f"x.shape must be (B, S, F) or (B, C, H, W); got {tuple(x.shape)}") |
| else: |
| svd_param = tuple(svd) |
|
|
| in_dim = x.shape[-1] |
| return SVDTransformer( |
| in_dim=in_dim, svd=svd_param, bypass_crash=bypass_crash, |
| heads=heads, hidden_size=hidden_size, depth=depth, |
| encode=encode, attention_layers=attention_layers, |
| activation=activation, geo_activation=geo_activation, |
| token_out=token_out, target=target, |
| svd_solver=svd_solver, eigh_solver=eigh_solver, |
| embed_dim=embed_dim, |
| sphere_norm=sphere_norm, mag_head=mag_head, |
| cm_enabled=cm_enabled, cm_points=cm_points, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("\n" + "=" * 72) |
| print("SVDTransformer prototype self-test") |
| print("=" * 72) |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| torch.manual_seed(0) |
|
|
| |
| print("\n[1] Default config: svd=(16, 8, 4), depth=4, encode='mlp'") |
| x = torch.randn(2, 16, 32, device=device) |
| former = svd_transformer(x, svd=(16, 8, 4)) |
| former = former.to(device) |
| out = former(x) |
| n_params = sum(p.numel() for p in former.parameters()) |
| print(f" in shape={tuple(x.shape)} out shape={tuple(out.shape)} params={n_params:,}") |
| assert out.shape == (2, 16, 32), f"Expected (2,16,32), got {out.shape}" |
|
|
| |
| print("\n[2] All encoder types:") |
| for enc in _ENCODERS: |
| m = svd_transformer(x, svd=(16, 8, 4), encode=enc, depth=1, attention_layers=1).to(device) |
| try: |
| o = m(x) |
| print(f" encode={enc:12s} β out={tuple(o.shape)} params={sum(p.numel() for p in m.parameters()):,}") |
| except Exception as e: |
| print(f" encode={enc:12s} β {type(e).__name__}: {e}") |
|
|
| |
| print("\n[3] All target options:") |
| for tgt in ['SVD', 'VD', 'SV', 'S', 'V']: |
| m = svd_transformer(x, svd=(16, 8, 4), target=tgt, depth=1, attention_layers=0).to(device) |
| o = m(x) |
| |
| loss = o.sum() |
| loss.backward() |
| head_grads = { |
| 'S': m.cells[0].s_head.weight.grad.norm().item() if m.cells[0].s_head.weight.grad is not None else 0, |
| 'U': m.cells[0].u_head.weight.grad.norm().item() if m.cells[0].u_head.weight.grad is not None else 0, |
| 'Vt': m.cells[0].vt_head.weight.grad.norm().item() if m.cells[0].vt_head.weight.grad is not None else 0, |
| } |
| active = [k for k, v in head_grads.items() if v > 1e-9] |
| print(f" target={tgt:4s} β active heads={active} out={tuple(o.shape)}") |
|
|
| |
| print("\n[4] All token_out formats:") |
| for to in ['all', 'QKV', 'SUVt']: |
| m = svd_transformer(x, svd=(16, 8, 4), token_out=to, depth=1, attention_layers=0).to(device) |
| o = m(x) |
| if isinstance(o, tuple): |
| shapes = [tuple(t.shape) for t in o] |
| print(f" token_out={to:5s} β {len(o)} tensors, shapes={shapes}") |
| else: |
| print(f" token_out={to:5s} β out={tuple(o.shape)}") |
|
|
| |
| print("\n[5] SVD orthogonality check (no row centering):") |
| m = svd_transformer(x, svd=(16, 8, 4), depth=1, attention_layers=0).to(device) |
| with torch.no_grad(): |
| encoded = m.cells[0].encoder(x) |
| BS = 2 * 16 |
| M = encoded.reshape(BS, 8, 4) |
| rm = M[0].mean(dim=-1)[:3].tolist() |
| print(f" M not centered: row_means[0,:3] = [{rm[0]:.4f},{rm[1]:.4f},{rm[2]:.4f}]") |
| U, Sv, Vt = _svd_dispatch(M) |
| I_D = torch.eye(4, device=device).expand(BS, 4, 4) |
| u_orth = (torch.bmm(U.transpose(1, 2), U) - I_D).abs().max().item() |
| v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I_D).abs().max().item() |
| recon = (torch.bmm(U * Sv.unsqueeze(1), Vt) - M).abs().max().item() |
| print(f" ||U^T U - I|| = {u_orth:.2e} {'β' if u_orth < 1e-3 else 'β'}") |
| print(f" ||Vt Vt^T - I|| = {v_orth:.2e} {'β' if v_orth < 1e-3 else 'β'}") |
| print(f" reconstruction = {recon:.2e} {'β' if recon < 1e-4 else 'β'}") |
|
|
| |
| print("\n[6] Backward pass (gradient flow through SVD):") |
| m = svd_transformer(x, svd=(16, 8, 4), depth=2, attention_layers=1).to(device) |
| out = m(x) |
| loss = out.pow(2).mean() |
| loss.backward() |
| enc_grad = sum( |
| p.grad.norm().item() ** 2 |
| for p in m.cells[0].encoder.parameters() if p.grad is not None |
| ) ** 0.5 |
| print(f" loss = {loss.item():.4f}") |
| print(f" cell[0].encoder grad_norm = {enc_grad:.4e} " |
| f"{'β flowing through SVD into encoder' if enc_grad > 0 else 'β'}") |
|
|
| |
| print("\n[7] Solver dispatch combinations:") |
| for ssolver, esolver in [('auto', 'auto'), ('torch', 'auto'), |
| ('auto', 'fl'), ('auto', 'torch')]: |
| try: |
| m = svd_transformer(x, svd=(16, 8, 4), |
| svd_solver=ssolver, eigh_solver=esolver, |
| depth=1, attention_layers=0).to(device) |
| o = m(x) |
| print(f" svd={ssolver:6s} eigh={esolver:6s} β ok out={tuple(o.shape)}") |
| except Exception as e: |
| print(f" svd={ssolver:6s} eigh={esolver:6s} β {type(e).__name__}: {e}") |
|
|
| |
| print("\n[8] D-too-large guard:") |
| try: |
| m = svd_transformer(torch.randn(2, 4, 200, device=device), |
| svd=(4, 200, 200), bypass_crash=False, depth=1, attention_layers=0) |
| print(f" bypass_crash=False with D=200 β β (should have raised)") |
| except RuntimeError as e: |
| print(f" bypass_crash=False with D=200 β β raised: {str(e)[:60]}...") |
|
|
| print("\n" + "=" * 72) |
| print("All smoke tests complete.") |
| print("=" * 72) |