""" SVD Transformer Prototype ================================================================= Standalone prototype matching user-provided API spec. Combines: Claude is a new context window with opus 4.7 so we'll see if the results show. I essentially have to reteach claude how to use the session-scratchpad skill which is annoying but it's fine. - SpectralProbe lineage's three-head SVD readout (S, U, Vt → embed) - Correct geolip imports: geolip_core registers 'geolip' alias, then geolip.linalg as LA, then FLEigh from geolip_core.linalg.eigh - NO row centering (verified bug — gram-based SVD goes degenerate) - Configurable encoder (mlp/transformer/conv/film/ffn/rotary/lstm/gru) - Configurable geometric activation (star=ReLU² default) - Configurable attention layers between SVD passes - Configurable depth (stacked SVD cells) - Three head selection via `target` ({SVD, VD, SV, S, V}) - Three output formats via `token_out` ({all, QKV, SUVt}) - Solver dispatch: svd_solver={auto, torch, triton}, eigh_solver={auto, torch, fl} API parameter interpretations (clarify if wrong): svd=[S, V, D] — S = sequence/slot count, V/D = SVD matrix dims target="SVD" — all three heads active (S, U, Vt) target="VD" — U + Vt only (drop singular values) target="SV" — S + U only (drop right basis) target="S"/"V" — single head token_out="all" — return (B, S, embed_dim) sequence token_out="QKV" — return (Q, K, V) tuple after QKV projection token_out="SUVt" — return (U, S_vals, Vt) of the LAST cell's SVD depth=N — stack N independent SVD cells Lineage: AbstractPhil / SpectralProbe → CIFAR-10 (53.7% with 13.6k params) """ import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F # ──────────────────────────────────────────────────────────────────────── # geolip imports — CORRECT ORDER (geolip_core triggers sys.modules alias) # ──────────────────────────────────────────────────────────────────────── try: import geolip_core # noqa: F401 registers 'geolip' alias in sys.modules import geolip # now resolvable import geolip.linalg as LA # main dispatcher from geolip_core.linalg.eigh import FLEigh _HAS_GEOLIP = True print(f"✓ geolip {geolip.__version__} — using LA.svd + FLEigh") LA.backend.status() except ImportError as e: _HAS_GEOLIP = False LA = None FLEigh = None print(f"⚠ geolip_core not installed ({e}) — torch.linalg fallback") # ──────────────────────────────────────────────────────────────────────── # Activations (regular + geometric) # ──────────────────────────────────────────────────────────────────────── class StarActivation(nn.Module): """ReLU² — squared positive activation. All-positive output.""" def forward(self, x): return F.relu(x).pow(2) _GEO_ACTS = { 'star': lambda: StarActivation(), 'relu': lambda: nn.ReLU(), 'gelu': lambda: nn.GELU(), 'silu': lambda: nn.SiLU(), 'swilu': lambda: nn.SiLU(), # alias of silu 'tanh': lambda: nn.Tanh(), 'sigmoid': lambda: nn.Sigmoid(), 'leaky_relu': lambda: nn.LeakyReLU(0.01), } _REG_ACTS = { 'gelu': lambda: nn.GELU(), 'relu': lambda: nn.ReLU(), 'silu': lambda: nn.SiLU(), 'tanh': lambda: nn.Tanh(), 'leaky_relu': lambda: nn.LeakyReLU(0.01), } def make_geo_activation(name: str) -> nn.Module: name = (name or 'star').lower() if name not in _GEO_ACTS: raise ValueError(f"Unknown geo_activation: {name!r}; options: {list(_GEO_ACTS)}") return _GEO_ACTS[name]() def make_activation(name: str) -> nn.Module: name = (name or 'gelu').lower() if name not in _REG_ACTS: raise ValueError(f"Unknown activation: {name!r}; options: {list(_REG_ACTS)}") return _REG_ACTS[name]() def _act_name_for_pytorch(name: str) -> str: """nn.TransformerEncoderLayer accepts 'gelu'/'relu' strings; map our names.""" name = (name or 'gelu').lower() return name if name in ('gelu', 'relu') else 'gelu' # ──────────────────────────────────────────────────────────────────────── # Encoder variants — apply per-token before SVD reshape # ──────────────────────────────────────────────────────────────────────── def _fit_heads(d: int, target: int) -> int: """Pick a head count that divides d evenly (target preferred).""" for h in [target, target // 2, target // 4, 8, 4, 2, 1]: if h > 0 and d % h == 0: return h return 1 class MLPEncoder(nn.Module): """encode='mlp' (default) — two-layer MLP per token.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() # hidden_size is the API's "Internal MLP hidden size" — small (default 4) # Don't let it bottleneck; ensure at least max(in, out)/2 h = max(hidden_size, max(in_dim, out_dim) // 2, 8) self.net = nn.Sequential( nn.Linear(in_dim, h), make_activation(activation), nn.Linear(h, out_dim), ) def forward(self, x): return self.net(x) class FFNEncoder(nn.Module): """encode='ffn' — transformer-style 4× expansion FFN.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() h = max(hidden_size, 4 * out_dim) self.net = nn.Sequential( nn.Linear(in_dim, h), make_activation(activation), nn.Linear(h, out_dim), ) def forward(self, x): return self.net(x) class FiLMEncoder(nn.Module): """encode='film' — feature-wise affine modulation.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() self.skip = nn.Linear(in_dim, out_dim) self.gamma = nn.Linear(in_dim, out_dim) self.beta = nn.Linear(in_dim, out_dim) self.act = make_activation(activation) def forward(self, x): skip = self.skip(x) return self.act(skip * (1.0 + self.gamma(x)) + self.beta(x)) class ConvEncoder(nn.Module): """encode='conv' — 1D conv across the token sequence.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() self.proj = nn.Linear(in_dim, out_dim) self.conv = nn.Conv1d(out_dim, out_dim, kernel_size=3, padding=1) self.act = make_activation(activation) def forward(self, x): # (B, S, in_dim) x = self.proj(x) x = x.transpose(1, 2) # (B, out_dim, S) x = self.act(self.conv(x)) return x.transpose(1, 2) # (B, S, out_dim) class TransformerEncoder(nn.Module): """encode='transformer' — single transformer encoder layer pre-SVD.""" def __init__(self, in_dim, out_dim, hidden_size, activation, n_heads=4, **_): super().__init__() self.proj = nn.Linear(in_dim, out_dim) h = _fit_heads(out_dim, n_heads) self.layer = nn.TransformerEncoderLayer( d_model=out_dim, nhead=h, dim_feedforward=max(hidden_size, 4 * out_dim), activation=_act_name_for_pytorch(activation), batch_first=True, norm_first=True, ) def forward(self, x): return self.layer(self.proj(x)) class LSTMEncoder(nn.Module): """encode='lstm' — sequential LSTM.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() self.lstm = nn.LSTM(in_dim, out_dim, batch_first=True) self.act = make_activation(activation) def forward(self, x): out, _ = self.lstm(x) return self.act(out) class GRUEncoder(nn.Module): """encode='gru' — sequential GRU.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() self.gru = nn.GRU(in_dim, out_dim, batch_first=True) self.act = make_activation(activation) def forward(self, x): out, _ = self.gru(x) return self.act(out) class RotaryEncoder(nn.Module): """encode='rotary' — projection then rotary positional embedding.""" def __init__(self, in_dim, out_dim, hidden_size, activation, **_): super().__init__() self.proj = nn.Linear(in_dim, out_dim) self.dim = out_dim self.act = make_activation(activation) def forward(self, x): x = self.proj(x) # (B, S, out_dim) B, S, D = x.shape d_half = D // 2 if d_half == 0: return self.act(x) positions = torch.arange(S, device=x.device, dtype=x.dtype).unsqueeze(0) freqs = torch.exp(torch.arange(d_half, device=x.device, dtype=x.dtype) * (-math.log(10000.0) / d_half)) angles = positions.unsqueeze(-1) * freqs.unsqueeze(0) # (1, S, d_half) cos, sin = angles.cos(), angles.sin() x1 = x[..., :d_half] x2 = x[..., d_half:2 * d_half] rotated_1 = x1 * cos - x2 * sin rotated_2 = x1 * sin + x2 * cos if D % 2 == 1: tail = x[..., 2 * d_half:] x = torch.cat([rotated_1, rotated_2, tail], dim=-1) else: x = torch.cat([rotated_1, rotated_2], dim=-1) return self.act(x) _ENCODERS = { 'mlp': MLPEncoder, 'ffn': FFNEncoder, 'film': FiLMEncoder, 'conv': ConvEncoder, 'transformer': TransformerEncoder, 'lstm': LSTMEncoder, 'gru': GRUEncoder, 'rotary': RotaryEncoder, } def build_encoder(encode, in_dim, out_dim, hidden_size, activation): enc = (encode or 'mlp').lower() if enc not in _ENCODERS: raise ValueError(f"Unknown encode={encode!r}; options: {list(_ENCODERS)}") return _ENCODERS[enc](in_dim, out_dim, hidden_size, activation) # ──────────────────────────────────────────────────────────────────────── # SVD dispatch — auto-route to fastest available correct backend # ──────────────────────────────────────────────────────────────────────── def _svd_dispatch(M: torch.Tensor, svd_solver: str = 'auto', eigh_solver: str = 'auto' ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ M: (BS, V, D) — batch of matrices to decompose. Returns: U (BS, V, D), S_vals (BS, D), Vt (BS, D, D) Dispatch logic (ALL paths produce thin SVD with descending singular values): no geolip → torch.linalg.svd in fp64 (rank-deficient-safe) svd_solver='torch' → torch.linalg.svd in fp64 svd_solver='triton' → LA.svd(method='triton', compute_dtype='fp32') eigh_solver='fl' → custom gram + FLEigh (compiles up to D=12) auto/auto → LA.svd with compute_dtype selected by D: D ≤ 3 → fp32 (Triton fused kernel, fast and precise) D ≥ 4 → fp64 (FL eigh, compilable, ~1e-7 orthogonality) CRITICAL: without compute_dtype='fp64' for D≥4, the Blackwell Triton SVD kernels (4×4-6×6 fp32) engage by default and orthogonality drops from ~1e-7 to ~1e-3. fp64 is required for training stability at D=4. NEVER row-center M before this — the gram path produces garbage U for rank-deficient inputs. Verified bug across both this implementation and geolip's gram_eigh path. The production SVDObserver in geolip_core also avoids centering for this reason. """ # --- Fallback: no geolip if not _HAS_GEOLIP: with torch.amp.autocast('cuda', enabled=False): U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False) return U.float(), Sv.float(), Vt.float() # --- Explicit torch path if svd_solver == 'torch': with torch.amp.autocast('cuda', enabled=False): U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False) return U.float(), Sv.float(), Vt.float() # --- Explicit FL eigh path (custom gram + FLEigh, more accurate than torch.linalg.eigh) if eigh_solver == 'fl': return _gram_fl_eigh_svd(M) # --- Pick compute_dtype based on D --- # Triton fused kernels (2×2, 3×3, and newer 4×4-6×6) run in fp32 for speed. # For D ≤ 3 fp32 is precise enough (Triton fused kernel is numerically clean). # For D ≥ 4 we force fp64 to route through FL eigh rather than fp32 Triton SVD # — the Triton path only gives ~1e-3 orthogonality, FL eigh gives ~1e-7. # Matches SpectralCell's explicit dispatch logic. D = M.shape[-1] dtype_arg = 'fp32' if D <= 3 else 'fp64' # --- Triton path (explicit) — keep user's fp32 choice if they asked for it if svd_solver == 'triton': try: return LA.svd(M, method='triton', compute_dtype='fp32') except Exception as exc: print(f" ⚠ triton SVD failed ({exc}); falling back to LA.svd default") return LA.svd(M, compute_dtype=dtype_arg) # --- Default: LA.svd auto-dispatch with explicit compute_dtype # fp64 for D≥4 → routes to FL eigh (compilable, 70/72 math purity) # fp32 for D≤3 → routes to Triton fused kernel return LA.svd(M, compute_dtype=dtype_arg) def _gram_fl_eigh_svd(M: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Custom gram + FL eigh SVD. Uses geolip_core.linalg.eigh.FLEigh — the Faddeev-LeVerrier polynomial + Laguerre roots + Newton-Schulz pipeline. More accurate than torch.linalg.eigh on ill-conditioned grams. Compiles up to D=12 on CUDA. For larger D, use _svd_dispatch with default auto routing (which will pick torch.linalg.svd). """ if FLEigh is None: raise RuntimeError("FLEigh unavailable — geolip_core not installed") orig_dtype = M.dtype A = M.float() # FL eigh runs in float G = torch.bmm(A.transpose(1, 2), A) # (BS, D, D), symmetric PSD eigenvalues, V = FLEigh()(G) # eigh returns ascending; we want descending singular values eigenvalues = eigenvalues.flip(-1) V = V.flip(-1) Sv = torch.sqrt(eigenvalues.clamp(min=1e-12)) U = torch.bmm(A, V) / Sv.unsqueeze(1).clamp(min=1e-8) Vh = V.transpose(-2, -1).contiguous() return U.to(orig_dtype), Sv.to(orig_dtype), Vh.to(orig_dtype) # ──────────────────────────────────────────────────────────────────────── # Image patcher (helper for image inputs; not part of svd_transformer itself) # ──────────────────────────────────────────────────────────────────────── class TensorPatcher(nn.Module): """(B, C, H, W) → (B, N, C·ph·pw). Pure reshape, no learned params.""" def __init__(self, input_shape, patch_size): super().__init__() C, H, W = input_shape ph = pw = patch_size assert H % ph == 0 and W % pw == 0 self.C, self.H, self.W = C, H, W self.ph, self.pw = ph, pw self.n_patches = (H // ph) * (W // pw) self.patch_dim = C * ph * pw def forward(self, x): B, C, H, W = x.shape ph, pw = self.ph, self.pw gh, gw = H // ph, W // pw p = x.reshape(B, C, gh, ph, gw, pw) p = p.permute(0, 2, 4, 1, 3, 5).contiguous() return p.reshape(B, gh * gw, -1) # ──────────────────────────────────────────────────────────────────────── # Cayley-Menger validator — degeneracy detector via simplex volume # ──────────────────────────────────────────────────────────────────────── from itertools import combinations class CMValidator(nn.Module): """Batch-friendly Cayley-Menger determinant. Computes pairwise squared distances and simplex volume for (k+1)-point subsets in arbitrary embedding dimension. Used as a degeneracy detector: when sphere-normalized rows collapse toward coincidence, the simplex formed by any (k+1)-point subset flattens and vol² → 0. For k=4: 5 vertices → 10 pairwise d² + 1 vol². Pentachoron config. Adapted from SpectralCell (geolip-core). """ def __init__(self, k): super().__init__() self._k = k self._nv = k + 1 pairs = list(combinations(range(self._nv), 2)) self._npairs = len(pairs) self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) sign = (-1.0) ** (k + 1) fact = math.factorial(k) self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) def forward(self, verts): """verts: (..., nv, edim) → (d2_pairs: (..., npairs), vol2: (...))""" gram = torch.einsum('...ve,...we->...vw', verts, verts) norms = torch.diagonal(gram, dim1=-2, dim2=-1) d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram d2_mat = F.relu(d2_mat) d2_pairs = d2_mat[..., self._pi, self._pj] shape = d2_mat.shape[:-2] Vn = d2_mat.shape[-1] cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype) cm[..., 0, 1:] = 1.0 cm[..., 1:, 0] = 1.0 cm[..., 1:, 1:] = d2_mat vol2 = self._prefactor * torch.linalg.det(cm.float()) vol2 = vol2.to(d2_pairs.dtype) return d2_pairs, vol2 # ──────────────────────────────────────────────────────────────────────── # SVD Cell — one cycle: encode → sphere-normalize → SVD → heads → attention # ──────────────────────────────────────────────────────────────────────── class SVDCell(nn.Module): """ One cycle of the architecture: tokens (B, S, in_dim) ↓ encode (mlp/conv/transformer/...) [out: (B, S, V·D)] ↓ reshape [out: (B·S, V, D)] ↓ capture row_mag = ||M_row|| [out: (B·S, V)] ← NEW ↓ sphere-normalize rows: F.normalize [unit S^(D-1) rows] ← NEW ↓ (optional) CM validator on 5-point subset [vol² for degen] ← NEW ↓ SVD via geolip (on clean unit matrix) [out: U, S_vals, Vt] ↓ four-head readout: s + u + vt + mag [out: (B·S, embed_dim)] ↓ geo_activation [out: (B·S, embed_dim)] ↓ reshape [out: (B, S, embed_dim)] ↓ attention_layers × TransformerEncoderLayer ↓ LayerNorm → tokens (B, S, embed_dim) Sphere normalization rationale: Without it, M rows have unbounded magnitudes — when any row explodes, the SVD singular values explode, attention softmax sees extreme values, and the gradient destabilizes. SpectralCell's fix: normalize rows to S^(D-1), preserving magnitude as a separate signal (row_mag) that flows through mag_head. This keeps the high-end AND low-end magnitude distribution intact while giving the SVD clean inputs to decompose. CM validation rationale: Sphere-normalized rows can still collapse to near-coincidence, which causes SVD to have near-degenerate singular values (numerical nightmare in gradients). CM computes pentachoron volume from 5 sampled rows; vol² near zero signals degeneracy. Cached on self._last_cm_vol2 for optional loss regularization or diagnostics. The SVD components and CM signal are cached on self._last_svd and self._last_cm_vol2 for token_out="SUVt" extraction and diagnostics. """ _TARGET_TO_MASK = { 'SVD': (True, True, True), # all three heads 'VD': (False, True, True), # U + Vt only 'SV': (True, True, False), # S + U only 'S': (True, False, False), # singular values only 'V': (False, True, False), # U (left vectors) only } def __init__(self, *, in_dim, S, V, D, embed_dim, hidden_size, encode, activation, geo_activation, target, attention_layers, heads, svd_solver, eigh_solver, sphere_norm=True, mag_head=True, cm_enabled=True, cm_points=5): super().__init__() self.S, self.V, self.D = S, V, D self.embed_dim = embed_dim self.target = (target or 'SVD').upper() self.svd_solver = svd_solver self.eigh_solver = eigh_solver self.sphere_norm = sphere_norm self.use_mag_head = mag_head and sphere_norm # mag_head requires sphere_norm if self.target not in self._TARGET_TO_MASK: raise ValueError(f"Unknown target={target!r}; options: {list(self._TARGET_TO_MASK)}") mat_dim = V * D # Encoder: tokens (B, S, in_dim) → (B, S, V*D) self.encoder = build_encoder(encode, in_dim, mat_dim, hidden_size, activation) # Three SVD head linears (all instantiated; mask gates which contribute) self.s_head = nn.Linear(D, embed_dim) self.u_head = nn.Linear(V * D, embed_dim) self.vt_head = nn.Linear(D * D, embed_dim) # Magnitude head (optional, requires sphere_norm to have meaning) if self.use_mag_head: self.mag_head = nn.Linear(V, embed_dim) else: self.mag_head = None # CM validator (only meaningful when sphere_norm=True and D supports the simplex) self._cm_nv = cm_points self._cm_k = cm_points - 1 if cm_enabled and sphere_norm and D >= self._cm_k: self.cm = CMValidator(self._cm_k) self._cm_enabled = True else: self.cm = None self._cm_enabled = False self.geo_act = make_geo_activation(geo_activation) # Attention stack (post-SVD) if attention_layers > 0: n_h = _fit_heads(embed_dim, heads) layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=n_h, dim_feedforward=4 * embed_dim, activation=_act_name_for_pytorch(activation), batch_first=True, norm_first=True, ) self.attention = nn.TransformerEncoder(layer, num_layers=attention_layers) else: self.attention = nn.Identity() self.norm = nn.LayerNorm(embed_dim) self._last_svd = None # (U, S_vals, Vt) cache self._last_row_mag = None # (B*S, V) magnitude cache self._last_cm_vol2 = None # (B*S,) CM volume² cache for diagnostics def forward(self, tokens): """tokens: (B, S, in_dim) → (B, S, embed_dim)""" B, S, _ = tokens.shape assert S == self.S, f"Expected S={self.S} tokens, got {S}" # Encode → V×D matrix per token encoded = self.encoder(tokens) # (B, S, V*D) M = encoded.reshape(B * S, self.V, self.D) # Sphere-normalize rows: preserve magnitude as separate signal, give # the SVD clean unit-sphere rows to decompose. This is the fix for # magnitude-explosion instability events. M rows were unbounded before; # now each row lies on S^(D-1) with ||row||=1. Magnitude flows through # row_mag → mag_head as its own feature pathway. if self.sphere_norm: row_mag = M.norm(dim=-1) # (B*S, V) — pre-norm magnitude M = F.normalize(M, dim=-1) # rows now on S^(D-1) self._last_row_mag = row_mag else: row_mag = None self._last_row_mag = None # CM validation: sample 5 rows, compute pentachoron vol² as degeneracy # signal. Low vol² → rows clumped on sphere → SVD near-degenerate → # unstable gradients. Cached for optional loss regularization or logging. if self._cm_enabled: # Pick nv rows evenly spaced across V — linspace indices cm_idx = torch.linspace(0, self.V - 1, self._cm_nv, device=M.device).long() cm_verts = M[:, cm_idx, :] # (B*S, nv, D) _, cm_vol2 = self.cm(cm_verts) self._last_cm_vol2 = cm_vol2 else: self._last_cm_vol2 = None # SVD — now on a clean, bounded, well-conditioned input U, Sv, Vt = _svd_dispatch(M, self.svd_solver, self.eigh_solver) self._last_svd = (U, Sv, Vt) # Four-head readout (S, U, Vt gated by target; mag_head always-on when enabled) use_s, use_u, use_vt = self._TARGET_TO_MASK[self.target] token_feat = torch.zeros(B * S, self.embed_dim, device=tokens.device, dtype=tokens.dtype) if use_s: token_feat = token_feat + self.s_head(Sv) if use_u: token_feat = token_feat + self.u_head(U.reshape(B * S, -1)) if use_vt: token_feat = token_feat + self.vt_head(Vt.reshape(B * S, -1)) if self.mag_head is not None and row_mag is not None: token_feat = token_feat + self.mag_head(row_mag) token_feat = self.geo_act(token_feat) token_feat = token_feat.reshape(B, S, self.embed_dim) # Attention layers token_feat = self.attention(token_feat) return self.norm(token_feat) # ──────────────────────────────────────────────────────────────────────── # SVDTransformer — top-level module (depth × SVDCell) # ──────────────────────────────────────────────────────────────────────── class SVDTransformer(nn.Module): """ Stacked SVD cells with configurable encoder, attention, and head selection. First cell takes in_dim; subsequent cells take embed_dim. Each cell has its own encoder + SVD + attention substack; `depth` cells in sequence. """ def __init__(self, *, in_dim: int, svd: Tuple[int, int, int] = (16, 8, 4), bypass_crash: bool = True, heads: int = 64, hidden_size: int = 4, depth: int = 4, encode: str = 'mlp', attention_layers: int = 2, activation: str = 'gelu', geo_activation: str = 'star', token_out: str = 'all', target: str = 'SVD', svd_solver: str = 'auto', eigh_solver: str = 'auto', embed_dim: Optional[int] = None, sphere_norm: bool = True, mag_head: bool = True, cm_enabled: bool = True, cm_points: int = 5): super().__init__() S, V, D = svd self.S, self.V, self.D = S, V, D if D > 128: msg = f"D={D} > 128 — gram-based SVD will be very slow / OOM-prone." if not bypass_crash: raise RuntimeError(msg + " Pass bypass_crash=True to override.") print(f"⚠ {msg}") if embed_dim is None: embed_dim = V * D # default: same dim as flattened SVD matrix self.embed_dim = embed_dim self.token_out = (token_out or 'all').lower() cells = [] for i in range(depth): cell_in = in_dim if i == 0 else embed_dim cells.append(SVDCell( in_dim=cell_in, S=S, V=V, D=D, embed_dim=embed_dim, hidden_size=hidden_size, encode=encode, activation=activation, geo_activation=geo_activation, target=target, attention_layers=attention_layers, heads=heads, svd_solver=svd_solver, eigh_solver=eigh_solver, sphere_norm=sphere_norm, mag_head=mag_head, cm_enabled=cm_enabled, cm_points=cm_points, )) self.cells = nn.ModuleList(cells) # QKV projection (only used when token_out="QKV") if self.token_out == 'qkv': self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim) def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None, z: Optional[Union[torch.Tensor, dict, list]] = None): """ x: (B, S, in_dim) — input token sequence y: optional mask tensor (reserved; not yet wired into QKV/SUVt logic) z: experimentation hooks (passed through; not yet consumed) Returns one of: token_out="all" (default) → (B, S, embed_dim) token_out="QKV" → (Q, K, V) tuple, each (B, S, embed_dim) token_out="SUVt"/"SUV" → (U, S_vals, Vt) raw geometric tokens from the last cell's SVD """ for cell in self.cells: x = cell(x) # (B, S, embed_dim) if self.token_out == 'qkv': qkv = self.qkv_proj(x) q, k, v = qkv.chunk(3, dim=-1) return q, k, v if self.token_out in ('suvt', 'suv'): # Return raw SVD components from the last cell — pre-attention # would need to be tapped earlier; this returns post-attention SVD. U, Sv, Vt = self.cells[-1]._last_svd B, S = x.shape[:2] U = U.reshape(B, S, self.V, self.D) Sv = Sv.reshape(B, S, self.D) Vt = Vt.reshape(B, S, self.D, self.D) return U, Sv, Vt return x # ──────────────────────────────────────────────────────────────────────── # Functional wrapper matching the user-provided API spec # ──────────────────────────────────────────────────────────────────────── def svd_transformer(x: torch.Tensor, y: Optional[torch.Tensor] = None, z: Optional[Union[torch.Tensor, dict, list]] = None, *, svd: Optional[Tuple[int, int, int]] = None, bypass_crash: bool = True, heads: int = 64, hidden_size: int = 4, depth: int = 4, encode: str = 'mlp', attention_layers: int = 2, activation: str = 'gelu', geo_activation: str = 'star', token_out: str = 'all', target: str = 'SVD', svd_solver: str = 'auto', eigh_solver: str = 'auto', embed_dim: Optional[int] = None, sphere_norm: bool = True, mag_head: bool = True, cm_enabled: bool = True, cm_points: int = 5) -> SVDTransformer: """ Functional API matching user-provided spec. Returns an SVDTransformer initialized from x's shape; caller invokes it via former(x). Shape inference for `svd=None`: x.shape = (B, S, F) → svd = (S, V, D) using sqrt(F) if F is a perfect square, else (S, 8, 4) fallback x.shape = (B, C, H, W) → raises (caller must patchify or pass svd=) Extra params (beyond original API spec, for magnitude handling + degeneracy): sphere_norm : F.normalize(M, dim=-1) before SVD (default True). Critical for training stability — otherwise row magnitudes are unbounded and SVD produces extreme singular values on occasional batches, driving attention softmax collapse. mag_head : add a fourth readout head on captured row magnitudes (default True, requires sphere_norm). cm_enabled : Cayley-Menger validator on 5-point row subset. Caches pentachoron vol² on each cell for degeneracy monitoring / optional loss regularization. Default True. cm_points : vertices per simplex (default 5 = pentachoron). Returns the SVDTransformer module. Apply it with `former(x)`. """ if svd is None: if x.ndim == 3: B, S, F = x.shape sq = int(F ** 0.5) if sq * sq == F: V, D = sq, sq else: V, D = 8, 4 svd_param = (S, V, D) elif x.ndim == 4: raise ValueError( "svd_transformer with svd=None requires pre-tokenized input " "(B, S, F). For images, patchify first or pass svd=(S, V, D)." ) else: raise ValueError(f"x.shape must be (B, S, F) or (B, C, H, W); got {tuple(x.shape)}") else: svd_param = tuple(svd) in_dim = x.shape[-1] return SVDTransformer( in_dim=in_dim, svd=svd_param, bypass_crash=bypass_crash, heads=heads, hidden_size=hidden_size, depth=depth, encode=encode, attention_layers=attention_layers, activation=activation, geo_activation=geo_activation, token_out=token_out, target=target, svd_solver=svd_solver, eigh_solver=eigh_solver, embed_dim=embed_dim, sphere_norm=sphere_norm, mag_head=mag_head, cm_enabled=cm_enabled, cm_points=cm_points, ) # ──────────────────────────────────────────────────────────────────────── # Self-test on import (smoke check; remove for production) # ──────────────────────────────────────────────────────────────────────── if __name__ == '__main__': print("\n" + "=" * 72) print("SVDTransformer prototype self-test") print("=" * 72) device = 'cuda' if torch.cuda.is_available() else 'cpu' torch.manual_seed(0) # --- Test 1: default config --- print("\n[1] Default config: svd=(16, 8, 4), depth=4, encode='mlp'") x = torch.randn(2, 16, 32, device=device) # (B=2, S=16, F=32) former = svd_transformer(x, svd=(16, 8, 4)) former = former.to(device) out = former(x) n_params = sum(p.numel() for p in former.parameters()) print(f" in shape={tuple(x.shape)} out shape={tuple(out.shape)} params={n_params:,}") assert out.shape == (2, 16, 32), f"Expected (2,16,32), got {out.shape}" # --- Test 2: each encoder type --- print("\n[2] All encoder types:") for enc in _ENCODERS: m = svd_transformer(x, svd=(16, 8, 4), encode=enc, depth=1, attention_layers=1).to(device) try: o = m(x) print(f" encode={enc:12s} → out={tuple(o.shape)} params={sum(p.numel() for p in m.parameters()):,}") except Exception as e: print(f" encode={enc:12s} ✗ {type(e).__name__}: {e}") # --- Test 3: each target --- print("\n[3] All target options:") for tgt in ['SVD', 'VD', 'SV', 'S', 'V']: m = svd_transformer(x, svd=(16, 8, 4), target=tgt, depth=1, attention_layers=0).to(device) o = m(x) # Count how many heads will receive nonzero gradient loss = o.sum() loss.backward() head_grads = { 'S': m.cells[0].s_head.weight.grad.norm().item() if m.cells[0].s_head.weight.grad is not None else 0, 'U': m.cells[0].u_head.weight.grad.norm().item() if m.cells[0].u_head.weight.grad is not None else 0, 'Vt': m.cells[0].vt_head.weight.grad.norm().item() if m.cells[0].vt_head.weight.grad is not None else 0, } active = [k for k, v in head_grads.items() if v > 1e-9] print(f" target={tgt:4s} → active heads={active} out={tuple(o.shape)}") # --- Test 4: each token_out format --- print("\n[4] All token_out formats:") for to in ['all', 'QKV', 'SUVt']: m = svd_transformer(x, svd=(16, 8, 4), token_out=to, depth=1, attention_layers=0).to(device) o = m(x) if isinstance(o, tuple): shapes = [tuple(t.shape) for t in o] print(f" token_out={to:5s} → {len(o)} tensors, shapes={shapes}") else: print(f" token_out={to:5s} → out={tuple(o.shape)}") # --- Test 5: SVD orthogonality on a real model M (post-encoder) --- print("\n[5] SVD orthogonality check (no row centering):") m = svd_transformer(x, svd=(16, 8, 4), depth=1, attention_layers=0).to(device) with torch.no_grad(): encoded = m.cells[0].encoder(x) BS = 2 * 16 M = encoded.reshape(BS, 8, 4) rm = M[0].mean(dim=-1)[:3].tolist() print(f" M not centered: row_means[0,:3] = [{rm[0]:.4f},{rm[1]:.4f},{rm[2]:.4f}]") U, Sv, Vt = _svd_dispatch(M) I_D = torch.eye(4, device=device).expand(BS, 4, 4) u_orth = (torch.bmm(U.transpose(1, 2), U) - I_D).abs().max().item() v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I_D).abs().max().item() recon = (torch.bmm(U * Sv.unsqueeze(1), Vt) - M).abs().max().item() print(f" ||U^T U - I|| = {u_orth:.2e} {'✓' if u_orth < 1e-3 else '✗'}") print(f" ||Vt Vt^T - I|| = {v_orth:.2e} {'✓' if v_orth < 1e-3 else '✗'}") print(f" reconstruction = {recon:.2e} {'✓' if recon < 1e-4 else '✗'}") # --- Test 6: backward pass (gradient flows through SVD) --- print("\n[6] Backward pass (gradient flow through SVD):") m = svd_transformer(x, svd=(16, 8, 4), depth=2, attention_layers=1).to(device) out = m(x) loss = out.pow(2).mean() loss.backward() enc_grad = sum( p.grad.norm().item() ** 2 for p in m.cells[0].encoder.parameters() if p.grad is not None ) ** 0.5 print(f" loss = {loss.item():.4f}") print(f" cell[0].encoder grad_norm = {enc_grad:.4e} " f"{'✓ flowing through SVD into encoder' if enc_grad > 0 else '✗'}") # --- Test 7: solver dispatch combinations --- print("\n[7] Solver dispatch combinations:") for ssolver, esolver in [('auto', 'auto'), ('torch', 'auto'), ('auto', 'fl'), ('auto', 'torch')]: try: m = svd_transformer(x, svd=(16, 8, 4), svd_solver=ssolver, eigh_solver=esolver, depth=1, attention_layers=0).to(device) o = m(x) print(f" svd={ssolver:6s} eigh={esolver:6s} → ok out={tuple(o.shape)}") except Exception as e: print(f" svd={ssolver:6s} eigh={esolver:6s} → {type(e).__name__}: {e}") # --- Test 8: bypass_crash for D > 128 --- print("\n[8] D-too-large guard:") try: m = svd_transformer(torch.randn(2, 4, 200, device=device), svd=(4, 200, 200), bypass_crash=False, depth=1, attention_layers=0) print(f" bypass_crash=False with D=200 → ✗ (should have raised)") except RuntimeError as e: print(f" bypass_crash=False with D=200 → ✓ raised: {str(e)[:60]}...") print("\n" + "=" * 72) print("All smoke tests complete.") print("=" * 72)