AbstractPhil commited on
Commit
428e0c8
Β·
verified Β·
1 Parent(s): 150fa15

Create prototype_transformer.py

Browse files
Files changed (1) hide show
  1. prototype_transformer.py +749 -0
prototype_transformer.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SVD Transformer Prototype
3
+ =================================================================
4
+ Standalone prototype matching user-provided API spec. Combines:
5
+
6
+ - SpectralProbe lineage's three-head SVD readout (S, U, Vt β†’ embed)
7
+ - Correct geolip imports: geolip_core registers 'geolip' alias, then
8
+ geolip.linalg as LA, then FLEigh from geolip_core.linalg.eigh
9
+ - NO row centering (verified bug β€” gram-based SVD goes degenerate)
10
+ - Configurable encoder (mlp/transformer/conv/film/ffn/rotary/lstm/gru)
11
+ - Configurable geometric activation (star=ReLUΒ² default)
12
+ - Configurable attention layers between SVD passes
13
+ - Configurable depth (stacked SVD cells)
14
+ - Three head selection via `target` ({SVD, VD, SV, S, V})
15
+ - Three output formats via `token_out` ({all, QKV, SUVt})
16
+ - Solver dispatch: svd_solver={auto, torch, triton}, eigh_solver={auto, torch, fl}
17
+
18
+ API parameter interpretations (clarify if wrong):
19
+ svd=[S, V, D] β€” S = sequence/slot count, V/D = SVD matrix dims
20
+ target="SVD" β€” all three heads active (S, U, Vt)
21
+ target="VD" β€” U + Vt only (drop singular values)
22
+ target="SV" β€” S + U only (drop right basis)
23
+ target="S"/"V" β€” single head
24
+ token_out="all" β€” return (B, S, embed_dim) sequence
25
+ token_out="QKV" β€” return (Q, K, V) tuple after QKV projection
26
+ token_out="SUVt" β€” return (U, S_vals, Vt) of the LAST cell's SVD
27
+ depth=N β€” stack N independent SVD cells
28
+
29
+ Lineage: AbstractPhil / SpectralProbe β†’ CIFAR-10 (53.7% with 13.6k params)
30
+ """
31
+
32
+ import math
33
+ from typing import Optional, Tuple, Union
34
+
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+
39
+
40
+ # ────────────────────────────────────────────────────────────────────────
41
+ # geolip imports β€” CORRECT ORDER (geolip_core triggers sys.modules alias)
42
+ # ────────────────────────────────────────────────────────────────────────
43
+
44
+ try:
45
+ import geolip_core # noqa: F401 registers 'geolip' alias in sys.modules
46
+ import geolip # now resolvable
47
+ import geolip.linalg as LA # main dispatcher
48
+ from geolip_core.linalg.eigh import FLEigh
49
+ _HAS_GEOLIP = True
50
+ print(f"βœ“ geolip {geolip.__version__} β€” using LA.svd + FLEigh")
51
+ LA.backend.status()
52
+ except ImportError as e:
53
+ _HAS_GEOLIP = False
54
+ LA = None
55
+ FLEigh = None
56
+ print(f"⚠ geolip_core not installed ({e}) β€” torch.linalg fallback")
57
+
58
+
59
+ # ────────────────────────────────────────────────────────────────────────
60
+ # Activations (regular + geometric)
61
+ # ────────────────────────────────────────────────────────────────────────
62
+
63
+ class StarActivation(nn.Module):
64
+ """ReLUΒ² β€” squared positive activation. All-positive output."""
65
+ def forward(self, x):
66
+ return F.relu(x).pow(2)
67
+
68
+
69
+ _GEO_ACTS = {
70
+ 'star': lambda: StarActivation(),
71
+ 'relu': lambda: nn.ReLU(),
72
+ 'gelu': lambda: nn.GELU(),
73
+ 'silu': lambda: nn.SiLU(),
74
+ 'swilu': lambda: nn.SiLU(), # alias of silu
75
+ 'tanh': lambda: nn.Tanh(),
76
+ 'sigmoid': lambda: nn.Sigmoid(),
77
+ 'leaky_relu': lambda: nn.LeakyReLU(0.01),
78
+ }
79
+
80
+ _REG_ACTS = {
81
+ 'gelu': lambda: nn.GELU(),
82
+ 'relu': lambda: nn.ReLU(),
83
+ 'silu': lambda: nn.SiLU(),
84
+ 'tanh': lambda: nn.Tanh(),
85
+ 'leaky_relu': lambda: nn.LeakyReLU(0.01),
86
+ }
87
+
88
+
89
+ def make_geo_activation(name: str) -> nn.Module:
90
+ name = (name or 'star').lower()
91
+ if name not in _GEO_ACTS:
92
+ raise ValueError(f"Unknown geo_activation: {name!r}; options: {list(_GEO_ACTS)}")
93
+ return _GEO_ACTS[name]()
94
+
95
+
96
+ def make_activation(name: str) -> nn.Module:
97
+ name = (name or 'gelu').lower()
98
+ if name not in _REG_ACTS:
99
+ raise ValueError(f"Unknown activation: {name!r}; options: {list(_REG_ACTS)}")
100
+ return _REG_ACTS[name]()
101
+
102
+
103
+ def _act_name_for_pytorch(name: str) -> str:
104
+ """nn.TransformerEncoderLayer accepts 'gelu'/'relu' strings; map our names."""
105
+ name = (name or 'gelu').lower()
106
+ return name if name in ('gelu', 'relu') else 'gelu'
107
+
108
+
109
+ # ────────────────────────────────────────────────────────────────────────
110
+ # Encoder variants β€” apply per-token before SVD reshape
111
+ # ──────────────────────────────────��─────────────────────────────────────
112
+
113
+ def _fit_heads(d: int, target: int) -> int:
114
+ """Pick a head count that divides d evenly (target preferred)."""
115
+ for h in [target, target // 2, target // 4, 8, 4, 2, 1]:
116
+ if h > 0 and d % h == 0:
117
+ return h
118
+ return 1
119
+
120
+
121
+ class MLPEncoder(nn.Module):
122
+ """encode='mlp' (default) β€” two-layer MLP per token."""
123
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
124
+ super().__init__()
125
+ # hidden_size is the API's "Internal MLP hidden size" β€” small (default 4)
126
+ # Don't let it bottleneck; ensure at least max(in, out)/2
127
+ h = max(hidden_size, max(in_dim, out_dim) // 2, 8)
128
+ self.net = nn.Sequential(
129
+ nn.Linear(in_dim, h),
130
+ make_activation(activation),
131
+ nn.Linear(h, out_dim),
132
+ )
133
+
134
+ def forward(self, x):
135
+ return self.net(x)
136
+
137
+
138
+ class FFNEncoder(nn.Module):
139
+ """encode='ffn' β€” transformer-style 4Γ— expansion FFN."""
140
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
141
+ super().__init__()
142
+ h = max(hidden_size, 4 * out_dim)
143
+ self.net = nn.Sequential(
144
+ nn.Linear(in_dim, h),
145
+ make_activation(activation),
146
+ nn.Linear(h, out_dim),
147
+ )
148
+
149
+ def forward(self, x):
150
+ return self.net(x)
151
+
152
+
153
+ class FiLMEncoder(nn.Module):
154
+ """encode='film' β€” feature-wise affine modulation."""
155
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
156
+ super().__init__()
157
+ self.skip = nn.Linear(in_dim, out_dim)
158
+ self.gamma = nn.Linear(in_dim, out_dim)
159
+ self.beta = nn.Linear(in_dim, out_dim)
160
+ self.act = make_activation(activation)
161
+
162
+ def forward(self, x):
163
+ skip = self.skip(x)
164
+ return self.act(skip * (1.0 + self.gamma(x)) + self.beta(x))
165
+
166
+
167
+ class ConvEncoder(nn.Module):
168
+ """encode='conv' β€” 1D conv across the token sequence."""
169
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
170
+ super().__init__()
171
+ self.proj = nn.Linear(in_dim, out_dim)
172
+ self.conv = nn.Conv1d(out_dim, out_dim, kernel_size=3, padding=1)
173
+ self.act = make_activation(activation)
174
+
175
+ def forward(self, x): # (B, S, in_dim)
176
+ x = self.proj(x)
177
+ x = x.transpose(1, 2) # (B, out_dim, S)
178
+ x = self.act(self.conv(x))
179
+ return x.transpose(1, 2) # (B, S, out_dim)
180
+
181
+
182
+ class TransformerEncoder(nn.Module):
183
+ """encode='transformer' β€” single transformer encoder layer pre-SVD."""
184
+ def __init__(self, in_dim, out_dim, hidden_size, activation, n_heads=4, **_):
185
+ super().__init__()
186
+ self.proj = nn.Linear(in_dim, out_dim)
187
+ h = _fit_heads(out_dim, n_heads)
188
+ self.layer = nn.TransformerEncoderLayer(
189
+ d_model=out_dim, nhead=h,
190
+ dim_feedforward=max(hidden_size, 4 * out_dim),
191
+ activation=_act_name_for_pytorch(activation),
192
+ batch_first=True, norm_first=True,
193
+ )
194
+
195
+ def forward(self, x):
196
+ return self.layer(self.proj(x))
197
+
198
+
199
+ class LSTMEncoder(nn.Module):
200
+ """encode='lstm' β€” sequential LSTM."""
201
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
202
+ super().__init__()
203
+ self.lstm = nn.LSTM(in_dim, out_dim, batch_first=True)
204
+ self.act = make_activation(activation)
205
+
206
+ def forward(self, x):
207
+ out, _ = self.lstm(x)
208
+ return self.act(out)
209
+
210
+
211
+ class GRUEncoder(nn.Module):
212
+ """encode='gru' β€” sequential GRU."""
213
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
214
+ super().__init__()
215
+ self.gru = nn.GRU(in_dim, out_dim, batch_first=True)
216
+ self.act = make_activation(activation)
217
+
218
+ def forward(self, x):
219
+ out, _ = self.gru(x)
220
+ return self.act(out)
221
+
222
+
223
+ class RotaryEncoder(nn.Module):
224
+ """encode='rotary' β€” projection then rotary positional embedding."""
225
+ def __init__(self, in_dim, out_dim, hidden_size, activation, **_):
226
+ super().__init__()
227
+ self.proj = nn.Linear(in_dim, out_dim)
228
+ self.dim = out_dim
229
+ self.act = make_activation(activation)
230
+
231
+ def forward(self, x):
232
+ x = self.proj(x) # (B, S, out_dim)
233
+ B, S, D = x.shape
234
+ d_half = D // 2
235
+ if d_half == 0:
236
+ return self.act(x)
237
+ positions = torch.arange(S, device=x.device, dtype=x.dtype).unsqueeze(0)
238
+ freqs = torch.exp(torch.arange(d_half, device=x.device, dtype=x.dtype)
239
+ * (-math.log(10000.0) / d_half))
240
+ angles = positions.unsqueeze(-1) * freqs.unsqueeze(0) # (1, S, d_half)
241
+ cos, sin = angles.cos(), angles.sin()
242
+ x1 = x[..., :d_half]
243
+ x2 = x[..., d_half:2 * d_half]
244
+ rotated_1 = x1 * cos - x2 * sin
245
+ rotated_2 = x1 * sin + x2 * cos
246
+ if D % 2 == 1:
247
+ tail = x[..., 2 * d_half:]
248
+ x = torch.cat([rotated_1, rotated_2, tail], dim=-1)
249
+ else:
250
+ x = torch.cat([rotated_1, rotated_2], dim=-1)
251
+ return self.act(x)
252
+
253
+
254
+ _ENCODERS = {
255
+ 'mlp': MLPEncoder,
256
+ 'ffn': FFNEncoder,
257
+ 'film': FiLMEncoder,
258
+ 'conv': ConvEncoder,
259
+ 'transformer': TransformerEncoder,
260
+ 'lstm': LSTMEncoder,
261
+ 'gru': GRUEncoder,
262
+ 'rotary': RotaryEncoder,
263
+ }
264
+
265
+
266
+ def build_encoder(encode, in_dim, out_dim, hidden_size, activation):
267
+ enc = (encode or 'mlp').lower()
268
+ if enc not in _ENCODERS:
269
+ raise ValueError(f"Unknown encode={encode!r}; options: {list(_ENCODERS)}")
270
+ return _ENCODERS[enc](in_dim, out_dim, hidden_size, activation)
271
+
272
+
273
+ # ────────────────────────────────────────────────────────────────────────
274
+ # SVD dispatch β€” auto-route to fastest available correct backend
275
+ # ────────────────────────────────────────────────────────────────────────
276
+
277
+ def _svd_dispatch(M: torch.Tensor,
278
+ svd_solver: str = 'auto',
279
+ eigh_solver: str = 'auto'
280
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
281
+ """
282
+ M: (BS, V, D) β€” batch of matrices to decompose.
283
+ Returns: U (BS, V, D), S_vals (BS, D), Vt (BS, D, D)
284
+
285
+ Dispatch logic (ALL paths produce thin SVD with descending singular values):
286
+
287
+ no geolip β†’ torch.linalg.svd in fp64 (rank-deficient-safe)
288
+ svd_solver='torch' β†’ torch.linalg.svd in fp64
289
+ svd_solver='triton' β†’ LA.svd(method='triton') β€” D ≀ 6 fp64 only
290
+ eigh_solver='fl' β†’ custom gram + FLEigh (compiles up to D=12)
291
+ auto/auto β†’ LA.svd default dispatch (best per backend)
292
+
293
+ NEVER row-center M before this β€” the gram path produces garbage U for
294
+ rank-deficient inputs. Verified bug across both this implementation and
295
+ geolip's gram_eigh path. The production SVDObserver in geolip_core also
296
+ avoids centering for this reason.
297
+ """
298
+ # --- Fallback: no geolip
299
+ if not _HAS_GEOLIP:
300
+ with torch.amp.autocast('cuda', enabled=False):
301
+ U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
302
+ return U.float(), Sv.float(), Vt.float()
303
+
304
+ # --- Explicit torch path
305
+ if svd_solver == 'torch':
306
+ with torch.amp.autocast('cuda', enabled=False):
307
+ U, Sv, Vt = torch.linalg.svd(M.double(), full_matrices=False)
308
+ return U.float(), Sv.float(), Vt.float()
309
+
310
+ # --- Explicit FL eigh path (custom gram + FLEigh, more accurate than torch.linalg.eigh)
311
+ if eigh_solver == 'fl':
312
+ return _gram_fl_eigh_svd(M)
313
+
314
+ # --- Triton path
315
+ if svd_solver == 'triton':
316
+ try:
317
+ return LA.svd(M, method='triton')
318
+ except Exception as exc:
319
+ print(f" ⚠ triton SVD failed ({exc}); falling back to LA.svd default")
320
+ return LA.svd(M)
321
+
322
+ # --- Default: LA.svd auto-dispatch (FL eigh on CUDA D≀12, torch otherwise)
323
+ return LA.svd(M)
324
+
325
+
326
+ def _gram_fl_eigh_svd(M: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
327
+ """
328
+ Custom gram + FL eigh SVD. Uses geolip_core.linalg.eigh.FLEigh β€” the
329
+ Faddeev-LeVerrier polynomial + Laguerre roots + Newton-Schulz pipeline.
330
+ More accurate than torch.linalg.eigh on ill-conditioned grams.
331
+
332
+ Compiles up to D=12 on CUDA. For larger D, use _svd_dispatch with default
333
+ auto routing (which will pick torch.linalg.svd).
334
+ """
335
+ if FLEigh is None:
336
+ raise RuntimeError("FLEigh unavailable β€” geolip_core not installed")
337
+ orig_dtype = M.dtype
338
+ A = M.float() # FL eigh runs in float
339
+ G = torch.bmm(A.transpose(1, 2), A) # (BS, D, D), symmetric PSD
340
+ eigenvalues, V = FLEigh()(G)
341
+ # eigh returns ascending; we want descending singular values
342
+ eigenvalues = eigenvalues.flip(-1)
343
+ V = V.flip(-1)
344
+ Sv = torch.sqrt(eigenvalues.clamp(min=1e-12))
345
+ U = torch.bmm(A, V) / Sv.unsqueeze(1).clamp(min=1e-8)
346
+ Vh = V.transpose(-2, -1).contiguous()
347
+ return U.to(orig_dtype), Sv.to(orig_dtype), Vh.to(orig_dtype)
348
+
349
+
350
+ # ────────────────────────────────────────────────────────────────────────
351
+ # Image patcher (helper for image inputs; not part of svd_transformer itself)
352
+ # ────────────────────────────────────────────────────────────────────────
353
+
354
+ class TensorPatcher(nn.Module):
355
+ """(B, C, H, W) β†’ (B, N, CΒ·phΒ·pw). Pure reshape, no learned params."""
356
+ def __init__(self, input_shape, patch_size):
357
+ super().__init__()
358
+ C, H, W = input_shape
359
+ ph = pw = patch_size
360
+ assert H % ph == 0 and W % pw == 0
361
+ self.C, self.H, self.W = C, H, W
362
+ self.ph, self.pw = ph, pw
363
+ self.n_patches = (H // ph) * (W // pw)
364
+ self.patch_dim = C * ph * pw
365
+
366
+ def forward(self, x):
367
+ B, C, H, W = x.shape
368
+ ph, pw = self.ph, self.pw
369
+ gh, gw = H // ph, W // pw
370
+ p = x.reshape(B, C, gh, ph, gw, pw)
371
+ p = p.permute(0, 2, 4, 1, 3, 5).contiguous()
372
+ return p.reshape(B, gh * gw, -1)
373
+
374
+
375
+ # ────────────────────────────────────────────────────────────────────────
376
+ # SVD Cell β€” one cycle: encode β†’ SVD β†’ three heads β†’ attention
377
+ # ────────────────────────────────────────────────────────────────────────
378
+
379
+ class SVDCell(nn.Module):
380
+ """
381
+ One cycle of the architecture:
382
+
383
+ tokens (B, S, in_dim)
384
+ ↓ encode (mlp/conv/transformer/...) [out: (B, S, VΒ·D)]
385
+ ↓ reshape [out: (BΒ·S, V, D)]
386
+ ↓ SVD via geolip [out: U(BS,V,D), S(BS,D), Vt(BS,D,D)]
387
+ ↓ three-head readout (target masks heads) [out: (BΒ·S, embed_dim)]
388
+ ↓ geo_activation [out: (BΒ·S, embed_dim)]
389
+ ↓ reshape [out: (B, S, embed_dim)]
390
+ ↓ attention_layers Γ— TransformerEncoderLayer
391
+ ↓ LayerNorm
392
+ β†’ tokens (B, S, embed_dim)
393
+
394
+ The SVD components (U, S_vals, Vt) of the last forward pass are cached
395
+ on `self._last_svd` for token_out="SUVt" extraction.
396
+ """
397
+ _TARGET_TO_MASK = {
398
+ 'SVD': (True, True, True), # all three heads
399
+ 'VD': (False, True, True), # U + Vt only
400
+ 'SV': (True, True, False), # S + U only
401
+ 'S': (True, False, False), # singular values only
402
+ 'V': (False, True, False), # U (left vectors) only
403
+ }
404
+
405
+ def __init__(self, *, in_dim, S, V, D, embed_dim, hidden_size,
406
+ encode, activation, geo_activation, target,
407
+ attention_layers, heads, svd_solver, eigh_solver):
408
+ super().__init__()
409
+ self.S, self.V, self.D = S, V, D
410
+ self.embed_dim = embed_dim
411
+ self.target = (target or 'SVD').upper()
412
+ self.svd_solver = svd_solver
413
+ self.eigh_solver = eigh_solver
414
+ if self.target not in self._TARGET_TO_MASK:
415
+ raise ValueError(f"Unknown target={target!r}; options: {list(self._TARGET_TO_MASK)}")
416
+
417
+ mat_dim = V * D
418
+
419
+ # Encoder: tokens (B, S, in_dim) β†’ (B, S, V*D)
420
+ self.encoder = build_encoder(encode, in_dim, mat_dim, hidden_size, activation)
421
+
422
+ # Three head linears (all instantiated; mask gates which contribute)
423
+ self.s_head = nn.Linear(D, embed_dim)
424
+ self.u_head = nn.Linear(V * D, embed_dim)
425
+ self.vt_head = nn.Linear(D * D, embed_dim)
426
+
427
+ self.geo_act = make_geo_activation(geo_activation)
428
+
429
+ # Attention stack (post-SVD)
430
+ if attention_layers > 0:
431
+ n_h = _fit_heads(embed_dim, heads)
432
+ layer = nn.TransformerEncoderLayer(
433
+ d_model=embed_dim, nhead=n_h,
434
+ dim_feedforward=4 * embed_dim,
435
+ activation=_act_name_for_pytorch(activation),
436
+ batch_first=True, norm_first=True,
437
+ )
438
+ self.attention = nn.TransformerEncoder(layer, num_layers=attention_layers)
439
+ else:
440
+ self.attention = nn.Identity()
441
+
442
+ self.norm = nn.LayerNorm(embed_dim)
443
+ self._last_svd = None # (U, S_vals, Vt) cache for token_out="SUVt"
444
+
445
+ def forward(self, tokens):
446
+ """tokens: (B, S, in_dim) β†’ (B, S, embed_dim)"""
447
+ B, S, _ = tokens.shape
448
+ assert S == self.S, f"Expected S={self.S} tokens, got {S}"
449
+
450
+ # Encode β†’ VΓ—D matrix per token (NO row-centering)
451
+ encoded = self.encoder(tokens) # (B, S, V*D)
452
+ M = encoded.reshape(B * S, self.V, self.D)
453
+
454
+ # SVD
455
+ U, Sv, Vt = _svd_dispatch(M, self.svd_solver, self.eigh_solver)
456
+ self._last_svd = (U, Sv, Vt)
457
+
458
+ # Three-head readout (target gates which heads contribute)
459
+ use_s, use_u, use_vt = self._TARGET_TO_MASK[self.target]
460
+ token_feat = torch.zeros(B * S, self.embed_dim,
461
+ device=tokens.device, dtype=tokens.dtype)
462
+ if use_s:
463
+ token_feat = token_feat + self.s_head(Sv)
464
+ if use_u:
465
+ token_feat = token_feat + self.u_head(U.reshape(B * S, -1))
466
+ if use_vt:
467
+ token_feat = token_feat + self.vt_head(Vt.reshape(B * S, -1))
468
+
469
+ token_feat = self.geo_act(token_feat)
470
+ token_feat = token_feat.reshape(B, S, self.embed_dim)
471
+
472
+ # Attention layers
473
+ token_feat = self.attention(token_feat)
474
+ return self.norm(token_feat)
475
+
476
+
477
+ # ────────────────────────────────────────────────────────────────────────
478
+ # SVDTransformer β€” top-level module (depth Γ— SVDCell)
479
+ # ────────────────────────────────────────────────────────────────────────
480
+
481
+ class SVDTransformer(nn.Module):
482
+ """
483
+ Stacked SVD cells with configurable encoder, attention, and head selection.
484
+
485
+ First cell takes in_dim; subsequent cells take embed_dim. Each cell has its
486
+ own encoder + SVD + attention substack; `depth` cells in sequence.
487
+ """
488
+ def __init__(self, *,
489
+ in_dim: int,
490
+ svd: Tuple[int, int, int] = (16, 8, 4),
491
+ bypass_crash: bool = True,
492
+ heads: int = 64,
493
+ hidden_size: int = 4,
494
+ depth: int = 4,
495
+ encode: str = 'mlp',
496
+ attention_layers: int = 2,
497
+ activation: str = 'gelu',
498
+ geo_activation: str = 'star',
499
+ token_out: str = 'all',
500
+ target: str = 'SVD',
501
+ svd_solver: str = 'auto',
502
+ eigh_solver: str = 'auto',
503
+ embed_dim: Optional[int] = None):
504
+ super().__init__()
505
+ S, V, D = svd
506
+ self.S, self.V, self.D = S, V, D
507
+
508
+ if D > 128:
509
+ msg = f"D={D} > 128 β€” gram-based SVD will be very slow / OOM-prone."
510
+ if not bypass_crash:
511
+ raise RuntimeError(msg + " Pass bypass_crash=True to override.")
512
+ print(f"⚠ {msg}")
513
+
514
+ if embed_dim is None:
515
+ embed_dim = V * D # default: same dim as flattened SVD matrix
516
+ self.embed_dim = embed_dim
517
+ self.token_out = (token_out or 'all').lower()
518
+
519
+ cells = []
520
+ for i in range(depth):
521
+ cell_in = in_dim if i == 0 else embed_dim
522
+ cells.append(SVDCell(
523
+ in_dim=cell_in, S=S, V=V, D=D, embed_dim=embed_dim,
524
+ hidden_size=hidden_size, encode=encode,
525
+ activation=activation, geo_activation=geo_activation,
526
+ target=target, attention_layers=attention_layers,
527
+ heads=heads, svd_solver=svd_solver, eigh_solver=eigh_solver,
528
+ ))
529
+ self.cells = nn.ModuleList(cells)
530
+
531
+ # QKV projection (only used when token_out="QKV")
532
+ if self.token_out == 'qkv':
533
+ self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
534
+
535
+ def forward(self, x: torch.Tensor,
536
+ y: Optional[torch.Tensor] = None,
537
+ z: Optional[Union[torch.Tensor, dict, list]] = None):
538
+ """
539
+ x: (B, S, in_dim) β€” input token sequence
540
+ y: optional mask tensor (reserved; not yet wired into QKV/SUVt logic)
541
+ z: experimentation hooks (passed through; not yet consumed)
542
+
543
+ Returns one of:
544
+ token_out="all" (default) β†’ (B, S, embed_dim)
545
+ token_out="QKV" β†’ (Q, K, V) tuple, each (B, S, embed_dim)
546
+ token_out="SUVt"/"SUV" β†’ (U, S_vals, Vt) raw geometric tokens
547
+ from the last cell's SVD
548
+ """
549
+ for cell in self.cells:
550
+ x = cell(x) # (B, S, embed_dim)
551
+
552
+ if self.token_out == 'qkv':
553
+ qkv = self.qkv_proj(x)
554
+ q, k, v = qkv.chunk(3, dim=-1)
555
+ return q, k, v
556
+
557
+ if self.token_out in ('suvt', 'suv'):
558
+ # Return raw SVD components from the last cell β€” pre-attention
559
+ # would need to be tapped earlier; this returns post-attention SVD.
560
+ U, Sv, Vt = self.cells[-1]._last_svd
561
+ B, S = x.shape[:2]
562
+ U = U.reshape(B, S, self.V, self.D)
563
+ Sv = Sv.reshape(B, S, self.D)
564
+ Vt = Vt.reshape(B, S, self.D, self.D)
565
+ return U, Sv, Vt
566
+
567
+ return x
568
+
569
+
570
+ # ────────────────────────────────────────────────────────────────────────
571
+ # Functional wrapper matching the user-provided API spec
572
+ # ────────────────────────────────────────────────────────────────────────
573
+
574
+ def svd_transformer(x: torch.Tensor,
575
+ y: Optional[torch.Tensor] = None,
576
+ z: Optional[Union[torch.Tensor, dict, list]] = None,
577
+ *,
578
+ svd: Optional[Tuple[int, int, int]] = None,
579
+ bypass_crash: bool = True,
580
+ heads: int = 64,
581
+ hidden_size: int = 4,
582
+ depth: int = 4,
583
+ encode: str = 'mlp',
584
+ attention_layers: int = 2,
585
+ activation: str = 'gelu',
586
+ geo_activation: str = 'star',
587
+ token_out: str = 'all',
588
+ target: str = 'SVD',
589
+ svd_solver: str = 'auto',
590
+ eigh_solver: str = 'auto',
591
+ embed_dim: Optional[int] = None) -> SVDTransformer:
592
+ """
593
+ Functional API matching user-provided spec. Returns an SVDTransformer
594
+ initialized from x's shape; caller invokes it via former(x).
595
+
596
+ Shape inference for `svd=None`:
597
+ x.shape = (B, S, F) β†’ svd = (S, V, D) using sqrt(F) if F is a perfect
598
+ square, else (S, 8, 4) fallback
599
+ x.shape = (B, C, H, W) β†’ raises (caller must patchify or pass svd=)
600
+
601
+ Returns the SVDTransformer module. Apply it with `former(x)`.
602
+ """
603
+ if svd is None:
604
+ if x.ndim == 3:
605
+ B, S, F = x.shape
606
+ sq = int(F ** 0.5)
607
+ if sq * sq == F:
608
+ V, D = sq, sq
609
+ else:
610
+ V, D = 8, 4
611
+ svd_param = (S, V, D)
612
+ elif x.ndim == 4:
613
+ raise ValueError(
614
+ "svd_transformer with svd=None requires pre-tokenized input "
615
+ "(B, S, F). For images, patchify first or pass svd=(S, V, D)."
616
+ )
617
+ else:
618
+ raise ValueError(f"x.shape must be (B, S, F) or (B, C, H, W); got {tuple(x.shape)}")
619
+ else:
620
+ svd_param = tuple(svd)
621
+
622
+ in_dim = x.shape[-1]
623
+ return SVDTransformer(
624
+ in_dim=in_dim, svd=svd_param, bypass_crash=bypass_crash,
625
+ heads=heads, hidden_size=hidden_size, depth=depth,
626
+ encode=encode, attention_layers=attention_layers,
627
+ activation=activation, geo_activation=geo_activation,
628
+ token_out=token_out, target=target,
629
+ svd_solver=svd_solver, eigh_solver=eigh_solver,
630
+ embed_dim=embed_dim,
631
+ )
632
+
633
+
634
+ # ────────────────────────────────────────────────────────────────────────
635
+ # Self-test on import (smoke check; remove for production)
636
+ # ────────────────────────────────────────────────────────────────────────
637
+
638
+ if __name__ == '__main__':
639
+ print("\n" + "=" * 72)
640
+ print("SVDTransformer prototype self-test")
641
+ print("=" * 72)
642
+
643
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
644
+ torch.manual_seed(0)
645
+
646
+ # --- Test 1: default config ---
647
+ print("\n[1] Default config: svd=(16, 8, 4), depth=4, encode='mlp'")
648
+ x = torch.randn(2, 16, 32, device=device) # (B=2, S=16, F=32)
649
+ former = svd_transformer(x, svd=(16, 8, 4))
650
+ former = former.to(device)
651
+ out = former(x)
652
+ n_params = sum(p.numel() for p in former.parameters())
653
+ print(f" in shape={tuple(x.shape)} out shape={tuple(out.shape)} params={n_params:,}")
654
+ assert out.shape == (2, 16, 32), f"Expected (2,16,32), got {out.shape}"
655
+
656
+ # --- Test 2: each encoder type ---
657
+ print("\n[2] All encoder types:")
658
+ for enc in _ENCODERS:
659
+ m = svd_transformer(x, svd=(16, 8, 4), encode=enc, depth=1, attention_layers=1).to(device)
660
+ try:
661
+ o = m(x)
662
+ print(f" encode={enc:12s} β†’ out={tuple(o.shape)} params={sum(p.numel() for p in m.parameters()):,}")
663
+ except Exception as e:
664
+ print(f" encode={enc:12s} βœ— {type(e).__name__}: {e}")
665
+
666
+ # --- Test 3: each target ---
667
+ print("\n[3] All target options:")
668
+ for tgt in ['SVD', 'VD', 'SV', 'S', 'V']:
669
+ m = svd_transformer(x, svd=(16, 8, 4), target=tgt, depth=1, attention_layers=0).to(device)
670
+ o = m(x)
671
+ # Count how many heads will receive nonzero gradient
672
+ loss = o.sum()
673
+ loss.backward()
674
+ head_grads = {
675
+ 'S': m.cells[0].s_head.weight.grad.norm().item() if m.cells[0].s_head.weight.grad is not None else 0,
676
+ 'U': m.cells[0].u_head.weight.grad.norm().item() if m.cells[0].u_head.weight.grad is not None else 0,
677
+ 'Vt': m.cells[0].vt_head.weight.grad.norm().item() if m.cells[0].vt_head.weight.grad is not None else 0,
678
+ }
679
+ active = [k for k, v in head_grads.items() if v > 1e-9]
680
+ print(f" target={tgt:4s} β†’ active heads={active} out={tuple(o.shape)}")
681
+
682
+ # --- Test 4: each token_out format ---
683
+ print("\n[4] All token_out formats:")
684
+ for to in ['all', 'QKV', 'SUVt']:
685
+ m = svd_transformer(x, svd=(16, 8, 4), token_out=to, depth=1, attention_layers=0).to(device)
686
+ o = m(x)
687
+ if isinstance(o, tuple):
688
+ shapes = [tuple(t.shape) for t in o]
689
+ print(f" token_out={to:5s} β†’ {len(o)} tensors, shapes={shapes}")
690
+ else:
691
+ print(f" token_out={to:5s} β†’ out={tuple(o.shape)}")
692
+
693
+ # --- Test 5: SVD orthogonality on a real model M (post-encoder) ---
694
+ print("\n[5] SVD orthogonality check (no row centering):")
695
+ m = svd_transformer(x, svd=(16, 8, 4), depth=1, attention_layers=0).to(device)
696
+ with torch.no_grad():
697
+ encoded = m.cells[0].encoder(x)
698
+ BS = 2 * 16
699
+ M = encoded.reshape(BS, 8, 4)
700
+ rm = M[0].mean(dim=-1)[:3].tolist()
701
+ print(f" M not centered: row_means[0,:3] = [{rm[0]:.4f},{rm[1]:.4f},{rm[2]:.4f}]")
702
+ U, Sv, Vt = _svd_dispatch(M)
703
+ I_D = torch.eye(4, device=device).expand(BS, 4, 4)
704
+ u_orth = (torch.bmm(U.transpose(1, 2), U) - I_D).abs().max().item()
705
+ v_orth = (torch.bmm(Vt, Vt.transpose(1, 2)) - I_D).abs().max().item()
706
+ recon = (torch.bmm(U * Sv.unsqueeze(1), Vt) - M).abs().max().item()
707
+ print(f" ||U^T U - I|| = {u_orth:.2e} {'βœ“' if u_orth < 1e-3 else 'βœ—'}")
708
+ print(f" ||Vt Vt^T - I|| = {v_orth:.2e} {'βœ“' if v_orth < 1e-3 else 'βœ—'}")
709
+ print(f" reconstruction = {recon:.2e} {'βœ“' if recon < 1e-4 else 'βœ—'}")
710
+
711
+ # --- Test 6: backward pass (gradient flows through SVD) ---
712
+ print("\n[6] Backward pass (gradient flow through SVD):")
713
+ m = svd_transformer(x, svd=(16, 8, 4), depth=2, attention_layers=1).to(device)
714
+ out = m(x)
715
+ loss = out.pow(2).mean()
716
+ loss.backward()
717
+ enc_grad = sum(
718
+ p.grad.norm().item() ** 2
719
+ for p in m.cells[0].encoder.parameters() if p.grad is not None
720
+ ) ** 0.5
721
+ print(f" loss = {loss.item():.4f}")
722
+ print(f" cell[0].encoder grad_norm = {enc_grad:.4e} "
723
+ f"{'βœ“ flowing through SVD into encoder' if enc_grad > 0 else 'βœ—'}")
724
+
725
+ # --- Test 7: solver dispatch combinations ---
726
+ print("\n[7] Solver dispatch combinations:")
727
+ for ssolver, esolver in [('auto', 'auto'), ('torch', 'auto'),
728
+ ('auto', 'fl'), ('auto', 'torch')]:
729
+ try:
730
+ m = svd_transformer(x, svd=(16, 8, 4),
731
+ svd_solver=ssolver, eigh_solver=esolver,
732
+ depth=1, attention_layers=0).to(device)
733
+ o = m(x)
734
+ print(f" svd={ssolver:6s} eigh={esolver:6s} β†’ ok out={tuple(o.shape)}")
735
+ except Exception as e:
736
+ print(f" svd={ssolver:6s} eigh={esolver:6s} β†’ {type(e).__name__}: {e}")
737
+
738
+ # --- Test 8: bypass_crash for D > 128 ---
739
+ print("\n[8] D-too-large guard:")
740
+ try:
741
+ m = svd_transformer(torch.randn(2, 4, 200, device=device),
742
+ svd=(4, 200, 200), bypass_crash=False, depth=1, attention_layers=0)
743
+ print(f" bypass_crash=False with D=200 β†’ βœ— (should have raised)")
744
+ except RuntimeError as e:
745
+ print(f" bypass_crash=False with D=200 β†’ βœ“ raised: {str(e)[:60]}...")
746
+
747
+ print("\n" + "=" * 72)
748
+ print("All smoke tests complete.")
749
+ print("=" * 72)