Fused Batched Thin SVD, Part II: Extending the Jacobi Pipeline to N=6 with Configurable Convergence

Community Article Published May 1, 2026

A continuation of Fused Batched Thin SVD: Engineering a 5000× Speedup with Triton Kernels, now shipped in AbstractEyes/geolip-core.

The first article documented the N=2 and N=3 fused kernels. The Gram-Eigh hybrid handled N=4 through N=32, with a sub-millisecond floor that was "good enough" for most uses but still ~10× slower than the fused path. This follow-up walks through the N=4, N=5, and N=6 fused kernels that close that gap, the configurable Jacobi convergence parameter that makes them numerically safe in fp32, the new fp32/fp64 dtype contract, and the backend toggle infrastructure that lets you pin a specific dispatch path for testing or fallback validation.


Recap

Thin SVD on (B, M, N) matrices with M ≫ N reduces to:

  1. Form the Gram matrix G = AᵀA of shape (N, N),
  2. Eigendecompose G to get S² = eigenvalues, V = eigenvectors,
  3. Recover U = AV / S.

For small N this is bandwidth-bound on the A reads, with an N×N eigensolve that fits in scalar registers. cuSOLVER doesn't know this, so it dispatches a full bidiagonalization pipeline and pays five-figure overhead per batch element. The fused Triton kernel for N=3 ran the entire pipeline — Gram accumulation, cyclic Jacobi sweep, sort, U recovery — in a single program instance per batch, hitting 0.022ms at B=512, M=1024 versus 117.5ms for torch.linalg.svd. That's the 5,488× number from Part I.

The catch was the upper bound. At N=4 the pipeline fell back to torch.bmm + torch.linalg.eigh, which works fine but launches three kernels and serializes through cuSOLVER's eigh. At N=4, M=1024, B=512 that's roughly ~250µs — perfectly usable, but not in the same league as the fused path.

This article is about pushing the fused boundary out to N=6, why N=6 specifically required revisiting the Jacobi convergence schedule, and the dtype-aware backend rebuild that came with it.


1. The Pair-Count Cliff at N=4 and Beyond

The cyclic Jacobi sweep zeros each off-diagonal g_pq of the Gram matrix in turn. For an N×N symmetric matrix, one full sweep applies rotations to all N(N-1)/2 upper-triangle pairs. After 6 sweeps the matrix is diagonal to machine precision for N=3 (3 pairs/sweep, 18 total rotations).

Scaling that up:

N Pairs/sweep Rotations at 6 sweeps Rotations at 12 sweeps
2 1 6 (1 needed)
3 3 18
4 6 36 72
5 10 60 120
6 15 90 180

Each rotation is six FMAs on the diagonal, two on each off-diagonal that the rotation touches, and a column-pair update on V. The arithmetic itself is cheap — every operand is in a scalar register. The problem is two-fold: register pressure for V and accumulated round-off in fp32.

Register pressure

The N=3 kernel holds 6 unique Gram entries plus 9 V entries — 15 scalar slots. N=6 needs 21 unique Gram entries plus 36 V entries — 57 scalar slots, plus the per-tile load registers and Jacobi temporaries. That's well within Blackwell's per-thread register file (255 scalars), but it's no longer trivial. The N=6 kernel was the first one where I had to think about register coloring across the rotation pattern to avoid spilling.

Accumulated round-off

This is the subtle one. Each Givens rotation is exactly orthogonal in real arithmetic, but in fp32 it's exactly orthogonal to ~6 decimal digits. After 90 rotations the accumulated error in VᵀV − I drifts to ~4×10⁻³ on a Blackwell — above the 1e-3 orthogonality tolerance the validation suite checks for.

For N=2,3,4,5 with 6 sweeps this never showed up. At N=6 it became a hard floor. The fix wasn't "use a smaller tolerance" — the fix was to give N=6 more sweeps so the post-convergence matrix is closer to truly diagonal, which means the accumulated rotation closer to truly orthogonal. Twelve sweeps (180 rotations) puts orth error well under 1e-3 with negligible runtime cost relative to the eigh fallback.

The convergence sweep test built into kernel.py's validation harness measures reconstruction error on a controlled-spectrum fixture (A = U_o · diag(S) · V_oᵀ with S ∈ [0.5, 2.0], condition number ≤ 4). On Blackwell:

[convergence sweep — backend.resolve_svd_nN]
  fp32 N=3: it=2:9.66e-09  it=4:9.82e-09  it=6:9.82e-09  it=8:9.82e-09  it=12:9.82e-09
  fp32 N=4: it=2:1.57e-08  it=4:1.86e-08  it=6:1.86e-08  it=8:1.86e-08  it=12:1.86e-08
  fp32 N=5: it=2:1.66e-08  it=4:2.40e-08  it=6:2.40e-08  it=8:2.40e-08  it=12:2.40e-08
  fp32 N=6: it=2:1.98e-08  it=4:2.94e-08  it=6:2.95e-08  it=8:2.95e-08  it=12:2.95e-08
  fp64 N=6: it=2:3.13e-14  it=4:3.12e-14  it=6:3.12e-14  it=8:3.12e-14  it=12:3.12e-14

Reconstruction error plateaus at iter=2 across the board on this fixture — fp32 hits its noise floor (10⁻⁸) and fp64 hits its noise floor (10⁻¹⁴) immediately. So why ship N=6 with jacobi_iters=12?

Because the convergence sweep doesn't tell the whole story. The 4×10⁻³ orthogonality drift shows up on the random Gaussian inputs that occur in real training loops, not on the controlled-spectrum fixture. A randn(B, M, 6) matrix has condition numbers that occasionally land in the 50–500 range, and at fp32 the accumulated round-off across 90 rotations (15 pairs × 6 sweeps) drifts ‖VᵀV − I‖ past the 1×10⁻³ tolerance the validation harness requires. Twelve sweeps (180 rotations) doesn't reduce per-rotation error — it drives the off-diagonals further toward zero so the terminal V is closer to truly orthogonal even with the same per-step round-off.

The convergence sweep validates "well-conditioned inputs converge fast." The 12-iter default protects against the ill-conditioned case the sweep deliberately avoids.

For N=2 through N=5 the rotation count stays low enough that 6 sweeps stay safely under 1×10⁻³ orth even on random inputs. The default for those kernels stays at 6. If you know your inputs are well-conditioned (Gram precomputed, bounded spectrum, etc.) you can drop N=6 to jacobi_iters=6 and reclaim ~40% of the kernel time.


2. The N=4, N=5, N=6 Kernel Structure

Each kernel follows the same three-stage pattern as N=3, with one wrinkle for the U recovery step at higher N.

Stage 1 — Tiled Gram accumulation

for block_start in range(0, M, BLOCK_M):
    a0 = tl.load(A_ptr + base + row_idx*N + 0, ...)
    a1 = tl.load(A_ptr + base + row_idx*N + 1, ...)
    # ... a2..a5
    g00 += tl.sum(a0*a0); g01 += tl.sum(a0*a1); ...

N(N+1)/2 scalar accumulators. For N=6 that's 21 scalars; the entire upper triangle stays in registers across all tiles.

Stage 2 — Cyclic Jacobi over all pairs

For each Jacobi sweep, walk pairs (0,1), (0,2), ..., (N-2, N-1) in lexicographic order:

# pair (p,q) — for example (0,1) at N=6:
off_diag = g01; diag_diff = g11 - g00
tau = diag_diff / (2 * off_diag)   # guarded against off_diag == 0
t   = sign(tau) / (|tau| + sqrt(1 + tau*tau))
c   = 1 / sqrt(1 + t*t); s = t*c

# Update G — only the row/col touched by (p,q) changes:
ng00 = c*c*g00 - 2*s*c*g01 + s*s*g11
ng11 = s*s*g00 + 2*s*c*g01 + c*c*g11
# Plus the cross-row updates: ng02, ng12, ng03, ng13, ng04, ng14, ng05, ng15
g01 = 0  # by construction

# Update V — columns p and q rotate together:
nv00 = c*v00 - s*v01;  nv01 = s*v00 + c*v01
nv10 = c*v10 - s*v11;  nv11 = s*v10 + c*v11
# ... rows 2 through 5

The total per-pair work for N=6 is roughly: 4 trig-replacement ops to compute (c, s), 8 FMAs to update the two diagonal entries, 4 FMAs × 4 cross-rows = 16 FMAs to update the off-diagonals the rotation touches, and 2 FMAs × 6 rows = 12 FMAs to update V. About 40 FMAs per pair × 15 pairs × 12 sweeps = ~7,200 FMAs per batch element. On a Blackwell SM that's a couple of microseconds — the wall-clock is still dominated by the A read.

Stage 2b — Selection sort with V column swaps

For N=6 there are C(6,2) = 15 compare-and-swap operations to fully sort. Each swap permutes a V column (6 scalar swaps). Done with tl.where to keep the kernel branchless.

Stage 3 — U recovery

for block_start in range(0, M, BLOCK_M):
    a0..a5 = tl.load(...)
    u0 = (a0*v00 + a1*v10 + a2*v20 + a3*v30 + a4*v40 + a5*v50) * inv_s0
    # ... u1..u5
    tl.store(...)

Same tiling as Stage 1. V entries are in registers from Stage 2. For N=6 each output row is 6 dot products of length 6 — 36 FMAs per output tile element, with the V matrix entirely in registers. No shared memory at any point in the kernel.


3. The fp32 / fp64 Contract

The Part I kernels were fp32 only. The new code paths take a DTYPE: tl.constexpr so the same kernel source compiles for both fp32 and fp64. The Python wrapper picks the dtype from the input tensor:

def _triton_dtype(A):
    if A.dtype == torch.float32: return tl.float32
    if A.dtype == torch.float64: return tl.float64
    raise TypeError(f"Triton SVD kernels support fp32/fp64 only, got {A.dtype}")

For unsupported dtypes (fp16, bf16, complex), the wrapper routes to a _torch_svd_fallback that upcasts to fp32, runs torch.linalg.svd, and casts the results back to preserve the output-dtype-matches-input contract:

def batched_svd6(A, block_m=128, jacobi_iters=12):
    if not HAS_TRITON or not A.is_cuda or A.dtype not in (torch.float32, torch.float64):
        return _torch_svd_fallback(A)
    # ... fused path

Why bother with fp64 in a kernel that's bandwidth-bound? Two reasons:

  1. The orthogonality drift problem disappears. Where fp32 N=6 needs 12 sweeps, fp64 N=6 hits machine precision at 6. That said, jacobi_iters=12 is set as the default for safety; if you're confident in your input conditioning you can drop it.
  2. Downstream pipelines that need fp64 throughout (Cayley-Menger volume validation, characteristic polynomial coefficients, Gram-CV diagnostics) can now run the SVD step at the same precision instead of casting around it.

The dtype contract is verified in the validation harness:

[N=4 triton fp64 dtype]
  PASS: U.dtype == torch.float64 and S.dtype == torch.float64 and Vh.dtype == torch.float64
[N=4 fp16 fallback dtype]
  PASS: U.dtype == torch.float16 and S.dtype == torch.float16 and Vh.dtype == torch.float16
[N=4 fp16 fallback recon]
  PASS: err=2.18e-04

4. Backend Toggles for Test and Dispatch Control

geolip-core exposes a small backend module that lets you pin which path the dispatcher takes:

from geolip_core.linalg._backend import backend

backend.use_triton = False   # Force the Gram-Eigh path even when Triton is available
backend.use_fl_eigh = False  # Force cuSOLVER eigh even when FL is available

backend.status()
# geolip_core.linalg backend:
#   CUDA:       yes
#   Triton:     3.6.0 (disabled)
#   FL eigh:    disabled
#   GPU:        NVIDIA RTX PRO 6000 Blackwell Server Edition

This was added because the original dispatcher made it hard to verify the fallback path was correct. With both use_triton and use_fl_eigh off, every SVD call routes through gram_eigh_svdtorch.linalg.eigh → cuSOLVER, which is the path you want to validate when shipping to a machine without Triton. The test harness exercises exactly this:

[backend toggle — Triton OFF, FL OFF]
  PASS: off/fp32 4x4 recon  err=1.91e-06
  PASS: off/fp32 5x5 recon  err=2.31e-06
  PASS: off/fp64 6x6 recon  err=8.43e-15
  ...

The toggle is also useful when you suspect a numerical regression in the fused kernel — flip Triton off, rerun your training loop, and bisect against the cuSOLVER result.


5. Updated Dispatch Map

The auto-dispatcher now routes:

N == 2          → batched_svd2   (fused, single Jacobi rotation)
N == 3          → batched_svd3   (fused, 6 Jacobi sweeps)
N == 4          → batched_svd4   (fused, 6 Jacobi sweeps)
N == 5          → batched_svd5   (fused, 6 Jacobi sweeps)
N == 6          → batched_svd6   (fused, 12 Jacobi sweeps)
7 ≤ N ≤ 32      → gram_eigh_svd  (Gram + cuSOLVER eigh)
N ≥ 33          → gram_eigh_svd  (still works; large-N users should consider
                                  rank-projected SVD from Part I)

The fused boundary is N=6 because N=7 would need 21 pairs/sweep and 49 V entries plus 28 Gram entries — that's 77 scalar slots before counting tile loads and rotation temporaries. It fits but starts crowding the 255-scalar register file once the tile prefetch and rotation intermediates are accounted for, and the per-pair cost overtakes the cuBLAS bmm of the Gram path. Empirically, N=7 was no longer a clear win on Blackwell, so the cutoff stayed at 6.


6. Benchmarks

Measured with python -m geolip_core.utils.kernel on an NVIDIA RTX PRO 6000 Blackwell, B=512, M=1024, after 437/437 validation checks pass.

fp32

Shape torch.linalg.svd batched_svd (auto) Triton fused Speedup
1024×2 79.83ms 16.0µs 15.8µs 5,038×
1024×3 121.90ms 16.1µs 16.1µs 7,587×
1024×4 128.28ms 22.6µs 22.6µs 5,675×
1024×5 147.76ms 30.8µs 30.8µs 4,796×
1024×6 156.36ms 57.4µs 57.4µs 2,724×

The auto and triton columns are within noise of each other — confirming the dispatcher correctly routes N=2..6 through the fused path. cuSOLVER's per-element overhead grows linearly with N (and with B; see batch scaling below).

fp64

The fp64 numbers are new — Part I was fp32 only. Note that the RTX PRO 6000 Blackwell has consumer-class fp64 throughput (~1/64 of fp32), so the absolute timings climb fast as the kernel becomes more compute-bound.

Shape torch.linalg.svd Triton fused Speedup
1024×2 292.32ms 37.0µs 7,898×
1024×3 506.73ms 190.6µs 2,660×
1024×4 546.93ms 376.3µs 1,453×
1024×5 631.52ms 647.4µs 975×
1024×6 693.55ms 1.94ms 357×

Three observations:

  1. The fused kernel is faster than torch.linalg.svd at every size in fp64, but the margin shrinks as N grows. At N=6 the kernel is still 357× faster, but the absolute ratio of fp64-to-fp32 cost is ~34× — much steeper than fp32→fp64's 2× memory cost would predict. That extra factor is the per-element FMA throttling on consumer Blackwell.
  2. fp64 N=2 is interesting: 37µs versus 16µs in fp32 — only a 2.3× slowdown. At N=2 the kernel is purely bandwidth-bound and fp64 just doubles the bytes per element.
  3. fp64 N=6 hits 1.94ms — at this point cuSOLVER's 693ms is still 357× slower, but if your downstream needs fp64 and you have many N=6 calls per forward pass, you should profile carefully. For most uses, fp32 + a verification call in fp64 every K steps is the right tradeoff.

Why does fp32 N=6 cost ~1.86× N=5?

N=5 with 6 sweeps does 60 rotations. N=6 with 12 sweeps does 180 rotations — a 3× increase in rotation work. But the wall time only grows 1.86×. The remaining cost is the bandwidth-bound A passes, which scale only ~1.2× from N=5 to N=6 (one extra column to load and recover). The kernel sits in a regime where compute and bandwidth are roughly balanced at N=6 — neither dominates.

If you change batched_svd6(A, jacobi_iters=6) you'd save roughly 60% of the kernel time and land near ~32µs, putting N=6 essentially at parity with N=5. That's the version to use when you control your input conditioning.

Batch scaling at N=6 (fp32)

The fused kernel is sublinear in batch size; cuSOLVER is exactly linear. The crossover where the fused path becomes "free" is well below B=256.

B cuSOLVER Triton fused Speedup
256 78.14ms 46.2µs 1,693×
512 155.61ms 57.7µs 2,695×
1024 311.46ms 101.0µs 3,084×
2048 624.57ms 166.2µs 3,757×
4096 1.248s 307.3µs 4,061×
8192 2.502s 567.0µs 4,414×

The fused kernel time grows ~12× from B=256 to B=8192 (a 32× batch increase) — well below linear, indicating the SMs are still finding parallelism to exploit. cuSOLVER's 2.502s at B=8192 is the kind of number that turns a single SVD call into a training-loop bottleneck. The fused kernel turns the same call into 567µs.


7. API Additions

Backwards-compatible. Code written against the Part I API still works.

from geolip_core.utils.kernel import (
    batched_svd,                                              # auto-dispatcher
    batched_svd2, batched_svd3, batched_svd4,                 # N=2,3,4 fused
    batched_svd5, batched_svd6,                               # N=5,6 fused
    gram_eigh_svd,                                            # N≥7 path
    newton_schulz_invsqrt,
    batched_procrustes,
    HAS_TRITON,
)

# Auto-dispatch — same as before, just covers more N
U, S, Vh = batched_svd(A)

# Force a specific path (useful for testing)
U, S, Vh = batched_svd(A, method='triton')      # N=2..6 only
U, S, Vh = batched_svd(A, method='gram_eigh')   # any N
U, S, Vh = batched_svd(A, method='torch')       # cuSOLVER baseline

# Tune the convergence parameter for the higher-N kernels
U, S, Vh = batched_svd6(A, jacobi_iters=8)      # faster, lower precision
U, S, Vh = batched_svd6(A, jacobi_iters=20)     # slower, fp64-tier precision

# Backend pin (test the fallback)
from geolip_core.linalg._backend import backend
backend.use_triton = False
U, S, Vh = batched_svd(A)   # routes to Gram+eigh

8. Reproducing

pip install "git+https://github.com/AbstractEyes/geolip-core.git"
python -m geolip_core.utils.kernel

Requirements: PyTorch ≥ 2.0, Triton ≥ 2.1, CUDA GPU. The __main__ block runs 437 checks across:

  1. Correctness — auto-dispatch + each kernel directly (batched_svd2..6, gram_eigh_svd, newton_schulz_invsqrt, batched_procrustes), on fp32 and fp64, with fp16 routed to the upcast fallback
  2. Size sweep — every (M, N) shape with M ∈ {N..6, 1024} and N ∈ {2..6}, validating dtype, recon, orth, and descending-S contracts
  3. Convergence sweep — reconstruction error vs jacobi_iters ∈ {2, 4, 6, 8, 12} at each N for fp32 and fp64; pins monotone improvement
  4. Backend toggle — flips both use_triton and use_fl_eigh off, exercises the cuSOLVER fallback path on the same shape grid
  5. Throughput — fp32 and fp64 size sweep + N=6 batch scaling, against torch.linalg.svd

A passing run ends with 437/437 passed — all clear. Any failure prints the specific check name and the numerical detail so you can bisect.


What's Next

Three threads remaining on the SVD side:

  • N=7 and beyond as a fused kernel. The register-file argument suggests this is the soft ceiling on Blackwell, but Hopper/Ada have different occupancy tradeoffs and the answer might differ.
  • Half-precision support without upcast. bf16 has enough mantissa precision for the Gram accumulation if we keep the eigensolve in fp32; the kernel structure supports this with a mixed-precision DTYPE pair, but I haven't measured it yet.
  • Backward pass. The current path is forward-only; gradients flow through the cuSOLVER fallback when needed. A custom autograd that uses the same Jacobi structure would close the last factor-of-N latency gap during training.

Each of those is a follow-up post in its own right.


Citation

@software{abstractphil2026svd_part2,
  title  = {Fused Batched Thin SVD, Part II: Extending the Jacobi Pipeline to N=6},
  author = {AbstractPhil and Claude},
  year   = {2026},
  url    = {https://github.com/AbstractEyes/geolip-core},
  note   = {Triton kernels, geolip\_core/utils/kernel.py}
}

License

Apache 2.0.


Part of the GEOLIP ecosystem. Repo: AbstractEyes/geolip-core. The kernel file: geolip_core/utils/kernel.py.

Community

Sign up or log in to comment