""" fl_eigh_cuda.py — CUDA FL Hybrid Eigh via CuPy RawKernel. Compiles at runtime using NVRTC (part of CUDA toolkit, already installed). No ninja, no C++ compiler, no build system. Just pip install cupy-cuda12x. PyTorch <-> CuPy via DLPack (zero-copy). Usage: from fl_eigh_cuda import fl_eigh_cuda evals, evecs = fl_eigh_cuda(A) # A is [B, 6, 6] PyTorch CUDA tensor """ import math, time, gc, sys import torch from torch import Tensor from typing import Tuple torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.set_float32_matmul_precision('highest') _KERNEL_SRC = r""" extern "C" __global__ void fl_eigh_kernel( const float* __restrict__ A_in, float* __restrict__ evals_out, float* __restrict__ evecs_out, int B ) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid >= B) return; const int NN = 6; const int N2 = 36; // Load A and pre-scale double a[36]; double frob_sq = 0.0; for (int i = 0; i < N2; i++) { a[i] = (double)A_in[tid * N2 + i]; frob_sq += a[i] * a[i]; } double scale = sqrt(frob_sq / 6.0); if (scale < 1e-12) scale = 1e-12; double inv_s = 1.0 / scale; for (int i = 0; i < N2; i++) a[i] *= inv_s; // Phase 1: FL coefficients (fp64) double c[7]; for (int i = 0; i < 7; i++) c[i] = 0.0; c[6] = 1.0; double m[36]; for (int i = 0; i < N2; i++) m[i] = 0.0; for (int k = 1; k <= NN; k++) { double mn[36]; for (int i = 0; i < NN; i++) { for (int j = 0; j < NN; j++) { double acc = 0.0; for (int l = 0; l < NN; l++) acc += a[i*NN+l] * m[l*NN+j]; if (i == j) acc += c[NN-k+1]; mn[i*NN+j] = acc; } } double tr = 0.0; for (int i = 0; i < NN; i++) for (int l = 0; l < NN; l++) tr += a[i*NN+l] * mn[l*NN+i]; c[NN-k] = -tr / (double)k; for (int i = 0; i < N2; i++) m[i] = mn[i]; } // Phase 2: Laguerre + deflation + polish (fp64) double diag[6]; for (int i = 0; i < NN; i++) diag[i] = a[i*NN+i]; for (int pass = 0; pass < NN-1; pass++) for (int j = 0; j < NN-1; j++) if (diag[j] > diag[j+1]) { double tmp = diag[j]; diag[j] = diag[j+1]; diag[j+1] = tmp; } for (int i = 0; i < NN; i++) diag[i] += -1e-4 + 2e-4 * (double)i / 5.0; double cl[7]; for (int i = 0; i < 7; i++) cl[i] = c[i]; double roots[6]; for (int ri = 0; ri < NN; ri++) { int deg = NN - ri; double z = diag[ri]; for (int lag = 0; lag < 5; lag++) { double pv = cl[deg], dp = 0.0, d2 = 0.0; for (int j = deg - 1; j >= 0; j--) { d2 = d2 * z + dp; dp = dp * z + pv; pv = pv * z + cl[j]; } if (fabs(pv) > 1e-30) { double G = dp / pv; double H = G * G - 2.0 * d2 / pv; double disc = ((double)(deg-1)) * ((double)deg * H - G * G); if (disc < 0.0) disc = 0.0; double sq = sqrt(disc); double gp = G + sq, gm = G - sq; double den = (fabs(gp) >= fabs(gm)) ? gp : gm; if (fabs(den) > 1e-20) z -= (double)deg / den; } } roots[ri] = z; if (deg > 1) { double b = cl[deg]; for (int j = deg - 1; j > 0; j--) { double bn = cl[j] + z * b; cl[j] = b; b = bn; } cl[0] = b; } } // Newton polish for (int pol = 0; pol < 3; pol++) for (int ri = 0; ri < NN; ri++) { double pv = c[NN], dp = 0.0; for (int j = NN - 1; j >= 0; j--) { dp = dp * roots[ri] + pv; pv = pv * roots[ri] + c[j]; } if (fabs(dp) > 1e-30) roots[ri] -= pv / dp; } // Phase 3: Eigenvectors via interleaved FL+Horner (fp64) float evecs[36]; for (int ei = 0; ei < NN; ei++) { double lam = roots[ei]; double m_loc[36], r_loc[36]; for (int i = 0; i < N2; i++) m_loc[i] = 0.0; for (int k = 1; k <= NN; k++) { double mn_loc[36]; for (int i = 0; i < NN; i++) for (int j = 0; j < NN; j++) { double acc = 0.0; for (int l = 0; l < NN; l++) acc += a[i*NN+l] * m_loc[l*NN+j]; if (i == j) acc += c[NN-k+1]; mn_loc[i*NN+j] = acc; } if (k == 1) for (int i = 0; i < N2; i++) r_loc[i] = mn_loc[i]; else for (int i = 0; i < N2; i++) r_loc[i] = r_loc[i] * lam + mn_loc[i]; for (int i = 0; i < N2; i++) m_loc[i] = mn_loc[i]; } int best_j = 0; double best_norm = -1.0; for (int j = 0; j < NN; j++) { double col_sq = 0.0; for (int i = 0; i < NN; i++) col_sq += r_loc[i*NN+j] * r_loc[i*NN+j]; if (col_sq > best_norm) { best_norm = col_sq; best_j = j; } } double vnorm = 0.0; double vec[6]; for (int i = 0; i < NN; i++) { vec[i] = r_loc[i*NN + best_j]; vnorm += vec[i] * vec[i]; } vnorm = sqrt(vnorm) + 1e-30; for (int i = 0; i < NN; i++) evecs[i*NN + ei] = (float)(vec[i] / vnorm); } // Phase 4: Newton-Schulz (fp32, 2 iters) for (int ns = 0; ns < 2; ns++) { float y[36], t_m[36], vn[36]; for (int i = 0; i < NN; i++) for (int j = 0; j < NN; j++) { float acc = 0.0f; for (int l = 0; l < NN; l++) acc += evecs[l*NN+i] * evecs[l*NN+j]; y[i*NN+j] = acc; } for (int i = 0; i < NN; i++) for (int j = 0; j < NN; j++) t_m[i*NN+j] = ((i==j) ? 3.0f : 0.0f) - y[i*NN+j]; for (int i = 0; i < NN; i++) for (int j = 0; j < NN; j++) { float acc = 0.0f; for (int l = 0; l < NN; l++) acc += evecs[i*NN+l] * t_m[l*NN+j]; vn[i*NN+j] = 0.5f * acc; } for (int i = 0; i < N2; i++) evecs[i] = vn[i]; } // Phase 5: Rayleigh quotient (fp32) float af[36]; for (int i = 0; i < N2; i++) af[i] = (float)a[i]; float evals_local[6]; for (int ei = 0; ei < NN; ei++) { float lam_f = 0.0f; for (int l = 0; l < NN; l++) { float av = 0.0f; for (int mm = 0; mm < NN; mm++) av += af[l*NN+mm] * evecs[mm*NN+ei]; lam_f += evecs[l*NN+ei] * av; } evals_local[ei] = lam_f * (float)scale; } // Sort ascending + permute int perm[6]; for (int i = 0; i < NN; i++) perm[i] = i; for (int pass = 0; pass < NN-1; pass++) for (int j = 0; j < NN-1; j++) if (evals_local[j] > evals_local[j+1]) { float tmp = evals_local[j]; evals_local[j] = evals_local[j+1]; evals_local[j+1] = tmp; int ptmp = perm[j]; perm[j] = perm[j+1]; perm[j+1] = ptmp; } for (int i = 0; i < NN; i++) evals_out[tid * NN + i] = evals_local[i]; for (int j_out = 0; j_out < NN; j_out++) { int j_src = perm[j_out]; for (int i = 0; i < NN; i++) evecs_out[tid * N2 + i*NN + j_out] = evecs[i*NN + j_src]; } } """ # ═══════════════════════════════════════════════════════════════════════ # CuPy compilation + PyTorch wrapper # ═══════════════════════════════════════════════════════════════════════ _kernel = None def _get_kernel(): global _kernel if _kernel is not None: return _kernel import cupy print(" Compiling via NVRTC...", end=" ", flush=True) _kernel = cupy.RawKernel(_KERNEL_SRC, 'fl_eigh_kernel') # Force compilation now (not on first launch) _kernel.compile() print("done.") return _kernel def fl_eigh_cuda(A: Tensor) -> Tuple[Tensor, Tensor]: """CUDA FL Hybrid Eigendecomposition for [B, 6, 6] symmetric matrices. Uses CuPy RawKernel (NVRTC). Zero-copy PyTorch interop via data_ptr. """ assert A.is_cuda and A.shape[-2:] == (6, 6), f"Need CUDA [B,6,6], got {A.shape}" B = A.shape[0] kernel = _get_kernel() A_contig = A.contiguous().float() evals = torch.empty(B, 6, device=A.device, dtype=torch.float32) evecs = torch.empty(B, 6, 6, device=A.device, dtype=torch.float32) import cupy # Raw pointers — zero copy, no DLPack needed a_ptr = cupy.cuda.MemoryPointer( cupy.cuda.UnownedMemory(A_contig.data_ptr(), A_contig.nelement() * 4, None), 0) ev_ptr = cupy.cuda.MemoryPointer( cupy.cuda.UnownedMemory(evals.data_ptr(), evals.nelement() * 4, None), 0) vc_ptr = cupy.cuda.MemoryPointer( cupy.cuda.UnownedMemory(evecs.data_ptr(), evecs.nelement() * 4, None), 0) threads = 128 blocks = (B + threads - 1) // threads # Launch on PyTorch's current CUDA stream stream = cupy.cuda.ExternalStream(torch.cuda.current_stream().cuda_stream) with stream: kernel((blocks,), (threads,), (a_ptr, ev_ptr, vc_ptr, B)) return evals, evecs # ═══════════════════════════════════════════════════════════════════════ # Math purity test # ═══════════════════════════════════════════════════════════════════════ def math_test(A, vals, vecs): B,n,_=A.shape; dev=A.device Ad=A.double(); vd=vals.double(); Vd=vecs.double() AV=torch.bmm(Ad,Vd); VL=Vd*vd.unsqueeze(-2) An=Ad.reshape(B,-1).norm(dim=-1,keepdim=True).clamp(min=1e-30) res=(AV-VL).norm(dim=-2)/An VtV=torch.bmm(Vd.mT,Vd); I=torch.eye(n,device=dev,dtype=torch.float64).unsqueeze(0) orth=(VtV-I).reshape(B,-1).norm(dim=-1) recon=torch.bmm(Vd*vd.unsqueeze(-2),Vd.mT) recon_err=(Ad-recon).reshape(B,-1).norm(dim=-1)/An.squeeze(-1) tr_err=(Ad.diagonal(dim1=-2,dim2=-1).sum(-1)-vd.sum(-1)).abs() det_A=torch.linalg.det(Ad); det_err=(det_A-vd.prod(-1)).abs()/det_A.abs().clamp(min=1e-30) return dict(res_max=res.max().item(), res_mean=res.mean().item(), orth_max=orth.max().item(), orth_mean=orth.mean().item(), recon_max=recon_err.max().item(), recon_mean=recon_err.mean().item(), tr_max=tr_err.max().item(), det_max=det_err.max().item()) # ═══════════════════════════════════════════════════════════════════════ # Benchmark # ═══════════════════════════════════════════════════════════════════════ def sync(): torch.cuda.synchronize() def gt(fn,w=20,r=200): for _ in range(w): fn() sync(); t=time.perf_counter() for _ in range(r): fn() sync(); return (time.perf_counter()-t)/r def fmt(s): if s<1e-3: return f"{s*1e6:.1f}us" if s<1: return f"{s*1e3:.2f}ms" return f"{s:.3f}s" def main(): if not torch.cuda.is_available(): sys.exit(1) dev=torch.device('cuda') p=torch.cuda.get_device_properties(0) print("="*72) print(" FL Eigh CUDA Kernel (CuPy/NVRTC)") print("="*72) print(f" {p.name}") print(f" PyTorch {torch.__version__}") N=6; B=4096 A=(lambda R:(R+R.mT)/2)(torch.randn(B,N,N,device=dev)) rv,rV=torch.linalg.eigh(A) _get_kernel() # Accuracy print(f"\n ACCURACY (n={N} B={B})") cv,cV=fl_eigh_cuda(A) ve=(cv-rv).abs().max().item() dots=torch.bmm(rV.double().mT,cV.double()).abs().max(dim=-1).values.min().item() print(f" CUDA FL: val={ve:.1e} align={dots:.6f}") # Math purity mc=math_test(A,rv,rV); mf=math_test(A,cv,cV) wins=0 print(f"\n MATH PURITY: CUDA FL vs cuSOLVER") print(f" {'Property':<28} {'cuSOLVER':>10} {'CUDA FL':>10} {'Win':>6}") for key in ['res_max','res_mean','orth_max','orth_mean','recon_max','recon_mean','tr_max','det_max']: vc=mc[key]; vf=mf[key]; w='FL' if vf10.1e} {vf:>10.1e} {w:>6}") print(f"\n CUDA FL wins {wins}/8") # Throughput print(f"\n THROUGHPUT (n={N} B={B})") tr=gt(lambda:torch.linalg.eigh(A)) tc=gt(lambda:fl_eigh_cuda(A)) print(f" cuSOLVER: {fmt(tr)}") print(f" CUDA FL: {fmt(tc)} ({tr/tc:.2f}x)") # Batch scaling print(f"\n BATCH SCALING (n={N})") print(f" {'B':>6} {'cuSOLVER':>10} {'CUDA FL':>10} {'ratio':>7}") for Bx in [256,512,1024,2048,4096,8192,16384,32768]: try: Ax=(lambda R:(R+R.mT)/2)(torch.randn(Bx,N,N,device=dev)) t1=gt(lambda:torch.linalg.eigh(Ax),10,100) t2=gt(lambda:fl_eigh_cuda(Ax),10,100) print(f" {Bx:>6} {fmt(t1):>10} {fmt(t2):>10} {t1/t2:>6.2f}x") del Ax except RuntimeError: print(f" {Bx:>6} OOM"); torch.cuda.empty_cache() # Memory print(f"\n MEMORY (n={N} B={B})") for lbl,fn in [("cuSOLVER",lambda:torch.linalg.eigh(A)),("CUDA FL",lambda:fl_eigh_cuda(A))]: torch.cuda.empty_cache(); gc.collect(); torch.cuda.reset_peak_memory_stats() base=torch.cuda.memory_allocated(); fn(); sync() print(f" {lbl:<12} {(torch.cuda.max_memory_allocated()-base)/1024**2:.1f}MB") print("="*72) if __name__=='__main__': main()