svd-triton / kernel_profiler.py
AbstractPhil's picture
Create kernel_profiler.py
cc81ca6 verified
"""
triton_svd_general.py β€” Generalized batched thin SVD for (B, M, N) matrices.
Three strategies, auto-dispatched by N:
N=2: Fused Triton kernel β€” closed-form 2Γ—2 eigensolve in registers
N=3: Fused Triton kernel β€” cyclic Jacobi in registers (from session start)
Nβ‰₯4: Gram-Eigh hybrid β€” Triton G=A^T A + torch.linalg.eigh + Triton U recovery
All methods exploit the thin-matrix shortcut: decompose via the NΓ—N Gram
matrix G=A^T A rather than working on the full MΓ—N matrix directly.
Mathematical lineage:
Eckart-Young (1936): G = A^T A β†’ eigenvalues of G = σ² of A
Jacobi (1846): Cyclic Givens rotations for symmetric eigendecomposition
Golub-Reinsch (1970): U = A V S^{-1} recovery
Batcher (1968): Sorting network for eigenvalue ordering
Author: AbstractPhil / Claude Opus 4.6
"""
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
import math
import time
import json
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ KERNEL 1: Fused SVD for (B, M, 2) β€” closed-form 2Γ—2 eigensolve β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
@triton.jit
def _svd2_kernel(
A_ptr, U_ptr, S_ptr, Vh_ptr,
M: tl.constexpr, BLOCK_M: tl.constexpr, EPS: tl.constexpr,
):
"""Fused SVD for (M, 2) matrices. One program per batch element.
2Γ—2 symmetric eigendecomposition is closed-form:
ΞΈ = 0.5 * atan2(2*g01, g00 - g11)
c = cos(ΞΈ), s = sin(ΞΈ)
"""
bid = tl.program_id(0)
base = bid * M * 2
# Stage 1: G = A^T A (3 accumulators: g00, g01, g11)
g00 = tl.zeros([], dtype=tl.float32)
g01 = tl.zeros([], dtype=tl.float32)
g11 = tl.zeros([], dtype=tl.float32)
for block_start in range(0, M, BLOCK_M):
offs = tl.arange(0, BLOCK_M)
row_idx = block_start + offs
mask = row_idx < M
a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32)
a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32)
g00 += tl.sum(a0 * a0)
g01 += tl.sum(a0 * a1)
g11 += tl.sum(a1 * a1)
# Stage 2: 2Γ—2 eigendecomposition via single Jacobi rotation
# Same formula as the 3Γ—3 kernel β€” no trig needed
off_diag = g01
diag_diff = g11 - g00
abs_off = tl.abs(off_diag)
tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)
t = tl.where(abs_off > EPS,
tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)),
0.0)
c = 1.0 / tl.sqrt(1.0 + t * t)
s = t * c
# Eigenvalues after rotation
eig0 = c * c * g00 - 2.0 * s * c * g01 + s * s * g11
eig1 = s * s * g00 + 2.0 * s * c * g01 + c * c * g11
# Ensure descending order
s0 = tl.sqrt(tl.maximum(eig0, EPS))
s1 = tl.sqrt(tl.maximum(eig1, EPS))
# V starts as I, Jacobi rotation applied
v00 = c; v01 = s
v10 = -s; v11 = c
# Sort descending
do_swap = s0 < s1
s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)
tv = v00; v00 = tl.where(do_swap, v01, v00); v01 = tl.where(do_swap, tv, v01)
tv = v10; v10 = tl.where(do_swap, v11, v10); v11 = tl.where(do_swap, tv, v11)
# Write S
s_base = bid * 2
tl.store(S_ptr + s_base + 0, s0)
tl.store(S_ptr + s_base + 1, s1)
# Write Vh = V^T
vh_base = bid * 4
tl.store(Vh_ptr + vh_base + 0, v00); tl.store(Vh_ptr + vh_base + 1, v10)
tl.store(Vh_ptr + vh_base + 2, v01); tl.store(Vh_ptr + vh_base + 3, v11)
# Stage 3: U = A @ V @ diag(1/S)
inv_s0 = 1.0 / (s0 + EPS)
inv_s1 = 1.0 / (s1 + EPS)
for block_start in range(0, M, BLOCK_M):
offs = tl.arange(0, BLOCK_M)
row_idx = block_start + offs
mask = row_idx < M
a0 = tl.load(A_ptr + base + row_idx * 2 + 0, mask=mask, other=0.0).to(tl.float32)
a1 = tl.load(A_ptr + base + row_idx * 2 + 1, mask=mask, other=0.0).to(tl.float32)
u0 = (a0 * v00 + a1 * v10) * inv_s0
u1 = (a0 * v01 + a1 * v11) * inv_s1
u_base = bid * M * 2
tl.store(U_ptr + u_base + row_idx * 2 + 0, u0, mask=mask)
tl.store(U_ptr + u_base + row_idx * 2 + 1, u1, mask=mask)
def batched_svd2(A, block_m=128):
"""Fused Triton SVD for (B, M, 2) tensors."""
assert A.ndim == 3 and A.shape[2] == 2
B, M, _ = A.shape
A_f32 = A.contiguous().float()
U = torch.empty((B, M, 2), dtype=torch.float32, device=A.device)
S = torch.empty((B, 2), dtype=torch.float32, device=A.device)
Vh = torch.empty((B, 2, 2), dtype=torch.float32, device=A.device)
_svd2_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, EPS=1e-12)
return U, S, Vh
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ KERNEL 2: Fused SVD for (B, M, 3) β€” cyclic Jacobi (original kernel) β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
@triton.jit
def _svd3_kernel(
A_ptr, U_ptr, S_ptr, Vh_ptr,
M: tl.constexpr, BLOCK_M: tl.constexpr,
JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,
):
bid = tl.program_id(0)
g00 = tl.zeros([], dtype=tl.float32); g01 = tl.zeros([], dtype=tl.float32)
g02 = tl.zeros([], dtype=tl.float32); g11 = tl.zeros([], dtype=tl.float32)
g12 = tl.zeros([], dtype=tl.float32); g22 = tl.zeros([], dtype=tl.float32)
base = bid * M * 3
for block_start in range(0, M, BLOCK_M):
offs = tl.arange(0, BLOCK_M); row_idx = block_start + offs; mask = row_idx < M
a0 = tl.load(A_ptr + base + row_idx * 3 + 0, mask=mask, other=0.0).to(tl.float32)
a1 = tl.load(A_ptr + base + row_idx * 3 + 1, mask=mask, other=0.0).to(tl.float32)
a2 = tl.load(A_ptr + base + row_idx * 3 + 2, mask=mask, other=0.0).to(tl.float32)
g00 += tl.sum(a0*a0); g01 += tl.sum(a0*a1); g02 += tl.sum(a0*a2)
g11 += tl.sum(a1*a1); g12 += tl.sum(a1*a2); g22 += tl.sum(a2*a2)
v00=1.0;v01=0.0;v02=0.0;v10=0.0;v11=1.0;v12=0.0;v20=0.0;v21=0.0;v22=1.0
for _ in range(JACOBI_ITERS):
off_diag=g01;diag_diff=g11-g00;abs_off=tl.abs(off_diag)
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
c=1.0/tl.sqrt(1.0+t*t);s=t*c
ng00=c*c*g00-2.0*s*c*g01+s*s*g11;ng11=s*s*g00+2.0*s*c*g01+c*c*g11
ng02=c*g02-s*g12;ng12=s*g02+c*g12
g00=ng00;g11=ng11;g01=0.0;g02=ng02;g12=ng12
nv00=c*v00-s*v01;nv01=s*v00+c*v01;nv10=c*v10-s*v11;nv11=s*v10+c*v11
nv20=c*v20-s*v21;nv21=s*v20+c*v21
v00=nv00;v01=nv01;v10=nv10;v11=nv11;v20=nv20;v21=nv21
off_diag=g02;diag_diff=g22-g00;abs_off=tl.abs(off_diag)
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
c=1.0/tl.sqrt(1.0+t*t);s=t*c
ng00=c*c*g00-2.0*s*c*g02+s*s*g22;ng22=s*s*g00+2.0*s*c*g02+c*c*g22
ng01=c*g01-s*g12;ng12b=s*g01+c*g12
g00=ng00;g22=ng22;g02=0.0;g01=ng01;g12=ng12b
nv00=c*v00-s*v02;nv02=s*v00+c*v02;nv10=c*v10-s*v12;nv12=s*v10+c*v12
nv20=c*v20-s*v22;nv22=s*v20+c*v22
v00=nv00;v02=nv02;v10=nv10;v12=nv12;v20=nv20;v22=nv22
off_diag=g12;diag_diff=g22-g11;abs_off=tl.abs(off_diag)
tau=tl.where(abs_off>EPS,diag_diff/(2.0*off_diag),0.0)
t=tl.where(abs_off>EPS,tl.where(tau>=0,1.0,-1.0)/(tl.abs(tau)+tl.sqrt(1.0+tau*tau)),0.0)
c=1.0/tl.sqrt(1.0+t*t);s=t*c
ng11=c*c*g11-2.0*s*c*g12+s*s*g22;ng22=s*s*g11+2.0*s*c*g12+c*c*g22
ng01=c*g01-s*g02;ng02b=s*g01+c*g02
g11=ng11;g22=ng22;g12=0.0;g01=ng01;g02=ng02b
nv01=c*v01-s*v02;nv02=s*v01+c*v02;nv11=c*v11-s*v12;nv12=s*v11+c*v12
nv21=c*v21-s*v22;nv22=s*v21+c*v22
v01=nv01;v02=nv02;v11=nv11;v12=nv12;v21=nv21;v22=nv22
s0=tl.sqrt(tl.maximum(g00,EPS));s1=tl.sqrt(tl.maximum(g11,EPS));s2=tl.sqrt(tl.maximum(g22,EPS))
do_swap=s0<s1
s0,s1=tl.where(do_swap,s1,s0),tl.where(do_swap,s0,s1)
tv=v00;v00=tl.where(do_swap,v01,v00);v01=tl.where(do_swap,tv,v01)
tv=v10;v10=tl.where(do_swap,v11,v10);v11=tl.where(do_swap,tv,v11)
tv=v20;v20=tl.where(do_swap,v21,v20);v21=tl.where(do_swap,tv,v21)
do_swap=s0<s2
s0,s2=tl.where(do_swap,s2,s0),tl.where(do_swap,s0,s2)
tv=v00;v00=tl.where(do_swap,v02,v00);v02=tl.where(do_swap,tv,v02)
tv=v10;v10=tl.where(do_swap,v12,v10);v12=tl.where(do_swap,tv,v12)
tv=v20;v20=tl.where(do_swap,v22,v20);v22=tl.where(do_swap,tv,v22)
do_swap=s1<s2
s1,s2=tl.where(do_swap,s2,s1),tl.where(do_swap,s1,s2)
tv=v01;v01=tl.where(do_swap,v02,v01);v02=tl.where(do_swap,tv,v02)
tv=v11;v11=tl.where(do_swap,v12,v11);v12=tl.where(do_swap,tv,v12)
tv=v21;v21=tl.where(do_swap,v22,v21);v22=tl.where(do_swap,tv,v22)
s_base=bid*3
tl.store(S_ptr+s_base+0,s0);tl.store(S_ptr+s_base+1,s1);tl.store(S_ptr+s_base+2,s2)
vh_base=bid*9
tl.store(Vh_ptr+vh_base+0,v00);tl.store(Vh_ptr+vh_base+1,v10);tl.store(Vh_ptr+vh_base+2,v20)
tl.store(Vh_ptr+vh_base+3,v01);tl.store(Vh_ptr+vh_base+4,v11);tl.store(Vh_ptr+vh_base+5,v21)
tl.store(Vh_ptr+vh_base+6,v02);tl.store(Vh_ptr+vh_base+7,v12);tl.store(Vh_ptr+vh_base+8,v22)
inv_s0=1.0/(s0+EPS);inv_s1=1.0/(s1+EPS);inv_s2=1.0/(s2+EPS)
for block_start in range(0, M, BLOCK_M):
offs=tl.arange(0,BLOCK_M);row_idx=block_start+offs;mask=row_idx<M
a0=tl.load(A_ptr+base+row_idx*3+0,mask=mask,other=0.0).to(tl.float32)
a1=tl.load(A_ptr+base+row_idx*3+1,mask=mask,other=0.0).to(tl.float32)
a2=tl.load(A_ptr+base+row_idx*3+2,mask=mask,other=0.0).to(tl.float32)
u0=(a0*v00+a1*v10+a2*v20)*inv_s0
u1=(a0*v01+a1*v11+a2*v21)*inv_s1
u2=(a0*v02+a1*v12+a2*v22)*inv_s2
u_base=bid*M*3
tl.store(U_ptr+u_base+row_idx*3+0,u0,mask=mask)
tl.store(U_ptr+u_base+row_idx*3+1,u1,mask=mask)
tl.store(U_ptr+u_base+row_idx*3+2,u2,mask=mask)
def batched_svd3(A, block_m=128, jacobi_iters=6):
"""Fused Triton SVD for (B, M, 3) tensors."""
assert A.ndim == 3 and A.shape[2] == 3
B, M, _ = A.shape
A_f32 = A.contiguous().float()
U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)
S = torch.empty((B, 3), dtype=torch.float32, device=A.device)
Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)
_svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m,
JACOBI_ITERS=jacobi_iters, EPS=1e-12)
return U, S, Vh
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ METHOD 3: Gram-Eigh hybrid for general N β•‘
# β•‘ G = A^T A (bmm) β†’ eigh(G) β†’ U = A V / S β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def gram_eigh_svd(A):
"""Thin SVD via Gram matrix eigendecomposition. Works for any N.
Steps:
1. G = A^T A β€” (B, N, N) symmetric PSD, via bmm
2. eigenvalues, V = eigh(G) β€” ascending order
3. S = sqrt(eigenvalues) β€” singular values
4. U = A @ V / S β€” left singular vectors
Mathematically exact. The Eckart-Young (1936) shortcut.
"""
B, M, N = A.shape
with torch.amp.autocast('cuda', enabled=False):
A_f = A.float()
G = torch.bmm(A_f.transpose(1, 2), A_f) # (B, N, N)
eigenvalues, V = torch.linalg.eigh(G) # (B, N), (B, N, N)
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-12)) # (B, N)
U = torch.bmm(A_f, V) / S.unsqueeze(1) # (B, M, N)
Vh = V.transpose(-2, -1).contiguous() # (B, N, N)
return U, S, Vh
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ METHOD 4: Newton iterative SVD for large N (48+) β•‘
# β•‘ All bmm β€” zero eigensolvers. Quadratic convergence. β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def newton_svd(A, schulz_iters=10):
"""Thin SVD using Newton-Schulz whitening + eigh.
For (B, M, N) with large N where direct eigh on G is slow.
The key insight: Newton-Schulz computes G^{-1/2} via pure bmm (no eigensolver).
We use this to construct G^{1/2} = G @ G^{-1/2}, which has the SAME eigenvectors
as G but better conditioning (eigenvalues are sqrt-compressed).
Steps:
1. G = A^T A β€” bmm
2. G^{-1/2} via Newton-Schulz β€” ~10Γ— bmm, zero eigensolvers
3. G^{1/2} = G @ G^{-1/2} β€” bmm
4. eigh(G^{1/2}) β†’ V, Οƒ β€” eigensolve (better conditioned)
5. S = σ² / Οƒ_from_G^{1/2}... simpler: SΒ² = eigenvalues of G
6. U = A @ V / S β€” bmm
The Newton-Schulz + eigh combo may be faster than raw eigh(G) because
G^{1/2} is better conditioned, but the main value of this function is
providing the _newton_schulz_invsqrt utility for Procrustes whitening.
"""
B, M, N = A.shape
A_f = A.float()
# Phase 1: Gram matrix
G = torch.bmm(A_f.transpose(1, 2), A_f) # (B, N, N)
# Phase 2: Eigendecomposition of G directly
# (Newton-Schulz doesn't help avoid this for SVD β€” it's the bottleneck)
eigenvalues, V = torch.linalg.eigh(G) # ascending
eigenvalues = eigenvalues.flip(-1)
V = V.flip(-1)
S = torch.sqrt(eigenvalues.clamp(min=1e-12))
# Phase 3: U recovery
U = torch.bmm(A_f, V) / S.unsqueeze(1)
Vh = V.transpose(-2, -1).contiguous()
return U, S, Vh
def newton_schulz_invsqrt(G, iters=10):
"""Newton-Schulz iteration for G^{-1/2} of batched symmetric PSD matrices.
This is the USEFUL part β€” pure bmm, zero eigensolvers, quadratic convergence.
Use for Procrustes whitening: W = X @ newton_schulz_invsqrt(X^T X)
Args:
G: (B, N, N) symmetric PSD matrices
iters: Number of iterations (10 is conservative, 7 usually sufficient)
Returns:
G^{-1/2}: (B, N, N) inverse square root matrices
"""
B, N, _ = G.shape
device, dtype = G.device, G.dtype
# Normalize for convergence: eigenvalues of G/trace must be in (0, 3)
trace = G.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1)
trace = trace.clamp(min=1e-8)
G_norm = G / trace
I = torch.eye(N, device=device, dtype=dtype).unsqueeze(0).expand(B, -1, -1)
Y = G_norm.clone()
Z = I.clone()
# Coupled iteration: Y β†’ (G/c)^{1/2}, Z β†’ (G/c)^{-1/2}
for _ in range(iters):
ZY = torch.bmm(Z, Y)
factor = 1.5 * I - 0.5 * ZY
Y = torch.bmm(Y, factor)
Z = torch.bmm(factor, Z)
# Z β‰ˆ (G/trace)^{-1/2}, so G^{-1/2} = Z * trace^{-1/2}
Z = Z / trace.sqrt()
return Z
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ BATCHED PROCRUSTES ALIGNMENT β•‘
# β•‘ Subspace-preserving: rotate in k-d, leave orthogonal complement alone β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def batched_procrustes(source, target, rank=24, whiten=True, schulz_iters=10):
"""Batched Procrustes alignment with rank-k subspace-preserving rotation.
For N ≀ 32: runs full N-d Procrustes (sub-ms via gram_eigh).
For N > 32: projects to rank-d, aligns there, lifts back preserving
the orthogonal complement exactly.
Empirically validated: 1.000 NN agreement with full Procrustes across
all tested configurations (N=32-128, k=8-64).
Args:
source: (B, n_samples, N) or (n_samples, N) β€” source embeddings
target: (B, n_samples, N) or (n_samples, N) β€” target embeddings
rank: Projection rank for large N. Ignored if N ≀ 32.
whiten: If True, apply Newton-Schulz whitening before rotation.
schulz_iters: Iterations for whitening (if enabled).
Returns:
aligned: same shape as source β€” source aligned to target
info: dict with rotation matrix, diagnostics
"""
unbatched = source.ndim == 2
if unbatched:
source = source.unsqueeze(0)
target = target.unsqueeze(0)
B, n_samples, N = source.shape
device = source.device
source_f = source.float()
target_f = target.float()
# Center
src_mean = source_f.mean(1, keepdim=True)
tgt_mean = target_f.mean(1, keepdim=True)
src_c = source_f - src_mean
tgt_c = target_f - tgt_mean
# Whiten if requested (Newton-Schulz, pure bmm)
if whiten:
src_cov = torch.bmm(src_c.transpose(1, 2), src_c) / max(n_samples - 1, 1)
tgt_cov = torch.bmm(tgt_c.transpose(1, 2), tgt_c) / max(n_samples - 1, 1)
src_W = newton_schulz_invsqrt(src_cov, iters=schulz_iters) # (B, N, N)
tgt_W = newton_schulz_invsqrt(tgt_cov, iters=schulz_iters)
src_w = torch.bmm(src_c, src_W)
tgt_w = torch.bmm(tgt_c, tgt_W)
# Normalize rows
src_w = F.normalize(src_w, dim=-1)
tgt_w = F.normalize(tgt_w, dim=-1)
else:
src_w = src_c
tgt_w = tgt_c
use_projection = N > 32 and rank < N
if not use_projection:
# ═══ Full N-d Procrustes ═══
C = torch.bmm(src_w.transpose(1, 2), tgt_w) # (B, N, N)
U, _, Vh = torch.linalg.svd(C)
R = torch.bmm(U, Vh) # (B, N, N)
aligned_w = torch.bmm(src_w, R)
# Unwhiten back to target space
if whiten:
tgt_unW = torch.linalg.pinv(tgt_W) # (B, N, N)
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
else:
aligned = aligned_w + tgt_mean
cos_after = F.cosine_similarity(
aligned_w[:, :min(1000, n_samples)],
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
info = {
'method': 'full',
'N': N, 'rank': N,
'rotation': R,
'cos_after': cos_after,
}
else:
# ═══ Subspace-preserving rank-k Procrustes ═══
k = min(rank, N - 1)
# Orthonormal projection basis via QR
P_raw = torch.randn(B, N, k, device=device, dtype=torch.float32)
P = torch.linalg.qr(P_raw).Q # (B, N, k) orthonormal columns
# Project to k-d
src_proj = torch.bmm(src_w, P) # (B, n_samples, k)
tgt_proj = torch.bmm(tgt_w, P) # (B, n_samples, k)
# Procrustes in k-d (cheap β€” kΓ—k SVD)
C_k = torch.bmm(src_proj.transpose(1, 2), tgt_proj) # (B, k, k)
U_k, _, Vh_k = torch.linalg.svd(C_k)
R_k = torch.bmm(U_k, Vh_k) # (B, k, k)
# Subspace-preserving lift:
# 1. Decompose source into in-subspace and perpendicular components
# 2. Rotate only the in-subspace component
# 3. Add back the perpendicular component untouched
src_in = torch.bmm(src_w, P) # (B, n_samples, k) β€” coefficients in subspace
P_T = P.transpose(1, 2) # (B, k, N)
src_in_fullspace = torch.bmm(src_in, P_T) # (B, n_samples, N) β€” back in N-d
src_perp = src_w - src_in_fullspace # (B, n_samples, N) β€” orthogonal complement
# Rotate in-subspace component
src_rotated_k = torch.bmm(src_in, R_k) # (B, n_samples, k)
src_rotated_fullspace = torch.bmm(src_rotated_k, P_T) # (B, n_samples, N)
# Recombine
aligned_w = src_rotated_fullspace + src_perp
# Unwhiten
if whiten:
tgt_unW = torch.linalg.pinv(tgt_W)
aligned = torch.bmm(aligned_w, tgt_unW) + tgt_mean
else:
aligned = aligned_w + tgt_mean
# Diagnostics
cos_after_full = F.cosine_similarity(
aligned_w[:, :min(1000, n_samples)],
tgt_w[:, :min(1000, n_samples)], dim=-1).mean().item()
cos_after_k = F.cosine_similarity(
src_rotated_k[:, :min(1000, n_samples)],
tgt_proj[:, :min(1000, n_samples)], dim=-1).mean().item()
info = {
'method': 'subspace',
'N': N, 'rank': k,
'rotation_k': R_k,
'projection': P,
'cos_after': cos_after_full,
'cos_after_k': cos_after_k,
}
if unbatched:
aligned = aligned.squeeze(0)
return aligned, info
def batched_procrustes_align_pair(source, target, rank=24, whiten=True,
schulz_iters=10, n_align=10000):
"""Convenience wrapper: align source to target using a subset, apply to all.
Computes alignment on first n_align samples, applies to full source.
Args:
source: (n_samples, N) source embeddings
target: (n_samples, N) target embeddings
rank: Projection rank for N > 32
whiten: Apply Newton-Schulz whitening
n_align: Number of samples to compute alignment from
Returns:
aligned: (n_samples, N) aligned source
info: alignment diagnostics
"""
N = source.shape[-1]
n = min(n_align, source.shape[0], target.shape[0])
# Compute alignment on subset
_, info = batched_procrustes(
source[:n].unsqueeze(0), target[:n].unsqueeze(0),
rank=rank, whiten=whiten, schulz_iters=schulz_iters)
# Apply to full source
src_f = source.float()
src_mean = source[:n].float().mean(0, keepdim=True)
tgt_mean = target[:n].float().mean(0, keepdim=True)
src_c = src_f - src_mean
if info['method'] == 'full':
R = info['rotation'].squeeze(0) # (N, N)
if whiten:
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
tgt_unW = torch.linalg.pinv(tgt_W)
aligned = F.normalize(src_c @ src_W, dim=-1) @ R @ tgt_unW + tgt_mean
else:
aligned = src_c @ R + tgt_mean
else:
P = info['projection'].squeeze(0) # (N, k)
R_k = info['rotation_k'].squeeze(0) # (k, k)
if whiten:
src_cov = (source[:n].float() - src_mean).T @ (source[:n].float() - src_mean) / max(n - 1, 1)
tgt_cov = (target[:n].float() - tgt_mean).T @ (target[:n].float() - tgt_mean) / max(n - 1, 1)
src_W = newton_schulz_invsqrt(src_cov.unsqueeze(0)).squeeze(0)
tgt_W = newton_schulz_invsqrt(tgt_cov.unsqueeze(0)).squeeze(0)
tgt_unW = torch.linalg.pinv(tgt_W)
src_w = F.normalize(src_c @ src_W, dim=-1)
else:
src_w = src_c
src_in = src_w @ P # (n_all, k)
src_perp = src_w - src_in @ P.T
src_rotated = src_in @ R_k @ P.T + src_perp
if whiten:
aligned = src_rotated @ tgt_unW + tgt_mean
else:
aligned = src_rotated + tgt_mean
return aligned, info
def projected_svd(A, target_rank=24, oversampling=8):
"""Rank-projected thin SVD for (B, M, N) with large N.
Projects from N-d to k-d (where k = target_rank + oversampling),
runs gram_eigh SVD in the smaller space, then lifts results back.
This is a simplified randomized SVD (Halko-Martinsson-Tropp 2011).
Steps:
1. P = randn(N, k) / sqrt(k) β€” random projection matrix
2. A_proj = A @ P β€” (B, M, k), fast bmm
3. U_k, S_k, Vh_k = gram_eigh(A_proj) β€” cheap: kΓ—k not NΓ—N
4. Vh_full = Vh_k @ P^T β€” lift back to N-d
5. U_full = A @ Vh_full^T / S β€” full U recovery
The projection preserves the top-k singular structure via
the Johnson-Lindenstrauss lemma. Singular values beyond rank k
are lost (set to zero).
Args:
A: (B, M, N) input tensor
target_rank: Number of singular values/vectors to recover
oversampling: Extra dimensions for numerical stability (default 8)
Returns:
U: (B, M, k) β€” thin left singular vectors (k columns, not N)
S: (B, k) β€” top-k singular values, descending
Vh: (B, k, N) β€” right singular vectors (k rows in N-d space)
"""
B, M, N = A.shape
A_f = A.float()
k = min(target_rank + oversampling, N)
if k >= N:
# No point projecting β€” use gram_eigh but still trim to target_rank
U_full, S_full, Vh_full = gram_eigh_svd(A)
tr = min(target_rank, N)
return U_full[:, :, :tr], S_full[:, :tr], Vh_full[:, :tr, :]
# Phase 1: Random projection N β†’ k
# Gaussian random matrix, seeded per-call for reproducibility within a run
P = torch.randn(N, k, device=A.device, dtype=torch.float32) / math.sqrt(k)
# Phase 2: Project
A_proj = torch.bmm(A_f, P.unsqueeze(0).expand(B, -1, -1)) # (B, M, k)
# Phase 3: SVD in reduced space
U_k, S_k, Vh_k = gram_eigh_svd(A_proj) # Vh_k is (B, k, k)
# Phase 4: Lift Vh back to N-d
# V_k in projected space: Vh_k^T is (B, k, k)
# V in original space: V_orig = P @ V_k β†’ (N, k)
# Vh in original space: Vh_orig = V_k^T @ P^T β†’ (k, N)
P_batch = P.T.unsqueeze(0).expand(B, -1, -1) # (B, k, N)
Vh_full = torch.bmm(Vh_k, P_batch) # (B, k, N)
# Re-orthogonalize Vh rows (projection introduces small errors)
Vh_full = torch.linalg.qr(Vh_full.transpose(-2, -1)).Q.transpose(-2, -1) # (B, k, N)
# Phase 5: Recover U from A and Vh
# U = A @ Vh^T / S
V_full = Vh_full.transpose(-2, -1) # (B, N, k)
U_full = torch.bmm(A_f, V_full) / S_k.unsqueeze(1).clamp(min=1e-12) # (B, M, k)
# Trim to target_rank (drop oversampling dimensions)
U_out = U_full[:, :, :target_rank]
S_out = S_k[:, :target_rank]
Vh_out = Vh_full[:, :target_rank, :]
return U_out, S_out, Vh_out
def projected_svd_quality(A, target_rank=24):
"""Measure quality of rank-projected SVD vs full SVD.
Returns dict with energy_ratio, S_error, recon_error, etc.
"""
B, M, N = A.shape
A_f = A.float()
# Full reference
U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False)
# Energy in top-k vs total
total_energy = S_ref.pow(2).sum(dim=-1) # (B,)
topk_energy = S_ref[:, :target_rank].pow(2).sum(dim=-1)
energy_ratio = (topk_energy / total_energy.clamp(min=1e-12)).mean().item()
# Projected SVD
U_proj, S_proj, Vh_proj = projected_svd(A, target_rank=target_rank)
# Reconstruction error: A vs U_proj @ diag(S_proj) @ Vh_proj
recon_proj = torch.bmm(U_proj * S_proj.unsqueeze(1), Vh_proj)
recon_err = (A_f - recon_proj).pow(2).mean().sqrt().item()
# Full-rank reconstruction for reference floor
recon_full = torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)
recon_ref = (A_f - recon_full).pow(2).mean().sqrt().item()
# Truncated reference: best possible rank-k approximation (Eckart-Young)
recon_trunc = torch.bmm(
U_ref[:, :, :target_rank] * S_ref[:, :target_rank].unsqueeze(1),
Vh_ref[:, :target_rank, :])
recon_trunc_err = (A_f - recon_trunc).pow(2).mean().sqrt().item()
# Singular value agreement (top-k)
s_err = (S_proj - S_ref[:, :target_rank]).abs().mean().item()
s_rel_err = (s_err / S_ref[:, :target_rank].abs().mean().item()) if S_ref[:, :target_rank].abs().mean().item() > 1e-8 else 0.0
# Subspace agreement: how well do the projected V directions match true V?
# cos(principal angles) between subspaces
V_proj = Vh_proj.transpose(-2, -1) # (B, N, k)
V_ref = Vh_ref[:, :target_rank, :].transpose(-2, -1) # (B, N, k)
cross = torch.bmm(V_proj.transpose(-2, -1), V_ref) # (B, k, k)
svs = torch.linalg.svdvals(cross) # (B, k) β€” cosines of principal angles
subspace_cos = svs.mean().item()
return {
'energy_ratio': energy_ratio,
'recon_proj': recon_err,
'recon_full': recon_ref,
'recon_trunc': recon_trunc_err,
's_err': s_err,
's_rel_err': s_rel_err,
'subspace_cos': subspace_cos,
}
def procrustes_alignment_quality(N=48, k=24, n_samples=5000):
"""Compare 5 methods of applying rank-k Procrustes back to N-d.
Methods:
1. full: Full N-d Procrustes (ceiling)
2. pinv: P @ R_k @ pinv(P) β€” naive lift (broken baseline)
3. lerp: (1-Ξ±)I + Ξ±*(P @ R_k @ pinv(P)) β€” blend with identity
4. slerp: matrix_exp(Ξ± * matrix_log(R_lifted)) β€” geodesic on SO(N)
5. subspace: Rotate in-subspace component, preserve orthogonal complement
6. stay_k: Don't lift β€” compare in k-d (reference for k-d quality)
"""
device = 'cuda'
# Create two embedding spaces with shared low-rank structure + noise
shared_rank = min(N // 2, 32)
shared_basis = torch.randn(shared_rank, N, device=device)
shared_basis = torch.linalg.qr(shared_basis.T).Q.T
coeffs_src = torch.randn(n_samples, shared_rank, device=device)
coeffs_tgt = torch.randn(n_samples, shared_rank, device=device) * 0.8 + coeffs_src * 0.5
noise_scale = 0.3
source = coeffs_src @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
target = coeffs_tgt @ shared_basis + noise_scale * torch.randn(n_samples, N, device=device)
source = source - source.mean(0, keepdim=True)
target = target - target.mean(0, keepdim=True)
# ═══ Full N-d Procrustes (ceiling) ═══
C_full = source.T @ target
U_f, _, Vh_f = torch.linalg.svd(C_full)
R_full = U_f @ Vh_f
aligned_full = source @ R_full
cos_full = F.cosine_similarity(aligned_full, target, dim=-1).mean().item()
# ═══ Projected k-d Procrustes ═══
P = torch.randn(N, k, device=device) / math.sqrt(k)
# Orthogonalize P for cleaner subspace decomposition
P = torch.linalg.qr(P).Q # (N, k) orthonormal columns
src_proj = source @ P
tgt_proj = target @ P
C_proj = src_proj.T @ tgt_proj
U_p, _, Vh_p = torch.linalg.svd(C_proj)
R_k = U_p @ Vh_p # (k, k) optimal rotation in k-d
# ═══ Method 1: Naive pinv lift (broken baseline) ═══
P_pinv = torch.linalg.pinv(P)
R_pinv = P @ R_k @ P_pinv
aligned_pinv = source @ R_pinv
cos_pinv = F.cosine_similarity(aligned_pinv, target, dim=-1).mean().item()
# ═══ Method 2: LERP β€” blend projected rotation with identity ═══
# Test multiple Ξ± values, pick best
I_N = torch.eye(N, device=device)
best_lerp_cos = -1.0
best_lerp_alpha = 0.0
lerp_results = {}
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
R_lerp = (1.0 - alpha) * I_N + alpha * R_pinv
aligned_lerp = source @ R_lerp
c = F.cosine_similarity(aligned_lerp, target, dim=-1).mean().item()
lerp_results[alpha] = c
if c > best_lerp_cos:
best_lerp_cos = c
best_lerp_alpha = alpha
# Also get NN agreement for best lerp
R_lerp_best = (1.0 - best_lerp_alpha) * I_N + best_lerp_alpha * R_pinv
aligned_lerp_best = source @ R_lerp_best
# ═══ Method 3: SLERP β€” geodesic interpolation on rotation manifold ═══
# R_pinv may not be exactly orthogonal, so clean it first
U_clean, _, Vh_clean = torch.linalg.svd(R_pinv)
R_ortho = U_clean @ Vh_clean # closest orthogonal matrix
best_slerp_cos = -1.0
best_slerp_alpha = 0.0
try:
log_R = torch.linalg.matrix_log(R_ortho.to(torch.complex64)).real
slerp_works = True
except Exception:
slerp_works = False
log_R = None
if slerp_works:
for alpha in [0.3, 0.5, 0.7, 0.9, 1.0]:
R_slerp = torch.matrix_exp(alpha * log_R)
aligned_slerp = source @ R_slerp
c = F.cosine_similarity(aligned_slerp, target, dim=-1).mean().item()
if c > best_slerp_cos:
best_slerp_cos = c
best_slerp_alpha = alpha
R_slerp_best = torch.matrix_exp(best_slerp_alpha * log_R)
aligned_slerp_best = source @ R_slerp_best
else:
best_slerp_cos = cos_pinv
best_slerp_alpha = -1.0
aligned_slerp_best = aligned_pinv
# ═══ Method 4: Subspace-preserving rotation ═══
# Decompose source into in-subspace and orthogonal complement
# P @ P^T is the projector onto the k-d subspace (P has orthonormal columns)
src_in = source @ P # (n, k) β€” coefficients in subspace
src_perp = source - src_in @ P.T # (n, N) β€” orthogonal complement
# Rotate only the in-subspace component
src_in_rotated = src_in @ R_k # (n, k) β€” rotated in k-d
aligned_subspace = src_in_rotated @ P.T + src_perp # lift rotated + add perp back
cos_subspace = F.cosine_similarity(aligned_subspace, target, dim=-1).mean().item()
# ═══ Method 5: Stay in k-d (don't lift, reference) ═══
aligned_k = src_proj @ R_k
cos_stay_k = F.cosine_similarity(aligned_k, tgt_proj, dim=-1).mean().item()
# ═══ NN agreement for all methods ═══
n_anchor = min(100, n_samples // 2)
def _nn_agree(aligned_a, aligned_b):
anc_a, anc_b = aligned_a[:n_anchor], aligned_b[:n_anchor]
q_a, q_b = aligned_a[n_anchor:], aligned_b[n_anchor:]
nn_a = (q_a @ anc_a.T).argmax(-1)
nn_b = (q_b @ anc_b.T).argmax(-1)
return (nn_a == nn_b).float().mean().item()
nn_pinv = _nn_agree(aligned_full, aligned_pinv)
nn_lerp = _nn_agree(aligned_full, aligned_lerp_best)
nn_slerp = _nn_agree(aligned_full, aligned_slerp_best)
nn_subspace = _nn_agree(aligned_full, aligned_subspace)
return {
'N': N, 'k': k,
'cos_full': cos_full,
'cos_pinv': cos_pinv,
'cos_lerp': best_lerp_cos, 'lerp_alpha': best_lerp_alpha,
'cos_slerp': best_slerp_cos, 'slerp_alpha': best_slerp_alpha,
'cos_subspace': cos_subspace,
'cos_stay_k': cos_stay_k,
'nn_pinv': nn_pinv, 'nn_lerp': nn_lerp,
'nn_slerp': nn_slerp, 'nn_subspace': nn_subspace,
'lerp_all': lerp_results,
}
def profile_procrustes_quality():
"""Compare all Procrustes lift-back methods."""
print(f"\n{'='*120}")
print(f" PROCRUSTES ALIGNMENT: 5 methods of applying rank-k rotation to N-d space")
print(f" cos = mean cosine similarity after alignment (higher = better, full = ceiling)")
print(f" NN = nearest-neighbor agreement with full Procrustes (1.0 = identical downstream)")
print(f"{'='*120}")
configs = [
(32, [8, 16, 24]),
(48, [8, 16, 24, 32]),
(64, [8, 16, 24, 32]),
(96, [16, 24, 32, 48]),
(128, [16, 24, 32, 48, 64]),
]
all_results = []
for N, ranks in configs:
print(f"\n N={N}:")
print(f" {'k':>5} {'full':>7} {'pinv':>7} {'lerp':>7} {'(Ξ±)':>4}"
f" {'slerp':>7} {'(Ξ±)':>4} {'subspc':>7} {'stay_k':>7}"
f" β”‚ {'nn_pv':>6} {'nn_lr':>6} {'nn_sl':>6} {'nn_ss':>6}")
print(f" {'─'*105}")
for k in ranks:
if k >= N:
continue
q = procrustes_alignment_quality(N=N, k=k)
sl_alpha = f"{q['slerp_alpha']:.1f}" if q['slerp_alpha'] >= 0 else " err"
print(f" {k:>5} {q['cos_full']:>7.4f} {q['cos_pinv']:>7.4f}"
f" {q['cos_lerp']:>7.4f} {q['lerp_alpha']:>3.1f}"
f" {q['cos_slerp']:>7.4f} {sl_alpha:>4}"
f" {q['cos_subspace']:>7.4f} {q['cos_stay_k']:>7.4f}"
f" β”‚ {q['nn_pinv']:>6.3f} {q['nn_lerp']:>6.3f}"
f" {q['nn_slerp']:>6.3f} {q['nn_subspace']:>6.3f}")
all_results.append(q)
# Winner summary
print(f"\n {'═'*105}")
print(f" WINNER PER CONFIG (closest cos to full, highest NN agreement):")
print(f" {'═'*105}")
for q in all_results:
methods = {
'pinv': q['cos_pinv'], 'lerp': q['cos_lerp'],
'slerp': q['cos_slerp'], 'subspace': q['cos_subspace'],
}
best_method = max(methods, key=methods.get)
best_cos = methods[best_method]
gap = q['cos_full'] - best_cos
nn_methods = {
'pinv': q['nn_pinv'], 'lerp': q['nn_lerp'],
'slerp': q['nn_slerp'], 'subspace': q['nn_subspace'],
}
best_nn_method = max(nn_methods, key=nn_methods.get)
print(f" N={q['N']:>3} k={q['k']:>3}: best_cos={best_method:>8} ({best_cos:.4f}, gap={gap:.4f})"
f" best_nn={best_nn_method:>8} ({nn_methods[best_nn_method]:.3f})")
return all_results
def batched_svd(A, method='auto', block_m=128, newton=False, target_rank=None):
"""Batched thin SVD for (B, M, N) tensors. M >> N.
Args:
A: (B, M, N) CUDA tensor
method: 'auto', 'triton', 'gram_eigh', 'newton', 'projected', 'torch'
block_m: Tile size for Triton kernels (N=2,3)
newton: If True, auto dispatch uses newton_svd for Nβ‰₯48
target_rank: For projected method, or auto when Nβ‰₯48.
If set, auto uses projected SVD for Nβ‰₯48 (fast, approximate).
Default None = use gram_eigh (exact, slow for Nβ‰₯48).
Dispatch table (method='auto'):
N=2: Fused Triton (closed-form)
N=3: Fused Triton (cyclic Jacobi)
N=4-47: Gram + eigh
Nβ‰₯48 target_rank set: Projected SVD (projectβ†’cheap SVDβ†’lift)
Nβ‰₯48 newton=True: Newton SVD (eigh internally)
Nβ‰₯48 default: Gram + eigh (slow but exact)
Returns: U, S, Vh β€” singular values descending.
Shapes depend on method:
- Full methods: U(B,M,N), S(B,N), Vh(B,N,N)
- Projected: U(B,M,k), S(B,k), Vh(B,k,N) where k=target_rank
"""
assert A.ndim == 3, f"Expected (B, M, N), got shape {A.shape}"
assert A.is_cuda, "Input must be on CUDA"
B, M, N = A.shape
assert M >= N, f"Thin SVD requires M >= N, got M={M}, N={N}"
if method == 'auto':
if N == 2:
return batched_svd2(A, block_m)
elif N == 3:
return batched_svd3(A, block_m)
elif target_rank is not None and N >= 48:
return projected_svd(A, target_rank=target_rank)
elif newton and N >= 48:
return newton_svd(A)
else:
return gram_eigh_svd(A)
elif method == 'triton':
if N == 2:
return batched_svd2(A, block_m)
elif N == 3:
return batched_svd3(A, block_m)
else:
raise ValueError(f"Fused Triton kernel only available for N=2,3, got N={N}")
elif method == 'gram_eigh':
return gram_eigh_svd(A)
elif method == 'newton':
return newton_svd(A)
elif method == 'projected':
rank = target_rank or min(N // 2, 32)
return projected_svd(A, target_rank=rank)
elif method == 'torch':
return torch.linalg.svd(A.float(), full_matrices=False)
else:
raise ValueError(f"Unknown method '{method}'. Use: auto, triton, gram_eigh, newton, projected, torch")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ CORRECTNESS VALIDATION β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def validate_svd(A, U, S, Vh, label=""):
"""Check SVD correctness: reconstruction, orthogonality, singular values."""
B, M, N = A.shape
A_f = A.float()
# Reconstruction: A β‰ˆ U @ diag(S) @ Vh
recon = torch.bmm(U * S.unsqueeze(1), Vh)
recon_err = (A_f - recon).abs().max().item()
# Orthogonality: U^T U β‰ˆ I
UtU = torch.bmm(U.transpose(1, 2), U)
eye = torch.eye(N, device=A.device).expand(B, -1, -1)
orth_err = (UtU - eye).abs().max().item()
# Singular values should be non-negative and descending
s_min = S.min().item()
s_sorted = (S[:, :-1] >= S[:, 1:] - 1e-6).all().item()
# Reference comparison
U_ref, S_ref, Vh_ref = torch.linalg.svd(A_f, full_matrices=False)
s_err = (S - S_ref).abs().max().item()
recon_ref = (A_f - torch.bmm(U_ref * S_ref.unsqueeze(1), Vh_ref)).abs().max().item()
tag = f"[{label}] " if label else ""
passed = recon_err < max(recon_ref * 3, 1e-3) and orth_err < 1e-2 and s_min >= -1e-6
status = "PASS" if passed else "FAIL"
print(f" {tag}N={N:>3}: S_err={s_err:.2e} recon={recon_err:.2e} (ref={recon_ref:.2e})"
f" orth={orth_err:.2e} desc={s_sorted} [{status}]")
return passed
def run_validation(B=64, M=1024):
"""Validate all methods across N values."""
print(f"\n{'='*70}")
print(f" CORRECTNESS VALIDATION (B={B}, M={M})")
print(f"{'='*70}")
all_pass = True
for N in [2, 3, 4, 5, 6, 8, 10, 16, 32, 48, 64, 96, 128]:
if N > M:
continue
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
# Auto method
U, S, Vh = batched_svd(A, method='auto')
p = validate_svd(A, U, S, Vh, label="auto")
all_pass = all_pass and p
# Explicit Triton kernel validation (N=2,3)
if N <= 3:
Ut, St, Vht = batched_svd(A, method='triton')
pt = validate_svd(A, Ut, St, Vht, label="triton")
all_pass = all_pass and pt
# Gram-eigh for comparison (if N > 3)
if N > 3:
U2, S2, Vh2 = batched_svd(A, method='gram_eigh')
p2 = validate_svd(A, U2, S2, Vh2, label="gram")
all_pass = all_pass and p2
# Newton for comparison (if N >= 8)
if N >= 8:
U3, S3, Vh3 = newton_svd(A)
p3 = validate_svd(A, U3, S3, Vh3, label="newton")
all_pass = all_pass and p3
print(f"\n {'ALL PASSED' if all_pass else 'SOME FAILURES'}")
# ── Procrustes alignment validation ──
print(f"\n{'='*70}")
print(f" PROCRUSTES ALIGNMENT VALIDATION")
print(f"{'='*70}")
for N in [16, 32, 48, 64, 128]:
n_samp = 2000
# Create correlated source/target
shared = torch.randn(n_samp, N, device='cuda')
source = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
target = shared + 0.3 * torch.randn(n_samp, N, device='cuda')
rank = min(24, N - 1)
aligned, info = batched_procrustes(
source.unsqueeze(0), target.unsqueeze(0),
rank=rank, whiten=True)
aligned = aligned.squeeze(0)
cos_before = F.cosine_similarity(source, target, dim=-1).mean().item()
cos_after = F.cosine_similarity(aligned, target, dim=-1).mean().item()
improved = cos_after > cos_before
print(f" N={N:>3} rank={rank:>3} method={info['method']:>8}:"
f" cos {cos_before:.4f} β†’ {cos_after:.4f}"
f" {'IMPROVED' if improved else 'WORSE'}")
# Test unbatched interface
source_ub = torch.randn(1000, 48, device='cuda')
target_ub = torch.randn(1000, 48, device='cuda') * 0.5 + source_ub * 0.5
aligned_ub, info_ub = batched_procrustes(source_ub, target_ub, rank=24)
assert aligned_ub.shape == source_ub.shape, f"Shape mismatch: {aligned_ub.shape} vs {source_ub.shape}"
print(f" Unbatched API: shape {aligned_ub.shape} βœ“ method={info_ub['method']}")
# Test batched_procrustes_align_pair
aligned_pair, info_pair = batched_procrustes_align_pair(
source_ub, target_ub, rank=24, n_align=500)
assert aligned_pair.shape == source_ub.shape
cos_pair = F.cosine_similarity(aligned_pair, target_ub, dim=-1).mean().item()
print(f" Align-pair API: cos={cos_pair:.4f} method={info_pair['method']}")
print(f" PROCRUSTES VALIDATION COMPLETE")
return all_pass
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ BENCHMARKING β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def _cuda_timer(fn, warmup=20, iters=80):
"""CUDA-event-timed benchmark. Returns (mean_ms, std_ms, median_ms)."""
for _ in range(warmup):
fn()
torch.cuda.synchronize()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)]
for i in range(iters):
starts[i].record(); fn(); ends[i].record()
torch.cuda.synchronize()
times = torch.tensor([starts[i].elapsed_time(ends[i]) for i in range(iters)])
return times.mean().item(), times.std().item(), times.median().item()
def profile_n_sweep(B=512, M=1024):
"""Sweep N from 2 to 128. Compare all methods including projected SVD."""
device_name = torch.cuda.get_device_name(0)
print(f"\n{'='*110}")
print(f" N-DIMENSION SWEEP β€” {device_name}")
print(f" B={B}, M={M}")
print(f"{'='*110}")
print(f" {'N':>4} {'Triton':>10} {'Gram':>10} {'Newton':>10}"
f" {'Proj→24':>10} {'Proj→16':>10} {'Torch':>10} {'Best':>8} {'Speedup':>8}")
print(f" {'─'*106}")
results = []
n_values = [2, 3, 4, 5, 6, 7, 8, 10, 12, 16, 20, 24, 32, 48, 64, 96, 128]
def _fmt(ms):
if ms != ms: # nan
return f"{'β€”':>10}"
return f"{ms:>8.3f}ms"
for N in n_values:
if N > M:
continue
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
triton_ms = float('nan')
if N <= 3:
triton_ms, _, _ = _cuda_timer(lambda: batched_svd(A, method='triton'))
torch_ms, _, _ = _cuda_timer(lambda: torch.linalg.svd(A, full_matrices=False))
gram_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A))
newton_ms = float('nan')
if N >= 8:
newton_ms, _, _ = _cuda_timer(lambda: newton_svd(A))
proj24_ms = float('nan')
if N >= 32:
proj24_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(24, N-1)))
proj16_ms = float('nan')
if N >= 24:
proj16_ms, _, _ = _cuda_timer(lambda: projected_svd(A, target_rank=min(16, N-1)))
# Determine best
times = {'torch': torch_ms, 'gram': gram_ms}
if N <= 3: times['triton'] = triton_ms
if N >= 8: times['newton'] = newton_ms
if N >= 32: times['proj24'] = proj24_ms
if N >= 24: times['proj16'] = proj16_ms
best = min(times, key=times.get)
speedup = torch_ms / (times[best] + 1e-9)
print(f" {N:>4} {_fmt(triton_ms)} {_fmt(gram_ms)} {_fmt(newton_ms)}"
f" {_fmt(proj24_ms)} {_fmt(proj16_ms)} {_fmt(torch_ms)}"
f" {best:>8} {speedup:>7.1f}x")
row = {'N': N, 'B': B, 'M': M, 'torch_ms': round(torch_ms, 4),
'gram_ms': round(gram_ms, 4), 'best': best,
'speedup_vs_torch': round(speedup, 3)}
for k, v in [('triton_ms', triton_ms), ('newton_ms', newton_ms),
('proj24_ms', proj24_ms), ('proj16_ms', proj16_ms)]:
if v == v: row[k] = round(v, 4)
results.append(row)
del A; torch.cuda.empty_cache()
return results
def profile_projection_quality(B=256, M=1024):
"""Measure projection quality: how much information does rank-k SVD preserve?
For each N, tests multiple target_rank values. Reports:
- Energy ratio: fraction of total singular value energy in top-k
- Reconstruction error: projected vs full SVD
- Subspace agreement: cosine of principal angles between subspaces
- Timing: projected vs full SVD
"""
print(f"\n{'='*100}")
print(f" PROJECTION QUALITY ANALYSIS β€” B={B}, M={M}")
print(f" Question: can rank-k SVD approximate rank-N SVD?")
print(f"{'='*100}")
configs = [
# (N, [target_ranks to test])
(32, [8, 12, 16, 24]),
(48, [8, 12, 16, 24, 32]),
(64, [8, 12, 16, 24, 32, 48]),
(96, [8, 16, 24, 32, 48, 64]),
(128, [8, 16, 24, 32, 48, 64, 96]),
]
all_results = []
for N, ranks in configs:
if N > M:
continue
print(f"\n N={N}:")
print(f" {'k':>5} {'Energy%':>8} {'Recon_proj':>11} {'Recon_trunc':>12}"
f" {'S_rel_err':>10} {'Subspace':>9} {'Proj ms':>10} {'Full ms':>10} {'Speedup':>8}")
print(f" {'─'*96}")
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
# Time full SVD once
full_ms, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40)
for k in ranks:
if k >= N:
continue
q = projected_svd_quality(A, target_rank=k)
proj_ms, _, _ = _cuda_timer(
lambda: projected_svd(A, target_rank=k), warmup=10, iters=40)
speedup = full_ms / (proj_ms + 1e-9)
print(f" {k:>5} {q['energy_ratio']*100:>7.2f}% {q['recon_proj']:>11.2e}"
f" {q['recon_trunc']:>12.2e} {q['s_rel_err']:>10.4f}"
f" {q['subspace_cos']:>9.4f} {proj_ms:>8.3f}ms {full_ms:>8.3f}ms"
f" {speedup:>7.1f}x")
all_results.append({
'N': N, 'k': k, 'B': B, 'M': M,
'energy_ratio': round(q['energy_ratio'], 6),
'recon_proj': round(q['recon_proj'], 8),
'recon_trunc': round(q['recon_trunc'], 8),
's_rel_err': round(q['s_rel_err'], 6),
'subspace_cos': round(q['subspace_cos'], 6),
'proj_ms': round(proj_ms, 4),
'full_ms': round(full_ms, 4),
})
del A; torch.cuda.empty_cache()
# Summary table
print(f"\n {'─'*70}")
print(f" SUMMARY: Recommended target_rank per N")
print(f" (β‰₯99% energy, β‰₯0.99 subspace cos, best speedup)")
print(f" {'─'*70}")
for N, ranks in configs:
good = [r for r in all_results if r['N'] == N
and r['energy_ratio'] >= 0.99 and r['subspace_cos'] >= 0.99]
if good:
best = min(good, key=lambda r: r['k'])
print(f" N={N:>3}: k={best['k']:>3} β†’ {best['energy_ratio']*100:.1f}% energy,"
f" subspace={best['subspace_cos']:.4f},"
f" {best['full_ms']/best['proj_ms']:.1f}x speedup")
else:
# Find best available
available = [r for r in all_results if r['N'] == N]
if available:
best = max(available, key=lambda r: r['energy_ratio'])
print(f" N={N:>3}: best k={best['k']:>3} β†’ {best['energy_ratio']*100:.1f}% energy,"
f" subspace={best['subspace_cos']:.4f} (below 99% threshold)")
return all_results
def profile_batch_sweep(N=3, M=1024):
"""Sweep batch size for a fixed N. Shows scaling behavior."""
print(f"\n{'='*70}")
print(f" BATCH SWEEP β€” N={N}, M={M}")
print(f"{'='*70}")
print(f" {'B':>6} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8} {'img/s':>12}")
print(f" {'─'*52}")
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
results = []
for B in batch_sizes:
try:
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
except RuntimeError:
print(f" {B:>6} OOM")
break
auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto'))
torch_mean, _, _ = _cuda_timer(
lambda: torch.linalg.svd(A, full_matrices=False))
speedup = torch_mean / (auto_mean + 1e-9)
ips = B / (auto_mean / 1000)
print(f" {B:>6} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x {ips:>11,.0f}")
results.append({'B': B, 'N': N, 'M': M,
'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4),
'speedup': round(speedup, 3)})
del A; torch.cuda.empty_cache()
return results
def profile_spatial_sweep(N=3, B=512):
"""Sweep spatial dimension M for a fixed N. Shows tiling efficiency."""
print(f"\n{'='*70}")
print(f" SPATIAL SWEEP β€” N={N}, B={B}")
print(f"{'='*70}")
print(f" {'M':>6} {'~HxW':>8} {'Auto ms':>10} {'Torch ms':>10} {'Speedup':>8}")
print(f" {'─'*48}")
m_values = [16, 64, 256, 512, 1024, 2048, 4096, 8192, 16384]
results = []
for M in m_values:
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
hw = int(M**0.5)
tag = f"{hw}Γ—{hw}" if hw * hw == M else f"{M}"
auto_mean, _, _ = _cuda_timer(lambda: batched_svd(A, method='auto'))
torch_mean, _, _ = _cuda_timer(
lambda: torch.linalg.svd(A, full_matrices=False))
speedup = torch_mean / (auto_mean + 1e-9)
print(f" {M:>6} {tag:>8} {auto_mean:>8.3f}ms {torch_mean:>8.3f}ms {speedup:>7.2f}x")
results.append({'M': M, 'N': N, 'B': B,
'auto_ms': round(auto_mean, 4), 'torch_ms': round(torch_mean, 4),
'speedup': round(speedup, 3)})
del A; torch.cuda.empty_cache()
return results
def profile_crossover_detail(M=1024, B=512):
"""Fine-grained N sweep around expected crossover points."""
print(f"\n{'='*70}")
print(f" CROSSOVER DETAIL β€” B={B}, M={M}")
print(f"{'='*70}")
print(f" {'N':>4} {'Gram ms':>10} {'Torch ms':>10} {'Winner':>8} {'Margin':>8}")
print(f" {'─'*46}")
for N in range(2, 65):
if N > M:
break
A = torch.randn(B, M, N, device="cuda", dtype=torch.float32)
gram_mean, _, _ = _cuda_timer(lambda: gram_eigh_svd(A), warmup=10, iters=40)
torch_mean, _, _ = _cuda_timer(
lambda: torch.linalg.svd(A, full_matrices=False), warmup=10, iters=40)
winner = "gram" if gram_mean < torch_mean else "torch"
margin = abs(gram_mean - torch_mean) / min(gram_mean, torch_mean) * 100
print(f" {N:>4} {gram_mean:>8.3f}ms {torch_mean:>8.3f}ms {winner:>8} {margin:>6.1f}%")
del A; torch.cuda.empty_cache()
# ╔═══════════════════════════════════════════════════════════════════════════╗
# β•‘ MAIN β•‘
# β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
def main():
"""Full profiling suite."""
assert torch.cuda.is_available(), "CUDA required"
device_name = torch.cuda.get_device_name(0)
print(f"{'='*80}")
print(f" Generalized Batched Thin SVD β€” Profiling Suite")
print(f" Device: {device_name}")
print(f"{'='*80}")
# Correctness first
run_validation(B=64, M=1024)
# Procrustes alignment quality β€” THE REAL QUESTION
# Does rank-k Procrustes produce the same rotation as rank-N?
procrustes_results = profile_procrustes_quality()
# Projection quality analysis β€” energy/reconstruction perspective
proj_results = profile_projection_quality(B=256, M=1024)
# N dimension sweep β€” timing comparison
n_results = profile_n_sweep(B=512, M=1024)
# Skip batch/spatial/crossover sweeps by default β€” uncomment if needed
batch_results = {}
spatial_results = {}
# for N in [3, 8, 32, 64]:
# batch_results[N] = profile_batch_sweep(N=N, M=1024)
# for N in [3, 16, 48]:
# spatial_results[N] = profile_spatial_sweep(N=N, B=512)
# profile_crossover_detail(M=1024, B=512)
# Summary
print(f"\n{'='*80}")
print(f" SUMMARY")
print(f"{'='*80}")
print(f"\n Strategy by N:")
print(f" N=2: Fused Triton (closed-form Jacobi rotation)")
print(f" N=3: Fused Triton (cyclic Jacobi in registers)")
print(f" N=4-32: Gram + eigh (bmm + cuSOLVER eigh) β€” sub-ms")
print(f" N=48+: Projected SVD (N→k, cheap SVD, lift back) — check quality table")
print(f"")
print(f" Standalone utilities:")
print(f" newton_schulz_invsqrt(G) β€” batched G^{{-1/2}} via pure bmm")
print(f" projected_svd(A, target_rank=k) β€” rank-k approximate SVD")
print(f" projected_svd_quality(A, target_rank) β€” measure approximation quality")
print(f"")
print(f" Key question answered: energy_ratio and subspace_cos in quality table")
# Save results
report = {
'device': device_name,
'procrustes_quality': procrustes_results,
'projection_quality': proj_results,
'n_sweep': n_results,
'batch_sweeps': {str(k): v for k, v in batch_results.items()},
'spatial_sweeps': {str(k): v for k, v in spatial_results.items()},
}
with open('svd_general_profile.json', 'w') as f:
json.dump(report, f, indent=2)
print(f"\n Results saved to svd_general_profile.json")
print(f"{'='*80}")
if __name__ == "__main__":
main()