eigh-triton / eigh_pytorch_version_27.py
AbstractPhil's picture
Create eigh_pytorch_version_27.py
8fd0bfa verified
"""
FL Eigh β€” Performance Benchmark
Both modes: Fast (fp32 eigvecs, ~240Β΅s) and Precise (fp64 eigvecs + Rayleigh, ~340Β΅s)
vs cuSOLVER baseline.
This is a full-graph compile capable higher-accuracy variation of the eigens formula at a small dip in performance overall.
"""
import math, time, gc, sys
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
torch.set_float32_matmul_precision('highest')
# ═══════════════════════════════════════════════════════════════════════
# Core: FL phases shared by both modes
# ═══════════════════════════════════════════════════════════════════════
def _fl_coefficients(As, B, n, device):
"""Phase 1: Faddeev-LeVerrier in fp64."""
Ad = As.double()
eye_d = torch.eye(n, device=device, dtype=torch.float64).unsqueeze(0).expand(B, -1, -1)
c = torch.zeros(B, n + 1, device=device, dtype=torch.float64)
c[:, n] = 1.0
Mstore = torch.zeros(n + 1, B, n, n, device=device, dtype=torch.float64)
Mk = torch.zeros(B, n, n, device=device, dtype=torch.float64)
for k in range(1, n + 1):
Mk = torch.bmm(Ad, Mk) + c[:, n - k + 1, None, None] * eye_d
Mstore[k] = Mk
c[:, n - k] = -(Ad * Mk).sum((-2, -1)) / k
return c, Mstore
def _laguerre_polish(c, As, B, n, device):
"""Phase 2: Laguerre + deflation + Newton polish."""
use_f64 = n > 6
dt = torch.float64 if use_f64 else torch.float32
cl = c.to(dt).clone()
roots = torch.zeros(B, n, device=device, dtype=dt)
zi = As.to(dt).diagonal(dim1=-2, dim2=-1).sort(dim=-1).values
zi = zi + torch.linspace(-1e-4, 1e-4, n, device=device, dtype=dt).unsqueeze(0)
for ri in range(n):
deg = n - ri; z = zi[:, ri]
for _ in range(5):
pv = cl[:, deg]; dp = torch.zeros(B, device=device, dtype=dt)
d2 = torch.zeros(B, device=device, dtype=dt)
for j in range(deg - 1, -1, -1):
d2 = d2 * z + dp; dp = dp * z + pv; pv = pv * z + cl[:, j]
ok = pv.abs() > 1e-30
ps = torch.where(ok, pv, torch.ones_like(pv))
G = torch.where(ok, dp / ps, torch.zeros_like(dp))
H = G * G - torch.where(ok, 2.0 * d2 / ps, torch.zeros_like(d2))
disc = ((deg - 1.0) * (deg * H - G * G)).clamp(min=0.0)
sq = torch.sqrt(disc); gp = G + sq; gm = G - sq
den = torch.where(gp.abs() >= gm.abs(), gp, gm)
dok = den.abs() > 1e-20
ds = torch.where(dok, den, torch.ones_like(den))
z = z - torch.where(dok, float(deg) / ds, torch.zeros_like(den))
roots[:, ri] = z
b = cl[:, deg]
for j in range(deg - 1, 0, -1):
bn = cl[:, j] + z * b; cl[:, j] = b; b = bn
cl[:, 0] = b
roots = roots.double()
for _ in range(3):
pv = torch.ones(B, n, device=device, dtype=torch.float64)
dp = torch.zeros(B, n, device=device, dtype=torch.float64)
for j in range(n - 1, -1, -1):
dp = dp * roots + pv; pv = pv * roots + c[:, j:j + 1]
ok = dp.abs() > 1e-30
dps = torch.where(ok, dp, torch.ones_like(dp))
roots = roots - torch.where(ok, pv / dps, torch.zeros_like(pv))
return roots
# ═══════════════════════════════════════════════════════════════════════
# Fast Mode: fp32 eigvecs + sum-of-columns (best speed)
# ═══════════════════════════════════════════════════════════════════════
class FLEighFast(nn.Module):
def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]:
B, n, _ = A.shape; device = A.device
scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12)
As = A / scale[:, None, None]
c, Mstore = _fl_coefficients(As, B, n, device)
roots = _laguerre_polish(c, As, B, n, device)
evals_f = roots.float()
# fp32 eigvecs: broadcast Horner + sum-of-columns
Mf = Mstore.float()
R = Mf[1].unsqueeze(1).expand(-1, n, -1, -1).clone()
for k in range(2, n + 1):
R = R * evals_f[:, :, None, None] + Mf[k].unsqueeze(1)
vec = R.sum(dim=-1)
vnorm = vec.norm(dim=-1, keepdim=True)
vec = torch.where(vnorm > 1e-10, vec, R[:, :, :, 0])
vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30)
V = vec.transpose(-2, -1)
# NS orthogonalization
eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1)
Y = torch.bmm(V.mT, V); X = eye_f.clone()
for _ in range(2):
T = 3.0 * eye_f - Y; X = 0.5 * torch.bmm(X, T); Y = 0.5 * torch.bmm(T, Y)
V = torch.bmm(V, X)
# Rayleigh quotient
AV = torch.bmm(A, V); evals = (V * AV).sum(dim=-2)
se, perm = evals.sort(dim=-1)
return se, V.gather(-1, perm.unsqueeze(-2).expand_as(V))
# ═══════════════════════════════════════════════════════════════════════
# Precise Mode: fp64 eigvecs + max-col + Rayleigh (best accuracy)
# ═══════════════════════════════════════════════════════════════════════
class FLEighPrecise(nn.Module):
def forward(self, A: Tensor) -> Tuple[Tensor, Tensor]:
B, n, _ = A.shape; device = A.device
scale = (torch.linalg.norm(A.reshape(B, -1), dim=-1) / math.sqrt(n)).clamp(min=1e-12)
As = A / scale[:, None, None]
c, Mstore = _fl_coefficients(As, B, n, device)
roots = _laguerre_polish(c, As, B, n, device)
# fp64 eigvecs: broadcast Horner + max-col
lam = roots
R = Mstore[1].unsqueeze(1).expand(-1, n, -1, -1).clone()
for k in range(2, n + 1):
R = R * lam[:, :, None, None] + Mstore[k].unsqueeze(1)
cnorms = R.norm(dim=-2)
best = cnorms.argmax(dim=-1)
idx = best.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, n, 1)
vec = R.gather(-1, idx).squeeze(-1)
vec = vec / (vec.norm(dim=-1, keepdim=True) + 1e-30)
V = vec.float().transpose(-2, -1)
# NS
eye_f = torch.eye(n, device=device, dtype=torch.float32).unsqueeze(0).expand(B, -1, -1)
Y = torch.bmm(V.mT, V); X = eye_f.clone()
for _ in range(2):
T = 3.0 * eye_f - Y; X = 0.5 * torch.bmm(X, T); Y = 0.5 * torch.bmm(T, Y)
V = torch.bmm(V, X)
# Rayleigh
AV = torch.bmm(A, V); evals = (V * AV).sum(dim=-2)
se, perm = evals.sort(dim=-1)
return se, V.gather(-1, perm.unsqueeze(-2).expand_as(V))
# ═══════════════════════════════════════════════════════════════════════
# Benchmark
# ═══════════════════════════════════════════════════════════════════════
def sync(): torch.cuda.synchronize()
def gpu_t(fn, w=20, r=200):
for _ in range(w): fn()
sync(); t0 = time.perf_counter()
for _ in range(r): fn()
sync(); return (time.perf_counter() - t0) / r
def fmt(s):
if s < 1e-3: return f"{s*1e6:.1f}Β΅s"
if s < 1.0: return f"{s*1e3:.2f}ms"
return f"{s:.3f}s"
def make(B, n, dev):
R = torch.randn(B, n, n, device=dev)
return (R + R.mT) / 2
def main():
if not torch.cuda.is_available(): sys.exit(1)
dev = torch.device('cuda')
p = torch.cuda.get_device_properties(0)
print("=" * 72)
print(" FL Eigh β€” Performance Benchmark")
print("=" * 72)
print(f" {p.name}")
print(f" PyTorch {torch.__version__}")
print(f" VRAM: {p.total_memory / 1024**3:.1f} GB")
N = 6; B = 4096
A = make(B, N, dev)
fast = FLEighFast()
precise = FLEighPrecise()
# ── Compile all three ──
print("\n Compiling Fast (fullgraph=True)...", end=" ", flush=True)
c_fast = torch.compile(fast, fullgraph=True)
for _ in range(3): c_fast(A); sync()
print("done.")
print(" Compiling Precise (fullgraph=True)...", end=" ", flush=True)
c_precise = torch.compile(precise, fullgraph=True)
for _ in range(3): c_precise(A); sync()
print("done.")
# ── Primary config ──
print(f"\n" + "=" * 72)
print(f" PRIMARY: n={N} B={B}")
print("=" * 72)
t_ref = gpu_t(lambda: torch.linalg.eigh(A))
t_fast = gpu_t(lambda: c_fast(A))
t_prec = gpu_t(lambda: c_precise(A))
t_fast_e = gpu_t(lambda: fast(A))
t_prec_e = gpu_t(lambda: precise(A))
print(f"\n {'Implementation':<28} {'Eager':>10} {'Compiled':>10} {'vs cuSOLVER':>12}")
print(f" {'─'*28} {'─'*10} {'─'*10} {'─'*12}")
print(f" {'cuSOLVER':<28} {'β€”':>10} {fmt(t_ref):>10} {'1.00Γ—':>12}")
print(f" {'FL Fast (fp32 eigvec)':<28} {fmt(t_fast_e):>10} {fmt(t_fast):>10} {t_ref/t_fast:>11.2f}Γ—")
print(f" {'FL Precise (fp64+Rayleigh)':<28} {fmt(t_prec_e):>10} {fmt(t_prec):>10} {t_ref/t_prec:>11.2f}Γ—")
# ── Batch scaling ──
print(f"\n" + "=" * 72)
print(f" BATCH SCALING (n={N}, compiled, dynamic=True)")
print("=" * 72)
cd_fast = torch.compile(FLEighFast(), fullgraph=True, dynamic=True)
cd_prec = torch.compile(FLEighPrecise(), fullgraph=True, dynamic=True)
# Warmup dynamic compilation
cd_fast(A); cd_prec(A); sync()
print(f"\n {'B':>6} {'cuSOLVER':>10} {'FL Fast':>10} {'F/cuS':>7} {'FL Prec':>10} {'P/cuS':>7}")
print(f" {'─'*6} {'─'*10} {'─'*10} {'─'*7} {'─'*10} {'─'*7}")
for Bx in [256, 512, 1024, 2048, 4096, 8192, 16384]:
try:
Ax = make(Bx, N, dev)
# Warm each size
cd_fast(Ax); cd_prec(Ax); sync()
tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100)
tf = gpu_t(lambda: cd_fast(Ax), 10, 100)
tp = gpu_t(lambda: cd_prec(Ax), 10, 100)
print(f" {Bx:>6} {fmt(tr):>10} {fmt(tf):>10} {tr/tf:>6.2f}Γ— {fmt(tp):>10} {tr/tp:>6.2f}Γ—")
del Ax
except RuntimeError:
print(f" {Bx:>6} OOM"); torch.cuda.empty_cache()
# ── Size scaling ──
print(f"\n" + "=" * 72)
print(f" SIZE SCALING (B={B}, compiled per size)")
print("=" * 72)
print(f"\n {'n':>3} {'cuSOLVER':>10} {'FL Fast':>10} {'F/cuS':>7} {'FL Prec':>10} {'P/cuS':>7}")
print(f" {'─'*3} {'─'*10} {'─'*10} {'─'*7} {'─'*10} {'─'*7}")
for nx in [3, 4, 5, 6, 8, 10, 12, 16]:
try:
Ax = make(B, nx, dev)
sf = torch.compile(FLEighFast(), fullgraph=True)
sp = torch.compile(FLEighPrecise(), fullgraph=True)
for _ in range(3): sf(Ax); sp(Ax); sync()
tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100)
tf = gpu_t(lambda: sf(Ax), 10, 100)
tp = gpu_t(lambda: sp(Ax), 10, 100)
print(f" {nx:>3} {fmt(tr):>10} {fmt(tf):>10} {tr/tf:>6.2f}Γ— {fmt(tp):>10} {tr/tp:>6.2f}Γ—")
del Ax, sf, sp
except Exception as e:
print(f" {nx:>3} ERR: {str(e)[:40]}")
torch.cuda.empty_cache()
# ── Memory ──
print(f"\n" + "=" * 72)
print(" MEMORY (n=6 B=4096)")
print("=" * 72)
for label, fn in [("cuSOLVER", lambda: torch.linalg.eigh(A)),
("FL Fast", lambda: fast(A)),
("FL Precise", lambda: precise(A))]:
torch.cuda.empty_cache(); gc.collect()
torch.cuda.reset_peak_memory_stats()
base = torch.cuda.memory_allocated()
fn(); sync()
delta = (torch.cuda.max_memory_allocated() - base) / 1024**2
print(f" {label:<16} {delta:.1f} MB")
# ── Throughput at scale ──
print(f"\n" + "=" * 72)
print(" THROUGHPUT AT SCALE (matrices/second)")
print("=" * 72)
for nx, Bx in [(5, 8192), (6, 8192), (6, 16384), (8, 4096)]:
try:
Ax = make(Bx, nx, dev)
sf = torch.compile(FLEighFast(), fullgraph=True)
for _ in range(3): sf(Ax); sync()
tr = gpu_t(lambda: torch.linalg.eigh(Ax), 10, 100)
tf = gpu_t(lambda: sf(Ax), 10, 100)
thr_r = Bx / tr; thr_f = Bx / tf
print(f" n={nx} B={Bx:>5}: cuSOLVER {thr_r/1e6:.2f}M/s FL {thr_f/1e6:.2f}M/s ({tf/tr:.2f}Γ— time)")
del Ax, sf
except Exception as e:
print(f" n={nx} B={Bx:>5}: {str(e)[:40]}")
torch.cuda.empty_cache()
print(f"\n" + "=" * 72)
print(f" Fast compiled (n=6 B=4096): {fmt(t_fast)} ({t_ref/t_fast:.2f}Γ— vs cuSOLVER)")
print(f" Precise compiled: {fmt(t_prec)} ({t_ref/t_prec:.2f}Γ— vs cuSOLVER)")
print(f" cuSOLVER: {fmt(t_ref)}")
print("=" * 72)
if __name__ == '__main__':
main()