Fused Batched Thin SVD, Part II: Extending the Jacobi Pipeline to N=6 with Configurable Convergence
A continuation of Fused Batched Thin SVD: Engineering a 5000× Speedup with Triton Kernels, now shipped inAbstractEyes/geolip-core.The first article documented the N=2 and N=3 fused kernels. The Gram-Eigh hybrid handled N=4 through N=32, with a sub-millisecond floor that was "good enough" for most uses but still ~10× slower than the fused path. This follow-up walks through the N=4, N=5, and N=6 fused kernels that close that gap, the configurable Jacobi convergence parameter that makes them numerically safe in fp32, the new fp32/fp64 dtype contract, and the backend toggle infrastructure that lets you pin a specific dispatch path for testing or fallback validation.
Recap
Thin SVD on (B, M, N) matrices with M ≫ N reduces to:
- Form the Gram matrix
G = AᵀAof shape(N, N), - Eigendecompose
Gto getS² = eigenvalues,V = eigenvectors, - Recover
U = AV / S.
For small N this is bandwidth-bound on the A reads, with an N×N eigensolve that fits in scalar registers. cuSOLVER doesn't know this, so it dispatches a full bidiagonalization pipeline and pays five-figure overhead per batch element. The fused Triton kernel for N=3 ran the entire pipeline — Gram accumulation, cyclic Jacobi sweep, sort, U recovery — in a single program instance per batch, hitting 0.022ms at B=512, M=1024 versus 117.5ms for torch.linalg.svd. That's the 5,488× number from Part I.
The catch was the upper bound. At N=4 the pipeline fell back to torch.bmm + torch.linalg.eigh, which works fine but launches three kernels and serializes through cuSOLVER's eigh. At N=4, M=1024, B=512 that's roughly ~250µs — perfectly usable, but not in the same league as the fused path.
This article is about pushing the fused boundary out to N=6, why N=6 specifically required revisiting the Jacobi convergence schedule, and the dtype-aware backend rebuild that came with it.
1. The Pair-Count Cliff at N=4 and Beyond
The cyclic Jacobi sweep zeros each off-diagonal g_pq of the Gram matrix in turn. For an N×N symmetric matrix, one full sweep applies rotations to all N(N-1)/2 upper-triangle pairs. After 6 sweeps the matrix is diagonal to machine precision for N=3 (3 pairs/sweep, 18 total rotations).
Scaling that up:
| N | Pairs/sweep | Rotations at 6 sweeps | Rotations at 12 sweeps |
|---|---|---|---|
| 2 | 1 | 6 (1 needed) | — |
| 3 | 3 | 18 | — |
| 4 | 6 | 36 | 72 |
| 5 | 10 | 60 | 120 |
| 6 | 15 | 90 | 180 |
Each rotation is six FMAs on the diagonal, two on each off-diagonal that the rotation touches, and a column-pair update on V. The arithmetic itself is cheap — every operand is in a scalar register. The problem is two-fold: register pressure for V and accumulated round-off in fp32.
Register pressure
The N=3 kernel holds 6 unique Gram entries plus 9 V entries — 15 scalar slots. N=6 needs 21 unique Gram entries plus 36 V entries — 57 scalar slots, plus the per-tile load registers and Jacobi temporaries. That's well within Blackwell's per-thread register file (255 scalars), but it's no longer trivial. The N=6 kernel was the first one where I had to think about register coloring across the rotation pattern to avoid spilling.
Accumulated round-off
This is the subtle one. Each Givens rotation is exactly orthogonal in real arithmetic, but in fp32 it's exactly orthogonal to ~6 decimal digits. After 90 rotations the accumulated error in VᵀV − I drifts to ~4×10⁻³ on a Blackwell — above the 1e-3 orthogonality tolerance the validation suite checks for.
For N=2,3,4,5 with 6 sweeps this never showed up. At N=6 it became a hard floor. The fix wasn't "use a smaller tolerance" — the fix was to give N=6 more sweeps so the post-convergence matrix is closer to truly diagonal, which means the accumulated rotation closer to truly orthogonal. Twelve sweeps (180 rotations) puts orth error well under 1e-3 with negligible runtime cost relative to the eigh fallback.
The convergence sweep test built into kernel.py's validation harness measures reconstruction error on a controlled-spectrum fixture (A = U_o · diag(S) · V_oᵀ with S ∈ [0.5, 2.0], condition number ≤ 4). On Blackwell:
[convergence sweep — backend.resolve_svd_nN]
fp32 N=3: it=2:9.66e-09 it=4:9.82e-09 it=6:9.82e-09 it=8:9.82e-09 it=12:9.82e-09
fp32 N=4: it=2:1.57e-08 it=4:1.86e-08 it=6:1.86e-08 it=8:1.86e-08 it=12:1.86e-08
fp32 N=5: it=2:1.66e-08 it=4:2.40e-08 it=6:2.40e-08 it=8:2.40e-08 it=12:2.40e-08
fp32 N=6: it=2:1.98e-08 it=4:2.94e-08 it=6:2.95e-08 it=8:2.95e-08 it=12:2.95e-08
fp64 N=6: it=2:3.13e-14 it=4:3.12e-14 it=6:3.12e-14 it=8:3.12e-14 it=12:3.12e-14
Reconstruction error plateaus at iter=2 across the board on this fixture — fp32 hits its noise floor (10⁻⁸) and fp64 hits its noise floor (10⁻¹⁴) immediately. So why ship N=6 with jacobi_iters=12?
Because the convergence sweep doesn't tell the whole story. The 4×10⁻³ orthogonality drift shows up on the random Gaussian inputs that occur in real training loops, not on the controlled-spectrum fixture. A randn(B, M, 6) matrix has condition numbers that occasionally land in the 50–500 range, and at fp32 the accumulated round-off across 90 rotations (15 pairs × 6 sweeps) drifts ‖VᵀV − I‖ past the 1×10⁻³ tolerance the validation harness requires. Twelve sweeps (180 rotations) doesn't reduce per-rotation error — it drives the off-diagonals further toward zero so the terminal V is closer to truly orthogonal even with the same per-step round-off.
The convergence sweep validates "well-conditioned inputs converge fast." The 12-iter default protects against the ill-conditioned case the sweep deliberately avoids.
For N=2 through N=5 the rotation count stays low enough that 6 sweeps stay safely under 1×10⁻³ orth even on random inputs. The default for those kernels stays at 6. If you know your inputs are well-conditioned (Gram precomputed, bounded spectrum, etc.) you can drop N=6 to jacobi_iters=6 and reclaim ~40% of the kernel time.
2. The N=4, N=5, N=6 Kernel Structure
Each kernel follows the same three-stage pattern as N=3, with one wrinkle for the U recovery step at higher N.
Stage 1 — Tiled Gram accumulation
for block_start in range(0, M, BLOCK_M):
a0 = tl.load(A_ptr + base + row_idx*N + 0, ...)
a1 = tl.load(A_ptr + base + row_idx*N + 1, ...)
# ... a2..a5
g00 += tl.sum(a0*a0); g01 += tl.sum(a0*a1); ...
N(N+1)/2 scalar accumulators. For N=6 that's 21 scalars; the entire upper triangle stays in registers across all tiles.
Stage 2 — Cyclic Jacobi over all pairs
For each Jacobi sweep, walk pairs (0,1), (0,2), ..., (N-2, N-1) in lexicographic order:
# pair (p,q) — for example (0,1) at N=6:
off_diag = g01; diag_diff = g11 - g00
tau = diag_diff / (2 * off_diag) # guarded against off_diag == 0
t = sign(tau) / (|tau| + sqrt(1 + tau*tau))
c = 1 / sqrt(1 + t*t); s = t*c
# Update G — only the row/col touched by (p,q) changes:
ng00 = c*c*g00 - 2*s*c*g01 + s*s*g11
ng11 = s*s*g00 + 2*s*c*g01 + c*c*g11
# Plus the cross-row updates: ng02, ng12, ng03, ng13, ng04, ng14, ng05, ng15
g01 = 0 # by construction
# Update V — columns p and q rotate together:
nv00 = c*v00 - s*v01; nv01 = s*v00 + c*v01
nv10 = c*v10 - s*v11; nv11 = s*v10 + c*v11
# ... rows 2 through 5
The total per-pair work for N=6 is roughly: 4 trig-replacement ops to compute (c, s), 8 FMAs to update the two diagonal entries, 4 FMAs × 4 cross-rows = 16 FMAs to update the off-diagonals the rotation touches, and 2 FMAs × 6 rows = 12 FMAs to update V. About 40 FMAs per pair × 15 pairs × 12 sweeps = ~7,200 FMAs per batch element. On a Blackwell SM that's a couple of microseconds — the wall-clock is still dominated by the A read.
Stage 2b — Selection sort with V column swaps
For N=6 there are C(6,2) = 15 compare-and-swap operations to fully sort. Each swap permutes a V column (6 scalar swaps). Done with tl.where to keep the kernel branchless.
Stage 3 — U recovery
for block_start in range(0, M, BLOCK_M):
a0..a5 = tl.load(...)
u0 = (a0*v00 + a1*v10 + a2*v20 + a3*v30 + a4*v40 + a5*v50) * inv_s0
# ... u1..u5
tl.store(...)
Same tiling as Stage 1. V entries are in registers from Stage 2. For N=6 each output row is 6 dot products of length 6 — 36 FMAs per output tile element, with the V matrix entirely in registers. No shared memory at any point in the kernel.
3. The fp32 / fp64 Contract
The Part I kernels were fp32 only. The new code paths take a DTYPE: tl.constexpr so the same kernel source compiles for both fp32 and fp64. The Python wrapper picks the dtype from the input tensor:
def _triton_dtype(A):
if A.dtype == torch.float32: return tl.float32
if A.dtype == torch.float64: return tl.float64
raise TypeError(f"Triton SVD kernels support fp32/fp64 only, got {A.dtype}")
For unsupported dtypes (fp16, bf16, complex), the wrapper routes to a _torch_svd_fallback that upcasts to fp32, runs torch.linalg.svd, and casts the results back to preserve the output-dtype-matches-input contract:
def batched_svd6(A, block_m=128, jacobi_iters=12):
if not HAS_TRITON or not A.is_cuda or A.dtype not in (torch.float32, torch.float64):
return _torch_svd_fallback(A)
# ... fused path
Why bother with fp64 in a kernel that's bandwidth-bound? Two reasons:
- The orthogonality drift problem disappears. Where fp32 N=6 needs 12 sweeps, fp64 N=6 hits machine precision at 6. That said,
jacobi_iters=12is set as the default for safety; if you're confident in your input conditioning you can drop it. - Downstream pipelines that need fp64 throughout (Cayley-Menger volume validation, characteristic polynomial coefficients, Gram-CV diagnostics) can now run the SVD step at the same precision instead of casting around it.
The dtype contract is verified in the validation harness:
[N=4 triton fp64 dtype]
PASS: U.dtype == torch.float64 and S.dtype == torch.float64 and Vh.dtype == torch.float64
[N=4 fp16 fallback dtype]
PASS: U.dtype == torch.float16 and S.dtype == torch.float16 and Vh.dtype == torch.float16
[N=4 fp16 fallback recon]
PASS: err=2.18e-04
4. Backend Toggles for Test and Dispatch Control
geolip-core exposes a small backend module that lets you pin which path the dispatcher takes:
from geolip_core.linalg._backend import backend
backend.use_triton = False # Force the Gram-Eigh path even when Triton is available
backend.use_fl_eigh = False # Force cuSOLVER eigh even when FL is available
backend.status()
# geolip_core.linalg backend:
# CUDA: yes
# Triton: 3.6.0 (disabled)
# FL eigh: disabled
# GPU: NVIDIA RTX PRO 6000 Blackwell Server Edition
This was added because the original dispatcher made it hard to verify the fallback path was correct. With both use_triton and use_fl_eigh off, every SVD call routes through gram_eigh_svd → torch.linalg.eigh → cuSOLVER, which is the path you want to validate when shipping to a machine without Triton. The test harness exercises exactly this:
[backend toggle — Triton OFF, FL OFF]
PASS: off/fp32 4x4 recon err=1.91e-06
PASS: off/fp32 5x5 recon err=2.31e-06
PASS: off/fp64 6x6 recon err=8.43e-15
...
The toggle is also useful when you suspect a numerical regression in the fused kernel — flip Triton off, rerun your training loop, and bisect against the cuSOLVER result.
5. Updated Dispatch Map
The auto-dispatcher now routes:
N == 2 → batched_svd2 (fused, single Jacobi rotation)
N == 3 → batched_svd3 (fused, 6 Jacobi sweeps)
N == 4 → batched_svd4 (fused, 6 Jacobi sweeps)
N == 5 → batched_svd5 (fused, 6 Jacobi sweeps)
N == 6 → batched_svd6 (fused, 12 Jacobi sweeps)
7 ≤ N ≤ 32 → gram_eigh_svd (Gram + cuSOLVER eigh)
N ≥ 33 → gram_eigh_svd (still works; large-N users should consider
rank-projected SVD from Part I)
The fused boundary is N=6 because N=7 would need 21 pairs/sweep and 49 V entries plus 28 Gram entries — that's 77 scalar slots before counting tile loads and rotation temporaries. It fits but starts crowding the 255-scalar register file once the tile prefetch and rotation intermediates are accounted for, and the per-pair cost overtakes the cuBLAS bmm of the Gram path. Empirically, N=7 was no longer a clear win on Blackwell, so the cutoff stayed at 6.
6. Benchmarks
Measured with python -m geolip_core.utils.kernel on an NVIDIA RTX PRO 6000 Blackwell, B=512, M=1024, after 437/437 validation checks pass.
fp32
| Shape | torch.linalg.svd |
batched_svd (auto) |
Triton fused | Speedup |
|---|---|---|---|---|
| 1024×2 | 79.83ms | 16.0µs | 15.8µs | 5,038× |
| 1024×3 | 121.90ms | 16.1µs | 16.1µs | 7,587× |
| 1024×4 | 128.28ms | 22.6µs | 22.6µs | 5,675× |
| 1024×5 | 147.76ms | 30.8µs | 30.8µs | 4,796× |
| 1024×6 | 156.36ms | 57.4µs | 57.4µs | 2,724× |
The auto and triton columns are within noise of each other — confirming the dispatcher correctly routes N=2..6 through the fused path. cuSOLVER's per-element overhead grows linearly with N (and with B; see batch scaling below).
fp64
The fp64 numbers are new — Part I was fp32 only. Note that the RTX PRO 6000 Blackwell has consumer-class fp64 throughput (~1/64 of fp32), so the absolute timings climb fast as the kernel becomes more compute-bound.
| Shape | torch.linalg.svd |
Triton fused | Speedup |
|---|---|---|---|
| 1024×2 | 292.32ms | 37.0µs | 7,898× |
| 1024×3 | 506.73ms | 190.6µs | 2,660× |
| 1024×4 | 546.93ms | 376.3µs | 1,453× |
| 1024×5 | 631.52ms | 647.4µs | 975× |
| 1024×6 | 693.55ms | 1.94ms | 357× |
Three observations:
- The fused kernel is faster than
torch.linalg.svdat every size in fp64, but the margin shrinks as N grows. At N=6 the kernel is still 357× faster, but the absolute ratio of fp64-to-fp32 cost is ~34× — much steeper than fp32→fp64's 2× memory cost would predict. That extra factor is the per-element FMA throttling on consumer Blackwell. - fp64 N=2 is interesting: 37µs versus 16µs in fp32 — only a 2.3× slowdown. At N=2 the kernel is purely bandwidth-bound and fp64 just doubles the bytes per element.
- fp64 N=6 hits 1.94ms — at this point cuSOLVER's 693ms is still 357× slower, but if your downstream needs fp64 and you have many N=6 calls per forward pass, you should profile carefully. For most uses, fp32 + a verification call in fp64 every K steps is the right tradeoff.
Why does fp32 N=6 cost ~1.86× N=5?
N=5 with 6 sweeps does 60 rotations. N=6 with 12 sweeps does 180 rotations — a 3× increase in rotation work. But the wall time only grows 1.86×. The remaining cost is the bandwidth-bound A passes, which scale only ~1.2× from N=5 to N=6 (one extra column to load and recover). The kernel sits in a regime where compute and bandwidth are roughly balanced at N=6 — neither dominates.
If you change batched_svd6(A, jacobi_iters=6) you'd save roughly 60% of the kernel time and land near ~32µs, putting N=6 essentially at parity with N=5. That's the version to use when you control your input conditioning.
Batch scaling at N=6 (fp32)
The fused kernel is sublinear in batch size; cuSOLVER is exactly linear. The crossover where the fused path becomes "free" is well below B=256.
| B | cuSOLVER | Triton fused | Speedup |
|---|---|---|---|
| 256 | 78.14ms | 46.2µs | 1,693× |
| 512 | 155.61ms | 57.7µs | 2,695× |
| 1024 | 311.46ms | 101.0µs | 3,084× |
| 2048 | 624.57ms | 166.2µs | 3,757× |
| 4096 | 1.248s | 307.3µs | 4,061× |
| 8192 | 2.502s | 567.0µs | 4,414× |
The fused kernel time grows ~12× from B=256 to B=8192 (a 32× batch increase) — well below linear, indicating the SMs are still finding parallelism to exploit. cuSOLVER's 2.502s at B=8192 is the kind of number that turns a single SVD call into a training-loop bottleneck. The fused kernel turns the same call into 567µs.
7. API Additions
Backwards-compatible. Code written against the Part I API still works.
from geolip_core.utils.kernel import (
batched_svd, # auto-dispatcher
batched_svd2, batched_svd3, batched_svd4, # N=2,3,4 fused
batched_svd5, batched_svd6, # N=5,6 fused
gram_eigh_svd, # N≥7 path
newton_schulz_invsqrt,
batched_procrustes,
HAS_TRITON,
)
# Auto-dispatch — same as before, just covers more N
U, S, Vh = batched_svd(A)
# Force a specific path (useful for testing)
U, S, Vh = batched_svd(A, method='triton') # N=2..6 only
U, S, Vh = batched_svd(A, method='gram_eigh') # any N
U, S, Vh = batched_svd(A, method='torch') # cuSOLVER baseline
# Tune the convergence parameter for the higher-N kernels
U, S, Vh = batched_svd6(A, jacobi_iters=8) # faster, lower precision
U, S, Vh = batched_svd6(A, jacobi_iters=20) # slower, fp64-tier precision
# Backend pin (test the fallback)
from geolip_core.linalg._backend import backend
backend.use_triton = False
U, S, Vh = batched_svd(A) # routes to Gram+eigh
8. Reproducing
pip install "git+https://github.com/AbstractEyes/geolip-core.git"
python -m geolip_core.utils.kernel
Requirements: PyTorch ≥ 2.0, Triton ≥ 2.1, CUDA GPU. The __main__ block runs 437 checks across:
- Correctness — auto-dispatch + each kernel directly (
batched_svd2..6,gram_eigh_svd,newton_schulz_invsqrt,batched_procrustes), on fp32 and fp64, with fp16 routed to the upcast fallback - Size sweep — every (M, N) shape with M ∈ {N..6, 1024} and N ∈ {2..6}, validating dtype, recon, orth, and descending-S contracts
- Convergence sweep — reconstruction error vs
jacobi_iters∈ {2, 4, 6, 8, 12} at each N for fp32 and fp64; pins monotone improvement - Backend toggle — flips both
use_tritonanduse_fl_eighoff, exercises the cuSOLVER fallback path on the same shape grid - Throughput — fp32 and fp64 size sweep + N=6 batch scaling, against
torch.linalg.svd
A passing run ends with 437/437 passed — all clear. Any failure prints the specific check name and the numerical detail so you can bisect.
What's Next
Three threads remaining on the SVD side:
- N=7 and beyond as a fused kernel. The register-file argument suggests this is the soft ceiling on Blackwell, but Hopper/Ada have different occupancy tradeoffs and the answer might differ.
- Half-precision support without upcast. bf16 has enough mantissa precision for the Gram accumulation if we keep the eigensolve in fp32; the kernel structure supports this with a mixed-precision DTYPE pair, but I haven't measured it yet.
- Backward pass. The current path is forward-only; gradients flow through the cuSOLVER fallback when needed. A custom autograd that uses the same Jacobi structure would close the last factor-of-N latency gap during training.
Each of those is a follow-up post in its own right.
Citation
@software{abstractphil2026svd_part2,
title = {Fused Batched Thin SVD, Part II: Extending the Jacobi Pipeline to N=6},
author = {AbstractPhil and Claude},
year = {2026},
url = {https://github.com/AbstractEyes/geolip-core},
note = {Triton kernels, geolip\_core/utils/kernel.py}
}
License
Apache 2.0.
Part of the GEOLIP ecosystem. Repo: AbstractEyes/geolip-core. The kernel file: geolip_core/utils/kernel.py.

