Create prototype_transformer.py
Browse files- prototype_transformer.py +749 -0
prototype_transformer.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SVD Transformer Prototype
|
| 3 |
+
=================================================================
|
| 4 |
+
Standalone prototype matching user-provided API spec. Combines:
|
| 5 |
+
|
| 6 |
+
- SpectralProbe lineage's three-head SVD readout (S, U, Vt β embed)
|
| 7 |
+
- Correct geolip imports: geolip_core registers 'geolip' alias, then
|
| 8 |
+
geolip.linalg as LA, then FLEigh from geolip_core.linalg.eigh
|
| 9 |
+
- NO row centering (verified bug β gram-based SVD goes degenerate)
|
| 10 |
+
- Configurable encoder (mlp/transformer/conv/film/ffn/rotary/lstm/gru)
|
| 11 |
+
- Configurable geometric activation (star=ReLUΒ² default)
|
| 12 |
+
- Configurable attention layers between SVD passes
|
| 13 |
+
- Configurable depth (stacked SVD cells)
|
| 14 |
+
- Three head selection via `target` ({SVD, VD, SV, S, V})
|
| 15 |
+
- Three output formats via `token_out` ({all, QKV, SUVt})
|
| 16 |
+
- Solver dispatch: svd_solver={auto, torch, triton}, eigh_solver={auto, torch, fl}
|
| 17 |
+
|
| 18 |
+
API parameter interpretations (clarify if wrong):
|
| 19 |
+
svd=[S, V, D] β S = sequence/slot count, V/D = SVD matrix dims
|
| 20 |
+
target="SVD" β all three heads active (S, U, Vt)
|
| 21 |
+
target="VD" β U + Vt only (drop singular values)
|
| 22 |
+
target="SV" β S + U only (drop right basis)
|
| 23 |
+
target="S"/"V" β single head
|
| 24 |
+
token_out="all" β return (B, S, embed_dim) sequence
|
| 25 |
+
token_out="QKV" β return (Q, K, V) tuple after QKV projection
|
| 26 |
+
token_out="SUVt" β return (U, S_vals, Vt) of the LAST cell's SVD
|
| 27 |
+
depth=N β stack N independent SVD cells
|
| 28 |
+
|
| 29 |
+
Lineage: AbstractPhil / SpectralProbe β CIFAR-10 (53.7% with 13.6k params)
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import math
|
| 33 |
+
from typing import Optional, Tuple, Union
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn as nn
|
| 37 |
+
import torch.nn.functional as F
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# geolip imports β CORRECT ORDER (geolip_core triggers sys.modules alias)
|
| 42 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
import geolip_core # noqa: F401 registers 'geolip' alias in sys.modules
|
| 46 |
+
import geolip # now resolvable
|
| 47 |
+
import geolip.linalg as LA # main dispatcher
|
| 48 |
+
from geolip_core.linalg.eigh import FLEigh
|
| 49 |
+
_HAS_GEOLIP = True
|
| 50 |
+
print(f"β geolip {geolip.__version__} β using LA.svd + FLEigh")
|
| 51 |
+
LA.backend.status()
|
| 52 |
+
except ImportError as e:
|
| 53 |
+
_HAS_GEOLIP = False
|
| 54 |
+
LA = None
|
| 55 |
+
FLEigh = None
|
| 56 |
+
print(f"β geolip_core not installed ({e}) β torch.linalg fallback")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 60 |
+
# Activations (regular + geometric)
|
| 61 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
class StarActivation(nn.Module):
|
| 64 |
+
"""ReLUΒ² β squared positive activation. All-positive output."""
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
return F.relu(x).pow(2)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
_GEO_ACTS = {
|
| 70 |
+
'star': lambda: StarActivation(),
|
| 71 |
+
'relu': lambda: nn.ReLU(),
|
| 72 |
+
'gelu': lambda: nn.GELU(),
|
| 73 |
+
'silu': lambda: nn.SiLU(),
|
| 74 |
+
'swilu': lambda: nn.SiLU(), # alias of silu
|
| 75 |
+
'tanh': lambda: nn.Tanh(),
|
| 76 |
+
'sigmoid': lambda: nn.Sigmoid(),
|
| 77 |
+
'leaky_relu': lambda: nn.LeakyReLU(0.01),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
_REG_ACTS = {
|
| 81 |
+
'gelu': lambda: nn.GELU(),
|
| 82 |
+
'relu': lambda: nn.ReLU(),
|
| 83 |
+
'silu': lambda: nn.SiLU(),
|
| 84 |
+
'tanh': lambda: nn.Tanh(),
|
| 85 |
+
'leaky_relu': lambda: nn.LeakyReLU(0.01),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def make_geo_activation(name: str) -> nn.Module:
|
| 90 |
+
name = (name or 'star').lower()
|
| 91 |
+
if name not in _GEO_ACTS:
|
| 92 |
+
raise ValueError(f"Unknown geo_activation: {name!r}; options: {list(_GEO_ACTS)}")
|
| 93 |
+
return _GEO_ACTS[name]()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def make_activation(name: str) -> nn.Module:
|
| 97 |
+
name = (name or 'gelu').lower()
|
| 98 |
+
if name not in _REG_ACTS:
|
| 99 |
+
raise ValueError(f"Unknown activation: {name!r}; options: {list(_REG_ACTS)}")
|
| 100 |
+
return _REG_ACTS[name]()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _act_name_for_pytorch(name: str) -> str:
|
| 104 |
+
"""nn.TransformerEncoderLayer accepts 'gelu'/'relu' strings; map our names."""
|
| 105 |
+
name = (name or 'gelu').lower()
|
| 106 |
+
return name if name in ('gelu', 'relu') else 'gelu'
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
# Encoder variants β apply per-token before SVD reshape
|
| 111 |
+
# ββββββββββββββββββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
|
| 113 |
+
def _fit_heads(d: int, target: int) -> int:
|
| 114 |
+
"""Pick a head count that divides d evenly (target preferred)."""
|
| 115 |
+
for h in [target, target // 2, target // 4, 8, 4, 2, 1]:
|
| 116 |
+
if h > 0 and d % h == 0:
|
| 117 |
+
return h
|
| 118 |
+
return 1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MLPEncoder(nn.Module):
|
| 122 |
+
"""encode='mlp' (default) β two-layer MLP per token."""
|
| 123 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 124 |
+
super().__init__()
|
| 125 |
+
# hidden_size is the API's "Internal MLP hidden size" β small (default 4)
|
| 126 |
+
# Don't let it bottleneck; ensure at least max(in, out)/2
|
| 127 |
+
h = max(hidden_size, max(in_dim, out_dim) // 2, 8)
|
| 128 |
+
self.net = nn.Sequential(
|
| 129 |
+
nn.Linear(in_dim, h),
|
| 130 |
+
make_activation(activation),
|
| 131 |
+
nn.Linear(h, out_dim),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
return self.net(x)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FFNEncoder(nn.Module):
|
| 139 |
+
"""encode='ffn' β transformer-style 4Γ expansion FFN."""
|
| 140 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 141 |
+
super().__init__()
|
| 142 |
+
h = max(hidden_size, 4 * out_dim)
|
| 143 |
+
self.net = nn.Sequential(
|
| 144 |
+
nn.Linear(in_dim, h),
|
| 145 |
+
make_activation(activation),
|
| 146 |
+
nn.Linear(h, out_dim),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, x):
|
| 150 |
+
return self.net(x)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class FiLMEncoder(nn.Module):
|
| 154 |
+
"""encode='film' β feature-wise affine modulation."""
|
| 155 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.skip = nn.Linear(in_dim, out_dim)
|
| 158 |
+
self.gamma = nn.Linear(in_dim, out_dim)
|
| 159 |
+
self.beta = nn.Linear(in_dim, out_dim)
|
| 160 |
+
self.act = make_activation(activation)
|
| 161 |
+
|
| 162 |
+
def forward(self, x):
|
| 163 |
+
skip = self.skip(x)
|
| 164 |
+
return self.act(skip * (1.0 + self.gamma(x)) + self.beta(x))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class ConvEncoder(nn.Module):
|
| 168 |
+
"""encode='conv' β 1D conv across the token sequence."""
|
| 169 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.proj = nn.Linear(in_dim, out_dim)
|
| 172 |
+
self.conv = nn.Conv1d(out_dim, out_dim, kernel_size=3, padding=1)
|
| 173 |
+
self.act = make_activation(activation)
|
| 174 |
+
|
| 175 |
+
def forward(self, x): # (B, S, in_dim)
|
| 176 |
+
x = self.proj(x)
|
| 177 |
+
x = x.transpose(1, 2) # (B, out_dim, S)
|
| 178 |
+
x = self.act(self.conv(x))
|
| 179 |
+
return x.transpose(1, 2) # (B, S, out_dim)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class TransformerEncoder(nn.Module):
|
| 183 |
+
"""encode='transformer' β single transformer encoder layer pre-SVD."""
|
| 184 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, n_heads=4, **_):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.proj = nn.Linear(in_dim, out_dim)
|
| 187 |
+
h = _fit_heads(out_dim, n_heads)
|
| 188 |
+
self.layer = nn.TransformerEncoderLayer(
|
| 189 |
+
d_model=out_dim, nhead=h,
|
| 190 |
+
dim_feedforward=max(hidden_size, 4 * out_dim),
|
| 191 |
+
activation=_act_name_for_pytorch(activation),
|
| 192 |
+
batch_first=True, norm_first=True,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
return self.layer(self.proj(x))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class LSTMEncoder(nn.Module):
|
| 200 |
+
"""encode='lstm' β sequential LSTM."""
|
| 201 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.lstm = nn.LSTM(in_dim, out_dim, batch_first=True)
|
| 204 |
+
self.act = make_activation(activation)
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
out, _ = self.lstm(x)
|
| 208 |
+
return self.act(out)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class GRUEncoder(nn.Module):
|
| 212 |
+
"""encode='gru' β sequential GRU."""
|
| 213 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.gru = nn.GRU(in_dim, out_dim, batch_first=True)
|
| 216 |
+
self.act = make_activation(activation)
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
out, _ = self.gru(x)
|
| 220 |
+
return self.act(out)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class RotaryEncoder(nn.Module):
|
| 224 |
+
"""encode='rotary' β projection then rotary positional embedding."""
|
| 225 |
+
def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.proj = nn.Linear(in_dim, out_dim)
|
| 228 |
+
self.dim = out_dim
|
| 229 |
+
self.act = make_activation(activation)
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
x = self.proj(x) # (B, S, out_dim)
|
| 233 |
+
B, S, D = x.shape
|
| 234 |
+
d_half = D // 2
|
| 235 |
+
if d_half == 0:
|
| 236 |
+
return self.act(x)
|
| 237 |
+
positions = torch.arange(S, device=x.device, dtype=x.dtype).unsqueeze(0)
|
| 238 |
+
freqs = torch.exp(torch.arange(d_half, device=x.device, dtype=x.dtype)
|
| 239 |
+
* (-math.log(10000.0) / d_half))
|
| 240 |
+
angles = positions.unsqueeze(-1) * freqs.unsqueeze(0) # (1, S, d_half)
|
| 241 |
+
cos, sin = angles.cos(), angles.sin()
|
| 242 |
+
x1 = x[..., :d_half]
|
| 243 |
+
x2 = x[..., d_half:2 * d_half]
|
| 244 |
+
rotated_1 = x1 * cos - x2 * sin
|
| 245 |
+
rotated_2 = x1 * sin + x2 * cos
|
| 246 |
+
if D % 2 == 1:
|
| 247 |
+
tail = x[..., 2 * d_half:]
|
| 248 |
+
x = torch.cat([rotated_1, rotated_2, tail], dim=-1)
|
| 249 |
+
else:
|
| 250 |
+
x = torch.cat([rotated_1, rotated_2], dim=-1)
|
| 251 |
+
return self.act(x)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
_ENCODERS = {
|
| 255 |
+
'mlp': MLPEncoder,
|
| 256 |
+
'ffn': FFNEncoder,
|
| 257 |
+
'film': FiLMEncoder,
|
| 258 |
+
'conv': ConvEncoder,
|
| 259 |
+
'transformer': TransformerEncoder,
|
| 260 |
+
'lstm': LSTMEncoder,
|
| 261 |
+
'gru': GRUEncoder,
|
| 262 |
+
'rotary': RotaryEncoder,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def build_encoder(encode, in_dim, out_dim, hidden_size, activation):
|
| 267 |
+
enc = (encode or 'mlp').lower()
|
| 268 |
+
if enc not in _ENCODERS:
|
| 269 |
+
raise ValueError(f"Unknown encode={encode!r}; options: {list(_ENCODERS)}")
|
| 270 |
+
return _ENCODERS[enc](in_dim, out_dim, hidden_size, activation)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
# SVD dispatch β auto-route to fastest available correct backend
|
| 275 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
|
| 277 |
+
def _svd_dispatch(M: torch.Tensor,
|
| 278 |
+
svd_solver: str = 'auto',
|
| 279 |
+
eigh_solver: str = 'auto'
|
| 280 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 281 |
+
"""
|
| 282 |
+
M: (BS, V, D) β batch of matrices to decompose.
|
| 283 |
+
Returns: U (BS, V, D), S_vals (BS, D), Vt (BS, D, D)
|
| 284 |
+
|
| 285 |
+
Dispatch logic (ALL paths produce thin SVD with descending singular values):
|
| 286 |
+
|
| 287 |
+
no geolip β torch.linalg.svd in fp64 (rank-deficient-safe)
|
| 288 |
+
svd_solver='torch' β torch.linalg.svd in fp64
|
| 289 |
+
svd_solver='triton' β LA.svd(method='triton') β D β€ 6 fp64 only
|
| 290 |
+
eigh_solver='fl' β custom gram + FLEigh (compiles up to D=12)
|
| 291 |
+
auto/auto β LA.svd default dispatch (best per backend)
|
| 292 |
+
|
| 293 |
+
NEVER row-center M before this β the gram path produces garbage U for
|
| 294 |
+
rank-deficient inputs. Verified bug across both this implementation and
|
| 295 |
+
geolip's gram_eigh path. The production SVDObserver in geolip_core also
|
| 296 |
+
avoids centering for this reason.
|
| 297 |
+
"""
|
| 298 |
+
# --- Fallback: no geolip
|
| 299 |
+
if not _HAS_GEOLIP:
|
| 300 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 301 |
+
U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
|
| 302 |
+
return U.float(), Sv.float(), Vt.float()
|
| 303 |
+
|
| 304 |
+
# --- Explicit torch path
|
| 305 |
+
if svd_solver == 'torch':
|
| 306 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 307 |
+
U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
|
| 308 |
+
return U.float(), Sv.float(), Vt.float()
|
| 309 |
+
|
| 310 |
+
# --- Explicit FL eigh path (custom gram + FLEigh, more accurate than torch.linalg.eigh)
|
| 311 |
+
if eigh_solver == 'fl':
|
| 312 |
+
return _gram_fl_eigh_svd(M)
|
| 313 |
+
|
| 314 |
+
# --- Triton path
|
| 315 |
+
if svd_solver == 'triton':
|
| 316 |
+
try:
|
| 317 |
+
return LA.svd(M, method='triton')
|
| 318 |
+
except Exception as exc:
|
| 319 |
+
print(f" β triton SVD failed ({exc}); falling back to LA.svd default")
|
| 320 |
+
return LA.svd(M)
|
| 321 |
+
|
| 322 |
+
# --- Default: LA.svd auto-dispatch (FL eigh on CUDA Dβ€12, torch otherwise)
|
| 323 |
+
return LA.svd(M)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _gram_fl_eigh_svd(M: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 327 |
+
"""
|
| 328 |
+
Custom gram + FL eigh SVD. Uses geolip_core.linalg.eigh.FLEigh β the
|
| 329 |
+
Faddeev-LeVerrier polynomial + Laguerre roots + Newton-Schulz pipeline.
|
| 330 |
+
More accurate than torch.linalg.eigh on ill-conditioned grams.
|
| 331 |
+
|
| 332 |
+
Compiles up to D=12 on CUDA. For larger D, use _svd_dispatch with default
|
| 333 |
+
auto routing (which will pick torch.linalg.svd).
|
| 334 |
+
"""
|
| 335 |
+
if FLEigh is None:
|
| 336 |
+
raise RuntimeError("FLEigh unavailable β geolip_core not installed")
|
| 337 |
+
orig_dtype = M.dtype
|
| 338 |
+
A = M.float() # FL eigh runs in float
|
| 339 |
+
G = torch.bmm(A.transpose(1, 2), A) # (BS, D, D), symmetric PSD
|
| 340 |
+
eigenvalues, V = FLEigh()(G)
|
| 341 |
+
# eigh returns ascending; we want descending singular values
|
| 342 |
+
eigenvalues = eigenvalues.flip(-1)
|
| 343 |
+
V = V.flip(-1)
|
| 344 |
+
Sv = torch.sqrt(eigenvalues.clamp(min=1e-12))
|
| 345 |
+
U = torch.bmm(A, V) / Sv.unsqueeze(1).clamp(min=1e-8)
|
| 346 |
+
Vh = V.transpose(-2, -1).contiguous()
|
| 347 |
+
return U.to(orig_dtype), Sv.to(orig_dtype), Vh.to(orig_dtype)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 351 |
+
# Image patcher (helper for image inputs; not part of svd_transformer itself)
|
| 352 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
+
|
| 354 |
+
class TensorPatcher(nn.Module):
|
| 355 |
+
"""(B, C, H, W) β (B, N, CΒ·phΒ·pw). Pure reshape, no learned params."""
|
| 356 |
+
def __init__(self, input_shape, patch_size):
|
| 357 |
+
super().__init__()
|
| 358 |
+
C, H, W = input_shape
|
| 359 |
+
ph = pw = patch_size
|
| 360 |
+
assert H % ph == 0 and W % pw == 0
|
| 361 |
+
self.C, self.H, self.W = C, H, W
|
| 362 |
+
self.ph, self.pw = ph, pw
|
| 363 |
+
self.n_patches = (H // ph) * (W // pw)
|
| 364 |
+
self.patch_dim = C * ph * pw
|
| 365 |
+
|
| 366 |
+
def forward(self, x):
|
| 367 |
+
B, C, H, W = x.shape
|
| 368 |
+
ph, pw = self.ph, self.pw
|
| 369 |
+
gh, gw = H // ph, W // pw
|
| 370 |
+
p = x.reshape(B, C, gh, ph, gw, pw)
|
| 371 |
+
p = p.permute(0, 2, 4, 1, 3, 5).contiguous()
|
| 372 |
+
return p.reshape(B, gh * gw, -1)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 376 |
+
# SVD Cell β one cycle: encode β SVD β three heads β attention
|
| 377 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 378 |
+
|
| 379 |
+
class SVDCell(nn.Module):
|
| 380 |
+
"""
|
| 381 |
+
One cycle of the architecture:
|
| 382 |
+
|
| 383 |
+
tokens (B, S, in_dim)
|
| 384 |
+
β encode (mlp/conv/transformer/...) [out: (B, S, VΒ·D)]
|
| 385 |
+
β reshape [out: (BΒ·S, V, D)]
|
| 386 |
+
β SVD via geolip [out: U(BS,V,D), S(BS,D), Vt(BS,D,D)]
|
| 387 |
+
β three-head readout (target masks heads) [out: (BΒ·S, embed_dim)]
|
| 388 |
+
β geo_activation [out: (BΒ·S, embed_dim)]
|
| 389 |
+
β reshape [out: (B, S, embed_dim)]
|
| 390 |
+
β attention_layers Γ TransformerEncoderLayer
|
| 391 |
+
β LayerNorm
|
| 392 |
+
β tokens (B, S, embed_dim)
|
| 393 |
+
|
| 394 |
+
The SVD components (U, S_vals, Vt) of the last forward pass are cached
|
| 395 |
+
on `self._last_svd` for token_out="SUVt" extraction.
|
| 396 |
+
"""
|
| 397 |
+
_TARGET_TO_MASK = {
|
| 398 |
+
'SVD': (True, True, True), # all three heads
|
| 399 |
+
'VD': (False, True, True), # U + Vt only
|
| 400 |
+
'SV': (True, True, False), # S + U only
|
| 401 |
+
'S': (True, False, False), # singular values only
|
| 402 |
+
'V': (False, True, False), # U (left vectors) only
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
def __init__(self, *, in_dim, S, V, D, embed_dim, hidden_size,
|
| 406 |
+
encode, activation, geo_activation, target,
|
| 407 |
+
attention_layers, heads, svd_solver, eigh_solver):
|
| 408 |
+
super().__init__()
|
| 409 |
+
self.S, self.V, self.D = S, V, D
|
| 410 |
+
self.embed_dim = embed_dim
|
| 411 |
+
self.target = (target or 'SVD').upper()
|
| 412 |
+
self.svd_solver = svd_solver
|
| 413 |
+
self.eigh_solver = eigh_solver
|
| 414 |
+
if self.target not in self._TARGET_TO_MASK:
|
| 415 |
+
raise ValueError(f"Unknown target={target!r}; options: {list(self._TARGET_TO_MASK)}")
|
| 416 |
+
|
| 417 |
+
mat_dim = V * D
|
| 418 |
+
|
| 419 |
+
# Encoder: tokens (B, S, in_dim) β (B, S, V*D)
|
| 420 |
+
self.encoder = build_encoder(encode, in_dim, mat_dim, hidden_size, activation)
|
| 421 |
+
|
| 422 |
+
# Three head linears (all instantiated; mask gates which contribute)
|
| 423 |
+
self.s_head = nn.Linear(D, embed_dim)
|
| 424 |
+
self.u_head = nn.Linear(V * D, embed_dim)
|
| 425 |
+
self.vt_head = nn.Linear(D * D, embed_dim)
|
| 426 |
+
|
| 427 |
+
self.geo_act = make_geo_activation(geo_activation)
|
| 428 |
+
|
| 429 |
+
# Attention stack (post-SVD)
|
| 430 |
+
if attention_layers > 0:
|
| 431 |
+
n_h = _fit_heads(embed_dim, heads)
|
| 432 |
+
layer = nn.TransformerEncoderLayer(
|
| 433 |
+
d_model=embed_dim, nhead=n_h,
|
| 434 |
+
dim_feedforward=4 * embed_dim,
|
| 435 |
+
activation=_act_name_for_pytorch(activation),
|
| 436 |
+
batch_first=True, norm_first=True,
|
| 437 |
+
)
|
| 438 |
+
self.attention = nn.TransformerEncoder(layer, num_layers=attention_layers)
|
| 439 |
+
else:
|
| 440 |
+
self.attention = nn.Identity()
|
| 441 |
+
|
| 442 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 443 |
+
self._last_svd = None # (U, S_vals, Vt) cache for token_out="SUVt"
|
| 444 |
+
|
| 445 |
+
def forward(self, tokens):
|
| 446 |
+
"""tokens: (B, S, in_dim) β (B, S, embed_dim)"""
|
| 447 |
+
B, S, _ = tokens.shape
|
| 448 |
+
assert S == self.S, f"Expected S={self.S} tokens, got {S}"
|
| 449 |
+
|
| 450 |
+
# Encode β VΓD matrix per token (NO row-centering)
|
| 451 |
+
encoded = self.encoder(tokens) # (B, S, V*D)
|
| 452 |
+
M = encoded.reshape(B * S, self.V, self.D)
|
| 453 |
+
|
| 454 |
+
# SVD
|
| 455 |
+
U, Sv, Vt = _svd_dispatch(M, self.svd_solver, self.eigh_solver)
|
| 456 |
+
self._last_svd = (U, Sv, Vt)
|
| 457 |
+
|
| 458 |
+
# Three-head readout (target gates which heads contribute)
|
| 459 |
+
use_s, use_u, use_vt = self._TARGET_TO_MASK[self.target]
|
| 460 |
+
token_feat = torch.zeros(B * S, self.embed_dim,
|
| 461 |
+
device=tokens.device, dtype=tokens.dtype)
|
| 462 |
+
if use_s:
|
| 463 |
+
token_feat = token_feat + self.s_head(Sv)
|
| 464 |
+
if use_u:
|
| 465 |
+
token_feat = token_feat + self.u_head(U.reshape(B * S, -1))
|
| 466 |
+
if use_vt:
|
| 467 |
+
token_feat = token_feat + self.vt_head(Vt.reshape(B * S, -1))
|
| 468 |
+
|
| 469 |
+
token_feat = self.geo_act(token_feat)
|
| 470 |
+
token_feat = token_feat.reshape(B, S, self.embed_dim)
|
| 471 |
+
|
| 472 |
+
# Attention layers
|
| 473 |
+
token_feat = self.attention(token_feat)
|
| 474 |
+
return self.norm(token_feat)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 478 |
+
# SVDTransformer β top-level module (depth Γ SVDCell)
|
| 479 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 480 |
+
|
| 481 |
+
class SVDTransformer(nn.Module):
|
| 482 |
+
"""
|
| 483 |
+
Stacked SVD cells with configurable encoder, attention, and head selection.
|
| 484 |
+
|
| 485 |
+
First cell takes in_dim; subsequent cells take embed_dim. Each cell has its
|
| 486 |
+
own encoder + SVD + attention substack; `depth` cells in sequence.
|
| 487 |
+
"""
|
| 488 |
+
def __init__(self, *,
|
| 489 |
+
in_dim: int,
|
| 490 |
+
svd: Tuple[int, int, int] = (16, 8, 4),
|
| 491 |
+
bypass_crash: bool = True,
|
| 492 |
+
heads: int = 64,
|
| 493 |
+
hidden_size: int = 4,
|
| 494 |
+
depth: int = 4,
|
| 495 |
+
encode: str = 'mlp',
|
| 496 |
+
attention_layers: int = 2,
|
| 497 |
+
activation: str = 'gelu',
|
| 498 |
+
geo_activation: str = 'star',
|
| 499 |
+
token_out: str = 'all',
|
| 500 |
+
target: str = 'SVD',
|
| 501 |
+
svd_solver: str = 'auto',
|
| 502 |
+
eigh_solver: str = 'auto',
|
| 503 |
+
embed_dim: Optional[int] = None):
|
| 504 |
+
super().__init__()
|
| 505 |
+
S, V, D = svd
|
| 506 |
+
self.S, self.V, self.D = S, V, D
|
| 507 |
+
|
| 508 |
+
if D > 128:
|
| 509 |
+
msg = f"D={D} > 128 β gram-based SVD will be very slow / OOM-prone."
|
| 510 |
+
if not bypass_crash:
|
| 511 |
+
raise RuntimeError(msg + " Pass bypass_crash=True to override.")
|
| 512 |
+
print(f"β {msg}")
|
| 513 |
+
|
| 514 |
+
if embed_dim is None:
|
| 515 |
+
embed_dim = V * D # default: same dim as flattened SVD matrix
|
| 516 |
+
self.embed_dim = embed_dim
|
| 517 |
+
self.token_out = (token_out or 'all').lower()
|
| 518 |
+
|
| 519 |
+
cells = []
|
| 520 |
+
for i in range(depth):
|
| 521 |
+
cell_in = in_dim if i == 0 else embed_dim
|
| 522 |
+
cells.append(SVDCell(
|
| 523 |
+
in_dim=cell_in, S=S, V=V, D=D, embed_dim=embed_dim,
|
| 524 |
+
hidden_size=hidden_size, encode=encode,
|
| 525 |
+
activation=activation, geo_activation=geo_activation,
|
| 526 |
+
target=target, attention_layers=attention_layers,
|
| 527 |
+
heads=heads, svd_solver=svd_solver, eigh_solver=eigh_solver,
|
| 528 |
+
))
|
| 529 |
+
self.cells = nn.ModuleList(cells)
|
| 530 |
+
|
| 531 |
+
# QKV projection (only used when token_out="QKV")
|
| 532 |
+
if self.token_out == 'qkv':
|
| 533 |
+
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
|
| 534 |
+
|
| 535 |
+
def forward(self, x: torch.Tensor,
|
| 536 |
+
y: Optional[torch.Tensor] = None,
|
| 537 |
+
z: Optional[Union[torch.Tensor, dict, list]] = None):
|
| 538 |
+
"""
|
| 539 |
+
x: (B, S, in_dim) β input token sequence
|
| 540 |
+
y: optional mask tensor (reserved; not yet wired into QKV/SUVt logic)
|
| 541 |
+
z: experimentation hooks (passed through; not yet consumed)
|
| 542 |
+
|
| 543 |
+
Returns one of:
|
| 544 |
+
token_out="all" (default) β (B, S, embed_dim)
|
| 545 |
+
token_out="QKV" β (Q, K, V) tuple, each (B, S, embed_dim)
|
| 546 |
+
token_out="SUVt"/"SUV" β (U, S_vals, Vt) raw geometric tokens
|
| 547 |
+
from the last cell's SVD
|
| 548 |
+
"""
|
| 549 |
+
for cell in self.cells:
|
| 550 |
+
x = cell(x) # (B, S, embed_dim)
|
| 551 |
+
|
| 552 |
+
if self.token_out == 'qkv':
|
| 553 |
+
qkv = self.qkv_proj(x)
|
| 554 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 555 |
+
return q, k, v
|
| 556 |
+
|
| 557 |
+
if self.token_out in ('suvt', 'suv'):
|
| 558 |
+
# Return raw SVD components from the last cell β pre-attention
|
| 559 |
+
# would need to be tapped earlier; this returns post-attention SVD.
|
| 560 |
+
U, Sv, Vt = self.cells[-1]._last_svd
|
| 561 |
+
B, S = x.shape[:2]
|
| 562 |
+
U = U.reshape(B, S, self.V, self.D)
|
| 563 |
+
Sv = Sv.reshape(B, S, self.D)
|
| 564 |
+
Vt = Vt.reshape(B, S, self.D, self.D)
|
| 565 |
+
return U, Sv, Vt
|
| 566 |
+
|
| 567 |
+
return x
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 571 |
+
# Functional wrapper matching the user-provided API spec
|
| 572 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 573 |
+
|
| 574 |
+
def svd_transformer(x: torch.Tensor,
|
| 575 |
+
y: Optional[torch.Tensor] = None,
|
| 576 |
+
z: Optional[Union[torch.Tensor, dict, list]] = None,
|
| 577 |
+
*,
|
| 578 |
+
svd: Optional[Tuple[int, int, int]] = None,
|
| 579 |
+
bypass_crash: bool = True,
|
| 580 |
+
heads: int = 64,
|
| 581 |
+
hidden_size: int = 4,
|
| 582 |
+
depth: int = 4,
|
| 583 |
+
encode: str = 'mlp',
|
| 584 |
+
attention_layers: int = 2,
|
| 585 |
+
activation: str = 'gelu',
|
| 586 |
+
geo_activation: str = 'star',
|
| 587 |
+
token_out: str = 'all',
|
| 588 |
+
target: str = 'SVD',
|
| 589 |
+
svd_solver: str = 'auto',
|
| 590 |
+
eigh_solver: str = 'auto',
|
| 591 |
+
embed_dim: Optional[int] = None) -> SVDTransformer:
|
| 592 |
+
"""
|
| 593 |
+
Functional API matching user-provided spec. Returns an SVDTransformer
|
| 594 |
+
initialized from x's shape; caller invokes it via former(x).
|
| 595 |
+
|
| 596 |
+
Shape inference for `svd=None`:
|
| 597 |
+
x.shape = (B, S, F) β svd = (S, V, D) using sqrt(F) if F is a perfect
|
| 598 |
+
square, else (S, 8, 4) fallback
|
| 599 |
+
x.shape = (B, C, H, W) β raises (caller must patchify or pass svd=)
|
| 600 |
+
|
| 601 |
+
Returns the SVDTransformer module. Apply it with `former(x)`.
|
| 602 |
+
"""
|
| 603 |
+
if svd is None:
|
| 604 |
+
if x.ndim == 3:
|
| 605 |
+
B, S, F = x.shape
|
| 606 |
+
sq = int(F ** 0.5)
|
| 607 |
+
if sq * sq == F:
|
| 608 |
+
V, D = sq, sq
|
| 609 |
+
else:
|
| 610 |
+
V, D = 8, 4
|
| 611 |
+
svd_param = (S, V, D)
|
| 612 |
+
elif x.ndim == 4:
|
| 613 |
+
raise ValueError(
|
| 614 |
+
"svd_transformer with svd=None requires pre-tokenized input "
|
| 615 |
+
"(B, S, F). For images, patchify first or pass svd=(S, V, D)."
|
| 616 |
+
)
|
| 617 |
+
else:
|
| 618 |
+
raise ValueError(f"x.shape must be (B, S, F) or (B, C, H, W); got {tuple(x.shape)}")
|
| 619 |
+
else:
|
| 620 |
+
svd_param = tuple(svd)
|
| 621 |
+
|
| 622 |
+
in_dim = x.shape[-1]
|
| 623 |
+
return SVDTransformer(
|
| 624 |
+
in_dim=in_dim, svd=svd_param, bypass_crash=bypass_crash,
|
| 625 |
+
heads=heads, hidden_size=hidden_size, depth=depth,
|
| 626 |
+
encode=encode, attention_layers=attention_layers,
|
| 627 |
+
activation=activation, geo_activation=geo_activation,
|
| 628 |
+
token_out=token_out, target=target,
|
| 629 |
+
svd_solver=svd_solver, eigh_solver=eigh_solver,
|
| 630 |
+
embed_dim=embed_dim,
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 635 |
+
# Self-test on import (smoke check; remove for production)
|
| 636 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 637 |
+
|
| 638 |
+
if __name__ == '__main__':
|
| 639 |
+
print("\n" + "=" * 72)
|
| 640 |
+
print("SVDTransformer prototype self-test")
|
| 641 |
+
print("=" * 72)
|
| 642 |
+
|
| 643 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 644 |
+
torch.manual_seed(0)
|
| 645 |
+
|
| 646 |
+
# --- Test 1: default config ---
|
| 647 |
+
print("\n[1] Default config: svd=(16, 8, 4), depth=4, encode='mlp'")
|
| 648 |
+
x = torch.randn(2, 16, 32, device=device) # (B=2, S=16, F=32)
|
| 649 |
+
former = svd_transformer(x, svd=(16, 8, 4))
|
| 650 |
+
former = former.to(device)
|
| 651 |
+
out = former(x)
|
| 652 |
+
n_params = sum(p.numel() for p in former.parameters())
|
| 653 |
+
print(f" in shape={tuple(x.shape)} out shape={tuple(out.shape)} params={n_params:,}")
|
| 654 |
+
assert out.shape == (2, 16, 32), f"Expected (2,16,32), got {out.shape}"
|
| 655 |
+
|
| 656 |
+
# --- Test 2: each encoder type ---
|
| 657 |
+
print("\n[2] All encoder types:")
|
| 658 |
+
for enc in _ENCODERS:
|
| 659 |
+
m = svd_transformer(x, svd=(16, 8, 4), encode=enc, depth=1, attention_layers=1).to(device)
|
| 660 |
+
try:
|
| 661 |
+
o = m(x)
|
| 662 |
+
print(f" encode={enc:12s} β out={tuple(o.shape)} params={sum(p.numel() for p in m.parameters()):,}")
|
| 663 |
+
except Exception as e:
|
| 664 |
+
print(f" encode={enc:12s} β {type(e).__name__}: {e}")
|
| 665 |
+
|
| 666 |
+
# --- Test 3: each target ---
|
| 667 |
+
print("\n[3] All target options:")
|
| 668 |
+
for tgt in ['SVD', 'VD', 'SV', 'S', 'V']:
|
| 669 |
+
m = svd_transformer(x, svd=(16, 8, 4), target=tgt, depth=1, attention_layers=0).to(device)
|
| 670 |
+
o = m(x)
|
| 671 |
+
# Count how many heads will receive nonzero gradient
|
| 672 |
+
loss = o.sum()
|
| 673 |
+
loss.backward()
|
| 674 |
+
head_grads = {
|
| 675 |
+
'S': m.cells[0].s_head.weight.grad.norm().item() if m.cells[0].s_head.weight.grad is not None else 0,
|
| 676 |
+
'U': m.cells[0].u_head.weight.grad.norm().item() if m.cells[0].u_head.weight.grad is not None else 0,
|
| 677 |
+
'Vt': m.cells[0].vt_head.weight.grad.norm().item() if m.cells[0].vt_head.weight.grad is not None else 0,
|
| 678 |
+
}
|
| 679 |
+
active = [k for k, v in head_grads.items() if v > 1e-9]
|
| 680 |
+
print(f" target={tgt:4s} β active heads={active} out={tuple(o.shape)}")
|
| 681 |
+
|
| 682 |
+
# --- Test 4: each token_out format ---
|
| 683 |
+
print("\n[4] All token_out formats:")
|
| 684 |
+
for to in ['all', 'QKV', 'SUVt']:
|
| 685 |
+
m = svd_transformer(x, svd=(16, 8, 4), token_out=to, depth=1, attention_layers=0).to(device)
|
| 686 |
+
o = m(x)
|
| 687 |
+
if isinstance(o, tuple):
|
| 688 |
+
shapes = [tuple(t.shape) for t in o]
|
| 689 |
+
print(f" token_out={to:5s} β {len(o)} tensors, shapes={shapes}")
|
| 690 |
+
else:
|
| 691 |
+
print(f" token_out={to:5s} β out={tuple(o.shape)}")
|
| 692 |
+
|
| 693 |
+
# --- Test 5: SVD orthogonality on a real model M (post-encoder) ---
|
| 694 |
+
print("\n[5] SVD orthogonality check (no row centering):")
|
| 695 |
+
m = svd_transformer(x, svd=(16, 8, 4), depth=1, attention_layers=0).to(device)
|
| 696 |
+
with torch.no_grad():
|
| 697 |
+
encoded = m.cells[0].encoder(x)
|
| 698 |
+
BS = 2 * 16
|
| 699 |
+
M = encoded.reshape(BS, 8, 4)
|
| 700 |
+
rm = M[0].mean(dim=-1)[:3].tolist()
|
| 701 |
+
print(f" M not centered: row_means[0,:3] = [{rm[0]:.4f},{rm[1]:.4f},{rm[2]:.4f}]")
|
| 702 |
+
U, Sv, Vt = _svd_dispatch(M)
|
| 703 |
+
I_D = torch.eye(4, device=device).expand(BS, 4, 4)
|
| 704 |
+
u_orth = (torch.bmm(U.transpose(1, 2), U) - I_D).abs().max().item()
|
| 705 |
+
v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I_D).abs().max().item()
|
| 706 |
+
recon = (torch.bmm(U * Sv.unsqueeze(1), Vt) - M).abs().max().item()
|
| 707 |
+
print(f" ||U^T U - I|| = {u_orth:.2e} {'β' if u_orth < 1e-3 else 'β'}")
|
| 708 |
+
print(f" ||Vt Vt^T - I|| = {v_orth:.2e} {'β' if v_orth < 1e-3 else 'β'}")
|
| 709 |
+
print(f" reconstruction = {recon:.2e} {'β' if recon < 1e-4 else 'β'}")
|
| 710 |
+
|
| 711 |
+
# --- Test 6: backward pass (gradient flows through SVD) ---
|
| 712 |
+
print("\n[6] Backward pass (gradient flow through SVD):")
|
| 713 |
+
m = svd_transformer(x, svd=(16, 8, 4), depth=2, attention_layers=1).to(device)
|
| 714 |
+
out = m(x)
|
| 715 |
+
loss = out.pow(2).mean()
|
| 716 |
+
loss.backward()
|
| 717 |
+
enc_grad = sum(
|
| 718 |
+
p.grad.norm().item() ** 2
|
| 719 |
+
for p in m.cells[0].encoder.parameters() if p.grad is not None
|
| 720 |
+
) ** 0.5
|
| 721 |
+
print(f" loss = {loss.item():.4f}")
|
| 722 |
+
print(f" cell[0].encoder grad_norm = {enc_grad:.4e} "
|
| 723 |
+
f"{'β flowing through SVD into encoder' if enc_grad > 0 else 'β'}")
|
| 724 |
+
|
| 725 |
+
# --- Test 7: solver dispatch combinations ---
|
| 726 |
+
print("\n[7] Solver dispatch combinations:")
|
| 727 |
+
for ssolver, esolver in [('auto', 'auto'), ('torch', 'auto'),
|
| 728 |
+
('auto', 'fl'), ('auto', 'torch')]:
|
| 729 |
+
try:
|
| 730 |
+
m = svd_transformer(x, svd=(16, 8, 4),
|
| 731 |
+
svd_solver=ssolver, eigh_solver=esolver,
|
| 732 |
+
depth=1, attention_layers=0).to(device)
|
| 733 |
+
o = m(x)
|
| 734 |
+
print(f" svd={ssolver:6s} eigh={esolver:6s} β ok out={tuple(o.shape)}")
|
| 735 |
+
except Exception as e:
|
| 736 |
+
print(f" svd={ssolver:6s} eigh={esolver:6s} β {type(e).__name__}: {e}")
|
| 737 |
+
|
| 738 |
+
# --- Test 8: bypass_crash for D > 128 ---
|
| 739 |
+
print("\n[8] D-too-large guard:")
|
| 740 |
+
try:
|
| 741 |
+
m = svd_transformer(torch.randn(2, 4, 200, device=device),
|
| 742 |
+
svd=(4, 200, 200), bypass_crash=False, depth=1, attention_layers=0)
|
| 743 |
+
print(f" bypass_crash=False with D=200 β β (should have raised)")
|
| 744 |
+
except RuntimeError as e:
|
| 745 |
+
print(f" bypass_crash=False with D=200 β β raised: {str(e)[:60]}...")
|
| 746 |
+
|
| 747 |
+
print("\n" + "=" * 72)
|
| 748 |
+
print("All smoke tests complete.")
|
| 749 |
+
print("=" * 72)
|