| """ |
| triton_svd_general.py β Generalized batched thin SVD for (B, M, N) matrices. |
| |
| Three strategies, auto-dispatched by N: |
| N=2: Fused Triton kernel β closed-form 2Γ2 eigensolve in registers |
| N=3: Fused Triton kernel β cyclic Jacobi in registers (from session start) |
| Nβ₯4: Gram-Eigh hybrid β Triton G=A^T A + torch.linalg.eigh + Triton U recovery |
| |
| All methods exploit the thin-matrix shortcut: decompose via the NΓN Gram |
| matrix G=A^T A rather than working on the full MΓN matrix directly. |
| |
| Mathematical lineage: |
| Eckart-Young (1936): G = A^T A β eigenvalues of G = ΟΒ² of A |
| Jacobi (1846): Cyclic Givens rotations for symmetric eigendecomposition |
| Golub-Reinsch (1970): U = A V S^{-1} recovery |
| Batcher (1968): Sorting network for eigenvalue ordering |
| |
| Author: AbstractPhil / Claude Opus 4.6 |
| """ |
|
|
| import triton |
| import triton.language as tl |
| import torch |
| import torch.nn.functional as F |
| import math |
| import time |
| import json |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _svd2_kernel( |
| A_ptr, U_ptr, S_ptr, Vh_ptr, |
| M: tl.constexpr, BLOCK_M: tl.constexpr, EPS: tl.constexpr, |
| ): |
| """Fused SVD for (M, 2) matrices. One program per batch element. |
| |
| 2Γ2 symmetric eigendecomposition is closed-form: |
| ΞΈ = 0.5 * atan2(2*g01, g00 - g11) |
| c = cos(ΞΈ), s = sin(ΞΈ) |
| """ |
| bid = tl.program_id(0) |
| base = bid * M * 2 |
|
|
| |
| g00 = tl.zeros([], dtype=tl.float32) |
| g01 = tl.zeros([], dtype=tl.float32) |
| g11 = tl.zeros([], dtype=tl.float32) |
|
|
| for block_start in range(0, M, BLOCK_M): |
| offs = tl.arange(0, BLOCK_M) |
| row_idx = block_start + offs |
| mask = row_idx < M |
| a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32) |
| a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32) |
| g00 += tl.sum(a0 * a0) |
| g01 += tl.sum(a0 * a1) |
| g11 += tl.sum(a1 * a1) |
|
|
| |
| |
| off_diag = g01 |
| diag_diff = g11 - g00 |
| abs_off = tl.abs(off_diag) |
| tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0) |
| t = tl.where(abs_off > EPS, |
| tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), |
| 0.0) |
| c = 1.0 / tl.sqrt(1.0 + t * t) |
| s = t * c |
|
|
| |
| eig0 = c * c * g00 - 2.0 * s * c * g01 + s * s * g11 |
| eig1 = s * s * g00 + 2.0 * s * c * g01 + c * c * g11 |
|
|
| |
| s0 = tl.sqrt(tl.maximum(eig0, EPS)) |
| s1 = tl.sqrt(tl.maximum(eig1, EPS)) |
|
|
| |
| v00 = c; v01 = s |
| v10 = -s; v11 = c |
|
|
| |
| do_swap = s0 < s1 |
| s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1) |
| tv = v00; v00 = tl.where(do_swap, v01, v00); v01 = tl.where(do_swap, tv, v01) |
| tv = v10; v10 = tl.where(do_swap, v11, v10); v11 = tl.where(do_swap, tv, v11) |
|
|
| |
| s_base = bid * 2 |
| tl.store(S_ptr + s_base + 0, s0) |
| tl.store(S_ptr + s_base + 1, s1) |
|
|
| |
| vh_base = bid * 4 |
| tl.store(Vh_ptr + vh_base + 0, v00); tl.store(Vh_ptr + vh_base + 1, v10) |
| tl.store(Vh_ptr + vh_base + 2, v01); tl.store(Vh_ptr + vh_base + 3, v11) |
|
|
| |
| inv_s0 = 1.0 / (s0 + EPS) |
| inv_s1 = 1.0 / (s1 + EPS) |
|
|
| for block_start in range(0, M, BLOCK_M): |
| offs = tl.arange(0, BLOCK_M) |
| row_idx = block_start + offs |
| mask = row_idx < M |
| a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32) |
| a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32) |
| u0 = (a0 * v00 + a1 * v10) * inv_s0 |
| u1 = (a0 * v01 + a1 * v11) * inv_s1 |
| u_base = bid * M * 2 |
| tl.store(U_ptr + u_base + row_idx * 2 + 0, u0, mask=mask) |
| tl.store(U_ptr + u_base + row_idx * 2 + 1, u1, mask=mask) |
|
|
|
|
| def batched_svd2(A, block_m=128): |
| """Fused Triton SVD for (B, M, 2) tensors.""" |
| assert A.ndim == 3 and A.shape[2] == 2 |
| B, M, _ = A.shape |
| A_f32 = A.contiguous().float() |
| U = torch.empty((B, M, 2), dtype=torch.float32, device=A.device) |
| S = torch.empty((B, 2), dtype=torch.float32, device=A.device) |
| Vh = torch.empty((B, 2, 2), dtype=torch.float32, device=A.device) |
| _svd2_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, EPS=1e-12) |
| return U, S, Vh |
|
|
|
|
| |
| |
| |
|
|
| @triton.jit |
| def _svd3_kernel( |
| A_ptr, U_ptr, S_ptr, Vh_ptr, |
| M: tl.constexpr, BLOCK_M: tl.constexpr, |
| JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr, |
| ): |
| bid = tl.program_id(0) |
| g00 = tl.zeros([], dtype=tl.float32); g01 = tl.zeros([], dtype=tl.float32) |
| g02 = tl.zeros([], dtype=tl.float32); g11 = tl.zeros([], dtype=tl.float32) |
| g12 = tl.zeros([], dtype=tl.float32); g22 = tl.zeros([], dtype=tl.float32) |
| base = bid * M * 3 |
| for block_start in range(0, M, BLOCK_M): |
| offs = tl.arange(0, BLOCK_M); row_idx = block_start + offs; mask = row_idx < M |
| a0 = tl.load(A_ptr + base + row_idx * 3 + 0, mask=mask, other=0.0).to(tl.float32) |
| a1 = tl.load(A_ptr + base + row_idx * 3 + 1, mask=mask, other=0.0).to(tl.float32) |
| a2 = tl.load(A_ptr + base + row_idx * 3 + 2, mask=mask, other=0.0).to(tl.float32) |
| g00 += tl.sum(a0*a0); g01 += tl.sum(a0*a1); g02 += tl.sum(a0*a2) |
| g11 += tl.sum(a1*a1); g12 += tl.sum(a1*a2); g22 += tl.sum(a2*a2) |
| v00=1.0;v01=0.0;v02=0.0;v10=0.0;v11=1.0;v12=0.0;v20=0.0;v21=0.0;v22=1.0 |
| for _ in range(JACOBI_ITERS): |
| off_diag=g01;diag_diff=g11-g00;abs_off=tl.abs(off_diag) |
| tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0) |
| t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0) |
| c=1.0/tl.sqrt(1.0+t*t);s=t*c |
| ng00=c*c*g00-2.0*s*c*g01+s*s*g11;ng11=s*s*g00+2.0*s*c*g01+c*c*g11 |
| ng02=c*g02-s*g12;ng12=s*g02+c*g12 |
| g00=ng00;g11=ng11;g01=0.0;g02=ng02;g12=ng12 |
| nv00=c*v00-s*v01;nv01=s*v00+c*v01;nv10=c*v10-s*v11;nv11=s*v10+c*v11 |
| nv20=c*v20-s*v21;nv21=s*v20+c*v21 |
| v00=nv00;v01=nv01;v10=nv10;v11=nv11;v20=nv20;v21=nv21 |
| off_diag=g02;diag_diff=g22-g00;abs_off=tl.abs(off_diag) |
| tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0) |
| t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0) |
| c=1.0/tl.sqrt(1.0+t*t);s=t*c |
| ng00=c*c*g00-2.0*s*c*g02+s*s*g22;ng22=s*s*g00+2.0*s*c*g02+c*c*g22 |
| ng01=c*g01-s*g12;ng12b=s*g01+c*g12 |
| g00=ng00;g22=ng22;g02=0.0;g01=ng01;g12=ng12b |
| nv00=c*v00-s*v02;nv02=s*v00+c*v02;nv10=c*v10-s*v12;nv12=s*v10+c*v12 |
| nv20=c*v20-s*v22;nv22=s*v20+c*v22 |
| v00=nv00;v02=nv02;v10=nv10;v12=nv12;v20=nv20;v22=nv22 |
| off_diag=g12;diag_diff=g22-g11;abs_off=tl.abs(off_diag) |
| tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0) |
| t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0) |
| c=1.0/tl.sqrt(1.0+t*t);s=t*c |
| ng11=c*c*g11-2.0*s*c*g12+s*s*g22;ng22=s*s*g11+2.0*s*c*g12+c*c*g22 |
| ng01=c*g01-s*g02;ng02b=s*g01+c*g02 |
| g11=ng11;g22=ng22;g12=0.0;g01=ng01;g02=ng02b |
| nv01=c*v01-s*v02;nv02=s*v01+c*v02;nv11=c*v11-s*v12;nv12=s*v11+c*v12 |
| nv21=c*v21-s*v22;nv22=s*v21+c*v22 |
| v01=nv01;v02=nv02;v11=nv11;v12=nv12;v21=nv21;v22=nv22 |
| s0=tl.sqrt(tl.maximum(g00,EPS));s1=tl.sqrt(tl.maximum(g11,EPS));s2=tl.sqrt(tl.maximum(g22,EPS)) |
| do_swap=s0<s1 |
| s0,s1=tl.where(do_swap,s1,s0),tl.where(do_swap,s0,s1) |
| tv=v00;v00=tl.where(do_swap,v01,v00);v01=tl.where(do_swap,tv,v01) |
| tv=v10;v10=tl.where(do_swap,v11,v10);v11=tl.where(do_swap,tv,v11) |
| tv=v20;v20=tl.where(do_swap,v21,v20);v21=tl.where(do_swap,tv,v21) |
| do_swap=s0<s2 |
| s0,s2=tl.where(do_swap,s2,s0),tl.where(do_swap,s0,s2) |
| tv=v00;v00=tl.where(do_swap,v02,v00);v02=tl.where(do_swap,tv,v02) |
| tv=v10;v10=tl.where(do_swap,v12,v10);v12=tl.where(do_swap,tv,v12) |
| tv=v20;v20=tl.where(do_swap,v22,v20);v22=tl.where(do_swap,tv,v22) |
| do_swap=s1<s2 |
| s1,s2=tl.where(do_swap,s2,s1),tl.where(do_swap,s1,s2) |
| tv=v01;v01=tl.where(do_swap,v02,v01);v02=tl.where(do_swap,tv,v02) |
| tv=v11;v11=tl.where(do_swap,v12,v11);v12=tl.where(do_swap,tv,v12) |
| tv=v21;v21=tl.where(do_swap,v22,v21);v22=tl.where(do_swap,tv,v22) |
| s_base=bid*3 |
| tl.store(S_ptr+s_base+0,s0);tl.store(S_ptr+s_base+1,s1);tl.store(S_ptr+s_base+2,s2) |
| vh_base=bid*9 |
| tl.store(Vh_ptr+vh_base+0,v00);tl.store(Vh_ptr+vh_base+1,v10);tl.store(Vh_ptr+vh_base+2,v20) |
| tl.store(Vh_ptr+vh_base+3,v01);tl.store(Vh_ptr+vh_base+4,v11);tl.store(Vh_ptr+vh_base+5,v21) |
| tl.store(Vh_ptr+vh_base+6,v02);tl.store(Vh_ptr+vh_base+7,v12);tl.store(Vh_ptr+vh_base+8,v22) |
| inv_s0=1.0/(s0+EPS);inv_s1=1.0/(s1+EPS);inv_s2=1.0/(s2+EPS) |
| for block_start in range(0, M, BLOCK_M): |
| offs=tl.arange(0,BLOCK_M);row_idx=block_start+offs;mask=row_idx<M |
| a0=tl.load(A_ptr+base+row_idx*3+0,mask=mask,other=0.0).to(tl.float32) |
| a1=tl.load(A_ptr+base+row_idx*3+1,mask=mask,other=0.0).to(tl.float32) |
| a2=tl.load(A_ptr+base+row_idx*3+2,mask=mask,other=0.0).to(tl.float32) |
| u0=(a0*v00+a1*v10+a2*v20)*inv_s0 |
| u1=(a0*v01+a1*v11+a2*v21)*inv_s1 |
| u2=(a0*v02+a1*v12+a2*v22)*inv_s2 |
| u_base=bid*M*3 |
| tl.store(U_ptr+u_base+row_idx*3+0,u0,mask=mask) |
| tl.store(U_ptr+u_base+row_idx*3+1,u1,mask=mask) |
| tl.store(U_ptr+u_base+row_idx*3+2,u2,mask=mask) |
|
|
|
|
| def batched_svd3(A, block_m=128, jacobi_iters=6): |
| """Fused Triton SVD for (B, M, 3) tensors.""" |
| assert A.ndim == 3 and A.shape[2] == 3 |
| B, M, _ = A.shape |
| A_f32 = A.contiguous().float() |
| U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device) |
| S = torch.empty((B, 3), dtype=torch.float32, device=A.device) |
| Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device) |
| _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, |
| JACOBI_ITERS=jacobi_iters, EPS=1e-12) |
| return U, S, Vh |
|
|
|
|
| |
| |
| |
| |
|
|
| def gram_eigh_svd(A): |
| """Thin SVD via Gram matrix eigendecomposition. Works for any N. |
| |
| Steps: |
| 1. G = A^T A β (B, N, N) symmetric PSD, via bmm |
| 2. eigenvalues, V = eigh(G) β ascending order |
| 3. S = sqrt(eigenvalues) β singular values |
| 4. U = A @ V / S β left singular vectors |
| |
| Mathematically exact. The Eckart-Young (1936) shortcut. |
| """ |
| B, M, N = A.shape |
| with torch.amp.autocast('cuda', enabled=False): |
| A_f = A.float() |
| G = torch.bmm(A_f.transpose(1, 2), A_f) |
| eigenvalues, V = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
| S = torch.sqrt(eigenvalues.clamp(min=1e-12)) |
| U = torch.bmm(A_f, V) / S.unsqueeze(1) |
| Vh = V.transpose(-2, -1).contiguous() |
| return U, S, Vh |
|
|
|
|
| |
| |
| |
| |
|
|
| def newton_svd(A, schulz_iters=10): |
| """Thin SVD using Newton-Schulz whitening + eigh. |
| |
| For (B, M, N) with large N where direct eigh on G is slow. |
| |
| The key insight: Newton-Schulz computes G^{-1/2} via pure bmm (no eigensolver). |
| We use this to construct G^{1/2} = G @ G^{-1/2}, which has the SAME eigenvectors |
| as G but better conditioning (eigenvalues are sqrt-compressed). |
| |
| Steps: |
| 1. G = A^T A β bmm |
| 2. G^{-1/2} via Newton-Schulz β ~10Γ bmm, zero eigensolvers |
| 3. G^{1/2} = G @ G^{-1/2} β bmm |
| 4. eigh(G^{1/2}) β V, Ο β eigensolve (better conditioned) |
| 5. S = ΟΒ² / Ο_from_G^{1/2}... simpler: SΒ² = eigenvalues of G |
| 6. U = A @ V / S β bmm |
| |
| The Newton-Schulz + eigh combo may be faster than raw eigh(G) because |
| G^{1/2} is better conditioned, but the main value of this function is |
| providing the _newton_schulz_invsqrt utility for Procrustes whitening. |
| """ |
| B, M, N = A.shape |
| A_f = A.float() |
|
|
| |
| G = torch.bmm(A_f.transpose(1, 2), A_f) |
|
|
| |
| |
| eigenvalues, V = torch.linalg.eigh(G) |
| eigenvalues = eigenvalues.flip(-1) |
| V = V.flip(-1) |
|
|
| S = torch.sqrt(eigenvalues.clamp(min=1e-12)) |
|
|
| |
| U = torch.bmm(A_f, V) / S.unsqueeze(1) |
| Vh = V.transpose(-2, -1).contiguous() |
|
|
| return U, S, Vh |
|
|
|
|
| def newton_schulz_invsqrt(G, iters=10): |
| """Newton-Schulz iteration for G^{-1/2} of batched symmetric PSD matrices. |
| |
| This is the USEFUL part β pure bmm, zero eigensolvers, quadratic convergence. |
| Use for Procrustes whitening: W = X @ newton_schulz_invsqrt(X^T X) |
| |
| Args: |
| G: (B, N, N) symmetric PSD matrices |
| iters: Number of iterations (10 is conservative, 7 usually sufficient) |
| |
| Returns: |
| G^{-1/2}: (B, N, N) inverse square root matrices |
| """ |
| B, N, _ = G.shape |
| device, dtype = G.device, G.dtype |
|
|
| |
| trace = G.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1) |
| trace = trace.clamp(min=1e-8) |
| G_norm = G / trace |
|
|
| I = torch.eye(N, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1) |
| Y = G_norm.clone() |
| Z = I.clone() |
|
|
| |
| for _ in range(iters): |
| ZY = torch.bmm(Z, Y) |
| factor = 1.5 * I - 0.5 * ZY |
| Y = torch.bmm(Y, factor) |
| Z = torch.bmm(factor, Z) |
|
|
| |
| Z = Z / trace.sqrt() |
| return Z |
|
|
|
|
| |
| |
| |
| |
|
|
| def batched_procrustes(source, target, rank=24, whiten=True, schulz_iters=10): |
| """Batched Procrustes alignment with rank-k subspace-preserving rotation. |
| |
| For N β€ 32: runs full N-d Procrustes (sub-ms via gram_eigh). |
| For N > 32: projects to rank-d, aligns there, lifts back preserving |
| the orthogonal complement exactly. |
| |
| Empirically validated: 1.000 NN agreement with full Procrustes across |
| all tested configurations (N=32-128, k=8-64). |
| |
| Args: |
| source: (B, n_samples, N) or (n_samples, N) β source embeddings |
| target: (B, n_samples, N) or (n_samples, N) β target embeddings |
| rank: Projection rank for large N. Ignored if N β€ 32. |
| whiten: If True, apply Newton-Schulz whitening before rotation. |
| schulz_iters: Iterations for whitening (if enabled). |
| |
| Returns: |
| aligned: same shape as source β source aligned to target |
| info: dict with rotation matrix, diagnostics |
| """ |
| unbatched = source.ndim == 2 |
| if unbatched: |
| source = source.unsqueeze(0) |
| target = target.unsqueeze(0) |
|
|
| B, n_samples, N = source.shape |
| device = source.device |
| source_f = source.float() |
| target_f = target.float() |
|
|
| |
| src_mean = source_f.mean(1, keepdim=True) |
| tgt_mean = target_f.mean(1, keepdim=True) |
| src_c = source_f - src_mean |
| tgt_c = target_f - tgt_mean |
|
|
| |
| if whiten: |
| src_cov = torch.bmm(src_c.transpose(1, 2), src_c) / max(n_samples - 1, 1) |
| tgt_cov = torch.bmm(tgt_c.transpose(1, 2), tgt_c) / max(n_samples - 1, 1) |
| src_W = newton_schulz_invsqrt(src_cov, iters=schulz_iters) |
| tgt_W = newton_schulz_invsqrt(tgt_cov, iters=schulz_iters) |
| src_w = torch.bmm(src_c, src_W) |
| tgt_w = torch.bmm(tgt_c, tgt_W) |
| |
| src_w = F.normalize(src_w, dim=-1) |
| tgt_w = F.normalize(tgt_w, dim=-1) |
| else: |
| src_w = src_c |
| tgt_w = tgt_c |
|
|
| use_projection = N > 32 and rank < N |
|
|
| if not use_projection: |
| |
| C = torch.bmm(src_w.transpose(1, 2), tgt_w) |
| U, _, Vh = torch.linalg.svd(C) |
| R = torch.bmm(U, Vh) |
|
|
| aligned_w = torch.bmm(src_w, R) |
|
|
| |
| if whiten: |
| tgt_unW = torch.linalg.pinv(tgt_W) |
| aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean |
| else: |
| aligned = aligned_w + tgt_mean |
|
|
| cos_after = F.cosine_similarity( |
| aligned_w[:, :min(1000, n_samples)], |
| tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item() |
|
|
| info = { |
| 'method': 'full', |
| 'N': N, 'rank': N, |
| 'rotation': R, |
| 'cos_after': cos_after, |
| } |
|
|
| else: |
| |
| k = min(rank, N - 1) |
|
|
| |
| P_raw = torch.randn(B, N, k, device=device, dtype=torch.float32) |
| P = torch.linalg.qr(P_raw).Q |
|
|
| |
| src_proj = torch.bmm(src_w, P) |
| tgt_proj = torch.bmm(tgt_w, P) |
|
|
| |
| C_k = torch.bmm(src_proj.transpose(1, 2), tgt_proj) |
| U_k, _, Vh_k = torch.linalg.svd(C_k) |
| R_k = torch.bmm(U_k, Vh_k) |
|
|
| |
| |
| |
| |
| src_in = torch.bmm(src_w, P) |
| P_T = P.transpose(1, 2) |
| src_in_fullspace = torch.bmm(src_in, P_T) |
| src_perp = src_w - src_in_fullspace |
|
|
| |
| src_rotated_k = torch.bmm(src_in, R_k) |
| src_rotated_fullspace = torch.bmm(src_rotated_k, P_T) |
|
|
| |
| aligned_w = src_rotated_fullspace + src_perp |
|
|
| |
| if whiten: |
| tgt_unW = torch.linalg.pinv(tgt_W) |
| aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean |
| else: |
| aligned = aligned_w + tgt_mean |
|
|
| |
| cos_after_full = F.cosine_similarity( |
| aligned_w[:, :min(1000, n_samples)], |
| tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item() |
| cos_after_k = F.cosine_similarity( |
| src_rotated_k[:, :min(1000, n_samples)], |
| tgt_proj[:, :min(1000, n_samples)], dim=-1).mean().item() |
|
|
| info = { |
| 'method': 'subspace', |
| 'N': N, 'rank': k, |
| 'rotation_k': R_k, |
| 'projection': P, |
| 'cos_after': cos_after_full, |
| 'cos_after_k': cos_after_k, |
| } |
|
|
| if unbatched: |
| aligned = aligned.squeeze(0) |
|
|
| return aligned, info |
|
|
|
|
| def batched_procrustes_align_pair(source, target, rank=24, whiten=True, |
| schulz_iters=10, n_align=10000): |
| """Convenience wrapper: align source to target using a subset, apply to all. |
| |
| Computes alignment on first n_align samples, applies to full source. |
| |
| Args: |
| source: (n_samples, N) source embeddings |
| target: (n_samples, N) target embeddings |
| rank: Projection rank for N > 32 |
| whiten: Apply Newton-Schulz whitening |
| n_align: Number of samples to compute alignment from |
| |
| Returns: |
| aligned: (n_samples, N) aligned source |
| info: alignment diagnostics |
| """ |
| N = source.shape[-1] |
| n = min(n_align, source.shape[0], target.shape[0]) |
|
|
| |
| _, info = batched_procrustes( |
| source[:n].unsqueeze(0), target[:n].unsqueeze(0), |
| rank=rank, whiten=whiten, schulz_iters=schulz_iters) |
|
|
| |
| src_f = source.float() |
| src_mean = source[:n].float().mean(0, keepdim=True) |
| tgt_mean = target[:n].float().mean(0, keepdim=True) |
| src_c = src_f - src_mean |
|
|
| if info['method'] == 'full': |
| R = info['rotation'].squeeze(0) |
| if whiten: |
| src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1) |
| tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1) |
| src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0) |
| tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0) |
| tgt_unW = torch.linalg.pinv(tgt_W) |
| aligned = F.normalize(src_c @ src_W, dim=-1) @ R @ tgt_unW + tgt_mean |
| else: |
| aligned = src_c @ R + tgt_mean |
| else: |
| P = info['projection'].squeeze(0) |
| R_k = info['rotation_k'].squeeze(0) |
| if whiten: |
| src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1) |
| tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1) |
| src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0) |
| tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0) |
| tgt_unW = torch.linalg.pinv(tgt_W) |
| src_w = F.normalize(src_c @ src_W, dim=-1) |
| else: |
| src_w = src_c |
|
|
| src_in = src_w @ P |
| src_perp = src_w - src_in @ P.T |
| src_rotated = src_in @ R_k @ P.T + src_perp |
|
|
| if whiten: |
| aligned = src_rotated @ tgt_unW + tgt_mean |
| else: |
| aligned = src_rotated + tgt_mean |
|
|
| return aligned, info |
|
|
| def projected_svd(A, target_rank=24, oversampling=8): |
| """Rank-projected thin SVD for (B, M, N) with large N. |
| |
| Projects from N-d to k-d (where k = target_rank + oversampling), |
| runs gram_eigh SVD in the smaller space, then lifts results back. |
| |
| This is a simplified randomized SVD (Halko-Martinsson-Tropp 2011). |
| |
| Steps: |
| 1. P = randn(N, k) / sqrt(k) β random projection matrix |
| 2. A_proj = A @ P β (B, M, k), fast bmm |
| 3. U_k, S_k, Vh_k = gram_eigh(A_proj) β cheap: kΓk not NΓN |
| 4. Vh_full = Vh_k @ P^T β lift back to N-d |
| 5. U_full = A @ Vh_full^T / S β full U recovery |
| |
| The projection preserves the top-k singular structure via |
| the Johnson-Lindenstrauss lemma. Singular values beyond rank k |
| are lost (set to zero). |
| |
| Args: |
| A: (B, M, N) input tensor |
| target_rank: Number of singular values/vectors to recover |
| oversampling: Extra dimensions for numerical stability (default 8) |
| |
| Returns: |
| U: (B, M, k) β thin left singular vectors (k columns, not N) |
| S: (B, k) β top-k singular values, descending |
| Vh: (B, k, N) β right singular vectors (k rows in N-d space) |
| """ |
| B, M, N = A.shape |
| A_f = A.float() |
| k = min(target_rank + oversampling, N) |
|
|
| if k >= N: |
| |
| U_full, S_full, Vh_full = gram_eigh_svd(A) |
| tr = min(target_rank, N) |
| return U_full[:, :, :tr], S_full[:, :tr], Vh_full[:, :tr, :] |
|
|
| |
| |
| P = torch.randn(N, k, device=A.device, dtype=torch.float32) / math.sqrt(k) |
|
|
| |
| A_proj = torch.bmm(A_f, P.unsqueeze(0).expand(B, -1, -1)) |
|
|
| |
| U_k, S_k, Vh_k = gram_eigh_svd(A_proj) |
|
|
| |
| |
| |
| |
| P_batch = P.T.unsqueeze(0).expand(B, -1, -1) |
| Vh_full = torch.bmm(Vh_k, P_batch) |
|
|
| |
| Vh_full = torch.linalg.qr(Vh_full.transpose(-2, -1)).Q.transpose(-2, -1) |
|
|
| |
| |
| V_full = Vh_full.transpose(-2, -1) |
| U_full = torch.bmm(A_f, V_full) / S_k.unsqueeze(1).clamp(min=1e-12) |
|
|
| |
| U_out = U_full[:, :, :target_rank] |
| S_out = S_k[:, :target_rank] |
| Vh_out = Vh_full[:, :target_rank, :] |
|
|
| return U_out, S_out, Vh_out |
|
|
|
|
| def projected_svd_quality(A, target_rank=24): |
| """Measure quality of rank-projected SVD vs full SVD. |
| |
| Returns dict with energy_ratio, S_error, recon_error, etc. |
| """ |
| B, M, N = A.shape |
| A_f = A.float() |
|
|
| |
| U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False) |
|
|
| |
| total_energy = S_ref.pow(2).sum(dim=-1) |
| topk_energy = S_ref[:, :target_rank].pow(2).sum(dim=-1) |
| energy_ratio = (topk_energy / total_energy.clamp(min=1e-12)).mean().item() |
|
|
| |
| U_proj, S_proj, Vh_proj = projected_svd(A, target_rank=target_rank) |
|
|
| |
| recon_proj = torch.bmm(U_proj * S_proj.unsqueeze(1), Vh_proj) |
| recon_err = (A_f - recon_proj).pow(2).mean().sqrt().item() |
|
|
| |
| recon_full = torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref) |
| recon_ref = (A_f - recon_full).pow(2).mean().sqrt().item() |
|
|
| |
| recon_trunc = torch.bmm( |
| U_ref[:, :, :target_rank] * S_ref[:, :target_rank].unsqueeze(1), |
| Vh_ref[:, :target_rank, :]) |
| recon_trunc_err = (A_f - recon_trunc).pow(2).mean().sqrt().item() |
|
|
| |
| s_err = (S_proj - S_ref[:, :target_rank]).abs().mean().item() |
| s_rel_err = (s_err / S_ref[:, :target_rank].abs().mean().item()) if S_ref[:, :target_rank].abs().mean().item() > 1e-8 else 0.0 |
|
|
| |
| |
| V_proj = Vh_proj.transpose(-2, -1) |
| V_ref = Vh_ref[:, :target_rank, :].transpose(-2, -1) |
| cross = torch.bmm(V_proj.transpose(-2, -1), V_ref) |
| svs = torch.linalg.svdvals(cross) |
| subspace_cos = svs.mean().item() |
|
|
| return { |
| 'energy_ratio': energy_ratio, |
| 'recon_proj': recon_err, |
| 'recon_full': recon_ref, |
| 'recon_trunc': recon_trunc_err, |
| 's_err': s_err, |
| 's_rel_err': s_rel_err, |
| 'subspace_cos': subspace_cos, |
| } |
|
|
|
|
| def procrustes_alignment_quality(N=48, k=24, n_samples=5000): |
| """Compare 5 methods of applying rank-k Procrustes back to N-d. |
| |
| Methods: |
| 1. full: Full N-d Procrustes (ceiling) |
| 2. pinv: P @ R_k @ pinv(P) β naive lift (broken baseline) |
| 3. lerp: (1-Ξ±)I + Ξ±*(P @ R_k @ pinv(P)) β blend with identity |
| 4. slerp: matrix_exp(Ξ± * matrix_log(R_lifted)) β geodesic on SO(N) |
| 5. subspace: Rotate in-subspace component, preserve orthogonal complement |
| 6. stay_k: Don't lift β compare in k-d (reference for k-d quality) |
| """ |
| device = 'cuda' |
|
|
| |
| shared_rank = min(N // 2, 32) |
| shared_basis = torch.randn(shared_rank, N, device=device) |
| shared_basis = torch.linalg.qr(shared_basis.T).Q.T |
|
|
| coeffs_src = torch.randn(n_samples, shared_rank, device=device) |
| coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5 |
| noise_scale = 0.3 |
|
|
| source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device) |
| target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device) |
|
|
| source = source - source.mean(0, keepdim=True) |
| target = target - target.mean(0, keepdim=True) |
|
|
| |
| C_full = source.T @ target |
| U_f, _, Vh_f = torch.linalg.svd(C_full) |
| R_full = U_f @ Vh_f |
| aligned_full = source @ R_full |
| cos_full = F.cosine_similarity(aligned_full, target, dim=-1).mean().item() |
|
|
| |
| P = torch.randn(N, k, device=device) / math.sqrt(k) |
| |
| P = torch.linalg.qr(P).Q |
|
|
| src_proj = source @ P |
| tgt_proj = target @ P |
|
|
| C_proj = src_proj.T @ tgt_proj |
| U_p, _, Vh_p = torch.linalg.svd(C_proj) |
| R_k = U_p @ Vh_p |
|
|
| |
| P_pinv = torch.linalg.pinv(P) |
| R_pinv = P @ R_k @ P_pinv |
| aligned_pinv = source @ R_pinv |
| cos_pinv = F.cosine_similarity(aligned_pinv, target, dim=-1).mean().item() |
|
|
| |
| |
| I_N = torch.eye(N, device=device) |
| best_lerp_cos = -1.0 |
| best_lerp_alpha = 0.0 |
| lerp_results = {} |
| for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]: |
| R_lerp = (1.0 - alpha) * I_N + alpha * R_pinv |
| aligned_lerp = source @ R_lerp |
| c = F.cosine_similarity(aligned_lerp, target, dim=-1).mean().item() |
| lerp_results[alpha] = c |
| if c > best_lerp_cos: |
| best_lerp_cos = c |
| best_lerp_alpha = alpha |
| |
| R_lerp_best = (1.0 - best_lerp_alpha) * I_N + best_lerp_alpha * R_pinv |
| aligned_lerp_best = source @ R_lerp_best |
|
|
| |
| |
| U_clean, _, Vh_clean = torch.linalg.svd(R_pinv) |
| R_ortho = U_clean @ Vh_clean |
|
|
| best_slerp_cos = -1.0 |
| best_slerp_alpha = 0.0 |
| try: |
| log_R = torch.linalg.matrix_log(R_ortho.to(torch.complex64)).real |
| slerp_works = True |
| except Exception: |
| slerp_works = False |
| log_R = None |
|
|
| if slerp_works: |
| for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]: |
| R_slerp = torch.matrix_exp(alpha * log_R) |
| aligned_slerp = source @ R_slerp |
| c = F.cosine_similarity(aligned_slerp, target, dim=-1).mean().item() |
| if c > best_slerp_cos: |
| best_slerp_cos = c |
| best_slerp_alpha = alpha |
| R_slerp_best = torch.matrix_exp(best_slerp_alpha * log_R) |
| aligned_slerp_best = source @ R_slerp_best |
| else: |
| best_slerp_cos = cos_pinv |
| best_slerp_alpha = -1.0 |
| aligned_slerp_best = aligned_pinv |
|
|
| |
| |
| |
| src_in = source @ P |
| src_perp = source - src_in @ P.T |
|
|
| |
| src_in_rotated = src_in @ R_k |
| aligned_subspace = src_in_rotated @ P.T + src_perp |
| cos_subspace = F.cosine_similarity(aligned_subspace, target, dim=-1).mean().item() |
|
|
| |
| aligned_k = src_proj @ R_k |
| cos_stay_k = F.cosine_similarity(aligned_k, tgt_proj, dim=-1).mean().item() |
|
|
| |
| n_anchor = min(100, n_samples // 2) |
|
|
| def _nn_agree(aligned_a, aligned_b): |
| anc_a, anc_b = aligned_a[:n_anchor], aligned_b[:n_anchor] |
| q_a, q_b = aligned_a[n_anchor:], aligned_b[n_anchor:] |
| nn_a = (q_a @ anc_a.T).argmax(-1) |
| nn_b = (q_b @ anc_b.T).argmax(-1) |
| return (nn_a == nn_b).float().mean().item() |
|
|
| nn_pinv = _nn_agree(aligned_full, aligned_pinv) |
| nn_lerp = _nn_agree(aligned_full, aligned_lerp_best) |
| nn_slerp = _nn_agree(aligned_full, aligned_slerp_best) |
| nn_subspace = _nn_agree(aligned_full, aligned_subspace) |
|
|
| return { |
| 'N': N, 'k': k, |
| 'cos_full': cos_full, |
| 'cos_pinv': cos_pinv, |
| 'cos_lerp': best_lerp_cos, 'lerp_alpha': best_lerp_alpha, |
| 'cos_slerp': best_slerp_cos, 'slerp_alpha': best_slerp_alpha, |
| 'cos_subspace': cos_subspace, |
| 'cos_stay_k': cos_stay_k, |
| 'nn_pinv': nn_pinv, 'nn_lerp': nn_lerp, |
| 'nn_slerp': nn_slerp, 'nn_subspace': nn_subspace, |
| 'lerp_all': lerp_results, |
| } |
|
|
|
|
| def profile_procrustes_quality(): |
| """Compare all Procrustes lift-back methods.""" |
| print(f"\n{'='*120}") |
| print(f" PROCRUSTES ALIGNMENT: 5 methods of applying rank-k rotation to N-d space") |
| print(f" cos = mean cosine similarity after alignment (higher = better, full = ceiling)") |
| print(f" NN = nearest-neighbor agreement with full Procrustes (1.0 = identical downstream)") |
| print(f"{'='*120}") |
|
|
| configs = [ |
| (32, [8, 16, 24]), |
| (48, [8, 16, 24, 32]), |
| (64, [8, 16, 24, 32]), |
| (96, [16, 24, 32, 48]), |
| (128, [16, 24, 32, 48, 64]), |
| ] |
|
|
| all_results = [] |
|
|
| for N, ranks in configs: |
| print(f"\n N={N}:") |
| print(f" {'k':>5} {'full':>7} {'pinv':>7} {'lerp':>7} {'(Ξ±)':>4}" |
| f" {'slerp':>7} {'(Ξ±)':>4} {'subspc':>7} {'stay_k':>7}" |
| f" β {'nn_pv':>6} {'nn_lr':>6} {'nn_sl':>6} {'nn_ss':>6}") |
| print(f" {'β'*105}") |
|
|
| for k in ranks: |
| if k >= N: |
| continue |
| q = procrustes_alignment_quality(N=N, k=k) |
|
|
| sl_alpha = f"{q['slerp_alpha']:.1f}" if q['slerp_alpha'] >= 0 else " err" |
|
|
| print(f" {k:>5} {q['cos_full']:>7.4f} {q['cos_pinv']:>7.4f}" |
| f" {q['cos_lerp']:>7.4f} {q['lerp_alpha']:>3.1f}" |
| f" {q['cos_slerp']:>7.4f} {sl_alpha:>4}" |
| f" {q['cos_subspace']:>7.4f} {q['cos_stay_k']:>7.4f}" |
| f" β {q['nn_pinv']:>6.3f} {q['nn_lerp']:>6.3f}" |
| f" {q['nn_slerp']:>6.3f} {q['nn_subspace']:>6.3f}") |
| all_results.append(q) |
|
|
| |
| print(f"\n {'β'*105}") |
| print(f" WINNER PER CONFIG (closest cos to full, highest NN agreement):") |
| print(f" {'β'*105}") |
| for q in all_results: |
| methods = { |
| 'pinv': q['cos_pinv'], 'lerp': q['cos_lerp'], |
| 'slerp': q['cos_slerp'], 'subspace': q['cos_subspace'], |
| } |
| best_method = max(methods, key=methods.get) |
| best_cos = methods[best_method] |
| gap = q['cos_full'] - best_cos |
| nn_methods = { |
| 'pinv': q['nn_pinv'], 'lerp': q['nn_lerp'], |
| 'slerp': q['nn_slerp'], 'subspace': q['nn_subspace'], |
| } |
| best_nn_method = max(nn_methods, key=nn_methods.get) |
| print(f" N={q['N']:>3} k={q['k']:>3}: best_cos={best_method:>8} ({best_cos:.4f}, gap={gap:.4f})" |
| f" best_nn={best_nn_method:>8} ({nn_methods[best_nn_method]:.3f})") |
|
|
| return all_results |
|
|
|
|
| def batched_svd(A, method='auto', block_m=128, newton=False, target_rank=None): |
| """Batched thin SVD for (B, M, N) tensors. M >> N. |
| |
| Args: |
| A: (B, M, N) CUDA tensor |
| method: 'auto', 'triton', 'gram_eigh', 'newton', 'projected', 'torch' |
| block_m: Tile size for Triton kernels (N=2,3) |
| newton: If True, auto dispatch uses newton_svd for Nβ₯48 |
| target_rank: For projected method, or auto when Nβ₯48. |
| If set, auto uses projected SVD for Nβ₯48 (fast, approximate). |
| Default None = use gram_eigh (exact, slow for Nβ₯48). |
| |
| Dispatch table (method='auto'): |
| N=2: Fused Triton (closed-form) |
| N=3: Fused Triton (cyclic Jacobi) |
| N=4-47: Gram + eigh |
| Nβ₯48 target_rank set: Projected SVD (projectβcheap SVDβlift) |
| Nβ₯48 newton=True: Newton SVD (eigh internally) |
| Nβ₯48 default: Gram + eigh (slow but exact) |
| |
| Returns: U, S, Vh β singular values descending. |
| Shapes depend on method: |
| - Full methods: U(B,M,N), S(B,N), Vh(B,N,N) |
| - Projected: U(B,M,k), S(B,k), Vh(B,k,N) where k=target_rank |
| """ |
| assert A.ndim == 3, f"Expected (B, M, N), got shape {A.shape}" |
| assert A.is_cuda, "Input must be on CUDA" |
| B, M, N = A.shape |
| assert M >= N, f"Thin SVD requires M >= N, got M={M}, N={N}" |
|
|
| if method == 'auto': |
| if N == 2: |
| return batched_svd2(A, block_m) |
| elif N == 3: |
| return batched_svd3(A, block_m) |
| elif target_rank is not None and N >= 48: |
| return projected_svd(A, target_rank=target_rank) |
| elif newton and N >= 48: |
| return newton_svd(A) |
| else: |
| return gram_eigh_svd(A) |
|
|
| elif method == 'triton': |
| if N == 2: |
| return batched_svd2(A, block_m) |
| elif N == 3: |
| return batched_svd3(A, block_m) |
| else: |
| raise ValueError(f"Fused Triton kernel only available for N=2,3, got N={N}") |
|
|
| elif method == 'gram_eigh': |
| return gram_eigh_svd(A) |
|
|
| elif method == 'newton': |
| return newton_svd(A) |
|
|
| elif method == 'projected': |
| rank = target_rank or min(N // 2, 32) |
| return projected_svd(A, target_rank=rank) |
|
|
| elif method == 'torch': |
| return torch.linalg.svd(A.float(), full_matrices=False) |
|
|
| else: |
| raise ValueError(f"Unknown method '{method}'. Use: auto, triton, gram_eigh, newton, projected, torch") |
|
|
|
|
| |
| |
| |
|
|
| def validate_svd(A, U, S, Vh, label=""): |
| """Check SVD correctness: reconstruction, orthogonality, singular values.""" |
| B, M, N = A.shape |
| A_f = A.float() |
|
|
| |
| recon = torch.bmm(U * S.unsqueeze(1), Vh) |
| recon_err = (A_f - recon).abs().max().item() |
|
|
| |
| UtU = torch.bmm(U.transpose(1, 2), U) |
| eye = torch.eye(N, device=A.device).expand(B, -1, -1) |
| orth_err = (UtU - eye).abs().max().item() |
|
|
| |
| s_min = S.min().item() |
| s_sorted = (S[:, :-1] >= S[:, 1:] - 1e-6).all().item() |
|
|
| |
| U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False) |
| s_err = (S - S_ref).abs().max().item() |
| recon_ref = (A_f - torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)).abs().max().item() |
|
|
| tag = f"[{label}] " if label else "" |
| passed = recon_err < max(recon_ref * 3, 1e-3) and orth_err < 1e-2 and s_min >= -1e-6 |
| status = "PASS" if passed else "FAIL" |
|
|
| print(f" {tag}N={N:>3}: S_err={s_err:.2e} recon={recon_err:.2e} (ref={recon_ref:.2e})" |
| f" orth={orth_err:.2e} desc={s_sorted} [{status}]") |
| return passed |
|
|
|
|
| def run_validation(B=64, M=1024): |
| """Validate all methods across N values.""" |
| print(f"\n{'='*70}") |
| print(f" CORRECTNESS VALIDATION (B={B}, M={M})") |
| print(f"{'='*70}") |
|
|
| all_pass = True |
|
|
| for N in [2, 3, 4, 5, 6, 8, 10, 16, 32, 48, 64, 96, 128]: |
| if N > M: |
| continue |
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
|
|
| |
| U, S, Vh = batched_svd(A, method='auto') |
| p = validate_svd(A, U, S, Vh, label="auto") |
| all_pass = all_pass and p |
|
|
| |
| if N <= 3: |
| Ut, St, Vht = batched_svd(A, method='triton') |
| pt = validate_svd(A, Ut, St, Vht, label="triton") |
| all_pass = all_pass and pt |
|
|
| |
| if N > 3: |
| U2, S2, Vh2 = batched_svd(A, method='gram_eigh') |
| p2 = validate_svd(A, U2, S2, Vh2, label="gram") |
| all_pass = all_pass and p2 |
|
|
| |
| if N >= 8: |
| U3, S3, Vh3 = newton_svd(A) |
| p3 = validate_svd(A, U3, S3, Vh3, label="newton") |
| all_pass = all_pass and p3 |
|
|
| print(f"\n {'ALL PASSED' if all_pass else 'SOME FAILURES'}") |
|
|
| |
| print(f"\n{'='*70}") |
| print(f" PROCRUSTES ALIGNMENT VALIDATION") |
| print(f"{'='*70}") |
|
|
| for N in [16, 32, 48, 64, 128]: |
| n_samp = 2000 |
| |
| shared = torch.randn(n_samp, N, device='cuda') |
| source = shared + 0.3 * torch.randn(n_samp, N, device='cuda') |
| target = shared + 0.3 * torch.randn(n_samp, N, device='cuda') |
|
|
| rank = min(24, N - 1) |
| aligned, info = batched_procrustes( |
| source.unsqueeze(0), target.unsqueeze(0), |
| rank=rank, whiten=True) |
| aligned = aligned.squeeze(0) |
|
|
| cos_before = F.cosine_similarity(source, target, dim=-1).mean().item() |
| cos_after = F.cosine_similarity(aligned, target, dim=-1).mean().item() |
| improved = cos_after > cos_before |
|
|
| print(f" N={N:>3} rank={rank:>3} method={info['method']:>8}:" |
| f" cos {cos_before:.4f} β {cos_after:.4f}" |
| f" {'IMPROVED' if improved else 'WORSE'}") |
|
|
| |
| source_ub = torch.randn(1000, 48, device='cuda') |
| target_ub = torch.randn(1000, 48, device='cuda') * 0.5 + source_ub * 0.5 |
| aligned_ub, info_ub = batched_procrustes(source_ub, target_ub, rank=24) |
| assert aligned_ub.shape == source_ub.shape, f"Shape mismatch: {aligned_ub.shape} vs {source_ub.shape}" |
| print(f" Unbatched API: shape {aligned_ub.shape} β method={info_ub['method']}") |
|
|
| |
| aligned_pair, info_pair = batched_procrustes_align_pair( |
| source_ub, target_ub, rank=24, n_align=500) |
| assert aligned_pair.shape == source_ub.shape |
| cos_pair = F.cosine_similarity(aligned_pair, target_ub, dim=-1).mean().item() |
| print(f" Align-pair API: cos={cos_pair:.4f} method={info_pair['method']}") |
|
|
| print(f" PROCRUSTES VALIDATION COMPLETE") |
|
|
| return all_pass |
|
|
|
|
| |
| |
| |
|
|
| def _cuda_timer(fn, warmup=20, iters=80): |
| """CUDA-event-timed benchmark. Returns (mean_ms, std_ms, median_ms).""" |
| for _ in range(warmup): |
| fn() |
| torch.cuda.synchronize() |
|
|
| starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] |
| for i in range(iters): |
| starts[i].record(); fn(); ends[i].record() |
| torch.cuda.synchronize() |
|
|
| times = torch.tensor([starts[i].elapsed_time(ends[i]) for i in range(iters)]) |
| return times.mean().item(), times.std().item(), times.median().item() |
|
|
|
|
| def profile_n_sweep(B=512, M=1024): |
| """Sweep N from 2 to 128. Compare all methods including projected SVD.""" |
| device_name = torch.cuda.get_device_name(0) |
| print(f"\n{'='*110}") |
| print(f" N-DIMENSION SWEEP β {device_name}") |
| print(f" B={B}, M={M}") |
| print(f"{'='*110}") |
| print(f" {'N':>4} {'Triton':>10} {'Gram':>10} {'Newton':>10}" |
| f" {'Projβ24':>10} {'Projβ16':>10} {'Torch':>10} {'Best':>8} {'Speedup':>8}") |
| print(f" {'β'*106}") |
|
|
| results = [] |
| n_values = [2, 3, 4, 5, 6, 7, 8, 10, 12, 16, 20, 24, 32, 48, 64, 96, 128] |
|
|
| def _fmt(ms): |
| if ms != ms: |
| return f"{'β':>10}" |
| return f"{ms:>8.3f}ms" |
|
|
| for N in n_values: |
| if N > M: |
| continue |
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
|
|
| triton_ms = float('nan') |
| if N <= 3: |
| triton_ms, _, _ = _cuda_timer(lambda: batched_svd(A, method='triton')) |
|
|
| torch_ms, _, _ = _cuda_timer(lambda: torch.linalg.svd(A, full_matrices=False)) |
| gram_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A)) |
|
|
| newton_ms = float('nan') |
| if N >= 8: |
| newton_ms, _, _ = _cuda_timer(lambda: newton_svd(A)) |
|
|
| proj24_ms = float('nan') |
| if N >= 32: |
| proj24_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(24, N-1))) |
|
|
| proj16_ms = float('nan') |
| if N >= 24: |
| proj16_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(16, N-1))) |
|
|
| |
| times = {'torch': torch_ms, 'gram': gram_ms} |
| if N <= 3: times['triton'] = triton_ms |
| if N >= 8: times['newton'] = newton_ms |
| if N >= 32: times['proj24'] = proj24_ms |
| if N >= 24: times['proj16'] = proj16_ms |
| best = min(times, key=times.get) |
| speedup = torch_ms / (times[best] + 1e-9) |
|
|
| print(f" {N:>4} {_fmt(triton_ms)} {_fmt(gram_ms)} {_fmt(newton_ms)}" |
| f" {_fmt(proj24_ms)} {_fmt(proj16_ms)} {_fmt(torch_ms)}" |
| f" {best:>8} {speedup:>7.1f}x") |
|
|
| row = {'N': N, 'B': B, 'M': M, 'torch_ms': round(torch_ms, 4), |
| 'gram_ms': round(gram_ms, 4), 'best': best, |
| 'speedup_vs_torch': round(speedup, 3)} |
| for k, v in [('triton_ms', triton_ms), ('newton_ms', newton_ms), |
| ('proj24_ms', proj24_ms), ('proj16_ms', proj16_ms)]: |
| if v == v: row[k] = round(v, 4) |
| results.append(row) |
|
|
| del A; torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def profile_projection_quality(B=256, M=1024): |
| """Measure projection quality: how much information does rank-k SVD preserve? |
| |
| For each N, tests multiple target_rank values. Reports: |
| - Energy ratio: fraction of total singular value energy in top-k |
| - Reconstruction error: projected vs full SVD |
| - Subspace agreement: cosine of principal angles between subspaces |
| - Timing: projected vs full SVD |
| """ |
| print(f"\n{'='*100}") |
| print(f" PROJECTION QUALITY ANALYSIS β B={B}, M={M}") |
| print(f" Question: can rank-k SVD approximate rank-N SVD?") |
| print(f"{'='*100}") |
|
|
| configs = [ |
| |
| (32, [8, 12, 16, 24]), |
| (48, [8, 12, 16, 24, 32]), |
| (64, [8, 12, 16, 24, 32, 48]), |
| (96, [8, 16, 24, 32, 48, 64]), |
| (128, [8, 16, 24, 32, 48, 64, 96]), |
| ] |
|
|
| all_results = [] |
|
|
| for N, ranks in configs: |
| if N > M: |
| continue |
|
|
| print(f"\n N={N}:") |
| print(f" {'k':>5} {'Energy%':>8} {'Recon_proj':>11} {'Recon_trunc':>12}" |
| f" {'S_rel_err':>10} {'Subspace':>9} {'Proj ms':>10} {'Full ms':>10} {'Speedup':>8}") |
| print(f" {'β'*96}") |
|
|
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
|
|
| |
| full_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40) |
|
|
| for k in ranks: |
| if k >= N: |
| continue |
|
|
| q = projected_svd_quality(A, target_rank=k) |
| proj_ms, _, _ = _cuda_timer( |
| lambda: projected_svd(A, target_rank=k), warmup=10, iters=40) |
|
|
| speedup = full_ms / (proj_ms + 1e-9) |
|
|
| print(f" {k:>5} {q['energy_ratio']*100:>7.2f}% {q['recon_proj']:>11.2e}" |
| f" {q['recon_trunc']:>12.2e} {q['s_rel_err']:>10.4f}" |
| f" {q['subspace_cos']:>9.4f} {proj_ms:>8.3f}ms {full_ms:>8.3f}ms" |
| f" {speedup:>7.1f}x") |
|
|
| all_results.append({ |
| 'N': N, 'k': k, 'B': B, 'M': M, |
| 'energy_ratio': round(q['energy_ratio'], 6), |
| 'recon_proj': round(q['recon_proj'], 8), |
| 'recon_trunc': round(q['recon_trunc'], 8), |
| 's_rel_err': round(q['s_rel_err'], 6), |
| 'subspace_cos': round(q['subspace_cos'], 6), |
| 'proj_ms': round(proj_ms, 4), |
| 'full_ms': round(full_ms, 4), |
| }) |
|
|
| del A; torch.cuda.empty_cache() |
|
|
| |
| print(f"\n {'β'*70}") |
| print(f" SUMMARY: Recommended target_rank per N") |
| print(f" (β₯99% energy, β₯0.99 subspace cos, best speedup)") |
| print(f" {'β'*70}") |
| for N, ranks in configs: |
| good = [r for r in all_results if r['N'] == N |
| and r['energy_ratio'] >= 0.99 and r['subspace_cos'] >= 0.99] |
| if good: |
| best = min(good, key=lambda r: r['k']) |
| print(f" N={N:>3}: k={best['k']:>3} β {best['energy_ratio']*100:.1f}% energy," |
| f" subspace={best['subspace_cos']:.4f}," |
| f" {best['full_ms']/best['proj_ms']:.1f}x speedup") |
| else: |
| |
| available = [r for r in all_results if r['N'] == N] |
| if available: |
| best = max(available, key=lambda r: r['energy_ratio']) |
| print(f" N={N:>3}: best k={best['k']:>3} β {best['energy_ratio']*100:.1f}% energy," |
| f" subspace={best['subspace_cos']:.4f} (below 99% threshold)") |
|
|
| return all_results |
|
|
|
|
| def profile_batch_sweep(N=3, M=1024): |
| """Sweep batch size for a fixed N. Shows scaling behavior.""" |
| print(f"\n{'='*70}") |
| print(f" BATCH SWEEP β N={N}, M={M}") |
| print(f"{'='*70}") |
| print(f" {'B':>6} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8} {'img/s':>12}") |
| print(f" {'β'*52}") |
|
|
| batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] |
| results = [] |
|
|
| for B in batch_sizes: |
| try: |
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
| except RuntimeError: |
| print(f" {B:>6} OOM") |
| break |
|
|
| auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto')) |
| torch_mean, _, _ = _cuda_timer( |
| lambda: torch.linalg.svd(A, full_matrices=False)) |
|
|
| speedup = torch_mean / (auto_mean + 1e-9) |
| ips = B / (auto_mean / 1000) |
|
|
| print(f" {B:>6} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x {ips:>11,.0f}") |
| results.append({'B': B, 'N': N, 'M': M, |
| 'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4), |
| 'speedup': round(speedup, 3)}) |
| del A; torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def profile_spatial_sweep(N=3, B=512): |
| """Sweep spatial dimension M for a fixed N. Shows tiling efficiency.""" |
| print(f"\n{'='*70}") |
| print(f" SPATIAL SWEEP β N={N}, B={B}") |
| print(f"{'='*70}") |
| print(f" {'M':>6} {'~HxW':>8} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8}") |
| print(f" {'β'*48}") |
|
|
| m_values = [16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384] |
| results = [] |
|
|
| for M in m_values: |
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
| hw = int(M**0.5) |
| tag = f"{hw}Γ{hw}" if hw * hw == M else f"{M}" |
|
|
| auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto')) |
| torch_mean, _, _ = _cuda_timer( |
| lambda: torch.linalg.svd(A, full_matrices=False)) |
|
|
| speedup = torch_mean / (auto_mean + 1e-9) |
| print(f" {M:>6} {tag:>8} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x") |
| results.append({'M': M, 'N': N, 'B': B, |
| 'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4), |
| 'speedup': round(speedup, 3)}) |
| del A; torch.cuda.empty_cache() |
|
|
| return results |
|
|
|
|
| def profile_crossover_detail(M=1024, B=512): |
| """Fine-grained N sweep around expected crossover points.""" |
| print(f"\n{'='*70}") |
| print(f" CROSSOVER DETAIL β B={B}, M={M}") |
| print(f"{'='*70}") |
| print(f" {'N':>4} {'Gram ms':>10} {'Torch ms':>10} {'Winner':>8} {'Margin':>8}") |
| print(f" {'β'*46}") |
|
|
| for N in range(2, 65): |
| if N > M: |
| break |
| A = torch.randn(B, M, N, device="cuda", dtype=torch.float32) |
| gram_mean, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40) |
| torch_mean, _, _ = _cuda_timer( |
| lambda: torch.linalg.svd(A, full_matrices=False), warmup=10, iters=40) |
|
|
| winner = "gram" if gram_mean < torch_mean else "torch" |
| margin = abs(gram_mean - torch_mean) / min(gram_mean, torch_mean) * 100 |
|
|
| print(f" {N:>4} {gram_mean:>8.3f}ms {torch_mean:>8.3f}ms {winner:>8} {margin:>6.1f}%") |
| del A; torch.cuda.empty_cache() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| """Full profiling suite.""" |
| assert torch.cuda.is_available(), "CUDA required" |
| device_name = torch.cuda.get_device_name(0) |
| print(f"{'='*80}") |
| print(f" Generalized Batched Thin SVD β Profiling Suite") |
| print(f" Device: {device_name}") |
| print(f"{'='*80}") |
|
|
| |
| run_validation(B=64, M=1024) |
|
|
| |
| |
| procrustes_results = profile_procrustes_quality() |
|
|
| |
| proj_results = profile_projection_quality(B=256, M=1024) |
|
|
| |
| n_results = profile_n_sweep(B=512, M=1024) |
|
|
| |
| batch_results = {} |
| spatial_results = {} |
| |
| |
| |
| |
| |
|
|
| |
| print(f"\n{'='*80}") |
| print(f" SUMMARY") |
| print(f"{'='*80}") |
| print(f"\n Strategy by N:") |
| print(f" N=2: Fused Triton (closed-form Jacobi rotation)") |
| print(f" N=3: Fused Triton (cyclic Jacobi in registers)") |
| print(f" N=4-32: Gram + eigh (bmm + cuSOLVER eigh) β sub-ms") |
| print(f" N=48+: Projected SVD (Nβk, cheap SVD, lift back) β check quality table") |
| print(f"") |
| print(f" Standalone utilities:") |
| print(f" newton_schulz_invsqrt(G) β batched G^{{-1/2}} via pure bmm") |
| print(f" projected_svd(A, target_rank=k) β rank-k approximate SVD") |
| print(f" projected_svd_quality(A, target_rank) β measure approximation quality") |
| print(f"") |
| print(f" Key question answered: energy_ratio and subspace_cos in quality table") |
|
|
| |
| report = { |
| 'device': device_name, |
| 'procrustes_quality': procrustes_results, |
| 'projection_quality': proj_results, |
| 'n_sweep': n_results, |
| 'batch_sweeps': {str(k): v for k, v in batch_results.items()}, |
| 'spatial_sweeps': {str(k): v for k, v in spatial_results.items()}, |
| } |
| with open('svd_general_profile.json', 'w') as f: |
| json.dump(report, f, indent=2) |
| print(f"\n Results saved to svd_general_profile.json") |
| print(f"{'='*80}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |