| """ |
| FL Eigh β Performance Benchmark |
| Both modes: Fast (fp32 eigvecs, ~240Β΅s) and Precise (fp64 eigvecs + Rayleigh, ~340Β΅s) |
| vs cuSOLVER baseline. |
| |
| This is a full-graph compile capable higher-accuracy variation of the eigens formula at a small dip in performance overall. |
| """ |
| import math, time, gc, sys |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from typing import Tuple |
|
|
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
| torch.set_float32_matmul_precision('highest') |
|
|
|
|
| |
| |
| |
|
|
| def _fl_coefficients(As, B, n, device): |
| """Phase 1: Faddeev-LeVerrier in fp64.""" |
| Ad = As.double() |
| eye_d = torch.eye(n, device=device, dtype=torch.float64).unsqueeze(0).expand(B, -1, -1) |
| c = torch.zeros(B, n + 1, device=device, dtype=torch.float64) |
| c[:, n] = 1.0 |
| Mstore = torch.zeros(n + 1, B, n, n, device=device, dtype=torch.float64) |
| Mk = torch.zeros(B, n, n, device=device, dtype=torch.float64) |
| for k in range(1, n + 1): |
| Mk = torch.bmm(Ad, Mk) + c[:, n - k + 1, None, None] * eye_d |
| Mstore[k] = Mk |
| c[:, n - k] = -(Ad * Mk).sum((-2, -1)) / k |
| return c, Mstore |
|
|
|
|
| def _laguerre_polish(c, As, B, n, device): |
| """Phase 2: Laguerre + deflation + Newton polish.""" |
| use_f64 = n > 6 |
| dt = torch.float64 if use_f64 else torch.float32 |
| cl = c.to(dt).clone() |
| roots = torch.zeros(B, n, device=device, dtype=dt) |
| zi = As.to(dt).diagonal(dim1=-2, dim2=-1).sort(dim=-1).values |
| zi = zi + torch.linspace(-1e-4, 1e-4, n, device=device, dtype=dt).unsqueeze(0) |
| for ri in range(n): |
| deg = n - ri; z = zi[:, ri] |
| for _ in range(5): |
| pv = cl[:, deg]; dp = torch.zeros(B, device=device, dtype=dt) |
| d2 = torch.zeros(B, device=device, dtype=dt) |
| for j in range(deg - 1, -1, -1): |
| d2 = d2 * z + dp; dp = dp * z + pv; pv = pv * z + cl[:, j] |
| ok = pv.abs() > 1e-30 |
| ps = torch.where(ok, pv, torch.ones_like(pv)) |
| G = torch.where(ok, dp / ps, torch.zeros_like(dp)) |
| H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2)) |
| disc = ((deg - 1.0) * (deg * H - G * G)).clamp(min=0.0) |
| sq = torch.sqrt(disc); gp = G + sq; gm = G - sq |
| den = torch.where(gp.abs() >= gm.abs(), gp, gm) |
| dok = den.abs() > 1e-20 |
| ds = torch.where(dok, den, torch.ones_like(den)) |
| z = z - torch.where(dok, float(deg) / ds, torch.zeros_like(den)) |
| roots[:, ri] = z |
| b = cl[:, deg] |
| for j in range(deg - 1, 0, -1): |
| bn = cl[:, j] + z * b; cl[:, j] = b; b = bn |
| cl[:, 0] = b |
| roots = roots.double() |
| for _ in range(3): |
| pv = torch.ones(B, n, device=device, dtype=torch.float64) |
| dp = torch.zeros(B, n, device=device, dtype=torch.float64) |
| for j in range(n - 1, -1, -1): |
| dp = dp * roots + pv; pv = pv * roots + c[:, j:j + 1] |
| ok = dp.abs() > 1e-30 |
| dps = torch.where(ok, dp, torch.ones_like(dp)) |
| roots = roots - torch.where(ok, pv / dps, torch.zeros_like(pv)) |
| return roots |
|
|
|
|
| |
| |
| |
|
|
| class FLEighFast(nn.Module): |
| def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]: |
| B, n, _ = A.shape; device = A.device |
| scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12) |
| As = A / scale[:, None, None] |
| c, Mstore = _fl_coefficients(As, B, n, device) |
| roots = _laguerre_polish(c, As, B, n, device) |
| evals_f = roots.float() |
| |
| Mf = Mstore.float() |
| R = Mf[1].unsqueeze(1).expand(-1, n, -1, -1).clone() |
| for k in range(2, n + 1): |
| R = R * evals_f[:, :, None, None] + Mf[k].unsqueeze(1) |
| vec = R.sum(dim=-1) |
| vnorm = vec.norm(dim=-1, keepdim=True) |
| vec = torch.where(vnorm > 1e-10, vec, R[:, :, :, 0]) |
| vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30) |
| V = vec.transpose(-2, -1) |
| |
| eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1) |
| Y = torch.bmm(V.mT, V); X = eye_f.clone() |
| for _ in range(2): |
| T = 3.0 * eye_f - Y; X = 0.5 * torch.bmm(X, T); Y = 0.5 * torch.bmm(T, Y) |
| V = torch.bmm(V, X) |
| |
| AV = torch.bmm(A, V); evals = (V * AV).sum(dim=-2) |
| se, perm = evals.sort(dim=-1) |
| return se, V.gather(-1, perm.unsqueeze(-2).expand_as(V)) |
|
|
|
|
| |
| |
| |
|
|
| class FLEighPrecise(nn.Module): |
| def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]: |
| B, n, _ = A.shape; device = A.device |
| scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12) |
| As = A / scale[:, None, None] |
| c, Mstore = _fl_coefficients(As, B, n, device) |
| roots = _laguerre_polish(c, As, B, n, device) |
| |
| lam = roots |
| R = Mstore[1].unsqueeze(1).expand(-1, n, -1, -1).clone() |
| for k in range(2, n + 1): |
| R = R * lam[:, :, None, None] + Mstore[k].unsqueeze(1) |
| cnorms = R.norm(dim=-2) |
| best = cnorms.argmax(dim=-1) |
| idx = best.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n, 1) |
| vec = R.gather(-1, idx).squeeze(-1) |
| vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30) |
| V = vec.float().transpose(-2, -1) |
| |
| eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1) |
| Y = torch.bmm(V.mT, V); X = eye_f.clone() |
| for _ in range(2): |
| T = 3.0 * eye_f - Y; X = 0.5 * torch.bmm(X, T); Y = 0.5 * torch.bmm(T, Y) |
| V = torch.bmm(V, X) |
| |
| AV = torch.bmm(A, V); evals = (V * AV).sum(dim=-2) |
| se, perm = evals.sort(dim=-1) |
| return se, V.gather(-1, perm.unsqueeze(-2).expand_as(V)) |
|
|
|
|
| |
| |
| |
|
|
| def sync(): torch.cuda.synchronize() |
|
|
| def gpu_t(fn, w=20, r=200): |
| for _ in range(w): fn() |
| sync(); t0 = time.perf_counter() |
| for _ in range(r): fn() |
| sync(); return (time.perf_counter() - t0) / r |
|
|
| def fmt(s): |
| if s < 1e-3: return f"{s*1e6:.1f}Β΅s" |
| if s < 1.0: return f"{s*1e3:.2f}ms" |
| return f"{s:.3f}s" |
|
|
| def make(B, n, dev): |
| R = torch.randn(B, n, n, device=dev) |
| return (R + R.mT) / 2 |
|
|
|
|
| def main(): |
| if not torch.cuda.is_available(): sys.exit(1) |
| dev = torch.device('cuda') |
| p = torch.cuda.get_device_properties(0) |
|
|
| print("=" * 72) |
| print(" FL Eigh β Performance Benchmark") |
| print("=" * 72) |
| print(f" {p.name}") |
| print(f" PyTorch {torch.__version__}") |
| print(f" VRAM: {p.total_memory / 1024**3:.1f} GB") |
|
|
| N = 6; B = 4096 |
| A = make(B, N, dev) |
|
|
| fast = FLEighFast() |
| precise = FLEighPrecise() |
|
|
| |
| print("\n Compiling Fast (fullgraph=True)...", end=" ", flush=True) |
| c_fast = torch.compile(fast, fullgraph=True) |
| for _ in range(3): c_fast(A); sync() |
| print("done.") |
|
|
| print(" Compiling Precise (fullgraph=True)...", end=" ", flush=True) |
| c_precise = torch.compile(precise, fullgraph=True) |
| for _ in range(3): c_precise(A); sync() |
| print("done.") |
|
|
| |
| print(f"\n" + "=" * 72) |
| print(f" PRIMARY: n={N} B={B}") |
| print("=" * 72) |
|
|
| t_ref = gpu_t(lambda: torch.linalg.eigh(A)) |
| t_fast = gpu_t(lambda: c_fast(A)) |
| t_prec = gpu_t(lambda: c_precise(A)) |
| t_fast_e = gpu_t(lambda: fast(A)) |
| t_prec_e = gpu_t(lambda: precise(A)) |
|
|
| print(f"\n {'Implementation':<28} {'Eager':>10} {'Compiled':>10} {'vs cuSOLVER':>12}") |
| print(f" {'β'*28} {'β'*10} {'β'*10} {'β'*12}") |
| print(f" {'cuSOLVER':<28} {'β':>10} {fmt(t_ref):>10} {'1.00Γ':>12}") |
| print(f" {'FL Fast (fp32 eigvec)':<28} {fmt(t_fast_e):>10} {fmt(t_fast):>10} {t_ref/t_fast:>11.2f}Γ") |
| print(f" {'FL Precise (fp64+Rayleigh)':<28} {fmt(t_prec_e):>10} {fmt(t_prec):>10} {t_ref/t_prec:>11.2f}Γ") |
|
|
| |
| print(f"\n" + "=" * 72) |
| print(f" BATCH SCALING (n={N}, compiled, dynamic=True)") |
| print("=" * 72) |
|
|
| cd_fast = torch.compile(FLEighFast(), fullgraph=True, dynamic=True) |
| cd_prec = torch.compile(FLEighPrecise(), fullgraph=True, dynamic=True) |
| |
| cd_fast(A); cd_prec(A); sync() |
|
|
| print(f"\n {'B':>6} {'cuSOLVER':>10} {'FL Fast':>10} {'F/cuS':>7} {'FL Prec':>10} {'P/cuS':>7}") |
| print(f" {'β'*6} {'β'*10} {'β'*10} {'β'*7} {'β'*10} {'β'*7}") |
|
|
| for Bx in [256, 512, 1024, 2048, 4096, 8192, 16384]: |
| try: |
| Ax = make(Bx, N, dev) |
| |
| cd_fast(Ax); cd_prec(Ax); sync() |
| tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100) |
| tf = gpu_t(lambda: cd_fast(Ax), 10, 100) |
| tp = gpu_t(lambda: cd_prec(Ax), 10, 100) |
| print(f" {Bx:>6} {fmt(tr):>10} {fmt(tf):>10} {tr/tf:>6.2f}Γ {fmt(tp):>10} {tr/tp:>6.2f}Γ") |
| del Ax |
| except RuntimeError: |
| print(f" {Bx:>6} OOM"); torch.cuda.empty_cache() |
|
|
| |
| print(f"\n" + "=" * 72) |
| print(f" SIZE SCALING (B={B}, compiled per size)") |
| print("=" * 72) |
|
|
| print(f"\n {'n':>3} {'cuSOLVER':>10} {'FL Fast':>10} {'F/cuS':>7} {'FL Prec':>10} {'P/cuS':>7}") |
| print(f" {'β'*3} {'β'*10} {'β'*10} {'β'*7} {'β'*10} {'β'*7}") |
|
|
| for nx in [3, 4, 5, 6, 8, 10, 12, 16]: |
| try: |
| Ax = make(B, nx, dev) |
| sf = torch.compile(FLEighFast(), fullgraph=True) |
| sp = torch.compile(FLEighPrecise(), fullgraph=True) |
| for _ in range(3): sf(Ax); sp(Ax); sync() |
| tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100) |
| tf = gpu_t(lambda: sf(Ax), 10, 100) |
| tp = gpu_t(lambda: sp(Ax), 10, 100) |
| print(f" {nx:>3} {fmt(tr):>10} {fmt(tf):>10} {tr/tf:>6.2f}Γ {fmt(tp):>10} {tr/tp:>6.2f}Γ") |
| del Ax, sf, sp |
| except Exception as e: |
| print(f" {nx:>3} ERR: {str(e)[:40]}") |
| torch.cuda.empty_cache() |
|
|
| |
| print(f"\n" + "=" * 72) |
| print(" MEMORY (n=6 B=4096)") |
| print("=" * 72) |
|
|
| for label, fn in [("cuSOLVER", lambda: torch.linalg.eigh(A)), |
| ("FL Fast", lambda: fast(A)), |
| ("FL Precise", lambda: precise(A))]: |
| torch.cuda.empty_cache(); gc.collect() |
| torch.cuda.reset_peak_memory_stats() |
| base = torch.cuda.memory_allocated() |
| fn(); sync() |
| delta = (torch.cuda.max_memory_allocated() - base) / 1024**2 |
| print(f" {label:<16} {delta:.1f} MB") |
|
|
| |
| print(f"\n" + "=" * 72) |
| print(" THROUGHPUT AT SCALE (matrices/second)") |
| print("=" * 72) |
|
|
| for nx, Bx in [(5, 8192), (6, 8192), (6, 16384), (8, 4096)]: |
| try: |
| Ax = make(Bx, nx, dev) |
| sf = torch.compile(FLEighFast(), fullgraph=True) |
| for _ in range(3): sf(Ax); sync() |
| tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100) |
| tf = gpu_t(lambda: sf(Ax), 10, 100) |
| thr_r = Bx / tr; thr_f = Bx / tf |
| print(f" n={nx} B={Bx:>5}: cuSOLVER {thr_r/1e6:.2f}M/s FL {thr_f/1e6:.2f}M/s ({tf/tr:.2f}Γ time)") |
| del Ax, sf |
| except Exception as e: |
| print(f" n={nx} B={Bx:>5}: {str(e)[:40]}") |
| torch.cuda.empty_cache() |
|
|
| print(f"\n" + "=" * 72) |
| print(f" Fast compiled (n=6 B=4096): {fmt(t_fast)} ({t_ref/t_fast:.2f}Γ vs cuSOLVER)") |
| print(f" Precise compiled: {fmt(t_prec)} ({t_ref/t_prec:.2f}Γ vs cuSOLVER)") |
| print(f" cuSOLVER: {fmt(t_ref)}") |
| print("=" * 72) |
|
|
|
|
| if __name__ == '__main__': |
| main() |