AbstractPhil commited on
Commit
0e9f66f
Β·
verified Β·
1 Parent(s): 8de8161

Create eigh_readme_version_tester.py

Browse files
Files changed (1) hide show
  1. eigh_readme_version_tester.py +281 -0
eigh_readme_version_tester.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ fl_eigh.py β€” Hybrid optimal eigendecomposition.
3
+
4
+ Combines FL algebraic superiority with geometric refinement:
5
+ Phase 1: FL characteristic polynomial (fp64) β†’ algebraically exact coefficients
6
+ Phase 2: Laguerre + Newton polish β†’ algebraically optimal eigenvalues
7
+ Phase 3: FL adjugate (fp64 Horner + max-col) β†’ eigenvector extraction
8
+ Phase 4: Newton-Schulz β†’ orthonormal eigenvectors (geometric projection)
9
+ Phase 5: Rayleigh quotient β†’ Ξ»α΅’ = vα΅’α΅€Avα΅’ (geometrically optimal eigenvalues)
10
+
11
+ The Rayleigh quotient is the KEY insight: given orthonormal eigenvectors V,
12
+ the eigenvalues Ξ»α΅’ = vα΅’α΅€Avα΅’ minimize ||Av - Ξ»v||Β² β€” the eigenpair residual.
13
+ This fuses FL's algebraic precision with geometric optimality.
14
+
15
+ Result: eigenvalues that minimize residual + eigenvectors that are orthonormal.
16
+ Both reconstruction ||A - VΞ›Vα΅€|| and eigenpair ||Av - Ξ»v|| are optimal.
17
+
18
+ Author: AbstractPhil / GeoLIP project
19
+ """
20
+
21
+ import math, time, gc, sys
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch import Tensor
25
+ from typing import Tuple
26
+
27
+ torch.backends.cuda.matmul.allow_tf32 = False
28
+ torch.backends.cudnn.allow_tf32 = False
29
+ torch.set_float32_matmul_precision('highest')
30
+
31
+
32
+ class FLEigh(nn.Module):
33
+
34
+ def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]:
35
+ B, n, _ = A.shape
36
+ device = A.device
37
+
38
+ # ── Pre-scale ──
39
+ scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12)
40
+ As = A / scale[:, None, None]
41
+
42
+ # ══════ Phase 1: Faddeev-LeVerrier (fp64) ══════
43
+ # n bmm β†’ characteristic polynomial + adjugate basis
44
+ Ad = As.double()
45
+ eye_d = torch.eye(n, device=device, dtype=torch.float64).unsqueeze(0).expand(B, -1, -1)
46
+ c = torch.zeros(B, n + 1, device=device, dtype=torch.float64)
47
+ c[:, n] = 1.0
48
+ Mstore = torch.zeros(n + 1, B, n, n, device=device, dtype=torch.float64)
49
+ Mk = torch.zeros(B, n, n, device=device, dtype=torch.float64)
50
+ for k in range(1, n + 1):
51
+ Mk = torch.bmm(Ad, Mk) + c[:, n - k + 1, None, None] * eye_d
52
+ Mstore[k] = Mk
53
+ c[:, n - k] = -(Ad * Mk).sum((-2, -1)) / k
54
+
55
+ # ══════ Phase 2: Laguerre + Polish β†’ algebraic eigenvalues ══════
56
+ use_f64 = n > 6
57
+ dt = torch.float64 if use_f64 else torch.float32
58
+ cl = c.to(dt).clone()
59
+ roots = torch.zeros(B, n, device=device, dtype=dt)
60
+ zi = As.to(dt).diagonal(dim1=-2, dim2=-1).sort(dim=-1).values
61
+ zi = zi + torch.linspace(-1e-4, 1e-4, n, device=device, dtype=dt).unsqueeze(0)
62
+
63
+ for ri in range(n):
64
+ deg = n - ri
65
+ z = zi[:, ri]
66
+ for _ in range(5):
67
+ pv = cl[:, deg]; dp = torch.zeros(B, device=device, dtype=dt)
68
+ d2 = torch.zeros(B, device=device, dtype=dt)
69
+ for j in range(deg - 1, -1, -1):
70
+ d2 = d2 * z + dp; dp = dp * z + pv; pv = pv * z + cl[:, j]
71
+ ok = pv.abs() > 1e-30
72
+ ps = torch.where(ok, pv, torch.ones_like(pv))
73
+ G = torch.where(ok, dp / ps, torch.zeros_like(dp))
74
+ H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2))
75
+ disc = ((deg - 1.0) * (deg * H - G * G)).clamp(min=0.0)
76
+ sq = torch.sqrt(disc); gp = G + sq; gm = G - sq
77
+ den = torch.where(gp.abs() >= gm.abs(), gp, gm)
78
+ dok = den.abs() > 1e-20
79
+ ds = torch.where(dok, den, torch.ones_like(den))
80
+ z = z - torch.where(dok, float(deg) / ds, torch.zeros_like(den))
81
+ roots[:, ri] = z
82
+ b = cl[:, deg]
83
+ for j in range(deg - 1, 0, -1):
84
+ bn = cl[:, j] + z * b; cl[:, j] = b; b = bn
85
+ cl[:, 0] = b
86
+
87
+ # Newton polish on original polynomial (fp64)
88
+ roots = roots.double()
89
+ for _ in range(3):
90
+ pv = torch.ones(B, n, device=device, dtype=torch.float64)
91
+ dp = torch.zeros(B, n, device=device, dtype=torch.float64)
92
+ for j in range(n - 1, -1, -1):
93
+ dp = dp * roots + pv; pv = pv * roots + c[:, j:j + 1]
94
+ ok = dp.abs() > 1e-30
95
+ dps = torch.where(ok, dp, torch.ones_like(dp))
96
+ roots = roots - torch.where(ok, pv / dps, torch.zeros_like(pv))
97
+
98
+ # ══════ Phase 3: FL adjugate β†’ eigenvector extraction (fp64) ══════
99
+ # Horner evaluation of adj(Ξ»I-A) at each eigenvalue
100
+ lam = roots # [B, n] fp64
101
+ R = Mstore[1].unsqueeze(1).expand(-1, n, -1, -1).clone()
102
+ for k in range(2, n + 1):
103
+ R = R * lam[:, :, None, None] + Mstore[k].unsqueeze(1)
104
+
105
+ # Max-norm column extraction (robust for all n)
106
+ cnorms = R.norm(dim=-2) # [B, n_eig, n_mat]
107
+ best = cnorms.argmax(dim=-1) # [B, n_eig]
108
+ idx = best.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n, 1)
109
+ vec = R.gather(-1, idx).squeeze(-1) # [B, n_eig, n_mat]
110
+ vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30)
111
+ V = vec.float().transpose(-2, -1) # [B, n, n] columns = eigvecs
112
+
113
+ # ══════ Phase 4: Newton-Schulz β†’ orthonormal eigenvectors ══════
114
+ # 2 iterations: stable for all n (3 diverges on near-degenerate cases)
115
+ eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1)
116
+ Y = torch.bmm(V.transpose(-2, -1), V)
117
+ X = eye_f.clone()
118
+ for _ in range(2):
119
+ T = 3.0 * eye_f - Y
120
+ X = 0.5 * torch.bmm(X, T)
121
+ Y = 0.5 * torch.bmm(T, Y)
122
+ V = torch.bmm(V, X)
123
+
124
+ # ══════ Phase 5: Rayleigh quotient β†’ geometrically optimal eigenvalues ══════
125
+ # Ξ»α΅’ = vα΅’α΅€ A vα΅’ β€” minimizes ||Av - Ξ»v||Β² for the given v
126
+ AV = torch.bmm(A, V) # [B, n, n]
127
+ evals = (V * AV).sum(dim=-2) # [B, n] = diag(Vα΅€AV)
128
+
129
+ # ── Sort ──
130
+ se, perm = evals.sort(dim=-1)
131
+ sv = V.gather(-1, perm.unsqueeze(-2).expand_as(V))
132
+ return se, sv
133
+
134
+
135
+ # ═══════════════════════════════════════════════════════════════════════
136
+ # Mathematical purity test
137
+ # ═══════════════════════════════════════════════════════════════════════
138
+
139
+ def math_test(A, vals, vecs):
140
+ B, n, _ = A.shape
141
+ dev = A.device
142
+ Ad = A.double(); vd = vals.double(); Vd = vecs.double()
143
+ AV = torch.bmm(Ad, Vd); VL = Vd * vd.unsqueeze(-2)
144
+ An = Ad.reshape(B, -1).norm(dim=-1, keepdim=True).clamp(min=1e-30)
145
+ res = (AV - VL).norm(dim=-2) / An # per-eigvec residual
146
+ VtV = torch.bmm(Vd.mT, Vd)
147
+ I = torch.eye(n, device=dev, dtype=torch.float64).unsqueeze(0)
148
+ orth = (VtV - I).reshape(B, -1).norm(dim=-1)
149
+ recon = torch.bmm(Vd * vd.unsqueeze(-2), Vd.mT)
150
+ recon_err = (Ad - recon).reshape(B, -1).norm(dim=-1) / An.squeeze(-1)
151
+ tr_err = (Ad.diagonal(dim1=-2,dim2=-1).sum(-1) - vd.sum(-1)).abs()
152
+ det_A = torch.linalg.det(Ad)
153
+ det_err = (det_A - vd.prod(-1)).abs() / det_A.abs().clamp(min=1e-30)
154
+ cp = torch.zeros(B, n, device=dev, dtype=torch.float64)
155
+ for i in range(n):
156
+ cp[:, i] = torch.linalg.det(vd[:, i:i+1, None] * I - Ad).abs()
157
+ return dict(
158
+ res_max=res.max().item(), res_mean=res.mean().item(),
159
+ orth_max=orth.max().item(), orth_mean=orth.mean().item(),
160
+ recon_max=recon_err.max().item(), recon_mean=recon_err.mean().item(),
161
+ tr_max=tr_err.max().item(), tr_mean=tr_err.mean().item(),
162
+ det_max=det_err.max().item(), det_mean=det_err.mean().item(),
163
+ cp_max=cp.max().item(), cp_mean=cp.mean().item(),
164
+ )
165
+
166
+
167
+ def sync(): torch.cuda.synchronize()
168
+ def gt(fn, w=20, r=300):
169
+ for _ in range(w): fn()
170
+ sync(); t=time.perf_counter()
171
+ for _ in range(r): fn()
172
+ sync(); return (time.perf_counter()-t)/r
173
+ def fmt(s):
174
+ if s<1e-3: return f"{s*1e6:.1f}Β΅s"
175
+ if s<1: return f"{s*1e3:.2f}ms"
176
+ return f"{s:.3f}s"
177
+
178
+
179
+ def main():
180
+ if not torch.cuda.is_available(): sys.exit(1)
181
+ dev = torch.device('cuda')
182
+ p = torch.cuda.get_device_properties(0)
183
+
184
+ print("="*72)
185
+ print(" FL Hybrid Eigh β€” Algebraic + Geometric Optimal")
186
+ print("="*72)
187
+ print(f" {p.name} | PyTorch {torch.__version__}")
188
+
189
+ # ── Mathematical purity sweep ──
190
+ print("\n" + "="*72)
191
+ print(" MATHEMATICAL PURITY (no reference impl, only definitions)")
192
+ print("="*72)
193
+
194
+ for nx in [3, 5, 6, 8, 10, 12]:
195
+ B = 2048 if nx <= 8 else 1024
196
+ A = (lambda R:(R+R.mT)/2)(torch.randn(B, nx, nx, device=dev))
197
+ cv, cV = torch.linalg.eigh(A)
198
+ fv, fV = FLEigh()(A)
199
+ mc = math_test(A, cv, cV); mf = math_test(A, fv, fV)
200
+
201
+ wins_c = 0; wins_f = 0
202
+ for key in mc:
203
+ if mf[key] < mc[key]: wins_f += 1
204
+ elif mc[key] < mf[key]: wins_c += 1
205
+ print(f"\n n={nx} B={B}: FL wins {wins_f}/12, cuSOLVER wins {wins_c}/12")
206
+
207
+ def row(name, key):
208
+ vc=mc[key]; vf=mf[key]
209
+ w="FL" if vf<vc else ("cuS" if vc<vf else "tie")
210
+ m="β—„" if w=="FL" else ("β–Ί" if w=="cuS" else " ")
211
+ print(f" {name:<28} {vc:>10.1e} {vf:>10.1e} {w} {m}")
212
+
213
+ print(f" {'Property':<28} {'cuSOLVER':>10} {'FL':>10}")
214
+ row("Eigenpair max", "res_max")
215
+ row("Eigenpair mean", "res_mean")
216
+ row("Orthogonality max", "orth_max")
217
+ row("Orthogonality mean", "orth_mean")
218
+ row("Reconstruction max", "recon_max")
219
+ row("Reconstruction mean", "recon_mean")
220
+ row("Trace max", "tr_max")
221
+ row("Determinant max", "det_max")
222
+ row("Char poly max", "cp_max")
223
+ row("Char poly mean", "cp_mean")
224
+ del A
225
+
226
+ # ── Accuracy pass/fail ──
227
+ print("\n" + "="*72)
228
+ print(" ACCURACY PASS/FAIL")
229
+ print("="*72)
230
+ ok_all = True
231
+ for nx in [3,4,5,6,8,10,12,16]:
232
+ A = (lambda R:(R+R.mT)/2)(torch.randn(1024, nx, nx, device=dev))
233
+ rv,rV = torch.linalg.eigh(A); fv,fV = FLEigh()(A)
234
+ ve = (fv-rv).abs().max().item()
235
+ dots = torch.bmm(rV.double().mT, fV.double()).abs().max(dim=-1).values.min().item()
236
+ ok = ve < 1e-2 and dots > 0.99
237
+ if not ok: ok_all = False
238
+ print(f" [{'OK' if ok else 'NO'}] n={nx:>2} val_diff={ve:.1e} align={dots:.6f}")
239
+ del A
240
+
241
+ # ── Speed ──
242
+ N=6; B=4096
243
+ A = (lambda R:(R+R.mT)/2)(torch.randn(B, N, N, device=dev))
244
+ solver = FLEigh()
245
+
246
+ print(f"\n" + "="*72)
247
+ print(f" THROUGHPUT (n={N} B={B})")
248
+ print("="*72)
249
+ for _ in range(5): solver(A); sync()
250
+ tr = gt(lambda: torch.linalg.eigh(A))
251
+ te = gt(lambda: solver(A))
252
+ print(f" cuSOLVER: {fmt(tr)}")
253
+ print(f" FL eager: {fmt(te)} ({tr/te:.2f}Γ—)")
254
+
255
+ try:
256
+ cs = torch.compile(solver, fullgraph=True)
257
+ print(" Compiling...", end=" ", flush=True)
258
+ for _ in range(3): cs(A); sync()
259
+ print("done.")
260
+ tc = gt(lambda: cs(A))
261
+ print(f" FL compiled: {fmt(tc)} ({tr/tc:.2f}Γ—)")
262
+ except Exception as e:
263
+ print(f" COMPILE FAILED: {str(e)[:100]}")
264
+ tc = None
265
+
266
+ # ── Memory ──
267
+ print(f"\n MEMORY")
268
+ for l,fn in [("cuSOLVER",lambda:torch.linalg.eigh(A)),("FL",lambda:solver(A))]:
269
+ torch.cuda.empty_cache(); gc.collect(); torch.cuda.reset_peak_memory_stats()
270
+ b=torch.cuda.memory_allocated(); fn(); sync()
271
+ print(f" {l:<10} {(torch.cuda.max_memory_allocated()-b)/1024**2:.1f}MB")
272
+
273
+ # ── Summary ──
274
+ print(f"\n" + "="*72)
275
+ print(f" All pass: {ok_all}")
276
+ if tc: print(f" Compiled: {tr/tc:.2f}Γ— vs cuSOLVER")
277
+ print("="*72)
278
+
279
+
280
+ if __name__ == '__main__':
281
+ main()