Create eigh_readme_version_tester.py
Browse files- eigh_readme_version_tester.py +281 -0
eigh_readme_version_tester.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
fl_eigh.py β Hybrid optimal eigendecomposition.
|
| 3 |
+
|
| 4 |
+
Combines FL algebraic superiority with geometric refinement:
|
| 5 |
+
Phase 1: FL characteristic polynomial (fp64) β algebraically exact coefficients
|
| 6 |
+
Phase 2: Laguerre + Newton polish β algebraically optimal eigenvalues
|
| 7 |
+
Phase 3: FL adjugate (fp64 Horner + max-col) β eigenvector extraction
|
| 8 |
+
Phase 4: Newton-Schulz β orthonormal eigenvectors (geometric projection)
|
| 9 |
+
Phase 5: Rayleigh quotient β Ξ»α΅’ = vα΅’α΅Avα΅’ (geometrically optimal eigenvalues)
|
| 10 |
+
|
| 11 |
+
The Rayleigh quotient is the KEY insight: given orthonormal eigenvectors V,
|
| 12 |
+
the eigenvalues Ξ»α΅’ = vα΅’α΅Avα΅’ minimize ||Av - Ξ»v||Β² β the eigenpair residual.
|
| 13 |
+
This fuses FL's algebraic precision with geometric optimality.
|
| 14 |
+
|
| 15 |
+
Result: eigenvalues that minimize residual + eigenvectors that are orthonormal.
|
| 16 |
+
Both reconstruction ||A - VΞVα΅|| and eigenpair ||Av - Ξ»v|| are optimal.
|
| 17 |
+
|
| 18 |
+
Author: AbstractPhil / GeoLIP project
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import math, time, gc, sys
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from torch import Tensor
|
| 25 |
+
from typing import Tuple
|
| 26 |
+
|
| 27 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
| 28 |
+
torch.backends.cudnn.allow_tf32 = False
|
| 29 |
+
torch.set_float32_matmul_precision('highest')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class FLEigh(nn.Module):
|
| 33 |
+
|
| 34 |
+
def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]:
|
| 35 |
+
B, n, _ = A.shape
|
| 36 |
+
device = A.device
|
| 37 |
+
|
| 38 |
+
# ββ Pre-scale ββ
|
| 39 |
+
scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12)
|
| 40 |
+
As = A / scale[:, None, None]
|
| 41 |
+
|
| 42 |
+
# ββββββ Phase 1: Faddeev-LeVerrier (fp64) ββββββ
|
| 43 |
+
# n bmm β characteristic polynomial + adjugate basis
|
| 44 |
+
Ad = As.double()
|
| 45 |
+
eye_d = torch.eye(n, device=device, dtype=torch.float64).unsqueeze(0).expand(B, -1, -1)
|
| 46 |
+
c = torch.zeros(B, n + 1, device=device, dtype=torch.float64)
|
| 47 |
+
c[:, n] = 1.0
|
| 48 |
+
Mstore = torch.zeros(n + 1, B, n, n, device=device, dtype=torch.float64)
|
| 49 |
+
Mk = torch.zeros(B, n, n, device=device, dtype=torch.float64)
|
| 50 |
+
for k in range(1, n + 1):
|
| 51 |
+
Mk = torch.bmm(Ad, Mk) + c[:, n - k + 1, None, None] * eye_d
|
| 52 |
+
Mstore[k] = Mk
|
| 53 |
+
c[:, n - k] = -(Ad * Mk).sum((-2, -1)) / k
|
| 54 |
+
|
| 55 |
+
# ββββββ Phase 2: Laguerre + Polish β algebraic eigenvalues ββββββ
|
| 56 |
+
use_f64 = n > 6
|
| 57 |
+
dt = torch.float64 if use_f64 else torch.float32
|
| 58 |
+
cl = c.to(dt).clone()
|
| 59 |
+
roots = torch.zeros(B, n, device=device, dtype=dt)
|
| 60 |
+
zi = As.to(dt).diagonal(dim1=-2, dim2=-1).sort(dim=-1).values
|
| 61 |
+
zi = zi + torch.linspace(-1e-4, 1e-4, n, device=device, dtype=dt).unsqueeze(0)
|
| 62 |
+
|
| 63 |
+
for ri in range(n):
|
| 64 |
+
deg = n - ri
|
| 65 |
+
z = zi[:, ri]
|
| 66 |
+
for _ in range(5):
|
| 67 |
+
pv = cl[:, deg]; dp = torch.zeros(B, device=device, dtype=dt)
|
| 68 |
+
d2 = torch.zeros(B, device=device, dtype=dt)
|
| 69 |
+
for j in range(deg - 1, -1, -1):
|
| 70 |
+
d2 = d2 * z + dp; dp = dp * z + pv; pv = pv * z + cl[:, j]
|
| 71 |
+
ok = pv.abs() > 1e-30
|
| 72 |
+
ps = torch.where(ok, pv, torch.ones_like(pv))
|
| 73 |
+
G = torch.where(ok, dp / ps, torch.zeros_like(dp))
|
| 74 |
+
H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2))
|
| 75 |
+
disc = ((deg - 1.0) * (deg * H - G * G)).clamp(min=0.0)
|
| 76 |
+
sq = torch.sqrt(disc); gp = G + sq; gm = G - sq
|
| 77 |
+
den = torch.where(gp.abs() >= gm.abs(), gp, gm)
|
| 78 |
+
dok = den.abs() > 1e-20
|
| 79 |
+
ds = torch.where(dok, den, torch.ones_like(den))
|
| 80 |
+
z = z - torch.where(dok, float(deg) / ds, torch.zeros_like(den))
|
| 81 |
+
roots[:, ri] = z
|
| 82 |
+
b = cl[:, deg]
|
| 83 |
+
for j in range(deg - 1, 0, -1):
|
| 84 |
+
bn = cl[:, j] + z * b; cl[:, j] = b; b = bn
|
| 85 |
+
cl[:, 0] = b
|
| 86 |
+
|
| 87 |
+
# Newton polish on original polynomial (fp64)
|
| 88 |
+
roots = roots.double()
|
| 89 |
+
for _ in range(3):
|
| 90 |
+
pv = torch.ones(B, n, device=device, dtype=torch.float64)
|
| 91 |
+
dp = torch.zeros(B, n, device=device, dtype=torch.float64)
|
| 92 |
+
for j in range(n - 1, -1, -1):
|
| 93 |
+
dp = dp * roots + pv; pv = pv * roots + c[:, j:j + 1]
|
| 94 |
+
ok = dp.abs() > 1e-30
|
| 95 |
+
dps = torch.where(ok, dp, torch.ones_like(dp))
|
| 96 |
+
roots = roots - torch.where(ok, pv / dps, torch.zeros_like(pv))
|
| 97 |
+
|
| 98 |
+
# ββββββ Phase 3: FL adjugate β eigenvector extraction (fp64) ββββββ
|
| 99 |
+
# Horner evaluation of adj(Ξ»I-A) at each eigenvalue
|
| 100 |
+
lam = roots # [B, n] fp64
|
| 101 |
+
R = Mstore[1].unsqueeze(1).expand(-1, n, -1, -1).clone()
|
| 102 |
+
for k in range(2, n + 1):
|
| 103 |
+
R = R * lam[:, :, None, None] + Mstore[k].unsqueeze(1)
|
| 104 |
+
|
| 105 |
+
# Max-norm column extraction (robust for all n)
|
| 106 |
+
cnorms = R.norm(dim=-2) # [B, n_eig, n_mat]
|
| 107 |
+
best = cnorms.argmax(dim=-1) # [B, n_eig]
|
| 108 |
+
idx = best.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n, 1)
|
| 109 |
+
vec = R.gather(-1, idx).squeeze(-1) # [B, n_eig, n_mat]
|
| 110 |
+
vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30)
|
| 111 |
+
V = vec.float().transpose(-2, -1) # [B, n, n] columns = eigvecs
|
| 112 |
+
|
| 113 |
+
# ββββββ Phase 4: Newton-Schulz β orthonormal eigenvectors ββββββ
|
| 114 |
+
# 2 iterations: stable for all n (3 diverges on near-degenerate cases)
|
| 115 |
+
eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1)
|
| 116 |
+
Y = torch.bmm(V.transpose(-2, -1), V)
|
| 117 |
+
X = eye_f.clone()
|
| 118 |
+
for _ in range(2):
|
| 119 |
+
T = 3.0 * eye_f - Y
|
| 120 |
+
X = 0.5 * torch.bmm(X, T)
|
| 121 |
+
Y = 0.5 * torch.bmm(T, Y)
|
| 122 |
+
V = torch.bmm(V, X)
|
| 123 |
+
|
| 124 |
+
# ββββββ Phase 5: Rayleigh quotient β geometrically optimal eigenvalues ββββββ
|
| 125 |
+
# Ξ»α΅’ = vα΅’α΅ A vα΅’ β minimizes ||Av - Ξ»v||Β² for the given v
|
| 126 |
+
AV = torch.bmm(A, V) # [B, n, n]
|
| 127 |
+
evals = (V * AV).sum(dim=-2) # [B, n] = diag(Vα΅AV)
|
| 128 |
+
|
| 129 |
+
# ββ Sort ββ
|
| 130 |
+
se, perm = evals.sort(dim=-1)
|
| 131 |
+
sv = V.gather(-1, perm.unsqueeze(-2).expand_as(V))
|
| 132 |
+
return se, sv
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# Mathematical purity test
|
| 137 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 138 |
+
|
| 139 |
+
def math_test(A, vals, vecs):
|
| 140 |
+
B, n, _ = A.shape
|
| 141 |
+
dev = A.device
|
| 142 |
+
Ad = A.double(); vd = vals.double(); Vd = vecs.double()
|
| 143 |
+
AV = torch.bmm(Ad, Vd); VL = Vd * vd.unsqueeze(-2)
|
| 144 |
+
An = Ad.reshape(B, -1).norm(dim=-1, keepdim=True).clamp(min=1e-30)
|
| 145 |
+
res = (AV - VL).norm(dim=-2) / An # per-eigvec residual
|
| 146 |
+
VtV = torch.bmm(Vd.mT, Vd)
|
| 147 |
+
I = torch.eye(n, device=dev, dtype=torch.float64).unsqueeze(0)
|
| 148 |
+
orth = (VtV - I).reshape(B, -1).norm(dim=-1)
|
| 149 |
+
recon = torch.bmm(Vd * vd.unsqueeze(-2), Vd.mT)
|
| 150 |
+
recon_err = (Ad - recon).reshape(B, -1).norm(dim=-1) / An.squeeze(-1)
|
| 151 |
+
tr_err = (Ad.diagonal(dim1=-2,dim2=-1).sum(-1) - vd.sum(-1)).abs()
|
| 152 |
+
det_A = torch.linalg.det(Ad)
|
| 153 |
+
det_err = (det_A - vd.prod(-1)).abs() / det_A.abs().clamp(min=1e-30)
|
| 154 |
+
cp = torch.zeros(B, n, device=dev, dtype=torch.float64)
|
| 155 |
+
for i in range(n):
|
| 156 |
+
cp[:, i] = torch.linalg.det(vd[:, i:i+1, None] * I - Ad).abs()
|
| 157 |
+
return dict(
|
| 158 |
+
res_max=res.max().item(), res_mean=res.mean().item(),
|
| 159 |
+
orth_max=orth.max().item(), orth_mean=orth.mean().item(),
|
| 160 |
+
recon_max=recon_err.max().item(), recon_mean=recon_err.mean().item(),
|
| 161 |
+
tr_max=tr_err.max().item(), tr_mean=tr_err.mean().item(),
|
| 162 |
+
det_max=det_err.max().item(), det_mean=det_err.mean().item(),
|
| 163 |
+
cp_max=cp.max().item(), cp_mean=cp.mean().item(),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def sync(): torch.cuda.synchronize()
|
| 168 |
+
def gt(fn, w=20, r=300):
|
| 169 |
+
for _ in range(w): fn()
|
| 170 |
+
sync(); t=time.perf_counter()
|
| 171 |
+
for _ in range(r): fn()
|
| 172 |
+
sync(); return (time.perf_counter()-t)/r
|
| 173 |
+
def fmt(s):
|
| 174 |
+
if s<1e-3: return f"{s*1e6:.1f}Β΅s"
|
| 175 |
+
if s<1: return f"{s*1e3:.2f}ms"
|
| 176 |
+
return f"{s:.3f}s"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def main():
|
| 180 |
+
if not torch.cuda.is_available(): sys.exit(1)
|
| 181 |
+
dev = torch.device('cuda')
|
| 182 |
+
p = torch.cuda.get_device_properties(0)
|
| 183 |
+
|
| 184 |
+
print("="*72)
|
| 185 |
+
print(" FL Hybrid Eigh β Algebraic + Geometric Optimal")
|
| 186 |
+
print("="*72)
|
| 187 |
+
print(f" {p.name} | PyTorch {torch.__version__}")
|
| 188 |
+
|
| 189 |
+
# ββ Mathematical purity sweep ββ
|
| 190 |
+
print("\n" + "="*72)
|
| 191 |
+
print(" MATHEMATICAL PURITY (no reference impl, only definitions)")
|
| 192 |
+
print("="*72)
|
| 193 |
+
|
| 194 |
+
for nx in [3, 5, 6, 8, 10, 12]:
|
| 195 |
+
B = 2048 if nx <= 8 else 1024
|
| 196 |
+
A = (lambda R:(R+R.mT)/2)(torch.randn(B, nx, nx, device=dev))
|
| 197 |
+
cv, cV = torch.linalg.eigh(A)
|
| 198 |
+
fv, fV = FLEigh()(A)
|
| 199 |
+
mc = math_test(A, cv, cV); mf = math_test(A, fv, fV)
|
| 200 |
+
|
| 201 |
+
wins_c = 0; wins_f = 0
|
| 202 |
+
for key in mc:
|
| 203 |
+
if mf[key] < mc[key]: wins_f += 1
|
| 204 |
+
elif mc[key] < mf[key]: wins_c += 1
|
| 205 |
+
print(f"\n n={nx} B={B}: FL wins {wins_f}/12, cuSOLVER wins {wins_c}/12")
|
| 206 |
+
|
| 207 |
+
def row(name, key):
|
| 208 |
+
vc=mc[key]; vf=mf[key]
|
| 209 |
+
w="FL" if vf<vc else ("cuS" if vc<vf else "tie")
|
| 210 |
+
m="β" if w=="FL" else ("βΊ" if w=="cuS" else " ")
|
| 211 |
+
print(f" {name:<28} {vc:>10.1e} {vf:>10.1e} {w} {m}")
|
| 212 |
+
|
| 213 |
+
print(f" {'Property':<28} {'cuSOLVER':>10} {'FL':>10}")
|
| 214 |
+
row("Eigenpair max", "res_max")
|
| 215 |
+
row("Eigenpair mean", "res_mean")
|
| 216 |
+
row("Orthogonality max", "orth_max")
|
| 217 |
+
row("Orthogonality mean", "orth_mean")
|
| 218 |
+
row("Reconstruction max", "recon_max")
|
| 219 |
+
row("Reconstruction mean", "recon_mean")
|
| 220 |
+
row("Trace max", "tr_max")
|
| 221 |
+
row("Determinant max", "det_max")
|
| 222 |
+
row("Char poly max", "cp_max")
|
| 223 |
+
row("Char poly mean", "cp_mean")
|
| 224 |
+
del A
|
| 225 |
+
|
| 226 |
+
# ββ Accuracy pass/fail ββ
|
| 227 |
+
print("\n" + "="*72)
|
| 228 |
+
print(" ACCURACY PASS/FAIL")
|
| 229 |
+
print("="*72)
|
| 230 |
+
ok_all = True
|
| 231 |
+
for nx in [3,4,5,6,8,10,12,16]:
|
| 232 |
+
A = (lambda R:(R+R.mT)/2)(torch.randn(1024, nx, nx, device=dev))
|
| 233 |
+
rv,rV = torch.linalg.eigh(A); fv,fV = FLEigh()(A)
|
| 234 |
+
ve = (fv-rv).abs().max().item()
|
| 235 |
+
dots = torch.bmm(rV.double().mT, fV.double()).abs().max(dim=-1).values.min().item()
|
| 236 |
+
ok = ve < 1e-2 and dots > 0.99
|
| 237 |
+
if not ok: ok_all = False
|
| 238 |
+
print(f" [{'OK' if ok else 'NO'}] n={nx:>2} val_diff={ve:.1e} align={dots:.6f}")
|
| 239 |
+
del A
|
| 240 |
+
|
| 241 |
+
# ββ Speed ββ
|
| 242 |
+
N=6; B=4096
|
| 243 |
+
A = (lambda R:(R+R.mT)/2)(torch.randn(B, N, N, device=dev))
|
| 244 |
+
solver = FLEigh()
|
| 245 |
+
|
| 246 |
+
print(f"\n" + "="*72)
|
| 247 |
+
print(f" THROUGHPUT (n={N} B={B})")
|
| 248 |
+
print("="*72)
|
| 249 |
+
for _ in range(5): solver(A); sync()
|
| 250 |
+
tr = gt(lambda: torch.linalg.eigh(A))
|
| 251 |
+
te = gt(lambda: solver(A))
|
| 252 |
+
print(f" cuSOLVER: {fmt(tr)}")
|
| 253 |
+
print(f" FL eager: {fmt(te)} ({tr/te:.2f}Γ)")
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
cs = torch.compile(solver, fullgraph=True)
|
| 257 |
+
print(" Compiling...", end=" ", flush=True)
|
| 258 |
+
for _ in range(3): cs(A); sync()
|
| 259 |
+
print("done.")
|
| 260 |
+
tc = gt(lambda: cs(A))
|
| 261 |
+
print(f" FL compiled: {fmt(tc)} ({tr/tc:.2f}Γ)")
|
| 262 |
+
except Exception as e:
|
| 263 |
+
print(f" COMPILE FAILED: {str(e)[:100]}")
|
| 264 |
+
tc = None
|
| 265 |
+
|
| 266 |
+
# ββ Memory ββ
|
| 267 |
+
print(f"\n MEMORY")
|
| 268 |
+
for l,fn in [("cuSOLVER",lambda:torch.linalg.eigh(A)),("FL",lambda:solver(A))]:
|
| 269 |
+
torch.cuda.empty_cache(); gc.collect(); torch.cuda.reset_peak_memory_stats()
|
| 270 |
+
b=torch.cuda.memory_allocated(); fn(); sync()
|
| 271 |
+
print(f" {l:<10} {(torch.cuda.max_memory_allocated()-b)/1024**2:.1f}MB")
|
| 272 |
+
|
| 273 |
+
# ββ Summary ββ
|
| 274 |
+
print(f"\n" + "="*72)
|
| 275 |
+
print(f" All pass: {ok_all}")
|
| 276 |
+
if tc: print(f" Compiled: {tr/tc:.2f}Γ vs cuSOLVER")
|
| 277 |
+
print("="*72)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == '__main__':
|
| 281 |
+
main()
|