| """ |
| fl_eigh_cuda.py β CUDA FL Hybrid Eigh via CuPy RawKernel. |
| |
| Compiles at runtime using NVRTC (part of CUDA toolkit, already installed). |
| No ninja, no C++ compiler, no build system. Just pip install cupy-cuda12x. |
| |
| PyTorch <-> CuPy via DLPack (zero-copy). |
| |
| Usage: |
| from fl_eigh_cuda import fl_eigh_cuda |
| evals, evecs = fl_eigh_cuda(A) # A is [B, 6, 6] PyTorch CUDA tensor |
| """ |
|
|
| import math, time, gc, sys |
| import torch |
| from torch import Tensor |
| from typing import Tuple |
|
|
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cudnn.allow_tf32 = False |
| torch.set_float32_matmul_precision('highest') |
|
|
|
|
| _KERNEL_SRC = r""" |
| extern "C" __global__ |
| void fl_eigh_kernel( |
| const float* __restrict__ A_in, |
| float* __restrict__ evals_out, |
| float* __restrict__ evecs_out, |
| int B |
| ) { |
| int tid = blockIdx.x * blockDim.x + threadIdx.x; |
| if (tid >= B) return; |
| |
| const int NN = 6; |
| const int N2 = 36; |
| |
| // Load A and pre-scale |
| double a[36]; |
| double frob_sq = 0.0; |
| for (int i = 0; i < N2; i++) { |
| a[i] = (double)A_in[tid * N2 + i]; |
| frob_sq += a[i] * a[i]; |
| } |
| double scale = sqrt(frob_sq / 6.0); |
| if (scale < 1e-12) scale = 1e-12; |
| double inv_s = 1.0 / scale; |
| for (int i = 0; i < N2; i++) a[i] *= inv_s; |
| |
| // Phase 1: FL coefficients (fp64) |
| double c[7]; |
| for (int i = 0; i < 7; i++) c[i] = 0.0; |
| c[6] = 1.0; |
| |
| double m[36]; |
| for (int i = 0; i < N2; i++) m[i] = 0.0; |
| |
| for (int k = 1; k <= NN; k++) { |
| double mn[36]; |
| for (int i = 0; i < NN; i++) { |
| for (int j = 0; j < NN; j++) { |
| double acc = 0.0; |
| for (int l = 0; l < NN; l++) |
| acc += a[i*NN+l] * m[l*NN+j]; |
| if (i == j) acc += c[NN-k+1]; |
| mn[i*NN+j] = acc; |
| } |
| } |
| double tr = 0.0; |
| for (int i = 0; i < NN; i++) |
| for (int l = 0; l < NN; l++) |
| tr += a[i*NN+l] * mn[l*NN+i]; |
| c[NN-k] = -tr / (double)k; |
| for (int i = 0; i < N2; i++) m[i] = mn[i]; |
| } |
| |
| // Phase 2: Laguerre + deflation + polish (fp64) |
| double diag[6]; |
| for (int i = 0; i < NN; i++) diag[i] = a[i*NN+i]; |
| for (int pass = 0; pass < NN-1; pass++) |
| for (int j = 0; j < NN-1; j++) |
| if (diag[j] > diag[j+1]) { |
| double tmp = diag[j]; diag[j] = diag[j+1]; diag[j+1] = tmp; |
| } |
| for (int i = 0; i < NN; i++) |
| diag[i] += -1e-4 + 2e-4 * (double)i / 5.0; |
| |
| double cl[7]; |
| for (int i = 0; i < 7; i++) cl[i] = c[i]; |
| |
| double roots[6]; |
| |
| for (int ri = 0; ri < NN; ri++) { |
| int deg = NN - ri; |
| double z = diag[ri]; |
| for (int lag = 0; lag < 5; lag++) { |
| double pv = cl[deg], dp = 0.0, d2 = 0.0; |
| for (int j = deg - 1; j >= 0; j--) { |
| d2 = d2 * z + dp; |
| dp = dp * z + pv; |
| pv = pv * z + cl[j]; |
| } |
| if (fabs(pv) > 1e-30) { |
| double G = dp / pv; |
| double H = G * G - 2.0 * d2 / pv; |
| double disc = ((double)(deg-1)) * ((double)deg * H - G * G); |
| if (disc < 0.0) disc = 0.0; |
| double sq = sqrt(disc); |
| double gp = G + sq, gm = G - sq; |
| double den = (fabs(gp) >= fabs(gm)) ? gp : gm; |
| if (fabs(den) > 1e-20) |
| z -= (double)deg / den; |
| } |
| } |
| roots[ri] = z; |
| if (deg > 1) { |
| double b = cl[deg]; |
| for (int j = deg - 1; j > 0; j--) { |
| double bn = cl[j] + z * b; |
| cl[j] = b; |
| b = bn; |
| } |
| cl[0] = b; |
| } |
| } |
| |
| // Newton polish |
| for (int pol = 0; pol < 3; pol++) |
| for (int ri = 0; ri < NN; ri++) { |
| double pv = c[NN], dp = 0.0; |
| for (int j = NN - 1; j >= 0; j--) { |
| dp = dp * roots[ri] + pv; |
| pv = pv * roots[ri] + c[j]; |
| } |
| if (fabs(dp) > 1e-30) |
| roots[ri] -= pv / dp; |
| } |
| |
| // Phase 3: Eigenvectors via interleaved FL+Horner (fp64) |
| float evecs[36]; |
| |
| for (int ei = 0; ei < NN; ei++) { |
| double lam = roots[ei]; |
| double m_loc[36], r_loc[36]; |
| for (int i = 0; i < N2; i++) m_loc[i] = 0.0; |
| |
| for (int k = 1; k <= NN; k++) { |
| double mn_loc[36]; |
| for (int i = 0; i < NN; i++) |
| for (int j = 0; j < NN; j++) { |
| double acc = 0.0; |
| for (int l = 0; l < NN; l++) |
| acc += a[i*NN+l] * m_loc[l*NN+j]; |
| if (i == j) acc += c[NN-k+1]; |
| mn_loc[i*NN+j] = acc; |
| } |
| if (k == 1) |
| for (int i = 0; i < N2; i++) r_loc[i] = mn_loc[i]; |
| else |
| for (int i = 0; i < N2; i++) r_loc[i] = r_loc[i] * lam + mn_loc[i]; |
| for (int i = 0; i < N2; i++) m_loc[i] = mn_loc[i]; |
| } |
| |
| int best_j = 0; |
| double best_norm = -1.0; |
| for (int j = 0; j < NN; j++) { |
| double col_sq = 0.0; |
| for (int i = 0; i < NN; i++) |
| col_sq += r_loc[i*NN+j] * r_loc[i*NN+j]; |
| if (col_sq > best_norm) { best_norm = col_sq; best_j = j; } |
| } |
| double vnorm = 0.0; |
| double vec[6]; |
| for (int i = 0; i < NN; i++) { |
| vec[i] = r_loc[i*NN + best_j]; |
| vnorm += vec[i] * vec[i]; |
| } |
| vnorm = sqrt(vnorm) + 1e-30; |
| for (int i = 0; i < NN; i++) |
| evecs[i*NN + ei] = (float)(vec[i] / vnorm); |
| } |
| |
| // Phase 4: Newton-Schulz (fp32, 2 iters) |
| for (int ns = 0; ns < 2; ns++) { |
| float y[36], t_m[36], vn[36]; |
| for (int i = 0; i < NN; i++) |
| for (int j = 0; j < NN; j++) { |
| float acc = 0.0f; |
| for (int l = 0; l < NN; l++) |
| acc += evecs[l*NN+i] * evecs[l*NN+j]; |
| y[i*NN+j] = acc; |
| } |
| for (int i = 0; i < NN; i++) |
| for (int j = 0; j < NN; j++) |
| t_m[i*NN+j] = ((i==j) ? 3.0f : 0.0f) - y[i*NN+j]; |
| for (int i = 0; i < NN; i++) |
| for (int j = 0; j < NN; j++) { |
| float acc = 0.0f; |
| for (int l = 0; l < NN; l++) |
| acc += evecs[i*NN+l] * t_m[l*NN+j]; |
| vn[i*NN+j] = 0.5f * acc; |
| } |
| for (int i = 0; i < N2; i++) evecs[i] = vn[i]; |
| } |
| |
| // Phase 5: Rayleigh quotient (fp32) |
| float af[36]; |
| for (int i = 0; i < N2; i++) af[i] = (float)a[i]; |
| |
| float evals_local[6]; |
| for (int ei = 0; ei < NN; ei++) { |
| float lam_f = 0.0f; |
| for (int l = 0; l < NN; l++) { |
| float av = 0.0f; |
| for (int mm = 0; mm < NN; mm++) |
| av += af[l*NN+mm] * evecs[mm*NN+ei]; |
| lam_f += evecs[l*NN+ei] * av; |
| } |
| evals_local[ei] = lam_f * (float)scale; |
| } |
| |
| // Sort ascending + permute |
| int perm[6]; |
| for (int i = 0; i < NN; i++) perm[i] = i; |
| for (int pass = 0; pass < NN-1; pass++) |
| for (int j = 0; j < NN-1; j++) |
| if (evals_local[j] > evals_local[j+1]) { |
| float tmp = evals_local[j]; evals_local[j] = evals_local[j+1]; evals_local[j+1] = tmp; |
| int ptmp = perm[j]; perm[j] = perm[j+1]; perm[j+1] = ptmp; |
| } |
| |
| for (int i = 0; i < NN; i++) |
| evals_out[tid * NN + i] = evals_local[i]; |
| for (int j_out = 0; j_out < NN; j_out++) { |
| int j_src = perm[j_out]; |
| for (int i = 0; i < NN; i++) |
| evecs_out[tid * N2 + i*NN + j_out] = evecs[i*NN + j_src]; |
| } |
| } |
| """ |
|
|
| |
| |
| |
|
|
| _kernel = None |
|
|
| def _get_kernel(): |
| global _kernel |
| if _kernel is not None: |
| return _kernel |
| import cupy |
| print(" Compiling via NVRTC...", end=" ", flush=True) |
| _kernel = cupy.RawKernel(_KERNEL_SRC, 'fl_eigh_kernel') |
| |
| _kernel.compile() |
| print("done.") |
| return _kernel |
|
|
|
|
| def fl_eigh_cuda(A: Tensor) -> Tuple[Tensor, Tensor]: |
| """CUDA FL Hybrid Eigendecomposition for [B, 6, 6] symmetric matrices. |
| |
| Uses CuPy RawKernel (NVRTC). Zero-copy PyTorch interop via data_ptr. |
| """ |
| assert A.is_cuda and A.shape[-2:] == (6, 6), f"Need CUDA [B,6,6], got {A.shape}" |
| B = A.shape[0] |
| kernel = _get_kernel() |
|
|
| A_contig = A.contiguous().float() |
| evals = torch.empty(B, 6, device=A.device, dtype=torch.float32) |
| evecs = torch.empty(B, 6, 6, device=A.device, dtype=torch.float32) |
|
|
| import cupy |
| |
| a_ptr = cupy.cuda.MemoryPointer( |
| cupy.cuda.UnownedMemory(A_contig.data_ptr(), A_contig.nelement() * 4, None), 0) |
| ev_ptr = cupy.cuda.MemoryPointer( |
| cupy.cuda.UnownedMemory(evals.data_ptr(), evals.nelement() * 4, None), 0) |
| vc_ptr = cupy.cuda.MemoryPointer( |
| cupy.cuda.UnownedMemory(evecs.data_ptr(), evecs.nelement() * 4, None), 0) |
|
|
| threads = 128 |
| blocks = (B + threads - 1) // threads |
|
|
| |
| stream = cupy.cuda.ExternalStream(torch.cuda.current_stream().cuda_stream) |
| with stream: |
| kernel((blocks,), (threads,), |
| (a_ptr, ev_ptr, vc_ptr, B)) |
|
|
| return evals, evecs |
|
|
|
|
| |
| |
| |
|
|
| def math_test(A, vals, vecs): |
| B,n,_=A.shape; dev=A.device |
| Ad=A.double(); vd=vals.double(); Vd=vecs.double() |
| AV=torch.bmm(Ad,Vd); VL=Vd*vd.unsqueeze(-2) |
| An=Ad.reshape(B,-1).norm(dim=-1,keepdim=True).clamp(min=1e-30) |
| res=(AV-VL).norm(dim=-2)/An |
| VtV=torch.bmm(Vd.mT,Vd); I=torch.eye(n,device=dev,dtype=torch.float64).unsqueeze(0) |
| orth=(VtV-I).reshape(B,-1).norm(dim=-1) |
| recon=torch.bmm(Vd*vd.unsqueeze(-2),Vd.mT) |
| recon_err=(Ad-recon).reshape(B,-1).norm(dim=-1)/An.squeeze(-1) |
| tr_err=(Ad.diagonal(dim1=-2,dim2=-1).sum(-1)-vd.sum(-1)).abs() |
| det_A=torch.linalg.det(Ad); det_err=(det_A-vd.prod(-1)).abs()/det_A.abs().clamp(min=1e-30) |
| return dict(res_max=res.max().item(), res_mean=res.mean().item(), |
| orth_max=orth.max().item(), orth_mean=orth.mean().item(), |
| recon_max=recon_err.max().item(), recon_mean=recon_err.mean().item(), |
| tr_max=tr_err.max().item(), det_max=det_err.max().item()) |
|
|
|
|
| |
| |
| |
|
|
| def sync(): torch.cuda.synchronize() |
| def gt(fn,w=20,r=200): |
| for _ in range(w): fn() |
| sync(); t=time.perf_counter() |
| for _ in range(r): fn() |
| sync(); return (time.perf_counter()-t)/r |
| def fmt(s): |
| if s<1e-3: return f"{s*1e6:.1f}us" |
| if s<1: return f"{s*1e3:.2f}ms" |
| return f"{s:.3f}s" |
|
|
|
|
| def main(): |
| if not torch.cuda.is_available(): sys.exit(1) |
| dev=torch.device('cuda') |
| p=torch.cuda.get_device_properties(0) |
| print("="*72) |
| print(" FL Eigh CUDA Kernel (CuPy/NVRTC)") |
| print("="*72) |
| print(f" {p.name}") |
| print(f" PyTorch {torch.__version__}") |
|
|
| N=6; B=4096 |
| A=(lambda R:(R+R.mT)/2)(torch.randn(B,N,N,device=dev)) |
| rv,rV=torch.linalg.eigh(A) |
|
|
| _get_kernel() |
|
|
| |
| print(f"\n ACCURACY (n={N} B={B})") |
| cv,cV=fl_eigh_cuda(A) |
| ve=(cv-rv).abs().max().item() |
| dots=torch.bmm(rV.double().mT,cV.double()).abs().max(dim=-1).values.min().item() |
| print(f" CUDA FL: val={ve:.1e} align={dots:.6f}") |
|
|
| |
| mc=math_test(A,rv,rV); mf=math_test(A,cv,cV) |
| wins=0 |
| print(f"\n MATH PURITY: CUDA FL vs cuSOLVER") |
| print(f" {'Property':<28} {'cuSOLVER':>10} {'CUDA FL':>10} {'Win':>6}") |
| for key in ['res_max','res_mean','orth_max','orth_mean','recon_max','recon_mean','tr_max','det_max']: |
| vc=mc[key]; vf=mf[key]; w='FL' if vf<vc else 'cuS' |
| if vf<vc: wins+=1 |
| print(f" {key:<28} {vc:>10.1e} {vf:>10.1e} {w:>6}") |
| print(f"\n CUDA FL wins {wins}/8") |
|
|
| |
| print(f"\n THROUGHPUT (n={N} B={B})") |
| tr=gt(lambda:torch.linalg.eigh(A)) |
| tc=gt(lambda:fl_eigh_cuda(A)) |
| print(f" cuSOLVER: {fmt(tr)}") |
| print(f" CUDA FL: {fmt(tc)} ({tr/tc:.2f}x)") |
|
|
| |
| print(f"\n BATCH SCALING (n={N})") |
| print(f" {'B':>6} {'cuSOLVER':>10} {'CUDA FL':>10} {'ratio':>7}") |
| for Bx in [256,512,1024,2048,4096,8192,16384,32768]: |
| try: |
| Ax=(lambda R:(R+R.mT)/2)(torch.randn(Bx,N,N,device=dev)) |
| t1=gt(lambda:torch.linalg.eigh(Ax),10,100) |
| t2=gt(lambda:fl_eigh_cuda(Ax),10,100) |
| print(f" {Bx:>6} {fmt(t1):>10} {fmt(t2):>10} {t1/t2:>6.2f}x") |
| del Ax |
| except RuntimeError: |
| print(f" {Bx:>6} OOM"); torch.cuda.empty_cache() |
|
|
| |
| print(f"\n MEMORY (n={N} B={B})") |
| for lbl,fn in [("cuSOLVER",lambda:torch.linalg.eigh(A)),("CUDA FL",lambda:fl_eigh_cuda(A))]: |
| torch.cuda.empty_cache(); gc.collect(); torch.cuda.reset_peak_memory_stats() |
| base=torch.cuda.memory_allocated(); fn(); sync() |
| print(f" {lbl:<12} {(torch.cuda.max_memory_allocated()-base)/1024**2:.1f}MB") |
|
|
| print("="*72) |
|
|
|
|
| if __name__=='__main__': main() |