geolip-svd-encoder-sweeps / prototype_transformer.py
AbstractPhil's picture
Update prototype_transformer.py
cec44ea verified
"""
SVD Transformer Prototype
=================================================================
Standalone prototype matching user-provided API spec. Combines:
Claude is a new context window with opus 4.7 so we'll see if the results show.
I essentially have to reteach claude how to use the session-scratchpad skill which is annoying but it's fine.
- SpectralProbe lineage's three-head SVD readout (S, U, Vt β†’ embed)
- Correct geolip imports: geolip_core registers 'geolip' alias, then
geolip.linalg as LA, then FLEigh from geolip_core.linalg.eigh
- NO row centering (verified bug β€” gram-based SVD goes degenerate)
- Configurable encoder (mlp/transformer/conv/film/ffn/rotary/lstm/gru)
- Configurable geometric activation (star=ReLUΒ² default)
- Configurable attention layers between SVD passes
- Configurable depth (stacked SVD cells)
- Three head selection via `target` ({SVD, VD, SV, S, V})
- Three output formats via `token_out` ({all, QKV, SUVt})
- Solver dispatch: svd_solver={auto, torch, triton}, eigh_solver={auto, torch, fl}
API parameter interpretations (clarify if wrong):
svd=[S, V, D] β€” S = sequence/slot count, V/D = SVD matrix dims
target="SVD" β€” all three heads active (S, U, Vt)
target="VD" β€” U + Vt only (drop singular values)
target="SV" β€” S + U only (drop right basis)
target="S"/"V" β€” single head
token_out="all" β€” return (B, S, embed_dim) sequence
token_out="QKV" β€” return (Q, K, V) tuple after QKV projection
token_out="SUVt" β€” return (U, S_vals, Vt) of the LAST cell's SVD
depth=N β€” stack N independent SVD cells
Lineage: AbstractPhil / SpectralProbe β†’ CIFAR-10 (53.7% with 13.6k params)
"""
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
# ────────────────────────────────────────────────────────────────────────
# geolip imports β€” CORRECT ORDER (geolip_core triggers sys.modules alias)
# ────────────────────────────────────────────────────────────────────────
try:
import geolip_core # noqa: F401 registers 'geolip' alias in sys.modules
import geolip # now resolvable
import geolip.linalg as LA # main dispatcher
from geolip_core.linalg.eigh import FLEigh
_HAS_GEOLIP = True
print(f"βœ“ geolip {geolip.__version__} β€” using LA.svd + FLEigh")
LA.backend.status()
except ImportError as e:
_HAS_GEOLIP = False
LA = None
FLEigh = None
print(f"⚠ geolip_core not installed ({e}) β€” torch.linalg fallback")
# ────────────────────────────────────────────────────────────────────────
# Activations (regular + geometric)
# ────────────────────────────────────────────────────────────────────────
class StarActivation(nn.Module):
"""ReLUΒ² β€” squared positive activation. All-positive output."""
def forward(self, x):
return F.relu(x).pow(2)
_GEO_ACTS = {
'star': lambda: StarActivation(),
'relu': lambda: nn.ReLU(),
'gelu': lambda: nn.GELU(),
'silu': lambda: nn.SiLU(),
'swilu': lambda: nn.SiLU(), # alias of silu
'tanh': lambda: nn.Tanh(),
'sigmoid': lambda: nn.Sigmoid(),
'leaky_relu': lambda: nn.LeakyReLU(0.01),
}
_REG_ACTS = {
'gelu': lambda: nn.GELU(),
'relu': lambda: nn.ReLU(),
'silu': lambda: nn.SiLU(),
'tanh': lambda: nn.Tanh(),
'leaky_relu': lambda: nn.LeakyReLU(0.01),
}
def make_geo_activation(name: str) -> nn.Module:
name = (name or 'star').lower()
if name not in _GEO_ACTS:
raise ValueError(f"Unknown geo_activation: {name!r}; options: {list(_GEO_ACTS)}")
return _GEO_ACTS[name]()
def make_activation(name: str) -> nn.Module:
name = (name or 'gelu').lower()
if name not in _REG_ACTS:
raise ValueError(f"Unknown activation: {name!r}; options: {list(_REG_ACTS)}")
return _REG_ACTS[name]()
def _act_name_for_pytorch(name: str) -> str:
"""nn.TransformerEncoderLayer accepts 'gelu'/'relu' strings; map our names."""
name = (name or 'gelu').lower()
return name if name in ('gelu', 'relu') else 'gelu'
# ────────────────────────────────────────────────────────────────────────
# Encoder variants β€” apply per-token before SVD reshape
# ────────────────────────────────────────────────────────────────────────
def _fit_heads(d: int, target: int) -> int:
"""Pick a head count that divides d evenly (target preferred)."""
for h in [target, target // 2, target // 4, 8, 4, 2, 1]:
if h > 0 and d % h == 0:
return h
return 1
class MLPEncoder(nn.Module):
"""encode='mlp' (default) β€” two-layer MLP per token."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
# hidden_size is the API's "Internal MLP hidden size" β€” small (default 4)
# Don't let it bottleneck; ensure at least max(in, out)/2
h = max(hidden_size, max(in_dim, out_dim) // 2, 8)
self.net = nn.Sequential(
nn.Linear(in_dim, h),
make_activation(activation),
nn.Linear(h, out_dim),
)
def forward(self, x):
return self.net(x)
class FFNEncoder(nn.Module):
"""encode='ffn' β€” transformer-style 4Γ— expansion FFN."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
h = max(hidden_size, 4 * out_dim)
self.net = nn.Sequential(
nn.Linear(in_dim, h),
make_activation(activation),
nn.Linear(h, out_dim),
)
def forward(self, x):
return self.net(x)
class FiLMEncoder(nn.Module):
"""encode='film' β€” feature-wise affine modulation."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
self.skip = nn.Linear(in_dim, out_dim)
self.gamma = nn.Linear(in_dim, out_dim)
self.beta = nn.Linear(in_dim, out_dim)
self.act = make_activation(activation)
def forward(self, x):
skip = self.skip(x)
return self.act(skip * (1.0 + self.gamma(x)) + self.beta(x))
class ConvEncoder(nn.Module):
"""encode='conv' β€” 1D conv across the token sequence."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim)
self.conv = nn.Conv1d(out_dim, out_dim, kernel_size=3, padding=1)
self.act = make_activation(activation)
def forward(self, x): # (B, S, in_dim)
x = self.proj(x)
x = x.transpose(1, 2) # (B, out_dim, S)
x = self.act(self.conv(x))
return x.transpose(1, 2) # (B, S, out_dim)
class TransformerEncoder(nn.Module):
"""encode='transformer' β€” single transformer encoder layer pre-SVD."""
def __init__(self, in_dim, out_dim, hidden_size, activation, n_heads=4, **_):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim)
h = _fit_heads(out_dim, n_heads)
self.layer = nn.TransformerEncoderLayer(
d_model=out_dim, nhead=h,
dim_feedforward=max(hidden_size, 4 * out_dim),
activation=_act_name_for_pytorch(activation),
batch_first=True, norm_first=True,
)
def forward(self, x):
return self.layer(self.proj(x))
class LSTMEncoder(nn.Module):
"""encode='lstm' β€” sequential LSTM."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
self.lstm = nn.LSTM(in_dim, out_dim, batch_first=True)
self.act = make_activation(activation)
def forward(self, x):
out, _ = self.lstm(x)
return self.act(out)
class GRUEncoder(nn.Module):
"""encode='gru' β€” sequential GRU."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
self.gru = nn.GRU(in_dim, out_dim, batch_first=True)
self.act = make_activation(activation)
def forward(self, x):
out, _ = self.gru(x)
return self.act(out)
class RotaryEncoder(nn.Module):
"""encode='rotary' β€” projection then rotary positional embedding."""
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
super().__init__()
self.proj = nn.Linear(in_dim, out_dim)
self.dim = out_dim
self.act = make_activation(activation)
def forward(self, x):
x = self.proj(x) # (B, S, out_dim)
B, S, D = x.shape
d_half = D // 2
if d_half == 0:
return self.act(x)
positions = torch.arange(S, device=x.device, dtype=x.dtype).unsqueeze(0)
freqs = torch.exp(torch.arange(d_half, device=x.device, dtype=x.dtype)
* (-math.log(10000.0) / d_half))
angles = positions.unsqueeze(-1) * freqs.unsqueeze(0) # (1, S, d_half)
cos, sin = angles.cos(), angles.sin()
x1 = x[..., :d_half]
x2 = x[..., d_half:2 * d_half]
rotated_1 = x1 * cos - x2 * sin
rotated_2 = x1 * sin + x2 * cos
if D % 2 == 1:
tail = x[..., 2 * d_half:]
x = torch.cat([rotated_1, rotated_2, tail], dim=-1)
else:
x = torch.cat([rotated_1, rotated_2], dim=-1)
return self.act(x)
_ENCODERS = {
'mlp': MLPEncoder,
'ffn': FFNEncoder,
'film': FiLMEncoder,
'conv': ConvEncoder,
'transformer': TransformerEncoder,
'lstm': LSTMEncoder,
'gru': GRUEncoder,
'rotary': RotaryEncoder,
}
def build_encoder(encode, in_dim, out_dim, hidden_size, activation):
enc = (encode or 'mlp').lower()
if enc not in _ENCODERS:
raise ValueError(f"Unknown encode={encode!r}; options: {list(_ENCODERS)}")
return _ENCODERS[enc](in_dim, out_dim, hidden_size, activation)
# ────────────────────────────────────────────────────────────────────────
# SVD dispatch β€” auto-route to fastest available correct backend
# ────────────────────────────────────────────────────────────────────────
def _svd_dispatch(M: torch.Tensor,
svd_solver: str = 'auto',
eigh_solver: str = 'auto'
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
M: (BS, V, D) β€” batch of matrices to decompose.
Returns: U (BS, V, D), S_vals (BS, D), Vt (BS, D, D)
Dispatch logic (ALL paths produce thin SVD with descending singular values):
no geolip β†’ torch.linalg.svd in fp64 (rank-deficient-safe)
svd_solver='torch' β†’ torch.linalg.svd in fp64
svd_solver='triton' β†’ LA.svd(method='triton', compute_dtype='fp32')
eigh_solver='fl' β†’ custom gram + FLEigh (compiles up to D=12)
auto/auto β†’ LA.svd with compute_dtype selected by D:
D ≀ 3 β†’ fp32 (Triton fused kernel, fast and precise)
D β‰₯ 4 β†’ fp64 (FL eigh, compilable, ~1e-7 orthogonality)
CRITICAL: without compute_dtype='fp64' for Dβ‰₯4, the Blackwell Triton SVD
kernels (4Γ—4-6Γ—6 fp32) engage by default and orthogonality drops from
~1e-7 to ~1e-3. fp64 is required for training stability at D=4.
NEVER row-center M before this β€” the gram path produces garbage U for
rank-deficient inputs. Verified bug across both this implementation and
geolip's gram_eigh path. The production SVDObserver in geolip_core also
avoids centering for this reason.
"""
# --- Fallback: no geolip
if not _HAS_GEOLIP:
with torch.amp.autocast('cuda', enabled=False):
U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
return U.float(), Sv.float(), Vt.float()
# --- Explicit torch path
if svd_solver == 'torch':
with torch.amp.autocast('cuda', enabled=False):
U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
return U.float(), Sv.float(), Vt.float()
# --- Explicit FL eigh path (custom gram + FLEigh, more accurate than torch.linalg.eigh)
if eigh_solver == 'fl':
return _gram_fl_eigh_svd(M)
# --- Pick compute_dtype based on D ---
# Triton fused kernels (2Γ—2, 3Γ—3, and newer 4Γ—4-6Γ—6) run in fp32 for speed.
# For D ≀ 3 fp32 is precise enough (Triton fused kernel is numerically clean).
# For D β‰₯ 4 we force fp64 to route through FL eigh rather than fp32 Triton SVD
# β€” the Triton path only gives ~1e-3 orthogonality, FL eigh gives ~1e-7.
# Matches SpectralCell's explicit dispatch logic.
D = M.shape[-1]
dtype_arg = 'fp32' if D <= 3 else 'fp64'
# --- Triton path (explicit) β€” keep user's fp32 choice if they asked for it
if svd_solver == 'triton':
try:
return LA.svd(M, method='triton', compute_dtype='fp32')
except Exception as exc:
print(f" ⚠ triton SVD failed ({exc}); falling back to LA.svd default")
return LA.svd(M, compute_dtype=dtype_arg)
# --- Default: LA.svd auto-dispatch with explicit compute_dtype
# fp64 for Dβ‰₯4 β†’ routes to FL eigh (compilable, 70/72 math purity)
# fp32 for D≀3 β†’ routes to Triton fused kernel
return LA.svd(M, compute_dtype=dtype_arg)
def _gram_fl_eigh_svd(M: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Custom gram + FL eigh SVD. Uses geolip_core.linalg.eigh.FLEigh β€” the
Faddeev-LeVerrier polynomial + Laguerre roots + Newton-Schulz pipeline.
More accurate than torch.linalg.eigh on ill-conditioned grams.
Compiles up to D=12 on CUDA. For larger D, use _svd_dispatch with default
auto routing (which will pick torch.linalg.svd).
"""
if FLEigh is None:
raise RuntimeError("FLEigh unavailable β€” geolip_core not installed")
orig_dtype = M.dtype
A = M.float() # FL eigh runs in float
G = torch.bmm(A.transpose(1, 2), A) # (BS, D, D), symmetric PSD
eigenvalues, V = FLEigh()(G)
# eigh returns ascending; we want descending singular values
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
Sv = torch.sqrt(eigenvalues.clamp(min=1e-12))
U = torch.bmm(A, V) / Sv.unsqueeze(1).clamp(min=1e-8)
Vh = V.transpose(-2, -1).contiguous()
return U.to(orig_dtype), Sv.to(orig_dtype), Vh.to(orig_dtype)
# ────────────────────────────────────────────────────────────────────────
# Image patcher (helper for image inputs; not part of svd_transformer itself)
# ────────────────────────────────────────────────────────────────────────
class TensorPatcher(nn.Module):
"""(B, C, H, W) β†’ (B, N, CΒ·phΒ·pw). Pure reshape, no learned params."""
def __init__(self, input_shape, patch_size):
super().__init__()
C, H, W = input_shape
ph = pw = patch_size
assert H % ph == 0 and W % pw == 0
self.C, self.H, self.W = C, H, W
self.ph, self.pw = ph, pw
self.n_patches = (H // ph) * (W // pw)
self.patch_dim = C * ph * pw
def forward(self, x):
B, C, H, W = x.shape
ph, pw = self.ph, self.pw
gh, gw = H // ph, W // pw
p = x.reshape(B, C, gh, ph, gw, pw)
p = p.permute(0, 2, 4, 1, 3, 5).contiguous()
return p.reshape(B, gh * gw, -1)
# ────────────────────────────────────────────────────────────────────────
# Cayley-Menger validator β€” degeneracy detector via simplex volume
# ────────────────────────────────────────────────────────────────────────
from itertools import combinations
class CMValidator(nn.Module):
"""Batch-friendly Cayley-Menger determinant.
Computes pairwise squared distances and simplex volume for (k+1)-point
subsets in arbitrary embedding dimension. Used as a degeneracy detector:
when sphere-normalized rows collapse toward coincidence, the simplex
formed by any (k+1)-point subset flattens and volΒ² β†’ 0.
For k=4: 5 vertices β†’ 10 pairwise dΒ² + 1 volΒ². Pentachoron config.
Adapted from SpectralCell (geolip-core).
"""
def __init__(self, k):
super().__init__()
self._k = k
self._nv = k + 1
pairs = list(combinations(range(self._nv), 2))
self._npairs = len(pairs)
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
def forward(self, verts):
"""verts: (..., nv, edim) β†’ (d2_pairs: (..., npairs), vol2: (...))"""
gram = torch.einsum('...ve,...we->...vw', verts, verts)
norms = torch.diagonal(gram, dim1=-2, dim2=-1)
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
d2_mat = F.relu(d2_mat)
d2_pairs = d2_mat[..., self._pi, self._pj]
shape = d2_mat.shape[:-2]
Vn = d2_mat.shape[-1]
cm = torch.zeros(*shape, Vn + 1, Vn + 1, device=d2_mat.device, dtype=d2_mat.dtype)
cm[..., 0, 1:] = 1.0
cm[..., 1:, 0] = 1.0
cm[..., 1:, 1:] = d2_mat
vol2 = self._prefactor * torch.linalg.det(cm.float())
vol2 = vol2.to(d2_pairs.dtype)
return d2_pairs, vol2
# ────────────────────────────────────────────────────────────────────────
# SVD Cell β€” one cycle: encode β†’ sphere-normalize β†’ SVD β†’ heads β†’ attention
# ────────────────────────────────────────────────────────────────────────
class SVDCell(nn.Module):
"""
One cycle of the architecture:
tokens (B, S, in_dim)
↓ encode (mlp/conv/transformer/...) [out: (B, S, VΒ·D)]
↓ reshape [out: (BΒ·S, V, D)]
↓ capture row_mag = ||M_row|| [out: (BΒ·S, V)] ← NEW
↓ sphere-normalize rows: F.normalize [unit S^(D-1) rows] ← NEW
↓ (optional) CM validator on 5-point subset [volΒ² for degen] ← NEW
↓ SVD via geolip (on clean unit matrix) [out: U, S_vals, Vt]
↓ four-head readout: s + u + vt + mag [out: (BΒ·S, embed_dim)]
↓ geo_activation [out: (BΒ·S, embed_dim)]
↓ reshape [out: (B, S, embed_dim)]
↓ attention_layers Γ— TransformerEncoderLayer
↓ LayerNorm
β†’ tokens (B, S, embed_dim)
Sphere normalization rationale:
Without it, M rows have unbounded magnitudes β€” when any row explodes,
the SVD singular values explode, attention softmax sees extreme values,
and the gradient destabilizes. SpectralCell's fix: normalize rows to
S^(D-1), preserving magnitude as a separate signal (row_mag) that
flows through mag_head. This keeps the high-end AND low-end magnitude
distribution intact while giving the SVD clean inputs to decompose.
CM validation rationale:
Sphere-normalized rows can still collapse to near-coincidence, which
causes SVD to have near-degenerate singular values (numerical nightmare
in gradients). CM computes pentachoron volume from 5 sampled rows; volΒ²
near zero signals degeneracy. Cached on self._last_cm_vol2 for
optional loss regularization or diagnostics.
The SVD components and CM signal are cached on self._last_svd and
self._last_cm_vol2 for token_out="SUVt" extraction and diagnostics.
"""
_TARGET_TO_MASK = {
'SVD': (True, True, True), # all three heads
'VD': (False, True, True), # U + Vt only
'SV': (True, True, False), # S + U only
'S': (True, False, False), # singular values only
'V': (False, True, False), # U (left vectors) only
}
def __init__(self, *, in_dim, S, V, D, embed_dim, hidden_size,
encode, activation, geo_activation, target,
attention_layers, heads, svd_solver, eigh_solver,
sphere_norm=True, mag_head=True,
cm_enabled=True, cm_points=5):
super().__init__()
self.S, self.V, self.D = S, V, D
self.embed_dim = embed_dim
self.target = (target or 'SVD').upper()
self.svd_solver = svd_solver
self.eigh_solver = eigh_solver
self.sphere_norm = sphere_norm
self.use_mag_head = mag_head and sphere_norm # mag_head requires sphere_norm
if self.target not in self._TARGET_TO_MASK:
raise ValueError(f"Unknown target={target!r}; options: {list(self._TARGET_TO_MASK)}")
mat_dim = V * D
# Encoder: tokens (B, S, in_dim) β†’ (B, S, V*D)
self.encoder = build_encoder(encode, in_dim, mat_dim, hidden_size, activation)
# Three SVD head linears (all instantiated; mask gates which contribute)
self.s_head = nn.Linear(D, embed_dim)
self.u_head = nn.Linear(V * D, embed_dim)
self.vt_head = nn.Linear(D * D, embed_dim)
# Magnitude head (optional, requires sphere_norm to have meaning)
if self.use_mag_head:
self.mag_head = nn.Linear(V, embed_dim)
else:
self.mag_head = None
# CM validator (only meaningful when sphere_norm=True and D supports the simplex)
self._cm_nv = cm_points
self._cm_k = cm_points - 1
if cm_enabled and sphere_norm and D >= self._cm_k:
self.cm = CMValidator(self._cm_k)
self._cm_enabled = True
else:
self.cm = None
self._cm_enabled = False
self.geo_act = make_geo_activation(geo_activation)
# Attention stack (post-SVD)
if attention_layers > 0:
n_h = _fit_heads(embed_dim, heads)
layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=n_h,
dim_feedforward=4 * embed_dim,
activation=_act_name_for_pytorch(activation),
batch_first=True, norm_first=True,
)
self.attention = nn.TransformerEncoder(layer, num_layers=attention_layers)
else:
self.attention = nn.Identity()
self.norm = nn.LayerNorm(embed_dim)
self._last_svd = None # (U, S_vals, Vt) cache
self._last_row_mag = None # (B*S, V) magnitude cache
self._last_cm_vol2 = None # (B*S,) CM volumeΒ² cache for diagnostics
def forward(self, tokens):
"""tokens: (B, S, in_dim) β†’ (B, S, embed_dim)"""
B, S, _ = tokens.shape
assert S == self.S, f"Expected S={self.S} tokens, got {S}"
# Encode β†’ VΓ—D matrix per token
encoded = self.encoder(tokens) # (B, S, V*D)
M = encoded.reshape(B * S, self.V, self.D)
# Sphere-normalize rows: preserve magnitude as separate signal, give
# the SVD clean unit-sphere rows to decompose. This is the fix for
# magnitude-explosion instability events. M rows were unbounded before;
# now each row lies on S^(D-1) with ||row||=1. Magnitude flows through
# row_mag β†’ mag_head as its own feature pathway.
if self.sphere_norm:
row_mag = M.norm(dim=-1) # (B*S, V) β€” pre-norm magnitude
M = F.normalize(M, dim=-1) # rows now on S^(D-1)
self._last_row_mag = row_mag
else:
row_mag = None
self._last_row_mag = None
# CM validation: sample 5 rows, compute pentachoron volΒ² as degeneracy
# signal. Low volΒ² β†’ rows clumped on sphere β†’ SVD near-degenerate β†’
# unstable gradients. Cached for optional loss regularization or logging.
if self._cm_enabled:
# Pick nv rows evenly spaced across V β€” linspace indices
cm_idx = torch.linspace(0, self.V - 1, self._cm_nv,
device=M.device).long()
cm_verts = M[:, cm_idx, :] # (B*S, nv, D)
_, cm_vol2 = self.cm(cm_verts)
self._last_cm_vol2 = cm_vol2
else:
self._last_cm_vol2 = None
# SVD β€” now on a clean, bounded, well-conditioned input
U, Sv, Vt = _svd_dispatch(M, self.svd_solver, self.eigh_solver)
self._last_svd = (U, Sv, Vt)
# Four-head readout (S, U, Vt gated by target; mag_head always-on when enabled)
use_s, use_u, use_vt = self._TARGET_TO_MASK[self.target]
token_feat = torch.zeros(B * S, self.embed_dim,
device=tokens.device, dtype=tokens.dtype)
if use_s:
token_feat = token_feat + self.s_head(Sv)
if use_u:
token_feat = token_feat + self.u_head(U.reshape(B * S, -1))
if use_vt:
token_feat = token_feat + self.vt_head(Vt.reshape(B * S, -1))
if self.mag_head is not None and row_mag is not None:
token_feat = token_feat + self.mag_head(row_mag)
token_feat = self.geo_act(token_feat)
token_feat = token_feat.reshape(B, S, self.embed_dim)
# Attention layers
token_feat = self.attention(token_feat)
return self.norm(token_feat)
# ────────────────────────────────────────────────────────────────────────
# SVDTransformer β€” top-level module (depth Γ— SVDCell)
# ────────────────────────────────────────────────────────────────────────
class SVDTransformer(nn.Module):
"""
Stacked SVD cells with configurable encoder, attention, and head selection.
First cell takes in_dim; subsequent cells take embed_dim. Each cell has its
own encoder + SVD + attention substack; `depth` cells in sequence.
"""
def __init__(self, *,
in_dim: int,
svd: Tuple[int, int, int] = (16, 8, 4),
bypass_crash: bool = True,
heads: int = 64,
hidden_size: int = 4,
depth: int = 4,
encode: str = 'mlp',
attention_layers: int = 2,
activation: str = 'gelu',
geo_activation: str = 'star',
token_out: str = 'all',
target: str = 'SVD',
svd_solver: str = 'auto',
eigh_solver: str = 'auto',
embed_dim: Optional[int] = None,
sphere_norm: bool = True,
mag_head: bool = True,
cm_enabled: bool = True,
cm_points: int = 5):
super().__init__()
S, V, D = svd
self.S, self.V, self.D = S, V, D
if D > 128:
msg = f"D={D} > 128 β€” gram-based SVD will be very slow / OOM-prone."
if not bypass_crash:
raise RuntimeError(msg + " Pass bypass_crash=True to override.")
print(f"⚠ {msg}")
if embed_dim is None:
embed_dim = V * D # default: same dim as flattened SVD matrix
self.embed_dim = embed_dim
self.token_out = (token_out or 'all').lower()
cells = []
for i in range(depth):
cell_in = in_dim if i == 0 else embed_dim
cells.append(SVDCell(
in_dim=cell_in, S=S, V=V, D=D, embed_dim=embed_dim,
hidden_size=hidden_size, encode=encode,
activation=activation, geo_activation=geo_activation,
target=target, attention_layers=attention_layers,
heads=heads, svd_solver=svd_solver, eigh_solver=eigh_solver,
sphere_norm=sphere_norm, mag_head=mag_head,
cm_enabled=cm_enabled, cm_points=cm_points,
))
self.cells = nn.ModuleList(cells)
# QKV projection (only used when token_out="QKV")
if self.token_out == 'qkv':
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
def forward(self, x: torch.Tensor,
y: Optional[torch.Tensor] = None,
z: Optional[Union[torch.Tensor, dict, list]] = None):
"""
x: (B, S, in_dim) β€” input token sequence
y: optional mask tensor (reserved; not yet wired into QKV/SUVt logic)
z: experimentation hooks (passed through; not yet consumed)
Returns one of:
token_out="all" (default) β†’ (B, S, embed_dim)
token_out="QKV" β†’ (Q, K, V) tuple, each (B, S, embed_dim)
token_out="SUVt"/"SUV" β†’ (U, S_vals, Vt) raw geometric tokens
from the last cell's SVD
"""
for cell in self.cells:
x = cell(x) # (B, S, embed_dim)
if self.token_out == 'qkv':
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
return q, k, v
if self.token_out in ('suvt', 'suv'):
# Return raw SVD components from the last cell β€” pre-attention
# would need to be tapped earlier; this returns post-attention SVD.
U, Sv, Vt = self.cells[-1]._last_svd
B, S = x.shape[:2]
U = U.reshape(B, S, self.V, self.D)
Sv = Sv.reshape(B, S, self.D)
Vt = Vt.reshape(B, S, self.D, self.D)
return U, Sv, Vt
return x
# ────────────────────────────────────────────────────────────────────────
# Functional wrapper matching the user-provided API spec
# ────────────────────────────────────────────────────────────────────────
def svd_transformer(x: torch.Tensor,
y: Optional[torch.Tensor] = None,
z: Optional[Union[torch.Tensor, dict, list]] = None,
*,
svd: Optional[Tuple[int, int, int]] = None,
bypass_crash: bool = True,
heads: int = 64,
hidden_size: int = 4,
depth: int = 4,
encode: str = 'mlp',
attention_layers: int = 2,
activation: str = 'gelu',
geo_activation: str = 'star',
token_out: str = 'all',
target: str = 'SVD',
svd_solver: str = 'auto',
eigh_solver: str = 'auto',
embed_dim: Optional[int] = None,
sphere_norm: bool = True,
mag_head: bool = True,
cm_enabled: bool = True,
cm_points: int = 5) -> SVDTransformer:
"""
Functional API matching user-provided spec. Returns an SVDTransformer
initialized from x's shape; caller invokes it via former(x).
Shape inference for `svd=None`:
x.shape = (B, S, F) β†’ svd = (S, V, D) using sqrt(F) if F is a perfect
square, else (S, 8, 4) fallback
x.shape = (B, C, H, W) β†’ raises (caller must patchify or pass svd=)
Extra params (beyond original API spec, for magnitude handling + degeneracy):
sphere_norm : F.normalize(M, dim=-1) before SVD (default True). Critical
for training stability β€” otherwise row magnitudes are
unbounded and SVD produces extreme singular values on
occasional batches, driving attention softmax collapse.
mag_head : add a fourth readout head on captured row magnitudes
(default True, requires sphere_norm).
cm_enabled : Cayley-Menger validator on 5-point row subset. Caches
pentachoron volΒ² on each cell for degeneracy monitoring
/ optional loss regularization. Default True.
cm_points : vertices per simplex (default 5 = pentachoron).
Returns the SVDTransformer module. Apply it with `former(x)`.
"""
if svd is None:
if x.ndim == 3:
B, S, F = x.shape
sq = int(F ** 0.5)
if sq * sq == F:
V, D = sq, sq
else:
V, D = 8, 4
svd_param = (S, V, D)
elif x.ndim == 4:
raise ValueError(
"svd_transformer with svd=None requires pre-tokenized input "
"(B, S, F). For images, patchify first or pass svd=(S, V, D)."
)
else:
raise ValueError(f"x.shape must be (B, S, F) or (B, C, H, W); got {tuple(x.shape)}")
else:
svd_param = tuple(svd)
in_dim = x.shape[-1]
return SVDTransformer(
in_dim=in_dim, svd=svd_param, bypass_crash=bypass_crash,
heads=heads, hidden_size=hidden_size, depth=depth,
encode=encode, attention_layers=attention_layers,
activation=activation, geo_activation=geo_activation,
token_out=token_out, target=target,
svd_solver=svd_solver, eigh_solver=eigh_solver,
embed_dim=embed_dim,
sphere_norm=sphere_norm, mag_head=mag_head,
cm_enabled=cm_enabled, cm_points=cm_points,
)
# ────────────────────────────────────────────────────────────────────────
# Self-test on import (smoke check; remove for production)
# ────────────────────────────────────────────────────────────────────────
if __name__ == '__main__':
print("\n" + "=" * 72)
print("SVDTransformer prototype self-test")
print("=" * 72)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)
# --- Test 1: default config ---
print("\n[1] Default config: svd=(16, 8, 4), depth=4, encode='mlp'")
x = torch.randn(2, 16, 32, device=device) # (B=2, S=16, F=32)
former = svd_transformer(x, svd=(16, 8, 4))
former = former.to(device)
out = former(x)
n_params = sum(p.numel() for p in former.parameters())
print(f" in shape={tuple(x.shape)} out shape={tuple(out.shape)} params={n_params:,}")
assert out.shape == (2, 16, 32), f"Expected (2,16,32), got {out.shape}"
# --- Test 2: each encoder type ---
print("\n[2] All encoder types:")
for enc in _ENCODERS:
m = svd_transformer(x, svd=(16, 8, 4), encode=enc, depth=1, attention_layers=1).to(device)
try:
o = m(x)
print(f" encode={enc:12s} β†’ out={tuple(o.shape)} params={sum(p.numel() for p in m.parameters()):,}")
except Exception as e:
print(f" encode={enc:12s} βœ— {type(e).__name__}: {e}")
# --- Test 3: each target ---
print("\n[3] All target options:")
for tgt in ['SVD', 'VD', 'SV', 'S', 'V']:
m = svd_transformer(x, svd=(16, 8, 4), target=tgt, depth=1, attention_layers=0).to(device)
o = m(x)
# Count how many heads will receive nonzero gradient
loss = o.sum()
loss.backward()
head_grads = {
'S': m.cells[0].s_head.weight.grad.norm().item() if m.cells[0].s_head.weight.grad is not None else 0,
'U': m.cells[0].u_head.weight.grad.norm().item() if m.cells[0].u_head.weight.grad is not None else 0,
'Vt': m.cells[0].vt_head.weight.grad.norm().item() if m.cells[0].vt_head.weight.grad is not None else 0,
}
active = [k for k, v in head_grads.items() if v > 1e-9]
print(f" target={tgt:4s} β†’ active heads={active} out={tuple(o.shape)}")
# --- Test 4: each token_out format ---
print("\n[4] All token_out formats:")
for to in ['all', 'QKV', 'SUVt']:
m = svd_transformer(x, svd=(16, 8, 4), token_out=to, depth=1, attention_layers=0).to(device)
o = m(x)
if isinstance(o, tuple):
shapes = [tuple(t.shape) for t in o]
print(f" token_out={to:5s} β†’ {len(o)} tensors, shapes={shapes}")
else:
print(f" token_out={to:5s} β†’ out={tuple(o.shape)}")
# --- Test 5: SVD orthogonality on a real model M (post-encoder) ---
print("\n[5] SVD orthogonality check (no row centering):")
m = svd_transformer(x, svd=(16, 8, 4), depth=1, attention_layers=0).to(device)
with torch.no_grad():
encoded = m.cells[0].encoder(x)
BS = 2 * 16
M = encoded.reshape(BS, 8, 4)
rm = M[0].mean(dim=-1)[:3].tolist()
print(f" M not centered: row_means[0,:3] = [{rm[0]:.4f},{rm[1]:.4f},{rm[2]:.4f}]")
U, Sv, Vt = _svd_dispatch(M)
I_D = torch.eye(4, device=device).expand(BS, 4, 4)
u_orth = (torch.bmm(U.transpose(1, 2), U) - I_D).abs().max().item()
v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I_D).abs().max().item()
recon = (torch.bmm(U * Sv.unsqueeze(1), Vt) - M).abs().max().item()
print(f" ||U^T U - I|| = {u_orth:.2e} {'βœ“' if u_orth < 1e-3 else 'βœ—'}")
print(f" ||Vt Vt^T - I|| = {v_orth:.2e} {'βœ“' if v_orth < 1e-3 else 'βœ—'}")
print(f" reconstruction = {recon:.2e} {'βœ“' if recon < 1e-4 else 'βœ—'}")
# --- Test 6: backward pass (gradient flows through SVD) ---
print("\n[6] Backward pass (gradient flow through SVD):")
m = svd_transformer(x, svd=(16, 8, 4), depth=2, attention_layers=1).to(device)
out = m(x)
loss = out.pow(2).mean()
loss.backward()
enc_grad = sum(
p.grad.norm().item() ** 2
for p in m.cells[0].encoder.parameters() if p.grad is not None
) ** 0.5
print(f" loss = {loss.item():.4f}")
print(f" cell[0].encoder grad_norm = {enc_grad:.4e} "
f"{'βœ“ flowing through SVD into encoder' if enc_grad > 0 else 'βœ—'}")
# --- Test 7: solver dispatch combinations ---
print("\n[7] Solver dispatch combinations:")
for ssolver, esolver in [('auto', 'auto'), ('torch', 'auto'),
('auto', 'fl'), ('auto', 'torch')]:
try:
m = svd_transformer(x, svd=(16, 8, 4),
svd_solver=ssolver, eigh_solver=esolver,
depth=1, attention_layers=0).to(device)
o = m(x)
print(f" svd={ssolver:6s} eigh={esolver:6s} β†’ ok out={tuple(o.shape)}")
except Exception as e:
print(f" svd={ssolver:6s} eigh={esolver:6s} β†’ {type(e).__name__}: {e}")
# --- Test 8: bypass_crash for D > 128 ---
print("\n[8] D-too-large guard:")
try:
m = svd_transformer(torch.randn(2, 4, 200, device=device),
svd=(4, 200, 200), bypass_crash=False, depth=1, attention_layers=0)
print(f" bypass_crash=False with D=200 β†’ βœ— (should have raised)")
except RuntimeError as e:
print(f" bypass_crash=False with D=200 β†’ βœ“ raised: {str(e)[:60]}...")
print("\n" + "=" * 72)
print("All smoke tests complete.")
print("=" * 72)